diff --git a/devops/driver/libvirt/libvirt_driver.py b/devops/driver/libvirt/libvirt_driver.py index 1cdbb842..abb12c8f 100644 --- a/devops/driver/libvirt/libvirt_driver.py +++ b/devops/driver/libvirt/libvirt_driver.py @@ -967,6 +967,7 @@ class LibvirtNode(Node): def destroy(self, *args, **kwargs): if self.is_active(): self._libvirt_node.destroy() + super(LibvirtNode, self).destroy() @retry() def remove(self, *args, **kwargs): @@ -988,6 +989,7 @@ class LibvirtNode(Node): def suspend(self, *args, **kwargs): if self.is_active(): self._libvirt_node.suspend() + super(LibvirtNode, self).suspend() @retry() def resume(self, *args, **kwargs): @@ -1003,6 +1005,7 @@ class LibvirtNode(Node): :rtype : None """ self._libvirt_node.reboot() + super(LibvirtNode, self).reboot() @retry() def shutdown(self): @@ -1011,10 +1014,12 @@ class LibvirtNode(Node): :rtype : None """ self._libvirt_node.shutdown() + super(LibvirtNode, self).shutdown() @retry() def reset(self): self._libvirt_node.reset() + super(LibvirtNode, self).reset() @retry() def has_snapshot(self, name): diff --git a/devops/helpers/ntp.py b/devops/helpers/ntp.py index 6e8ff79c..ce6b3040 100644 --- a/devops/helpers/ntp.py +++ b/devops/helpers/ntp.py @@ -307,12 +307,7 @@ class GroupNtpSync(object): return self def __exit__(self, exp_type, exp_value, traceback): - for ntp in self.admin_ntps: - ntp.remote.clear() - for ntp in self.pacemaker_ntps: - ntp.remote.clear() - for ntp in self.other_ntps: - ntp.remote.clear() + pass @staticmethod def report_node_names(ntps): diff --git a/devops/helpers/ssh_client.py b/devops/helpers/ssh_client.py index d977f3dc..9319713e 100644 --- a/devops/helpers/ssh_client.py +++ b/devops/helpers/ssh_client.py @@ -177,7 +177,77 @@ class SSHAuth(object): ) -class SSHClient(object): +class _MemorizedSSH(type): + __cache = {} + + def __call__( + cls, + host, port=22, + username=None, password=None, private_keys=None, + auth=None + ): + """Main memorize method: check for cached instance and return it + + :type host: str + :type port: int + :type username: str + :type password: str + :type private_keys: list + :type auth: SSHAuth + :rtype: SSHClient + """ + if (host, port) in cls.__cache: + key = host, port + if auth is None: + auth = SSHAuth( + username=username, password=password, keys=private_keys) + if hash((cls, host, port, auth)) == hash(cls.__cache[key]): + ssh = cls.__cache[key] + try: + ssh.execute('cd ~') + except (paramiko.SSHException, AttributeError): + logger.debug('Reconnect {}'.format(ssh)) + ssh.reconnect() + return ssh + del cls.__cache[key] + return super( + _MemorizedSSH, cls).__call__( + host=host, port=port, + username=username, password=password, private_keys=private_keys, + auth=auth) + + @classmethod + def record(cls, ssh): + """Record SSH client to cache + + :type ssh: SSHClient + """ + cls.__cache[(ssh.hostname, ssh.port)] = ssh + + @classmethod + def clear_cache(cls): + """Clear cached connections for initialize new instance on next call""" + cls.__cache = {} + + @classmethod + def close_connections(cls, hostname=None): + """Close connections for selected or all cached records + + :type hostname: str + """ + if hostname is None: + keys = [key for key, ssh in cls.__cache.items() if ssh.is_alive] + else: + keys = [ + (host, port) + for (host, port), ssh + in cls.__cache.items() if host == hostname and ssh.is_alive] + # raise ValueError(keys) + for key in keys: + cls.__cache[key].close() + + +class SSHClient(six.with_metaclass(_MemorizedSSH, object)): __slots__ = [ '__hostname', '__port', '__auth', '__ssh', '__sftp', 'sudo_mode' ] @@ -243,6 +313,7 @@ class SSHClient(object): ) self.__connect() + _MemorizedSSH.record(ssh=self) if auth is None: logger.info( '{0}:{1}> SSHAuth was made from old style creds: ' @@ -269,6 +340,15 @@ class SSHClient(object): """ return self.__hostname + @property + def host(self): + """Hostname access for backward compatibility""" + warn( + 'host has been deprecated in favor of hostname', + DeprecationWarning + ) + return self.hostname + @property def port(self): """Connected remote port number @@ -277,6 +357,14 @@ class SSHClient(object): """ return self.__port + @property + def is_alive(self): + """Paramiko status: ready to use|reconnect required + + :rtype: bool + """ + return self.__ssh.get_transport() is not None + def __repr__(self): return '{cls}(host={host}, port={port}, auth={auth!r})'.format( cls=self.__class__.__name__, host=self.hostname, port=self.port, @@ -330,8 +418,8 @@ class SSHClient(object): return self.__sftp raise paramiko.SSHException('SFTP connection failed') - def clear(self): - """Clear SSH and SFTP sessions""" + def close(self): + """Close SSH and SFTP sessions""" try: self.__ssh.close() self.__sftp = None @@ -343,7 +431,37 @@ class SSHClient(object): except Exception: logger.exception("Could not close sftp connection") + @staticmethod + def clear(): + warn( + "clear is removed: use close() only if it mandatory: " + "it's automatically called on revert|shutdown|suspend|destroy", + DeprecationWarning + ) + + @classmethod + def _clear_cache(cls): + """Enforce clear memorized records""" + warn( + '_clear_cache() is dangerous and not recommended for normal use!', + Warning + ) + _MemorizedSSH.clear_cache() + + @classmethod + def close_connections(cls, hostname=None): + """Close cached connections: if hostname is not set, then close all + + :type hostname: str + """ + _MemorizedSSH.close_connections(hostname=hostname) + def __del__(self): + """Destructor helper: close channel and threads BEFORE closing others + + Due to threading in paramiko, default destructor could generate asserts + on close, so we calling channel close before closing main ssh object. + """ self.__ssh.close() self.__sftp = None @@ -351,11 +469,11 @@ class SSHClient(object): return self def __exit__(self, exc_type, exc_val, exc_tb): - self.clear() + pass def reconnect(self): - """Reconnect SSH and SFTP session""" - self.clear() + """Reconnect SSH session""" + self.close() self.__ssh = paramiko.SSHClient() self.__ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) @@ -513,6 +631,7 @@ class SSHClient(object): :rtype: tuple """ logger.debug("Executing command: '{}'".format(command.rstrip())) + chan = self._ssh.get_transport().open_session() stdin = chan.makefile('wb') stdout = chan.makefile('rb') diff --git a/devops/models/node.py b/devops/models/node.py index f2dc2d22..5b19189c 100644 --- a/devops/models/node.py +++ b/devops/models/node.py @@ -180,19 +180,20 @@ class Node(six.with_metaclass(ExtendableNodeType, ParamedModel, BaseModel)): pass def destroy(self, *args, **kwargs): - pass + self._close_remotes() def erase(self, *args, **kwargs): self.remove() def remove(self, *args, **kwargs): + self._close_remotes() self.erase_volumes() for iface in self.interfaces: iface.remove() self.delete() def suspend(self, *args, **kwargs): - pass + self._close_remotes() def resume(self, *args, **kwargs): pass @@ -201,12 +202,21 @@ class Node(six.with_metaclass(ExtendableNodeType, ParamedModel, BaseModel)): pass def revert(self, *args, **kwargs): - pass + self._close_remotes() # for fuel-qa compatibility def has_snapshot(self, *args, **kwargs): return True + def reboot(self): + pass + + def shutdown(self): + self._close_remotes() + + def reset(self): + pass + # for fuel-qa compatibility def get_snapshots(self): """Return full snapshots objects""" @@ -282,6 +292,17 @@ class Node(six.with_metaclass(ExtendableNodeType, ParamedModel, BaseModel)): username=login, password=password, private_keys=private_keys, auth=auth) + def _close_remotes(self): + """Call close cached ssh connections for current node""" + for network_name in {'admin', 'public', 'internal'}: + try: + SSHClient.close_connections( + hostname=self.get_ip_address_by_network_name(network_name)) + except BaseException: + logger.debug( + '{0}._close_remotes for {1} failed'.format( + self.name, network_name)) + def await(self, network_name, timeout=120, by_port=22): wait_pass( lambda: tcp_ping_( diff --git a/devops/tests/helpers/test_ssh_client.py b/devops/tests/helpers/test_ssh_client.py index 0d05e296..ec3df449 100644 --- a/devops/tests/helpers/test_ssh_client.py +++ b/devops/tests/helpers/test_ssh_client.py @@ -53,6 +53,9 @@ command = 'ls ~ ' class TestSSHAuth(TestCase): + def tearDown(self): + SSHClient._clear_cache() + def init_checks(self, username=None, password=None, key=None, keys=None): """shared positive init checks @@ -163,6 +166,9 @@ class TestSSHAuth(TestCase): 'paramiko.AutoAddPolicy', autospec=True, return_value='AutoAddPolicy') @mock.patch('paramiko.SSHClient', autospec=True) class TestSSHClientInit(TestCase): + def tearDown(self): + SSHClient._clear_cache() + def init_checks( self, client, policy, logger, @@ -498,7 +504,7 @@ class TestSSHClientInit(TestCase): logger.reset_mock() - ssh.clear() + ssh.close() logger.assert_has_calls(( mock.call.exception('Could not close ssh connection'), @@ -756,12 +762,92 @@ class TestSSHClientInit(TestCase): mock.call.debug('SFTP is not connected, try to connect...'), )) + def test_init_memorize(self, client, policy, logger, sleep): + port1 = 2222 + host1 = '127.0.0.2' + + # 1. Normal init + ssh01 = SSHClient(host=host) + ssh02 = SSHClient(host=host) + ssh11 = SSHClient(host=host, port=port1) + ssh12 = SSHClient(host=host, port=port1) + ssh21 = SSHClient(host=host1) + ssh22 = SSHClient(host=host1) + + self.assertTrue(ssh01 is ssh02) + self.assertTrue(ssh11 is ssh12) + self.assertTrue(ssh21 is ssh22) + self.assertFalse(ssh01 is ssh11) + self.assertFalse(ssh01 is ssh21) + self.assertFalse(ssh11 is ssh21) + + # 2. Close connections check + client.reset_mock() + ssh01.close_connections(ssh01.hostname) + client.assert_has_calls(( + mock.call().get_transport(), + mock.call().get_transport(), + mock.call().close(), + mock.call().close(), + )) + client.reset_mock() + ssh01.close_connections() + # Mock returns false-connected state, so we just count close calls + + client.assert_has_calls(( + mock.call().get_transport(), + mock.call().get_transport(), + mock.call().get_transport(), + mock.call().close(), + mock.call().close(), + mock.call().close(), + )) + + # change creds + SSHClient(host=host, auth=SSHAuth(username=username)) + + # Change back: new connection differs from old with the same creds + ssh004 = SSHAuth(host) + self.assertFalse(ssh01 is ssh004) + + @mock.patch( + 'devops.helpers.ssh_client.SSHClient.execute') + def test_init_memorize_reconnect( + self, execute, client, policy, logger, sleep): + execute.side_effect = paramiko.SSHException + SSHClient(host=host) + client.reset_mock() + policy.reset_mock() + logger.reset_mock() + SSHClient(host=host) + client.assert_called_once() + policy.assert_called_once() + + @mock.patch('devops.helpers.ssh_client.warn') + def test_init_clear(self, warn, client, policy, logger, sleep): + ssh01 = SSHClient(host=host, auth=SSHAuth()) + + ssh01.clear() + warn.assert_called_once_with( + "clear is removed: use close() only if it mandatory: " + "it's automatically called on revert|shutdown|suspend|destroy", + DeprecationWarning + ) + + self.assertNotIn( + mock.call.close(), + client.mock_calls + ) + @mock.patch('devops.helpers.ssh_client.logger', autospec=True) @mock.patch( 'paramiko.AutoAddPolicy', autospec=True, return_value='AutoAddPolicy') @mock.patch('paramiko.SSHClient', autospec=True) class TestExecute(TestCase): + def tearDown(self): + SSHClient._clear_cache() + @staticmethod def get_ssh(): """SSHClient object builder for execution tests @@ -1101,6 +1187,9 @@ class TestExecute(TestCase): @mock.patch('paramiko.SSHClient', autospec=True) @mock.patch('paramiko.Transport', autospec=True) class TestExecuteThrowHost(TestCase): + def tearDown(self): + SSHClient._clear_cache() + @staticmethod def prepare_execute_through_host(transp, client, exit_code): intermediate_channel = mock.Mock() @@ -1237,6 +1326,9 @@ class TestExecuteThrowHost(TestCase): 'paramiko.AutoAddPolicy', autospec=True, return_value='AutoAddPolicy') @mock.patch('paramiko.SSHClient', autospec=True) class TestSftp(TestCase): + def tearDown(self): + SSHClient._clear_cache() + @staticmethod def prepare_sftp_file_tests(client): _ssh = mock.Mock() diff --git a/tox.ini b/tox.ini index e583e83d..973bfb49 100644 --- a/tox.ini +++ b/tox.ini @@ -28,7 +28,7 @@ deps = -r{toxinidir}/test-requirements.txt commands = py.test --cov-config .coveragerc --cov-report html --cov=devops devops/tests - coverage report --fail-under 73 + coverage report --fail-under 74 [testenv:pep8]