diff --git a/swift/common/memcached.py b/swift/common/memcached.py index 1bfa424d02..b7d4a41e26 100644 --- a/swift/common/memcached.py +++ b/swift/common/memcached.py @@ -123,31 +123,30 @@ class MemcacheConnPool(Pool): def __init__(self, server, size, connect_timeout): Pool.__init__(self, max_size=size) - self.server = server + self.host, self.port = self._get_addr(server) self._connect_timeout = connect_timeout - def _get_addr(self): + def _get_addr(self, server): port = DEFAULT_MEMCACHED_PORT # IPv6 addresses must be between '[]' - if self.server.startswith('['): - match = MemcacheConnPool.IPV6_RE.match(self.server) + if server.startswith('['): + match = MemcacheConnPool.IPV6_RE.match(server) if not match: - raise ValueError("Invalid IPv6 address: %s" % self.server) + raise ValueError("Invalid IPv6 address: %s" % server) host = match.group('address') port = match.group('port') or port else: - if ':' in self.server: - tokens = self.server.split(':') + if ':' in server: + tokens = server.split(':') if len(tokens) > 2: raise ValueError("IPv6 addresses must be between '[]'") host, port = tokens else: - host = self.server + host = server return (host, port) def create(self): - host, port = self._get_addr() - addrs = socket.getaddrinfo(host, port, socket.AF_UNSPEC, + addrs = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) family, socktype, proto, canonname, sockaddr = addrs[0] sock = socket.socket(family, socket.SOCK_STREAM) diff --git a/test/unit/common/test_memcached.py b/test/unit/common/test_memcached.py index d226bd4c9c..13f9f6a86a 100644 --- a/test/unit/common/test_memcached.py +++ b/test/unit/common/test_memcached.py @@ -228,25 +228,10 @@ class TestMemcached(unittest.TestCase): sock.close() def test_get_conns_bad_v6(self): - if not socket.has_ipv6: - return - try: - sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - sock.bind(('::1', 0)) - sock.listen(1) - sock_addr = sock.getsockname() + with self.assertRaises(ValueError): # IPv6 address with missing [] is invalid - server_socket = '%s:%s' % (sock_addr[0], sock_addr[1]) - memcache_client = memcached.MemcacheRing([server_socket]) - key = uuid4().hex - for conn in memcache_client._get_conns(key): - peer_sockaddr = conn[2].getpeername() - peer_socket = '[%s]:%s' % (peer_sockaddr[0], peer_sockaddr[1]) - self.assertEqual(peer_socket, server_socket) - # Expect a parsing error when creating the socket - self.assertEqual(len(memcache_client._errors[server_socket]), 1) - finally: - sock.close() + server_socket = '%s:%s' % ('::1', 11211) + memcached.MemcacheRing([server_socket]) def test_get_conns_hostname(self): with patch('swift.common.memcached.socket.getaddrinfo') as addrinfo: @@ -533,14 +518,14 @@ class TestMemcached(unittest.TestCase): class MockConnectionPool(orig_conn_pool): def get(self): - pending[self.server] += 1 - conn = connections[self.server].get() - pending[self.server] -= 1 + pending[self.host] += 1 + conn = connections[self.host].get() + pending[self.host] -= 1 return conn def put(self, *args, **kwargs): - connections[self.server].put(*args, **kwargs) - served[self.server] += 1 + connections[self.host].put(*args, **kwargs) + served[self.host] += 1 memcached.MemcacheConnPool = MockConnectionPool @@ -554,12 +539,12 @@ class TestMemcached(unittest.TestCase): # then move on to .4, and we'll assert all that below. mock_conn = MagicMock(), MagicMock() mock_conn[1].sendall = lambda x: sleep(0.2) - connections['1.2.3.5:11211'].put(mock_conn) - connections['1.2.3.5:11211'].put(mock_conn) + connections['1.2.3.5'].put(mock_conn) + connections['1.2.3.5'].put(mock_conn) mock_conn = MagicMock(), MagicMock() - connections['1.2.3.4:11211'].put(mock_conn) - connections['1.2.3.4:11211'].put(mock_conn) + connections['1.2.3.4'].put(mock_conn) + connections['1.2.3.4'].put(mock_conn) p = GreenPool() for i in range(10): @@ -568,16 +553,16 @@ class TestMemcached(unittest.TestCase): # Wait for the dust to settle. p.waitall() - self.assertEqual(pending['1.2.3.5:11211'], 8) + self.assertEqual(pending['1.2.3.5'], 8) self.assertEqual(len(memcache_client._errors['1.2.3.5:11211']), 8) - self.assertEqual(served['1.2.3.5:11211'], 2) - self.assertEqual(pending['1.2.3.4:11211'], 0) + self.assertEqual(served['1.2.3.5'], 2) + self.assertEqual(pending['1.2.3.4'], 0) self.assertEqual(len(memcache_client._errors['1.2.3.4:11211']), 0) - self.assertEqual(served['1.2.3.4:11211'], 8) + self.assertEqual(served['1.2.3.4'], 8) # and we never got more put in that we gave out - self.assertEqual(connections['1.2.3.5:11211'].qsize(), 2) - self.assertEqual(connections['1.2.3.4:11211'].qsize(), 2) + self.assertEqual(connections['1.2.3.5'].qsize(), 2) + self.assertEqual(connections['1.2.3.4'].qsize(), 2) finally: memcached.MemcacheConnPool = orig_conn_pool @@ -594,14 +579,13 @@ class TestMemcached(unittest.TestCase): ('dead:beef::1:5000', None, None)] for addr, expected_host, expected_port in addrs: - pool = memcached.MemcacheConnPool(addr, 1, 0) if expected_host: - host, port = pool._get_addr() - self.assertEqual(expected_host, host) - self.assertEqual(expected_port, int(port)) + pool = memcached.MemcacheConnPool(addr, 1, 0) + self.assertEqual(expected_host, pool.host) + self.assertEqual(expected_port, int(pool.port)) else: with self.assertRaises(ValueError): - pool._get_addr() + memcached.MemcacheConnPool(addr, 1, 0) if __name__ == '__main__': unittest.main()