Merge "Fail early if the memcache address is invalid."

This commit is contained in:
Jenkins 2016-02-22 19:42:20 +00:00 committed by Gerrit Code Review
commit 3e3ad5cb16
2 changed files with 31 additions and 48 deletions

View File

@ -123,31 +123,30 @@ class MemcacheConnPool(Pool):
def __init__(self, server, size, connect_timeout):
Pool.__init__(self, max_size=size)
self.server = server
self.host, self.port = self._get_addr(server)
self._connect_timeout = connect_timeout
def _get_addr(self):
def _get_addr(self, server):
port = DEFAULT_MEMCACHED_PORT
# IPv6 addresses must be between '[]'
if self.server.startswith('['):
match = MemcacheConnPool.IPV6_RE.match(self.server)
if server.startswith('['):
match = MemcacheConnPool.IPV6_RE.match(server)
if not match:
raise ValueError("Invalid IPv6 address: %s" % self.server)
raise ValueError("Invalid IPv6 address: %s" % server)
host = match.group('address')
port = match.group('port') or port
else:
if ':' in self.server:
tokens = self.server.split(':')
if ':' in server:
tokens = server.split(':')
if len(tokens) > 2:
raise ValueError("IPv6 addresses must be between '[]'")
host, port = tokens
else:
host = self.server
host = server
return (host, port)
def create(self):
host, port = self._get_addr()
addrs = socket.getaddrinfo(host, port, socket.AF_UNSPEC,
addrs = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC,
socket.SOCK_STREAM)
family, socktype, proto, canonname, sockaddr = addrs[0]
sock = socket.socket(family, socket.SOCK_STREAM)

View File

@ -228,25 +228,10 @@ class TestMemcached(unittest.TestCase):
sock.close()
def test_get_conns_bad_v6(self):
if not socket.has_ipv6:
return
try:
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
sock.bind(('::1', 0))
sock.listen(1)
sock_addr = sock.getsockname()
with self.assertRaises(ValueError):
# IPv6 address with missing [] is invalid
server_socket = '%s:%s' % (sock_addr[0], sock_addr[1])
memcache_client = memcached.MemcacheRing([server_socket])
key = uuid4().hex
for conn in memcache_client._get_conns(key):
peer_sockaddr = conn[2].getpeername()
peer_socket = '[%s]:%s' % (peer_sockaddr[0], peer_sockaddr[1])
self.assertEqual(peer_socket, server_socket)
# Expect a parsing error when creating the socket
self.assertEqual(len(memcache_client._errors[server_socket]), 1)
finally:
sock.close()
server_socket = '%s:%s' % ('::1', 11211)
memcached.MemcacheRing([server_socket])
def test_get_conns_hostname(self):
with patch('swift.common.memcached.socket.getaddrinfo') as addrinfo:
@ -533,14 +518,14 @@ class TestMemcached(unittest.TestCase):
class MockConnectionPool(orig_conn_pool):
def get(self):
pending[self.server] += 1
conn = connections[self.server].get()
pending[self.server] -= 1
pending[self.host] += 1
conn = connections[self.host].get()
pending[self.host] -= 1
return conn
def put(self, *args, **kwargs):
connections[self.server].put(*args, **kwargs)
served[self.server] += 1
connections[self.host].put(*args, **kwargs)
served[self.host] += 1
memcached.MemcacheConnPool = MockConnectionPool
@ -554,12 +539,12 @@ class TestMemcached(unittest.TestCase):
# then move on to .4, and we'll assert all that below.
mock_conn = MagicMock(), MagicMock()
mock_conn[1].sendall = lambda x: sleep(0.2)
connections['1.2.3.5:11211'].put(mock_conn)
connections['1.2.3.5:11211'].put(mock_conn)
connections['1.2.3.5'].put(mock_conn)
connections['1.2.3.5'].put(mock_conn)
mock_conn = MagicMock(), MagicMock()
connections['1.2.3.4:11211'].put(mock_conn)
connections['1.2.3.4:11211'].put(mock_conn)
connections['1.2.3.4'].put(mock_conn)
connections['1.2.3.4'].put(mock_conn)
p = GreenPool()
for i in range(10):
@ -568,16 +553,16 @@ class TestMemcached(unittest.TestCase):
# Wait for the dust to settle.
p.waitall()
self.assertEqual(pending['1.2.3.5:11211'], 8)
self.assertEqual(pending['1.2.3.5'], 8)
self.assertEqual(len(memcache_client._errors['1.2.3.5:11211']), 8)
self.assertEqual(served['1.2.3.5:11211'], 2)
self.assertEqual(pending['1.2.3.4:11211'], 0)
self.assertEqual(served['1.2.3.5'], 2)
self.assertEqual(pending['1.2.3.4'], 0)
self.assertEqual(len(memcache_client._errors['1.2.3.4:11211']), 0)
self.assertEqual(served['1.2.3.4:11211'], 8)
self.assertEqual(served['1.2.3.4'], 8)
# and we never got more put in that we gave out
self.assertEqual(connections['1.2.3.5:11211'].qsize(), 2)
self.assertEqual(connections['1.2.3.4:11211'].qsize(), 2)
self.assertEqual(connections['1.2.3.5'].qsize(), 2)
self.assertEqual(connections['1.2.3.4'].qsize(), 2)
finally:
memcached.MemcacheConnPool = orig_conn_pool
@ -594,14 +579,13 @@ class TestMemcached(unittest.TestCase):
('dead:beef::1:5000', None, None)]
for addr, expected_host, expected_port in addrs:
pool = memcached.MemcacheConnPool(addr, 1, 0)
if expected_host:
host, port = pool._get_addr()
self.assertEqual(expected_host, host)
self.assertEqual(expected_port, int(port))
pool = memcached.MemcacheConnPool(addr, 1, 0)
self.assertEqual(expected_host, pool.host)
self.assertEqual(expected_port, int(pool.port))
else:
with self.assertRaises(ValueError):
pool._get_addr()
memcached.MemcacheConnPool(addr, 1, 0)
if __name__ == '__main__':
unittest.main()