From 1f00b5375076f0bf20888d3342ad459597b20f62 Mon Sep 17 00:00:00 2001 From: Ramana Raja Date: Tue, 20 Jan 2015 12:51:57 +0530 Subject: [PATCH] utils: Allow discovery of private key in ~/.ssh A SSHPool class object, used to hold ssh connections, fails to authenticate to the server in its create () method if the path to SSH private key or user password is not passed during object creation. The create () method does not allow private key to be discovered in the default ~/.ssh folder when trying to authenticate or connect to the SSH server. Instead allow auto discovery of the key in the paramiko SSHClient's connect method (called in create () method) if the path to private key or password is not provided. Also make the following minor cleanups in the create() method of SSHPool class: - pass the path to the key file input directly to the SSHClient's connect () method using the appropriate parameter, keyfile, instead of picking key from key file and passing that as 'pkey' parameter. - restrict the try except block to only the steps involved in connecting to the SSH server. - reraise the exception that would be raised by the Paramiko library instead of raising own exception. initialization method of SSHPool class: - rename the attribute privatekey as path_to_private_key to accurately reflect what it refers to, the path of the private key file. Change-Id: I590702d97086d33245894fd686250e75e8e359f2 Closes-Bug: #1412782 --- manila/exception.py | 4 ++++ manila/tests/test_utils.py | 44 ++++++++++++++++++++++++++------------ manila/utils.py | 28 +++++++++++++----------- 3 files changed, 49 insertions(+), 27 deletions(-) diff --git a/manila/exception.py b/manila/exception.py index 0a3b1748cf..8cda8ea5b2 100644 --- a/manila/exception.py +++ b/manila/exception.py @@ -477,3 +477,7 @@ class GaneshaCommandFailure(ProcessExecutionError): class InvalidSqliteDB(Invalid): message = _("Invalid Sqlite database.") + + +class SSHException(ManilaException): + message = _("Exception in SSH protocol negotiation or logic.") diff --git a/manila/tests/test_utils.py b/manila/tests/test_utils.py index 47a9186586..b5fd800d6d 100644 --- a/manila/tests/test_utils.py +++ b/manila/tests/test_utils.py @@ -415,7 +415,7 @@ class FakeSSHClient(object): pass def connect(self, ip, port=22, username=None, password=None, - pkey=None, timeout=10): + key_filename=None, look_for_keys=None, timeout=10): pass def get_transport(self): @@ -473,28 +473,44 @@ class SSHPoolTestCase(test.TestCase): fake_ssh_client.connect.assert_called_once_with( "127.0.0.1", port=22, username="test", - password="test", pkey=None, timeout=10) + password="test", key_filename=None, look_for_keys=False, + timeout=10) def test_create_ssh_with_key(self): - key = os.path.expanduser("fake_key") + path_to_private_key = "/fakepath/to/privatekey" fake_ssh_client = mock.Mock() ssh_pool = utils.SSHPool("127.0.0.1", 22, 10, "test", - privatekey="fake_key") + privatekey="/fakepath/to/privatekey") with mock.patch.object(paramiko, "SSHClient", return_value=fake_ssh_client): - with mock.patch.object(paramiko.RSAKey, "from_private_key_file", - return_value=key) as from_private_key_mock: - - ssh_pool.create() - from_private_key_mock.assert_called_once_with(key) - fake_ssh_client.connect.assert_called_once_with( - "127.0.0.1", port=22, username="test", - password=None, pkey=key, timeout=10) + ssh_pool.create() + fake_ssh_client.connect.assert_called_once_with( + "127.0.0.1", port=22, username="test", password=None, + key_filename=path_to_private_key, look_for_keys=False, + timeout=10) def test_create_ssh_with_nothing(self): + fake_ssh_client = mock.Mock() ssh_pool = utils.SSHPool("127.0.0.1", 22, 10, "test") - with mock.patch.object(paramiko, "SSHClient"): - self.assertRaises(paramiko.SSHException, ssh_pool.create) + with mock.patch.object(paramiko, "SSHClient", + return_value=fake_ssh_client): + ssh_pool.create() + fake_ssh_client.connect.assert_called_once_with( + "127.0.0.1", port=22, username="test", password=None, + key_filename=None, look_for_keys=True, + timeout=10) + + def test_create_ssh_error_connecting(self): + attrs = {'connect.side_effect': paramiko.SSHException, } + fake_ssh_client = mock.Mock(**attrs) + ssh_pool = utils.SSHPool("127.0.0.1", 22, 10, "test") + with mock.patch.object(paramiko, "SSHClient", + return_value=fake_ssh_client): + self.assertRaises(exception.SSHException, ssh_pool.create) + fake_ssh_client.connect.assert_called_once_with( + "127.0.0.1", port=22, username="test", password=None, + key_filename=None, look_for_keys=True, + timeout=10) def test_closed_reopend_ssh_connections(self): with mock.patch.object(paramiko, "SSHClient", diff --git a/manila/utils.py b/manila/utils.py index 88cbd7b4ff..117da0b0b6 100644 --- a/manila/utils.py +++ b/manila/utils.py @@ -80,26 +80,27 @@ class SSHPool(pools.Pool): self.login = login self.password = password self.conn_timeout = conn_timeout if conn_timeout else None - self.privatekey = privatekey + self.path_to_private_key = privatekey super(SSHPool, self).__init__(*args, **kwargs) def create(self): + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + look_for_keys = True + if self.path_to_private_key: + self.path_to_private_key = os.path.expanduser( + self.path_to_private_key) + look_for_keys = False + elif self.password: + look_for_keys = False try: - ssh = paramiko.SSHClient() - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - if self.privatekey: - pkfile = os.path.expanduser(self.privatekey) - self.privatekey = paramiko.RSAKey.from_private_key_file(pkfile) - elif not self.password: - msg = _("Specify a password or private_key") - raise exception.ManilaException(msg) ssh.connect(self.ip, port=self.port, username=self.login, password=self.password, - pkey=self.privatekey, + key_filename=self.path_to_private_key, + look_for_keys=look_for_keys, timeout=self.conn_timeout) - # Paramiko by default sets the socket timeout to 0.1 seconds, # ignoring what we set thru the sshclient. This doesn't help for # keeping long lived connections. Hence we have to bypass it, by @@ -113,9 +114,10 @@ class SSHPool(pools.Pool): transport.set_keepalive(self.conn_timeout) return ssh except Exception as e: - msg = _("Error connecting via ssh: %s") % e + msg = _("Check whether private key or password are correctly " + "set. Error connecting via ssh: %s") % e LOG.error(msg) - raise paramiko.SSHException(msg) + raise exception.SSHException(msg) def get(self): """Return an item from the pool, when one is available.