FragmentationInterceptor.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.catalina.tribes.group.interceptors;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Set;

import org.apache.catalina.tribes.ChannelException;
import org.apache.catalina.tribes.ChannelMessage;
import org.apache.catalina.tribes.Member;
import org.apache.catalina.tribes.group.ChannelInterceptorBase;
import org.apache.catalina.tribes.group.InterceptorPayload;
import org.apache.catalina.tribes.io.XByteBuffer;
import org.apache.catalina.tribes.util.StringManager;
import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;

/**
 * The fragmentation interceptor splits up large messages into smaller messages and assembles them on the other end.
 * This is very useful when you don't want large messages hogging the sending sockets and smaller messages can make it
 * through. <br>
 * <b>Configuration Options</b><br>
 * FragmentationInterceptor.expire=&lt;milliseconds&gt; - how long do we keep the fragments in memory and wait for the
 * rest to arrive <b>default=60,000ms -&gt; 60seconds</b> This setting is useful to avoid OutOfMemoryErrors<br>
 * FragmentationInterceptor.maxSize=&lt;max message size&gt; - message size in bytes <b>default=1024*100 (around a tenth
 * of a MB)</b><br>
 */
public class FragmentationInterceptor extends ChannelInterceptorBase implements FragmentationInterceptorMBean {
    private static final Log log = LogFactory.getLog(FragmentationInterceptor.class);
    protected static final StringManager sm = StringManager.getManager(FragmentationInterceptor.class);

    protected final HashMap<FragKey,FragCollection> fragpieces = new HashMap<>();
    private int maxSize = 1024 * 100;
    private long expire = 1000 * 60; // one minute expiration
    protected final boolean deepclone = true;


    @Override
    public void sendMessage(Member[] destination, ChannelMessage msg, InterceptorPayload payload)
            throws ChannelException {
        int size = msg.getMessage().getLength();
        boolean frag = (size > maxSize) && okToProcess(msg.getOptions());
        if (frag) {
            frag(destination, msg, payload);
        } else {
            msg.getMessage().append(frag);
            super.sendMessage(destination, msg, payload);
        }
    }

    @Override
    public void messageReceived(ChannelMessage msg) {
        boolean isFrag = XByteBuffer.toBoolean(msg.getMessage().getBytesDirect(), msg.getMessage().getLength() - 1);
        msg.getMessage().trim(1);
        if (isFrag) {
            defrag(msg);
        } else {
            super.messageReceived(msg);
        }
    }


    public FragCollection getFragCollection(FragKey key, ChannelMessage msg) {
        FragCollection coll = fragpieces.get(key);
        if (coll == null) {
            synchronized (fragpieces) {
                coll = fragpieces.get(key);
                if (coll == null) {
                    coll = new FragCollection(msg);
                    fragpieces.put(key, coll);
                }
            }
        }
        return coll;
    }

    public void removeFragCollection(FragKey key) {
        fragpieces.remove(key);
    }

    public void defrag(ChannelMessage msg) {
        FragKey key = new FragKey(msg.getUniqueId());
        FragCollection coll = getFragCollection(key, msg);
        coll.addMessage((ChannelMessage) msg.deepclone());

        if (coll.complete()) {
            removeFragCollection(key);
            ChannelMessage complete = coll.assemble();
            super.messageReceived(complete);

        }
    }

    public void frag(Member[] destination, ChannelMessage msg, InterceptorPayload payload) throws ChannelException {
        int size = msg.getMessage().getLength();

        int count = ((size / maxSize) + (size % maxSize == 0 ? 0 : 1));
        ChannelMessage[] messages = new ChannelMessage[count];
        int remaining = size;
        for (int i = 0; i < count; i++) {
            ChannelMessage tmp = (ChannelMessage) msg.clone();
            int offset = (i * maxSize);
            int length = Math.min(remaining, maxSize);
            tmp.getMessage().clear();
            tmp.getMessage().append(msg.getMessage().getBytesDirect(), offset, length);
            // add the msg nr
            // tmp.getMessage().append(XByteBuffer.toBytes(i),0,4);
            tmp.getMessage().append(i);
            // add the total nr of messages
            // tmp.getMessage().append(XByteBuffer.toBytes(count),0,4);
            tmp.getMessage().append(count);
            // add true as the frag flag
            // byte[] flag = XByteBuffer.toBytes(true);
            // tmp.getMessage().append(flag,0,flag.length);
            tmp.getMessage().append(true);
            messages[i] = tmp;
            remaining -= length;

        }
        for (ChannelMessage message : messages) {
            super.sendMessage(destination, message, payload);
        }
    }

    @Override
    public void heartbeat() {
        try {
            Set<FragKey> set = fragpieces.keySet();
            Object[] keys = set.toArray();
            for (Object o : keys) {
                FragKey key = (FragKey) o;
                if (key != null && key.expired(getExpire())) {
                    removeFragCollection(key);
                }
            }
        } catch (Exception x) {
            if (log.isErrorEnabled()) {
                log.error(sm.getString("fragmentationInterceptor.heartbeat.failed"), x);
            }
        }
        super.heartbeat();
    }

    @Override
    public int getMaxSize() {
        return maxSize;
    }

    @Override
    public long getExpire() {
        return expire;
    }

    @Override
    public void setMaxSize(int maxSize) {
        this.maxSize = maxSize;
    }

    @Override
    public void setExpire(long expire) {
        this.expire = expire;
    }

    public static class FragCollection {
        private final long received = System.currentTimeMillis();
        private final ChannelMessage msg;
        private final XByteBuffer[] frags;

        public FragCollection(ChannelMessage msg) {
            // get the total messages
            int count = XByteBuffer.toInt(msg.getMessage().getBytesDirect(), msg.getMessage().getLength() - 4);
            frags = new XByteBuffer[count];
            this.msg = msg;
        }

        public void addMessage(ChannelMessage msg) {
            // remove the total messages
            msg.getMessage().trim(4);
            // get the msg nr
            int nr = XByteBuffer.toInt(msg.getMessage().getBytesDirect(), msg.getMessage().getLength() - 4);
            // remove the msg nr
            msg.getMessage().trim(4);
            frags[nr] = msg.getMessage();

        }

        public boolean complete() {
            boolean result = true;
            for (int i = 0; (i < frags.length) && (result); i++) {
                result = (frags[i] != null);
            }
            return result;
        }

        public ChannelMessage assemble() {
            if (!complete()) {
                throw new IllegalStateException(sm.getString("fragmentationInterceptor.fragments.missing"));
            }
            int buffersize = 0;
            for (XByteBuffer frag : frags) {
                buffersize += frag.getLength();
            }
            XByteBuffer buf = new XByteBuffer(buffersize, false);
            msg.setMessage(buf);
            for (XByteBuffer frag : frags) {
                msg.getMessage().append(frag.getBytesDirect(), 0, frag.getLength());
            }
            return msg;
        }

        public boolean expired(long expire) {
            return (System.currentTimeMillis() - received) > expire;
        }
    }

    public static class FragKey {
        private final byte[] uniqueId;
        private final long received = System.currentTimeMillis();

        public FragKey(byte[] id) {
            this.uniqueId = id;
        }

        @Override
        public int hashCode() {
            return XByteBuffer.toInt(uniqueId, 0);
        }

        @Override
        public boolean equals(Object o) {
            if (o instanceof FragKey) {
                return Arrays.equals(uniqueId, ((FragKey) o).uniqueId);
            } else {
                return false;
            }

        }

        public boolean expired(long expire) {
            return (System.currentTimeMillis() - received) > expire;
        }

    }

}