CLOUDSTACK-9348: Improve Nio SSH handshake buffers

Use a holder class to pass buffers, fixes potential leak.

Signed-off-by: Rohit Yadav <rohit.yadav@shapeblue.com>
This commit is contained in:
Rohit Yadav 2017-11-30 16:51:48 +05:30
parent 893f2af31f
commit d0005d8353
4 changed files with 58 additions and 20 deletions

View File

@ -519,7 +519,7 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust
sslEngine.setUseClientMode(true); sslEngine.setUseClientMode(true);
sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols())); sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols()));
sslEngine.beginHandshake(); sslEngine.beginHandshake();
if (!Link.doHandshake(ch1, sslEngine, true)) { if (!Link.doHandshake(ch1, sslEngine)) {
ch1.close(); ch1.close();
throw new IOException(String.format("SSL: Handshake failed with peer management server '%s' on %s:%d ", peerName, ip, port)); throw new IOException(String.format("SSL: Handshake failed with peer management server '%s' on %s:%d ", peerName, ip, port));
} }

View File

@ -32,6 +32,8 @@ import java.security.GeneralSecurityException;
import java.security.KeyStore; import java.security.KeyStore;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
@ -462,7 +464,7 @@ public class Link {
return buffer; return buffer;
} }
public static ByteBuffer handleBufferUnderflow(final SSLEngine engine, ByteBuffer buffer) { public static ByteBuffer handleBufferUnderflow(final SSLEngine engine, final ByteBuffer buffer) {
if (engine == null || buffer == null) { if (engine == null || buffer == null) {
return buffer; return buffer;
} }
@ -475,14 +477,14 @@ public class Link {
return replaceBuffer; return replaceBuffer;
} }
private static boolean doHandshakeUnwrap(final SocketChannel socketChannel, final SSLEngine sslEngine, private static HandshakeHolder doHandshakeUnwrap(final SocketChannel socketChannel, final SSLEngine sslEngine,
ByteBuffer peerAppData, ByteBuffer peerNetData, final int appBufferSize) throws IOException { ByteBuffer peerAppData, ByteBuffer peerNetData, final int appBufferSize) throws IOException {
if (socketChannel == null || sslEngine == null || peerAppData == null || peerNetData == null || appBufferSize < 0) { if (socketChannel == null || sslEngine == null || peerAppData == null || peerNetData == null || appBufferSize < 0) {
return false; return new HandshakeHolder(peerAppData, peerNetData, false);
} }
if (socketChannel.read(peerNetData) < 0) { if (socketChannel.read(peerNetData) < 0) {
if (sslEngine.isInboundDone() && sslEngine.isOutboundDone()) { if (sslEngine.isInboundDone() && sslEngine.isOutboundDone()) {
return false; return new HandshakeHolder(peerAppData, peerNetData, false);
} }
try { try {
sslEngine.closeInbound(); sslEngine.closeInbound();
@ -492,7 +494,7 @@ public class Link {
sslEngine.closeOutbound(); sslEngine.closeOutbound();
// After closeOutbound the engine will be set to WRAP state, // After closeOutbound the engine will be set to WRAP state,
// in order to try to send a close message to the client. // in order to try to send a close message to the client.
return true; return new HandshakeHolder(peerAppData, peerNetData, true);
} }
peerNetData.flip(); peerNetData.flip();
SSLEngineResult result = null; SSLEngineResult result = null;
@ -503,7 +505,10 @@ public class Link {
s_logger.error(String.format("SSL error caught during unwrap data: %s, for local address=%s, remote address=%s. The client may have invalid ca-certificates.", s_logger.error(String.format("SSL error caught during unwrap data: %s, for local address=%s, remote address=%s. The client may have invalid ca-certificates.",
sslException.getMessage(), socketChannel.getLocalAddress(), socketChannel.getRemoteAddress())); sslException.getMessage(), socketChannel.getLocalAddress(), socketChannel.getRemoteAddress()));
sslEngine.closeOutbound(); sslEngine.closeOutbound();
return false; return new HandshakeHolder(peerAppData, peerNetData, true);
}
if (result == null) {
return new HandshakeHolder(peerAppData, peerNetData, false);
} }
switch (result.getStatus()) { switch (result.getStatus()) {
case OK: case OK:
@ -519,23 +524,23 @@ public class Link {
break; break;
case CLOSED: case CLOSED:
if (sslEngine.isOutboundDone()) { if (sslEngine.isOutboundDone()) {
return false; return new HandshakeHolder(peerAppData, peerNetData, false);
} else { } else {
sslEngine.closeOutbound(); sslEngine.closeOutbound();
break;
} }
break;
default: default:
throw new IllegalStateException("Invalid SSL status: " + result.getStatus()); throw new IllegalStateException("Invalid SSL status: " + result.getStatus());
} }
return true; return new HandshakeHolder(peerAppData, peerNetData, true);
} }
private static boolean doHandshakeWrap(final SocketChannel socketChannel, final SSLEngine sslEngine, private static HandshakeHolder doHandshakeWrap(final SocketChannel socketChannel, final SSLEngine sslEngine,
ByteBuffer myAppData, ByteBuffer myNetData, ByteBuffer peerNetData, ByteBuffer myAppData, ByteBuffer myNetData, ByteBuffer peerNetData,
final int netBufferSize) throws IOException { final int netBufferSize) throws IOException {
if (socketChannel == null || sslEngine == null || myNetData == null || peerNetData == null if (socketChannel == null || sslEngine == null || myNetData == null || peerNetData == null
|| myAppData == null || netBufferSize < 0) { || myAppData == null || netBufferSize < 0) {
return false; return new HandshakeHolder(myAppData, myNetData, false);
} }
myNetData.clear(); myNetData.clear();
SSLEngineResult result = null; SSLEngineResult result = null;
@ -545,7 +550,10 @@ public class Link {
s_logger.error(String.format("SSL error caught during wrap data: %s, for local address=%s, remote address=%s.", s_logger.error(String.format("SSL error caught during wrap data: %s, for local address=%s, remote address=%s.",
sslException.getMessage(), socketChannel.getLocalAddress(), socketChannel.getRemoteAddress())); sslException.getMessage(), socketChannel.getLocalAddress(), socketChannel.getRemoteAddress()));
sslEngine.closeOutbound(); sslEngine.closeOutbound();
return false; return new HandshakeHolder(myAppData, myNetData, true);
}
if (result == null) {
return new HandshakeHolder(myAppData, myNetData, false);
} }
switch (result.getStatus()) { switch (result.getStatus()) {
case OK : case OK :
@ -579,10 +587,10 @@ public class Link {
default: default:
throw new IllegalStateException("Invalid SSL status: " + result.getStatus()); throw new IllegalStateException("Invalid SSL status: " + result.getStatus());
} }
return true; return new HandshakeHolder(myAppData, myNetData, true);
} }
public static boolean doHandshake(final SocketChannel socketChannel, final SSLEngine sslEngine, final boolean isClient) throws IOException { public static boolean doHandshake(final SocketChannel socketChannel, final SSLEngine sslEngine) throws IOException {
if (socketChannel == null || sslEngine == null) { if (socketChannel == null || sslEngine == null) {
return false; return false;
} }
@ -593,6 +601,7 @@ public class Link {
ByteBuffer myNetData = ByteBuffer.allocate(netBufferSize); ByteBuffer myNetData = ByteBuffer.allocate(netBufferSize);
ByteBuffer peerNetData = ByteBuffer.allocate(netBufferSize); ByteBuffer peerNetData = ByteBuffer.allocate(netBufferSize);
final Executor executor = Executors.newSingleThreadExecutor();
final long startTimeMills = System.currentTimeMillis(); final long startTimeMills = System.currentTimeMillis();
HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus(); HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
@ -606,12 +615,17 @@ public class Link {
} }
switch (handshakeStatus) { switch (handshakeStatus) {
case NEED_UNWRAP: case NEED_UNWRAP:
if (!doHandshakeUnwrap(socketChannel, sslEngine, peerAppData, peerNetData, appBufferSize)) { final HandshakeHolder unwrapResult = doHandshakeUnwrap(socketChannel, sslEngine, peerAppData, peerNetData, appBufferSize);
peerAppData = unwrapResult.getAppDataBuffer();
peerNetData = unwrapResult.getNetDataBuffer();
if (!unwrapResult.isSuccess()) {
return false; return false;
} }
break; break;
case NEED_WRAP: case NEED_WRAP:
if (!doHandshakeWrap(socketChannel, sslEngine, myAppData, myNetData, peerNetData, netBufferSize)) { final HandshakeHolder wrapResult = doHandshakeWrap(socketChannel, sslEngine, myAppData, myNetData, peerNetData, netBufferSize);
myNetData = wrapResult.getNetDataBuffer();
if (!wrapResult.isSuccess()) {
return false; return false;
} }
break; break;
@ -621,7 +635,7 @@ public class Link {
if (s_logger.isTraceEnabled()) { if (s_logger.isTraceEnabled()) {
s_logger.trace("SSL: Running delegated task!"); s_logger.trace("SSL: Running delegated task!");
} }
task.run(); executor.execute(task);
} }
break; break;
case FINISHED: case FINISHED:
@ -636,4 +650,28 @@ public class Link {
return true; return true;
} }
private static class HandshakeHolder {
private ByteBuffer appData;
private ByteBuffer netData;
private boolean success = true;
HandshakeHolder(ByteBuffer appData, ByteBuffer netData, boolean success) {
this.appData = appData;
this.netData = netData;
this.success = success;
}
ByteBuffer getAppDataBuffer() {
return appData;
}
ByteBuffer getNetDataBuffer() {
return netData;
}
boolean isSuccess() {
return success;
}
}
} }

View File

@ -61,7 +61,7 @@ public class NioClient extends NioConnection {
sslEngine.setUseClientMode(true); sslEngine.setUseClientMode(true);
sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols())); sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols()));
sslEngine.beginHandshake(); sslEngine.beginHandshake();
if (!Link.doHandshake(_clientConnection, sslEngine, true)) { if (!Link.doHandshake(_clientConnection, sslEngine)) {
s_logger.error("SSL Handshake failed while connecting to host: " + _host + " port: " + _port); s_logger.error("SSL Handshake failed while connecting to host: " + _host + " port: " + _port);
_selector.close(); _selector.close();
throw new IOException("SSL Handshake failed while connecting to host: " + _host + " port: " + _port); throw new IOException("SSL Handshake failed while connecting to host: " + _host + " port: " + _port);

View File

@ -213,7 +213,7 @@ public abstract class NioConnection implements Callable<Boolean> {
_selector.wakeup(); _selector.wakeup();
try { try {
sslEngine.beginHandshake(); sslEngine.beginHandshake();
if (!Link.doHandshake(socketChannel, sslEngine, false)) { if (!Link.doHandshake(socketChannel, sslEngine)) {
throw new IOException("SSL handshake timed out with " + socketChannel.getRemoteAddress()); throw new IOException("SSL handshake timed out with " + socketChannel.getRemoteAddress());
} }
if (s_logger.isTraceEnabled()) { if (s_logger.isTraceEnabled()) {