diff --git a/marconi/queues/storage/sqlalchemy/messages.py b/marconi/queues/storage/sqlalchemy/messages.py index b2e5aa80a..5257ee651 100644 --- a/marconi/queues/storage/sqlalchemy/messages.py +++ b/marconi/queues/storage/sqlalchemy/messages.py @@ -27,6 +27,12 @@ from marconi.queues.storage.sqlalchemy import utils class MessageController(storage.Message): + def _and_stmt_with_ttl(self, queue_name, project): + return [tables.Queues.c.name == queue_name, + tables.Queues.c.project == project, + tables.Messages.c.ttl > + sfunc.now() - tables.Messages.c.created] + def _get(self, queue, message_id, project, count=False): if project is None: @@ -49,11 +55,10 @@ class MessageController(storage.Message): sel = sa.sql.select([sfunc.count(tables.Messages.c.id)]) sel = sel.select_from(j) - sel = sel.where(sa.and_(tables.Messages.c.id == mid, - tables.Queues.c.project == project, - tables.Queues.c.name == queue, - tables.Messages.c.ttl > - sfunc.now() - tables.Messages.c.created)) + and_stmt = [tables.Messages.c.id == mid] + and_stmt.extend(self._and_stmt_with_ttl(queue, project)) + + sel = sel.where(sa.and_(*and_stmt)) return self.driver.get(sel) except utils.NoResult: @@ -91,11 +96,8 @@ class MessageController(storage.Message): tables.Messages.c.ttl, tables.Messages.c.created]) - and_stmt = [tables.Messages.c.id.in_(message_ids), - tables.Queues.c.name == queue, - tables.Queues.c.project == project, - tables.Messages.c.ttl > - sfunc.now() - tables.Messages.c.created] + and_stmt = [tables.Messages.c.id.in_(message_ids)] + and_stmt.extend(self._and_stmt_with_ttl(queue, project)) j = sa.join(tables.Messages, tables.Queues, tables.Messages.c.qid == tables.Queues.c.id) @@ -166,8 +168,7 @@ class MessageController(storage.Message): tables.Messages.c.qid == tables.Queues.c.id) sel = sel.select_from(j) - and_clause = [tables.Queues.c.name == queue, - tables.Queues.c.project == project] + and_clause = self._and_stmt_with_ttl(queue, project) if not echo: and_clause.append(tables.Messages.c.client != str(client_uuid)) @@ -308,8 +309,7 @@ class MessageController(storage.Message): tables.Messages.c.qid == tables.Queues.c.id) sel = sel.select_from(j) - and_clause = [tables.Queues.c.name == queue_name, - tables.Queues.c.project == project] + and_clause = self._and_stmt_with_ttl(queue_name, project) and_clause.append(tables.Messages.c.cid == (None)) diff --git a/marconi/tests/queues/storage/base.py b/marconi/tests/queues/storage/base.py index 3d620af9d..e67f59125 100644 --- a/marconi/tests/queues/storage/base.py +++ b/marconi/tests/queues/storage/base.py @@ -421,23 +421,39 @@ class MessageControllerTest(ControllerBaseTest): @testing.is_slow(condition=lambda self: self.gc_interval != 0) def test_expired_messages(self): - messages = [{'body': 3.14, 'ttl': 0}] + messages = [{'body': 3.14, 'ttl': 0}, {'body': 0.618, 'ttl': 600}] client_uuid = uuid.uuid4() - [msgid] = self.controller.post(self.queue_name, messages, - project=self.project, - client_uuid=client_uuid) + [msgid_expired, msgid] = self.controller.post(self.queue_name, + messages, + project=self.project, + client_uuid=client_uuid) time.sleep(self.gc_interval) with testing.expect(storage.errors.DoesNotExist): - self.controller.get(self.queue_name, msgid, + self.controller.get(self.queue_name, msgid_expired, project=self.project) stats = self.queue_controller.stats(self.queue_name, project=self.project) - self.assertEqual(stats['messages']['free'], 0) + self.assertEqual(stats['messages']['free'], 1) + + # Make sure expired messages not return when listing + interaction = self.controller.list(self.queue_name, + project=self.project) + + messages = list(next(interaction)) + self.assertEqual(len(messages), 1) + self.assertEqual(msgid, messages[0]['id']) + + # Make sure expired messages not return when popping + messages = self.controller.pop(self.queue_name, + limit=10, + project=self.project) + self.assertEqual(len(messages), 1) + self.assertEqual(msgid, messages[0]['id']) def test_bad_id(self): # NOTE(cpp-cabrera): A malformed ID should result in an empty