Merge "Make _get_addr() method a function in utils."
This commit is contained in:
commit
d9f500a128
@ -47,7 +47,6 @@ http://github.com/memcached/memcached/blob/1.4.2/doc/protocol.txt
|
||||
import six.moves.cPickle as pickle
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from bisect import bisect
|
||||
from swift import gettext_ as _
|
||||
@ -57,7 +56,7 @@ from eventlet.green import socket
|
||||
from eventlet.pools import Pool
|
||||
from eventlet import Timeout
|
||||
from six.moves import range
|
||||
|
||||
from swift.common import utils
|
||||
|
||||
DEFAULT_MEMCACHED_PORT = 11211
|
||||
|
||||
@ -106,45 +105,16 @@ class MemcacheConnPool(Pool):
|
||||
Connection pool for Memcache Connections
|
||||
|
||||
The *server* parameter can be a hostname, an IPv4 address, or an IPv6
|
||||
address with an optional port. If an IPv6 address is specified it **must**
|
||||
be enclosed in [], like *[::1]* or *[::1]:11211*. This follows the accepted
|
||||
prescription for `IPv6 host literals`_.
|
||||
|
||||
Examples::
|
||||
|
||||
memcache.local:11211
|
||||
127.0.0.1:11211
|
||||
[::1]:11211
|
||||
[::1]
|
||||
|
||||
.. _IPv6 host literals: https://tools.ietf.org/html/rfc3986#section-3.2.2
|
||||
address with an optional port. See
|
||||
:func:`swift.common.utils.parse_socket_string` for details.
|
||||
"""
|
||||
IPV6_RE = re.compile("^\[(?P<address>.*)\](:(?P<port>[0-9]+))?$")
|
||||
|
||||
def __init__(self, server, size, connect_timeout):
|
||||
Pool.__init__(self, max_size=size)
|
||||
self.host, self.port = self._get_addr(server)
|
||||
self.host, self.port = utils.parse_socket_string(
|
||||
server, DEFAULT_MEMCACHED_PORT)
|
||||
self._connect_timeout = connect_timeout
|
||||
|
||||
def _get_addr(self, server):
|
||||
port = DEFAULT_MEMCACHED_PORT
|
||||
# IPv6 addresses must be between '[]'
|
||||
if server.startswith('['):
|
||||
match = MemcacheConnPool.IPV6_RE.match(server)
|
||||
if not match:
|
||||
raise ValueError("Invalid IPv6 address: %s" % server)
|
||||
host = match.group('address')
|
||||
port = match.group('port') or port
|
||||
else:
|
||||
if ':' in server:
|
||||
tokens = server.split(':')
|
||||
if len(tokens) > 2:
|
||||
raise ValueError("IPv6 addresses must be between '[]'")
|
||||
host, port = tokens
|
||||
else:
|
||||
host = server
|
||||
return (host, port)
|
||||
|
||||
def create(self):
|
||||
addrs = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC,
|
||||
socket.SOCK_STREAM)
|
||||
|
@ -112,6 +112,9 @@ SWIFT_CONF_FILE = '/etc/swift/swift.conf'
|
||||
AF_ALG = getattr(socket, 'AF_ALG', 38)
|
||||
F_SETPIPE_SZ = getattr(fcntl, 'F_SETPIPE_SZ', 1031)
|
||||
|
||||
# Used by the parse_socket_string() function to validate IPv6 addresses
|
||||
IPV6_RE = re.compile("^\[(?P<address>.*)\](:(?P<port>[0-9]+))?$")
|
||||
|
||||
|
||||
class InvalidHashPathConfigError(ValueError):
|
||||
|
||||
@ -1799,6 +1802,43 @@ def whataremyips(bind_ip=None):
|
||||
return addresses
|
||||
|
||||
|
||||
def parse_socket_string(socket_string, default_port):
|
||||
"""
|
||||
Given a string representing a socket, returns a tuple of (host, port).
|
||||
Valid strings are DNS names, IPv4 addresses, or IPv6 addresses, with an
|
||||
optional port. If an IPv6 address is specified it **must** be enclosed in
|
||||
[], like *[::1]* or *[::1]:11211*. This follows the accepted prescription
|
||||
for `IPv6 host literals`_.
|
||||
|
||||
Examples::
|
||||
|
||||
server.org
|
||||
server.org:1337
|
||||
127.0.0.1:1337
|
||||
[::1]:1337
|
||||
[::1]
|
||||
|
||||
.. _IPv6 host literals: https://tools.ietf.org/html/rfc3986#section-3.2.2
|
||||
"""
|
||||
port = default_port
|
||||
# IPv6 addresses must be between '[]'
|
||||
if socket_string.startswith('['):
|
||||
match = IPV6_RE.match(socket_string)
|
||||
if not match:
|
||||
raise ValueError("Invalid IPv6 address: %s" % socket_string)
|
||||
host = match.group('address')
|
||||
port = match.group('port') or port
|
||||
else:
|
||||
if ':' in socket_string:
|
||||
tokens = socket_string.split(':')
|
||||
if len(tokens) > 2:
|
||||
raise ValueError("IPv6 addresses must be between '[]'")
|
||||
host, port = tokens
|
||||
else:
|
||||
host = socket_string
|
||||
return (host, port)
|
||||
|
||||
|
||||
def storage_directory(datadir, partition, name_hash):
|
||||
"""
|
||||
Get the storage directory
|
||||
|
@ -566,26 +566,5 @@ class TestMemcached(unittest.TestCase):
|
||||
finally:
|
||||
memcached.MemcacheConnPool = orig_conn_pool
|
||||
|
||||
def test_connection_pool_parser(self):
|
||||
default = memcached.DEFAULT_MEMCACHED_PORT
|
||||
addrs = [('1.2.3.4', '1.2.3.4', default),
|
||||
('1.2.3.4:5000', '1.2.3.4', 5000),
|
||||
('[dead:beef::1]', 'dead:beef::1', default),
|
||||
('[dead:beef::1]:5000', 'dead:beef::1', 5000),
|
||||
('example.com', 'example.com', default),
|
||||
('example.com:5000', 'example.com', 5000),
|
||||
('foo.1-2-3.bar.com:5000', 'foo.1-2-3.bar.com', 5000),
|
||||
('1.2.3.4:10:20', None, None),
|
||||
('dead:beef::1:5000', None, None)]
|
||||
|
||||
for addr, expected_host, expected_port in addrs:
|
||||
if expected_host:
|
||||
pool = memcached.MemcacheConnPool(addr, 1, 0)
|
||||
self.assertEqual(expected_host, pool.host)
|
||||
self.assertEqual(expected_port, int(pool.port))
|
||||
else:
|
||||
with self.assertRaises(ValueError):
|
||||
memcached.MemcacheConnPool(addr, 1, 0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -5329,5 +5329,28 @@ class TestPairs(unittest.TestCase):
|
||||
(50, 60)]))
|
||||
|
||||
|
||||
class TestSocketStringParser(unittest.TestCase):
|
||||
def test_socket_string_parser(self):
|
||||
default = 1337
|
||||
addrs = [('1.2.3.4', '1.2.3.4', default),
|
||||
('1.2.3.4:5000', '1.2.3.4', 5000),
|
||||
('[dead:beef::1]', 'dead:beef::1', default),
|
||||
('[dead:beef::1]:5000', 'dead:beef::1', 5000),
|
||||
('example.com', 'example.com', default),
|
||||
('example.com:5000', 'example.com', 5000),
|
||||
('foo.1-2-3.bar.com:5000', 'foo.1-2-3.bar.com', 5000),
|
||||
('1.2.3.4:10:20', None, None),
|
||||
('dead:beef::1:5000', None, None)]
|
||||
|
||||
for addr, expected_host, expected_port in addrs:
|
||||
if expected_host:
|
||||
host, port = utils.parse_socket_string(addr, default)
|
||||
self.assertEqual(expected_host, host)
|
||||
self.assertEqual(expected_port, int(port))
|
||||
else:
|
||||
with self.assertRaises(ValueError):
|
||||
utils.parse_socket_string(addr, default)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user