diff --git a/zaqar/api/handler.py b/zaqar/api/handler.py index 16a7b73df..d3f1712f0 100644 --- a/zaqar/api/handler.py +++ b/zaqar/api/handler.py @@ -13,6 +13,11 @@ # the License. from zaqar.api.v1_1 import endpoints +from zaqar.api.v1_1 import request as schema_validator + +from zaqar.common.api import request +from zaqar.common.api import response +from zaqar.common import errors class Handler(object): @@ -28,3 +33,40 @@ class Handler(object): def process_request(self, req): # FIXME(vkmc): Control API version return getattr(self.v1_1_endpoints, req._action)(req) + + @staticmethod + def validate_request(payload, req): + """Validate a request and its payload against a schema. + + :return: a Response object if validation failed, None otherwise. + """ + try: + action = payload.get('action') + validator = schema_validator.RequestSchema() + is_valid = validator.validate(action=action, body=payload) + except errors.InvalidAction as ex: + body = {'error': str(ex)} + headers = {'status': 400} + return response.Response(req, body, headers) + else: + if not is_valid: + body = {'error': 'Schema validation failed.'} + headers = {'status': 400} + return response.Response(req, body, headers) + + def create_response(self, code, body, req=None): + if req is None: + req = self.create_request() + headers = {'status': code} + return response.Response(req, body, headers) + + @staticmethod + def create_request(payload=None): + if payload is None: + payload = {} + action = payload.get('action') + body = payload.get('body', {}) + headers = payload.get('headers') + + return request.Request(action=action, body=body, + headers=headers, api="v1.1") diff --git a/zaqar/api/v1_1/request.py b/zaqar/api/v1_1/request.py index 7cfbc3710..929fa7117 100644 --- a/zaqar/api/v1_1/request.py +++ b/zaqar/api/v1_1/request.py @@ -65,6 +65,17 @@ class RequestSchema(api.Api): 'required': ['action', 'headers'], 'admin': True, }, + 'authenticate': { + 'properties': { + 'action': {'enum': ['authenticate']}, + 'headers': { + 'type': 'object', + 'properties': headers, + 'required': ['X-Project-ID', 'X-Auth-Token'] + } + }, + 'required': ['action', 'headers'], + }, # Queues 'queue_list': { diff --git a/zaqar/tests/etc/websocket_mongodb_keystone_auth.conf b/zaqar/tests/etc/websocket_mongodb_keystone_auth.conf new file mode 100644 index 000000000..64e98f0f3 --- /dev/null +++ b/zaqar/tests/etc/websocket_mongodb_keystone_auth.conf @@ -0,0 +1,20 @@ +[DEFAULT] +auth_strategy = keystone + +[drivers] + +# Transport driver to use (string value) +transport = websocket + +# Storage driver to use (string value) +message_store = mongodb + +[drivers:management_store:mongodb] + +# Mongodb Connection URI +uri = mongodb://127.0.0.1:27017 + +[drivers:message_store:mongodb] + +# Mongodb Connection URI +uri = mongodb://127.0.0.1:27017 diff --git a/zaqar/tests/unit/transport/websocket/v1_1/test_auth.py b/zaqar/tests/unit/transport/websocket/v1_1/test_auth.py new file mode 100644 index 000000000..7b8ab771e --- /dev/null +++ b/zaqar/tests/unit/transport/websocket/v1_1/test_auth.py @@ -0,0 +1,120 @@ +# Copyright (c) 2015 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import uuid + +from keystonemiddleware import auth_token +import mock + +from zaqar.tests.unit.transport.websocket import base +from zaqar.tests.unit.transport.websocket import utils as test_utils + + +class AuthTest(base.V1_1Base): + + config_file = "websocket_mongodb_keystone_auth.conf" + + def setUp(self): + super(AuthTest, self).setUp() + self.protocol = self.transport.factory() + + self.default_message_ttl = 3600 + + self.project_id = '7e55e1a7e' + self.headers = { + 'Client-ID': str(uuid.uuid4()), + 'X-Project-ID': self.project_id + } + auth_mock = mock.patch.object(auth_token.AuthProtocol, '__call__') + self.addCleanup(auth_mock.stop) + self.auth = auth_mock.start() + self.env = {'keystone.token_info': { + 'token': {'expires_at': '2035-08-05T15:16:33.603700+00:00'}}} + + def test_post(self): + headers = self.headers.copy() + headers['X-Auth-Token'] = 'mytoken1' + req = json.dumps({'action': 'authenticate', 'headers': headers}) + + msg_mock = mock.patch.object(self.protocol, 'sendMessage') + self.addCleanup(msg_mock.stop) + msg_mock = msg_mock.start() + self.protocol.onMessage(req, False) + + # Didn't send the response yet + self.assertEqual(0, msg_mock.call_count) + + self.assertEqual(1, self.auth.call_count) + responses = [] + self.protocol._auth_start(self.env, lambda x, y: responses.append(x)) + + self.assertEqual(1, len(responses)) + self.assertEqual('200 OK', responses[0]) + + def test_post_between_auth(self): + headers = self.headers.copy() + headers['X-Auth-Token'] = 'mytoken1' + req = json.dumps({'action': 'authenticate', 'headers': headers}) + + msg_mock = mock.patch.object(self.protocol, 'sendMessage') + self.addCleanup(msg_mock.stop) + msg_mock = msg_mock.start() + self.protocol.onMessage(req, False) + + req = test_utils.create_request("queue_list", {}, self.headers) + self.protocol.onMessage(req, False) + + self.assertEqual(1, msg_mock.call_count) + resp = json.loads(msg_mock.call_args[0][0]) + self.assertEqual(resp['headers']['status'], 403) + + def test_failed_auth(self): + msg_mock = mock.patch.object(self.protocol, 'sendMessage') + self.addCleanup(msg_mock.stop) + msg_mock = msg_mock.start() + self.protocol._auth_response('401 error', 'Failed') + self.assertEqual(1, msg_mock.call_count) + resp = json.loads(msg_mock.call_args[0][0]) + self.assertEqual(resp['headers']['status'], 401) + + def test_reauth(self): + headers = self.headers.copy() + headers['X-Auth-Token'] = 'mytoken1' + req = json.dumps({'action': 'authenticate', 'headers': headers}) + + msg_mock = mock.patch.object(self.protocol, 'sendMessage') + self.addCleanup(msg_mock.stop) + msg_mock = msg_mock.start() + self.protocol.onMessage(req, False) + + self.assertEqual(1, self.auth.call_count) + responses = [] + self.protocol._auth_start(self.env, lambda x, y: responses.append(x)) + + self.assertEqual(1, len(responses)) + handle = self.protocol._deauth_handle + self.assertIsNotNone(handle) + + headers = self.headers.copy() + headers['X-Auth-Token'] = 'mytoken2' + req = json.dumps({'action': 'authenticate', 'headers': headers}) + self.protocol.onMessage(req, False) + self.protocol._auth_start(self.env, lambda x, y: responses.append(x)) + + self.assertNotEqual(handle, self.protocol._deauth_handle) + self.assertEqual(2, len(responses)) + self.assertIn('cancelled', repr(handle)) + self.assertNotIn('cancelled', repr(self.protocol._deauth_handle)) diff --git a/zaqar/tests/unit/transport/websocket/v1_1/test_messages.py b/zaqar/tests/unit/transport/websocket/v1_1/test_messages.py index f7f3a2dd9..a3777ef79 100644 --- a/zaqar/tests/unit/transport/websocket/v1_1/test_messages.py +++ b/zaqar/tests/unit/transport/websocket/v1_1/test_messages.py @@ -583,5 +583,5 @@ class MessagesBaseTest(base.V1_1Base): self.assertIn('error', response['body']) self.assertEqual({'status': 400}, response['headers']) self.assertEqual( - {'action': None, 'api': None, 'body': None, 'headers': {}}, + {'action': None, 'api': 'v1.1', 'body': {}, 'headers': {}}, response['request']) diff --git a/zaqar/transport/websocket/driver.py b/zaqar/transport/websocket/driver.py index 3a3bb73ce..531732086 100644 --- a/zaqar/transport/websocket/driver.py +++ b/zaqar/transport/websocket/driver.py @@ -23,6 +23,8 @@ except ImportError: from zaqar.common import decorators from zaqar.i18n import _ +from zaqar.transport import auth +from zaqar.transport import base from zaqar.transport.websocket import factory @@ -48,22 +50,33 @@ def _config_options(): return [(_WS_GROUP, _WS_OPTIONS)] -class Driver(object): +class Driver(base.DriverBase): def __init__(self, conf, api, cache): - self._conf = conf + super(Driver, self).__init__(conf, None, None, None) self._api = api self._cache = cache self._conf.register_opts(_WS_OPTIONS, group=_WS_GROUP) self._ws_conf = self._conf[_WS_GROUP] + if self._conf.auth_strategy: + auth_strategy = auth.strategy(self._conf.auth_strategy) + self._auth_strategy = lambda app: auth_strategy.install( + app, self._conf) + else: + self._auth_strategy = None + @decorators.lazy_property(write=False) def factory(self): uri = 'ws://' + self._ws_conf.bind + ':' + str(self._ws_conf.port) return factory.ProtocolFactory( - uri, debug=self._ws_conf.debug, handler=self._api, - external_port=self._ws_conf.external_port) + uri, + debug=self._ws_conf.debug, + handler=self._api, + external_port=self._ws_conf.external_port, + auth_strategy=self._auth_strategy, + loop=asyncio.get_event_loop()) def listen(self): """Self-host using 'bind' and 'port' from the WS config group.""" diff --git a/zaqar/transport/websocket/factory.py b/zaqar/transport/websocket/factory.py index 50f82b6dd..4cf5812a9 100644 --- a/zaqar/transport/websocket/factory.py +++ b/zaqar/transport/websocket/factory.py @@ -22,12 +22,15 @@ class ProtocolFactory(websocket.WebSocketServerFactory): protocol = protocol.MessagingProtocol - def __init__(self, uri, debug, handler, external_port): + def __init__(self, uri, debug, handler, external_port, auth_strategy, + loop): websocket.WebSocketServerFactory.__init__( self, url=uri, debug=debug, externalPort=external_port) self._handler = handler + self._auth_strategy = auth_strategy + self._loop = loop def __call__(self): - proto = self.protocol(self._handler) + proto = self.protocol(self._handler, self._auth_strategy, self._loop) proto.factory = self return proto diff --git a/zaqar/transport/websocket/protocol.py b/zaqar/transport/websocket/protocol.py index 69218ce70..06b237a80 100644 --- a/zaqar/transport/websocket/protocol.py +++ b/zaqar/transport/websocket/protocol.py @@ -13,24 +13,39 @@ # See the License for the specific language governing permissions and # limitations under the License. -from autobahn.asyncio import websocket -from oslo_log import log as logging - +import datetime import json -from zaqar.api.v1_1 import request as schema_validator -from zaqar.common.api import request -from zaqar.common.api import response -from zaqar.common import errors +from autobahn.asyncio import websocket +from oslo_log import log as logging +from oslo_utils import timeutils +import pytz LOG = logging.getLogger(__name__) +_EPOCH = datetime.datetime(1970, 1, 1, tzinfo=pytz.UTC) + class MessagingProtocol(websocket.WebSocketServerProtocol): - def __init__(self, handler): + _fake_env = { + 'REQUEST_METHOD': 'POST', + 'SERVER_NAME': 'zaqar', + 'SERVER_PORT': 80, + 'SERVER_PROTOCOL': 'HTTP/1.1', + 'PATH_INFO': '/', + 'SCRIPT_NAME': '', + 'wsgi.url_scheme': 'http' + } + + def __init__(self, handler, auth_strategy, loop): websocket.WebSocketServerProtocol.__init__(self) self._handler = handler + self._auth_strategy = auth_strategy + self._loop = loop + self._authentified = False + self._auth_app = None + self._deauth_handle = None def onConnect(self, request): print("Client connecting: {0}".format(request.peer)) @@ -43,61 +58,75 @@ class MessagingProtocol(websocket.WebSocketServerProtocol): # TODO(vkmc): Binary support will be added in the next cycle # For now, we are returning an invalid request response print("Binary message received: {0} bytes".format(len(payload))) - req = self._dummy_request() body = {'error': 'Schema validation failed.'} - headers = {'status': 400} - resp = response.Response(req, body, headers) - return resp - else: - try: - print("Text message received: {0}".format(payload)) - pl = json.loads(payload) - req = self._create_request(pl) - resp = (self._validate_request(pl, req) or - self._handler.process_request(req)) - except ValueError as ex: - LOG.exception(ex) - req = self._dummy_request() - body = {'error': str(ex)} - headers = {'status': 400} - resp = response.Response(req, body, headers) + resp = self._handler.create_response(400, body) + return self._send_response(resp) + try: + print("Text message received: {0}".format(payload)) + payload = json.loads(payload) + except ValueError as ex: + LOG.exception(ex) + body = {'error': str(ex)} + resp = self._handler.create_response(400, body) + return self._send_response(resp) - resp_json = json.dumps(resp.get_response()) - self.sendMessage(resp_json, isBinary) + req = self._handler.create_request(payload) + resp = self._handler.validate_request(payload, req) + if resp is None: + if self._auth_strategy and not self._authentified: + if self._auth_app or payload.get('action') != 'authenticate': + body = {'error': 'Not authentified.'} + resp = self._handler.create_response(403, body) + else: + return self._authenticate(payload) + elif payload.get('action') == 'authenticate': + return self._authenticate(payload) + else: + resp = self._handler.process_request(req) + return self._send_response(resp) def onClose(self, wasClean, code, reason): print("WebSocket connection closed: {0}".format(reason)) - @staticmethod - def _create_request(pl): - action = pl.get('action') - body = pl.get('body', {}) - headers = pl.get('headers') + def _authenticate(self, payload): + self._auth_app = self._auth_strategy(self._auth_start) + env = self._fake_env.copy() + env.update( + (self._header_to_env_var(key), value) + for key, value in payload.get('headers').items()) + self._auth_app(env, self._auth_response) - return request.Request(action=action, body=body, - headers=headers, api="v1.1") + def _auth_start(self, env, start_response): + self._authentified = True + self._auth_app = None + expire = env['keystone.token_info']['token']['expires_at'] + expire_time = timeutils.parse_isotime(expire) + timestamp = (expire_time - _EPOCH).total_seconds() + if self._deauth_handle is not None: + self._deauth_handle.cancel() + self._deauth_handle = self._loop.call_at( + timestamp, self._deauthenticate) - @staticmethod - def _validate_request(pl, req): - try: - action = pl.get('action') - validator = schema_validator.RequestSchema() - is_valid = validator.validate(action=action, body=pl) - except errors.InvalidAction as ex: - body = {'error': str(ex)} - headers = {'status': 400} - resp = response.Response(req, body, headers) - return resp + start_response('200 OK', []) + + def _deauthenticate(self): + self._authentified = False + self.sendClose(403, 'Authentication expired.') + + def _auth_response(self, status, message): + code = int(status.split()[0]) + if code != 200: + body = {'error': 'Authentication failed.'} + resp = self._handler.create_response(code, body) + self._send_response(resp) else: - if not is_valid: - body = {'error': 'Schema validation failed.'} - headers = {'status': 400} - resp = response.Response(req, body, headers) - return resp + body = {'message': 'Authentified.'} + resp = self._handler.create_response(200, body) + self._send_response(resp) - return None + def _header_to_env_var(self, key): + return 'HTTP_%s' % key.replace('-', '_').upper() - @staticmethod - def _dummy_request(): - action = None - return request.Request(action) + def _send_response(self, resp): + resp_json = json.dumps(resp.get_response()) + self.sendMessage(resp_json, False)