SSHClient: execute and derivatives: implement timeout

Use chan.status_event.wait(timeout) to wait for exit code instead of chan.recv_exit_status(), which waits infinite

Change-Id: Ia313b41aafaa2e1e24236a91548db37840f92568
(cherry picked from commit 2056518)
This commit is contained in:
Alexey Stepanov 2016-07-06 13:07:36 +03:00
parent 828b77bd9c
commit e80d24770d
2 changed files with 183 additions and 24 deletions

View File

@ -26,6 +26,7 @@ import paramiko
import six
from devops.error import DevopsCalledProcessError
from devops.error import TimeoutError
from devops.helpers.retry import retry
from devops import logger
@ -228,8 +229,8 @@ class _MemorizedSSH(type):
if hash((cls, host, port, auth)) == hash(cls.__cache[key]):
ssh = cls.__cache[key]
try:
ssh.execute('cd ~')
except (paramiko.SSHException, AttributeError):
ssh.execute('cd ~', timeout=5)
except (paramiko.SSHException, AttributeError, TimeoutError):
logger.debug('Reconnect {}'.format(ssh))
ssh.reconnect()
return ssh
@ -533,12 +534,13 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
def check_call(
self,
command, verbose=False,
command, verbose=False, timeout=None,
expected=None, raise_on_err=True):
"""Execute command and check for return code
:type command: str
:type verbose: bool
:type timeout: int
:type expected: list
:type raise_on_err: bool
:rtype: dict
@ -546,7 +548,7 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
"""
if expected is None:
expected = [0]
ret = self.execute(command, verbose)
ret = self.execute(command, verbose, timeout)
if ret['exit_code'] not in expected:
message = (
"Command '{cmd}' returned exit code {code} while "
@ -570,16 +572,21 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
stderr=ret['stderr_str'])
return ret
def check_stderr(self, command, verbose=False, raise_on_err=True):
def check_stderr(
self,
command, verbose=False, timeout=None,
raise_on_err=True):
"""Execute command expecting return code 0 and empty STDERR
:type command: str
:type verbose: bool
:type timeout: int
:type raise_on_err: bool
:rtype: dict
:raises: DevopsCalledProcessError
"""
ret = self.check_call(command, verbose, raise_on_err=raise_on_err)
ret = self.check_call(
command, verbose, timeout=timeout, raise_on_err=raise_on_err)
if ret['stderr']:
message = (
"Command '{cmd}' STDERR while not expected\n"
@ -626,19 +633,64 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
if errors and raise_on_err:
raise DevopsCalledProcessError(command, errors)
def execute(self, command, verbose=False):
@classmethod
def __get_channel_exit_status(
cls, command, channel, stdout, stderr, timeout):
"""Get exit status from channel with timeout
:type command: str
:type channel: paramiko.channel.Channel
:type stdout: file
:type stderr: file
:type timeout: int
:rtype: int
:raises: TimeoutError
"""
channel.status_event.wait(timeout)
if channel.status_event.is_set():
return channel.exit_status
else:
stdout_lst = stdout.readlines()
stderr_lst = stderr.readlines()
channel.close()
status_tmpl = (
'Wait for {0} during {1}s: no return code!\n'
'\tSTDOUT:\n'
'{2}\n'
'\tSTDERR"\n'
'{3}')
logger.debug(
status_tmpl.format(
command, timeout,
cls._get_str_from_list(stdout_lst),
cls._get_str_from_list(stderr_lst)
)
)
raise TimeoutError(
status_tmpl.format(
command, timeout,
cls._get_str_from_list(stdout_lst[-5:]), # 5 last lines
cls._get_str_from_list(stderr_lst[-5:]) # 5 last lines
))
def execute(self, command, verbose=False, timeout=None):
"""Execute command and wait for return code
:type command: str
:type verbose: bool
:type timeout: int
:rtype: dict
:raises: TimeoutError
"""
chan, _, stderr, stdout = self.execute_async(command)
# noinspection PyDictCreation
result = {
'exit_code': chan.recv_exit_status()
'exit_code': self.__get_channel_exit_status(
command, chan, stdout, stderr, timeout)
}
result['stdout'] = stdout.readlines()
result['stderr'] = stderr.readlines()
@ -710,14 +762,18 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
hostname,
cmd,
auth=None,
target_port=22):
target_port=22,
timeout=None
):
"""Execute command on remote host through currently connected host
:type hostname: str
:type cmd: str
:type auth: SSHAuth
:type target_port: int
:type timeout: int
:rtype: dict
:raises: TimeoutError
"""
if auth is None:
auth = self.auth
@ -740,14 +796,16 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
channel.exec_command(cmd)
# TODO(astepanov): make a logic for controlling channel state
# noinspection PyDictCreation
result = {}
result['exit_code'] = channel.recv_exit_status()
result = {
'exit_code': self.__get_channel_exit_status(
cmd, channel, stdout, stderr, timeout)
}
result['stdout'] = stdout.readlines()
result['stderr'] = stderr.readlines()
channel.close()
intermediate_channel.close()
result['stdout_str'] = self._get_str_from_list(result['stdout'])
result['stderr_str'] = self._get_str_from_list(result['stderr'])

View File

@ -29,6 +29,7 @@ import paramiko
from six.moves import cStringIO
from devops.error import DevopsCalledProcessError
from devops.error import TimeoutError
from devops.helpers.ssh_client import SSHAuth
from devops.helpers.ssh_client import SSHClient
@ -1061,19 +1062,29 @@ class TestExecute(TestCase):
chan = mock.Mock()
recv_exit_status = mock.Mock(return_value=exit_code)
chan.attach_mock(recv_exit_status, 'recv_exit_status')
wait = mock.Mock()
status_event = mock.Mock()
status_event.attach_mock(wait, 'wait')
chan.attach_mock(status_event, 'status_event')
chan.configure_mock(exit_status=exit_code)
return chan, '', stderr, stdout
@mock.patch(
'devops.helpers.ssh_client.SSHClient.execute_async')
def test_execute(self, execute_async, client, policy, logger):
chan, _stdin, stderr, stdout = self.get_patched_execute_async_retval()
is_set = mock.Mock(return_value=True)
chan.status_event.attach_mock(is_set, 'is_set')
execute_async.return_value = chan, _stdin, stderr, stdout
stderr_lst = stderr.readlines()
stdout_lst = stdout.readlines()
expected = {
'exit_code': chan.recv_exit_status(),
'exit_code': 0,
'stderr': stderr_lst,
'stdout': stdout_lst,
'stderr_str': b''.join(stderr_lst).strip().decode(
@ -1093,7 +1104,8 @@ class TestExecute(TestCase):
)
execute_async.assert_called_once_with(command)
chan.assert_has_calls((
mock.call.recv_exit_status(),
mock.call.status_event.wait(None),
mock.call.status_event.is_set(),
mock.call.close()))
logger.assert_has_calls((
mock.call.info(
@ -1110,6 +1122,82 @@ class TestExecute(TestCase):
)),
))
@mock.patch(
'devops.helpers.ssh_client.SSHClient.execute_async')
def test_execute_timeout(self, execute_async, client, policy, logger):
exit_code = 0
chan, _stdin, stderr, stdout = self.get_patched_execute_async_retval()
is_set = mock.Mock(return_value=True)
chan.status_event.attach_mock(is_set, 'is_set')
execute_async.return_value = chan, _stdin, stderr, stdout
stderr_lst = stderr.readlines()
stdout_lst = stdout.readlines()
expected = {
'exit_code': exit_code,
'stderr': stderr_lst,
'stdout': stdout_lst,
'stderr_str': b''.join(stderr_lst).strip().decode(
encoding='utf-8'),
'stdout_str': b''.join(stdout_lst).strip().decode(
encoding='utf-8')}
ssh = self.get_ssh()
logger.reset_mock()
result = ssh.execute(command=command, verbose=True, timeout=1)
self.assertEqual(
result,
expected
)
execute_async.assert_called_once_with(command)
chan.assert_has_calls((
mock.call.status_event.wait(1),
mock.call.status_event.is_set(),
mock.call.close()))
logger.assert_has_calls((
mock.call.info(
'{cmd} execution results:\n'
'Exit code: {code}\n'
'STDOUT:\n'
'{stdout}\n'
'STDERR:\n'
'{stderr}'.format(
cmd=command,
code=result['exit_code'],
stdout=result['stdout_str'],
stderr=result['stderr_str']
)),
))
@mock.patch(
'devops.helpers.ssh_client.SSHClient.execute_async')
def test_execute_timeout_fail(self, execute_async, client, policy, logger):
chan, _stdin, stderr, stdout = self.get_patched_execute_async_retval()
is_set = mock.Mock(return_value=False)
chan.status_event.attach_mock(is_set, 'is_set')
execute_async.return_value = chan, _stdin, stderr, stdout
ssh = self.get_ssh()
logger.reset_mock()
with self.assertRaises(TimeoutError):
ssh.execute(command=command, verbose=True, timeout=1)
execute_async.assert_called_once_with(command)
chan.assert_has_calls((
mock.call.status_event.wait(1),
mock.call.status_event.is_set(),
mock.call.close()))
@mock.patch(
'devops.helpers.ssh_client.SSHClient.execute_async')
def test_execute_together(self, execute_async, client, policy, logger):
@ -1163,8 +1251,8 @@ class TestExecute(TestCase):
ssh = self.get_ssh()
result = ssh.check_call(command=command, verbose=verbose)
execute.assert_called_once_with(command, verbose)
result = ssh.check_call(command=command, verbose=verbose, timeout=None)
execute.assert_called_once_with(command, verbose, None)
self.assertEqual(result, return_value)
exit_code = 1
@ -1172,8 +1260,8 @@ class TestExecute(TestCase):
execute.reset_mock()
execute.return_value = return_value
with self.assertRaises(DevopsCalledProcessError):
ssh.check_call(command=command, verbose=verbose)
execute.assert_called_once_with(command, verbose)
ssh.check_call(command=command, verbose=verbose, timeout=None)
execute.assert_called_once_with(command, verbose, None)
@mock.patch(
'devops.helpers.ssh_client.SSHClient.check_call')
@ -1192,9 +1280,10 @@ class TestExecute(TestCase):
ssh = self.get_ssh()
result = ssh.check_stderr(
command=command, verbose=verbose, raise_on_err=raise_on_err)
command=command, verbose=verbose, timeout=None,
raise_on_err=raise_on_err)
check_call.assert_called_once_with(
command, verbose, raise_on_err=raise_on_err)
command, verbose, timeout=None, raise_on_err=raise_on_err)
self.assertEqual(result, return_value)
return_value['stderr_str'] = '0\n1'
@ -1204,9 +1293,10 @@ class TestExecute(TestCase):
check_call.return_value = return_value
with self.assertRaises(DevopsCalledProcessError):
ssh.check_stderr(
command=command, verbose=verbose, raise_on_err=raise_on_err)
command=command, verbose=verbose, timeout=None,
raise_on_err=raise_on_err)
check_call.assert_called_once_with(
command, verbose, raise_on_err=raise_on_err)
command, verbose, timeout=None, raise_on_err=raise_on_err)
@mock.patch('devops.helpers.ssh_client.logger', autospec=True)
@ -1252,6 +1342,15 @@ class TestExecuteThrowHost(TestCase):
open_session = mock.Mock(return_value=channel)
transport.attach_mock(open_session, 'open_session')
wait = mock.Mock()
status_event = mock.Mock()
status_event.attach_mock(wait, 'wait')
channel.attach_mock(status_event, 'status_event')
channel.configure_mock(exit_status=exit_code)
is_set = mock.Mock(return_value=True)
channel.status_event.attach_mock(is_set, 'is_set')
return (
open_session, transport, channel, get_transport,
open_channel, intermediate_channel
@ -1296,7 +1395,8 @@ class TestExecuteThrowHost(TestCase):
mock.call.makefile('rb'),
mock.call.makefile_stderr('rb'),
mock.call.exec_command('ls ~ '),
mock.call.recv_exit_status(),
mock.call.status_event.wait(None),
mock.call.status_event.is_set(),
mock.call.close()
))
@ -1344,7 +1444,8 @@ class TestExecuteThrowHost(TestCase):
mock.call.makefile('rb'),
mock.call.makefile_stderr('rb'),
mock.call.exec_command('ls ~ '),
mock.call.recv_exit_status(),
mock.call.status_event.wait(None),
mock.call.status_event.is_set(),
mock.call.close()
))