Fixing a possible memcached race condition and refactoring the incr/decr functionality

This commit is contained in:
David Goetz 2010-10-25 20:42:28 +00:00 committed by Tarmac
commit 6e025c43fe
3 changed files with 64 additions and 35 deletions

View File

@ -168,23 +168,42 @@ class MemcacheRing(object):
def incr(self, key, delta=1, timeout=0): def incr(self, key, delta=1, timeout=0):
""" """
Increments a key which has a numeric value by delta. Increments a key which has a numeric value by delta.
If the key can't be found, it's added as delta. If the key can't be found, it's added as delta or 0 if delta < 0.
If passed a negative number, will use memcached's decr. Returns
the int stored in memcached
Note: The data memcached stores as the result of incr/decr is
an unsigned int. decr's that result in a number below 0 are
stored as 0.
:param key: key :param key: key
:param delta: amount to add to the value of key (or set as the value :param delta: amount to add to the value of key (or set as the value
if the key is not found) if the key is not found) will be cast to an int
:param timeout: ttl in memcache :param timeout: ttl in memcache
""" """
key = md5hash(key) key = md5hash(key)
command = 'incr'
if delta < 0:
command = 'decr'
delta = str(abs(int(delta)))
for (server, fp, sock) in self._get_conns(key): for (server, fp, sock) in self._get_conns(key):
try: try:
sock.sendall('incr %s %s\r\n' % (key, delta)) sock.sendall('%s %s %s\r\n' % (command, key, delta))
line = fp.readline().strip().split() line = fp.readline().strip().split()
if line[0].upper() == 'NOT_FOUND': if line[0].upper() == 'NOT_FOUND':
line[0] = str(delta) add_val = delta
sock.sendall('add %s %d %d %s noreply\r\n%s\r\n' % \ if command == 'decr':
(key, 0, timeout, len(line[0]), line[0])) add_val = '0'
ret = int(line[0].strip()) sock.sendall('add %s %d %d %s\r\n%s\r\n' % \
(key, 0, timeout, len(add_val), add_val))
line = fp.readline().strip().split()
if line[0].upper() == 'NOT_STORED':
sock.sendall('%s %s %s\r\n' % (command, key, delta))
line = fp.readline().strip().split()
ret = int(line[0].strip())
else:
ret = int(add_val)
else:
ret = int(line[0].strip())
self._return_conn(server, fp, sock) self._return_conn(server, fp, sock)
return ret return ret
except Exception, e: except Exception, e:
@ -192,29 +211,16 @@ class MemcacheRing(object):
def decr(self, key, delta=1, timeout=0): def decr(self, key, delta=1, timeout=0):
""" """
Decrements a key which has a numeric value by delta. Decrements a key which has a numeric value by delta. Calls incr with
If the key can't be found, it's added as 0. Memcached -delta.
will treat data values below 0 as 0 with incr/decr.
:param key: key :param key: key
:param delta: amount to subtract to the value of key (or set :param delta: amount to subtract to the value of key (or set the
as the value if the key is not found) value to 0 if the key is not found) will be cast to
an int
:param timeout: ttl in memcache :param timeout: ttl in memcache
""" """
key = md5hash(key) self.incr(key, delta=-delta, timeout=timeout)
for (server, fp, sock) in self._get_conns(key):
try:
sock.sendall('decr %s %s\r\n' % (key, delta))
line = fp.readline().strip().split()
if line[0].upper() == 'NOT_FOUND':
line[0] = '0'
sock.sendall('add %s %d %d %s noreply\r\n%s\r\n' %
(key, 0, timeout, len(line[0]), line[0]))
ret = int(line[0].strip())
self._return_conn(server, fp, sock)
return ret
except Exception, e:
self._exception_occurred(server, e)
def delete(self, key): def delete(self, key):
""" """

View File

@ -36,19 +36,14 @@ class FakeMemcache(object):
return True return True
def incr(self, key, delta=1, timeout=0): def incr(self, key, delta=1, timeout=0):
if delta < 0: self.store[key] = int(self.store.setdefault(key, 0)) + int(delta)
raise "Cannot incr by a negative number"
self.store[key] = int(self.store.setdefault(key, 0)) + delta
return int(self.store[key])
def decr(self, key, delta=1, timeout=0):
if delta < 0:
raise "Cannot decr by a negative number"
self.store[key] = int(self.store.setdefault(key, 0)) - delta
if self.store[key] < 0: if self.store[key] < 0:
self.store[key] = 0 self.store[key] = 0
return int(self.store[key]) return int(self.store[key])
def decr(self, key, delta=1, timeout=0):
return self.incr(key, delta=-delta, timeout=timeout)
@contextmanager @contextmanager
def soft_lock(self, key, timeout=0, retries=5): def soft_lock(self, key, timeout=0, retries=5):
yield True yield True

View File

@ -98,6 +98,17 @@ class MockMemcached(object):
self.outbuf += str(val[2]) + '\r\n' self.outbuf += str(val[2]) + '\r\n'
else: else:
self.outbuf += 'NOT_FOUND\r\n' self.outbuf += 'NOT_FOUND\r\n'
elif parts[0].lower() == 'decr':
if parts[1] in self.cache:
val = list(self.cache[parts[1]])
if int(val[2]) - int(parts[2]) > 0:
val[2] = str(int(val[2]) - int(parts[2]))
else:
val[2] = '0'
self.cache[parts[1]] = val
self.outbuf += str(val[2]) + '\r\n'
else:
self.outbuf += 'NOT_FOUND\r\n'
def readline(self): def readline(self):
if self.down: if self.down:
raise Exception('mock is down') raise Exception('mock is down')
@ -151,6 +162,23 @@ class TestMemcached(unittest.TestCase):
self.assertEquals(memcache_client.get('some_key'), '10') self.assertEquals(memcache_client.get('some_key'), '10')
memcache_client.incr('some_key', delta=1) memcache_client.incr('some_key', delta=1)
self.assertEquals(memcache_client.get('some_key'), '11') self.assertEquals(memcache_client.get('some_key'), '11')
memcache_client.incr('some_key', delta=-5)
self.assertEquals(memcache_client.get('some_key'), '6')
memcache_client.incr('some_key', delta=-15)
self.assertEquals(memcache_client.get('some_key'), '0')
def test_decr(self):
memcache_client = memcached.MemcacheRing(['1.2.3.4:11211'])
mock = MockMemcached()
memcache_client._client_cache['1.2.3.4:11211'] = [(mock, mock)] * 2
memcache_client.decr('some_key', delta=5)
self.assertEquals(memcache_client.get('some_key'), '0')
memcache_client.incr('some_key', delta=15)
self.assertEquals(memcache_client.get('some_key'), '15')
memcache_client.decr('some_key', delta=4)
self.assertEquals(memcache_client.get('some_key'), '11')
memcache_client.decr('some_key', delta=15)
self.assertEquals(memcache_client.get('some_key'), '0')
def test_retry(self): def test_retry(self):
logging.getLogger().addHandler(NullLoggingHandler()) logging.getLogger().addHandler(NullLoggingHandler())