1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.apache.mina.filter.ssl;
20
21 import java.io.BufferedReader;
22 import java.io.IOException;
23 import java.io.InputStreamReader;
24 import java.net.InetAddress;
25 import java.net.InetSocketAddress;
26 import java.net.Socket;
27 import java.net.SocketTimeoutException;
28 import java.security.GeneralSecurityException;
29 import java.security.KeyStore;
30 import java.security.Security;
31
32 import javax.net.ssl.KeyManagerFactory;
33 import javax.net.ssl.SSLContext;
34 import javax.net.ssl.SSLSocketFactory;
35 import javax.net.ssl.TrustManagerFactory;
36
37 import org.apache.mina.core.filterchain.DefaultIoFilterChainBuilder;
38 import org.apache.mina.core.service.IoHandlerAdapter;
39 import org.apache.mina.core.session.IoSession;
40 import org.apache.mina.filter.codec.ProtocolCodecFilter;
41 import org.apache.mina.filter.codec.textline.TextLineCodecFactory;
42 import org.apache.mina.transport.socket.nio.NioSocketAcceptor;
43 import org.apache.mina.util.AvailablePortFinder;
44 import org.junit.Test;
45
46
47
48
49
50
51
52 public class SslTest {
53
54 private static final int port = AvailablePortFinder.getNextAvailable(5555);
55
56 private static Exception clientError = null;
57
58 private static InetAddress address;
59
60 private static SSLSocketFactory factory;
61
62 private static NioSocketAcceptor acceptor;
63
64
65 private static final String KEY_MANAGER_FACTORY_ALGORITHM;
66
67 static {
68 String algorithm = Security.getProperty("ssl.KeyManagerFactory.algorithm");
69 if (algorithm == null) {
70 algorithm = KeyManagerFactory.getDefaultAlgorithm();
71 }
72
73 KEY_MANAGER_FACTORY_ALGORITHM = algorithm;
74 }
75
76 private static class TestHandler extends IoHandlerAdapter {
77 public void messageReceived(IoSession session, Object message) throws Exception {
78 String line = (String) message;
79
80 if (line.startsWith("hello")) {
81
82 Thread.sleep(1500);
83 } else if (line.startsWith("send")) {
84
85 StringBuilder sb = new StringBuilder();
86
87 for ( int i = 0; i < 10000; i++) {
88 sb.append('A');
89 }
90
91 session.write(sb.toString());
92 session.closeOnFlush();
93 }
94 }
95 }
96
97
98
99
100
101 private static void startServer() throws Exception {
102 acceptor = new NioSocketAcceptor();
103
104 acceptor.setReuseAddress(true);
105 DefaultIoFilterChainBuilder filters = acceptor.getFilterChain();
106
107
108 SslFilter sslFilter = new SslFilter(createSSLContext());
109 filters.addLast("sslFilter", sslFilter);
110 sslFilter.setNeedClientAuth(true);
111
112
113 filters.addLast("text", new ProtocolCodecFilter(new TextLineCodecFactory()));
114
115 acceptor.setHandler(new TestHandler());
116 acceptor.bind(new InetSocketAddress(port));
117 }
118
119 private static void stopServer() {
120 acceptor.dispose();
121 }
122
123
124
125
126 private static void startClient() throws Exception {
127 address = InetAddress.getByName("localhost");
128
129 SSLContext context = createSSLContext();
130 factory = context.getSocketFactory();
131
132 connectAndSend();
133
134
135 connectAndSend();
136 }
137
138 private static void connectAndSend() throws Exception {
139 Socket parent = new Socket(address, port);
140 Socket socket = factory.createSocket(parent, address.getCanonicalHostName(), port, false);
141
142
143 socket.getOutputStream().write("hello \n".getBytes());
144 socket.getOutputStream().flush();
145 socket.setSoTimeout(1000000);
146
147
148 socket.getOutputStream().write("send\n".getBytes());
149 socket.getOutputStream().flush();
150
151 BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream()));
152 String line = in.readLine();
153
154 socket.close();
155 }
156
157 private static SSLContext createSSLContext() throws IOException, GeneralSecurityException {
158 char[] passphrase = "password".toCharArray();
159
160 SSLContext ctx = SSLContext.getInstance("TLS");
161 KeyManagerFactory kmf = KeyManagerFactory.getInstance(KEY_MANAGER_FACTORY_ALGORITHM);
162 TrustManagerFactory tmf = TrustManagerFactory.getInstance(KEY_MANAGER_FACTORY_ALGORITHM);
163
164 KeyStore ks = KeyStore.getInstance("JKS");
165 KeyStore ts = KeyStore.getInstance("JKS");
166
167 ks.load(SslTest.class.getResourceAsStream("keystore.sslTest"), passphrase);
168 ts.load(SslTest.class.getResourceAsStream("truststore.sslTest"), passphrase);
169
170 kmf.init(ks, passphrase);
171 tmf.init(ts);
172 ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
173
174 return ctx;
175 }
176
177 @Test
178 public void testSSL() throws Exception {
179 try {
180 startServer();
181
182 Thread t = new Thread() {
183 public void run() {
184 try {
185 startClient();
186 } catch (Exception e) {
187 clientError = e;
188 }
189 }
190 };
191 t.start();
192 t.join();
193
194 if (clientError != null) {
195 throw clientError;
196 }
197 } finally {
198 stopServer();
199 }
200 }
201
202
203 @Test
204 public void unsecureClientTryToConnectoToSecureServer() throws Exception {
205 try {
206 startServer();
207
208
209 Thread t = new Thread() {
210 @Override
211 public void run() {
212 try {
213 address = InetAddress.getByName("localhost");
214
215 Socket socket = new Socket(address, port);
216 socket.setSoTimeout(10000);
217
218 String response = null;
219
220 while (response == null) {
221 try {
222 System.out.println(socket.isConnected());
223
224 socket.getOutputStream().write("hello \n".getBytes());
225 socket.getOutputStream().flush();
226 socket.setSoTimeout(1000);
227
228
229 socket.getOutputStream().write("send\n".getBytes());
230 socket.getOutputStream().flush();
231
232 BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream()));
233 String line = "";
234
235 while ((line = in.readLine()) != null) {
236 response = response + line;
237 }
238 } catch (SocketTimeoutException timeout) {
239
240 timeout.printStackTrace();
241 }
242 }
243
244 if (response.contains("AAAAAAA")){
245 throw new IllegalStateException("getting response:" + response);
246 }
247
248
249 socket.close();
250 } catch (Exception e) {
251 clientError = e;
252 }
253 }
254 };
255
256 t.start();
257 t.join();
258
259 if (clientError != null) {
260 throw clientError;
261 }
262 } finally {
263 stopServer();
264 }
265 }
266 }