Don't gather host keys for non ssh connections
In case of an image with the connection type winrm we cannot scan the ssh host keys. So in case the connection type is not ssh we need to skip gathering the host keys. Change-Id: I56f308baa10d40461cf4a919bbcdc4467e85a551
This commit is contained in:
parent
ee78684521
commit
2da274e2ae
@ -194,18 +194,23 @@ class NodeLauncher(threading.Thread, stats.StatsReporter):
|
||||
self._node.interface_ip, self._node.public_ipv4,
|
||||
self._node.public_ipv6))
|
||||
|
||||
# Get the SSH public keys for the new node and record in ZooKeeper
|
||||
# wait and scan the new node and record in ZooKeeper
|
||||
host_keys = []
|
||||
if self._pool.host_key_checking:
|
||||
try:
|
||||
self.log.debug(
|
||||
"Gathering host keys for node %s", self._node.id)
|
||||
host_keys = utils.keyscan(
|
||||
interface_ip, timeout=self._provider_config.boot_timeout)
|
||||
if not host_keys:
|
||||
# only gather host keys if the connection type is ssh
|
||||
gather_host_keys = connection_type == 'ssh'
|
||||
host_keys = utils.nodescan(
|
||||
interface_ip,
|
||||
timeout=self._provider_config.boot_timeout,
|
||||
gather_hostkeys=gather_host_keys)
|
||||
|
||||
if gather_host_keys and not host_keys:
|
||||
raise exceptions.LaunchKeyscanException(
|
||||
"Unable to gather host keys")
|
||||
except exceptions.SSHTimeoutException:
|
||||
except exceptions.ConnectionTimeoutException:
|
||||
self.logConsole(self._node.external_id, self._node.hostname)
|
||||
raise
|
||||
|
||||
|
@ -16,7 +16,7 @@ import logging
|
||||
|
||||
from nodepool import exceptions
|
||||
from nodepool.driver import Provider
|
||||
from nodepool.nodeutils import keyscan
|
||||
from nodepool.nodeutils import nodescan
|
||||
|
||||
|
||||
class StaticNodeError(Exception):
|
||||
@ -36,11 +36,12 @@ class StaticNodeProvider(Provider):
|
||||
def checkHost(self, node):
|
||||
# Check node is reachable
|
||||
try:
|
||||
keys = keyscan(node["name"],
|
||||
port=node["ssh-port"],
|
||||
timeout=node["timeout"])
|
||||
except exceptions.SSHTimeoutException:
|
||||
raise StaticNodeError("%s: SSHTimeoutException" % node["name"])
|
||||
keys = nodescan(node["name"],
|
||||
port=node["ssh-port"],
|
||||
timeout=node["timeout"])
|
||||
except exceptions.ConnectionTimeoutException:
|
||||
raise StaticNodeError(
|
||||
"%s: ConnectionTimeoutException" % node["name"])
|
||||
|
||||
# Check node host-key
|
||||
if set(node["host-key"]).issubset(set(keys)):
|
||||
|
@ -49,7 +49,7 @@ class TimeoutException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SSHTimeoutException(TimeoutException):
|
||||
class ConnectionTimeoutException(TimeoutException):
|
||||
statsd_key = 'error.ssh'
|
||||
|
||||
|
||||
|
@ -57,14 +57,17 @@ def set_node_ip(node):
|
||||
"Unable to find public IP of server")
|
||||
|
||||
|
||||
def keyscan(ip, port=22, timeout=60):
|
||||
def nodescan(ip, port=22, timeout=60, gather_hostkeys=True):
|
||||
'''
|
||||
Scan the IP address for public SSH keys.
|
||||
|
||||
Keys are returned formatted as: "<type> <base64_string>"
|
||||
'''
|
||||
if 'fake' in ip:
|
||||
return ['ssh-rsa FAKEKEY']
|
||||
if gather_hostkeys:
|
||||
return ['ssh-rsa FAKEKEY']
|
||||
else:
|
||||
return []
|
||||
|
||||
addrinfo = socket.getaddrinfo(ip, port)[0]
|
||||
family = addrinfo[0]
|
||||
@ -73,16 +76,18 @@ def keyscan(ip, port=22, timeout=60):
|
||||
keys = []
|
||||
key = None
|
||||
for count in iterate_timeout(
|
||||
timeout, exceptions.SSHTimeoutException, "ssh access"):
|
||||
timeout, exceptions.ConnectionTimeoutException,
|
||||
"connection on port %s" % port):
|
||||
sock = None
|
||||
t = None
|
||||
try:
|
||||
sock = socket.socket(family, socket.SOCK_STREAM)
|
||||
sock.settimeout(timeout)
|
||||
sock.connect(sockaddr)
|
||||
t = paramiko.transport.Transport(sock)
|
||||
t.start_client(timeout=timeout)
|
||||
key = t.get_remote_server_key()
|
||||
if gather_hostkeys:
|
||||
t = paramiko.transport.Transport(sock)
|
||||
t.start_client(timeout=timeout)
|
||||
key = t.get_remote_server_key()
|
||||
break
|
||||
except socket.error as e:
|
||||
if e.errno not in [errno.ECONNREFUSED, errno.EHOSTUNREACH, None]:
|
||||
|
@ -949,6 +949,7 @@ class TestLauncher(tests.DBTestCase):
|
||||
self.assertEqual(len(nodes), 1)
|
||||
self.assertEqual('zuul', nodes[0].username)
|
||||
self.assertEqual('winrm', nodes[0].connection_type)
|
||||
self.assertEqual(nodes[0].host_keys, [])
|
||||
|
||||
def test_unmanaged_image_provider_name(self):
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user