TLSClientHelloExtractor.java

/*
 *  Licensed to the Apache Software Foundation (ASF) under one or more
 *  contributor license agreements.  See the NOTICE file distributed with
 *  this work for additional information regarding copyright ownership.
 *  The ASF licenses this file to You under the Apache License, Version 2.0
 *  (the "License"); you may not use this file except in compliance with
 *  the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */
package org.apache.tomcat.util.net;

import java.io.IOException;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;

import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;
import org.apache.tomcat.util.buf.HexUtils;
import org.apache.tomcat.util.http.parser.HttpParser;
import org.apache.tomcat.util.net.openssl.ciphers.Cipher;
import org.apache.tomcat.util.res.StringManager;

/**
 * This class extracts the SNI host name and ALPN protocols from a TLS
 * client-hello message.
 */
public class TLSClientHelloExtractor {

    private static final Log log = LogFactory.getLog(TLSClientHelloExtractor.class);
    private static final StringManager sm = StringManager.getManager(TLSClientHelloExtractor.class);

    private final ExtractorResult result;
    private final List<Cipher> clientRequestedCiphers;
    private final List<String> clientRequestedCipherNames;
    private final String sniValue;
    private final List<String> clientRequestedApplicationProtocols;
    private final List<String> clientRequestedProtocols;

    private static final int TLS_RECORD_HEADER_LEN = 5;

    private static final int TLS_EXTENSION_SERVER_NAME = 0;
    private static final int TLS_EXTENSION_ALPN = 16;
    private static final int TLS_EXTENSION_SUPPORTED_VERSION = 43;

    public static byte[] USE_TLS_RESPONSE = ("HTTP/1.1 400 \r\n" +
            "Content-Type: text/plain;charset=UTF-8\r\n" +
            "Connection: close\r\n" +
            "\r\n" +
            "Bad Request\r\n" +
            "This combination of host and port requires TLS.\r\n").getBytes(StandardCharsets.UTF_8);


    /**
     * Creates the instance of the parser and processes the provided buffer. The
     * buffer position and limit will be modified during the execution of this
     * method but they will be returned to the original values before the method
     * exits.
     *
     * @param netInBuffer The buffer containing the TLS data to process
     * @throws IOException If the client hello message is malformed
     */
    public TLSClientHelloExtractor(ByteBuffer netInBuffer) throws IOException {
        // Buffer is in write mode at this point. Record the current position so
        // the buffer state can be restored at the end of this method.
        int pos = netInBuffer.position();
        int limit = netInBuffer.limit();
        ExtractorResult result = ExtractorResult.NOT_PRESENT;
        List<Cipher> clientRequestedCiphers = new ArrayList<>();
        List<String> clientRequestedCipherNames = new ArrayList<>();
        List<String> clientRequestedApplicationProtocols = new ArrayList<>();
        List<String> clientRequestedProtocols = new ArrayList<>();
        String sniValue = null;
        try {
            // Switch to read mode.
            netInBuffer.flip();

            // A complete TLS record header is required before we can figure out
            // how many bytes there are in the record.
            if (!isAvailable(netInBuffer, TLS_RECORD_HEADER_LEN)) {
                result = handleIncompleteRead(netInBuffer);
                return;
            }

            if (!isTLSHandshake(netInBuffer)) {
                // Is the client trying to use clear text HTTP?
                if (isHttp(netInBuffer)) {
                    result = ExtractorResult.NON_SECURE;
                }
                return;
            }

            if (!isAllRecordAvailable(netInBuffer)) {
                result = handleIncompleteRead(netInBuffer);
                return;
            }

            if (!isClientHello(netInBuffer)) {
                return;
            }

            if (!isAllClientHelloAvailable(netInBuffer)) {
                // Client hello didn't fit into single TLS record.
                // Treat this as not present.
                log.warn(sm.getString("sniExtractor.clientHelloTooBig"));
                return;
            }

            // Protocol Version
            String legacyVersion = readProtocol(netInBuffer);
            // Random
            skipBytes(netInBuffer, 32);
            // Session ID (single byte for length)
            skipBytes(netInBuffer, (netInBuffer.get() & 0xFF));

            // Cipher Suites
            // (2 bytes for length, each cipher ID is 2 bytes)
            int cipherCount = netInBuffer.getChar() / 2;
            for (int i = 0; i < cipherCount; i++) {
                char cipherId = netInBuffer.getChar();
                Cipher c = Cipher.valueOf(cipherId);
                // Some clients transmit grease values (see RFC 8701)
                if (c == null) {
                    clientRequestedCipherNames.add("Unknown(0x" + HexUtils.toHexString(cipherId) + ")");
                } else {
                    clientRequestedCiphers.add(c);
                    clientRequestedCipherNames.add(c.name());
                }
            }

            // Compression methods (single byte for length)
            skipBytes(netInBuffer, (netInBuffer.get() & 0xFF));

            if (!netInBuffer.hasRemaining()) {
                // No more data means no extensions present
                return;
            }

            // Extension length
            skipBytes(netInBuffer, 2);
            // Read the extensions until we run out of data or find the data
            // we need
            while (netInBuffer.hasRemaining() && (sniValue == null ||
                    clientRequestedApplicationProtocols.isEmpty() || clientRequestedProtocols.isEmpty())) {
                // Extension type is two byte
                char extensionType = netInBuffer.getChar();
                // Extension size is another two bytes
                char extensionDataSize = netInBuffer.getChar();
                switch (extensionType) {
                case TLS_EXTENSION_SERVER_NAME: {
                    sniValue = readSniExtension(netInBuffer);
                    break;
                }
                case TLS_EXTENSION_ALPN:
                    readAlpnExtension(netInBuffer, clientRequestedApplicationProtocols);
                    break;
                case TLS_EXTENSION_SUPPORTED_VERSION:
                    readSupportedVersions(netInBuffer, clientRequestedProtocols);
                    break;
                default: {
                    skipBytes(netInBuffer, extensionDataSize);
                }
                }
            }
            if (clientRequestedProtocols.isEmpty()) {
                clientRequestedProtocols.add(legacyVersion);
            }
            result = ExtractorResult.COMPLETE;
        } catch (BufferUnderflowException | IllegalArgumentException e) {
            throw new IOException(sm.getString("sniExtractor.clientHelloInvalid"), e);
        } finally {
            this.result = result;
            this.clientRequestedCiphers = clientRequestedCiphers;
            this.clientRequestedCipherNames = clientRequestedCipherNames;
            this.clientRequestedApplicationProtocols = clientRequestedApplicationProtocols;
            this.sniValue = sniValue;
            this.clientRequestedProtocols = clientRequestedProtocols;
            // Whatever happens, return the buffer to its original state
            netInBuffer.limit(limit);
            netInBuffer.position(pos);
        }
    }


    public ExtractorResult getResult() {
        return result;
    }


    /**
     * @return The SNI value provided by the client converted to lower case if
     *         not already lower case.
     */
    public String getSNIValue() {
        if (result == ExtractorResult.COMPLETE) {
            return sniValue;
        } else {
            throw new IllegalStateException(sm.getString("sniExtractor.tooEarly"));
        }
    }


    public List<Cipher> getClientRequestedCiphers() {
        if (result == ExtractorResult.COMPLETE || result == ExtractorResult.NOT_PRESENT) {
            return clientRequestedCiphers;
        } else {
            throw new IllegalStateException(sm.getString("sniExtractor.tooEarly"));
        }
    }


    public List<String> getClientRequestedCipherNames() {
        if (result == ExtractorResult.COMPLETE || result == ExtractorResult.NOT_PRESENT) {
            return clientRequestedCipherNames;
        } else {
            throw new IllegalStateException(sm.getString("sniExtractor.tooEarly"));
        }
    }


    public List<String> getClientRequestedApplicationProtocols() {
        if (result == ExtractorResult.COMPLETE || result == ExtractorResult.NOT_PRESENT) {
            return clientRequestedApplicationProtocols;
        } else {
            throw new IllegalStateException(sm.getString("sniExtractor.tooEarly"));
        }
    }


    public List<String> getClientRequestedProtocols() {
        if (result == ExtractorResult.COMPLETE || result == ExtractorResult.NOT_PRESENT) {
            return clientRequestedProtocols;
        } else {
            throw new IllegalStateException(sm.getString("sniExtractor.tooEarly"));
        }
    }


    private static ExtractorResult handleIncompleteRead(ByteBuffer bb) {
        if (bb.limit() == bb.capacity()) {
            // Buffer not big enough
            return ExtractorResult.UNDERFLOW;
        } else {
            // Need to read more data into buffer
            return ExtractorResult.NEED_READ;
        }
    }


    private static boolean isAvailable(ByteBuffer bb, int size) {
        if (bb.remaining() < size) {
            bb.position(bb.limit());
            return false;
        }
        return true;
    }


    private static boolean isTLSHandshake(ByteBuffer bb) {
        // For a TLS client hello the first byte must be 22 - handshake
        if (bb.get() != 22) {
            return false;
        }
        // Next two bytes are major/minor version. We need at least 3.1.
        byte b2 = bb.get();
        byte b3 = bb.get();
        if (b2 < 3 || b2 == 3 && b3 == 0) {
            return false;
        }
        return true;
    }


    private static boolean isHttp(ByteBuffer bb) {
        // Based on code in Http11InputBuffer
        // Note: The actual request is not important. This code only checks that
        //       the buffer contains a correctly formatted HTTP request line.
        //       The method, target and protocol are not validated.
        byte chr = 0;
        bb.position(0);

        // Skip blank lines
        do {
            if (!bb.hasRemaining()) {
                return false;
            }
            chr = bb.get();
        } while (chr == '\r' || chr == '\n');

        // Read the method
        do {
            if (!HttpParser.isToken(chr) || !bb.hasRemaining()) {
                return false;
            }
            chr = bb.get();
        } while (chr != ' ' && chr != '\t');

        // Whitespace between method and target
        while (chr == ' ' || chr == '\t') {
            if (!bb.hasRemaining()) {
                return false;
            }
            chr = bb.get();
        }

        // Read the target
        while (chr != ' ' && chr != '\t') {
            if (HttpParser.isNotRequestTarget(chr) || !bb.hasRemaining()) {
                return false;
            }
            chr = bb.get();
        }

        // Whitespace between target and protocol
        while (chr == ' ' || chr == '\t') {
            if (!bb.hasRemaining()) {
                return false;
            }
            chr = bb.get();
        }

        // Read protocol
        do {
            if (!HttpParser.isHttpProtocol(chr) || !bb.hasRemaining()) {
                return false;
            }
            chr = bb.get();

        } while (chr != '\r' && chr != '\n');

        return true;
    }


    private static boolean isAllRecordAvailable(ByteBuffer bb) {
        // Next two bytes (unsigned) are the size of the record. We need all of
        // it.
        int size = bb.getChar();
        return isAvailable(bb, size);
    }


    private static boolean isClientHello(ByteBuffer bb) {
        // Client hello is handshake type 1
        if (bb.get() == 1) {
            return true;
        }
        return false;
    }


    private static boolean isAllClientHelloAvailable(ByteBuffer bb) {
        // Next three bytes (unsigned) are the size of the client hello. We need
        // all of it.
        int size = ((bb.get() & 0xFF) << 16) + ((bb.get() & 0xFF) << 8) + (bb.get() & 0xFF);
        return isAvailable(bb, size);
    }


    private static void skipBytes(ByteBuffer bb, int size) {
        bb.position(bb.position() + size);
    }


    private static String readProtocol(ByteBuffer bb) {
        char protocol = bb.getChar();
        switch (protocol) {
            case 0x0300: {
                return Constants.SSL_PROTO_SSLv3;
            }
            case 0x0301: {
                return Constants.SSL_PROTO_TLSv1_0;
            }
            case 0x0302: {
                return Constants.SSL_PROTO_TLSv1_1;
            }
            case 0x0303: {
                return Constants.SSL_PROTO_TLSv1_2;
            }
            case 0x0304: {
                return Constants.SSL_PROTO_TLSv1_3;
            }
            default:
                return "Unknown(0x" + HexUtils.toHexString(protocol) + ")";
        }
    }


    private static String readSniExtension(ByteBuffer bb) {
        // First 2 bytes are size of server name list (only expecting one)
        // Next byte is type (0 for hostname)
        skipBytes(bb, 3);
        // Next 2 bytes are length of host name
        char serverNameSize = bb.getChar();
        byte[] serverNameBytes = new byte[serverNameSize];
        bb.get(serverNameBytes);
        return new String(serverNameBytes, StandardCharsets.UTF_8).toLowerCase(Locale.ENGLISH);
    }


    private static void readAlpnExtension(ByteBuffer bb, List<String> protocolNames) {
        // First 2 bytes are size of the protocol list
        char toRead = bb.getChar();
        byte[] inputBuffer = new byte[255];
        while (toRead > 0) {
            // Each list entry has one byte for length followed by a string of
            // that length
            int len = bb.get() & 0xFF;
            bb.get(inputBuffer, 0, len);
            protocolNames.add(new String(inputBuffer, 0, len, StandardCharsets.UTF_8));
            toRead--;
            toRead -= len;
        }
    }


    private static void readSupportedVersions(ByteBuffer bb, List<String> protocolNames) {
        // First byte is the size of the list in bytes
        int count = (bb.get() & 0xFF) / 2;
        // Then the list of protocols
        for (int i = 0; i < count; i++) {
            protocolNames.add(readProtocol(bb));
        }
    }


    public enum ExtractorResult {
        COMPLETE,
        NOT_PRESENT,
        UNDERFLOW,
        NEED_READ,
        NON_SECURE
    }
}