Refactor auth_token cache

There are very different paths for when we are encrypting tokens before
saving them to memcache. Extract the encrypted strategy out to its own
class to see the similarities.

Change-Id: I5bf5dae3143f15a840a580a2e01ba67c50911276
This commit is contained in:
Jamie Lennox 2014-07-08 10:50:09 +10:00 committed by Brant Knudson
parent e77a7a225b
commit b99291cbdc

View File

@ -661,8 +661,6 @@ class AuthProtocol(object):
val = '%s/revoked.pem' % self._signing_dirname
self._revoked_file_name = val
self._memcache_security_strategy = (
self._conf_get('memcache_security_strategy'))
self._token_cache = self._token_cache_factory()
self._token_revocation_list_prop = None
self._token_revocation_list_fetched_time_prop = None
@ -1288,14 +1286,13 @@ class AuthProtocol(object):
return identity_server
def _token_cache_factory(self):
token_cache = _TokenCache(
self._LOG,
security_strategy = self._conf_get('memcache_security_strategy')
cache_kwargs = dict(
cache_time=int(self._conf_get('token_cache_time')),
hash_algorithms=self._conf_get('hash_algorithms'),
env_cache_name=self._conf_get('cache'),
memcached_servers=self._conf_get('memcached_servers'),
memcache_security_strategy=self._memcache_security_strategy,
memcache_secret_key=self._conf_get('memcache_secret_key'),
use_advanced_pool=self._conf_get('memcache_use_advanced_pool'),
memcache_pool_dead_retry=self._conf_get(
'memcache_pool_dead_retry'),
@ -1305,8 +1302,16 @@ class AuthProtocol(object):
memcache_pool_conn_get_timeout=self._conf_get(
'memcache_pool_conn_get_timeout'),
memcache_pool_socket_timeout=self._conf_get(
'memcache_pool_socket_timeout'))
return token_cache
'memcache_pool_socket_timeout'),
)
if security_strategy:
return _SecureTokenCache(self._LOG,
security_strategy,
self._conf_get('memcache_secret_key'),
**cache_kwargs)
else:
return _TokenCache(self._LOG, **cache_kwargs)
class _CachePool(list):
@ -1630,7 +1635,6 @@ class _TokenCache(object):
def __init__(self, log, cache_time=None, hash_algorithms=None,
env_cache_name=None, memcached_servers=None,
memcache_security_strategy=None, memcache_secret_key=None,
use_advanced_pool=False, memcache_pool_dead_retry=None,
memcache_pool_maxsize=None, memcache_pool_unused_timeout=None,
memcache_pool_conn_get_timeout=None,
@ -1647,18 +1651,9 @@ class _TokenCache(object):
self._memcache_pool_conn_get_timeout = memcache_pool_conn_get_timeout
self._memcache_pool_socket_timeout = memcache_pool_socket_timeout
# memcache value treatment, ENCRYPT or MAC
self._memcache_security_strategy = memcache_security_strategy
if self._memcache_security_strategy is not None:
self._memcache_security_strategy = (
self._memcache_security_strategy.upper())
self._memcache_secret_key = memcache_secret_key
self._cache_pool = None
self._initialized = False
self._assert_valid_memcache_protection_config()
def _get_cache_pool(self, cache, memcache_servers, use_advanced_pool=False,
memcache_dead_retry=None, memcache_pool_maxsize=None,
memcache_pool_unused_timeout=None,
@ -1738,15 +1733,57 @@ class _TokenCache(object):
self._LOG.debug('Marking token as unauthorized in cache')
self._cache_store(token_id, self._INVALID_INDICATOR)
def _assert_valid_memcache_protection_config(self):
if self._memcache_security_strategy:
if self._memcache_security_strategy not in ('MAC', 'ENCRYPT'):
raise ConfigurationError('memcache_security_strategy must be '
'ENCRYPT or MAC')
if not self._memcache_secret_key:
raise ConfigurationError('memcache_secret_key must be defined '
'when a memcache_security_strategy '
'is defined')
def _get_cache_key(self, token_id):
"""Get a unique key for this token id.
Turn the token_id into something that can uniquely identify that token
in a key value store.
As this is generally the first function called in a key lookup this
function also returns a context object. This context object is not
modified or used by the Cache object but is passed back on subsequent
functions so that decryption or other data can be shared throughout a
cache lookup.
:param str token_id: The unique token id.
:returns: A tuple of a string key and an implementation specific
context object
"""
# NOTE(jamielennox): in the basic implementation there is no need for
# a context so just pass None as it will only get passed back later.
unused_context = None
return self._CACHE_KEY_TEMPLATE % token_id, unused_context
def _deserialize(self, data, context):
"""Deserialize data from the cache back into python objects.
Take data retrieved from the cache and return an appropriate python
dictionary.
:param str data: The data retrieved from the cache.
:param object context: The context that was returned from
_get_cache_key.
:returns: The python object that was saved.
"""
# memory cache will handle deserialization for us
return data
def _serialize(self, data, context):
"""Serialize data so that it can be saved to the cache.
Take python objects and serialize them so that they can be saved into
the cache.
:param object data: The data to be cached.
:param object context: The context that was returned from
_get_cache_key.
:returns: The python object that was saved.
"""
# memory cache will handle serialization for us
return data
def _cache_get(self, token_id):
"""Return token information from cache.
@ -1759,45 +1796,22 @@ class _TokenCache(object):
# Nothing to do
return
if self._memcache_security_strategy is None:
key = self._CACHE_KEY_TEMPLATE % token_id
key, context = self._get_cache_key(token_id)
with self._cache_pool.reserve() as cache:
serialized = cache.get(key)
else:
secret_key = self._memcache_secret_key
if isinstance(secret_key, six.string_types):
secret_key = secret_key.encode('utf-8')
security_strategy = self._memcache_security_strategy
if isinstance(security_strategy, six.string_types):
security_strategy = security_strategy.encode('utf-8')
keys = memcache_crypt.derive_keys(
token_id,
secret_key,
security_strategy)
cache_key = self._CACHE_KEY_TEMPLATE % (
memcache_crypt.get_cache_key(keys))
with self._cache_pool.reserve() as cache:
raw_cached = cache.get(cache_key)
try:
# unprotect_data will return None if raw_cached is None
serialized = memcache_crypt.unprotect_data(keys,
raw_cached)
except Exception:
msg = 'Failed to decrypt/verify cache data'
self._LOG.exception(msg)
# this should have the same effect as data not
# found in cache
serialized = None
if serialized is None:
return None
data = self._deserialize(serialized, context)
# Note that _INVALID_INDICATOR and (data, expires) are the only
# valid types of serialized cache entries, so there is not
# a collision with jsonutils.loads(serialized) == None.
if not isinstance(serialized, six.string_types):
serialized = serialized.decode('utf-8')
cached = jsonutils.loads(serialized)
if not isinstance(data, six.string_types):
data = data.decode('utf-8')
cached = jsonutils.loads(data)
if cached == self._INVALID_INDICATOR:
self._LOG.debug('Cached Token is marked unauthorized')
raise InvalidToken('Token authorization failed')
@ -1827,29 +1841,68 @@ class _TokenCache(object):
data may be _INVALID_INDICATOR or a tuple like (data, expires)
"""
serialized_data = jsonutils.dumps(data)
if isinstance(serialized_data, six.text_type):
serialized_data = serialized_data.encode('utf-8')
if self._memcache_security_strategy is None:
cache_key = self._CACHE_KEY_TEMPLATE % token_id
data_to_store = serialized_data
else:
secret_key = self._memcache_secret_key
if isinstance(secret_key, six.string_types):
secret_key = secret_key.encode('utf-8')
security_strategy = self._memcache_security_strategy
if isinstance(security_strategy, six.string_types):
security_strategy = security_strategy.encode('utf-8')
keys = memcache_crypt.derive_keys(
token_id, secret_key, security_strategy)
cache_key = memcache_crypt.get_cache_key(keys)
cache_key = self._CACHE_KEY_TEMPLATE % cache_key
data_to_store = memcache_crypt.protect_data(keys, serialized_data)
data = jsonutils.dumps(data)
if isinstance(data, six.text_type):
data = data.encode('utf-8')
cache_key, context = self._get_cache_key(token_id)
data_to_store = self._serialize(data, context)
with self._cache_pool.reserve() as cache:
cache.set(cache_key, data_to_store, time=self._cache_time)
class _SecureTokenCache(_TokenCache):
"""A token cache that stores tokens encrypted.
A more secure version of _TokenCache that will encrypt tokens before
caching them.
"""
def __init__(self, log, security_strategy, secret_key, **kwargs):
super(_SecureTokenCache, self).__init__(log, **kwargs)
security_strategy = security_strategy.upper()
if security_strategy not in ('MAC', 'ENCRYPT'):
raise ConfigurationError('memcache_security_strategy must be '
'ENCRYPT or MAC')
if not secret_key:
raise ConfigurationError('memcache_secret_key must be defined '
'when a memcache_security_strategy '
'is defined')
if isinstance(security_strategy, six.string_types):
security_strategy = security_strategy.encode('utf-8')
if isinstance(secret_key, six.string_types):
secret_key = secret_key.encode('utf-8')
self._security_strategy = security_strategy
self._secret_key = secret_key
def _get_cache_key(self, token_id):
context = memcache_crypt.derive_keys(token_id,
self._secret_key,
self._security_strategy)
key = self._CACHE_KEY_TEMPLATE % memcache_crypt.get_cache_key(context)
return key, context
def _deserialize(self, data, context):
try:
# unprotect_data will return None if raw_cached is None
return memcache_crypt.unprotect_data(context, data)
except Exception:
msg = 'Failed to decrypt/verify cache data'
self._LOG.exception(msg)
# this should have the same effect as data not
# found in cache
return None
def _serialize(self, data, context):
return memcache_crypt.protect_data(context, data)
def filter_factory(global_conf, **local_conf):
"""Returns a WSGI filter app for use with paste.deploy."""
conf = global_conf.copy()