1 package org.apache.mina.filter.ssl;
2
3 import static org.junit.Assert.fail;
4
5 import java.io.IOException;
6 import java.nio.ByteBuffer;
7 import java.security.GeneralSecurityException;
8 import java.security.KeyStore;
9 import java.security.Security;
10 import java.util.Deque;
11 import java.util.concurrent.BlockingDeque;
12 import java.util.concurrent.LinkedBlockingDeque;
13
14 import javax.net.ssl.KeyManagerFactory;
15 import javax.net.ssl.SSLContext;
16 import javax.net.ssl.SSLEngine;
17 import javax.net.ssl.SSLEngineResult;
18 import javax.net.ssl.SSLException;
19 import javax.net.ssl.SSLEngineResult.HandshakeStatus;
20 import javax.net.ssl.SSLEngineResult.Status;
21 import javax.net.ssl.TrustManagerFactory;
22
23 import org.apache.mina.core.buffer.IoBuffer;
24 import org.junit.Ignore;
25 import org.junit.Test;
26
27 public class SslEngineTest
28 {
29 private BlockingDeque<ByteBuffer> clientQueue = new LinkedBlockingDeque<>();
30 private BlockingDeque<ByteBuffer> serverQueue = new LinkedBlockingDeque<>();
31
32 private class Handshaker implements Runnable {
33 private SSLEngine sslEngine;
34 private ByteBuffer workBuffer;
35 private ByteBuffer emptyBuffer= ByteBuffer.allocate(0);
36
37 private void push(Deque<ByteBuffer> queue, ByteBuffer buffer) {
38 ByteBuffer result = ByteBuffer.allocate(buffer.capacity());
39 result.put(buffer);
40 queue.addFirst(result);
41 }
42
43 public void run()
44 {
45 HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
46 SSLEngineResult result;
47
48 try
49 {
50 while (handshakeStatus != HandshakeStatus.FINISHED) {
51 switch (handshakeStatus)
52 {
53 case NEED_TASK:
54 break;
55
56 case NEED_UNWRAP:
57
58
59
60 ByteBuffer data = serverQueue.takeLast();
61 result = sslEngine.unwrap(data, workBuffer);
62
63 while (result.getStatus() == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
64
65
66 result = sslEngine.unwrap(data, workBuffer);
67 }
68
69 handshakeStatus = sslEngine.getHandshakeStatus();
70 break;
71
72 case NEED_WRAP:
73 case NOT_HANDSHAKING:
74 result = sslEngine.wrap(emptyBuffer, workBuffer);
75
76 workBuffer.flip();
77
78 if (workBuffer.hasRemaining()) {
79 push(clientQueue, workBuffer);
80 workBuffer.clear();
81 }
82
83 handshakeStatus = result.getHandshakeStatus();
84
85 break;
86
87 case FINISHED:
88
89 }
90 }
91 }
92 catch ( SSLException e )
93 {
94
95 e.printStackTrace();
96 }
97 catch ( InterruptedException e )
98 {
99
100 e.printStackTrace();
101 }
102 }
103
104 public Handshaker(SSLEngine sslEngine) {
105 this.sslEngine = sslEngine;
106 int packetBufferSize = sslEngine.getSession().getPacketBufferSize();
107 workBuffer = ByteBuffer.allocate(packetBufferSize);
108 }
109 }
110
111
112 private static final String KEY_MANAGER_FACTORY_ALGORITHM;
113
114 static {
115 String algorithm = Security.getProperty("ssl.KeyManagerFactory.algorithm");
116 if (algorithm == null) {
117 algorithm = KeyManagerFactory.getDefaultAlgorithm();
118 }
119
120 KEY_MANAGER_FACTORY_ALGORITHM = algorithm;
121 }
122
123
124 private IoBuffer inNetBufferClient;
125
126
127 private IoBuffer outNetBufferClient;
128
129
130 private IoBuffer inNetBufferServer;
131
132
133 private IoBuffer outNetBufferServer;
134
135 private final IoBuffer emptyBuffer = IoBuffer.allocate(0);
136
137
138 private static SSLContext createSSLContext() throws IOException, GeneralSecurityException {
139 char[] passphrase = "password".toCharArray();
140
141 SSLContext ctx = SSLContext.getInstance("TLS");
142 KeyManagerFactory kmf = KeyManagerFactory.getInstance(KEY_MANAGER_FACTORY_ALGORITHM);
143 TrustManagerFactory tmf = TrustManagerFactory.getInstance(KEY_MANAGER_FACTORY_ALGORITHM);
144
145 KeyStore ks = KeyStore.getInstance("JKS");
146 KeyStore ts = KeyStore.getInstance("JKS");
147
148 ks.load(SslTest.class.getResourceAsStream("keystore.sslTest"), passphrase);
149 ts.load(SslTest.class.getResourceAsStream("truststore.sslTest"), passphrase);
150
151 kmf.init(ks, passphrase);
152 tmf.init(ts);
153 ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
154
155 return ctx;
156 }
157
158
159
160
161
162
163 private SSLEngineResult unwrap(SSLEngine sslEngine, IoBuffer inBuffer, IoBuffer outBuffer) throws SSLException {
164
165 if (outBuffer == null) {
166 outBuffer = IoBuffer.allocate(inBuffer.remaining());
167 } else {
168
169 outBuffer.expand(inBuffer.remaining());
170 }
171
172 SSLEngineResult res;
173 Status status;
174 HandshakeStatus localHandshakeStatus;
175
176 do {
177
178 res = sslEngine.unwrap(inBuffer.buf(), outBuffer.buf());
179 status = res.getStatus();
180
181
182 localHandshakeStatus = res.getHandshakeStatus();
183
184 if (status == SSLEngineResult.Status.BUFFER_OVERFLOW) {
185
186
187 int newCapacity = sslEngine.getSession().getApplicationBufferSize();
188
189 if (inBuffer.remaining() >= newCapacity) {
190
191
192 throw new SSLException("SSL buffer overflow");
193 }
194
195 inBuffer.expand(newCapacity);
196 continue;
197 }
198 } while (((status == SSLEngineResult.Status.OK) || (status == SSLEngineResult.Status.BUFFER_OVERFLOW))
199 && ((localHandshakeStatus == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) ||
200 (localHandshakeStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP)));
201
202 return res;
203 }
204
205
206 private SSLEngineResult.Status unwrapHandshake(SSLEngine sslEngine, IoBuffer appBuffer, IoBuffer netBuffer) throws SSLException {
207
208 if ((appBuffer == null) || !appBuffer.hasRemaining()) {
209
210 return SSLEngineResult.Status.BUFFER_UNDERFLOW;
211 }
212
213 SSLEngineResult res = unwrap(sslEngine, appBuffer, netBuffer);
214 HandshakeStatus handshakeStatus = res.getHandshakeStatus();
215
216
217
218
219
220 if ((handshakeStatus == SSLEngineResult.HandshakeStatus.FINISHED)
221 && (res.getStatus() == SSLEngineResult.Status.OK)
222 && appBuffer.hasRemaining()) {
223 res = unwrap(sslEngine, appBuffer, netBuffer);
224
225
226 if (appBuffer.hasRemaining()) {
227 appBuffer.compact();
228 } else {
229 appBuffer.free();
230 appBuffer = null;
231 }
232 } else {
233
234 if (appBuffer.hasRemaining()) {
235 appBuffer.compact();
236 } else {
237 appBuffer.free();
238 appBuffer = null;
239 }
240 }
241
242 return res.getStatus();
243 }
244
245
246 boolean isInboundDone(SSLEngine sslEngine) {
247 return sslEngine == null || sslEngine.isInboundDone();
248 }
249
250
251 boolean isOutboundDone(SSLEngine sslEngine) {
252 return sslEngine == null || sslEngine.isOutboundDone();
253 }
254
255
256
257
258
259 HandshakeStatus handshake(SSLEngine sslEngine, IoBuffer appBuffer, IoBuffer netBuffer ) throws SSLException {
260 SSLEngineResult result;
261 HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
262
263 for (;;) {
264 switch (handshakeStatus) {
265 case FINISHED:
266
267 return handshakeStatus;
268
269 case NEED_TASK:
270
271 break;
272
273 case NEED_UNWRAP:
274
275 SSLEngineResult.Status status = unwrapHandshake(sslEngine, appBuffer, netBuffer);
276 handshakeStatus = sslEngine.getHandshakeStatus();
277
278 return handshakeStatus;
279
280 case NEED_WRAP:
281 result = sslEngine.wrap(emptyBuffer.buf(), netBuffer.buf());
282
283 while ( result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW ) {
284 netBuffer.capacity(netBuffer.capacity() << 1);
285 netBuffer.limit(netBuffer.capacity());
286
287 result = sslEngine.wrap(emptyBuffer.buf(), netBuffer.buf());
288 }
289
290 netBuffer.flip();
291 return result.getHandshakeStatus();
292
293 case NOT_HANDSHAKING:
294 result = sslEngine.wrap(emptyBuffer.buf(), netBuffer.buf());
295
296 while ( result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW ) {
297 netBuffer.capacity(netBuffer.capacity() << 1);
298 netBuffer.limit(netBuffer.capacity());
299
300 result = sslEngine.wrap(emptyBuffer.buf(), netBuffer.buf());
301 }
302
303 netBuffer.flip();
304 handshakeStatus = result.getHandshakeStatus();
305 return handshakeStatus;
306
307 default:
308 throw new IllegalStateException("error");
309 }
310 }
311 }
312
313
314
315
316
317 private SSLEngineResult.HandshakeStatus doTasks(SSLEngine sslEngine) {
318
319
320
321
322 Runnable runnable;
323 while ((runnable = sslEngine.getDelegatedTask()) != null) {
324
325
326 runnable.run();
327 }
328 return sslEngine.getHandshakeStatus();
329 }
330
331
332 private HandshakeStatus handshake(SSLEngine sslEngine, HandshakeStatus expected,
333 IoBuffer inBuffer, IoBuffer outBuffer, boolean dumpBuffer) throws SSLException {
334 HandshakeStatus handshakeStatus = handshake(sslEngine, inBuffer, outBuffer);
335
336 if ( handshakeStatus != expected) {
337 fail();
338 }
339
340 if (dumpBuffer) {
341 System.out.println("Message:" + outBuffer);
342 }
343
344 return handshakeStatus;
345 }
346
347
348 @Test
349 @Ignore
350 public void testSSL() throws Exception {
351
352 SSLContext sslContextClient = createSSLContext();
353 SSLEngine sslEngineClient = sslContextClient.createSSLEngine();
354 int packetBufferSize = sslEngineClient.getSession().getPacketBufferSize();
355 inNetBufferClient = IoBuffer.allocate(packetBufferSize).setAutoExpand(true);
356 outNetBufferClient = IoBuffer.allocate(packetBufferSize).setAutoExpand(true);
357
358 sslEngineClient.setUseClientMode(true);
359
360
361 SSLContext sslContextServer = createSSLContext();
362 SSLEngine sslEngineServer = sslContextServer.createSSLEngine();
363 packetBufferSize = sslEngineServer.getSession().getPacketBufferSize();
364 inNetBufferServer = IoBuffer.allocate(packetBufferSize).setAutoExpand(true);
365 outNetBufferServer = IoBuffer.allocate(packetBufferSize).setAutoExpand(true);
366
367 sslEngineServer.setUseClientMode(false);
368
369 Handshaker handshakerClient = new Handshaker( sslEngineClient );
370 Handshaker handshakerServer = new Handshaker( sslEngineServer );
371
372 handshakerServer.run();
373
374 HandshakeStatus handshakeStatusClient = sslEngineClient.getHandshakeStatus();
375 HandshakeStatus handshakeStatusServer = sslEngineServer.getHandshakeStatus();
376
377
378
379 handshakeStatusServer = handshake(sslEngineServer, HandshakeStatus.NEED_UNWRAP,
380 null, outNetBufferServer, false);
381
382
383
384
385 handshakeStatusClient = handshake(sslEngineClient, HandshakeStatus.NEED_UNWRAP,
386 null, outNetBufferClient, true);
387
388
389
390 handshakeStatusServer = handshake(sslEngineServer, HandshakeStatus.NEED_TASK,
391 outNetBufferClient, outNetBufferServer, false);
392
393
394 handshakeStatusServer = doTasks(sslEngineServer);
395
396
397 if ( handshakeStatusServer != HandshakeStatus.NEED_WRAP) {
398 fail();
399 }
400
401
402
403 outNetBufferServer.clear();
404 handshakeStatusServer = handshake(sslEngineServer, HandshakeStatus.NEED_UNWRAP,
405 null, outNetBufferServer, true);
406
407
408
409 handshakeStatusClient = handshake(sslEngineClient, HandshakeStatus.NEED_TASK,
410 outNetBufferServer, inNetBufferClient, false);
411
412
413 handshakeStatusClient = doTasks(sslEngineClient);
414
415
416
417 if ( handshakeStatusClient != HandshakeStatus.NEED_WRAP) {
418 fail();
419 }
420
421
422 outNetBufferClient.clear();
423 handshakeStatusClient = handshake(sslEngineClient, HandshakeStatus.NEED_WRAP,
424 null, outNetBufferClient, true);
425
426
427
428 outNetBufferServer.clear();
429 handshakeStatusServer = handshake(sslEngineServer, HandshakeStatus.NEED_TASK,
430 outNetBufferClient, outNetBufferServer, false);
431
432
433 handshakeStatusServer = doTasks(sslEngineServer);
434
435
436 if ( handshakeStatusServer != HandshakeStatus.NEED_UNWRAP) {
437 fail();
438 }
439
440
441
442 outNetBufferClient.clear();
443 handshakeStatusClient = handshake(sslEngineClient, HandshakeStatus.NEED_WRAP,
444 null, outNetBufferClient, true);
445
446
447
448 outNetBufferServer.clear();
449 handshakeStatusServer = handshake(sslEngineServer, HandshakeStatus.NEED_UNWRAP,
450 outNetBufferClient, outNetBufferServer, false);
451
452
453
454 outNetBufferClient.clear();
455 handshakeStatusClient = handshake(sslEngineClient, HandshakeStatus.NEED_UNWRAP,
456 null, outNetBufferClient, true);
457
458
459
460 outNetBufferServer.clear();
461 handshakeStatusServer = handshake(sslEngineServer, HandshakeStatus.NEED_WRAP,
462 outNetBufferClient, outNetBufferServer, false);
463
464
465 handshakeStatusServer = handshake(sslEngineServer, HandshakeStatus.NEED_WRAP,
466 null, outNetBufferServer, true);
467
468
469
470 outNetBufferClient.clear();
471 handshakeStatusClient = handshake(sslEngineClient, HandshakeStatus.NEED_UNWRAP,
472 outNetBufferServer, outNetBufferClient, false);
473
474
475
476 outNetBufferServer.clear();
477 handshakeStatusServer = handshake(sslEngineServer, HandshakeStatus.FINISHED,
478 null, outNetBufferServer, true);
479
480
481
482 outNetBufferClient.clear();
483 handshakeStatusClient = handshake(sslEngineClient, HandshakeStatus.NOT_HANDSHAKING,
484 outNetBufferServer, outNetBufferClient, false);
485 }
486 }