RpcChannel.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.catalina.tribes.group;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

import org.apache.catalina.tribes.Channel;
import org.apache.catalina.tribes.ChannelException;
import org.apache.catalina.tribes.ChannelListener;
import org.apache.catalina.tribes.ErrorHandler;
import org.apache.catalina.tribes.Member;
import org.apache.catalina.tribes.UniqueId;
import org.apache.catalina.tribes.util.StringManager;
import org.apache.catalina.tribes.util.UUIDGenerator;
import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;

/**
 * A channel to handle RPC messaging
 */
public class RpcChannel implements ChannelListener {
    private static final Log log = LogFactory.getLog(RpcChannel.class);
    protected static final StringManager sm = StringManager.getManager(RpcChannel.class);

    public static final int FIRST_REPLY = 1;
    public static final int MAJORITY_REPLY = 2;
    public static final int ALL_REPLY = 3;
    public static final int NO_REPLY = 4;

    private Channel channel;
    private RpcCallback callback;
    private byte[] rpcId;
    private int replyMessageOptions = 0;

    private final ConcurrentMap<RpcCollectorKey, RpcCollector> responseMap = new ConcurrentHashMap<>();

    /**
     * Create an RPC channel. You can have several RPC channels attached to a group
     * all separated out by the uniqueness
     * @param rpcId - the unique Id for this RPC group
     * @param channel Channel
     * @param callback RpcCallback
     */
    public RpcChannel(byte[] rpcId, Channel channel, RpcCallback callback) {
        this.channel = channel;
        this.callback = callback;
        this.rpcId = rpcId;
        channel.addChannelListener(this);
    }


    /**
     * Send a message and wait for the response.
     * @param destination Member[] - the destination for the message, and the members you request a reply from
     * @param message Serializable - the message you are sending out
     * @param rpcOptions int - FIRST_REPLY, MAJORITY_REPLY or ALL_REPLY
     * @param channelOptions channel sender options
     * @param timeout long - timeout in milliseconds, if no reply is received within this time null is returned
     * @return Response[] - an array of response objects.
     * @throws ChannelException Error sending message
     */
    public Response[] send(Member[] destination,
                           Serializable message,
                           int rpcOptions,
                           int channelOptions,
                           long timeout) throws ChannelException {

        if ( destination==null || destination.length == 0 ) {
            return new Response[0];
        }

        //avoid dead lock
        int sendOptions =
            channelOptions & ~Channel.SEND_OPTIONS_SYNCHRONIZED_ACK;

        RpcCollectorKey key = new RpcCollectorKey(UUIDGenerator.randomUUID(false));
        RpcCollector collector = new RpcCollector(key,rpcOptions,destination.length);
        try {
            synchronized (collector) {
                if ( rpcOptions != NO_REPLY ) {
                    responseMap.put(key, collector);
                }
                RpcMessage rmsg = new RpcMessage(rpcId, key.id, message);
                channel.send(destination, rmsg, sendOptions);
                if ( rpcOptions != NO_REPLY ) {
                    collector.wait(timeout);
                }
            }
        } catch ( InterruptedException ix ) {
            Thread.currentThread().interrupt();
        } finally {
            responseMap.remove(key);
        }
        return collector.getResponses();
    }

    @Override
    public void messageReceived(Serializable msg, Member sender) {
        RpcMessage rmsg = (RpcMessage)msg;
        RpcCollectorKey key = new RpcCollectorKey(rmsg.uuid);
        if ( rmsg.reply ) {
            RpcCollector collector = responseMap.get(key);
            if (collector == null) {
                if (!(rmsg instanceof RpcMessage.NoRpcChannelReply)) {
                    callback.leftOver(rmsg.message, sender);
                }
            } else {
                synchronized (collector) {
                    //make sure it hasn't been removed
                    if ( responseMap.containsKey(key) ) {
                        if ( (rmsg instanceof RpcMessage.NoRpcChannelReply) ) {
                            collector.destcnt--;
                        } else {
                            collector.addResponse(rmsg.message, sender);
                        }
                        if (collector.isComplete()) {
                            collector.notifyAll();
                        }
                    } else {
                        if (! (rmsg instanceof RpcMessage.NoRpcChannelReply) ) {
                            callback.leftOver(rmsg.message, sender);
                        }
                    }
                }//synchronized
            }//end if
        } else {
            boolean finished = false;
            final ExtendedRpcCallback excallback = (callback instanceof ExtendedRpcCallback)?((ExtendedRpcCallback)callback) : null;
            boolean asyncReply = ((replyMessageOptions & Channel.SEND_OPTIONS_ASYNCHRONOUS) == Channel.SEND_OPTIONS_ASYNCHRONOUS);
            Serializable reply = callback.replyRequest(rmsg.message,sender);
            ErrorHandler handler = null;
            final Serializable request = msg;
            final Serializable response = reply;
            final Member fsender = sender;
            if (excallback!=null && asyncReply) {
                handler = new ErrorHandler() {
                    @Override
                    public void handleError(ChannelException x, UniqueId id) {
                        excallback.replyFailed(request, response, fsender, x);
                    }
                    @Override
                    public void handleCompletion(UniqueId id) {
                        excallback.replySucceeded(request, response, fsender);
                    }
                };
            }
            rmsg.reply = true;
            rmsg.message = reply;
            try {
                if (handler!=null) {
                    channel.send(new Member[] {sender}, rmsg,replyMessageOptions & ~Channel.SEND_OPTIONS_SYNCHRONIZED_ACK, handler);
                } else {
                    channel.send(new Member[] {sender}, rmsg,replyMessageOptions & ~Channel.SEND_OPTIONS_SYNCHRONIZED_ACK);
                }
                finished = true;
            } catch ( Exception x )  {
                if (excallback != null && !asyncReply) {
                    excallback.replyFailed(rmsg.message, reply, sender, x);
                } else {
                    log.error(sm.getString("rpcChannel.replyFailed"),x);
                }
            }
            if (finished && excallback != null && !asyncReply) {
                excallback.replySucceeded(rmsg.message, reply, sender);
            }
        }//end if
    }

    public void breakdown() {
        channel.removeChannelListener(this);
    }

    @Override
    public boolean accept(Serializable msg, Member sender) {
        if ( msg instanceof RpcMessage ) {
            RpcMessage rmsg = (RpcMessage)msg;
            return Arrays.equals(rmsg.rpcId,rpcId);
        } else {
            return false;
        }
    }

    public Channel getChannel() {
        return channel;
    }

    public RpcCallback getCallback() {
        return callback;
    }

    public byte[] getRpcId() {
        return rpcId;
    }

    public void setChannel(Channel channel) {
        this.channel = channel;
    }

    public void setCallback(RpcCallback callback) {
        this.callback = callback;
    }

    public void setRpcId(byte[] rpcId) {
        this.rpcId = rpcId;
    }

    public int getReplyMessageOptions() {
        return replyMessageOptions;
    }

    public void setReplyMessageOptions(int replyMessageOptions) {
        this.replyMessageOptions = replyMessageOptions;
    }

    /**
     * Class that holds all response.
     */
    public static class RpcCollector {
        public final ArrayList<Response> responses = new ArrayList<>();
        public final RpcCollectorKey key;
        public final int options;
        public int destcnt;

        public RpcCollector(RpcCollectorKey key, int options, int destcnt) {
            this.key = key;
            this.options = options;
            this.destcnt = destcnt;
        }

        public void addResponse(Serializable message, Member sender) {
            Response resp = new Response(sender,message);
            responses.add(resp);
        }

        public boolean isComplete() {
            if ( destcnt <= 0 ) {
                return true;
            }
            switch (options) {
                case ALL_REPLY:
                    return destcnt == responses.size();
                case MAJORITY_REPLY:
                    float perc = ((float)responses.size()) / ((float)destcnt);
                    return perc >= 0.50f;
                case FIRST_REPLY:
                    return responses.size()>0;
                default:
                    return false;
            }
        }

        @Override
        public int hashCode() {
            return key.hashCode();
        }

        @Override
        public boolean equals(Object o) {
            if ( o instanceof RpcCollector ) {
                RpcCollector r = (RpcCollector)o;
                return r.key.equals(this.key);
            } else {
                return false;
            }
        }

        public Response[] getResponses() {
            return responses.toArray(new Response[0]);
        }
    }

    public static class RpcCollectorKey {
        final byte[] id;
        public RpcCollectorKey(byte[] id) {
            this.id = id;
        }

        @Override
        public int hashCode() {
            return id[0]+id[1]+id[2]+id[3];
        }

        @Override
        public boolean equals(Object o) {
            if ( o instanceof RpcCollectorKey ) {
                RpcCollectorKey r = (RpcCollectorKey)o;
                return Arrays.equals(id,r.id);
            } else {
                return false;
            }
        }

    }
}