ApplicationContextFacade.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.catalina.core;


import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.net.URL;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Enumeration;
import java.util.EventListener;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import jakarta.servlet.Filter;
import jakarta.servlet.FilterRegistration;
import jakarta.servlet.RequestDispatcher;
import jakarta.servlet.Servlet;
import jakarta.servlet.ServletContext;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRegistration;
import jakarta.servlet.ServletRegistration.Dynamic;
import jakarta.servlet.SessionCookieConfig;
import jakarta.servlet.SessionTrackingMode;
import jakarta.servlet.descriptor.JspConfigDescriptor;

import org.apache.catalina.Globals;
import org.apache.catalina.security.SecurityUtil;
import org.apache.tomcat.util.ExceptionUtils;


/**
 * Facade object which masks the internal <code>ApplicationContext</code> object from the web application.
 *
 * @author Remy Maucherat
 */
public class ApplicationContextFacade implements ServletContext {

    // ---------------------------------------------------------- Attributes
    /**
     * Cache Class object used for reflection.
     */
    private final Map<String,Class<?>[]> classCache;


    /**
     * Cache method object.
     */
    private final Map<String,Method> objectCache;


    // ----------------------------------------------------------- Constructors

    /**
     * Construct a new instance of this class, associated with the specified Context instance.
     *
     * @param context The associated Context instance
     */
    public ApplicationContextFacade(ApplicationContext context) {
        super();
        this.context = context;

        classCache = new HashMap<>();
        objectCache = new ConcurrentHashMap<>();
        initClassCache();
    }


    private void initClassCache() {
        Class<?>[] clazz = new Class[] { String.class };
        classCache.put("getContext", clazz);
        classCache.put("getMimeType", clazz);
        classCache.put("getResourcePaths", clazz);
        classCache.put("getResource", clazz);
        classCache.put("getResourceAsStream", clazz);
        classCache.put("getRequestDispatcher", clazz);
        classCache.put("getNamedDispatcher", clazz);
        classCache.put("getServlet", clazz);
        classCache.put("setInitParameter", new Class[] { String.class, String.class });
        classCache.put("createServlet", new Class[] { Class.class });
        classCache.put("addServlet", new Class[] { String.class, String.class });
        classCache.put("createFilter", new Class[] { Class.class });
        classCache.put("addFilter", new Class[] { String.class, String.class });
        classCache.put("createListener", new Class[] { Class.class });
        classCache.put("addListener", clazz);
        classCache.put("getFilterRegistration", clazz);
        classCache.put("getServletRegistration", clazz);
        classCache.put("getInitParameter", clazz);
        classCache.put("setAttribute", new Class[] { String.class, Object.class });
        classCache.put("removeAttribute", clazz);
        classCache.put("getRealPath", clazz);
        classCache.put("getAttribute", clazz);
        classCache.put("log", clazz);
        classCache.put("setSessionTrackingModes", new Class[] { Set.class });
        classCache.put("addJspFile", new Class[] { String.class, String.class });
        classCache.put("declareRoles", new Class[] { String[].class });
        classCache.put("setSessionTimeout", new Class[] { int.class });
        classCache.put("setRequestCharacterEncoding", new Class[] { String.class });
        classCache.put("setResponseCharacterEncoding", new Class[] { String.class });
    }


    // ----------------------------------------------------- Instance Variables


    /**
     * Wrapped application context.
     */
    private final ApplicationContext context;


    // ------------------------------------------------- ServletContext Methods


    @Override
    public ServletContext getContext(String uripath) {
        ServletContext theContext = null;
        if (SecurityUtil.isPackageProtectionEnabled()) {
            theContext = (ServletContext) doPrivileged("getContext", new Object[] { uripath });
        } else {
            theContext = context.getContext(uripath);
        }
        if ((theContext != null) && (theContext instanceof ApplicationContext)) {
            theContext = ((ApplicationContext) theContext).getFacade();
        }
        return theContext;
    }


    @Override
    public int getMajorVersion() {
        return context.getMajorVersion();
    }


    @Override
    public int getMinorVersion() {
        return context.getMinorVersion();
    }


    @Override
    public String getMimeType(String file) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (String) doPrivileged("getMimeType", new Object[] { file });
        } else {
            return context.getMimeType(file);
        }
    }

    @Override
    @SuppressWarnings("unchecked") // doPrivileged() returns the correct type
    public Set<String> getResourcePaths(String path) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (Set<String>) doPrivileged("getResourcePaths", new Object[] { path });
        } else {
            return context.getResourcePaths(path);
        }
    }


    @Override
    public URL getResource(String path) throws MalformedURLException {
        if (Globals.IS_SECURITY_ENABLED) {
            try {
                return (URL) invokeMethod(context, "getResource", new Object[] { path });
            } catch (Throwable t) {
                ExceptionUtils.handleThrowable(t);
                if (t instanceof MalformedURLException) {
                    throw (MalformedURLException) t;
                }
                return null;
            }
        } else {
            return context.getResource(path);
        }
    }


    @Override
    public InputStream getResourceAsStream(String path) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (InputStream) doPrivileged("getResourceAsStream", new Object[] { path });
        } else {
            return context.getResourceAsStream(path);
        }
    }


    @Override
    public RequestDispatcher getRequestDispatcher(final String path) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (RequestDispatcher) doPrivileged("getRequestDispatcher", new Object[] { path });
        } else {
            return context.getRequestDispatcher(path);
        }
    }


    @Override
    public RequestDispatcher getNamedDispatcher(String name) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (RequestDispatcher) doPrivileged("getNamedDispatcher", new Object[] { name });
        } else {
            return context.getNamedDispatcher(name);
        }
    }


    @Override
    public void log(String msg) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            doPrivileged("log", new Object[] { msg });
        } else {
            context.log(msg);
        }
    }


    @Override
    public void log(String message, Throwable throwable) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            doPrivileged("log", new Class[] { String.class, Throwable.class }, new Object[] { message, throwable });
        } else {
            context.log(message, throwable);
        }
    }


    @Override
    public String getRealPath(String path) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (String) doPrivileged("getRealPath", new Object[] { path });
        } else {
            return context.getRealPath(path);
        }
    }


    @Override
    public String getServerInfo() {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (String) doPrivileged("getServerInfo", null);
        } else {
            return context.getServerInfo();
        }
    }


    @Override
    public String getInitParameter(String name) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (String) doPrivileged("getInitParameter", new Object[] { name });
        } else {
            return context.getInitParameter(name);
        }
    }


    @Override
    @SuppressWarnings("unchecked") // doPrivileged() returns the correct type
    public Enumeration<String> getInitParameterNames() {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (Enumeration<String>) doPrivileged("getInitParameterNames", null);
        } else {
            return context.getInitParameterNames();
        }
    }


    @Override
    public Object getAttribute(String name) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return doPrivileged("getAttribute", new Object[] { name });
        } else {
            return context.getAttribute(name);
        }
    }


    @Override
    @SuppressWarnings("unchecked") // doPrivileged() returns the correct type
    public Enumeration<String> getAttributeNames() {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (Enumeration<String>) doPrivileged("getAttributeNames", null);
        } else {
            return context.getAttributeNames();
        }
    }


    @Override
    public void setAttribute(String name, Object object) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            doPrivileged("setAttribute", new Object[] { name, object });
        } else {
            context.setAttribute(name, object);
        }
    }


    @Override
    public void removeAttribute(String name) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            doPrivileged("removeAttribute", new Object[] { name });
        } else {
            context.removeAttribute(name);
        }
    }


    @Override
    public String getServletContextName() {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (String) doPrivileged("getServletContextName", null);
        } else {
            return context.getServletContextName();
        }
    }


    @Override
    public String getContextPath() {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (String) doPrivileged("getContextPath", null);
        } else {
            return context.getContextPath();
        }
    }


    @Override
    public FilterRegistration.Dynamic addFilter(String filterName, String className) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (FilterRegistration.Dynamic) doPrivileged("addFilter", new Object[] { filterName, className });
        } else {
            return context.addFilter(filterName, className);
        }
    }


    @Override
    public FilterRegistration.Dynamic addFilter(String filterName, Filter filter) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (FilterRegistration.Dynamic) doPrivileged("addFilter", new Class[] { String.class, Filter.class },
                    new Object[] { filterName, filter });
        } else {
            return context.addFilter(filterName, filter);
        }
    }


    @Override
    public FilterRegistration.Dynamic addFilter(String filterName, Class<? extends Filter> filterClass) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (FilterRegistration.Dynamic) doPrivileged("addFilter", new Class[] { String.class, Class.class },
                    new Object[] { filterName, filterClass });
        } else {
            return context.addFilter(filterName, filterClass);
        }
    }

    @Override
    @SuppressWarnings("unchecked") // doPrivileged() returns the correct type
    public <T extends Filter> T createFilter(Class<T> c) throws ServletException {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            try {
                return (T) invokeMethod(context, "createFilter", new Object[] { c });
            } catch (Throwable t) {
                ExceptionUtils.handleThrowable(t);
                if (t instanceof ServletException) {
                    throw (ServletException) t;
                }
                return null;
            }
        } else {
            return context.createFilter(c);
        }
    }


    @Override
    public FilterRegistration getFilterRegistration(String filterName) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (FilterRegistration) doPrivileged("getFilterRegistration", new Object[] { filterName });
        } else {
            return context.getFilterRegistration(filterName);
        }
    }


    @Override
    public ServletRegistration.Dynamic addServlet(String servletName, String className) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (ServletRegistration.Dynamic) doPrivileged("addServlet", new Object[] { servletName, className });
        } else {
            return context.addServlet(servletName, className);
        }
    }


    @Override
    public ServletRegistration.Dynamic addServlet(String servletName, Servlet servlet) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (ServletRegistration.Dynamic) doPrivileged("addServlet", new Class[] { String.class, Servlet.class },
                    new Object[] { servletName, servlet });
        } else {
            return context.addServlet(servletName, servlet);
        }
    }


    @Override
    public ServletRegistration.Dynamic addServlet(String servletName, Class<? extends Servlet> servletClass) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (ServletRegistration.Dynamic) doPrivileged("addServlet", new Class[] { String.class, Class.class },
                    new Object[] { servletName, servletClass });
        } else {
            return context.addServlet(servletName, servletClass);
        }
    }


    @Override
    public Dynamic addJspFile(String jspName, String jspFile) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (ServletRegistration.Dynamic) doPrivileged("addJspFile", new Object[] { jspName, jspFile });
        } else {
            return context.addJspFile(jspName, jspFile);
        }
    }


    @Override
    @SuppressWarnings("unchecked") // doPrivileged() returns the correct type
    public <T extends Servlet> T createServlet(Class<T> c) throws ServletException {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            try {
                return (T) invokeMethod(context, "createServlet", new Object[] { c });
            } catch (Throwable t) {
                ExceptionUtils.handleThrowable(t);
                if (t instanceof ServletException) {
                    throw (ServletException) t;
                }
                return null;
            }
        } else {
            return context.createServlet(c);
        }
    }


    @Override
    public ServletRegistration getServletRegistration(String servletName) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (ServletRegistration) doPrivileged("getServletRegistration", new Object[] { servletName });
        } else {
            return context.getServletRegistration(servletName);
        }
    }


    @Override
    @SuppressWarnings("unchecked") // doPrivileged() returns the correct type
    public Set<SessionTrackingMode> getDefaultSessionTrackingModes() {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (Set<SessionTrackingMode>) doPrivileged("getDefaultSessionTrackingModes", null);
        } else {
            return context.getDefaultSessionTrackingModes();
        }
    }

    @Override
    @SuppressWarnings("unchecked") // doPrivileged() returns the correct type
    public Set<SessionTrackingMode> getEffectiveSessionTrackingModes() {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (Set<SessionTrackingMode>) doPrivileged("getEffectiveSessionTrackingModes", null);
        } else {
            return context.getEffectiveSessionTrackingModes();
        }
    }


    @Override
    public SessionCookieConfig getSessionCookieConfig() {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (SessionCookieConfig) doPrivileged("getSessionCookieConfig", null);
        } else {
            return context.getSessionCookieConfig();
        }
    }


    @Override
    public void setSessionTrackingModes(Set<SessionTrackingMode> sessionTrackingModes) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            doPrivileged("setSessionTrackingModes", new Object[] { sessionTrackingModes });
        } else {
            context.setSessionTrackingModes(sessionTrackingModes);
        }
    }


    @Override
    public boolean setInitParameter(String name, String value) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return ((Boolean) doPrivileged("setInitParameter", new Object[] { name, value })).booleanValue();
        } else {
            return context.setInitParameter(name, value);
        }
    }


    @Override
    public void addListener(Class<? extends EventListener> listenerClass) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            doPrivileged("addListener", new Class[] { Class.class }, new Object[] { listenerClass });
        } else {
            context.addListener(listenerClass);
        }
    }


    @Override
    public void addListener(String className) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            doPrivileged("addListener", new Object[] { className });
        } else {
            context.addListener(className);
        }
    }


    @Override
    public <T extends EventListener> void addListener(T t) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            doPrivileged("addListener", new Class[] { EventListener.class }, new Object[] { t });
        } else {
            context.addListener(t);
        }
    }


    @Override
    @SuppressWarnings("unchecked") // doPrivileged() returns the correct type
    public <T extends EventListener> T createListener(Class<T> c) throws ServletException {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            try {
                return (T) invokeMethod(context, "createListener", new Object[] { c });
            } catch (Throwable t) {
                ExceptionUtils.handleThrowable(t);
                if (t instanceof ServletException) {
                    throw (ServletException) t;
                }
                return null;
            }
        } else {
            return context.createListener(c);
        }
    }


    @Override
    public void declareRoles(String... roleNames) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            doPrivileged("declareRoles", new Object[] { roleNames });
        } else {
            context.declareRoles(roleNames);
        }
    }


    @Override
    public ClassLoader getClassLoader() {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (ClassLoader) doPrivileged("getClassLoader", null);
        } else {
            return context.getClassLoader();
        }
    }


    @Override
    public int getEffectiveMajorVersion() {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return ((Integer) doPrivileged("getEffectiveMajorVersion", null)).intValue();
        } else {
            return context.getEffectiveMajorVersion();
        }
    }


    @Override
    public int getEffectiveMinorVersion() {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return ((Integer) doPrivileged("getEffectiveMinorVersion", null)).intValue();
        } else {
            return context.getEffectiveMinorVersion();
        }
    }


    @Override
    @SuppressWarnings("unchecked") // doPrivileged() returns the correct type
    public Map<String,? extends FilterRegistration> getFilterRegistrations() {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (Map<String,? extends FilterRegistration>) doPrivileged("getFilterRegistrations", null);
        } else {
            return context.getFilterRegistrations();
        }
    }


    @Override
    public JspConfigDescriptor getJspConfigDescriptor() {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (JspConfigDescriptor) doPrivileged("getJspConfigDescriptor", null);
        } else {
            return context.getJspConfigDescriptor();
        }
    }


    @Override
    @SuppressWarnings("unchecked") // doPrivileged() returns the correct type
    public Map<String,? extends ServletRegistration> getServletRegistrations() {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (Map<String,? extends ServletRegistration>) doPrivileged("getServletRegistrations", null);
        } else {
            return context.getServletRegistrations();
        }
    }


    @Override
    public String getVirtualServerName() {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (String) doPrivileged("getVirtualServerName", null);
        } else {
            return context.getVirtualServerName();
        }
    }


    @Override
    public int getSessionTimeout() {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return ((Integer) doPrivileged("getSessionTimeout", null)).intValue();
        } else {
            return context.getSessionTimeout();
        }
    }


    @Override
    public void setSessionTimeout(int sessionTimeout) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            doPrivileged("setSessionTimeout", new Object[] { Integer.valueOf(sessionTimeout) });
        } else {
            context.setSessionTimeout(sessionTimeout);
        }
    }


    @Override
    public String getRequestCharacterEncoding() {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (String) doPrivileged("getRequestCharacterEncoding", null);
        } else {
            return context.getRequestCharacterEncoding();
        }
    }


    @Override
    public void setRequestCharacterEncoding(String encoding) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            doPrivileged("setRequestCharacterEncoding", new Object[] { encoding });
        } else {
            context.setRequestCharacterEncoding(encoding);
        }
    }


    @Override
    public String getResponseCharacterEncoding() {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            return (String) doPrivileged("getResponseCharacterEncoding", null);
        } else {
            return context.getResponseCharacterEncoding();
        }
    }


    @Override
    public void setResponseCharacterEncoding(String encoding) {
        if (SecurityUtil.isPackageProtectionEnabled()) {
            doPrivileged("setResponseCharacterEncoding", new Object[] { encoding });
        } else {
            context.setResponseCharacterEncoding(encoding);
        }
    }


    /**
     * Use reflection to invoke the requested method. Cache the method object to speed up the process
     *
     * @param methodName The method to call.
     * @param params     The arguments passed to the called method.
     */
    private Object doPrivileged(final String methodName, final Object[] params) {
        try {
            return invokeMethod(context, methodName, params);
        } catch (Throwable t) {
            ExceptionUtils.handleThrowable(t);
            throw new RuntimeException(t.getMessage(), t);
        }
    }


    /**
     * Use reflection to invoke the requested method. Cache the method object to speed up the process
     *
     * @param appContext The ApplicationContext object on which the method will be invoked
     * @param methodName The method to call.
     * @param params     The arguments passed to the called method.
     */
    private Object invokeMethod(ApplicationContext appContext, final String methodName, Object[] params)
            throws Throwable {

        try {
            Method method = objectCache.get(methodName);
            if (method == null) {
                method = appContext.getClass().getMethod(methodName, classCache.get(methodName));
                objectCache.put(methodName, method);
            }

            return executeMethod(method, appContext, params);
        } catch (Exception ex) {
            handleException(ex);
            return null;
        } finally {
            params = null;
        }
    }

    /**
     * Use reflection to invoke the requested method. Cache the method object to speed up the process
     *
     * @param methodName The method to invoke.
     * @param clazz      The class where the method is.
     * @param params     The arguments passed to the called method.
     */
    private Object doPrivileged(final String methodName, final Class<?>[] clazz, Object[] params) {

        try {
            Method method = context.getClass().getMethod(methodName, clazz);
            return executeMethod(method, context, params);
        } catch (Exception ex) {
            try {
                handleException(ex);
            } catch (Throwable t) {
                ExceptionUtils.handleThrowable(t);
                throw new RuntimeException(t.getMessage());
            }
            return null;
        } finally {
            params = null;
        }
    }


    /**
     * Executes the method of the specified <code>ApplicationContext</code>
     *
     * @param method  The method object to be invoked.
     * @param context The ApplicationContext object on which the method will be invoked
     * @param params  The arguments passed to the called method.
     */
    private Object executeMethod(final Method method, final ApplicationContext context, final Object[] params)
            throws PrivilegedActionException, IllegalAccessException, InvocationTargetException {

        if (SecurityUtil.isPackageProtectionEnabled()) {
            return AccessController.doPrivileged(new PrivilegedExecuteMethod(method, context, params));
        } else {
            return method.invoke(context, params);
        }
    }


    /**
     * Throw the real exception.
     *
     * @param ex The current exception
     */
    private void handleException(Exception ex) throws Throwable {

        Throwable realException;

        if (ex instanceof PrivilegedActionException) {
            ex = ((PrivilegedActionException) ex).getException();
        }

        if (ex instanceof InvocationTargetException) {
            realException = ex.getCause();
            if (realException == null) {
                realException = ex;
            }
        } else {
            realException = ex;
        }

        throw realException;
    }


    private static class PrivilegedExecuteMethod implements PrivilegedExceptionAction<Object> {

        private final Method method;
        private final ApplicationContext context;
        private final Object[] params;

        PrivilegedExecuteMethod(Method method, ApplicationContext context, Object[] params) {
            this.method = method;
            this.context = context;
            this.params = params;
        }

        @Override
        public Object run() throws Exception {
            return method.invoke(context, params);
        }
    }
}