Open PTY on channel before command execution

Change-Id: I1dc49ad66c00d3a99a3070b0aaa708bad287e885
Closes-bug: #1607402
This commit is contained in:
Alexey Stepanov 2016-08-12 12:11:37 +03:00
parent 944f953caf
commit 642962f4e1
2 changed files with 53 additions and 9 deletions

View File

@ -555,7 +555,7 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
self,
command, verbose=False, timeout=None,
error_info=None,
expected=None, raise_on_err=True):
expected=None, raise_on_err=True, **kwargs):
"""Execute command and check for return code
:type command: str
@ -578,7 +578,7 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
else code
for code in expected
]
ret = self.execute(command, verbose, timeout)
ret = self.execute(command, verbose, timeout, **kwargs)
if ret['exit_code'] not in expected:
message = (
"{append}Command '{cmd}' returned exit code {code!s} while "
@ -607,7 +607,7 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
self,
command, verbose=False, timeout=None,
error_info=None,
raise_on_err=True):
raise_on_err=True, **kwargs):
"""Execute command expecting return code 0 and empty STDERR
:type command: str
@ -620,7 +620,7 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
"""
ret = self.check_call(
command, verbose, timeout=timeout,
error_info=error_info, raise_on_err=raise_on_err)
error_info=error_info, raise_on_err=raise_on_err, **kwargs)
if ret['stderr']:
message = (
"{append}Command '{cmd}' STDERR while not expected\n"
@ -644,7 +644,7 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
@classmethod
def execute_together(
cls, remotes, command, expected=None, raise_on_err=True):
cls, remotes, command, expected=None, raise_on_err=True, **kwargs):
"""Execute command on multiple remotes in async mode
:type remotes: list
@ -658,7 +658,7 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
futures = {}
errors = {}
for remote in set(remotes): # Use distinct remotes
chan, _, _, _ = remote.execute_async(command)
chan, _, _, _ = remote.execute_async(command, **kwargs)
futures[remote] = chan
for remote, chan in futures.items():
ret = chan.recv_exit_status()
@ -715,7 +715,7 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
result.stderr_brief
))
def execute(self, command, verbose=False, timeout=None):
def execute(self, command, verbose=False, timeout=None, **kwargs):
"""Execute command and wait for return code
:type command: str
@ -724,7 +724,7 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
:rtype: ExecResult
:raises: TimeoutError
"""
chan, _, stderr, stdout = self.execute_async(command)
chan, _, stderr, stdout = self.execute_async(command, **kwargs)
result = self.__exec_command(command, chan, stdout, stderr, timeout)
@ -751,7 +751,7 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
return result
def execute_async(self, command, timeout=None):
def execute_async(self, command, timeout=None, get_pty=False):
"""Execute command in async mode and return channel with IO objects
:type command: str
@ -761,6 +761,15 @@ class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
logger.debug("Executing command: '{}'".format(command.rstrip()))
chan = self._ssh.get_transport().open_session(timeout=timeout)
if get_pty:
# Open PTY
chan.get_pty(
term='vt100',
width=80, height=24,
width_pixels=0, height_pixels=0
)
stdin = chan.makefile('wb')
stdout = chan.makefile('rb')
stderr = chan.makefile_stderr('rb')

View File

@ -940,6 +940,41 @@ class TestExecute(TestCase):
logger.mock_calls
)
def test_execute_async_pty(self, client, policy, logger):
chan = mock.Mock()
open_session = mock.Mock(return_value=chan)
transport = mock.Mock()
transport.attach_mock(open_session, 'open_session')
get_transport = mock.Mock(return_value=transport)
_ssh = mock.Mock()
_ssh.attach_mock(get_transport, 'get_transport')
client.return_value = _ssh
ssh = self.get_ssh()
# noinspection PyTypeChecker
result = ssh.execute_async(command=command, get_pty=True)
get_transport.assert_called_once()
open_session.assert_called_once()
self.assertIn(chan, result)
chan.assert_has_calls((
mock.call.get_pty(
term='vt100',
width=80, height=24,
width_pixels=0, height_pixels=0
),
mock.call.makefile('wb'),
mock.call.makefile('rb'),
mock.call.makefile_stderr('rb'),
mock.call.exec_command('{}\n'.format(command))
))
self.assertIn(
mock.call.debug(
"Executing command: '{}'".format(command.rstrip())),
logger.mock_calls
)
def test_execute_async_sudo(self, client, policy, logger):
chan = mock.Mock()
open_session = mock.Mock(return_value=chan)