Merge "Make _get_addr() method a function in utils."

This commit is contained in:
Jenkins 2016-02-22 20:30:22 +00:00 committed by Gerrit Code Review
commit d9f500a128
4 changed files with 68 additions and 56 deletions

View File

@ -47,7 +47,6 @@ http://github.com/memcached/memcached/blob/1.4.2/doc/protocol.txt
import six.moves.cPickle as pickle import six.moves.cPickle as pickle
import json import json
import logging import logging
import re
import time import time
from bisect import bisect from bisect import bisect
from swift import gettext_ as _ from swift import gettext_ as _
@ -57,7 +56,7 @@ from eventlet.green import socket
from eventlet.pools import Pool from eventlet.pools import Pool
from eventlet import Timeout from eventlet import Timeout
from six.moves import range from six.moves import range
from swift.common import utils
DEFAULT_MEMCACHED_PORT = 11211 DEFAULT_MEMCACHED_PORT = 11211
@ -106,45 +105,16 @@ class MemcacheConnPool(Pool):
Connection pool for Memcache Connections Connection pool for Memcache Connections
The *server* parameter can be a hostname, an IPv4 address, or an IPv6 The *server* parameter can be a hostname, an IPv4 address, or an IPv6
address with an optional port. If an IPv6 address is specified it **must** address with an optional port. See
be enclosed in [], like *[::1]* or *[::1]:11211*. This follows the accepted :func:`swift.common.utils.parse_socket_string` for details.
prescription for `IPv6 host literals`_.
Examples::
memcache.local:11211
127.0.0.1:11211
[::1]:11211
[::1]
.. _IPv6 host literals: https://tools.ietf.org/html/rfc3986#section-3.2.2
""" """
IPV6_RE = re.compile("^\[(?P<address>.*)\](:(?P<port>[0-9]+))?$")
def __init__(self, server, size, connect_timeout): def __init__(self, server, size, connect_timeout):
Pool.__init__(self, max_size=size) Pool.__init__(self, max_size=size)
self.host, self.port = self._get_addr(server) self.host, self.port = utils.parse_socket_string(
server, DEFAULT_MEMCACHED_PORT)
self._connect_timeout = connect_timeout self._connect_timeout = connect_timeout
def _get_addr(self, server):
port = DEFAULT_MEMCACHED_PORT
# IPv6 addresses must be between '[]'
if server.startswith('['):
match = MemcacheConnPool.IPV6_RE.match(server)
if not match:
raise ValueError("Invalid IPv6 address: %s" % server)
host = match.group('address')
port = match.group('port') or port
else:
if ':' in server:
tokens = server.split(':')
if len(tokens) > 2:
raise ValueError("IPv6 addresses must be between '[]'")
host, port = tokens
else:
host = server
return (host, port)
def create(self): def create(self):
addrs = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, addrs = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC,
socket.SOCK_STREAM) socket.SOCK_STREAM)

View File

@ -112,6 +112,9 @@ SWIFT_CONF_FILE = '/etc/swift/swift.conf'
AF_ALG = getattr(socket, 'AF_ALG', 38) AF_ALG = getattr(socket, 'AF_ALG', 38)
F_SETPIPE_SZ = getattr(fcntl, 'F_SETPIPE_SZ', 1031) F_SETPIPE_SZ = getattr(fcntl, 'F_SETPIPE_SZ', 1031)
# Used by the parse_socket_string() function to validate IPv6 addresses
IPV6_RE = re.compile("^\[(?P<address>.*)\](:(?P<port>[0-9]+))?$")
class InvalidHashPathConfigError(ValueError): class InvalidHashPathConfigError(ValueError):
@ -1799,6 +1802,43 @@ def whataremyips(bind_ip=None):
return addresses return addresses
def parse_socket_string(socket_string, default_port):
"""
Given a string representing a socket, returns a tuple of (host, port).
Valid strings are DNS names, IPv4 addresses, or IPv6 addresses, with an
optional port. If an IPv6 address is specified it **must** be enclosed in
[], like *[::1]* or *[::1]:11211*. This follows the accepted prescription
for `IPv6 host literals`_.
Examples::
server.org
server.org:1337
127.0.0.1:1337
[::1]:1337
[::1]
.. _IPv6 host literals: https://tools.ietf.org/html/rfc3986#section-3.2.2
"""
port = default_port
# IPv6 addresses must be between '[]'
if socket_string.startswith('['):
match = IPV6_RE.match(socket_string)
if not match:
raise ValueError("Invalid IPv6 address: %s" % socket_string)
host = match.group('address')
port = match.group('port') or port
else:
if ':' in socket_string:
tokens = socket_string.split(':')
if len(tokens) > 2:
raise ValueError("IPv6 addresses must be between '[]'")
host, port = tokens
else:
host = socket_string
return (host, port)
def storage_directory(datadir, partition, name_hash): def storage_directory(datadir, partition, name_hash):
""" """
Get the storage directory Get the storage directory

View File

@ -566,26 +566,5 @@ class TestMemcached(unittest.TestCase):
finally: finally:
memcached.MemcacheConnPool = orig_conn_pool memcached.MemcacheConnPool = orig_conn_pool
def test_connection_pool_parser(self):
default = memcached.DEFAULT_MEMCACHED_PORT
addrs = [('1.2.3.4', '1.2.3.4', default),
('1.2.3.4:5000', '1.2.3.4', 5000),
('[dead:beef::1]', 'dead:beef::1', default),
('[dead:beef::1]:5000', 'dead:beef::1', 5000),
('example.com', 'example.com', default),
('example.com:5000', 'example.com', 5000),
('foo.1-2-3.bar.com:5000', 'foo.1-2-3.bar.com', 5000),
('1.2.3.4:10:20', None, None),
('dead:beef::1:5000', None, None)]
for addr, expected_host, expected_port in addrs:
if expected_host:
pool = memcached.MemcacheConnPool(addr, 1, 0)
self.assertEqual(expected_host, pool.host)
self.assertEqual(expected_port, int(pool.port))
else:
with self.assertRaises(ValueError):
memcached.MemcacheConnPool(addr, 1, 0)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -5329,5 +5329,28 @@ class TestPairs(unittest.TestCase):
(50, 60)])) (50, 60)]))
class TestSocketStringParser(unittest.TestCase):
def test_socket_string_parser(self):
default = 1337
addrs = [('1.2.3.4', '1.2.3.4', default),
('1.2.3.4:5000', '1.2.3.4', 5000),
('[dead:beef::1]', 'dead:beef::1', default),
('[dead:beef::1]:5000', 'dead:beef::1', 5000),
('example.com', 'example.com', default),
('example.com:5000', 'example.com', 5000),
('foo.1-2-3.bar.com:5000', 'foo.1-2-3.bar.com', 5000),
('1.2.3.4:10:20', None, None),
('dead:beef::1:5000', None, None)]
for addr, expected_host, expected_port in addrs:
if expected_host:
host, port = utils.parse_socket_string(addr, default)
self.assertEqual(expected_host, host)
self.assertEqual(expected_port, int(port))
else:
with self.assertRaises(ValueError):
utils.parse_socket_string(addr, default)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()