SSHClient rework for SSH Manager integration

1. Implemented SSHAuth
2. Old initialization API has been marked as deprecated
3. SFTP is started on demand with 3 retries
4. Reworked unit test to cover 100%
5. Added docstrings
6. Remove cyclic SSH session initialization in helper
7. Code is ready for adopted memorize pattern

blueprint: sshmanager-integration

Change-Id: I49d0aa635ba3f3125ab17531c0790a0106b87fea
(cherry picked from commit aae58e2)
This commit is contained in:
Alexey Stepanov 2016-05-27 19:44:50 +03:00
parent 297b24f9f9
commit caffa38d48
6 changed files with 1072 additions and 247 deletions

9
.coveragerc Normal file
View File

@ -0,0 +1,9 @@
[run]
source =
devops
omit =
devops/tests/*
devops/migrations/*
devops/driver/dummy/*
devops/settings.py
devops/test_settings.py

View File

@ -32,6 +32,7 @@ from six.moves import xmlrpc_client
from devops.error import AuthenticationError
from devops.error import DevopsError
from devops.error import TimeoutError
from devops.helpers.ssh_client import SSHAuth
from devops.helpers.ssh_client import SSHClient
from devops import logger
from devops.settings import KEYSTONE_CREDS
@ -174,10 +175,12 @@ def get_node_remote(env, node_name, login=SSH_SLAVE_CREDENTIALS['login'],
name=node_name).interfaces[0].mac_address)
wait(lambda: tcp_ping(ip, 22), timeout=180,
timeout_msg="Node {ip} is not accessible by SSH.".format(ip=ip))
return SSHClient(ip,
username=login,
password=password,
private_keys=get_private_keys(env))
return SSHClient(
ip,
auth=SSHAuth(
username=login,
password=password,
keys=get_private_keys(env)))
def get_admin_ip(env):

View File

@ -16,6 +16,7 @@ import base64
import os
import posixpath
import stat
from warnings import warn
import paramiko
import six
@ -25,7 +26,163 @@ from devops.helpers.retry import retry
from devops import logger
class SSHAuth(object):
__slots__ = ['__username', '__password', '__key', '__keys']
def __init__(
self,
username=None, password=None, key=None, keys=None):
"""SSH authorisation object
Used to authorize SSHClient.
Single SSHAuth object is associated with single host:port.
Password and key is private, other data is read-only.
:type username: str
:type password: str
:type key: paramiko.RSAKey
:type keys: list
"""
self.__username = username
self.__password = password
self.__key = key
self.__keys = [None]
if key is not None:
self.__keys.append(key)
if keys is not None:
for key in keys:
if key not in self.__keys:
self.__keys.append(key)
@property
def username(self):
"""Username for auth
:rtype: str
"""
return self.__username
@staticmethod
def __get_public_key(key):
"""Internal method for get public key from private
:type key: paramiko.RSAKey
"""
if key is None:
return None
return '{0} {1}'.format(key.get_name(), key.get_base64())
@property
def public_key(self):
"""public key for stored private key if presents else None
:rtype: str
"""
return self.__get_public_key(self.__key)
def enter_password(self, tgt):
"""Enter password to STDIN
Note: required for 'sudo' call
:type tgt: file
:rtype: str
"""
return tgt.write('{}\n'.format(self.__password))
def connect(self, client, hostname=None, port=22, log=True):
"""Connect SSH client object using credentials
:type client:
paramiko.client.SSHClient
paramiko.transport.Transport
:type log: bool
:raises paramiko.AuthenticationException
"""
kwargs = {
'username': self.username,
'password': self.__password}
if hostname is not None:
kwargs['hostname'] = hostname
kwargs['port'] = port
keys = [self.__key]
keys.extend([k for k in self.__keys if k != self.__key])
for key in keys:
kwargs['pkey'] = key
try:
client.connect(**kwargs)
if self.__key != key:
self.__key = key
logger.debug(
'Main key has been updated, public key is: \n'
'{}'.format(self.public_key))
return
except paramiko.PasswordRequiredException:
if self.__password is None:
logger.exception('No password has been set!')
raise
else:
logger.critical(
'Unexpected PasswordRequiredException, '
'when password is set!')
raise
except paramiko.AuthenticationException:
continue
msg = 'Connection using stored authentication info failed!'
if log:
logger.exception(
'Connection using stored authentication info failed!')
raise paramiko.AuthenticationException(msg)
def __hash__(self):
return hash((
self.__class__,
self.username,
self.__password,
tuple(self.__keys)
))
def __eq__(self, other):
return hash(self) == hash(other)
def __repr__(self):
_key = (
None if self.__key is None else
'<private for pub: {}>'.format(self.public_key)
)
_keys = []
for k in self.__keys:
if k == self.__key:
continue
_keys.append(
'<private for pub: {}>'.format(
self.__get_public_key(key=k)) if k is not None else None)
return (
'{cls}(username={username}, '
'password=<*masked*>, key={key}, keys={keys})'.format(
cls=self.__class__.__name__,
username=self.username,
key=_key,
keys=_keys)
)
def __str__(self):
return (
'{cls} for {username}'.format(
cls=self.__class__.__name__,
username=self.username,
)
)
class SSHClient(object):
__slots__ = [
'__hostname', '__port', '__auth', '__ssh', '__sftp', 'sudo_mode'
]
class get_sudo(object):
"""Context manager for call commands with sudo"""
def __init__(self, ssh):
@ -37,42 +194,145 @@ class SSHClient(object):
def __exit__(self, exc_type, exc_val, exc_tb):
self.ssh.sudo_mode = False
def __init__(self, host, port=22, username=None, password=None,
private_keys=None):
self.host = str(host)
self.port = int(port)
self.username = username
self.__password = password
if not private_keys:
private_keys = []
self.__private_keys = private_keys
self.__actual_pkey = None
def __hash__(self):
return hash((
self.__class__,
self.hostname,
self.port,
self.auth))
def __init__(
self,
host, port=22,
username=None, password=None, private_keys=None,
auth=None
):
"""SSHClient helper
:type host: str
:type port: int
:type username: str
:type password: str
:type private_keys: list
:type auth: SSHAuth
"""
self.__hostname = host
self.__port = port
self.sudo_mode = False
self.sudo = self.get_sudo(self)
self._ssh = None
self.__ssh = paramiko.SSHClient()
self.__ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.__sftp = None
self.reconnect()
self.__auth = auth
if auth is None:
msg = (
'SSHClient(host={host}, port={port}, username={username}): '
'initialization by username/password/private_keys '
'is deprecated in favor of SSHAuth usage. '
'Please update your code'.format(
host=host, port=port, username=username
))
warn(msg, DeprecationWarning)
logger.debug(msg)
self.__auth = SSHAuth(
username=username,
password=password,
keys=private_keys
)
self.__connect()
if auth is None:
logger.info(
'{0}:{1}> SSHAuth was made from old style creds: '
'{2}'.format(self.hostname, self.port, self.auth))
@property
def password(self):
return self.__password
def auth(self):
"""Internal authorisation object
Attention: this public property is mainly for inheritance,
debug and information purposes.
Calls outside SSHClient and child classes is sign of incorrect design.
Change is completely disallowed.
:rtype: SSHAuth
"""
return self.__auth
@property
def private_keys(self):
return self.__private_keys
def hostname(self):
"""Connected remote host name
:rtype: str
"""
return self.__hostname
@property
def private_key(self):
return self.__actual_pkey
def host(self):
"""Hostname access for backward compatibility"""
warn(
'host has been deprecated in favor of hostname',
DeprecationWarning
)
return self.hostname
@property
def public_key(self):
if self.private_key is None:
return None
key = self.private_key
return '{0} {1}'.format(key.get_name(), key.get_base64())
def port(self):
"""Connected remote port number
:rtype: int
"""
return self.__port
@property
def is_alive(self):
"""Paramiko status: ready to use|reconnect required
:rtype: bool
"""
return self.__ssh.get_transport() is not None
def __repr__(self):
return '{cls}(host={host}, port={port}, auth={auth!r})'.format(
cls=self.__class__.__name__, host=self.hostname, port=self.port,
auth=self.auth
)
def __str__(self):
return '{cls}(host={host}, port={port}) for user {user}'.format(
cls=self.__class__.__name__, host=self.hostname, port=self.port,
user=self.auth.username
)
@property
def _ssh(self):
"""ssh client object getter for inheritance support only
Attention: ssh client object creation and change
is allowed only by __init__ and reconnect call.
:rtype: paramiko.SSHClient
"""
return self.__ssh
@retry(count=3, delay=3)
def __connect(self):
"""Main method for connection open"""
self.auth.connect(
client=self.__ssh,
hostname=self.hostname, port=self.port,
log=True)
@retry(3, delay=0)
def __connect_sftp(self):
"""SFTP connection opener"""
try:
self.__sftp = self.__ssh.open_sftp()
except paramiko.SSHException:
logger.warning('SFTP enable failed! SSH only is accessible.')
@property
def _sftp(self):
@ -82,25 +342,33 @@ class SSHClient(object):
"""
if self.__sftp is not None:
return self.__sftp
logger.warning('SFTP is not connected, try to reconnect')
self._connect_sftp()
logger.debug('SFTP is not connected, try to connect...')
self.__connect_sftp()
if self.__sftp is not None:
return self.__sftp
raise paramiko.SSHException('SFTP connection failed')
def clear(self):
if self.__sftp is not None:
try:
self.__sftp.close()
except Exception:
logger.exception("Could not close sftp connection")
"""Clear SSH and SFTP sessions"""
try:
self._ssh.close()
self.__ssh.close()
self.__sftp = None
except Exception:
logger.exception("Could not close ssh connection")
if self.__sftp is not None:
try:
self.__sftp.close()
except Exception:
logger.exception("Could not close sftp connection")
def __del__(self):
self.clear()
"""Destructor helper: close channel and threads BEFORE closing others
Due to threading in paramiko, default destructor could generate asserts
on close, so we calling channel close before closing main ssh object.
"""
self.__ssh.close()
self.__sftp = None
def __enter__(self):
return self
@ -108,39 +376,14 @@ class SSHClient(object):
def __exit__(self, exc_type, exc_val, exc_tb):
self.clear()
@retry(count=3, delay=3)
def connect(self):
logger.debug(
"Connect to '{0}:{1}' as '{2}:{3}'".format(
self.host, self.port, self.username, self.password))
for private_key in self.private_keys:
try:
self._ssh.connect(
self.host, port=self.port, username=self.username,
password=self.password, pkey=private_key)
self.__actual_pkey = private_key
return
except paramiko.AuthenticationException:
continue
if self.private_keys:
logger.error("Authentication with keys failed")
self.__actual_pkey = None
self._ssh.connect(
self.host, port=self.port, username=self.username,
password=self.password)
def _connect_sftp(self):
try:
self.__sftp = self._ssh.open_sftp()
except paramiko.SSHException:
logger.warning('SFTP enable failed! SSH only is accessible.')
def reconnect(self):
self._ssh = paramiko.SSHClient()
self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.connect()
self._connect_sftp()
"""Reconnect SSH and SFTP session"""
self.clear()
self.__ssh = paramiko.SSHClient()
self.__ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.__connect()
def check_call(
self,
@ -233,7 +476,7 @@ class SSHClient(object):
ret = chan.recv_exit_status()
chan.close()
if ret not in expected:
errors[remote.host] = ret
errors[remote.hostname] = ret
if errors and raise_on_err:
raise DevopsCalledProcessError(command, errors)
@ -293,6 +536,7 @@ class SSHClient(object):
:rtype: tuple
"""
logger.debug("Executing command: '{}'".format(command.rstrip()))
chan = self._ssh.get_transport().open_session()
stdin = chan.makefile('wb')
stdout = chan.makefile('rb')
@ -305,7 +549,7 @@ class SSHClient(object):
)
chan.exec_command(cmd)
if stdout.channel.closed is False:
stdin.write('%s\n' % self.password)
self.auth.enter_password(stdin)
stdin.flush()
else:
chan.exec_command(cmd)
@ -315,23 +559,27 @@ class SSHClient(object):
self,
hostname,
cmd,
username=None,
password=None,
key=None,
auth=None,
target_port=22):
if username is None and password is None and key is None:
username = self.username
password = self.__password
key = self.private_key
"""Execute command on remote host through currently connected host
:type hostname: str
:type cmd: str
:type auth: SSHAuth
:type target_port: int
:rtype: dict
"""
if auth is None:
auth = self.auth
intermediate_channel = self._ssh.get_transport().open_channel(
kind='direct-tcpip',
dest_addr=(hostname, target_port),
src_addr=(self.host, 0))
src_addr=(self.hostname, 0))
transport = paramiko.Transport(sock=intermediate_channel)
# start client and authenticate transport
transport.connect(username=username, password=password, pkey=key)
auth.connect(transport)
# open ssh session
channel = transport.open_session()
@ -340,7 +588,6 @@ class SSHClient(object):
stdout = channel.makefile('rb')
stderr = channel.makefile_stderr('rb')
logger.info("Executing command: {}".format(cmd))
channel.exec_command(cmd)
# TODO(astepanov): make a logic for controlling channel state
@ -478,3 +725,5 @@ class SSHClient(object):
return attrs.st_mode & stat.S_IFDIR != 0
except IOError:
return False
__all__ = ['SSHAuth', 'SSHClient']

View File

@ -1,4 +1,4 @@
# Copyright 2013 - 2015 Mirantis, Inc.
# Copyright 2013 - 2016 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
@ -13,6 +13,7 @@
# under the License.
import time
from warnings import warn
from django.conf import settings
from django.db import models
@ -22,7 +23,8 @@ from paramiko import RSAKey
from devops.error import DevopsEnvironmentError
from devops.helpers.helpers import get_file_size
from devops.helpers.helpers import SSHClient
from devops.helpers.ssh_client import SSHAuth
from devops.helpers.ssh_client import SSHClient
from devops.helpers.templates import create_devops_config
from devops.helpers.templates import get_devops_config
from devops import logger
@ -402,29 +404,39 @@ class Environment(DriverModel):
:rtype : SSHClient
"""
return self.nodes().admin.remote(
self.admin_net,
login=login,
password=password)
admin = sorted(
list(self.get_nodes(role='fuel_master')),
key=lambda node: node.name
)[0]
return admin.remote(
self.admin_net, auth=SSHAuth(
username=login,
password=password))
# @logwrap
def get_ssh_to_remote(self, ip,
login=settings.SSH_SLAVE_CREDENTIALS['login'],
password=settings.SSH_SLAVE_CREDENTIALS['password']):
warn('LEGACY, for fuel-qa compatibility', DeprecationWarning)
keys = []
remote = self.get_admin_remote()
for key_string in ['/root/.ssh/id_rsa',
'/root/.ssh/bootstrap.rsa']:
if self.get_admin_remote().isfile(key_string):
with self.get_admin_remote().open(key_string) as f:
if remote.isfile(key_string):
with remote.open(key_string) as f:
keys.append(RSAKey.from_private_key(f))
return SSHClient(ip,
username=login,
password=password,
private_keys=keys)
return SSHClient(
ip,
auth=SSHAuth(
username=login,
password=password,
keys=keys))
# @logwrap
def get_ssh_to_remote_by_key(self, ip, keyfile):
@staticmethod
def get_ssh_to_remote_by_key(ip, keyfile):
warn('LEGACY, for fuel-qa compatibility', DeprecationWarning)
try:
with open(keyfile) as f:
keys = [RSAKey.from_private_key(f)]
@ -432,7 +444,7 @@ class Environment(DriverModel):
logger.warning('Loading of SSH key from file failed. Trying to use'
' SSH agent ...')
keys = Agent().get_keys()
return SSHClient(ip, private_keys=keys)
return SSHClient(ip, auth=SSHAuth(keys=keys))
def nodes(self): # migrated from EnvironmentModel.nodes()
# DEPRECATED. Please use environment.get_nodes() instead.

View File

@ -162,7 +162,9 @@ class Node(DriverModel):
network__name=name).order_by('id')[0]
return interface.address_set.get(interface=interface).ip_address
def remote(self, network_name, login, password=None, private_keys=None):
def remote(
self, network_name, login=None, password=None, private_keys=None,
auth=None):
"""Create SSH-connection to the network
:rtype : SSHClient
@ -170,7 +172,7 @@ class Node(DriverModel):
return SSHClient(
self.get_ip_address_by_network_name(network_name),
username=login,
password=password, private_keys=private_keys)
password=password, private_keys=private_keys, auth=auth)
def send_keys(self, keys):
self.driver.node_send_keys(self, keys)

View File

@ -15,6 +15,7 @@
# pylint: disable=no-self-use
import base64
from contextlib import closing
from os.path import basename
import posixpath
import stat
@ -22,9 +23,12 @@ from unittest import TestCase
import mock
import paramiko
# noinspection PyUnresolvedReferences
from six.moves import cStringIO
from six import PY2
from devops.error import DevopsCalledProcessError
from devops.helpers.ssh_client import SSHAuth
from devops.helpers.ssh_client import SSHClient
@ -50,20 +54,136 @@ command = 'ls ~ '
encoded_cmd = base64.b64encode("%s\n" % command)
class TestSSHAuth(TestCase):
def init_checks(self, username=None, password=None, key=None, keys=None):
"""shared positive init checks
:type username: str
:type password: str
:type key: paramiko.RSAKey
:type keys: list
"""
auth = SSHAuth(
username=username,
password=password,
key=key,
keys=keys
)
int_keys = [None]
if key is not None:
int_keys.append(key)
if keys is not None:
for k in keys:
if k not in int_keys:
int_keys.append(k)
self.assertEqual(auth.username, username)
with closing(cStringIO()) as tgt:
auth.enter_password(tgt)
self.assertEqual(tgt.getvalue(), '{}\n'.format(password))
self.assertEqual(
auth.public_key,
gen_public_key(key) if key is not None else None)
_key = (
None if auth.public_key is None else
'<private for pub: {}>'.format(auth.public_key)
)
_keys = []
for k in int_keys:
if k == key:
continue
_keys.append(
'<private for pub: {}>'.format(
gen_public_key(k)) if k is not None else None)
self.assertEqual(
repr(auth),
"{cls}("
"username={username}, "
"password=<*masked*>, "
"key={key}, "
"keys={keys})".format(
cls=SSHAuth.__name__,
username=auth.username,
key=_key,
keys=_keys
)
)
self.assertEqual(
str(auth),
'{cls} for {username}'.format(
cls=SSHAuth.__name__,
username=auth.username,
)
)
def test_init_username_only(self):
self.init_checks(
username=username
)
def test_init_username_password(self):
self.init_checks(
username=username,
password=password
)
def test_init_username_key(self):
self.init_checks(
username=username,
key=gen_private_keys(1).pop()
)
def test_init_username_password_key(self):
self.init_checks(
username=username,
password=password,
key=gen_private_keys(1).pop()
)
def test_init_username_password_keys(self):
self.init_checks(
username=username,
password=password,
keys=gen_private_keys(2)
)
def test_init_username_password_key_keys(self):
self.init_checks(
username=username,
password=password,
key=gen_private_keys(1).pop(),
keys=gen_private_keys(2)
)
@mock.patch('devops.helpers.retry.sleep', autospec=True)
@mock.patch('devops.helpers.ssh_client.logger', autospec=True)
@mock.patch(
'paramiko.AutoAddPolicy', autospec=True, return_value='AutoAddPolicy')
@mock.patch('paramiko.SSHClient', autospec=True)
class TestSSHClient(TestCase):
def check_defaults(
self, obj, host, port, username, password, private_keys):
self.assertEqual(obj.host, host)
self.assertEqual(obj.port, port)
self.assertEqual(obj.username, username)
self.assertEqual(obj.password, password)
self.assertEqual(obj.private_keys, private_keys)
class TestSSHClientInit(TestCase):
def init_checks(
self,
client, policy, logger,
host=None, port=22,
username=None, password=None, private_keys=None,
auth=None
):
"""shared checks for positive cases
def test_init_passwd(self, client, policy, logger):
:type client: mock.Mock
:type policy: mock.Mock
:type logger: mock.Mock
:type host: str
:type port: int
:type username: str
:type password: str
:type private_keys: list
:type auth: SSHAuth
"""
_ssh = mock.call()
ssh = SSHClient(
@ -71,150 +191,572 @@ class TestSSHClient(TestCase):
port=port,
username=username,
password=password,
private_keys=private_keys)
private_keys=private_keys,
auth=auth
)
client.assert_called_once()
policy.assert_called_once()
expected_calls = [
_ssh,
_ssh.set_missing_host_key_policy('AutoAddPolicy'),
_ssh.connect(
host, password=password,
port=port, username=username),
_ssh.open_sftp()
]
if auth is None:
if private_keys is None or len(private_keys) == 0:
logger.assert_has_calls((
mock.call.debug(
'SSHClient('
'host={host}, port={port}, username={username}): '
'initialization by username/password/private_keys '
'is deprecated in favor of SSHAuth usage. '
'Please update your code'.format(
host=host, port=port, username=username
)),
mock.call.info(
'{0}:{1}> SSHAuth was made from old style creds: '
'SSHAuth for {2}'.format(host, port, username))
))
else:
logger.assert_has_calls((
mock.call.debug(
'SSHClient('
'host={host}, port={port}, username={username}): '
'initialization by username/password/private_keys '
'is deprecated in favor of SSHAuth usage. '
'Please update your code'.format(
host=host, port=port, username=username
)),
mock.call.debug(
'Main key has been updated, public key is: \n'
'{}'.format(ssh.auth.public_key)),
mock.call.info(
'{0}:{1}> SSHAuth was made from old style creds: '
'SSHAuth for {2}'.format(host, port, username))
))
else:
logger.assert_not_called()
self.assertIn(expected_calls, client.mock_calls)
self.check_defaults(ssh, host, port, username, password, private_keys)
self.assertIsNone(ssh.private_key)
self.assertIsNone(ssh.public_key)
self.assertIn(
mock.call.debug("Connect to '{0}:{1}' as '{2}:{3}'".format(
host, port, username, password
)),
logger.mock_calls
)
sftp = ssh._sftp
self.assertEqual(sftp, client().open_sftp())
def test_init_keys(self, client, policy, logger):
_ssh = mock.call()
private_keys = gen_private_keys(1)
ssh = SSHClient(
host=host,
port=port,
username=username,
password=password,
private_keys=private_keys)
client.assert_called_once()
policy.assert_called_once()
expected_calls = [
_ssh,
_ssh.set_missing_host_key_policy('AutoAddPolicy'),
_ssh.connect(
host, password=password, pkey=private_keys[0],
port=port, username=username),
_ssh.open_sftp()
]
self.assertIn(expected_calls, client.mock_calls)
self.check_defaults(ssh, host, port, username, password, private_keys)
self.assertEqual(ssh.private_key, private_keys[0])
self.assertEqual(ssh.public_key, gen_public_key(private_keys[0]))
self.assertIn(
mock.call.debug("Connect to '{0}:{1}' as '{2}:{3}'".format(
host, port, username, password
)),
logger.mock_calls
)
def test_init_as_context(self, client, policy, logger):
_ssh = mock.call()
private_keys = gen_private_keys(1)
with SSHClient(
host=host,
port=port,
username=username,
password=password,
private_keys=private_keys) as ssh:
client.assert_called_once()
policy.assert_called_once()
expected_calls = [
_ssh,
_ssh.set_missing_host_key_policy('AutoAddPolicy'),
_ssh.connect(
host, password=password, pkey=private_keys[0],
port=port, username=username),
_ssh.open_sftp()
]
if auth is None:
if private_keys is None or len(private_keys) == 0:
pkey = None
expected_calls = [
_ssh,
_ssh.set_missing_host_key_policy('AutoAddPolicy'),
_ssh.connect(
hostname=host, password=password,
pkey=pkey,
port=port, username=username),
]
else:
pkey = private_keys[0]
expected_calls = [
_ssh,
_ssh.set_missing_host_key_policy('AutoAddPolicy'),
_ssh.connect(
hostname=host, password=password,
pkey=None,
port=port, username=username),
_ssh.connect(
hostname=host, password=password,
pkey=pkey,
port=port, username=username),
]
self.assertIn(expected_calls, client.mock_calls)
self.check_defaults(ssh, host, port, username, password,
private_keys)
self.assertEqual(
ssh.auth,
SSHAuth(
username=username,
password=password,
keys=private_keys
)
)
else:
self.assertEqual(ssh.auth, auth)
self.assertIn(
mock.call.debug("Connect to '{0}:{1}' as '{2}:{3}'".format(
host, port, username, password
)),
logger.mock_calls
sftp = ssh._sftp
self.assertEqual(sftp, client().open_sftp())
self.assertEqual(ssh._ssh, client())
self.assertEqual(ssh.hostname, host)
self.assertEqual(ssh.port, port)
self.assertEqual(
repr(ssh),
'{cls}(host={host}, port={port}, auth={auth!r})'.format(
cls=ssh.__class__.__name__, host=ssh.hostname,
port=ssh.port,
auth=ssh.auth
)
)
def test_init_host(self, client, policy, logger, sleep):
"""Test with host only set"""
self.init_checks(
client, policy, logger,
host=host)
def test_init_alternate_port(self, client, policy, logger, sleep):
"""Test with alternate port"""
self.init_checks(
client, policy, logger,
host=host,
port=2222
)
def test_init_username(self, client, policy, logger, sleep):
"""Test with username only set from creds"""
self.init_checks(
client, policy, logger,
host=host,
username=username
)
def test_init_username_password(self, client, policy, logger, sleep):
"""Test with username and password set from creds"""
self.init_checks(
client, policy, logger,
host=host,
username=username,
password=password
)
def test_init_fail_sftp(self, client, policy, logger):
_ssh = mock.Mock()
client.return_value = _ssh
open_sftp = mock.Mock(parent=_ssh, side_effect=paramiko.SSHException)
_ssh.attach_mock(open_sftp, 'open_sftp')
warning = mock.Mock(parent=logger)
logger.attach_mock(warning, 'warning')
ssh = SSHClient(
def test_init_username_password_empty_keys(
self, client, policy, logger, sleep):
"""Test with username, password and empty keys set from creds"""
self.init_checks(
client, policy, logger,
host=host,
port=port,
username=username,
password=password,
private_keys=private_keys)
private_keys=[]
)
def test_init_username_single_key(self, client, policy, logger, sleep):
"""Test with username and single key set from creds"""
connect = mock.Mock(
side_effect=[
paramiko.AuthenticationException, mock.Mock()
])
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
self.init_checks(
client, policy, logger,
host=host,
username=username,
private_keys=gen_private_keys(1)
)
def test_init_username_password_single_key(
self, client, policy, logger, sleep):
"""Test with username, password and single key set from creds"""
connect = mock.Mock(
side_effect=[
paramiko.AuthenticationException, mock.Mock()
])
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
self.init_checks(
client, policy, logger,
host=host,
username=username,
password=password,
private_keys=gen_private_keys(1)
)
def test_init_username_multiple_keys(self, client, policy, logger, sleep):
"""Test with username and multiple keys set from creds"""
connect = mock.Mock(
side_effect=[
paramiko.AuthenticationException, mock.Mock()
])
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
self.init_checks(
client, policy, logger,
host=host,
username=username,
private_keys=gen_private_keys(2)
)
def test_init_username_password_multiple_keys(
self, client, policy, logger, sleep):
"""Test with username, password and multiple keys set from creds"""
connect = mock.Mock(
side_effect=[
paramiko.AuthenticationException, mock.Mock()
])
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
connect = mock.Mock(
side_effect=[
paramiko.AuthenticationException, mock.Mock()
])
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
self.init_checks(
client, policy, logger,
host=host,
username=username,
password=password,
private_keys=gen_private_keys(2)
)
def test_init_auth(
self, client, policy, logger, sleep):
self.init_checks(
client, policy, logger,
host=host,
auth=SSHAuth(
username=username,
password=password,
key=gen_private_keys(1).pop()
)
)
def test_init_auth_break(
self, client, policy, logger, sleep):
self.init_checks(
client, policy, logger,
host=host,
username='Invalid',
password='Invalid',
private_keys=gen_private_keys(1),
auth=SSHAuth(
username=username,
password=password,
key=gen_private_keys(1).pop()
)
)
def test_init_context(
self, client, policy, logger, sleep):
with SSHClient(host=host, auth=SSHAuth()) as ssh:
client.assert_called_once()
policy.assert_called_once()
logger.assert_not_called()
self.assertEqual(ssh.auth, SSHAuth())
sftp = ssh._sftp
self.assertEqual(sftp, client().open_sftp())
self.assertEqual(ssh._ssh, client())
self.assertEqual(ssh.hostname, host)
self.assertEqual(ssh.port, port)
def test_init_clear_failed(
self, client, policy, logger, sleep):
"""Test reconnect
:type client: mock.Mock
:type policy: mock.Mock
:type logger: mock.Mock
"""
_ssh = mock.Mock()
_ssh.attach_mock(
mock.Mock(
side_effect=[
Exception('Mocked SSH close()'),
mock.Mock()
]),
'close')
_sftp = mock.Mock()
_sftp.attach_mock(
mock.Mock(
side_effect=[
Exception('Mocked SFTP close()'),
mock.Mock()
]),
'close')
client.return_value = _ssh
_ssh.attach_mock(mock.Mock(return_value=_sftp), 'open_sftp')
ssh = SSHClient(host=host, auth=SSHAuth())
client.assert_called_once()
policy.assert_called_once()
logger.assert_not_called()
self.assertEqual(ssh.auth, SSHAuth())
sftp = ssh._sftp
self.assertEqual(sftp, _sftp)
self.assertEqual(ssh._ssh, _ssh)
self.assertEqual(ssh.hostname, host)
self.assertEqual(ssh.port, port)
logger.reset_mock()
ssh.clear()
logger.assert_has_calls((
mock.call.exception('Could not close ssh connection'),
mock.call.exception('Could not close sftp connection'),
))
def test_init_reconnect(
self, client, policy, logger, sleep):
"""Test reconnect
:type client: mock.Mock
:type policy: mock.Mock
:type logger: mock.Mock
"""
ssh = SSHClient(host=host, auth=SSHAuth())
client.assert_called_once()
policy.assert_called_once()
logger.assert_not_called()
self.assertEqual(ssh.auth, SSHAuth())
sftp = ssh._sftp
self.assertEqual(sftp, client().open_sftp())
self.assertEqual(ssh._ssh, client())
client.reset_mock()
policy.reset_mock()
self.assertEqual(ssh.hostname, host)
self.assertEqual(ssh.port, port)
ssh.reconnect()
_ssh = mock.call()
expected_calls = [
_ssh.close(),
_ssh,
_ssh.set_missing_host_key_policy('AutoAddPolicy'),
_ssh.connect(
hostname='127.0.0.1',
password=None,
pkey=None,
port=22,
username=None),
]
self.assertIn(
expected_calls,
client.mock_calls
)
client.assert_called_once()
policy.assert_called_once()
self.check_defaults(ssh, host, port, username, password, private_keys)
logger.assert_not_called()
warning.assert_called_once_with(
'SFTP enable failed! SSH only is accessible.'
self.assertEqual(ssh.auth, SSHAuth())
sftp = ssh._sftp
self.assertEqual(sftp, client().open_sftp())
self.assertEqual(ssh._ssh, client())
def test_init_password_required(
self, client, policy, logger, sleep):
connect = mock.Mock(side_effect=paramiko.PasswordRequiredException)
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
with self.assertRaises(paramiko.PasswordRequiredException):
SSHClient(host=host, auth=SSHAuth())
logger.assert_has_calls((
mock.call.exception('No password has been set!'),
))
def test_init_password_broken(
self, client, policy, logger, sleep):
connect = mock.Mock(side_effect=paramiko.PasswordRequiredException)
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
with self.assertRaises(paramiko.PasswordRequiredException):
SSHClient(host=host, auth=SSHAuth(password=password))
logger.assert_has_calls((
mock.call.critical(
'Unexpected PasswordRequiredException, '
'when password is set!'
),
))
def test_init_auth_impossible_password(
self, client, policy, logger, sleep):
connect = mock.Mock(side_effect=paramiko.AuthenticationException)
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
with self.assertRaises(paramiko.AuthenticationException):
SSHClient(host=host, auth=SSHAuth(password=password))
logger.assert_has_calls(
(
mock.call.exception(
'Connection using stored authentication info failed!'),
) * 3
)
def test_init_auth_impossible_key(
self, client, policy, logger, sleep):
connect = mock.Mock(side_effect=paramiko.AuthenticationException)
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
with self.assertRaises(paramiko.AuthenticationException):
SSHClient(
host=host,
auth=SSHAuth(key=gen_private_keys(1).pop())
)
logger.assert_has_calls(
(
mock.call.exception(
'Connection using stored authentication info failed!'),
) * 3
)
def test_init_auth_pass_no_key(
self, client, policy, logger, sleep):
connect = mock.Mock(
side_effect=[
paramiko.AuthenticationException,
mock.Mock()
])
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
key = gen_private_keys(1).pop()
ssh = SSHClient(
host=host,
auth=SSHAuth(
username=username,
password=password,
key=key
)
)
client.assert_called_once()
policy.assert_called_once()
logger.assert_has_calls((
mock.call.debug(
'Main key has been updated, public key is: \nNone'),
))
self.assertEqual(
ssh.auth,
SSHAuth(
username=username,
password=password,
keys=[key]
)
)
sftp = ssh._sftp
self.assertEqual(sftp, client().open_sftp())
self.assertEqual(ssh._ssh, client())
def test_init_auth_brute_impossible(
self, client, policy, logger, sleep):
connect = mock.Mock(side_effect=paramiko.AuthenticationException)
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
with self.assertRaises(paramiko.AuthenticationException):
SSHClient(
host=host,
username=username,
private_keys=gen_private_keys(2))
logger.assert_has_calls(
(
mock.call.debug(
'SSHClient('
'host={host}, port={port}, username={username}): '
'initialization by username/password/private_keys '
'is deprecated in favor of SSHAuth usage. '
'Please update your code'.format(
host=host, port=port, username=username
)),
) + (
mock.call.exception(
'Connection using stored authentication info failed!'),
) * 3
)
def test_init_no_sftp(
self, client, policy, logger, sleep):
open_sftp = mock.Mock(side_effect=paramiko.SSHException)
_ssh = mock.Mock()
_ssh.attach_mock(open_sftp, 'open_sftp')
client.return_value = _ssh
ssh = SSHClient(host=host, auth=SSHAuth(password=password))
with self.assertRaises(paramiko.SSHException):
# pylint: disable=pointless-statement
# noinspection PyStatementEffect
ssh._sftp
# pylint: enable=pointless-statement
logger.assert_has_calls((
mock.call.debug('SFTP is not connected, try to connect...'),
mock.call.warning(
'SFTP enable failed! SSH only is accessible.'),
))
def test_init_sftp_repair(
self, client, policy, logger, sleep):
_sftp = mock.Mock()
open_sftp = mock.Mock(
side_effect=[
paramiko.SSHException,
_sftp, _sftp])
_ssh = mock.Mock()
_ssh.attach_mock(open_sftp, 'open_sftp')
client.return_value = _ssh
ssh = SSHClient(host=host, auth=SSHAuth(password=password))
with self.assertRaises(paramiko.SSHException):
# pylint: disable=pointless-statement
# noinspection PyStatementEffect
ssh._sftp
# pylint: enable=pointless-statement
warning.assert_has_calls([
mock.call('SFTP enable failed! SSH only is accessible.'),
mock.call('SFTP is not connected, try to reconnect'),
mock.call('SFTP enable failed! SSH only is accessible.')])
logger.reset_mock()
# Unblock sftp connection
# (reset_mock is not possible to use in this case)
_sftp = mock.Mock()
open_sftp = mock.Mock(parent=_ssh, return_value=_sftp)
_ssh.attach_mock(open_sftp, 'open_sftp')
sftp = ssh._sftp
self.assertEqual(sftp, _sftp)
self.assertEqual(sftp, open_sftp())
logger.assert_has_calls((
mock.call.debug('SFTP is not connected, try to connect...'),
))
@mock.patch('devops.helpers.ssh_client.logger', autospec=True)
@ -231,9 +773,10 @@ class TestExecute(TestCase):
return SSHClient(
host=host,
port=port,
username=username,
password=password
)
auth=SSHAuth(
username=username,
password=password
))
def test_execute_async(self, client, policy, logger):
chan = mock.Mock()
@ -331,8 +874,9 @@ class TestExecute(TestCase):
logger.mock_calls
)
@mock.patch('devops.helpers.ssh_client.SSHAuth.enter_password')
def test_execute_async_sudo_password(
self, client, policy, logger):
self, enter_password, client, policy, logger):
stdin = mock.Mock(name='stdin')
stdout = mock.Mock(name='stdout')
stdout_channel = mock.Mock()
@ -356,6 +900,7 @@ class TestExecute(TestCase):
get_transport.assert_called_once()
open_session.assert_called_once()
# raise ValueError(closed.mock_calls)
enter_password.assert_called_once_with(stdin)
stdin.assert_has_calls((mock.call.flush(), ))
self.assertIn(chan, result)
@ -373,8 +918,7 @@ class TestExecute(TestCase):
logger.mock_calls
)
@staticmethod
def get_patched_execute_async_retval(ec=0, stderr_val=True):
def get_patched_execute_async_retval(self, ec=0, stderr_val=True):
stderr = mock.Mock()
stdout = mock.Mock()
@ -472,9 +1016,10 @@ class TestExecute(TestCase):
ssh2 = SSHClient(
host=host2,
port=port,
username=username,
password=password
)
auth=SSHAuth(
username=username,
password=password
))
remotes = [ssh, ssh2]
@ -623,9 +1168,10 @@ class TestExecuteThrowHost(TestCase):
ssh = SSHClient(
host=host,
port=port,
username=username,
password=password
)
auth=SSHAuth(
username=username,
password=password
))
result = ssh.execute_through_host(target, command)
self.assertEqual(result, return_value)
@ -668,13 +1214,14 @@ class TestExecuteThrowHost(TestCase):
ssh = SSHClient(
host=host,
port=port,
username=username,
password=password
)
auth=SSHAuth(
username=username,
password=password
))
result = ssh.execute_through_host(
target, command,
username=_login, password=_password)
auth=SSHAuth(username=_login, password=_password))
self.assertEqual(result, return_value)
get_transport.assert_called_once()
open_channel.assert_called_once()
@ -709,9 +1256,10 @@ class TestSftp(TestCase):
ssh = SSHClient(
host=host,
port=port,
username=username,
password=password
)
auth=SSHAuth(
username=username,
password=password
))
return ssh, _sftp
def test_exists(self, client, policy, logger):
@ -801,9 +1349,10 @@ class TestSftp(TestCase):
ssh = SSHClient(
host=host,
port=port,
username=username,
password=password
)
auth=SSHAuth(
username=username,
password=password
))
# Path not exists
ssh.mkdir(path)
@ -825,9 +1374,10 @@ class TestSftp(TestCase):
ssh = SSHClient(
host=host,
port=port,
username=username,
password=password
)
auth=SSHAuth(
username=username,
password=password
))
# Path not exists
ssh.rm_rf(path)