001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.wicket.protocol.ws;
018
019import java.io.IOException;
020import java.util.ArrayList;
021import java.util.Enumeration;
022import java.util.List;
023
024import javax.servlet.FilterChain;
025import javax.servlet.ServletException;
026import javax.servlet.http.HttpServletRequest;
027import javax.servlet.http.HttpServletResponse;
028
029import org.apache.wicket.ThreadContext;
030import org.apache.wicket.protocol.http.WebApplication;
031import org.apache.wicket.protocol.http.WicketFilter;
032import org.apache.wicket.request.cycle.RequestCycle;
033import org.apache.wicket.request.http.WebResponse;
034import org.apache.wicket.util.string.Strings;
035
036/**
037 * An extension of WicketFilter that is used to check whether
038 * the processed HttpServletRequest needs to upgrade its protocol
039 * from HTTP to something else
040 *
041 * @since 6.0
042 */
043public class AbstractUpgradeFilter extends WicketFilter
044{
045        public AbstractUpgradeFilter()
046        {
047                super();
048        }
049
050        public AbstractUpgradeFilter(WebApplication application)
051        {
052                super(application);
053        }
054
055        @Override
056        protected boolean processRequestCycle(final RequestCycle requestCycle, final WebResponse webResponse,
057                        final HttpServletRequest httpServletRequest, final HttpServletResponse httpServletResponse,
058                        final FilterChain chain)
059                throws IOException, ServletException
060        {
061                ThreadContext.setRequestCycle(requestCycle);
062                if (acceptWebSocket(httpServletRequest, httpServletResponse) || httpServletResponse.isCommitted())
063                {
064                        return true;
065                }
066                
067                return super.processRequestCycle(requestCycle, webResponse, httpServletRequest, httpServletResponse, chain);
068        }
069
070        protected boolean acceptWebSocket(HttpServletRequest req, HttpServletResponse resp)
071                        throws ServletException, IOException
072        {
073                // Information required to send the server handshake message
074                String key;
075                String subProtocol = null;
076
077                if (!headerContainsToken(req, "Upgrade", "websocket"))
078                {
079                        return false;
080                }
081
082                if (!headerContainsToken(req, "Connection", "upgrade"))
083                {
084                        resp.sendError(HttpServletResponse.SC_BAD_REQUEST);
085                        return false;
086                }
087
088                if (!headerContainsToken(req, "Sec-websocket-version", "13"))
089                {
090                        resp.setStatus(HttpServletResponse.SC_BAD_REQUEST); // http://tools.ietf.org/html/rfc6455#section-4.4
091                        resp.setHeader("Sec-WebSocket-Version", "13");
092                        return false;
093                }
094
095                key = req.getHeader("Sec-WebSocket-Key");
096                if (key == null)
097                {
098                        resp.sendError(HttpServletResponse.SC_BAD_REQUEST);
099                        return false;
100                }
101
102                String origin = req.getHeader("Origin");
103                if (!verifyOrigin(origin))
104                {
105                        resp.sendError(HttpServletResponse.SC_FORBIDDEN);
106                        return false;
107                }
108
109                List<String> subProtocols = getTokensFromHeader(req, "Sec-WebSocket-Protocol-Client");
110                if (!subProtocols.isEmpty())
111                {
112                        subProtocol = selectSubProtocol(subProtocols);
113                }
114
115                if (subProtocol != null)
116                {
117                        resp.setHeader("Sec-WebSocket-Protocol", subProtocol);
118                }
119
120                return true;
121        }
122
123        /*
124         * This only works for tokens. Quoted strings need more sophisticated
125         * parsing.
126         */
127        private boolean headerContainsToken(HttpServletRequest req, String headerName, String target)
128        {
129                Enumeration<String> headers = req.getHeaders(headerName);
130                while (headers.hasMoreElements()) {
131                        String header = headers.nextElement();
132                        String[] tokens = Strings.split(header, ',');
133                        for (String token : tokens) {
134                                if (target.equalsIgnoreCase(token.trim())) {
135                                        return true;
136                                }
137                        }
138                }
139                return false;
140        }
141
142        /*
143         * This only works for tokens. Quoted strings need more sophisticated
144         * parsing.
145         */
146        protected List<String> getTokensFromHeader(HttpServletRequest req, String headerName)
147        {
148                List<String> result = new ArrayList<>();
149
150                Enumeration<String> headers = req.getHeaders(headerName);
151                while (headers.hasMoreElements()) {
152                        String header = headers.nextElement();
153                        String[] tokens = Strings.split(header, ',');
154                        for (String token : tokens) {
155                                result.add(token.trim());
156                        }
157                }
158                return result;
159        }
160
161        /**
162         * Intended to be overridden by sub-classes that wish to verify the origin
163         * of a WebSocket request before processing it.
164         *
165         * @param origin    The value of the origin header from the request which
166         *                  may be <code>null</code>
167         *
168         * @return  <code>true</code> to accept the request. <code>false</code> to
169         *          reject it. This default implementation always returns
170         *          <code>true</code>.
171         */
172        protected boolean verifyOrigin(String origin)
173        {
174                return true;
175        }
176
177        /**
178         * Intended to be overridden by sub-classes that wish to select a
179         * sub-protocol if the client provides a list of supported protocols.
180         *
181         * @param subProtocols  The list of sub-protocols supported by the client
182         *                      in client preference order. The server is under no
183         *                      obligation to respect the declared preference
184         * @return  <code>null</code> if no sub-protocol is selected or the name of
185         *          the protocol which <b>must</b> be one of the protocols listed by
186         *          the client. This default implementation always returns
187         *          <code>null</code>.
188         */
189        protected String selectSubProtocol(List<String> subProtocols)
190        {
191                return null;
192        }
193}