utils: Allow discovery of private key in ~/.ssh

A SSHPool class object, used to hold ssh connections, fails to
authenticate to the server in its create () method if the path to
SSH private key or user password is not passed during object creation.
The create () method does not allow private key to be discovered in
the default ~/.ssh folder when trying to authenticate or connect
to the SSH server. Instead allow auto discovery of the key in the
paramiko SSHClient's connect method (called in create () method) if the
path to private key or password is not provided.

Also make the following minor cleanups in the

create() method of SSHPool class:
- pass the path to the key file input directly to the SSHClient's
  connect () method using the appropriate parameter, keyfile, instead
  of picking key from key file and passing that as 'pkey' parameter.
- restrict the try except block to only the steps involved in
  connecting to the SSH server.
- reraise the exception that would be raised by the Paramiko library
  instead of raising own exception.

initialization method of SSHPool class:
- rename the attribute privatekey as path_to_private_key to accurately
  reflect what it refers to, the path of the private key file.

Change-Id: I590702d97086d33245894fd686250e75e8e359f2
Closes-Bug: #1412782
This commit is contained in:
Ramana Raja 2015-01-20 12:51:57 +05:30
parent b5edd4f9c9
commit 1f00b53750
3 changed files with 49 additions and 27 deletions

View File

@ -477,3 +477,7 @@ class GaneshaCommandFailure(ProcessExecutionError):
class InvalidSqliteDB(Invalid):
message = _("Invalid Sqlite database.")
class SSHException(ManilaException):
message = _("Exception in SSH protocol negotiation or logic.")

View File

@ -415,7 +415,7 @@ class FakeSSHClient(object):
pass
def connect(self, ip, port=22, username=None, password=None,
pkey=None, timeout=10):
key_filename=None, look_for_keys=None, timeout=10):
pass
def get_transport(self):
@ -473,28 +473,44 @@ class SSHPoolTestCase(test.TestCase):
fake_ssh_client.connect.assert_called_once_with(
"127.0.0.1", port=22, username="test",
password="test", pkey=None, timeout=10)
password="test", key_filename=None, look_for_keys=False,
timeout=10)
def test_create_ssh_with_key(self):
key = os.path.expanduser("fake_key")
path_to_private_key = "/fakepath/to/privatekey"
fake_ssh_client = mock.Mock()
ssh_pool = utils.SSHPool("127.0.0.1", 22, 10, "test",
privatekey="fake_key")
privatekey="/fakepath/to/privatekey")
with mock.patch.object(paramiko, "SSHClient",
return_value=fake_ssh_client):
with mock.patch.object(paramiko.RSAKey, "from_private_key_file",
return_value=key) as from_private_key_mock:
ssh_pool.create()
from_private_key_mock.assert_called_once_with(key)
fake_ssh_client.connect.assert_called_once_with(
"127.0.0.1", port=22, username="test",
password=None, pkey=key, timeout=10)
ssh_pool.create()
fake_ssh_client.connect.assert_called_once_with(
"127.0.0.1", port=22, username="test", password=None,
key_filename=path_to_private_key, look_for_keys=False,
timeout=10)
def test_create_ssh_with_nothing(self):
fake_ssh_client = mock.Mock()
ssh_pool = utils.SSHPool("127.0.0.1", 22, 10, "test")
with mock.patch.object(paramiko, "SSHClient"):
self.assertRaises(paramiko.SSHException, ssh_pool.create)
with mock.patch.object(paramiko, "SSHClient",
return_value=fake_ssh_client):
ssh_pool.create()
fake_ssh_client.connect.assert_called_once_with(
"127.0.0.1", port=22, username="test", password=None,
key_filename=None, look_for_keys=True,
timeout=10)
def test_create_ssh_error_connecting(self):
attrs = {'connect.side_effect': paramiko.SSHException, }
fake_ssh_client = mock.Mock(**attrs)
ssh_pool = utils.SSHPool("127.0.0.1", 22, 10, "test")
with mock.patch.object(paramiko, "SSHClient",
return_value=fake_ssh_client):
self.assertRaises(exception.SSHException, ssh_pool.create)
fake_ssh_client.connect.assert_called_once_with(
"127.0.0.1", port=22, username="test", password=None,
key_filename=None, look_for_keys=True,
timeout=10)
def test_closed_reopend_ssh_connections(self):
with mock.patch.object(paramiko, "SSHClient",

View File

@ -80,26 +80,27 @@ class SSHPool(pools.Pool):
self.login = login
self.password = password
self.conn_timeout = conn_timeout if conn_timeout else None
self.privatekey = privatekey
self.path_to_private_key = privatekey
super(SSHPool, self).__init__(*args, **kwargs)
def create(self):
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
look_for_keys = True
if self.path_to_private_key:
self.path_to_private_key = os.path.expanduser(
self.path_to_private_key)
look_for_keys = False
elif self.password:
look_for_keys = False
try:
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
if self.privatekey:
pkfile = os.path.expanduser(self.privatekey)
self.privatekey = paramiko.RSAKey.from_private_key_file(pkfile)
elif not self.password:
msg = _("Specify a password or private_key")
raise exception.ManilaException(msg)
ssh.connect(self.ip,
port=self.port,
username=self.login,
password=self.password,
pkey=self.privatekey,
key_filename=self.path_to_private_key,
look_for_keys=look_for_keys,
timeout=self.conn_timeout)
# Paramiko by default sets the socket timeout to 0.1 seconds,
# ignoring what we set thru the sshclient. This doesn't help for
# keeping long lived connections. Hence we have to bypass it, by
@ -113,9 +114,10 @@ class SSHPool(pools.Pool):
transport.set_keepalive(self.conn_timeout)
return ssh
except Exception as e:
msg = _("Error connecting via ssh: %s") % e
msg = _("Check whether private key or password are correctly "
"set. Error connecting via ssh: %s") % e
LOG.error(msg)
raise paramiko.SSHException(msg)
raise exception.SSHException(msg)
def get(self):
"""Return an item from the pool, when one is available.