Merge pull request #1493 from shapeblue/nio-fix

CLOUDSTACK-9348: Use non-blocking SSL handshake in NioConnection/Link- Uses non-blocking socket config in NioClient and NioServer/NioConnection
- Scalable connectivity from agents and peer clustered-management server
- Removes blocking ssl handshake code with a non-blocking code
- Protects from denial-of-service issues that can degrade mgmt server responsiveness
  due to an aggressive/malicious client
- Uses separate executor services for handling connect/accept events

Changes are covered the NioTest so I did not write a new test, advise how we can improve this. Further, I tried to invest time on writing a benchmark test to reproduce a degraded server but could not write it deterministic-ally (sometimes fails/passes but not always). Review, CI testing and feedback requested /cc @swill @jburwell @DaanHoogland @wido @remibergsma @rafaelweingartner @GabrielBrascher

* pr/1493:
  CLOUDSTACK-9348: Use non-blocking SSL handshake
  CLOUDSTACK-9348: Unit test to demonstrate denial of service attack

Signed-off-by: Will Stevens <williamstevens@gmail.com>
This commit is contained in:
Will Stevens 2016-05-04 10:30:58 -04:00
commit 7ce0e10fbc
7 changed files with 420 additions and 284 deletions

View File

@ -499,7 +499,7 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust
SocketChannel ch1 = null;
try {
ch1 = SocketChannel.open(new InetSocketAddress(addr, Port.value()));
ch1.configureBlocking(true); // make sure we are working at blocking mode
ch1.configureBlocking(false);
ch1.socket().setKeepAlive(true);
ch1.socket().setSoTimeout(60 * 1000);
try {
@ -507,8 +507,11 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust
sslEngine = sslContext.createSSLEngine(ip, Port.value());
sslEngine.setUseClientMode(true);
sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols()));
Link.doHandshake(ch1, sslEngine, true);
sslEngine.beginHandshake();
if (!Link.doHandshake(ch1, sslEngine, true)) {
ch1.close();
throw new IOException("SSL handshake failed!");
}
s_logger.info("SSL: Handshake done");
} catch (final Exception e) {
ch1.close();

View File

@ -208,7 +208,6 @@
<excludes>
<exclude>com/cloud/utils/testcase/*TestCase*</exclude>
<exclude>com/cloud/utils/db/*Test*</exclude>
<exclude>com/cloud/utils/testcase/NioTest.java</exclude>
</excludes>
</configuration>
</plugin>

View File

@ -19,36 +19,32 @@
package com.cloud.utils.nio;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.InetSocketAddress;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.util.concurrent.ConcurrentLinkedQueue;
import com.cloud.utils.PropertiesUtil;
import com.cloud.utils.db.DbProperties;
import org.apache.cloudstack.utils.security.SSLUtils;
import org.apache.log4j.Logger;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import org.apache.cloudstack.utils.security.SSLUtils;
import org.apache.log4j.Logger;
import com.cloud.utils.PropertiesUtil;
import com.cloud.utils.db.DbProperties;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.util.concurrent.ConcurrentLinkedQueue;
/**
*/
@ -453,115 +449,185 @@ public class Link {
return sslContext;
}
public static void doHandshake(SocketChannel ch, SSLEngine sslEngine, boolean isClient) throws IOException {
if (s_logger.isTraceEnabled()) {
s_logger.trace("SSL: begin Handshake, isClient: " + isClient);
public static ByteBuffer enlargeBuffer(ByteBuffer buffer, final int sessionProposedCapacity) {
if (buffer == null || sessionProposedCapacity < 0) {
return buffer;
}
SSLEngineResult engResult;
SSLSession sslSession = sslEngine.getSession();
HandshakeStatus hsStatus;
ByteBuffer in_pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
ByteBuffer in_appBuf = ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40);
ByteBuffer out_pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
ByteBuffer out_appBuf = ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40);
int count;
ch.socket().setSoTimeout(60 * 1000);
InputStream inStream = ch.socket().getInputStream();
// Use readCh to make sure the timeout on reading is working
ReadableByteChannel readCh = Channels.newChannel(inStream);
if (isClient) {
hsStatus = SSLEngineResult.HandshakeStatus.NEED_WRAP;
if (sessionProposedCapacity > buffer.capacity()) {
buffer = ByteBuffer.allocate(sessionProposedCapacity);
} else {
hsStatus = SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
buffer = ByteBuffer.allocate(buffer.capacity() * 2);
}
return buffer;
}
while (hsStatus != SSLEngineResult.HandshakeStatus.FINISHED) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("SSL: Handshake status " + hsStatus);
}
engResult = null;
if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP) {
out_pkgBuf.clear();
out_appBuf.clear();
out_appBuf.put("Hello".getBytes());
engResult = sslEngine.wrap(out_appBuf, out_pkgBuf);
out_pkgBuf.flip();
int remain = out_pkgBuf.limit();
while (remain != 0) {
remain -= ch.write(out_pkgBuf);
if (remain < 0) {
throw new IOException("Too much bytes sent?");
}
}
} else if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
in_appBuf.clear();
// One packet may contained multiply operation
if (in_pkgBuf.position() == 0 || !in_pkgBuf.hasRemaining()) {
in_pkgBuf.clear();
count = 0;
try {
count = readCh.read(in_pkgBuf);
} catch (SocketTimeoutException ex) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Handshake reading time out! Cut the connection");
}
count = -1;
}
if (count == -1) {
throw new IOException("Connection closed with -1 on reading size.");
}
in_pkgBuf.flip();
}
engResult = sslEngine.unwrap(in_pkgBuf, in_appBuf);
ByteBuffer tmp_pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
int loop_count = 0;
while (engResult.getStatus() == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
// The client is too slow? Cut it and let it reconnect
if (loop_count > 10) {
throw new IOException("Too many times in SSL BUFFER_UNDERFLOW, disconnect guest.");
}
// We need more packets to complete this operation
if (s_logger.isTraceEnabled()) {
s_logger.trace("SSL: Buffer underflowed, getting more packets");
}
tmp_pkgBuf.clear();
count = ch.read(tmp_pkgBuf);
if (count == -1) {
throw new IOException("Connection closed with -1 on reading size.");
}
tmp_pkgBuf.flip();
in_pkgBuf.mark();
in_pkgBuf.position(in_pkgBuf.limit());
in_pkgBuf.limit(in_pkgBuf.limit() + tmp_pkgBuf.limit());
in_pkgBuf.put(tmp_pkgBuf);
in_pkgBuf.reset();
in_appBuf.clear();
engResult = sslEngine.unwrap(in_pkgBuf, in_appBuf);
loop_count++;
}
} else if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK) {
Runnable run;
while ((run = sslEngine.getDelegatedTask()) != null) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("SSL: Running delegated task!");
}
run.run();
}
} else if (hsStatus == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
throw new IOException("NOT a handshaking!");
}
if (engResult != null && engResult.getStatus() != SSLEngineResult.Status.OK) {
throw new IOException("Fail to handshake! " + engResult.getStatus());
}
if (engResult != null)
hsStatus = engResult.getHandshakeStatus();
else
hsStatus = sslEngine.getHandshakeStatus();
public static ByteBuffer handleBufferUnderflow(final SSLEngine engine, ByteBuffer buffer) {
if (engine == null || buffer == null) {
return buffer;
}
if (buffer.position() < buffer.limit()) {
return buffer;
}
ByteBuffer replaceBuffer = enlargeBuffer(buffer, engine.getSession().getPacketBufferSize());
buffer.flip();
replaceBuffer.put(buffer);
return replaceBuffer;
}
private static boolean doHandshakeUnwrap(final SocketChannel socketChannel, final SSLEngine sslEngine,
ByteBuffer peerAppData, ByteBuffer peerNetData, final int appBufferSize) throws IOException {
if (socketChannel == null || sslEngine == null || peerAppData == null || peerNetData == null || appBufferSize < 0) {
return false;
}
if (socketChannel.read(peerNetData) < 0) {
if (sslEngine.isInboundDone() && sslEngine.isOutboundDone()) {
return false;
}
try {
sslEngine.closeInbound();
} catch (SSLException e) {
s_logger.warn("This SSL engine was forced to close inbound due to end of stream.");
}
sslEngine.closeOutbound();
// After closeOutbound the engine will be set to WRAP state,
// in order to try to send a close message to the client.
return true;
}
peerNetData.flip();
SSLEngineResult result = null;
try {
result = sslEngine.unwrap(peerNetData, peerAppData);
peerNetData.compact();
} catch (SSLException sslException) {
s_logger.error("SSL error occurred while processing unwrap data: " + sslException.getMessage());
sslEngine.closeOutbound();
return true;
}
switch (result.getStatus()) {
case OK:
break;
case BUFFER_OVERFLOW:
// Will occur when peerAppData's capacity is smaller than the data derived from peerNetData's unwrap.
peerAppData = enlargeBuffer(peerAppData, appBufferSize);
break;
case BUFFER_UNDERFLOW:
// Will occur either when no data was read from the peer or when the peerNetData buffer
// was too small to hold all peer's data.
peerNetData = handleBufferUnderflow(sslEngine, peerNetData);
break;
case CLOSED:
if (sslEngine.isOutboundDone()) {
return false;
} else {
sslEngine.closeOutbound();
break;
}
default:
throw new IllegalStateException("Invalid SSL status: " + result.getStatus());
}
return true;
}
private static boolean doHandshakeWrap(final SocketChannel socketChannel, final SSLEngine sslEngine,
ByteBuffer myAppData, ByteBuffer myNetData, ByteBuffer peerNetData,
final int netBufferSize) throws IOException {
if (socketChannel == null || sslEngine == null || myNetData == null || peerNetData == null
|| myAppData == null || netBufferSize < 0) {
return false;
}
myNetData.clear();
SSLEngineResult result = null;
try {
result = sslEngine.wrap(myAppData, myNetData);
} catch (SSLException sslException) {
s_logger.error("SSL error occurred while processing wrap data: " + sslException.getMessage());
sslEngine.closeOutbound();
return true;
}
switch (result.getStatus()) {
case OK :
myNetData.flip();
while (myNetData.hasRemaining()) {
socketChannel.write(myNetData);
}
break;
case BUFFER_OVERFLOW:
// Will occur if there is not enough space in myNetData buffer to write all the data
// that would be generated by the method wrap. Since myNetData is set to session's packet
// size we should not get to this point because SSLEngine is supposed to produce messages
// smaller or equal to that, but a general handling would be the following:
myNetData = enlargeBuffer(myNetData, netBufferSize);
break;
case BUFFER_UNDERFLOW:
throw new SSLException("Buffer underflow occurred after a wrap. We should not reach here.");
case CLOSED:
try {
myNetData.flip();
while (myNetData.hasRemaining()) {
socketChannel.write(myNetData);
}
// At this point the handshake status will probably be NEED_UNWRAP
// so we make sure that peerNetData is clear to read.
peerNetData.clear();
} catch (Exception e) {
s_logger.error("Failed to send server's CLOSE message due to socket channel's failure.");
}
break;
default:
throw new IllegalStateException("Invalid SSL status: " + result.getStatus());
}
return true;
}
public static boolean doHandshake(final SocketChannel socketChannel, final SSLEngine sslEngine, final boolean isClient) throws IOException {
if (socketChannel == null || sslEngine == null) {
return false;
}
final int appBufferSize = sslEngine.getSession().getApplicationBufferSize();
final int netBufferSize = sslEngine.getSession().getPacketBufferSize();
ByteBuffer myAppData = ByteBuffer.allocate(appBufferSize);
ByteBuffer peerAppData = ByteBuffer.allocate(appBufferSize);
ByteBuffer myNetData = ByteBuffer.allocate(netBufferSize);
ByteBuffer peerNetData = ByteBuffer.allocate(netBufferSize);
final long startTimeMills = System.currentTimeMillis();
HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
while (handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED
&& handshakeStatus != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
final long timeTaken = System.currentTimeMillis() - startTimeMills;
if (timeTaken > 60000L) {
s_logger.warn("SSL Handshake has taken more than 60s to connect to: " + socketChannel.getRemoteAddress() +
". Please investigate this connection.");
return false;
}
switch (handshakeStatus) {
case NEED_UNWRAP:
if (!doHandshakeUnwrap(socketChannel, sslEngine, peerAppData, peerNetData, appBufferSize)) {
return false;
}
break;
case NEED_WRAP:
if (!doHandshakeWrap(socketChannel, sslEngine, myAppData, myNetData, peerNetData, netBufferSize)) {
return false;
}
break;
case NEED_TASK:
Runnable task;
while ((task = sslEngine.getDelegatedTask()) != null) {
new Thread(task).run();
}
break;
case FINISHED:
break;
case NOT_HANDSHAKING:
break;
default:
throw new IllegalStateException("Invalid SSL status: " + handshakeStatus);
}
handshakeStatus = sslEngine.getHandshakeStatus();
}
return true;
}
}

View File

@ -36,7 +36,6 @@ public class NioClient extends NioConnection {
private static final Logger s_logger = Logger.getLogger(NioClient.class);
protected String _host;
protected String _bindAddress;
protected SocketChannel _clientConnection;
public NioClient(final String name, final String host, final int port, final int workers, final HandlerFactory factory) {
@ -44,10 +43,6 @@ public class NioClient extends NioConnection {
_host = host;
}
public void setBindAddress(final String ipAddress) {
_bindAddress = ipAddress;
}
@Override
protected void init() throws IOException {
_selector = Selector.open();
@ -55,33 +50,25 @@ public class NioClient extends NioConnection {
try {
_clientConnection = SocketChannel.open();
_clientConnection.configureBlocking(true);
s_logger.info("Connecting to " + _host + ":" + _port);
if (_bindAddress != null) {
s_logger.info("Binding outbound interface at " + _bindAddress);
final InetSocketAddress bindAddr = new InetSocketAddress(_bindAddress, 0);
_clientConnection.socket().bind(bindAddr);
}
final InetSocketAddress peerAddr = new InetSocketAddress(_host, _port);
_clientConnection.connect(peerAddr);
SSLEngine sslEngine = null;
// Begin SSL handshake in BLOCKING mode
_clientConnection.configureBlocking(true);
_clientConnection.configureBlocking(false);
final SSLContext sslContext = Link.initSSLContext(true);
sslEngine = sslContext.createSSLEngine(_host, _port);
SSLEngine sslEngine = sslContext.createSSLEngine(_host, _port);
sslEngine.setUseClientMode(true);
sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols()));
Link.doHandshake(_clientConnection, sslEngine, true);
sslEngine.beginHandshake();
if (!Link.doHandshake(_clientConnection, sslEngine, true)) {
s_logger.error("SSL Handshake failed while connecting to host: " + _host + " port: " + _port);
_selector.close();
throw new IOException("SSL Handshake failed while connecting to host: " + _host + " port: " + _port);
}
s_logger.info("SSL: Handshake done");
s_logger.info("Connected to " + _host + ":" + _port);
_clientConnection.configureBlocking(false);
final Link link = new Link(peerAddr, this);
link.setSSLEngine(sslEngine);
final SelectionKey key = _clientConnection.register(_selector, SelectionKey.OP_READ);

View File

@ -19,8 +19,13 @@
package com.cloud.utils.nio;
import static com.cloud.utils.AutoCloseableUtil.closeAutoCloseable;
import com.cloud.utils.concurrency.NamedThreadFactory;
import com.cloud.utils.exception.NioConnectionException;
import org.apache.cloudstack.utils.security.SSLUtils;
import org.apache.log4j.Logger;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import java.io.IOException;
import java.net.ConnectException;
import java.net.InetSocketAddress;
@ -44,14 +49,7 @@ import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import org.apache.cloudstack.utils.security.SSLUtils;
import org.apache.log4j.Logger;
import com.cloud.utils.concurrency.NamedThreadFactory;
import com.cloud.utils.exception.NioConnectionException;
import static com.cloud.utils.AutoCloseableUtil.closeAutoCloseable;
/**
* NioConnection abstracts the NIO socket operations. The Java implementation
@ -71,6 +69,7 @@ public abstract class NioConnection implements Callable<Boolean> {
protected HandlerFactory _factory;
protected String _name;
protected ExecutorService _executor;
protected ExecutorService _sslHandshakeExecutor;
public NioConnection(final String name, final int port, final int workers, final HandlerFactory factory) {
_name = name;
@ -79,6 +78,7 @@ public abstract class NioConnection implements Callable<Boolean> {
_port = port;
_factory = factory;
_executor = new ThreadPoolExecutor(workers, 5 * workers, 1, TimeUnit.DAYS, new LinkedBlockingQueue<Runnable>(), new NamedThreadFactory(name + "-Handler"));
_sslHandshakeExecutor = Executors.newCachedThreadPool(new NamedThreadFactory(name + "-SSLHandshakeHandler"));
}
public void start() throws NioConnectionException {
@ -185,8 +185,9 @@ public abstract class NioConnection implements Callable<Boolean> {
protected void accept(final SelectionKey key) throws IOException {
final ServerSocketChannel serverSocketChannel = (ServerSocketChannel)key.channel();
final SocketChannel socketChannel = serverSocketChannel.accept();
socketChannel.configureBlocking(false);
final Socket socket = socketChannel.socket();
socket.setKeepAlive(true);
@ -194,43 +195,52 @@ public abstract class NioConnection implements Callable<Boolean> {
s_logger.trace("Connection accepted for " + socket);
}
// Begin SSL handshake in BLOCKING mode
socketChannel.configureBlocking(true);
SSLEngine sslEngine = null;
final SSLEngine sslEngine;
try {
final SSLContext sslContext = Link.initSSLContext(false);
sslEngine = sslContext.createSSLEngine();
sslEngine.setUseClientMode(false);
sslEngine.setNeedClientAuth(false);
sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols()));
Link.doHandshake(socketChannel, sslEngine, false);
final NioConnection nioConnection = this;
_sslHandshakeExecutor.submit(new Runnable() {
@Override
public void run() {
_selector.wakeup();
try {
sslEngine.beginHandshake();
if (!Link.doHandshake(socketChannel, sslEngine, false)) {
throw new IOException("SSL handshake timed out with " + socketChannel.getRemoteAddress());
}
if (s_logger.isTraceEnabled()) {
s_logger.trace("SSL: Handshake done");
}
final InetSocketAddress saddr = (InetSocketAddress)socket.getRemoteSocketAddress();
final Link link = new Link(saddr, nioConnection);
link.setSSLEngine(sslEngine);
link.setKey(socketChannel.register(key.selector(), SelectionKey.OP_READ, link));
final Task task = _factory.create(Task.Type.CONNECT, link, null);
registerLink(saddr, link);
_executor.submit(task);
} catch (IOException e) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Connection closed due to failure: " + e.getMessage());
}
closeAutoCloseable(socket, "accepting socket");
closeAutoCloseable(socketChannel, "accepting socketChannel");
} finally {
_selector.wakeup();
}
}
});
} catch (final Exception e) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Socket " + socket + " closed on read. Probably -1 returned: " + e.getMessage());
s_logger.trace("Connection closed due to failure: " + e.getMessage());
}
closeAutoCloseable(socket, "accepting socket");
closeAutoCloseable(socketChannel, "accepting socketChannel");
closeAutoCloseable(socket, "opened socket");
return;
}
if (s_logger.isTraceEnabled()) {
s_logger.trace("SSL: Handshake done");
}
socketChannel.configureBlocking(false);
final InetSocketAddress saddr = (InetSocketAddress)socket.getRemoteSocketAddress();
final Link link = new Link(saddr, this);
link.setSSLEngine(sslEngine);
link.setKey(socketChannel.register(key.selector(), SelectionKey.OP_READ, link));
final Task task = _factory.create(Task.Type.CONNECT, link, null);
registerLink(saddr, link);
try {
_executor.submit(task);
} catch (final Exception e) {
s_logger.warn("Exception occurred when submitting the task", e);
} finally {
_selector.wakeup();
}
}

View File

@ -43,6 +43,10 @@ public class NioServer extends NioConnection {
_links = new WeakHashMap<InetSocketAddress, Link>(1024);
}
public int getPort() {
return _serverSocket.socket().getLocalPort();
}
@Override
protected void init() throws IOException {
_selector = SelectorProvider.provider().openSelector();
@ -53,9 +57,9 @@ public class NioServer extends NioConnection {
_localAddr = new InetSocketAddress(_port);
_serverSocket.socket().bind(_localAddr);
_serverSocket.register(_selector, SelectionKey.OP_ACCEPT, null);
_serverSocket.register(_selector, SelectionKey.OP_ACCEPT);
s_logger.info("NioConnection started and listening on " + _localAddr.toString());
s_logger.info("NioConnection started and listening on " + _serverSocket.socket().getLocalSocketAddress());
}
@Override

View File

@ -19,14 +19,7 @@
package com.cloud.utils.testcase;
import java.nio.channels.ClosedChannelException;
import java.util.Random;
import junit.framework.TestCase;
import org.apache.log4j.Logger;
import org.junit.Assert;
import com.cloud.utils.concurrency.NamedThreadFactory;
import com.cloud.utils.exception.NioConnectionException;
import com.cloud.utils.nio.HandlerFactory;
import com.cloud.utils.nio.Link;
@ -34,131 +27,200 @@ import com.cloud.utils.nio.NioClient;
import com.cloud.utils.nio.NioServer;
import com.cloud.utils.nio.Task;
import com.cloud.utils.nio.Task.Type;
import org.apache.log4j.Logger;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
/**
*
*
*
*
* NioTest demonstrates that NioServer can function without getting its main IO
* loop blocked when an aggressive or malicious client connects to the server but
* fail to participate in SSL handshake. In this test, we run bunch of clients
* that send a known payload to the server, to which multiple malicious clients
* also try to connect and hang.
* A malicious client could cause denial-of-service if the server's main IO loop
* along with SSL handshake was blocking. A passing tests shows that NioServer
* can still function in case of connection load and that the main IO loop along
* with SSL handshake is non-blocking with some internal timeout mechanism.
*/
public class NioTest extends TestCase {
public class NioTest {
private static final Logger s_logger = Logger.getLogger(NioTest.class);
private static final Logger LOGGER = Logger.getLogger(NioTest.class);
private NioServer _server;
private NioClient _client;
// Test should fail in due time instead of looping forever
private static final int TESTTIMEOUT = 300000;
private Link _clientLink;
final private int totalTestCount = 5;
private int completedTestCount = 0;
private int _testCount;
private int _completedCount;
private NioServer server;
private List<NioClient> clients = new ArrayList<>();
private List<NioClient> maliciousClients = new ArrayList<>();
private ExecutorService clientExecutor = Executors.newFixedThreadPool(totalTestCount, new NamedThreadFactory("NioClientHandler"));;
private ExecutorService maliciousExecutor = Executors.newFixedThreadPool(5*totalTestCount, new NamedThreadFactory("MaliciousNioClientHandler"));;
private Random randomGenerator = new Random();
private byte[] testBytes;
private boolean isTestsDone() {
boolean result;
synchronized (this) {
result = _testCount == _completedCount;
result = totalTestCount == completedTestCount;
}
return result;
}
private void getOneMoreTest() {
synchronized (this) {
_testCount++;
}
}
private void oneMoreTestDone() {
synchronized (this) {
_completedCount++;
completedTestCount++;
}
}
@Override
@Before
public void setUp() {
s_logger.info("Test");
LOGGER.info("Setting up Benchmark Test");
_testCount = 0;
_completedCount = 0;
completedTestCount = 0;
testBytes = new byte[1000000];
randomGenerator.nextBytes(testBytes);
_server = new NioServer("NioTestServer", 7777, 5, new NioTestServer());
server = new NioServer("NioTestServer", 0, 1, new NioTestServer());
try {
_server.start();
server.start();
} catch (final NioConnectionException e) {
fail(e.getMessage());
Assert.fail(e.getMessage());
}
_client = new NioClient("NioTestServer", "127.0.0.1", 7777, 5, new NioTestClient());
try {
_client.start();
} catch (final NioConnectionException e) {
fail(e.getMessage());
}
while (_clientLink == null) {
try {
s_logger.debug("Link is not up! Waiting ...");
Thread.sleep(1000);
} catch (final InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
for (int i = 0; i < totalTestCount; i++) {
for (int j = 0; j < 4; j++) {
final NioClient maliciousClient = new NioMaliciousClient("NioMaliciousTestClient-" + i, "127.0.0.1", server.getPort(), 1, new NioMaliciousTestClient());
maliciousClients.add(maliciousClient);
maliciousExecutor.submit(new ThreadedNioClient(maliciousClient));
}
final NioClient client = new NioClient("NioTestClient-" + i, "127.0.0.1", server.getPort(), 1, new NioTestClient());
clients.add(client);
clientExecutor.submit(new ThreadedNioClient(client));
}
}
@Override
@After
public void tearDown() {
while (!isTestsDone()) {
try {
s_logger.debug(_completedCount + "/" + _testCount + " tests done. Waiting for completion");
Thread.sleep(1000);
} catch (final InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
stopClient();
stopServer();
}
protected void stopClient() {
_client.stop();
s_logger.info("Client stopped.");
for (NioClient client : clients) {
client.stop();
}
for (NioClient maliciousClient : maliciousClients) {
maliciousClient.stop();
}
LOGGER.info("Clients stopped.");
}
protected void stopServer() {
_server.stop();
s_logger.info("Server stopped.");
server.stop();
LOGGER.info("Server stopped.");
}
protected void setClientLink(final Link link) {
_clientLink = link;
}
Random randomGenerator = new Random();
byte[] _testBytes;
@Test(timeout=TESTTIMEOUT)
public void testConnection() {
_testBytes = new byte[1000000];
randomGenerator.nextBytes(_testBytes);
try {
getOneMoreTest();
_clientLink.send(_testBytes);
s_logger.info("Client: Data sent");
getOneMoreTest();
_clientLink.send(_testBytes);
s_logger.info("Client: Data sent");
} catch (final ClosedChannelException e) {
// TODO Auto-generated catch block
e.printStackTrace();
while (!isTestsDone()) {
try {
LOGGER.debug(completedTestCount + "/" + totalTestCount + " tests done. Waiting for completion");
Thread.sleep(1000);
} catch (final InterruptedException e) {
Assert.fail(e.getMessage());
}
}
LOGGER.debug(completedTestCount + "/" + totalTestCount + " tests done.");
}
protected void doServerProcess(final byte[] data) {
oneMoreTestDone();
Assert.assertArrayEquals(_testBytes, data);
s_logger.info("Verify done.");
Assert.assertArrayEquals(testBytes, data);
LOGGER.info("Verify data received by server done.");
}
public byte[] getTestBytes() {
return testBytes;
}
public class ThreadedNioClient implements Runnable {
final private NioClient client;
ThreadedNioClient(final NioClient client) {
this.client = client;
}
@Override
public void run() {
try {
client.start();
} catch (NioConnectionException e) {
Assert.fail(e.getMessage());
}
}
}
public class NioMaliciousClient extends NioClient {
public NioMaliciousClient(String name, String host, int port, int workers, HandlerFactory factory) {
super(name, host, port, workers, factory);
}
@Override
protected void init() throws IOException {
_selector = Selector.open();
try {
_clientConnection = SocketChannel.open();
LOGGER.info("Connecting to " + _host + ":" + _port);
final InetSocketAddress peerAddr = new InetSocketAddress(_host, _port);
_clientConnection.connect(peerAddr);
// This is done on purpose, the malicious client would connect
// to the server and then do nothing, hence using a large sleep value
Thread.sleep(Long.MAX_VALUE);
} catch (final IOException e) {
_selector.close();
throw e;
} catch (InterruptedException e) {
LOGGER.debug(e.getMessage());
}
}
}
public class NioMaliciousTestClient implements HandlerFactory {
@Override
public Task create(final Type type, final Link link, final byte[] data) {
return new NioMaliciousTestClientHandler(type, link, data);
}
public class NioMaliciousTestClientHandler extends Task {
public NioMaliciousTestClientHandler(final Type type, final Link link, final byte[] data) {
super(type, link, data);
}
@Override
public void doTask(final Task task) {
LOGGER.info("Malicious Client: Received task " + task.getType().toString());
}
}
}
public class NioTestClient implements HandlerFactory {
@ -177,18 +239,23 @@ public class NioTest extends TestCase {
@Override
public void doTask(final Task task) {
if (task.getType() == Task.Type.CONNECT) {
s_logger.info("Client: Received CONNECT task");
setClientLink(task.getLink());
LOGGER.info("Client: Received CONNECT task");
try {
LOGGER.info("Sending data to server");
task.getLink().send(getTestBytes());
} catch (ClosedChannelException e) {
LOGGER.error(e.getMessage());
e.printStackTrace();
}
} else if (task.getType() == Task.Type.DATA) {
s_logger.info("Client: Received DATA task");
LOGGER.info("Client: Received DATA task");
} else if (task.getType() == Task.Type.DISCONNECT) {
s_logger.info("Client: Received DISCONNECT task");
LOGGER.info("Client: Received DISCONNECT task");
stopClient();
} else if (task.getType() == Task.Type.OTHER) {
s_logger.info("Client: Received OTHER task");
LOGGER.info("Client: Received OTHER task");
}
}
}
}
@ -208,15 +275,15 @@ public class NioTest extends TestCase {
@Override
public void doTask(final Task task) {
if (task.getType() == Task.Type.CONNECT) {
s_logger.info("Server: Received CONNECT task");
LOGGER.info("Server: Received CONNECT task");
} else if (task.getType() == Task.Type.DATA) {
s_logger.info("Server: Received DATA task");
LOGGER.info("Server: Received DATA task");
doServerProcess(task.getData());
} else if (task.getType() == Task.Type.DISCONNECT) {
s_logger.info("Server: Received DISCONNECT task");
LOGGER.info("Server: Received DISCONNECT task");
stopServer();
} else if (task.getType() == Task.Type.OTHER) {
s_logger.info("Server: Received OTHER task");
LOGGER.info("Server: Received OTHER task");
}
}