Merge "Refactored query filters."

This commit is contained in:
Jenkins 2013-06-20 20:44:00 +00:00 committed by Gerrit Code Review
commit 21e5bd0539
5 changed files with 34 additions and 16 deletions

View File

@ -190,6 +190,10 @@ class InvalidUUID(Invalid):
message = _("Expected a uuid but received %(uuid)s.")
class InvalidIdentity(Invalid):
message = _("Expected an uuid or int but received %(identity)s.")
class InvalidMAC(Invalid):
message = _("Expected a MAC address but received %(mac)s.")

View File

@ -297,7 +297,7 @@ def is_valid_boolstr(val):
def is_valid_mac(address):
"""Verify the format of a MAC addres."""
m = "[0-9a-f]{2}([-:])[0-9a-f]{2}(\\1[0-9a-f]{2}){4}$"
if re.match(m, address.lower()):
if isinstance(address, str) and re.match(m, address.lower()):
return True
return False

View File

@ -58,24 +58,38 @@ def model_query(model, *args, **kwargs):
return query
def add_uuid_filter(query, value):
def add_identity_filter(query, value):
"""Adds an identity filter to a query.
Filters results by ID, if supplied value is a valid integer.
Otherwise attempts to filter results by UUID.
:param query: Initial query to add filter to.
:param value: Value for filtering results by.
:return: Modified query.
"""
if utils.is_int_like(value):
return query.filter_by(id=value)
elif uuidutils.is_uuid_like(value):
return query.filter_by(uuid=value)
else:
raise exception.InvalidUUID(uuid=value)
raise exception.InvalidIdentity(identity=value)
def add_port_filter(query, value):
if utils.is_int_like(value):
return query.filter_by(id=value)
elif utils.is_valid_mac(value):
"""Adds a port-specific filter to a query.
Filters results by address, if supplied value is a valid MAC
address. Otherwise attempts to filter results by identity.
:param query: Initial query to add filter to.
:param value: Value for filtering results by.
:return: Modified query.
"""
if utils.is_valid_mac(value):
return query.filter_by(address=value)
elif uuidutils.is_uuid_like(value):
return add_uuid_filter(query, value)
else:
raise exception.InvalidMAC(mac=value)
return add_identity_filter(query, value)
class Connection(api.Connection):
@ -109,7 +123,7 @@ class Connection(api.Connection):
# only if needed to determine the cause of an error.
for node in nodes:
query = model_query(models.Node, session=session)
query = add_uuid_filter(query, node)
query = add_identity_filter(query, node)
# Be optimistic and assume we usually get a reservation.
count = query.filter_by(reservation=None).\
@ -135,7 +149,7 @@ class Connection(api.Connection):
# only if needed to determine the cause of an error.
for node in nodes:
query = model_query(models.Node, session=session)
query = add_uuid_filter(query, node)
query = add_identity_filter(query, node)
# be optimistic and assume we usually release a reservation
count = query.filter_by(reservation=tag).\
@ -160,7 +174,7 @@ class Connection(api.Connection):
@objects.objectify(objects.Node)
def get_node(self, node):
query = model_query(models.Node)
query = add_uuid_filter(query, node)
query = add_identity_filter(query, node)
try:
result = query.one()
@ -188,7 +202,7 @@ class Connection(api.Connection):
session = get_session()
with session.begin():
query = model_query(models.Node, session=session)
query = add_uuid_filter(query, node)
query = add_identity_filter(query, node)
count = query.delete()
if count != 1:
@ -199,7 +213,7 @@ class Connection(api.Connection):
session = get_session()
with session.begin():
query = model_query(models.Node, session=session)
query = add_uuid_filter(query, node)
query = add_identity_filter(query, node)
print "Updating with %s." % values
count = query.update(values, synchronize_session='fetch')

View File

@ -64,7 +64,7 @@ class DbNodeTestCase(base.DbTestCase):
self.assertRaises(exception.NodeNotFound,
self.dbapi.get_node,
'12345678-9999-0000-aaaa-123456789012')
self.assertRaises(exception.InvalidUUID,
self.assertRaises(exception.InvalidIdentity,
self.dbapi.get_node, 'not-a-uuid')
def test_get_node_by_instance(self):

View File

@ -52,7 +52,7 @@ class DbPortTestCase(base.DbTestCase):
self.dbapi.get_port, 99)
self.assertRaises(exception.PortNotFound,
self.dbapi.get_port, 'aa:bb:cc:dd:ee:ff')
self.assertRaises(exception.InvalidMAC,
self.assertRaises(exception.InvalidIdentity,
self.dbapi.get_port, 'not-a-mac')
def test_get_ports_by_node_id(self):