diff --git a/swift/common/memcached.py b/swift/common/memcached.py index bb359539ae..db73e7b455 100644 --- a/swift/common/memcached.py +++ b/swift/common/memcached.py @@ -47,6 +47,7 @@ http://github.com/memcached/memcached/blob/1.4.2/doc/protocol.txt import six.moves.cPickle as pickle import json import logging +import re import time from bisect import bisect from swift import gettext_ as _ @@ -101,23 +102,57 @@ class MemcachePoolTimeout(Timeout): class MemcacheConnPool(Pool): - """Connection pool for Memcache Connections""" + """ + Connection pool for Memcache Connections + + The *server* parameter can be a hostname, an IPv4 address, or an IPv6 + address with an optional port. If an IPv6 address is specified it **must** + be enclosed in [], like *[::1]* or *[::1]:11211*. This follows the accepted + prescription for IPv6 host literals: + https://tools.ietf.org/html/rfc3986#section-3.2.2. + + Examples: + + * memcache.local:11211 + * 127.0.0.1:11211 + * [::1]:11211 + * [::1] + """ + IPV6_RE = re.compile("^\[(?P
.*)\](:(?P[0-9]+))?$") def __init__(self, server, size, connect_timeout): Pool.__init__(self, max_size=size) self.server = server self._connect_timeout = connect_timeout - def create(self): - if ':' in self.server: - host, port = self.server.split(':') + def _get_addr(self): + port = DEFAULT_MEMCACHED_PORT + # IPv6 addresses must be between '[]' + if self.server.startswith('['): + match = MemcacheConnPool.IPV6_RE.match(self.server) + if not match: + raise ValueError("Invalid IPv6 address: %s" % self.server) + host = match.group('address') + port = match.group('port') or port else: - host = self.server - port = DEFAULT_MEMCACHED_PORT - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if ':' in self.server: + tokens = self.server.split(':') + if len(tokens) > 2: + raise ValueError("IPv6 addresses must be between '[]'") + host, port = tokens + else: + host = self.server + return (host, port) + + def create(self): + host, port = self._get_addr() + addrs = socket.getaddrinfo(host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM) + family, socktype, proto, canonname, sockaddr = addrs[0] + sock = socket.socket(family, socket.SOCK_STREAM) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) with Timeout(self._connect_timeout): - sock.connect((host, int(port))) + sock.connect(sockaddr) return (sock.makefile(), sock) def get(self): diff --git a/test/unit/common/test_memcached.py b/test/unit/common/test_memcached.py index 1490c02852..d226bd4c9c 100644 --- a/test/unit/common/test_memcached.py +++ b/test/unit/common/test_memcached.py @@ -182,9 +182,121 @@ class TestMemcached(unittest.TestCase): one = False if peeripport == sock2ipport: two = False + self.assertEqual(len(memcache_client._errors[sock1ipport]), 0) + self.assertEqual(len(memcache_client._errors[sock2ip]), 0) finally: memcached.DEFAULT_MEMCACHED_PORT = orig_port + def test_get_conns_v6(self): + if not socket.has_ipv6: + return + try: + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + sock.bind(('::1', 0, 0, 0)) + sock.listen(1) + sock_addr = sock.getsockname() + server_socket = '[%s]:%s' % (sock_addr[0], sock_addr[1]) + memcache_client = memcached.MemcacheRing([server_socket]) + key = uuid4().hex + for conn in memcache_client._get_conns(key): + peer_sockaddr = conn[2].getpeername() + peer_socket = '[%s]:%s' % (peer_sockaddr[0], peer_sockaddr[1]) + self.assertEqual(peer_socket, server_socket) + self.assertEqual(len(memcache_client._errors[server_socket]), 0) + finally: + sock.close() + + def test_get_conns_v6_default(self): + if not socket.has_ipv6: + return + try: + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + sock.bind(('::1', 0)) + sock.listen(1) + sock_addr = sock.getsockname() + server_socket = '[%s]:%s' % (sock_addr[0], sock_addr[1]) + server_host = '[%s]' % sock_addr[0] + memcached.DEFAULT_MEMCACHED_PORT = sock_addr[1] + memcache_client = memcached.MemcacheRing([server_host]) + key = uuid4().hex + for conn in memcache_client._get_conns(key): + peer_sockaddr = conn[2].getpeername() + peer_socket = '[%s]:%s' % (peer_sockaddr[0], peer_sockaddr[1]) + self.assertEqual(peer_socket, server_socket) + self.assertEqual(len(memcache_client._errors[server_host]), 0) + finally: + sock.close() + + def test_get_conns_bad_v6(self): + if not socket.has_ipv6: + return + try: + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + sock.bind(('::1', 0)) + sock.listen(1) + sock_addr = sock.getsockname() + # IPv6 address with missing [] is invalid + server_socket = '%s:%s' % (sock_addr[0], sock_addr[1]) + memcache_client = memcached.MemcacheRing([server_socket]) + key = uuid4().hex + for conn in memcache_client._get_conns(key): + peer_sockaddr = conn[2].getpeername() + peer_socket = '[%s]:%s' % (peer_sockaddr[0], peer_sockaddr[1]) + self.assertEqual(peer_socket, server_socket) + # Expect a parsing error when creating the socket + self.assertEqual(len(memcache_client._errors[server_socket]), 1) + finally: + sock.close() + + def test_get_conns_hostname(self): + with patch('swift.common.memcached.socket.getaddrinfo') as addrinfo: + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(('127.0.0.1', 0)) + sock.listen(1) + sock_addr = sock.getsockname() + fqdn = socket.getfqdn() + server_socket = '%s:%s' % (fqdn, sock_addr[1]) + addrinfo.return_value = [(socket.AF_INET, + socket.SOCK_STREAM, 0, '', + ('127.0.0.1', sock_addr[1]))] + memcache_client = memcached.MemcacheRing([server_socket]) + key = uuid4().hex + for conn in memcache_client._get_conns(key): + peer_sockaddr = conn[2].getpeername() + peer_socket = '%s:%s' % (peer_sockaddr[0], + peer_sockaddr[1]) + self.assertEqual(peer_socket, + '127.0.0.1:%d' % sock_addr[1]) + self.assertEqual(len(memcache_client._errors[server_socket]), + 0) + finally: + sock.close() + + def test_get_conns_hostname6(self): + with patch('swift.common.memcached.socket.getaddrinfo') as addrinfo: + try: + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + sock.bind(('::1', 0)) + sock.listen(1) + sock_addr = sock.getsockname() + fqdn = socket.getfqdn() + server_socket = '%s:%s' % (fqdn, sock_addr[1]) + addrinfo.return_value = [(socket.AF_INET6, + socket.SOCK_STREAM, 0, '', + ('::1', sock_addr[1]))] + memcache_client = memcached.MemcacheRing([server_socket]) + key = uuid4().hex + for conn in memcache_client._get_conns(key): + peer_sockaddr = conn[2].getpeername() + peer_socket = '[%s]:%s' % (peer_sockaddr[0], + peer_sockaddr[1]) + self.assertEqual(peer_socket, '[::1]:%d' % sock_addr[1]) + self.assertEqual(len(memcache_client._errors[server_socket]), + 0) + finally: + sock.close() + def test_set_get(self): memcache_client = memcached.MemcacheRing(['1.2.3.4:11211']) mock = MockMemcached() @@ -349,6 +461,13 @@ class TestMemcached(unittest.TestCase): def test_connection_pooling(self): with patch('swift.common.memcached.socket') as mock_module: + def mock_getaddrinfo(host, port, family=socket.AF_INET, + socktype=socket.SOCK_STREAM, proto=0, + flags=0): + return [(family, socktype, proto, '', (host, port))] + + mock_module.getaddrinfo = mock_getaddrinfo + # patch socket, stub socket.socket, mock sock mock_sock = mock_module.socket.return_value @@ -462,5 +581,27 @@ class TestMemcached(unittest.TestCase): finally: memcached.MemcacheConnPool = orig_conn_pool + def test_connection_pool_parser(self): + default = memcached.DEFAULT_MEMCACHED_PORT + addrs = [('1.2.3.4', '1.2.3.4', default), + ('1.2.3.4:5000', '1.2.3.4', 5000), + ('[dead:beef::1]', 'dead:beef::1', default), + ('[dead:beef::1]:5000', 'dead:beef::1', 5000), + ('example.com', 'example.com', default), + ('example.com:5000', 'example.com', 5000), + ('foo.1-2-3.bar.com:5000', 'foo.1-2-3.bar.com', 5000), + ('1.2.3.4:10:20', None, None), + ('dead:beef::1:5000', None, None)] + + for addr, expected_host, expected_port in addrs: + pool = memcached.MemcacheConnPool(addr, 1, 0) + if expected_host: + host, port = pool._get_addr() + self.assertEqual(expected_host, host) + self.assertEqual(expected_port, int(port)) + else: + with self.assertRaises(ValueError): + pool._get_addr() + if __name__ == '__main__': unittest.main()