diff --git a/swift/common/memcached.py b/swift/common/memcached.py index b7d4a41e26..9640ac6f8f 100644 --- a/swift/common/memcached.py +++ b/swift/common/memcached.py @@ -47,7 +47,6 @@ http://github.com/memcached/memcached/blob/1.4.2/doc/protocol.txt import six.moves.cPickle as pickle import json import logging -import re import time from bisect import bisect from swift import gettext_ as _ @@ -57,7 +56,7 @@ from eventlet.green import socket from eventlet.pools import Pool from eventlet import Timeout from six.moves import range - +from swift.common import utils DEFAULT_MEMCACHED_PORT = 11211 @@ -106,45 +105,16 @@ class MemcacheConnPool(Pool): Connection pool for Memcache Connections The *server* parameter can be a hostname, an IPv4 address, or an IPv6 - address with an optional port. If an IPv6 address is specified it **must** - be enclosed in [], like *[::1]* or *[::1]:11211*. This follows the accepted - prescription for `IPv6 host literals`_. - - Examples:: - - memcache.local:11211 - 127.0.0.1:11211 - [::1]:11211 - [::1] - - .. _IPv6 host literals: https://tools.ietf.org/html/rfc3986#section-3.2.2 + address with an optional port. See + :func:`swift.common.utils.parse_socket_string` for details. """ - IPV6_RE = re.compile("^\[(?P
.*)\](:(?P[0-9]+))?$") def __init__(self, server, size, connect_timeout): Pool.__init__(self, max_size=size) - self.host, self.port = self._get_addr(server) + self.host, self.port = utils.parse_socket_string( + server, DEFAULT_MEMCACHED_PORT) self._connect_timeout = connect_timeout - def _get_addr(self, server): - port = DEFAULT_MEMCACHED_PORT - # IPv6 addresses must be between '[]' - if server.startswith('['): - match = MemcacheConnPool.IPV6_RE.match(server) - if not match: - raise ValueError("Invalid IPv6 address: %s" % server) - host = match.group('address') - port = match.group('port') or port - else: - if ':' in server: - tokens = server.split(':') - if len(tokens) > 2: - raise ValueError("IPv6 addresses must be between '[]'") - host, port = tokens - else: - host = server - return (host, port) - def create(self): addrs = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) diff --git a/swift/common/utils.py b/swift/common/utils.py index f4192393c0..e65d1f27f0 100644 --- a/swift/common/utils.py +++ b/swift/common/utils.py @@ -112,6 +112,9 @@ SWIFT_CONF_FILE = '/etc/swift/swift.conf' AF_ALG = getattr(socket, 'AF_ALG', 38) F_SETPIPE_SZ = getattr(fcntl, 'F_SETPIPE_SZ', 1031) +# Used by the parse_socket_string() function to validate IPv6 addresses +IPV6_RE = re.compile("^\[(?P
.*)\](:(?P[0-9]+))?$") + class InvalidHashPathConfigError(ValueError): @@ -1799,6 +1802,43 @@ def whataremyips(bind_ip=None): return addresses +def parse_socket_string(socket_string, default_port): + """ + Given a string representing a socket, returns a tuple of (host, port). + Valid strings are DNS names, IPv4 addresses, or IPv6 addresses, with an + optional port. If an IPv6 address is specified it **must** be enclosed in + [], like *[::1]* or *[::1]:11211*. This follows the accepted prescription + for `IPv6 host literals`_. + + Examples:: + + server.org + server.org:1337 + 127.0.0.1:1337 + [::1]:1337 + [::1] + + .. _IPv6 host literals: https://tools.ietf.org/html/rfc3986#section-3.2.2 + """ + port = default_port + # IPv6 addresses must be between '[]' + if socket_string.startswith('['): + match = IPV6_RE.match(socket_string) + if not match: + raise ValueError("Invalid IPv6 address: %s" % socket_string) + host = match.group('address') + port = match.group('port') or port + else: + if ':' in socket_string: + tokens = socket_string.split(':') + if len(tokens) > 2: + raise ValueError("IPv6 addresses must be between '[]'") + host, port = tokens + else: + host = socket_string + return (host, port) + + def storage_directory(datadir, partition, name_hash): """ Get the storage directory diff --git a/test/unit/common/test_memcached.py b/test/unit/common/test_memcached.py index 13f9f6a86a..da7fbf3875 100644 --- a/test/unit/common/test_memcached.py +++ b/test/unit/common/test_memcached.py @@ -566,26 +566,5 @@ class TestMemcached(unittest.TestCase): finally: memcached.MemcacheConnPool = orig_conn_pool - def test_connection_pool_parser(self): - default = memcached.DEFAULT_MEMCACHED_PORT - addrs = [('1.2.3.4', '1.2.3.4', default), - ('1.2.3.4:5000', '1.2.3.4', 5000), - ('[dead:beef::1]', 'dead:beef::1', default), - ('[dead:beef::1]:5000', 'dead:beef::1', 5000), - ('example.com', 'example.com', default), - ('example.com:5000', 'example.com', 5000), - ('foo.1-2-3.bar.com:5000', 'foo.1-2-3.bar.com', 5000), - ('1.2.3.4:10:20', None, None), - ('dead:beef::1:5000', None, None)] - - for addr, expected_host, expected_port in addrs: - if expected_host: - pool = memcached.MemcacheConnPool(addr, 1, 0) - self.assertEqual(expected_host, pool.host) - self.assertEqual(expected_port, int(pool.port)) - else: - with self.assertRaises(ValueError): - memcached.MemcacheConnPool(addr, 1, 0) - if __name__ == '__main__': unittest.main() diff --git a/test/unit/common/test_utils.py b/test/unit/common/test_utils.py index 2c6b782641..7d8e42a8c2 100644 --- a/test/unit/common/test_utils.py +++ b/test/unit/common/test_utils.py @@ -5329,5 +5329,28 @@ class TestPairs(unittest.TestCase): (50, 60)])) +class TestSocketStringParser(unittest.TestCase): + def test_socket_string_parser(self): + default = 1337 + addrs = [('1.2.3.4', '1.2.3.4', default), + ('1.2.3.4:5000', '1.2.3.4', 5000), + ('[dead:beef::1]', 'dead:beef::1', default), + ('[dead:beef::1]:5000', 'dead:beef::1', 5000), + ('example.com', 'example.com', default), + ('example.com:5000', 'example.com', 5000), + ('foo.1-2-3.bar.com:5000', 'foo.1-2-3.bar.com', 5000), + ('1.2.3.4:10:20', None, None), + ('dead:beef::1:5000', None, None)] + + for addr, expected_host, expected_port in addrs: + if expected_host: + host, port = utils.parse_socket_string(addr, default) + self.assertEqual(expected_host, host) + self.assertEqual(expected_port, int(port)) + else: + with self.assertRaises(ValueError): + utils.parse_socket_string(addr, default) + + if __name__ == '__main__': unittest.main()