Convert GCE to state machine driver and remove simple

GCE is the only "simple" driver since it turns out every other cloud
is not simple enough.  The state machine driver interface is just
as simple (perhaps simpler) and can accomodate a variety of clouds.
This ports the GCE driver to the statemachine interface and drops
the simple interface entirely.

Change-Id: Icfc298d83750ca31503211f920900d9207488bc2
This commit is contained in:
James E. Blair 2022-08-16 16:10:59 -07:00
parent 6320b06950
commit adab6eeb0d
9 changed files with 402 additions and 1096 deletions

View File

@ -150,66 +150,10 @@ The launch procedure usually consists of the following operations:
- Once the resource is created, READY should be stored to the node.state.
Otherwise raise an exception to restart the launch attempt.
TaskManager
-----------
If you need to use a thread-unsafe client library, or you need to
manage rate limiting in your driver, you may want to use the
:py:class:`~nodepool.driver.taskmanager.TaskManager` class. Implement
any remote API calls as tasks and invoke them by submitting the tasks
to the TaskManager. It will run them sequentially from a single
thread, and assist in rate limiting.
The :py:class:`~nodepool.driver.taskmanager.BaseTaskManagerProvider`
class is a subclass of :py:class:`~nodepool.driver.Provider` which
starts and stops a TaskManager automatically. Inherit from it to
build a Provider as described above with a TaskManager.
.. autoclass:: nodepool.driver.taskmanager.Task
:members:
.. autoclass:: nodepool.driver.taskmanager.TaskManager
:members:
.. autoclass:: nodepool.driver.taskmanager.BaseTaskManagerProvider
Simple Drivers
--------------
If your system is simple enough, you may be able to use the
SimpleTaskManagerDriver class to implement support with just a few
methods. In order to use this class, your system must create and
delete instances as a unit (without requiring multiple resource
creation calls such as volumes or floating IPs).
.. note:: This system is still in development and lacks robust support
for quotas or image building.
To use this system, you will need to implement a few subclasses.
First, create a :ref:`provider_config` subclass as you would for any
driver. Then, subclass
:py:class:`~nodepool.driver.simple.SimpleTaskManagerInstance` to map
remote instance data into a format the simple driver can understand.
Next, subclass
:py:class:`~nodepool.driver.simple.SimpleTaskManagerAdapter` to
implement the main API methods of your provider. Finally, subclass
:py:class:`~nodepool.driver.simple.SimpleTaskManagerDriver` to tie them
all together.
See the ``gce`` provider for an example.
.. autoclass:: nodepool.driver.simple.SimpleTaskManagerInstance
:members:
.. autoclass:: nodepool.driver.simple.SimpleTaskManagerAdapter
:members:
.. autoclass:: nodepool.driver.simple.SimpleTaskManagerDriver
:members:
State Machine Drivers
---------------------
.. note:: This system is still in development and lacks robust support
for quotas or image building.
To use this system, you will need to implement a few subclasses.
First, create a :ref:`provider_config` subclass as you would for any
driver.

View File

@ -1,4 +1,5 @@
# Copyright 2019 Red Hat
# Copyright 2022 Acme Gating, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,14 +15,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from nodepool.driver.simple import SimpleTaskManagerDriver
from nodepool.driver.gce.config import GCEProviderConfig
from nodepool.driver.gce.adapter import GCEAdapter
from nodepool.driver.statemachine import StateMachineDriver
from nodepool.driver.gce.config import GceProviderConfig
from nodepool.driver.gce.adapter import GceAdapter
class GCEDriver(SimpleTaskManagerDriver):
class GceDriver(StateMachineDriver):
def getProviderConfig(self, provider):
return GCEProviderConfig(self, provider)
return GceProviderConfig(self, provider)
def getAdapter(self, provider_config):
return GCEAdapter(provider_config)
return GceAdapter(provider_config)

View File

@ -1,4 +1,5 @@
# Copyright 2019 Red Hat
# Copyright 2022 Acme Gating, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
@ -12,22 +13,35 @@
# License for the specific language governing permissions and limitations
# under the License.
import cachetools.func
import logging
import math
from nodepool.driver.simple import SimpleTaskManagerAdapter
from nodepool.driver.simple import SimpleTaskManagerInstance
from nodepool.driver.utils import QuotaInformation
from nodepool.driver import statemachine
from nodepool.driver.utils import QuotaInformation, RateLimiter
import googleapiclient.discovery
class GCEInstance(SimpleTaskManagerInstance):
def load(self, data):
if data['status'] == 'TERMINATED':
self.deleted = True
elif data['status'] == 'RUNNING':
self.ready = True
CACHE_TTL = 10
def gce_metadata_to_dict(metadata):
if metadata is None:
return {}
return {item['key']: item['value'] for item in metadata.get('items', [])}
def dict_to_gce_metadata(metadata):
metadata_items = []
for (k, v) in metadata.items():
metadata_items.append(dict(key=k, value=v))
return dict(items=metadata_items)
class GceInstance(statemachine.Instance):
def __init__(self, data, quota):
super().__init__()
self.external_id = data['name']
self.az = data['zone']
@ -38,115 +52,136 @@ class GCEInstance(SimpleTaskManagerInstance):
if len(access):
self.public_ipv4 = access[0].get('natIP')
self.interface_ip = self.public_ipv4 or self.private_ipv4
self._machine_type = data.get('_nodepool_gce_machine_type')
if data.get('metadata'):
for item in data['metadata'].get('items', []):
self.metadata[item['key']] = item['value']
self.metadata = gce_metadata_to_dict(data.get('metadata'))
self.quota = quota
def getQuotaInformation(self):
return QuotaInformation(
cores=self._machine_type['guestCpus'],
instances=1,
ram=self._machine_type['memoryMb'])
return self.quota
class GCEAdapter(SimpleTaskManagerAdapter):
log = logging.getLogger("nodepool.driver.gce.GCEAdapter")
class GceResource(statemachine.Resource):
def __init__(self, metadata, type, id):
super().__init__(metadata)
self.type = type
self.id = id
def __init__(self, provider):
self.provider = provider
class GceDeleteStateMachine(statemachine.StateMachine):
INSTANCE_DELETING = 'deleting instance'
COMPLETE = 'complete'
def __init__(self, adapter, external_id, log):
self.log = log
super().__init__()
self.adapter = adapter
self.external_id = external_id
def advance(self):
if self.state == self.START:
self.adapter._deleteInstance(self.external_id)
self.state = self.INSTANCE_DELETING
if self.state == self.INSTANCE_DELETING:
data = self.adapter._getInstance(self.external_id)
if data is None or data['status'] == 'TERMINATED':
self.state = self.COMPLETE
if self.state == self.COMPLETE:
self.complete = True
class GceCreateStateMachine(statemachine.StateMachine):
INSTANCE_CREATING = 'creating instance'
INSTANCE_RETRY = 'retrying instance creation'
COMPLETE = 'complete'
def __init__(self, adapter, hostname, label, image_external_id,
metadata, retries, request, log):
self.log = log
super().__init__()
self.adapter = adapter
self.retries = retries
self.attempts = 0
self.image_external_id = image_external_id
self.metadata = metadata
self.hostname = hostname
self.label = label
self.instance = None
self.quota = None
def advance(self):
if self.state == self.START:
self.external_id = self.hostname
self.adapter._createInstance(
self.hostname, self.metadata, self.label)
self.state = self.INSTANCE_CREATING
if self.state == self.INSTANCE_CREATING:
data = self.adapter._getInstance(self.hostname)
if data is None:
return
if self.quota is None:
machine_type = data['machineType'].split('/')[-1]
self.quota = self.adapter._getQuotaForMachineType(machine_type)
if data['status'] == 'RUNNING':
self.instance = data
self.state = self.COMPLETE
elif data['status'] == 'TERMINATED':
if self.attempts >= self.retries:
raise Exception("Too many retries")
self.attempts += 1
self.state = self.START
return
else:
return
if self.state == self.COMPLETE:
self.complete = True
return GceInstance(self.instance, self.quota)
class GceAdapter(statemachine.Adapter):
log = logging.getLogger("nodepool.GceAdapter")
def __init__(self, provider_config):
self.provider = provider_config
self.compute = googleapiclient.discovery.build('compute', 'v1')
self._machine_types = {}
self.rate_limiter = RateLimiter(self.provider.name,
self.provider.rate)
def listInstances(self, task_manager):
servers = []
def getCreateStateMachine(self, hostname, label, image_external_id,
metadata, retries, request, log):
return GceCreateStateMachine(self, hostname, label, image_external_id,
metadata, retries, request, log)
q = self.compute.instances().list(project=self.provider.project,
zone=self.provider.zone)
with task_manager.rateLimit():
result = q.execute()
def getDeleteStateMachine(self, external_id, log):
return GceDeleteStateMachine(self, external_id, log)
for instance in result.get('items', []):
instance_type = instance['machineType'].split('/')[-1]
mtype = self._getMachineType(task_manager, instance_type)
instance['_nodepool_gce_machine_type'] = mtype
servers.append(GCEInstance(instance))
return servers
def listInstances(self):
instances = []
def deleteInstance(self, task_manager, server_id):
q = self.compute.instances().delete(project=self.provider.project,
zone=self.provider.zone,
instance=server_id)
with task_manager.rateLimit():
q.execute()
for instance in self._listInstances():
machine_type = instance['machineType'].split('/')[-1]
quota = self._getQuotaForMachineType(machine_type)
instances.append(GceInstance(instance, quota))
return instances
def _getImageId(self, task_manager, cloud_image):
image_id = cloud_image.image_id
def listResources(self):
for instance in self._listInstances():
if instance['status'] == 'TERMINATED':
continue
metadata = gce_metadata_to_dict(instance.get('metadata'))
yield GceResource(metadata, 'instance', instance['name'])
if image_id:
return image_id
def deleteResource(self, resource):
self.log.info(f"Deleting leaked {resource.type}: {resource.id}")
if resource.type == 'instance':
self._deleteInstance(resource.id)
if cloud_image.image_family:
q = self.compute.images().getFromFamily(
project=cloud_image.image_project,
family=cloud_image.image_family)
with task_manager.rateLimit():
result = q.execute()
image_id = result['selfLink']
return image_id
def _getMachineType(self, task_manager, machine_type):
if machine_type in self._machine_types:
return self._machine_types[machine_type]
q = self.compute.machineTypes().get(
project=self.provider.project,
zone=self.provider.zone,
machineType=machine_type)
with task_manager.rateLimit():
result = q.execute()
self._machine_types[machine_type] = result
return result
def createInstance(self, task_manager, hostname, metadata, label):
image_id = self._getImageId(task_manager, label.cloud_image)
disk_init = dict(sourceImage=image_id,
diskType='zones/{}/diskTypes/{}'.format(
self.provider.zone, label.volume_type),
diskSizeGb=str(label.volume_size))
disk = dict(boot=True,
autoDelete=True,
initializeParams=disk_init)
mtype = self._getMachineType(task_manager, label.instance_type)
machine_type = mtype['selfLink']
network = dict(network='global/networks/default',
accessConfigs=[dict(
type='ONE_TO_ONE_NAT',
name='External NAT')])
metadata_items = []
for (k, v) in metadata.items():
metadata_items.append(dict(key=k, value=v))
meta = dict(items=metadata_items)
args = dict(
name=hostname,
machineType=machine_type,
disks=[disk],
networkInterfaces=[network],
serviceAccounts=[],
metadata=meta)
q = self.compute.instances().insert(
project=self.provider.project,
zone=self.provider.zone,
body=args)
with task_manager.rateLimit():
q.execute()
return hostname
def getQuotaLimits(self, task_manager):
def getQuotaLimits(self):
q = self.compute.regions().get(project=self.provider.project,
region=self.provider.region)
with task_manager.rateLimit():
with self.rate_limiter:
ret = q.execute()
cores = None
@ -163,9 +198,95 @@ class GCEAdapter(SimpleTaskManagerAdapter):
instances=instances,
default=math.inf)
def getQuotaForLabel(self, task_manager, label):
mtype = self._getMachineType(task_manager, label.instance_type)
def getQuotaForLabel(self, label):
return self._getQuotaForMachineType(label.instance_type)
# Local implementation below
def _createInstance(self, hostname, metadata, label):
metadata = metadata.copy()
image_id = self._getImageId(label.cloud_image)
disk_init = dict(sourceImage=image_id,
diskType='zones/{}/diskTypes/{}'.format(
self.provider.zone, label.volume_type),
diskSizeGb=str(label.volume_size))
disk = dict(boot=True,
autoDelete=True,
initializeParams=disk_init)
mtype = self._getMachineType(label.instance_type)
machine_type = mtype['selfLink']
network = dict(network='global/networks/default',
accessConfigs=[dict(
type='ONE_TO_ONE_NAT',
name='External NAT')])
if label.cloud_image.key:
metadata['ssh-keys'] = '{}:{}'.format(
label.cloud_image.username,
label.cloud_image.key)
args = dict(
name=hostname,
machineType=machine_type,
disks=[disk],
networkInterfaces=[network],
serviceAccounts=[],
metadata=dict_to_gce_metadata(metadata))
q = self.compute.instances().insert(
project=self.provider.project,
zone=self.provider.zone,
body=args)
with self.rate_limiter:
q.execute()
def _deleteInstance(self, server_id):
q = self.compute.instances().delete(project=self.provider.project,
zone=self.provider.zone,
instance=server_id)
with self.rate_limiter:
q.execute()
@cachetools.func.ttl_cache(maxsize=1, ttl=CACHE_TTL)
def _listInstances(self):
q = self.compute.instances().list(project=self.provider.project,
zone=self.provider.zone)
with self.rate_limiter:
result = q.execute()
return result.get('items', [])
@cachetools.func.lru_cache(maxsize=None)
def _getImageId(self, cloud_image):
image_id = cloud_image.image_id
if image_id:
return image_id
if cloud_image.image_family:
q = self.compute.images().getFromFamily(
project=cloud_image.image_project,
family=cloud_image.image_family)
with self.rate_limiter:
result = q.execute()
image_id = result['selfLink']
return image_id
@cachetools.func.lru_cache(maxsize=None)
def _getMachineType(self, machine_type):
q = self.compute.machineTypes().get(
project=self.provider.project,
zone=self.provider.zone,
machineType=machine_type)
with self.rate_limiter:
return q.execute()
def _getQuotaForMachineType(self, machine_type):
mtype = self._getMachineType(machine_type)
return QuotaInformation(
cores=mtype['guestCpus'],
instances=1,
ram=mtype['memoryMb'])
def _getInstance(self, hostname):
for instance in self._listInstances():
if instance['name'] == hostname:
return instance
return None

View File

@ -21,7 +21,7 @@ from nodepool.driver import ConfigValue
from nodepool.driver import ProviderConfig
class ProviderCloudImage(ConfigValue):
class GceProviderCloudImage(ConfigValue):
def __init__(self):
self.name = None
self.image_id = None
@ -32,119 +32,60 @@ class ProviderCloudImage(ConfigValue):
self.connection_port = None
self.shell_type = None
def __eq__(self, other):
if isinstance(other, ProviderCloudImage):
return (self.name == other.name
and self.image_id == other.image_id
and self.username == other.username
and self.key == other.key
and self.python_path == other.python_path
and self.connection_type == other.connection_type
and self.connection_port == other.connection_port
and self.shell_type == other.shell_type)
return False
def __repr__(self):
return "<ProviderCloudImage %s>" % self.name
@property
def external_name(self):
'''Human readable version of external.'''
return self.image_id or self.name
class ProviderLabel(ConfigValue):
class GceLabel(ConfigValue):
ignore_equality = ['pool']
def __init__(self):
self.name = None
self.cloud_image = None
self.instance_type = None
self.volume_size = None
self.volume_type = None
# The ProviderPool object that owns this label.
self.pool = None
def __init__(self, label, provider_config, provider_pool):
self.name = label['name']
self.pool = provider_pool
def __eq__(self, other):
if isinstance(other, ProviderLabel):
# NOTE(Shrews): We intentionally do not compare 'pool' here
# since this causes recursive checks with ProviderPool.
return (other.name == self.name
and other.cloud_image == self.cloud_image
and other.instance_type == self.instance_type
and other.volume_size == self.volume_size
and other.volume_type == self.volume_type)
return False
def __repr__(self):
return "<ProviderLabel %s>" % self.name
cloud_image_name = label.get('cloud-image', None)
if cloud_image_name:
cloud_image = provider_config.cloud_images.get(
cloud_image_name, None)
if not cloud_image:
raise ValueError(
"cloud-image %s does not exist in provider %s"
" but is referenced in label %s" %
(cloud_image_name, provider_config.name, self.name))
self.cloud_image = cloud_image
else:
self.cloud_image = None
self.instance_type = label['instance-type']
self.volume_type = label.get('volume-type', 'pd-standard')
self.volume_size = label.get('volume-size', '10')
self.diskimage = None
class ProviderPool(ConfigPool):
class GcePool(ConfigPool):
ignore_equality = ['provider']
def __init__(self):
self.name = None
self.host_key_checking = True
self.use_internal_ip = False
self.labels = None
# The ProviderConfig object that owns this pool.
self.provider = None
# Initialize base class attributes
def __init__(self, provider_config, pool_config):
super().__init__()
self.provider = provider_config
self.load(pool_config)
def load(self, pool_config, full_config, provider):
def load(self, pool_config):
super().load(pool_config)
self.name = pool_config['name']
self.provider = provider
self.host_key_checking = bool(
pool_config.get('host-key-checking', True))
self.use_internal_ip = bool(
pool_config.get('use-internal-ip', False))
for label in pool_config.get('labels', []):
pl = ProviderLabel()
pl.name = label['name']
pl.pool = self
self.labels[pl.name] = pl
cloud_image_name = label.get('cloud-image', None)
if cloud_image_name:
cloud_image = self.provider.cloud_images.get(
cloud_image_name, None)
if not cloud_image:
raise ValueError(
"cloud-image %s does not exist in provider %s"
" but is referenced in label %s" %
(cloud_image_name, self.name, pl.name))
else:
cloud_image = None
pl.cloud_image = cloud_image
pl.instance_type = label['instance-type']
pl.volume_type = label.get('volume-type', 'pd-standard')
pl.volume_size = label.get('volume-size', '10')
full_config.labels[label['name']].pools.append(self)
def __eq__(self, other):
if isinstance(other, ProviderPool):
# NOTE(Shrews): We intentionally do not compare 'provider' here
# since this causes recursive checks with OpenStackProviderConfig.
return (super().__eq__(other)
and other.name == self.name
and other.host_key_checking == self.host_key_checking
and other.use_internal_ip == self.use_internal_ip
and other.labels == self.labels)
return False
def __repr__(self):
return "<ProviderPool %s>" % self.name
class GCEProviderConfig(ProviderConfig):
class GceProviderConfig(ProviderConfig):
def __init__(self, driver, provider):
self.driver_object = driver
self.__pools = {}
super().__init__(provider)
self._pools = {}
self.rate = None
self.region = None
self.boot_timeout = None
self.launch_retries = None
@ -152,24 +93,10 @@ class GCEProviderConfig(ProviderConfig):
self.zone = None
self.cloud_images = {}
self.rate_limit = None
super().__init__(provider)
def __eq__(self, other):
if isinstance(other, GCEProviderConfig):
return (super().__eq__(other)
and other.region == self.region
and other.pools == self.pools
and other.boot_timeout == self.boot_timeout
and other.launch_retries == self.launch_retries
and other.cloud_images == self.cloud_images
and other.project == self.project
and other.rate_limit == self.rate_limit
and other.zone == self.zone)
return False
@property
def pools(self):
return self.__pools
return self._pools
@property
def manage_images(self):
@ -182,9 +109,11 @@ class GCEProviderConfig(ProviderConfig):
pass
def load(self, config):
self.rate = self.provider.get('rate', 2)
self.region = self.provider.get('region')
self.boot_timeout = self.provider.get('boot-timeout', 60)
self.launch_retries = self.provider.get('launch-retries', 3)
self.launch_timeout = self.provider.get('launch-timeout', 3600)
self.project = self.provider.get('project')
self.zone = self.provider.get('zone')
self.rate_limit = self.provider.get('rate-limit', 1)
@ -196,7 +125,7 @@ class GCEProviderConfig(ProviderConfig):
# TODO: diskimages
for image in self.provider.get('cloud-images', []):
i = ProviderCloudImage()
i = GceProviderCloudImage()
i.name = image['name']
i.image_id = image.get('image-id', None)
i.image_project = image.get('image-project', None)
@ -212,9 +141,13 @@ class GCEProviderConfig(ProviderConfig):
self.cloud_images[i.name] = i
for pool in self.provider.get('pools', []):
pp = ProviderPool()
pp.load(pool, config, self)
self.pools[pp.name] = pp
pp = GcePool(self, pool)
self._pools[pp.name] = pp
for label in pool.get('labels', []):
pl = GceLabel(label, self, pp)
pp.labels[pl.name] = pl
config.labels[pl.name].pools.append(pp)
def getSchema(self):
pool_label = {

View File

@ -12,13 +12,14 @@
# License for the specific language governing permissions and limitations
# under the License.
import math
import logging
from kazoo import exceptions as kze
from nodepool.zk import zookeeper as zk
from nodepool.driver.simple import SimpleTaskManagerHandler
from nodepool.driver.utils import NodeLauncher
from nodepool.driver import NodeRequestHandler
from nodepool.driver.utils import NodeLauncher, QuotaInformation
class K8SLauncher(NodeLauncher):
@ -81,7 +82,126 @@ class K8SLauncher(NodeLauncher):
attempts += 1
class KubernetesNodeRequestHandler(SimpleTaskManagerHandler):
class KubernetesNodeRequestHandler(NodeRequestHandler):
log = logging.getLogger("nodepool.driver.kubernetes."
"KubernetesNodeRequestHandler")
launcher = K8SLauncher
def __init__(self, pw, request):
super().__init__(pw, request)
self._threads = []
@property
def alive_thread_count(self):
count = 0
for t in self._threads:
if t.is_alive():
count += 1
return count
def imagesAvailable(self):
'''
Determines if the requested images are available for this provider.
:returns: True if it is available, False otherwise.
'''
return True
def hasProviderQuota(self, node_types):
'''
Checks if a provider has enough quota to handle a list of nodes.
This does not take our currently existing nodes into account.
:param node_types: list of node types to check
:return: True if the node list fits into the provider, False otherwise
'''
needed_quota = QuotaInformation()
for ntype in node_types:
needed_quota.add(
self.manager.quotaNeededByLabel(ntype, self.pool))
if hasattr(self.pool, 'ignore_provider_quota'):
if not self.pool.ignore_provider_quota:
cloud_quota = self.manager.estimatedNodepoolQuota()
cloud_quota.subtract(needed_quota)
if not cloud_quota.non_negative():
return False
# Now calculate pool specific quota. Values indicating no quota default
# to math.inf representing infinity that can be calculated with.
pool_quota = QuotaInformation(
cores=getattr(self.pool, 'max_cores', None),
instances=self.pool.max_servers,
ram=getattr(self.pool, 'max_ram', None),
default=math.inf)
pool_quota.subtract(needed_quota)
return pool_quota.non_negative()
def hasRemainingQuota(self, ntype):
'''
Checks if the predicted quota is enough for an additional node of type
ntype.
:param ntype: node type for the quota check
:return: True if there is enough quota, False otherwise
'''
needed_quota = self.manager.quotaNeededByLabel(ntype, self.pool)
# Calculate remaining quota which is calculated as:
# quota = <total nodepool quota> - <used quota> - <quota for node>
cloud_quota = self.manager.estimatedNodepoolQuota()
cloud_quota.subtract(
self.manager.estimatedNodepoolQuotaUsed())
cloud_quota.subtract(needed_quota)
self.log.debug("Predicted remaining provider quota: %s",
cloud_quota)
if not cloud_quota.non_negative():
return False
# Now calculate pool specific quota. Values indicating no quota default
# to math.inf representing infinity that can be calculated with.
pool_quota = QuotaInformation(
cores=getattr(self.pool, 'max_cores', None),
instances=self.pool.max_servers,
ram=getattr(self.pool, 'max_ram', None),
default=math.inf)
pool_quota.subtract(
self.manager.estimatedNodepoolQuotaUsed(self.pool))
self.log.debug("Current pool quota: %s" % pool_quota)
pool_quota.subtract(needed_quota)
self.log.debug("Predicted remaining pool quota: %s", pool_quota)
return pool_quota.non_negative()
def launchesComplete(self):
'''
Check if all launch requests have completed.
When all of the Node objects have reached a final state (READY, FAILED
or ABORTED), we'll know all threads have finished the launch process.
'''
if not self._threads:
return True
# Give the NodeLaunch threads time to finish.
if self.alive_thread_count:
return False
node_states = [node.state for node in self.nodeset]
# NOTE: It is very important that NodeLauncher always sets one
# of these states, no matter what.
if not all(s in (zk.READY, zk.FAILED, zk.ABORTED)
for s in node_states):
return False
return True
def launch(self, node):
label = self.pool.labels[node.type[0]]
thd = self.launcher(self, node, self.provider, label)
thd.start()
self._threads.append(thd)

View File

@ -1,624 +0,0 @@
# Copyright 2019 Red Hat
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import time
import logging
import math
from nodepool.driver.taskmanager import BaseTaskManagerProvider, Task
from nodepool.driver import Driver, NodeRequestHandler
from nodepool.driver.utils import NodeLauncher, QuotaInformation, QuotaSupport
from nodepool.driver.utils import NodeDeleter
from nodepool.nodeutils import iterate_timeout, nodescan
from nodepool import exceptions
from nodepool.zk import zookeeper as zk
# Private support classes
class CreateInstanceTask(Task):
name = 'create_instance'
def main(self, manager):
return self.args['adapter'].createInstance(
manager, self.args['hostname'], self.args['metadata'],
self.args['label_config'])
class DeleteInstanceTask(Task):
name = 'delete_instance'
def main(self, manager):
return self.args['adapter'].deleteInstance(
manager, self.args['external_id'])
class ListInstancesTask(Task):
name = 'list_instances'
def main(self, manager):
return self.args['adapter'].listInstances(manager)
class GetQuotaLimitsTask(Task):
name = 'get_quota_limits'
def main(self, manager):
return self.args['adapter'].getQuotaLimits(manager)
class GetQuotaForLabelTask(Task):
name = 'get_quota_for_label'
def main(self, manager):
return self.args['adapter'].getQuotaForLabel(
manager, self.args['label_config'])
class SimpleTaskManagerLauncher(NodeLauncher):
"""The NodeLauncher implementation for the SimpleTaskManager driver
framework"""
def __init__(self, handler, node, provider_config, provider_label):
super().__init__(handler, node, provider_config)
self.provider_name = provider_config.name
self.retries = provider_config.launch_retries
self.pool = provider_config.pools[provider_label.pool.name]
self.boot_timeout = provider_config.boot_timeout
self.label = provider_label
def launch(self):
self.log.debug("Starting %s instance" % self.node.type)
attempts = 1
hostname = 'nodepool-' + self.node.id
tm = self.handler.manager.task_manager
adapter = self.handler.manager.adapter
metadata = {'nodepool_node_id': self.node.id,
'nodepool_pool_name': self.pool.name,
'nodepool_provider_name': self.provider_name}
if self.label.cloud_image.key:
metadata['ssh-keys'] = '{}:{}'.format(
self.label.cloud_image.username,
self.label.cloud_image.key)
while attempts <= self.retries:
try:
t = tm.submitTask(CreateInstanceTask(
adapter=adapter, hostname=hostname,
metadata=metadata,
label_config=self.label))
external_id = t.wait()
break
except Exception:
if attempts <= self.retries:
self.log.exception(
"Launch attempt %d/%d failed for node %s:",
attempts, self.retries, self.node.id)
if attempts == self.retries:
raise
attempts += 1
time.sleep(1)
self.node.external_id = external_id
self.zk.storeNode(self.node)
for count in iterate_timeout(
self.boot_timeout, exceptions.LaunchStatusException,
"server %s creation" % external_id):
instance = self.handler.manager.getInstance(external_id)
if instance and instance.ready:
break
self.log.debug("Created instance %s", repr(instance))
if self.pool.use_internal_ip:
server_ip = instance.private_ipv4
else:
server_ip = instance.interface_ip
self.node.connection_port = self.label.cloud_image.connection_port
self.node.connection_type = self.label.cloud_image.connection_type
keys = []
if self.pool.host_key_checking:
try:
if (self.node.connection_type == 'ssh' or
self.node.connection_type == 'network_cli'):
gather_hostkeys = True
else:
gather_hostkeys = False
keys = nodescan(server_ip, port=self.node.connection_port,
timeout=180, gather_hostkeys=gather_hostkeys)
except Exception:
raise exceptions.LaunchKeyscanException(
"Can't scan instance %s key" % hostname)
self.log.info("Instance %s ready" % hostname)
self.node.state = zk.READY
self.node.external_id = hostname
self.node.hostname = hostname
self.node.interface_ip = server_ip
self.node.public_ipv4 = instance.public_ipv4
self.node.private_ipv4 = instance.private_ipv4
self.node.public_ipv6 = instance.public_ipv6
self.node.region = instance.region
self.node.az = instance.az
self.node.host_keys = keys
self.node.username = self.label.cloud_image.username
self.node.python_path = self.label.cloud_image.python_path
self.node.shell_type = self.label.cloud_image.shell_type
self.zk.storeNode(self.node)
self.log.info("Instance %s is ready", hostname)
class SimpleTaskManagerHandler(NodeRequestHandler):
log = logging.getLogger("nodepool.driver.simple."
"SimpleTaskManagerHandler")
launcher = SimpleTaskManagerLauncher
def __init__(self, pw, request):
super().__init__(pw, request)
self._threads = []
@property
def alive_thread_count(self):
count = 0
for t in self._threads:
if t.is_alive():
count += 1
return count
def imagesAvailable(self):
'''
Determines if the requested images are available for this provider.
:returns: True if it is available, False otherwise.
'''
return True
def hasProviderQuota(self, node_types):
'''
Checks if a provider has enough quota to handle a list of nodes.
This does not take our currently existing nodes into account.
:param node_types: list of node types to check
:return: True if the node list fits into the provider, False otherwise
'''
needed_quota = QuotaInformation()
for ntype in node_types:
needed_quota.add(
self.manager.quotaNeededByLabel(ntype, self.pool))
if hasattr(self.pool, 'ignore_provider_quota'):
if not self.pool.ignore_provider_quota:
cloud_quota = self.manager.estimatedNodepoolQuota()
cloud_quota.subtract(needed_quota)
if not cloud_quota.non_negative():
return False
# Now calculate pool specific quota. Values indicating no quota default
# to math.inf representing infinity that can be calculated with.
pool_quota = QuotaInformation(
cores=getattr(self.pool, 'max_cores', None),
instances=self.pool.max_servers,
ram=getattr(self.pool, 'max_ram', None),
default=math.inf)
pool_quota.subtract(needed_quota)
return pool_quota.non_negative()
def hasRemainingQuota(self, ntype):
'''
Checks if the predicted quota is enough for an additional node of type
ntype.
:param ntype: node type for the quota check
:return: True if there is enough quota, False otherwise
'''
needed_quota = self.manager.quotaNeededByLabel(ntype, self.pool)
# Calculate remaining quota which is calculated as:
# quota = <total nodepool quota> - <used quota> - <quota for node>
cloud_quota = self.manager.estimatedNodepoolQuota()
cloud_quota.subtract(
self.manager.estimatedNodepoolQuotaUsed())
cloud_quota.subtract(needed_quota)
self.log.debug("Predicted remaining provider quota: %s",
cloud_quota)
if not cloud_quota.non_negative():
return False
# Now calculate pool specific quota. Values indicating no quota default
# to math.inf representing infinity that can be calculated with.
pool_quota = QuotaInformation(
cores=getattr(self.pool, 'max_cores', None),
instances=self.pool.max_servers,
ram=getattr(self.pool, 'max_ram', None),
default=math.inf)
pool_quota.subtract(
self.manager.estimatedNodepoolQuotaUsed(self.pool))
self.log.debug("Current pool quota: %s" % pool_quota)
pool_quota.subtract(needed_quota)
self.log.debug("Predicted remaining pool quota: %s", pool_quota)
return pool_quota.non_negative()
def launchesComplete(self):
'''
Check if all launch requests have completed.
When all of the Node objects have reached a final state (READY, FAILED
or ABORTED), we'll know all threads have finished the launch process.
'''
if not self._threads:
return True
# Give the NodeLaunch threads time to finish.
if self.alive_thread_count:
return False
node_states = [node.state for node in self.nodeset]
# NOTE: It is very important that NodeLauncher always sets one
# of these states, no matter what.
if not all(s in (zk.READY, zk.FAILED, zk.ABORTED)
for s in node_states):
return False
return True
def launch(self, node):
label = self.pool.labels[node.type[0]]
thd = self.launcher(self, node, self.provider, label)
thd.start()
self._threads.append(thd)
class SimpleTaskManagerProvider(BaseTaskManagerProvider, QuotaSupport):
"""The Provider implementation for the SimpleTaskManager driver
framework"""
def __init__(self, adapter, provider):
super().__init__(provider)
self.adapter = adapter
self.node_cache_time = 0
self.node_cache = []
self._zk = None
def start(self, zk_conn):
super().start(zk_conn)
self._zk = zk_conn
def getRequestHandler(self, poolworker, request):
return SimpleTaskManagerHandler(poolworker, request)
def labelReady(self, label):
return True
def getProviderLimits(self):
try:
t = self.task_manager.submitTask(GetQuotaLimitsTask(
adapter=self.adapter))
return t.wait()
except NotImplementedError:
return QuotaInformation(
cores=math.inf,
instances=math.inf,
ram=math.inf,
default=math.inf)
def quotaNeededByLabel(self, ntype, pool):
provider_label = pool.labels[ntype]
try:
t = self.task_manager.submitTask(GetQuotaForLabelTask(
adapter=self.adapter, label_config=provider_label))
return t.wait()
except NotImplementedError:
return QuotaInformation()
def unmanagedQuotaUsed(self):
'''
Sums up the quota used by servers unmanaged by nodepool.
:return: Calculated quota in use by unmanaged servers
'''
used_quota = QuotaInformation()
node_ids = set([n.id for n in self._zk.nodeIterator()])
for server in self.listNodes():
meta = server.metadata
nodepool_provider_name = meta.get('nodepool_provider_name')
if (nodepool_provider_name and
nodepool_provider_name == self.provider.name):
# This provider (regardless of the launcher) owns this
# node so it must not be accounted for unmanaged
# quota; unless it has leaked.
nodepool_node_id = meta.get('nodepool_node_id')
if nodepool_node_id and nodepool_node_id in node_ids:
# It has not leaked.
continue
try:
qi = server.getQuotaInformation()
except NotImplementedError:
qi = QuotaInformation()
used_quota.add(qi)
return used_quota
def startNodeCleanup(self, node):
t = NodeDeleter(self._zk, self, node)
t.start()
return t
def cleanupNode(self, external_id):
instance = self.getInstance(external_id)
if (not instance) or instance.deleted:
raise exceptions.NotFound()
t = self.task_manager.submitTask(DeleteInstanceTask(
adapter=self.adapter, external_id=external_id))
t.wait()
def waitForNodeCleanup(self, external_id, timeout=600):
for count in iterate_timeout(
timeout, exceptions.ServerDeleteException,
"server %s deletion" % external_id):
instance = self.getInstance(external_id)
if (not instance) or instance.deleted:
return
def cleanupLeakedResources(self):
deleting_nodes = {}
for node in self._zk.nodeIterator():
if node.state == zk.DELETING:
if node.provider != self.provider.name:
continue
if node.provider not in deleting_nodes:
deleting_nodes[node.provider] = []
deleting_nodes[node.provider].append(node.external_id)
for server in self.listNodes():
meta = server.metadata
if meta.get('nodepool_provider_name') != self.provider.name:
# Not our responsibility
continue
if (server.external_id in
deleting_nodes.get(self.provider.name, [])):
# Already deleting this node
continue
if not self._zk.getNode(meta['nodepool_node_id']):
self.log.warning(
"Marking for delete leaked instance %s in %s "
"(unknown node id %s)",
server.external_id, self.provider.name,
meta['nodepool_node_id']
)
# Create an artifical node to use for deleting the server.
node = zk.Node()
node.external_id = server.external_id
node.provider = self.provider.name
node.state = zk.DELETING
self._zk.storeNode(node)
def listNodes(self):
now = time.monotonic()
if now - self.node_cache_time > 5:
t = self.task_manager.submitTask(ListInstancesTask(
adapter=self.adapter))
nodes = t.wait()
self.node_cache = nodes
self.node_cache_time = time.monotonic()
return self.node_cache
def countNodes(self, provider_name, pool_name):
return len(
[n for n in self.listNodes() if
n.metadata.get('nodepool_provider_name') == provider_name and
n.metadata.get('nodepool_pool_name') == pool_name])
def getInstance(self, external_id):
for candidate in self.listNodes():
if (candidate.external_id == external_id):
return candidate
return None
# Public interface below
class SimpleTaskManagerInstance:
"""Represents a cloud instance
This class is used by the Simple Task Manager Driver classes to
represent a standardized version of a remote cloud instance.
Implement this class in your driver, override the :py:meth:`load`
method, and supply as many of the fields as possible.
:param data: An opaque data object to be passed to the load method.
"""
def __init__(self, data):
self.ready = False
self.deleted = False
self.external_id = None
self.public_ipv4 = None
self.public_ipv6 = None
self.private_ipv4 = None
self.interface_ip = None
self.az = None
self.region = None
self.metadata = {}
self.load(data)
def __repr__(self):
state = []
if self.ready:
state.append('ready')
if self.deleted:
state.append('deleted')
state = ' '.join(state)
return '<{klass} {external_id} {state}>'.format(
klass=self.__class__.__name__,
external_id=self.external_id,
state=state)
def load(self, data):
"""Parse data and update this object's attributes
:param data: An opaque data object which was passed to the
constructor.
Override this method and extract data from the `data`
parameter.
The following attributes are required:
* ready: bool (whether the instance is ready)
* deleted: bool (whether the instance is in a deleted state)
* external_id: str (the unique id of the instance)
* interface_ip: str
* metadata: dict
The following are optional:
* public_ipv4: str
* public_ipv6: str
* private_ipv4: str
* az: str
* region: str
"""
raise NotImplementedError()
def getQuotaInformation(self):
"""Return quota information about this instance.
:returns: A :py:class:`QuotaInformation` object.
"""
raise NotImplementedError()
class SimpleTaskManagerAdapter:
"""Public interface for the simple TaskManager Provider
Implement these methods as simple synchronous calls, and pass this
class to the SimpleTaskManagerDriver class.
You can establish a single long-lived connection in the
initializer. The provider will call methods on this object from a
single thread.
All methods accept a task_manager argument. Use this to control
rate limiting:
.. code:: python
with task_manager.rateLimit():
<execute API call>
"""
def __init__(self, provider):
pass
def createInstance(self, task_manager, hostname, metadata, label_config):
"""Create an instance
:param TaskManager task_manager: An instance of
:py:class:`~nodepool.driver.taskmananger.TaskManager`.
:param str hostname: The intended hostname for the instance.
:param dict metadata: A dictionary of key/value pairs that
must be stored on the instance.
:param ProviderLabel label_config: A LabelConfig object describing
the instance which should be created.
"""
raise NotImplementedError()
def deleteInstance(self, task_manager, external_id):
"""Delete an instance
:param TaskManager task_manager: An instance of
:py:class:`~nodepool.driver.taskmananger.TaskManager`.
:param str external_id: The id of the cloud instance.
"""
raise NotImplementedError()
def listInstances(self, task_manager):
"""Return a list of instances
:param TaskManager task_manager: An instance of
:py:class:`~nodepool.driver.taskmananger.TaskManager`.
:returns: A list of :py:class:`SimpleTaskManagerInstance` objects.
"""
raise NotImplementedError()
def getQuotaLimits(self, task_manager):
"""Return the quota limits for this provider
The default implementation returns a simple QuotaInformation
with no limits. Override this to provide accurate
information.
:param TaskManager task_manager: An instance of
:py:class:`~nodepool.driver.taskmananger.TaskManager`.
:returns: A :py:class:`QuotaInformation` object.
"""
return QuotaInformation(default=math.inf)
def getQuotaForLabel(self, task_manager, label_config):
"""Return information about the quota used for a label
The default implementation returns a simple QuotaInformation
for one instance; override this to return more detailed
information including cores and RAM.
:param TaskManager task_manager: An instance of
:py:class:`~nodepool.driver.taskmananger.TaskManager`.
:param ProviderLabel label_config: A LabelConfig object describing
a label for an instance.
:returns: A :py:class:`QuotaInformation` object.
"""
return QuotaInformation(instances=1)
class SimpleTaskManagerDriver(Driver):
"""Subclass this to make a simple driver"""
def getProvider(self, provider_config):
"""Return a provider.
Usually this method does not need to be overridden.
"""
adapter = self.getAdapter(provider_config)
return SimpleTaskManagerProvider(adapter, provider_config)
# Public interface
def getProviderConfig(self, provider):
"""Instantiate a config object
:param dict provider: A dictionary of YAML config describing
the provider.
:returns: A ProviderConfig instance with the parsed data.
"""
raise NotImplementedError()
def getAdapter(self, provider_config):
"""Instantiate an adapter
:param ProviderConfig provider_config: An instance of
ProviderConfig previously returned by :py:meth:`getProviderConfig`.
:returns: An instance of :py:class:`SimpleTaskManagerAdapter`
"""
raise NotImplementedError()

View File

@ -346,8 +346,7 @@ class StateMachineNodeDeleter:
class StateMachineHandler(NodeRequestHandler):
log = logging.getLogger("nodepool.driver.simple."
"StateMachineHandler")
log = logging.getLogger("nodepool.StateMachineHandler")
def __init__(self, pw, request):
super().__init__(pw, request)
@ -710,7 +709,7 @@ class StateMachineDriver(Driver):
:param ProviderConfig provider_config: An instance of
ProviderConfig previously returned by :py:meth:`getProviderConfig`.
:returns: An instance of :py:class:`SimpleTaskManagerAdapter`
:returns: An instance of :py:class:`Adapter`
"""
raise NotImplementedError()
@ -720,7 +719,7 @@ class StateMachineDriver(Driver):
class Instance:
"""Represents a cloud instance
This class is used by the Simple Task Manager Driver classes to
This class is used by the State Machine Driver classes to
represent a standardized version of a remote cloud instance.
Implement this class in your driver, override the :py:meth:`load`
method, and supply as many of the fields as possible.

View File

@ -1,188 +0,0 @@
# Copyright 2019 Red Hat
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import time
import logging
import queue
import threading
from nodepool.driver import Provider
class Task:
"""Base task class for use with :py:class:`TaskManager`
Subclass this to implement your own tasks.
Set the `name` field to the name of your task and override the
:py:meth:`main` method.
Keyword arguments to the constructor are stored on `self.args` for
use by the :py:meth:`main` method.
"""
name = "task_name"
def __init__(self, **kw):
self._wait_event = threading.Event()
self._exception = None
self._traceback = None
self._result = None
self.args = kw
def done(self, result):
self._result = result
self._wait_event.set()
def exception(self, e):
self._exception = e
self._wait_event.set()
def wait(self):
"""Call this method after submitting the task to the TaskManager to
receieve the results."""
self._wait_event.wait()
if self._exception:
raise self._exception
return self._result
def run(self, manager):
try:
self.done(self.main(manager))
except Exception as e:
self.exception(e)
def main(self, manager):
"""Implement the work of the task
:param TaskManager manager: The instance of
:py:class:`TaskManager` running this task.
Arguments passed to the constructor are available as `self.args`.
"""
pass
class StopTask(Task):
name = "stop_taskmanager"
def main(self, manager):
manager._running = False
class RateLimitContextManager:
def __init__(self, task_manager):
self.task_manager = task_manager
def __enter__(self):
if self.task_manager.last_ts is None:
return
while True:
delta = time.monotonic() - self.task_manager.last_ts
if delta >= self.task_manager.delta:
break
time.sleep(self.task_manager.delta - delta)
def __exit__(self, etype, value, tb):
self.task_manager.last_ts = time.monotonic()
class TaskManager:
"""A single-threaded task dispatcher
This class is meant to be instantiated by a Provider in order to
execute remote API calls from a single thread with rate limiting.
:param str name: The name of the TaskManager (usually the provider name)
used in logging.
:param float rate_limit: The rate limit of the task manager expressed in
requests per second.
"""
log = logging.getLogger("nodepool.driver.taskmanager.TaskManager")
def __init__(self, name, rate_limit):
self._running = True
self.name = name
self.queue = queue.Queue()
self.delta = 1.0 / rate_limit
self.last_ts = None
def rateLimit(self):
"""Return a context manager to perform rate limiting. Use as follows:
.. code: python
with task_manager.rateLimit():
<execute API call>
"""
return RateLimitContextManager(self)
def submitTask(self, task):
"""Submit a task to the task manager.
:param Task task: An instance of a subclass of :py:class:`Task`.
:returns: The submitted task for use in function chaning.
"""
self.queue.put(task)
return task
def stop(self):
"""Stop the task manager."""
self.submitTask(StopTask())
def run(self):
try:
while True:
task = self.queue.get()
if not task:
continue
self.log.debug("Manager %s running task %s (queue %s)" %
(self.name, task.name, self.queue.qsize()))
task.run(self)
self.queue.task_done()
if not self._running:
break
except Exception:
self.log.exception("Task manager died")
raise
class BaseTaskManagerProvider(Provider):
"""Subclass this to build a Provider with an included taskmanager"""
log = logging.getLogger("nodepool.driver.taskmanager.TaskManagerProvider")
def __init__(self, provider):
super().__init__()
self.provider = provider
self.thread = None
self.task_manager = TaskManager(provider.name, provider.rate_limit)
def start(self, zk_conn):
self.log.debug("Starting")
if self.thread is None:
self.log.debug("Starting thread")
self.thread = threading.Thread(target=self.task_manager.run)
self.thread.start()
def stop(self):
self.log.debug("Stopping")
if self.thread is not None:
self.log.debug("Stopping thread")
self.task_manager.stop()
def join(self):
self.log.debug("Joining")
if self.thread is not None:
self.thread.join()

View File

@ -276,7 +276,7 @@ class TestDriverGce(tests.DBTestCase):
self._wait_for_provider(pool, 'gcloud-provider')
with patch('nodepool.driver.simple.nodescan') as nodescan:
with patch('nodepool.driver.statemachine.nodescan') as nodescan:
nodescan.return_value = 'MOCK KEY'
req = zk.NodeRequest()
req.state = zk.REQUESTED
@ -307,7 +307,7 @@ class TestDriverGce(tests.DBTestCase):
nodescan.assert_called_with(
node.interface_ip,
port=22,
timeout=180,
timeout=60,
gather_hostkeys=True)
# A new request will be paused and for lack of quota