diff --git a/stacktach/dbapi.py b/stacktach/dbapi.py index 6488bfd..98003a8 100644 --- a/stacktach/dbapi.py +++ b/stacktach/dbapi.py @@ -2,6 +2,7 @@ import decimal import functools import json +from django.db.models import FieldDoesNotExist from django.forms.models import model_to_dict from django.http import HttpResponse from django.http import HttpResponseBadRequest @@ -51,7 +52,7 @@ def api_call(func): @api_call def list_usage_launches(request): - filter_args = _get_filter_args(request) + filter_args = _get_filter_args(models.InstanceUsage, request) if len(filter_args) > 0: objects = models.InstanceUsage.objects.filter(**filter_args) @@ -69,7 +70,7 @@ def get_usage_launch(request, launch_id): @api_call def list_usage_deletes(request): - filter_args = _get_filter_args(request) + filter_args = _get_filter_args(models.InstanceDeletes, request) if len(filter_args) > 0: objects = models.InstanceDeletes.objects.filter(**filter_args) @@ -87,7 +88,7 @@ def get_usage_delete(request, delete_id): @api_call def list_usage_exists(request): - filter_args = _get_filter_args(request) + filter_args = _get_filter_args(models.InstanceExists, request) if len(filter_args) > 0: objects = models.InstanceExists.objects.filter(**filter_args) @@ -109,7 +110,15 @@ def _get_model_by_id(klass, model_id): return model_dict -def _get_filter_args(request): +def _check_has_field(klass, field_name): + try: + klass._meta.get_field_by_name(field_name) + except FieldDoesNotExist: + msg = "No such field '%s'." % field_name + raise BadRequestException(msg) + + +def _get_filter_args(klass, request): filter_args = {} if 'instance' in request.GET: filter_args['instance'] = request.GET['instance'] @@ -118,6 +127,7 @@ def _get_filter_args(request): if key.endswith('_min'): k = key[0:-4] + _check_has_field(klass, k) try: filter_args['%s__gte' % k] = utils.str_time_to_unix(value) except AttributeError: @@ -125,6 +135,7 @@ def _get_filter_args(request): raise BadRequestException(message=msg) elif key.endswith('_max'): k = key[0:-4] + _check_has_field(klass, k) try: filter_args['%s__lte' % k] = utils.str_time_to_unix(value) except AttributeError: diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index 9e97502..8d18e08 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -1,6 +1,7 @@ import datetime import unittest +from django.db.models import FieldDoesNotExist import mox from stacktach import dbapi @@ -15,18 +16,27 @@ class StacktachRawParsingTestCase(unittest.TestCase): def tearDown(self): self.mox.UnsetStubs() + def make_fake_model(self): + fake_model = self.mox.CreateMockAnything() + fake_meta = self.mox.CreateMockAnything() + fake_model._meta = fake_meta + return fake_model + def test_get_filter_args(self): start_time = datetime.datetime.utcnow() start_decimal = utils.decimal_utc(start_time) end_time = start_time + datetime.timedelta(days=1) end_decimal = utils.decimal_utc(end_time) fake_request = self.mox.CreateMockAnything() + fake_model = self.make_fake_model() + fake_model._meta.get_field_by_name('launched_at') + fake_model._meta.get_field_by_name('launched_at') fake_request.GET = {'instance': INSTANCE_ID_1, 'launched_at_min': str(start_time), 'launched_at_max': str(end_time)} self.mox.ReplayAll() - filter_args = dbapi._get_filter_args(fake_request) + filter_args = dbapi._get_filter_args(fake_model, fake_request) self.mox.VerifyAll() self.assertEquals(filter_args['instance'], INSTANCE_ID_1) @@ -38,19 +48,37 @@ class StacktachRawParsingTestCase(unittest.TestCase): def test_get_filter_args_bad_min_value(self): fake_request = self.mox.CreateMockAnything() fake_request.GET = {'launched_at_min': 'obviouslybaddatetime'} + fake_model = self.make_fake_model() + fake_model._meta.get_field_by_name('launched_at') self.mox.ReplayAll() self.assertRaises(dbapi.BadRequestException, dbapi._get_filter_args, - fake_request) + fake_model, fake_request) self.mox.VerifyAll() def test_get_filter_args_bad_max_value(self): fake_request = self.mox.CreateMockAnything() fake_request.GET = {'launched_at_max': 'obviouslybaddatetime'} + fake_model = self.make_fake_model() + fake_model._meta.get_field_by_name('launched_at') self.mox.ReplayAll() self.assertRaises(dbapi.BadRequestException, dbapi._get_filter_args, - fake_request) + fake_model, fake_request) + + self.mox.VerifyAll() + + def test_get_filter_args_bad_range_key(self): + start_time = datetime.datetime.utcnow() + fake_request = self.mox.CreateMockAnything() + fake_request.GET = {'somebadfield_max': str(start_time)} + fake_model = self.make_fake_model() + fake_model._meta.get_field_by_name('somebadfield')\ + .AndRaise(FieldDoesNotExist()) + self.mox.ReplayAll() + + self.assertRaises(dbapi.BadRequestException, dbapi._get_filter_args, + fake_model, fake_request) self.mox.VerifyAll()