diff --git a/trove/guestagent/common/guestagent_utils.py b/trove/guestagent/common/guestagent_utils.py index dcd8894f39..d408a3491b 100644 --- a/trove/guestagent/common/guestagent_utils.py +++ b/trove/guestagent/common/guestagent_utils.py @@ -19,6 +19,8 @@ import re import six +from trove.common import pagination + def update_dict(updates, target): """Recursively update a target dictionary with given updates. @@ -120,3 +122,21 @@ def to_bytes(value): return int(round(factor * float(value))) return value + + +def paginate_list(li, limit=None, marker=None, include_marker=False): + """Paginate a list of objects based on the name attribute. + :returns: Page sublist and a marker (name of the last item). + """ + return pagination.paginate_object_list( + li, 'name', limit=limit, marker=marker, include_marker=include_marker) + + +def serialize_list(li, limit=None, marker=None, include_marker=False): + """ + Paginate (by name) and serialize a given object list. + :returns: A serialized and paginated version of a given list. + """ + page, next_name = paginate_list(li, limit=limit, marker=marker, + include_marker=include_marker) + return [item.serialize() for item in page], next_name diff --git a/trove/guestagent/datastore/experimental/cassandra/service.py b/trove/guestagent/datastore/experimental/cassandra/service.py index 7d8d06bac5..37a9dda3ac 100644 --- a/trove/guestagent/datastore/experimental/cassandra/service.py +++ b/trove/guestagent/datastore/experimental/cassandra/service.py @@ -29,7 +29,6 @@ from trove.common import cfg from trove.common import exception from trove.common.i18n import _ from trove.common import instance as rd_instance -from trove.common import pagination from trove.common.stream_codecs import IniCodec from trove.common.stream_codecs import PropertiesCodec from trove.common.stream_codecs import SafeYamlCodec @@ -840,9 +839,9 @@ class CassandraAdmin(object): List all non-superuser accounts. Omit names on the ignored list. Return an empty set if None. """ - users = [user.serialize() for user in - self._get_listed_users(self.client)] - return pagination.paginate_list(users, limit, marker, include_marker) + return guestagent_utils.serialize_list( + self._get_listed_users(self.client), + limit=limit, marker=marker, include_marker=include_marker) def _get_listed_users(self, client): """ @@ -1093,10 +1092,9 @@ class CassandraAdmin(object): def list_databases(self, context, limit=None, marker=None, include_marker=False): - databases = [keyspace.serialize() for keyspace - in self._get_available_keyspaces(self.client)] - return pagination.paginate_list(databases, limit, marker, - include_marker) + return guestagent_utils.serialize_list( + self._get_available_keyspaces(self.client), + limit=limit, marker=marker, include_marker=include_marker) def _get_available_keyspaces(self, client): """ diff --git a/trove/guestagent/datastore/experimental/mongodb/service.py b/trove/guestagent/datastore/experimental/mongodb/service.py index 6b4332147a..becffb2375 100644 --- a/trove/guestagent/datastore/experimental/mongodb/service.py +++ b/trove/guestagent/datastore/experimental/mongodb/service.py @@ -23,7 +23,6 @@ from trove.common import cfg from trove.common import exception from trove.common.i18n import _ from trove.common import instance as ds_instance -from trove.common import pagination from trove.common.stream_codecs import JsonCodec, SafeYamlCodec from trove.common import utils as utils from trove.guestagent.common.configuration import ConfigurationManager @@ -591,10 +590,11 @@ class MongoDBAdmin(object): user = models.MongoDBUser(name=user_info['_id']) user.roles = user_info['roles'] if self._is_modifiable_user(user.name): - users.append(user.serialize()) + users.append(user) LOG.debug('users = ' + str(users)) - return pagination.paginate_list(users, limit, marker, - include_marker) + return guestagent_utils.serialize_list( + users, + limit=limit, marker=marker, include_marker=include_marker) def change_passwords(self, users): with MongoDBClient(self._admin_user()) as admin_client: @@ -726,11 +726,12 @@ class MongoDBAdmin(object): for hidden in cfg.get_ignored_dbs(): if hidden in db_names: db_names.remove(hidden) - databases = [models.MongoDBSchema(db_name).serialize() + databases = [models.MongoDBSchema(db_name) for db_name in db_names] LOG.debug('databases = ' + str(databases)) - return pagination.paginate_list(databases, limit, marker, - include_marker) + return guestagent_utils.serialize_list( + databases, + limit=limit, marker=marker, include_marker=include_marker) def add_shard(self, url): """Runs the addShard command.""" diff --git a/trove/guestagent/datastore/experimental/postgresql/service/database.py b/trove/guestagent/datastore/experimental/postgresql/service/database.py index 1b174dcbe5..944236abbe 100644 --- a/trove/guestagent/datastore/experimental/postgresql/service/database.py +++ b/trove/guestagent/datastore/experimental/postgresql/service/database.py @@ -18,7 +18,7 @@ from oslo_log import log as logging from trove.common import cfg from trove.common.i18n import _ from trove.common.notification import EndNotification -from trove.common import pagination +from trove.guestagent.common import guestagent_utils from trove.guestagent.datastore.experimental.postgresql import pgutil from trove.guestagent.db import models @@ -97,9 +97,9 @@ class PgSqlDatabase(object): """List all databases on the instance. Return a paginated list of serialized Postgres databases. """ - page, next_name = pagination.paginate_object_list( - self._get_databases(), 'name', limit, marker, include_marker) - return [db.serialize() for db in page], next_name + return guestagent_utils.serialize_list( + self._get_databases(), + limit=limit, marker=marker, include_marker=include_marker) def _get_databases(self): """Return all non-system Postgres databases on the instance.""" diff --git a/trove/guestagent/datastore/experimental/postgresql/service/users.py b/trove/guestagent/datastore/experimental/postgresql/service/users.py index 8cfb864aae..81bcd1499d 100644 --- a/trove/guestagent/datastore/experimental/postgresql/service/users.py +++ b/trove/guestagent/datastore/experimental/postgresql/service/users.py @@ -19,8 +19,8 @@ from trove.common import cfg from trove.common import exception from trove.common.i18n import _ from trove.common.notification import EndNotification -from trove.common import pagination from trove.common import utils +from trove.guestagent.common import guestagent_utils from trove.guestagent.datastore.experimental.postgresql import pgutil from trove.guestagent.datastore.experimental.postgresql.service.access import ( PgSqlAccess) @@ -131,9 +131,9 @@ class PgSqlUsers(PgSqlAccess): """List all users on the instance along with their access permissions. Return a paginated list of serialized Postgres users. """ - page, next_name = pagination.paginate_object_list( - self._get_users(context), 'name', limit, marker, include_marker) - return [db.serialize() for db in page], next_name + return guestagent_utils.serialize_list( + self._get_users(context), + limit=limit, marker=marker, include_marker=include_marker) def _get_users(self, context): """Return all non-system Postgres users on the instance.""" diff --git a/trove/tests/unittests/guestagent/test_guestagent_utils.py b/trove/tests/unittests/guestagent/test_guestagent_utils.py index 16360b883d..3c599dc274 100644 --- a/trove/tests/unittests/guestagent/test_guestagent_utils.py +++ b/trove/tests/unittests/guestagent/test_guestagent_utils.py @@ -13,6 +13,10 @@ # License for the specific language governing permissions and limitations # under the License. +from mock import Mock +from mock import patch + +from trove.common import pagination from trove.guestagent.common import guestagent_utils from trove.tests.unittests import trove_testtools @@ -145,3 +149,29 @@ class TestGuestagentUtils(trove_testtools.TestCase): self.assertEqual('Hello!', guestagent_utils.to_bytes('Hello!')) self.assertEqual('', guestagent_utils.to_bytes('')) self.assertIsNone(guestagent_utils.to_bytes(None)) + + @patch.object(pagination, 'paginate_object_list') + def test_paginate_list(self, paginate_obj_mock): + limit = Mock() + marker = Mock() + include_marker = Mock() + test_list = [Mock(), Mock(), Mock()] + guestagent_utils.paginate_list( + test_list, + limit=limit, marker=marker, include_marker=include_marker) + paginate_obj_mock.assert_called_once_with( + test_list, 'name', + limit=limit, marker=marker, include_marker=include_marker) + + def test_serialize_list(self): + test_list = [Mock(), Mock(), Mock()] + with patch.object(guestagent_utils, 'paginate_list', + return_value=(test_list[:2], test_list[-2]) + ) as paginate_lst_mock: + _, next_name = guestagent_utils.serialize_list(test_list) + paginate_lst_mock.assert_called_once_with( + test_list, + limit=None, marker=None, include_marker=False) + for item in paginate_lst_mock.return_value[0]: + item.serialize.assert_called_once_with() + self.assertEqual(paginate_lst_mock.return_value[1], next_name) diff --git a/trove/tests/unittests/guestagent/test_mongodb_manager.py b/trove/tests/unittests/guestagent/test_mongodb_manager.py index f1b68e9f09..4fe4823965 100644 --- a/trove/tests/unittests/guestagent/test_mongodb_manager.py +++ b/trove/tests/unittests/guestagent/test_mongodb_manager.py @@ -241,7 +241,8 @@ class GuestAgentMongoDBManagerTest(DatastoreManagerTest): users, next_marker = self.manager.list_users(self.context) self.assertIsNone(next_marker) - self.assertEqual(sorted([user1, user2]), users) + self.assertEqual(sorted([user1, user2], key=lambda x: x['_name']), + users) @mock.patch.object(service.MongoDBAdmin, 'create_validated_user') @mock.patch.object(utils, 'generate_random_password', @@ -345,16 +346,14 @@ class GuestAgentMongoDBManagerTest(DatastoreManagerTest): 'db0', 'db1', 'db2', 'db3']) mocked_client().__enter__().database_names = mocked_list - marker = models.MongoDBSchema('db1').serialize() dbs, next_marker = self.manager.list_databases( - self.context, limit=2, marker=marker, include_marker=True) + self.context, limit=2, marker='db1', include_marker=True) mocked_list.assert_any_call() self.assertEqual([models.MongoDBSchema('db1').serialize(), models.MongoDBSchema('db2').serialize()], dbs) - self.assertEqual(models.MongoDBSchema('db2').serialize(), - next_marker) + self.assertEqual('db2', next_marker) @mock.patch.object(service, 'MongoDBClient') @mock.patch.object(service.MongoDBAdmin, '_admin_user')