Guarantee FIFO/once-and-only-once delivery when using MongoDB

This patch modifies the way message markers are generated and
used, such that Marconi can guarantee FIFO for a single message
producer posting to a single queue (FIFO for multiple producers
is not guaranteed.) At the same time, these changes guarantee
that observer clients will recieve messages once-and-only once
by removing race conditions inherent in timestamp-based markers.

Along the way, some technical debt was also addressed, particulary
regarding style, as well as some minor optimizations.

A new service, marconi-mongo-gc, was also added, and is required
when using the MongoDB storage driver for Marconi. This service
was necessary due to the special requirement that at least the
most recent message always remain in each queue, which is an
a-priori assumption upon which the FIFO/pagination algorithm
is based, and arises from the particular constraints imposed
by MongoDB's semantics.

Note: While implementing this blueprint, many calculated tradeoffs
were made in an attempt to balance performance, risk, readability,
and maintainability. The goal was to create a resonable baseline
implementation that can be iterated upon pending comprehensive
system and performance testing. Due to the many subtleties
of solving the FIFO/once-and-only-once problem for the MongoDB
driver, future contributors should excercise extreme caution when
modifying the algorithm introduced in this patch.

Changes include:
* Align text in comments
* Add counter to queue, messages
* Markers are now converted over to using monotonic counter
* Handle DuplicateKeyError
* Return resources in body as a response to a message POST
* Added mongo driver claims tests
* Return 503 when no messages were enqueued due to marker conflict
* Added backoff sleep between retries
* Added marconi-mongo-gc service. This is a new required servi

Implements: blueprint message-pagination
Change-Id: Ifa0bb9e1bc393545adc4c804d14c6eb2df01848c
This commit is contained in:
kgriffs 2013-05-21 15:54:12 -03:00
parent c91c724a05
commit bfd29252f5
22 changed files with 1139 additions and 297 deletions

View File

@ -1,12 +1,12 @@
[DEFAULT] [DEFAULT]
; debug = False ; debug = False
; verbose = False ; verbose = False
; auth_strategy = ; auth_strategy =
[drivers] [drivers]
;transport = marconi.transport.wsgi, marconi.transport.zmq # Transport driver module (e.g., marconi.transport.wsgi, marconi.transport.zmq)
transport = marconi.transport.wsgi transport = marconi.transport.wsgi
;storage = marconi.storage.mongodb, marconi.storage.sqlite # Storage driver module (e.g., marconi.storage.mongodb, marconi.storage.sqlite)
storage = marconi.storage.mongodb storage = marconi.storage.mongodb
[drivers:transport:wsgi] [drivers:transport:wsgi]
@ -19,3 +19,30 @@ port = 8888
[drivers:storage:mongodb] [drivers:storage:mongodb]
uri = mongodb://db1.example.net,db2.example.net:2500/?replicaSet=test&ssl=true&w=majority uri = mongodb://db1.example.net,db2.example.net:2500/?replicaSet=test&ssl=true&w=majority
database = marconi database = marconi
# Maximum number of times to retry a failed operation. Currently
# only used for retrying a message post.
;max_attempts = 1000
# Maximum sleep interval between retries (actual sleep time
# increases linearly according to number of attempts performed).
;max_retry_sleep = 0.1
# Maximum jitter interval, to be added to the sleep interval, in
# order to decrease probability that parallel requests will retry
# at the same instant.
;max_retry_jitter = 0.005
# Frequency of message garbage collections, in seconds
;gc_interval = 5 * 60
# Threshold of number of expired messages to reach in a given
# queue, before performing the GC. Useful for reducing frequent
# locks on the DB for non-busy queues, or for worker queues
# which process jobs quickly enough to keep the number of in-
# flight messages low.
#
# Note: The higher this number, the larger the memory-mapped DB
# files will be.
;gc_threshold = 1000

103
marconi/cmd/gc.py Normal file
View File

@ -0,0 +1,103 @@
# Copyright (c) 2013 Rackspace Hosting, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
#
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
import random
import sys
import termios
import time
from marconi import bootstrap
from marconi.common import config
from marconi.openstack.common import log as logging
PROJECT_CFG = config.project('marconi')
LOG = logging.getLogger(__name__)
def _fail(returncode, ex):
"""Handles terminal errors.
:param returncode: process return code to pass to sys.exit
:param ex: the error that occurred
"""
LOG.exception(ex)
sys.stderr.write('ERROR: %s\n' % ex)
sys.exit(returncode)
def _enable_echo(enable):
"""Enables or disables terminal echo.
:param enable: pass True to enable echo, False to disable
"""
fd = sys.stdin.fileno()
new = termios.tcgetattr(fd)
if enable:
new[3] |= termios.ECHO
else:
new[3] &= ~termios.ECHO
termios.tcsetattr(fd, termios.TCSANOW, new)
def run():
"""Entry point to start marconi-gc.
Operators should run 2-3 instances on different
boxes for fault-tolerance.
Note: This call blocks until the process is killed
or interrupted.
"""
atexit.register(_enable_echo, True)
_enable_echo(False)
try:
logging.setup('marconi')
PROJECT_CFG.load(args=sys.argv[1:])
info = _('Starting marconi-gc')
print(info + _('. Use CTRL+C to exit...\n'))
LOG.info(info)
boot = bootstrap.Bootstrap(cli_args=sys.argv[1:])
storage_driver = boot.storage
gc_interval = storage_driver.gc_interval
# NOTE(kgriffs): Don't want all garbage collector
# instances running at the same time (will peg the DB).
offset = random.random() * gc_interval
time.sleep(offset)
while True:
storage_driver.gc()
time.sleep(gc_interval)
except NotImplementedError as ex:
print('The configured storage driver does not support GC.\n')
LOG.exception(ex)
print('')
except KeyboardInterrupt:
LOG.info('Terminating marconi-gc')
except Exception as ex:
_fail(1, ex)

View File

@ -17,17 +17,17 @@
A config variable `foo` is a read-only property accessible through A config variable `foo` is a read-only property accessible through
cfg.foo CFG.foo
, where `cfg` is either a global configuration accessible through , where `CFG` is either a global configuration accessible through
cfg = config.project('marconi').from_options( CFG = config.project('marconi').from_options(
foo=("bar", "usage"), foo=("bar", "usage"),
...) ...)
, or a local configuration associated with a namespace , or a local configuration associated with a namespace
cfg = config.namespace('drivers:transport:wsgi').from_options( CFG = config.namespace('drivers:transport:wsgi').from_options(
port=80, port=80,
...) ...)
@ -43,8 +43,8 @@ sections named by their associated namespaces.
To load the configurations from a file: To load the configurations from a file:
cfg_handle = config.project('marconi') PROJECT_CFG = config.project('marconi')
cfg_handle.load(filename="/path/to/example.conf") PROJECT_CFG.load(filename="/path/to/example.conf")
A call to `.load` without a filename looks up for the default ones: A call to `.load` without a filename looks up for the default ones:
@ -54,7 +54,7 @@ A call to `.load` without a filename looks up for the default ones:
Global config variables, if any, can also be read from the command line Global config variables, if any, can also be read from the command line
arguments: arguments:
cfg_handle.load(filename="example.conf", args=sys.argv[1:]) PROJECT_CFG.load(filename="example.conf", args=sys.argv[1:])
""" """
from oslo.config import cfg from oslo.config import cfg

View File

@ -16,3 +16,7 @@
class InvalidDriver(Exception): class InvalidDriver(Exception):
pass pass
class PatternNotFound(Exception):
"""A string did not match the expected pattern or regex."""

View File

@ -21,6 +21,29 @@ import abc
class DriverBase: class DriverBase:
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def gc(self):
"""Runs a garbage collection operation.
Called periodically by mongo-gc to trigger removal
of expired resources from the storage provider.
If GC is supported by a given driver, the driver
MUST override this method.
"""
raise NotImplementedError
@property
def gc_interval(self):
"""Returns the GC interval, in seconds.
Used by mongo-gc to determine how often to
call driver.gc().
If GC is supported by a given driver, the driver
MUST override this method.
"""
raise NotImplementedError
@abc.abstractproperty @abc.abstractproperty
def queue_controller(self): def queue_controller(self):
"""Returns storage's queues controller.""" """Returns storage's queues controller."""

View File

@ -22,6 +22,45 @@ class NotPermitted(Exception):
pass pass
class Conflict(Exception):
"""Resource could not be created due to a conflict
with an existing resource.
"""
class MalformedID(ValueError):
"""ID was malformed."""
class MalformedMarker(ValueError):
"""Pagination marker was malformed."""
class MessageConflict(Conflict):
def __init__(self, queue, project, message_ids):
"""Initializes the error with contextual information.
:param queue: name of the queue to which the message was posted
:param project: name of the project to which the queue belongs
:param message_ids: list of IDs for messages successfully
posted. Note that these must be in the same order as the
list of messages originally submitted to be enqueued.
"""
msg = (_("Message could not be enqueued due to a conflict "
"with another message that is already in "
"queue %(queue)s for project %(project)s") %
dict(queue=queue, project=project))
super(MessageConflict, self).__init__(msg)
self._succeeded_ids = message_ids
@property
def succeeded_ids(self):
return self._succeeded_ids
class QueueDoesNotExist(DoesNotExist): class QueueDoesNotExist(DoesNotExist):
def __init__(self, name, project): def __init__(self, name, project):
@ -34,7 +73,7 @@ class MessageDoesNotExist(DoesNotExist):
def __init__(self, mid, queue, project): def __init__(self, mid, queue, project):
msg = (_("Message %(mid)s does not exist in " msg = (_("Message %(mid)s does not exist in "
"queue %(queue)s of project %(project)s") % "queue %(queue)s for project %(project)s") %
dict(mid=mid, queue=queue, project=project)) dict(mid=mid, queue=queue, project=project))
super(MessageDoesNotExist, self).__init__(msg) super(MessageDoesNotExist, self).__init__(msg)
@ -43,7 +82,7 @@ class ClaimDoesNotExist(DoesNotExist):
def __init__(self, cid, queue, project): def __init__(self, cid, queue, project):
msg = (_("Claim %(cid)s does not exist in " msg = (_("Claim %(cid)s does not exist in "
"queue %(queue)s of project %(project)s") % "queue %(queue)s for project %(project)s") %
dict(cid=cid, queue=queue, project=project)) dict(cid=cid, queue=queue, project=project))
super(ClaimDoesNotExist, self).__init__(msg) super(ClaimDoesNotExist, self).__init__(msg)

View File

@ -2,6 +2,5 @@
from marconi.storage.mongodb import driver from marconi.storage.mongodb import driver
# Hoist classes into package namespace # Hoist classes into package namespace
Driver = driver.Driver Driver = driver.Driver

View File

@ -22,25 +22,34 @@ Field Mappings:
updated and documented in each class. updated and documented in each class.
""" """
import collections
import datetime import datetime
import time
from bson import objectid from bson import objectid
import pymongo.errors
import marconi.openstack.common.log as logging
from marconi.openstack.common import timeutils from marconi.openstack.common import timeutils
from marconi import storage from marconi import storage
from marconi.storage import exceptions from marconi.storage import exceptions
from marconi.storage.mongodb import options
from marconi.storage.mongodb import utils from marconi.storage.mongodb import utils
LOG = logging.getLogger(__name__)
class QueueController(storage.QueueBase): class QueueController(storage.QueueBase):
"""Implements queue resource operations using MongoDB. """Implements queue resource operations using MongoDB.
Queues: Queues:
Name Field Name Field
---------------- ------------------
project -> p name -> n
metadata -> m project -> p
name -> n counter -> c
metadata -> m
""" """
@ -55,6 +64,34 @@ class QueueController(storage.QueueBase):
# as specific project, for example. Order Matters! # as specific project, for example. Order Matters!
self._col.ensure_index([("p", 1), ("n", 1)], unique=True) self._col.ensure_index([("p", 1), ("n", 1)], unique=True)
#-----------------------------------------------------------------------
# Helpers
#-----------------------------------------------------------------------
def _get(self, name, project=None, fields={"m": 1, "_id": 0}):
queue = self._col.find_one({"p": project, "n": name}, fields=fields)
if queue is None:
raise exceptions.QueueDoesNotExist(name, project)
return queue
def _get_id(self, name, project=None):
"""Just like the `get` method, but only returns the queue's id
:returns: Queue's `ObjectId`
"""
queue = self._get(name, project, fields=["_id"])
return queue.get("_id")
def _get_ids(self):
"""Returns a generator producing a list of all queue IDs."""
cursor = self._col.find({}, fields={"_id": 1})
return (doc["_id"] for doc in cursor)
#-----------------------------------------------------------------------
# Interface
#-----------------------------------------------------------------------
def list(self, project=None, marker=None, def list(self, project=None, marker=None,
limit=10, detailed=False): limit=10, detailed=False):
query = {"p": project} query = {"p": project}
@ -80,20 +117,6 @@ class QueueController(storage.QueueBase):
yield normalizer(cursor) yield normalizer(cursor)
yield marker_name["next"] yield marker_name["next"]
def _get(self, name, project=None, fields={"m": 1, "_id": 0}):
queue = self._col.find_one({"p": project, "n": name}, fields=fields)
if queue is None:
raise exceptions.QueueDoesNotExist(name, project)
return queue
def get_id(self, name, project=None):
"""Just like the `get` method, but only returns the queue's id
:returns: Queue's `ObjectId`
"""
queue = self._get(name, project, fields=["_id"])
return queue.get("_id")
def get(self, name, project=None): def get(self, name, project=None):
queue = self._get(name, project) queue = self._get(name, project)
return queue.get("m", {}) return queue.get("m", {})
@ -102,7 +125,7 @@ class QueueController(storage.QueueBase):
super(QueueController, self).upsert(name, metadata, project) super(QueueController, self).upsert(name, metadata, project)
rst = self._col.update({"p": project, "n": name}, rst = self._col.update({"p": project, "n": name},
{"$set": {"m": metadata}}, {"$set": {"m": metadata, "c": 1}},
multi=False, multi=False,
upsert=True, upsert=True,
manipulate=False) manipulate=False)
@ -110,14 +133,14 @@ class QueueController(storage.QueueBase):
return not rst["updatedExisting"] return not rst["updatedExisting"]
def delete(self, name, project=None): def delete(self, name, project=None):
self.driver.message_controller.purge_queue(name, project) self.driver.message_controller._purge_queue(name, project)
self._col.remove({"p": project, "n": name}) self._col.remove({"p": project, "n": name})
def stats(self, name, project=None): def stats(self, name, project=None):
qid = self.get_id(name, project) queue_id = self._get_id(name, project)
msg_ctrl = self.driver.message_controller controller = self.driver.message_controller
active = msg_ctrl.active(qid) active = controller.active(queue_id)
claimed = msg_ctrl.claimed(qid) claimed = controller.claimed(queue_id)
return { return {
"actions": 0, "actions": 0,
@ -135,28 +158,28 @@ class MessageController(storage.MessageBase):
"""Implements message resource operations using MongoDB. """Implements message resource operations using MongoDB.
Messages: Messages:
Name Field Name Field
---------------- -----------------
queue -> q queue_id -> q
expires -> e expires -> e
ttl -> t ttl -> t
uuid -> u uuid -> u
claim -> c claim -> c
marker -> k
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(MessageController, self).__init__(*args, **kwargs) super(MessageController, self).__init__(*args, **kwargs)
# Cache for convenience and performance (avoids extra lookups and
# recreating the range for every request.)
self._queue_controller = self.driver.queue_controller
self._db = self.driver.db
self._retry_range = range(options.CFG.max_attempts)
# Make sure indexes exist before, # Make sure indexes exist before,
# doing anything. # doing anything.
self._col = self.driver.db["messages"] self._col = self._db["messages"]
# NOTE(flaper87): Let's make sure we clean up
# expired messages. Notice that TTL indexes run
# a clean up thread every minute, this means that
# every message would have an implicit 1min grace
# if we don't filter them out in the active method.
self._col.ensure_index("e", background=True,
expireAfterSeconds=0)
# NOTE(flaper87): This index is used mostly in the # NOTE(flaper87): This index is used mostly in the
# active method but some parts of it are used in # active method but some parts of it are used in
@ -165,33 +188,184 @@ class MessageController(storage.MessageBase):
# beginning of the index. # beginning of the index.
# * e: Together with q is used for getting a # * e: Together with q is used for getting a
# specific message. (see `get`) # specific message. (see `get`)
self._col.ensure_index([("q", 1), active_fields = [
("e", 1), ("q", 1),
("c.e", 1), ("e", 1),
("_id", -1)], background=True) ("c.e", 1),
("k", 1),
("_id", -1),
]
# Indexes used for claims self._col.ensure_index(active_fields,
self._col.ensure_index([("q", 1), name="active",
("c.id", 1), background=True)
("c.e", 1),
("_id", -1)], background=True)
def _get_queue_id(self, queue, project): # Index used for claims
queue_controller = self.driver.queue_controller claimed_fields = [
return queue_controller.get_id(queue, project) ("q", 1),
("c.id", 1),
("c.e", 1),
("_id", -1),
]
self._col.ensure_index(claimed_fields,
name="claimed",
background=True)
# Index used for _next_marker() and also to ensure
# uniqueness.
#
# NOTE(kgriffs): This index must be unique so that
# inserting a message with the same marker to the
# same queue will fail; this is used to detect a
# race condition which can cause an observer client
# to miss a message when there is more than one
# producer posting messages to the same queue, in
# parallel.
self._col.ensure_index([("q", 1), ("k", -1)],
name="queue_marker",
unique=True,
background=True)
#-----------------------------------------------------------------------
# Helpers
#-----------------------------------------------------------------------
def _get_queue_id(self, queue, project=None):
return self._queue_controller._get_id(queue, project)
def _get_queue_ids(self):
return self._queue_controller._get_ids()
def _next_marker(self, queue_id):
"""Retrieves the next message marker for a given queue.
This helper is used to generate monotonic pagination
markers that are saved as part of the message
document. Simply taking the max of the current message
markers works, since Marconi always leaves the most recent
message in the queue (new queues always return 1).
Note 1: Markers are scoped per-queue and so are *not*
globally unique or globally ordered.
Note 2: If two or more requests to this method are made
in parallel, this method will return the same
marker. This is done intentionally so that the caller
can detect a parallel message post, allowing it to
mitigate race conditions between producer and
observer clients.
:param queue_id: queue ID
:returns: next message marker as an integer
"""
document = self._col.find_one({"q": queue_id},
sort=[("k", -1)],
fields={"k": 1, "_id": 0})
# NOTE(kgriffs): this approach is faster than using "or"
return 1 if document is None else (document["k"] + 1)
def _backoff_sleep(self, attempt):
"""Sleep between retries using a jitter algorithm.
Mitigates thrashing between multiple parallel requests, and
creates backpressure on clients to slow down the rate
at which they submit requests.
:param attempt: current attempt number, zero-based
"""
seconds = utils.calculate_backoff(attempt, options.CFG.max_attempts,
options.CFG.max_retry_sleep,
options.CFG.max_retry_jitter)
time.sleep(seconds)
def _count_expired(self, queue_id):
"""Counts the number of expired messages in a queue.
:param queue_id: id for the queue to stat
"""
query = {
"q": queue_id,
"e": {"$lte": timeutils.utcnow()},
}
return self._col.find(query).count()
def _remove_expired(self, queue_id):
"""Removes all expired messages except for the most recent
in each queue.
This method is used in lieu of mongo's TTL index since we
must always leave at least one message in the queue for
calculating the next marker.
Note that expired messages are only removed if their count
exceeds options.CFG.gc_threshold.
:param queue_id: id for the queue from which to remove
expired messages
"""
if options.CFG.gc_threshold <= self._count_expired(queue_id):
# Get the message with the highest marker, and leave
# it in the queue
head = self._col.find_one({"q": queue_id},
sort=[("k", -1)],
fields={"_id": 1})
if head is None:
# Assume queue was just deleted via a parallel request
LOG.warning(_("Queue %s is empty or missing.") % queue_id)
return
query = {
"q": queue_id,
"e": {"$lte": timeutils.utcnow()},
"_id": {"$ne": head["_id"]}
}
self._col.remove(query)
def _purge_queue(self, queue, project=None):
"""Removes all messages from the queue.
Warning: Only use this when deleting the queue; otherwise
you can cause a side-effect of reseting the marker counter
which can cause clients to miss tons of messages.
If the queue does not exist, this method fails silently.
:param queue: name of the queue to purge
:param project: name of the project to which the queue belongs
"""
try:
qid = self._get_queue_id(queue, project)
self._col.remove({"q": qid}, w=0)
except exceptions.QueueDoesNotExist:
pass
#-----------------------------------------------------------------------
# Interface
#-----------------------------------------------------------------------
def all(self): def all(self):
return self._col.find() return self._col.find()
def active(self, queue, marker=None, echo=False, def active(self, queue_id, marker=None, echo=False,
client_uuid=None, fields=None): client_uuid=None, fields=None):
now = timeutils.utcnow() now = timeutils.utcnow()
query = { query = {
# Messages must belong to this queue # Messages must belong to this queue
"q": utils.to_oid(queue), "q": utils.to_oid(queue_id),
# The messages can not be expired
"e": {"$gt": now}, "e": {"$gt": now},
# Include messages that are part of expired claims
"c.e": {"$lte": now}, "c.e": {"$lte": now},
} }
@ -202,17 +376,17 @@ class MessageController(storage.MessageBase):
query["u"] = {"$ne": client_uuid} query["u"] = {"$ne": client_uuid}
if marker: if marker:
query["_id"] = {"$gt": utils.to_oid(marker)} query["k"] = {"$gt": marker}
return self._col.find(query, fields=fields) return self._col.find(query, fields=fields)
def claimed(self, queue, claim_id=None, expires=None, limit=None): def claimed(self, queue_id, claim_id=None, expires=None, limit=None):
query = { query = {
"c.id": claim_id, "c.id": claim_id,
"c.e": {"$gt": expires or timeutils.utcnow()}, "c.e": {"$gt": expires or timeutils.utcnow()},
"q": utils.to_oid(queue), "q": utils.to_oid(queue_id),
} }
if not claim_id: if not claim_id:
# lookup over c.id to use the index # lookup over c.id to use the index
query["c.id"] = {"$ne": None} query["c.id"] = {"$ne": None}
@ -243,18 +417,50 @@ class MessageController(storage.MessageBase):
cid = utils.to_oid(claim_id) cid = utils.to_oid(claim_id)
except ValueError: except ValueError:
return return
self._col.update({"c.id": cid}, self._col.update({"c.id": cid},
{"$set": {"c": {"id": None, "e": 0}}}, {"$set": {"c": {"id": None, "e": 0}}},
upsert=False, multi=True) upsert=False, multi=True)
def remove_expired(self, project=None):
"""Removes all expired messages except for the most recent
in each queue.
This method is used in lieu of mongo's TTL index since we
must always leave at least one message in the queue for
calculating the next marker.
Warning: This method is expensive, since it must perform
separate queries for each queue, due to the requirement that
it must leave at least one message in each queue, and it
is impractical to send a huge list of _id's to filter out
in a single call. That being said, this is somewhat mitigated
by the gc_threshold configuration option, which reduces the
frequency at which the DB is locked for non-busy queues. Also,
since .remove is run on each queue seperately, this reduces
the duration that any given lock is held, avoiding blocking
regular writes.
"""
# TODO(kgriffs): Optimize first by batching the .removes, second
# by setting a "last inserted ID" in the queue collection for
# each message inserted (TBD, may cause problematic side-effect),
# and third, by changing the marker algorithm such that it no
# longer depends on retaining the last message in the queue!
for id in self._get_queue_ids():
self._remove_expired(id)
def list(self, queue, project=None, marker=None, def list(self, queue, project=None, marker=None,
limit=10, echo=False, client_uuid=None): limit=10, echo=False, client_uuid=None):
try: if marker is not None:
qid = self._get_queue_id(queue, project) try:
messages = self.active(qid, marker, echo, client_uuid) marker = int(marker)
except ValueError: except ValueError:
return raise exceptions.MalformedMarker()
qid = self._get_queue_id(queue, project)
messages = self.active(qid, marker, echo, client_uuid)
messages = messages.limit(limit).sort("_id") messages = messages.limit(limit).sort("_id")
marker_id = {} marker_id = {}
@ -264,7 +470,7 @@ class MessageController(storage.MessageBase):
def denormalizer(msg): def denormalizer(msg):
oid = msg["_id"] oid = msg["_id"]
age = now - utils.oid_utc(oid) age = now - utils.oid_utc(oid)
marker_id['next'] = oid marker_id['next'] = msg["k"]
return { return {
"id": str(oid), "id": str(oid),
@ -277,15 +483,10 @@ class MessageController(storage.MessageBase):
yield str(marker_id['next']) yield str(marker_id['next'])
def get(self, queue, message_id, project=None): def get(self, queue, message_id, project=None):
mid = utils.to_oid(message_id)
# Base query, always check expire time
try:
mid = utils.to_oid(message_id)
except ValueError:
raise exceptions.MessageDoesNotExist(message_id, queue, project)
now = timeutils.utcnow() now = timeutils.utcnow()
# Base query, always check expire time
query = { query = {
"q": self._get_queue_id(queue, project), "q": self._get_queue_id(queue, project),
"e": {"$gt": now}, "e": {"$gt": now},
@ -308,33 +509,114 @@ class MessageController(storage.MessageBase):
} }
def post(self, queue, messages, client_uuid, project=None): def post(self, queue, messages, client_uuid, project=None):
qid = self._get_queue_id(queue, project)
now = timeutils.utcnow() now = timeutils.utcnow()
queue_id = self._get_queue_id(queue, project)
def denormalizer(messages): # Set the next basis marker for the first attempt.
for msg in messages: next_marker = self._next_marker(queue_id)
ttl = int(msg["ttl"])
expires = now + datetime.timedelta(seconds=ttl)
yield { # Results are aggregated across all attempts
"t": ttl, # NOTE(kgriffs): lazy instantiation
"q": qid, aggregated_results = None
"e": expires,
"u": client_uuid,
"c": {"id": None, "e": now},
"b": msg['body'] if 'body' in msg else {}
}
ids = self._col.insert(denormalizer(messages)) # NOTE(kgriffs): This avoids iterating over messages twice,
return map(str, ids) # since pymongo internally will iterate over them all to
# encode as bson before submitting to mongod. By using a
# generator, we can produce each message only once,
# as needed by pymongo. At the same time, each message is
# cached in case we need to retry any of them.
message_gen = (
{
"t": message["ttl"],
"q": queue_id,
"e": now + datetime.timedelta(seconds=message["ttl"]),
"u": client_uuid,
"c": {"id": None, "e": now},
"b": message["body"] if "body" in message else {},
"k": next_marker + index,
}
for index, message in enumerate(messages)
)
prepared_messages, cached_messages = utils.cached_gen(message_gen)
# Use a retry range for sanity, although we expect
# to rarely, if ever, reach the maximum number of
# retries.
for attempt in self._retry_range:
try:
ids = self._col.insert(prepared_messages)
# NOTE(kgriffs): Only use aggregated results if we must,
# which saves some cycles on the happy path.
if aggregated_results:
aggregated_results.extend(ids)
ids = aggregated_results
return map(str, ids)
except pymongo.errors.DuplicateKeyError as ex:
# Try again with the remaining messages
# TODO(kgriffs): Record stats of how often retries happen,
# and how many attempts, on average, are required to insert
# messages.
# NOTE(kgriffs): Slice prepared_messages. We have to interpret
# the error message to get the duplicate key, which gives
# us the marker that had a dupe, allowing us to extrapolate
# how many messages were consumed, since markers are monotonic
# counters.
duplicate_marker = utils.dup_marker_from_error(str(ex))
failed_index = duplicate_marker - next_marker
# First time here, convert the deque to a list
# to support slicing.
if isinstance(cached_messages, collections.deque):
cached_messages = list(cached_messages)
# Put the successful one's IDs into aggregated_results.
succeeded_messages = cached_messages[:failed_index]
succeeded_ids = [msg["_id"] for msg in succeeded_messages]
# Results are aggregated across all attempts
if aggregated_results is None:
aggregated_results = succeeded_ids
else:
aggregated_results.extend(succeeded_ids)
# Retry the remaining messages with a new sequence
# of markers.
prepared_messages = cached_messages[failed_index:]
next_marker = self._next_marker(queue_id)
for index, message in enumerate(prepared_messages):
message["k"] = next_marker + index
self._backoff_sleep(attempt)
except Exception as ex:
# TODO(kgriffs): Query the DB to get the last marker that
# made it, and extrapolate from there to figure out what
# needs to be retried. Definitely retry on AutoReconnect;
# other types of errors TBD.
LOG.exception(ex)
raise
message = _("Hit maximum number of attempts (%(max)s) for queue "
"%(id)s in project %(project)s")
message %= dict(max=options.CFG.max_attempts, id=queue_id,
project=project)
LOG.warning(message)
succeeded_ids = map(str, aggregated_results)
raise exceptions.MessageConflict(queue, project, succeeded_ids)
def delete(self, queue, message_id, project=None, claim=None): def delete(self, queue, message_id, project=None, claim=None):
try: try:
try: mid = utils.to_oid(message_id)
mid = utils.to_oid(message_id)
except ValueError:
return
query = { query = {
"q": self._get_queue_id(queue, project), "q": self._get_queue_id(queue, project),
@ -349,10 +631,7 @@ class MessageController(storage.MessageBase):
if message is None: if message is None:
return return
try: cid = utils.to_oid(claim)
cid = utils.to_oid(claim)
except ValueError:
raise exceptions.ClaimNotPermitted(message_id, claim)
if not ("c" in message and if not ("c" in message and
message["c"]["id"] == cid and message["c"]["id"] == cid and
@ -365,13 +644,6 @@ class MessageController(storage.MessageBase):
except exceptions.QueueDoesNotExist: except exceptions.QueueDoesNotExist:
pass pass
def purge_queue(self, queue, project=None):
try:
qid = self._get_queue_id(queue, project)
self._col.remove({"q": qid}, w=0)
except exceptions.QueueDoesNotExist:
pass
class ClaimController(storage.ClaimBase): class ClaimController(storage.ClaimBase):
"""Implements claim resource operations using MongoDB. """Implements claim resource operations using MongoDB.
@ -396,7 +668,7 @@ class ClaimController(storage.ClaimBase):
def _get_queue_id(self, queue, project): def _get_queue_id(self, queue, project):
queue_controller = self.driver.queue_controller queue_controller = self.driver.queue_controller
return queue_controller.get_id(queue, project) return queue_controller._get_id(queue, project)
def get(self, queue, claim_id, project=None): def get(self, queue, claim_id, project=None):
msg_ctrl = self.driver.message_controller msg_ctrl = self.driver.message_controller

View File

@ -18,36 +18,46 @@
import pymongo import pymongo
import pymongo.errors import pymongo.errors
from marconi.common import config from marconi.openstack.common import log as logging
from marconi import storage from marconi import storage
from marconi.storage.mongodb import controllers from marconi.storage.mongodb import controllers
from marconi.storage.mongodb import options
options = { LOG = logging.getLogger(__name__)
"uri": None,
"database": "marconi",
}
cfg = config.namespace('drivers:storage:mongodb').from_options(**options)
class Driver(storage.DriverBase): class Driver(storage.DriverBase):
def __init__(self): def __init__(self):
# Lazy instantiation
self._database = None self._database = None
@property @property
def db(self): def db(self):
"""Property for lazy instantiation of mongodb's database.""" """Property for lazy instantiation of mongodb's database."""
if not self._database: if self._database is None:
if cfg.uri and 'replicaSet' in cfg.uri: if options.CFG.uri and 'replicaSet' in options.CFG.uri:
conn = pymongo.MongoReplicaSetClient(cfg.uri) conn = pymongo.MongoReplicaSetClient(options.CFG.uri)
else: else:
conn = pymongo.MongoClient(cfg.uri) conn = pymongo.MongoClient(options.CFG.uri)
self._database = conn[cfg.database] self._database = conn[options.CFG.database]
return self._database return self._database
def gc(self):
LOG.info("Performing garbage collection.")
try:
self.message_controller.remove_expired()
except pymongo.errors.ConnectionFailure as ex:
# Better luck next time...
LOG.exception(ex)
@property
def gc_interval(self):
return options.CFG.gc_interval
@property @property
def queue_controller(self): def queue_controller(self):
return controllers.QueueController(self) return controllers.QueueController(self)

View File

@ -0,0 +1,57 @@
# Copyright (c) 2013 Rackspace Hosting, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
#
# See the License for the specific language governing permissions and
# limitations under the License.
"""MongoDB storage driver configuration options."""
from marconi.common import config
OPTIONS = {
# Connection string
"uri": None,
# Database name
#TODO(kgriffs): Consider local sharding across DBs to mitigate
# per-DB locking latency.
"database": "marconi",
# Maximum number of times to retry a failed operation. Currently
# only used for retrying a message post.
"max_attempts": 1000,
# Maximum sleep interval between retries (actual sleep time
# increases linearly according to number of attempts performed).
"max_retry_sleep": 0.1,
# Maximum jitter interval, to be added to the sleep interval, in
# order to decrease probability that parallel requests will retry
# at the same instant.
"max_retry_jitter": 0.005,
# Frequency of message garbage collections, in seconds
"gc_interval": 5 * 60,
# Threshold of number of expired messages to reach in a given
# queue, before performing the GC. Useful for reducing frequent
# locks on the DB for non-busy queues, or for worker queues
# which process jobs quickly enough to keep the number of in-
# flight messages low.
#
# Note: The higher this number, the larger the memory-mapped DB
# files will be.
"gc_threshold": 1000,
}
CFG = config.namespace('drivers:storage:mongodb').from_options(**OPTIONS)

View File

@ -13,27 +13,118 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import random
import re
from bson import errors as berrors from bson import errors as berrors
from bson import objectid from bson import objectid
from marconi.common import exceptions
from marconi.openstack.common import timeutils from marconi.openstack.common import timeutils
from marconi.storage import exceptions as storage_exceptions
DUP_MARKER_REGEX = re.compile(r"\$queue_marker\s+dup key: { : [^:]+: (\d)+")
def dup_marker_from_error(error_message):
"""Extracts the duplicate marker from a MongoDB error string.
:param error_message: raw error message string returned
by mongod on a duplicate key error.
:raises: marconi.common.exceptions.PatternNotFound
:returns: extracted marker as an integer
"""
match = DUP_MARKER_REGEX.search(error_message)
if match is None:
description = (_("Error message could not be parsed: %s") %
error_message)
raise exceptions.PatternNotFound(description)
return int(match.groups()[0])
def cached_gen(iterable):
"""Converts the iterable into a caching generator.
Returns a proxy that yields each item of iterable, while at
the same time caching those items in a deque.
:param iterable: an iterable to wrap in a caching generator
:returns: (proxy(iterable), cached_items)
"""
cached_items = collections.deque()
def generator(iterable):
for item in iterable:
cached_items.append(item)
yield item
return (generator(iterable), cached_items)
def calculate_backoff(attempt, max_attempts, max_sleep, max_jitter=0):
"""Calculates backoff time, in seconds, when retrying an operation.
This function calculates a simple linear backoff time with
optional jitter, useful for retrying a request under high
concurrency.
The result may be passed directly into time.sleep() in order to
mitigate stampeding herd syndrome and introduce backpressure towards
the clients, slowing them down.
:param attempt: current value of the attempt counter (zero-based)
:param max_attempts: maximum number of attempts that will be tried
:param max_sleep: maximum sleep value to apply before jitter, assumed
to be seconds. Fractional seconds are supported to 1 ms
granularity.
:param max_jitter: maximum jitter value to add to the baseline sleep
time. Actual value will be chosen randomly.
:raises: ValueError
:returns: float representing the number of seconds to sleep, within
the interval [0, max_sleep), determined linearly according to
the ratio attempt / max_attempts, with optional jitter.
"""
if max_attempts < 0:
raise ValueError("max_attempts must be >= 0")
if max_sleep < 0:
raise ValueError("max_sleep must be >= 0")
if max_jitter < 0:
raise ValueError("max_jitter must be >= 0")
if not (0 <= attempt < max_attempts):
raise ValueError("attempt value is out of range")
ratio = float(attempt) / float(max_attempts)
backoff_sec = ratio * max_sleep
jitter_sec = random.random() * max_jitter
return backoff_sec + jitter_sec
def to_oid(obj): def to_oid(obj):
"""Creates a new ObjectId based on the input. """Creates a new ObjectId based on the input.
Raises ValueError when TypeError or InvalidId Raises MalformedID when TypeError or berrors.InvalidId
is raised by the ObjectID class. is raised by the ObjectID class.
:param obj: Anything that can be passed as an :param obj: Anything that can be passed as an
input to `objectid.ObjectId` input to `objectid.ObjectId`
:raises: ValueError
:raises: MalformedID
""" """
try: try:
return objectid.ObjectId(obj) return objectid.ObjectId(obj)
except (TypeError, berrors.InvalidId): except (TypeError, berrors.InvalidId):
msg = _("Wrong id %s") % obj msg = _("Wrong id %s") % obj
raise ValueError(msg) raise storage_exceptions.MalformedID(msg)
def oid_utc(oid): def oid_utc(oid):

View File

@ -160,52 +160,49 @@ class Message(base.MessageBase):
'body': content, 'body': content,
} }
except (_NoResult, _BadID): except _NoResult:
raise exceptions.MessageDoesNotExist(message_id, queue, project) raise exceptions.MessageDoesNotExist(message_id, queue, project)
def list(self, queue, project, marker=None, def list(self, queue, project, marker=None,
limit=10, echo=False, client_uuid=None): limit=10, echo=False, client_uuid=None):
with self.driver('deferred'): with self.driver('deferred'):
try: sql = '''
sql = ''' select id, content, ttl, julianday() * 86400.0 - created
select id, content, ttl, julianday() * 86400.0 - created from Messages
from Messages where ttl > julianday() * 86400.0 - created
where ttl > julianday() * 86400.0 - created and qid = ?'''
and qid = ?''' args = [_get_qid(self.driver, queue, project)]
args = [_get_qid(self.driver, queue, project)]
if not echo:
sql += '''
and client != ?'''
args += [client_uuid]
if marker:
sql += '''
and id > ?'''
args += [_marker_decode(marker)]
if not echo:
sql += ''' sql += '''
limit ?''' and client != ?'''
args += [limit] args += [client_uuid]
records = self.driver.run(sql, *args) if marker:
marker_id = {} sql += '''
and id > ?'''
args += [_marker_decode(marker)]
def it(): sql += '''
for id, content, ttl, age in records: limit ?'''
marker_id['next'] = id args += [limit]
yield {
'id': _msgid_encode(id),
'ttl': ttl,
'age': int(age),
'body': content,
}
yield it() records = self.driver.run(sql, *args)
yield _marker_encode(marker_id['next']) marker_id = {}
except _BadID: def it():
return for id, content, ttl, age in records:
marker_id['next'] = id
yield {
'id': _msgid_encode(id),
'ttl': ttl,
'age': int(age),
'body': content,
}
yield it()
yield _marker_encode(marker_id['next'])
def post(self, queue, messages, client_uuid, project): def post(self, queue, messages, client_uuid, project):
with self.driver('immediate'): with self.driver('immediate'):
@ -238,52 +235,44 @@ class Message(base.MessageBase):
return map(_msgid_encode, range(unused, my['newid'])) return map(_msgid_encode, range(unused, my['newid']))
def delete(self, queue, message_id, project, claim=None): def delete(self, queue, message_id, project, claim=None):
try: id = _msgid_decode(message_id)
id = _msgid_decode(message_id)
if not claim: if not claim:
self.driver.run('''
delete from Messages
where id = ?
and qid = (select id from Queues
where project = ? and name = ?)
''', id, project, queue)
return
with self.driver('immediate'):
message_exists, = self.driver.get('''
select count(M.id)
from Queues as Q join Messages as M
on qid = Q.id
where ttl > julianday() * 86400.0 - created
and M.id = ? and project = ? and name = ?
''', id, project, queue)
if not message_exists:
return
self.__delete_claimed(id, claim)
except _BadID:
pass
def __delete_claimed(self, id, claim):
# Precondition: id exists in a specific queue
try:
self.driver.run(''' self.driver.run('''
delete from Messages delete from Messages
where id = ? where id = ?
and id in (select msgid and qid = (select id from Queues
from Claims join Locked where project = ? and name = ?)
on id = cid ''', id, project, queue)
where ttl > julianday() * 86400.0 - created return
and id = ?)
''', id, _cid_decode(claim))
if not self.driver.affected: with self.driver('immediate'):
raise exceptions.ClaimNotPermitted(_msgid_encode(id), claim) message_exists, = self.driver.get('''
select count(M.id)
from Queues as Q join Messages as M
on qid = Q.id
where ttl > julianday() * 86400.0 - created
and M.id = ? and project = ? and name = ?
''', id, project, queue)
except _BadID: if not message_exists:
return
self.__delete_claimed(id, claim)
def __delete_claimed(self, id, claim):
# Precondition: id exists in a specific queue
self.driver.run('''
delete from Messages
where id = ?
and id in (select msgid
from Claims join Locked
on id = cid
where ttl > julianday() * 86400.0 - created
and id = ?)
''', id, _cid_decode(claim))
if not self.driver.affected:
raise exceptions.ClaimNotPermitted(_msgid_encode(id), claim) raise exceptions.ClaimNotPermitted(_msgid_encode(id), claim)
@ -332,7 +321,7 @@ class Claim(base.ClaimBase):
self.__get(id) self.__get(id)
) )
except (_NoResult, _BadID): except (_NoResult, exceptions.MalformedID()):
raise exceptions.ClaimDoesNotExist(claim_id, queue, project) raise exceptions.ClaimDoesNotExist(claim_id, queue, project)
def create(self, queue, metadata, project, limit=10): def create(self, queue, metadata, project, limit=10):
@ -386,30 +375,29 @@ class Claim(base.ClaimBase):
def update(self, queue, claim_id, metadata, project): def update(self, queue, claim_id, metadata, project):
try: try:
id = _cid_decode(claim_id) id = _cid_decode(claim_id)
except exceptions.MalformedID:
with self.driver('deferred'):
# still delay the cleanup here
self.driver.run('''
update Claims
set created = julianday() * 86400.0,
ttl = ?
where ttl > julianday() * 86400.0 - created
and id = ?
and qid = (select id from Queues
where project = ? and name = ?)
''', metadata['ttl'], id, project, queue)
if not self.driver.affected:
raise exceptions.ClaimDoesNotExist(claim_id,
queue,
project)
self.__update_claimed(id, metadata['ttl'])
except _BadID:
raise exceptions.ClaimDoesNotExist(claim_id, queue, project) raise exceptions.ClaimDoesNotExist(claim_id, queue, project)
with self.driver('deferred'):
# still delay the cleanup here
self.driver.run('''
update Claims
set created = julianday() * 86400.0,
ttl = ?
where ttl > julianday() * 86400.0 - created
and id = ?
and qid = (select id from Queues
where project = ? and name = ?)
''', metadata['ttl'], id, project, queue)
if not self.driver.affected:
raise exceptions.ClaimDoesNotExist(claim_id,
queue,
project)
self.__update_claimed(id, metadata['ttl'])
def __update_claimed(self, cid, ttl): def __update_claimed(self, cid, ttl):
# Precondition: cid is not expired # Precondition: cid is not expired
self.driver.run(''' self.driver.run('''
@ -423,25 +411,22 @@ class Claim(base.ClaimBase):
def delete(self, queue, claim_id, project): def delete(self, queue, claim_id, project):
try: try:
self.driver.run(''' cid = _cid_decode(claim_id)
delete from Claims except exceptions.MalformedID:
where id = ? return
and qid = (select id from Queues
where project = ? and name = ?)
''', _cid_decode(claim_id), project, queue)
except _BadID: self.driver.run('''
pass delete from Claims
where id = ?
and qid = (select id from Queues
where project = ? and name = ?)
''', cid, project, queue)
class _NoResult(Exception): class _NoResult(Exception):
pass pass
class _BadID(Exception):
pass
def _get_qid(driver, queue, project): def _get_qid(driver, queue, project):
try: try:
return driver.get(''' return driver.get('''
@ -469,7 +454,7 @@ def _msgid_decode(id):
return int(id, 16) ^ 0x5c693a53 return int(id, 16) ^ 0x5c693a53
except ValueError: except ValueError:
raise _BadID raise exceptions.MalformedID()
def _marker_encode(id): def _marker_encode(id):
@ -481,7 +466,7 @@ def _marker_decode(id):
return int(id, 8) ^ 0x3c96a355 return int(id, 8) ^ 0x3c96a355
except ValueError: except ValueError:
raise _BadID raise exceptions.MalformedMarker()
def _cid_encode(id): def _cid_encode(id):
@ -493,4 +478,4 @@ def _cid_decode(id):
return int(id, 16) ^ 0x63c9a59c return int(id, 16) ^ 0x63c9a59c
except ValueError: except ValueError:
raise _BadID raise exceptions.MalformedID()

View File

@ -23,15 +23,14 @@ from marconi.common import config
from marconi import storage from marconi import storage
from marconi.storage.sqlite import controllers from marconi.storage.sqlite import controllers
CFG = config.namespace('drivers:storage:sqlite').from_options(
cfg = config.namespace('drivers:storage:sqlite').from_options(
database=':memory:') database=':memory:')
class Driver(storage.DriverBase): class Driver(storage.DriverBase):
def __init__(self): def __init__(self):
self.__path = cfg.database self.__path = CFG.database
self.__conn = sqlite3.connect(self.__path, self.__conn = sqlite3.connect(self.__path,
detect_types=sqlite3.PARSE_DECLTYPES) detect_types=sqlite3.PARSE_DECLTYPES)
self.__db = self.__conn.cursor() self.__db = self.__conn.cursor()

View File

@ -8,3 +8,4 @@ port = 8888
[drivers:storage:mongodb] [drivers:storage:mongodb]
uri = "mongodb://127.0.0.1:27017" uri = "mongodb://127.0.0.1:27017"
database = "marconi_test" database = "marconi_test"
gc_threshold = 100

View File

@ -243,31 +243,48 @@ class MessageControllerTest(ControllerBaseTest):
project=self.project) project=self.project)
self.assertEquals(countof['messages']['free'], 0) self.assertEquals(countof['messages']['free'], 0)
def test_illformed_id(self): def test_bad_id(self):
# any ill-formed IDs should be regarded as non-existing ones. # A malformed ID should result in an error. This
# doesn't hurt anything, since an attacker could just
# read the source code anyway to find out how IDs are
# implemented. Plus, if someone is just trying to
# get a message that they don't own, they would
# more likely just list the messages, not try to
# guess an ID of an arbitrary message.
self.queue_controller.upsert('unused', {}, '480924') queue = 'foo'
self.controller.delete('unused', 'illformed', '480924') project = '480924'
self.queue_controller.upsert(queue, {}, project)
msgs = list(self.controller.list('unused', '480924', bad_message_id = 'xyz'
marker='illformed')) with testing.expect(exceptions.MalformedID):
self.controller.delete(queue, bad_message_id, project)
self.assertEquals(len(msgs), 0) with testing.expect(exceptions.MalformedID):
self.controller.get(queue, bad_message_id, project)
with testing.expect(exceptions.DoesNotExist): def test_bad_claim_id(self):
self.controller.get('unused', 'illformed', '480924')
def test_illformed_claim(self):
self.queue_controller.upsert('unused', {}, '480924') self.queue_controller.upsert('unused', {}, '480924')
[msgid] = self.controller.post('unused', [msgid] = self.controller.post('unused',
[{'body': {}, 'ttl': 10}], [{'body': {}, 'ttl': 10}],
project='480924', project='480924',
client_uuid='unused') client_uuid='unused')
with testing.expect(exceptions.NotPermitted): bad_claim_id = '; DROP TABLE queues'
with testing.expect(exceptions.MalformedID):
self.controller.delete('unused', msgid, self.controller.delete('unused', msgid,
project='480924', project='480924',
claim='illformed') claim=bad_claim_id)
def test_bad_marker(self):
queue = 'foo'
project = '480924'
self.queue_controller.upsert(queue, {}, project)
bad_marker = 'xyz'
func = self.controller.list
results = func(queue, project, marker=bad_marker)
self.assertRaises(exceptions.MalformedMarker, results.next)
class ClaimControllerTest(ControllerBaseTest): class ClaimControllerTest(ControllerBaseTest):
@ -378,7 +395,7 @@ def _insert_fixtures(controller, queue_name, project=None,
def messages(): def messages():
for n in xrange(num): for n in xrange(num):
yield { yield {
"ttl": 60, "ttl": 120,
"body": { "body": {
"event": "Event number %s" % n "event": "Event number %s" % n
}} }}

View File

@ -14,17 +14,70 @@
# limitations under the License. # limitations under the License.
import os import os
import random
import time import time
from marconi.common import config from testtools import matchers
from marconi.common import exceptions
from marconi import storage from marconi import storage
from marconi.storage import mongodb from marconi.storage import mongodb
from marconi.storage.mongodb import controllers from marconi.storage.mongodb import controllers
from marconi.storage.mongodb import options as mongodb_options
from marconi.storage.mongodb import utils
from marconi.tests.storage import base from marconi.tests.storage import base
from marconi.tests import util as testing from marconi.tests import util as testing
cfg = config.namespace("drivers:storage:mongodb").from_options() class MongodbUtilsTest(testing.TestBase):
def test_dup_marker_from_error(self):
error_message = ("E11000 duplicate key error index: "
"marconi.messages.$queue_marker dup key: "
"{ : ObjectId('51adff46b100eb85d8a93a2d'), : 3 }")
marker = utils.dup_marker_from_error(error_message)
self.assertEquals(marker, 3)
error_message = ("E11000 duplicate key error index: "
"marconi.messages.$x_y dup key: "
"{ : ObjectId('51adff46b100eb85d8a93a2d'), : 3 }")
self.assertRaises(exceptions.PatternNotFound,
utils.dup_marker_from_error, error_message)
error_message = ("E11000 duplicate key error index: "
"marconi.messages.$queue_marker dup key: "
"{ : ObjectId('51adff46b100eb85d8a93a2d') }")
self.assertRaises(exceptions.PatternNotFound,
utils.dup_marker_from_error, error_message)
def test_calculate_backoff(self):
sec = utils.calculate_backoff(0, 10, 2, 0)
self.assertEquals(sec, 0)
sec = utils.calculate_backoff(9, 10, 2, 0)
self.assertEquals(sec, 1.8)
sec = utils.calculate_backoff(4, 10, 2, 0)
self.assertEquals(sec, 0.8)
sec = utils.calculate_backoff(4, 10, 2, 1)
if sec != 0.8:
self.assertThat(sec, matchers.GreaterThan(0.8))
self.assertThat(sec, matchers.LessThan(1.8))
self.assertRaises(ValueError, utils.calculate_backoff, 0, 10, -2, -1)
self.assertRaises(ValueError, utils.calculate_backoff, 0, 10, -2, 0)
self.assertRaises(ValueError, utils.calculate_backoff, 0, 10, 2, -1)
self.assertRaises(ValueError, utils.calculate_backoff, -2, -10, 2, 0)
self.assertRaises(ValueError, utils.calculate_backoff, 2, -10, 2, 0)
self.assertRaises(ValueError, utils.calculate_backoff, -2, 10, 2, 0)
self.assertRaises(ValueError, utils.calculate_backoff, -1, 10, 2, 0)
self.assertRaises(ValueError, utils.calculate_backoff, 10, 10, 2, 0)
self.assertRaises(ValueError, utils.calculate_backoff, 11, 10, 2, 0)
class MongodbDriverTest(testing.TestBase): class MongodbDriverTest(testing.TestBase):
@ -39,7 +92,7 @@ class MongodbDriverTest(testing.TestBase):
def test_db_instance(self): def test_db_instance(self):
driver = mongodb.Driver() driver = mongodb.Driver()
db = driver.db db = driver.db
self.assertEquals(db.name, cfg.database) self.assertEquals(db.name, mongodb_options.CFG.database)
class MongodbQueueTests(base.QueueControllerTest): class MongodbQueueTests(base.QueueControllerTest):
@ -66,7 +119,7 @@ class MongodbQueueTests(base.QueueControllerTest):
def test_messages_purged(self): def test_messages_purged(self):
queue_name = "test" queue_name = "test"
self.controller.upsert(queue_name, {}) self.controller.upsert(queue_name, {})
qid = self.controller.get_id(queue_name) qid = self.controller._get_id(queue_name)
self.message_controller.post(queue_name, self.message_controller.post(queue_name,
[{"ttl": 60}], [{"ttl": 60}],
1234) 1234)
@ -91,12 +144,92 @@ class MongodbMessageTests(base.MessageControllerTest):
self.controller._col.drop() self.controller._col.drop()
super(MongodbMessageTests, self).tearDown() super(MongodbMessageTests, self).tearDown()
def _count_expired(self, queue, project=None):
queue_id = self.queue_controller._get_id(queue, project)
return self.controller._count_expired(queue_id)
def test_indexes(self): def test_indexes(self):
col = self.controller._col col = self.controller._col
indexes = col.index_information() indexes = col.index_information()
#self.assertIn("e_1", indexes) self.assertIn("active", indexes)
self.assertIn("q_1_e_1_c.e_1__id_-1", indexes) self.assertIn("claimed", indexes)
self.assertIn("q_1_c.id_1_c.e_1__id_-1", indexes) self.assertIn("queue_marker", indexes)
def test_next_marker(self):
queue_name = "marker_test"
iterations = 10
self.queue_controller.upsert(queue_name, {})
queue_id = self.queue_controller._get_id(queue_name)
seed_marker1 = self.controller._next_marker(queue_name)
self.assertEqual(seed_marker1, 1, "First marker is 1")
for i in range(iterations):
self.controller.post(queue_name, [{"ttl": 60}], "uuid")
marker1 = self.controller._next_marker(queue_id)
marker2 = self.controller._next_marker(queue_id)
marker3 = self.controller._next_marker(queue_id)
self.assertEqual(marker1, marker2)
self.assertEqual(marker2, marker3)
self.assertEqual(marker1, i+2)
def test_remove_expired(self):
num_projects = 10
num_queues = 10
total_queues = num_projects * num_queues
gc_threshold = mongodb_options.CFG.gc_threshold
messages_per_queue = gc_threshold
nogc_messages_per_queue = gc_threshold - 1
projects = ["gc-test-project-%s" % i for i in range(num_projects)]
queue_names = ["gc-test-%s" % i for i in range(num_queues)]
client_uuid = "b623c53c-cf75-11e2-84e1-a1187188419e"
messages = [{"ttl": 0, "body": str(i)}
for i in range(messages_per_queue)]
for project in projects:
for queue in queue_names:
self.queue_controller.upsert(queue, {}, project)
self.controller.post(queue, messages, client_uuid, project)
# Add one that should not be gc'd due to being under threshold
self.queue_controller.upsert("nogc-test", {}, "nogc-test-project")
nogc_messages = [{"ttl": 0, "body": str(i)}
for i in range(nogc_messages_per_queue)]
self.controller.post("nogc-test", nogc_messages,
client_uuid, "nogc-test-project")
total_expired = sum(
self._count_expired(queue, project)
for queue in queue_names
for project in projects)
self.assertEquals(total_expired, total_queues * messages_per_queue)
self.controller.remove_expired()
# Make sure the messages in this queue were not gc'd since
# the count was under the threshold.
self.assertEquals(
self._count_expired("nogc-test", "nogc-test-project"),
len(nogc_messages))
total_expired = sum(
self._count_expired(queue, project)
for queue in queue_names
for project in projects)
# Expect that the most recent message for each queue
# will not be removed.
self.assertEquals(total_expired, total_queues)
# Sanity-check that the most recent message is the
# one remaining in the queue.
queue = random.choice(queue_names)
queue_id = self.queue_controller._get_id(queue, project)
message = self.driver.db.messages.find_one({"q": queue_id})
self.assertEquals(message["k"], messages_per_queue)
class MongodbClaimTests(base.ClaimControllerTest): class MongodbClaimTests(base.ClaimControllerTest):

View File

@ -17,8 +17,8 @@ from marconi.common import config
from marconi.tests import util as testing from marconi.tests import util as testing
cfg_handle = config.project() PROJECT_CONFIG = config.project()
cfg = cfg_handle.from_options( CFG = PROJECT_CONFIG.from_options(
without_help=3, without_help=3,
with_help=(None, "nonsense")) with_help=(None, "nonsense"))
@ -27,11 +27,11 @@ class TestConfig(testing.TestBase):
def test_cli(self): def test_cli(self):
args = ['--with_help', 'sense'] args = ['--with_help', 'sense']
cfg_handle.load(self.conf_path('wsgi_sqlite.conf'), args) PROJECT_CONFIG.load(self.conf_path('wsgi_sqlite.conf'), args)
self.assertEquals(cfg.with_help, 'sense') self.assertEquals(CFG.with_help, 'sense')
cfg_handle.load(args=[]) PROJECT_CONFIG.load(args=[])
self.assertEquals(cfg.with_help, None) self.assertEquals(CFG.with_help, None)
def test_wrong_type(self): def test_wrong_type(self):
ns = config.namespace('local') ns = config.namespace('local')

View File

@ -14,10 +14,14 @@
# limitations under the License. # limitations under the License.
import json import json
import os
import pymongo
import falcon import falcon
from falcon import testing from falcon import testing
from marconi.common import config
from marconi.tests.transport.wsgi import base from marconi.tests.transport.wsgi import base
@ -194,6 +198,23 @@ class ClaimsBaseTest(base.TestBase):
super(ClaimsBaseTest, self).tearDown() super(ClaimsBaseTest, self).tearDown()
class ClaimsMongoDBTests(ClaimsBaseTest):
config_filename = 'wsgi_mongodb.conf'
def setUp(self):
if not os.environ.get("MONGODB_TEST_LIVE"):
self.skipTest("No MongoDB instance running")
super(ClaimsMongoDBTests, self).setUp()
self.cfg = config.namespace("drivers:storage:mongodb").from_options()
def tearDown(self):
conn = pymongo.MongoClient(self.cfg.uri)
conn.drop_database(self.cfg.database)
super(ClaimsMongoDBTests, self).tearDown()
class ClaimsSQLiteTests(ClaimsBaseTest): class ClaimsSQLiteTests(ClaimsBaseTest):
config_filename = 'wsgi_sqlite.conf' config_filename = 'wsgi_sqlite.conf'

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import json import json
import os
import falcon import falcon
from falcon import testing from falcon import testing
@ -35,6 +36,13 @@ class MessagesBaseTest(base.TestBase):
'Client-ID': '30387f00', 'Client-ID': '30387f00',
} }
def tearDown(self):
env = testing.create_environ('/v1/480924/queues/fizbit',
method="DELETE")
self.app(env, self.srmock)
super(MessagesBaseTest, self).tearDown()
def test_post(self): def test_post(self):
doc = ''' doc = '''
[ [
@ -43,15 +51,23 @@ class MessagesBaseTest(base.TestBase):
{"body": [1, 3], "ttl": 30} {"body": [1, 3], "ttl": 30}
] ]
''' '''
env = testing.create_environ('/v1/480924/queues/fizbit/messages',
path = '/v1/480924/queues/fizbit/messages'
env = testing.create_environ(path,
method="POST", method="POST",
body=doc, body=doc,
headers=self.headers) headers=self.headers)
self.app(env, self.srmock) body = self.app(env, self.srmock)
self.assertEquals(self.srmock.status, falcon.HTTP_201) self.assertEquals(self.srmock.status, falcon.HTTP_201)
msg_ids = self._get_msg_ids(self.srmock.headers_dict) msg_ids = self._get_msg_ids(self.srmock.headers_dict)
body = json.loads(body[0])
expected_resources = [unicode(path + '/' + id) for id in msg_ids]
self.assertEquals(expected_resources, body['resources'])
self.assertFalse(body['partial'])
real_msgs = json.loads(doc) real_msgs = json.loads(doc)
self.assertEquals(len(msg_ids), len(real_msgs)) self.assertEquals(len(msg_ids), len(real_msgs))
@ -59,16 +75,16 @@ class MessagesBaseTest(base.TestBase):
lookup = dict([(m['ttl'], m['body']) for m in real_msgs]) lookup = dict([(m['ttl'], m['body']) for m in real_msgs])
for msg_id in msg_ids: for msg_id in msg_ids:
env = testing.create_environ('/v1/480924/queues/fizbit/messages/' message_uri = path + '/' + msg_id
+ msg_id, method="GET") env = testing.create_environ(message_uri, method="GET")
body = self.app(env, self.srmock) body = self.app(env, self.srmock)
self.assertEquals(self.srmock.status, falcon.HTTP_200) self.assertEquals(self.srmock.status, falcon.HTTP_200)
self.assertEquals(self.srmock.headers_dict['Content-Location'], self.assertEquals(self.srmock.headers_dict['Content-Location'],
env['PATH_INFO']) message_uri)
msg = json.loads(body[0]) msg = json.loads(body[0])
self.assertEquals(msg['href'], env['PATH_INFO']) self.assertEquals(msg['href'], message_uri)
self.assertEquals(msg['body'], lookup[msg['ttl']]) self.assertEquals(msg['body'], lookup[msg['ttl']])
self._post_messages('/v1/480924/queues/nonexistent/messages') self._post_messages('/v1/480924/queues/nonexistent/messages')
@ -175,6 +191,16 @@ class MessagesBaseTest(base.TestBase):
body = self.app(env, self.srmock) body = self.app(env, self.srmock)
self.assertEquals(self.srmock.status, falcon.HTTP_404) self.assertEquals(self.srmock.status, falcon.HTTP_404)
def test_list_with_bad_marker(self):
self._post_messages('/v1/480924/queues/fizbit/messages', repeat=5)
query_string = 'limit=3&echo=true&marker=sfhlsfdjh2048'
env = testing.create_environ('/v1/480924/queues/fizbit/messages',
query_string=query_string,
headers=self.headers)
self.app(env, self.srmock)
self.assertEqual(self.srmock.status, falcon.HTTP_400)
def test_no_uuid(self): def test_no_uuid(self):
env = testing.create_environ('/v1/480924/queues/fizbit/messages', env = testing.create_environ('/v1/480924/queues/fizbit/messages',
method="POST", method="POST",
@ -189,13 +215,6 @@ class MessagesBaseTest(base.TestBase):
self.app(env, self.srmock) self.app(env, self.srmock)
self.assertEquals(self.srmock.status, falcon.HTTP_400) self.assertEquals(self.srmock.status, falcon.HTTP_400)
def tearDown(self):
env = testing.create_environ('/v1/480924/queues/fizbit',
method="DELETE")
self.app(env, self.srmock)
super(MessagesBaseTest, self).tearDown()
def _post_messages(self, target, repeat=1): def _post_messages(self, target, repeat=1):
doc = json.dumps([{"body": 239, "ttl": 30}] * repeat) doc = json.dumps([{"body": 239, "ttl": 30}] * repeat)
@ -214,6 +233,17 @@ class MessagesSQLiteTests(MessagesBaseTest):
config_filename = 'wsgi_sqlite.conf' config_filename = 'wsgi_sqlite.conf'
class MessagesMongoDBTests(MessagesBaseTest):
config_filename = 'wsgi_mongodb.conf'
def setUp(self):
if not os.environ.get("MONGODB_TEST_LIVE"):
self.skipTest("No MongoDB instance running")
super(MessagesMongoDBTests, self).setUp()
class MessagesFaultyDriverTests(base.TestBase): class MessagesFaultyDriverTests(base.TestBase):
config_filename = 'wsgi_faulty.conf' config_filename = 'wsgi_faulty.conf'

View File

@ -18,8 +18,7 @@ import testtools
from marconi.common import config from marconi.common import config
CFG = config.project()
cfg = config.project()
class TestBase(testtools.TestCase): class TestBase(testtools.TestCase):
@ -48,8 +47,8 @@ class TestBase(testtools.TestCase):
:returns: Project's config object. :returns: Project's config object.
""" """
cfg.load(filename=self.conf_path(filename)) CFG.load(filename=self.conf_path(filename))
return cfg return CFG
def _my_dir(self): def _my_dir(self):
return os.path.abspath(os.path.dirname(__file__)) return os.path.abspath(os.path.dirname(__file__))

View File

@ -57,6 +57,8 @@ class CollectionResource(object):
messages = itertools.chain((first_message,), messages) messages = itertools.chain((first_message,), messages)
# Enqueue the messages # Enqueue the messages
partial = False
try: try:
message_ids = self.message_controller.post( message_ids = self.message_controller.post(
queue_name, queue_name,
@ -66,15 +68,30 @@ class CollectionResource(object):
except storage_exceptions.DoesNotExist: except storage_exceptions.DoesNotExist:
raise falcon.HTTPNotFound() raise falcon.HTTPNotFound()
except storage_exceptions.MessageConflict as ex:
LOG.exception(ex)
partial = True
message_ids = ex.succeeded_ids
if not message_ids:
#TODO(kgriffs): Include error code that is different
# from the code used in the generic case, below.
description = _('No messages could be enqueued.')
raise wsgi_exceptions.HTTPServiceUnavailable(description)
except Exception as ex: except Exception as ex:
LOG.exception(ex) LOG.exception(ex)
description = _('Messages could not be enqueued.') description = _('Messages could not be enqueued.')
raise wsgi_exceptions.HTTPServiceUnavailable(description) raise wsgi_exceptions.HTTPServiceUnavailable(description)
#TODO(kgriffs): Optimize # Prepare the response
resource = ','.join([id.encode('utf-8') for id in message_ids])
resp.location = req.path + '/' + resource
resp.status = falcon.HTTP_201 resp.status = falcon.HTTP_201
resource = ','.join(message_ids)
resp.location = req.path + '/' + resource
hrefs = [req.path + '/' + id for id in message_ids]
body = {"resources": hrefs, "partial": partial}
resp.body = helpers.to_json(body)
def on_get(self, req, resp, project_id, queue_name): def on_get(self, req, resp, project_id, queue_name):
uuid = req.get_header('Client-ID', required=True) uuid = req.get_header('Client-ID', required=True)
@ -98,7 +115,19 @@ class CollectionResource(object):
messages = list(cursor) messages = list(cursor)
except storage_exceptions.DoesNotExist: except storage_exceptions.DoesNotExist:
raise falcon.HTTPNotFound raise falcon.HTTPNotFound()
except storage_exceptions.MalformedMarker:
title = _('Invalid query string parameter')
description = _('The value for the query string '
'parameter "marker" could not be '
'parsed. We recommend using the '
'"next" URI from a previous '
'request directly, rather than '
'constructing the URI manually. ')
raise falcon.HTTPBadRequest(title, description)
except Exception as ex: except Exception as ex:
LOG.exception(ex) LOG.exception(ex)
description = _('Messages could not be listed.') description = _('Messages could not be listed.')

View File

@ -40,6 +40,9 @@ setuptools.setup(
cmdclass=common_setup.get_cmdclass(), cmdclass=common_setup.get_cmdclass(),
entry_points={ entry_points={
'console_scripts': 'console_scripts':
['marconi-server = marconi.cmd.server:run'] [
'marconi-server = marconi.cmd.server:run',
'marconi-gc = marconi.cmd.gc:run'
]
} }
) )