diff --git a/swift/common/memcached.py b/swift/common/memcached.py
index bb359539ae..db73e7b455 100644
--- a/swift/common/memcached.py
+++ b/swift/common/memcached.py
@@ -47,6 +47,7 @@ http://github.com/memcached/memcached/blob/1.4.2/doc/protocol.txt
import six.moves.cPickle as pickle
import json
import logging
+import re
import time
from bisect import bisect
from swift import gettext_ as _
@@ -101,23 +102,57 @@ class MemcachePoolTimeout(Timeout):
class MemcacheConnPool(Pool):
- """Connection pool for Memcache Connections"""
+ """
+ Connection pool for Memcache Connections
+
+ The *server* parameter can be a hostname, an IPv4 address, or an IPv6
+ address with an optional port. If an IPv6 address is specified it **must**
+ be enclosed in [], like *[::1]* or *[::1]:11211*. This follows the accepted
+ prescription for IPv6 host literals:
+ https://tools.ietf.org/html/rfc3986#section-3.2.2.
+
+ Examples:
+
+ * memcache.local:11211
+ * 127.0.0.1:11211
+ * [::1]:11211
+ * [::1]
+ """
+ IPV6_RE = re.compile("^\[(?P
.*)\](:(?P[0-9]+))?$")
def __init__(self, server, size, connect_timeout):
Pool.__init__(self, max_size=size)
self.server = server
self._connect_timeout = connect_timeout
- def create(self):
- if ':' in self.server:
- host, port = self.server.split(':')
+ def _get_addr(self):
+ port = DEFAULT_MEMCACHED_PORT
+ # IPv6 addresses must be between '[]'
+ if self.server.startswith('['):
+ match = MemcacheConnPool.IPV6_RE.match(self.server)
+ if not match:
+ raise ValueError("Invalid IPv6 address: %s" % self.server)
+ host = match.group('address')
+ port = match.group('port') or port
else:
- host = self.server
- port = DEFAULT_MEMCACHED_PORT
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ if ':' in self.server:
+ tokens = self.server.split(':')
+ if len(tokens) > 2:
+ raise ValueError("IPv6 addresses must be between '[]'")
+ host, port = tokens
+ else:
+ host = self.server
+ return (host, port)
+
+ def create(self):
+ host, port = self._get_addr()
+ addrs = socket.getaddrinfo(host, port, socket.AF_UNSPEC,
+ socket.SOCK_STREAM)
+ family, socktype, proto, canonname, sockaddr = addrs[0]
+ sock = socket.socket(family, socket.SOCK_STREAM)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
with Timeout(self._connect_timeout):
- sock.connect((host, int(port)))
+ sock.connect(sockaddr)
return (sock.makefile(), sock)
def get(self):
diff --git a/test/unit/common/test_memcached.py b/test/unit/common/test_memcached.py
index 1490c02852..d226bd4c9c 100644
--- a/test/unit/common/test_memcached.py
+++ b/test/unit/common/test_memcached.py
@@ -182,9 +182,121 @@ class TestMemcached(unittest.TestCase):
one = False
if peeripport == sock2ipport:
two = False
+ self.assertEqual(len(memcache_client._errors[sock1ipport]), 0)
+ self.assertEqual(len(memcache_client._errors[sock2ip]), 0)
finally:
memcached.DEFAULT_MEMCACHED_PORT = orig_port
+ def test_get_conns_v6(self):
+ if not socket.has_ipv6:
+ return
+ try:
+ sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+ sock.bind(('::1', 0, 0, 0))
+ sock.listen(1)
+ sock_addr = sock.getsockname()
+ server_socket = '[%s]:%s' % (sock_addr[0], sock_addr[1])
+ memcache_client = memcached.MemcacheRing([server_socket])
+ key = uuid4().hex
+ for conn in memcache_client._get_conns(key):
+ peer_sockaddr = conn[2].getpeername()
+ peer_socket = '[%s]:%s' % (peer_sockaddr[0], peer_sockaddr[1])
+ self.assertEqual(peer_socket, server_socket)
+ self.assertEqual(len(memcache_client._errors[server_socket]), 0)
+ finally:
+ sock.close()
+
+ def test_get_conns_v6_default(self):
+ if not socket.has_ipv6:
+ return
+ try:
+ sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+ sock.bind(('::1', 0))
+ sock.listen(1)
+ sock_addr = sock.getsockname()
+ server_socket = '[%s]:%s' % (sock_addr[0], sock_addr[1])
+ server_host = '[%s]' % sock_addr[0]
+ memcached.DEFAULT_MEMCACHED_PORT = sock_addr[1]
+ memcache_client = memcached.MemcacheRing([server_host])
+ key = uuid4().hex
+ for conn in memcache_client._get_conns(key):
+ peer_sockaddr = conn[2].getpeername()
+ peer_socket = '[%s]:%s' % (peer_sockaddr[0], peer_sockaddr[1])
+ self.assertEqual(peer_socket, server_socket)
+ self.assertEqual(len(memcache_client._errors[server_host]), 0)
+ finally:
+ sock.close()
+
+ def test_get_conns_bad_v6(self):
+ if not socket.has_ipv6:
+ return
+ try:
+ sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+ sock.bind(('::1', 0))
+ sock.listen(1)
+ sock_addr = sock.getsockname()
+ # IPv6 address with missing [] is invalid
+ server_socket = '%s:%s' % (sock_addr[0], sock_addr[1])
+ memcache_client = memcached.MemcacheRing([server_socket])
+ key = uuid4().hex
+ for conn in memcache_client._get_conns(key):
+ peer_sockaddr = conn[2].getpeername()
+ peer_socket = '[%s]:%s' % (peer_sockaddr[0], peer_sockaddr[1])
+ self.assertEqual(peer_socket, server_socket)
+ # Expect a parsing error when creating the socket
+ self.assertEqual(len(memcache_client._errors[server_socket]), 1)
+ finally:
+ sock.close()
+
+ def test_get_conns_hostname(self):
+ with patch('swift.common.memcached.socket.getaddrinfo') as addrinfo:
+ try:
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.bind(('127.0.0.1', 0))
+ sock.listen(1)
+ sock_addr = sock.getsockname()
+ fqdn = socket.getfqdn()
+ server_socket = '%s:%s' % (fqdn, sock_addr[1])
+ addrinfo.return_value = [(socket.AF_INET,
+ socket.SOCK_STREAM, 0, '',
+ ('127.0.0.1', sock_addr[1]))]
+ memcache_client = memcached.MemcacheRing([server_socket])
+ key = uuid4().hex
+ for conn in memcache_client._get_conns(key):
+ peer_sockaddr = conn[2].getpeername()
+ peer_socket = '%s:%s' % (peer_sockaddr[0],
+ peer_sockaddr[1])
+ self.assertEqual(peer_socket,
+ '127.0.0.1:%d' % sock_addr[1])
+ self.assertEqual(len(memcache_client._errors[server_socket]),
+ 0)
+ finally:
+ sock.close()
+
+ def test_get_conns_hostname6(self):
+ with patch('swift.common.memcached.socket.getaddrinfo') as addrinfo:
+ try:
+ sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+ sock.bind(('::1', 0))
+ sock.listen(1)
+ sock_addr = sock.getsockname()
+ fqdn = socket.getfqdn()
+ server_socket = '%s:%s' % (fqdn, sock_addr[1])
+ addrinfo.return_value = [(socket.AF_INET6,
+ socket.SOCK_STREAM, 0, '',
+ ('::1', sock_addr[1]))]
+ memcache_client = memcached.MemcacheRing([server_socket])
+ key = uuid4().hex
+ for conn in memcache_client._get_conns(key):
+ peer_sockaddr = conn[2].getpeername()
+ peer_socket = '[%s]:%s' % (peer_sockaddr[0],
+ peer_sockaddr[1])
+ self.assertEqual(peer_socket, '[::1]:%d' % sock_addr[1])
+ self.assertEqual(len(memcache_client._errors[server_socket]),
+ 0)
+ finally:
+ sock.close()
+
def test_set_get(self):
memcache_client = memcached.MemcacheRing(['1.2.3.4:11211'])
mock = MockMemcached()
@@ -349,6 +461,13 @@ class TestMemcached(unittest.TestCase):
def test_connection_pooling(self):
with patch('swift.common.memcached.socket') as mock_module:
+ def mock_getaddrinfo(host, port, family=socket.AF_INET,
+ socktype=socket.SOCK_STREAM, proto=0,
+ flags=0):
+ return [(family, socktype, proto, '', (host, port))]
+
+ mock_module.getaddrinfo = mock_getaddrinfo
+
# patch socket, stub socket.socket, mock sock
mock_sock = mock_module.socket.return_value
@@ -462,5 +581,27 @@ class TestMemcached(unittest.TestCase):
finally:
memcached.MemcacheConnPool = orig_conn_pool
+ def test_connection_pool_parser(self):
+ default = memcached.DEFAULT_MEMCACHED_PORT
+ addrs = [('1.2.3.4', '1.2.3.4', default),
+ ('1.2.3.4:5000', '1.2.3.4', 5000),
+ ('[dead:beef::1]', 'dead:beef::1', default),
+ ('[dead:beef::1]:5000', 'dead:beef::1', 5000),
+ ('example.com', 'example.com', default),
+ ('example.com:5000', 'example.com', 5000),
+ ('foo.1-2-3.bar.com:5000', 'foo.1-2-3.bar.com', 5000),
+ ('1.2.3.4:10:20', None, None),
+ ('dead:beef::1:5000', None, None)]
+
+ for addr, expected_host, expected_port in addrs:
+ pool = memcached.MemcacheConnPool(addr, 1, 0)
+ if expected_host:
+ host, port = pool._get_addr()
+ self.assertEqual(expected_host, host)
+ self.assertEqual(expected_port, int(port))
+ else:
+ with self.assertRaises(ValueError):
+ pool._get_addr()
+
if __name__ == '__main__':
unittest.main()