McastServiceImpl.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.catalina.tribes.membership;
import java.io.IOException;
import java.net.BindException;
import java.net.DatagramPacket;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.MulticastSocket;
import java.net.NetworkInterface;
import java.net.SocketTimeoutException;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.catalina.tribes.Channel;
import org.apache.catalina.tribes.Member;
import org.apache.catalina.tribes.MembershipListener;
import org.apache.catalina.tribes.MessageListener;
import org.apache.catalina.tribes.io.ChannelData;
import org.apache.catalina.tribes.io.XByteBuffer;
import org.apache.catalina.tribes.util.StringManager;
import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;
/**
* A <b>membership</b> implementation using simple multicast. This is the representation of a multicast membership
* service. This class is responsible for maintaining a list of active cluster nodes in the cluster. If a node fails to
* send out a heartbeat, the node will be dismissed. This is the low level implementation that handles the multicasting
* sockets. Need to fix this, could use java.nio and only need one thread to send and receive, or just use a timeout on
* the receive
*/
public class McastServiceImpl extends MembershipProviderBase {
private static final Log log = LogFactory.getLog(McastService.class);
protected static final int MAX_PACKET_SIZE = 65535;
protected static final StringManager sm = StringManager.getManager(Constants.Package);
/**
* Internal flag used for the listen thread that listens to the multicasting socket.
*/
protected volatile boolean doRunSender = false;
protected volatile boolean doRunReceiver = false;
protected volatile int startLevel = 0;
/**
* Socket that we intend to listen to
*/
protected MulticastSocket socket;
/**
* The local member that we intend to broad cast over and over again
*/
protected final MemberImpl member;
/**
* The multicast address
*/
protected final InetAddress address;
/**
* The multicast port
*/
protected final int port;
/**
* The time it takes for a member to expire.
*/
protected final long timeToExpiration;
/**
* How often to we send out a broadcast saying we are alive, must be smaller than timeToExpiration
*/
protected final long sendFrequency;
/**
* Reuse the sendPacket, no need to create a new one every time
*/
protected DatagramPacket sendPacket;
/**
* Reuse the receivePacket, no need to create a new one every time
*/
protected DatagramPacket receivePacket;
/**
* The actual listener, for callback when stuff goes down
*/
protected final MembershipListener service;
/**
* The actual listener for broadcast callbacks
*/
protected final MessageListener msgservice;
/**
* Thread to listen for pings
*/
protected ReceiverThread receiver;
/**
* Thread to send pings
*/
protected SenderThread sender;
/**
* Time to live for the multicast packets that are being sent out
*/
protected final int mcastTTL;
/**
* Read timeout on the mcast socket
*/
protected int mcastSoTimeout = -1;
/**
* bind address
*/
protected final InetAddress mcastBindAddress;
/**
* nr of times the system has to fail before a recovery is initiated
*/
protected int recoveryCounter = 10;
/**
* The time the recovery thread sleeps between recovery attempts
*/
protected long recoverySleepTime = 5000;
/**
* Add the ability to turn on/off recovery
*/
protected boolean recoveryEnabled = true;
/**
* disable/enable local loopback message
*/
protected final boolean localLoopbackDisabled;
private Channel channel;
/**
* Create a new mcast service instance.
*
* @param member - the local member
* @param sendFrequency - the time (ms) in between pings sent out
* @param expireTime - the time (ms) for a member to expire
* @param port - the mcast port
* @param bind - the bind address (not sure this is used yet)
* @param mcastAddress - the mcast address
* @param ttl multicast ttl that will be set on the socket
* @param soTimeout Socket timeout
* @param service - the callback service
* @param msgservice Message listener
* @param localLoopbackDisabled - disable loopbackMode
*
* @throws IOException Init error
*/
public McastServiceImpl(MemberImpl member, long sendFrequency, long expireTime, int port, InetAddress bind,
InetAddress mcastAddress, int ttl, int soTimeout, MembershipListener service, MessageListener msgservice,
boolean localLoopbackDisabled) throws IOException {
this.member = member;
this.address = mcastAddress;
this.port = port;
this.mcastSoTimeout = soTimeout;
this.mcastTTL = ttl;
this.mcastBindAddress = bind;
this.timeToExpiration = expireTime;
this.service = service;
this.msgservice = msgservice;
this.sendFrequency = sendFrequency;
this.localLoopbackDisabled = localLoopbackDisabled;
init();
}
public void init() throws IOException {
setupSocket();
sendPacket = new DatagramPacket(new byte[MAX_PACKET_SIZE], MAX_PACKET_SIZE);
sendPacket.setAddress(address);
sendPacket.setPort(port);
receivePacket = new DatagramPacket(new byte[MAX_PACKET_SIZE], MAX_PACKET_SIZE);
receivePacket.setAddress(address);
receivePacket.setPort(port);
member.setCommand(new byte[0]);
if (membership == null) {
membership = new Membership(member);
}
}
protected void setupSocket() throws IOException {
if (mcastBindAddress != null) {
try {
log.info(sm.getString("mcastServiceImpl.bind", address, Integer.toString(port)));
socket = new MulticastSocket(new InetSocketAddress(address, port));
} catch (BindException e) {
/*
* On some platforms (e.g. Linux) it is not possible to bind to the multicast address. In this case only
* bind to the port.
*/
log.info(sm.getString("mcastServiceImpl.bind.failed"));
socket = new MulticastSocket(port);
}
} else {
socket = new MulticastSocket(port);
}
// Hint if we want disable loop back(local machine) messages
socket.setLoopbackMode(localLoopbackDisabled);
if (mcastBindAddress != null) {
if (log.isInfoEnabled()) {
log.info(sm.getString("mcastServiceImpl.setInterface", mcastBindAddress));
}
NetworkInterface networkInterface = NetworkInterface.getByInetAddress(mcastBindAddress);
socket.setNetworkInterface(networkInterface);
} // end if
// force a so timeout so that we don't block forever
if (mcastSoTimeout <= 0) {
mcastSoTimeout = (int) sendFrequency;
}
if (log.isInfoEnabled()) {
log.info(sm.getString("mcastServiceImpl.setSoTimeout", Integer.toString(mcastSoTimeout)));
}
socket.setSoTimeout(mcastSoTimeout);
if (mcastTTL >= 0) {
if (log.isInfoEnabled()) {
log.info(sm.getString("mcastServiceImpl.setTTL", Integer.toString(mcastTTL)));
}
socket.setTimeToLive(mcastTTL);
}
}
@Override
public synchronized void start(int level) throws IOException {
boolean valid = false;
if ((level & Channel.MBR_RX_SEQ) == Channel.MBR_RX_SEQ) {
if (receiver != null) {
throw new IllegalStateException(sm.getString("mcastServiceImpl.receive.running"));
}
try {
if (sender == null) {
socket.joinGroup(new InetSocketAddress(address, 0), null);
}
} catch (IOException iox) {
log.error(sm.getString("mcastServiceImpl.unable.join"));
throw iox;
}
doRunReceiver = true;
receiver = new ReceiverThread();
receiver.setDaemon(true);
receiver.start();
valid = true;
}
if ((level & Channel.MBR_TX_SEQ) == Channel.MBR_TX_SEQ) {
if (sender != null) {
throw new IllegalStateException(sm.getString("mcastServiceImpl.send.running"));
}
if (receiver == null) {
socket.joinGroup(new InetSocketAddress(address, 0), null);
}
// make sure at least one packet gets out there
send(false);
doRunSender = true;
sender = new SenderThread(sendFrequency);
sender.setDaemon(true);
sender.start();
// we have started the receiver, but not yet waited for membership to establish
valid = true;
}
if (!valid) {
throw new IllegalArgumentException(sm.getString("mcastServiceImpl.invalid.startLevel"));
}
// pause, once or twice
waitForMembers(level);
startLevel = (startLevel | level);
}
private void waitForMembers(int level) {
long memberwait = sendFrequency * 2;
if (log.isInfoEnabled()) {
log.info(sm.getString("mcastServiceImpl.waitForMembers.start", Long.toString(memberwait),
Integer.toString(level)));
}
try {
Thread.sleep(memberwait);
} catch (InterruptedException ignore) {
}
if (log.isInfoEnabled()) {
log.info(sm.getString("mcastServiceImpl.waitForMembers.done", Integer.toString(level)));
}
}
@Override
public synchronized boolean stop(int level) throws IOException {
boolean valid = false;
if ((level & Channel.MBR_RX_SEQ) == Channel.MBR_RX_SEQ) {
valid = true;
doRunReceiver = false;
if (receiver != null) {
receiver.interrupt();
}
receiver = null;
}
if ((level & Channel.MBR_TX_SEQ) == Channel.MBR_TX_SEQ) {
valid = true;
doRunSender = false;
if (sender != null) {
sender.interrupt();
}
sender = null;
}
if (!valid) {
throw new IllegalArgumentException(sm.getString("mcastServiceImpl.invalid.stopLevel"));
}
startLevel = (startLevel & (~level));
// we're shutting down, send a shutdown message and close the socket
if (startLevel == 0) {
// send a stop message
member.setCommand(Member.SHUTDOWN_PAYLOAD);
send(false);
// leave mcast group
try {
socket.leaveGroup(new InetSocketAddress(address, 0), null);
} catch (Exception ignore) {
// NO-OP
}
try {
socket.close();
} catch (Exception ignore) {
// NO-OP
}
member.setServiceStartTime(-1);
}
return (startLevel == 0);
}
/**
* Receive a datagram packet, locking wait
*
* @throws IOException Received failed
*/
public void receive() throws IOException {
boolean checkexpired = true;
try {
socket.receive(receivePacket);
if (receivePacket.getLength() > MAX_PACKET_SIZE) {
log.error(sm.getString("mcastServiceImpl.packet.tooLong", Integer.toString(receivePacket.getLength())));
} else {
byte[] data = new byte[receivePacket.getLength()];
System.arraycopy(receivePacket.getData(), receivePacket.getOffset(), data, 0, data.length);
if (XByteBuffer.firstIndexOf(data, 0, MemberImpl.TRIBES_MBR_BEGIN) == 0) {
memberDataReceived(data);
} else {
memberBroadcastsReceived(data);
}
}
} catch (SocketTimeoutException x) {
// do nothing, this is normal, we don't want to block forever
// since the receive thread is the same thread
// that does membership expiration
}
if (checkexpired) {
checkExpired();
}
}
private void memberDataReceived(byte[] data) {
final Member m = MemberImpl.getMember(data);
if (log.isTraceEnabled()) {
log.trace("Mcast receive ping from member " + m);
}
Runnable t = null;
Thread currentThread = Thread.currentThread();
if (Arrays.equals(m.getCommand(), Member.SHUTDOWN_PAYLOAD)) {
if (log.isDebugEnabled()) {
log.debug(sm.getString("mcastServiceImpl.memberShutdown", m));
}
membership.removeMember(m);
t = () -> {
String name = currentThread.getName();
try {
currentThread.setName("Membership-MemberDisappeared");
service.memberDisappeared(m);
} finally {
currentThread.setName(name);
}
};
} else if (membership.memberAlive(m)) {
if (log.isDebugEnabled()) {
log.debug(sm.getString("mcastServiceImpl.memberAdd", m));
}
t = () -> {
String name = currentThread.getName();
try {
currentThread.setName("Membership-MemberAdded");
service.memberAdded(m);
} finally {
currentThread.setName(name);
}
};
}
if (t != null) {
executor.execute(t);
}
}
private void memberBroadcastsReceived(final byte[] b) {
if (log.isTraceEnabled()) {
log.trace("Mcast received broadcasts.");
}
XByteBuffer buffer = new XByteBuffer(b, true);
if (buffer.countPackages(true) > 0) {
int count = buffer.countPackages();
final ChannelData[] data = new ChannelData[count];
for (int i = 0; i < count; i++) {
try {
data[i] = buffer.extractPackage(true);
} catch (IllegalStateException ise) {
log.debug(sm.getString("mcastServiceImpl.messageError"), ise);
}
}
Runnable t = () -> {
Thread currentThread = Thread.currentThread();
String name = currentThread.getName();
try {
currentThread.setName("Membership-MemberAdded");
for (ChannelData datum : data) {
try {
if (datum != null && !member.equals(datum.getAddress())) {
msgservice.messageReceived(datum);
}
} catch (Throwable t1) {
if (t1 instanceof ThreadDeath) {
throw (ThreadDeath) t1;
}
if (t1 instanceof VirtualMachineError) {
throw (VirtualMachineError) t1;
}
log.error(sm.getString("mcastServiceImpl.unableReceive.broadcastMessage"), t1);
}
}
} finally {
currentThread.setName(name);
}
};
executor.execute(t);
}
}
protected final Object expiredMutex = new Object();
protected void checkExpired() {
synchronized (expiredMutex) {
Member[] expired = membership.expire(timeToExpiration);
for (final Member member : expired) {
if (log.isDebugEnabled()) {
log.debug(sm.getString("mcastServiceImpl.memberExpire", member));
}
try {
Runnable t = () -> {
Thread currentThread = Thread.currentThread();
String name = currentThread.getName();
try {
currentThread.setName("Membership-MemberExpired");
service.memberDisappeared(member);
} finally {
currentThread.setName(name);
}
};
executor.execute(t);
} catch (Exception x) {
log.error(sm.getString("mcastServiceImpl.memberDisappeared.failed"), x);
}
}
}
}
/**
* Send a ping.
*
* @param checkexpired <code>true</code> to check for expiration
*
* @throws IOException Send error
*/
public void send(boolean checkexpired) throws IOException {
send(checkexpired, null);
}
private final Object sendLock = new Object();
public void send(boolean checkexpired, DatagramPacket packet) throws IOException {
checkexpired = (checkexpired && (packet == null));
// ignore if we haven't started the sender
// if ( (startLevel&Channel.MBR_TX_SEQ) != Channel.MBR_TX_SEQ ) return;
if (packet == null) {
member.inc();
if (log.isTraceEnabled()) {
log.trace("Mcast send ping from member " + member);
}
byte[] data = member.getData();
packet = new DatagramPacket(data, data.length);
} else if (log.isTraceEnabled()) {
log.trace("Sending message broadcast " + packet.getLength() + " bytes from " + member);
}
packet.setAddress(address);
packet.setPort(port);
// TODO this operation is not thread safe
synchronized (sendLock) {
socket.send(packet);
}
if (checkexpired) {
checkExpired();
}
}
public long getServiceStartTime() {
return (member != null) ? member.getServiceStartTime() : -1l;
}
public int getRecoveryCounter() {
return recoveryCounter;
}
public boolean isRecoveryEnabled() {
return recoveryEnabled;
}
public long getRecoverySleepTime() {
return recoverySleepTime;
}
public Channel getChannel() {
return channel;
}
public void setChannel(Channel channel) {
this.channel = channel;
}
public class ReceiverThread extends Thread {
int errorCounter = 0;
public ReceiverThread() {
super();
String channelName = "";
if (channel.getName() != null) {
channelName = "[" + channel.getName() + "]";
}
setName("Tribes-MembershipReceiver" + channelName);
}
@Override
public void run() {
while (doRunReceiver) {
try {
receive();
errorCounter = 0;
} catch (ArrayIndexOutOfBoundsException ax) {
// we can ignore this, as it means we have an invalid package
// but we will log it to debug
if (log.isDebugEnabled()) {
log.debug(sm.getString("mcastServiceImpl.invalidMemberPackage"), ax);
}
} catch (Exception x) {
if (errorCounter == 0 && doRunReceiver) {
log.warn(sm.getString("mcastServiceImpl.error.receiving"), x);
} else if (log.isDebugEnabled()) {
if (doRunReceiver) {
log.debug(sm.getString("mcastServiceImpl.error.receiving"), x);
} else {
log.warn(sm.getString("mcastServiceImpl.error.receivingNoSleep"), x);
}
}
if (doRunReceiver) {
try {
sleep(500);
} catch (Exception ignore) {
// Ignore
}
if ((++errorCounter) >= recoveryCounter) {
errorCounter = 0;
RecoveryThread.recover(McastServiceImpl.this);
}
}
}
}
}
}// class ReceiverThread
public class SenderThread extends Thread {
final long time;
int errorCounter = 0;
public SenderThread(long time) {
this.time = time;
String channelName = "";
if (channel.getName() != null) {
channelName = "[" + channel.getName() + "]";
}
setName("Tribes-MembershipSender" + channelName);
}
@Override
public void run() {
while (doRunSender) {
try {
send(true);
errorCounter = 0;
} catch (Exception x) {
if (errorCounter == 0) {
log.warn(sm.getString("mcastServiceImpl.send.failed"), x);
} else {
log.debug(sm.getString("mcastServiceImpl.send.failed"), x);
}
if ((++errorCounter) >= recoveryCounter) {
errorCounter = 0;
RecoveryThread.recover(McastServiceImpl.this);
}
}
try {
sleep(time);
} catch (Exception ignore) {
// Ignore
}
}
}
}// class SenderThread
protected static class RecoveryThread extends Thread {
private static final AtomicBoolean running = new AtomicBoolean(false);
public static synchronized void recover(McastServiceImpl parent) {
if (!parent.isRecoveryEnabled()) {
return;
}
if (!running.compareAndSet(false, true)) {
return;
}
Thread t = new RecoveryThread(parent);
String channelName = "";
if (parent.channel.getName() != null) {
channelName = "[" + parent.channel.getName() + "]";
}
t.setName("Tribes-MembershipRecovery" + channelName);
t.setDaemon(true);
t.start();
}
final McastServiceImpl parent;
public RecoveryThread(McastServiceImpl parent) {
this.parent = parent;
}
public boolean stopService() {
try {
parent.stop(Channel.MBR_RX_SEQ | Channel.MBR_TX_SEQ);
return true;
} catch (Exception x) {
log.warn(sm.getString("mcastServiceImpl.recovery.stopFailed"), x);
return false;
}
}
public boolean startService() {
try {
parent.init();
parent.start(Channel.MBR_RX_SEQ | Channel.MBR_TX_SEQ);
return true;
} catch (Exception x) {
log.warn(sm.getString("mcastServiceImpl.recovery.startFailed"), x);
return false;
}
}
@Override
public void run() {
boolean success = false;
int attempt = 0;
try {
while (!success) {
if (log.isInfoEnabled()) {
log.info(sm.getString("mcastServiceImpl.recovery"));
}
if (stopService() & startService()) {
success = true;
if (log.isInfoEnabled()) {
log.info(sm.getString("mcastServiceImpl.recovery.successful"));
}
}
try {
if (!success) {
if (log.isInfoEnabled()) {
log.info(sm.getString("mcastServiceImpl.recovery.failed", Integer.toString(++attempt),
Long.toString(parent.recoverySleepTime)));
}
sleep(parent.recoverySleepTime);
}
} catch (InterruptedException ignore) {
}
}
} finally {
running.set(false);
}
}
}
public void setRecoveryCounter(int recoveryCounter) {
this.recoveryCounter = recoveryCounter;
}
public void setRecoveryEnabled(boolean recoveryEnabled) {
this.recoveryEnabled = recoveryEnabled;
}
public void setRecoverySleepTime(long recoverySleepTime) {
this.recoverySleepTime = recoverySleepTime;
}
}