diff --git a/trove/cluster/models.py b/trove/cluster/models.py index 2fc5b37fc4..2e990fece2 100644 --- a/trove/cluster/models.py +++ b/trove/cluster/models.py @@ -22,6 +22,7 @@ from trove.common import exception from trove.common.i18n import _ from trove.common import remote from trove.common.strategies.cluster import strategy +from trove.common import utils from trove.datastore import models as datastore_models from trove.db import models as dbmodels from trove.instance import models as inst_models @@ -96,9 +97,7 @@ class Cluster(object): def load_all(cls, context, tenant_id): db_infos = DBCluster.find_all(tenant_id=tenant_id, deleted=False) - limit = int(context.limit or Cluster.DEFAULT_LIMIT) - if limit > Cluster.DEFAULT_LIMIT: - limit = Cluster.DEFAULT_LIMIT + limit = utils.pagination_limit(context.limit, Cluster.DEFAULT_LIMIT) data_view = DBCluster.find_by_pagination('clusters', db_infos, "foo", limit=limit, marker=context.marker) diff --git a/trove/common/utils.py b/trove/common/utils.py index 85fd27ed3c..a307d8a8a6 100644 --- a/trove/common/utils.py +++ b/trove/common/utils.py @@ -54,6 +54,11 @@ ENV = jinja2.Environment(loader=jinja2.ChoiceLoader([ ])) +def pagination_limit(limit, default_limit): + limit = int(limit or default_limit) + return min(limit, default_limit) + + def create_method_args_string(*args, **kwargs): """Returns a string representation of args and keyword args. diff --git a/trove/configuration/models.py b/trove/configuration/models.py index 5e1f9178a7..3d6c081e53 100644 --- a/trove/configuration/models.py +++ b/trove/configuration/models.py @@ -54,10 +54,8 @@ class Configurations(object): LOG.debug("No configurations found for tenant %s" % context.tenant) - limit = int(context.limit or Configurations.DEFAULT_LIMIT) - if limit > Configurations.DEFAULT_LIMIT: - limit = Configurations.DEFAULT_LIMIT - + limit = utils.pagination_limit(context.limit, + Configurations.DEFAULT_LIMIT) data_view = DBConfiguration.find_by_pagination('configurations', db_info, "foo", diff --git a/trove/extensions/mysql/models.py b/trove/extensions/mysql/models.py index 0c080ca888..80ba45b456 100644 --- a/trove/extensions/mysql/models.py +++ b/trove/extensions/mysql/models.py @@ -22,6 +22,7 @@ from oslo_log import log as logging from trove.common import cfg from trove.common import exception from trove.common.remote import create_guest_client +from trove.common import utils from trove.extensions.common.models import load_and_verify from trove.extensions.common.models import RootHistory from trove.guestagent.db import models as guest_models @@ -166,8 +167,7 @@ class UserAccess(object): def load_via_context(cls, context, instance_id): """Creates guest and fetches pagination arguments from the context.""" load_and_verify(context, instance_id) - limit = int(context.limit or cls.DEFAULT_LIMIT) - limit = cls.DEFAULT_LIMIT if limit > cls.DEFAULT_LIMIT else limit + limit = utils.pagination_limit(context.limit, cls.DEFAULT_LIMIT) client = create_guest_client(context, instance_id) # The REST API standard dictates that we *NEVER* include the marker. return cls.load_with_client(client=client, limit=limit, diff --git a/trove/instance/models.py b/trove/instance/models.py index 14690505f0..ee63583d38 100644 --- a/trove/instance/models.py +++ b/trove/instance/models.py @@ -1172,9 +1172,7 @@ class Instances(object): db_infos = DBInstance.find_all(tenant_id=context.tenant, cluster_id=None, deleted=False) - limit = int(context.limit or Instances.DEFAULT_LIMIT) - if limit > Instances.DEFAULT_LIMIT: - limit = Instances.DEFAULT_LIMIT + limit = utils.pagination_limit(context.limit, Instances.DEFAULT_LIMIT) data_view = DBInstance.find_by_pagination('instances', db_infos, "foo", limit=limit, marker=context.marker) diff --git a/trove/tests/unittests/common/test_utils.py b/trove/tests/unittests/common/test_utils.py index cfda365ff2..00c14f8fd0 100644 --- a/trove/tests/unittests/common/test_utils.py +++ b/trove/tests/unittests/common/test_utils.py @@ -75,3 +75,7 @@ class TestTroveExecuteWithTimeout(trove_testtools.TestCase): self.assertEqual(1, utils.unpack_singleton([[[1]]])) self.assertEqual([[1], [2]], utils.unpack_singleton([[1], [2]])) self.assertEqual(['a', 'b'], utils.unpack_singleton(['a', 'b'])) + + def test_pagination_limit(self): + self.assertEqual(5, utils.pagination_limit(5, 9)) + self.assertEqual(5, utils.pagination_limit(9, 5))