bug 10135: Add SSL encryption for non-NIO link

Now Link.write() support SSL encryption. And since there is no user of
Link.read(), we comment it out.
This commit is contained in:
Sheng Yang 2011-06-04 21:27:18 -07:00
parent d9e0bcfa1e
commit ff86c865e2
5 changed files with 219 additions and 198 deletions

View File

@ -10,6 +10,8 @@ import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SocketChannel;
import javax.net.ssl.SSLEngine;
import org.apache.log4j.Logger;
import com.cloud.agent.AgentManager;
@ -126,6 +128,11 @@ public class ClusteredAgentAttache extends ConnectedAgentAttache implements Rout
}
continue;
}
SSLEngine sslEngine = s_clusteredAgentMgr.getSSLEngine(peerName);
if (sslEngine == null) {
throw new AgentUnavailableException("Unable to get SSLEngine of peer " + peerName, _id);
}
try {
if (s_logger.isDebugEnabled()) {
@ -135,7 +142,7 @@ public class ClusteredAgentAttache extends ConnectedAgentAttache implements Rout
SynchronousListener synchronous = (SynchronousListener)listener;
synchronous.setPeer(peerName);
}
Link.write(ch, req.toBytes());
Link.write(ch, req.toBytes(), sslEngine);
error = false;
return;
} catch (IOException e) {

View File

@ -27,6 +27,8 @@ import java.util.concurrent.TimeUnit;
import javax.ejb.Local;
import javax.naming.ConfigurationException;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import org.apache.log4j.Logger;
@ -89,6 +91,7 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust
protected ClusterManager _clusterMgr = null;
protected HashMap<String, SocketChannel> _peers;
protected HashMap<String, SSLEngine> _sslEngines;
private final Timer _timer = new Timer("ClusteredAgentManager Timer");
@Inject
@ -106,6 +109,7 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust
@Override
public boolean configure(String name, Map<String, Object> xmlParams) throws ConfigurationException {
_peers = new HashMap<String, SocketChannel>(7);
_sslEngines = new HashMap<String, SSLEngine>(7);
_nodeId = _clusterMgr.getManagementNodeId();
ConfigurationDao configDao = ComponentLocator.getCurrentLocator().getDao(ConfigurationDao.class);
@ -406,6 +410,7 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust
public boolean routeToPeer(String peer, byte[] bytes) {
int i = 0;
SocketChannel ch = null;
SSLEngine sslEngine = null;
while (i++ < 5) {
ch = connectToPeer(peer, ch);
if (ch == null) {
@ -415,11 +420,16 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust
}
return false;
}
sslEngine = getSSLEngine(peer);
if (sslEngine == null) {
logD(bytes, "Unable to get SSLEngine of peer: " + peer);
return false;
}
try {
if (s_logger.isDebugEnabled()) {
logD(bytes, "Routing to peer");
}
Link.write(ch, new ByteBuffer[] { ByteBuffer.wrap(bytes) });
Link.write(ch, new ByteBuffer[] { ByteBuffer.wrap(bytes) }, sslEngine);
return true;
} catch (IOException e) {
try {
@ -434,6 +444,10 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust
public String findPeer(long hostId) {
return _clusterMgr.getPeerName(hostId);
}
public SSLEngine getSSLEngine(String peerName) {
return _sslEngines.get(peerName);
}
public void cancel(String peerName, long hostId, long sequence, String reason) {
CancelCommand cancel = new CancelCommand(sequence, reason);
@ -453,12 +467,14 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust
}
}
_peers.remove(peerName);
_sslEngines.remove(peerName);
}
}
public SocketChannel connectToPeer(String peerName, SocketChannel prevCh) {
synchronized (_peers) {
SocketChannel ch = _peers.get(peerName);
SSLEngine sslEngine = null;
if (prevCh != null) {
try {
prevCh.close();
@ -483,10 +499,21 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust
ch.configureBlocking(true); // make sure we are working at blocking mode
ch.socket().setKeepAlive(true);
ch.socket().setSoTimeout(60 * 1000);
try {
SSLContext sslContext = Link.initSSLContext(true);
sslEngine = sslContext.createSSLEngine(ip, _port);
sslEngine.setUseClientMode(true);
Link.doHandshake(ch, sslEngine, true);
s_logger.info("SSL: Handshake done");
} catch (Exception e) {
throw new IOException("SSL: Fail to init SSL! " + e);
}
if (s_logger.isDebugEnabled()) {
s_logger.debug("Connection to peer opened: " + peerName + ", ip: " + ip);
}
_peers.put(peerName, ch);
_sslEngines.put(peerName, sslEngine);
} catch (IOException e) {
s_logger.warn("Unable to connect to peer management server: " + peerName + ", ip: " + ip + " due to " + e.getMessage(), e);
return null;

View File

@ -18,6 +18,8 @@
package com.cloud.utils.nio;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
@ -26,11 +28,16 @@ import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.nio.channels.WritableByteChannel;
import java.security.KeyStore;
import java.util.concurrent.ConcurrentLinkedQueue;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLSession;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import org.apache.log4j.Logger;
@ -82,13 +89,14 @@ public class Link {
}
/**
* No user, so comment it out.
*
* Static methods for reading from a channel in case
* you need to add a client that doesn't require nio.
* @param ch channel to read from.
* @param bytebuffer to use.
* @return bytes read
* @throws IOException if not read to completion.
*/
public static byte[] read(SocketChannel ch, ByteBuffer buff) throws IOException {
synchronized(buff) {
buff.clear();
@ -121,7 +129,44 @@ public class Link {
return output.toByteArray();
}
}
*/
private static void doWrite(SocketChannel ch, ByteBuffer[] buffers, SSLEngine sslEngine) throws IOException {
ByteBuffer pkgBuf;
SSLSession sslSession = sslEngine.getSession();
SSLEngineResult engResult;
ByteBuffer headBuf = ByteBuffer.allocate(4);
pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
engResult = sslEngine.wrap(buffers, pkgBuf);
if (engResult.getHandshakeStatus() != HandshakeStatus.FINISHED &&
engResult.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING &&
engResult.getStatus() != SSLEngineResult.Status.OK) {
throw new IOException("SSL: SSLEngine return bad result! " + engResult);
}
int dataRemaining = pkgBuf.position();
int headRemaining = 4;
pkgBuf.flip();
headBuf.putInt(dataRemaining);
headBuf.flip();
while (headRemaining > 0) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Writing Header " + headRemaining);
}
long count = ch.write(headBuf);
headRemaining -= count;
}
while (dataRemaining > 0) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Writing Data " + dataRemaining);
}
long count = ch.write(pkgBuf);
dataRemaining -= count;
}
}
/**
* write method to write to a socket. This method writes to completion so
@ -132,26 +177,10 @@ public class Link {
* @param buffers buffers to write.
* @throws IOException if unable to write to completion.
*/
public static void write(SocketChannel ch, ByteBuffer[] buffers) throws IOException {
public static void write(SocketChannel ch, ByteBuffer[] buffers, SSLEngine sslEngine) throws IOException {
synchronized(ch) {
int length = 0;
ByteBuffer[] buff = new ByteBuffer[buffers.length + 1];
for (int i = 0; i < buffers.length; i++) {
length += buffers[i].remaining();
buff[i + 1] = buffers[i];
}
buff[0] = ByteBuffer.allocate(4);
buff[0].putInt(length);
buff[0].flip();
long count = 0;
while (count < length + 4) {
long written = ch.write(buff);
if (written < 0) {
throw new IOException("Unable to write after " + count);
}
count += written;
}
}
doWrite(ch, buffers, sslEngine);
}
}
public byte[] read(SocketChannel ch) throws IOException {
@ -285,42 +314,10 @@ public class Link {
return true;
}
ByteBuffer pkgBuf;
SSLSession sslSession = _sslEngine.getSession();
SSLEngineResult engResult;
ByteBuffer headBuf = ByteBuffer.allocate(4);
ByteBuffer[] raw_data = new ByteBuffer[data.length - 1];
System.arraycopy(data, 1, raw_data, 0, data.length - 1);
pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
engResult = _sslEngine.wrap(raw_data, pkgBuf);
if (engResult.getHandshakeStatus() != HandshakeStatus.FINISHED &&
engResult.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING &&
engResult.getStatus() != SSLEngineResult.Status.OK) {
throw new IOException("SSL: SSLEngine return bad result! " + engResult);
}
int dataRemaining = pkgBuf.position();
int headRemaining = 4;
pkgBuf.flip();
headBuf.putInt(dataRemaining);
headBuf.flip();
while (headRemaining > 0) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Writing Header " + headRemaining);
}
long count = ch.write(headBuf);
headRemaining -= count;
}
while (dataRemaining > 0) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Writing Data " + dataRemaining);
}
long count = ch.write(pkgBuf);
dataRemaining -= count;
}
doWrite(ch, raw_data, _sslEngine);
}
return false;
}
@ -343,4 +340,132 @@ public class Link {
}
_connection.scheduleTask(task);
}
public static SSLContext initSSLContext(boolean isClient) throws Exception {
SSLContext sslContext = null;
KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
KeyStore ks = KeyStore.getInstance("JKS");
TrustManager[] tms;
if (!isClient) {
char[] passphrase = "vmops.com".toCharArray();
String keystorePath = "/etc/cloud/management/cloud.keystore";
if (new File(keystorePath).exists()) {
ks.load(new FileInputStream(keystorePath), passphrase);
} else {
s_logger.warn("SSL: Fail to find the generated keystore. Loading fail-safe one to continue.");
ks.load(NioConnection.class.getResourceAsStream("/cloud.keystore"), passphrase);
}
kmf.init(ks, passphrase);
tmf.init(ks);
tms = tmf.getTrustManagers();
} else {
ks.load(null, null);
kmf.init(ks, null);
tms = new TrustManager[1];
tms[0] = new TrustAllManager();
}
sslContext = SSLContext.getInstance("TLS");
sslContext.init(kmf.getKeyManagers(), tms, null);
s_logger.info("SSL: SSLcontext has been initialized");
return sslContext;
}
public static void doHandshake(SocketChannel ch, SSLEngine sslEngine,
boolean isClient) throws IOException {
s_logger.info("SSL: begin Handshake, isClient: " + isClient);
SSLEngineResult engResult;
SSLSession sslSession = sslEngine.getSession();
HandshakeStatus hsStatus;
ByteBuffer in_pkgBuf =
ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
ByteBuffer in_appBuf =
ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40);
ByteBuffer out_pkgBuf =
ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
ByteBuffer out_appBuf =
ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40);
int count;
if (isClient) {
hsStatus = SSLEngineResult.HandshakeStatus.NEED_WRAP;
} else {
hsStatus = SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
}
while (hsStatus != SSLEngineResult.HandshakeStatus.FINISHED) {
if (s_logger.isTraceEnabled()) {
s_logger.info("SSL: Handshake status " + hsStatus);
}
engResult = null;
if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP) {
out_pkgBuf.clear();
out_appBuf.clear();
out_appBuf.put("Hello".getBytes());
engResult = sslEngine.wrap(out_appBuf, out_pkgBuf);
out_pkgBuf.flip();
int remain = out_pkgBuf.limit();
while (remain != 0) {
remain -= ch.write(out_pkgBuf);
if (remain < 0) {
throw new IOException("Too much bytes sent?");
}
}
} else if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
in_appBuf.clear();
// One packet may contained multiply operation
if (in_pkgBuf.position() == 0 || !in_pkgBuf.hasRemaining()) {
in_pkgBuf.clear();
count = ch.read(in_pkgBuf);
if (count == -1) {
throw new IOException("Connection closed with -1 on reading size.");
}
in_pkgBuf.flip();
}
engResult = sslEngine.unwrap(in_pkgBuf, in_appBuf);
ByteBuffer tmp_pkgBuf =
ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
while (engResult.getStatus() == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
// We need more packets to complete this operation
if (s_logger.isTraceEnabled()) {
s_logger.info("SSL: Buffer overflowed, getting more packets");
}
tmp_pkgBuf.clear();
count = ch.read(tmp_pkgBuf);
tmp_pkgBuf.flip();
in_pkgBuf.mark();
in_pkgBuf.position(in_pkgBuf.limit());
in_pkgBuf.limit(in_pkgBuf.limit() + tmp_pkgBuf.limit());
in_pkgBuf.put(tmp_pkgBuf);
in_pkgBuf.reset();
in_appBuf.clear();
engResult = sslEngine.unwrap(in_pkgBuf, in_appBuf);
}
} else if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK) {
Runnable run;
while ((run = sslEngine.getDelegatedTask()) != null) {
if (s_logger.isTraceEnabled()) {
s_logger.info("SSL: Running delegated task!");
}
run.run();
}
} else if (hsStatus == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
throw new IOException("NOT a handshaking!");
}
if (engResult != null && engResult.getStatus() != SSLEngineResult.Status.OK) {
throw new IOException("Fail to handshake! " + engResult.getStatus());
}
if (engResult != null)
hsStatus = engResult.getHandshakeStatus();
else
hsStatus = sslEngine.getHandshakeStatus();
}
}
}

View File

@ -71,11 +71,11 @@ public class NioClient extends NioConnection {
// Begin SSL handshake in BLOCKING mode
sch.configureBlocking(true);
SSLContext sslContext = initSSLContext(true);
SSLContext sslContext = Link.initSSLContext(true);
sslEngine = sslContext.createSSLEngine(_host, _port);
sslEngine.setUseClientMode(true);
doHandshake(sch, sslEngine, true);
Link.doHandshake(sch, sslEngine, true);
s_logger.info("SSL: Handshake done");
} catch (Exception e) {
throw new IOException("SSL: Fail to init SSL! " + e);

View File

@ -17,20 +17,16 @@
*/
package com.cloud.utils.nio;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.net.ConnectException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.security.KeyStore;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
@ -40,19 +36,12 @@ import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLSession;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import org.apache.log4j.Logger;
import com.cloud.utils.concurrency.NamedThreadFactory;
import com.cloud.utils.nio.TrustAllManager;
/**
* NioConnection abstracts the NIO socket operations. The Java implementation
@ -184,133 +173,6 @@ public abstract class NioConnection implements Runnable {
abstract void registerLink(InetSocketAddress saddr, Link link);
abstract void unregisterLink(InetSocketAddress saddr);
protected SSLContext initSSLContext(boolean isClient) throws Exception {
SSLContext sslContext = null;
KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
KeyStore ks = KeyStore.getInstance("JKS");
TrustManager[] tms;
if (!isClient) {
char[] passphrase = "vmops.com".toCharArray();
String keystorePath = "/etc/cloud/management/cloud.keystore";
if (new File(keystorePath).exists()) {
ks.load(new FileInputStream(keystorePath), passphrase);
} else {
s_logger.warn("SSL: Fail to find the generated keystore. Loading fail-safe one to continue.");
ks.load(NioConnection.class.getResourceAsStream("/cloud.keystore"), passphrase);
}
kmf.init(ks, passphrase);
tmf.init(ks);
tms = tmf.getTrustManagers();
} else {
ks.load(null, null);
kmf.init(ks, null);
tms = new TrustManager[1];
tms[0] = new TrustAllManager();
}
sslContext = SSLContext.getInstance("TLS");
sslContext.init(kmf.getKeyManagers(), tms, null);
s_logger.info("SSL: SSLcontext has been initialized");
return sslContext;
}
protected void doHandshake(SocketChannel ch, SSLEngine sslEngine,
boolean isClient) throws IOException {
s_logger.info("SSL: begin Handshake, isClient: " + isClient);
SSLEngineResult engResult;
SSLSession sslSession = sslEngine.getSession();
HandshakeStatus hsStatus;
ByteBuffer in_pkgBuf =
ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
ByteBuffer in_appBuf =
ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40);
ByteBuffer out_pkgBuf =
ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
ByteBuffer out_appBuf =
ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40);
int count;
if (isClient) {
hsStatus = SSLEngineResult.HandshakeStatus.NEED_WRAP;
} else {
hsStatus = SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
}
while (hsStatus != SSLEngineResult.HandshakeStatus.FINISHED) {
if (s_logger.isTraceEnabled()) {
s_logger.info("SSL: Handshake status " + hsStatus);
}
engResult = null;
if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP) {
out_pkgBuf.clear();
out_appBuf.clear();
out_appBuf.put("Hello".getBytes());
engResult = sslEngine.wrap(out_appBuf, out_pkgBuf);
out_pkgBuf.flip();
int remain = out_pkgBuf.limit();
while (remain != 0) {
remain -= ch.write(out_pkgBuf);
if (remain < 0) {
throw new IOException("Too much bytes sent?");
}
}
} else if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
in_appBuf.clear();
// One packet may contained multiply operation
if (in_pkgBuf.position() == 0 || !in_pkgBuf.hasRemaining()) {
in_pkgBuf.clear();
count = ch.read(in_pkgBuf);
if (count == -1) {
throw new IOException("Connection closed with -1 on reading size.");
}
in_pkgBuf.flip();
}
engResult = sslEngine.unwrap(in_pkgBuf, in_appBuf);
ByteBuffer tmp_pkgBuf =
ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
while (engResult.getStatus() == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
// We need more packets to complete this operation
if (s_logger.isTraceEnabled()) {
s_logger.info("SSL: Buffer overflowed, getting more packets");
}
tmp_pkgBuf.clear();
count = ch.read(tmp_pkgBuf);
tmp_pkgBuf.flip();
in_pkgBuf.mark();
in_pkgBuf.position(in_pkgBuf.limit());
in_pkgBuf.limit(in_pkgBuf.limit() + tmp_pkgBuf.limit());
in_pkgBuf.put(tmp_pkgBuf);
in_pkgBuf.reset();
in_appBuf.clear();
engResult = sslEngine.unwrap(in_pkgBuf, in_appBuf);
}
} else if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK) {
Runnable run;
while ((run = sslEngine.getDelegatedTask()) != null) {
if (s_logger.isTraceEnabled()) {
s_logger.info("SSL: Running delegated task!");
}
run.run();
}
} else if (hsStatus == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
throw new IOException("NOT a handshaking!");
}
if (engResult != null && engResult.getStatus() != SSLEngineResult.Status.OK) {
throw new IOException("Fail to handshake! " + engResult.getStatus());
}
if (engResult != null)
hsStatus = engResult.getHandshakeStatus();
else
hsStatus = sslEngine.getHandshakeStatus();
}
}
protected void accept(SelectionKey key) throws IOException {
ServerSocketChannel serverSocketChannel = (ServerSocketChannel)key.channel();
@ -327,12 +189,12 @@ public abstract class NioConnection implements Runnable {
SSLEngine sslEngine = null;
try {
SSLContext sslContext = initSSLContext(false);
SSLContext sslContext = Link.initSSLContext(false);
sslEngine = sslContext.createSSLEngine();
sslEngine.setUseClientMode(false);
sslEngine.setNeedClientAuth(false);
doHandshake(socketChannel, sslEngine, false);
Link.doHandshake(socketChannel, sslEngine, false);
s_logger.info("SSL: Handshake done");
} catch (Exception e) {
throw new IOException("SSL: Fail to init SSL! " + e);