diff --git a/test-requirements.txt b/test-requirements.txt index f2be8cbf12..d47c55e236 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -24,4 +24,4 @@ mox3>=0.7.0 testtools>=1.4.0 discover testrepository>=0.0.18 - +pymongo>=3.0.2 diff --git a/trove/common/cfg.py b/trove/common/cfg.py index 12008f08cf..31ba19efa4 100644 --- a/trove/common/cfg.py +++ b/trove/common/cfg.py @@ -741,6 +741,10 @@ mongodb_opts = [ help='Namespace to load restore strategies from.', deprecated_name='restore_namespace', deprecated_group='DEFAULT'), + cfg.IntOpt('mongodb_port', default=27017, + help='Port for mongod and mongos instances.'), + cfg.IntOpt('configsvr_port', default=27019, + help='Port for instances running as config servers.'), ] # PostgreSQL diff --git a/trove/common/strategies/cluster/experimental/mongodb/guestagent.py b/trove/common/strategies/cluster/experimental/mongodb/guestagent.py index 1ab7bdbbdb..058c260ca0 100644 --- a/trove/common/strategies/cluster/experimental/mongodb/guestagent.py +++ b/trove/common/strategies/cluster/experimental/mongodb/guestagent.py @@ -60,3 +60,8 @@ class MongoDbGuestAgentAPI(guest_api.API): LOG.debug("Notify regarding cluster install completion") return self._call("cluster_complete", guest_api.AGENT_LOW_TIMEOUT, self.version_cap) + + def get_key(self): + LOG.debug("Requesting cluster key from guest") + return self._call("get_key", guest_api.AGENT_LOW_TIMEOUT, + self.version_cap) diff --git a/trove/common/strategies/cluster/experimental/mongodb/taskmanager.py b/trove/common/strategies/cluster/experimental/mongodb/taskmanager.py index 56de1cd3ef..5ded75603e 100644 --- a/trove/common/strategies/cluster/experimental/mongodb/taskmanager.py +++ b/trove/common/strategies/cluster/experimental/mongodb/taskmanager.py @@ -18,6 +18,7 @@ from eventlet.timeout import Timeout from trove.common import cfg from trove.common.i18n import _ from trove.common.strategies.cluster import base +from trove.common import utils from trove.instance.models import DBInstance from trove.instance.models import Instance from trove.openstack.common import log as logging @@ -94,6 +95,9 @@ class MongoDbClusterTasks(task_models.ClusterTasks): return False return True + def get_key(self, member): + return self.get_guest(member).get_key() + def create_cluster(self, context, cluster_id): LOG.debug("begin create_cluster for id: %s" % cluster_id) @@ -108,6 +112,8 @@ class MongoDbClusterTasks(task_models.ClusterTasks): if not self._all_instances_ready(instance_ids, cluster_id): return + LOG.debug("all instances in cluster %s ready." % cluster_id) + instances = [Instance.load(context, instance_id) for instance_id in instance_ids] @@ -134,11 +140,23 @@ class MongoDbClusterTasks(task_models.ClusterTasks): for instance in config_servers] LOG.debug("config server ips: %s" % config_server_ips) - LOG.debug("calling add_config_servers on query_routers") + # Give the query routers the configsvr ips to connect to. + # Create the admin user on the query routers. + # The first will create the user, and the others will just reset + # the password to the same value. + LOG.debug("calling add_config_servers on, and sending admin user " + "password to, query_routers") try: + admin_created = False + admin_password = utils.generate_random_password() for query_router in query_routers: - (self.get_guest(query_router) - .add_config_servers(config_server_ips)) + guest = self.get_guest(query_router) + guest.add_config_servers(config_server_ips) + if admin_created: + guest.store_admin_password(admin_password) + else: + guest.create_admin_user(admin_password) + admin_created = True except Exception: LOG.exception(_("error adding config servers")) self.update_statuses_on_failure(cluster_id) diff --git a/trove/guestagent/common/operating_system.py b/trove/guestagent/common/operating_system.py index 4d0ee031ed..07eec3d7ac 100644 --- a/trove/guestagent/common/operating_system.py +++ b/trove/guestagent/common/operating_system.py @@ -143,6 +143,14 @@ class FileMode(object): def SET_GRP_RW_OTH_R(cls): return cls(reset=[stat.S_IRGRP | stat.S_IWGRP | stat.S_IROTH]) # =0064 + @classmethod + def SET_USR_RO(cls): + return cls(reset=[stat.S_IRUSR]) # =0400 + + @classmethod + def SET_USR_RW(cls): + return cls(reset=[stat.S_IRUSR | stat.S_IWUSR]) # =0600 + @classmethod def ADD_READ_ALL(cls): return cls(add=[stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH]) # +0444 diff --git a/trove/guestagent/datastore/experimental/mongodb/manager.py b/trove/guestagent/datastore/experimental/mongodb/manager.py index 0905472522..57c5ea1918 100644 --- a/trove/guestagent/datastore/experimental/mongodb/manager.py +++ b/trove/guestagent/datastore/experimental/mongodb/manager.py @@ -23,8 +23,7 @@ from trove.common.i18n import _ from trove.common import instance as ds_instance from trove.guestagent import backup from trove.guestagent.common import operating_system -from trove.guestagent.datastore.experimental.mongodb import ( - service as mongo_service) +from trove.guestagent.datastore.experimental.mongodb import service from trove.guestagent.datastore.experimental.mongodb import system from trove.guestagent import dbaas from trove.guestagent import volume @@ -40,8 +39,8 @@ MANAGER = CONF.datastore_manager class Manager(periodic_task.PeriodicTasks): def __init__(self): - self.status = mongo_service.MongoDbAppStatus() - self.app = mongo_service.MongoDBApp(self.status) + self.status = service.MongoDBAppStatus() + self.app = service.MongoDBApp(self.status) @periodic_task.periodic_task(ticks_between_runs=3) def update_status(self, context): @@ -80,6 +79,7 @@ class Manager(periodic_task.PeriodicTasks): LOG.debug("Mounted the volume %(path)s as %(mount)s." % {'path': device_path, "mount": mount_point}) + self.app.secure(cluster_config) conf_changes = self.get_config_changes(cluster_config, mount_point) config_contents = self.app.update_config_contents( config_contents, conf_changes) @@ -115,8 +115,13 @@ class Manager(periodic_task.PeriodicTasks): def get_config_changes(self, cluster_config, mount_point=None): LOG.debug("Getting configuration changes.") config_changes = {} + # todo mvandijk: uncomment the following when auth is being enabled + # config_changes['auth'] = 'true' + config_changes['bind_ip'] = ','.join([netutils.get_my_ipv4(), + '127.0.0.1']) if cluster_config is not None: - config_changes['bind_ip'] = netutils.get_my_ipv4() + # todo mvandijk: uncomment the following when auth is being enabled + # config_changes['keyFile'] = self.app.get_key_file() if cluster_config["instance_type"] == "config_server": config_changes["configsvr"] = "true" elif cluster_config["instance_type"] == "member": @@ -344,3 +349,14 @@ class Manager(periodic_task.PeriodicTasks): LOG.debug("Cluster creation complete, starting status checks.") status = self.status._get_actual_db_status() self.status.set_status(status) + + def get_key(self, context): + # Return the cluster key + LOG.debug("Getting the cluster key.") + return self.app.get_key() + + def create_admin_user(self, context, password): + self.app.create_admin_user(password) + + def store_admin_password(self, context, password): + self.app.store_admin_password(password) diff --git a/trove/guestagent/datastore/experimental/mongodb/service.py b/trove/guestagent/datastore/experimental/mongodb/service.py index 50ec2dbba9..850b6f55dc 100644 --- a/trove/guestagent/datastore/experimental/mongodb/service.py +++ b/trove/guestagent/datastore/experimental/mongodb/service.py @@ -16,8 +16,10 @@ import json import os import re +import tempfile from oslo_utils import netutils +import pymongo from trove.common import cfg from trove.common import exception @@ -28,12 +30,16 @@ from trove.common import utils as utils from trove.guestagent.common import operating_system from trove.guestagent.datastore.experimental.mongodb import system from trove.guestagent.datastore import service +from trove.guestagent.db import models from trove.openstack.common import log as logging + LOG = logging.getLogger(__name__) CONF = cfg.CONF CONFIG_FILE = (operating_system. file_discovery(system.CONFIG_CANDIDATES)) +MONGODB_PORT = CONF.mongodb.mongodb_port +CONFIGSVR_PORT = CONF.mongodb.configsvr_port class MongoDBApp(object): @@ -221,7 +227,8 @@ class MongoDBApp(object): This method is used by query router (mongos) instances. """ config_contents = self._read_config() - configdb_contents = ','.join(['%s:27019' % host + configdb_contents = ','.join(['%(host)s:%(port)s' + % {'host': host, 'port': CONFIGSVR_PORT} for host in config_server_hosts]) LOG.debug("Config server list %s." % configdb_contents) # remove db path from config and update configdb @@ -251,44 +258,29 @@ class MongoDBApp(object): operating_system.remove('/etc/init/mongodb.conf', force=True, as_root=True) - def do_mongo(self, db_cmd): - cmd = ('mongo --host ' + netutils.get_my_ipv4() + - ' --quiet --eval \'printjson(%s)\'' % db_cmd) - # TODO(ramashri) see if hardcoded values can be removed - out, err = utils.execute_with_timeout(cmd, shell=True, timeout=100) - LOG.debug(out.strip()) - return (out, err) - def add_shard(self, replica_set_name, replica_set_member): """ This method is used by query router (mongos) instances. """ - cmd = 'db.adminCommand({addShard: "%s/%s:27017"})' % ( - replica_set_name, replica_set_member) - self.do_mongo(cmd) + url = "%(rs)s/%(host)s:%(port)s"\ + % {'rs': replica_set_name, + 'host': replica_set_member, + 'port': MONGODB_PORT} + MongoDBAdmin().add_shard(url) def add_members(self, members): """ This method is used by a replica-set member instance. """ - def clean_json(val): - """ - This method removes from json, values that are functions like - ISODate(), TimeStamp(). - """ - return re.sub(':\s*\w+\(\"?(.*?)\"?\)', r': "\1"', val) - def check_initiate_status(): """ This method is used to verify replica-set status. """ - out, err = self.do_mongo("rs.status()") - response = clean_json(out.strip()) - json_data = json.loads(response) + status = MongoDBAdmin().get_repl_status() - if((json_data["ok"] == 1) and - (json_data["members"][0]["stateStr"] == "PRIMARY") and - (json_data["myState"] == 1)): + if((status["ok"] == 1) and + (status["members"][0]["stateStr"] == "PRIMARY") and + (status["myState"] == 1)): return True else: return False @@ -297,16 +289,14 @@ class MongoDBApp(object): """ This method is used to verify replica-set status. """ - out, err = self.do_mongo("rs.status()") - response = clean_json(out.strip()) - json_data = json.loads(response) + status = MongoDBAdmin().get_repl_status() primary_count = 0 - if json_data["ok"] != 1: + if status["ok"] != 1: return False - if len(json_data["members"]) != (len(members) + 1): + if len(status["members"]) != (len(members) + 1): return False - for rs_member in json_data["members"]: + for rs_member in status["members"]: if rs_member["state"] not in [1, 2, 7]: return False if rs_member["health"] != 1: @@ -316,26 +306,89 @@ class MongoDBApp(object): return primary_count == 1 + # Create the admin user on this member. + # This is only necessary for setting up the replica set. + # The query router will handle requests once this set + # is added as a shard. + password = utils.generate_random_password() + self.create_admin_user(password) + # initiate replica-set - self.do_mongo("rs.initiate()") + MongoDBAdmin().rs_initiate() # TODO(ramashri) see if hardcoded values can be removed utils.poll_until(check_initiate_status, sleep_time=60, time_out=100) # add replica-set members - for member in members: - self.do_mongo('rs.add("' + member + '")') + MongoDBAdmin().rs_add_members(members) # TODO(ramashri) see if hardcoded values can be removed utils.poll_until(check_rs_status, sleep_time=60, time_out=100) - def list_databases(self): - cmd = 'db.adminCommand("listDatabases").databases' - out, err = self.do_mongo(cmd) - out.strip() - dbs = json.loads(out) - return [d['name'] for d in dbs] + def list_all_dbs(self): + return MongoDBAdmin().list_database_names() + + def db_data_size(self, db_name): + schema = models.MongoDBSchema(db_name) + return MongoDBAdmin().db_stats(schema.serialize())['dataSize'] + + def admin_cmd_auth_params(self): + return MongoDBAdmin().cmd_admin_auth_params + + def get_key_file(self): + return system.MONGO_KEY_FILE + + def get_key(self): + return open(system.MONGO_KEY_FILE).read().rstrip() + + def store_key(self, key): + """Store the cluster key.""" + LOG.debug('Storing key for MongoDB cluster.') + with tempfile.NamedTemporaryFile() as f: + f.write(key) + f.flush() + operating_system.copy(f.name, system.MONGO_KEY_FILE, + force=True, as_root=True) + operating_system.chmod(system.MONGO_KEY_FILE, + operating_system.FileMode.SET_USR_RO, + as_root=True) + operating_system.chown(system.MONGO_KEY_FILE, + system.MONGO_USER, system.MONGO_USER, + as_root=True) + + def store_admin_password(self, password): + LOG.debug('Storing admin password.') + creds = MongoDBCredentials(username=system.MONGO_ADMIN_NAME, + password=password) + creds.write(system.MONGO_ADMIN_CREDS_FILE) + return creds + + def create_admin_user(self, password): + """Create the admin user while the localhost exception is active.""" + LOG.debug('Creating the admin user.') + creds = self.store_admin_password(password) + user = models.MongoDBUser(name='admin.%s' % creds.username, + password=creds.password) + user.roles = system.MONGO_ADMIN_ROLES + user.databases = 'admin' + with MongoDBClient(user, auth=False) as client: + MongoDBAdmin().create_user(user, client=client) + LOG.debug('Created admin user.') + + def secure(self, cluster_config=None): + # Secure the server by storing the cluster key if this is a cluster + # or creating the admin user if this is a single instance. + LOG.debug('Securing MongoDB instance.') + if cluster_config: + self.store_key(cluster_config['key']) + else: + LOG.debug('Generating admin password.') + password = utils.generate_random_password() + self.start_db() + self.create_admin_user(password) + self.stop_db() + LOG.debug('MongoDB secure complete.') -class MongoDbAppStatus(service.BaseDbStatus): +class MongoDBAppStatus(service.BaseDbStatus): is_config_server = None is_query_router = None @@ -367,12 +420,13 @@ class MongoDbAppStatus(service.BaseDbStatus): if self._is_config_server() is True: status_check = (system.CMD_STATUS % (netutils.get_my_ipv4() + - ' --port 27019')) + ' --port %s' % CONFIGSVR_PORT)) else: status_check = (system.CMD_STATUS % netutils.get_my_ipv4()) - out, err = utils.execute_with_timeout(status_check, shell=True) + out, err = utils.execute_with_timeout(status_check, shell=True, + check_exit_code=[0, 1]) if not err: return ds_instance.ServiceStatuses.RUNNING else: @@ -383,3 +437,171 @@ class MongoDbAppStatus(service.BaseDbStatus): except OSError as e: LOG.exception(_("OS Error %s.") % e) return ds_instance.ServiceStatuses.SHUTDOWN + + +class MongoDBAdmin(object): + """Handles administrative tasks on MongoDB.""" + + # user is cached by making it a class attribute + admin_user = None + + def _admin_user(self): + if not type(self).admin_user: + creds = MongoDBCredentials() + creds.read(system.MONGO_ADMIN_CREDS_FILE) + user = models.MongoDBUser( + 'admin.%s' % creds.username, + creds.password + ) + user.databases = 'admin' + type(self).admin_user = user + return type(self).admin_user + + @property + def cmd_admin_auth_params(self): + """Returns a list of strings that constitute MongoDB command line + authentication parameters. + """ + user = self._admin_user() + return ['--username', user.username, + '--password', user.password, + '--authenticationDatabase', user.database.name] + + def _create_user_with_client(self, user, client): + """Run the create user.""" + client[user.database.name].add_user( + user.username, password=user.password, roles=user.roles + ) + + def create_user(self, user, client=None): + """Creates a user, authenticated on the specified database.""" + if client: + self._create_user_with_client(user, client) + else: + with MongoDBClient(self._admin_user()) as admin_client: + self._create_user_with_client(user, admin_client) + + def list_database_names(self): + """Get the list of database names.""" + with MongoDBClient(self._admin_user()) as admin_client: + return admin_client.database_names() + + def add_shard(self, url): + """Runs the addShard command.""" + with MongoDBClient(self._admin_user()) as admin_client: + admin_client.admin.command({'addShard': url}) + + def get_repl_status(self): + """Runs the replSetGetStatus command.""" + with MongoDBClient(self._admin_user()) as admin_client: + return admin_client.admin.command('replSetGetStatus') + + def rs_initiate(self): + """Runs the replSetInitiate command.""" + with MongoDBClient(self._admin_user()) as admin_client: + return admin_client.admin.command('replSetInitiate') + + def rs_add_members(self, members): + """Adds the given members to the replication set.""" + with MongoDBClient(self._admin_user()) as admin_client: + # get the current config, add the new members, then save it + config = admin_client.admin.command('replSetGetConfig')['config'] + config['version'] += 1 + next_id = max([m['_id'] for m in config['members']]) + 1 + for member in members: + config['members'].append({'_id': next_id, 'host': member}) + next_id += 1 + admin_client.admin.command('replSetReconfig', config) + + def db_stats(self, database, scale=1): + """Gets the stats for the given database.""" + with MongoDBClient(self._admin_user()) as admin_client: + db_name = models.MongoDBSchema.deserialize_schema(database).name + return admin_client[db_name].command('dbStats', scale=scale) + + +class MongoDBClient(object): + """A wrapper to manage a MongoDB connection.""" + + # engine information is cached by making it a class attribute + engine = {} + + def __init__(self, user, host=None, port=None, + auth=True): + """Get the client. Specifying host and/or port updates cached values. + :param user: (required) MongoDBUser instance + :param host: server address, defaults to localhost + :param port: server port, defaults to 27017 + :param auth: set to False to disable authentication, default True + :return: + """ + new_client = False + self._logged_in = False + if not type(self).engine: + # no engine cached + type(self).engine['host'] = (host if host else 'localhost') + type(self).engine['port'] = (port if port else MONGODB_PORT) + new_client = True + elif host or port: + LOG.debug("Updating MongoDB client.") + if host: + type(self).engine['host'] = host + if port: + type(self).engine['host'] = port + new_client = True + if new_client: + host = type(self).engine['host'] + port = type(self).engine['port'] + LOG.debug("Creating MongoDB client to %(host)s:%(port)s." + % {'host': host, 'port': port}) + type(self).engine['client'] = pymongo.MongoClient(host=host, + port=port, + connect=False) + self.session = type(self).engine['client'] + if auth: + db_name = user.database.name + LOG.debug("Authentication MongoDB client on %s." % db_name) + self._db = self.session[db_name] + self._db.authenticate(user.username, password=user.password) + self._logged_in = True + + def __enter__(self): + return self.session + + def __exit__(self, exc_type, exc_value, traceback): + LOG.debug("Disconnecting from MongoDB.") + if self._logged_in: + self._db.logout() + self.session.close() + + +class MongoDBCredentials(object): + """Handles storing/retrieving credentials. Stored as json in files.""" + + def __init__(self, username=None, password=None): + self.username = username + self.password = password + + def read(self, filename): + with open(filename) as f: + credentials = json.load(f) + self.username = credentials['username'] + self.password = credentials['password'] + + def write(self, filename): + self.clear_file(filename) + with open(filename, 'w') as f: + credentials = {'username': self.username, + 'password': self.password} + json.dump(credentials, f) + + @staticmethod + def clear_file(filename): + LOG.debug("Creating clean file %s" % filename) + if operating_system.file_discovery([filename]): + operating_system.remove(filename) + # force file creation by just opening it + open(filename, 'wb') + operating_system.chmod(filename, + operating_system.FileMode.SET_USR_RW, + as_root=True) diff --git a/trove/guestagent/datastore/experimental/mongodb/system.py b/trove/guestagent/datastore/experimental/mongodb/system.py index 2b0bf1e3ac..81e00d0f03 100644 --- a/trove/guestagent/datastore/experimental/mongodb/system.py +++ b/trove/guestagent/datastore/experimental/mongodb/system.py @@ -13,6 +13,8 @@ # License for the specific language governing permissions and limitations # under the License. +from os import path + from trove.guestagent.common import operating_system from trove.guestagent import pkg @@ -27,6 +29,13 @@ TMP_CONFIG = "/tmp/mongodb.conf.tmp" CONFIG_CANDIDATES = ["/etc/mongodb.conf", "/etc/mongod.conf"] MONGOS_UPSTART = "/etc/init/mongos.conf" TMP_MONGOS_UPSTART = "/tmp/mongos.conf.tmp" +MONGO_ADMIN_NAME = 'os_admin' +MONGO_ADMIN_ROLES = ['userAdminAnyDatabase', + 'dbAdminAnyDatabase', + 'clusterAdmin'] +MONGO_ADMIN_CREDS_FILE = path.join(path.expanduser('~'), + '.os_mongo_admin_creds.json') +MONGO_KEY_FILE = '/etc/mongo_key' MONGOS_SERVICE_CANDIDATES = ["mongos"] MONGOD_SERVICE_CANDIDATES = ["mongodb", "mongod"] MONGODB_KILL = "sudo kill %s" diff --git a/trove/guestagent/db/models.py b/trove/guestagent/db/models.py index f3290208ff..eab399effe 100644 --- a/trove/guestagent/db/models.py +++ b/trove/guestagent/db/models.py @@ -13,12 +13,14 @@ # License for the specific language governing permissions and limitations # under the License. +import abc import re import string import netaddr from trove.common import cfg +from trove.common import exception from trove.common.i18n import _ CONF = cfg.CONF @@ -31,6 +33,129 @@ class Base(object): def deserialize(self, o): self.__dict__ = o + @classmethod + def _validate_dict(cls, value): + reqs = cls._dict_requirements() + return (isinstance(value, dict) and + all(key in value for key in reqs)) + + @classmethod + @abc.abstractmethod + def _dict_requirements(cls): + """Get the dictionary requirements for a user created via + deserialization. + :returns: List of required dictionary keys. + """ + + +class DatastoreSchema(Base): + """Represents a database schema.""" + + def __init__(self): + self._name = None + self._collate = None + self._character_set = None + + @classmethod + def deserialize_schema(cls, value): + if not cls._validate_dict(value): + raise ValueError(_("Bad dictionary. Keys: %(keys)s. " + "Required: %(reqs)s") + % ({'keys': value.keys(), + 'reqs': cls._dict_requirements()})) + + schema = cls(deserializing=True) + schema.deserialize(value) + return schema + + @property + def name(self): + return self._name + + @name.setter + def name(self, value): + self._validate_schema_name(value) + self._name = value + + @property + def collate(self): + return self._collate + + @property + def character_set(self): + return self._character_set + + def _validate_schema_name(self, value): + """Perform validations on a given schema name. + :param value: Validated schema name. + :type value: string + :raises: ValueError On validation errors. + """ + if self._max_schema_name_length and (len(value) > + self._max_schema_name_length): + raise ValueError(_("Schema name '%(name)s' is too long. " + "Max length = %(max_length)d.") + % {'name': value, + 'max_length': self._max_schema_name_length}) + elif not self._is_valid_schema_name(value): + raise ValueError(_("'%s' is not a valid schema name.") % value) + + @abc.abstractproperty + def _max_schema_name_length(self): + """Return the maximum valid schema name length if any. + :returns: Maximum schema name length or None if unlimited. + """ + + @abc.abstractmethod + def _is_valid_schema_name(self, value): + """Validate a given schema name. + :param value: Validated schema name. + :type value: string + :returns: TRUE if valid, FALSE otherwise. + """ + + @classmethod + @abc.abstractmethod + def _dict_requirements(cls): + """Get the dictionary requirements for a user created via + deserialization. + :returns: List of required dictionary keys. + """ + + +class MongoDBSchema(DatastoreSchema): + """Represents the MongoDB schema and its associated properties. + + MongoDB database names are limited to 128 characters, + alphanumeric and - and _ only. + """ + + name_regex = re.compile(r'^[a-zA-Z0-9_\-]+$') + + def __init__(self, name=None, deserializing=False): + super(MongoDBSchema, self).__init__() + # need one or the other, not both, not none (!= ~ XOR) + if not (bool(deserializing) != bool(name)): + raise ValueError(_("Bad args. name: %(name)s, " + "deserializing %(deser)s.") + % ({'name': bool(name), + 'deser': bool(deserializing)})) + if not deserializing: + self.name = name + + @property + def _max_schema_name_length(self): + return 64 + + def _is_valid_schema_name(self, value): + # check against the invalid character set from + # http://docs.mongodb.org/manual/reference/limits + return not any(c in value for c in '/\. "$') + + @classmethod + def _dict_requirements(cls): + return ['_name'] + class MySQLDatabase(Base): """Represents a Database and its properties.""" @@ -349,6 +474,265 @@ class ValidatedMySQLDatabase(MySQLDatabase): self._name = value +class DatastoreUser(Base): + """Represents a datastore user.""" + + _HOSTNAME_WILDCARD = '%' + + def __init__(self): + self._name = None + self._password = None + self._host = None + self._databases = [] + + @classmethod + def deserialize_user(cls, value): + if not cls._validate_dict(value): + raise ValueError(_("Bad dictionary. Keys: %(keys)s. " + "Required: %(reqs)s") + % ({'keys': value.keys(), + 'reqs': cls._dict_requirements()})) + user = cls(deserializing=True) + user.deserialize(value) + return user + + @property + def name(self): + return self._name + + @name.setter + def name(self, value): + self._validate_user_name(value) + self._name = value + + @property + def password(self): + return self._password + + @password.setter + def password(self, value): + if self._is_valid_password(value): + self._password = value + else: + raise ValueError(_("'%s' is not a valid password.") % value) + + @property + def databases(self): + return self._databases + + @databases.setter + def databases(self, value): + mydb = self._build_database_schema(value) + self._databases.append(mydb.serialize()) + + @property + def host(self): + if self._host is None: + return self._HOSTNAME_WILDCARD + return self._host + + @host.setter + def host(self, value): + if self._is_valid_host_name(value): + self._host = value + else: + raise ValueError(_("'%s' is not a valid hostname.") % value) + + @abc.abstractmethod + def _build_database_schema(self, name): + """Build a schema for this user. + :type name: string + :type character_set: string + :type collate: string + """ + + def _validate_user_name(self, value): + """Perform validations on a given user name. + :param value: Validated user name. + :type value: string + :raises: ValueError On validation errors. + """ + if self._max_username_length and (len(value) > + self._max_username_length): + raise ValueError(_("User name '%(name)s' is too long. " + "Max length = %(max_length)d.") + % {'name': value, + 'max_length': self._max_username_length}) + elif not self._is_valid_name(value): + raise ValueError(_("'%s' is not a valid user name.") % value) + + @abc.abstractproperty + def _max_username_length(self): + """Return the maximum valid user name length if any. + :returns: Maximum user name length or None if unlimited. + """ + + @abc.abstractmethod + def _is_valid_name(self, value): + """Validate a given user name. + :param value: User name to be validated. + :type value: string + :returns: TRUE if valid, FALSE otherwise. + """ + + @abc.abstractmethod + def _is_valid_host_name(self, value): + """Validate a given host name. + :param value: Host name to be validated. + :type value: string + :returns: TRUE if valid, FALSE otherwise. + """ + + @abc.abstractmethod + def _is_valid_password(self, value): + """Validate a given password. + :param value: Password to be validated. + :type value: string + :returns: TRUE if valid, FALSE otherwise. + """ + + @classmethod + @abc.abstractmethod + def _dict_requirements(cls): + """Get the dictionary requirements for a user created via + deserialization. + :returns: List of required dictionary keys. + """ + + +class MongoDBUser(DatastoreUser): + """Represents a MongoDB user and its associated properties. + MongoDB users are identified using their namd and database. + Trove stores this as . + """ + + def __init__(self, name=None, password=None, deserializing=False): + super(MongoDBUser, self).__init__() + self._name = None + self._username = None + self._database = None + self._roles = [] + # need only one of: deserializing, name, or (name and password) + if ((not (bool(deserializing) != bool(name))) or + (bool(deserializing) and bool(password))): + raise ValueError(_("Bad args. name: %(name)s, " + "password %(pass)s, " + "deserializing %(deser)s.") + % ({'name': bool(name), + 'pass': bool(password), + 'deser': bool(deserializing)})) + if not deserializing: + self.name = name + self.password = password + + @property + def username(self): + return self._username + + @username.setter + def username(self, value): + self._update_name(username=value) + + @property + def database(self): + return MongoDBSchema.deserialize_schema(self._database) + + @database.setter + def database(self, value): + self._update_name(database=value) + + @property + def name(self): + return self._name + + @name.setter + def name(self, value): + self._update_name(name=value) + + def _update_name(self, name=None, username=None, database=None): + """Keep the name, username, and database values in sync.""" + if name: + (database, username) = self._parse_name(name) + if not (database and username): + missing = 'username' if self.database else 'database' + raise ValueError(_("MongoDB user's name missing %s.") + % missing) + else: + if username: + if not self.database: + raise ValueError(_('MongoDB user missing database.')) + database = self.database.name + else: # database + if not self.username: + raise ValueError(_('MongoDB user missing username.')) + username = self.username + name = '%s.%s' % (database, username) + self._name = name + self._username = username + self._database = self._build_database_schema(database).serialize() + + @property + def roles(self): + return self._roles + + @roles.setter + def roles(self, value): + if isinstance(value, list): + for role in value: + self._add_role(role) + else: + self._add_role(value) + + def _init_roles(self): + if '_roles' not in self.__dict__: + self._roles = [] + + @classmethod + def deserialize_user(cls, value): + user = super(MongoDBUser, cls).deserialize_user(value) + user.name = user._name + user._init_roles() + return user + + def _build_database_schema(self, name): + return MongoDBSchema(name) + + @staticmethod + def _parse_name(value): + """The name will be ., so split it.""" + parts = value.split('.', 1) + if len(parts) != 2: + raise exception.BadRequest(_( + 'MongoDB user name "%s" not in . format.' + ) % value) + return parts[0], parts[1] + + @property + def _max_username_length(self): + return None + + def _is_valid_name(self, value): + return True + + def _is_valid_host_name(self, value): + return True + + def _is_valid_password(self, value): + return True + + def _add_role(self, value): + if not self._is_valid_role(value): + raise ValueError(_('Role %s is invalid.') % value) + self._roles.append(value) + + def _is_valid_role(self, value): + return isinstance(value, dict) or isinstance(value, str) + + @classmethod + def _dict_requirements(cls): + return ['_name'] + + class MySQLUser(Base): """Represents a MySQL User and its associated properties.""" diff --git a/trove/guestagent/strategies/backup/experimental/mongo_impl.py b/trove/guestagent/strategies/backup/experimental/mongo_impl.py index fd2f251ce0..a047033f43 100644 --- a/trove/guestagent/strategies/backup/experimental/mongo_impl.py +++ b/trove/guestagent/strategies/backup/experimental/mongo_impl.py @@ -42,8 +42,9 @@ class MongoDump(base.BackupRunner): backup_cmd = 'mongodump --out ' + MONGO_DUMP_DIR def __init__(self, *args, **kwargs): - self.status = mongo_service.MongoDbAppStatus() + self.status = mongo_service.MongoDBAppStatus() self.app = mongo_service.MongoDBApp(self.status) + self.admin = mongo_service.MongoDBApp(self.status) super(MongoDump, self).__init__(*args, **kwargs) def _run_pre_backup(self): @@ -66,9 +67,12 @@ class MongoDump(base.BackupRunner): "nogroup", as_root=True) # high timeout here since mongodump can take a long time - utils.execute_with_timeout(self.backup_cmd, shell=True, - run_as_root=True, root_helper='sudo', - timeout=LARGE_TIMEOUT) + utils.execute_with_timeout( + 'mongodump', '--out', MONGO_DUMP_DIR, + *(self.app.admin_cmd_auth_params()), + run_as_root=True, root_helper='sudo', + timeout=LARGE_TIMEOUT + ) except exception.ProcessExecutionError as e: LOG.debug("Caught exception when creating the dump") self.cleanup() @@ -94,12 +98,12 @@ class MongoDump(base.BackupRunner): db.stats().dataSize. This seems to be conservative, as the actual bson output in many cases is a fair bit smaller. """ - dbstats_cmd = 'db.getSiblingDB("%s").stats().dataSize' - dbs = self.app.list_databases() + dbs = self.app.list_all_dbs() + # mongodump does not dump the content of the local database + dbs.remove('local') dbstats = dict([(d, 0) for d in dbs]) for d in dbstats: - out, err = self.app.do_mongo(dbstats_cmd % d) - dbstats[d] = int(out) + dbstats[d] = self.app.db_data_size(d) LOG.debug("Estimated size for databases: " + str(dbstats)) return sum(dbstats.values()) diff --git a/trove/guestagent/strategies/restore/experimental/mongo_impl.py b/trove/guestagent/strategies/restore/experimental/mongo_impl.py index d04e548e53..27faaf1c22 100644 --- a/trove/guestagent/strategies/restore/experimental/mongo_impl.py +++ b/trove/guestagent/strategies/restore/experimental/mongo_impl.py @@ -38,14 +38,16 @@ class MongoDump(base.RestoreRunner): def __init__(self, *args, **kwargs): super(MongoDump, self).__init__(*args, **kwargs) - self.status = mongo_service.MongoDbAppStatus() + self.status = mongo_service.MongoDBAppStatus() self.app = mongo_service.MongoDBApp(self.status) def post_restore(self): """ Restore from the directory that we untarred into """ - utils.execute_with_timeout("mongorestore", MONGO_DUMP_DIR, + params = self.app.admin_cmd_auth_params() + params.append(MONGO_DUMP_DIR) + utils.execute_with_timeout('mongorestore', *params, timeout=LARGE_TIMEOUT) operating_system.remove(MONGO_DUMP_DIR, force=True, as_root=True) diff --git a/trove/tests/unittests/guestagent/test_mongodb_cluster_manager.py b/trove/tests/unittests/guestagent/test_mongodb_cluster_manager.py index 2dbf023c7f..dd09561852 100644 --- a/trove/tests/unittests/guestagent/test_mongodb_cluster_manager.py +++ b/trove/tests/unittests/guestagent/test_mongodb_cluster_manager.py @@ -13,63 +13,73 @@ # License for the specific language governing permissions and limitations # under the License. -from mock import patch +import mock from oslo_utils import netutils -import testtools +import pymongo -from trove.common.context import TroveContext -from trove.common import instance as ds_instance -from trove.common import utils -from trove.guestagent.datastore.experimental.mongodb import ( - service as mongo_service) -from trove.guestagent.datastore.experimental.mongodb.manager import Manager -from trove.guestagent.datastore.experimental.mongodb.service import MongoDBApp -from trove.guestagent import volume +import trove.common.context as context +import trove.common.instance as ds_instance +import trove.common.utils as utils +import trove.guestagent.datastore.experimental.mongodb.manager as manager +import trove.guestagent.datastore.experimental.mongodb.service as service +import trove.guestagent.volume as volume +import trove.tests.unittests.trove_testtools as trove_testtools -class GuestAgentMongoDBClusterManagerTest(testtools.TestCase): +class GuestAgentMongoDBClusterManagerTest(trove_testtools.TestCase): def setUp(self): super(GuestAgentMongoDBClusterManagerTest, self).setUp() - self.context = TroveContext() - self.manager = Manager() + self.context = context.TroveContext() + self.manager = manager.Manager() + + self.pymongo_patch = mock.patch.object( + pymongo, 'MongoClient' + ) + self.addCleanup(self.pymongo_patch.stop) + self.pymongo_patch.start() def tearDown(self): super(GuestAgentMongoDBClusterManagerTest, self).tearDown() - @patch.object(mongo_service.MongoDbAppStatus, 'set_status') - @patch.object(MongoDBApp, 'add_members', side_effect=RuntimeError("Boom!")) + @mock.patch.object(service.MongoDBAppStatus, 'set_status') + @mock.patch.object(service.MongoDBApp, 'add_members', + side_effect=RuntimeError("Boom!")) def test_add_members_failure(self, mock_add_members, mock_set_status): members = ["test1", "test2"] self.assertRaises(RuntimeError, self.manager.add_members, self.context, members) mock_set_status.assert_called_with(ds_instance.ServiceStatuses.FAILED) - @patch.object(utils, 'poll_until') - @patch.object(MongoDBApp, 'do_mongo') - def test_add_member(self, mock_do_mongo, mock_poll): + @mock.patch.object(utils, 'poll_until') + @mock.patch.object(utils, 'generate_random_password', return_value='pwd') + @mock.patch.object(service.MongoDBApp, 'create_admin_user') + @mock.patch.object(service.MongoDBAdmin, 'rs_initiate') + @mock.patch.object(service.MongoDBAdmin, 'rs_add_members') + def test_add_member(self, mock_add, mock_initiate, + mock_user, mock_pwd, mock_poll): members = ["test1", "test2"] self.manager.add_members(self.context, members) - mock_do_mongo.assert_any_call("rs.initiate()") - mock_do_mongo.assert_any_call("rs.add(\"test1\")") - mock_do_mongo.assert_any_call("rs.add(\"test2\")") + mock_user.assert_any_call('pwd') + mock_initiate.assert_any_call() + mock_add.assert_any_call(["test1", "test2"]) - @patch.object(mongo_service.MongoDbAppStatus, 'set_status') - @patch.object(MongoDBApp, 'add_shard', side_effect=RuntimeError("Boom!")) + @mock.patch.object(service.MongoDBAppStatus, 'set_status') + @mock.patch.object(service.MongoDBApp, 'add_shard', + side_effect=RuntimeError("Boom!")) def test_add_shard_failure(self, mock_add_shard, mock_set_status): self.assertRaises(RuntimeError, self.manager.add_shard, self.context, "rs", "rs_member") mock_set_status.assert_called_with(ds_instance.ServiceStatuses.FAILED) - @patch.object(MongoDBApp, 'do_mongo') - def test_add_shard(self, mock_do_mongo): + @mock.patch.object(service.MongoDBAdmin, 'add_shard') + def test_add_shard(self, mock_add_shard): self.manager.add_shard(self.context, "rs", "rs_member") - mock_do_mongo.assert_called_with( - "db.adminCommand({addShard: \"rs/rs_member:27017\"})") + mock_add_shard.assert_called_with("rs/rs_member:27017") - @patch.object(mongo_service.MongoDbAppStatus, 'set_status') - @patch.object(MongoDBApp, 'add_config_servers', - side_effect=RuntimeError("Boom!")) + @mock.patch.object(service.MongoDBAppStatus, 'set_status') + @mock.patch.object(service.MongoDBApp, 'add_config_servers', + side_effect=RuntimeError("Boom!")) def test_add_config_server_failure(self, mock_add_config, mock_set_status): self.assertRaises(RuntimeError, self.manager.add_config_servers, @@ -77,10 +87,12 @@ class GuestAgentMongoDBClusterManagerTest(testtools.TestCase): ["cfg_server1", "cfg_server2"]) mock_set_status.assert_called_with(ds_instance.ServiceStatuses.FAILED) - @patch.object(MongoDBApp, 'start_db_with_conf_changes') - @patch.object(MongoDBApp, '_add_config_parameter', return_value="") - @patch.object(MongoDBApp, '_delete_config_parameters', return_value="") - @patch.object(MongoDBApp, '_read_config', return_value="") + @mock.patch.object(service.MongoDBApp, 'start_db_with_conf_changes') + @mock.patch.object(service.MongoDBApp, '_add_config_parameter', + return_value="") + @mock.patch.object(service.MongoDBApp, '_delete_config_parameters', + return_value="") + @mock.patch.object(service.MongoDBApp, '_read_config', return_value="") def test_add_config_servers(self, mock_read, mock_delete, mock_add, mock_start): self.manager.add_config_servers(self.context, @@ -94,64 +106,101 @@ class GuestAgentMongoDBClusterManagerTest(testtools.TestCase): "cfg_server1:27019,cfg_server2:27019") mock_start.assert_called_with("") - @patch.object(mongo_service.MongoDbAppStatus, 'set_status') - @patch.object(MongoDBApp, 'write_mongos_upstart') - @patch.object(MongoDBApp, 'reset_configuration') - @patch.object(MongoDBApp, 'update_config_contents') - @patch.object(netutils, 'get_my_ipv4', return_value="10.0.0.2") - def test_prepare_mongos(self, mock_ip_address, mock_update, mock_reset, + @mock.patch.object(service.MongoDBAppStatus, 'set_status') + @mock.patch.object(service.MongoDBApp, 'write_mongos_upstart') + @mock.patch.object(service.MongoDBApp, 'reset_configuration') + @mock.patch.object(service.MongoDBApp, 'update_config_contents') + @mock.patch.object(service.MongoDBApp, 'secure') + @mock.patch.object(service.MongoDBApp, 'get_key_file', + return_value="/test/key/file") + @mock.patch.object(netutils, 'get_my_ipv4', return_value="10.0.0.2") + def test_prepare_mongos(self, mock_ip_address, mock_key_file, + mock_secure, mock_update, mock_reset, mock_upstart, mock_set_status): - self._prepare_method("test-id-1", "query_router") - mock_update.assert_called_with(None, {'bind_ip': '10.0.0.2'}) + self._prepare_method("test-id-1", "query_router", None) + mock_update.assert_called_with(None, {'bind_ip': '10.0.0.2,127.0.0.1', + # 'keyFile': '/test/key/file'}) + }) self.assertTrue(self.manager.app.status.is_query_router) mock_set_status.assert_called_with( ds_instance.ServiceStatuses.BUILD_PENDING) - @patch.object(mongo_service.MongoDbAppStatus, 'set_status') - @patch.object(utils, 'poll_until') - @patch.object(MongoDBApp, 'start_db_with_conf_changes') - @patch.object(MongoDBApp, 'update_config_contents') - @patch.object(netutils, 'get_my_ipv4', return_value="10.0.0.3") - def test_prepare_config_server(self, mock_ip_address, mock_update, - mock_start, mock_poll, mock_set_status): - self._prepare_method("test-id-2", "config_server") + @mock.patch.object(service.MongoDBAppStatus, 'set_status') + @mock.patch.object(utils, 'poll_until') + @mock.patch.object(service.MongoDBApp, 'start_db_with_conf_changes') + @mock.patch.object(service.MongoDBApp, 'update_config_contents') + @mock.patch.object(service.MongoDBApp, 'secure') + @mock.patch.object(service.MongoDBApp, 'get_key_file', + return_value="/test/key/file") + @mock.patch.object(netutils, 'get_my_ipv4', return_value="10.0.0.3") + def test_prepare_config_server(self, mock_ip_address, mock_key_file, + mock_secure, mock_update, mock_start, + mock_poll, mock_set_status): + self._prepare_method("test-id-2", "config_server", None) mock_update.assert_called_with(None, {'configsvr': 'true', - 'bind_ip': '10.0.0.3', + 'bind_ip': '10.0.0.3,127.0.0.1', + # 'keyFile': '/test/key/file', 'dbpath': '/var/lib/mongodb'}) self.assertTrue(self.manager.app.status.is_config_server) mock_set_status.assert_called_with( ds_instance.ServiceStatuses.BUILD_PENDING) - @patch.object(mongo_service.MongoDbAppStatus, 'set_status') - @patch.object(utils, 'poll_until') - @patch.object(MongoDBApp, 'start_db_with_conf_changes') - @patch.object(MongoDBApp, 'update_config_contents') - @patch.object(netutils, 'get_my_ipv4', return_value="10.0.0.4") - def test_prepare_member(self, mock_ip_address, mock_update, mock_start, + @mock.patch.object(service.MongoDBAppStatus, 'set_status') + @mock.patch.object(utils, 'poll_until') + @mock.patch.object(service.MongoDBApp, 'start_db_with_conf_changes') + @mock.patch.object(service.MongoDBApp, 'update_config_contents') + @mock.patch.object(service.MongoDBApp, 'secure') + @mock.patch.object(service.MongoDBApp, 'get_key_file', + return_value="/test/key/file") + @mock.patch.object(netutils, 'get_my_ipv4', return_value="10.0.0.4") + def test_prepare_member(self, mock_ip_address, mock_key_file, + mock_secure, mock_update, mock_start, mock_poll, mock_set_status): - self._prepare_method("test-id-3", "member") + self._prepare_method("test-id-3", "member", None) mock_update.assert_called_with(None, - {'bind_ip': '10.0.0.4', + {'bind_ip': '10.0.0.4,127.0.0.1', + # 'keyFile': '/test/key/file', 'dbpath': '/var/lib/mongodb', 'replSet': 'rs1'}) mock_set_status.assert_called_with( ds_instance.ServiceStatuses.BUILD_PENDING) - @patch.object(volume.VolumeDevice, 'mount_points', return_value=[]) - @patch.object(volume.VolumeDevice, 'mount', return_value=None) - @patch.object(volume.VolumeDevice, 'migrate_data', return_value=None) - @patch.object(volume.VolumeDevice, 'format', return_value=None) - @patch.object(MongoDBApp, 'clear_storage') - @patch.object(MongoDBApp, 'start_db') - @patch.object(MongoDBApp, 'stop_db') - @patch.object(MongoDBApp, 'install_if_needed') - @patch.object(mongo_service.MongoDbAppStatus, 'begin_install') - def _prepare_method(self, instance_id, instance_type, *args): + @mock.patch.object(service.MongoDBAppStatus, 'set_status') + @mock.patch.object(utils, 'poll_until') + @mock.patch.object(service.MongoDBApp, 'start_db_with_conf_changes') + @mock.patch.object(service.MongoDBApp, 'update_config_contents') + @mock.patch.object(service.MongoDBApp, 'secure') + @mock.patch.object(netutils, 'get_my_ipv4', return_value="10.0.0.4") + def test_prepare_secure(self, mock_ip_address, mock_secure, + mock_update, mock_start, mock_poll, + mock_set_status): + key = "test_key" + self._prepare_method("test-id-4", "member", key) + mock_secure.assert_called_with( + {"id": "test-id-4", + "shard_id": "test_shard_id", + "instance_type": 'member', + "replica_set_name": "rs1", + "key": key} + + ) + + @mock.patch.object(volume.VolumeDevice, 'mount_points', return_value=[]) + @mock.patch.object(volume.VolumeDevice, 'mount', return_value=None) + @mock.patch.object(volume.VolumeDevice, 'migrate_data', return_value=None) + @mock.patch.object(volume.VolumeDevice, 'format', return_value=None) + @mock.patch.object(service.MongoDBApp, 'clear_storage') + @mock.patch.object(service.MongoDBApp, 'start_db') + @mock.patch.object(service.MongoDBApp, 'stop_db') + @mock.patch.object(service.MongoDBApp, 'install_if_needed') + @mock.patch.object(service.MongoDBAppStatus, 'begin_install') + def _prepare_method(self, instance_id, instance_type, key, *args): cluster_config = {"id": instance_id, "shard_id": "test_shard_id", "instance_type": instance_type, - "replica_set_name": "rs1"} + "replica_set_name": "rs1", + "key": key} # invocation self.manager.prepare(context=self.context, databases=None, diff --git a/trove/tests/unittests/guestagent/test_mongodb_manager.py b/trove/tests/unittests/guestagent/test_mongodb_manager.py index 9acd3a85cb..7ff5adfa44 100644 --- a/trove/tests/unittests/guestagent/test_mongodb_manager.py +++ b/trove/tests/unittests/guestagent/test_mongodb_manager.py @@ -12,105 +12,104 @@ # License for the specific language governing permissions and limitations # under the License. -import os +import mock +import pymongo -from mock import MagicMock -from mock import patch -import testtools - -from trove.common.context import TroveContext -from trove.common import utils -from trove.guestagent import backup -from trove.guestagent.datastore.experimental.mongodb import ( - manager as mongo_manager) -from trove.guestagent.datastore.experimental.mongodb import ( - service as mongo_service) -from trove.guestagent import volume -from trove.guestagent.volume import VolumeDevice +import trove.common.context as context +import trove.common.utils as utils +import trove.guestagent.backup as backup +import trove.guestagent.datastore.experimental.mongodb.manager as manager +import trove.guestagent.volume as volume +import trove.tests.unittests.trove_testtools as trove_testtools -class GuestAgentMongoDBManagerTest(testtools.TestCase): +class GuestAgentMongoDBManagerTest(trove_testtools.TestCase): def setUp(self): super(GuestAgentMongoDBManagerTest, self).setUp() - self.context = TroveContext() - self.manager = mongo_manager.Manager() - self.origin_MongoDbAppStatus = mongo_service.MongoDbAppStatus - self.origin_os_path_exists = os.path.exists - self.origin_format = volume.VolumeDevice.format - self.origin_migrate_data = volume.VolumeDevice.migrate_data - self.origin_mount = volume.VolumeDevice.mount - self.origin_mount_points = volume.VolumeDevice.mount_points - self.origin_stop_db = mongo_service.MongoDBApp.stop_db - self.origin_start_db = mongo_service.MongoDBApp.start_db - self.orig_exec_with_to = utils.execute_with_timeout - self.orig_backup_restore = backup.restore + self.context = context.TroveContext() + self.manager = manager.Manager() + + self.execute_with_timeout_patch = mock.patch.object( + utils, 'execute_with_timeout' + ) + self.addCleanup(self.execute_with_timeout_patch.stop) + self.execute_with_timeout_patch.start() + + self.pymongo_patch = mock.patch.object( + pymongo, 'MongoClient' + ) + self.addCleanup(self.pymongo_patch.stop) + self.pymongo_patch.start() + + self.mount_point = '/var/lib/mongodb' def tearDown(self): super(GuestAgentMongoDBManagerTest, self).tearDown() - mongo_service.MongoDbAppStatus = self.origin_MongoDbAppStatus - os.path.exists = self.origin_os_path_exists - volume.VolumeDevice.format = self.origin_format - volume.VolumeDevice.migrate_data = self.origin_migrate_data - volume.VolumeDevice.mount = self.origin_mount - volume.VolumeDevice.mount_points = self.origin_mount_points - mongo_service.MongoDBApp.stop_db = self.origin_stop_db - mongo_service.MongoDBApp.start_db = self.origin_start_db - utils.execute_with_timeout = self.orig_exec_with_to - backup.restore = self.orig_backup_restore def test_update_status(self): - self.manager.status = MagicMock() - self.manager.update_status(self.context) - self.manager.status.update.assert_any_call() + with mock.patch.object(self.manager, 'status') as status: + self.manager.update_status(self.context) + status.update.assert_any_call() - def test_prepare_from_backup(self): - self._prepare_dynamic(backup_id='backup_id_123abc') + def _prepare_method(self, databases=None, users=None, device_path=None, + mount_point=None, backup_info=None, + cluster_config=None, overrides=None, memory_mb='2048', + packages=['packages']): + """self.manager.app must be correctly mocked before calling.""" - def _prepare_dynamic(self, device_path='/dev/vdb', is_db_installed=True, - backup_id=None): + self.manager.status = mock.Mock() + self.manager.get_config_changes = mock.Mock() - # covering all outcomes is starting to cause trouble here - backup_info = {'id': backup_id, + self.manager.prepare(self.context, packages, + databases, memory_mb, users, + device_path=device_path, + mount_point=mount_point, + backup_info=backup_info, + overrides=overrides, + cluster_config=cluster_config) + + self.manager.status.begin_install.assert_any_call() + self.manager.app.install_if_needed.assert_called_with(packages) + self.manager.app.stop_db.assert_any_call() + self.manager.app.clear_storage.assert_any_call() + self.manager.get_config_changes.assert_called_with(cluster_config, + self.mount_point) + + @mock.patch.object(volume, 'VolumeDevice') + @mock.patch('os.path.exists') + def test_prepare_for_volume(self, exists, mocked_volume): + device_path = '/dev/vdb' + + self.manager.app = mock.Mock() + + self._prepare_method(device_path=device_path) + + mocked_volume().unmount_device.assert_called_with(device_path) + mocked_volume().format.assert_any_call() + mocked_volume().migrate_data.assert_called_with(self.mount_point) + mocked_volume().mount.assert_called_with(self.mount_point) + + def test_secure(self): + self.manager.app = mock.Mock() + + mock_secure = mock.Mock() + self.manager.app.secure = mock_secure + + self._prepare_method() + + mock_secure.assert_called_with(None) + + @mock.patch.object(backup, 'restore') + def test_prepare_from_backup(self, mocked_restore): + self.manager.app = mock.Mock() + + backup_info = {'id': 'backup_id_123abc', 'location': 'fake-location', - 'type': 'MongoDump', - 'checksum': 'fake-checksum'} if backup_id else None + 'type': 'MongoDBDump', + 'checksum': 'fake-checksum'} - mock_status = MagicMock() - mock_app = MagicMock() - self.manager.status = mock_status - self.manager.app = mock_app + self._prepare_method(backup_info=backup_info) - mock_status.begin_install = MagicMock(return_value=None) - volume.VolumeDevice.format = MagicMock(return_value=None) - volume.VolumeDevice.migrate_data = MagicMock(return_value=None) - volume.VolumeDevice.mount = MagicMock(return_value=None) - volume.VolumeDevice.mount_points = MagicMock(return_value=[]) - backup.restore = MagicMock(return_value=None) - - mock_app.stop_db = MagicMock(return_value=None) - mock_app.start_db = MagicMock(return_value=None) - mock_app.clear_storage = MagicMock(return_value=None) - os.path.exists = MagicMock(return_value=is_db_installed) - - with patch.object(utils, 'execute_with_timeout'): - # invocation - self.manager.prepare(context=self.context, databases=None, - packages=['package'], - memory_mb='2048', users=None, - device_path=device_path, - mount_point='/var/lib/mongodb', - backup_info=backup_info, - overrides=None, - cluster_config=None) - - # verification/assertion - mock_status.begin_install.assert_any_call() - mock_app.install_if_needed.assert_any_call(['package']) - mock_app.stop_db.assert_any_call() - VolumeDevice.format.assert_any_call() - VolumeDevice.migrate_data.assert_any_call('/var/lib/mongodb') - if backup_info: - backup.restore.assert_any_call(self.context, - backup_info, - '/var/lib/mongodb') + mocked_restore.assert_called_with(self.context, backup_info, + '/var/lib/mongodb')