Refactoring error handling in Refstack API

All exited refstack excetions should be transformed to
valid HTTPError with descriptive messages

Closes-Bug: #1485715
Change-Id: I5da363a2daa84846ed862e5316c22f1b9cf8d2ec
This commit is contained in:
sslypushenko 2015-08-26 15:28:08 +03:00
parent 1ca17dde52
commit c60fa7d8d2
15 changed files with 155 additions and 146 deletions

View File

@ -24,10 +24,12 @@ from oslo_config import cfg
from oslo_log import log from oslo_log import log
from oslo_log import loggers from oslo_log import loggers
import pecan import pecan
import six
import webob import webob
from refstack.api import exceptions as api_exc
from refstack.api import utils as api_utils from refstack.api import utils as api_utils
from refstack.common import validators from refstack import db
LOG = log.getLogger(__name__) LOG = log.getLogger(__name__)
@ -121,19 +123,29 @@ class JSONErrorHook(pecan.hooks.PecanHook):
if isinstance(exc, webob.exc.HTTPRedirection): if isinstance(exc, webob.exc.HTTPRedirection):
return return
elif isinstance(exc, webob.exc.HTTPError): elif isinstance(exc, webob.exc.HTTPError):
status_code = exc.status_int return webob.Response(
body = {'title': exc.title} body=json.dumps({'code': exc.status_int,
elif isinstance(exc, validators.ValidationError): 'title': exc.title}),
status=exc.status_int,
content_type='application/json'
)
title = None
if isinstance(exc, api_exc.ValidationError):
status_code = 400 status_code = 400
body = {'title': exc.title} elif isinstance(exc, api_exc.ParseInputsError):
status_code = 400
elif isinstance(exc, db.NotFound):
status_code = 404
elif isinstance(exc, db.Duplication):
status_code = 409
else: else:
LOG.exception(exc) LOG.exception(exc)
status_code = 500 status_code = 500
body = {'title': 'Internal Server Error'} title = 'Internal Server Error'
body['code'] = status_code body = {'title': title or exc.args[0], 'code': status_code}
if self.debug: if self.debug:
body['detail'] = str(exc) body['detail'] = six.text_type(exc)
return webob.Response( return webob.Response(
body=json.dumps(body), body=json.dumps(body),
status=status_code, status=status_code,

View File

@ -24,8 +24,8 @@ from six.moves.urllib import parse
from refstack import db from refstack import db
from refstack.api import constants as const from refstack.api import constants as const
from refstack.api import utils as api_utils from refstack.api import utils as api_utils
from refstack.api import validators
from refstack.api.controllers import validation from refstack.api.controllers import validation
from refstack.common import validators
LOG = log.getLogger(__name__) LOG = log.getLogger(__name__)
@ -130,13 +130,10 @@ class ResultsController(validation.BaseRestControllerWithValidation):
const.SIGNED const.SIGNED
] ]
try: filters = api_utils.parse_input_params(expected_input_params)
filters = api_utils.parse_input_params(expected_input_params) records_count = db.get_test_records_count(filters)
records_count = db.get_test_records_count(filters) page_number, total_pages_number = \
page_number, total_pages_number = \ api_utils.get_page_number(records_count)
api_utils.get_page_number(records_count)
except api_utils.ParseInputsError as ex:
pecan.abort(400, 'Reason: %s' % ex)
try: try:
per_page = CONF.api.results_per_page per_page = CONF.api.results_per_page

View File

@ -20,8 +20,8 @@ from pecan import rest
from pecan.secure import secure from pecan.secure import secure
from refstack.api import utils as api_utils from refstack.api import utils as api_utils
from refstack.api import validators
from refstack.api.controllers import validation from refstack.api.controllers import validation
from refstack.common import validators
from refstack import db from refstack import db

View File

@ -0,0 +1,46 @@
#
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Refstack API exceptions."""
class ParseInputsError(Exception):
"""Raise if input params are invalid."""
pass
class ValidationError(Exception):
"""Raise if request doesn't pass trough validation process."""
def __init__(self, title, exc=None):
"""Init."""
super(ValidationError, self).__init__(title)
self.exc = exc
self.title = title
self.details = "%s(%s: %s)" % (self.title,
self.exc.__class__.__name__,
str(self.exc)) \
if self.exc else self.title
def __repr__(self):
"""Repr method."""
return self.details
def __str__(self):
"""Str method."""
return self.__repr__()

View File

@ -31,18 +31,12 @@ from six.moves.urllib import parse
from refstack import db from refstack import db
from refstack.api import constants as const from refstack.api import constants as const
from refstack.api import exceptions as api_exc
LOG = log.getLogger(__name__) LOG = log.getLogger(__name__)
CONF = cfg.CONF CONF = cfg.CONF
class ParseInputsError(Exception):
"""Raise if input params are invalid."""
pass
def _get_input_params_from_request(expected_params): def _get_input_params_from_request(expected_params):
"""Get input parameters from request. """Get input parameters from request.
@ -77,18 +71,16 @@ def parse_input_params(expected_input_params):
try: try:
filters[key] = timeutils.parse_strtime(value, date_fmt) filters[key] = timeutils.parse_strtime(value, date_fmt)
except (ValueError, TypeError) as exc: except (ValueError, TypeError) as exc:
raise ParseInputsError('Invalid date format: %(exc)s' raise api_exc.ParseInputsError(
% {'exc': exc}) 'Invalid date format: %(exc)s' % {'exc': exc})
start_date = filters.get(const.START_DATE) start_date = filters.get(const.START_DATE)
end_date = filters.get(const.END_DATE) end_date = filters.get(const.END_DATE)
if start_date and end_date: if start_date and end_date:
if start_date > end_date: if start_date > end_date:
raise ParseInputsError('Invalid dates: %(start)s ' raise api_exc.ParseInputsError(
'more than %(end)s' % { 'Invalid dates: %(start)s more than %(end)s'
'start': const.START_DATE, '' % {'start': const.START_DATE, 'end': const.END_DATE})
'end': const.END_DATE
})
if const.SIGNED in filters: if const.SIGNED in filters:
if is_authenticated(): if is_authenticated():
filters[const.OPENID] = get_user_id() filters[const.OPENID] = get_user_id()
@ -97,8 +89,8 @@ def parse_input_params(expected_input_params):
for pk in get_user_public_keys() for pk in get_user_public_keys()
] ]
else: else:
raise ParseInputsError('To see signed test ' raise api_exc.ParseInputsError(
'results you need to authenticate') 'To see signed test results you need to authenticate')
return filters return filters
@ -129,19 +121,21 @@ def get_page_number(records_count):
try: try:
page_number = int(page_number) page_number = int(page_number)
except (ValueError, TypeError): except (ValueError, TypeError):
raise ParseInputsError('Invalid page number: The page number can not ' raise api_exc.ParseInputsError(
'be converted to an integer') 'Invalid page number: The page number can not be converted to '
'an integer')
if page_number == 1: if page_number == 1:
return (page_number, total_pages) return (page_number, total_pages)
if page_number <= 0: if page_number <= 0:
raise ParseInputsError('Invalid page number: ' raise api_exc.ParseInputsError('Invalid page number: '
'The page number less or equal zero.') 'The page number less or equal zero.')
if page_number > total_pages: if page_number > total_pages:
raise ParseInputsError('Invalid page number: The page number ' raise api_exc.ParseInputsError(
'is greater than the total number of pages.') 'Invalid page number: '
'The page number is greater than the total number of pages.')
return (page_number, total_pages) return (page_number, total_pages)

View File

@ -24,32 +24,11 @@ from Crypto.Hash import SHA256
from Crypto.PublicKey import RSA from Crypto.PublicKey import RSA
from Crypto.Signature import PKCS1_v1_5 from Crypto.Signature import PKCS1_v1_5
from refstack.api import exceptions as api_exc
ext_format_checker = jsonschema.FormatChecker() ext_format_checker = jsonschema.FormatChecker()
class ValidationError(Exception):
"""Raise if request doesn't pass trough validation process."""
def __init__(self, title, exc=None):
"""Init."""
super(ValidationError, self).__init__(title)
self.exc = exc
self.title = title
self.details = "%s(%s: %s)" % (self.title,
self.exc.__class__.__name__,
str(self.exc)) \
if self.exc else self.title
def __repr__(self):
"""Repr method."""
return self.details
def __str__(self):
"""Str method."""
return self.__repr__()
def is_uuid(inst): def is_uuid(inst):
"""Check that inst is a uuid_hex string.""" """Check that inst is a uuid_hex string."""
try: try:
@ -86,12 +65,13 @@ class BaseValidator(object):
try: try:
body = json.loads(request.body) body = json.loads(request.body)
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
raise ValidationError('Malformed request', e) raise api_exc.ValidationError('Malformed request', e)
try: try:
jsonschema.validate(body, self.schema) jsonschema.validate(body, self.schema)
except jsonschema.ValidationError as e: except jsonschema.ValidationError as e:
raise ValidationError('Request doesn''t correspond to schema', e) raise api_exc.ValidationError(
'Request doesn''t correspond to schema', e)
class TestResultValidator(BaseValidator): class TestResultValidator(BaseValidator):
@ -132,17 +112,17 @@ class TestResultValidator(BaseValidator):
try: try:
sign = binascii.a2b_hex(request.headers.get('X-Signature', '')) sign = binascii.a2b_hex(request.headers.get('X-Signature', ''))
except (binascii.Error, TypeError) as e: except (binascii.Error, TypeError) as e:
raise ValidationError('Malformed signature', e) raise api_exc.ValidationError('Malformed signature', e)
try: try:
key = RSA.importKey(request.headers.get('X-Public-Key', '')) key = RSA.importKey(request.headers.get('X-Public-Key', ''))
except (binascii.Error, ValueError) as e: except (binascii.Error, ValueError) as e:
raise ValidationError('Malformed public key', e) raise api_exc.ValidationError('Malformed public key', e)
signer = PKCS1_v1_5.new(key) signer = PKCS1_v1_5.new(key)
data_hash = SHA256.new() data_hash = SHA256.new()
data_hash.update(request.body.encode('utf-8')) data_hash.update(request.body.encode('utf-8'))
if not signer.verify(data_hash, sign): if not signer.verify(data_hash, sign):
raise ValidationError('Signature verification failed') raise api_exc.ValidationError('Signature verification failed')
@staticmethod @staticmethod
def assert_id(_id): def assert_id(_id):
@ -172,19 +152,19 @@ class PubkeyValidator(BaseValidator):
if key_format not in ('ssh-dss', 'ssh-rsa', if key_format not in ('ssh-dss', 'ssh-rsa',
'pgp-sign-rsa', 'pgp-sign-dss'): 'pgp-sign-rsa', 'pgp-sign-dss'):
raise ValidationError('Public key has unsupported format') raise api_exc.ValidationError('Public key has unsupported format')
try: try:
sign = binascii.a2b_hex(body['self_signature']) sign = binascii.a2b_hex(body['self_signature'])
except (binascii.Error, TypeError) as e: except (binascii.Error, TypeError) as e:
raise ValidationError('Malformed signature', e) raise api_exc.ValidationError('Malformed signature', e)
try: try:
key = RSA.importKey(body['raw_key']) key = RSA.importKey(body['raw_key'])
except (binascii.Error, ValueError) as e: except (binascii.Error, ValueError) as e:
raise ValidationError('Malformed public key', e) raise api_exc.ValidationError('Malformed public key', e)
signer = PKCS1_v1_5.new(key) signer = PKCS1_v1_5.new(key)
data_hash = SHA256.new() data_hash = SHA256.new()
data_hash.update('signature'.encode('utf-8')) data_hash.update('signature'.encode('utf-8'))
if not signer.verify(data_hash, sign): if not signer.verify(data_hash, sign):
raise ValidationError('Signature verification failed') raise api_exc.ValidationError('Signature verification failed')

View File

@ -1,15 +0,0 @@
# Copyright (c) 2015 Mirantis, Inc.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Refstack common package."""

View File

@ -37,6 +37,7 @@ IMPL = db_api.DBAPI.from_config(cfg.CONF, backend_mapping=_BACKEND_MAPPING,
lazy=True) lazy=True)
NotFound = IMPL.NotFound NotFound = IMPL.NotFound
Duplication = IMPL.Duplication
def store_results(results): def store_results(results):

View File

@ -23,7 +23,6 @@ import uuid
from oslo_config import cfg from oslo_config import cfg
from oslo_db import options as db_options from oslo_db import options as db_options
from oslo_db.sqlalchemy import session as db_session from oslo_db.sqlalchemy import session as db_session
from oslo_db import exception as oslo_db_exc
import six import six
from refstack.api import constants as api_const from refstack.api import constants as api_const
@ -41,11 +40,14 @@ class NotFound(Exception):
"""Raise if item not found in db.""" """Raise if item not found in db."""
def __init__(self, model, details=None): pass
"""Init."""
self.model = model
title = details if details else ''.join((model, ' not found.')) class Duplication(Exception):
super(NotFound, self).__init__(title)
"""Raise if unique constraint violates."""
pass
def _create_facade_lazily(): def _create_facade_lazily():
@ -136,7 +138,7 @@ def get_test(test_id, allowed_keys=None):
filter_by(id=test_id). \ filter_by(id=test_id). \
first() first()
if not test_info: if not test_info:
raise NotFound('Test', 'Test result %s not found' % test_id) raise NotFound('Test result %s not found' % test_id)
return _to_dict(test_info, allowed_keys) return _to_dict(test_info, allowed_keys)
@ -152,7 +154,7 @@ def delete_test(test_id):
.filter_by(test_id=test_id).delete() .filter_by(test_id=test_id).delete()
session.delete(test) session.delete(test)
else: else:
raise NotFound('Test', 'Test result %s not found' % test_id) raise NotFound('Test result %s not found' % test_id)
def get_test_meta_key(test_id, key, default=None): def get_test_meta_key(test_id, key, default=None):
@ -190,8 +192,7 @@ def delete_test_meta_item(test_id, key):
with session.begin(): with session.begin():
session.delete(meta_item) session.delete(meta_item)
else: else:
raise NotFound('TestMeta', raise NotFound('Metadata key %s '
'Metadata key %s '
'not found for test run %s' % (key, test_id)) 'not found for test run %s' % (key, test_id))
@ -263,7 +264,7 @@ def user_get(user_openid):
session = get_session() session = get_session()
user = session.query(models.User).filter_by(openid=user_openid).first() user = session.query(models.User).filter_by(openid=user_openid).first()
if user is None: if user is None:
raise NotFound('User', 'User with OpenID %s not found' % user_openid) raise NotFound('User with OpenID %s not found' % user_openid)
return user return user
@ -302,8 +303,7 @@ def store_pubkey(pubkey_info):
if not pubkeys_collision: if not pubkeys_collision:
pubkey.save(session) pubkey.save(session)
else: else:
raise oslo_db_exc.DBDuplicateEntry(columns=['pubkeys.pubkey'], raise Duplication('Public key already exists.')
value=pubkey.pubkey)
return pubkey.id return pubkey.id

View File

@ -23,7 +23,7 @@ from oslo_config import fixture as config_fixture
import six import six
import webtest.app import webtest.app
from refstack.common import validators from refstack.api import validators
from refstack.tests import api from refstack.tests import api
FAKE_TESTS_RESULT = { FAKE_TESTS_RESULT = {

View File

@ -26,7 +26,7 @@ from six.moves.urllib import parse
import webob.exc import webob.exc
from refstack.api import constants as const from refstack.api import constants as const
from refstack.api import utils as api_utils from refstack.api import exceptions as api_exc
from refstack.api.controllers import auth from refstack.api.controllers import auth
from refstack.api.controllers import capabilities from refstack.api.controllers import capabilities
from refstack.api.controllers import results from refstack.api.controllers import results
@ -184,8 +184,8 @@ class ResultsControllerTestCase(BaseControllerTestCase):
@mock.patch('refstack.api.utils.parse_input_params') @mock.patch('refstack.api.utils.parse_input_params')
def test_get_failed_in_parse_input_params(self, parse_inputs): def test_get_failed_in_parse_input_params(self, parse_inputs):
parse_inputs.side_effect = api_utils.ParseInputsError() parse_inputs.side_effect = api_exc.ParseInputsError()
self.assertRaises(webob.exc.HTTPError, self.assertRaises(api_exc.ParseInputsError,
self.controller.get) self.controller.get)
@mock.patch('refstack.db.get_test_records_count') @mock.patch('refstack.db.get_test_records_count')
@ -193,8 +193,8 @@ class ResultsControllerTestCase(BaseControllerTestCase):
def test_get_failed_in_get_test_records_number(self, def test_get_failed_in_get_test_records_number(self,
parse_inputs, parse_inputs,
db_get_count): db_get_count):
db_get_count.side_effect = api_utils.ParseInputsError() db_get_count.side_effect = api_exc.ParseInputsError()
self.assertRaises(webob.exc.HTTPError, self.assertRaises(api_exc.ParseInputsError,
self.controller.get) self.controller.get)
@mock.patch('refstack.db.get_test_records_count') @mock.patch('refstack.db.get_test_records_count')
@ -205,8 +205,8 @@ class ResultsControllerTestCase(BaseControllerTestCase):
parse_input, parse_input,
db_get_count): db_get_count):
get_page.side_effect = api_utils.ParseInputsError() get_page.side_effect = api_exc.ParseInputsError()
self.assertRaises(webob.exc.HTTPError, self.assertRaises(api_exc.ParseInputsError,
self.controller.get) self.controller.get)
@mock.patch('refstack.db.get_test_records') @mock.patch('refstack.db.get_test_records')

View File

@ -24,6 +24,7 @@ from six.moves.urllib import parse
from webob import exc from webob import exc
from refstack.api import constants as const from refstack.api import constants as const
from refstack.api import exceptions as api_exc
from refstack.api import utils as api_utils from refstack.api import utils as api_utils
from refstack import db from refstack import db
@ -97,7 +98,7 @@ class APIUtilsTestCase(base.BaseTestCase):
expected_params = mock.Mock() expected_params = mock.Mock()
mock_get_input.return_value = raw_filters mock_get_input.return_value = raw_filters
mock_strtime.side_effect = ValueError() mock_strtime.side_effect = ValueError()
self.assertRaises(api_utils.ParseInputsError, self.assertRaises(api_exc.ParseInputsError,
api_utils.parse_input_params, api_utils.parse_input_params,
expected_params) expected_params)
@ -115,7 +116,7 @@ class APIUtilsTestCase(base.BaseTestCase):
expected_params = mock.Mock() expected_params = mock.Mock()
mock_get_input.return_value = raw_filters mock_get_input.return_value = raw_filters
self.assertRaises(api_utils.ParseInputsError, self.assertRaises(api_exc.ParseInputsError,
api_utils.parse_input_params, api_utils.parse_input_params,
expected_params) expected_params)
@ -135,7 +136,7 @@ class APIUtilsTestCase(base.BaseTestCase):
} }
expected_params = mock.Mock() expected_params = mock.Mock()
mock_get_input.return_value = raw_filters mock_get_input.return_value = raw_filters
self.assertRaises(api_utils.ParseInputsError, self.assertRaises(api_exc.ParseInputsError,
api_utils.parse_input_params, expected_params) api_utils.parse_input_params, expected_params)
@mock.patch.object(api_utils, '_get_input_params_from_request') @mock.patch.object(api_utils, '_get_input_params_from_request')
@ -225,7 +226,7 @@ class APIUtilsTestCase(base.BaseTestCase):
const.PAGE: 'abc' const.PAGE: 'abc'
} }
self.assertRaises(api_utils.ParseInputsError, self.assertRaises(api_exc.ParseInputsError,
api_utils.get_page_number, api_utils.get_page_number,
total_records) total_records)
@ -256,7 +257,7 @@ class APIUtilsTestCase(base.BaseTestCase):
const.PAGE: '-1' const.PAGE: '-1'
} }
self.assertRaises(api_utils.ParseInputsError, self.assertRaises(api_exc.ParseInputsError,
api_utils.get_page_number, api_utils.get_page_number,
total_records) total_records)
@ -271,7 +272,7 @@ class APIUtilsTestCase(base.BaseTestCase):
const.PAGE: '100' const.PAGE: '100'
} }
self.assertRaises(api_utils.ParseInputsError, self.assertRaises(api_exc.ParseInputsError,
api_utils.get_page_number, api_utils.get_page_number,
total_records) total_records)

View File

@ -24,7 +24,7 @@ import pecan
import webob import webob
from refstack.api import app from refstack.api import app
from refstack.common import validators from refstack.api import exceptions as api_exc
def get_response_kwargs(response_mock): def get_response_kwargs(response_mock):
@ -65,19 +65,12 @@ class JSONErrorHookTestCase(base.BaseTestCase):
expected_body={'code': exc.status_int, 'title': exc.title} expected_body={'code': exc.status_int, 'title': exc.title}
) )
self.CONF.set_override('app_dev_mode', True, 'api')
self._on_error(
response, exc, expected_status_code=exc.status,
expected_body={'code': exc.status_int, 'title': exc.title,
'detail': str(exc)}
)
@mock.patch.object(webob, 'Response') @mock.patch.object(webob, 'Response')
def test_on_error_with_validation_error(self, response): def test_on_error_with_validation_error(self, response):
self.CONF.set_override('app_dev_mode', False, 'api') self.CONF.set_override('app_dev_mode', False, 'api')
exc = mock.Mock(spec=validators.ValidationError, exc = mock.MagicMock(spec=api_exc.ValidationError,
title='No No No!') title='No No No!')
exc.args = ('No No No!',)
self._on_error( self._on_error(
response, exc, expected_status_code=400, response, exc, expected_status_code=400,
expected_body={'code': 400, 'title': exc.title} expected_body={'code': 400, 'title': exc.title}

View File

@ -20,7 +20,6 @@ import hashlib
import six import six
import mock import mock
from oslo_config import fixture as config_fixture from oslo_config import fixture as config_fixture
from oslo_db import exception as oslo_db_exc
from oslotest import base from oslotest import base
import sqlalchemy.orm import sqlalchemy.orm
@ -551,7 +550,7 @@ class DBBackendTestCase(base.BaseTestCase):
.filter_by.return_value\ .filter_by.return_value\
.filter_by.return_value\ .filter_by.return_value\
.all.return_value = mock_pubkey .all.return_value = mock_pubkey
self.assertRaises(oslo_db_exc.DBDuplicateEntry, self.assertRaises(db.Duplication,
db.store_pubkey, pubkey_info) db.store_pubkey, pubkey_info)
@mock.patch.object(api, 'get_session') @mock.patch.object(api, 'get_session')

View File

@ -25,14 +25,15 @@ import mock
from oslotest import base from oslotest import base
import six import six
from refstack.common import validators from refstack.api import exceptions as api_exc
from refstack.api import validators
class ValidatorsTestCase(base.BaseTestCase): class ValidatorsTestCase(base.BaseTestCase):
"""Test case for validator's helpers.""" """Test case for validator's helpers."""
def test_str_validation_error(self): def test_str_validation_error(self):
err = validators.ValidationError( err = api_exc.ValidationError(
'Something went wrong!', 'Something went wrong!',
AttributeError("'NoneType' object has no attribute 'a'") AttributeError("'NoneType' object has no attribute 'a'")
) )
@ -42,7 +43,7 @@ class ValidatorsTestCase(base.BaseTestCase):
'AttributeError', 'AttributeError',
"'NoneType' object has no attribute 'a'" "'NoneType' object has no attribute 'a'"
), str(err)) ), str(err))
err = validators.ValidationError( err = api_exc.ValidationError(
'Something went wrong again!' 'Something went wrong again!'
) )
self.assertEqual('Something went wrong again!', str(err)) self.assertEqual('Something went wrong again!', str(err))
@ -100,7 +101,7 @@ class TestResultValidatorTestCase(base.BaseTestCase):
request.body = json.dumps(self.FAKE_JSON) request.body = json.dumps(self.FAKE_JSON)
data_hash = SHA256.new() data_hash = SHA256.new()
data_hash.update(request.body.encode('utf-8')) data_hash.update(request.body.encode('utf-8'))
key = RSA.generate(4096) key = RSA.generate(1024)
signer = PKCS1_v1_5.new(key) signer = PKCS1_v1_5.new(key)
sign = signer.sign(data_hash) sign = signer.sign(data_hash)
request.headers = { request.headers = {
@ -112,12 +113,12 @@ class TestResultValidatorTestCase(base.BaseTestCase):
def test_validation_fail_no_json(self): def test_validation_fail_no_json(self):
wrong_request = mock.Mock() wrong_request = mock.Mock()
wrong_request.body = 'foo' wrong_request.body = 'foo'
self.assertRaises(validators.ValidationError, self.assertRaises(api_exc.ValidationError,
self.validator.validate, self.validator.validate,
wrong_request) wrong_request)
try: try:
self.validator.validate(wrong_request) self.validator.validate(wrong_request)
except validators.ValidationError as e: except api_exc.ValidationError as e:
self.assertIsInstance(e.exc, ValueError) self.assertIsInstance(e.exc, ValueError)
def test_validation_fail(self): def test_validation_fail(self):
@ -125,12 +126,12 @@ class TestResultValidatorTestCase(base.BaseTestCase):
wrong_request.body = json.dumps({ wrong_request.body = json.dumps({
'foo': 'bar' 'foo': 'bar'
}) })
self.assertRaises(validators.ValidationError, self.assertRaises(api_exc.ValidationError,
self.validator.validate, self.validator.validate,
wrong_request) wrong_request)
try: try:
self.validator.validate(wrong_request) self.validator.validate(wrong_request)
except validators.ValidationError as e: except api_exc.ValidationError as e:
self.assertIsInstance(e.exc, jsonschema.ValidationError) self.assertIsInstance(e.exc, jsonschema.ValidationError)
@mock.patch('jsonschema.validate') @mock.patch('jsonschema.validate')
@ -145,7 +146,7 @@ class TestResultValidatorTestCase(base.BaseTestCase):
'X-Signature': binascii.b2a_hex('fake_sign'.encode('utf-8')), 'X-Signature': binascii.b2a_hex('fake_sign'.encode('utf-8')),
'X-Public-Key': key.publickey().exportKey('OpenSSH') 'X-Public-Key': key.publickey().exportKey('OpenSSH')
} }
self.assertRaises(validators.ValidationError, self.assertRaises(api_exc.ValidationError,
self.validator.validate, self.validator.validate,
request) request)
request.headers = { request.headers = {
@ -154,7 +155,7 @@ class TestResultValidatorTestCase(base.BaseTestCase):
} }
try: try:
self.validator.validate(request) self.validator.validate(request)
except validators.ValidationError as e: except api_exc.ValidationError as e:
self.assertEqual(e.title, self.assertEqual(e.title,
'Signature verification failed') 'Signature verification failed')
@ -164,7 +165,7 @@ class TestResultValidatorTestCase(base.BaseTestCase):
} }
try: try:
self.validator.validate(request) self.validator.validate(request)
except validators.ValidationError as e: except api_exc.ValidationError as e:
self.assertIsInstance(e.exc, TypeError) self.assertIsInstance(e.exc, TypeError)
request.headers = { request.headers = {
@ -173,7 +174,7 @@ class TestResultValidatorTestCase(base.BaseTestCase):
} }
try: try:
self.validator.validate(request) self.validator.validate(request)
except validators.ValidationError as e: except api_exc.ValidationError as e:
self.assertIsInstance(e.exc, ValueError) self.assertIsInstance(e.exc, ValueError)
@ -207,12 +208,12 @@ class PubkeyValidatorTestCase(base.BaseTestCase):
def test_validation_fail_no_json(self): def test_validation_fail_no_json(self):
wrong_request = mock.Mock() wrong_request = mock.Mock()
wrong_request.body = 'foo' wrong_request.body = 'foo'
self.assertRaises(validators.ValidationError, self.assertRaises(api_exc.ValidationError,
self.validator.validate, self.validator.validate,
wrong_request) wrong_request)
try: try:
self.validator.validate(wrong_request) self.validator.validate(wrong_request)
except validators.ValidationError as e: except api_exc.ValidationError as e:
self.assertIsInstance(e.exc, ValueError) self.assertIsInstance(e.exc, ValueError)
def test_validation_fail(self): def test_validation_fail(self):
@ -220,12 +221,12 @@ class PubkeyValidatorTestCase(base.BaseTestCase):
wrong_request.body = json.dumps({ wrong_request.body = json.dumps({
'foo': 'bar' 'foo': 'bar'
}) })
self.assertRaises(validators.ValidationError, self.assertRaises(api_exc.ValidationError,
self.validator.validate, self.validator.validate,
wrong_request) wrong_request)
try: try:
self.validator.validate(wrong_request) self.validator.validate(wrong_request)
except validators.ValidationError as e: except api_exc.ValidationError as e:
self.assertIsInstance(e.exc, jsonschema.ValidationError) self.assertIsInstance(e.exc, jsonschema.ValidationError)
@mock.patch('jsonschema.validate') @mock.patch('jsonschema.validate')
@ -239,7 +240,7 @@ class PubkeyValidatorTestCase(base.BaseTestCase):
request.body = json.dumps(body) request.body = json.dumps(body)
try: try:
self.validator.validate(request) self.validator.validate(request)
except validators.ValidationError as e: except api_exc.ValidationError as e:
self.assertEqual(e.title, self.assertEqual(e.title,
'Signature verification failed') 'Signature verification failed')
@ -251,7 +252,7 @@ class PubkeyValidatorTestCase(base.BaseTestCase):
request.body = json.dumps(body) request.body = json.dumps(body)
try: try:
self.validator.validate(request) self.validator.validate(request)
except validators.ValidationError as e: except api_exc.ValidationError as e:
self.assertEqual(e.title, self.assertEqual(e.title,
'Public key has unsupported format') 'Public key has unsupported format')
@ -263,7 +264,7 @@ class PubkeyValidatorTestCase(base.BaseTestCase):
request.body = json.dumps(body) request.body = json.dumps(body)
try: try:
self.validator.validate(request) self.validator.validate(request)
except validators.ValidationError as e: except api_exc.ValidationError as e:
self.assertEqual(e.title, self.assertEqual(e.title,
'Malformed signature') 'Malformed signature')
@ -275,6 +276,6 @@ class PubkeyValidatorTestCase(base.BaseTestCase):
request.body = json.dumps(body) request.body = json.dumps(body)
try: try:
self.validator.validate(request) self.validator.validate(request)
except validators.ValidationError as e: except api_exc.ValidationError as e:
self.assertEqual(e.title, self.assertEqual(e.title,
'Malformed public key') 'Malformed public key')