mirror of
https://github.com/apache/cloudstack.git
synced 2025-10-26 08:42:29 +01:00
Merge pull request #1493 from shapeblue/nio-fix
CLOUDSTACK-9348: Use non-blocking SSL handshake in NioConnection/Link- Uses non-blocking socket config in NioClient and NioServer/NioConnection - Scalable connectivity from agents and peer clustered-management server - Removes blocking ssl handshake code with a non-blocking code - Protects from denial-of-service issues that can degrade mgmt server responsiveness due to an aggressive/malicious client - Uses separate executor services for handling connect/accept events Changes are covered the NioTest so I did not write a new test, advise how we can improve this. Further, I tried to invest time on writing a benchmark test to reproduce a degraded server but could not write it deterministic-ally (sometimes fails/passes but not always). Review, CI testing and feedback requested /cc @swill @jburwell @DaanHoogland @wido @remibergsma @rafaelweingartner @GabrielBrascher * pr/1493: CLOUDSTACK-9348: Use non-blocking SSL handshake CLOUDSTACK-9348: Unit test to demonstrate denial of service attack Signed-off-by: Will Stevens <williamstevens@gmail.com>
This commit is contained in:
commit
7ce0e10fbc
@ -499,7 +499,7 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust
|
||||
SocketChannel ch1 = null;
|
||||
try {
|
||||
ch1 = SocketChannel.open(new InetSocketAddress(addr, Port.value()));
|
||||
ch1.configureBlocking(true); // make sure we are working at blocking mode
|
||||
ch1.configureBlocking(false);
|
||||
ch1.socket().setKeepAlive(true);
|
||||
ch1.socket().setSoTimeout(60 * 1000);
|
||||
try {
|
||||
@ -507,8 +507,11 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust
|
||||
sslEngine = sslContext.createSSLEngine(ip, Port.value());
|
||||
sslEngine.setUseClientMode(true);
|
||||
sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols()));
|
||||
|
||||
Link.doHandshake(ch1, sslEngine, true);
|
||||
sslEngine.beginHandshake();
|
||||
if (!Link.doHandshake(ch1, sslEngine, true)) {
|
||||
ch1.close();
|
||||
throw new IOException("SSL handshake failed!");
|
||||
}
|
||||
s_logger.info("SSL: Handshake done");
|
||||
} catch (final Exception e) {
|
||||
ch1.close();
|
||||
|
||||
@ -208,7 +208,6 @@
|
||||
<excludes>
|
||||
<exclude>com/cloud/utils/testcase/*TestCase*</exclude>
|
||||
<exclude>com/cloud/utils/db/*Test*</exclude>
|
||||
<exclude>com/cloud/utils/testcase/NioTest.java</exclude>
|
||||
</excludes>
|
||||
</configuration>
|
||||
</plugin>
|
||||
|
||||
@ -19,36 +19,32 @@
|
||||
|
||||
package com.cloud.utils.nio;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.net.SocketTimeoutException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.channels.Channels;
|
||||
import java.nio.channels.ClosedChannelException;
|
||||
import java.nio.channels.ReadableByteChannel;
|
||||
import java.nio.channels.SelectionKey;
|
||||
import java.nio.channels.SocketChannel;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.security.KeyStore;
|
||||
import java.util.concurrent.ConcurrentLinkedQueue;
|
||||
import com.cloud.utils.PropertiesUtil;
|
||||
import com.cloud.utils.db.DbProperties;
|
||||
import org.apache.cloudstack.utils.security.SSLUtils;
|
||||
import org.apache.log4j.Logger;
|
||||
|
||||
import javax.net.ssl.KeyManagerFactory;
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLEngine;
|
||||
import javax.net.ssl.SSLEngineResult;
|
||||
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
|
||||
import javax.net.ssl.SSLException;
|
||||
import javax.net.ssl.SSLSession;
|
||||
import javax.net.ssl.TrustManager;
|
||||
import javax.net.ssl.TrustManagerFactory;
|
||||
|
||||
import org.apache.cloudstack.utils.security.SSLUtils;
|
||||
import org.apache.log4j.Logger;
|
||||
|
||||
import com.cloud.utils.PropertiesUtil;
|
||||
import com.cloud.utils.db.DbProperties;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.channels.ClosedChannelException;
|
||||
import java.nio.channels.SelectionKey;
|
||||
import java.nio.channels.SocketChannel;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.security.KeyStore;
|
||||
import java.util.concurrent.ConcurrentLinkedQueue;
|
||||
|
||||
/**
|
||||
*/
|
||||
@ -453,115 +449,185 @@ public class Link {
|
||||
return sslContext;
|
||||
}
|
||||
|
||||
public static void doHandshake(SocketChannel ch, SSLEngine sslEngine, boolean isClient) throws IOException {
|
||||
if (s_logger.isTraceEnabled()) {
|
||||
s_logger.trace("SSL: begin Handshake, isClient: " + isClient);
|
||||
public static ByteBuffer enlargeBuffer(ByteBuffer buffer, final int sessionProposedCapacity) {
|
||||
if (buffer == null || sessionProposedCapacity < 0) {
|
||||
return buffer;
|
||||
}
|
||||
|
||||
SSLEngineResult engResult;
|
||||
SSLSession sslSession = sslEngine.getSession();
|
||||
HandshakeStatus hsStatus;
|
||||
ByteBuffer in_pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
|
||||
ByteBuffer in_appBuf = ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40);
|
||||
ByteBuffer out_pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
|
||||
ByteBuffer out_appBuf = ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40);
|
||||
int count;
|
||||
ch.socket().setSoTimeout(60 * 1000);
|
||||
InputStream inStream = ch.socket().getInputStream();
|
||||
// Use readCh to make sure the timeout on reading is working
|
||||
ReadableByteChannel readCh = Channels.newChannel(inStream);
|
||||
|
||||
if (isClient) {
|
||||
hsStatus = SSLEngineResult.HandshakeStatus.NEED_WRAP;
|
||||
if (sessionProposedCapacity > buffer.capacity()) {
|
||||
buffer = ByteBuffer.allocate(sessionProposedCapacity);
|
||||
} else {
|
||||
hsStatus = SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
|
||||
buffer = ByteBuffer.allocate(buffer.capacity() * 2);
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
while (hsStatus != SSLEngineResult.HandshakeStatus.FINISHED) {
|
||||
if (s_logger.isTraceEnabled()) {
|
||||
s_logger.trace("SSL: Handshake status " + hsStatus);
|
||||
}
|
||||
engResult = null;
|
||||
if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP) {
|
||||
out_pkgBuf.clear();
|
||||
out_appBuf.clear();
|
||||
out_appBuf.put("Hello".getBytes());
|
||||
engResult = sslEngine.wrap(out_appBuf, out_pkgBuf);
|
||||
out_pkgBuf.flip();
|
||||
int remain = out_pkgBuf.limit();
|
||||
while (remain != 0) {
|
||||
remain -= ch.write(out_pkgBuf);
|
||||
if (remain < 0) {
|
||||
throw new IOException("Too much bytes sent?");
|
||||
}
|
||||
}
|
||||
} else if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
|
||||
in_appBuf.clear();
|
||||
// One packet may contained multiply operation
|
||||
if (in_pkgBuf.position() == 0 || !in_pkgBuf.hasRemaining()) {
|
||||
in_pkgBuf.clear();
|
||||
count = 0;
|
||||
try {
|
||||
count = readCh.read(in_pkgBuf);
|
||||
} catch (SocketTimeoutException ex) {
|
||||
if (s_logger.isTraceEnabled()) {
|
||||
s_logger.trace("Handshake reading time out! Cut the connection");
|
||||
}
|
||||
count = -1;
|
||||
}
|
||||
if (count == -1) {
|
||||
throw new IOException("Connection closed with -1 on reading size.");
|
||||
}
|
||||
in_pkgBuf.flip();
|
||||
}
|
||||
engResult = sslEngine.unwrap(in_pkgBuf, in_appBuf);
|
||||
ByteBuffer tmp_pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
|
||||
int loop_count = 0;
|
||||
while (engResult.getStatus() == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
|
||||
// The client is too slow? Cut it and let it reconnect
|
||||
if (loop_count > 10) {
|
||||
throw new IOException("Too many times in SSL BUFFER_UNDERFLOW, disconnect guest.");
|
||||
}
|
||||
// We need more packets to complete this operation
|
||||
if (s_logger.isTraceEnabled()) {
|
||||
s_logger.trace("SSL: Buffer underflowed, getting more packets");
|
||||
}
|
||||
tmp_pkgBuf.clear();
|
||||
count = ch.read(tmp_pkgBuf);
|
||||
if (count == -1) {
|
||||
throw new IOException("Connection closed with -1 on reading size.");
|
||||
}
|
||||
tmp_pkgBuf.flip();
|
||||
|
||||
in_pkgBuf.mark();
|
||||
in_pkgBuf.position(in_pkgBuf.limit());
|
||||
in_pkgBuf.limit(in_pkgBuf.limit() + tmp_pkgBuf.limit());
|
||||
in_pkgBuf.put(tmp_pkgBuf);
|
||||
in_pkgBuf.reset();
|
||||
|
||||
in_appBuf.clear();
|
||||
engResult = sslEngine.unwrap(in_pkgBuf, in_appBuf);
|
||||
loop_count++;
|
||||
}
|
||||
} else if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK) {
|
||||
Runnable run;
|
||||
while ((run = sslEngine.getDelegatedTask()) != null) {
|
||||
if (s_logger.isTraceEnabled()) {
|
||||
s_logger.trace("SSL: Running delegated task!");
|
||||
}
|
||||
run.run();
|
||||
}
|
||||
} else if (hsStatus == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
|
||||
throw new IOException("NOT a handshaking!");
|
||||
}
|
||||
if (engResult != null && engResult.getStatus() != SSLEngineResult.Status.OK) {
|
||||
throw new IOException("Fail to handshake! " + engResult.getStatus());
|
||||
}
|
||||
if (engResult != null)
|
||||
hsStatus = engResult.getHandshakeStatus();
|
||||
else
|
||||
hsStatus = sslEngine.getHandshakeStatus();
|
||||
public static ByteBuffer handleBufferUnderflow(final SSLEngine engine, ByteBuffer buffer) {
|
||||
if (engine == null || buffer == null) {
|
||||
return buffer;
|
||||
}
|
||||
if (buffer.position() < buffer.limit()) {
|
||||
return buffer;
|
||||
}
|
||||
ByteBuffer replaceBuffer = enlargeBuffer(buffer, engine.getSession().getPacketBufferSize());
|
||||
buffer.flip();
|
||||
replaceBuffer.put(buffer);
|
||||
return replaceBuffer;
|
||||
}
|
||||
|
||||
private static boolean doHandshakeUnwrap(final SocketChannel socketChannel, final SSLEngine sslEngine,
|
||||
ByteBuffer peerAppData, ByteBuffer peerNetData, final int appBufferSize) throws IOException {
|
||||
if (socketChannel == null || sslEngine == null || peerAppData == null || peerNetData == null || appBufferSize < 0) {
|
||||
return false;
|
||||
}
|
||||
if (socketChannel.read(peerNetData) < 0) {
|
||||
if (sslEngine.isInboundDone() && sslEngine.isOutboundDone()) {
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
sslEngine.closeInbound();
|
||||
} catch (SSLException e) {
|
||||
s_logger.warn("This SSL engine was forced to close inbound due to end of stream.");
|
||||
}
|
||||
sslEngine.closeOutbound();
|
||||
// After closeOutbound the engine will be set to WRAP state,
|
||||
// in order to try to send a close message to the client.
|
||||
return true;
|
||||
}
|
||||
peerNetData.flip();
|
||||
SSLEngineResult result = null;
|
||||
try {
|
||||
result = sslEngine.unwrap(peerNetData, peerAppData);
|
||||
peerNetData.compact();
|
||||
} catch (SSLException sslException) {
|
||||
s_logger.error("SSL error occurred while processing unwrap data: " + sslException.getMessage());
|
||||
sslEngine.closeOutbound();
|
||||
return true;
|
||||
}
|
||||
switch (result.getStatus()) {
|
||||
case OK:
|
||||
break;
|
||||
case BUFFER_OVERFLOW:
|
||||
// Will occur when peerAppData's capacity is smaller than the data derived from peerNetData's unwrap.
|
||||
peerAppData = enlargeBuffer(peerAppData, appBufferSize);
|
||||
break;
|
||||
case BUFFER_UNDERFLOW:
|
||||
// Will occur either when no data was read from the peer or when the peerNetData buffer
|
||||
// was too small to hold all peer's data.
|
||||
peerNetData = handleBufferUnderflow(sslEngine, peerNetData);
|
||||
break;
|
||||
case CLOSED:
|
||||
if (sslEngine.isOutboundDone()) {
|
||||
return false;
|
||||
} else {
|
||||
sslEngine.closeOutbound();
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw new IllegalStateException("Invalid SSL status: " + result.getStatus());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private static boolean doHandshakeWrap(final SocketChannel socketChannel, final SSLEngine sslEngine,
|
||||
ByteBuffer myAppData, ByteBuffer myNetData, ByteBuffer peerNetData,
|
||||
final int netBufferSize) throws IOException {
|
||||
if (socketChannel == null || sslEngine == null || myNetData == null || peerNetData == null
|
||||
|| myAppData == null || netBufferSize < 0) {
|
||||
return false;
|
||||
}
|
||||
myNetData.clear();
|
||||
SSLEngineResult result = null;
|
||||
try {
|
||||
result = sslEngine.wrap(myAppData, myNetData);
|
||||
} catch (SSLException sslException) {
|
||||
s_logger.error("SSL error occurred while processing wrap data: " + sslException.getMessage());
|
||||
sslEngine.closeOutbound();
|
||||
return true;
|
||||
}
|
||||
switch (result.getStatus()) {
|
||||
case OK :
|
||||
myNetData.flip();
|
||||
while (myNetData.hasRemaining()) {
|
||||
socketChannel.write(myNetData);
|
||||
}
|
||||
break;
|
||||
case BUFFER_OVERFLOW:
|
||||
// Will occur if there is not enough space in myNetData buffer to write all the data
|
||||
// that would be generated by the method wrap. Since myNetData is set to session's packet
|
||||
// size we should not get to this point because SSLEngine is supposed to produce messages
|
||||
// smaller or equal to that, but a general handling would be the following:
|
||||
myNetData = enlargeBuffer(myNetData, netBufferSize);
|
||||
break;
|
||||
case BUFFER_UNDERFLOW:
|
||||
throw new SSLException("Buffer underflow occurred after a wrap. We should not reach here.");
|
||||
case CLOSED:
|
||||
try {
|
||||
myNetData.flip();
|
||||
while (myNetData.hasRemaining()) {
|
||||
socketChannel.write(myNetData);
|
||||
}
|
||||
// At this point the handshake status will probably be NEED_UNWRAP
|
||||
// so we make sure that peerNetData is clear to read.
|
||||
peerNetData.clear();
|
||||
} catch (Exception e) {
|
||||
s_logger.error("Failed to send server's CLOSE message due to socket channel's failure.");
|
||||
}
|
||||
break;
|
||||
default:
|
||||
throw new IllegalStateException("Invalid SSL status: " + result.getStatus());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
public static boolean doHandshake(final SocketChannel socketChannel, final SSLEngine sslEngine, final boolean isClient) throws IOException {
|
||||
if (socketChannel == null || sslEngine == null) {
|
||||
return false;
|
||||
}
|
||||
final int appBufferSize = sslEngine.getSession().getApplicationBufferSize();
|
||||
final int netBufferSize = sslEngine.getSession().getPacketBufferSize();
|
||||
ByteBuffer myAppData = ByteBuffer.allocate(appBufferSize);
|
||||
ByteBuffer peerAppData = ByteBuffer.allocate(appBufferSize);
|
||||
ByteBuffer myNetData = ByteBuffer.allocate(netBufferSize);
|
||||
ByteBuffer peerNetData = ByteBuffer.allocate(netBufferSize);
|
||||
|
||||
final long startTimeMills = System.currentTimeMillis();
|
||||
|
||||
HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
|
||||
while (handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED
|
||||
&& handshakeStatus != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
|
||||
final long timeTaken = System.currentTimeMillis() - startTimeMills;
|
||||
if (timeTaken > 60000L) {
|
||||
s_logger.warn("SSL Handshake has taken more than 60s to connect to: " + socketChannel.getRemoteAddress() +
|
||||
". Please investigate this connection.");
|
||||
return false;
|
||||
}
|
||||
switch (handshakeStatus) {
|
||||
case NEED_UNWRAP:
|
||||
if (!doHandshakeUnwrap(socketChannel, sslEngine, peerAppData, peerNetData, appBufferSize)) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case NEED_WRAP:
|
||||
if (!doHandshakeWrap(socketChannel, sslEngine, myAppData, myNetData, peerNetData, netBufferSize)) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case NEED_TASK:
|
||||
Runnable task;
|
||||
while ((task = sslEngine.getDelegatedTask()) != null) {
|
||||
new Thread(task).run();
|
||||
}
|
||||
break;
|
||||
case FINISHED:
|
||||
break;
|
||||
case NOT_HANDSHAKING:
|
||||
break;
|
||||
default:
|
||||
throw new IllegalStateException("Invalid SSL status: " + handshakeStatus);
|
||||
}
|
||||
handshakeStatus = sslEngine.getHandshakeStatus();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -36,7 +36,6 @@ public class NioClient extends NioConnection {
|
||||
private static final Logger s_logger = Logger.getLogger(NioClient.class);
|
||||
|
||||
protected String _host;
|
||||
protected String _bindAddress;
|
||||
protected SocketChannel _clientConnection;
|
||||
|
||||
public NioClient(final String name, final String host, final int port, final int workers, final HandlerFactory factory) {
|
||||
@ -44,10 +43,6 @@ public class NioClient extends NioConnection {
|
||||
_host = host;
|
||||
}
|
||||
|
||||
public void setBindAddress(final String ipAddress) {
|
||||
_bindAddress = ipAddress;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void init() throws IOException {
|
||||
_selector = Selector.open();
|
||||
@ -55,33 +50,25 @@ public class NioClient extends NioConnection {
|
||||
|
||||
try {
|
||||
_clientConnection = SocketChannel.open();
|
||||
_clientConnection.configureBlocking(true);
|
||||
|
||||
s_logger.info("Connecting to " + _host + ":" + _port);
|
||||
|
||||
if (_bindAddress != null) {
|
||||
s_logger.info("Binding outbound interface at " + _bindAddress);
|
||||
|
||||
final InetSocketAddress bindAddr = new InetSocketAddress(_bindAddress, 0);
|
||||
_clientConnection.socket().bind(bindAddr);
|
||||
}
|
||||
|
||||
final InetSocketAddress peerAddr = new InetSocketAddress(_host, _port);
|
||||
_clientConnection.connect(peerAddr);
|
||||
|
||||
SSLEngine sslEngine = null;
|
||||
// Begin SSL handshake in BLOCKING mode
|
||||
_clientConnection.configureBlocking(true);
|
||||
_clientConnection.configureBlocking(false);
|
||||
|
||||
final SSLContext sslContext = Link.initSSLContext(true);
|
||||
sslEngine = sslContext.createSSLEngine(_host, _port);
|
||||
SSLEngine sslEngine = sslContext.createSSLEngine(_host, _port);
|
||||
sslEngine.setUseClientMode(true);
|
||||
sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols()));
|
||||
|
||||
Link.doHandshake(_clientConnection, sslEngine, true);
|
||||
sslEngine.beginHandshake();
|
||||
if (!Link.doHandshake(_clientConnection, sslEngine, true)) {
|
||||
s_logger.error("SSL Handshake failed while connecting to host: " + _host + " port: " + _port);
|
||||
_selector.close();
|
||||
throw new IOException("SSL Handshake failed while connecting to host: " + _host + " port: " + _port);
|
||||
}
|
||||
s_logger.info("SSL: Handshake done");
|
||||
s_logger.info("Connected to " + _host + ":" + _port);
|
||||
|
||||
_clientConnection.configureBlocking(false);
|
||||
final Link link = new Link(peerAddr, this);
|
||||
link.setSSLEngine(sslEngine);
|
||||
final SelectionKey key = _clientConnection.register(_selector, SelectionKey.OP_READ);
|
||||
|
||||
@ -19,8 +19,13 @@
|
||||
|
||||
package com.cloud.utils.nio;
|
||||
|
||||
import static com.cloud.utils.AutoCloseableUtil.closeAutoCloseable;
|
||||
import com.cloud.utils.concurrency.NamedThreadFactory;
|
||||
import com.cloud.utils.exception.NioConnectionException;
|
||||
import org.apache.cloudstack.utils.security.SSLUtils;
|
||||
import org.apache.log4j.Logger;
|
||||
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLEngine;
|
||||
import java.io.IOException;
|
||||
import java.net.ConnectException;
|
||||
import java.net.InetSocketAddress;
|
||||
@ -44,14 +49,7 @@ import java.util.concurrent.LinkedBlockingQueue;
|
||||
import java.util.concurrent.ThreadPoolExecutor;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLEngine;
|
||||
|
||||
import org.apache.cloudstack.utils.security.SSLUtils;
|
||||
import org.apache.log4j.Logger;
|
||||
|
||||
import com.cloud.utils.concurrency.NamedThreadFactory;
|
||||
import com.cloud.utils.exception.NioConnectionException;
|
||||
import static com.cloud.utils.AutoCloseableUtil.closeAutoCloseable;
|
||||
|
||||
/**
|
||||
* NioConnection abstracts the NIO socket operations. The Java implementation
|
||||
@ -71,6 +69,7 @@ public abstract class NioConnection implements Callable<Boolean> {
|
||||
protected HandlerFactory _factory;
|
||||
protected String _name;
|
||||
protected ExecutorService _executor;
|
||||
protected ExecutorService _sslHandshakeExecutor;
|
||||
|
||||
public NioConnection(final String name, final int port, final int workers, final HandlerFactory factory) {
|
||||
_name = name;
|
||||
@ -79,6 +78,7 @@ public abstract class NioConnection implements Callable<Boolean> {
|
||||
_port = port;
|
||||
_factory = factory;
|
||||
_executor = new ThreadPoolExecutor(workers, 5 * workers, 1, TimeUnit.DAYS, new LinkedBlockingQueue<Runnable>(), new NamedThreadFactory(name + "-Handler"));
|
||||
_sslHandshakeExecutor = Executors.newCachedThreadPool(new NamedThreadFactory(name + "-SSLHandshakeHandler"));
|
||||
}
|
||||
|
||||
public void start() throws NioConnectionException {
|
||||
@ -185,8 +185,9 @@ public abstract class NioConnection implements Callable<Boolean> {
|
||||
|
||||
protected void accept(final SelectionKey key) throws IOException {
|
||||
final ServerSocketChannel serverSocketChannel = (ServerSocketChannel)key.channel();
|
||||
|
||||
final SocketChannel socketChannel = serverSocketChannel.accept();
|
||||
socketChannel.configureBlocking(false);
|
||||
|
||||
final Socket socket = socketChannel.socket();
|
||||
socket.setKeepAlive(true);
|
||||
|
||||
@ -194,43 +195,52 @@ public abstract class NioConnection implements Callable<Boolean> {
|
||||
s_logger.trace("Connection accepted for " + socket);
|
||||
}
|
||||
|
||||
// Begin SSL handshake in BLOCKING mode
|
||||
socketChannel.configureBlocking(true);
|
||||
|
||||
SSLEngine sslEngine = null;
|
||||
final SSLEngine sslEngine;
|
||||
try {
|
||||
final SSLContext sslContext = Link.initSSLContext(false);
|
||||
sslEngine = sslContext.createSSLEngine();
|
||||
sslEngine.setUseClientMode(false);
|
||||
sslEngine.setNeedClientAuth(false);
|
||||
sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols()));
|
||||
|
||||
Link.doHandshake(socketChannel, sslEngine, false);
|
||||
|
||||
final NioConnection nioConnection = this;
|
||||
_sslHandshakeExecutor.submit(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
_selector.wakeup();
|
||||
try {
|
||||
sslEngine.beginHandshake();
|
||||
if (!Link.doHandshake(socketChannel, sslEngine, false)) {
|
||||
throw new IOException("SSL handshake timed out with " + socketChannel.getRemoteAddress());
|
||||
}
|
||||
if (s_logger.isTraceEnabled()) {
|
||||
s_logger.trace("SSL: Handshake done");
|
||||
}
|
||||
final InetSocketAddress saddr = (InetSocketAddress)socket.getRemoteSocketAddress();
|
||||
final Link link = new Link(saddr, nioConnection);
|
||||
link.setSSLEngine(sslEngine);
|
||||
link.setKey(socketChannel.register(key.selector(), SelectionKey.OP_READ, link));
|
||||
final Task task = _factory.create(Task.Type.CONNECT, link, null);
|
||||
registerLink(saddr, link);
|
||||
_executor.submit(task);
|
||||
} catch (IOException e) {
|
||||
if (s_logger.isTraceEnabled()) {
|
||||
s_logger.trace("Connection closed due to failure: " + e.getMessage());
|
||||
}
|
||||
closeAutoCloseable(socket, "accepting socket");
|
||||
closeAutoCloseable(socketChannel, "accepting socketChannel");
|
||||
} finally {
|
||||
_selector.wakeup();
|
||||
}
|
||||
}
|
||||
});
|
||||
} catch (final Exception e) {
|
||||
if (s_logger.isTraceEnabled()) {
|
||||
s_logger.trace("Socket " + socket + " closed on read. Probably -1 returned: " + e.getMessage());
|
||||
s_logger.trace("Connection closed due to failure: " + e.getMessage());
|
||||
}
|
||||
closeAutoCloseable(socket, "accepting socket");
|
||||
closeAutoCloseable(socketChannel, "accepting socketChannel");
|
||||
closeAutoCloseable(socket, "opened socket");
|
||||
return;
|
||||
}
|
||||
|
||||
if (s_logger.isTraceEnabled()) {
|
||||
s_logger.trace("SSL: Handshake done");
|
||||
}
|
||||
socketChannel.configureBlocking(false);
|
||||
final InetSocketAddress saddr = (InetSocketAddress)socket.getRemoteSocketAddress();
|
||||
final Link link = new Link(saddr, this);
|
||||
link.setSSLEngine(sslEngine);
|
||||
link.setKey(socketChannel.register(key.selector(), SelectionKey.OP_READ, link));
|
||||
final Task task = _factory.create(Task.Type.CONNECT, link, null);
|
||||
registerLink(saddr, link);
|
||||
|
||||
try {
|
||||
_executor.submit(task);
|
||||
} catch (final Exception e) {
|
||||
s_logger.warn("Exception occurred when submitting the task", e);
|
||||
} finally {
|
||||
_selector.wakeup();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -43,6 +43,10 @@ public class NioServer extends NioConnection {
|
||||
_links = new WeakHashMap<InetSocketAddress, Link>(1024);
|
||||
}
|
||||
|
||||
public int getPort() {
|
||||
return _serverSocket.socket().getLocalPort();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void init() throws IOException {
|
||||
_selector = SelectorProvider.provider().openSelector();
|
||||
@ -53,9 +57,9 @@ public class NioServer extends NioConnection {
|
||||
_localAddr = new InetSocketAddress(_port);
|
||||
_serverSocket.socket().bind(_localAddr);
|
||||
|
||||
_serverSocket.register(_selector, SelectionKey.OP_ACCEPT, null);
|
||||
_serverSocket.register(_selector, SelectionKey.OP_ACCEPT);
|
||||
|
||||
s_logger.info("NioConnection started and listening on " + _localAddr.toString());
|
||||
s_logger.info("NioConnection started and listening on " + _serverSocket.socket().getLocalSocketAddress());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -19,14 +19,7 @@
|
||||
|
||||
package com.cloud.utils.testcase;
|
||||
|
||||
import java.nio.channels.ClosedChannelException;
|
||||
import java.util.Random;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
|
||||
import org.apache.log4j.Logger;
|
||||
import org.junit.Assert;
|
||||
|
||||
import com.cloud.utils.concurrency.NamedThreadFactory;
|
||||
import com.cloud.utils.exception.NioConnectionException;
|
||||
import com.cloud.utils.nio.HandlerFactory;
|
||||
import com.cloud.utils.nio.Link;
|
||||
@ -34,131 +27,200 @@ import com.cloud.utils.nio.NioClient;
|
||||
import com.cloud.utils.nio.NioServer;
|
||||
import com.cloud.utils.nio.Task;
|
||||
import com.cloud.utils.nio.Task.Type;
|
||||
import org.apache.log4j.Logger;
|
||||
import org.junit.After;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.nio.channels.ClosedChannelException;
|
||||
import java.nio.channels.Selector;
|
||||
import java.nio.channels.SocketChannel;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
*
|
||||
*
|
||||
* NioTest demonstrates that NioServer can function without getting its main IO
|
||||
* loop blocked when an aggressive or malicious client connects to the server but
|
||||
* fail to participate in SSL handshake. In this test, we run bunch of clients
|
||||
* that send a known payload to the server, to which multiple malicious clients
|
||||
* also try to connect and hang.
|
||||
* A malicious client could cause denial-of-service if the server's main IO loop
|
||||
* along with SSL handshake was blocking. A passing tests shows that NioServer
|
||||
* can still function in case of connection load and that the main IO loop along
|
||||
* with SSL handshake is non-blocking with some internal timeout mechanism.
|
||||
*/
|
||||
|
||||
public class NioTest extends TestCase {
|
||||
public class NioTest {
|
||||
|
||||
private static final Logger s_logger = Logger.getLogger(NioTest.class);
|
||||
private static final Logger LOGGER = Logger.getLogger(NioTest.class);
|
||||
|
||||
private NioServer _server;
|
||||
private NioClient _client;
|
||||
// Test should fail in due time instead of looping forever
|
||||
private static final int TESTTIMEOUT = 300000;
|
||||
|
||||
private Link _clientLink;
|
||||
final private int totalTestCount = 5;
|
||||
private int completedTestCount = 0;
|
||||
|
||||
private int _testCount;
|
||||
private int _completedCount;
|
||||
private NioServer server;
|
||||
private List<NioClient> clients = new ArrayList<>();
|
||||
private List<NioClient> maliciousClients = new ArrayList<>();
|
||||
|
||||
private ExecutorService clientExecutor = Executors.newFixedThreadPool(totalTestCount, new NamedThreadFactory("NioClientHandler"));;
|
||||
private ExecutorService maliciousExecutor = Executors.newFixedThreadPool(5*totalTestCount, new NamedThreadFactory("MaliciousNioClientHandler"));;
|
||||
|
||||
private Random randomGenerator = new Random();
|
||||
private byte[] testBytes;
|
||||
|
||||
private boolean isTestsDone() {
|
||||
boolean result;
|
||||
synchronized (this) {
|
||||
result = _testCount == _completedCount;
|
||||
result = totalTestCount == completedTestCount;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private void getOneMoreTest() {
|
||||
synchronized (this) {
|
||||
_testCount++;
|
||||
}
|
||||
}
|
||||
|
||||
private void oneMoreTestDone() {
|
||||
synchronized (this) {
|
||||
_completedCount++;
|
||||
completedTestCount++;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@Before
|
||||
public void setUp() {
|
||||
s_logger.info("Test");
|
||||
LOGGER.info("Setting up Benchmark Test");
|
||||
|
||||
_testCount = 0;
|
||||
_completedCount = 0;
|
||||
completedTestCount = 0;
|
||||
testBytes = new byte[1000000];
|
||||
randomGenerator.nextBytes(testBytes);
|
||||
|
||||
_server = new NioServer("NioTestServer", 7777, 5, new NioTestServer());
|
||||
server = new NioServer("NioTestServer", 0, 1, new NioTestServer());
|
||||
try {
|
||||
_server.start();
|
||||
server.start();
|
||||
} catch (final NioConnectionException e) {
|
||||
fail(e.getMessage());
|
||||
Assert.fail(e.getMessage());
|
||||
}
|
||||
|
||||
_client = new NioClient("NioTestServer", "127.0.0.1", 7777, 5, new NioTestClient());
|
||||
try {
|
||||
_client.start();
|
||||
} catch (final NioConnectionException e) {
|
||||
fail(e.getMessage());
|
||||
}
|
||||
|
||||
while (_clientLink == null) {
|
||||
try {
|
||||
s_logger.debug("Link is not up! Waiting ...");
|
||||
Thread.sleep(1000);
|
||||
} catch (final InterruptedException e) {
|
||||
// TODO Auto-generated catch block
|
||||
e.printStackTrace();
|
||||
for (int i = 0; i < totalTestCount; i++) {
|
||||
for (int j = 0; j < 4; j++) {
|
||||
final NioClient maliciousClient = new NioMaliciousClient("NioMaliciousTestClient-" + i, "127.0.0.1", server.getPort(), 1, new NioMaliciousTestClient());
|
||||
maliciousClients.add(maliciousClient);
|
||||
maliciousExecutor.submit(new ThreadedNioClient(maliciousClient));
|
||||
}
|
||||
final NioClient client = new NioClient("NioTestClient-" + i, "127.0.0.1", server.getPort(), 1, new NioTestClient());
|
||||
clients.add(client);
|
||||
clientExecutor.submit(new ThreadedNioClient(client));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@After
|
||||
public void tearDown() {
|
||||
while (!isTestsDone()) {
|
||||
try {
|
||||
s_logger.debug(_completedCount + "/" + _testCount + " tests done. Waiting for completion");
|
||||
Thread.sleep(1000);
|
||||
} catch (final InterruptedException e) {
|
||||
// TODO Auto-generated catch block
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
stopClient();
|
||||
stopServer();
|
||||
}
|
||||
|
||||
protected void stopClient() {
|
||||
_client.stop();
|
||||
s_logger.info("Client stopped.");
|
||||
for (NioClient client : clients) {
|
||||
client.stop();
|
||||
}
|
||||
for (NioClient maliciousClient : maliciousClients) {
|
||||
maliciousClient.stop();
|
||||
}
|
||||
LOGGER.info("Clients stopped.");
|
||||
}
|
||||
|
||||
protected void stopServer() {
|
||||
_server.stop();
|
||||
s_logger.info("Server stopped.");
|
||||
server.stop();
|
||||
LOGGER.info("Server stopped.");
|
||||
}
|
||||
|
||||
protected void setClientLink(final Link link) {
|
||||
_clientLink = link;
|
||||
}
|
||||
|
||||
Random randomGenerator = new Random();
|
||||
|
||||
byte[] _testBytes;
|
||||
|
||||
@Test(timeout=TESTTIMEOUT)
|
||||
public void testConnection() {
|
||||
_testBytes = new byte[1000000];
|
||||
randomGenerator.nextBytes(_testBytes);
|
||||
try {
|
||||
getOneMoreTest();
|
||||
_clientLink.send(_testBytes);
|
||||
s_logger.info("Client: Data sent");
|
||||
getOneMoreTest();
|
||||
_clientLink.send(_testBytes);
|
||||
s_logger.info("Client: Data sent");
|
||||
} catch (final ClosedChannelException e) {
|
||||
// TODO Auto-generated catch block
|
||||
e.printStackTrace();
|
||||
while (!isTestsDone()) {
|
||||
try {
|
||||
LOGGER.debug(completedTestCount + "/" + totalTestCount + " tests done. Waiting for completion");
|
||||
Thread.sleep(1000);
|
||||
} catch (final InterruptedException e) {
|
||||
Assert.fail(e.getMessage());
|
||||
}
|
||||
}
|
||||
LOGGER.debug(completedTestCount + "/" + totalTestCount + " tests done.");
|
||||
}
|
||||
|
||||
protected void doServerProcess(final byte[] data) {
|
||||
oneMoreTestDone();
|
||||
Assert.assertArrayEquals(_testBytes, data);
|
||||
s_logger.info("Verify done.");
|
||||
Assert.assertArrayEquals(testBytes, data);
|
||||
LOGGER.info("Verify data received by server done.");
|
||||
}
|
||||
|
||||
public byte[] getTestBytes() {
|
||||
return testBytes;
|
||||
}
|
||||
|
||||
public class ThreadedNioClient implements Runnable {
|
||||
final private NioClient client;
|
||||
ThreadedNioClient(final NioClient client) {
|
||||
this.client = client;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
try {
|
||||
client.start();
|
||||
} catch (NioConnectionException e) {
|
||||
Assert.fail(e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public class NioMaliciousClient extends NioClient {
|
||||
|
||||
public NioMaliciousClient(String name, String host, int port, int workers, HandlerFactory factory) {
|
||||
super(name, host, port, workers, factory);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void init() throws IOException {
|
||||
_selector = Selector.open();
|
||||
try {
|
||||
_clientConnection = SocketChannel.open();
|
||||
LOGGER.info("Connecting to " + _host + ":" + _port);
|
||||
final InetSocketAddress peerAddr = new InetSocketAddress(_host, _port);
|
||||
_clientConnection.connect(peerAddr);
|
||||
// This is done on purpose, the malicious client would connect
|
||||
// to the server and then do nothing, hence using a large sleep value
|
||||
Thread.sleep(Long.MAX_VALUE);
|
||||
} catch (final IOException e) {
|
||||
_selector.close();
|
||||
throw e;
|
||||
} catch (InterruptedException e) {
|
||||
LOGGER.debug(e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public class NioMaliciousTestClient implements HandlerFactory {
|
||||
|
||||
@Override
|
||||
public Task create(final Type type, final Link link, final byte[] data) {
|
||||
return new NioMaliciousTestClientHandler(type, link, data);
|
||||
}
|
||||
|
||||
public class NioMaliciousTestClientHandler extends Task {
|
||||
|
||||
public NioMaliciousTestClientHandler(final Type type, final Link link, final byte[] data) {
|
||||
super(type, link, data);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void doTask(final Task task) {
|
||||
LOGGER.info("Malicious Client: Received task " + task.getType().toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public class NioTestClient implements HandlerFactory {
|
||||
@ -177,18 +239,23 @@ public class NioTest extends TestCase {
|
||||
@Override
|
||||
public void doTask(final Task task) {
|
||||
if (task.getType() == Task.Type.CONNECT) {
|
||||
s_logger.info("Client: Received CONNECT task");
|
||||
setClientLink(task.getLink());
|
||||
LOGGER.info("Client: Received CONNECT task");
|
||||
try {
|
||||
LOGGER.info("Sending data to server");
|
||||
task.getLink().send(getTestBytes());
|
||||
} catch (ClosedChannelException e) {
|
||||
LOGGER.error(e.getMessage());
|
||||
e.printStackTrace();
|
||||
}
|
||||
} else if (task.getType() == Task.Type.DATA) {
|
||||
s_logger.info("Client: Received DATA task");
|
||||
LOGGER.info("Client: Received DATA task");
|
||||
} else if (task.getType() == Task.Type.DISCONNECT) {
|
||||
s_logger.info("Client: Received DISCONNECT task");
|
||||
LOGGER.info("Client: Received DISCONNECT task");
|
||||
stopClient();
|
||||
} else if (task.getType() == Task.Type.OTHER) {
|
||||
s_logger.info("Client: Received OTHER task");
|
||||
LOGGER.info("Client: Received OTHER task");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@ -208,15 +275,15 @@ public class NioTest extends TestCase {
|
||||
@Override
|
||||
public void doTask(final Task task) {
|
||||
if (task.getType() == Task.Type.CONNECT) {
|
||||
s_logger.info("Server: Received CONNECT task");
|
||||
LOGGER.info("Server: Received CONNECT task");
|
||||
} else if (task.getType() == Task.Type.DATA) {
|
||||
s_logger.info("Server: Received DATA task");
|
||||
LOGGER.info("Server: Received DATA task");
|
||||
doServerProcess(task.getData());
|
||||
} else if (task.getType() == Task.Type.DISCONNECT) {
|
||||
s_logger.info("Server: Received DISCONNECT task");
|
||||
LOGGER.info("Server: Received DISCONNECT task");
|
||||
stopServer();
|
||||
} else if (task.getType() == Task.Type.OTHER) {
|
||||
s_logger.info("Server: Received OTHER task");
|
||||
LOGGER.info("Server: Received OTHER task");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user