Server support for instance module feature

This changeset handles the details of applying,
removing, listing and retrieving 'modules'
from Trove instances.
See https://review.openstack.org/#/c/290177 for
the corresponding troveclient changes.

Scenario tests have been extended to cover the
new functionality.  These tests can be run by:
./redstack int-tests --group=module

A sample module type 'driver' - ping - is included
that simply parses the module contents for a
message=Text string and returns the 'Text' as the
status message.  If no 'message=' tag is found, then
the driver reports an error message.

Due to time constraints, a few unimplemented
parts/tests of the blueprint have been triaged as bugs
and are scheduled to be fixed before mitaka-rc1.
These include:
Vertica license module driver:
    https://bugs.launchpad.net/trove/+bug/1554898
Incomplete module-instances command:
    https://bugs.launchpad.net/trove/+bug/1554900
Incomplete 'live-update' of modules:
    https://bugs.launchpad.net/trove/+bug/1554903

Co-Authored-by: Peter Stachowski <peter@tesora.com>
Co-Authored-by: Simon Chang <schang@tesora.com>

Partially Implements: blueprint module-management
Change-Id: Ia8d3ff2f4560a6d997df99d41012ea61fb0096f7
Depends-On: If62f5e51d4628cc6a8b10303d5c3893b3bd5057e
This commit is contained in:
Peter Stachowski 2016-03-08 00:24:41 -05:00
parent f7cda9912d
commit 7d33401ee3
40 changed files with 2038 additions and 389 deletions

View File

@ -36,6 +36,9 @@ trove.api.extensions =
mysql = trove.extensions.routes.mysql:Mysql mysql = trove.extensions.routes.mysql:Mysql
security_group = trove.extensions.routes.security_group:Security_group security_group = trove.extensions.routes.security_group:Security_group
trove.guestagent.module.drivers =
ping = trove.guestagent.module.drivers.ping_driver:PingDriver
# These are for backwards compatibility with Havana notification_driver configuration values # These are for backwards compatibility with Havana notification_driver configuration values
oslo.messaging.notify.drivers = oslo.messaging.notify.drivers =
trove.openstack.common.notifier.log_notifier = oslo_messaging.notify._impl_log:LogDriver trove.openstack.common.notifier.log_notifier = oslo_messaging.notify._impl_log:LogDriver

View File

@ -116,6 +116,18 @@ class API(wsgi.Router):
controller=instance_resource, controller=instance_resource,
action="guest_log_action", action="guest_log_action",
conditions={'method': ['POST']}) conditions={'method': ['POST']})
mapper.connect("/{tenant_id}/instances/{id}/modules",
controller=instance_resource,
action="module_list",
conditions={'method': ['GET']})
mapper.connect("/{tenant_id}/instances/{id}/modules",
controller=instance_resource,
action="module_apply",
conditions={'method': ['POST']})
mapper.connect("/{tenant_id}/instances/{id}/modules/{module_id}",
controller=instance_resource,
action="module_remove",
conditions={'method': ['DELETE']})
def _cluster_router(self, mapper): def _cluster_router(self, mapper):
cluster_resource = ClusterController().create_resource() cluster_resource = ClusterController().create_resource()
@ -211,6 +223,10 @@ class API(wsgi.Router):
controller=modules_resource, controller=modules_resource,
action="delete", action="delete",
conditions={'method': ['DELETE']}) conditions={'method': ['DELETE']})
mapper.connect("/{tenant_id}/modules/{id}/instances",
controller=modules_resource,
action="instances",
conditions={'method': ['GET']})
def _configurations_router(self, mapper): def _configurations_router(self, mapper):
parameters_resource = ParametersController().create_resource() parameters_resource = ParametersController().create_resource()

View File

@ -207,6 +207,19 @@ configuration_id = {
] ]
} }
module_list = {
"type": "array",
"minItems": 0,
"items": {
"type": "object",
"required": ["id"],
"additionalProperties": True,
"properties": {
"id": uuid,
}
}
}
cluster = { cluster = {
"create": { "create": {
"type": "object", "type": "object",
@ -238,7 +251,8 @@ cluster = {
"flavorRef": flavorref, "flavorRef": flavorref,
"volume": volume, "volume": volume,
"nics": nics, "nics": nics,
"availability_zone": non_empty_string "availability_zone": non_empty_string,
"modules": module_list,
} }
} }
} }
@ -334,7 +348,8 @@ instance = {
"version": non_empty_string "version": non_empty_string
} }
}, },
"nics": nics "nics": nics,
"modules": module_list
} }
} }
} }
@ -528,10 +543,10 @@ guest_log = {
} }
} }
module_non_empty_string = { module_contents = {
"type": "string", "type": "string",
"minLength": 1, "minLength": 1,
"maxLength": 65535, "maxLength": 16777215,
"pattern": "^.*.+.*$" "pattern": "^.*.+.*$"
} }
@ -548,7 +563,7 @@ module = {
"properties": { "properties": {
"name": non_empty_string, "name": non_empty_string,
"module_type": non_empty_string, "module_type": non_empty_string,
"contents": module_non_empty_string, "contents": module_contents,
"description": non_empty_string, "description": non_empty_string,
"datastore": { "datastore": {
"type": "object", "type": "object",
@ -577,7 +592,7 @@ module = {
"properties": { "properties": {
"name": non_empty_string, "name": non_empty_string,
"type": non_empty_string, "type": non_empty_string,
"contents": module_non_empty_string, "contents": module_contents,
"description": non_empty_string, "description": non_empty_string,
"datastore": { "datastore": {
"type": "object", "type": "object",
@ -595,6 +610,24 @@ module = {
} }
} }
}, },
"apply": {
"name": "module:apply",
"type": "object",
"required": ["modules"],
"properties": {
"modules": module_list,
}
},
"list": {
"name": "module:list",
"type": "object",
"required": [],
"properties": {
"module": uuid,
"from_guest": boolean_string,
"include_contents": boolean_string
}
},
} }
configuration = { configuration = {

View File

@ -403,8 +403,9 @@ common_opts = [
'become alive.'), 'become alive.'),
cfg.StrOpt('module_aes_cbc_key', default='module_aes_cbc_key', cfg.StrOpt('module_aes_cbc_key', default='module_aes_cbc_key',
help='OpenSSL aes_cbc key for module encryption.'), help='OpenSSL aes_cbc key for module encryption.'),
cfg.StrOpt('module_types', default='test, hidden_test', cfg.ListOpt('module_types', default=['ping'],
help='A list of module types supported.'), help='A list of module types supported. A module type '
'corresponds to the name of a ModuleDriver.'),
cfg.StrOpt('guest_log_container_name', cfg.StrOpt('guest_log_container_name',
default='database_logs', default='database_logs',
help='Name of container that stores guest log components.'), help='Name of container that stores guest log components.'),

View File

@ -0,0 +1,62 @@
# Copyright 2016 Tesora, Inc.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# Encryption/decryption handling
from Crypto.Cipher import AES
from Crypto import Random
import hashlib
from trove.common import stream_codecs
IV_BIT_COUNT = 16
def encode_data(data):
return stream_codecs.Base64Codec().serialize(data)
def decode_data(data):
return stream_codecs.Base64Codec().deserialize(data)
# Pad the data string to an multiple of pad_size
def pad_for_encryption(data, pad_size=IV_BIT_COUNT):
pad_count = pad_size - (len(data) % pad_size)
return data + chr(pad_count) * pad_count
# Unpad the data string by stripping off excess characters
def unpad_after_decryption(data):
return data[:len(data) - ord(data[-1])]
def encrypt_data(data, key, iv_bit_count=IV_BIT_COUNT):
md5_key = hashlib.md5(key).hexdigest()
iv = encode_data(Random.new().read(iv_bit_count))[:iv_bit_count]
aes = AES.new(md5_key, AES.MODE_CBC, iv)
data = pad_for_encryption(data, iv_bit_count)
encrypted = aes.encrypt(data)
return iv + encrypted
def decrypt_data(data, key, iv_bit_count=IV_BIT_COUNT):
md5_key = hashlib.md5(key).hexdigest()
iv = data[:iv_bit_count]
aes = AES.new(md5_key, AES.MODE_CBC, bytes(iv))
decrypted = aes.decrypt(bytes(data[iv_bit_count:]))
return unpad_after_decryption(decrypted)

View File

@ -15,6 +15,7 @@
import abc import abc
import ast import ast
import base64
import csv import csv
import json import json
import re import re
@ -259,7 +260,7 @@ class PropertiesCodec(StreamCodec):
SKIP_INIT_SPACE = True SKIP_INIT_SPACE = True
def __init__(self, delimiter=' ', comment_markers=('#'), def __init__(self, delimiter=' ', comment_markers=('#'),
unpack_singletons=True, string_mappings={}): unpack_singletons=True, string_mappings=None):
""" """
:param delimiter: A one-character used to separate fields. :param delimiter: A one-character used to separate fields.
:type delimiter: string :type delimiter: string
@ -280,7 +281,7 @@ class PropertiesCodec(StreamCodec):
""" """
self._delimiter = delimiter self._delimiter = delimiter
self._comment_markers = comment_markers self._comment_markers = comment_markers
self._string_converter = StringConverter(string_mappings) self._string_converter = StringConverter(string_mappings or {})
self._unpack_singletons = unpack_singletons self._unpack_singletons = unpack_singletons
def serialize(self, dict_data): def serialize(self, dict_data):
@ -366,6 +367,30 @@ class PropertiesCodec(StreamCodec):
return container return container
class KeyValueCodec(PropertiesCodec):
"""
Read/write data from/into a simple key=value file.
key1=value1
key2=value2
key3=value3
...
The above file content would be represented as:
{'key1': 'value1',
'key2': 'value2',
'key3': 'value3',
...
}
"""
def __init__(self, delimiter='=', comment_markers=('#'),
unpack_singletons=True, string_mappings=None):
super(KeyValueCodec, self).__init__(
delimiter=delimiter, comment_markers=comment_markers,
unpack_singletons=unpack_singletons,
string_mappings=string_mappings)
class JsonCodec(StreamCodec): class JsonCodec(StreamCodec):
def serialize(self, dict_data): def serialize(self, dict_data):
@ -373,3 +398,28 @@ class JsonCodec(StreamCodec):
def deserialize(self, stream): def deserialize(self, stream):
return json.load(six.StringIO(stream)) return json.load(six.StringIO(stream))
class Base64Codec(StreamCodec):
"""Serialize (encode) and deserialize (decode) using the base64 codec.
To read binary data from a file and b64encode it, used the decode=False
flag on operating_system's read calls. Use encode=False to decode
binary data before writing to a file as well.
"""
def serialize(self, data):
try:
# py27str - if we've got text data, this should encode it
# py27aa/py34aa - if we've got a bytearray, this should work too
encoded = str(base64.b64encode(data).decode('utf-8'))
except TypeError:
# py34str - convert to bytes first, then we can encode
data_bytes = bytes([ord(item) for item in data])
encoded = base64.b64encode(data_bytes).decode('utf-8')
return encoded
def deserialize(self, stream):
# py27 & py34 seem to understand bytearray the same
return bytearray([item for item in base64.b64decode(stream)])

View File

@ -14,12 +14,8 @@
# under the License. # under the License.
"""I totally stole most of this from melange, thx guys!!!""" """I totally stole most of this from melange, thx guys!!!"""
import base64
import collections import collections
from Crypto.Cipher import AES
from Crypto import Random
import datetime import datetime
import hashlib
import inspect import inspect
import os import os
import shutil import shutil
@ -331,44 +327,3 @@ def is_collection(item):
""" """
return (isinstance(item, collections.Iterable) and return (isinstance(item, collections.Iterable) and
not isinstance(item, types.StringTypes)) not isinstance(item, types.StringTypes))
# Encryption/decryption handling methods
IV_BIT_COUNT = 16
def encode_string(data_str):
byte_array = bytearray(data_str)
return base64.b64encode(byte_array)
def decode_string(data_str):
return base64.b64decode(data_str)
# Pad the data string to an multiple of pad_size
def pad_for_encryption(data_str, pad_size=IV_BIT_COUNT):
pad_count = pad_size - (len(data_str) % pad_size)
return data_str + chr(pad_count) * pad_count
# Unpad the data string by stripping off excess characters
def unpad_after_decryption(data_str):
return data_str[:len(data_str) - ord(data_str[-1])]
def encrypt_string(data_str, key, iv_bit_count=IV_BIT_COUNT):
md5_key = hashlib.md5(key).hexdigest()
iv = encode_string(Random.new().read(iv_bit_count))[:iv_bit_count]
aes = AES.new(md5_key, AES.MODE_CBC, iv)
data_str = pad_for_encryption(data_str, iv_bit_count)
encrypted_str = aes.encrypt(data_str)
return iv + encrypted_str
def decrypt_string(data_str, key, iv_bit_count=IV_BIT_COUNT):
md5_key = hashlib.md5(key).hexdigest()
iv = data_str[:iv_bit_count]
aes = AES.new(md5_key, AES.MODE_CBC, iv)
decrypted_str = aes.decrypt(data_str[iv_bit_count:])
return unpad_after_decryption(decrypted_str)

View File

@ -320,6 +320,7 @@ class Controller(object):
exception.ReplicaSourceDeleteForbidden, exception.ReplicaSourceDeleteForbidden,
exception.BackupTooLarge, exception.BackupTooLarge,
exception.ModuleAccessForbidden, exception.ModuleAccessForbidden,
exception.ModuleAppliedToInstance,
], ],
webob.exc.HTTPBadRequest: [ webob.exc.HTTPBadRequest: [
exception.InvalidModelError, exception.InvalidModelError,
@ -330,7 +331,6 @@ class Controller(object):
exception.UserAlreadyExists, exception.UserAlreadyExists,
exception.LocalStorageNotSpecified, exception.LocalStorageNotSpecified,
exception.ModuleAlreadyExists, exception.ModuleAlreadyExists,
exception.ModuleAppliedToInstance,
], ],
webob.exc.HTTPNotFound: [ webob.exc.HTTPNotFound: [
exception.NotFound, exception.NotFound,

View File

@ -36,7 +36,7 @@ modules = Table(
Column('id', String(length=64), primary_key=True, nullable=False), Column('id', String(length=64), primary_key=True, nullable=False),
Column('name', String(length=255), nullable=False), Column('name', String(length=255), nullable=False),
Column('type', String(length=255), nullable=False), Column('type', String(length=255), nullable=False),
Column('contents', Text(), nullable=False), Column('contents', Text(length=16777215), nullable=False),
Column('description', String(length=255)), Column('description', String(length=255)),
Column('tenant_id', String(length=64), nullable=True), Column('tenant_id', String(length=64), nullable=True),
Column('datastore_id', String(length=64), nullable=True), Column('datastore_id', String(length=64), nullable=True),

View File

@ -227,7 +227,8 @@ class API(object):
def prepare(self, memory_mb, packages, databases, users, def prepare(self, memory_mb, packages, databases, users,
device_path='/dev/vdb', mount_point='/mnt/volume', device_path='/dev/vdb', mount_point='/mnt/volume',
backup_info=None, config_contents=None, root_password=None, backup_info=None, config_contents=None, root_password=None,
overrides=None, cluster_config=None, snapshot=None): overrides=None, cluster_config=None, snapshot=None,
modules=None):
"""Make an asynchronous call to prepare the guest """Make an asynchronous call to prepare the guest
as a database container optionally includes a backup id for restores as a database container optionally includes a backup id for restores
""" """
@ -246,7 +247,7 @@ class API(object):
device_path=device_path, mount_point=mount_point, device_path=device_path, mount_point=mount_point,
backup_info=backup_info, config_contents=config_contents, backup_info=backup_info, config_contents=config_contents,
root_password=root_password, overrides=overrides, root_password=root_password, overrides=overrides,
cluster_config=cluster_config, snapshot=snapshot) cluster_config=cluster_config, snapshot=snapshot, modules=modules)
def _create_guest_queue(self): def _create_guest_queue(self):
"""Call to construct, start and immediately stop rpc server in order """Call to construct, start and immediately stop rpc server in order
@ -434,7 +435,7 @@ class API(object):
LOG.debug("Retrieving guest log list for %s.", self.id) LOG.debug("Retrieving guest log list for %s.", self.id)
result = self._call("guest_log_list", AGENT_HIGH_TIMEOUT, result = self._call("guest_log_list", AGENT_HIGH_TIMEOUT,
self.version_cap) self.version_cap)
LOG.debug("guest_log_list 1 returns %s", result) LOG.debug("guest_log_list returns %s", result)
return result return result
def guest_log_action(self, log_name, enable, disable, publish, discard): def guest_log_action(self, log_name, enable, disable, publish, discard):
@ -443,3 +444,21 @@ class API(object):
self.version_cap, log_name=log_name, self.version_cap, log_name=log_name,
enable=enable, disable=disable, enable=enable, disable=disable,
publish=publish, discard=discard) publish=publish, discard=discard)
def module_list(self, include_contents):
LOG.debug("Querying modules on %s (contents: %s).",
self.id, include_contents)
result = self._call("module_list", AGENT_HIGH_TIMEOUT,
self.version_cap,
include_contents=include_contents)
return result
def module_apply(self, modules):
LOG.debug("Applying modules to %s.", self.id)
return self._call("module_apply", AGENT_HIGH_TIMEOUT,
self.version_cap, modules=modules)
def module_remove(self, module):
LOG.debug("Removing modules from %s.", self.id)
return self._call("module_remove", AGENT_HIGH_TIMEOUT,
self.version_cap, module=module)

View File

@ -33,31 +33,36 @@ DEBIAN = 'debian'
SUSE = 'suse' SUSE = 'suse'
def read_file(path, codec=IdentityCodec(), as_root=False): def read_file(path, codec=IdentityCodec(), as_root=False, decode=True):
""" """
Read a file into a Python data structure Read a file into a Python data structure
digestible by 'write_file'. digestible by 'write_file'.
:param path Path to the read config file. :param path: Path to the read config file.
:type path string :type path: string
:param codec: A codec used to deserialize the data. :param codec: A codec used to transform the data.
:type codec: StreamCodec :type codec: StreamCodec
:returns: A dictionary of key-value pairs.
:param as_root: Execute as root. :param as_root: Execute as root.
:type as_root: boolean :type as_root: boolean
:param decode: Should the codec decode the data.
:type decode: boolean
:returns: A dictionary of key-value pairs.
:raises: :class:`UnprocessableEntity` if file doesn't exist. :raises: :class:`UnprocessableEntity` if file doesn't exist.
:raises: :class:`UnprocessableEntity` if codec not given. :raises: :class:`UnprocessableEntity` if codec not given.
""" """
if path and exists(path, is_directory=False, as_root=as_root): if path and exists(path, is_directory=False, as_root=as_root):
if as_root: if as_root:
return _read_file_as_root(path, codec) return _read_file_as_root(path, codec, decode=decode)
with open(path, 'r') as fp: with open(path, 'rb') as fp:
return codec.deserialize(fp.read()) if decode:
return codec.deserialize(fp.read())
return codec.serialize(fp.read())
raise exception.UnprocessableEntity(_("File does not exist: %s") % path) raise exception.UnprocessableEntity(_("File does not exist: %s") % path)
@ -92,22 +97,27 @@ def exists(path, is_directory=False, as_root=False):
return found return found
def _read_file_as_root(path, codec): def _read_file_as_root(path, codec, decode=True):
"""Read a file as root. """Read a file as root.
:param path Path to the written file. :param path Path to the written file.
:type path string :type path string
:param codec: A codec used to serialize the data. :param codec: A codec used to transform the data.
:type codec: StreamCodec :type codec: StreamCodec
:param decode: Should the codec decode the data.
:type decode: boolean
""" """
with tempfile.NamedTemporaryFile() as fp: with tempfile.NamedTemporaryFile() as fp:
copy(path, fp.name, force=True, as_root=True) copy(path, fp.name, force=True, as_root=True)
chmod(fp.name, FileMode.ADD_READ_ALL(), as_root=True) chmod(fp.name, FileMode.ADD_READ_ALL(), as_root=True)
return codec.deserialize(fp.read()) if decode:
return codec.deserialize(fp.read())
return codec.serialize(fp.read())
def write_file(path, data, codec=IdentityCodec(), as_root=False): def write_file(path, data, codec=IdentityCodec(), as_root=False, encode=True):
"""Write data into file using a given codec. """Write data into file using a given codec.
Overwrite any existing contents. Overwrite any existing contents.
The written file can be read back into its original The written file can be read back into its original
@ -119,25 +129,31 @@ def write_file(path, data, codec=IdentityCodec(), as_root=False):
:param data: An object representing the file contents. :param data: An object representing the file contents.
:type data: object :type data: object
:param codec: A codec used to serialize the data. :param codec: A codec used to transform the data.
:type codec: StreamCodec :type codec: StreamCodec
:param as_root: Execute as root. :param as_root: Execute as root.
:type as_root: boolean :type as_root: boolean
:param encode: Should the codec encode the data.
:type encode: boolean
:raises: :class:`UnprocessableEntity` if path not given. :raises: :class:`UnprocessableEntity` if path not given.
""" """
if path: if path:
if as_root: if as_root:
_write_file_as_root(path, data, codec) _write_file_as_root(path, data, codec, encode=encode)
else: else:
with open(path, 'w', 0) as fp: with open(path, 'wb', 0) as fp:
fp.write(codec.serialize(data)) if encode:
fp.write(codec.serialize(data))
else:
fp.write(codec.deserialize(data))
else: else:
raise exception.UnprocessableEntity(_("Invalid path: %s") % path) raise exception.UnprocessableEntity(_("Invalid path: %s") % path)
def _write_file_as_root(path, data, codec): def _write_file_as_root(path, data, codec, encode=True):
"""Write a file as root. Overwrite any existing contents. """Write a file as root. Overwrite any existing contents.
:param path Path to the written file. :param path Path to the written file.
@ -146,13 +162,19 @@ def _write_file_as_root(path, data, codec):
:param data: An object representing the file contents. :param data: An object representing the file contents.
:type data: StreamCodec :type data: StreamCodec
:param codec: A codec used to serialize the data. :param codec: A codec used to transform the data.
:type codec: StreamCodec :type codec: StreamCodec
:param encode: Should the codec encode the data.
:type encode: boolean
""" """
# The files gets removed automatically once the managing object goes # The files gets removed automatically once the managing object goes
# out of scope. # out of scope.
with tempfile.NamedTemporaryFile('w', 0, delete=False) as fp: with tempfile.NamedTemporaryFile('wb', 0, delete=False) as fp:
fp.write(codec.serialize(data)) if encode:
fp.write(codec.serialize(data))
else:
fp.write(codec.deserialize(data))
fp.close() # Release the resource before proceeding. fp.close() # Release the resource before proceeding.
copy(fp.name, path, force=True, as_root=True) copy(fp.name, path, force=True, as_root=True)

View File

@ -30,6 +30,8 @@ from trove.guestagent.common import operating_system
from trove.guestagent.common.operating_system import FileMode from trove.guestagent.common.operating_system import FileMode
from trove.guestagent import dbaas from trove.guestagent import dbaas
from trove.guestagent import guest_log from trove.guestagent import guest_log
from trove.guestagent.module import driver_manager
from trove.guestagent.module import module_manager
from trove.guestagent.strategies import replication as repl_strategy from trove.guestagent.strategies import replication as repl_strategy
from trove.guestagent import volume from trove.guestagent import volume
@ -73,6 +75,9 @@ class Manager(periodic_task.PeriodicTasks):
self._guest_log_cache = None self._guest_log_cache = None
self._guest_log_defs = None self._guest_log_defs = None
# Module
self.module_driver_manager = driver_manager.ModuleDriverManager()
@property @property
def manager_name(self): def manager_name(self):
"""This returns the passed-in name of the manager.""" """This returns the passed-in name of the manager."""
@ -251,22 +256,24 @@ class Manager(periodic_task.PeriodicTasks):
def prepare(self, context, packages, databases, memory_mb, users, def prepare(self, context, packages, databases, memory_mb, users,
device_path=None, mount_point=None, backup_info=None, device_path=None, mount_point=None, backup_info=None,
config_contents=None, root_password=None, overrides=None, config_contents=None, root_password=None, overrides=None,
cluster_config=None, snapshot=None): cluster_config=None, snapshot=None, modules=None):
"""Set up datastore on a Guest Instance.""" """Set up datastore on a Guest Instance."""
with EndNotification(context, instance_id=CONF.guest_id): with EndNotification(context, instance_id=CONF.guest_id):
self._prepare(context, packages, databases, memory_mb, users, self._prepare(context, packages, databases, memory_mb, users,
device_path, mount_point, backup_info, device_path, mount_point, backup_info,
config_contents, root_password, overrides, config_contents, root_password, overrides,
cluster_config, snapshot) cluster_config, snapshot, modules)
def _prepare(self, context, packages, databases, memory_mb, users, def _prepare(self, context, packages, databases, memory_mb, users,
device_path=None, mount_point=None, backup_info=None, device_path, mount_point, backup_info,
config_contents=None, root_password=None, overrides=None, config_contents, root_password, overrides,
cluster_config=None, snapshot=None): cluster_config, snapshot, modules):
LOG.info(_("Starting datastore prepare for '%s'.") % self.manager) LOG.info(_("Starting datastore prepare for '%s'.") % self.manager)
self.status.begin_install() self.status.begin_install()
post_processing = True if cluster_config else False post_processing = True if cluster_config else False
try: try:
# Since all module handling is common, don't pass it down to the
# individual 'do_prepare' methods.
self.do_prepare(context, packages, databases, memory_mb, self.do_prepare(context, packages, databases, memory_mb,
users, device_path, mount_point, backup_info, users, device_path, mount_point, backup_info,
config_contents, root_password, overrides, config_contents, root_password, overrides,
@ -291,6 +298,17 @@ class Manager(periodic_task.PeriodicTasks):
LOG.info(_("Completed setup of '%s' datastore successfully.") % LOG.info(_("Completed setup of '%s' datastore successfully.") %
self.manager) self.manager)
# The following block performs additional instance initialization.
# Failures will be recorded, but won't stop the provisioning
# or change the instance state.
try:
if modules:
LOG.info(_("Applying modules (called from 'prepare')."))
self.module_apply(context, modules)
LOG.info(_('Module apply completed.'))
except Exception as ex:
LOG.exception(_("An error occurred applying modules: "
"%s") % ex.message)
# The following block performs single-instance initialization. # The following block performs single-instance initialization.
# Failures will be recorded, but won't stop the provisioning # Failures will be recorded, but won't stop the provisioning
# or change the instance state. # or change the instance state.
@ -595,6 +613,66 @@ class Manager(periodic_task.PeriodicTasks):
LOG.debug("Set log file '%s' as readable" % log_file) LOG.debug("Set log file '%s' as readable" % log_file)
return log_file return log_file
################
# Module related
################
def module_list(self, context, include_contents=False):
LOG.info(_("Getting list of modules."))
results = module_manager.ModuleManager.read_module_results(
is_admin=context.is_admin, include_contents=include_contents)
LOG.info(_("Returning list of modules: %s") % results)
return results
def module_apply(self, context, modules=None):
LOG.info(_("Applying modules."))
results = []
for module_data in modules:
module = module_data['module']
id = module.get('id', None)
module_type = module.get('type', None)
name = module.get('name', None)
tenant = module.get('tenant', None)
datastore = module.get('datastore', None)
ds_version = module.get('datastore_version', None)
contents = module.get('contents', None)
md5 = module.get('md5', None)
auto_apply = module.get('auto_apply', True)
visible = module.get('visible', True)
if not name:
raise AttributeError(_("Module name not specified"))
if not contents:
raise AttributeError(_("Module contents not specified"))
driver = self.module_driver_manager.get_driver(module_type)
if not driver:
raise exception.ModuleTypeNotFound(
_("No driver implemented for module type '%s'") %
module_type)
result = module_manager.ModuleManager.apply_module(
driver, module_type, name, tenant, datastore, ds_version,
contents, id, md5, auto_apply, visible)
results.append(result)
LOG.info(_("Returning list of modules: %s") % results)
return results
def module_remove(self, context, module=None):
LOG.info(_("Removing module."))
module = module['module']
id = module.get('id', None)
module_type = module.get('type', None)
name = module.get('name', None)
datastore = module.get('datastore', None)
ds_version = module.get('datastore_version', None)
if not name:
raise AttributeError(_("Module name not specified"))
driver = self.module_driver_manager.get_driver(module_type)
if not driver:
raise exception.ModuleTypeNotFound(
_("No driver implemented for module type '%s'") %
module_type)
module_manager.ModuleManager.remove_module(
driver, module_type, id, name, datastore, ds_version)
LOG.info(_("Deleted module: %s") % name)
############### ###############
# Not Supported # Not Supported
############### ###############

View File

View File

@ -0,0 +1,96 @@
# Copyright 2016 Tesora, Inc.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
from oslo_log import log as logging
import stevedore
from trove.common import base_exception as exception
from trove.common import cfg
from trove.common.i18n import _
LOG = logging.getLogger(__name__)
CONF = cfg.CONF
class ModuleDriverManager(object):
MODULE_DRIVER_NAMESPACE = 'trove.guestagent.module.drivers'
def __init__(self):
LOG.info(_('Initializing module driver manager.'))
self._drivers = {}
self._module_types = [mt.lower() for mt in CONF.module_types]
self._load_drivers()
def _load_drivers(self):
manager = stevedore.enabled.EnabledExtensionManager(
namespace=self.MODULE_DRIVER_NAMESPACE,
check_func=self._check_extension,
invoke_on_load=True,
invoke_kwds={})
try:
manager.map(self.add_driver_extension)
except stevedore.exception.NoMatches:
LOG.info(_("No module drivers loaded"))
def _check_extension(self, extension):
"""Checks for required methods in driver objects."""
driver = extension.obj
supported = False
try:
LOG.info(_('Loading Module driver: %s'), driver.get_type())
if driver.get_type() != driver.get_type().lower():
raise AttributeError(_("Driver 'type' must be lower-case"))
LOG.debug(' description: %s', driver.get_description())
LOG.debug(' updated : %s', driver.get_updated())
required_attrs = ['apply', 'remove']
for attr in required_attrs:
if not hasattr(driver, attr):
raise AttributeError(
_("Driver '%(type)s' missing attribute: %(attr)s")
% {'type': driver.get_type(), 'attr': attr})
if driver.get_type() in self._module_types:
supported = True
else:
LOG.info(_("Driver '%s' not supported, skipping"),
driver.get_type)
except AttributeError as ex:
LOG.exception(_("Exception loading module driver: %s"),
unicode(ex))
return supported
def add_driver_extension(self, extension):
# Add a module driver from the extension.
# If the stevedore manager is changed to one that doesn't
# check the extension driver, then it should be done manually here
# by calling self._check_extension(extension)
driver = extension.obj
driver_type = driver.get_type()
LOG.info(_('Loaded module driver: %s'), driver_type)
if driver_type in self._drivers:
raise exception.Error(_("Found duplicate driver: %s") %
driver_type)
self._drivers[driver_type] = driver
def get_driver(self, driver_type):
found = None
if driver_type in self._drivers:
found = self._drivers[driver_type]
return found

View File

@ -0,0 +1,72 @@
# Copyright 2016 Tesora, Inc.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
import abc
import six
from oslo_log import log as logging
from trove.common import cfg
LOG = logging.getLogger(__name__)
CONF = cfg.CONF
@six.add_metaclass(abc.ABCMeta)
class ModuleDriver(object):
"""Base class that defines the contract for module drivers.
Note that you don't have to derive from this class to have a valid
driver; it is purely a convenience.
"""
def get_type(self):
"""This is used when setting up a module in Trove, and is here for
code clarity. It just returns the name of the driver.
"""
return self.get_name()
def get_name(self):
"""Attempt to generate a usable name based on the class name. If
overridden, must be in lower-case.
"""
return self.__class__.__name__.lower().replace(
'driver', '').replace(' ', '_')
@abc.abstractmethod
def get_description(self):
"""Description for the driver."""
pass
@abc.abstractmethod
def get_updated(self):
"""Date the driver was last updated."""
pass
@abc.abstractmethod
def apply(self, name, datastore, ds_version, data_file):
"""Apply the data to the guest instance. Return status and message
as a tupple.
"""
return False, "Not a concrete driver"
@abc.abstractmethod
def remove(self, name, datastore, ds_version, data_file):
"""Remove the data from the guest instance. Return status and message
as a tupple.
"""
return False, "Not a concrete driver"

View File

@ -0,0 +1,73 @@
# Copyright 2016 Tesora, Inc.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
from datetime import date
from oslo_log import log as logging
from trove.common import cfg
from trove.common.i18n import _
from trove.common import stream_codecs
from trove.guestagent.common import operating_system
from trove.guestagent.module.drivers import module_driver
LOG = logging.getLogger(__name__)
CONF = cfg.CONF
class PingDriver(module_driver.ModuleDriver):
"""Concrete module to show implementation and functionality. Responds
like an actual module driver, but does nothing except return the
value of the message key in the contents file. For example, if the file
contains 'message=Hello' then the message returned by module-apply will
be 'Hello.'
"""
def get_type(self):
return 'ping'
def get_description(self):
return "Ping Guestagent Module Driver"
def get_updated(self):
return date(2016, 3, 4)
def apply(self, name, datastore, ds_version, data_file):
success = False
message = "Message not found in contents file"
try:
data = operating_system.read_file(
data_file, codec=stream_codecs.KeyValueCodec())
for key, value in data.items():
if 'message' == key.lower():
success = True
message = value
break
except Exception:
# assume we couldn't read the file, because there was some
# issue with it (for example, it's a binary file). Just log
# it and drive on.
LOG.error(_("Could not extract contents from '%s' - possibly "
"a binary file?") % name)
return success, message
def _is_binary(self, data_str):
bool(data_str.translate(None, self.TEXT_CHARS))
def remove(self, name, datastore, ds_version, data_file):
return True, ""

View File

@ -0,0 +1,218 @@
# Copyright 2016 Tesora, Inc.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
import datetime
import os
from oslo_log import log as logging
from trove.common import cfg
from trove.common import exception
from trove.common.i18n import _
from trove.common import stream_codecs
from trove.guestagent.common import guestagent_utils
from trove.guestagent.common import operating_system
LOG = logging.getLogger(__name__)
CONF = cfg.CONF
class ModuleManager():
"""This is a Manager utility class (mixin) for managing module-related
tasks.
"""
MODULE_APPLY_TO_ALL = 'all'
MODULE_BASE_DIR = guestagent_utils.build_file_path('~', 'modules')
MODULE_CONTENTS_FILENAME = 'contents.dat'
MODULE_RESULT_FILENAME = 'result.json'
@classmethod
def get_current_timestamp(cls):
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@classmethod
def apply_module(cls, driver, module_type, name, tenant,
datastore, ds_version, contents, module_id, md5,
auto_apply, visible):
tenant = tenant or cls.MODULE_APPLY_TO_ALL
datastore = datastore or cls.MODULE_APPLY_TO_ALL
ds_version = ds_version or cls.MODULE_APPLY_TO_ALL
module_dir = cls.build_module_dir(module_type, module_id)
data_file = cls.write_module_contents(module_dir, contents, md5)
applied = True
message = None
now = cls.get_current_timestamp()
default_result = cls.build_default_result(
module_type, name, tenant, datastore,
ds_version, module_id, md5, auto_apply, visible, now)
result = cls.read_module_result(module_dir, default_result)
try:
applied, message = driver.apply(
name, datastore, ds_version, data_file)
except Exception as ex:
LOG.exception(_("Could not apply module '%s'") % name)
applied = False
message = ex.message
finally:
status = 'OK' if applied else 'ERROR'
admin_only = (not visible or tenant == cls.MODULE_APPLY_TO_ALL or
auto_apply)
result['status'] = status
result['message'] = message
result['updated'] = now
result['id'] = module_id
result['md5'] = md5
result['tenant'] = tenant
result['auto_apply'] = auto_apply
result['visible'] = visible
result['admin_only'] = admin_only
cls.write_module_result(module_dir, result)
return result
@classmethod
def build_module_dir(cls, module_type, module_id):
sub_dir = os.path.join(module_type, module_id)
module_dir = guestagent_utils.build_file_path(
cls.MODULE_BASE_DIR, sub_dir)
if not operating_system.exists(module_dir, is_directory=True):
operating_system.create_directory(module_dir, force=True)
return module_dir
@classmethod
def write_module_contents(cls, module_dir, contents, md5):
contents_file = cls.build_contents_filename(module_dir)
operating_system.write_file(contents_file, contents,
codec=stream_codecs.Base64Codec(),
encode=False)
return contents_file
@classmethod
def build_contents_filename(cls, module_dir):
contents_file = guestagent_utils.build_file_path(
module_dir, cls.MODULE_CONTENTS_FILENAME)
return contents_file
@classmethod
def build_default_result(cls, module_type, name, tenant,
datastore, ds_version, module_id, md5,
auto_apply, visible, now):
admin_only = (not visible or tenant == cls.MODULE_APPLY_TO_ALL or
auto_apply)
result = {
'type': module_type,
'name': name,
'datastore': datastore,
'datastore_version': ds_version,
'tenant': tenant,
'id': module_id,
'md5': md5,
'status': None,
'message': None,
'created': now,
'updated': now,
'removed': None,
'auto_apply': auto_apply,
'visible': visible,
'admin_only': admin_only,
'contents': None,
}
return result
@classmethod
def read_module_result(cls, result_file, default=None):
result_file = cls.get_result_filename(result_file)
result = default
try:
result = operating_system.read_file(
result_file, codec=stream_codecs.JsonCodec())
except Exception:
if not result:
LOG.exception(_("Could not find module result in %s") %
result_file)
raise
return result
@classmethod
def get_result_filename(cls, file_or_dir):
result_file = file_or_dir
if operating_system.exists(file_or_dir, is_directory=True):
result_file = guestagent_utils.build_file_path(
file_or_dir, cls.MODULE_RESULT_FILENAME)
return result_file
@classmethod
def write_module_result(cls, result_file, result):
result_file = cls.get_result_filename(result_file)
operating_system.write_file(
result_file, result, codec=stream_codecs.JsonCodec())
@classmethod
def read_module_results(cls, is_admin=False, include_contents=False):
"""Read all the module results on the guest and return a list
of them.
"""
results = []
pattern = cls.MODULE_RESULT_FILENAME
result_files = operating_system.list_files_in_directory(
cls.MODULE_BASE_DIR, recursive=True, pattern=pattern)
for result_file in result_files:
result = cls.read_module_result(result_file)
if (not result.get('removed') and
(is_admin or result.get('visible'))):
if include_contents:
codec = stream_codecs.Base64Codec()
if not is_admin and result.get('admin_only'):
contents = (
"Must be admin to retrieve contents for module %s"
% result.get('name', 'Unknown'))
result['contents'] = codec.serialize(contents)
else:
contents_dir = os.path.dirname(result_file)
contents_file = cls.build_contents_filename(
contents_dir)
result['contents'] = operating_system.read_file(
contents_file, codec=codec, decode=False)
results.append(result)
return results
@classmethod
def remove_module(cls, driver, module_type, module_id, name,
datastore, ds_version):
datastore = datastore or cls.MODULE_APPLY_TO_ALL
ds_version = ds_version or cls.MODULE_APPLY_TO_ALL
module_dir = cls.build_module_dir(module_type, module_id)
contents_file = cls.build_contents_filename(module_dir)
if not operating_system.exists(cls.get_result_filename(module_dir)):
raise exception.NotFound(
_("Module '%s' has not been applied") % name)
try:
removed, message = driver.remove(
name, datastore, ds_version, contents_file)
cls.remove_module_result(module_dir)
except Exception:
LOG.exception(_("Could not remove module '%s'") % name)
raise
return removed, message
@classmethod
def remove_module_result(cls, result_file):
now = cls.get_current_timestamp()
result = cls.read_module_result(result_file, None)
result['removed'] = now
cls.write_module_result(result_file, result)

View File

@ -43,6 +43,8 @@ from trove.db import models as dbmodels
from trove.extensions.security_group.models import SecurityGroup from trove.extensions.security_group.models import SecurityGroup
from trove.instance.tasks import InstanceTask from trove.instance.tasks import InstanceTask
from trove.instance.tasks import InstanceTasks from trove.instance.tasks import InstanceTasks
from trove.module import models as module_models
from trove.module import views as module_views
from trove.quota.quota import run_with_quotas from trove.quota.quota import run_with_quotas
from trove.taskmanager import api as task_api from trove.taskmanager import api as task_api
@ -672,7 +674,7 @@ class Instance(BuiltInstance):
datastore, datastore_version, volume_size, backup_id, datastore, datastore_version, volume_size, backup_id,
availability_zone=None, nics=None, availability_zone=None, nics=None,
configuration_id=None, slave_of_id=None, cluster_config=None, configuration_id=None, slave_of_id=None, cluster_config=None,
replica_count=None, volume_type=None): replica_count=None, volume_type=None, modules=None):
call_args = { call_args = {
'name': name, 'name': name,
@ -798,6 +800,23 @@ class Instance(BuiltInstance):
if cluster_config: if cluster_config:
call_args['cluster_id'] = cluster_config.get("id", None) call_args['cluster_id'] = cluster_config.get("id", None)
if not modules:
modules = []
module_ids = [mod['id'] for mod in modules]
modules = module_models.Modules.load_by_ids(context, module_ids)
auto_apply_modules = module_models.Modules.load_auto_apply(
context, datastore.id, datastore_version.id)
for aa_module in auto_apply_modules:
if aa_module.id not in module_ids:
modules.append(aa_module)
module_list = []
for module in modules:
module.contents = module_models.Module.deprocess_contents(
module.contents)
module_info = module_views.DetailedModuleView(module).data(
include_contents=True)
module_list.append(module_info)
def _create_resources(): def _create_resources():
if cluster_config: if cluster_config:
@ -825,6 +844,7 @@ class Instance(BuiltInstance):
{'tenant': context.tenant, 'db': db_info.id}) {'tenant': context.tenant, 'db': db_info.id})
instance_id = db_info.id instance_id = db_info.id
cls.add_instance_modules(context, instance_id, modules)
instance_name = name instance_name = name
ids.append(instance_id) ids.append(instance_id)
names.append(instance_name) names.append(instance_name)
@ -866,7 +886,7 @@ class Instance(BuiltInstance):
datastore_version.manager, datastore_version.packages, datastore_version.manager, datastore_version.packages,
volume_size, backup_id, availability_zone, root_password, volume_size, backup_id, availability_zone, root_password,
nics, overrides, slave_of_id, cluster_config, nics, overrides, slave_of_id, cluster_config,
volume_type=volume_type) volume_type=volume_type, modules=module_list)
return SimpleInstance(context, db_info, service_status, return SimpleInstance(context, db_info, service_status,
root_password) root_password)
@ -874,6 +894,12 @@ class Instance(BuiltInstance):
with StartNotification(context, **call_args): with StartNotification(context, **call_args):
return run_with_quotas(context.tenant, deltas, _create_resources) return run_with_quotas(context.tenant, deltas, _create_resources)
@classmethod
def add_instance_modules(cls, context, instance_id, modules):
for module in modules:
module_models.InstanceModule.create(
context, instance_id, module.id, module.md5)
def get_flavor(self): def get_flavor(self):
client = create_nova_client(self.context) client = create_nova_client(self.context)
return client.flavors.get(self.flavor_id) return client.flavors.get(self.flavor_id)
@ -1177,7 +1203,7 @@ class Instances(object):
DEFAULT_LIMIT = CONF.instances_page_size DEFAULT_LIMIT = CONF.instances_page_size
@staticmethod @staticmethod
def load(context, include_clustered): def load(context, include_clustered, instance_ids=None):
def load_simple_instance(context, db, status, **kwargs): def load_simple_instance(context, db, status, **kwargs):
return SimpleInstance(context, db, status) return SimpleInstance(context, db, status)
@ -1186,14 +1212,18 @@ class Instances(object):
raise TypeError("Argument context not defined.") raise TypeError("Argument context not defined.")
client = create_nova_client(context) client = create_nova_client(context)
servers = client.servers.list() servers = client.servers.list()
query_opts = {'tenant_id': context.tenant,
if include_clustered: 'deleted': False}
db_infos = DBInstance.find_all(tenant_id=context.tenant, if not include_clustered:
deleted=False) query_opts['cluster_id'] = None
if instance_ids and len(instance_ids) > 1:
raise exception.DatastoreOperationNotSupported(
operation='module-instances', datastore='current')
db_infos = DBInstance.query().filter_by(**query_opts)
else: else:
db_infos = DBInstance.find_all(tenant_id=context.tenant, if instance_ids:
cluster_id=None, query_opts['id'] = instance_ids[0]
deleted=False) db_infos = DBInstance.find_all(**query_opts)
limit = utils.pagination_limit(context.limit, Instances.DEFAULT_LIMIT) limit = utils.pagination_limit(context.limit, Instances.DEFAULT_LIMIT)
data_view = DBInstance.find_by_pagination('instances', db_infos, "foo", data_view = DBInstance.find_by_pagination('instances', db_infos, "foo",
limit=limit, limit=limit,

View File

@ -34,6 +34,8 @@ from trove.datastore import models as datastore_models
from trove.extensions.mysql.common import populate_users from trove.extensions.mysql.common import populate_users
from trove.extensions.mysql.common import populate_validated_databases from trove.extensions.mysql.common import populate_validated_databases
from trove.instance import models, views from trove.instance import models, views
from trove.module import models as module_models
from trove.module import views as module_views
CONF = cfg.CONF CONF = cfg.CONF
@ -205,14 +207,22 @@ class InstanceController(wsgi.Controller):
"'%(tenant_id)s'"), "'%(tenant_id)s'"),
{'instance_id': id, 'tenant_id': tenant_id}) {'instance_id': id, 'tenant_id': tenant_id})
LOG.debug("req : '%s'\n\n", req) LOG.debug("req : '%s'\n\n", req)
# TODO(hub-cap): turn this into middleware
context = req.environ[wsgi.CONTEXT_KEY] context = req.environ[wsgi.CONTEXT_KEY]
instance = models.load_any_instance(context, id) instance = models.load_any_instance(context, id)
context.notification = notification.DBaaSInstanceDelete(context, context.notification = notification.DBaaSInstanceDelete(
request=req) context, request=req)
with StartNotification(context, instance_id=instance.id): with StartNotification(context, instance_id=instance.id):
marker = 'foo'
while marker:
instance_modules, marker = module_models.InstanceModules.load(
context, instance_id=id)
for instance_module in instance_modules:
instance_module = module_models.InstanceModule.load(
context, instance_module['instance_id'],
instance_module['module_id'])
module_models.InstanceModule.delete(
context, instance_module)
instance.delete() instance.delete()
# TODO(cp16net): need to set the return code correctly
return wsgi.Result(None, 202) return wsgi.Result(None, 202)
def create(self, req, body, tenant_id): def create(self, req, body, tenant_id):
@ -262,6 +272,7 @@ class InstanceController(wsgi.Controller):
slave_of_id = body['instance'].get('replica_of') slave_of_id = body['instance'].get('replica_of')
replica_count = body['instance'].get('replica_count') replica_count = body['instance'].get('replica_count')
modules = body['instance'].get('modules')
instance = models.Instance.create(context, name, flavor_id, instance = models.Instance.create(context, name, flavor_id,
image_id, databases, users, image_id, databases, users,
datastore, datastore_version, datastore, datastore_version,
@ -269,7 +280,8 @@ class InstanceController(wsgi.Controller):
availability_zone, nics, availability_zone, nics,
configuration, slave_of_id, configuration, slave_of_id,
replica_count=replica_count, replica_count=replica_count,
volume_type=volume_type) volume_type=volume_type,
modules=modules)
view = views.InstanceDetailView(instance, req=req) view = views.InstanceDetailView(instance, req=req)
return wsgi.Result(view.data(), 200) return wsgi.Result(view.data(), 200)
@ -396,3 +408,66 @@ class InstanceController(wsgi.Controller):
guest_log = client.guest_log_action(log_name, enable, disable, guest_log = client.guest_log_action(log_name, enable, disable,
publish, discard) publish, discard)
return wsgi.Result({'log': guest_log}, 200) return wsgi.Result({'log': guest_log}, 200)
def module_list(self, req, tenant_id, id):
"""Return information about modules on an instance."""
context = req.environ[wsgi.CONTEXT_KEY]
instance = models.Instance.load(context, id)
if not instance:
raise exception.NotFound(uuid=id)
from_guest = bool(req.GET.get('from_guest', '').lower())
include_contents = bool(req.GET.get('include_contents', '').lower())
if from_guest:
return self._module_list_guest(
context, id, include_contents=include_contents)
else:
return self._module_list(
context, id, include_contents=include_contents)
def _module_list_guest(self, context, id, include_contents):
"""Return information about modules on an instance."""
client = create_guest_client(context, id)
result_list = client.module_list(include_contents)
return wsgi.Result({'modules': result_list}, 200)
def _module_list(self, context, id, include_contents):
"""Return information about instnace modules."""
client = create_guest_client(context, id)
result_list = client.module_list(include_contents)
return wsgi.Result({'modules': result_list}, 200)
def module_apply(self, req, body, tenant_id, id):
"""Apply modules to an instance."""
context = req.environ[wsgi.CONTEXT_KEY]
instance = models.Instance.load(context, id)
if not instance:
raise exception.NotFound(uuid=id)
module_ids = [mod['id'] for mod in body.get('modules', [])]
modules = module_models.Modules.load_by_ids(context, module_ids)
module_list = []
for module in modules:
module.contents = module_models.Module.deprocess_contents(
module.contents)
module_info = module_views.DetailedModuleView(module).data(
include_contents=True)
module_list.append(module_info)
client = create_guest_client(context, id)
result_list = client.module_apply(module_list)
models.Instance.add_instance_modules(context, id, modules)
return wsgi.Result({'modules': result_list}, 200)
def module_remove(self, req, tenant_id, id, module_id):
"""Remove module from an instance."""
context = req.environ[wsgi.CONTEXT_KEY]
instance = models.Instance.load(context, id)
if not instance:
raise exception.NotFound(uuid=id)
module = module_models.Module.load(context, module_id)
module_info = module_views.DetailedModuleView(module).data()
client = create_guest_client(context, id)
client.module_remove(module_info)
instance_module = module_models.InstanceModule.load(
context, instance_id=id, module_id=module_id)
if instance_module:
module_models.InstanceModule.delete(context, instance_module)
return wsgi.Result(None, 200)

View File

@ -18,14 +18,15 @@
from datetime import datetime from datetime import datetime
import hashlib import hashlib
from sqlalchemy.sql.expression import or_
from trove.common import cfg from trove.common import cfg
from trove.common import crypto_utils
from trove.common import exception from trove.common import exception
from trove.common.i18n import _ from trove.common.i18n import _
from trove.common import utils from trove.common import utils
from trove.datastore import models as datastore_models from trove.datastore import models as datastore_models
from trove.db import models from trove.db import models
from trove.instance import models as instances_models
from oslo_log import log as logging from oslo_log import log as logging
@ -38,32 +39,92 @@ class Modules(object):
DEFAULT_LIMIT = CONF.modules_page_size DEFAULT_LIMIT = CONF.modules_page_size
ENCRYPT_KEY = CONF.module_aes_cbc_key ENCRYPT_KEY = CONF.module_aes_cbc_key
VALID_MODULE_TYPES = CONF.module_types VALID_MODULE_TYPES = [mt.lower() for mt in CONF.module_types]
MATCH_ALL_NAME = 'all' MATCH_ALL_NAME = 'all'
@staticmethod @staticmethod
def load(context): def load(context, datastore=None):
if context is None: if context is None:
raise TypeError("Argument context not defined.") raise TypeError("Argument context not defined.")
elif id is None: elif id is None:
raise TypeError("Argument is not defined.") raise TypeError("Argument is not defined.")
query_opts = {'deleted': False}
if datastore:
if datastore.lower() == Modules.MATCH_ALL_NAME:
datastore = None
query_opts['datastore_id'] = datastore
if context.is_admin: if context.is_admin:
db_info = DBModule.find_all(deleted=False) db_info = DBModule.find_all(**query_opts)
if db_info.count() == 0: if db_info.count() == 0:
LOG.debug("No modules found for admin user") LOG.debug("No modules found for admin user")
else: else:
db_info = DBModule.find_all( # build a query manually, since we need current tenant
tenant_id=context.tenant, visible=True, deleted=False) # plus the 'all' tenant ones
query_opts['visible'] = True
db_info = DBModule.query().filter_by(**query_opts)
db_info = db_info.filter(or_(DBModule.tenant_id == context.tenant,
DBModule.tenant_id.is_(None)))
if db_info.count() == 0: if db_info.count() == 0:
LOG.debug("No modules found for tenant %s" % context.tenant) LOG.debug("No modules found for tenant %s" % context.tenant)
modules = db_info.all()
return modules
limit = utils.pagination_limit( @staticmethod
context.limit, Modules.DEFAULT_LIMIT) def load_auto_apply(context, datastore_id, datastore_version_id):
data_view = DBModule.find_by_pagination( """Return all the auto-apply modules for the given criteria."""
'modules', db_info, 'foo', limit=limit, marker=context.marker) if context is None:
next_marker = data_view.next_page_marker raise TypeError("Argument context not defined.")
return data_view.collection, next_marker elif id is None:
raise TypeError("Argument is not defined.")
query_opts = {'deleted': False,
'auto_apply': True}
db_info = DBModule.query().filter_by(**query_opts)
db_info = Modules.add_tenant_filter(db_info, context.tenant)
db_info = Modules.add_datastore_filter(db_info, datastore_id)
db_info = Modules.add_ds_version_filter(db_info, datastore_version_id)
if db_info.count() == 0:
LOG.debug("No auto-apply modules found for tenant %s" %
context.tenant)
modules = db_info.all()
return modules
@staticmethod
def add_tenant_filter(query, tenant_id):
return query.filter(or_(DBModule.tenant_id == tenant_id,
DBModule.tenant_id.is_(None)))
@staticmethod
def add_datastore_filter(query, datastore_id):
return query.filter(or_(DBModule.datastore_id == datastore_id,
DBModule.datastore_id.is_(None)))
@staticmethod
def add_ds_version_filter(query, datastore_version_id):
return query.filter(or_(
DBModule.datastore_version_id == datastore_version_id,
DBModule.datastore_version_id.is_(None)))
@staticmethod
def load_by_ids(context, module_ids):
"""Return all the modules for the given ids. Screens out the ones
for other tenants, unless the user is admin.
"""
if context is None:
raise TypeError("Argument context not defined.")
elif id is None:
raise TypeError("Argument is not defined.")
modules = []
if module_ids:
query_opts = {'deleted': False}
db_info = DBModule.query().filter_by(**query_opts)
if not context.is_admin:
db_info = Modules.add_tenant_filter(db_info, context.tenant)
db_info = db_info.filter(DBModule.id.in_(module_ids))
modules = db_info.all()
return modules
class Module(object): class Module(object):
@ -76,7 +137,8 @@ class Module(object):
def create(context, name, module_type, contents, def create(context, name, module_type, contents,
description, tenant_id, datastore, description, tenant_id, datastore,
datastore_version, auto_apply, visible, live_update): datastore_version, auto_apply, visible, live_update):
if module_type not in Modules.VALID_MODULE_TYPES: if module_type.lower() not in Modules.VALID_MODULE_TYPES:
LOG.error("Valid module types: %s" % Modules.VALID_MODULE_TYPES)
raise exception.ModuleTypeNotFound(module_type=module_type) raise exception.ModuleTypeNotFound(module_type=module_type)
Module.validate_action( Module.validate_action(
context, 'create', tenant_id, auto_apply, visible) context, 'create', tenant_id, auto_apply, visible)
@ -92,7 +154,7 @@ class Module(object):
md5, processed_contents = Module.process_contents(contents) md5, processed_contents = Module.process_contents(contents)
module = DBModule.create( module = DBModule.create(
name=name, name=name,
type=module_type, type=module_type.lower(),
contents=processed_contents, contents=processed_contents,
description=description, description=description,
tenant_id=tenant_id, tenant_id=tenant_id,
@ -156,9 +218,16 @@ class Module(object):
@staticmethod @staticmethod
def process_contents(contents): def process_contents(contents):
md5 = hashlib.md5(contents).hexdigest() md5 = hashlib.md5(contents).hexdigest()
encrypted_contents = utils.encrypt_string( encrypted_contents = crypto_utils.encrypt_data(
contents, Modules.ENCRYPT_KEY) contents, Modules.ENCRYPT_KEY)
return md5, utils.encode_string(encrypted_contents) return md5, crypto_utils.encode_data(encrypted_contents)
# Do the reverse to 'deprocess' the contents
@staticmethod
def deprocess_contents(processed_contents):
encrypted_contents = crypto_utils.decode_data(processed_contents)
return crypto_utils.decrypt_data(
encrypted_contents, Modules.ENCRYPT_KEY)
@staticmethod @staticmethod
def delete(context, module): def delete(context, module):
@ -173,46 +242,46 @@ class Module(object):
@staticmethod @staticmethod
def enforce_live_update(module_id, live_update, md5): def enforce_live_update(module_id, live_update, md5):
if not live_update: if not live_update:
instances = DBInstanceModules.find_all( instances = DBInstanceModule.find_all(
id=module_id, md5=md5, deleted=False).all() module_id=module_id, md5=md5, deleted=False).all()
if instances: if instances:
raise exception.ModuleAppliedToInstance() raise exception.ModuleAppliedToInstance()
@staticmethod @staticmethod
def load(context, module_id): def load(context, module_id):
module = None
try: try:
if context.is_admin: if context.is_admin:
return DBModule.find_by(id=module_id, deleted=False) module = DBModule.find_by(id=module_id, deleted=False)
else: else:
return DBModule.find_by( module = DBModule.find_by(
id=module_id, tenant_id=context.tenant, visible=True, id=module_id, tenant_id=context.tenant, visible=True,
deleted=False) deleted=False)
except exception.ModelNotFoundError: except exception.ModelNotFoundError:
# See if we have the module in the 'all' tenant section # See if we have the module in the 'all' tenant section
if not context.is_admin: if not context.is_admin:
try: try:
return DBModule.find_by( module = DBModule.find_by(
id=module_id, tenant_id=None, visible=True, id=module_id, tenant_id=None, visible=True,
deleted=False) deleted=False)
except exception.ModelNotFoundError: except exception.ModelNotFoundError:
pass # fall through to the raise below pass # fall through to the raise below
if not module:
msg = _("Module with ID %s could not be found.") % module_id msg = _("Module with ID %s could not be found.") % module_id
raise exception.ModelNotFoundError(msg) raise exception.ModelNotFoundError(msg)
# Save the encrypted contents in case we need to put it back
# when updating the record
module.encrypted_contents = module.contents
module.contents = Module.deprocess_contents(module.contents)
return module
@staticmethod @staticmethod
def update(context, module, original_module): def update(context, module, original_module):
Module.enforce_live_update( Module.enforce_live_update(
original_module.id, original_module.live_update, original_module.id, original_module.live_update,
original_module.md5) original_module.md5)
do_update = False
if module.contents != original_module.contents:
md5, processed_contents = Module.process_contents(module.contents)
do_update = (original_module.live_update and
md5 != original_module.md5)
module.md5 = md5
module.contents = processed_contents
else:
module.contents = original_module.contents
# we don't allow any changes to 'admin'-type modules, even if # we don't allow any changes to 'admin'-type modules, even if
# the values changed aren't the admin ones. # the values changed aren't the admin ones.
access_tenant_id = (None if (original_module.tenant_id is None or access_tenant_id = (None if (original_module.tenant_id is None or
@ -225,6 +294,14 @@ class Module(object):
access_tenant_id, access_auto_apply, access_visible) access_tenant_id, access_auto_apply, access_visible)
ds_id, ds_ver_id = Module.validate_datastore( ds_id, ds_ver_id = Module.validate_datastore(
module.datastore_id, module.datastore_version_id) module.datastore_id, module.datastore_version_id)
if module.contents != original_module.contents:
md5, processed_contents = Module.process_contents(module.contents)
module.md5 = md5
module.contents = processed_contents
else:
# on load the contents were decrypted, so
# we need to put the encrypted contents back before we update
module.contents = original_module.encrypted_contents
if module.datastore_id: if module.datastore_id:
module.datastore_id = ds_id module.datastore_id = ds_id
if module.datastore_version_id: if module.datastore_version_id:
@ -232,27 +309,73 @@ class Module(object):
module.updated = datetime.utcnow() module.updated = datetime.utcnow()
DBModule.save(module) DBModule.save(module)
if do_update:
Module.reapply_on_all_instances(context, module)
class InstanceModules(object):
@staticmethod @staticmethod
def reapply_on_all_instances(context, module): def load(context, instance_id=None, module_id=None, md5=None):
"""Reapply a module on all its instances, if required.""" selection = {'deleted': False}
if module.live_update: if instance_id:
instance_modules = DBInstanceModules.find_all( selection['instance_id'] = instance_id
id=module.id, deleted=False).all() if module_id:
selection['module_id'] = module_id
if md5:
selection['md5'] = md5
db_info = DBInstanceModule.find_all(**selection)
if db_info.count() == 0:
LOG.debug("No instance module records found")
LOG.debug( limit = utils.pagination_limit(
"All instances with module '%s' applied: %s" context.limit, Modules.DEFAULT_LIMIT)
% (module.id, instance_modules)) data_view = DBInstanceModule.find_by_pagination(
'modules', db_info, 'foo', limit=limit, marker=context.marker)
next_marker = data_view.next_page_marker
return data_view.collection, next_marker
for instance_module in instance_modules:
if instance_module.md5 != module.md5: class InstanceModule(object):
LOG.debug("Applying module '%s' to instance: %s"
% (module.id, instance_module.instance_id)) def __init__(self, context, instance_id, module_id):
instance = instances_models.Instance.load( self.context = context
context, instance_module.instance_id) self.instance_id = instance_id
instance.apply_module(module) self.module_id = module_id
@staticmethod
def create(context, instance_id, module_id, md5):
instance_module = DBInstanceModule.create(
instance_id=instance_id,
module_id=module_id,
md5=md5)
return instance_module
@staticmethod
def delete(context, instance_module):
instance_module.deleted = True
instance_module.deleted_at = datetime.utcnow()
instance_module.save()
@staticmethod
def load(context, instance_id, module_id, deleted=False):
instance_module = None
try:
instance_module = DBInstanceModule.find_by(
instance_id=instance_id, module_id=module_id, deleted=deleted)
except exception.ModelNotFoundError:
pass
return instance_module
@staticmethod
def update(context, instance_module):
instance_module.updated = datetime.utcnow()
DBInstanceModule.save(instance_module)
class DBInstanceModule(models.DatabaseModelBase):
_data_fields = [
'id', 'instance_id', 'module_id', 'md5', 'created',
'updated', 'deleted', 'deleted_at']
class DBModule(models.DatabaseModelBase): class DBModule(models.DatabaseModelBase):
@ -263,11 +386,5 @@ class DBModule(models.DatabaseModelBase):
'md5', 'created', 'updated', 'deleted', 'deleted_at'] 'md5', 'created', 'updated', 'deleted', 'deleted_at']
class DBInstanceModules(models.DatabaseModelBase):
_data_fields = [
'id', 'instance_id', 'module_id', 'md5', 'created',
'updated', 'deleted', 'deleted_at']
def persisted_models(): def persisted_models():
return {'modules': DBModule, 'instance_modules': DBInstanceModules} return {'modules': DBModule, 'instance_modules': DBInstanceModule}

View File

@ -23,6 +23,9 @@ from trove.common import cfg
from trove.common.i18n import _ from trove.common.i18n import _
from trove.common import pagination from trove.common import pagination
from trove.common import wsgi from trove.common import wsgi
from trove.datastore import models as datastore_models
from trove.instance import models as instance_models
from trove.instance import views as instance_views
from trove.module import models from trove.module import models
from trove.module import views from trove.module import views
@ -37,20 +40,22 @@ class ModuleController(wsgi.Controller):
def index(self, req, tenant_id): def index(self, req, tenant_id):
context = req.environ[wsgi.CONTEXT_KEY] context = req.environ[wsgi.CONTEXT_KEY]
modules, marker = models.Modules.load(context) datastore = req.GET.get('datastore', '')
if datastore and datastore.lower() != models.Modules.MATCH_ALL_NAME:
ds, ds_ver = datastore_models.get_datastore_version(
type=datastore)
datastore = ds.id
modules = models.Modules.load(context, datastore=datastore)
view = views.ModulesView(modules) view = views.ModulesView(modules)
paged = pagination.SimplePaginatedDataView(req.url, 'modules', return wsgi.Result(view.data(), 200)
view, marker)
return wsgi.Result(paged.data(), 200)
def show(self, req, tenant_id, id): def show(self, req, tenant_id, id):
LOG.info(_("Showing module %s") % id) LOG.info(_("Showing module %s") % id)
context = req.environ[wsgi.CONTEXT_KEY] context = req.environ[wsgi.CONTEXT_KEY]
module = models.Module.load(context, id) module = models.Module.load(context, id)
module.instance_count = models.DBInstanceModules.find_all( module.instance_count = len(models.InstanceModules.load(
id=module.id, md5=module.md5, context, module_id=module.id, md5=module.md5))
deleted=False).count()
return wsgi.Result( return wsgi.Result(
views.DetailedModuleView(module).data(), 200) views.DetailedModuleView(module).data(), 200)
@ -121,3 +126,24 @@ class ModuleController(wsgi.Controller):
models.Module.update(context, module, original_module) models.Module.update(context, module, original_module)
view_data = views.DetailedModuleView(module) view_data = views.DetailedModuleView(module)
return wsgi.Result(view_data.data(), 200) return wsgi.Result(view_data.data(), 200)
def instances(self, req, tenant_id, id):
LOG.info(_("Getting instances for module %s") % id)
context = req.environ[wsgi.CONTEXT_KEY]
instance_modules, marker = models.InstanceModules.load(
context, module_id=id)
if instance_modules:
instance_ids = [inst_mod.instance_id
for inst_mod in instance_modules]
include_clustered = (
req.GET.get('include_clustered', '').lower() == 'true')
instances, marker = instance_models.Instances.load(
context, include_clustered, instance_ids=instance_ids)
else:
instances = []
marker = None
view = instance_views.InstancesView(instances, req=req)
paged = pagination.SimplePaginatedDataView(req.url, 'instances',
view, marker)
return wsgi.Result(paged.data(), 200)

View File

@ -38,6 +38,7 @@ class ModuleView(object):
datastore_version_id=self.module.datastore_version_id, datastore_version_id=self.module.datastore_version_id,
auto_apply=self.module.auto_apply, auto_apply=self.module.auto_apply,
md5=self.module.md5, md5=self.module.md5,
visible=self.module.visible,
created=self.module.created, created=self.module.created,
updated=self.module.updated) updated=self.module.updated)
# add extra data to make results more legible # add extra data to make results more legible
@ -91,11 +92,12 @@ class DetailedModuleView(ModuleView):
def __init__(self, module): def __init__(self, module):
super(DetailedModuleView, self).__init__(module) super(DetailedModuleView, self).__init__(module)
def data(self): def data(self, include_contents=False):
return_value = super(DetailedModuleView, self).data() return_value = super(DetailedModuleView, self).data()
module_dict = return_value["module"] module_dict = return_value["module"]
module_dict["visible"] = self.module.visible
module_dict["live_update"] = self.module.live_update module_dict["live_update"] = self.module.live_update
if hasattr(self.module, 'instance_count'): if hasattr(self.module, 'instance_count'):
module_dict["instance_count"] = self.module.instance_count module_dict["instance_count"] = self.module.instance_count
if include_contents:
module_dict['contents'] = self.module.contents
return {"module": module_dict} return {"module": module_dict}

View File

@ -151,10 +151,10 @@ class API(object):
packages, volume_size, backup_id=None, packages, volume_size, backup_id=None,
availability_zone=None, root_password=None, availability_zone=None, root_password=None,
nics=None, overrides=None, slave_of_id=None, nics=None, overrides=None, slave_of_id=None,
cluster_config=None, volume_type=None): cluster_config=None, volume_type=None,
modules=None):
LOG.debug("Making async call to create instance %s " % instance_id) LOG.debug("Making async call to create instance %s " % instance_id)
self._cast("create_instance", self.version_cap, self._cast("create_instance", self.version_cap,
instance_id=instance_id, name=name, instance_id=instance_id, name=name,
flavor=self._transform_obj(flavor), flavor=self._transform_obj(flavor),
@ -171,7 +171,8 @@ class API(object):
overrides=overrides, overrides=overrides,
slave_of_id=slave_of_id, slave_of_id=slave_of_id,
cluster_config=cluster_config, cluster_config=cluster_config,
volume_type=volume_type) volume_type=volume_type,
modules=modules)
def create_cluster(self, cluster_id): def create_cluster(self, cluster_id):
LOG.debug("Making async call to create cluster %s " % cluster_id) LOG.debug("Making async call to create cluster %s " % cluster_id)

View File

@ -277,7 +277,7 @@ class Manager(periodic_task.PeriodicTasks):
datastore_manager, packages, volume_size, datastore_manager, packages, volume_size,
availability_zone, root_password, nics, availability_zone, root_password, nics,
overrides, slave_of_id, backup_id, overrides, slave_of_id, backup_id,
volume_type): volume_type, modules):
if type(instance_id) in [list]: if type(instance_id) in [list]:
ids = instance_id ids = instance_id
@ -307,7 +307,8 @@ class Manager(periodic_task.PeriodicTasks):
flavor, image_id, databases, users, datastore_manager, flavor, image_id, databases, users, datastore_manager,
packages, volume_size, replica_backup_id, packages, volume_size, replica_backup_id,
availability_zone, root_passwords[replica_index], availability_zone, root_passwords[replica_index],
nics, overrides, None, snapshot, volume_type) nics, overrides, None, snapshot, volume_type,
modules)
replicas.append(instance_tasks) replicas.append(instance_tasks)
except Exception: except Exception:
# if it's the first replica, then we shouldn't continue # if it's the first replica, then we shouldn't continue
@ -328,7 +329,7 @@ class Manager(periodic_task.PeriodicTasks):
image_id, databases, users, datastore_manager, image_id, databases, users, datastore_manager,
packages, volume_size, backup_id, availability_zone, packages, volume_size, backup_id, availability_zone,
root_password, nics, overrides, slave_of_id, root_password, nics, overrides, slave_of_id,
cluster_config, volume_type): cluster_config, volume_type, modules):
if slave_of_id: if slave_of_id:
self._create_replication_slave(context, instance_id, name, self._create_replication_slave(context, instance_id, name,
flavor, image_id, databases, users, flavor, image_id, databases, users,
@ -336,7 +337,7 @@ class Manager(periodic_task.PeriodicTasks):
volume_size, volume_size,
availability_zone, root_password, availability_zone, root_password,
nics, overrides, slave_of_id, nics, overrides, slave_of_id,
backup_id, volume_type) backup_id, volume_type, modules)
else: else:
if type(instance_id) in [list]: if type(instance_id) in [list]:
raise AttributeError(_( raise AttributeError(_(
@ -347,7 +348,7 @@ class Manager(periodic_task.PeriodicTasks):
volume_size, backup_id, volume_size, backup_id,
availability_zone, root_password, availability_zone, root_password,
nics, overrides, cluster_config, nics, overrides, cluster_config,
None, volume_type) None, volume_type, modules)
timeout = (CONF.restore_usage_timeout if backup_id timeout = (CONF.restore_usage_timeout if backup_id
else CONF.usage_timeout) else CONF.usage_timeout)
instance_tasks.wait_for_instance(timeout, flavor) instance_tasks.wait_for_instance(timeout, flavor)
@ -356,7 +357,7 @@ class Manager(periodic_task.PeriodicTasks):
image_id, databases, users, datastore_manager, image_id, databases, users, datastore_manager,
packages, volume_size, backup_id, availability_zone, packages, volume_size, backup_id, availability_zone,
root_password, nics, overrides, slave_of_id, root_password, nics, overrides, slave_of_id,
cluster_config, volume_type): cluster_config, volume_type, modules):
with EndNotification(context, with EndNotification(context,
instance_id=(instance_id[0] instance_id=(instance_id[0]
if type(instance_id) is list if type(instance_id) is list
@ -366,7 +367,7 @@ class Manager(periodic_task.PeriodicTasks):
datastore_manager, packages, volume_size, datastore_manager, packages, volume_size,
backup_id, availability_zone, backup_id, availability_zone,
root_password, nics, overrides, slave_of_id, root_password, nics, overrides, slave_of_id,
cluster_config, volume_type) cluster_config, volume_type, modules)
def update_overrides(self, context, instance_id, overrides): def update_overrides(self, context, instance_id, overrides):
instance_tasks = models.BuiltInstanceTasks.load(context, instance_id) instance_tasks = models.BuiltInstanceTasks.load(context, instance_id)

View File

@ -366,7 +366,8 @@ class FreshInstanceTasks(FreshInstance, NotifyMixin, ConfigurationMixin):
def create_instance(self, flavor, image_id, databases, users, def create_instance(self, flavor, image_id, databases, users,
datastore_manager, packages, volume_size, datastore_manager, packages, volume_size,
backup_id, availability_zone, root_password, nics, backup_id, availability_zone, root_password, nics,
overrides, cluster_config, snapshot, volume_type): overrides, cluster_config, snapshot, volume_type,
modules):
# It is the caller's responsibility to ensure that # It is the caller's responsibility to ensure that
# FreshInstanceTasks.wait_for_instance is called after # FreshInstanceTasks.wait_for_instance is called after
# create_instance to ensure that the proper usage event gets sent # create_instance to ensure that the proper usage event gets sent
@ -440,7 +441,7 @@ class FreshInstanceTasks(FreshInstance, NotifyMixin, ConfigurationMixin):
packages, databases, users, backup_info, packages, databases, users, backup_info,
config.config_contents, root_password, config.config_contents, root_password,
overrides, overrides,
cluster_config, snapshot) cluster_config, snapshot, modules)
if root_password: if root_password:
self.report_root_enabled() self.report_root_enabled()
@ -922,7 +923,8 @@ class FreshInstanceTasks(FreshInstance, NotifyMixin, ConfigurationMixin):
def _guest_prepare(self, flavor_ram, volume_info, def _guest_prepare(self, flavor_ram, volume_info,
packages, databases, users, backup_info=None, packages, databases, users, backup_info=None,
config_contents=None, root_password=None, config_contents=None, root_password=None,
overrides=None, cluster_config=None, snapshot=None): overrides=None, cluster_config=None, snapshot=None,
modules=None):
LOG.debug("Entering guest_prepare") LOG.debug("Entering guest_prepare")
# Now wait for the response from the create to do additional work # Now wait for the response from the create to do additional work
self.guest.prepare(flavor_ram, packages, databases, users, self.guest.prepare(flavor_ram, packages, databases, users,
@ -933,7 +935,7 @@ class FreshInstanceTasks(FreshInstance, NotifyMixin, ConfigurationMixin):
root_password=root_password, root_password=root_password,
overrides=overrides, overrides=overrides,
cluster_config=cluster_config, cluster_config=cluster_config,
snapshot=snapshot) snapshot=snapshot, modules=modules)
def _create_dns_entry(self): def _create_dns_entry(self):
dns_support = CONF.trove_dns_support dns_support = CONF.trove_dns_support

View File

@ -222,7 +222,7 @@ class FakeGuest(object):
def prepare(self, memory_mb, packages, databases, users, device_path=None, def prepare(self, memory_mb, packages, databases, users, device_path=None,
mount_point=None, backup_info=None, config_contents=None, mount_point=None, backup_info=None, config_contents=None,
root_password=None, overrides=None, cluster_config=None, root_password=None, overrides=None, cluster_config=None,
snapshot=None): snapshot=None, modules=None):
from trove.guestagent.models import AgentHeartBeat from trove.guestagent.models import AgentHeartBeat
from trove.instance.models import DBInstance from trove.instance.models import DBInstance
from trove.instance.models import InstanceServiceStatus from trove.instance.models import InstanceServiceStatus
@ -361,6 +361,15 @@ class FakeGuest(object):
def backup_required_for_replication(self): def backup_required_for_replication(self):
return True return True
def module_list(self, context, include_contents=False):
return []
def module_apply(self, context, modules=None):
return []
def module_remove(self, context, module=None):
pass
def get_or_create(id): def get_or_create(id):
if id not in DB: if id not in DB:

View File

@ -163,7 +163,7 @@ module_groups = list(instance_create_groups)
module_groups.extend([module_group.GROUP]) module_groups.extend([module_group.GROUP])
module_create_groups = list(base_groups) module_create_groups = list(base_groups)
module_create_groups.extend([module_group.GROUP_MODULE, module_create_groups.extend([module_group.GROUP_MODULE_CREATE,
module_group.GROUP_MODULE_DELETE]) module_group.GROUP_MODULE_DELETE])
replication_groups = list(instance_create_groups) replication_groups = list(instance_create_groups)

View File

@ -20,13 +20,13 @@ from trove.tests.scenario.groups import instance_create_group
from trove.tests.scenario.groups.test_group import TestGroup from trove.tests.scenario.groups.test_group import TestGroup
GROUP = "scenario.module_all_group" GROUP = "scenario.module_group"
GROUP_MODULE = "scenario.module_group" GROUP_MODULE_CREATE = "scenario.module_create_group"
GROUP_MODULE_DELETE = "scenario.module_delete_group"
GROUP_INSTANCE_MODULE = "scenario.instance_module_group" GROUP_INSTANCE_MODULE = "scenario.instance_module_group"
GROUP_MODULE_DELETE = "scenario.module_delete_group"
@test(groups=[GROUP, GROUP_MODULE]) @test(groups=[GROUP, GROUP_MODULE_CREATE])
class ModuleGroup(TestGroup): class ModuleGroup(TestGroup):
"""Test Module functionality.""" """Test Module functionality."""
@ -34,251 +34,291 @@ class ModuleGroup(TestGroup):
super(ModuleGroup, self).__init__( super(ModuleGroup, self).__init__(
'module_runners', 'ModuleRunner') 'module_runners', 'ModuleRunner')
@test(groups=[GROUP, GROUP_MODULE]) @test(groups=[GROUP, GROUP_MODULE_CREATE])
def module_delete_existing(self): def module_delete_existing(self):
"""Delete all previous test modules.""" """Delete all previous test modules."""
self.test_runner.run_module_delete_existing() self.test_runner.run_module_delete_existing()
@test(groups=[GROUP, GROUP_MODULE]) @test(groups=[GROUP, GROUP_MODULE_CREATE])
def module_create_bad_type(self): def module_create_bad_type(self):
"""Ensure create module fails with invalid type.""" """Ensure create module with invalid type fails."""
self.test_runner.run_module_create_bad_type() self.test_runner.run_module_create_bad_type()
@test(groups=[GROUP, GROUP_MODULE]) @test(groups=[GROUP, GROUP_MODULE_CREATE])
def module_create_non_admin_auto(self): def module_create_non_admin_auto(self):
"""Ensure create auto_apply module fails for non-admin.""" """Ensure create auto_apply module for non-admin fails."""
self.test_runner.run_module_create_non_admin_auto() self.test_runner.run_module_create_non_admin_auto()
@test(groups=[GROUP, GROUP_MODULE]) @test(groups=[GROUP, GROUP_MODULE_CREATE])
def module_create_non_admin_all_tenant(self): def module_create_non_admin_all_tenant(self):
"""Ensure create all tenant module fails for non-admin.""" """Ensure create all tenant module for non-admin fails."""
self.test_runner.run_module_create_non_admin_all_tenant() self.test_runner.run_module_create_non_admin_all_tenant()
@test(groups=[GROUP, GROUP_MODULE]) @test(groups=[GROUP, GROUP_MODULE_CREATE])
def module_create_non_admin_hidden(self): def module_create_non_admin_hidden(self):
"""Ensure create hidden module fails for non-admin.""" """Ensure create hidden module for non-admin fails."""
self.test_runner.run_module_create_non_admin_hidden() self.test_runner.run_module_create_non_admin_hidden()
@test(groups=[GROUP, GROUP_MODULE]) @test(groups=[GROUP, GROUP_MODULE_CREATE])
def module_create_bad_datastore(self): def module_create_bad_datastore(self):
"""Ensure create module fails with invalid datastore.""" """Ensure create module with invalid datastore fails."""
self.test_runner.run_module_create_bad_datastore() self.test_runner.run_module_create_bad_datastore()
@test(groups=[GROUP, GROUP_MODULE]) @test(groups=[GROUP, GROUP_MODULE_CREATE])
def module_create_bad_datastore_version(self): def module_create_bad_datastore_version(self):
"""Ensure create module fails with invalid datastore_version.""" """Ensure create module with invalid datastore_version fails."""
self.test_runner.run_module_create_bad_datastore_version() self.test_runner.run_module_create_bad_datastore_version()
@test(groups=[GROUP, GROUP_MODULE]) @test(groups=[GROUP, GROUP_MODULE_CREATE])
def module_create_missing_datastore(self): def module_create_missing_datastore(self):
"""Ensure create module fails with missing datastore.""" """Ensure create module with missing datastore fails."""
self.test_runner.run_module_create_missing_datastore() self.test_runner.run_module_create_missing_datastore()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
runs_after=[module_delete_existing]) runs_after=[module_delete_existing])
def module_create(self): def module_create(self):
"""Check that create module works.""" """Check that create module works."""
self.test_runner.run_module_create() self.test_runner.run_module_create()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_create]) depends_on=[module_create])
def module_create_dupe(self): def module_create_dupe(self):
"""Ensure create with duplicate info fails.""" """Ensure create with duplicate info fails."""
self.test_runner.run_module_create_dupe() self.test_runner.run_module_create_dupe()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
runs_after=[module_create])
def module_create_bin(self):
"""Check that create module with binary contents works."""
self.test_runner.run_module_create_bin()
@test(groups=[GROUP, GROUP_MODULE_CREATE],
runs_after=[module_create_bin])
def module_create_bin2(self):
"""Check that create module with other binary contents works."""
self.test_runner.run_module_create_bin2()
@test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_create]) depends_on=[module_create])
def module_show(self): def module_show(self):
"""Check that show module works.""" """Check that show module works."""
self.test_runner.run_module_show() self.test_runner.run_module_show()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_create]) depends_on=[module_create])
def module_show_unauth_user(self): def module_show_unauth_user(self):
"""Ensure that show module for unauth user fails.""" """Ensure that show module for unauth user fails."""
self.test_runner.run_module_show_unauth_user() self.test_runner.run_module_show_unauth_user()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_create]) depends_on=[module_create, module_create_bin, module_create_bin2])
def module_list(self): def module_list(self):
"""Check that list modules works.""" """Check that list modules works."""
self.test_runner.run_module_list() self.test_runner.run_module_list()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_create]) depends_on=[module_create, module_create_bin, module_create_bin2])
def module_list_unauth_user(self): def module_list_unauth_user(self):
"""Ensure that list module for unauth user fails.""" """Ensure that list module for unauth user fails."""
self.test_runner.run_module_list_unauth_user() self.test_runner.run_module_list_unauth_user()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_create], depends_on=[module_create, module_create_bin, module_create_bin2],
runs_after=[module_list]) runs_after=[module_list])
def module_create_admin_all(self): def module_create_admin_all(self):
"""Check that create module works with all admin options.""" """Check that create module works with all admin options."""
self.test_runner.run_module_create_admin_all() self.test_runner.run_module_create_admin_all()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_create], depends_on=[module_create, module_create_bin, module_create_bin2],
runs_after=[module_create_admin_all]) runs_after=[module_create_admin_all])
def module_create_admin_hidden(self): def module_create_admin_hidden(self):
"""Check that create module works with hidden option.""" """Check that create module works with hidden option."""
self.test_runner.run_module_create_admin_hidden() self.test_runner.run_module_create_admin_hidden()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_create], depends_on=[module_create, module_create_bin, module_create_bin2],
runs_after=[module_create_admin_hidden]) runs_after=[module_create_admin_hidden])
def module_create_admin_auto(self): def module_create_admin_auto(self):
"""Check that create module works with auto option.""" """Check that create module works with auto option."""
self.test_runner.run_module_create_admin_auto() self.test_runner.run_module_create_admin_auto()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_create], depends_on=[module_create, module_create_bin, module_create_bin2],
runs_after=[module_create_admin_auto]) runs_after=[module_create_admin_auto])
def module_create_admin_live_update(self): def module_create_admin_live_update(self):
"""Check that create module works with live-update option.""" """Check that create module works with live-update option."""
self.test_runner.run_module_create_admin_live_update() self.test_runner.run_module_create_admin_live_update()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_create], depends_on=[module_create, module_create_bin, module_create_bin2],
runs_after=[module_create_admin_live_update]) runs_after=[module_create_admin_live_update])
def module_create_datastore(self):
"""Check that create module with datastore works."""
self.test_runner.run_module_create_datastore()
@test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_create, module_create_bin, module_create_bin2],
runs_after=[module_create_datastore])
def module_create_ds_version(self):
"""Check that create module with ds version works."""
self.test_runner.run_module_create_ds_version()
@test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_create, module_create_bin, module_create_bin2],
runs_after=[module_create_ds_version])
def module_create_all_tenant(self): def module_create_all_tenant(self):
"""Check that create 'all' tenants with datastore module works.""" """Check that create 'all' tenants with datastore module works."""
self.test_runner.run_module_create_all_tenant() self.test_runner.run_module_create_all_tenant()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_create], depends_on=[module_create, module_create_bin, module_create_bin2],
runs_after=[module_create_all_tenant, module_list_unauth_user]) runs_after=[module_create_all_tenant, module_list_unauth_user])
def module_create_different_tenant(self): def module_create_different_tenant(self):
"""Check that create with same name on different tenant works.""" """Check that create with same name on different tenant works."""
self.test_runner.run_module_create_different_tenant() self.test_runner.run_module_create_different_tenant()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_create_all_tenant], depends_on=[module_create_all_tenant],
runs_after=[module_create_different_tenant]) runs_after=[module_create_different_tenant])
def module_list_again(self): def module_list_again(self):
"""Check that list modules skips invisible modules.""" """Check that list modules skips invisible modules."""
self.test_runner.run_module_list_again() self.test_runner.run_module_list_again()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_create_ds_version],
runs_after=[module_list_again])
def module_list_ds(self):
"""Check that list modules by datastore works."""
self.test_runner.run_module_list_ds()
@test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_create_ds_version],
runs_after=[module_list_ds])
def module_list_ds_all(self):
"""Check that list modules by all datastores works."""
self.test_runner.run_module_list_ds_all()
@test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_create_admin_hidden]) depends_on=[module_create_admin_hidden])
def module_show_invisible(self): def module_show_invisible(self):
"""Ensure that show invisible module for non-admin fails.""" """Ensure that show invisible module for non-admin fails."""
self.test_runner.run_module_show_invisible() self.test_runner.run_module_show_invisible()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_create_all_tenant], depends_on=[module_create_all_tenant],
runs_after=[module_create_different_tenant]) runs_after=[module_create_different_tenant])
def module_list_admin(self): def module_list_admin(self):
"""Check that list modules for admin works.""" """Check that list modules for admin works."""
self.test_runner.run_module_list_admin() self.test_runner.run_module_list_admin()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_create], depends_on=[module_create],
runs_after=[module_show]) runs_after=[module_show])
def module_update(self): def module_update(self):
"""Check that update module works.""" """Check that update module works."""
self.test_runner.run_module_update() self.test_runner.run_module_update()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_update]) depends_on=[module_update])
def module_update_same_contents(self): def module_update_same_contents(self):
"""Check that update module with same contents works.""" """Check that update module with same contents works."""
self.test_runner.run_module_update_same_contents() self.test_runner.run_module_update_same_contents()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_update], depends_on=[module_update],
runs_after=[module_update_same_contents]) runs_after=[module_update_same_contents])
def module_update_auto_toggle(self): def module_update_auto_toggle(self):
"""Check that update module works for auto apply toggle.""" """Check that update module works for auto apply toggle."""
self.test_runner.run_module_update_auto_toggle() self.test_runner.run_module_update_auto_toggle()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_update], depends_on=[module_update],
runs_after=[module_update_auto_toggle]) runs_after=[module_update_auto_toggle])
def module_update_all_tenant_toggle(self): def module_update_all_tenant_toggle(self):
"""Check that update module works for all tenant toggle.""" """Check that update module works for all tenant toggle."""
self.test_runner.run_module_update_all_tenant_toggle() self.test_runner.run_module_update_all_tenant_toggle()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_update], depends_on=[module_update],
runs_after=[module_update_all_tenant_toggle]) runs_after=[module_update_all_tenant_toggle])
def module_update_invisible_toggle(self): def module_update_invisible_toggle(self):
"""Check that update module works for invisible toggle.""" """Check that update module works for invisible toggle."""
self.test_runner.run_module_update_invisible_toggle() self.test_runner.run_module_update_invisible_toggle()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_update], depends_on=[module_update],
runs_after=[module_update_invisible_toggle]) runs_after=[module_update_invisible_toggle])
def module_update_unauth(self): def module_update_unauth(self):
"""Ensure update module fails for unauth user.""" """Ensure update module for unauth user fails."""
self.test_runner.run_module_update_unauth() self.test_runner.run_module_update_unauth()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_update], depends_on=[module_update],
runs_after=[module_update_invisible_toggle]) runs_after=[module_update_invisible_toggle])
def module_update_non_admin_auto(self): def module_update_non_admin_auto(self):
"""Ensure update module to auto_apply fails for non-admin.""" """Ensure update module to auto_apply for non-admin fails."""
self.test_runner.run_module_update_non_admin_auto() self.test_runner.run_module_update_non_admin_auto()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_update], depends_on=[module_update],
runs_after=[module_update_invisible_toggle]) runs_after=[module_update_invisible_toggle])
def module_update_non_admin_auto_off(self): def module_update_non_admin_auto_off(self):
"""Ensure update module to auto_apply off fails for non-admin.""" """Ensure update module to auto_apply off for non-admin fails."""
self.test_runner.run_module_update_non_admin_auto_off() self.test_runner.run_module_update_non_admin_auto_off()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_update], depends_on=[module_update],
runs_after=[module_update_invisible_toggle]) runs_after=[module_update_invisible_toggle])
def module_update_non_admin_auto_any(self): def module_update_non_admin_auto_any(self):
"""Ensure any update module to auto_apply fails for non-admin.""" """Ensure any update module to auto_apply for non-admin fails."""
self.test_runner.run_module_update_non_admin_auto_any() self.test_runner.run_module_update_non_admin_auto_any()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_update], depends_on=[module_update],
runs_after=[module_update_invisible_toggle]) runs_after=[module_update_invisible_toggle])
def module_update_non_admin_all_tenant(self): def module_update_non_admin_all_tenant(self):
"""Ensure update module to all tenant fails for non-admin.""" """Ensure update module to all tenant for non-admin fails."""
self.test_runner.run_module_update_non_admin_all_tenant() self.test_runner.run_module_update_non_admin_all_tenant()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_update], depends_on=[module_update],
runs_after=[module_update_invisible_toggle]) runs_after=[module_update_invisible_toggle])
def module_update_non_admin_all_tenant_off(self): def module_update_non_admin_all_tenant_off(self):
"""Ensure update module to all tenant off fails for non-admin.""" """Ensure update module to all tenant off for non-admin fails."""
self.test_runner.run_module_update_non_admin_all_tenant_off() self.test_runner.run_module_update_non_admin_all_tenant_off()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_update], depends_on=[module_update],
runs_after=[module_update_invisible_toggle]) runs_after=[module_update_invisible_toggle])
def module_update_non_admin_all_tenant_any(self): def module_update_non_admin_all_tenant_any(self):
"""Ensure any update module to all tenant fails for non-admin.""" """Ensure any update module to all tenant for non-admin fails."""
self.test_runner.run_module_update_non_admin_all_tenant_any() self.test_runner.run_module_update_non_admin_all_tenant_any()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_update], depends_on=[module_update],
runs_after=[module_update_invisible_toggle]) runs_after=[module_update_invisible_toggle])
def module_update_non_admin_invisible(self): def module_update_non_admin_invisible(self):
"""Ensure update module to invisible fails for non-admin.""" """Ensure update module to invisible for non-admin fails."""
self.test_runner.run_module_update_non_admin_invisible() self.test_runner.run_module_update_non_admin_invisible()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_update], depends_on=[module_update],
runs_after=[module_update_invisible_toggle]) runs_after=[module_update_invisible_toggle])
def module_update_non_admin_invisible_off(self): def module_update_non_admin_invisible_off(self):
"""Ensure update module to invisible off fails for non-admin.""" """Ensure update module to invisible off for non-admin fails."""
self.test_runner.run_module_update_non_admin_invisible_off() self.test_runner.run_module_update_non_admin_invisible_off()
@test(groups=[GROUP, GROUP_MODULE], @test(groups=[GROUP, GROUP_MODULE_CREATE],
depends_on=[module_update], depends_on=[module_update],
runs_after=[module_update_invisible_toggle]) runs_after=[module_update_invisible_toggle])
def module_update_non_admin_invisible_any(self): def module_update_non_admin_invisible_any(self):
"""Ensure any update module to invisible fails for non-admin.""" """Ensure any update module to invisible for non-admin fails."""
self.test_runner.run_module_update_non_admin_invisible_any() self.test_runner.run_module_update_non_admin_invisible_any()
@test(depends_on_groups=[instance_create_group.GROUP, @test(depends_on_groups=[instance_create_group.GROUP,
GROUP_MODULE], GROUP_MODULE_CREATE],
groups=[GROUP, GROUP_INSTANCE_MODULE]) groups=[GROUP, GROUP_INSTANCE_MODULE])
class ModuleInstanceGroup(TestGroup): class ModuleInstanceGroup(TestGroup):
"""Test Instance Module functionality.""" """Test Instance Module functionality."""
@ -287,8 +327,118 @@ class ModuleInstanceGroup(TestGroup):
super(ModuleInstanceGroup, self).__init__( super(ModuleInstanceGroup, self).__init__(
'module_runners', 'ModuleRunner') 'module_runners', 'ModuleRunner')
@test(groups=[GROUP, GROUP_INSTANCE_MODULE])
def module_list_instance_empty(self):
"""Check that the instance has no modules associated."""
self.test_runner.run_module_list_instance_empty()
@test(depends_on_groups=[GROUP_MODULE], @test(groups=[GROUP, GROUP_INSTANCE_MODULE],
runs_after=[module_list_instance_empty])
def module_instances_empty(self):
"""Check that the module hasn't been applied to any instances."""
self.test_runner.run_module_instances_empty()
@test(groups=[GROUP, GROUP_INSTANCE_MODULE],
runs_after=[module_instances_empty])
def module_query_empty(self):
"""Check that the instance has no modules applied."""
self.test_runner.run_module_query_empty()
@test(groups=[GROUP, GROUP_INSTANCE_MODULE],
runs_after=[module_query_empty])
def module_apply(self):
"""Check that module-apply works."""
self.test_runner.run_module_apply()
@test(groups=[GROUP, GROUP_INSTANCE_MODULE],
depends_on=[module_apply])
def module_list_instance_after_apply(self):
"""Check that the instance has one module associated."""
self.test_runner.run_module_list_instance_after_apply()
@test(groups=[GROUP, GROUP_INSTANCE_MODULE],
depends_on=[module_apply])
def module_query_after_apply(self):
"""Check that module-query works."""
self.test_runner.run_module_query_after_apply()
@test(groups=[GROUP, GROUP_INSTANCE_MODULE],
depends_on=[module_apply],
runs_after=[module_query_after_apply])
def create_inst_with_mods(self):
"""Check that creating an instance with modules works."""
self.test_runner.run_create_inst_with_mods()
@test(groups=[GROUP, GROUP_INSTANCE_MODULE],
depends_on=[module_apply])
def module_delete_applied(self):
"""Ensure that deleting an applied module fails."""
self.test_runner.run_module_delete_applied()
@test(groups=[GROUP, GROUP_INSTANCE_MODULE],
depends_on=[module_apply],
runs_after=[module_list_instance_after_apply,
module_query_after_apply])
def module_remove(self):
"""Check that module-remove works."""
self.test_runner.run_module_remove()
@test(groups=[GROUP, GROUP_INSTANCE_MODULE],
depends_on=[module_remove])
def module_query_empty_after(self):
"""Check that the instance has no modules applied after remove."""
self.test_runner.run_module_query_empty()
@test(groups=[GROUP, GROUP_INSTANCE_MODULE],
depends_on=[create_inst_with_mods],
runs_after=[module_query_empty_after])
def wait_for_inst_with_mods(self):
"""Wait for create instance with modules to finish."""
self.test_runner.run_wait_for_inst_with_mods()
@test(groups=[GROUP, GROUP_INSTANCE_MODULE],
depends_on=[wait_for_inst_with_mods])
def module_query_after_inst_create(self):
"""Check that module-query works on new instance."""
self.test_runner.run_module_query_after_inst_create()
@test(groups=[GROUP, GROUP_INSTANCE_MODULE],
depends_on=[wait_for_inst_with_mods],
runs_after=[module_query_after_inst_create])
def module_retrieve_after_inst_create(self):
"""Check that module-retrieve works on new instance."""
self.test_runner.run_module_retrieve_after_inst_create()
@test(groups=[GROUP, GROUP_INSTANCE_MODULE],
depends_on=[wait_for_inst_with_mods],
runs_after=[module_retrieve_after_inst_create])
def module_query_after_inst_create_admin(self):
"""Check that module-query works for admin."""
self.test_runner.run_module_query_after_inst_create_admin()
@test(groups=[GROUP, GROUP_INSTANCE_MODULE],
depends_on=[wait_for_inst_with_mods],
runs_after=[module_query_after_inst_create_admin])
def module_retrieve_after_inst_create_admin(self):
"""Check that module-retrieve works for admin."""
self.test_runner.run_module_retrieve_after_inst_create_admin()
@test(groups=[GROUP, GROUP_INSTANCE_MODULE],
depends_on=[wait_for_inst_with_mods],
runs_after=[module_retrieve_after_inst_create_admin])
def module_delete_auto_applied(self):
"""Ensure that module-delete on auto-applied module fails."""
self.test_runner.run_module_delete_auto_applied()
@test(groups=[GROUP, GROUP_INSTANCE_MODULE],
depends_on=[wait_for_inst_with_mods],
runs_after=[module_delete_auto_applied])
def delete_inst_with_mods(self):
"""Check that instance with module can be deleted."""
self.test_runner.run_delete_inst_with_mods()
@test(depends_on_groups=[GROUP_MODULE_CREATE],
groups=[GROUP, GROUP_MODULE_DELETE]) groups=[GROUP, GROUP_MODULE_DELETE])
class ModuleDeleteGroup(TestGroup): class ModuleDeleteGroup(TestGroup):
"""Test Module Delete functionality.""" """Test Module Delete functionality."""
@ -329,16 +479,16 @@ class ModuleDeleteGroup(TestGroup):
runs_after=[module_delete_auto_by_non_admin]) runs_after=[module_delete_auto_by_non_admin])
def module_delete(self): def module_delete(self):
"""Check that delete module works.""" """Check that delete module works."""
self.test_runner.run_module_delete_auto_by_non_admin()
@test(groups=[GROUP, GROUP_MODULE_DELETE],
runs_after=[module_delete])
def module_delete_all(self):
"""Check that delete module works for admin."""
self.test_runner.run_module_delete() self.test_runner.run_module_delete()
@test(groups=[GROUP, GROUP_MODULE_DELETE], @test(groups=[GROUP, GROUP_MODULE_DELETE],
runs_after=[module_delete_all]) runs_after=[module_delete])
def module_delete_existing(self): def module_delete_admin(self):
"""Check that delete module works for admin."""
self.test_runner.run_module_delete_admin()
@test(groups=[GROUP, GROUP_MODULE_DELETE],
runs_after=[module_delete_admin])
def module_delete_remaining(self):
"""Delete all remaining test modules.""" """Delete all remaining test modules."""
self.test_runner.run_module_delete_existing() self.test_runner.run_module_delete_existing()

View File

@ -439,4 +439,4 @@ class TestHelper(object):
############## ##############
def get_valid_module_type(self): def get_valid_module_type(self):
"""Return a valid module type.""" """Return a valid module type."""
return "test" return "Ping"

View File

@ -96,5 +96,5 @@ class InstanceActionsRunner(TestRunner):
expected_http_code) expected_http_code)
instance = self.get_instance(instance_id) instance = self.get_instance(instance_id)
self.assert_equal(int(instance.flavor['id']), resize_flavor.id, self.assert_equal(instance.flavor['id'], resize_flavor.id,
'Unexpected resize flavor_id') 'Unexpected resize flavor_id')

View File

@ -14,22 +14,35 @@
# under the License. # under the License.
# #
import Crypto.Random
from proboscis import SkipTest from proboscis import SkipTest
import tempfile
from troveclient.compat import exceptions from troveclient.compat import exceptions
from trove.common import utils from trove.guestagent.common import guestagent_utils
from trove.guestagent.common import operating_system
from trove.module import models from trove.module import models
from trove.tests.scenario.runners.test_runners import TestRunner from trove.tests.scenario.runners.test_runners import TestRunner
# Variables here are set up to be used across multiple groups, # Variables here are set up to be used across multiple groups,
# since each group will instantiate a new runner # since each group will instantiate a new runner
random_data = Crypto.Random.new().read(20)
test_modules = [] test_modules = []
module_count_prior_to_create = 0 module_count_prior_to_create = 0
module_ds_count_prior_to_create = 0
module_ds_all_count_prior_to_create = 0
module_all_tenant_count_prior_to_create = 0
module_auto_apply_count_prior_to_create = 0
module_admin_count_prior_to_create = 0 module_admin_count_prior_to_create = 0
module_other_count_prior_to_create = 0 module_other_count_prior_to_create = 0
module_create_count = 0 module_create_count = 0
module_ds_create_count = 0
module_ds_all_create_count = 0
module_all_tenant_create_count = 0
module_auto_apply_create_count = 0
module_admin_create_count = 0 module_admin_create_count = 0
module_other_create_count = 0 module_other_create_count = 0
@ -42,11 +55,17 @@ class ModuleRunner(TestRunner):
super(ModuleRunner, self).__init__( super(ModuleRunner, self).__init__(
sleep_time=10, timeout=self.TIMEOUT_MODULE_APPLY) sleep_time=10, timeout=self.TIMEOUT_MODULE_APPLY)
self.MODULE_CONTENTS_PATTERN = 'Message=%s\n'
self.MODULE_MESSAGE_PATTERN = 'Hello World from: %s'
self.MODULE_NAME = 'test_module_1' self.MODULE_NAME = 'test_module_1'
self.MODULE_DESC = 'test description' self.MODULE_DESC = 'test description'
self.MODULE_CONTENTS = utils.encode_string( self.MODULE_NEG_CONTENTS = 'contents for negative tests'
'mode=echo\nkey=mysecretkey\n') self.MODULE_BINARY_SUFFIX = '_bin_auto'
self.MODULE_BINARY_SUFFIX2 = self.MODULE_BINARY_SUFFIX + '_2'
self.MODULE_BINARY_CONTENTS = random_data
self.MODULE_BINARY_CONTENTS2 = '\x00\xFF\xea\x9c\x11\xfeok\xb1\x8ax'
self.mod_inst_id = None
self.temp_module = None self.temp_module = None
self._module_type = None self._module_type = None
@ -62,6 +81,57 @@ class ModuleRunner(TestRunner):
SkipTest("No main module created") SkipTest("No main module created")
return test_modules[0] return test_modules[0]
def build_module_args(self, extra=None):
extra = extra or ''
name = self.MODULE_NAME + extra
desc = self.MODULE_DESC + extra.replace('_', ' ')
cont = self.get_module_contents(name)
return name, desc, cont
def get_module_contents(self, name=None):
message = self.get_module_message(name=name)
return self.MODULE_CONTENTS_PATTERN % message
def get_module_message(self, name=None):
name = name or self.MODULE_NAME
return self.MODULE_MESSAGE_PATTERN % name
def _find_invisible_module(self):
def _match(mod):
return not mod.visible and mod.tenant_id and not mod.auto_apply
return self._find_module(_match, "Could not find invisible module")
def _find_module(self, match_fn, not_found_message, find_all=False):
found = [] if find_all else None
for test_module in test_modules:
if match_fn(test_module):
if find_all:
found.append(test_module)
else:
found = test_module
break
if not found:
self.fail(not_found_message)
return found
def _find_auto_apply_module(self):
def _match(mod):
return mod.auto_apply and mod.tenant_id and mod.visible
return self._find_module(_match, "Could not find auto-apply module")
def _find_all_tenant_module(self):
def _match(mod):
return mod.tenant_id is None and mod.visible
return self._find_module(_match, "Could not find all tenant module")
def _find_all_auto_apply_modules(self, visible=None):
def _match(mod):
return mod.auto_apply and (
visible is None or mod.visible == visible)
return self._find_module(
_match, "Could not find all auto apply modules", find_all=True)
# Tests start here
def run_module_delete_existing(self): def run_module_delete_existing(self):
modules = self.admin_client.modules.list() modules = self.admin_client.modules.list()
for module in modules: for module in modules:
@ -74,7 +144,7 @@ class ModuleRunner(TestRunner):
self.assert_raises( self.assert_raises(
expected_exception, expected_http_code, expected_exception, expected_http_code,
self.auth_client.modules.create, self.auth_client.modules.create,
self.MODULE_NAME, 'invalid-type', self.MODULE_CONTENTS) self.MODULE_NAME, 'invalid-type', self.MODULE_NEG_CONTENTS)
def run_module_create_non_admin_auto( def run_module_create_non_admin_auto(
self, expected_exception=exceptions.Forbidden, self, expected_exception=exceptions.Forbidden,
@ -82,7 +152,7 @@ class ModuleRunner(TestRunner):
self.assert_raises( self.assert_raises(
expected_exception, expected_http_code, expected_exception, expected_http_code,
self.auth_client.modules.create, self.auth_client.modules.create,
self.MODULE_NAME, self.module_type, self.MODULE_CONTENTS, self.MODULE_NAME, self.module_type, self.MODULE_NEG_CONTENTS,
auto_apply=True) auto_apply=True)
def run_module_create_non_admin_all_tenant( def run_module_create_non_admin_all_tenant(
@ -91,7 +161,7 @@ class ModuleRunner(TestRunner):
self.assert_raises( self.assert_raises(
expected_exception, expected_http_code, expected_exception, expected_http_code,
self.auth_client.modules.create, self.auth_client.modules.create,
self.MODULE_NAME, self.module_type, self.MODULE_CONTENTS, self.MODULE_NAME, self.module_type, self.MODULE_NEG_CONTENTS,
all_tenants=True) all_tenants=True)
def run_module_create_non_admin_hidden( def run_module_create_non_admin_hidden(
@ -100,7 +170,7 @@ class ModuleRunner(TestRunner):
self.assert_raises( self.assert_raises(
expected_exception, expected_http_code, expected_exception, expected_http_code,
self.auth_client.modules.create, self.auth_client.modules.create,
self.MODULE_NAME, self.module_type, self.MODULE_CONTENTS, self.MODULE_NAME, self.module_type, self.MODULE_NEG_CONTENTS,
visible=False) visible=False)
def run_module_create_bad_datastore( def run_module_create_bad_datastore(
@ -109,7 +179,7 @@ class ModuleRunner(TestRunner):
self.assert_raises( self.assert_raises(
expected_exception, expected_http_code, expected_exception, expected_http_code,
self.auth_client.modules.create, self.auth_client.modules.create,
self.MODULE_NAME, self.module_type, self.MODULE_CONTENTS, self.MODULE_NAME, self.module_type, self.MODULE_NEG_CONTENTS,
datastore='bad-datastore') datastore='bad-datastore')
def run_module_create_bad_datastore_version( def run_module_create_bad_datastore_version(
@ -118,7 +188,7 @@ class ModuleRunner(TestRunner):
self.assert_raises( self.assert_raises(
expected_exception, expected_http_code, expected_exception, expected_http_code,
self.auth_client.modules.create, self.auth_client.modules.create,
self.MODULE_NAME, self.module_type, self.MODULE_CONTENTS, self.MODULE_NAME, self.module_type, self.MODULE_NEG_CONTENTS,
datastore=self.instance_info.dbaas_datastore, datastore=self.instance_info.dbaas_datastore,
datastore_version='bad-datastore-version') datastore_version='bad-datastore-version')
@ -128,26 +198,42 @@ class ModuleRunner(TestRunner):
self.assert_raises( self.assert_raises(
expected_exception, expected_http_code, expected_exception, expected_http_code,
self.auth_client.modules.create, self.auth_client.modules.create,
self.MODULE_NAME, self.module_type, self.MODULE_CONTENTS, self.MODULE_NAME, self.module_type, self.MODULE_NEG_CONTENTS,
datastore_version=self.instance_info.dbaas_datastore_version) datastore_version=self.instance_info.dbaas_datastore_version)
def run_module_create(self): def run_module_create(self):
# Necessary to test that the count increases. # Necessary to test that the count increases.
global module_count_prior_to_create global module_count_prior_to_create
global module_ds_count_prior_to_create
global module_ds_all_count_prior_to_create
global module_all_tenant_count_prior_to_create
global module_auto_apply_count_prior_to_create
global module_admin_count_prior_to_create global module_admin_count_prior_to_create
global module_other_count_prior_to_create global module_other_count_prior_to_create
module_count_prior_to_create = len( module_count_prior_to_create = len(
self.auth_client.modules.list()) self.auth_client.modules.list())
module_ds_count_prior_to_create = len(
self.auth_client.modules.list(
datastore=self.instance_info.dbaas_datastore))
module_ds_all_count_prior_to_create = len(
self.auth_client.modules.list(
datastore=models.Modules.MATCH_ALL_NAME))
module_all_tenant_count_prior_to_create = len(
self.unauth_client.modules.list())
module_auto_apply_count_prior_to_create = len(
[module for module in self.admin_client.modules.list()
if module.auto_apply])
module_admin_count_prior_to_create = len( module_admin_count_prior_to_create = len(
self.admin_client.modules.list()) self.admin_client.modules.list())
module_other_count_prior_to_create = len( module_other_count_prior_to_create = len(
self.unauth_client.modules.list()) self.unauth_client.modules.list())
name, description, contents = self.build_module_args()
self.assert_module_create( self.assert_module_create(
self.auth_client, self.auth_client,
name=self.MODULE_NAME, name=name,
module_type=self.module_type, module_type=self.module_type,
contents=self.MODULE_CONTENTS, contents=contents,
description=self.MODULE_DESC) description=description)
def assert_module_create(self, client, name=None, module_type=None, def assert_module_create(self, client, name=None, module_type=None,
contents=None, description=None, contents=None, description=None,
@ -163,15 +249,27 @@ class ModuleRunner(TestRunner):
auto_apply=auto_apply, auto_apply=auto_apply,
live_update=live_update, visible=visible) live_update=live_update, visible=visible)
global module_create_count global module_create_count
global module_ds_create_count
global module_ds_all_create_count
global module_auto_apply_create_count
global module_all_tenant_create_count
global module_admin_create_count global module_admin_create_count
global module_other_create_count global module_other_create_count
if (client == self.auth_client or if (client == self.auth_client or
(client == self.admin_client and visible)): (client == self.admin_client and visible)):
module_create_count += 1 module_create_count += 1
if datastore:
module_ds_create_count += 1
else:
module_ds_all_create_count += 1
elif not visible: elif not visible:
module_admin_create_count += 1 module_admin_create_count += 1
else: else:
module_other_create_count += 1 module_other_create_count += 1
if all_tenants and visible:
module_all_tenant_create_count += 1
if auto_apply and visible:
module_auto_apply_create_count += 1
global test_modules global test_modules
test_modules.append(result) test_modules.append(result)
@ -179,7 +277,8 @@ class ModuleRunner(TestRunner):
tenant = models.Modules.MATCH_ALL_NAME tenant = models.Modules.MATCH_ALL_NAME
if not all_tenants: if not all_tenants:
tenant, tenant_id = self.get_client_tenant(client) tenant, tenant_id = self.get_client_tenant(client)
# TODO(peterstac) we don't support tenant name yet ... # If we find a way to grab the tenant name in the module
# stuff, the line below can be removed
tenant = tenant_id tenant = tenant_id
datastore = datastore or models.Modules.MATCH_ALL_NAME datastore = datastore or models.Modules.MATCH_ALL_NAME
datastore_version = datastore_version or models.Modules.MATCH_ALL_NAME datastore_version = datastore_version or models.Modules.MATCH_ALL_NAME
@ -192,7 +291,8 @@ class ModuleRunner(TestRunner):
expected_tenant_id=tenant_id, expected_tenant_id=tenant_id,
expected_datastore=datastore, expected_datastore=datastore,
expected_ds_version=datastore_version, expected_ds_version=datastore_version,
expected_auto_apply=auto_apply) expected_auto_apply=auto_apply,
expected_contents=contents)
def validate_module(self, module, validate_all=False, def validate_module(self, module, validate_all=False,
expected_name=None, expected_name=None,
@ -216,7 +316,7 @@ class ModuleRunner(TestRunner):
self.assert_equal(expected_name, module.name, self.assert_equal(expected_name, module.name,
'Unexpected module name') 'Unexpected module name')
if expected_module_type: if expected_module_type:
self.assert_equal(expected_module_type, module.type, self.assert_equal(expected_module_type.lower(), module.type,
'Unexpected module type') 'Unexpected module type')
if expected_description: if expected_description:
self.assert_equal(expected_description, module.description, self.assert_equal(expected_description, module.description,
@ -258,7 +358,31 @@ class ModuleRunner(TestRunner):
self.assert_raises( self.assert_raises(
expected_exception, expected_http_code, expected_exception, expected_http_code,
self.auth_client.modules.create, self.auth_client.modules.create,
self.MODULE_NAME, self.module_type, self.MODULE_CONTENTS) self.MODULE_NAME, self.module_type, self.MODULE_NEG_CONTENTS)
def run_module_create_bin(self):
name, description, contents = self.build_module_args(
self.MODULE_BINARY_SUFFIX)
self.assert_module_create(
self.admin_client,
name=name,
module_type=self.module_type,
contents=self.MODULE_BINARY_CONTENTS,
description=description,
auto_apply=True,
visible=False)
def run_module_create_bin2(self):
name, description, contents = self.build_module_args(
self.MODULE_BINARY_SUFFIX2)
self.assert_module_create(
self.admin_client,
name=name,
module_type=self.module_type,
contents=self.MODULE_BINARY_CONTENTS2,
description=description,
auto_apply=True,
visible=False)
def run_module_show(self): def run_module_show(self):
test_module = self.main_test_module test_module = self.main_test_module
@ -291,9 +415,12 @@ class ModuleRunner(TestRunner):
self.auth_client, self.auth_client,
module_count_prior_to_create + module_create_count) module_count_prior_to_create + module_create_count)
def assert_module_list(self, client, expected_count, def assert_module_list(self, client, expected_count, datastore=None,
skip_validation=False): skip_validation=False):
module_list = client.modules.list() if datastore:
module_list = client.modules.list(datastore=datastore)
else:
module_list = client.modules.list()
self.assert_equal(expected_count, len(module_list), self.assert_equal(expected_count, len(module_list),
"Wrong number of modules for list") "Wrong number of modules for list")
if not skip_validation: if not skip_validation:
@ -312,71 +439,99 @@ class ModuleRunner(TestRunner):
expected_auto_apply=test_module.auto_apply) expected_auto_apply=test_module.auto_apply)
def run_module_list_unauth_user(self): def run_module_list_unauth_user(self):
self.assert_module_list(self.unauth_client, 0) self.assert_module_list(
self.unauth_client,
module_all_tenant_count_prior_to_create +
module_all_tenant_create_count + module_other_create_count)
def run_module_create_admin_all(self): def run_module_create_admin_all(self):
name, description, contents = self.build_module_args(
'_hidden_all_tenant_auto')
self.assert_module_create( self.assert_module_create(
self.admin_client, self.admin_client,
name=self.MODULE_NAME + '_admin_apply', name=name, module_type=self.module_type, contents=contents,
module_type=self.module_type, description=description,
contents=self.MODULE_CONTENTS,
description=(self.MODULE_DESC + ' admin apply'),
all_tenants=True, all_tenants=True,
visible=False, visible=False,
auto_apply=True) auto_apply=True)
def run_module_create_admin_hidden(self): def run_module_create_admin_hidden(self):
name, description, contents = self.build_module_args('_hidden')
self.assert_module_create( self.assert_module_create(
self.admin_client, self.admin_client,
name=self.MODULE_NAME + '_hidden', name=name, module_type=self.module_type, contents=contents,
module_type=self.module_type, description=description,
contents=self.MODULE_CONTENTS,
description=self.MODULE_DESC + ' hidden',
visible=False) visible=False)
def run_module_create_admin_auto(self): def run_module_create_admin_auto(self):
name, description, contents = self.build_module_args('_auto')
self.assert_module_create( self.assert_module_create(
self.admin_client, self.admin_client,
name=self.MODULE_NAME + '_auto', name=name, module_type=self.module_type, contents=contents,
module_type=self.module_type, description=description,
contents=self.MODULE_CONTENTS,
description=self.MODULE_DESC + ' hidden',
auto_apply=True) auto_apply=True)
def run_module_create_admin_live_update(self): def run_module_create_admin_live_update(self):
name, description, contents = self.build_module_args('_live')
self.assert_module_create( self.assert_module_create(
self.admin_client, self.admin_client,
name=self.MODULE_NAME + '_live', name=name, module_type=self.module_type, contents=contents,
module_type=self.module_type, description=description,
contents=self.MODULE_CONTENTS,
description=(self.MODULE_DESC + ' live update'),
live_update=True) live_update=True)
def run_module_create_all_tenant(self): def run_module_create_datastore(self):
name, description, contents = self.build_module_args('_ds')
self.assert_module_create( self.assert_module_create(
self.admin_client, self.admin_client,
name=self.MODULE_NAME + '_all_tenant', name=name, module_type=self.module_type, contents=contents,
module_type=self.module_type, description=description,
contents=self.MODULE_CONTENTS, datastore=self.instance_info.dbaas_datastore)
description=self.MODULE_DESC + ' all tenant',
def run_module_create_ds_version(self):
name, description, contents = self.build_module_args('_ds_ver')
self.assert_module_create(
self.admin_client,
name=name, module_type=self.module_type, contents=contents,
description=description,
datastore=self.instance_info.dbaas_datastore,
datastore_version=self.instance_info.dbaas_datastore_version)
def run_module_create_all_tenant(self):
name, description, contents = self.build_module_args(
'_all_tenant_ds_ver')
self.assert_module_create(
self.admin_client,
name=name, module_type=self.module_type, contents=contents,
description=description,
all_tenants=True, all_tenants=True,
datastore=self.instance_info.dbaas_datastore, datastore=self.instance_info.dbaas_datastore,
datastore_version=self.instance_info.dbaas_datastore_version) datastore_version=self.instance_info.dbaas_datastore_version)
def run_module_create_different_tenant(self): def run_module_create_different_tenant(self):
name, description, contents = self.build_module_args()
self.assert_module_create( self.assert_module_create(
self.unauth_client, self.unauth_client,
name=self.MODULE_NAME, name=name, module_type=self.module_type, contents=contents,
module_type=self.module_type, description=description)
contents=self.MODULE_CONTENTS,
description=self.MODULE_DESC)
def run_module_list_again(self): def run_module_list_again(self):
self.assert_module_list( self.assert_module_list(
self.auth_client, self.auth_client,
# TODO(peterstac) remove the '-1' once the list is fixed to module_count_prior_to_create + module_create_count,
# include 'all' tenant modules skip_validation=True)
module_count_prior_to_create + module_create_count - 1,
def run_module_list_ds(self):
self.assert_module_list(
self.auth_client,
module_ds_count_prior_to_create + module_ds_create_count,
datastore=self.instance_info.dbaas_datastore,
skip_validation=True)
def run_module_list_ds_all(self):
self.assert_module_list(
self.auth_client,
module_ds_all_count_prior_to_create + module_ds_all_create_count,
datastore=models.Modules.MATCH_ALL_NAME,
skip_validation=True) skip_validation=True)
def run_module_show_invisible( def run_module_show_invisible(
@ -387,21 +542,6 @@ class ModuleRunner(TestRunner):
expected_exception, expected_http_code, expected_exception, expected_http_code,
self.auth_client.modules.get, module.id) self.auth_client.modules.get, module.id)
def _find_invisible_module(self):
def _match(mod):
return not mod.visible and mod.tenant_id and not mod.auto_apply
return self._find_module(_match, "Could not find invisible module")
def _find_module(self, match_fn, not_found_message):
module = None
for test_module in test_modules:
if match_fn(test_module):
module = test_module
break
if not module:
self.fail(not_found_message)
return module
def run_module_list_admin(self): def run_module_list_admin(self):
self.assert_module_list( self.assert_module_list(
self.admin_client, self.admin_client,
@ -422,7 +562,7 @@ class ModuleRunner(TestRunner):
self.assert_module_update( self.assert_module_update(
self.auth_client, self.auth_client,
self.main_test_module.id, self.main_test_module.id,
contents=self.MODULE_CONTENTS) contents=self.get_module_contents(self.main_test_module.name))
self.assert_equal(old_md5, self.main_test_module.md5, self.assert_equal(old_md5, self.main_test_module.md5,
"MD5 changed with same contents") "MD5 changed with same contents")
@ -501,11 +641,6 @@ class ModuleRunner(TestRunner):
expected_exception, expected_http_code, expected_exception, expected_http_code,
self.auth_client.modules.update, module.id, auto_apply=False) self.auth_client.modules.update, module.id, auto_apply=False)
def _find_auto_apply_module(self):
def _match(mod):
return mod.auto_apply and mod.tenant_id and mod.visible
return self._find_module(_match, "Could not find auto-apply module")
def run_module_update_non_admin_auto_any( def run_module_update_non_admin_auto_any(
self, expected_exception=exceptions.Forbidden, self, expected_exception=exceptions.Forbidden,
expected_http_code=403): expected_http_code=403):
@ -530,11 +665,6 @@ class ModuleRunner(TestRunner):
expected_exception, expected_http_code, expected_exception, expected_http_code,
self.auth_client.modules.update, module.id, all_tenants=False) self.auth_client.modules.update, module.id, all_tenants=False)
def _find_all_tenant_module(self):
def _match(mod):
return mod.tenant_id is None and mod.visible
return self._find_module(_match, "Could not find all tenant module")
def run_module_update_non_admin_all_tenant_any( def run_module_update_non_admin_all_tenant_any(
self, expected_exception=exceptions.Forbidden, self, expected_exception=exceptions.Forbidden,
expected_http_code=403): expected_http_code=403):
@ -567,6 +697,297 @@ class ModuleRunner(TestRunner):
expected_exception, expected_http_code, expected_exception, expected_http_code,
self.auth_client.modules.update, module.id, description='Upd') self.auth_client.modules.update, module.id, description='Upd')
# ModuleInstanceGroup methods
def run_module_list_instance_empty(self):
self.assert_module_list_instance(
self.auth_client, self.instance_info.id,
module_auto_apply_count_prior_to_create)
def assert_module_list_instance(self, client, instance_id, expected_count,
expected_http_code=200):
module_list = client.instances.modules(instance_id)
self.assert_client_code(expected_http_code, client)
count = len(module_list)
self.assert_equal(expected_count, count,
"Wrong number of modules from list instance")
for module in module_list:
self.validate_module(module)
def run_module_instances_empty(self):
self.assert_module_instances(
self.auth_client, self.main_test_module.id, 0)
def assert_module_instances(self, client, module_id, expected_count,
expected_http_code=200):
instance_list = client.modules.instances(module_id)
self.assert_client_code(expected_http_code, client)
count = len(instance_list)
self.assert_equal(expected_count, count,
"Wrong number of instances applied from module")
def run_module_query_empty(self):
self.assert_module_query(self.auth_client, self.instance_info.id,
module_auto_apply_count_prior_to_create)
def assert_module_query(self, client, instance_id, expected_count,
expected_http_code=200, expected_results=None):
modquery_list = client.instances.module_query(instance_id)
self.assert_client_code(expected_http_code, client)
count = len(modquery_list)
self.assert_equal(expected_count, count,
"Wrong number of modules from query")
expected_results = expected_results or {}
for modquery in modquery_list:
if modquery.name in expected_results:
expected = expected_results[modquery.name]
self.validate_module_info(
modquery,
expected_status=expected['status'],
expected_message=expected['message'])
def run_module_apply(self):
self.assert_module_apply(self.auth_client, self.instance_info.id,
self.main_test_module)
def assert_module_apply(self, client, instance_id, module,
expected_status=None, expected_message=None,
expected_contents=None,
expected_http_code=200):
module_apply_list = client.instances.module_apply(
instance_id, [module.id])
self.assert_client_code(expected_http_code, client)
admin_only = (not module.visible or module.auto_apply or
not module.tenant_id)
expected_status = expected_status or 'OK'
expected_message = (expected_message or
self.get_module_message(module.name))
for module_apply in module_apply_list:
self.validate_module_info(
module_apply,
expected_name=module.name,
expected_module_type=module.type,
expected_datastore=module.datastore,
expected_ds_version=module.datastore_version,
expected_auto_apply=module.auto_apply,
expected_visible=module.visible,
expected_admin_only=admin_only,
expected_contents=expected_contents,
expected_status=expected_status,
expected_message=expected_message)
def validate_module_info(self, module_apply,
expected_name=None,
expected_module_type=None,
expected_datastore=None,
expected_ds_version=None,
expected_auto_apply=None,
expected_visible=None,
expected_admin_only=None,
expected_contents=None,
expected_message=None,
expected_status=None):
prefix = "Module: %s -" % expected_name
if expected_name:
self.assert_equal(expected_name, module_apply.name,
'%s Unexpected module name' % prefix)
if expected_module_type:
self.assert_equal(expected_module_type, module_apply.type,
'%s Unexpected module type' % prefix)
if expected_datastore:
self.assert_equal(expected_datastore, module_apply.datastore,
'%s Unexpected datastore' % prefix)
if expected_ds_version:
self.assert_equal(expected_ds_version,
module_apply.datastore_version,
'%s Unexpected datastore version' % prefix)
if expected_auto_apply is not None:
self.assert_equal(expected_auto_apply, module_apply.auto_apply,
'%s Unexpected auto_apply' % prefix)
if expected_visible is not None:
self.assert_equal(expected_visible, module_apply.visible,
'%s Unexpected visible' % prefix)
if expected_admin_only is not None:
self.assert_equal(expected_admin_only, module_apply.admin_only,
'%s Unexpected admin_only' % prefix)
if expected_contents is not None:
self.assert_equal(expected_contents, module_apply.contents,
'%s Unexpected contents' % prefix)
if expected_message is not None:
self.assert_equal(expected_message, module_apply.message,
'%s Unexpected message' % prefix)
if expected_status is not None:
self.assert_equal(expected_status, module_apply.status,
'%s Unexpected status' % prefix)
def run_module_list_instance_after_apply(self):
self.assert_module_list_instance(
self.auth_client, self.instance_info.id, 1)
def run_module_query_after_apply(self):
expected_count = module_auto_apply_count_prior_to_create + 1
expected_results = self.create_default_query_expected_results(
[self.main_test_module])
self.assert_module_query(self.auth_client, self.instance_info.id,
expected_count=expected_count,
expected_results=expected_results)
def create_default_query_expected_results(self, modules, is_admin=False):
expected_results = {}
for module in modules:
status = 'OK'
message = self.get_module_message(module.name)
contents = self.get_module_contents(module.name)
if not is_admin and (not module.visible or module.auto_apply or
not module.tenant_id):
contents = ('Must be admin to retrieve contents for module %s'
% module.name)
elif self.MODULE_BINARY_SUFFIX in module.name:
status = 'ERROR'
message = 'Message not found in contents file'
contents = self.MODULE_BINARY_CONTENTS
if self.MODULE_BINARY_SUFFIX2 in module.name:
contents = self.MODULE_BINARY_CONTENTS2
expected_results[module.name] = {
'status': status,
'message': message,
'datastore': module.datastore,
'datastore_version': module.datastore_version,
'contents': contents,
}
return expected_results
def run_create_inst_with_mods(self, expected_http_code=200):
self.mod_inst_id = self.assert_inst_mod_create(
self.main_test_module.id, 'module_1', expected_http_code)
def assert_inst_mod_create(self, module_id, name_suffix,
expected_http_code):
inst = self.auth_client.instances.create(
self.instance_info.name + name_suffix,
self.instance_info.dbaas_flavor_href,
self.instance_info.volume,
datastore=self.instance_info.dbaas_datastore,
datastore_version=self.instance_info.dbaas_datastore_version,
nics=self.instance_info.nics,
modules=[module_id],
)
self.assert_client_code(expected_http_code)
return inst.id
def run_module_delete_applied(
self, expected_exception=exceptions.Forbidden,
expected_http_code=403):
self.assert_raises(
expected_exception, expected_http_code,
self.auth_client.modules.delete, self.main_test_module.id)
def run_module_remove(self):
self.assert_module_remove(self.auth_client, self.instance_info.id,
self.main_test_module.id)
def assert_module_remove(self, client, instance_id, module_id,
expected_http_code=200):
client.instances.module_remove(instance_id, module_id)
self.assert_client_code(expected_http_code, client)
def run_wait_for_inst_with_mods(self, expected_states=['BUILD', 'ACTIVE']):
self.assert_instance_action(self.mod_inst_id, expected_states, None)
def run_module_query_after_inst_create(self):
auto_modules = self._find_all_auto_apply_modules(visible=True)
expected_count = 1 + len(auto_modules)
expected_results = self.create_default_query_expected_results(
[self.main_test_module] + auto_modules)
self.assert_module_query(self.auth_client, self.mod_inst_id,
expected_count=expected_count,
expected_results=expected_results)
def run_module_retrieve_after_inst_create(self):
auto_modules = self._find_all_auto_apply_modules(visible=True)
expected_count = 1 + len(auto_modules)
expected_results = self.create_default_query_expected_results(
[self.main_test_module] + auto_modules)
self.assert_module_retrieve(self.auth_client, self.mod_inst_id,
expected_count=expected_count,
expected_results=expected_results)
def assert_module_retrieve(self, client, instance_id, expected_count,
expected_http_code=200, expected_results=None):
try:
temp_dir = tempfile.mkdtemp()
prefix = 'contents'
modretrieve_list = client.instances.module_retrieve(
instance_id, directory=temp_dir, prefix=prefix)
self.assert_client_code(expected_http_code, client)
count = len(modretrieve_list)
self.assert_equal(expected_count, count,
"Wrong number of modules from retrieve")
expected_results = expected_results or {}
for module_name, filename in modretrieve_list.items():
if module_name in expected_results:
expected = expected_results[module_name]
contents_name = '%s_%s_%s_%s' % (
prefix, module_name,
expected['datastore'], expected['datastore_version'])
expected_filename = guestagent_utils.build_file_path(
temp_dir, contents_name, 'dat')
self.assert_equal(expected_filename, filename,
'Unexpected retrieve filename')
if 'contents' in expected and expected['contents']:
with open(filename, 'rb') as fh:
contents = fh.read()
# convert contents into bytearray to work with py27
# and py34
contents = bytes([ord(item) for item in contents])
expected_contents = bytes(
[ord(item) for item in expected['contents']])
self.assert_equal(expected_contents, contents,
"Unexpected contents for %s" %
module_name)
finally:
operating_system.remove(temp_dir)
def run_module_query_after_inst_create_admin(self):
auto_modules = self._find_all_auto_apply_modules()
expected_count = 1 + len(auto_modules)
expected_results = self.create_default_query_expected_results(
[self.main_test_module] + auto_modules, is_admin=True)
self.assert_module_query(self.admin_client, self.mod_inst_id,
expected_count=expected_count,
expected_results=expected_results)
def run_module_retrieve_after_inst_create_admin(self):
pass
auto_modules = self._find_all_auto_apply_modules()
expected_count = 1 + len(auto_modules)
expected_results = self.create_default_query_expected_results(
[self.main_test_module] + auto_modules, is_admin=True)
self.assert_module_retrieve(self.admin_client, self.mod_inst_id,
expected_count=expected_count,
expected_results=expected_results)
def run_module_delete_auto_applied(
self, expected_exception=exceptions.Forbidden,
expected_http_code=403):
module = self._find_auto_apply_module()
self.assert_raises(
expected_exception, expected_http_code,
self.auth_client.modules.delete, module.id)
def run_delete_inst_with_mods(self, expected_last_state=['SHUTDOWN'],
expected_http_code=202):
self.assert_delete_instance(
self.mod_inst_id,
expected_last_state, expected_http_code)
def assert_delete_instance(
self, instance_id, expected_last_state, expected_http_code):
self.auth_client.instances.delete(instance_id)
self.assert_client_code(expected_http_code)
self.assert_all_gone(instance_id, expected_last_state)
# ModuleDeleteGroup methods # ModuleDeleteGroup methods
def run_module_delete_non_existent( def run_module_delete_non_existent(
self, expected_exception=exceptions.NotFound, self, expected_exception=exceptions.NotFound,

View File

@ -0,0 +1,64 @@
# Copyright 2016 Tesora, Inc.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
from Crypto import Random
from trove.common import crypto_utils
from trove.tests.unittests import trove_testtools
class TestEncryptUtils(trove_testtools.TestCase):
def setUp(self):
super(TestEncryptUtils, self).setUp()
def tearDown(self):
super(TestEncryptUtils, self).tearDown()
def test_encode_decode_string(self):
random_data = bytearray(Random.new().read(12))
data = ['abc', 'numbers01234', '\x00\xFF\x00\xFF\xFF\x00', random_data]
for datum in data:
encoded_data = crypto_utils.encode_data(datum)
decoded_data = crypto_utils.decode_data(encoded_data)
self. assertEqual(datum, decoded_data,
"Encode/decode failed")
def test_pad_unpad(self):
for size in range(1, 100):
data_str = 'a' * size
padded_str = crypto_utils.pad_for_encryption(
data_str, crypto_utils.IV_BIT_COUNT)
self.assertEqual(0, len(padded_str) % crypto_utils.IV_BIT_COUNT,
"Padding not successful")
unpadded_str = crypto_utils.unpad_after_decryption(padded_str)
self.assertEqual(data_str, unpadded_str,
"String mangled after pad/unpad")
def test_encryp_decrypt(self):
key = 'my_secure_key'
for size in range(1, 100):
orig_data = Random.new().read(size)
orig_encoded = crypto_utils.encode_data(orig_data)
encrypted = crypto_utils.encrypt_data(orig_encoded, key)
encoded = crypto_utils.encode_data(encrypted)
decoded = crypto_utils.decode_data(encoded)
decrypted = crypto_utils.decrypt_data(decoded, key)
final_decoded = crypto_utils.decode_data(decrypted)
self.assertEqual(orig_data, final_decoded,
"Decrypted data did not match original")

View File

@ -0,0 +1,42 @@
# Copyright 2016 Tesora, Inc.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
from Crypto import Random
from trove.common import stream_codecs
from trove.tests.unittests import trove_testtools
class TestStreamCodecs(trove_testtools.TestCase):
def setUp(self):
super(TestStreamCodecs, self).setUp()
def tearDown(self):
super(TestStreamCodecs, self).tearDown()
def test_serialize_deserialize_base64codec(self):
random_data = bytearray(Random.new().read(12))
data = ['abc',
'numbers01234',
random_data]
codec = stream_codecs.Base64Codec()
for datum in data:
serialized_data = codec.serialize(datum)
deserialized_data = codec.deserialize(serialized_data)
self. assertEqual(datum, deserialized_data,
"Serialize/Deserialize failed")

View File

@ -14,7 +14,6 @@
# under the License. # under the License.
# #
from Crypto import Random
from mock import Mock from mock import Mock
from testtools import ExpectedException from testtools import ExpectedException
@ -82,39 +81,3 @@ class TestTroveExecuteWithTimeout(trove_testtools.TestCase):
def test_pagination_limit(self): def test_pagination_limit(self):
self.assertEqual(5, utils.pagination_limit(5, 9)) self.assertEqual(5, utils.pagination_limit(5, 9))
self.assertEqual(5, utils.pagination_limit(9, 5)) self.assertEqual(5, utils.pagination_limit(9, 5))
def test_encode_decode_string(self):
random_data = bytearray(Random.new().read(12))
data = ['abc', 'numbers01234', '\x00\xFF\x00\xFF\xFF\x00', random_data]
for datum in data:
encoded_data = utils.encode_string(datum)
decoded_data = utils.decode_string(encoded_data)
self. assertEqual(datum, decoded_data,
"Encode/decode failed")
def test_pad_unpad(self):
for size in range(1, 100):
data_str = 'a' * size
padded_str = utils.pad_for_encryption(data_str, utils.IV_BIT_COUNT)
self.assertEqual(0, len(padded_str) % utils.IV_BIT_COUNT,
"Padding not successful")
unpadded_str = utils.unpad_after_decryption(padded_str)
self.assertEqual(data_str, unpadded_str,
"String mangled after pad/unpad")
def test_encryp_decrypt(self):
key = 'my_secure_key'
for size in range(1, 100):
orig_str = ''
for index in range(1, size):
orig_str += Random.new().read(1)
orig_encoded = utils.encode_string(orig_str)
encrypted = utils.encrypt_string(orig_encoded, key)
encoded = utils.encode_string(encrypted)
decoded = utils.decode_string(encoded)
decrypted = utils.decrypt_string(decoded, key)
final_decoded = utils.decode_string(decrypted)
self.assertEqual(orig_str, final_decoded,
"String did not match original")

View File

@ -416,7 +416,7 @@ class ApiTest(trove_testtools.TestCase):
mount_point='/mnt/opt', backup_info=None, mount_point='/mnt/opt', backup_info=None,
config_contents='cont', root_password='1-2-3-4', config_contents='cont', root_password='1-2-3-4',
overrides='override', cluster_config={'id': '2-3-4-5'}, overrides='override', cluster_config={'id': '2-3-4-5'},
snapshot=None) snapshot=None, modules=None)
@mock.patch.object(messaging, 'Target') @mock.patch.object(messaging, 'Target')
@mock.patch.object(rpc, 'get_server') @mock.patch.object(rpc, 'get_server')
@ -424,7 +424,7 @@ class ApiTest(trove_testtools.TestCase):
backup = {'id': 'backup_id_123'} backup = {'id': 'backup_id_123'}
self.api.prepare('2048', 'package1', 'db1', 'user1', '/dev/vdt', self.api.prepare('2048', 'package1', 'db1', 'user1', '/dev/vdt',
'/mnt/opt', backup, 'cont', '1-2-3-4', '/mnt/opt', backup, 'cont', '1-2-3-4',
'overrides', {"id": "2-3-4-5"}) 'overrides', {"id": "2-3-4-5"}, modules=None)
self._verify_rpc_prepare_before_cast() self._verify_rpc_prepare_before_cast()
self._verify_cast( self._verify_cast(
@ -433,7 +433,24 @@ class ApiTest(trove_testtools.TestCase):
mount_point='/mnt/opt', backup_info=backup, mount_point='/mnt/opt', backup_info=backup,
config_contents='cont', root_password='1-2-3-4', config_contents='cont', root_password='1-2-3-4',
overrides='overrides', cluster_config={'id': '2-3-4-5'}, overrides='overrides', cluster_config={'id': '2-3-4-5'},
snapshot=None) snapshot=None, modules=None)
@mock.patch.object(messaging, 'Target')
@mock.patch.object(rpc, 'get_server')
def test_prepare_with_modules(self, *args):
modules = [{'id': 'mod_id'}]
self.api.prepare('2048', 'package1', 'db1', 'user1', '/dev/vdt',
'/mnt/opt', None, 'cont', '1-2-3-4',
'overrides', {"id": "2-3-4-5"}, modules=modules)
self._verify_rpc_prepare_before_cast()
self._verify_cast(
'prepare', packages=['package1'], databases='db1',
memory_mb='2048', users='user1', device_path='/dev/vdt',
mount_point='/mnt/opt', backup_info=None,
config_contents='cont', root_password='1-2-3-4',
overrides='overrides', cluster_config={'id': '2-3-4-5'},
snapshot=None, modules=modules)
def test_upgrade(self): def test_upgrade(self):
instance_version = "v1.0.1" instance_version = "v1.0.1"

View File

@ -25,7 +25,8 @@ from testtools import ExpectedException
from trove.common import exception from trove.common import exception
from trove.common.stream_codecs import ( from trove.common.stream_codecs import (
IdentityCodec, IniCodec, JsonCodec, PropertiesCodec, YamlCodec) Base64Codec, IdentityCodec, IniCodec, JsonCodec,
KeyValueCodec, PropertiesCodec, YamlCodec)
from trove.common import utils from trove.common import utils
from trove.guestagent.common import guestagent_utils from trove.guestagent.common import guestagent_utils
from trove.guestagent.common import operating_system from trove.guestagent.common import operating_system
@ -35,6 +36,16 @@ from trove.tests.unittests import trove_testtools
class TestOperatingSystem(trove_testtools.TestCase): class TestOperatingSystem(trove_testtools.TestCase):
def test_base64_codec(self):
data = "Line 1\nLine 2\n"
self._test_file_codec(data, Base64Codec())
data = "TGluZSAxCkxpbmUgMgo="
self._test_file_codec(data, Base64Codec(), reverse_encoding=True)
data = "5Am9+y0wTwqUx39sMMBg3611FWg="
self._test_file_codec(data, Base64Codec(), reverse_encoding=True)
def test_identity_file_codec(self): def test_identity_file_codec(self):
data = ("Lorem Ipsum, Lorem Ipsum\n" data = ("Lorem Ipsum, Lorem Ipsum\n"
"Lorem Ipsum, Lorem Ipsum\n" "Lorem Ipsum, Lorem Ipsum\n"
@ -105,6 +116,13 @@ class TestOperatingSystem(trove_testtools.TestCase):
self._test_file_codec(data, PropertiesCodec( self._test_file_codec(data, PropertiesCodec(
string_mappings={'yes': True, 'no': False, "''": None})) string_mappings={'yes': True, 'no': False, "''": None}))
def test_key_value_file_codec(self):
data = {'key1': 'value1',
'key2': 'value2',
'key3': 'value3'}
self._test_file_codec(data, KeyValueCodec())
def test_json_file_codec(self): def test_json_file_codec(self):
data = {"Section1": 's1v1', data = {"Section1": 's1v1',
"Section2": {"s2k1": '1', "Section2": {"s2k1": '1',
@ -117,21 +135,31 @@ class TestOperatingSystem(trove_testtools.TestCase):
def _test_file_codec(self, data, read_codec, write_codec=None, def _test_file_codec(self, data, read_codec, write_codec=None,
expected_data=None, expected_data=None,
expected_exception=None): expected_exception=None,
reverse_encoding=False):
write_codec = write_codec or read_codec write_codec = write_codec or read_codec
with tempfile.NamedTemporaryFile() as test_file: with tempfile.NamedTemporaryFile() as test_file:
encode = True
decode = True
if reverse_encoding:
encode = False
decode = False
if expected_exception: if expected_exception:
with expected_exception: with expected_exception:
operating_system.write_file(test_file.name, data, operating_system.write_file(test_file.name, data,
codec=write_codec) codec=write_codec,
encode=encode)
operating_system.read_file(test_file.name, operating_system.read_file(test_file.name,
codec=read_codec) codec=read_codec,
decode=decode)
else: else:
operating_system.write_file(test_file.name, data, operating_system.write_file(test_file.name, data,
codec=write_codec) codec=write_codec,
encode=encode)
read = operating_system.read_file(test_file.name, read = operating_system.read_file(test_file.name,
codec=read_codec) codec=read_codec,
decode=decode)
if expected_data is not None: if expected_data is not None:
self.assertEqual(expected_data, read) self.assertEqual(expected_data, read)
else: else:

View File

@ -32,7 +32,7 @@ class CreateModuleTest(trove_testtools.TestCase):
util.init_db() util.init_db()
self.context = Mock() self.context = Mock()
self.name = "name" self.name = "name"
self.module_type = 'test' self.module_type = 'ping'
self.contents = 'my_contents\n' self.contents = 'my_contents\n'
super(CreateModuleTest, self).setUp() super(CreateModuleTest, self).setUp()

View File

@ -202,7 +202,8 @@ class TestManager(trove_testtools.TestCase):
'mysql', 'mysql-server', 2, 'mysql', 'mysql-server', 2,
'temp-backup-id', None, 'temp-backup-id', None,
'some_password', None, Mock(), 'some_password', None, Mock(),
'some-master-id', None, None) 'some-master-id', None, None,
None)
mock_tasks.get_replication_master_snapshot.assert_called_with( mock_tasks.get_replication_master_snapshot.assert_called_with(
self.context, 'some-master-id', mock_flavor, 'temp-backup-id', self.context, 'some-master-id', mock_flavor, 'temp-backup-id',
replica_number=1) replica_number=1)
@ -218,7 +219,7 @@ class TestManager(trove_testtools.TestCase):
self.context, ['id1', 'id2'], Mock(), Mock(), self.context, ['id1', 'id2'], Mock(), Mock(),
Mock(), None, None, 'mysql', 'mysql-server', 2, Mock(), None, None, 'mysql', 'mysql-server', 2,
'temp-backup-id', None, 'some_password', None, 'temp-backup-id', None, 'some_password', None,
Mock(), 'some-master-id', None, None) Mock(), 'some-master-id', None, None, None)
def test_AttributeError_create_instance(self): def test_AttributeError_create_instance(self):
self.assertRaisesRegexp( self.assertRaisesRegexp(
@ -226,7 +227,7 @@ class TestManager(trove_testtools.TestCase):
self.manager.create_instance, self.context, ['id1', 'id2'], self.manager.create_instance, self.context, ['id1', 'id2'],
Mock(), Mock(), Mock(), None, None, 'mysql', 'mysql-server', 2, Mock(), Mock(), Mock(), None, None, 'mysql', 'mysql-server', 2,
'temp-backup-id', None, 'some_password', None, Mock(), None, None, 'temp-backup-id', None, 'some_password', None, Mock(), None, None,
None) None, None)
def test_create_instance(self): def test_create_instance(self):
mock_tasks = Mock() mock_tasks = Mock()
@ -238,7 +239,8 @@ class TestManager(trove_testtools.TestCase):
mock_flavor, 'mysql-image-id', None, mock_flavor, 'mysql-image-id', None,
None, 'mysql', 'mysql-server', 2, None, 'mysql', 'mysql-server', 2,
'temp-backup-id', None, 'password', 'temp-backup-id', None, 'password',
None, mock_override, None, None, None) None, mock_override, None, None, None,
None)
mock_tasks.create_instance.assert_called_with(mock_flavor, mock_tasks.create_instance.assert_called_with(mock_flavor,
'mysql-image-id', None, 'mysql-image-id', None,
None, 'mysql', None, 'mysql',
@ -246,7 +248,7 @@ class TestManager(trove_testtools.TestCase):
'temp-backup-id', None, 'temp-backup-id', None,
'password', None, 'password', None,
mock_override, mock_override,
None, None, None) None, None, None, None)
mock_tasks.wait_for_instance.assert_called_with(36000, mock_flavor) mock_tasks.wait_for_instance.assert_called_with(36000, mock_flavor)
def test_create_cluster(self): def test_create_cluster(self):

View File

@ -379,7 +379,7 @@ class FreshInstanceTasksTest(trove_testtools.TestCase):
'Error creating security group for instance', 'Error creating security group for instance',
self.freshinstancetasks.create_instance, mock_flavor, self.freshinstancetasks.create_instance, mock_flavor,
'mysql-image-id', None, None, 'mysql', 'mysql-server', 2, 'mysql-image-id', None, None, 'mysql', 'mysql-server', 2,
None, None, None, None, Mock(), None, None, None) None, None, None, None, Mock(), None, None, None, None)
@patch.object(BaseInstance, 'update_db') @patch.object(BaseInstance, 'update_db')
@patch.object(backup_models.Backup, 'get_by_id') @patch.object(backup_models.Backup, 'get_by_id')
@ -401,7 +401,8 @@ class FreshInstanceTasksTest(trove_testtools.TestCase):
'Error creating DNS entry for instance', 'Error creating DNS entry for instance',
self.freshinstancetasks.create_instance, mock_flavor, self.freshinstancetasks.create_instance, mock_flavor,
'mysql-image-id', None, None, 'mysql', 'mysql-server', 'mysql-image-id', None, None, 'mysql', 'mysql-server',
2, Mock(), None, 'root_password', None, Mock(), None, None, None) 2, Mock(), None, 'root_password', None, Mock(), None, None, None,
None)
@patch.object(BaseInstance, 'update_db') @patch.object(BaseInstance, 'update_db')
@patch.object(taskmanager_models.FreshInstanceTasks, '_create_dns_entry') @patch.object(taskmanager_models.FreshInstanceTasks, '_create_dns_entry')
@ -427,13 +428,13 @@ class FreshInstanceTasksTest(trove_testtools.TestCase):
'mysql-server', 2, 'mysql-server', 2,
None, None, None, None, None, None, None, None,
overrides, None, None, overrides, None, None,
'volume_type') 'volume_type', None)
mock_create_secgroup.assert_called_with('mysql') mock_create_secgroup.assert_called_with('mysql')
mock_build_volume_info.assert_called_with('mysql', volume_size=2, mock_build_volume_info.assert_called_with('mysql', volume_size=2,
volume_type='volume_type') volume_type='volume_type')
mock_guest_prepare.assert_called_with( mock_guest_prepare.assert_called_with(
768, mock_build_volume_info(), 'mysql-server', None, None, None, 768, mock_build_volume_info(), 'mysql-server', None, None, None,
config_content, None, overrides, None, None) config_content, None, overrides, None, None, None)
@patch.object(trove.guestagent.api.API, 'attach_replication_slave') @patch.object(trove.guestagent.api.API, 'attach_replication_slave')
@patch.object(rpc, 'get_client') @patch.object(rpc, 'get_client')