SSHClient: execute and derivatives: implement timeout
Use chan.status_event.wait(timeout) to wait for exit code instead of chan.recv_exit_status(), which waits infinite
Change-Id: Ia313b41aafaa2e1e24236a91548db37840f92568
(cherry picked from commit 2056518
)
This commit is contained in:
parent
828b77bd9c
commit
e80d24770d
|
@ -26,6 +26,7 @@ import paramiko
|
|||
import six
|
||||
|
||||
from devops.error import DevopsCalledProcessError
|
||||
from devops.error import TimeoutError
|
||||
from devops.helpers.retry import retry
|
||||
from devops import logger
|
||||
|
||||
|
@ -228,8 +229,8 @@ class _MemorizedSSH(type):
|
|||
if hash((cls, host, port, auth)) == hash(cls.__cache[key]):
|
||||
ssh = cls.__cache[key]
|
||||
try:
|
||||
ssh.execute('cd ~')
|
||||
except (paramiko.SSHException, AttributeError):
|
||||
ssh.execute('cd ~', timeout=5)
|
||||
except (paramiko.SSHException, AttributeError, TimeoutError):
|
||||
logger.debug('Reconnect {}'.format(ssh))
|
||||
ssh.reconnect()
|
||||
return ssh
|
||||
|
@ -533,12 +534,13 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
|
|||
|
||||
def check_call(
|
||||
self,
|
||||
command, verbose=False,
|
||||
command, verbose=False, timeout=None,
|
||||
expected=None, raise_on_err=True):
|
||||
"""Execute command and check for return code
|
||||
|
||||
:type command: str
|
||||
:type verbose: bool
|
||||
:type timeout: int
|
||||
:type expected: list
|
||||
:type raise_on_err: bool
|
||||
:rtype: dict
|
||||
|
@ -546,7 +548,7 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
|
|||
"""
|
||||
if expected is None:
|
||||
expected = [0]
|
||||
ret = self.execute(command, verbose)
|
||||
ret = self.execute(command, verbose, timeout)
|
||||
if ret['exit_code'] not in expected:
|
||||
message = (
|
||||
"Command '{cmd}' returned exit code {code} while "
|
||||
|
@ -570,16 +572,21 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
|
|||
stderr=ret['stderr_str'])
|
||||
return ret
|
||||
|
||||
def check_stderr(self, command, verbose=False, raise_on_err=True):
|
||||
def check_stderr(
|
||||
self,
|
||||
command, verbose=False, timeout=None,
|
||||
raise_on_err=True):
|
||||
"""Execute command expecting return code 0 and empty STDERR
|
||||
|
||||
:type command: str
|
||||
:type verbose: bool
|
||||
:type timeout: int
|
||||
:type raise_on_err: bool
|
||||
:rtype: dict
|
||||
:raises: DevopsCalledProcessError
|
||||
"""
|
||||
ret = self.check_call(command, verbose, raise_on_err=raise_on_err)
|
||||
ret = self.check_call(
|
||||
command, verbose, timeout=timeout, raise_on_err=raise_on_err)
|
||||
if ret['stderr']:
|
||||
message = (
|
||||
"Command '{cmd}' STDERR while not expected\n"
|
||||
|
@ -626,19 +633,64 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
|
|||
if errors and raise_on_err:
|
||||
raise DevopsCalledProcessError(command, errors)
|
||||
|
||||
def execute(self, command, verbose=False):
|
||||
@classmethod
|
||||
def __get_channel_exit_status(
|
||||
cls, command, channel, stdout, stderr, timeout):
|
||||
"""Get exit status from channel with timeout
|
||||
|
||||
:type command: str
|
||||
:type channel: paramiko.channel.Channel
|
||||
:type stdout: file
|
||||
:type stderr: file
|
||||
:type timeout: int
|
||||
:rtype: int
|
||||
:raises: TimeoutError
|
||||
"""
|
||||
channel.status_event.wait(timeout)
|
||||
if channel.status_event.is_set():
|
||||
return channel.exit_status
|
||||
else:
|
||||
stdout_lst = stdout.readlines()
|
||||
stderr_lst = stderr.readlines()
|
||||
|
||||
channel.close()
|
||||
status_tmpl = (
|
||||
'Wait for {0} during {1}s: no return code!\n'
|
||||
'\tSTDOUT:\n'
|
||||
'{2}\n'
|
||||
'\tSTDERR"\n'
|
||||
'{3}')
|
||||
logger.debug(
|
||||
status_tmpl.format(
|
||||
command, timeout,
|
||||
cls._get_str_from_list(stdout_lst),
|
||||
cls._get_str_from_list(stderr_lst)
|
||||
)
|
||||
)
|
||||
raise TimeoutError(
|
||||
status_tmpl.format(
|
||||
command, timeout,
|
||||
cls._get_str_from_list(stdout_lst[-5:]), # 5 last lines
|
||||
cls._get_str_from_list(stderr_lst[-5:]) # 5 last lines
|
||||
))
|
||||
|
||||
def execute(self, command, verbose=False, timeout=None):
|
||||
"""Execute command and wait for return code
|
||||
|
||||
:type command: str
|
||||
:type verbose: bool
|
||||
:type timeout: int
|
||||
:rtype: dict
|
||||
:raises: TimeoutError
|
||||
"""
|
||||
chan, _, stderr, stdout = self.execute_async(command)
|
||||
|
||||
# noinspection PyDictCreation
|
||||
result = {
|
||||
'exit_code': chan.recv_exit_status()
|
||||
'exit_code': self.__get_channel_exit_status(
|
||||
command, chan, stdout, stderr, timeout)
|
||||
}
|
||||
|
||||
result['stdout'] = stdout.readlines()
|
||||
result['stderr'] = stderr.readlines()
|
||||
|
||||
|
@ -710,14 +762,18 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
|
|||
hostname,
|
||||
cmd,
|
||||
auth=None,
|
||||
target_port=22):
|
||||
target_port=22,
|
||||
timeout=None
|
||||
):
|
||||
"""Execute command on remote host through currently connected host
|
||||
|
||||
:type hostname: str
|
||||
:type cmd: str
|
||||
:type auth: SSHAuth
|
||||
:type target_port: int
|
||||
:type timeout: int
|
||||
:rtype: dict
|
||||
:raises: TimeoutError
|
||||
"""
|
||||
if auth is None:
|
||||
auth = self.auth
|
||||
|
@ -740,14 +796,16 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
|
|||
|
||||
channel.exec_command(cmd)
|
||||
|
||||
# TODO(astepanov): make a logic for controlling channel state
|
||||
# noinspection PyDictCreation
|
||||
result = {}
|
||||
result['exit_code'] = channel.recv_exit_status()
|
||||
result = {
|
||||
'exit_code': self.__get_channel_exit_status(
|
||||
cmd, channel, stdout, stderr, timeout)
|
||||
}
|
||||
|
||||
result['stdout'] = stdout.readlines()
|
||||
result['stderr'] = stderr.readlines()
|
||||
channel.close()
|
||||
intermediate_channel.close()
|
||||
|
||||
result['stdout_str'] = self._get_str_from_list(result['stdout'])
|
||||
result['stderr_str'] = self._get_str_from_list(result['stderr'])
|
||||
|
|
|
@ -29,6 +29,7 @@ import paramiko
|
|||
from six.moves import cStringIO
|
||||
|
||||
from devops.error import DevopsCalledProcessError
|
||||
from devops.error import TimeoutError
|
||||
from devops.helpers.ssh_client import SSHAuth
|
||||
from devops.helpers.ssh_client import SSHClient
|
||||
|
||||
|
@ -1061,19 +1062,29 @@ class TestExecute(TestCase):
|
|||
chan = mock.Mock()
|
||||
recv_exit_status = mock.Mock(return_value=exit_code)
|
||||
chan.attach_mock(recv_exit_status, 'recv_exit_status')
|
||||
|
||||
wait = mock.Mock()
|
||||
status_event = mock.Mock()
|
||||
status_event.attach_mock(wait, 'wait')
|
||||
chan.attach_mock(status_event, 'status_event')
|
||||
chan.configure_mock(exit_status=exit_code)
|
||||
|
||||
return chan, '', stderr, stdout
|
||||
|
||||
@mock.patch(
|
||||
'devops.helpers.ssh_client.SSHClient.execute_async')
|
||||
def test_execute(self, execute_async, client, policy, logger):
|
||||
chan, _stdin, stderr, stdout = self.get_patched_execute_async_retval()
|
||||
is_set = mock.Mock(return_value=True)
|
||||
chan.status_event.attach_mock(is_set, 'is_set')
|
||||
|
||||
execute_async.return_value = chan, _stdin, stderr, stdout
|
||||
|
||||
stderr_lst = stderr.readlines()
|
||||
stdout_lst = stdout.readlines()
|
||||
|
||||
expected = {
|
||||
'exit_code': chan.recv_exit_status(),
|
||||
'exit_code': 0,
|
||||
'stderr': stderr_lst,
|
||||
'stdout': stdout_lst,
|
||||
'stderr_str': b''.join(stderr_lst).strip().decode(
|
||||
|
@ -1093,7 +1104,8 @@ class TestExecute(TestCase):
|
|||
)
|
||||
execute_async.assert_called_once_with(command)
|
||||
chan.assert_has_calls((
|
||||
mock.call.recv_exit_status(),
|
||||
mock.call.status_event.wait(None),
|
||||
mock.call.status_event.is_set(),
|
||||
mock.call.close()))
|
||||
logger.assert_has_calls((
|
||||
mock.call.info(
|
||||
|
@ -1110,6 +1122,82 @@ class TestExecute(TestCase):
|
|||
)),
|
||||
))
|
||||
|
||||
@mock.patch(
|
||||
'devops.helpers.ssh_client.SSHClient.execute_async')
|
||||
def test_execute_timeout(self, execute_async, client, policy, logger):
|
||||
exit_code = 0
|
||||
|
||||
chan, _stdin, stderr, stdout = self.get_patched_execute_async_retval()
|
||||
is_set = mock.Mock(return_value=True)
|
||||
chan.status_event.attach_mock(is_set, 'is_set')
|
||||
|
||||
execute_async.return_value = chan, _stdin, stderr, stdout
|
||||
|
||||
stderr_lst = stderr.readlines()
|
||||
stdout_lst = stdout.readlines()
|
||||
|
||||
expected = {
|
||||
'exit_code': exit_code,
|
||||
'stderr': stderr_lst,
|
||||
'stdout': stdout_lst,
|
||||
'stderr_str': b''.join(stderr_lst).strip().decode(
|
||||
encoding='utf-8'),
|
||||
'stdout_str': b''.join(stdout_lst).strip().decode(
|
||||
encoding='utf-8')}
|
||||
|
||||
ssh = self.get_ssh()
|
||||
|
||||
logger.reset_mock()
|
||||
|
||||
result = ssh.execute(command=command, verbose=True, timeout=1)
|
||||
|
||||
self.assertEqual(
|
||||
result,
|
||||
expected
|
||||
)
|
||||
execute_async.assert_called_once_with(command)
|
||||
chan.assert_has_calls((
|
||||
mock.call.status_event.wait(1),
|
||||
mock.call.status_event.is_set(),
|
||||
mock.call.close()))
|
||||
logger.assert_has_calls((
|
||||
mock.call.info(
|
||||
'{cmd} execution results:\n'
|
||||
'Exit code: {code}\n'
|
||||
'STDOUT:\n'
|
||||
'{stdout}\n'
|
||||
'STDERR:\n'
|
||||
'{stderr}'.format(
|
||||
cmd=command,
|
||||
code=result['exit_code'],
|
||||
stdout=result['stdout_str'],
|
||||
stderr=result['stderr_str']
|
||||
)),
|
||||
))
|
||||
|
||||
@mock.patch(
|
||||
'devops.helpers.ssh_client.SSHClient.execute_async')
|
||||
def test_execute_timeout_fail(self, execute_async, client, policy, logger):
|
||||
|
||||
chan, _stdin, stderr, stdout = self.get_patched_execute_async_retval()
|
||||
is_set = mock.Mock(return_value=False)
|
||||
chan.status_event.attach_mock(is_set, 'is_set')
|
||||
|
||||
execute_async.return_value = chan, _stdin, stderr, stdout
|
||||
|
||||
ssh = self.get_ssh()
|
||||
|
||||
logger.reset_mock()
|
||||
|
||||
with self.assertRaises(TimeoutError):
|
||||
ssh.execute(command=command, verbose=True, timeout=1)
|
||||
|
||||
execute_async.assert_called_once_with(command)
|
||||
chan.assert_has_calls((
|
||||
mock.call.status_event.wait(1),
|
||||
mock.call.status_event.is_set(),
|
||||
mock.call.close()))
|
||||
|
||||
@mock.patch(
|
||||
'devops.helpers.ssh_client.SSHClient.execute_async')
|
||||
def test_execute_together(self, execute_async, client, policy, logger):
|
||||
|
@ -1163,8 +1251,8 @@ class TestExecute(TestCase):
|
|||
|
||||
ssh = self.get_ssh()
|
||||
|
||||
result = ssh.check_call(command=command, verbose=verbose)
|
||||
execute.assert_called_once_with(command, verbose)
|
||||
result = ssh.check_call(command=command, verbose=verbose, timeout=None)
|
||||
execute.assert_called_once_with(command, verbose, None)
|
||||
self.assertEqual(result, return_value)
|
||||
|
||||
exit_code = 1
|
||||
|
@ -1172,8 +1260,8 @@ class TestExecute(TestCase):
|
|||
execute.reset_mock()
|
||||
execute.return_value = return_value
|
||||
with self.assertRaises(DevopsCalledProcessError):
|
||||
ssh.check_call(command=command, verbose=verbose)
|
||||
execute.assert_called_once_with(command, verbose)
|
||||
ssh.check_call(command=command, verbose=verbose, timeout=None)
|
||||
execute.assert_called_once_with(command, verbose, None)
|
||||
|
||||
@mock.patch(
|
||||
'devops.helpers.ssh_client.SSHClient.check_call')
|
||||
|
@ -1192,9 +1280,10 @@ class TestExecute(TestCase):
|
|||
ssh = self.get_ssh()
|
||||
|
||||
result = ssh.check_stderr(
|
||||
command=command, verbose=verbose, raise_on_err=raise_on_err)
|
||||
command=command, verbose=verbose, timeout=None,
|
||||
raise_on_err=raise_on_err)
|
||||
check_call.assert_called_once_with(
|
||||
command, verbose, raise_on_err=raise_on_err)
|
||||
command, verbose, timeout=None, raise_on_err=raise_on_err)
|
||||
self.assertEqual(result, return_value)
|
||||
|
||||
return_value['stderr_str'] = '0\n1'
|
||||
|
@ -1204,9 +1293,10 @@ class TestExecute(TestCase):
|
|||
check_call.return_value = return_value
|
||||
with self.assertRaises(DevopsCalledProcessError):
|
||||
ssh.check_stderr(
|
||||
command=command, verbose=verbose, raise_on_err=raise_on_err)
|
||||
command=command, verbose=verbose, timeout=None,
|
||||
raise_on_err=raise_on_err)
|
||||
check_call.assert_called_once_with(
|
||||
command, verbose, raise_on_err=raise_on_err)
|
||||
command, verbose, timeout=None, raise_on_err=raise_on_err)
|
||||
|
||||
|
||||
@mock.patch('devops.helpers.ssh_client.logger', autospec=True)
|
||||
|
@ -1252,6 +1342,15 @@ class TestExecuteThrowHost(TestCase):
|
|||
open_session = mock.Mock(return_value=channel)
|
||||
transport.attach_mock(open_session, 'open_session')
|
||||
|
||||
wait = mock.Mock()
|
||||
status_event = mock.Mock()
|
||||
status_event.attach_mock(wait, 'wait')
|
||||
channel.attach_mock(status_event, 'status_event')
|
||||
channel.configure_mock(exit_status=exit_code)
|
||||
|
||||
is_set = mock.Mock(return_value=True)
|
||||
channel.status_event.attach_mock(is_set, 'is_set')
|
||||
|
||||
return (
|
||||
open_session, transport, channel, get_transport,
|
||||
open_channel, intermediate_channel
|
||||
|
@ -1296,7 +1395,8 @@ class TestExecuteThrowHost(TestCase):
|
|||
mock.call.makefile('rb'),
|
||||
mock.call.makefile_stderr('rb'),
|
||||
mock.call.exec_command('ls ~ '),
|
||||
mock.call.recv_exit_status(),
|
||||
mock.call.status_event.wait(None),
|
||||
mock.call.status_event.is_set(),
|
||||
mock.call.close()
|
||||
))
|
||||
|
||||
|
@ -1344,7 +1444,8 @@ class TestExecuteThrowHost(TestCase):
|
|||
mock.call.makefile('rb'),
|
||||
mock.call.makefile_stderr('rb'),
|
||||
mock.call.exec_command('ls ~ '),
|
||||
mock.call.recv_exit_status(),
|
||||
mock.call.status_event.wait(None),
|
||||
mock.call.status_event.is_set(),
|
||||
mock.call.close()
|
||||
))
|
||||
|
||||
|
|
Loading…
Reference in New Issue