Merge "Memcached client TLS support"

This commit is contained in:
Zuul 2021-01-07 11:35:34 +00:00 committed by Gerrit Code Review
commit 8c611be876
6 changed files with 85 additions and 7 deletions

View File

@ -34,3 +34,25 @@
# #
# How many errors can accumulate before a server is temporarily ignored. # How many errors can accumulate before a server is temporarily ignored.
# error_suppression_limit = 10 # error_suppression_limit = 10
#
# (Optional) Global toggle for TLS usage when comunicating with
# the caching servers.
# tls_enabled = false
#
# (Optional) Path to a file of concatenated CA certificates in PEM
# format necessary to establish the caching server's authenticity.
# If tls_enabled is False, this option is ignored.
# tls_cafile =
#
# (Optional) Path to a single file in PEM format containing the
# client's certificate as well as any number of CA certificates
# needed to establish the certificate's authenticity. This file
# is only required when client side authentication is necessary.
# If tls_enabled is False, this option is ignored.
# tls_certfile =
#
# (Optional) Path to a single file containing the client's private
# key in. Otherwhise the private key will be taken from the file
# specified in tls_certfile. If tls_enabled is False, this option
# is ignored.
# tls_keyfile =

View File

@ -712,6 +712,10 @@ use = egg:swift#memcache
# How many errors can accumulate before a server is temporarily ignored. # How many errors can accumulate before a server is temporarily ignored.
# error_suppression_limit = 10 # error_suppression_limit = 10
# #
# (Optional) Global toggle for TLS usage when comunicating with
# the caching servers.
# tls_enabled =
#
# More options documented in memcache.conf-sample # More options documented in memcache.conf-sample
[filter:ratelimit] [filter:ratelimit]

View File

@ -127,11 +127,12 @@ class MemcacheConnPool(Pool):
:func:`swift.common.utils.parse_socket_string` for details. :func:`swift.common.utils.parse_socket_string` for details.
""" """
def __init__(self, server, size, connect_timeout): def __init__(self, server, size, connect_timeout, tls_context=None):
Pool.__init__(self, max_size=size) Pool.__init__(self, max_size=size)
self.host, self.port = utils.parse_socket_string( self.host, self.port = utils.parse_socket_string(
server, DEFAULT_MEMCACHED_PORT) server, DEFAULT_MEMCACHED_PORT)
self._connect_timeout = connect_timeout self._connect_timeout = connect_timeout
self._tls_context = tls_context
def create(self): def create(self):
addrs = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, addrs = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC,
@ -141,6 +142,9 @@ class MemcacheConnPool(Pool):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
with Timeout(self._connect_timeout): with Timeout(self._connect_timeout):
sock.connect(sockaddr) sock.connect(sockaddr)
if self._tls_context:
sock = self._tls_context.wrap_socket(sock,
server_hostname=self.host)
return (sock.makefile('rwb'), sock) return (sock.makefile('rwb'), sock)
def get(self): def get(self):
@ -159,7 +163,7 @@ class MemcacheRing(object):
def __init__(self, servers, connect_timeout=CONN_TIMEOUT, def __init__(self, servers, connect_timeout=CONN_TIMEOUT,
io_timeout=IO_TIMEOUT, pool_timeout=POOL_TIMEOUT, io_timeout=IO_TIMEOUT, pool_timeout=POOL_TIMEOUT,
tries=TRY_COUNT, allow_pickle=False, allow_unpickle=False, tries=TRY_COUNT, allow_pickle=False, allow_unpickle=False,
max_conns=2, logger=None, max_conns=2, tls_context=None, logger=None,
error_limit_count=ERROR_LIMIT_COUNT, error_limit_count=ERROR_LIMIT_COUNT,
error_limit_time=ERROR_LIMIT_TIME, error_limit_time=ERROR_LIMIT_TIME,
error_limit_duration=ERROR_LIMIT_DURATION): error_limit_duration=ERROR_LIMIT_DURATION):
@ -174,10 +178,10 @@ class MemcacheRing(object):
self._ring[md5hash('%s-%s' % (server, i))] = server self._ring[md5hash('%s-%s' % (server, i))] = server
self._tries = tries if tries <= len(servers) else len(servers) self._tries = tries if tries <= len(servers) else len(servers)
self._sorted = sorted(self._ring) self._sorted = sorted(self._ring)
self._client_cache = dict(((server, self._client_cache = dict((
MemcacheConnPool(server, max_conns, (server, MemcacheConnPool(server, max_conns, connect_timeout,
connect_timeout)) tls_context=tls_context))
for server in servers)) for server in servers))
self._connect_timeout = connect_timeout self._connect_timeout = connect_timeout
self._io_timeout = io_timeout self._io_timeout = io_timeout
self._pool_timeout = pool_timeout self._pool_timeout = pool_timeout

View File

@ -15,12 +15,13 @@
import os import os
from eventlet.green import ssl
from six.moves.configparser import ConfigParser, NoSectionError, NoOptionError from six.moves.configparser import ConfigParser, NoSectionError, NoOptionError
from swift.common.memcached import ( from swift.common.memcached import (
MemcacheRing, CONN_TIMEOUT, POOL_TIMEOUT, IO_TIMEOUT, TRY_COUNT, MemcacheRing, CONN_TIMEOUT, POOL_TIMEOUT, IO_TIMEOUT, TRY_COUNT,
ERROR_LIMIT_COUNT, ERROR_LIMIT_TIME) ERROR_LIMIT_COUNT, ERROR_LIMIT_TIME)
from swift.common.utils import get_logger from swift.common.utils import get_logger, config_true_value
class MemcacheMiddleware(object): class MemcacheMiddleware(object):
@ -87,6 +88,17 @@ class MemcacheMiddleware(object):
'pool_timeout', POOL_TIMEOUT)) 'pool_timeout', POOL_TIMEOUT))
tries = int(memcache_options.get('tries', TRY_COUNT)) tries = int(memcache_options.get('tries', TRY_COUNT))
io_timeout = float(memcache_options.get('io_timeout', IO_TIMEOUT)) io_timeout = float(memcache_options.get('io_timeout', IO_TIMEOUT))
if config_true_value(memcache_options.get('tls_enabled', 'false')):
tls_cafile = memcache_options.get('tls_cafile')
tls_certfile = memcache_options.get('tls_certfile')
tls_keyfile = memcache_options.get('tls_keyfile')
self.tls_context = ssl.create_default_context(
cafile=tls_cafile)
if tls_certfile:
self.tls_context.load_cert_chain(tls_certfile,
tls_keyfile)
else:
self.tls_context = None
error_suppression_interval = float(memcache_options.get( error_suppression_interval = float(memcache_options.get(
'error_suppression_interval', ERROR_LIMIT_TIME)) 'error_suppression_interval', ERROR_LIMIT_TIME))
error_suppression_limit = float(memcache_options.get( error_suppression_limit = float(memcache_options.get(
@ -110,6 +122,7 @@ class MemcacheMiddleware(object):
allow_pickle=(serialization_format == 0), allow_pickle=(serialization_format == 0),
allow_unpickle=(serialization_format <= 1), allow_unpickle=(serialization_format <= 1),
max_conns=max_conns, max_conns=max_conns,
tls_context=self.tls_context,
logger=self.logger, logger=self.logger,
error_limit_count=error_suppression_limit, error_limit_count=error_suppression_limit,
error_limit_time=error_suppression_interval, error_limit_time=error_suppression_interval,

View File

@ -17,6 +17,7 @@ import os
from textwrap import dedent from textwrap import dedent
import unittest import unittest
from eventlet.green import ssl
import mock import mock
from six.moves.configparser import NoSectionError, NoOptionError from six.moves.configparser import NoSectionError, NoOptionError
@ -170,6 +171,22 @@ class TestCacheMiddleware(unittest.TestCase):
self.assertEqual(app.memcache._error_limit_time, 2.5) self.assertEqual(app.memcache._error_limit_time, 2.5)
self.assertEqual(app.memcache._error_limit_duration, 2.5) self.assertEqual(app.memcache._error_limit_duration, 2.5)
def test_conf_inline_tls(self):
fake_context = mock.Mock()
with mock.patch.object(ssl, 'create_default_context',
return_value=fake_context):
with mock.patch.object(memcache, 'ConfigParser',
get_config_parser()):
memcache.MemcacheMiddleware(
FakeApp(),
{'tls_enabled': 'true',
'tls_cafile': 'cafile',
'tls_certfile': 'certfile',
'tls_keyfile': 'keyfile'})
ssl.create_default_context.assert_called_with(cafile='cafile')
fake_context.load_cert_chain.assert_called_with('certfile',
'keyfile')
def test_conf_extra_no_section(self): def test_conf_extra_no_section(self):
with mock.patch.object(memcache, 'ConfigParser', with mock.patch.object(memcache, 'ConfigParser',
get_config_parser(section='foobar')): get_config_parser(section='foobar')):
@ -333,6 +350,7 @@ class TestCacheMiddleware(unittest.TestCase):
pool_timeout = 0.5 pool_timeout = 0.5
tries = 4 tries = 4
io_timeout = 1.0 io_timeout = 1.0
tls_enabled = true
""" """
config_path = os.path.join(tempdir, 'test.conf') config_path = os.path.join(tempdir, 'test.conf')
with open(config_path, 'w') as f: with open(config_path, 'w') as f:
@ -349,6 +367,9 @@ class TestCacheMiddleware(unittest.TestCase):
self.assertEqual(memcache_ring._error_limit_count, 10) self.assertEqual(memcache_ring._error_limit_count, 10)
self.assertEqual(memcache_ring._error_limit_time, 60) self.assertEqual(memcache_ring._error_limit_time, 60)
self.assertEqual(memcache_ring._error_limit_duration, 60) self.assertEqual(memcache_ring._error_limit_duration, 60)
self.assertIsInstance(
list(memcache_ring._client_cache.values())[0]._tls_context,
ssl.SSLContext)
@with_tempdir @with_tempdir
def test_real_memcache_config(self, tempdir): def test_real_memcache_config(self, tempdir):

View File

@ -198,6 +198,20 @@ class TestMemcached(unittest.TestCase):
client = memcached.MemcacheRing([server_socket], logger=self.logger) client = memcached.MemcacheRing([server_socket], logger=self.logger)
self.assertIs(client.logger, self.logger) self.assertIs(client.logger, self.logger)
def test_tls_context_kwarg(self):
with patch('swift.common.memcached.socket.socket'):
server = '%s:%s' % ('[::1]', 11211)
client = memcached.MemcacheRing([server])
self.assertIsNone(client._client_cache[server]._tls_context)
context = mock.Mock()
client = memcached.MemcacheRing([server], tls_context=context)
self.assertIs(client._client_cache[server]._tls_context, context)
key = uuid4().hex.encode('ascii')
list(client._get_conns(key))
context.wrap_socket.assert_called_once()
def test_get_conns(self): def test_get_conns(self):
sock1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock1.bind(('127.0.0.1', 0)) sock1.bind(('127.0.0.1', 0))