Make the keycloak middleware thread safe

After moving to uWSGI keycloak middleware was found to be not
thread safe since we hold the token nad relam in an attribute
of a class which initialized once.
So if there are several threads
accessing the class there might have race conditions
lets move them to be parameters on the stack

Change-Id: I5325c7f08f6567d46189d4efdfbc9b0f6f5a8916
This commit is contained in:
Eyal 2019-07-11 11:03:20 +03:00
parent 7c9cc804db
commit 5e2f6ec9ef

View File

@ -63,55 +63,54 @@ class KeycloakAuth(base.ConfigurableMiddleware):
self._get_system_ca_file() self._get_system_ca_file()
self.user_info_endpoint_url = self._conf_get('user_info_endpoint_url', self.user_info_endpoint_url = self._conf_get('user_info_endpoint_url',
KEYCLOAK_GROUP) KEYCLOAK_GROUP)
self.decoded = {}
@property @property
def reject_auth_headers(self): def reject_auth_headers(self):
header_val = 'Keycloak uri=\'%s\'' % self.auth_url header_val = 'Keycloak uri=\'%s\'' % self.auth_url
return [('WWW-Authenticate', header_val)] return [('WWW-Authenticate', header_val)]
@property @staticmethod
def roles(self): def roles(decoded):
return ','.join(self.decoded['realm_access']['roles']) \ return ','.join(decoded['realm_access']['roles']) \
if 'realm_access' in self.decoded else '' if 'realm_access' in decoded else ''
@property @staticmethod
def realm_name(self): def realm_name(decoded):
# Get user realm from parsed token # Get user realm from parsed token
# Format is "iss": "http://<host>:<port>/auth/realms/<realm_name>", # Format is "iss": "http://<host>:<port>/auth/realms/<realm_name>",
__, __, realm_name = self.decoded['iss'].strip().rpartition('/realms/') __, __, realm_name = decoded['iss'].strip().rpartition('/realms/')
return realm_name return realm_name
def process_request(self, req): def process_request(self, req):
self._authenticate(req) self._authenticate(req)
def _authenticate(self, req): def _authenticate(self, req):
self.token = req.headers.get('X-Auth-Token') decoded = {}
if self.token: token = req.headers.get('X-Auth-Token')
self._decode() if token:
decoded = self._decode(token)
else: else:
message = 'Auth token must be provided in "X-Auth-Token" header.' message = 'Auth token must be provided in "X-Auth-Token" header.'
self._unauthorized(message) self._unauthorized(message)
self.call_keycloak() self.call_keycloak(token, decoded)
self._set_req_headers(req) self._set_req_headers(req, decoded)
def _decode(self): def _decode(self, token):
try: try:
self.decoded = jwt.decode(self.token, algorithms=['RS256'], return jwt.decode(token, algorithms=['RS256'], verify=False)
verify=False)
except jwt.DecodeError: except jwt.DecodeError:
message = "Token can't be decoded because of wrong format." message = "Token can't be decoded because of wrong format."
self._unauthorized(message) self._unauthorized(message)
def call_keycloak(self): def call_keycloak(self, token, decoded):
if self.user_info_endpoint_url.startswith(('http://', 'https://')): if self.user_info_endpoint_url.startswith(('http://', 'https://')):
endpoint = self.user_info_endpoint_url endpoint = self.user_info_endpoint_url
else: else:
endpoint = ('%s' + self.user_info_endpoint_url) % \ endpoint = ('%s' + self.user_info_endpoint_url) % \
(self.auth_url, self.realm_name) (self.auth_url, self.realm_name(decoded))
headers = {'Authorization': 'Bearer %s' % self.token} headers = {'Authorization': 'Bearer %s' % token}
verify = None verify = None
if urllib.parse.urlparse(endpoint).scheme == "https": if urllib.parse.urlparse(endpoint).scheme == "https":
verify = False if self.insecure else self.cafile verify = False if self.insecure else self.cafile
@ -123,10 +122,10 @@ class KeycloakAuth(base.ConfigurableMiddleware):
if not resp.ok: if not resp.ok:
abort(resp.status_code, resp.reason) abort(resp.status_code, resp.reason)
def _set_req_headers(self, req): def _set_req_headers(self, req, decoded):
req.headers['X-Identity-Status'] = 'Confirmed' req.headers['X-Identity-Status'] = 'Confirmed'
req.headers['X-Roles'] = self.roles req.headers['X-Roles'] = self.roles(decoded)
req.headers["X-Project-Id"] = self.realm_name req.headers["X-Project-Id"] = self.realm_name(decoded)
def _unauthorized(self, message): def _unauthorized(self, message):
body = {'error': { body = {'error': {