diff --git a/tests/test_flask.py b/tests/test_flask.py index b1702b9..da9e1c6 100644 --- a/tests/test_flask.py +++ b/tests/test_flask.py @@ -9,6 +9,11 @@ class Model(Base): name = text +class Criterion(Base): + op = text + attr = text + value = text + test_app = Flask(__name__) @@ -25,9 +30,13 @@ def divide_by_zero(): @test_app.route('/models') -@signature([Model]) -def list_models(): - return [Model(name='first')] +@signature([Model], [Criterion]) +def list_models(q=None): + if q: + name = q[0].value + else: + name = 'first' + return [Model(name=name)] @test_app.route('/models/') @@ -63,6 +72,14 @@ class FlaskrTestCase(unittest.TestCase): resp = self.app.get('/models') assert resp.status_code == 200 + def test_array_parameter(self): + resp = self.app.get('/models?q.op=%3D&q.attr=name&q.value=second') + assert resp.status_code == 200 + print resp.data + self.assertEquals( + resp.data, '[{"name": "second"}]' + ) + def test_post_model(self): resp = self.app.post('/models', data={"body.name": "test"}) assert resp.status_code == 200 diff --git a/wsme/rest/args.py b/wsme/rest/args.py index f88ec3b..7769e18 100644 --- a/wsme/rest/args.py +++ b/wsme/rest/args.py @@ -81,11 +81,19 @@ def from_params(datatype, params, path, hit_paths): @from_params.when_type(ArrayType) def array_from_params(datatype, params, path, hit_paths): + if hasattr(params, 'getall'): + # webob multidict + def getall(params, path): + return params.getall(path) + elif hasattr(params, 'getlist'): + # werkzeug multidict + def getall(params, path): # noqa + return params.getlist(path) if path in params: hit_paths.add(path) return [ from_param(datatype.item_type, value) - for value in params.getall(path)] + for value in getall(params, path)] if iscomplex(datatype.item_type): attributes = set() @@ -99,7 +107,7 @@ def array_from_params(datatype, params, path, hit_paths): for attrdef in list_attributes(datatype.item_type): attrpath = '%s.%s' % (path, attrdef.key) hit_paths.add(attrpath) - attrvalues = params.getall(attrpath) + attrvalues = getall(params, attrpath) if len(value) < len(attrvalues): value[-1:] = [ datatype.item_type() @@ -158,7 +166,9 @@ def args_from_args(funcdef, args, kwargs): newargs.append(from_param(argdef.datatype, arg)) newkwargs = {} for argname, value in kwargs.items(): - newkwargs[argname] = from_param(funcdef.get_arg(argname).datatype, value) + newkwargs[argname] = from_param( + funcdef.get_arg(argname).datatype, value + ) return newargs, newkwargs