View Javadoc
1   /*
2    *  Licensed to the Apache Software Foundation (ASF) under one
3    *  or more contributor license agreements.  See the NOTICE file
4    *  distributed with this work for additional information
5    *  regarding copyright ownership.  The ASF licenses this file
6    *  to you under the Apache License, Version 2.0 (the
7    *  "License"); you may not use this file except in compliance
8    *  with the License.  You may obtain a copy of the License at
9    *
10   *    http://www.apache.org/licenses/LICENSE-2.0
11   *
12   *  Unless required by applicable law or agreed to in writing,
13   *  software distributed under the License is distributed on an
14   *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   *  KIND, either express or implied.  See the License for the
16   *  specific language governing permissions and limitations
17   *  under the License.
18   *
19   */package org.apache.mina.filter.ssl;
20  
21  import java.io.BufferedReader;
22  import java.io.IOException;
23  import java.io.InputStreamReader;
24  import java.net.InetAddress;
25  import java.net.InetSocketAddress;
26  import java.net.Socket;
27  import java.net.SocketTimeoutException;
28  import java.security.GeneralSecurityException;
29  import java.security.KeyStore;
30  import java.security.Security;
31  
32  import javax.net.ssl.KeyManagerFactory;
33  import javax.net.ssl.SSLContext;
34  import javax.net.ssl.SSLSocketFactory;
35  import javax.net.ssl.TrustManagerFactory;
36  
37  import org.apache.mina.core.filterchain.DefaultIoFilterChainBuilder;
38  import org.apache.mina.core.service.IoHandlerAdapter;
39  import org.apache.mina.core.session.IoSession;
40  import org.apache.mina.filter.codec.ProtocolCodecFilter;
41  import org.apache.mina.filter.codec.textline.TextLineCodecFactory;
42  import org.apache.mina.transport.socket.nio.NioSocketAcceptor;
43  import org.apache.mina.util.AvailablePortFinder;
44  import org.junit.Test;
45  
46  /**
47   * Test a SSL session where the connection is established and closed twice. It should be
48   * processed correctly (Test for DIRMINA-650)
49   *
50   * @author <a href="http://mina.apache.org">Apache MINA Project</a>
51   */
52  public class SslTest {
53      /** A static port used for his test, chosen to avoid collisions */
54      private static final int port = AvailablePortFinder.getNextAvailable(5555);
55  
56      private static Exception clientError = null;
57  
58      private static InetAddress address;
59  
60      private static SSLSocketFactory factory;
61      
62      private static NioSocketAcceptor acceptor;
63  
64      /** A JVM independant KEY_MANAGER_FACTORY algorithm */
65      private static final String KEY_MANAGER_FACTORY_ALGORITHM;
66  
67      static {
68          String algorithm = Security.getProperty("ssl.KeyManagerFactory.algorithm");
69          if (algorithm == null) {
70              algorithm = KeyManagerFactory.getDefaultAlgorithm();
71          }
72  
73          KEY_MANAGER_FACTORY_ALGORITHM = algorithm;
74      }
75  
76      private static class TestHandler extends IoHandlerAdapter {
77          public void messageReceived(IoSession session, Object message) throws Exception {
78              String line = (String) message;
79  
80              if (line.startsWith("hello")) {
81                  //System.out.println("Server got: 'hello', waiting for 'send'");
82                  Thread.sleep(1500);
83              } else if (line.startsWith("send")) {
84                  //System.out.println("Server got: 'send', sending 'data'");
85                  StringBuilder sb = new StringBuilder();
86                  
87                  for ( int i = 0; i < 10000; i++) {
88                      sb.append('A');
89                  }
90                      
91                  session.write(sb.toString());
92                  session.closeOnFlush();
93              }
94          }
95      }
96  
97      /**
98       * Starts a Server with the SSL Filter and a simple text line 
99       * protocol codec filter
100      */
101     private static void startServer() throws Exception {
102         acceptor = new NioSocketAcceptor();
103 
104         acceptor.setReuseAddress(true);
105         DefaultIoFilterChainBuilder filters = acceptor.getFilterChain();
106 
107         // Inject the SSL filter
108         SslFilter sslFilter = new SslFilter(createSSLContext());
109         filters.addLast("sslFilter", sslFilter);
110         sslFilter.setNeedClientAuth(true);
111 
112         // Inject the TestLine codec filter
113         filters.addLast("text", new ProtocolCodecFilter(new TextLineCodecFactory()));
114 
115         acceptor.setHandler(new TestHandler());
116         acceptor.bind(new InetSocketAddress(port));
117     }
118     
119     private static void stopServer() {
120         acceptor.dispose();
121     }
122 
123     /**
124      * Starts a client which will connect twice using SSL
125      */
126     private static void startClient() throws Exception {
127         address = InetAddress.getByName("localhost");
128 
129         SSLContext context = createSSLContext();
130         factory = context.getSocketFactory();
131 
132         connectAndSend();
133 
134         // This one will throw a SocketTimeoutException if DIRMINA-650 is not fixed
135         connectAndSend();
136     }
137 
138     private static void connectAndSend() throws Exception {
139         Socket parent = new Socket(address, port);
140         Socket socket = factory.createSocket(parent, address.getCanonicalHostName(), port, false);
141 
142         //System.out.println("Client sending: hello");
143         socket.getOutputStream().write("hello                      \n".getBytes());
144         socket.getOutputStream().flush();
145         socket.setSoTimeout(1000000);
146 
147         //System.out.println("Client sending: send");
148         socket.getOutputStream().write("send\n".getBytes());
149         socket.getOutputStream().flush();
150 
151         BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream()));
152         String line = in.readLine();
153         //System.out.println("Client got: " + line);
154         socket.close();
155     }
156 
157     private static SSLContext createSSLContext() throws IOException, GeneralSecurityException {
158         char[] passphrase = "password".toCharArray();
159 
160         SSLContext ctx = SSLContext.getInstance("TLS");
161         KeyManagerFactory kmf = KeyManagerFactory.getInstance(KEY_MANAGER_FACTORY_ALGORITHM);
162         TrustManagerFactory tmf = TrustManagerFactory.getInstance(KEY_MANAGER_FACTORY_ALGORITHM);
163 
164         KeyStore ks = KeyStore.getInstance("JKS");
165         KeyStore ts = KeyStore.getInstance("JKS");
166 
167         ks.load(SslTest.class.getResourceAsStream("keystore.sslTest"), passphrase);
168         ts.load(SslTest.class.getResourceAsStream("truststore.sslTest"), passphrase);
169 
170         kmf.init(ks, passphrase);
171         tmf.init(ts);
172         ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
173 
174         return ctx;
175     }
176 
177     @Test
178     public void testSSL() throws Exception {
179         try {
180             startServer();
181     
182             Thread t = new Thread() {
183                 public void run() {
184                     try {
185                         startClient();
186                     } catch (Exception e) {
187                         clientError = e;
188                     }
189                 }
190             };
191             t.start();
192             t.join();
193             
194             if (clientError != null) {
195                 throw clientError;
196             }
197         } finally {
198             stopServer();
199         }
200     }
201     
202     
203     @Test
204     public void unsecureClientTryToConnectoToSecureServer() throws Exception {
205         try {
206             startServer(); // Start Server with SSLFilter
207     
208             //Now start a client without any SSL
209             Thread t = new Thread() {
210                 @Override
211                 public void run() {
212                     try {
213                         address = InetAddress.getByName("localhost");
214     
215                         Socket socket = new Socket(address, port);
216                         socket.setSoTimeout(10000);
217     
218                         String response = null;
219     
220                         while (response == null) {
221                             try {
222                                 System.out.println(socket.isConnected());
223                                 // System.out.println("Client sending: hello");
224                                 socket.getOutputStream().write("hello                      \n".getBytes());
225                                 socket.getOutputStream().flush();
226                                 socket.setSoTimeout(1000);
227     
228                                 // System.out.println("Client sending: send");
229                                 socket.getOutputStream().write("send\n".getBytes());
230                                 socket.getOutputStream().flush();
231     
232                                 BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream()));
233                                 String line = "";
234                                 
235                                 while ((line = in.readLine()) != null) {
236                                     response = response + line;
237                                 }
238                             } catch (SocketTimeoutException timeout) {
239                                 // donothing
240                                 timeout.printStackTrace();
241                             }
242                         }
243                         
244                         if (response.contains("AAAAAAA")){
245                             throw new IllegalStateException("getting response:" + response);
246                         }
247                         
248                         // System.out.println("Client got: " + line);
249                         socket.close();
250                     } catch (Exception e) {
251                         clientError = e;
252                     }
253                 }
254             };
255             
256             t.start();
257             t.join();
258             
259             if (clientError != null) {
260                 throw clientError;
261             }
262         } finally {
263             stopServer();
264         }
265     }
266 }