AsyncChannelWrapperSecure.java

/*
 *  Licensed to the Apache Software Foundation (ASF) under one or more
 *  contributor license agreements.  See the NOTICE file distributed with
 *  this work for additional information regarding copyright ownership.
 *  The ASF licenses this file to You under the Apache License, Version 2.0
 *  (the "License"); you may not use this file except in compliance with
 *  the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */
package org.apache.tomcat.websocket;

import java.io.EOFException;
import java.io.IOException;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.channels.CompletionHandler;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLException;

import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;
import org.apache.tomcat.util.res.StringManager;

/**
 * Wraps the {@link AsynchronousSocketChannel} with SSL/TLS. This needs a lot more testing before it can be considered
 * robust.
 */
public class AsyncChannelWrapperSecure implements AsyncChannelWrapper {

    private final Log log = LogFactory.getLog(AsyncChannelWrapperSecure.class);
    private static final StringManager sm = StringManager.getManager(AsyncChannelWrapperSecure.class);

    private static final ByteBuffer DUMMY = ByteBuffer.allocate(16921);
    private final AsynchronousSocketChannel socketChannel;
    private final SSLEngine sslEngine;
    private final ByteBuffer socketReadBuffer;
    private final ByteBuffer socketWriteBuffer;
    // One thread for read, one for write
    private final ExecutorService executor = Executors.newFixedThreadPool(2, new SecureIOThreadFactory());
    private AtomicBoolean writing = new AtomicBoolean(false);
    private AtomicBoolean reading = new AtomicBoolean(false);

    public AsyncChannelWrapperSecure(AsynchronousSocketChannel socketChannel, SSLEngine sslEngine) {
        this.socketChannel = socketChannel;
        this.sslEngine = sslEngine;

        int socketBufferSize = sslEngine.getSession().getPacketBufferSize();
        socketReadBuffer = ByteBuffer.allocateDirect(socketBufferSize);
        socketWriteBuffer = ByteBuffer.allocateDirect(socketBufferSize);
    }

    @Override
    public Future<Integer> read(ByteBuffer dst) {
        WrapperFuture<Integer, Void> future = new WrapperFuture<>();

        if (!reading.compareAndSet(false, true)) {
            throw new IllegalStateException(sm.getString("asyncChannelWrapperSecure.concurrentRead"));
        }

        ReadTask readTask = new ReadTask(dst, future);

        executor.execute(readTask);

        return future;
    }

    @Override
    public <B, A extends B> void read(ByteBuffer dst, A attachment, CompletionHandler<Integer, B> handler) {

        WrapperFuture<Integer, B> future = new WrapperFuture<>(handler, attachment);

        if (!reading.compareAndSet(false, true)) {
            throw new IllegalStateException(sm.getString("asyncChannelWrapperSecure.concurrentRead"));
        }

        ReadTask readTask = new ReadTask(dst, future);

        executor.execute(readTask);
    }

    @Override
    public Future<Integer> write(ByteBuffer src) {

        WrapperFuture<Long, Void> inner = new WrapperFuture<>();

        if (!writing.compareAndSet(false, true)) {
            throw new IllegalStateException(sm.getString("asyncChannelWrapperSecure.concurrentWrite"));
        }

        WriteTask writeTask = new WriteTask(new ByteBuffer[] { src }, 0, 1, inner);

        executor.execute(writeTask);

        Future<Integer> future = new LongToIntegerFuture(inner);
        return future;
    }

    @Override
    public <B, A extends B> void write(ByteBuffer[] srcs, int offset, int length, long timeout, TimeUnit unit,
            A attachment, CompletionHandler<Long, B> handler) {

        WrapperFuture<Long, B> future = new WrapperFuture<>(handler, attachment);

        if (!writing.compareAndSet(false, true)) {
            throw new IllegalStateException(sm.getString("asyncChannelWrapperSecure.concurrentWrite"));
        }

        WriteTask writeTask = new WriteTask(srcs, offset, length, future);

        executor.execute(writeTask);
    }

    @Override
    public void close() {
        try {
            socketChannel.close();
        } catch (IOException e) {
            log.info(sm.getString("asyncChannelWrapperSecure.closeFail"));
        }
        executor.shutdownNow();
    }

    @Override
    public Future<Void> handshake() throws SSLException {

        WrapperFuture<Void, Void> wFuture = new WrapperFuture<>();

        Thread t = new WebSocketSslHandshakeThread(wFuture);
        t.start();

        return wFuture;
    }


    @Override
    public SocketAddress getLocalAddress() throws IOException {
        return socketChannel.getLocalAddress();
    }


    private class WriteTask implements Runnable {

        private final ByteBuffer[] srcs;
        private final int offset;
        private final int length;
        private final WrapperFuture<Long, ?> future;

        WriteTask(ByteBuffer[] srcs, int offset, int length, WrapperFuture<Long, ?> future) {
            this.srcs = srcs;
            this.future = future;
            this.offset = offset;
            this.length = length;
        }

        @Override
        public void run() {
            long written = 0;

            try {
                for (int i = offset; i < offset + length; i++) {
                    ByteBuffer src = srcs[i];
                    while (src.hasRemaining()) {
                        socketWriteBuffer.clear();

                        // Encrypt the data
                        SSLEngineResult r = sslEngine.wrap(src, socketWriteBuffer);
                        written += r.bytesConsumed();
                        Status s = r.getStatus();

                        if (s == Status.OK || s == Status.BUFFER_OVERFLOW) {
                            // Need to write out the bytes and may need to read from
                            // the source again to empty it
                        } else {
                            // Status.BUFFER_UNDERFLOW - only happens on unwrap
                            // Status.CLOSED - unexpected
                            throw new IllegalStateException(sm.getString("asyncChannelWrapperSecure.statusWrap"));
                        }

                        // Check for tasks
                        if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
                            Runnable runnable = sslEngine.getDelegatedTask();
                            while (runnable != null) {
                                runnable.run();
                                runnable = sslEngine.getDelegatedTask();
                            }
                        }

                        socketWriteBuffer.flip();

                        // Do the write
                        int toWrite = r.bytesProduced();
                        while (toWrite > 0) {
                            Future<Integer> f = socketChannel.write(socketWriteBuffer);
                            Integer socketWrite = f.get();
                            toWrite -= socketWrite.intValue();
                        }
                    }
                }


                if (writing.compareAndSet(true, false)) {
                    future.complete(Long.valueOf(written));
                } else {
                    future.fail(new IllegalStateException(sm.getString("asyncChannelWrapperSecure.wrongStateWrite")));
                }
            } catch (Exception e) {
                writing.set(false);
                future.fail(e);
            }
        }
    }


    private class ReadTask implements Runnable {

        private final ByteBuffer dest;
        private final WrapperFuture<Integer, ?> future;

        ReadTask(ByteBuffer dest, WrapperFuture<Integer, ?> future) {
            this.dest = dest;
            this.future = future;
        }

        @Override
        public void run() {
            int read = 0;

            boolean forceRead = false;

            try {
                while (read == 0) {
                    socketReadBuffer.compact();

                    if (forceRead) {
                        forceRead = false;
                        Future<Integer> f = socketChannel.read(socketReadBuffer);
                        Integer socketRead = f.get();
                        if (socketRead.intValue() == -1) {
                            throw new EOFException(sm.getString("asyncChannelWrapperSecure.eof"));
                        }
                    }

                    socketReadBuffer.flip();

                    if (socketReadBuffer.hasRemaining()) {
                        // Decrypt the data in the buffer
                        SSLEngineResult r = sslEngine.unwrap(socketReadBuffer, dest);
                        read += r.bytesProduced();
                        Status s = r.getStatus();

                        if (s == Status.OK) {
                            // Bytes available for reading and there may be
                            // sufficient data in the socketReadBuffer to
                            // support further reads without reading from the
                            // socket
                        } else if (s == Status.BUFFER_UNDERFLOW) {
                            // There is partial data in the socketReadBuffer
                            if (read == 0) {
                                // Need more data before the partial data can be
                                // processed and some output generated
                                forceRead = true;
                            }
                            // else return the data we have and deal with the
                            // partial data on the next read
                        } else if (s == Status.BUFFER_OVERFLOW) {
                            // Not enough space in the destination buffer to
                            // store all of the data. We could use a bytes read
                            // value of -bufferSizeRequired to signal the new
                            // buffer size required but an explicit exception is
                            // clearer.
                            if (reading.compareAndSet(true, false)) {
                                throw new ReadBufferOverflowException(
                                        sslEngine.getSession().getApplicationBufferSize());
                            } else {
                                future.fail(new IllegalStateException(
                                        sm.getString("asyncChannelWrapperSecure.wrongStateRead")));
                            }
                        } else {
                            // Status.CLOSED - unexpected
                            throw new IllegalStateException(sm.getString("asyncChannelWrapperSecure.statusUnwrap"));
                        }

                        // Check for tasks
                        if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
                            Runnable runnable = sslEngine.getDelegatedTask();
                            while (runnable != null) {
                                runnable.run();
                                runnable = sslEngine.getDelegatedTask();
                            }
                        }
                    } else {
                        forceRead = true;
                    }
                }


                if (reading.compareAndSet(true, false)) {
                    future.complete(Integer.valueOf(read));
                } else {
                    future.fail(new IllegalStateException(sm.getString("asyncChannelWrapperSecure.wrongStateRead")));
                }
            } catch (RuntimeException | ReadBufferOverflowException | SSLException | EOFException | ExecutionException
                    | InterruptedException e) {
                reading.set(false);
                future.fail(e);
            }
        }
    }


    private class WebSocketSslHandshakeThread extends Thread {

        private final WrapperFuture<Void, Void> hFuture;

        private HandshakeStatus handshakeStatus;
        private Status resultStatus;

        WebSocketSslHandshakeThread(WrapperFuture<Void, Void> hFuture) {
            this.hFuture = hFuture;
        }

        @Override
        public void run() {
            try {
                sslEngine.beginHandshake();
                // So the first compact does the right thing
                socketReadBuffer.position(socketReadBuffer.limit());

                handshakeStatus = sslEngine.getHandshakeStatus();
                resultStatus = Status.OK;

                boolean handshaking = true;

                while (handshaking) {
                    switch (handshakeStatus) {
                        case NEED_WRAP: {
                            socketWriteBuffer.clear();
                            SSLEngineResult r = sslEngine.wrap(DUMMY, socketWriteBuffer);
                            checkResult(r, true);
                            socketWriteBuffer.flip();
                            Future<Integer> fWrite = socketChannel.write(socketWriteBuffer);
                            fWrite.get();
                            break;
                        }
                        case NEED_UNWRAP: {
                            socketReadBuffer.compact();
                            if (socketReadBuffer.position() == 0 || resultStatus == Status.BUFFER_UNDERFLOW) {
                                Future<Integer> fRead = socketChannel.read(socketReadBuffer);
                                fRead.get();
                            }
                            socketReadBuffer.flip();
                            SSLEngineResult r = sslEngine.unwrap(socketReadBuffer, DUMMY);
                            checkResult(r, false);
                            break;
                        }
                        case NEED_TASK: {
                            Runnable r = null;
                            while ((r = sslEngine.getDelegatedTask()) != null) {
                                r.run();
                            }
                            handshakeStatus = sslEngine.getHandshakeStatus();
                            break;
                        }
                        case FINISHED: {
                            handshaking = false;
                            break;
                        }
                        case NOT_HANDSHAKING: {
                            // Don't expect to see this during a handshake
                            throw new SSLException(sm.getString("asyncChannelWrapperSecure.notHandshaking"));
                        }
                    }
                }
            } catch (Exception e) {
                hFuture.fail(e);
                return;
            }

            hFuture.complete(null);
        }

        private void checkResult(SSLEngineResult result, boolean wrap) throws SSLException {

            handshakeStatus = result.getHandshakeStatus();
            resultStatus = result.getStatus();

            if (resultStatus != Status.OK && (wrap || resultStatus != Status.BUFFER_UNDERFLOW)) {
                throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.notOk", resultStatus));
            }
            if (wrap && result.bytesConsumed() != 0) {
                throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.wrap"));
            }
            if (!wrap && result.bytesProduced() != 0) {
                throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.unwrap"));
            }
        }
    }


    private static class WrapperFuture<T, A> implements Future<T> {

        private final CompletionHandler<T, A> handler;
        private final A attachment;

        private volatile T result = null;
        private volatile Throwable throwable = null;
        private CountDownLatch completionLatch = new CountDownLatch(1);

        WrapperFuture() {
            this(null, null);
        }

        WrapperFuture(CompletionHandler<T, A> handler, A attachment) {
            this.handler = handler;
            this.attachment = attachment;
        }

        public void complete(T result) {
            this.result = result;
            completionLatch.countDown();
            if (handler != null) {
                handler.completed(result, attachment);
            }
        }

        public void fail(Throwable t) {
            throwable = t;
            completionLatch.countDown();
            if (handler != null) {
                handler.failed(throwable, attachment);
            }
        }

        @Override
        public final boolean cancel(boolean mayInterruptIfRunning) {
            // Could support cancellation by closing the connection
            return false;
        }

        @Override
        public final boolean isCancelled() {
            // Could support cancellation by closing the connection
            return false;
        }

        @Override
        public final boolean isDone() {
            return completionLatch.getCount() > 0;
        }

        @Override
        public T get() throws InterruptedException, ExecutionException {
            completionLatch.await();
            if (throwable != null) {
                throw new ExecutionException(throwable);
            }
            return result;
        }

        @Override
        public T get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
            boolean latchResult = completionLatch.await(timeout, unit);
            if (latchResult == false) {
                throw new TimeoutException();
            }
            if (throwable != null) {
                throw new ExecutionException(throwable);
            }
            return result;
        }
    }

    private static final class LongToIntegerFuture implements Future<Integer> {

        private final Future<Long> wrapped;

        LongToIntegerFuture(Future<Long> wrapped) {
            this.wrapped = wrapped;
        }

        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            return wrapped.cancel(mayInterruptIfRunning);
        }

        @Override
        public boolean isCancelled() {
            return wrapped.isCancelled();
        }

        @Override
        public boolean isDone() {
            return wrapped.isDone();
        }

        @Override
        public Integer get() throws InterruptedException, ExecutionException {
            Long result = wrapped.get();
            if (result.longValue() > Integer.MAX_VALUE) {
                throw new ExecutionException(sm.getString("asyncChannelWrapperSecure.tooBig", result), null);
            }
            return Integer.valueOf(result.intValue());
        }

        @Override
        public Integer get(long timeout, TimeUnit unit)
                throws InterruptedException, ExecutionException, TimeoutException {
            Long result = wrapped.get(timeout, unit);
            if (result.longValue() > Integer.MAX_VALUE) {
                throw new ExecutionException(sm.getString("asyncChannelWrapperSecure.tooBig", result), null);
            }
            return Integer.valueOf(result.intValue());
        }
    }


    private static class SecureIOThreadFactory implements ThreadFactory {

        private AtomicInteger count = new AtomicInteger(0);

        @Override
        public Thread newThread(Runnable r) {
            Thread t = new Thread(r);
            t.setName("WebSocketClient-SecureIO-" + count.incrementAndGet());
            // No need to set the context class loader. The threads will be
            // cleaned up when the connection is closed.
            t.setDaemon(true);
            return t;
        }
    }
}