diff --git a/server/src/com/cloud/server/ManagementServerImpl.java b/server/src/com/cloud/server/ManagementServerImpl.java index 781c915b123..564a59584cd 100644 --- a/server/src/com/cloud/server/ManagementServerImpl.java +++ b/server/src/com/cloud/server/ManagementServerImpl.java @@ -3642,9 +3642,9 @@ public class ManagementServerImpl extends ManagerBase implements ManagementServe * @throws InvalidParameterValueException */ private void checkForKeyByPublicKey(final RegisterSSHKeyPairCmd cmd, final Account owner) throws InvalidParameterValueException { - final SSHKeyPairVO existingPair = _sshKeyPairDao.findByPublicKey(owner.getAccountId(), owner.getDomainId(), cmd.getPublicKey()); + final SSHKeyPairVO existingPair = _sshKeyPairDao.findByPublicKey(owner.getAccountId(), owner.getDomainId(), getPublicKeyFromKeyKeyMaterial(cmd.getPublicKey())); if (existingPair != null) { - throw new InvalidParameterValueException("A key pair with name '" + cmd.getPublicKey() + "' already exists for this account."); + throw new InvalidParameterValueException("A key pair with key '" + cmd.getPublicKey() + "' already exists for this account."); } } @@ -3674,7 +3674,7 @@ public class ManagementServerImpl extends ManagerBase implements ManagementServe * @return * @throws InvalidParameterValueException */ - private String getPublicKeyFromKeyKeyMaterial(final String key) throws InvalidParameterValueException { + protected String getPublicKeyFromKeyKeyMaterial(final String key) throws InvalidParameterValueException { final String publicKey = SSHKeysHelper.getPublicKeyFromKeyMaterial(key); if (publicKey == null) { diff --git a/server/test/com/cloud/server/ManagementServerImplTest.java b/server/test/com/cloud/server/ManagementServerImplTest.java index 1e530e63725..ffaff8f68aa 100644 --- a/server/test/com/cloud/server/ManagementServerImplTest.java +++ b/server/test/com/cloud/server/ManagementServerImplTest.java @@ -16,6 +16,7 @@ // under the License. package com.cloud.server; +import com.cloud.user.SSHKeyPair; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; @@ -23,6 +24,10 @@ import org.mockito.Mockito; import org.mockito.Spy; import org.mockito.runners.MockitoJUnitRunner; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.any; + import org.apache.cloudstack.api.command.user.ssh.RegisterSSHKeyPairCmd; import com.cloud.exception.InvalidParameterValueException; @@ -35,33 +40,71 @@ public class ManagementServerImplTest { @Mock RegisterSSHKeyPairCmd regCmd; + @Mock SSHKeyPairVO existingPair; + @Mock Account account; + @Mock SSHKeyPairDao sshKeyPairDao; ManagementServerImpl ms = new ManagementServerImpl(); + + @Mock + SSHKeyPair sshKeyPair; + @Spy ManagementServerImpl spy; @Test(expected = InvalidParameterValueException.class) - public void testExistingPairRegistration() { + public void testDuplicateRegistraitons(){ String accountName = "account"; - String publicKeyString = "very public"; - // setup owner with domainid + String publicKeyString = "ssh-rsa very public"; + String publicKeyMaterial = spy.getPublicKeyFromKeyKeyMaterial(publicKeyString); + Mockito.doReturn(account).when(spy).getCaller(); Mockito.doReturn(account).when(spy).getOwner(regCmd); - // mock _sshKeyPairDao.findByName to return null + Mockito.doNothing().when(spy).checkForKeyByName(regCmd, account); - // mock _sshKeyPairDao.findByPublicKey to return a known object Mockito.doReturn(accountName).when(regCmd).getAccountName(); + Mockito.doReturn(publicKeyString).when(regCmd).getPublicKey(); Mockito.doReturn("name").when(regCmd).getName(); + spy._sshKeyPairDao = sshKeyPairDao; Mockito.doReturn(1L).when(account).getAccountId(); Mockito.doReturn(1L).when(account).getDomainId(); - Mockito.doReturn(existingPair).when(sshKeyPairDao).findByPublicKey(1L, 1L, publicKeyString); + Mockito.doReturn(Mockito.mock(SSHKeyPairVO.class)).when(sshKeyPairDao).persist(any(SSHKeyPairVO.class)); + + when(sshKeyPairDao.findByName(1L, 1L, "name")).thenReturn(null).thenReturn(null); + when(sshKeyPairDao.findByPublicKey(1L, 1L, publicKeyMaterial)).thenReturn(null).thenReturn(existingPair); + + spy.registerSSHKeyPair(regCmd); spy.registerSSHKeyPair(regCmd); } + @Test + public void testSuccess(){ + String accountName = "account"; + String publicKeyString = "ssh-rsa very public"; + String publicKeyMaterial = spy.getPublicKeyFromKeyKeyMaterial(publicKeyString); + + Mockito.doReturn(1L).when(account).getAccountId(); + Mockito.doReturn(1L).when(account).getAccountId(); + spy._sshKeyPairDao = sshKeyPairDao; + + + //Mocking the DAO object functions - NO object found in DB + Mockito.doReturn(Mockito.mock(SSHKeyPairVO.class)).when(sshKeyPairDao).findByPublicKey(1L, 1L,publicKeyMaterial); + Mockito.doReturn(Mockito.mock(SSHKeyPairVO.class)).when(sshKeyPairDao).findByName(1L, 1L, accountName); + Mockito.doReturn(Mockito.mock(SSHKeyPairVO.class)).when(sshKeyPairDao).persist(any(SSHKeyPairVO.class)); + + //Mocking the User Params + Mockito.doReturn(accountName).when(regCmd).getName(); + Mockito.doReturn(publicKeyString).when(regCmd).getPublicKey(); + Mockito.doReturn(account).when(spy).getOwner(regCmd); + + spy.registerSSHKeyPair(regCmd); + Mockito.verify(spy, Mockito.times(3)).getPublicKeyFromKeyKeyMaterial(anyString()); + } } diff --git a/setup/db/db/schema-452to460.sql b/setup/db/db/schema-452to460.sql index 0abd4f80408..74d5d92200e 100644 --- a/setup/db/db/schema-452to460.sql +++ b/setup/db/db/schema-452to460.sql @@ -354,6 +354,10 @@ CREATE VIEW `cloud`.`user_vm_view` AS left join `cloud`.`user_vm_details` `custom_ram_size` ON (((`custom_ram_size`.`vm_id` = `cloud`.`vm_instance`.`id`) and (`custom_ram_size`.`name` = 'memory'))); +---Additional checks to ensure duplicate keys are not registered and remove the previously stored duplicate keys. +DELETE `s1` FROM `ssh_keypairs` `s1`, `ssh_keypairs` `s2` WHERE `s1`.`id` > `s2`.`id` AND `s1`.`public_key` = `s2`.`public_key` AND `s1`.`account_id` = `s2`.`account_id`; +ALTER TABLE `ssh_keypairs` ADD UNIQUE `unique_index`(`fingerprint`,`account_id`); + -- ovm3 stuff INSERT INTO `cloud`.`guest_os_hypervisor` (hypervisor_type, guest_os_name, guest_os_id) VALUES ("Ovm3", 'Sun Solaris 10(32-bit)', 79); INSERT INTO `cloud`.`guest_os_hypervisor` (hypervisor_type, guest_os_name, guest_os_id) VALUES ("Ovm3", 'Sun Solaris 10(64-bit)', 80);