James E. Blair 43678bf4c1 Update AWS driver to use statemachine framework
This updates the aws driver to use the statemachine framework which
should be able to scale to a much higher number of parallel operations
than the standard thread-per-node model.  It is also simpler and
easier to maintain.  Several new features are added to bring it to
parity with other drivers.

The unit tests are changed minimally so that they continue to serve
as regression tests for the new framework.  Following changes will
revise the tests and add new tests for the additional functionality.

Change-Id: I8968667f927c82641460debeccd04e0511eb86a9
2022-02-22 17:06:07 -08:00

638 lines
24 KiB
Python

# Copyright 2018 Red Hat
# Copyright 2022 Acme Gating, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import json
import logging
import math
import cachetools.func
import urllib.parse
import time
import re
import boto3
from nodepool.driver.utils import QuotaInformation, RateLimiter
from nodepool.driver import statemachine
def tag_dict_to_list(tagdict):
# TODO: validate tag values are strings in config and deprecate
# non-string values.
return [{"Key": k, "Value": str(v)} for k, v in tagdict.items()]
def tag_list_to_dict(taglist):
if taglist is None:
return {}
return {t["Key"]: t["Value"] for t in taglist}
class AwsInstance(statemachine.Instance):
def __init__(self, instance, quota):
super().__init__()
self.external_id = instance.id
self.metadata = tag_list_to_dict(instance.tags)
self.private_ipv4 = instance.private_ip_address
self.private_ipv6 = None
self.public_ipv4 = instance.public_ip_address
self.public_ipv6 = None
self.az = ''
self.quota = quota
for iface in instance.network_interfaces[:1]:
if iface.ipv6_addresses:
v6addr = iface.ipv6_addresses[0]
self.public_ipv6 = v6addr['Ipv6Address']
self.interface_ip = (self.public_ipv4 or self.public_ipv6 or
self.private_ipv4 or self.private_ipv6)
def getQuotaInformation(self):
return self.quota
class AwsResource(statemachine.Resource):
def __init__(self, metadata, type, id):
super().__init__(metadata)
self.type = type
self.id = id
class AwsDeleteStateMachine(statemachine.StateMachine):
VM_DELETING = 'deleting vm'
NIC_DELETING = 'deleting nic'
PIP_DELETING = 'deleting pip'
DISK_DELETING = 'deleting disk'
COMPLETE = 'complete'
def __init__(self, adapter, external_id):
super().__init__()
self.adapter = adapter
self.external_id = external_id
def advance(self):
if self.state == self.START:
self.instance = self.adapter._deleteInstance(
self.external_id)
self.state = self.VM_DELETING
if self.state == self.VM_DELETING:
self.instance = self.adapter._refreshDelete(self.instance)
if self.instance is None:
self.state = self.COMPLETE
if self.state == self.COMPLETE:
self.complete = True
class AwsCreateStateMachine(statemachine.StateMachine):
INSTANCE_CREATING = 'creating instance'
INSTANCE_RETRY = 'retrying instance creation'
COMPLETE = 'complete'
def __init__(self, adapter, hostname, label, image_external_id,
metadata, retries):
super().__init__()
self.adapter = adapter
self.retries = retries
self.attempts = 0
self.image_external_id = image_external_id
self.metadata = metadata
self.tags = label.tags.copy() or {}
self.tags.update(metadata)
self.tags['Name'] = hostname
self.hostname = hostname
self.label = label
self.public_ipv4 = None
self.public_ipv6 = None
self.nic = None
self.instance = None
def advance(self):
if self.state == self.START:
self.external_id = self.hostname
self.instance = self.adapter._createInstance(
self.label, self.image_external_id,
self.tags, self.hostname)
self.state = self.INSTANCE_CREATING
if self.state == self.INSTANCE_CREATING:
self.quota = self.adapter._getQuotaForInstanceType(
self.instance.instance_type)
self.instance = self.adapter._refresh(self.instance)
if self.instance.state["Name"].lower() == "running":
self.state = self.COMPLETE
elif self.instance.state["Name"].lower() == "terminated":
if self.attempts >= self.retries:
raise Exception("Too many retries")
self.attempts += 1
self.instance = self.adapter._deleteInstance(
self.external_id)
self.state = self.INSTANCE_RETRY
else:
return
if self.state == self.INSTANCE_RETRY:
self.instance = self.adapter._refreshDelete(self.instance)
if self.instance is None:
self.state = self.START
return
if self.state == self.COMPLETE:
self.complete = True
return AwsInstance(self.instance, self.quota)
class AwsAdapter(statemachine.Adapter):
log = logging.getLogger("nodepool.driver.aws.AwsAdapter")
def __init__(self, provider_config):
self.provider = provider_config
# The standard rate limit, this might be 1 request per second
self.rate_limiter = RateLimiter(self.provider.name,
self.provider.rate)
# Non mutating requests can be made more often at 10x the rate
# of mutating requests by default.
self.non_mutating_rate_limiter = RateLimiter(self.provider.name,
self.provider.rate * 10.0)
self.image_id_by_filter_cache = cachetools.TTLCache(
maxsize=8192, ttl=(5 * 60))
self.aws = boto3.Session(
region_name=self.provider.region_name,
profile_name=self.provider.profile_name)
self.ec2 = self.aws.resource('ec2')
self.ec2_client = self.aws.client("ec2")
self.s3 = self.aws.resource('s3')
self.s3_client = self.aws.client('s3')
self.aws_quotas = self.aws.client("service-quotas")
# In listResources, we reconcile AMIs which appear to be
# imports but have no nodepool tags, however it's possible
# that these aren't nodepool images. If we determine that's
# the case, we'll add their ids here so we don't waste our
# time on that again.
self.not_our_images = set()
self.not_our_snapshots = set()
def getCreateStateMachine(self, hostname, label,
image_external_id, metadata, retries):
return AwsCreateStateMachine(self, hostname, label,
image_external_id, metadata, retries)
def getDeleteStateMachine(self, external_id):
return AwsDeleteStateMachine(self, external_id)
def listResources(self):
self._tagAmis()
self._tagSnapshots()
for instance in self._listInstances():
if instance.state["Name"].lower() == "terminated":
continue
yield AwsResource(tag_list_to_dict(instance.tags),
'instance', instance.id)
for volume in self._listVolumes():
if volume.state.lower() == "deleted":
continue
yield AwsResource(tag_list_to_dict(volume.tags),
'volume', volume.id)
for ami in self._listAmis():
if ami.state.lower() == "deleted":
continue
yield AwsResource(tag_list_to_dict(ami.tags),
'ami', ami.id)
for snap in self._listSnapshots():
if snap.state.lower() == "deleted":
continue
yield AwsResource(tag_list_to_dict(snap.tags),
'snapshot', snap.id)
if self.provider.object_storage:
for obj in self._listObjects():
with self.non_mutating_rate_limiter:
tags = self.s3_client.get_object_tagging(
Bucket=obj.bucket_name, Key=obj.key)
yield AwsResource(tag_list_to_dict(tags['TagSet']),
'object', obj.key)
def deleteResource(self, resource):
self.log.info(f"Deleting leaked {resource.type}: {resource.id}")
if resource.type == 'instance':
self._deleteInstance(resource.id)
if resource.type == 'volume':
self._deleteVolume(resource.id)
if resource.type == 'ami':
self._deleteAmi(resource.id)
if resource.type == 'snapshot':
self._deleteSnapshot(resource.id)
if resource.type == 'object':
self._deleteObject(resource.id)
def listInstances(self):
for instance in self._listInstances():
if instance.state["Name"].lower() == "terminated":
continue
quota = self._getQuotaForInstanceType(instance.instance_type)
yield AwsInstance(instance, quota)
def getQuotaLimits(self):
with self.non_mutating_rate_limiter:
response = self.aws_quotas.get_service_quota(
ServiceCode='ec2',
QuotaCode='L-1216C47A'
)
cores = response['Quota']['Value']
return QuotaInformation(cores=cores,
default=math.inf)
def getQuotaForLabel(self, label):
return self._getQuotaForInstanceType(label.instance_type)
def uploadImage(self, provider_image, image_name, filename,
image_format, metadata, md5, sha256):
self.log.debug(f"Uploading image {image_name}")
# Upload image to S3
bucket_name = self.provider.object_storage['bucket-name']
bucket = self.s3.Bucket(bucket_name)
object_filename = f'{image_name}.{image_format}'
extra_args = {'Tagging': urllib.parse.urlencode(metadata)}
with open(filename, "rb") as fobj:
with self.rate_limiter:
bucket.upload_fileobj(fobj, object_filename,
ExtraArgs=extra_args)
# Import image as AMI
self.log.debug(f"Importing {image_name}")
import_image_task = self.ec2_client.import_image(
Architecture=provider_image.architecture,
DiskContainers=[
{
'Format': image_format,
'UserBucket': {
'S3Bucket': bucket_name,
'S3Key': object_filename,
}
},
],
TagSpecifications=[
{
'ResourceType': 'import-image-task',
'Tags': tag_dict_to_list(metadata),
},
]
)
task_id = import_image_task['ImportTaskId']
paginator = self.ec2_client.get_paginator(
'describe_import_image_tasks')
done = False
while not done:
time.sleep(30)
with self.non_mutating_rate_limiter:
for page in paginator.paginate(ImportTaskIds=[task_id]):
for task in page['ImportImageTasks']:
if task['Status'].lower() in ('completed', 'deleted'):
done = True
break
self.log.debug(f"Deleting {image_name} from S3")
with self.rate_limiter:
self.s3.Object(bucket_name, object_filename).delete()
if task['Status'].lower() != 'completed':
raise Exception(f"Error uploading image: {task}")
# Tag the AMI
try:
with self.non_mutating_rate_limiter:
ami = self.ec2.Image(task['ImageId'])
with self.rate_limiter:
ami.create_tags(Tags=task['Tags'])
except Exception:
self.log.exception("Error tagging AMI:")
# Tag the snapshot
try:
with self.non_mutating_rate_limiter:
snap = self.ec2.Snapshot(
task['SnapshotDetails'][0]['SnapshotId'])
with self.rate_limiter:
snap.create_tags(Tags=task['Tags'])
except Exception:
self.log.exception("Error tagging snapshot:")
self.log.debug(f"Upload of {image_name} complete as {task['ImageId']}")
# Last task returned from paginator above
return task['ImageId']
def deleteImage(self, external_id):
self.log.debug(f"Deleting image {external_id}")
# Local implementation below
def _tagAmis(self):
# There is no way to tag imported AMIs, so this routine
# "eventually" tags them. We look for any AMIs without tags
# which correspond to import tasks, and we copy the tags from
# those import tasks to the AMI.
for ami in self._listAmis():
if (ami.name.startswith('import-ami-') and
not ami.tags and
ami.id not in self.not_our_images):
# This image was imported but has no tags, which means
# it's either not a nodepool image, or it's a new one
# which doesn't have tags yet. Copy over any tags
# from the import task; otherwise, mark it as an image
# we can ignore in future runs.
task = self._getImportImageTask(ami.name)
tags = tag_list_to_dict(task.get('Tags'))
if (tags.get('nodepool_provider_name') == self.provider.name):
# Copy over tags
self.log.debug(
f"Copying tags from import task {ami.name} to AMI")
with self.rate_limiter:
ami.create_tags(Tags=task['Tags'])
else:
self.not_our_images.add(ami.id)
def _tagSnapshots(self):
# See comments for _tagAmis
for snap in self._listSnapshots():
if ('import-ami-' in snap.description and
not snap.tags and
snap.id not in self.not_our_snapshots):
match = re.match(r'.*?(import-ami-\w*)', snap.description)
if not match:
self.not_our_snapshots.add(snap.id)
continue
task_id = match.group(1)
task = self._getImportImageTask(task_id)
tags = tag_list_to_dict(task.get('Tags'))
if (tags.get('nodepool_provider_name') == self.provider.name):
# Copy over tags
self.log.debug(
f"Copying tags from import task {task_id} to snapshot")
with self.rate_limiter:
snap.create_tags(Tags=task['Tags'])
else:
self.not_our_snapshots.add(snap.id)
def _getImportImageTask(self, task_id):
paginator = self.ec2_client.get_paginator(
'describe_import_image_tasks')
with self.non_mutating_rate_limiter:
for page in paginator.paginate(ImportTaskIds=[task_id]):
for task in page['ImportImageTasks']:
# Return the first and only task
return task
def _getQuotaForInstanceType(self, instance_type):
itype = self._getInstanceType(instance_type)
cores = itype['InstanceTypes'][0]['VCpuInfo']['DefaultCores']
ram = itype['InstanceTypes'][0]['MemoryInfo']['SizeInMiB']
return QuotaInformation(cores=cores,
ram=ram,
instances=1)
@cachetools.func.lru_cache(maxsize=None)
def _getInstanceType(self, instance_type):
with self.non_mutating_rate_limiter:
self.log.debug(
f"Getting information for instance type {instance_type}")
return self.ec2_client.describe_instance_types(
InstanceTypes=[instance_type])
def _refresh(self, obj):
for instance in self._listInstances():
if instance.id == obj.id:
return instance
def _refreshDelete(self, obj):
if obj is None:
return obj
for instance in self._listInstances():
if instance.id == obj.id:
if instance.state["Name"].lower() == "terminated":
return None
return instance
return None
@cachetools.func.ttl_cache(maxsize=1, ttl=10)
def _listInstances(self):
with self.non_mutating_rate_limiter:
return self.ec2.instances.all()
@cachetools.func.ttl_cache(maxsize=1, ttl=10)
def _listVolumes(self):
with self.non_mutating_rate_limiter:
return self.ec2.volumes.all()
@cachetools.func.ttl_cache(maxsize=1, ttl=10)
def _listAmis(self):
with self.non_mutating_rate_limiter:
return self.ec2.images.filter(Owners=['self'])
@cachetools.func.ttl_cache(maxsize=1, ttl=10)
def _listSnapshots(self):
with self.non_mutating_rate_limiter:
return self.ec2.snapshots.filter(OwnerIds=['self'])
@cachetools.func.ttl_cache(maxsize=1, ttl=10)
def _listObjects(self):
bucket_name = self.provider.object_storage.get('bucket-name')
if not bucket_name:
return []
bucket = self.s3.Bucket(bucket_name)
with self.non_mutating_rate_limiter:
return bucket.objects.all()
def _getLatestImageIdByFilters(self, image_filters):
# Normally we would decorate this method, but our cache key is
# complex, so we serialize it to JSON and manage the cache
# ourselves.
cache_key = json.dumps(image_filters)
val = self.image_id_by_filter_cache.get(cache_key)
if val:
return val
with self.non_mutating_rate_limiter:
res = self.ec2_client.describe_images(
Filters=image_filters
).get("Images")
images = sorted(
res,
key=lambda k: k["CreationDate"],
reverse=True
)
if not images:
raise Exception(
"No cloud-image (AMI) matches supplied image filters")
else:
val = images[0].get("ImageId")
self.image_id_by_filter_cache[cache_key] = val
return val
def _getImageId(self, cloud_image):
image_id = cloud_image.image_id
image_filters = cloud_image.image_filters
if image_filters is not None:
return self._getLatestImageIdByFilters(image_filters)
return image_id
@cachetools.func.lru_cache(maxsize=None)
def _getImage(self, image_id):
with self.non_mutating_rate_limiter:
return self.ec2.Image(image_id)
def _createInstance(self, label, image_external_id,
tags, hostname):
if image_external_id:
image_id = image_external_id
else:
image_id = self._getImageId(label.cloud_image)
args = dict(
ImageId=image_id,
MinCount=1,
MaxCount=1,
KeyName=label.key_name,
EbsOptimized=label.ebs_optimized,
InstanceType=label.instance_type,
NetworkInterfaces=[{
'AssociatePublicIpAddress': label.pool.public_ipv4,
'DeviceIndex': 0}],
TagSpecifications=[
{
'ResourceType': 'instance',
'Tags': tag_dict_to_list(tags),
},
{
'ResourceType': 'volume',
'Tags': tag_dict_to_list(tags),
},
]
)
if label.pool.security_group_id:
args['NetworkInterfaces'][0]['Groups'] = [
label.pool.security_group_id
]
if label.pool.subnet_id:
args['NetworkInterfaces'][0]['SubnetId'] = label.pool.subnet_id
if label.pool.public_ipv6:
args['NetworkInterfaces'][0]['Ipv6AddressCount'] = 1
if label.userdata:
args['UserData'] = label.userdata
if label.iam_instance_profile:
if 'name' in label.iam_instance_profile:
args['IamInstanceProfile'] = {
'Name': label.iam_instance_profile['name']
}
elif 'arn' in label.iam_instance_profile:
args['IamInstanceProfile'] = {
'Arn': label.iam_instance_profile['arn']
}
# Default block device mapping parameters are embedded in AMIs.
# We might need to supply our own mapping before lauching the instance.
# We basically want to make sure DeleteOnTermination is true and be
# able to set the volume type and size.
image = self._getImage(image_id)
# TODO: Flavors can also influence whether or not the VM spawns with a
# volume -- we basically need to ensure DeleteOnTermination is true.
# However, leaked volume detection may mitigate this.
if hasattr(image, 'block_device_mappings'):
bdm = image.block_device_mappings
mapping = bdm[0]
if 'Ebs' in mapping:
mapping['Ebs']['DeleteOnTermination'] = True
if label.volume_size:
mapping['Ebs']['VolumeSize'] = label.volume_size
if label.volume_type:
mapping['Ebs']['VolumeType'] = label.volume_type
# If the AMI is a snapshot, we cannot supply an "encrypted"
# parameter
if 'Encrypted' in mapping['Ebs']:
del mapping['Ebs']['Encrypted']
args['BlockDeviceMappings'] = [mapping]
with self.rate_limiter:
self.log.debug(f"Creating VM {hostname}")
instances = self.ec2.create_instances(**args)
return self.ec2.Instance(instances[0].id)
def _deleteInstance(self, external_id):
for instance in self._listInstances():
if instance.id == external_id:
break
else:
self.log.warning(f"Instance not found when deleting {external_id}")
return None
with self.rate_limiter:
self.log.debug(f"Deleting instance {external_id}")
instance.terminate()
return instance
def _deleteVolume(self, external_id):
for volume in self._listVolumes():
if volume.id == external_id:
break
else:
self.log.warning(f"Volume not found when deleting {external_id}")
return None
with self.rate_limiter:
self.log.debug(f"Deleting volume {external_id}")
volume.delete()
return volume
def _deleteAmi(self, external_id):
for ami in self._listAmis():
if ami.id == external_id:
break
else:
self.log.warning(f"AMI not found when deleting {external_id}")
return None
with self.rate_limiter:
self.log.debug(f"Deleting AMI {external_id}")
ami.deregister()
return ami
def _deleteSnapshot(self, external_id):
for snap in self._listSnapshots():
if snap.id == external_id:
break
else:
self.log.warning(f"Snapshot not found when deleting {external_id}")
return None
with self.rate_limiter:
self.log.debug(f"Deleting Snapshot {external_id}")
snap.delete()
return snap
def _deleteObject(self, external_id):
bucket_name = self.provider.object_storage.get('bucket-name')
with self.rate_limiter:
self.log.debug(f"Deleting object {external_id}")
self.s3.Object(bucket_name, external_id).delete()