Merge "Authentication for websocket"
This commit is contained in:
commit
a777714075
@ -13,6 +13,11 @@
|
||||
# the License.
|
||||
|
||||
from zaqar.api.v1_1 import endpoints
|
||||
from zaqar.api.v1_1 import request as schema_validator
|
||||
|
||||
from zaqar.common.api import request
|
||||
from zaqar.common.api import response
|
||||
from zaqar.common import errors
|
||||
|
||||
|
||||
class Handler(object):
|
||||
@ -28,3 +33,40 @@ class Handler(object):
|
||||
def process_request(self, req):
|
||||
# FIXME(vkmc): Control API version
|
||||
return getattr(self.v1_1_endpoints, req._action)(req)
|
||||
|
||||
@staticmethod
|
||||
def validate_request(payload, req):
|
||||
"""Validate a request and its payload against a schema.
|
||||
|
||||
:return: a Response object if validation failed, None otherwise.
|
||||
"""
|
||||
try:
|
||||
action = payload.get('action')
|
||||
validator = schema_validator.RequestSchema()
|
||||
is_valid = validator.validate(action=action, body=payload)
|
||||
except errors.InvalidAction as ex:
|
||||
body = {'error': str(ex)}
|
||||
headers = {'status': 400}
|
||||
return response.Response(req, body, headers)
|
||||
else:
|
||||
if not is_valid:
|
||||
body = {'error': 'Schema validation failed.'}
|
||||
headers = {'status': 400}
|
||||
return response.Response(req, body, headers)
|
||||
|
||||
def create_response(self, code, body, req=None):
|
||||
if req is None:
|
||||
req = self.create_request()
|
||||
headers = {'status': code}
|
||||
return response.Response(req, body, headers)
|
||||
|
||||
@staticmethod
|
||||
def create_request(payload=None):
|
||||
if payload is None:
|
||||
payload = {}
|
||||
action = payload.get('action')
|
||||
body = payload.get('body', {})
|
||||
headers = payload.get('headers')
|
||||
|
||||
return request.Request(action=action, body=body,
|
||||
headers=headers, api="v1.1")
|
||||
|
@ -65,6 +65,17 @@ class RequestSchema(api.Api):
|
||||
'required': ['action', 'headers'],
|
||||
'admin': True,
|
||||
},
|
||||
'authenticate': {
|
||||
'properties': {
|
||||
'action': {'enum': ['authenticate']},
|
||||
'headers': {
|
||||
'type': 'object',
|
||||
'properties': headers,
|
||||
'required': ['X-Project-ID', 'X-Auth-Token']
|
||||
}
|
||||
},
|
||||
'required': ['action', 'headers'],
|
||||
},
|
||||
|
||||
# Queues
|
||||
'queue_list': {
|
||||
|
20
zaqar/tests/etc/websocket_mongodb_keystone_auth.conf
Normal file
20
zaqar/tests/etc/websocket_mongodb_keystone_auth.conf
Normal file
@ -0,0 +1,20 @@
|
||||
[DEFAULT]
|
||||
auth_strategy = keystone
|
||||
|
||||
[drivers]
|
||||
|
||||
# Transport driver to use (string value)
|
||||
transport = websocket
|
||||
|
||||
# Storage driver to use (string value)
|
||||
message_store = mongodb
|
||||
|
||||
[drivers:management_store:mongodb]
|
||||
|
||||
# Mongodb Connection URI
|
||||
uri = mongodb://127.0.0.1:27017
|
||||
|
||||
[drivers:message_store:mongodb]
|
||||
|
||||
# Mongodb Connection URI
|
||||
uri = mongodb://127.0.0.1:27017
|
120
zaqar/tests/unit/transport/websocket/v1_1/test_auth.py
Normal file
120
zaqar/tests/unit/transport/websocket/v1_1/test_auth.py
Normal file
@ -0,0 +1,120 @@
|
||||
# Copyright (c) 2015 Red Hat, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from keystonemiddleware import auth_token
|
||||
import mock
|
||||
|
||||
from zaqar.tests.unit.transport.websocket import base
|
||||
from zaqar.tests.unit.transport.websocket import utils as test_utils
|
||||
|
||||
|
||||
class AuthTest(base.V1_1Base):
|
||||
|
||||
config_file = "websocket_mongodb_keystone_auth.conf"
|
||||
|
||||
def setUp(self):
|
||||
super(AuthTest, self).setUp()
|
||||
self.protocol = self.transport.factory()
|
||||
|
||||
self.default_message_ttl = 3600
|
||||
|
||||
self.project_id = '7e55e1a7e'
|
||||
self.headers = {
|
||||
'Client-ID': str(uuid.uuid4()),
|
||||
'X-Project-ID': self.project_id
|
||||
}
|
||||
auth_mock = mock.patch.object(auth_token.AuthProtocol, '__call__')
|
||||
self.addCleanup(auth_mock.stop)
|
||||
self.auth = auth_mock.start()
|
||||
self.env = {'keystone.token_info': {
|
||||
'token': {'expires_at': '2035-08-05T15:16:33.603700+00:00'}}}
|
||||
|
||||
def test_post(self):
|
||||
headers = self.headers.copy()
|
||||
headers['X-Auth-Token'] = 'mytoken1'
|
||||
req = json.dumps({'action': 'authenticate', 'headers': headers})
|
||||
|
||||
msg_mock = mock.patch.object(self.protocol, 'sendMessage')
|
||||
self.addCleanup(msg_mock.stop)
|
||||
msg_mock = msg_mock.start()
|
||||
self.protocol.onMessage(req, False)
|
||||
|
||||
# Didn't send the response yet
|
||||
self.assertEqual(0, msg_mock.call_count)
|
||||
|
||||
self.assertEqual(1, self.auth.call_count)
|
||||
responses = []
|
||||
self.protocol._auth_start(self.env, lambda x, y: responses.append(x))
|
||||
|
||||
self.assertEqual(1, len(responses))
|
||||
self.assertEqual('200 OK', responses[0])
|
||||
|
||||
def test_post_between_auth(self):
|
||||
headers = self.headers.copy()
|
||||
headers['X-Auth-Token'] = 'mytoken1'
|
||||
req = json.dumps({'action': 'authenticate', 'headers': headers})
|
||||
|
||||
msg_mock = mock.patch.object(self.protocol, 'sendMessage')
|
||||
self.addCleanup(msg_mock.stop)
|
||||
msg_mock = msg_mock.start()
|
||||
self.protocol.onMessage(req, False)
|
||||
|
||||
req = test_utils.create_request("queue_list", {}, self.headers)
|
||||
self.protocol.onMessage(req, False)
|
||||
|
||||
self.assertEqual(1, msg_mock.call_count)
|
||||
resp = json.loads(msg_mock.call_args[0][0])
|
||||
self.assertEqual(resp['headers']['status'], 403)
|
||||
|
||||
def test_failed_auth(self):
|
||||
msg_mock = mock.patch.object(self.protocol, 'sendMessage')
|
||||
self.addCleanup(msg_mock.stop)
|
||||
msg_mock = msg_mock.start()
|
||||
self.protocol._auth_response('401 error', 'Failed')
|
||||
self.assertEqual(1, msg_mock.call_count)
|
||||
resp = json.loads(msg_mock.call_args[0][0])
|
||||
self.assertEqual(resp['headers']['status'], 401)
|
||||
|
||||
def test_reauth(self):
|
||||
headers = self.headers.copy()
|
||||
headers['X-Auth-Token'] = 'mytoken1'
|
||||
req = json.dumps({'action': 'authenticate', 'headers': headers})
|
||||
|
||||
msg_mock = mock.patch.object(self.protocol, 'sendMessage')
|
||||
self.addCleanup(msg_mock.stop)
|
||||
msg_mock = msg_mock.start()
|
||||
self.protocol.onMessage(req, False)
|
||||
|
||||
self.assertEqual(1, self.auth.call_count)
|
||||
responses = []
|
||||
self.protocol._auth_start(self.env, lambda x, y: responses.append(x))
|
||||
|
||||
self.assertEqual(1, len(responses))
|
||||
handle = self.protocol._deauth_handle
|
||||
self.assertIsNotNone(handle)
|
||||
|
||||
headers = self.headers.copy()
|
||||
headers['X-Auth-Token'] = 'mytoken2'
|
||||
req = json.dumps({'action': 'authenticate', 'headers': headers})
|
||||
self.protocol.onMessage(req, False)
|
||||
self.protocol._auth_start(self.env, lambda x, y: responses.append(x))
|
||||
|
||||
self.assertNotEqual(handle, self.protocol._deauth_handle)
|
||||
self.assertEqual(2, len(responses))
|
||||
self.assertIn('cancelled', repr(handle))
|
||||
self.assertNotIn('cancelled', repr(self.protocol._deauth_handle))
|
@ -583,5 +583,5 @@ class MessagesBaseTest(base.V1_1Base):
|
||||
self.assertIn('error', response['body'])
|
||||
self.assertEqual({'status': 400}, response['headers'])
|
||||
self.assertEqual(
|
||||
{'action': None, 'api': None, 'body': None, 'headers': {}},
|
||||
{'action': None, 'api': 'v1.1', 'body': {}, 'headers': {}},
|
||||
response['request'])
|
||||
|
@ -23,6 +23,8 @@ except ImportError:
|
||||
|
||||
from zaqar.common import decorators
|
||||
from zaqar.i18n import _
|
||||
from zaqar.transport import auth
|
||||
from zaqar.transport import base
|
||||
from zaqar.transport.websocket import factory
|
||||
|
||||
|
||||
@ -48,22 +50,33 @@ def _config_options():
|
||||
return [(_WS_GROUP, _WS_OPTIONS)]
|
||||
|
||||
|
||||
class Driver(object):
|
||||
class Driver(base.DriverBase):
|
||||
|
||||
def __init__(self, conf, api, cache):
|
||||
self._conf = conf
|
||||
super(Driver, self).__init__(conf, None, None, None)
|
||||
self._api = api
|
||||
self._cache = cache
|
||||
|
||||
self._conf.register_opts(_WS_OPTIONS, group=_WS_GROUP)
|
||||
self._ws_conf = self._conf[_WS_GROUP]
|
||||
|
||||
if self._conf.auth_strategy:
|
||||
auth_strategy = auth.strategy(self._conf.auth_strategy)
|
||||
self._auth_strategy = lambda app: auth_strategy.install(
|
||||
app, self._conf)
|
||||
else:
|
||||
self._auth_strategy = None
|
||||
|
||||
@decorators.lazy_property(write=False)
|
||||
def factory(self):
|
||||
uri = 'ws://' + self._ws_conf.bind + ':' + str(self._ws_conf.port)
|
||||
return factory.ProtocolFactory(
|
||||
uri, debug=self._ws_conf.debug, handler=self._api,
|
||||
external_port=self._ws_conf.external_port)
|
||||
uri,
|
||||
debug=self._ws_conf.debug,
|
||||
handler=self._api,
|
||||
external_port=self._ws_conf.external_port,
|
||||
auth_strategy=self._auth_strategy,
|
||||
loop=asyncio.get_event_loop())
|
||||
|
||||
def listen(self):
|
||||
"""Self-host using 'bind' and 'port' from the WS config group."""
|
||||
|
@ -22,12 +22,15 @@ class ProtocolFactory(websocket.WebSocketServerFactory):
|
||||
|
||||
protocol = protocol.MessagingProtocol
|
||||
|
||||
def __init__(self, uri, debug, handler, external_port):
|
||||
def __init__(self, uri, debug, handler, external_port, auth_strategy,
|
||||
loop):
|
||||
websocket.WebSocketServerFactory.__init__(
|
||||
self, url=uri, debug=debug, externalPort=external_port)
|
||||
self._handler = handler
|
||||
self._auth_strategy = auth_strategy
|
||||
self._loop = loop
|
||||
|
||||
def __call__(self):
|
||||
proto = self.protocol(self._handler)
|
||||
proto = self.protocol(self._handler, self._auth_strategy, self._loop)
|
||||
proto.factory = self
|
||||
return proto
|
||||
|
@ -13,24 +13,39 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from autobahn.asyncio import websocket
|
||||
from oslo_log import log as logging
|
||||
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from zaqar.api.v1_1 import request as schema_validator
|
||||
from zaqar.common.api import request
|
||||
from zaqar.common.api import response
|
||||
from zaqar.common import errors
|
||||
from autobahn.asyncio import websocket
|
||||
from oslo_log import log as logging
|
||||
from oslo_utils import timeutils
|
||||
import pytz
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
_EPOCH = datetime.datetime(1970, 1, 1, tzinfo=pytz.UTC)
|
||||
|
||||
|
||||
class MessagingProtocol(websocket.WebSocketServerProtocol):
|
||||
|
||||
def __init__(self, handler):
|
||||
_fake_env = {
|
||||
'REQUEST_METHOD': 'POST',
|
||||
'SERVER_NAME': 'zaqar',
|
||||
'SERVER_PORT': 80,
|
||||
'SERVER_PROTOCOL': 'HTTP/1.1',
|
||||
'PATH_INFO': '/',
|
||||
'SCRIPT_NAME': '',
|
||||
'wsgi.url_scheme': 'http'
|
||||
}
|
||||
|
||||
def __init__(self, handler, auth_strategy, loop):
|
||||
websocket.WebSocketServerProtocol.__init__(self)
|
||||
self._handler = handler
|
||||
self._auth_strategy = auth_strategy
|
||||
self._loop = loop
|
||||
self._authentified = False
|
||||
self._auth_app = None
|
||||
self._deauth_handle = None
|
||||
|
||||
def onConnect(self, request):
|
||||
print("Client connecting: {0}".format(request.peer))
|
||||
@ -43,61 +58,75 @@ class MessagingProtocol(websocket.WebSocketServerProtocol):
|
||||
# TODO(vkmc): Binary support will be added in the next cycle
|
||||
# For now, we are returning an invalid request response
|
||||
print("Binary message received: {0} bytes".format(len(payload)))
|
||||
req = self._dummy_request()
|
||||
body = {'error': 'Schema validation failed.'}
|
||||
headers = {'status': 400}
|
||||
resp = response.Response(req, body, headers)
|
||||
return resp
|
||||
else:
|
||||
try:
|
||||
print("Text message received: {0}".format(payload))
|
||||
pl = json.loads(payload)
|
||||
req = self._create_request(pl)
|
||||
resp = (self._validate_request(pl, req) or
|
||||
self._handler.process_request(req))
|
||||
except ValueError as ex:
|
||||
LOG.exception(ex)
|
||||
req = self._dummy_request()
|
||||
body = {'error': str(ex)}
|
||||
headers = {'status': 400}
|
||||
resp = response.Response(req, body, headers)
|
||||
resp = self._handler.create_response(400, body)
|
||||
return self._send_response(resp)
|
||||
try:
|
||||
print("Text message received: {0}".format(payload))
|
||||
payload = json.loads(payload)
|
||||
except ValueError as ex:
|
||||
LOG.exception(ex)
|
||||
body = {'error': str(ex)}
|
||||
resp = self._handler.create_response(400, body)
|
||||
return self._send_response(resp)
|
||||
|
||||
resp_json = json.dumps(resp.get_response())
|
||||
self.sendMessage(resp_json, isBinary)
|
||||
req = self._handler.create_request(payload)
|
||||
resp = self._handler.validate_request(payload, req)
|
||||
if resp is None:
|
||||
if self._auth_strategy and not self._authentified:
|
||||
if self._auth_app or payload.get('action') != 'authenticate':
|
||||
body = {'error': 'Not authentified.'}
|
||||
resp = self._handler.create_response(403, body)
|
||||
else:
|
||||
return self._authenticate(payload)
|
||||
elif payload.get('action') == 'authenticate':
|
||||
return self._authenticate(payload)
|
||||
else:
|
||||
resp = self._handler.process_request(req)
|
||||
return self._send_response(resp)
|
||||
|
||||
def onClose(self, wasClean, code, reason):
|
||||
print("WebSocket connection closed: {0}".format(reason))
|
||||
|
||||
@staticmethod
|
||||
def _create_request(pl):
|
||||
action = pl.get('action')
|
||||
body = pl.get('body', {})
|
||||
headers = pl.get('headers')
|
||||
def _authenticate(self, payload):
|
||||
self._auth_app = self._auth_strategy(self._auth_start)
|
||||
env = self._fake_env.copy()
|
||||
env.update(
|
||||
(self._header_to_env_var(key), value)
|
||||
for key, value in payload.get('headers').items())
|
||||
self._auth_app(env, self._auth_response)
|
||||
|
||||
return request.Request(action=action, body=body,
|
||||
headers=headers, api="v1.1")
|
||||
def _auth_start(self, env, start_response):
|
||||
self._authentified = True
|
||||
self._auth_app = None
|
||||
expire = env['keystone.token_info']['token']['expires_at']
|
||||
expire_time = timeutils.parse_isotime(expire)
|
||||
timestamp = (expire_time - _EPOCH).total_seconds()
|
||||
if self._deauth_handle is not None:
|
||||
self._deauth_handle.cancel()
|
||||
self._deauth_handle = self._loop.call_at(
|
||||
timestamp, self._deauthenticate)
|
||||
|
||||
@staticmethod
|
||||
def _validate_request(pl, req):
|
||||
try:
|
||||
action = pl.get('action')
|
||||
validator = schema_validator.RequestSchema()
|
||||
is_valid = validator.validate(action=action, body=pl)
|
||||
except errors.InvalidAction as ex:
|
||||
body = {'error': str(ex)}
|
||||
headers = {'status': 400}
|
||||
resp = response.Response(req, body, headers)
|
||||
return resp
|
||||
start_response('200 OK', [])
|
||||
|
||||
def _deauthenticate(self):
|
||||
self._authentified = False
|
||||
self.sendClose(403, 'Authentication expired.')
|
||||
|
||||
def _auth_response(self, status, message):
|
||||
code = int(status.split()[0])
|
||||
if code != 200:
|
||||
body = {'error': 'Authentication failed.'}
|
||||
resp = self._handler.create_response(code, body)
|
||||
self._send_response(resp)
|
||||
else:
|
||||
if not is_valid:
|
||||
body = {'error': 'Schema validation failed.'}
|
||||
headers = {'status': 400}
|
||||
resp = response.Response(req, body, headers)
|
||||
return resp
|
||||
body = {'message': 'Authentified.'}
|
||||
resp = self._handler.create_response(200, body)
|
||||
self._send_response(resp)
|
||||
|
||||
return None
|
||||
def _header_to_env_var(self, key):
|
||||
return 'HTTP_%s' % key.replace('-', '_').upper()
|
||||
|
||||
@staticmethod
|
||||
def _dummy_request():
|
||||
action = None
|
||||
return request.Request(action)
|
||||
def _send_response(self, resp):
|
||||
resp_json = json.dumps(resp.get_response())
|
||||
self.sendMessage(resp_json, False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user