Merge "Authentication for websocket"

This commit is contained in:
Jenkins 2015-08-14 12:22:03 +00:00 committed by Gerrit Code Review
commit a777714075
8 changed files with 300 additions and 62 deletions

View File

@ -13,6 +13,11 @@
# the License.
from zaqar.api.v1_1 import endpoints
from zaqar.api.v1_1 import request as schema_validator
from zaqar.common.api import request
from zaqar.common.api import response
from zaqar.common import errors
class Handler(object):
@ -28,3 +33,40 @@ class Handler(object):
def process_request(self, req):
# FIXME(vkmc): Control API version
return getattr(self.v1_1_endpoints, req._action)(req)
@staticmethod
def validate_request(payload, req):
"""Validate a request and its payload against a schema.
:return: a Response object if validation failed, None otherwise.
"""
try:
action = payload.get('action')
validator = schema_validator.RequestSchema()
is_valid = validator.validate(action=action, body=payload)
except errors.InvalidAction as ex:
body = {'error': str(ex)}
headers = {'status': 400}
return response.Response(req, body, headers)
else:
if not is_valid:
body = {'error': 'Schema validation failed.'}
headers = {'status': 400}
return response.Response(req, body, headers)
def create_response(self, code, body, req=None):
if req is None:
req = self.create_request()
headers = {'status': code}
return response.Response(req, body, headers)
@staticmethod
def create_request(payload=None):
if payload is None:
payload = {}
action = payload.get('action')
body = payload.get('body', {})
headers = payload.get('headers')
return request.Request(action=action, body=body,
headers=headers, api="v1.1")

View File

@ -65,6 +65,17 @@ class RequestSchema(api.Api):
'required': ['action', 'headers'],
'admin': True,
},
'authenticate': {
'properties': {
'action': {'enum': ['authenticate']},
'headers': {
'type': 'object',
'properties': headers,
'required': ['X-Project-ID', 'X-Auth-Token']
}
},
'required': ['action', 'headers'],
},
# Queues
'queue_list': {

View File

@ -0,0 +1,20 @@
[DEFAULT]
auth_strategy = keystone
[drivers]
# Transport driver to use (string value)
transport = websocket
# Storage driver to use (string value)
message_store = mongodb
[drivers:management_store:mongodb]
# Mongodb Connection URI
uri = mongodb://127.0.0.1:27017
[drivers:message_store:mongodb]
# Mongodb Connection URI
uri = mongodb://127.0.0.1:27017

View File

@ -0,0 +1,120 @@
# Copyright (c) 2015 Red Hat, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import uuid
from keystonemiddleware import auth_token
import mock
from zaqar.tests.unit.transport.websocket import base
from zaqar.tests.unit.transport.websocket import utils as test_utils
class AuthTest(base.V1_1Base):
config_file = "websocket_mongodb_keystone_auth.conf"
def setUp(self):
super(AuthTest, self).setUp()
self.protocol = self.transport.factory()
self.default_message_ttl = 3600
self.project_id = '7e55e1a7e'
self.headers = {
'Client-ID': str(uuid.uuid4()),
'X-Project-ID': self.project_id
}
auth_mock = mock.patch.object(auth_token.AuthProtocol, '__call__')
self.addCleanup(auth_mock.stop)
self.auth = auth_mock.start()
self.env = {'keystone.token_info': {
'token': {'expires_at': '2035-08-05T15:16:33.603700+00:00'}}}
def test_post(self):
headers = self.headers.copy()
headers['X-Auth-Token'] = 'mytoken1'
req = json.dumps({'action': 'authenticate', 'headers': headers})
msg_mock = mock.patch.object(self.protocol, 'sendMessage')
self.addCleanup(msg_mock.stop)
msg_mock = msg_mock.start()
self.protocol.onMessage(req, False)
# Didn't send the response yet
self.assertEqual(0, msg_mock.call_count)
self.assertEqual(1, self.auth.call_count)
responses = []
self.protocol._auth_start(self.env, lambda x, y: responses.append(x))
self.assertEqual(1, len(responses))
self.assertEqual('200 OK', responses[0])
def test_post_between_auth(self):
headers = self.headers.copy()
headers['X-Auth-Token'] = 'mytoken1'
req = json.dumps({'action': 'authenticate', 'headers': headers})
msg_mock = mock.patch.object(self.protocol, 'sendMessage')
self.addCleanup(msg_mock.stop)
msg_mock = msg_mock.start()
self.protocol.onMessage(req, False)
req = test_utils.create_request("queue_list", {}, self.headers)
self.protocol.onMessage(req, False)
self.assertEqual(1, msg_mock.call_count)
resp = json.loads(msg_mock.call_args[0][0])
self.assertEqual(resp['headers']['status'], 403)
def test_failed_auth(self):
msg_mock = mock.patch.object(self.protocol, 'sendMessage')
self.addCleanup(msg_mock.stop)
msg_mock = msg_mock.start()
self.protocol._auth_response('401 error', 'Failed')
self.assertEqual(1, msg_mock.call_count)
resp = json.loads(msg_mock.call_args[0][0])
self.assertEqual(resp['headers']['status'], 401)
def test_reauth(self):
headers = self.headers.copy()
headers['X-Auth-Token'] = 'mytoken1'
req = json.dumps({'action': 'authenticate', 'headers': headers})
msg_mock = mock.patch.object(self.protocol, 'sendMessage')
self.addCleanup(msg_mock.stop)
msg_mock = msg_mock.start()
self.protocol.onMessage(req, False)
self.assertEqual(1, self.auth.call_count)
responses = []
self.protocol._auth_start(self.env, lambda x, y: responses.append(x))
self.assertEqual(1, len(responses))
handle = self.protocol._deauth_handle
self.assertIsNotNone(handle)
headers = self.headers.copy()
headers['X-Auth-Token'] = 'mytoken2'
req = json.dumps({'action': 'authenticate', 'headers': headers})
self.protocol.onMessage(req, False)
self.protocol._auth_start(self.env, lambda x, y: responses.append(x))
self.assertNotEqual(handle, self.protocol._deauth_handle)
self.assertEqual(2, len(responses))
self.assertIn('cancelled', repr(handle))
self.assertNotIn('cancelled', repr(self.protocol._deauth_handle))

View File

@ -583,5 +583,5 @@ class MessagesBaseTest(base.V1_1Base):
self.assertIn('error', response['body'])
self.assertEqual({'status': 400}, response['headers'])
self.assertEqual(
{'action': None, 'api': None, 'body': None, 'headers': {}},
{'action': None, 'api': 'v1.1', 'body': {}, 'headers': {}},
response['request'])

View File

@ -23,6 +23,8 @@ except ImportError:
from zaqar.common import decorators
from zaqar.i18n import _
from zaqar.transport import auth
from zaqar.transport import base
from zaqar.transport.websocket import factory
@ -48,22 +50,33 @@ def _config_options():
return [(_WS_GROUP, _WS_OPTIONS)]
class Driver(object):
class Driver(base.DriverBase):
def __init__(self, conf, api, cache):
self._conf = conf
super(Driver, self).__init__(conf, None, None, None)
self._api = api
self._cache = cache
self._conf.register_opts(_WS_OPTIONS, group=_WS_GROUP)
self._ws_conf = self._conf[_WS_GROUP]
if self._conf.auth_strategy:
auth_strategy = auth.strategy(self._conf.auth_strategy)
self._auth_strategy = lambda app: auth_strategy.install(
app, self._conf)
else:
self._auth_strategy = None
@decorators.lazy_property(write=False)
def factory(self):
uri = 'ws://' + self._ws_conf.bind + ':' + str(self._ws_conf.port)
return factory.ProtocolFactory(
uri, debug=self._ws_conf.debug, handler=self._api,
external_port=self._ws_conf.external_port)
uri,
debug=self._ws_conf.debug,
handler=self._api,
external_port=self._ws_conf.external_port,
auth_strategy=self._auth_strategy,
loop=asyncio.get_event_loop())
def listen(self):
"""Self-host using 'bind' and 'port' from the WS config group."""

View File

@ -22,12 +22,15 @@ class ProtocolFactory(websocket.WebSocketServerFactory):
protocol = protocol.MessagingProtocol
def __init__(self, uri, debug, handler, external_port):
def __init__(self, uri, debug, handler, external_port, auth_strategy,
loop):
websocket.WebSocketServerFactory.__init__(
self, url=uri, debug=debug, externalPort=external_port)
self._handler = handler
self._auth_strategy = auth_strategy
self._loop = loop
def __call__(self):
proto = self.protocol(self._handler)
proto = self.protocol(self._handler, self._auth_strategy, self._loop)
proto.factory = self
return proto

View File

@ -13,24 +13,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from autobahn.asyncio import websocket
from oslo_log import log as logging
import datetime
import json
from zaqar.api.v1_1 import request as schema_validator
from zaqar.common.api import request
from zaqar.common.api import response
from zaqar.common import errors
from autobahn.asyncio import websocket
from oslo_log import log as logging
from oslo_utils import timeutils
import pytz
LOG = logging.getLogger(__name__)
_EPOCH = datetime.datetime(1970, 1, 1, tzinfo=pytz.UTC)
class MessagingProtocol(websocket.WebSocketServerProtocol):
def __init__(self, handler):
_fake_env = {
'REQUEST_METHOD': 'POST',
'SERVER_NAME': 'zaqar',
'SERVER_PORT': 80,
'SERVER_PROTOCOL': 'HTTP/1.1',
'PATH_INFO': '/',
'SCRIPT_NAME': '',
'wsgi.url_scheme': 'http'
}
def __init__(self, handler, auth_strategy, loop):
websocket.WebSocketServerProtocol.__init__(self)
self._handler = handler
self._auth_strategy = auth_strategy
self._loop = loop
self._authentified = False
self._auth_app = None
self._deauth_handle = None
def onConnect(self, request):
print("Client connecting: {0}".format(request.peer))
@ -43,61 +58,75 @@ class MessagingProtocol(websocket.WebSocketServerProtocol):
# TODO(vkmc): Binary support will be added in the next cycle
# For now, we are returning an invalid request response
print("Binary message received: {0} bytes".format(len(payload)))
req = self._dummy_request()
body = {'error': 'Schema validation failed.'}
headers = {'status': 400}
resp = response.Response(req, body, headers)
return resp
else:
try:
print("Text message received: {0}".format(payload))
pl = json.loads(payload)
req = self._create_request(pl)
resp = (self._validate_request(pl, req) or
self._handler.process_request(req))
except ValueError as ex:
LOG.exception(ex)
req = self._dummy_request()
body = {'error': str(ex)}
headers = {'status': 400}
resp = response.Response(req, body, headers)
resp = self._handler.create_response(400, body)
return self._send_response(resp)
try:
print("Text message received: {0}".format(payload))
payload = json.loads(payload)
except ValueError as ex:
LOG.exception(ex)
body = {'error': str(ex)}
resp = self._handler.create_response(400, body)
return self._send_response(resp)
resp_json = json.dumps(resp.get_response())
self.sendMessage(resp_json, isBinary)
req = self._handler.create_request(payload)
resp = self._handler.validate_request(payload, req)
if resp is None:
if self._auth_strategy and not self._authentified:
if self._auth_app or payload.get('action') != 'authenticate':
body = {'error': 'Not authentified.'}
resp = self._handler.create_response(403, body)
else:
return self._authenticate(payload)
elif payload.get('action') == 'authenticate':
return self._authenticate(payload)
else:
resp = self._handler.process_request(req)
return self._send_response(resp)
def onClose(self, wasClean, code, reason):
print("WebSocket connection closed: {0}".format(reason))
@staticmethod
def _create_request(pl):
action = pl.get('action')
body = pl.get('body', {})
headers = pl.get('headers')
def _authenticate(self, payload):
self._auth_app = self._auth_strategy(self._auth_start)
env = self._fake_env.copy()
env.update(
(self._header_to_env_var(key), value)
for key, value in payload.get('headers').items())
self._auth_app(env, self._auth_response)
return request.Request(action=action, body=body,
headers=headers, api="v1.1")
def _auth_start(self, env, start_response):
self._authentified = True
self._auth_app = None
expire = env['keystone.token_info']['token']['expires_at']
expire_time = timeutils.parse_isotime(expire)
timestamp = (expire_time - _EPOCH).total_seconds()
if self._deauth_handle is not None:
self._deauth_handle.cancel()
self._deauth_handle = self._loop.call_at(
timestamp, self._deauthenticate)
@staticmethod
def _validate_request(pl, req):
try:
action = pl.get('action')
validator = schema_validator.RequestSchema()
is_valid = validator.validate(action=action, body=pl)
except errors.InvalidAction as ex:
body = {'error': str(ex)}
headers = {'status': 400}
resp = response.Response(req, body, headers)
return resp
start_response('200 OK', [])
def _deauthenticate(self):
self._authentified = False
self.sendClose(403, 'Authentication expired.')
def _auth_response(self, status, message):
code = int(status.split()[0])
if code != 200:
body = {'error': 'Authentication failed.'}
resp = self._handler.create_response(code, body)
self._send_response(resp)
else:
if not is_valid:
body = {'error': 'Schema validation failed.'}
headers = {'status': 400}
resp = response.Response(req, body, headers)
return resp
body = {'message': 'Authentified.'}
resp = self._handler.create_response(200, body)
self._send_response(resp)
return None
def _header_to_env_var(self, key):
return 'HTTP_%s' % key.replace('-', '_').upper()
@staticmethod
def _dummy_request():
action = None
return request.Request(action)
def _send_response(self, resp):
resp_json = json.dumps(resp.get_response())
self.sendMessage(resp_json, False)