CLOUDSTACK-9348: Unit test to demonstrate denial of service attack

The NioConnection uses blocking handlers for various events such as connect,
accept, read, write. In case a client connects NioServer (used by
agent mgr to service agents on port 8250) but fails to participate in SSL
handshake or just sits idle, this would block the main IO/selector loop in
NioConnection. Such a client could be either malicious or aggresive.

This unit test demonstrates such a malicious client that can perform a
denial-of-service attack on NioServer that blocks it to serve any other client.

Signed-off-by: Rohit Yadav <rohit.yadav@shapeblue.com>
This commit is contained in:
Rohit Yadav 2016-04-15 00:24:53 +05:30
parent e762e27054
commit 0154da6417
2 changed files with 153 additions and 97 deletions

View File

@ -208,7 +208,6 @@
<excludes> <excludes>
<exclude>com/cloud/utils/testcase/*TestCase*</exclude> <exclude>com/cloud/utils/testcase/*TestCase*</exclude>
<exclude>com/cloud/utils/db/*Test*</exclude> <exclude>com/cloud/utils/db/*Test*</exclude>
<exclude>com/cloud/utils/testcase/NioTest.java</exclude>
</excludes> </excludes>
</configuration> </configuration>
</plugin> </plugin>

View File

@ -19,14 +19,7 @@
package com.cloud.utils.testcase; package com.cloud.utils.testcase;
import java.nio.channels.ClosedChannelException; import com.cloud.utils.concurrency.NamedThreadFactory;
import java.util.Random;
import junit.framework.TestCase;
import org.apache.log4j.Logger;
import org.junit.Assert;
import com.cloud.utils.exception.NioConnectionException; import com.cloud.utils.exception.NioConnectionException;
import com.cloud.utils.nio.HandlerFactory; import com.cloud.utils.nio.HandlerFactory;
import com.cloud.utils.nio.Link; import com.cloud.utils.nio.Link;
@ -34,131 +27,190 @@ import com.cloud.utils.nio.NioClient;
import com.cloud.utils.nio.NioServer; import com.cloud.utils.nio.NioServer;
import com.cloud.utils.nio.Task; import com.cloud.utils.nio.Task;
import com.cloud.utils.nio.Task.Type; import com.cloud.utils.nio.Task.Type;
import org.apache.log4j.Logger;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
/** import java.io.IOException;
* import java.net.InetSocketAddress;
* import java.nio.channels.ClosedChannelException;
* import java.nio.channels.Selector;
* import java.nio.channels.SocketChannel;
*/ import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class NioTest extends TestCase { public class NioTest {
private static final Logger s_logger = Logger.getLogger(NioTest.class); private static final Logger LOGGER = Logger.getLogger(NioTest.class);
private NioServer _server; final private int totalTestCount = 10;
private NioClient _client; private int completedTestCount = 0;
private Link _clientLink; private NioServer server;
private List<NioClient> clients = new ArrayList<>();
private List<NioClient> maliciousClients = new ArrayList<>();
private int _testCount; private ExecutorService clientExecutor = Executors.newFixedThreadPool(totalTestCount, new NamedThreadFactory("NioClientHandler"));;
private int _completedCount; private ExecutorService maliciousExecutor = Executors.newFixedThreadPool(5*totalTestCount, new NamedThreadFactory("MaliciousNioClientHandler"));;
private Random randomGenerator = new Random();
private byte[] testBytes;
private boolean isTestsDone() { private boolean isTestsDone() {
boolean result; boolean result;
synchronized (this) { synchronized (this) {
result = _testCount == _completedCount; result = totalTestCount == completedTestCount;
} }
return result; return result;
} }
private void getOneMoreTest() {
synchronized (this) {
_testCount++;
}
}
private void oneMoreTestDone() { private void oneMoreTestDone() {
synchronized (this) { synchronized (this) {
_completedCount++; completedTestCount++;
} }
} }
@Override @Before
public void setUp() { public void setUp() {
s_logger.info("Test"); LOGGER.info("Setting up Benchmark Test");
_testCount = 0; completedTestCount = 0;
_completedCount = 0; testBytes = new byte[1000000];
randomGenerator.nextBytes(testBytes);
_server = new NioServer("NioTestServer", 7777, 5, new NioTestServer()); // Server configured with one worker
server = new NioServer("NioTestServer", 7777, 1, new NioTestServer());
try { try {
_server.start(); server.start();
} catch (final NioConnectionException e) { } catch (final NioConnectionException e) {
fail(e.getMessage()); Assert.fail(e.getMessage());
} }
_client = new NioClient("NioTestServer", "127.0.0.1", 7777, 5, new NioTestClient()); // 5 malicious clients per valid client
try { for (int i = 0; i < totalTestCount; i++) {
_client.start(); for (int j = 0; j < 5; j++) {
} catch (final NioConnectionException e) { final NioClient maliciousClient = new NioMaliciousClient("NioMaliciousTestClient-" + i, "127.0.0.1", 7777, 1, new NioMaliciousTestClient());
fail(e.getMessage()); maliciousClients.add(maliciousClient);
} maliciousExecutor.submit(new ThreadedNioClient(maliciousClient));
while (_clientLink == null) {
try {
s_logger.debug("Link is not up! Waiting ...");
Thread.sleep(1000);
} catch (final InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} }
final NioClient client = new NioClient("NioTestClient-" + i, "127.0.0.1", 7777, 1, new NioTestClient());
clients.add(client);
clientExecutor.submit(new ThreadedNioClient(client));
} }
} }
@Override @After
public void tearDown() { public void tearDown() {
while (!isTestsDone()) {
try {
s_logger.debug(_completedCount + "/" + _testCount + " tests done. Waiting for completion");
Thread.sleep(1000);
} catch (final InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
stopClient(); stopClient();
stopServer(); stopServer();
} }
protected void stopClient() { protected void stopClient() {
_client.stop(); for (NioClient client : clients) {
s_logger.info("Client stopped."); client.stop();
}
for (NioClient maliciousClient : maliciousClients) {
maliciousClient.stop();
}
LOGGER.info("Clients stopped.");
} }
protected void stopServer() { protected void stopServer() {
_server.stop(); server.stop();
s_logger.info("Server stopped."); LOGGER.info("Server stopped.");
} }
protected void setClientLink(final Link link) { @Test
_clientLink = link;
}
Random randomGenerator = new Random();
byte[] _testBytes;
public void testConnection() { public void testConnection() {
_testBytes = new byte[1000000]; final long currentTime = System.currentTimeMillis();
randomGenerator.nextBytes(_testBytes); while (!isTestsDone()) {
try { if (System.currentTimeMillis() - currentTime > 600000) {
getOneMoreTest(); Assert.fail("Failed to complete test within 600s");
_clientLink.send(_testBytes); }
s_logger.info("Client: Data sent"); try {
getOneMoreTest(); LOGGER.debug(completedTestCount + "/" + totalTestCount + " tests done. Waiting for completion");
_clientLink.send(_testBytes); Thread.sleep(1000);
s_logger.info("Client: Data sent"); } catch (final InterruptedException e) {
} catch (final ClosedChannelException e) { Assert.fail(e.getMessage());
// TODO Auto-generated catch block }
e.printStackTrace();
} }
LOGGER.debug(completedTestCount + "/" + totalTestCount + " tests done.");
} }
protected void doServerProcess(final byte[] data) { protected void doServerProcess(final byte[] data) {
oneMoreTestDone(); oneMoreTestDone();
Assert.assertArrayEquals(_testBytes, data); Assert.assertArrayEquals(testBytes, data);
s_logger.info("Verify done."); LOGGER.info("Verify data received by server done.");
}
public byte[] getTestBytes() {
return testBytes;
}
public class ThreadedNioClient implements Runnable {
final private NioClient client;
ThreadedNioClient(final NioClient client) {
this.client = client;
}
@Override
public void run() {
try {
client.start();
} catch (NioConnectionException e) {
Assert.fail(e.getMessage());
}
}
}
public class NioMaliciousClient extends NioClient {
public NioMaliciousClient(String name, String host, int port, int workers, HandlerFactory factory) {
super(name, host, port, workers, factory);
}
@Override
protected void init() throws IOException {
_selector = Selector.open();
try {
_clientConnection = SocketChannel.open();
LOGGER.info("Connecting to " + _host + ":" + _port);
final InetSocketAddress peerAddr = new InetSocketAddress(_host, _port);
_clientConnection.connect(peerAddr);
// Hang in there don't do anything
Thread.sleep(3600000);
} catch (final IOException e) {
_selector.close();
throw e;
} catch (InterruptedException e) {
LOGGER.debug(e.getMessage());
}
}
}
public class NioMaliciousTestClient implements HandlerFactory {
@Override
public Task create(final Type type, final Link link, final byte[] data) {
return new NioMaliciousTestClientHandler(type, link, data);
}
public class NioMaliciousTestClientHandler extends Task {
public NioMaliciousTestClientHandler(final Type type, final Link link, final byte[] data) {
super(type, link, data);
}
@Override
public void doTask(final Task task) {
LOGGER.info("Malicious Client: Received task " + task.getType().toString());
}
}
} }
public class NioTestClient implements HandlerFactory { public class NioTestClient implements HandlerFactory {
@ -177,18 +229,23 @@ public class NioTest extends TestCase {
@Override @Override
public void doTask(final Task task) { public void doTask(final Task task) {
if (task.getType() == Task.Type.CONNECT) { if (task.getType() == Task.Type.CONNECT) {
s_logger.info("Client: Received CONNECT task"); LOGGER.info("Client: Received CONNECT task");
setClientLink(task.getLink()); try {
LOGGER.info("Sending data to server");
task.getLink().send(getTestBytes());
} catch (ClosedChannelException e) {
LOGGER.error(e.getMessage());
e.printStackTrace();
}
} else if (task.getType() == Task.Type.DATA) { } else if (task.getType() == Task.Type.DATA) {
s_logger.info("Client: Received DATA task"); LOGGER.info("Client: Received DATA task");
} else if (task.getType() == Task.Type.DISCONNECT) { } else if (task.getType() == Task.Type.DISCONNECT) {
s_logger.info("Client: Received DISCONNECT task"); LOGGER.info("Client: Received DISCONNECT task");
stopClient(); stopClient();
} else if (task.getType() == Task.Type.OTHER) { } else if (task.getType() == Task.Type.OTHER) {
s_logger.info("Client: Received OTHER task"); LOGGER.info("Client: Received OTHER task");
} }
} }
} }
} }
@ -208,15 +265,15 @@ public class NioTest extends TestCase {
@Override @Override
public void doTask(final Task task) { public void doTask(final Task task) {
if (task.getType() == Task.Type.CONNECT) { if (task.getType() == Task.Type.CONNECT) {
s_logger.info("Server: Received CONNECT task"); LOGGER.info("Server: Received CONNECT task");
} else if (task.getType() == Task.Type.DATA) { } else if (task.getType() == Task.Type.DATA) {
s_logger.info("Server: Received DATA task"); LOGGER.info("Server: Received DATA task");
doServerProcess(task.getData()); doServerProcess(task.getData());
} else if (task.getType() == Task.Type.DISCONNECT) { } else if (task.getType() == Task.Type.DISCONNECT) {
s_logger.info("Server: Received DISCONNECT task"); LOGGER.info("Server: Received DISCONNECT task");
stopServer(); stopServer();
} else if (task.getType() == Task.Type.OTHER) { } else if (task.getType() == Task.Type.OTHER) {
s_logger.info("Server: Received OTHER task"); LOGGER.info("Server: Received OTHER task");
} }
} }