diff --git a/nodepool/driver/openstack/handler.py b/nodepool/driver/openstack/handler.py index 6550ed494..1c02ab46f 100644 --- a/nodepool/driver/openstack/handler.py +++ b/nodepool/driver/openstack/handler.py @@ -194,18 +194,23 @@ class NodeLauncher(threading.Thread, stats.StatsReporter): self._node.interface_ip, self._node.public_ipv4, self._node.public_ipv6)) - # Get the SSH public keys for the new node and record in ZooKeeper + # wait and scan the new node and record in ZooKeeper host_keys = [] if self._pool.host_key_checking: try: self.log.debug( "Gathering host keys for node %s", self._node.id) - host_keys = utils.keyscan( - interface_ip, timeout=self._provider_config.boot_timeout) - if not host_keys: + # only gather host keys if the connection type is ssh + gather_host_keys = connection_type == 'ssh' + host_keys = utils.nodescan( + interface_ip, + timeout=self._provider_config.boot_timeout, + gather_hostkeys=gather_host_keys) + + if gather_host_keys and not host_keys: raise exceptions.LaunchKeyscanException( "Unable to gather host keys") - except exceptions.SSHTimeoutException: + except exceptions.ConnectionTimeoutException: self.logConsole(self._node.external_id, self._node.hostname) raise diff --git a/nodepool/driver/static/provider.py b/nodepool/driver/static/provider.py index 212f20648..42da57d81 100644 --- a/nodepool/driver/static/provider.py +++ b/nodepool/driver/static/provider.py @@ -16,7 +16,7 @@ import logging from nodepool import exceptions from nodepool.driver import Provider -from nodepool.nodeutils import keyscan +from nodepool.nodeutils import nodescan class StaticNodeError(Exception): @@ -36,11 +36,12 @@ class StaticNodeProvider(Provider): def checkHost(self, node): # Check node is reachable try: - keys = keyscan(node["name"], - port=node["ssh-port"], - timeout=node["timeout"]) - except exceptions.SSHTimeoutException: - raise StaticNodeError("%s: SSHTimeoutException" % node["name"]) + keys = nodescan(node["name"], + port=node["ssh-port"], + timeout=node["timeout"]) + except exceptions.ConnectionTimeoutException: + raise StaticNodeError( + "%s: ConnectionTimeoutException" % node["name"]) # Check node host-key if set(node["host-key"]).issubset(set(keys)): diff --git a/nodepool/exceptions.py b/nodepool/exceptions.py index c754e4943..44cabfe59 100755 --- a/nodepool/exceptions.py +++ b/nodepool/exceptions.py @@ -49,7 +49,7 @@ class TimeoutException(Exception): pass -class SSHTimeoutException(TimeoutException): +class ConnectionTimeoutException(TimeoutException): statsd_key = 'error.ssh' diff --git a/nodepool/nodeutils.py b/nodepool/nodeutils.py index 3c6de886b..39bfb0b7b 100755 --- a/nodepool/nodeutils.py +++ b/nodepool/nodeutils.py @@ -57,14 +57,17 @@ def set_node_ip(node): "Unable to find public IP of server") -def keyscan(ip, port=22, timeout=60): +def nodescan(ip, port=22, timeout=60, gather_hostkeys=True): ''' Scan the IP address for public SSH keys. Keys are returned formatted as: " " ''' if 'fake' in ip: - return ['ssh-rsa FAKEKEY'] + if gather_hostkeys: + return ['ssh-rsa FAKEKEY'] + else: + return [] addrinfo = socket.getaddrinfo(ip, port)[0] family = addrinfo[0] @@ -73,16 +76,18 @@ def keyscan(ip, port=22, timeout=60): keys = [] key = None for count in iterate_timeout( - timeout, exceptions.SSHTimeoutException, "ssh access"): + timeout, exceptions.ConnectionTimeoutException, + "connection on port %s" % port): sock = None t = None try: sock = socket.socket(family, socket.SOCK_STREAM) sock.settimeout(timeout) sock.connect(sockaddr) - t = paramiko.transport.Transport(sock) - t.start_client(timeout=timeout) - key = t.get_remote_server_key() + if gather_hostkeys: + t = paramiko.transport.Transport(sock) + t.start_client(timeout=timeout) + key = t.get_remote_server_key() break except socket.error as e: if e.errno not in [errno.ECONNREFUSED, errno.EHOSTUNREACH, None]: diff --git a/nodepool/tests/test_launcher.py b/nodepool/tests/test_launcher.py index 0b4d67b09..56a881358 100644 --- a/nodepool/tests/test_launcher.py +++ b/nodepool/tests/test_launcher.py @@ -949,6 +949,7 @@ class TestLauncher(tests.DBTestCase): self.assertEqual(len(nodes), 1) self.assertEqual('zuul', nodes[0].username) self.assertEqual('winrm', nodes[0].connection_type) + self.assertEqual(nodes[0].host_keys, []) def test_unmanaged_image_provider_name(self): """