diff --git a/zun/db/sqlalchemy/api.py b/zun/db/sqlalchemy/api.py index 69ff77ef9..257cc92f4 100644 --- a/zun/db/sqlalchemy/api.py +++ b/zun/db/sqlalchemy/api.py @@ -129,7 +129,7 @@ class Connection(object): return query - def _add_filters(self, query, filters=None, filter_names=None): + def _add_filters(self, query, model, filters=None, filter_names=None): """Generic way to add filters to a Zun model""" if not filters: return query @@ -141,10 +141,11 @@ class Connection(object): if name in filters: value = filters[name] if isinstance(value, list): - column = getattr(models.Container, name) + column = getattr(model, name) query = query.filter(column.in_(value)) else: - query = query.filter_by(**{name: value}) + column = getattr(model, name) + query = query.filter(column == value) return query @@ -153,7 +154,7 @@ class Connection(object): 'memory', 'host', 'task_state', 'status', 'auto_remove', 'uuid', 'capsule_id'] - return self._add_filters(query, filters=filters, + return self._add_filters(query, models.Container, filters=filters, filter_names=filter_names) def list_containers(self, context, filters=None, limit=None, @@ -257,7 +258,7 @@ class Connection(object): def _add_volume_mappings_filters(self, query, filters): filter_names = ['project_id', 'user_id', 'cinder_volume_id', 'container_path', 'container_uuid', 'volume_provider'] - return self._add_filters(query, filters=filters, + return self._add_filters(query, models.VolumeMapping, filters=filters, filter_names=filter_names) def list_volume_mappings(self, context, filters=None, limit=None, @@ -368,7 +369,7 @@ class Connection(object): def _add_zun_service_filters(self, query, filters): filter_names = ['disabled', 'host', 'binary', 'project_id', 'user_id'] - return self._add_filters(query, filters=filters, + return self._add_filters(query, models.ZunService, filters=filters, filter_names=filter_names) def list_zun_services(self, filters=None, limit=None, marker=None, @@ -429,7 +430,7 @@ class Connection(object): def _add_image_filters(self, query, filters): filter_names = ['repo', 'project_id', 'user_id', 'size'] - return self._add_filters(query, filters=filters, + return self._add_filters(query, models.Image, filters=filters, filter_names=filter_names) def list_images(self, context, filters=None, limit=None, marker=None, @@ -460,7 +461,8 @@ class Connection(object): def _add_resource_providers_filters(self, query, filters): filter_names = ['name', 'root_provider', 'parent_provider', 'can_host'] - return self._add_filters(query, filters=filters, + return self._add_filters(query, models.ResourceProvider, + filters=filters, filter_names=filter_names) def list_resource_providers(self, context, filters=None, limit=None, @@ -608,7 +610,7 @@ class Connection(object): filter_names = ['resource_provider_id', 'resource_class_id', 'total', 'reserved', 'min_unit', 'max_unit', 'step_size', 'allocation_ratio', 'is_nested'] - return self._add_filters(query, filters=filters, + return self._add_filters(query, models.Inventory, filters=filters, filter_names=filter_names) def list_inventories(self, context, filters=None, limit=None, @@ -668,7 +670,7 @@ class Connection(object): def _add_allocations_filters(self, query, filters): filter_names = ['resource_provider_id', 'resource_class_id', 'consumer_id', 'used', 'is_nested'] - return self._add_filters(query, filters=filters, + return self._add_filters(query, models.Allocation, filters=filters, filter_names=filter_names) def list_allocations(self, context, filters=None, limit=None, @@ -726,7 +728,7 @@ class Connection(object): def _add_compute_nodes_filters(self, query, filters): filter_names = ['hostname'] - return self._add_filters(query, filters=filters, + return self._add_filters(query, models.ComputeNode, filters=filters, filter_names=filter_names) def list_compute_nodes(self, context, filters=None, limit=None, @@ -879,7 +881,7 @@ class Connection(object): def _add_capsules_filters(self, query, filters): # filter_names = ['uuid', 'project_id', 'user_id', 'containers'] filter_names = ['uuid', 'project_id', 'user_id'] - return self._add_filters(query, filters=filters, + return self._add_filters(query, models.Capsule, filters=filters, filter_names=filter_names) def get_pci_device_by_addr(self, node_id, dev_addr): @@ -1246,7 +1248,7 @@ class Connection(object): def _add_exec_instances_filters(self, query, filters): filter_names = ['container_id', 'exec_id', 'token'] - return self._add_filters(query, filters=filters, + return self._add_filters(query, models.ExecInstance, filters=filters, filter_names=filter_names) def create_exec_instance(self, context, values):