Merge "Don't gather host keys for non ssh connections"
This commit is contained in:
commit
07327ab296
|
@ -196,18 +196,23 @@ class NodeLauncher(threading.Thread, stats.StatsReporter):
|
|||
self._node.interface_ip, self._node.public_ipv4,
|
||||
self._node.public_ipv6))
|
||||
|
||||
# Get the SSH public keys for the new node and record in ZooKeeper
|
||||
# wait and scan the new node and record in ZooKeeper
|
||||
host_keys = []
|
||||
if self._pool.host_key_checking:
|
||||
try:
|
||||
self.log.debug(
|
||||
"Gathering host keys for node %s", self._node.id)
|
||||
host_keys = utils.keyscan(
|
||||
interface_ip, timeout=self._provider_config.boot_timeout)
|
||||
if not host_keys:
|
||||
# only gather host keys if the connection type is ssh
|
||||
gather_host_keys = connection_type == 'ssh'
|
||||
host_keys = utils.nodescan(
|
||||
interface_ip,
|
||||
timeout=self._provider_config.boot_timeout,
|
||||
gather_hostkeys=gather_host_keys)
|
||||
|
||||
if gather_host_keys and not host_keys:
|
||||
raise exceptions.LaunchKeyscanException(
|
||||
"Unable to gather host keys")
|
||||
except exceptions.SSHTimeoutException:
|
||||
except exceptions.ConnectionTimeoutException:
|
||||
self.logConsole(self._node.external_id, self._node.hostname)
|
||||
raise
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ import logging
|
|||
|
||||
from nodepool import exceptions
|
||||
from nodepool.driver import Provider
|
||||
from nodepool.nodeutils import keyscan
|
||||
from nodepool.nodeutils import nodescan
|
||||
|
||||
|
||||
class StaticNodeError(Exception):
|
||||
|
@ -36,11 +36,12 @@ class StaticNodeProvider(Provider):
|
|||
def checkHost(self, node):
|
||||
# Check node is reachable
|
||||
try:
|
||||
keys = keyscan(node["name"],
|
||||
port=node["ssh-port"],
|
||||
timeout=node["timeout"])
|
||||
except exceptions.SSHTimeoutException:
|
||||
raise StaticNodeError("%s: SSHTimeoutException" % node["name"])
|
||||
keys = nodescan(node["name"],
|
||||
port=node["ssh-port"],
|
||||
timeout=node["timeout"])
|
||||
except exceptions.ConnectionTimeoutException:
|
||||
raise StaticNodeError(
|
||||
"%s: ConnectionTimeoutException" % node["name"])
|
||||
|
||||
# Check node host-key
|
||||
if set(node["host-key"]).issubset(set(keys)):
|
||||
|
|
|
@ -49,7 +49,7 @@ class TimeoutException(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class SSHTimeoutException(TimeoutException):
|
||||
class ConnectionTimeoutException(TimeoutException):
|
||||
statsd_key = 'error.ssh'
|
||||
|
||||
|
||||
|
|
|
@ -57,14 +57,17 @@ def set_node_ip(node):
|
|||
"Unable to find public IP of server")
|
||||
|
||||
|
||||
def keyscan(ip, port=22, timeout=60):
|
||||
def nodescan(ip, port=22, timeout=60, gather_hostkeys=True):
|
||||
'''
|
||||
Scan the IP address for public SSH keys.
|
||||
|
||||
Keys are returned formatted as: "<type> <base64_string>"
|
||||
'''
|
||||
if 'fake' in ip:
|
||||
return ['ssh-rsa FAKEKEY']
|
||||
if gather_hostkeys:
|
||||
return ['ssh-rsa FAKEKEY']
|
||||
else:
|
||||
return []
|
||||
|
||||
addrinfo = socket.getaddrinfo(ip, port)[0]
|
||||
family = addrinfo[0]
|
||||
|
@ -73,16 +76,18 @@ def keyscan(ip, port=22, timeout=60):
|
|||
keys = []
|
||||
key = None
|
||||
for count in iterate_timeout(
|
||||
timeout, exceptions.SSHTimeoutException, "ssh access"):
|
||||
timeout, exceptions.ConnectionTimeoutException,
|
||||
"connection on port %s" % port):
|
||||
sock = None
|
||||
t = None
|
||||
try:
|
||||
sock = socket.socket(family, socket.SOCK_STREAM)
|
||||
sock.settimeout(timeout)
|
||||
sock.connect(sockaddr)
|
||||
t = paramiko.transport.Transport(sock)
|
||||
t.start_client(timeout=timeout)
|
||||
key = t.get_remote_server_key()
|
||||
if gather_hostkeys:
|
||||
t = paramiko.transport.Transport(sock)
|
||||
t.start_client(timeout=timeout)
|
||||
key = t.get_remote_server_key()
|
||||
break
|
||||
except socket.error as e:
|
||||
if e.errno not in [errno.ECONNREFUSED, errno.EHOSTUNREACH, None]:
|
||||
|
|
|
@ -952,6 +952,7 @@ class TestLauncher(tests.DBTestCase):
|
|||
self.assertEqual(len(nodes), 1)
|
||||
self.assertEqual('zuul', nodes[0].username)
|
||||
self.assertEqual('winrm', nodes[0].connection_type)
|
||||
self.assertEqual(nodes[0].host_keys, [])
|
||||
|
||||
def test_unmanaged_image_provider_name(self):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue