diff --git a/engine/orchestration/src/com/cloud/agent/manager/ClusteredAgentManagerImpl.java b/engine/orchestration/src/com/cloud/agent/manager/ClusteredAgentManagerImpl.java
index f35a4efe36b..91085bf8d14 100644
--- a/engine/orchestration/src/com/cloud/agent/manager/ClusteredAgentManagerImpl.java
+++ b/engine/orchestration/src/com/cloud/agent/manager/ClusteredAgentManagerImpl.java
@@ -499,7 +499,7 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust
SocketChannel ch1 = null;
try {
ch1 = SocketChannel.open(new InetSocketAddress(addr, Port.value()));
- ch1.configureBlocking(true); // make sure we are working at blocking mode
+ ch1.configureBlocking(false);
ch1.socket().setKeepAlive(true);
ch1.socket().setSoTimeout(60 * 1000);
try {
@@ -507,8 +507,11 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust
sslEngine = sslContext.createSSLEngine(ip, Port.value());
sslEngine.setUseClientMode(true);
sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols()));
-
- Link.doHandshake(ch1, sslEngine, true);
+ sslEngine.beginHandshake();
+ if (!Link.doHandshake(ch1, sslEngine, true)) {
+ ch1.close();
+ throw new IOException("SSL handshake failed!");
+ }
s_logger.info("SSL: Handshake done");
} catch (final Exception e) {
ch1.close();
diff --git a/utils/pom.xml b/utils/pom.xml
index 206eb1896a6..9e2358680f3 100755
--- a/utils/pom.xml
+++ b/utils/pom.xml
@@ -208,7 +208,6 @@
com/cloud/utils/testcase/*TestCase*
com/cloud/utils/db/*Test*
- com/cloud/utils/testcase/NioTest.java
diff --git a/utils/src/main/java/com/cloud/utils/nio/Link.java b/utils/src/main/java/com/cloud/utils/nio/Link.java
index 6d6306a53b8..f297d52c077 100644
--- a/utils/src/main/java/com/cloud/utils/nio/Link.java
+++ b/utils/src/main/java/com/cloud/utils/nio/Link.java
@@ -19,36 +19,32 @@
package com.cloud.utils.nio;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.net.InetSocketAddress;
-import java.net.SocketTimeoutException;
-import java.nio.ByteBuffer;
-import java.nio.channels.Channels;
-import java.nio.channels.ClosedChannelException;
-import java.nio.channels.ReadableByteChannel;
-import java.nio.channels.SelectionKey;
-import java.nio.channels.SocketChannel;
-import java.security.GeneralSecurityException;
-import java.security.KeyStore;
-import java.util.concurrent.ConcurrentLinkedQueue;
+import com.cloud.utils.PropertiesUtil;
+import com.cloud.utils.db.DbProperties;
+import org.apache.cloudstack.utils.security.SSLUtils;
+import org.apache.log4j.Logger;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
+import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
-
-import org.apache.cloudstack.utils.security.SSLUtils;
-import org.apache.log4j.Logger;
-
-import com.cloud.utils.PropertiesUtil;
-import com.cloud.utils.db.DbProperties;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.InetSocketAddress;
+import java.nio.ByteBuffer;
+import java.nio.channels.ClosedChannelException;
+import java.nio.channels.SelectionKey;
+import java.nio.channels.SocketChannel;
+import java.security.GeneralSecurityException;
+import java.security.KeyStore;
+import java.util.concurrent.ConcurrentLinkedQueue;
/**
*/
@@ -453,115 +449,185 @@ public class Link {
return sslContext;
}
- public static void doHandshake(SocketChannel ch, SSLEngine sslEngine, boolean isClient) throws IOException {
- if (s_logger.isTraceEnabled()) {
- s_logger.trace("SSL: begin Handshake, isClient: " + isClient);
+ public static ByteBuffer enlargeBuffer(ByteBuffer buffer, final int sessionProposedCapacity) {
+ if (buffer == null || sessionProposedCapacity < 0) {
+ return buffer;
}
-
- SSLEngineResult engResult;
- SSLSession sslSession = sslEngine.getSession();
- HandshakeStatus hsStatus;
- ByteBuffer in_pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
- ByteBuffer in_appBuf = ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40);
- ByteBuffer out_pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
- ByteBuffer out_appBuf = ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40);
- int count;
- ch.socket().setSoTimeout(60 * 1000);
- InputStream inStream = ch.socket().getInputStream();
- // Use readCh to make sure the timeout on reading is working
- ReadableByteChannel readCh = Channels.newChannel(inStream);
-
- if (isClient) {
- hsStatus = SSLEngineResult.HandshakeStatus.NEED_WRAP;
+ if (sessionProposedCapacity > buffer.capacity()) {
+ buffer = ByteBuffer.allocate(sessionProposedCapacity);
} else {
- hsStatus = SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
+ buffer = ByteBuffer.allocate(buffer.capacity() * 2);
}
+ return buffer;
+ }
- while (hsStatus != SSLEngineResult.HandshakeStatus.FINISHED) {
- if (s_logger.isTraceEnabled()) {
- s_logger.trace("SSL: Handshake status " + hsStatus);
- }
- engResult = null;
- if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP) {
- out_pkgBuf.clear();
- out_appBuf.clear();
- out_appBuf.put("Hello".getBytes());
- engResult = sslEngine.wrap(out_appBuf, out_pkgBuf);
- out_pkgBuf.flip();
- int remain = out_pkgBuf.limit();
- while (remain != 0) {
- remain -= ch.write(out_pkgBuf);
- if (remain < 0) {
- throw new IOException("Too much bytes sent?");
- }
- }
- } else if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
- in_appBuf.clear();
- // One packet may contained multiply operation
- if (in_pkgBuf.position() == 0 || !in_pkgBuf.hasRemaining()) {
- in_pkgBuf.clear();
- count = 0;
- try {
- count = readCh.read(in_pkgBuf);
- } catch (SocketTimeoutException ex) {
- if (s_logger.isTraceEnabled()) {
- s_logger.trace("Handshake reading time out! Cut the connection");
- }
- count = -1;
- }
- if (count == -1) {
- throw new IOException("Connection closed with -1 on reading size.");
- }
- in_pkgBuf.flip();
- }
- engResult = sslEngine.unwrap(in_pkgBuf, in_appBuf);
- ByteBuffer tmp_pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
- int loop_count = 0;
- while (engResult.getStatus() == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
- // The client is too slow? Cut it and let it reconnect
- if (loop_count > 10) {
- throw new IOException("Too many times in SSL BUFFER_UNDERFLOW, disconnect guest.");
- }
- // We need more packets to complete this operation
- if (s_logger.isTraceEnabled()) {
- s_logger.trace("SSL: Buffer underflowed, getting more packets");
- }
- tmp_pkgBuf.clear();
- count = ch.read(tmp_pkgBuf);
- if (count == -1) {
- throw new IOException("Connection closed with -1 on reading size.");
- }
- tmp_pkgBuf.flip();
-
- in_pkgBuf.mark();
- in_pkgBuf.position(in_pkgBuf.limit());
- in_pkgBuf.limit(in_pkgBuf.limit() + tmp_pkgBuf.limit());
- in_pkgBuf.put(tmp_pkgBuf);
- in_pkgBuf.reset();
-
- in_appBuf.clear();
- engResult = sslEngine.unwrap(in_pkgBuf, in_appBuf);
- loop_count++;
- }
- } else if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK) {
- Runnable run;
- while ((run = sslEngine.getDelegatedTask()) != null) {
- if (s_logger.isTraceEnabled()) {
- s_logger.trace("SSL: Running delegated task!");
- }
- run.run();
- }
- } else if (hsStatus == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
- throw new IOException("NOT a handshaking!");
- }
- if (engResult != null && engResult.getStatus() != SSLEngineResult.Status.OK) {
- throw new IOException("Fail to handshake! " + engResult.getStatus());
- }
- if (engResult != null)
- hsStatus = engResult.getHandshakeStatus();
- else
- hsStatus = sslEngine.getHandshakeStatus();
+ public static ByteBuffer handleBufferUnderflow(final SSLEngine engine, ByteBuffer buffer) {
+ if (engine == null || buffer == null) {
+ return buffer;
}
+ if (buffer.position() < buffer.limit()) {
+ return buffer;
+ }
+ ByteBuffer replaceBuffer = enlargeBuffer(buffer, engine.getSession().getPacketBufferSize());
+ buffer.flip();
+ replaceBuffer.put(buffer);
+ return replaceBuffer;
+ }
+
+ private static boolean doHandshakeUnwrap(final SocketChannel socketChannel, final SSLEngine sslEngine,
+ ByteBuffer peerAppData, ByteBuffer peerNetData, final int appBufferSize) throws IOException {
+ if (socketChannel == null || sslEngine == null || peerAppData == null || peerNetData == null || appBufferSize < 0) {
+ return false;
+ }
+ if (socketChannel.read(peerNetData) < 0) {
+ if (sslEngine.isInboundDone() && sslEngine.isOutboundDone()) {
+ return false;
+ }
+ try {
+ sslEngine.closeInbound();
+ } catch (SSLException e) {
+ s_logger.warn("This SSL engine was forced to close inbound due to end of stream.");
+ }
+ sslEngine.closeOutbound();
+ // After closeOutbound the engine will be set to WRAP state,
+ // in order to try to send a close message to the client.
+ return true;
+ }
+ peerNetData.flip();
+ SSLEngineResult result = null;
+ try {
+ result = sslEngine.unwrap(peerNetData, peerAppData);
+ peerNetData.compact();
+ } catch (SSLException sslException) {
+ s_logger.error("SSL error occurred while processing unwrap data: " + sslException.getMessage());
+ sslEngine.closeOutbound();
+ return true;
+ }
+ switch (result.getStatus()) {
+ case OK:
+ break;
+ case BUFFER_OVERFLOW:
+ // Will occur when peerAppData's capacity is smaller than the data derived from peerNetData's unwrap.
+ peerAppData = enlargeBuffer(peerAppData, appBufferSize);
+ break;
+ case BUFFER_UNDERFLOW:
+ // Will occur either when no data was read from the peer or when the peerNetData buffer
+ // was too small to hold all peer's data.
+ peerNetData = handleBufferUnderflow(sslEngine, peerNetData);
+ break;
+ case CLOSED:
+ if (sslEngine.isOutboundDone()) {
+ return false;
+ } else {
+ sslEngine.closeOutbound();
+ break;
+ }
+ default:
+ throw new IllegalStateException("Invalid SSL status: " + result.getStatus());
+ }
+ return true;
+ }
+
+ private static boolean doHandshakeWrap(final SocketChannel socketChannel, final SSLEngine sslEngine,
+ ByteBuffer myAppData, ByteBuffer myNetData, ByteBuffer peerNetData,
+ final int netBufferSize) throws IOException {
+ if (socketChannel == null || sslEngine == null || myNetData == null || peerNetData == null
+ || myAppData == null || netBufferSize < 0) {
+ return false;
+ }
+ myNetData.clear();
+ SSLEngineResult result = null;
+ try {
+ result = sslEngine.wrap(myAppData, myNetData);
+ } catch (SSLException sslException) {
+ s_logger.error("SSL error occurred while processing wrap data: " + sslException.getMessage());
+ sslEngine.closeOutbound();
+ return true;
+ }
+ switch (result.getStatus()) {
+ case OK :
+ myNetData.flip();
+ while (myNetData.hasRemaining()) {
+ socketChannel.write(myNetData);
+ }
+ break;
+ case BUFFER_OVERFLOW:
+ // Will occur if there is not enough space in myNetData buffer to write all the data
+ // that would be generated by the method wrap. Since myNetData is set to session's packet
+ // size we should not get to this point because SSLEngine is supposed to produce messages
+ // smaller or equal to that, but a general handling would be the following:
+ myNetData = enlargeBuffer(myNetData, netBufferSize);
+ break;
+ case BUFFER_UNDERFLOW:
+ throw new SSLException("Buffer underflow occurred after a wrap. We should not reach here.");
+ case CLOSED:
+ try {
+ myNetData.flip();
+ while (myNetData.hasRemaining()) {
+ socketChannel.write(myNetData);
+ }
+ // At this point the handshake status will probably be NEED_UNWRAP
+ // so we make sure that peerNetData is clear to read.
+ peerNetData.clear();
+ } catch (Exception e) {
+ s_logger.error("Failed to send server's CLOSE message due to socket channel's failure.");
+ }
+ break;
+ default:
+ throw new IllegalStateException("Invalid SSL status: " + result.getStatus());
+ }
+ return true;
+ }
+
+ public static boolean doHandshake(final SocketChannel socketChannel, final SSLEngine sslEngine, final boolean isClient) throws IOException {
+ if (socketChannel == null || sslEngine == null) {
+ return false;
+ }
+ final int appBufferSize = sslEngine.getSession().getApplicationBufferSize();
+ final int netBufferSize = sslEngine.getSession().getPacketBufferSize();
+ ByteBuffer myAppData = ByteBuffer.allocate(appBufferSize);
+ ByteBuffer peerAppData = ByteBuffer.allocate(appBufferSize);
+ ByteBuffer myNetData = ByteBuffer.allocate(netBufferSize);
+ ByteBuffer peerNetData = ByteBuffer.allocate(netBufferSize);
+
+ final long startTimeMills = System.currentTimeMillis();
+
+ HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
+ while (handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED
+ && handshakeStatus != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
+ final long timeTaken = System.currentTimeMillis() - startTimeMills;
+ if (timeTaken > 60000L) {
+ s_logger.warn("SSL Handshake has taken more than 60s to connect to: " + socketChannel.getRemoteAddress() +
+ ". Please investigate this connection.");
+ return false;
+ }
+ switch (handshakeStatus) {
+ case NEED_UNWRAP:
+ if (!doHandshakeUnwrap(socketChannel, sslEngine, peerAppData, peerNetData, appBufferSize)) {
+ return false;
+ }
+ break;
+ case NEED_WRAP:
+ if (!doHandshakeWrap(socketChannel, sslEngine, myAppData, myNetData, peerNetData, netBufferSize)) {
+ return false;
+ }
+ break;
+ case NEED_TASK:
+ Runnable task;
+ while ((task = sslEngine.getDelegatedTask()) != null) {
+ new Thread(task).run();
+ }
+ break;
+ case FINISHED:
+ break;
+ case NOT_HANDSHAKING:
+ break;
+ default:
+ throw new IllegalStateException("Invalid SSL status: " + handshakeStatus);
+ }
+ handshakeStatus = sslEngine.getHandshakeStatus();
+ }
+ return true;
}
}
diff --git a/utils/src/main/java/com/cloud/utils/nio/NioClient.java b/utils/src/main/java/com/cloud/utils/nio/NioClient.java
index d989f306c55..dc4f670de12 100644
--- a/utils/src/main/java/com/cloud/utils/nio/NioClient.java
+++ b/utils/src/main/java/com/cloud/utils/nio/NioClient.java
@@ -36,7 +36,6 @@ public class NioClient extends NioConnection {
private static final Logger s_logger = Logger.getLogger(NioClient.class);
protected String _host;
- protected String _bindAddress;
protected SocketChannel _clientConnection;
public NioClient(final String name, final String host, final int port, final int workers, final HandlerFactory factory) {
@@ -44,10 +43,6 @@ public class NioClient extends NioConnection {
_host = host;
}
- public void setBindAddress(final String ipAddress) {
- _bindAddress = ipAddress;
- }
-
@Override
protected void init() throws IOException {
_selector = Selector.open();
@@ -55,33 +50,25 @@ public class NioClient extends NioConnection {
try {
_clientConnection = SocketChannel.open();
- _clientConnection.configureBlocking(true);
+
s_logger.info("Connecting to " + _host + ":" + _port);
-
- if (_bindAddress != null) {
- s_logger.info("Binding outbound interface at " + _bindAddress);
-
- final InetSocketAddress bindAddr = new InetSocketAddress(_bindAddress, 0);
- _clientConnection.socket().bind(bindAddr);
- }
-
final InetSocketAddress peerAddr = new InetSocketAddress(_host, _port);
_clientConnection.connect(peerAddr);
-
- SSLEngine sslEngine = null;
- // Begin SSL handshake in BLOCKING mode
- _clientConnection.configureBlocking(true);
+ _clientConnection.configureBlocking(false);
final SSLContext sslContext = Link.initSSLContext(true);
- sslEngine = sslContext.createSSLEngine(_host, _port);
+ SSLEngine sslEngine = sslContext.createSSLEngine(_host, _port);
sslEngine.setUseClientMode(true);
sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols()));
-
- Link.doHandshake(_clientConnection, sslEngine, true);
+ sslEngine.beginHandshake();
+ if (!Link.doHandshake(_clientConnection, sslEngine, true)) {
+ s_logger.error("SSL Handshake failed while connecting to host: " + _host + " port: " + _port);
+ _selector.close();
+ throw new IOException("SSL Handshake failed while connecting to host: " + _host + " port: " + _port);
+ }
s_logger.info("SSL: Handshake done");
s_logger.info("Connected to " + _host + ":" + _port);
- _clientConnection.configureBlocking(false);
final Link link = new Link(peerAddr, this);
link.setSSLEngine(sslEngine);
final SelectionKey key = _clientConnection.register(_selector, SelectionKey.OP_READ);
diff --git a/utils/src/main/java/com/cloud/utils/nio/NioConnection.java b/utils/src/main/java/com/cloud/utils/nio/NioConnection.java
index 249f512d9c9..6fdb4736ac7 100644
--- a/utils/src/main/java/com/cloud/utils/nio/NioConnection.java
+++ b/utils/src/main/java/com/cloud/utils/nio/NioConnection.java
@@ -19,8 +19,13 @@
package com.cloud.utils.nio;
-import static com.cloud.utils.AutoCloseableUtil.closeAutoCloseable;
+import com.cloud.utils.concurrency.NamedThreadFactory;
+import com.cloud.utils.exception.NioConnectionException;
+import org.apache.cloudstack.utils.security.SSLUtils;
+import org.apache.log4j.Logger;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
import java.io.IOException;
import java.net.ConnectException;
import java.net.InetSocketAddress;
@@ -44,14 +49,7 @@ import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
-import javax.net.ssl.SSLContext;
-import javax.net.ssl.SSLEngine;
-
-import org.apache.cloudstack.utils.security.SSLUtils;
-import org.apache.log4j.Logger;
-
-import com.cloud.utils.concurrency.NamedThreadFactory;
-import com.cloud.utils.exception.NioConnectionException;
+import static com.cloud.utils.AutoCloseableUtil.closeAutoCloseable;
/**
* NioConnection abstracts the NIO socket operations. The Java implementation
@@ -71,6 +69,7 @@ public abstract class NioConnection implements Callable {
protected HandlerFactory _factory;
protected String _name;
protected ExecutorService _executor;
+ protected ExecutorService _sslHandshakeExecutor;
public NioConnection(final String name, final int port, final int workers, final HandlerFactory factory) {
_name = name;
@@ -79,6 +78,7 @@ public abstract class NioConnection implements Callable {
_port = port;
_factory = factory;
_executor = new ThreadPoolExecutor(workers, 5 * workers, 1, TimeUnit.DAYS, new LinkedBlockingQueue(), new NamedThreadFactory(name + "-Handler"));
+ _sslHandshakeExecutor = Executors.newCachedThreadPool(new NamedThreadFactory(name + "-SSLHandshakeHandler"));
}
public void start() throws NioConnectionException {
@@ -185,8 +185,9 @@ public abstract class NioConnection implements Callable {
protected void accept(final SelectionKey key) throws IOException {
final ServerSocketChannel serverSocketChannel = (ServerSocketChannel)key.channel();
-
final SocketChannel socketChannel = serverSocketChannel.accept();
+ socketChannel.configureBlocking(false);
+
final Socket socket = socketChannel.socket();
socket.setKeepAlive(true);
@@ -194,43 +195,52 @@ public abstract class NioConnection implements Callable {
s_logger.trace("Connection accepted for " + socket);
}
- // Begin SSL handshake in BLOCKING mode
- socketChannel.configureBlocking(true);
-
- SSLEngine sslEngine = null;
+ final SSLEngine sslEngine;
try {
final SSLContext sslContext = Link.initSSLContext(false);
sslEngine = sslContext.createSSLEngine();
sslEngine.setUseClientMode(false);
sslEngine.setNeedClientAuth(false);
sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols()));
-
- Link.doHandshake(socketChannel, sslEngine, false);
-
+ final NioConnection nioConnection = this;
+ _sslHandshakeExecutor.submit(new Runnable() {
+ @Override
+ public void run() {
+ _selector.wakeup();
+ try {
+ sslEngine.beginHandshake();
+ if (!Link.doHandshake(socketChannel, sslEngine, false)) {
+ throw new IOException("SSL handshake timed out with " + socketChannel.getRemoteAddress());
+ }
+ if (s_logger.isTraceEnabled()) {
+ s_logger.trace("SSL: Handshake done");
+ }
+ final InetSocketAddress saddr = (InetSocketAddress)socket.getRemoteSocketAddress();
+ final Link link = new Link(saddr, nioConnection);
+ link.setSSLEngine(sslEngine);
+ link.setKey(socketChannel.register(key.selector(), SelectionKey.OP_READ, link));
+ final Task task = _factory.create(Task.Type.CONNECT, link, null);
+ registerLink(saddr, link);
+ _executor.submit(task);
+ } catch (IOException e) {
+ if (s_logger.isTraceEnabled()) {
+ s_logger.trace("Connection closed due to failure: " + e.getMessage());
+ }
+ closeAutoCloseable(socket, "accepting socket");
+ closeAutoCloseable(socketChannel, "accepting socketChannel");
+ } finally {
+ _selector.wakeup();
+ }
+ }
+ });
} catch (final Exception e) {
if (s_logger.isTraceEnabled()) {
- s_logger.trace("Socket " + socket + " closed on read. Probably -1 returned: " + e.getMessage());
+ s_logger.trace("Connection closed due to failure: " + e.getMessage());
}
+ closeAutoCloseable(socket, "accepting socket");
closeAutoCloseable(socketChannel, "accepting socketChannel");
- closeAutoCloseable(socket, "opened socket");
- return;
- }
-
- if (s_logger.isTraceEnabled()) {
- s_logger.trace("SSL: Handshake done");
- }
- socketChannel.configureBlocking(false);
- final InetSocketAddress saddr = (InetSocketAddress)socket.getRemoteSocketAddress();
- final Link link = new Link(saddr, this);
- link.setSSLEngine(sslEngine);
- link.setKey(socketChannel.register(key.selector(), SelectionKey.OP_READ, link));
- final Task task = _factory.create(Task.Type.CONNECT, link, null);
- registerLink(saddr, link);
-
- try {
- _executor.submit(task);
- } catch (final Exception e) {
- s_logger.warn("Exception occurred when submitting the task", e);
+ } finally {
+ _selector.wakeup();
}
}
diff --git a/utils/src/main/java/com/cloud/utils/nio/NioServer.java b/utils/src/main/java/com/cloud/utils/nio/NioServer.java
index 539c2bb13d8..13d5476cba0 100644
--- a/utils/src/main/java/com/cloud/utils/nio/NioServer.java
+++ b/utils/src/main/java/com/cloud/utils/nio/NioServer.java
@@ -43,6 +43,10 @@ public class NioServer extends NioConnection {
_links = new WeakHashMap(1024);
}
+ public int getPort() {
+ return _serverSocket.socket().getLocalPort();
+ }
+
@Override
protected void init() throws IOException {
_selector = SelectorProvider.provider().openSelector();
@@ -53,9 +57,9 @@ public class NioServer extends NioConnection {
_localAddr = new InetSocketAddress(_port);
_serverSocket.socket().bind(_localAddr);
- _serverSocket.register(_selector, SelectionKey.OP_ACCEPT, null);
+ _serverSocket.register(_selector, SelectionKey.OP_ACCEPT);
- s_logger.info("NioConnection started and listening on " + _localAddr.toString());
+ s_logger.info("NioConnection started and listening on " + _serverSocket.socket().getLocalSocketAddress());
}
@Override
diff --git a/utils/src/test/java/com/cloud/utils/testcase/NioTest.java b/utils/src/test/java/com/cloud/utils/testcase/NioTest.java
index d8510cfcac2..20c31e21b1a 100644
--- a/utils/src/test/java/com/cloud/utils/testcase/NioTest.java
+++ b/utils/src/test/java/com/cloud/utils/testcase/NioTest.java
@@ -19,14 +19,7 @@
package com.cloud.utils.testcase;
-import java.nio.channels.ClosedChannelException;
-import java.util.Random;
-
-import junit.framework.TestCase;
-
-import org.apache.log4j.Logger;
-import org.junit.Assert;
-
+import com.cloud.utils.concurrency.NamedThreadFactory;
import com.cloud.utils.exception.NioConnectionException;
import com.cloud.utils.nio.HandlerFactory;
import com.cloud.utils.nio.Link;
@@ -34,131 +27,200 @@ import com.cloud.utils.nio.NioClient;
import com.cloud.utils.nio.NioServer;
import com.cloud.utils.nio.Task;
import com.cloud.utils.nio.Task.Type;
+import org.apache.log4j.Logger;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.nio.channels.ClosedChannelException;
+import java.nio.channels.Selector;
+import java.nio.channels.SocketChannel;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
/**
- *
- *
- *
- *
+ * NioTest demonstrates that NioServer can function without getting its main IO
+ * loop blocked when an aggressive or malicious client connects to the server but
+ * fail to participate in SSL handshake. In this test, we run bunch of clients
+ * that send a known payload to the server, to which multiple malicious clients
+ * also try to connect and hang.
+ * A malicious client could cause denial-of-service if the server's main IO loop
+ * along with SSL handshake was blocking. A passing tests shows that NioServer
+ * can still function in case of connection load and that the main IO loop along
+ * with SSL handshake is non-blocking with some internal timeout mechanism.
*/
-public class NioTest extends TestCase {
+public class NioTest {
- private static final Logger s_logger = Logger.getLogger(NioTest.class);
+ private static final Logger LOGGER = Logger.getLogger(NioTest.class);
- private NioServer _server;
- private NioClient _client;
+ // Test should fail in due time instead of looping forever
+ private static final int TESTTIMEOUT = 300000;
- private Link _clientLink;
+ final private int totalTestCount = 5;
+ private int completedTestCount = 0;
- private int _testCount;
- private int _completedCount;
+ private NioServer server;
+ private List clients = new ArrayList<>();
+ private List maliciousClients = new ArrayList<>();
+
+ private ExecutorService clientExecutor = Executors.newFixedThreadPool(totalTestCount, new NamedThreadFactory("NioClientHandler"));;
+ private ExecutorService maliciousExecutor = Executors.newFixedThreadPool(5*totalTestCount, new NamedThreadFactory("MaliciousNioClientHandler"));;
+
+ private Random randomGenerator = new Random();
+ private byte[] testBytes;
private boolean isTestsDone() {
boolean result;
synchronized (this) {
- result = _testCount == _completedCount;
+ result = totalTestCount == completedTestCount;
}
return result;
}
- private void getOneMoreTest() {
- synchronized (this) {
- _testCount++;
- }
- }
-
private void oneMoreTestDone() {
synchronized (this) {
- _completedCount++;
+ completedTestCount++;
}
}
- @Override
+ @Before
public void setUp() {
- s_logger.info("Test");
+ LOGGER.info("Setting up Benchmark Test");
- _testCount = 0;
- _completedCount = 0;
+ completedTestCount = 0;
+ testBytes = new byte[1000000];
+ randomGenerator.nextBytes(testBytes);
- _server = new NioServer("NioTestServer", 7777, 5, new NioTestServer());
+ server = new NioServer("NioTestServer", 0, 1, new NioTestServer());
try {
- _server.start();
+ server.start();
} catch (final NioConnectionException e) {
- fail(e.getMessage());
+ Assert.fail(e.getMessage());
}
- _client = new NioClient("NioTestServer", "127.0.0.1", 7777, 5, new NioTestClient());
- try {
- _client.start();
- } catch (final NioConnectionException e) {
- fail(e.getMessage());
- }
-
- while (_clientLink == null) {
- try {
- s_logger.debug("Link is not up! Waiting ...");
- Thread.sleep(1000);
- } catch (final InterruptedException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
+ for (int i = 0; i < totalTestCount; i++) {
+ for (int j = 0; j < 4; j++) {
+ final NioClient maliciousClient = new NioMaliciousClient("NioMaliciousTestClient-" + i, "127.0.0.1", server.getPort(), 1, new NioMaliciousTestClient());
+ maliciousClients.add(maliciousClient);
+ maliciousExecutor.submit(new ThreadedNioClient(maliciousClient));
}
+ final NioClient client = new NioClient("NioTestClient-" + i, "127.0.0.1", server.getPort(), 1, new NioTestClient());
+ clients.add(client);
+ clientExecutor.submit(new ThreadedNioClient(client));
}
}
- @Override
+ @After
public void tearDown() {
- while (!isTestsDone()) {
- try {
- s_logger.debug(_completedCount + "/" + _testCount + " tests done. Waiting for completion");
- Thread.sleep(1000);
- } catch (final InterruptedException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
- }
stopClient();
stopServer();
}
protected void stopClient() {
- _client.stop();
- s_logger.info("Client stopped.");
+ for (NioClient client : clients) {
+ client.stop();
+ }
+ for (NioClient maliciousClient : maliciousClients) {
+ maliciousClient.stop();
+ }
+ LOGGER.info("Clients stopped.");
}
protected void stopServer() {
- _server.stop();
- s_logger.info("Server stopped.");
+ server.stop();
+ LOGGER.info("Server stopped.");
}
- protected void setClientLink(final Link link) {
- _clientLink = link;
- }
-
- Random randomGenerator = new Random();
-
- byte[] _testBytes;
-
+ @Test(timeout=TESTTIMEOUT)
public void testConnection() {
- _testBytes = new byte[1000000];
- randomGenerator.nextBytes(_testBytes);
- try {
- getOneMoreTest();
- _clientLink.send(_testBytes);
- s_logger.info("Client: Data sent");
- getOneMoreTest();
- _clientLink.send(_testBytes);
- s_logger.info("Client: Data sent");
- } catch (final ClosedChannelException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
+ while (!isTestsDone()) {
+ try {
+ LOGGER.debug(completedTestCount + "/" + totalTestCount + " tests done. Waiting for completion");
+ Thread.sleep(1000);
+ } catch (final InterruptedException e) {
+ Assert.fail(e.getMessage());
+ }
}
+ LOGGER.debug(completedTestCount + "/" + totalTestCount + " tests done.");
}
protected void doServerProcess(final byte[] data) {
oneMoreTestDone();
- Assert.assertArrayEquals(_testBytes, data);
- s_logger.info("Verify done.");
+ Assert.assertArrayEquals(testBytes, data);
+ LOGGER.info("Verify data received by server done.");
+ }
+
+ public byte[] getTestBytes() {
+ return testBytes;
+ }
+
+ public class ThreadedNioClient implements Runnable {
+ final private NioClient client;
+ ThreadedNioClient(final NioClient client) {
+ this.client = client;
+ }
+
+ @Override
+ public void run() {
+ try {
+ client.start();
+ } catch (NioConnectionException e) {
+ Assert.fail(e.getMessage());
+ }
+ }
+ }
+
+ public class NioMaliciousClient extends NioClient {
+
+ public NioMaliciousClient(String name, String host, int port, int workers, HandlerFactory factory) {
+ super(name, host, port, workers, factory);
+ }
+
+ @Override
+ protected void init() throws IOException {
+ _selector = Selector.open();
+ try {
+ _clientConnection = SocketChannel.open();
+ LOGGER.info("Connecting to " + _host + ":" + _port);
+ final InetSocketAddress peerAddr = new InetSocketAddress(_host, _port);
+ _clientConnection.connect(peerAddr);
+ // This is done on purpose, the malicious client would connect
+ // to the server and then do nothing, hence using a large sleep value
+ Thread.sleep(Long.MAX_VALUE);
+ } catch (final IOException e) {
+ _selector.close();
+ throw e;
+ } catch (InterruptedException e) {
+ LOGGER.debug(e.getMessage());
+ }
+ }
+ }
+
+ public class NioMaliciousTestClient implements HandlerFactory {
+
+ @Override
+ public Task create(final Type type, final Link link, final byte[] data) {
+ return new NioMaliciousTestClientHandler(type, link, data);
+ }
+
+ public class NioMaliciousTestClientHandler extends Task {
+
+ public NioMaliciousTestClientHandler(final Type type, final Link link, final byte[] data) {
+ super(type, link, data);
+ }
+
+ @Override
+ public void doTask(final Task task) {
+ LOGGER.info("Malicious Client: Received task " + task.getType().toString());
+ }
+ }
}
public class NioTestClient implements HandlerFactory {
@@ -177,18 +239,23 @@ public class NioTest extends TestCase {
@Override
public void doTask(final Task task) {
if (task.getType() == Task.Type.CONNECT) {
- s_logger.info("Client: Received CONNECT task");
- setClientLink(task.getLink());
+ LOGGER.info("Client: Received CONNECT task");
+ try {
+ LOGGER.info("Sending data to server");
+ task.getLink().send(getTestBytes());
+ } catch (ClosedChannelException e) {
+ LOGGER.error(e.getMessage());
+ e.printStackTrace();
+ }
} else if (task.getType() == Task.Type.DATA) {
- s_logger.info("Client: Received DATA task");
+ LOGGER.info("Client: Received DATA task");
} else if (task.getType() == Task.Type.DISCONNECT) {
- s_logger.info("Client: Received DISCONNECT task");
+ LOGGER.info("Client: Received DISCONNECT task");
stopClient();
} else if (task.getType() == Task.Type.OTHER) {
- s_logger.info("Client: Received OTHER task");
+ LOGGER.info("Client: Received OTHER task");
}
}
-
}
}
@@ -208,15 +275,15 @@ public class NioTest extends TestCase {
@Override
public void doTask(final Task task) {
if (task.getType() == Task.Type.CONNECT) {
- s_logger.info("Server: Received CONNECT task");
+ LOGGER.info("Server: Received CONNECT task");
} else if (task.getType() == Task.Type.DATA) {
- s_logger.info("Server: Received DATA task");
+ LOGGER.info("Server: Received DATA task");
doServerProcess(task.getData());
} else if (task.getType() == Task.Type.DISCONNECT) {
- s_logger.info("Server: Received DISCONNECT task");
+ LOGGER.info("Server: Received DISCONNECT task");
stopServer();
} else if (task.getType() == Task.Type.OTHER) {
- s_logger.info("Server: Received OTHER task");
+ LOGGER.info("Server: Received OTHER task");
}
}