Merge "Don't gather host keys for non ssh connections"

This commit is contained in:
Zuul 2018-04-13 15:20:56 +00:00 committed by Gerrit Code Review
commit 07327ab296
5 changed files with 30 additions and 18 deletions

View File

@ -196,18 +196,23 @@ class NodeLauncher(threading.Thread, stats.StatsReporter):
self._node.interface_ip, self._node.public_ipv4,
self._node.public_ipv6))
# Get the SSH public keys for the new node and record in ZooKeeper
# wait and scan the new node and record in ZooKeeper
host_keys = []
if self._pool.host_key_checking:
try:
self.log.debug(
"Gathering host keys for node %s", self._node.id)
host_keys = utils.keyscan(
interface_ip, timeout=self._provider_config.boot_timeout)
if not host_keys:
# only gather host keys if the connection type is ssh
gather_host_keys = connection_type == 'ssh'
host_keys = utils.nodescan(
interface_ip,
timeout=self._provider_config.boot_timeout,
gather_hostkeys=gather_host_keys)
if gather_host_keys and not host_keys:
raise exceptions.LaunchKeyscanException(
"Unable to gather host keys")
except exceptions.SSHTimeoutException:
except exceptions.ConnectionTimeoutException:
self.logConsole(self._node.external_id, self._node.hostname)
raise

View File

@ -16,7 +16,7 @@ import logging
from nodepool import exceptions
from nodepool.driver import Provider
from nodepool.nodeutils import keyscan
from nodepool.nodeutils import nodescan
class StaticNodeError(Exception):
@ -36,11 +36,12 @@ class StaticNodeProvider(Provider):
def checkHost(self, node):
# Check node is reachable
try:
keys = keyscan(node["name"],
port=node["ssh-port"],
timeout=node["timeout"])
except exceptions.SSHTimeoutException:
raise StaticNodeError("%s: SSHTimeoutException" % node["name"])
keys = nodescan(node["name"],
port=node["ssh-port"],
timeout=node["timeout"])
except exceptions.ConnectionTimeoutException:
raise StaticNodeError(
"%s: ConnectionTimeoutException" % node["name"])
# Check node host-key
if set(node["host-key"]).issubset(set(keys)):

View File

@ -49,7 +49,7 @@ class TimeoutException(Exception):
pass
class SSHTimeoutException(TimeoutException):
class ConnectionTimeoutException(TimeoutException):
statsd_key = 'error.ssh'

View File

@ -57,14 +57,17 @@ def set_node_ip(node):
"Unable to find public IP of server")
def keyscan(ip, port=22, timeout=60):
def nodescan(ip, port=22, timeout=60, gather_hostkeys=True):
'''
Scan the IP address for public SSH keys.
Keys are returned formatted as: "<type> <base64_string>"
'''
if 'fake' in ip:
return ['ssh-rsa FAKEKEY']
if gather_hostkeys:
return ['ssh-rsa FAKEKEY']
else:
return []
addrinfo = socket.getaddrinfo(ip, port)[0]
family = addrinfo[0]
@ -73,16 +76,18 @@ def keyscan(ip, port=22, timeout=60):
keys = []
key = None
for count in iterate_timeout(
timeout, exceptions.SSHTimeoutException, "ssh access"):
timeout, exceptions.ConnectionTimeoutException,
"connection on port %s" % port):
sock = None
t = None
try:
sock = socket.socket(family, socket.SOCK_STREAM)
sock.settimeout(timeout)
sock.connect(sockaddr)
t = paramiko.transport.Transport(sock)
t.start_client(timeout=timeout)
key = t.get_remote_server_key()
if gather_hostkeys:
t = paramiko.transport.Transport(sock)
t.start_client(timeout=timeout)
key = t.get_remote_server_key()
break
except socket.error as e:
if e.errno not in [errno.ECONNREFUSED, errno.EHOSTUNREACH, None]:

View File

@ -952,6 +952,7 @@ class TestLauncher(tests.DBTestCase):
self.assertEqual(len(nodes), 1)
self.assertEqual('zuul', nodes[0].username)
self.assertEqual('winrm', nodes[0].connection_type)
self.assertEqual(nodes[0].host_keys, [])
def test_unmanaged_image_provider_name(self):
"""