Merge "Make _get_addr() method a function in utils."

This commit is contained in:
Jenkins 2016-02-22 20:30:22 +00:00 committed by Gerrit Code Review
commit d9f500a128
4 changed files with 68 additions and 56 deletions

View File

@ -47,7 +47,6 @@ http://github.com/memcached/memcached/blob/1.4.2/doc/protocol.txt
import six.moves.cPickle as pickle
import json
import logging
import re
import time
from bisect import bisect
from swift import gettext_ as _
@ -57,7 +56,7 @@ from eventlet.green import socket
from eventlet.pools import Pool
from eventlet import Timeout
from six.moves import range
from swift.common import utils
DEFAULT_MEMCACHED_PORT = 11211
@ -106,45 +105,16 @@ class MemcacheConnPool(Pool):
Connection pool for Memcache Connections
The *server* parameter can be a hostname, an IPv4 address, or an IPv6
address with an optional port. If an IPv6 address is specified it **must**
be enclosed in [], like *[::1]* or *[::1]:11211*. This follows the accepted
prescription for `IPv6 host literals`_.
Examples::
memcache.local:11211
127.0.0.1:11211
[::1]:11211
[::1]
.. _IPv6 host literals: https://tools.ietf.org/html/rfc3986#section-3.2.2
address with an optional port. See
:func:`swift.common.utils.parse_socket_string` for details.
"""
IPV6_RE = re.compile("^\[(?P<address>.*)\](:(?P<port>[0-9]+))?$")
def __init__(self, server, size, connect_timeout):
Pool.__init__(self, max_size=size)
self.host, self.port = self._get_addr(server)
self.host, self.port = utils.parse_socket_string(
server, DEFAULT_MEMCACHED_PORT)
self._connect_timeout = connect_timeout
def _get_addr(self, server):
port = DEFAULT_MEMCACHED_PORT
# IPv6 addresses must be between '[]'
if server.startswith('['):
match = MemcacheConnPool.IPV6_RE.match(server)
if not match:
raise ValueError("Invalid IPv6 address: %s" % server)
host = match.group('address')
port = match.group('port') or port
else:
if ':' in server:
tokens = server.split(':')
if len(tokens) > 2:
raise ValueError("IPv6 addresses must be between '[]'")
host, port = tokens
else:
host = server
return (host, port)
def create(self):
addrs = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC,
socket.SOCK_STREAM)

View File

@ -112,6 +112,9 @@ SWIFT_CONF_FILE = '/etc/swift/swift.conf'
AF_ALG = getattr(socket, 'AF_ALG', 38)
F_SETPIPE_SZ = getattr(fcntl, 'F_SETPIPE_SZ', 1031)
# Used by the parse_socket_string() function to validate IPv6 addresses
IPV6_RE = re.compile("^\[(?P<address>.*)\](:(?P<port>[0-9]+))?$")
class InvalidHashPathConfigError(ValueError):
@ -1799,6 +1802,43 @@ def whataremyips(bind_ip=None):
return addresses
def parse_socket_string(socket_string, default_port):
"""
Given a string representing a socket, returns a tuple of (host, port).
Valid strings are DNS names, IPv4 addresses, or IPv6 addresses, with an
optional port. If an IPv6 address is specified it **must** be enclosed in
[], like *[::1]* or *[::1]:11211*. This follows the accepted prescription
for `IPv6 host literals`_.
Examples::
server.org
server.org:1337
127.0.0.1:1337
[::1]:1337
[::1]
.. _IPv6 host literals: https://tools.ietf.org/html/rfc3986#section-3.2.2
"""
port = default_port
# IPv6 addresses must be between '[]'
if socket_string.startswith('['):
match = IPV6_RE.match(socket_string)
if not match:
raise ValueError("Invalid IPv6 address: %s" % socket_string)
host = match.group('address')
port = match.group('port') or port
else:
if ':' in socket_string:
tokens = socket_string.split(':')
if len(tokens) > 2:
raise ValueError("IPv6 addresses must be between '[]'")
host, port = tokens
else:
host = socket_string
return (host, port)
def storage_directory(datadir, partition, name_hash):
"""
Get the storage directory

View File

@ -566,26 +566,5 @@ class TestMemcached(unittest.TestCase):
finally:
memcached.MemcacheConnPool = orig_conn_pool
def test_connection_pool_parser(self):
default = memcached.DEFAULT_MEMCACHED_PORT
addrs = [('1.2.3.4', '1.2.3.4', default),
('1.2.3.4:5000', '1.2.3.4', 5000),
('[dead:beef::1]', 'dead:beef::1', default),
('[dead:beef::1]:5000', 'dead:beef::1', 5000),
('example.com', 'example.com', default),
('example.com:5000', 'example.com', 5000),
('foo.1-2-3.bar.com:5000', 'foo.1-2-3.bar.com', 5000),
('1.2.3.4:10:20', None, None),
('dead:beef::1:5000', None, None)]
for addr, expected_host, expected_port in addrs:
if expected_host:
pool = memcached.MemcacheConnPool(addr, 1, 0)
self.assertEqual(expected_host, pool.host)
self.assertEqual(expected_port, int(pool.port))
else:
with self.assertRaises(ValueError):
memcached.MemcacheConnPool(addr, 1, 0)
if __name__ == '__main__':
unittest.main()

View File

@ -5329,5 +5329,28 @@ class TestPairs(unittest.TestCase):
(50, 60)]))
class TestSocketStringParser(unittest.TestCase):
def test_socket_string_parser(self):
default = 1337
addrs = [('1.2.3.4', '1.2.3.4', default),
('1.2.3.4:5000', '1.2.3.4', 5000),
('[dead:beef::1]', 'dead:beef::1', default),
('[dead:beef::1]:5000', 'dead:beef::1', 5000),
('example.com', 'example.com', default),
('example.com:5000', 'example.com', 5000),
('foo.1-2-3.bar.com:5000', 'foo.1-2-3.bar.com', 5000),
('1.2.3.4:10:20', None, None),
('dead:beef::1:5000', None, None)]
for addr, expected_host, expected_port in addrs:
if expected_host:
host, port = utils.parse_socket_string(addr, default)
self.assertEqual(expected_host, host)
self.assertEqual(expected_port, int(port))
else:
with self.assertRaises(ValueError):
utils.parse_socket_string(addr, default)
if __name__ == '__main__':
unittest.main()