WsRemoteEndpointImplServer.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket.server;
import java.io.EOFException;
import java.io.IOException;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.channels.CompletionHandler;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import jakarta.servlet.http.WebConnection;
import jakarta.websocket.SendHandler;
import jakarta.websocket.SendResult;
import org.apache.coyote.http11.upgrade.UpgradeInfo;
import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;
import org.apache.tomcat.util.net.SocketWrapperBase;
import org.apache.tomcat.util.net.SocketWrapperBase.BlockingMode;
import org.apache.tomcat.util.res.StringManager;
import org.apache.tomcat.websocket.Constants;
import org.apache.tomcat.websocket.Transformation;
import org.apache.tomcat.websocket.WsRemoteEndpointImplBase;
/**
* This is the server side {@link jakarta.websocket.RemoteEndpoint} implementation - i.e. what the server uses to send
* data to the client.
*/
public class WsRemoteEndpointImplServer extends WsRemoteEndpointImplBase {
private static final StringManager sm = StringManager.getManager(WsRemoteEndpointImplServer.class);
private final Log log = LogFactory.getLog(WsRemoteEndpointImplServer.class); // must not be static
private final SocketWrapperBase<?> socketWrapper;
private final UpgradeInfo upgradeInfo;
private final WebConnection connection;
private final WsWriteTimeout wsWriteTimeout;
private volatile SendHandler handler = null;
private volatile ByteBuffer[] buffers = null;
private volatile long timeoutExpiry = -1;
public WsRemoteEndpointImplServer(SocketWrapperBase<?> socketWrapper, UpgradeInfo upgradeInfo,
WsServerContainer serverContainer, WebConnection connection) {
this.socketWrapper = socketWrapper;
this.upgradeInfo = upgradeInfo;
this.connection = connection;
this.wsWriteTimeout = serverContainer.getTimeout();
}
@Override
protected final boolean isMasked() {
return false;
}
/**
* {@inheritDoc}
* <p>
* The close message is a special case. It needs to be blocking else implementing the clean-up that follows the
* sending of the close message gets a lot more complicated. On the server, this creates additional complications as
* a dead-lock may occur in the following scenario:
* <ol>
* <li>Application thread writes message using non-blocking</li>
* <li>Write does not complete (write logic holds message pending lock)</li>
* <li>Socket is added to poller (or equivalent) for write
* <li>Client sends close message</li>
* <li>Container processes received close message and tries to send close message in response</li>
* <li>Container holds socket lock and is blocked waiting for message pending lock</li>
* <li>Poller fires write possible event for socket</li>
* <li>Container tries to process write possible event but is blocked waiting for socket lock</li>
* <li>Processing of the WebSocket connection is dead-locked until the original message write times out</li>
* </ol>
* The purpose of this method is to break the above dead-lock. It does this by returning control of the processor to
* the socket wrapper and releasing the socket lock while waiting for the pending message write to complete.
* Normally, that would be a terrible idea as it creates the possibility that the processor is returned to the pool
* more than once under various error conditions. In this instance it is safe because these are upgrade processors
* (isUpgrade() returns {@code true}) and upgrade processors are never pooled.
* <p>
* TODO: Despite the complications it creates, it would be worth exploring the possibility of processing a received
* close frame in a non-blocking manner.
*/
@Override
protected boolean acquireMessagePartInProgressSemaphore(byte opCode, long timeoutExpiry)
throws InterruptedException {
/*
* Special handling is required only when all of the following are true:
* - A close message is being sent
* - This thread currently holds the socketWrapper lock (i.e. the thread is current processing a socket event)
*
* Special handling is only possible if the socketWrapper lock is a ReentrantLock (it will be by default)
*/
if (socketWrapper.getLock() instanceof ReentrantLock) {
ReentrantLock reentrantLock = (ReentrantLock) socketWrapper.getLock();
if (opCode == Constants.OPCODE_CLOSE && reentrantLock.isHeldByCurrentThread()) {
int socketWrapperLockCount = reentrantLock.getHoldCount();
while (!messagePartInProgress.tryAcquire()) {
if (timeoutExpiry < System.currentTimeMillis()) {
return false;
}
try {
// Release control of the processor
socketWrapper.setCurrentProcessor(connection);
// Release the per socket lock(s)
for (int i = 0; i < socketWrapperLockCount; i++) {
socketWrapper.getLock().unlock();
}
// Provide opportunity for another thread to obtain the socketWrapper lock
Thread.yield();
} finally {
// Re-obtain the per socket lock(s)
for (int i = 0; i < socketWrapperLockCount; i++) {
socketWrapper.getLock().lock();
}
// Re-take control of the processor
socketWrapper.takeCurrentProcessor();
}
}
return true;
}
}
// Skip special handling
return super.acquireMessagePartInProgressSemaphore(opCode, timeoutExpiry);
}
@Override
protected void doWrite(SendHandler handler, long blockingWriteTimeoutExpiry, ByteBuffer... buffers) {
if (socketWrapper.hasAsyncIO()) {
final boolean block = (blockingWriteTimeoutExpiry != -1);
long timeout = -1;
if (block) {
timeout = blockingWriteTimeoutExpiry - System.currentTimeMillis();
if (timeout <= 0) {
SendResult sr = new SendResult(new SocketTimeoutException());
handler.onResult(sr);
return;
}
} else {
this.handler = handler;
timeout = getSendTimeout();
if (timeout > 0) {
// Register with timeout thread
timeoutExpiry = timeout + System.currentTimeMillis();
wsWriteTimeout.register(this);
}
}
socketWrapper.write(block ? BlockingMode.BLOCK : BlockingMode.SEMI_BLOCK, timeout, TimeUnit.MILLISECONDS,
null, SocketWrapperBase.COMPLETE_WRITE_WITH_COMPLETION, new CompletionHandler<Long, Void>() {
@Override
public void completed(Long result, Void attachment) {
if (block) {
long timeout = blockingWriteTimeoutExpiry - System.currentTimeMillis();
if (timeout <= 0) {
failed(new SocketTimeoutException(), null);
} else {
handler.onResult(SENDRESULT_OK);
}
} else {
wsWriteTimeout.unregister(WsRemoteEndpointImplServer.this);
clearHandler(null, true);
}
}
@Override
public void failed(Throwable exc, Void attachment) {
if (block) {
SendResult sr = new SendResult(exc);
handler.onResult(sr);
} else {
wsWriteTimeout.unregister(WsRemoteEndpointImplServer.this);
clearHandler(exc, true);
close();
}
}
}, buffers);
} else {
if (blockingWriteTimeoutExpiry == -1) {
this.handler = handler;
this.buffers = buffers;
// This is definitely the same thread that triggered the write so a
// dispatch will be required.
onWritePossible(true);
} else {
// Blocking
try {
for (ByteBuffer buffer : buffers) {
long timeout = blockingWriteTimeoutExpiry - System.currentTimeMillis();
if (timeout <= 0) {
SendResult sr = new SendResult(new SocketTimeoutException());
handler.onResult(sr);
return;
}
socketWrapper.setWriteTimeout(timeout);
socketWrapper.write(true, buffer);
}
long timeout = blockingWriteTimeoutExpiry - System.currentTimeMillis();
if (timeout <= 0) {
SendResult sr = new SendResult(new SocketTimeoutException());
handler.onResult(sr);
return;
}
socketWrapper.setWriteTimeout(timeout);
socketWrapper.flush(true);
handler.onResult(SENDRESULT_OK);
} catch (IOException e) {
SendResult sr = new SendResult(e);
handler.onResult(sr);
}
}
}
}
@Override
protected void updateStats(long payloadLength) {
upgradeInfo.addMsgsSent(1);
upgradeInfo.addBytesSent(payloadLength);
}
public void onWritePossible(boolean useDispatch) {
// Note: Unused for async IO
ByteBuffer[] buffers = this.buffers;
if (buffers == null) {
// Servlet 3.1 will call the write listener once even if nothing
// was written
return;
}
boolean complete = false;
try {
socketWrapper.flush(false);
// If this is false there will be a call back when it is true
while (socketWrapper.isReadyForWrite()) {
complete = true;
for (ByteBuffer buffer : buffers) {
if (buffer.hasRemaining()) {
complete = false;
socketWrapper.write(false, buffer);
break;
}
}
if (complete) {
socketWrapper.flush(false);
complete = socketWrapper.isReadyForWrite();
if (complete) {
wsWriteTimeout.unregister(this);
clearHandler(null, useDispatch);
}
break;
}
}
} catch (IOException | IllegalStateException e) {
wsWriteTimeout.unregister(this);
clearHandler(e, useDispatch);
close();
}
if (!complete) {
// Async write is in progress
long timeout = getSendTimeout();
if (timeout > 0) {
// Register with timeout thread
timeoutExpiry = timeout + System.currentTimeMillis();
wsWriteTimeout.register(this);
}
}
}
@Override
protected void doClose() {
if (handler != null) {
// close() can be triggered by a wide range of scenarios. It is far
// simpler just to always use a dispatch than it is to try and track
// whether or not this method was called by the same thread that
// triggered the write
clearHandler(new EOFException(), true);
}
try {
socketWrapper.close();
} catch (Exception e) {
if (log.isInfoEnabled()) {
log.info(sm.getString("wsRemoteEndpointServer.closeFailed"), e);
}
}
wsWriteTimeout.unregister(this);
}
protected long getTimeoutExpiry() {
return timeoutExpiry;
}
/*
* Currently this is only called from the background thread so we could just call clearHandler() with useDispatch ==
* false but the method parameter was added in case other callers started to use this method to make sure that those
* callers think through what the correct value of useDispatch is for them.
*/
protected void onTimeout(boolean useDispatch) {
if (handler != null) {
clearHandler(new SocketTimeoutException(), useDispatch);
}
close();
}
@Override
protected void setTransformation(Transformation transformation) {
// Overridden purely so it is visible to other classes in this package
super.setTransformation(transformation);
}
/**
* @param t The throwable associated with any error that occurred
* @param useDispatch Should {@link SendHandler#onResult(SendResult)} be called from a new thread, keeping in mind
* the requirements of {@link jakarta.websocket.RemoteEndpoint.Async}
*/
void clearHandler(Throwable t, boolean useDispatch) {
// Setting the result marks this (partial) message as
// complete which means the next one may be sent which
// could update the value of the handler. Therefore, keep a
// local copy before signalling the end of the (partial)
// message.
SendHandler sh = handler;
handler = null;
buffers = null;
if (sh != null) {
if (useDispatch) {
OnResultRunnable r = new OnResultRunnable(sh, t);
try {
socketWrapper.execute(r);
} catch (RejectedExecutionException ree) {
// Can't use the executor so call the runnable directly.
// This may not be strictly specification compliant in all
// cases but during shutdown only close messages are going
// to be sent so there should not be the issue of nested
// calls leading to stack overflow as described in bug
// 55715. The issues with nested calls was the reason for
// the separate thread requirement in the specification.
r.run();
}
} else {
if (t == null) {
sh.onResult(new SendResult());
} else {
sh.onResult(new SendResult(t));
}
}
}
}
@Override
protected Lock getLock() {
return socketWrapper.getLock();
}
private static class OnResultRunnable implements Runnable {
private final SendHandler sh;
private final Throwable t;
private OnResultRunnable(SendHandler sh, Throwable t) {
this.sh = sh;
this.t = t;
}
@Override
public void run() {
if (t == null) {
sh.onResult(new SendResult());
} else {
sh.onResult(new SendResult(t));
}
}
}
}