diff --git a/ironic/common/exception.py b/ironic/common/exception.py index 3c846f8d02..f9da690d18 100644 --- a/ironic/common/exception.py +++ b/ironic/common/exception.py @@ -190,6 +190,10 @@ class InvalidUUID(Invalid): message = _("Expected a uuid but received %(uuid)s.") +class InvalidIdentity(Invalid): + message = _("Expected an uuid or int but received %(identity)s.") + + class InvalidMAC(Invalid): message = _("Expected a MAC address but received %(mac)s.") diff --git a/ironic/common/utils.py b/ironic/common/utils.py index a7c49c3205..2130ecabfe 100644 --- a/ironic/common/utils.py +++ b/ironic/common/utils.py @@ -297,7 +297,7 @@ def is_valid_boolstr(val): def is_valid_mac(address): """Verify the format of a MAC addres.""" m = "[0-9a-f]{2}([-:])[0-9a-f]{2}(\\1[0-9a-f]{2}){4}$" - if re.match(m, address.lower()): + if isinstance(address, str) and re.match(m, address.lower()): return True return False diff --git a/ironic/db/sqlalchemy/api.py b/ironic/db/sqlalchemy/api.py index fc6f205a84..8f7480b63b 100644 --- a/ironic/db/sqlalchemy/api.py +++ b/ironic/db/sqlalchemy/api.py @@ -58,24 +58,38 @@ def model_query(model, *args, **kwargs): return query -def add_uuid_filter(query, value): +def add_identity_filter(query, value): + """Adds an identity filter to a query. + + Filters results by ID, if supplied value is a valid integer. + Otherwise attempts to filter results by UUID. + + :param query: Initial query to add filter to. + :param value: Value for filtering results by. + :return: Modified query. + """ if utils.is_int_like(value): return query.filter_by(id=value) elif uuidutils.is_uuid_like(value): return query.filter_by(uuid=value) else: - raise exception.InvalidUUID(uuid=value) + raise exception.InvalidIdentity(identity=value) def add_port_filter(query, value): - if utils.is_int_like(value): - return query.filter_by(id=value) - elif utils.is_valid_mac(value): + """Adds a port-specific filter to a query. + + Filters results by address, if supplied value is a valid MAC + address. Otherwise attempts to filter results by identity. + + :param query: Initial query to add filter to. + :param value: Value for filtering results by. + :return: Modified query. + """ + if utils.is_valid_mac(value): return query.filter_by(address=value) - elif uuidutils.is_uuid_like(value): - return add_uuid_filter(query, value) else: - raise exception.InvalidMAC(mac=value) + return add_identity_filter(query, value) class Connection(api.Connection): @@ -109,7 +123,7 @@ class Connection(api.Connection): # only if needed to determine the cause of an error. for node in nodes: query = model_query(models.Node, session=session) - query = add_uuid_filter(query, node) + query = add_identity_filter(query, node) # Be optimistic and assume we usually get a reservation. count = query.filter_by(reservation=None).\ @@ -135,7 +149,7 @@ class Connection(api.Connection): # only if needed to determine the cause of an error. for node in nodes: query = model_query(models.Node, session=session) - query = add_uuid_filter(query, node) + query = add_identity_filter(query, node) # be optimistic and assume we usually release a reservation count = query.filter_by(reservation=tag).\ @@ -160,7 +174,7 @@ class Connection(api.Connection): @objects.objectify(objects.Node) def get_node(self, node): query = model_query(models.Node) - query = add_uuid_filter(query, node) + query = add_identity_filter(query, node) try: result = query.one() @@ -188,7 +202,7 @@ class Connection(api.Connection): session = get_session() with session.begin(): query = model_query(models.Node, session=session) - query = add_uuid_filter(query, node) + query = add_identity_filter(query, node) count = query.delete() if count != 1: @@ -199,7 +213,7 @@ class Connection(api.Connection): session = get_session() with session.begin(): query = model_query(models.Node, session=session) - query = add_uuid_filter(query, node) + query = add_identity_filter(query, node) print "Updating with %s." % values count = query.update(values, synchronize_session='fetch') diff --git a/ironic/tests/db/test_nodes.py b/ironic/tests/db/test_nodes.py index ec61e59d0e..d784e2df20 100644 --- a/ironic/tests/db/test_nodes.py +++ b/ironic/tests/db/test_nodes.py @@ -64,7 +64,7 @@ class DbNodeTestCase(base.DbTestCase): self.assertRaises(exception.NodeNotFound, self.dbapi.get_node, '12345678-9999-0000-aaaa-123456789012') - self.assertRaises(exception.InvalidUUID, + self.assertRaises(exception.InvalidIdentity, self.dbapi.get_node, 'not-a-uuid') def test_get_node_by_instance(self): diff --git a/ironic/tests/db/test_ports.py b/ironic/tests/db/test_ports.py index dd25330964..a16b613776 100644 --- a/ironic/tests/db/test_ports.py +++ b/ironic/tests/db/test_ports.py @@ -52,7 +52,7 @@ class DbPortTestCase(base.DbTestCase): self.dbapi.get_port, 99) self.assertRaises(exception.PortNotFound, self.dbapi.get_port, 'aa:bb:cc:dd:ee:ff') - self.assertRaises(exception.InvalidMAC, + self.assertRaises(exception.InvalidIdentity, self.dbapi.get_port, 'not-a-mac') def test_get_ports_by_node_id(self):