diff --git a/nodepool/driver/__init__.py b/nodepool/driver/__init__.py index a8cfc1b35..8ffb40900 100644 --- a/nodepool/driver/__init__.py +++ b/nodepool/driver/__init__.py @@ -373,7 +373,7 @@ class NodeRequestHandler(NodeRequestHandlerNotifications, list if all are valid. ''' invalid = [] - valid = self.provider.getSupportedLabels() + valid = self.provider.getSupportedLabels(self.pool.name) for ntype in self.request.node_types: if ntype not in valid: invalid.append(ntype) @@ -859,8 +859,10 @@ class ProviderConfig(ConfigValue, metaclass=abc.ABCMeta): pass @abc.abstractmethod - def getSupportedLabels(self): + def getSupportedLabels(self, pool_name=None): ''' Return a set of label names supported by this provider. + + :param str pool_name: If provided, get labels for the given pool only. ''' pass diff --git a/nodepool/driver/openstack/config.py b/nodepool/driver/openstack/config.py index bf6c9f6c6..1fe030faa 100644 --- a/nodepool/driver/openstack/config.py +++ b/nodepool/driver/openstack/config.py @@ -385,8 +385,9 @@ class OpenStackProviderConfig(ProviderConfig): 'cloud-images': [provider_cloud_images], }) - def getSupportedLabels(self): + def getSupportedLabels(self, pool_name=None): labels = set() for pool in self.pools.values(): - labels.update(pool.labels.keys()) + if not pool_name or (pool.name == pool_name): + labels.update(pool.labels.keys()) return labels diff --git a/nodepool/driver/static/config.py b/nodepool/driver/static/config.py index 27e526324..48cb74941 100644 --- a/nodepool/driver/static/config.py +++ b/nodepool/driver/static/config.py @@ -109,8 +109,9 @@ class StaticProviderConfig(ProviderConfig): } return v.Schema({'pools': [pool]}) - def getSupportedLabels(self): + def getSupportedLabels(self, pool_name=None): labels = set() for pool in self.pools.values(): - labels.update(pool.labels) + if not pool_name or (pool.name == pool_name): + labels.update(pool.labels) return labels diff --git a/nodepool/driver/test/config.py b/nodepool/driver/test/config.py index 80f8a662f..067c2934c 100644 --- a/nodepool/driver/test/config.py +++ b/nodepool/driver/test/config.py @@ -57,5 +57,5 @@ class TestConfig(ProviderConfig): 'labels': [str]} return v.Schema({'pools': [pool]}) - def getSupportedLabels(self): + def getSupportedLabels(self, pool_name=None): return self.labels