TwoPhaseCommitInterceptor.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.catalina.tribes.group.interceptors;

import java.util.HashMap;
import java.util.Map;

import org.apache.catalina.tribes.ChannelException;
import org.apache.catalina.tribes.ChannelMessage;
import org.apache.catalina.tribes.Member;
import org.apache.catalina.tribes.UniqueId;
import org.apache.catalina.tribes.group.ChannelInterceptorBase;
import org.apache.catalina.tribes.group.InterceptorPayload;
import org.apache.catalina.tribes.util.Arrays;
import org.apache.catalina.tribes.util.StringManager;
import org.apache.catalina.tribes.util.UUIDGenerator;
import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;

public class TwoPhaseCommitInterceptor extends ChannelInterceptorBase {

    private static final byte[] START_DATA = new byte[] {113, 1, -58, 2, -34, -60, 75, -78, -101, -12, 32, -29, 32, 111, -40, 4};
    private static final byte[] END_DATA = new byte[] {54, -13, 90, 110, 47, -31, 75, -24, -81, -29, 36, 52, -58, 77, -110, 56};
    private static final Log log = LogFactory.getLog(TwoPhaseCommitInterceptor.class);
    protected static final StringManager sm = StringManager.getManager(TwoPhaseCommitInterceptor.class);

    protected final HashMap<UniqueId, MapEntry> messages = new HashMap<>();
    protected long expire = 1000 * 60; //one minute expiration
    protected boolean deepclone = true;

    @Override
    public void sendMessage(Member[] destination, ChannelMessage msg, InterceptorPayload payload) throws
        ChannelException {
        //todo, optimize, if destination.length==1, then we can do
        //msg.setOptions(msg.getOptions() & (~getOptionFlag())
        //and just send one message
        if (okToProcess(msg.getOptions()) ) {
            super.sendMessage(destination, msg, null);
            ChannelMessage confirmation = null;
            if ( deepclone ) {
                confirmation = (ChannelMessage)msg.deepclone();
            } else {
                confirmation = (ChannelMessage)msg.clone();
            }
            confirmation.getMessage().reset();
            UUIDGenerator.randomUUID(false,confirmation.getUniqueId(),0);
            confirmation.getMessage().append(START_DATA,0,START_DATA.length);
            confirmation.getMessage().append(msg.getUniqueId(),0,msg.getUniqueId().length);
            confirmation.getMessage().append(END_DATA,0,END_DATA.length);
            super.sendMessage(destination,confirmation,payload);
        } else {
            //turn off two phase commit
            //this won't work if the interceptor has 0 as a flag
            //since there is no flag to turn off
            //msg.setOptions(msg.getOptions() & (~getOptionFlag()));
            super.sendMessage(destination, msg, payload);
        }
    }

    @Override
    public void messageReceived(ChannelMessage msg) {
        if (okToProcess(msg.getOptions())) {
            if ( msg.getMessage().getLength() == (START_DATA.length+msg.getUniqueId().length+END_DATA.length) &&
                 Arrays.contains(msg.getMessage().getBytesDirect(),0,START_DATA,0,START_DATA.length) &&
                 Arrays.contains(msg.getMessage().getBytesDirect(),START_DATA.length+msg.getUniqueId().length,END_DATA,0,END_DATA.length) ) {
                UniqueId id = new UniqueId(msg.getMessage().getBytesDirect(),START_DATA.length,msg.getUniqueId().length);
                MapEntry original = messages.get(id);
                if ( original != null ) {
                    super.messageReceived(original.msg);
                    messages.remove(id);
                } else {
                    log.warn(sm.getString("twoPhaseCommitInterceptor.originalMessage.missing", Arrays.toString(id.getBytes())));
                }
            } else {
                UniqueId id = new UniqueId(msg.getUniqueId());
                MapEntry entry = new MapEntry((ChannelMessage)msg.deepclone(),id,System.currentTimeMillis());
                messages.put(id,entry);
            }
        } else {
            super.messageReceived(msg);
        }
    }

    public boolean getDeepclone() {
        return deepclone;
    }

    public long getExpire() {
        return expire;
    }

    public void setDeepclone(boolean deepclone) {
        this.deepclone = deepclone;
    }

    public void setExpire(long expire) {
        this.expire = expire;
    }

    @Override
    public void heartbeat() {
        try {
            long now = System.currentTimeMillis();
            @SuppressWarnings("unchecked")
            Map.Entry<UniqueId,MapEntry>[] entries = messages.entrySet().toArray(new Map.Entry[0]);
            for (Map.Entry<UniqueId, MapEntry> uniqueIdMapEntryEntry : entries) {
                MapEntry entry = uniqueIdMapEntryEntry.getValue();
                if (entry.expired(now, expire)) {
                    log.info(sm.getString("twoPhaseCommitInterceptor.expiredMessage", entry.id));
                    messages.remove(entry.id);
                }
            }
        } catch ( Exception x ) {
            log.warn(sm.getString("twoPhaseCommitInterceptor.heartbeat.failed"),x);
        } finally {
            super.heartbeat();
        }
    }

    public static class MapEntry {
        public final ChannelMessage msg;
        public final UniqueId id;
        public final long timestamp;

        public MapEntry(ChannelMessage msg, UniqueId id, long timestamp) {
            this.msg = msg;
            this.id = id;
            this.timestamp = timestamp;
        }
        public boolean expired(long now, long expiration) {
            return (now - timestamp ) > expiration;
        }

    }

}