View Javadoc
1   package org.apache.mina.filter.ssl;
2   
3   import static org.junit.Assert.fail;
4   
5   import java.io.IOException;
6   import java.nio.ByteBuffer;
7   import java.security.GeneralSecurityException;
8   import java.security.KeyStore;
9   import java.security.Security;
10  import java.util.Deque;
11  import java.util.concurrent.BlockingDeque;
12  import java.util.concurrent.LinkedBlockingDeque;
13  
14  import javax.net.ssl.KeyManagerFactory;
15  import javax.net.ssl.SSLContext;
16  import javax.net.ssl.SSLEngine;
17  import javax.net.ssl.SSLEngineResult;
18  import javax.net.ssl.SSLException;
19  import javax.net.ssl.SSLEngineResult.HandshakeStatus;
20  import javax.net.ssl.SSLEngineResult.Status;
21  import javax.net.ssl.TrustManagerFactory;
22  
23  import org.apache.mina.core.buffer.IoBuffer;
24  import org.junit.Ignore;
25  import org.junit.Test;
26  
27  public class SslEngineTest
28  {
29      private BlockingDeque<ByteBuffer> clientQueue = new LinkedBlockingDeque<>(); 
30      private BlockingDeque<ByteBuffer> serverQueue = new LinkedBlockingDeque<>(); 
31  
32      private class Handshaker implements Runnable {
33          private SSLEngine sslEngine;
34          private ByteBuffer workBuffer;
35          private ByteBuffer emptyBuffer= ByteBuffer.allocate(0);
36          
37          private void push(Deque<ByteBuffer> queue, ByteBuffer buffer) {
38              ByteBuffer result = ByteBuffer.allocate(buffer.capacity());
39              result.put(buffer);
40              queue.addFirst(result);
41          }
42          
43          public void run()
44          {
45              HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
46              SSLEngineResult result;
47  
48              try
49              {
50                  while (handshakeStatus != HandshakeStatus.FINISHED) {
51                      switch (handshakeStatus)
52                      {
53                          case NEED_TASK:
54                              break;
55                              
56                          case NEED_UNWRAP:
57                              // The SSLEngine waits for some input.
58                              // We may have received too few data (TCP fragmentation)
59                              // 
60                              ByteBuffer data = serverQueue.takeLast();
61                              result = sslEngine.unwrap(data, workBuffer);
62                              
63                              while (result.getStatus()  == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
64                                  // We need more data, until then, wait.
65                                  //ByteBuffer data = serverQueue.takeLast();
66                                  result = sslEngine.unwrap(data, workBuffer);
67                              }
68                              
69                              handshakeStatus = sslEngine.getHandshakeStatus();
70                              break;
71  
72                          case NEED_WRAP:
73                          case NOT_HANDSHAKING:
74                              result = sslEngine.wrap(emptyBuffer, workBuffer);
75      
76                              workBuffer.flip();
77                              
78                              if (workBuffer.hasRemaining()) {
79                                  push(clientQueue, workBuffer);
80                                  workBuffer.clear();
81                              }
82                              
83                              handshakeStatus = result.getHandshakeStatus();
84                              
85                              break;
86      
87                          case FINISHED:
88                          
89                      }
90                  }
91              }
92              catch ( SSLException e )
93              {
94                  // TODO Auto-generated catch block
95                  e.printStackTrace();
96              }
97              catch ( InterruptedException e )
98              {
99                  // TODO Auto-generated catch block
100                 e.printStackTrace();
101             }
102         }
103         
104         public Handshaker(SSLEngine sslEngine) {
105             this.sslEngine = sslEngine;
106             int packetBufferSize = sslEngine.getSession().getPacketBufferSize();
107             workBuffer = ByteBuffer.allocate(packetBufferSize);
108         }
109     }
110     
111     /** A JVM independant KEY_MANAGER_FACTORY algorithm */
112     private static final String KEY_MANAGER_FACTORY_ALGORITHM;
113 
114     static {
115         String algorithm = Security.getProperty("ssl.KeyManagerFactory.algorithm");
116         if (algorithm == null) {
117             algorithm = KeyManagerFactory.getDefaultAlgorithm();
118         }
119 
120         KEY_MANAGER_FACTORY_ALGORITHM = algorithm;
121     }
122 
123     /** App data buffer for the client SSLEngine*/
124     private IoBuffer inNetBufferClient;
125 
126     /** Net data buffer for the client SSLEngine */
127     private IoBuffer outNetBufferClient;
128 
129     /** App data buffer for the server SSLEngine */
130     private IoBuffer inNetBufferServer;
131 
132     /** Net data buffer for the server SSLEngine */
133     private IoBuffer outNetBufferServer;
134     
135     private final IoBuffer emptyBuffer = IoBuffer.allocate(0);
136 
137 
138     private static SSLContext createSSLContext() throws IOException, GeneralSecurityException {
139         char[] passphrase = "password".toCharArray();
140 
141         SSLContext ctx = SSLContext.getInstance("TLS");
142         KeyManagerFactory kmf = KeyManagerFactory.getInstance(KEY_MANAGER_FACTORY_ALGORITHM);
143         TrustManagerFactory tmf = TrustManagerFactory.getInstance(KEY_MANAGER_FACTORY_ALGORITHM);
144 
145         KeyStore ks = KeyStore.getInstance("JKS");
146         KeyStore ts = KeyStore.getInstance("JKS");
147 
148         ks.load(SslTest.class.getResourceAsStream("keystore.sslTest"), passphrase);
149         ts.load(SslTest.class.getResourceAsStream("truststore.sslTest"), passphrase);
150 
151         kmf.init(ks, passphrase);
152         tmf.init(ts);
153         ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
154 
155         return ctx;
156     }
157 
158     
159     /**
160      * Decrypt the incoming buffer and move the decrypted data to an
161      * application buffer.
162      */
163     private SSLEngineResult unwrap(SSLEngine sslEngine, IoBuffer inBuffer, IoBuffer outBuffer) throws SSLException {
164         // We first have to create the application buffer if it does not exist
165         if (outBuffer == null) {
166             outBuffer = IoBuffer.allocate(inBuffer.remaining());
167         } else {
168             // We already have one, just add the new data into it
169             outBuffer.expand(inBuffer.remaining());
170         }
171 
172         SSLEngineResult res;
173         Status status;
174         HandshakeStatus localHandshakeStatus;
175 
176         do {
177             // Decode the incoming data
178             res = sslEngine.unwrap(inBuffer.buf(), outBuffer.buf());
179             status = res.getStatus();
180 
181             // We can be processing the Handshake
182             localHandshakeStatus = res.getHandshakeStatus();
183 
184             if (status == SSLEngineResult.Status.BUFFER_OVERFLOW) {
185                 // We have to grow the target buffer, it's too small.
186                 // Then we can call the unwrap method again
187                 int newCapacity = sslEngine.getSession().getApplicationBufferSize();
188                 
189                 if (inBuffer.remaining() >= newCapacity) {
190                     // The buffer is already larger than the max buffer size suggested by the SSL engine.
191                     // Raising it any more will not make sense and it will end up in an endless loop. Throwing an error is safer
192                     throw new SSLException("SSL buffer overflow");
193                 }
194 
195                 inBuffer.expand(newCapacity);
196                 continue;
197             }
198         } while (((status == SSLEngineResult.Status.OK) || (status == SSLEngineResult.Status.BUFFER_OVERFLOW))
199                 && ((localHandshakeStatus == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) || 
200                         (localHandshakeStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP)));
201 
202         return res;
203     }
204 
205     
206     private SSLEngineResult.Status unwrapHandshake(SSLEngine sslEngine, IoBuffer appBuffer, IoBuffer netBuffer) throws SSLException {
207         // Prepare the net data for reading.
208         if ((appBuffer == null) || !appBuffer.hasRemaining()) {
209             // Need more data.
210             return SSLEngineResult.Status.BUFFER_UNDERFLOW;
211         }
212 
213         SSLEngineResult res = unwrap(sslEngine, appBuffer, netBuffer);
214         HandshakeStatus handshakeStatus = res.getHandshakeStatus();
215 
216         //checkStatus(res);
217 
218         // If handshake finished, no data was produced, and the status is still
219         // ok, try to unwrap more
220         if ((handshakeStatus == SSLEngineResult.HandshakeStatus.FINISHED)
221                 && (res.getStatus() == SSLEngineResult.Status.OK)
222                 && appBuffer.hasRemaining()) {
223             res = unwrap(sslEngine, appBuffer, netBuffer);
224 
225             // prepare to be written again
226             if (appBuffer.hasRemaining()) {
227                 appBuffer.compact();
228             } else {
229                 appBuffer.free();
230                 appBuffer = null;
231             }
232         } else {
233             // prepare to be written again
234             if (appBuffer.hasRemaining()) {
235                 appBuffer.compact();
236             } else {
237                 appBuffer.free();
238                 appBuffer = null;
239             }
240         }
241 
242         return res.getStatus();
243     }
244 
245     
246     /* no qualifier */boolean isInboundDone(SSLEngine sslEngine) {
247         return sslEngine == null || sslEngine.isInboundDone();
248     }
249 
250     
251     /* no qualifier */boolean isOutboundDone(SSLEngine sslEngine) {
252         return sslEngine == null || sslEngine.isOutboundDone();
253     }
254 
255     
256     /**
257      * Perform any handshaking processing.
258      */
259     /* no qualifier */HandshakeStatus handshake(SSLEngine sslEngine, IoBuffer appBuffer, IoBuffer netBuffer ) throws SSLException {
260         SSLEngineResult result;
261         HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
262 
263         for (;;) {
264             switch (handshakeStatus) {
265             case FINISHED:
266                 //handshakeComplete = true;
267                 return handshakeStatus;
268 
269             case NEED_TASK:
270                 //handshakeStatus = doTasks();
271                 break;
272 
273             case NEED_UNWRAP:
274                 // we need more data read
275                 SSLEngineResult.Status status = unwrapHandshake(sslEngine, appBuffer, netBuffer);
276                 handshakeStatus = sslEngine.getHandshakeStatus();
277 
278                 return handshakeStatus;
279 
280             case NEED_WRAP:
281                 result = sslEngine.wrap(emptyBuffer.buf(), netBuffer.buf());
282 
283                 while ( result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW ) {
284                     netBuffer.capacity(netBuffer.capacity() << 1);
285                     netBuffer.limit(netBuffer.capacity());
286 
287                     result = sslEngine.wrap(emptyBuffer.buf(), netBuffer.buf());
288                 }
289 
290                 netBuffer.flip();
291                 return result.getHandshakeStatus();
292 
293             case NOT_HANDSHAKING:
294                 result = sslEngine.wrap(emptyBuffer.buf(), netBuffer.buf());
295 
296                 while ( result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW ) {
297                     netBuffer.capacity(netBuffer.capacity() << 1);
298                     netBuffer.limit(netBuffer.capacity());
299 
300                     result = sslEngine.wrap(emptyBuffer.buf(), netBuffer.buf());
301                 }
302 
303                 netBuffer.flip();
304                 handshakeStatus = result.getHandshakeStatus();
305                 return handshakeStatus;
306 
307             default:
308                 throw new IllegalStateException("error");
309             }
310         }
311     }
312 
313     
314     /**
315      * Do all the outstanding handshake tasks in the current Thread.
316      */
317     private SSLEngineResult.HandshakeStatus doTasks(SSLEngine sslEngine) {
318         /*
319          * We could run this in a separate thread, but I don't see the need for
320          * this when used from SSLFilter. Use thread filters in MINA instead?
321          */
322         Runnable runnable;
323         while ((runnable = sslEngine.getDelegatedTask()) != null) {
324             //Thread thread = new Thread(runnable);
325             //thread.start();
326             runnable.run();
327         }
328         return sslEngine.getHandshakeStatus();
329     }
330     
331     
332     private HandshakeStatus handshake(SSLEngine sslEngine, HandshakeStatus expected, 
333         IoBuffer inBuffer, IoBuffer outBuffer, boolean dumpBuffer) throws SSLException {
334         HandshakeStatus handshakeStatus = handshake(sslEngine, inBuffer, outBuffer);
335 
336         if ( handshakeStatus != expected) {
337             fail();
338         }
339         
340         if (dumpBuffer) {
341             System.out.println("Message:" + outBuffer);
342         }
343         
344         return handshakeStatus;
345     }
346 
347     
348     @Test
349     @Ignore
350     public void testSSL() throws Exception {
351         // Initialise the client SSLEngine
352         SSLContext sslContextClient = createSSLContext();
353         SSLEngine sslEngineClient = sslContextClient.createSSLEngine();
354         int packetBufferSize = sslEngineClient.getSession().getPacketBufferSize();
355         inNetBufferClient = IoBuffer.allocate(packetBufferSize).setAutoExpand(true);
356         outNetBufferClient = IoBuffer.allocate(packetBufferSize).setAutoExpand(true);
357         
358         sslEngineClient.setUseClientMode(true);
359 
360         // Initialise the Server SSLEngine
361         SSLContext sslContextServer = createSSLContext();
362         SSLEngine sslEngineServer = sslContextServer.createSSLEngine();
363         packetBufferSize = sslEngineServer.getSession().getPacketBufferSize();
364         inNetBufferServer = IoBuffer.allocate(packetBufferSize).setAutoExpand(true);
365         outNetBufferServer = IoBuffer.allocate(packetBufferSize).setAutoExpand(true);
366         
367         sslEngineServer.setUseClientMode(false);
368         
369         Handshaker handshakerClient = new Handshaker( sslEngineClient );
370         Handshaker handshakerServer = new Handshaker( sslEngineServer );
371         
372         handshakerServer.run();
373 
374         HandshakeStatus handshakeStatusClient = sslEngineClient.getHandshakeStatus();
375         HandshakeStatus handshakeStatusServer = sslEngineServer.getHandshakeStatus();
376 
377     // <<< Server
378         // Start the server
379         handshakeStatusServer = handshake(sslEngineServer, HandshakeStatus.NEED_UNWRAP, 
380             null, outNetBufferServer, false);
381         
382     // >>> Client
383         // Now start the client, which will generate a CLIENT_HELLO,
384         // stored into the outNetBufferClient
385         handshakeStatusClient = handshake(sslEngineClient, HandshakeStatus.NEED_UNWRAP, 
386             null, outNetBufferClient, true);
387         
388     // <<< Server
389         // Process the CLIENT_HELLO on the server
390         handshakeStatusServer = handshake(sslEngineServer, HandshakeStatus.NEED_TASK, 
391             outNetBufferClient, outNetBufferServer, false);
392 
393         // Process the tasks on the server, prepare the SERVER_HELLO message
394         handshakeStatusServer = doTasks(sslEngineServer);
395         
396         // We should be ready to generate the SERVER_HELLO message
397         if ( handshakeStatusServer != HandshakeStatus.NEED_WRAP) {
398             fail();
399         }
400         
401         // Get the SERVER_HELLO message, with all the associated messages
402         // ([Certificate], [ServerKeyExchange], [CertificateRequest], ServerHelloDone)
403         outNetBufferServer.clear();
404         handshakeStatusServer = handshake(sslEngineServer, HandshakeStatus.NEED_UNWRAP, 
405             null, outNetBufferServer, true);
406         
407     // >>> Client
408         // Process the SERVER_HELLO message on the client
409         handshakeStatusClient = handshake(sslEngineClient, HandshakeStatus.NEED_TASK, 
410             outNetBufferServer, inNetBufferClient, false);
411     
412         // Prepare the client response
413         handshakeStatusClient = doTasks(sslEngineClient);
414     
415         // We should get back the Client messages ([Certificate],
416         // ClientKeyExchange, [CertificateVerify])
417         if ( handshakeStatusClient != HandshakeStatus.NEED_WRAP) {
418             fail();
419         }
420     
421         // Generate the [Certificate], ClientKeyExchange, [CertificateVerify] messages
422         outNetBufferClient.clear();
423         handshakeStatusClient = handshake(sslEngineClient, HandshakeStatus.NEED_WRAP, 
424             null, outNetBufferClient, true);
425         
426     // <<< Server
427         // Process the CLIENT_KEY_EXCHANGE on the server
428         outNetBufferServer.clear();
429         handshakeStatusServer = handshake(sslEngineServer, HandshakeStatus.NEED_TASK, 
430             outNetBufferClient, outNetBufferServer, false);
431 
432         // Do the controls
433         handshakeStatusServer = doTasks(sslEngineServer);
434         
435         // The server is waiting for more
436         if ( handshakeStatusServer != HandshakeStatus.NEED_UNWRAP) {
437             fail();
438         }
439 
440     // >>> Client
441         // The CHANGE_CIPHER_SPEC message generation
442         outNetBufferClient.clear();
443         handshakeStatusClient = handshake(sslEngineClient, HandshakeStatus.NEED_WRAP, 
444             null, outNetBufferClient, true);
445 
446     // <<< Server
447         // Process the CHANGE_CIPHER_SPEC on the server
448         outNetBufferServer.clear();
449         handshakeStatusServer = handshake(sslEngineServer, HandshakeStatus.NEED_UNWRAP, 
450             outNetBufferClient, outNetBufferServer, false);
451 
452     // >>> Client
453         // Generate the FINISHED message on thee client
454         outNetBufferClient.clear();
455         handshakeStatusClient = handshake(sslEngineClient, HandshakeStatus.NEED_UNWRAP, 
456             null, outNetBufferClient, true);
457 
458     // <<< Server
459         // Process the client FINISHED message
460         outNetBufferServer.clear();
461         handshakeStatusServer = handshake(sslEngineServer, HandshakeStatus.NEED_WRAP, 
462             outNetBufferClient, outNetBufferServer, false);
463 
464         // Generate the CHANGE_CIPHER_SPEC message on the server
465         handshakeStatusServer = handshake(sslEngineServer, HandshakeStatus.NEED_WRAP, 
466             null, outNetBufferServer, true);
467         
468     // >>> Client
469         // Process the server CHANGE_SCIPHER_SPEC message on the client
470         outNetBufferClient.clear();
471         handshakeStatusClient = handshake(sslEngineClient, HandshakeStatus.NEED_UNWRAP, 
472             outNetBufferServer, outNetBufferClient, false);
473 
474     // <<< Server
475         // Generate the server FINISHED message
476         outNetBufferServer.clear();
477         handshakeStatusServer = handshake(sslEngineServer, HandshakeStatus.FINISHED, 
478             null, outNetBufferServer, true);
479 
480     // >>> Client
481         // Process the server FINISHED message on the client
482         outNetBufferClient.clear();
483         handshakeStatusClient = handshake(sslEngineClient, HandshakeStatus.NOT_HANDSHAKING, 
484             outNetBufferServer, outNetBufferClient, false);
485     }
486 }