diff --git a/ironic/db/sqlalchemy/api.py b/ironic/db/sqlalchemy/api.py index 3e5aaa1170..5e46f6889b 100644 --- a/ironic/db/sqlalchemy/api.py +++ b/ironic/db/sqlalchemy/api.py @@ -75,8 +75,14 @@ def _wrap_session(session): return session -def _get_node_query_with_tags(): - return model_query(models.Node).options(joinedload('tags')) +def _get_node_query_with_all(): + """Return a query object for the Node model joined with all relevant fields. + + :returns: a query object. + """ + return model_query(models.Node)\ + .options(joinedload('tags'))\ + .options(joinedload('traits')) def model_query(model, *args, **kwargs): @@ -270,7 +276,7 @@ class Connection(api.Connection): def get_node_list(self, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): - query = _get_node_query_with_tags() + query = _get_node_query_with_all() query = self._add_nodes_filters(query, filters) return _paginate_query(models.Node, limit, marker, sort_key, sort_dir, query) @@ -278,7 +284,7 @@ class Connection(api.Connection): @oslo_db_api.retry_on_deadlock def reserve_node(self, tag, node_id): with _session_for_write(): - query = _get_node_query_with_tags() + query = _get_node_query_with_all() query = add_identity_filter(query, node_id) # be optimistic and assume we usually create a reservation count = query.filter_by(reservation=None).update( @@ -353,7 +359,7 @@ class Connection(api.Connection): return node def get_node_by_id(self, node_id): - query = _get_node_query_with_tags() + query = _get_node_query_with_all() query = query.filter_by(id=node_id) try: return query.one() @@ -361,7 +367,7 @@ class Connection(api.Connection): raise exception.NodeNotFound(node=node_id) def get_node_by_uuid(self, node_uuid): - query = _get_node_query_with_tags() + query = _get_node_query_with_all() query = query.filter_by(uuid=node_uuid) try: return query.one() @@ -369,7 +375,7 @@ class Connection(api.Connection): raise exception.NodeNotFound(node=node_uuid) def get_node_by_name(self, node_name): - query = _get_node_query_with_tags() + query = _get_node_query_with_all() query = query.filter_by(name=node_name) try: return query.one() @@ -380,7 +386,7 @@ class Connection(api.Connection): if not uuidutils.is_uuid_like(instance): raise exception.InvalidUUID(uuid=instance) - query = _get_node_query_with_tags() + query = _get_node_query_with_all() query = query.filter_by(instance_uuid=instance) try: diff --git a/ironic/db/sqlalchemy/models.py b/ironic/db/sqlalchemy/models.py index 3a1e7506ae..8573fef71d 100644 --- a/ironic/db/sqlalchemy/models.py +++ b/ironic/db/sqlalchemy/models.py @@ -290,3 +290,9 @@ class NodeTrait(Base): node_id = Column(Integer, ForeignKey('nodes.id'), primary_key=True, nullable=False) trait = Column(String(255), primary_key=True, nullable=False) + node = orm.relationship( + "Node", + backref='traits', + primaryjoin='and_(NodeTrait.node_id == Node.id)', + foreign_keys=node_id + ) diff --git a/ironic/tests/unit/db/test_nodes.py b/ironic/tests/unit/db/test_nodes.py index ff09b62d9f..01dea51aaa 100644 --- a/ironic/tests/unit/db/test_nodes.py +++ b/ironic/tests/unit/db/test_nodes.py @@ -66,27 +66,36 @@ class DbNodeTestCase(base.DbTestCase): def test_get_node_by_id(self): node = utils.create_test_node() self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) + self.dbapi.set_node_traits(node.id, ['trait1', 'trait2']) res = self.dbapi.get_node_by_id(node.id) self.assertEqual(node.id, res.id) self.assertEqual(node.uuid, res.uuid) self.assertItemsEqual(['tag1', 'tag2'], [tag.tag for tag in res.tags]) + self.assertItemsEqual(['trait1', 'trait2'], + [trait.trait for trait in res.traits]) def test_get_node_by_uuid(self): node = utils.create_test_node() self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) + self.dbapi.set_node_traits(node.id, ['trait1', 'trait2']) res = self.dbapi.get_node_by_uuid(node.uuid) self.assertEqual(node.id, res.id) self.assertEqual(node.uuid, res.uuid) self.assertItemsEqual(['tag1', 'tag2'], [tag.tag for tag in res.tags]) + self.assertItemsEqual(['trait1', 'trait2'], + [trait.trait for trait in res.traits]) def test_get_node_by_name(self): node = utils.create_test_node() self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) + self.dbapi.set_node_traits(node.id, ['trait1', 'trait2']) res = self.dbapi.get_node_by_name(node.name) self.assertEqual(node.id, res.id) self.assertEqual(node.uuid, res.uuid) self.assertEqual(node.name, res.name) self.assertItemsEqual(['tag1', 'tag2'], [tag.tag for tag in res.tags]) + self.assertItemsEqual(['trait1', 'trait2'], + [trait.trait for trait in res.traits]) def test_get_node_that_does_not_exist(self): self.assertRaises(exception.NodeNotFound, @@ -233,6 +242,7 @@ class DbNodeTestCase(base.DbTestCase): six.assertCountEqual(self, uuids, res_uuids) for r in res: self.assertEqual([], r.tags) + self.assertEqual([], r.traits) def test_get_node_list_with_filters(self): ch1 = utils.create_test_chassis(uuid=uuidutils.generate_uuid()) @@ -289,10 +299,13 @@ class DbNodeTestCase(base.DbTestCase): node = utils.create_test_node( instance_uuid='12345678-9999-0000-aaaa-123456789012') self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) + self.dbapi.set_node_traits(node.id, ['trait1', 'trait2']) res = self.dbapi.get_node_by_instance(node.instance_uuid) self.assertEqual(node.uuid, res.uuid) self.assertItemsEqual(['tag1', 'tag2'], [tag.tag for tag in res.tags]) + self.assertItemsEqual(['trait1', 'trait2'], + [trait.trait for trait in res.traits]) def test_get_node_by_instance_wrong_uuid(self): utils.create_test_node( @@ -515,6 +528,7 @@ class DbNodeTestCase(base.DbTestCase): def test_reserve_node(self): node = utils.create_test_node() self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) + self.dbapi.set_node_traits(node.id, ['trait1', 'trait2']) uuid = node.uuid r1 = 'fake-reservation' @@ -522,6 +536,8 @@ class DbNodeTestCase(base.DbTestCase): # reserve the node res = self.dbapi.reserve_node(r1, uuid) self.assertItemsEqual(['tag1', 'tag2'], [tag.tag for tag in res.tags]) + self.assertItemsEqual(['trait1', 'trait2'], + [trait.trait for trait in res.traits]) # check reservation res = self.dbapi.get_node_by_uuid(uuid)