OrderInterceptor.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.catalina.tribes.group.interceptors;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantReadWriteLock;

import org.apache.catalina.tribes.ChannelException;
import org.apache.catalina.tribes.ChannelMessage;
import org.apache.catalina.tribes.Member;
import org.apache.catalina.tribes.group.ChannelInterceptorBase;
import org.apache.catalina.tribes.group.InterceptorPayload;
import org.apache.catalina.tribes.io.XByteBuffer;
import org.apache.catalina.tribes.util.StringManager;


/**
 * The order interceptor guarantees that messages are received in the same order they were sent. This interceptor works
 * best with the ack=true setting. <br>
 * There is no point in using this with the replicationMode="fastasynchqueue" as this mode guarantees ordering.<BR>
 * If you are using the mode ack=false replicationMode=pooled, and have a lot of concurrent threads, this interceptor
 * can really slow you down, as many messages will be completely out of order and the queue might become rather large.
 * If this is the case, then you might want to set the value OrderInterceptor.maxQueue = 25 (meaning that we will never
 * keep more than 25 messages in our queue) <br>
 * <b>Configuration Options</b><br>
 * OrderInterceptor.expire=&lt;milliseconds&gt; - if a message arrives out of order, how long before we act on it
 * <b>default=3000ms</b><br>
 * OrderInterceptor.maxQueue=&lt;max queue size&gt; - how much can the queue grow to ensure ordering. This setting is
 * useful to avoid OutOfMemoryErrors<b>default=Integer.MAX_VALUE</b><br>
 * OrderInterceptor.forwardExpired=&lt;boolean&gt; - this flag tells the interceptor what to do when a message has
 * expired or the queue has grown larger than the maxQueue value. true means that the message is sent up the stack to
 * the receiver that will receive and out of order message false means, forget the message and reset the message
 * counter. <b>default=true</b>
 */
public class OrderInterceptor extends ChannelInterceptorBase {
    protected static final StringManager sm = StringManager.getManager(OrderInterceptor.class);
    private final Map<Member,Counter> outcounter = new HashMap<>();
    private final Map<Member,Counter> incounter = new HashMap<>();
    private final Map<Member,MessageOrder> incoming = new HashMap<>();
    private long expire = 3000;
    private boolean forwardExpired = true;
    private int maxQueue = Integer.MAX_VALUE;

    final ReentrantReadWriteLock inLock = new ReentrantReadWriteLock(true);
    final ReentrantReadWriteLock outLock = new ReentrantReadWriteLock(true);

    @Override
    public void sendMessage(Member[] destination, ChannelMessage msg, InterceptorPayload payload)
            throws ChannelException {
        if (!okToProcess(msg.getOptions())) {
            super.sendMessage(destination, msg, payload);
            return;
        }
        ChannelException cx = null;
        for (Member member : destination) {
            try {
                int nr = 0;
                outLock.writeLock().lock();
                try {
                    nr = incCounter(member);
                } finally {
                    outLock.writeLock().unlock();
                }
                // reduce byte copy
                msg.getMessage().append(nr);
                try {
                    getNext().sendMessage(new Member[] { member }, msg, payload);
                } finally {
                    msg.getMessage().trim(4);
                }
            } catch (ChannelException x) {
                if (cx == null) {
                    cx = x;
                }
                cx.addFaultyMember(x.getFaultyMembers());
            }
        } // for
        if (cx != null) {
            throw cx;
        }
    }

    @Override
    public void messageReceived(ChannelMessage msg) {
        if (!okToProcess(msg.getOptions())) {
            super.messageReceived(msg);
            return;
        }
        int msgnr = XByteBuffer.toInt(msg.getMessage().getBytesDirect(), msg.getMessage().getLength() - 4);
        msg.getMessage().trim(4);
        MessageOrder order = new MessageOrder(msgnr, (ChannelMessage) msg.deepclone());
        inLock.writeLock().lock();
        try {
            if (processIncoming(order)) {
                processLeftOvers(msg.getAddress(), false);
            }
        } finally {
            inLock.writeLock().unlock();
        }
    }

    protected void processLeftOvers(Member member, boolean force) {
        MessageOrder tmp = incoming.get(member);
        if (force) {
            Counter cnt = getInCounter(member);
            cnt.setCounter(Integer.MAX_VALUE);
        }
        if (tmp != null) {
            processIncoming(tmp);
        }
    }

    /**
     * @param order MessageOrder
     *
     * @return boolean - true if a message expired and was processed
     */
    protected boolean processIncoming(MessageOrder order) {
        boolean result = false;
        Member member = order.getMessage().getAddress();
        Counter cnt = getInCounter(member);

        MessageOrder tmp = incoming.get(member);
        if (tmp != null) {
            order = MessageOrder.add(tmp, order);
        }


        while ((order != null) && (order.getMsgNr() <= cnt.getCounter())) {
            // we are right on target. process orders
            if (order.getMsgNr() == cnt.getCounter()) {
                cnt.inc();
            } else if (order.getMsgNr() > cnt.getCounter()) {
                cnt.setCounter(order.getMsgNr());
            }
            super.messageReceived(order.getMessage());
            order.setMessage(null);
            order = order.next;
        }
        MessageOrder head = order;
        MessageOrder prev = null;
        tmp = order;
        // flag to empty out the queue when it larger than maxQueue
        boolean empty = order != null ? order.getCount() >= maxQueue : false;
        while (tmp != null) {
            // process expired messages or empty out the queue
            if (tmp.isExpired(expire) || empty) {
                // reset the head
                if (tmp == head) {
                    head = tmp.next;
                }
                cnt.setCounter(tmp.getMsgNr() + 1);
                if (getForwardExpired()) {
                    super.messageReceived(tmp.getMessage());
                }
                tmp.setMessage(null);
                tmp = tmp.next;
                if (prev != null) {
                    prev.next = tmp;
                }
                result = true;
            } else {
                prev = tmp;
                tmp = tmp.next;
            }
        }
        if (head == null) {
            incoming.remove(member);
        } else {
            incoming.put(member, head);
        }
        return result;
    }

    @Override
    public void memberAdded(Member member) {
        // notify upwards
        super.memberAdded(member);
    }

    @Override
    public void memberDisappeared(Member member) {
        // reset counters - lock free
        incounter.remove(member);
        outcounter.remove(member);
        // clear the remaining queue
        processLeftOvers(member, true);
        // notify upwards
        super.memberDisappeared(member);
    }

    protected int incCounter(Member mbr) {
        Counter cnt = getOutCounter(mbr);
        return cnt.inc();
    }

    protected Counter getInCounter(Member mbr) {
        Counter cnt = incounter.get(mbr);
        if (cnt == null) {
            cnt = new Counter();
            cnt.inc(); // always start at 1 for incoming
            incounter.put(mbr, cnt);
        }
        return cnt;
    }

    protected Counter getOutCounter(Member mbr) {
        Counter cnt = outcounter.get(mbr);
        if (cnt == null) {
            cnt = new Counter();
            outcounter.put(mbr, cnt);
        }
        return cnt;
    }

    protected static class Counter {
        private final AtomicInteger value = new AtomicInteger(0);

        public int getCounter() {
            return value.get();
        }

        public void setCounter(int counter) {
            this.value.set(counter);
        }

        public int inc() {
            return value.addAndGet(1);
        }
    }

    protected static class MessageOrder {
        private final long received = System.currentTimeMillis();
        private MessageOrder next;
        private final int msgNr;
        private ChannelMessage msg = null;

        public MessageOrder(int msgNr, ChannelMessage msg) {
            this.msgNr = msgNr;
            this.msg = msg;
        }

        public boolean isExpired(long expireTime) {
            return (System.currentTimeMillis() - received) > expireTime;
        }

        public ChannelMessage getMessage() {
            return msg;
        }

        public void setMessage(ChannelMessage msg) {
            this.msg = msg;
        }

        public void setNext(MessageOrder order) {
            this.next = order;
        }

        public MessageOrder getNext() {
            return next;
        }

        public int getCount() {
            int counter = 1;
            MessageOrder tmp = next;
            while (tmp != null) {
                counter++;
                tmp = tmp.next;
            }
            return counter;
        }

        @SuppressWarnings("null") // prev cannot be null
        public static MessageOrder add(MessageOrder head, MessageOrder add) {
            if (head == null) {
                return add;
            }
            if (add == null) {
                return head;
            }
            if (head == add) {
                return add;
            }

            if (head.getMsgNr() > add.getMsgNr()) {
                add.next = head;
                return add;
            }

            MessageOrder iter = head;
            MessageOrder prev = null;
            while (iter.getMsgNr() < add.getMsgNr() && (iter.next != null)) {
                prev = iter;
                iter = iter.next;
            }
            if (iter.getMsgNr() < add.getMsgNr()) {
                // add after
                add.next = iter.next;
                iter.next = add;
            } else if (iter.getMsgNr() > add.getMsgNr()) {
                // add before
                prev.next = add; // prev cannot be null here, warning suppressed
                add.next = iter;

            } else {
                throw new ArithmeticException(sm.getString("orderInterceptor.messageAdded.sameCounter"));
            }

            return head;
        }

        public int getMsgNr() {
            return msgNr;
        }


    }

    public void setExpire(long expire) {
        this.expire = expire;
    }

    public void setForwardExpired(boolean forwardExpired) {
        this.forwardExpired = forwardExpired;
    }

    public void setMaxQueue(int maxQueue) {
        this.maxQueue = maxQueue;
    }

    public long getExpire() {
        return expire;
    }

    public boolean getForwardExpired() {
        return forwardExpired;
    }

    public int getMaxQueue() {
        return maxQueue;
    }

}