diff --git a/keystonemiddleware/auth_token/__init__.py b/keystonemiddleware/auth_token/__init__.py index a87b69fc..7c522e1d 100644 --- a/keystonemiddleware/auth_token/__init__.py +++ b/keystonemiddleware/auth_token/__init__.py @@ -530,6 +530,8 @@ class AuthProtocol(BaseAuthProtocol): 'include_service_catalog') self._hash_algorithms = self._conf.get('hash_algorithms') + self._auth = self._create_auth_plugin() + self._session = self._create_session() self._identity_server = self._create_identity_server() self._auth_uri = self._conf.get('auth_uri') @@ -610,6 +612,10 @@ class AuthProtocol(BaseAuthProtocol): if self._include_service_catalog: request.set_service_catalog_headers(request.token_auth.user) + if request.token_auth: + request.token_auth._auth = self._auth + request.token_auth._session = self._session + if request.service_token and request.service_token_valid: request.set_service_headers(request.token_auth.service) @@ -813,7 +819,7 @@ class AuthProtocol(BaseAuthProtocol): self._SIGNING_CA_FILE_NAME, self._identity_server.fetch_ca_cert()) - def _get_auth_plugin(self): + def _create_auth_plugin(self): # NOTE(jamielennox): Ideally this would use load_from_conf_options # however that is not possible because we have to support the override # pattern we use in _conf.get. This function therefore does a manual @@ -852,25 +858,24 @@ class AuthProtocol(BaseAuthProtocol): getter = lambda opt: self._conf.get(opt.dest, group=group) return plugin_loader.load_from_options_getter(getter) - def _create_identity_server(self): + def _create_session(self, **kwargs): # NOTE(jamielennox): Loading Session here should be exactly the # same as calling Session.load_from_conf_options(CONF, GROUP) # however we can't do that because we have to use _conf.get to # support the paste.ini options. - sess = session_loading.Session().load_from_options( - cert=self._conf.get('certfile'), - key=self._conf.get('keyfile'), - cacert=self._conf.get('cafile'), - insecure=self._conf.get('insecure'), - timeout=self._conf.get('http_connect_timeout'), - user_agent=self._conf.user_agent, - ) + kwargs.setdefault('cert', self._conf.get('certfile')) + kwargs.setdefault('key', self._conf.get('keyfile')) + kwargs.setdefault('cacert', self._conf.get('cafile')) + kwargs.setdefault('insecure', self._conf.get('insecure')) + kwargs.setdefault('timeout', self._conf.get('http_connect_timeout')) + kwargs.setdefault('user_agent', self._conf.user_agent) - auth_plugin = self._get_auth_plugin() + return session_loading.Session().load_from_options(**kwargs) + def _create_identity_server(self): adap = adapter.Adapter( - sess, - auth=auth_plugin, + self._session, + auth=self._auth, service_type='identity', interface='admin', region_name=self._conf.get('region_name'), diff --git a/keystonemiddleware/auth_token/_user_plugin.py b/keystonemiddleware/auth_token/_user_plugin.py index ccddfc5f..0d558f39 100644 --- a/keystonemiddleware/auth_token/_user_plugin.py +++ b/keystonemiddleware/auth_token/_user_plugin.py @@ -31,12 +31,17 @@ class UserAuthPlugin(base_identity.BaseIdentityPlugin): authentication plugin when communicating via a session. """ - def __init__(self, user_auth_ref, serv_auth_ref): + def __init__(self, user_auth_ref, serv_auth_ref, session=None, auth=None): super(UserAuthPlugin, self).__init__(reauthenticate=False) self.user = user_auth_ref self.service = serv_auth_ref + # NOTE(jamielennox): adding a service token requires the original + # session and auth plugin from auth_token + self._session = session + self._auth = auth + @property def has_user_token(self): """Did this authentication request contained a user auth token.""" @@ -64,3 +69,14 @@ class UserAuthPlugin(base_identity.BaseIdentityPlugin): msg.append('service: %s' % _log_format(self.service)) return ' '.join(msg) + + def get_headers(self, session, **kwargs): + headers = super(UserAuthPlugin, self).get_headers(session, **kwargs) + + if headers is not None and self._session: + token = self._session.get_token(auth=self._auth) + + if token: + headers['X-Service-Token'] = token + + return headers diff --git a/keystonemiddleware/tests/unit/auth_token/test_auth_token_middleware.py b/keystonemiddleware/tests/unit/auth_token/test_auth_token_middleware.py index 443a5927..f2b1f45f 100644 --- a/keystonemiddleware/tests/unit/auth_token/test_auth_token_middleware.py +++ b/keystonemiddleware/tests/unit/auth_token/test_auth_token_middleware.py @@ -1295,6 +1295,28 @@ class CommonAuthTokenMiddlewareTest(object): r = m(env, _start_response) self.assertEqual(text, r) + def test_auth_plugin_service_token(self): + url = 'http://test.url' + text = uuid.uuid4().hex + self.requests_mock.get(url, text=text) + + token = self.token_dict['uuid_token_default'] + resp = self.call_middleware(headers={'X-Auth-Token': token}) + + self.assertEqual(200, resp.status_int) + self.assertEqual(FakeApp.SUCCESS, resp.body) + + s = session.Session(auth=resp.request.environ['keystone.token_auth']) + + resp = s.get(url) + + self.assertEqual(text, resp.text) + self.assertEqual(200, resp.status_code) + + headers = self.requests_mock.last_request.headers + + self.assertEqual(FAKE_ADMIN_TOKEN_ID, headers['X-Service-Token']) + class V2CertDownloadMiddlewareTest(BaseAuthTokenMiddlewareTest, testresources.ResourcedTestCase):