Fail early if the memcache address is invalid.

In cases when the memcache address is invalid, we should fail early.
This patch addresses the cases when the IPv6 addresses are not
enclosed in "[]". It does not, however, fix the case of an invalid
hostname. These improvements could also be added to the _get_addr()
method.

Change-Id: I4743dcda45a1fc1640989325c4a2e1fea591fc69
This commit is contained in:
Timur Alperovich 2016-01-08 14:54:56 -08:00
parent 3347646023
commit 0647aea9c5
2 changed files with 31 additions and 48 deletions

View File

@ -123,31 +123,30 @@ class MemcacheConnPool(Pool):
def __init__(self, server, size, connect_timeout):
Pool.__init__(self, max_size=size)
self.server = server
self.host, self.port = self._get_addr(server)
self._connect_timeout = connect_timeout
def _get_addr(self):
def _get_addr(self, server):
port = DEFAULT_MEMCACHED_PORT
# IPv6 addresses must be between '[]'
if self.server.startswith('['):
match = MemcacheConnPool.IPV6_RE.match(self.server)
if server.startswith('['):
match = MemcacheConnPool.IPV6_RE.match(server)
if not match:
raise ValueError("Invalid IPv6 address: %s" % self.server)
raise ValueError("Invalid IPv6 address: %s" % server)
host = match.group('address')
port = match.group('port') or port
else:
if ':' in self.server:
tokens = self.server.split(':')
if ':' in server:
tokens = server.split(':')
if len(tokens) > 2:
raise ValueError("IPv6 addresses must be between '[]'")
host, port = tokens
else:
host = self.server
host = server
return (host, port)
def create(self):
host, port = self._get_addr()
addrs = socket.getaddrinfo(host, port, socket.AF_UNSPEC,
addrs = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC,
socket.SOCK_STREAM)
family, socktype, proto, canonname, sockaddr = addrs[0]
sock = socket.socket(family, socket.SOCK_STREAM)

View File

@ -228,25 +228,10 @@ class TestMemcached(unittest.TestCase):
sock.close()
def test_get_conns_bad_v6(self):
if not socket.has_ipv6:
return
try:
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
sock.bind(('::1', 0))
sock.listen(1)
sock_addr = sock.getsockname()
with self.assertRaises(ValueError):
# IPv6 address with missing [] is invalid
server_socket = '%s:%s' % (sock_addr[0], sock_addr[1])
memcache_client = memcached.MemcacheRing([server_socket])
key = uuid4().hex
for conn in memcache_client._get_conns(key):
peer_sockaddr = conn[2].getpeername()
peer_socket = '[%s]:%s' % (peer_sockaddr[0], peer_sockaddr[1])
self.assertEqual(peer_socket, server_socket)
# Expect a parsing error when creating the socket
self.assertEqual(len(memcache_client._errors[server_socket]), 1)
finally:
sock.close()
server_socket = '%s:%s' % ('::1', 11211)
memcached.MemcacheRing([server_socket])
def test_get_conns_hostname(self):
with patch('swift.common.memcached.socket.getaddrinfo') as addrinfo:
@ -533,14 +518,14 @@ class TestMemcached(unittest.TestCase):
class MockConnectionPool(orig_conn_pool):
def get(self):
pending[self.server] += 1
conn = connections[self.server].get()
pending[self.server] -= 1
pending[self.host] += 1
conn = connections[self.host].get()
pending[self.host] -= 1
return conn
def put(self, *args, **kwargs):
connections[self.server].put(*args, **kwargs)
served[self.server] += 1
connections[self.host].put(*args, **kwargs)
served[self.host] += 1
memcached.MemcacheConnPool = MockConnectionPool
@ -554,12 +539,12 @@ class TestMemcached(unittest.TestCase):
# then move on to .4, and we'll assert all that below.
mock_conn = MagicMock(), MagicMock()
mock_conn[1].sendall = lambda x: sleep(0.2)
connections['1.2.3.5:11211'].put(mock_conn)
connections['1.2.3.5:11211'].put(mock_conn)
connections['1.2.3.5'].put(mock_conn)
connections['1.2.3.5'].put(mock_conn)
mock_conn = MagicMock(), MagicMock()
connections['1.2.3.4:11211'].put(mock_conn)
connections['1.2.3.4:11211'].put(mock_conn)
connections['1.2.3.4'].put(mock_conn)
connections['1.2.3.4'].put(mock_conn)
p = GreenPool()
for i in range(10):
@ -568,16 +553,16 @@ class TestMemcached(unittest.TestCase):
# Wait for the dust to settle.
p.waitall()
self.assertEqual(pending['1.2.3.5:11211'], 8)
self.assertEqual(pending['1.2.3.5'], 8)
self.assertEqual(len(memcache_client._errors['1.2.3.5:11211']), 8)
self.assertEqual(served['1.2.3.5:11211'], 2)
self.assertEqual(pending['1.2.3.4:11211'], 0)
self.assertEqual(served['1.2.3.5'], 2)
self.assertEqual(pending['1.2.3.4'], 0)
self.assertEqual(len(memcache_client._errors['1.2.3.4:11211']), 0)
self.assertEqual(served['1.2.3.4:11211'], 8)
self.assertEqual(served['1.2.3.4'], 8)
# and we never got more put in that we gave out
self.assertEqual(connections['1.2.3.5:11211'].qsize(), 2)
self.assertEqual(connections['1.2.3.4:11211'].qsize(), 2)
self.assertEqual(connections['1.2.3.5'].qsize(), 2)
self.assertEqual(connections['1.2.3.4'].qsize(), 2)
finally:
memcached.MemcacheConnPool = orig_conn_pool
@ -594,14 +579,13 @@ class TestMemcached(unittest.TestCase):
('dead:beef::1:5000', None, None)]
for addr, expected_host, expected_port in addrs:
pool = memcached.MemcacheConnPool(addr, 1, 0)
if expected_host:
host, port = pool._get_addr()
self.assertEqual(expected_host, host)
self.assertEqual(expected_port, int(port))
pool = memcached.MemcacheConnPool(addr, 1, 0)
self.assertEqual(expected_host, pool.host)
self.assertEqual(expected_port, int(pool.port))
else:
with self.assertRaises(ValueError):
pool._get_addr()
memcached.MemcacheConnPool(addr, 1, 0)
if __name__ == '__main__':
unittest.main()