Merge pull request #143 from CGenie/cgenie/graph-db-with-events

Migrate signals, connections to graph DB backend, add abstract DB layer (ORM)
This commit is contained in:
Dmitry Shulyak 2015-09-15 14:33:37 +03:00
commit 5386178e4d
57 changed files with 1959 additions and 1190 deletions

1
Vagrantfile vendored
View File

@ -57,6 +57,7 @@ Vagrant.configure(VAGRANTFILE_API_VERSION) do |config|
config.vm.provision "file", source: "~/.vagrant.d/insecure_private_key", destination: "/vagrant/tmp/keys/ssh_private"
config.vm.provision "file", source: "bootstrap/ansible.cfg", destination: "/home/vagrant/.ansible.cfg"
config.vm.network "private_network", ip: "10.0.0.2"
config.vm.network "forwarded_port", guest: 7474, host: 17474
config.vm.host_name = "solar-dev"
config.vm.provider :virtualbox do |v|

View File

@ -1,6 +1,6 @@
:backends:
- redis
#- yaml
- yaml
#- redis
#- json
:yaml:
:datadir: /etc/puppet/hieradata
@ -12,4 +12,5 @@
:host: localhost
:deserialize: :json
:hierarchy:
- "%{resource_name}"
- resource

View File

@ -13,7 +13,6 @@ db = get_db()
def run():
db.clear()
signals.Connections.clear()
node1 = vr.create('node1', 'resources/ro_node', {'name': 'first' + str(time.time()),
'ip': '10.0.0.3',

View File

@ -15,6 +15,18 @@ from solar import events as evapi
from solar.interfaces.db import get_db
PROFILE = False
#PROFILE = True
if PROFILE:
import StringIO
import cProfile
import pstats
pr = cProfile.Profile()
GIT_PUPPET_LIBS_URL = 'https://github.com/CGenie/puppet-libs-resource'
@ -39,7 +51,8 @@ def main():
def setup_resources():
db.clear()
signals.Connections.clear()
if PROFILE:
pr.enable()
node1, node2 = vr.create('nodes', 'templates/nodes.yml', {})
@ -581,6 +594,15 @@ def setup_resources():
'bind_port': 'glance_api_servers_port'
})
if PROFILE:
pr.disable()
s = StringIO.StringIO()
sortby = 'cumulative'
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
print s.getvalue()
sys.exit(0)
has_errors = False
for r in locals().values():
if not isinstance(r, resource.Resource):
@ -667,12 +689,13 @@ resources_to_run = [
'neutron_agents_ml22',
]
@click.command()
def deploy():
setup_resources()
# run
resources = map(resource.wrap_resource, db.get_list(collection=db.COLLECTIONS.resource))
resources = resource.load_all()
resources = {r.name: r for r in resources}
for name in resources_to_run:

View File

@ -26,7 +26,6 @@ db = get_db()
def setup_riak():
db.clear()
signals.Connections.clear()
nodes = vr.create('nodes', 'templates/riak_nodes.yml', {})
node1, node2, node3 = nodes

View File

@ -6,6 +6,6 @@ from solar.core.log import log
def test(resource):
log.debug('Testing apache_puppet')
requests.get(
'http://%s:%s' % (resource.args['ip'].value, 80)
'http://%s:%s' % (resource.args['ip'], 80)
)

View File

@ -6,5 +6,5 @@ from solar.core.log import log
def test(resource):
log.debug('Testing cinder_api_puppet')
requests.get(
'http://%s:%s' % (resource.args['ip'].value, resource.args['service_port'].value)
'http://%s:%s' % (resource.args['ip'], resource.args['service_port'])
)

View File

@ -6,5 +6,5 @@ from solar.core.log import log
def test(resource):
log.debug('Testing cinder_puppet')
requests.get(
'http://%s:%s' % (resource.args['ip'].value, resource.args['port'].value)
'http://%s:%s' % (resource.args['ip'], resource.args['port'])
)

View File

@ -6,7 +6,7 @@ from solar.core.log import log
def test(resource):
log.debug('Testing cinder_scheduler_puppet')
# requests.get(
# 'http://%s:%s' % (resource.args['ip'].value, resource.args['port'].value)
# 'http://%s:%s' % (resource.args['ip'], resource.args['port'])
# TODO(bogdando) figure out how to test this
# http://docs.openstack.org/developer/nova/devref/scheduler.html
# )

View File

@ -6,7 +6,7 @@ from solar.core.log import log
def test(resource):
log.debug('Testing cinder_volume_puppet')
# requests.get(
# 'http://%s:%s' % (resource.args['ip'].value, resource.args['port'].value)
# 'http://%s:%s' % (resource.args['ip'], resource.args['port'])
# TODO(bogdando) figure out how to test this
# http://docs.openstack.org/developer/nova/devref/volume.html
# )

View File

@ -10,15 +10,15 @@ def test(resource):
args = resource.args
token, _ = validation.validate_token(
keystone_host=args['keystone_host'].value,
keystone_port=args['keystone_port'].value,
keystone_host=args['keystone_host'],
keystone_port=args['keystone_port'],
user='glance_admin',
tenant='services',
password=args['keystone_password'].value,
password=args['keystone_password'],
)
images = requests.get(
'http://%s:%s/v1/images' % (resource.args['ip'].value, 9393),
'http://%s:%s/v1/images' % (resource.args['ip'], 9393),
headers={'X-Auth-Token': token}
)
assert images.json() == {'images': []}

View File

@ -7,16 +7,16 @@ from solar.core import validation
def test(resource):
log.debug('Testing glance_puppet')
requests.get(
'http://%s:%s' % (resource.args['ip'].value, resource.args['bind_port'].value)
'http://%s:%s' % (resource.args['ip'], resource.args['bind_port'])
)
#TODO(bogdando) test packages installed and filesystem store datadir created
args = resource.args
token, _ = validation.validate_token(
keystone_host=args['keystone_host'].value,
keystone_port=args['keystone_port'].value,
user=args['keystone_user'].value,
tenant=args['keystone_tenant'].value,
password=args['keystone_password'].value,
keystone_host=args['keystone_host'],
keystone_port=args['keystone_port'],
user=args['keystone_user'],
tenant=args['keystone_tenant'],
password=args['keystone_password'],
)

View File

@ -6,5 +6,5 @@ from solar.core.log import log
def test(resource):
log.debug('Testing glance_registry_puppet')
requests.get(
'http://%s:%s' % (resource.args['ip'].value, resource.args['bind_port'].value)
'http://%s:%s' % (resource.args['ip'], resource.args['bind_port'])
)

View File

@ -6,5 +6,5 @@ from solar.core.log import log
def test(resource):
log.debug('Testing haproxy_service')
requests.get(
'http://%s:%s' % (resource.args['ip'].value, resource.args['ports'].value[0]['value'][0]['value'])
'http://%s:%s' % (resource.args['ip'], resource.args['ports'][0][0])
)

View File

@ -13,8 +13,8 @@ input:
schema: str!
value:
hosts_names:
schema: [{value: str!}]
schema: [str!]
value: []
hosts_ips:
schema: [{value: str!}]
schema: [str!]
value: []

View File

@ -6,5 +6,5 @@ from solar.core.log import log
def test(resource):
log.debug('Testing keystone_puppet')
requests.get(
'http://%s:%s' % (resource.args['ip'].value, resource.args['port'].value)
'http://%s:%s' % (resource.args['ip'], resource.args['port'])
)

View File

@ -6,5 +6,5 @@ from solar.core.log import log
def test(resource):
log.debug('Testing keystone_service')
requests.get(
'http://%s:%s' % (resource.args['ip'].value, resource.args['port'].value)
'http://%s:%s' % (resource.args['ip'], resource.args['port'])
)

View File

@ -9,26 +9,26 @@ def test(resource):
log.debug('Testing keystone_service_endpoint %s', resource.name)
resp = requests.get(
'http://%s:%s/v3/services' % (resource.args['ip'].value, resource.args['keystone_admin_port'].value),
'http://%s:%s/v3/services' % (resource.args['ip'], resource.args['keystone_admin_port']),
headers={
'X-Auth-Token': resource.args['admin_token'].value,
'X-Auth-Token': resource.args['admin_token'],
}
)
resp_json = resp.json()
assert 'services' in resp_json
service = [s for s in resp_json['services'] if s['name'] == resource.args['endpoint_name'].value][0]
service = [s for s in resp_json['services'] if s['name'] == resource.args['endpoint_name']][0]
service_id = service['id']
assert service['description'] == resource.args['description'].value
assert service['description'] == resource.args['description']
log.debug('%s service: %s', resource.name, json.dumps(service, indent=2))
resp = requests.get(
'http://%s:%s/v3/endpoints' % (resource.args['ip'].value, resource.args['keystone_admin_port'].value),
'http://%s:%s/v3/endpoints' % (resource.args['ip'], resource.args['keystone_admin_port']),
headers={
'X-Auth-Token': resource.args['admin_token'].value,
'X-Auth-Token': resource.args['admin_token'],
}
)
@ -41,9 +41,8 @@ def test(resource):
if endpoint['service_id'] == service_id:
endpoints[endpoint['interface']] = endpoint
assert jinja2.Template(resource.args['adminurl'].value).render(**resource.args_dict()) == endpoints['admin']['url']
assert jinja2.Template(resource.args['internalurl'].value).render(**resource.args_dict()) == endpoints['internal']['url']
assert jinja2.Template(resource.args['publicurl'].value).render(**resource.args_dict()) == endpoints['public']['url']
assert jinja2.Template(resource.args['adminurl']).render(**resource.args) == endpoints['admin']['url']
assert jinja2.Template(resource.args['internalurl']).render(**resource.args) == endpoints['internal']['url']
assert jinja2.Template(resource.args['publicurl']).render(**resource.args) == endpoints['public']['url']
log.debug('%s endpoints: %s', resource.name, json.dumps(endpoints, indent=2))

View File

@ -5,14 +5,14 @@ from solar.core import validation
def test(resource):
log.debug('Testing keystone_user %s', resource.args['user_name'].value)
log.debug('Testing keystone_user %s', resource.args['user_name'])
args = resource.args
token, _ = validation.validate_token(
keystone_host=args['keystone_host'].value,
keystone_port=args['keystone_port'].value,
user=args['user_name'].value,
tenant=args['tenant_name'].value,
password=args['user_password'].value,
keystone_host=args['keystone_host'],
keystone_port=args['keystone_port'],
user=args['user_name'],
tenant=args['tenant_name'],
password=args['user_password'],
)

View File

@ -11,11 +11,11 @@ def test(resource):
args = resource.args
token, token_data = validation.validate_token(
keystone_host=args['auth_host'].value,
keystone_port=args['auth_port'].value,
user=args['admin_user'].value,
tenant=args['admin_tenant_name'].value,
password=args['admin_password'].value,
keystone_host=args['auth_host'],
keystone_port=args['auth_port'],
user=args['admin_user'],
tenant=args['admin_tenant_name'],
password=args['admin_password'],
)
endpoints = [
@ -89,4 +89,3 @@ def test(resource):
)
log.debug('NOVA API IMAGES: %s', images.json())

View File

@ -16,4 +16,6 @@ Fabric==1.10.2
tabulate==0.7.5
ansible
celery
mock
mock
multipledispatch==0.4.8
mock

View File

@ -313,9 +313,8 @@ def init_cli_resource():
k, v = arg.split('=')
args_parsed.update({k: v})
click.echo('Updating resource {} with args {}'.format(name, args_parsed))
all = sresource.load_all()
r = all[name]
r.update(args_parsed)
res = sresource.load(name)
res.update(args_parsed)
@resource.command()
@click.option('--check-missing-connections', default=False, is_flag=True)
@ -350,7 +349,6 @@ def run():
init_actions()
init_cli_connect()
init_cli_connections()
init_cli_deployment_config()
init_cli_resource()
main.add_command(orchestration)

View File

@ -86,8 +86,12 @@ def history(n):
@changes.command()
def test():
results = testing.test_all()
@click.option('--name', default=None)
def test(name):
if name:
results = testing.test(name)
else:
results = testing.test_all()
for name, result in results.items():
msg = '[{status}] {name} {message}'

View File

@ -25,7 +25,7 @@ _default_transports = {
}
def resource_action(resource, action):
handler = resource.metadata.get('handler', 'none')
handler = resource.db_obj.handler or 'none'
with handlers.get(handler)([resource], _default_transports) as h:
return h.action(resource, action)

View File

@ -67,10 +67,10 @@ class AnsibleTemplate(TempFileHandler):
# XXX: r.args['ssh_user'] should be something different in this case probably
inventory = '{0} ansible_connection=local user={1} {2}'
host, user = 'localhost', r.args['ssh_user'].value
host, user = 'localhost', r.args['ssh_user']
args = []
for arg in r.args:
args.append('{0}="{1}"'.format(arg, r.args[arg].value))
args.append('{0}="{1}"'.format(arg, r.args[arg]))
args = ' '.join(args)
inventory = inventory.format(host, user, args)
log.debug(inventory)

View File

@ -72,8 +72,7 @@ class TempFileHandler(BaseHandler):
def _render_action(self, resource, action):
log.debug('Rendering %s %s', resource.name, action)
action_file = resource.metadata['actions'][action]
action_file = os.path.join(resource.metadata['actions_path'], action_file)
action_file = resource.actions[action]
log.debug('action file: %s', action_file)
args = self._make_args(resource)
@ -88,7 +87,7 @@ class TempFileHandler(BaseHandler):
trg_templates_dir = None
trg_scripts_dir = None
base_path = resource.metadata['base_path']
base_path = resource.db_obj.base_path
src_templates_dir = os.path.join(base_path, 'templates')
if os.path.exists(src_templates_dir):
trg_templates_dir = os.path.join(self.dirs[resource.name], 'templates')
@ -111,7 +110,7 @@ class TempFileHandler(BaseHandler):
def _make_args(self, resource):
args = {'resource_name': resource.name}
args['resource_dir'] = resource.metadata['base_path']
args['resource_dir'] = resource.db_obj.base_path
args['templates_dir'] = 'templates/'
args['scripts_dir'] = 'scripts/'
args.update(resource.args)

View File

@ -14,6 +14,7 @@
# under the License.
import os
import yaml
from solar.core.log import log
from solar.core.handlers.base import TempFileHandler
@ -31,7 +32,7 @@ class LibrarianPuppet(object):
def install(self):
puppet_module = '{}-{}'.format(
self.organization,
self.resource.metadata['puppet_module']
self.resource.db_obj.puppet_module
)
puppetlabs = self.transport_run.run(
@ -40,7 +41,7 @@ class LibrarianPuppet(object):
)
log.debug('Puppetlabs file is: \n%s\n', puppetlabs)
git = self.resource.args['git'].value
git = self.resource.args['git']
definition = "mod '{module_name}', :git => '{repository}', :ref => '{branch}'".format(
module_name=puppet_module,
@ -98,6 +99,8 @@ class Puppet(TempFileHandler):
action_file = self._compile_action_file(resource, action_name)
log.debug('action_file: %s', action_file)
self.upload_hiera_resource(resource)
self.upload_manifests(resource)
self.prepare_templates_and_scripts(resource, action_file, '')
@ -111,7 +114,7 @@ class Puppet(TempFileHandler):
'FACTER_resource_name': resource.name,
},
use_sudo=True,
warn_only=True,
warn_only=True
)
# 0 - no changes, 2 - successfull changes
if cmd.return_code not in [0, 2]:
@ -121,19 +124,33 @@ class Puppet(TempFileHandler):
return cmd
def clone_manifests(self, resource):
git = resource.args['git'].value
git = resource.args['git']
p = GitProvider(git['repository'], branch=git['branch'])
return p.directory
def upload_hiera_resource(self, resource):
with open('/tmp/puppet_resource.yaml', 'w') as f:
f.write(yaml.dump({
resource.name: resource.to_dict()
}))
self.transport_sync.copy(
resource,
'/tmp/puppet_resource.yaml',
'/etc/puppet/hieradata/{}.yaml'.format(resource.name),
use_sudo=True
)
self.transport_sync.sync_all()
def upload_manifests(self, resource):
if 'forge' in resource.args and resource.args['forge'].value:
if 'forge' in resource.args and resource.args['forge']:
self.upload_manifests_forge(resource)
else:
self.upload_manifests_librarian(resource)
def upload_manifests_forge(self, resource):
forge = resource.args['forge'].value
forge = resource.args['forge']
# Check if module already installed
modules = self.transport_run.run(

View File

@ -1,222 +0,0 @@
# Copyright 2015 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from solar.core.log import log
from solar.core import signals
from solar.interfaces.db import get_db
db = get_db()
class BaseObserver(object):
type_ = None
def __init__(self, attached_to, name, value):
"""
:param attached_to: resource.Resource
:param name:
:param value:
:return:
"""
self._attached_to_name = attached_to.name
self.name = name
self.value = value
@property
def attached_to(self):
from solar.core import resource
return resource.load(self._attached_to_name)
@property
def receivers(self):
from solar.core import resource
for receiver_name, receiver_input in signals.Connections.receivers(
self._attached_to_name,
self.name
):
yield resource.load(receiver_name).args[receiver_input]
def __repr__(self):
return '[{}:{}] {}'.format(self._attached_to_name, self.name, self.value)
def __unicode__(self):
return unicode(self.value)
def __eq__(self, other):
if isinstance(other, BaseObserver):
return self.value == other.value
return self.value == other
def notify(self, emitter):
"""
:param emitter: Observer
:return:
"""
raise NotImplementedError
def update(self, value):
"""
:param value:
:return:
"""
raise NotImplementedError
def find_receiver(self, receiver):
fltr = [r for r in self.receivers
if r._attached_to_name == receiver._attached_to_name
and r.name == receiver.name]
if fltr:
return fltr[0]
def subscribe(self, receiver):
"""
:param receiver: Observer
:return:
"""
log.debug('Subscribe %s', receiver)
# No multiple subscriptions
if self.find_receiver(receiver):
log.error('No multiple subscriptions from %s', receiver)
return
receiver.subscribed(self)
signals.Connections.add(
self.attached_to,
self.name,
receiver.attached_to,
receiver.name
)
receiver.notify(self)
def subscribed(self, emitter):
log.debug('Subscribed %s', emitter)
def unsubscribe(self, receiver):
"""
:param receiver: Observer
:return:
"""
log.debug('Unsubscribe %s', receiver)
if self.find_receiver(receiver):
receiver.unsubscribed(self)
signals.Connections.remove(
self.attached_to,
self.name,
receiver.attached_to,
receiver.name
)
# TODO: ?
#receiver.notify(self)
def unsubscribed(self, emitter):
log.debug('Unsubscribed %s', emitter)
class Observer(BaseObserver):
type_ = 'simple'
@property
def emitter(self):
from solar.core import resource
emitter = signals.Connections.emitter(self._attached_to_name, self.name)
if emitter is not None:
emitter_name, emitter_input_name = emitter
return resource.load(emitter_name).args[emitter_input_name]
def notify(self, emitter):
log.debug('Notify from %s value %s', emitter, emitter.value)
# Copy emitter's values to receiver
self.value = emitter.value
for receiver in self.receivers:
receiver.notify(self)
self.attached_to.set_args_from_dict({self.name: self.value})
def update(self, value):
log.debug('Updating to value %s', value)
self.value = value
for receiver in self.receivers:
receiver.notify(self)
self.attached_to.set_args_from_dict({self.name: self.value})
def subscribed(self, emitter):
super(Observer, self).subscribed(emitter)
# Simple observer can be attached to at most one emitter
if self.emitter is not None:
self.emitter.unsubscribe(self)
class ListObserver(BaseObserver):
type_ = 'list'
def __unicode__(self):
return unicode(self.value)
@staticmethod
def _format_value(emitter):
return {
'emitter': emitter.name,
'emitter_attached_to': emitter._attached_to_name,
'value': emitter.value,
}
def notify(self, emitter):
log.debug('Notify from %s value %s', emitter, emitter.value)
# Copy emitter's values to receiver
idx = self._emitter_idx(emitter)
self.value[idx] = self._format_value(emitter)
for receiver in self.receivers:
receiver.notify(self)
self.attached_to.set_args_from_dict({self.name: self.value})
def subscribed(self, emitter):
super(ListObserver, self).subscribed(emitter)
idx = self._emitter_idx(emitter)
if idx is None:
self.value.append(self._format_value(emitter))
self.attached_to.set_args_from_dict({self.name: self.value})
def unsubscribed(self, emitter):
"""
:param receiver: Observer
:return:
"""
log.debug('Unsubscribed emitter %s', emitter)
idx = self._emitter_idx(emitter)
self.value.pop(idx)
self.attached_to.set_args_from_dict({self.name: self.value})
for receiver in self.receivers:
receiver.notify(self)
def _emitter_idx(self, emitter):
try:
return [i for i, e in enumerate(self.value)
if e['emitter_attached_to'] == emitter._attached_to_name
][0]
except IndexError:
return
def create(type_, *args, **kwargs):
for klass in BaseObserver.__subclasses__():
if klass.type_ == type_:
return klass(*args, **kwargs)
raise NotImplementedError('No handling class for type {}'.format(type_))

View File

@ -12,21 +12,4 @@
# License for the specific language governing permissions and limitations
# under the License.
__all__ = [
'Resource',
'create',
'load',
'load_all',
'prepare_meta',
'wrap_resource',
'validate_resources',
]
from solar.core.resource.resource import Resource
from solar.core.resource.resource import load
from solar.core.resource.resource import load_all
from solar.core.resource.resource import wrap_resource
from solar.core.resource.virtual_resource import create
from solar.core.resource.virtual_resource import prepare_meta
from solar.core.resource.virtual_resource import validate_resources
from .resource import Resource, load, load_all

View File

@ -1,80 +0,0 @@
# Copyright 2015 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import inflection
import os
import pprint
from solar.core import resource
from solar import utils
RESOURCE_HEADER_TEMPLATE = """
from solar.core.resource import Resource
"""
RESOURCE_CLASS_TEMPLATE = """
class {class_name}(Resource):
_metadata = {{
'actions': {meta_actions},
'actions_path': '{actions_path}',
'base_path': '{base_path}',
'input': {meta_input},
'handler': '{handler}',
}}
{input_properties}
"""
RESOURCE_INPUT_PROPERTY_TEMPLATE = """
@property
def {name}(self):
return self.args['{name}']
@{name}.setter
def {name}(self, value):
#self.args['{name}'].value = value
#self.set_args_from_dict({{'{name}': value}})
self.update({{'{name}': value}})
"""
def compile(meta):
destination_file = utils.read_config()['resources-compiled-file']
resource.prepare_meta(meta)
meta['class_name'] = '{}Resource'.format(
inflection.camelize(meta['base_name'])
)
meta['meta_actions'] = pprint.pformat(meta['actions'])
meta['meta_input'] = pprint.pformat(meta['input'])
print meta['base_name'], meta['class_name']
if not os.path.exists(destination_file):
with open(destination_file, 'w') as f:
f.write(RESOURCE_HEADER_TEMPLATE.format(**meta))
with open(destination_file, 'a') as f:
input_properties = '\n'.join(
RESOURCE_INPUT_PROPERTY_TEMPLATE.format(name=name)
for name in meta['input']
)
f.write(RESOURCE_CLASS_TEMPLATE.format(
input_properties=input_properties, **meta)
)

View File

@ -14,204 +14,126 @@
# under the License.
from copy import deepcopy
from multipledispatch import dispatch
import os
from solar.core import actions
from solar.core import observer
from solar.core import validation
from solar.interfaces import orm
from solar import utils
from solar.interfaces.db import get_db
db = get_db()
def read_meta(base_path):
base_meta_file = os.path.join(base_path, 'meta.yaml')
metadata = utils.yaml_load(base_meta_file)
metadata['version'] = '1.0.0'
metadata['base_path'] = os.path.abspath(base_path)
actions_path = os.path.join(metadata['base_path'], 'actions')
metadata['actions_path'] = actions_path
metadata['base_name'] = os.path.split(metadata['base_path'])[-1]
return metadata
class Resource(object):
_metadata = {}
def __init__(self, name, metadata, args, tags=None, virtual_resource=None):
# Create
@dispatch(str, str, dict)
def __init__(self, name, base_path, args, tags=None, virtual_resource=None):
self.name = name
if metadata:
self.metadata = metadata
if base_path:
metadata = read_meta(base_path)
else:
self.metadata = deepcopy(self._metadata)
self.metadata['id'] = name
metadata = deepcopy(self._metadata)
self.tags = tags or []
self.virtual_resource = virtual_resource
self.set_args_from_dict(args)
self.db_obj = orm.DBResource(**{
'id': name,
'name': name,
'actions_path': metadata.get('actions_path', ''),
'base_name': metadata.get('base_name', ''),
'base_path': metadata.get('base_path', ''),
'handler': metadata.get('handler', ''),
'puppet_module': metadata.get('puppet_module', ''),
'version': metadata.get('version', ''),
'meta_inputs': metadata.get('input', {})
})
self.db_obj.save()
self.create_inputs(args)
# Load
@dispatch(orm.DBResource)
def __init__(self, resource_db):
self.db_obj = resource_db
self.name = resource_db.name
# TODO: tags
self.tags = []
self.virtual_resource = None
@property
def actions(self):
return self.metadata.get('actions') or []
ret = {
os.path.splitext(p)[0]: os.path.join(
self.db_obj.actions_path, p
)
for p in os.listdir(self.db_obj.actions_path)
}
return {
k: v for k, v in ret.items() if os.path.isfile(v)
}
def create_inputs(self, args):
for name, v in self.db_obj.meta_inputs.items():
value = args.get(name, v.get('value'))
self.db_obj.add_input(name, v['schema'], value)
@property
def args(self):
ret = {}
for i in self.resource_inputs().values():
ret[i.name] = i.backtrack_value()
return ret
args = self.args_dict()
def update(self, args):
# TODO: disconnect input when it is updated and end_node
# for some input_to_input relation
resource_inputs = self.resource_inputs()
for arg_name, metadata_arg in self.metadata['input'].items():
type_ = validation.schema_input_type(metadata_arg.get('schema', 'str'))
for k, v in args.items():
i = resource_inputs[k]
i.value = v
i.save()
ret[arg_name] = observer.create(
type_, self, arg_name, args.get(arg_name)
)
def resource_inputs(self):
return {
i.name: i for i in self.db_obj.inputs.value
}
def to_dict(self):
ret = self.db_obj.to_dict()
ret['input'] = {}
for k, v in self.args.items():
ret['input'][k] = {
'value': v,
}
return ret
def args_dict(self):
raw_resource = db.read(self.name, collection=db.COLLECTIONS.resource)
if raw_resource is None:
return {}
self.metadata = raw_resource
def load(name):
r = orm.DBResource.load(name)
return Resource.get_raw_resource_args(raw_resource)
if not r:
raise Exception('Resource {} does not exist in DB'.format(name))
def set_args_from_dict(self, new_args):
args = self.args_dict()
args.update(new_args)
self.metadata['tags'] = self.tags
self.metadata['virtual_resource'] = self.virtual_resource
for k, v in args.items():
if k not in self.metadata['input']:
raise NotImplementedError(
'Argument {} not implemented for resource {}'.format(k, self)
)
if isinstance(v, dict) and 'value' in v:
v = v['value']
self.metadata['input'][k]['value'] = v
db.save(self.name, self.metadata, collection=db.COLLECTIONS.resource)
def set_args(self, args):
self.set_args_from_dict({k: v.value for k, v in args.items()})
def __repr__(self):
return ("Resource(name='{id}', metadata={metadata}, args={input}, "
"tags={tags})").format(**self.to_dict())
def color_repr(self):
import click
arg_color = 'yellow'
return ("{resource_s}({name_s}='{id}', {metadata_s}={metadata}, "
"{args_s}={input}, {tags_s}={tags})").format(
resource_s=click.style('Resource', fg='white', bold=True),
name_s=click.style('name', fg=arg_color, bold=True),
metadata_s=click.style('metadata', fg=arg_color, bold=True),
args_s=click.style('args', fg=arg_color, bold=True),
tags_s=click.style('tags', fg=arg_color, bold=True),
**self.to_dict()
)
def to_dict(self):
return {
'id': self.name,
'metadata': self.metadata,
'input': self.args_show(),
'tags': self.tags,
}
def args_show(self):
def formatter(v):
if isinstance(v, observer.ListObserver):
return v.value
elif isinstance(v, observer.Observer):
return {
'emitter': v.emitter.attached_to.name if v.emitter else None,
'value': v.value,
}
return v
return {k: formatter(v) for k, v in self.args.items()}
def add_tag(self, tag):
if tag not in self.tags:
self.tags.append(tag)
def remove_tag(self, tag):
try:
self.tags.remove(tag)
except ValueError:
pass
def notify(self, emitter):
"""Update resource's args from emitter's args.
:param emitter: Resource
:return:
"""
r_args = self.args
for key, value in emitter.args.iteritems():
r_args[key].notify(value)
def update(self, args):
"""This method updates resource's args with a simple dict.
:param args:
:return:
"""
# Update will be blocked if this resource is listening
# on some input that is to be updated -- we should only listen
# to the emitter and not be able to change the input's value
r_args = self.args
for key, value in args.iteritems():
r_args[key].update(value)
self.set_args(r_args)
def action(self, action):
if action in self.actions:
actions.resource_action(self, action)
else:
raise Exception('Uuups, action is not available')
@staticmethod
def get_raw_resource_args(raw_resource):
return {k: v.get('value') for k, v in raw_resource['input'].items()}
def wrap_resource(raw_resource):
name = raw_resource['id']
args = Resource.get_raw_resource_args(raw_resource)
tags = raw_resource.get('tags', [])
virtual_resource = raw_resource.get('virtual_resource', [])
return Resource(name, raw_resource, args, tags=tags, virtual_resource=virtual_resource)
def wrap_resource_no_value(raw_resource):
name = raw_resource['id']
args = {k: v for k, v in raw_resource['input'].items()}
tags = raw_resource.get('tags', [])
virtual_resource = raw_resource.get('virtual_resource', [])
return Resource(name, raw_resource, args, tags=tags, virtual_resource=virtual_resource)
def load(resource_name):
raw_resource = db.read(resource_name, collection=db.COLLECTIONS.resource)
if raw_resource is None:
raise KeyError(
'Resource {} does not exist'.format(resource_name)
)
return wrap_resource(raw_resource)
return Resource(r)
# TODO
def load_all():
ret = {}
for raw_resource in db.get_list(collection=db.COLLECTIONS.resource):
resource = wrap_resource(raw_resource)
ret[resource.name] = resource
return ret
return [Resource(r) for r in orm.DBResource.load_all()]

View File

@ -15,67 +15,16 @@
import os
from StringIO import StringIO
import yaml
from jinja2 import Template, Environment, meta
from solar import utils
from solar.core import validation
from solar.core.resource import load_all, Resource
from solar.core import provider
from solar.core import resource
from solar.core import signals
def create_resource(name, base_path, args, virtual_resource=None):
if isinstance(base_path, provider.BaseProvider):
base_path = base_path.directory
base_meta_file = os.path.join(base_path, 'meta.yaml')
actions_path = os.path.join(base_path, 'actions')
metadata = utils.yaml_load(base_meta_file)
metadata['id'] = name
metadata['version'] = '1.0.0'
metadata['base_path'] = os.path.abspath(base_path)
prepare_meta(metadata)
tags = metadata.get('tags', [])
resource = Resource(name, metadata, args, tags, virtual_resource)
return resource
def create_virtual_resource(vr_name, template):
resources = template['resources']
connections = []
created_resources = []
cwd = os.getcwd()
for resource in resources:
name = resource['id']
base_path = os.path.join(cwd, resource['from'])
args = resource['values']
new_resources = create(name, base_path, args, vr_name)
created_resources += new_resources
if not is_virtual(base_path):
for key, arg in args.items():
if isinstance(arg, basestring) and '::' in arg:
emitter, src = arg.split('::')
connections.append((emitter, name, {src: key}))
db = load_all()
for emitter, reciver, mapping in connections:
emitter = db[emitter]
reciver = db[reciver]
signals.connect(emitter, reciver, mapping)
return created_resources
def create(name, base_path, kwargs, virtual_resource=None):
def create(name, base_path, args={}, virtual_resource=None):
if isinstance(base_path, provider.BaseProvider):
base_path = base_path.directory
@ -85,93 +34,54 @@ def create(name, base_path, kwargs, virtual_resource=None):
)
if is_virtual(base_path):
template = _compile_file(name, base_path, kwargs)
template = _compile_file(name, base_path, args)
yaml_template = yaml.load(StringIO(template))
resources = create_virtual_resource(name, yaml_template)
rs = create_virtual_resource(name, yaml_template)
else:
resource = create_resource(name, base_path, kwargs, virtual_resource)
resources = [resource]
r = create_resource(name,
base_path,
args=args,
virtual_resource=virtual_resource)
rs = [r]
return resources
return rs
def prepare_meta(meta):
actions_path = os.path.join(meta['base_path'], 'actions')
meta['actions_path'] = actions_path
meta['base_name'] = os.path.split(meta['base_path'])[-1]
def create_resource(name, base_path, args={}, virtual_resource=None):
if isinstance(base_path, provider.BaseProvider):
base_path = base_path.directory
meta['actions'] = {}
if os.path.exists(meta['actions_path']):
for f in os.listdir(meta['actions_path']):
meta['actions'][os.path.splitext(f)[0]] = f
r = resource.Resource(
name, base_path, args, tags=[], virtual_resource=virtual_resource
)
return r
def validate_resources():
db = load_all()
all_errors = []
for r in db.values():
if not isinstance(r, Resource):
continue
def create_virtual_resource(vr_name, template):
resources = template['resources']
connections = []
created_resources = []
errors = validation.validate_resource(r)
if errors:
all_errors.append((r, errors))
return all_errors
cwd = os.getcwd()
for r in resources:
name = r['id']
base_path = os.path.join(cwd, r['from'])
args = r['values']
new_resources = create(name, base_path, args, vr_name)
created_resources += new_resources
if not is_virtual(base_path):
for key, arg in args.items():
if isinstance(arg, basestring) and '::' in arg:
emitter, src = arg.split('::')
connections.append((emitter, name, {src: key}))
def find_inputs_without_source():
"""Find resources and inputs values of which are hardcoded.
for emitter, reciver, mapping in connections:
emitter = r.load(emitter)
reciver = r.load(reciver)
signals.connect(emitter, reciver, mapping)
:return: [(resource_name, input_name)]
"""
resources = load_all()
ret = set([(r.name, input_name) for r in resources.values()
for input_name in r.args])
clients = signals.Connections.read_clients()
for dest_dict in clients.values():
for destinations in dest_dict.values():
for receiver_name, receiver_input in destinations:
try:
ret.remove((receiver_name, receiver_input))
except KeyError:
continue
return list(ret)
def find_missing_connections():
"""Find resources whose input values are duplicated
and they are not connected between each other (i.e. the values
are hardcoded, not coming from connection).
NOTE: this we could have 2 inputs of the same value living in 2 "circles".
This is not covered, we find only inputs whose value is hardcoded.
:return: [(resource_name1, input_name1, resource_name2, input_name2)]
"""
ret = set()
resources = load_all()
inputs_without_source = find_inputs_without_source()
for resource1, input1 in inputs_without_source:
r1 = resources[resource1]
v1 = r1.args[input1]
for resource2, input2 in inputs_without_source:
r2 = resources[resource2]
v2 = r2.args[input2]
if v1 == v2 and resource1 != resource2 and \
(resource2, input2, resource1, input1) not in ret:
ret.add((resource1, input1, resource2, input2))
return list(ret)
return created_resources
def _compile_file(name, path, kwargs):

View File

@ -13,106 +13,10 @@
# License for the specific language governing permissions and limitations
# under the License.
from collections import defaultdict
import itertools
import networkx as nx
from solar.core.log import log
from solar.interfaces.db import get_db
from solar.events.api import add_events
from solar.events.controls import Dependency
db = get_db()
class Connections(object):
@staticmethod
def read_clients():
"""
Returned structure is:
emitter_name:
emitter_input_name:
- - dst_name
- dst_input_name
while DB structure is:
emitter_name_key:
emitter: emitter_name
sources:
emitter_input_name:
- - dst_name
- dst_input_name
"""
ret = {}
for data in db.get_list(collection=db.COLLECTIONS.connection):
ret[data['emitter']] = data['sources']
return ret
@staticmethod
def save_clients(clients):
data = []
for emitter_name, sources in clients.items():
data.append((
emitter_name,
{
'emitter': emitter_name,
'sources': sources,
}))
db.save_list(data, collection=db.COLLECTIONS.connection)
@staticmethod
def add(emitter, src, receiver, dst):
if src not in emitter.args:
return
clients = Connections.read_clients()
# TODO: implement general circular detection, this one is simple
if [emitter.name, src] in clients.get(receiver.name, {}).get(dst, []):
raise Exception('Attempted to create cycle in dependencies. Not nice.')
clients.setdefault(emitter.name, {})
clients[emitter.name].setdefault(src, [])
if [receiver.name, dst] not in clients[emitter.name][src]:
clients[emitter.name][src].append([receiver.name, dst])
Connections.save_clients(clients)
@staticmethod
def remove(emitter, src, receiver, dst):
clients = Connections.read_clients()
clients[emitter.name][src] = [
destination for destination in clients[emitter.name][src]
if destination != [receiver.name, dst]
]
Connections.save_clients(clients)
@staticmethod
def receivers(emitter_name, emitter_input_name):
return Connections.read_clients().get(emitter_name, {}).get(
emitter_input_name, []
)
@staticmethod
def emitter(receiver_name, receiver_input_name):
for emitter_name, dest_dict in Connections.read_clients().items():
for emitter_input_name, destinations in dest_dict.items():
if [receiver_name, receiver_input_name] in destinations:
return [emitter_name, emitter_input_name]
@staticmethod
def clear():
db.clear_collection(collection=db.COLLECTIONS.connection)
def guess_mapping(emitter, receiver):
"""Guess connection mapping between emitter and receiver.
@ -140,23 +44,7 @@ def guess_mapping(emitter, receiver):
return guessed
def connect_single(emitter, src, receiver, dst):
# Disconnect all receiver inputs
# Check if receiver input is of list type first
if receiver.args[dst].type_ != 'list':
disconnect_receiver_by_input(receiver, dst)
emitter.args[src].subscribe(receiver.args[dst])
def connect(emitter, receiver, mapping=None, events=None):
# convert if needed
# TODO: handle invalid resource
# if isinstance(emitter, basestring):
# emitter = resource.load(emitter)
# if isinstance(receiver, basestring):
# receiver = resource.load(receiver)
def connect(emitter, receiver, mapping={}, events=None):
mapping = mapping or guess_mapping(emitter, receiver)
if isinstance(mapping, set):
@ -165,16 +53,12 @@ def connect(emitter, receiver, mapping=None, events=None):
return
for src, dst in mapping.items():
if isinstance(dst, list):
for d in dst:
connect_single(emitter, src, receiver, d)
continue
if not isinstance(dst, list):
dst = [dst]
connect_single(emitter, src, receiver, dst)
for d in dst:
connect_single(emitter, src, receiver, d)
# possibility to set events, when False it will NOT add events at all
# setting events to dict with `action_name`:False will not add `action_name`
# event
events_to_add = [
Dependency(emitter.name, 'run', 'success', receiver.name, 'run'),
Dependency(emitter.name, 'update', 'success', receiver.name, 'update')
@ -187,134 +71,40 @@ def connect(emitter, receiver, mapping=None, events=None):
elif events is not False:
add_events(emitter.name, events_to_add)
# receiver.save()
def connect_single(emitter, src, receiver, dst):
# Disconnect all receiver inputs
# Check if receiver input is of list type first
emitter_input = emitter.resource_inputs()[src]
receiver_input = receiver.resource_inputs()[dst]
if emitter_input.id == receiver_input.id:
raise Exception(
'Trying to connect {} to itself, this is not possible'.format(
emitter_input.id)
)
if not receiver_input.is_list:
receiver_input.receivers.delete_all_incoming(receiver_input)
# Check for cycles
# TODO: change to get_paths after it is implemented in drivers
if emitter_input in receiver_input.receivers.value:
raise Exception('Prevented creating a cycle')
log.debug('Connecting {}::{} -> {}::{}'.format(
emitter.name, emitter_input.name, receiver.name, receiver_input.name
))
emitter_input.receivers.add(receiver_input)
def disconnect_receiver_by_input(receiver, input_name):
input_node = receiver.resource_inputs()[input_name]
input_node.receivers.delete_all_incoming(input_node)
def disconnect(emitter, receiver):
# convert if needed
# TODO: handle invalid resource
# if isinstance(emitter, basestring):
# emitter = resource.load(emitter)
# if isinstance(receiver, basestring):
# receiver = resource.load(receiver)
clients = Connections.read_clients()
for src, destinations in clients[emitter.name].items():
for destination in destinations:
receiver_input = destination[1]
if receiver_input in receiver.args:
if receiver.args[receiver_input].type_ != 'list':
log.debug(
'Removing input %s from %s', receiver_input, receiver.name
)
emitter.args[src].unsubscribe(receiver.args[receiver_input])
disconnect_by_src(emitter.name, src, receiver)
def disconnect_receiver_by_input(receiver, input):
"""Find receiver connection by input and disconnect it.
:param receiver:
:param input:
:return:
"""
clients = Connections.read_clients()
for emitter_name, inputs in clients.items():
disconnect_by_src(emitter_name, input, receiver)
def disconnect_by_src(emitter_name, src, receiver):
clients = Connections.read_clients()
if src in clients[emitter_name]:
clients[emitter_name][src] = [
destination for destination in clients[emitter_name][src]
if destination[0] != receiver.name
]
Connections.save_clients(clients)
def notify(source, key, value):
from solar.core.resource import load
clients = Connections.read_clients()
if source.name not in clients:
clients[source.name] = {}
Connections.save_clients(clients)
log.debug('Notify %s %s %s %s', source.name, key, value, clients[source.name])
if key in clients[source.name]:
for client, r_key in clients[source.name][key]:
resource = load(client)
log.debug('Resource found: %s', client)
if resource:
resource.update({r_key: value})
else:
log.debug('Resource %s deleted?', client)
pass
def assign_connections(receiver, connections):
mappings = defaultdict(list)
for key, dest in connections.iteritems():
resource, r_key = dest.split('.')
mappings[resource].append([r_key, key])
for resource, r_mappings in mappings.iteritems():
connect(resource, receiver, r_mappings)
def connection_graph():
resource_dependencies = {}
clients = Connections.read_clients()
for emitter_name, destination_values in clients.items():
resource_dependencies.setdefault(emitter_name, set())
for emitter_input, receivers in destination_values.items():
resource_dependencies[emitter_name].update(
receiver[0] for receiver in receivers
)
g = nx.DiGraph()
# TODO: tags as graph node attributes
for emitter_name, receivers in resource_dependencies.items():
g.add_node(emitter_name)
g.add_nodes_from(receivers)
g.add_edges_from(
itertools.izip(
itertools.repeat(emitter_name),
receivers
)
)
return g
def detailed_connection_graph(start_with=None, end_with=None):
g = nx.MultiDiGraph()
clients = Connections.read_clients()
for emitter_name, destination_values in clients.items():
for emitter_input, receivers in destination_values.items():
for receiver_name, receiver_input in receivers:
label = '{}:{}'.format(emitter_input, receiver_input)
g.add_edge(emitter_name, receiver_name, label=label)
ret = g
if start_with is not None:
ret = g.subgraph(
nx.dfs_postorder_nodes(ret, start_with)
)
if end_with is not None:
ret = g.subgraph(
nx.dfs_postorder_nodes(ret.reverse(), end_with)
)
return ret
for emitter_input in emitter.resource_inputs().values():
for receiver_input in receiver.resource_inputs().values():
emitter_input.receivers.remove(receiver_input)

View File

@ -18,43 +18,53 @@ import traceback
from log import log
from solar.core import resource
from solar.core import signals
def test(r):
if isinstance(r, basestring):
r = resource.load(r)
log.debug('Trying {}'.format(r.name))
script_path = os.path.join(r.db_obj.base_path, 'test.py')
if not os.path.exists(script_path):
log.warning('resource {} has no tests'.format(r.name))
return {}
log.debug('File {} found'.format(script_path))
with open(script_path) as f:
module = imp.load_module(
'{}_test'.format(r.name),
f,
script_path,
('', 'r', imp.PY_SOURCE)
)
try:
module.test(r)
return {
r.name: {
'status': 'ok',
},
}
except Exception:
return {
r.name: {
'status': 'error',
'message': traceback.format_exc(),
}
}
def test_all():
results = {}
conn_graph = signals.detailed_connection_graph()
#srt = nx.topological_sort(conn_graph)
resources = resource.load_all()
for name in conn_graph:
log.debug('Trying {}'.format(name))
r = resource.load(name)
script_path = os.path.join(r.metadata['base_path'], 'test.py')
if not os.path.exists(script_path):
log.warning('resource {} has no tests'.format(name))
continue
log.debug('File {} found'.format(script_path))
with open(script_path) as f:
module = imp.load_module(
'{}_test'.format(name),
f,
script_path,
('', 'r', imp.PY_SOURCE)
)
try:
module.test(r)
results[name] = {
'status': 'ok',
}
except Exception:
results[name] = {
'status': 'error',
'message': traceback.format_exc(),
}
for r in resources:
ret = test(r)
if ret:
results.update(ret)
return results

View File

@ -28,13 +28,13 @@ class _SSHTransport(object):
def _fabric_settings(self, resource):
return {
'host_string': self._ssh_command_host(resource),
'key_filename': resource.args['ssh_key'].value,
'key_filename': resource.args['ssh_key'],
}
# TODO: maybe static/class method ?
def _ssh_command_host(self, resource):
return '{}@{}'.format(resource.args['ssh_user'].value,
resource.args['ip'].value)
return '{}@{}'.format(resource.args['ssh_user'],
resource.args['ip'])
class SSHSyncTransport(SyncTransport, _SSHTransport):

View File

@ -171,14 +171,14 @@ def validate_resource(r):
"""
ret = {}
input_schemas = r.metadata['input']
args = r.args_dict()
inputs = r.resource_inputs()
args = r.args
for input_name, input_definition in input_schemas.items():
for input_name, input_definition in inputs.items():
errors = validate_input(
args.get(input_name),
jsonschema=input_definition.get('jsonschema'),
schema=input_definition.get('schema')
#jsonschema=input_definition.get('jsonschema'),
schema=input_definition.schema
)
if errors:
ret[input_name] = errors

View File

@ -25,6 +25,10 @@ class CannotFindExtension(SolarError):
pass
class ValidationError(SolarError):
pass
class LexError(SolarError):
pass

View File

@ -19,7 +19,7 @@ import networkx as nx
from solar.core.log import log
from solar.interfaces.db import get_db
from solar.events.controls import Dep, React
from solar.events.controls import Dep, React, StateChange
db = get_db()
@ -42,7 +42,7 @@ def add_event(ev):
break
else:
rst.append(ev)
db.save(
db.create(
ev.parent_node,
[i.to_dict() for i in rst],
collection=db.COLLECTIONS.events)
@ -64,14 +64,14 @@ def add_react(parent, dep, actions, state='success'):
def remove_event(ev):
rst = all_events(ev.parent_node)
db.save(
db.create(
ev.parent_node,
[i.to_dict() for i in rst],
collection=db.COLLECTIONS.events)
def set_events(resource, lst):
db.save(
db.create(
resource,
[i.to_dict() for i in lst],
collection=db.COLLECTIONS.events)
@ -84,10 +84,11 @@ def add_events(resource, lst):
def all_events(resource):
events = db.read(resource, collection=db.COLLECTIONS.events)
events = db.get(resource, collection=db.COLLECTIONS.events,
return_empty=True, db_convert=False)
if not events:
return []
return [create_event(i) for i in events]
return [create_event(i) for i in events['properties']]
def bft_events_graph(start):

View File

@ -12,20 +12,27 @@
# License for the specific language governing permissions and limitations
# under the License.
from solar.interfaces.db.redis_db import RedisDB
from solar.interfaces.db.redis_db import FakeRedisDB
import importlib
mapping = {
'redis_db': RedisDB,
'fakeredis_db': FakeRedisDB
db_backends = {
'neo4j_db': ('solar.interfaces.db.neo4j', 'Neo4jDB'),
'redis_db': ('solar.interfaces.db.redis_db', 'RedisDB'),
'fakeredis_db': ('solar.interfaces.db.redis_db', 'FakeRedisDB'),
'redis_graph_db': ('solar.interfaces.db.redis_graph_db', 'RedisGraphDB'),
'fakeredis_graph_db': ('solar.interfaces.db.redis_graph_db', 'FakeRedisGraphDB'),
}
CURRENT_DB = 'redis_graph_db'
#CURRENT_DB = 'neo4j_db'
DB = None
def get_db():
def get_db(backend=CURRENT_DB):
# Should be retrieved from config
global DB
if DB is None:
DB = mapping['redis_db']()
import_path, klass = db_backends[backend]
module = importlib.import_module(import_path)
DB = getattr(module, klass)()
return DB

View File

@ -0,0 +1,224 @@
import abc
from enum import Enum
from functools import partial
class Node(object):
def __init__(self, db, uid, labels, properties):
self.db = db
self.uid = uid
self.labels = labels
self.properties = properties
@property
def collection(self):
return getattr(
BaseGraphDB.COLLECTIONS,
list(self.labels)[0]
)
class Relation(object):
def __init__(self, db, start_node, end_node, properties):
self.db = db
self.start_node = start_node
self.end_node = end_node
self.properties = properties
class DBObjectMeta(abc.ABCMeta):
# Tuples of: function name, is-multi (i.e. returns a list)
node_db_read_methods = [
('all', True),
('create', False),
('get', False),
('get_or_create', False),
]
relation_db_read_methods = [
('all_relations', True),
('create_relation', False),
('get_relations', True),
('get_relation', False),
('get_or_create_relation', False),
]
def __new__(cls, name, parents, dct):
def from_db_list_decorator(converting_func, method):
def wrapper(self, *args, **kwargs):
db_convert = kwargs.pop('db_convert', True)
result = method(self, *args, **kwargs)
if db_convert:
return map(partial(converting_func, self), result)
return result
return wrapper
def from_db_decorator(converting_func, method):
def wrapper(self, *args, **kwargs):
db_convert = kwargs.pop('db_convert', True)
result = method(self, *args, **kwargs)
if result is None:
return
if db_convert:
return converting_func(self, result)
return result
return wrapper
node_db_to_object = cls.find_method(
'node_db_to_object', name, parents, dct
)
relation_db_to_object = cls.find_method(
'relation_db_to_object', name, parents, dct
)
# Node conversions
for method_name, is_list in cls.node_db_read_methods:
method = cls.find_method(method_name, name, parents, dct)
if is_list:
func = from_db_list_decorator
else:
func = from_db_decorator
dct[method_name] = func(node_db_to_object, method)
# Relation conversions
for method_name, is_list in cls.relation_db_read_methods:
method = cls.find_method(method_name, name, parents, dct)
if is_list:
func = from_db_list_decorator
else:
func = from_db_decorator
dct[method_name] = func(relation_db_to_object, method)
return super(DBObjectMeta, cls).__new__(cls, name, parents, dct)
@classmethod
def find_method(cls, method_name, class_name, parents, dict):
if method_name in dict:
return dict[method_name]
for parent in parents:
method = getattr(parent, method_name)
if method:
return method
raise NotImplementedError(
'"{}" method not implemented in class {}'.format(
method_name, class_name
)
)
class BaseGraphDB(object):
__metaclass__ = DBObjectMeta
COLLECTIONS = Enum(
'Collections',
'input resource state_data state_log plan_node plan_graph events stage_log commit_log'
)
DEFAULT_COLLECTION=COLLECTIONS.resource
RELATION_TYPES = Enum(
'RelationTypes',
'input_to_input resource_input plan_edge graph_to_node'
)
DEFAULT_RELATION=RELATION_TYPES.resource_input
@staticmethod
def node_db_to_object(node_db):
"""Convert node DB object to Node object."""
@staticmethod
def object_to_node_db(node_obj):
"""Convert Node object to node DB object."""
@staticmethod
def relation_db_to_object(relation_db):
"""Convert relation DB object to Relation object."""
@staticmethod
def object_to_relation_db(relation_obj):
"""Convert Relation object to relation DB object."""
@abc.abstractmethod
def all(self, collection=DEFAULT_COLLECTION):
"""Return all elements (nodes) of type `collection`."""
@abc.abstractmethod
def all_relations(self, type_=DEFAULT_RELATION):
"""Return all relations of type `type_`."""
@abc.abstractmethod
def clear(self):
"""Clear the whole DB."""
@abc.abstractmethod
def clear_collection(self, collection=DEFAULT_COLLECTION):
"""Clear all elements (nodes) of type `collection`."""
@abc.abstractmethod
def create(self, name, properties={}, collection=DEFAULT_COLLECTION):
"""Create element (node) with given name, args, of type `collection`."""
@abc.abstractmethod
def create_relation(self,
source,
dest,
properties={},
type_=DEFAULT_RELATION):
"""
Create relation (connection) of type `type_` from source to dest with
given args.
"""
@abc.abstractmethod
def get(self, name, collection=DEFAULT_COLLECTION):
"""Fetch element with given name and collection type."""
@abc.abstractmethod
def get_or_create(self,
name,
properties={},
collection=DEFAULT_COLLECTION):
"""
Fetch or create element (if not exists) with given name, args of type
`collection`.
"""
@abc.abstractmethod
def delete_relations(self,
source=None,
dest=None,
type_=DEFAULT_RELATION):
"""Delete all relations of type `type_` from source to dest."""
@abc.abstractmethod
def get_relations(self,
source=None,
dest=None,
type_=DEFAULT_RELATION):
"""Fetch all relations of type `type_` from source to dest.
NOTE that this function must return only direct relations (edges)
between vertices `source` and `dest` of type `type_`.
If you want all PATHS between `source` and `dest`, write another
method for this (`get_paths`)."""
@abc.abstractmethod
def get_relation(self, source, dest, type_=DEFAULT_RELATION):
"""Fetch relations with given source, dest and type_."""
@abc.abstractmethod
def get_or_create_relation(self,
source,
dest,
properties={},
type_=DEFAULT_RELATION):
"""Fetch or create relation with given args."""

View File

@ -0,0 +1,205 @@
import json
from copy import deepcopy
import py2neo
from solar.core import log
from .base import BaseGraphDB, Node, Relation
class Neo4jDB(BaseGraphDB):
DB = {
'host': 'localhost',
'port': 7474,
}
NEO4J_CLIENT = py2neo.Graph
def __init__(self):
self._r = self.NEO4J_CLIENT('http://{host}:{port}/db/data/'.format(
**self.DB
))
def node_db_to_object(self, node_db):
return Node(
self,
node_db.properties['name'],
node_db.labels,
# Neo4j Node.properties is some strange PropertySet, use dict instead
dict(**node_db.properties)
)
def relation_db_to_object(self, relation_db):
return Relation(
self,
self.node_db_to_object(relation_db.start_node),
self.node_db_to_object(relation_db.end_node),
relation_db.properties
)
def all(self, collection=BaseGraphDB.DEFAULT_COLLECTION):
return [
r.n for r in self._r.cypher.execute(
'MATCH (n:%(collection)s) RETURN n' % {
'collection': collection.name,
}
)
]
def all_relations(self, type_=BaseGraphDB.DEFAULT_RELATION):
return [
r.r for r in self._r.cypher.execute(
*self._relations_query(
source=None, dest=None, type_=type_
)
)
]
def clear(self):
log.log.debug('Clearing whole DB')
self._r.delete_all()
def clear_collection(self, collection=BaseGraphDB.DEFAULT_COLLECTION):
log.log.debug('Clearing collection %s', collection.name)
# TODO: make single DELETE query
self._r.delete([r.n for r in self.all(collection=collection)])
def create(self, name, properties={}, collection=BaseGraphDB.DEFAULT_COLLECTION):
log.log.debug(
'Creating %s, name %s with properties %s',
collection.name,
name,
properties
)
properties = deepcopy(properties)
properties['name'] = name
n = py2neo.Node(collection.name, **properties)
return self._r.create(n)[0]
def create_relation(self,
source,
dest,
properties={},
type_=BaseGraphDB.DEFAULT_RELATION):
log.log.debug(
'Creating %s from %s to %s with properties %s',
type_.name,
source.properties['name'],
dest.properties['name'],
properties
)
s = self.get(
source.properties['name'],
collection=source.collection,
db_convert=False
)
d = self.get(
dest.properties['name'],
collection=dest.collection,
db_convert=False
)
r = py2neo.Relationship(s, type_.name, d, **properties)
self._r.create(r)
return r
def _get_query(self, name, collection=BaseGraphDB.DEFAULT_COLLECTION):
return 'MATCH (n:%(collection)s {name:{name}}) RETURN n' % {
'collection': collection.name,
}, {
'name': name,
}
def get(self, name, collection=BaseGraphDB.DEFAULT_COLLECTION):
query, kwargs = self._get_query(name, collection=collection)
res = self._r.cypher.execute(query, kwargs)
if res:
return res[0].n
def get_or_create(self,
name,
properties={},
collection=BaseGraphDB.DEFAULT_COLLECTION):
n = self.get(name, collection=collection, db_convert=False)
if n:
if properties != n.properties:
n.properties.update(properties)
n.push()
return n
return self.create(name, properties=properties, collection=collection)
def _relations_query(self,
source=None,
dest=None,
type_=BaseGraphDB.DEFAULT_RELATION,
query_type='RETURN'):
kwargs = {}
source_query = '(n)'
if source:
source_query = '(n {name:{source_name}})'
kwargs['source_name'] = source.properties['name']
dest_query = '(m)'
if dest:
dest_query = '(m {name:{dest_name}})'
kwargs['dest_name'] = dest.properties['name']
rel_query = '[r:%(type_)s]' % {'type_': type_.name}
query = ('MATCH %(source_query)s-%(rel_query)s->'
'%(dest_query)s %(query_type)s r' % {
'dest_query': dest_query,
'query_type': query_type,
'rel_query': rel_query,
'source_query': source_query,
})
return query, kwargs
def delete_relations(self,
source=None,
dest=None,
type_=BaseGraphDB.DEFAULT_RELATION):
query, kwargs = self._relations_query(
source=source, dest=dest, type_=type_, query_type='DELETE'
)
self._r.cypher.execute(query, kwargs)
def get_relations(self,
source=None,
dest=None,
type_=BaseGraphDB.DEFAULT_RELATION):
query, kwargs = self._relations_query(
source=source, dest=dest, type_=type_
)
res = self._r.cypher.execute(query, kwargs)
return [r.r for r in res]
def get_relation(self, source, dest, type_=BaseGraphDB.DEFAULT_RELATION):
rel = self.get_relations(source=source, dest=dest, type_=type_)
if rel:
return rel[0]
def get_or_create_relation(self,
source,
dest,
properties={},
type_=BaseGraphDB.DEFAULT_RELATION):
rel = self.get_relations(source=source, dest=dest, type_=type_)
if rel:
r = rel[0]
if properties != r.properties:
r.properties.update(properties)
r.push()
return r
return self.create_relation(source, dest, properties=properties, type_=type_)

View File

@ -77,8 +77,8 @@ class RedisDB(object):
def clear(self):
self._r.flushdb()
def get_set(self, collection):
return OrderedSet(self._r, collection)
def get_ordered_hash(self, collection):
return OrderedHash(self._r, collection)
def clear_collection(self, collection=COLLECTIONS.resource):
key_glob = self._make_key(collection, '*')
@ -96,7 +96,7 @@ class RedisDB(object):
return '{0}:{1}'.format(collection, _id)
class OrderedSet(object):
class OrderedHash(object):
def __init__(self, client, collection):
self.r = client

View File

@ -0,0 +1,266 @@
import json
import redis
import fakeredis
from solar import utils
from solar import errors
from .base import BaseGraphDB, Node, Relation
from .redis_db import OrderedHash
class RedisGraphDB(BaseGraphDB):
DB = {
'host': 'localhost',
'port': 6379,
}
REDIS_CLIENT = redis.StrictRedis
def __init__(self):
self._r = self.REDIS_CLIENT(**self.DB)
self.entities = {}
def node_db_to_object(self, node_db):
if isinstance(node_db, Node):
return node_db
return Node(
self,
node_db['name'],
[node_db['collection']],
node_db['properties']
)
def relation_db_to_object(self, relation_db):
if isinstance(relation_db, Relation):
return relation_db
if relation_db['type_'] == BaseGraphDB.RELATION_TYPES.input_to_input.name:
source_collection = BaseGraphDB.COLLECTIONS.input
dest_collection = BaseGraphDB.COLLECTIONS.input
elif relation_db['type_'] == BaseGraphDB.RELATION_TYPES.resource_input.name:
source_collection = BaseGraphDB.COLLECTIONS.resource
dest_collection = BaseGraphDB.COLLECTIONS.input
source = self.get(relation_db['source'], collection=source_collection)
dest = self.get(relation_db['dest'], collection=dest_collection)
return Relation(
self,
source,
dest,
relation_db['properties']
)
def all(self, collection=BaseGraphDB.DEFAULT_COLLECTION):
"""Return all elements (nodes) of type `collection`."""
key_glob = self._make_collection_key(collection, '*')
for result in self._all(key_glob):
yield result
def all_relations(self, type_=BaseGraphDB.DEFAULT_RELATION):
"""Return all relations of type `type_`."""
key_glob = self._make_relation_key(type_, '*')
for result in self._all(key_glob):
yield result
def _all(self, key_glob):
keys = self._r.keys(key_glob)
with self._r.pipeline() as pipe:
pipe.multi()
values = [self._r.get(key) for key in keys]
pipe.execute()
for value in values:
yield json.loads(value)
def clear(self):
"""Clear the whole DB."""
self._r.flushdb()
def clear_collection(self, collection=BaseGraphDB.DEFAULT_COLLECTION):
"""Clear all elements (nodes) of type `collection`."""
key_glob = self._make_collection_key(collection, '*')
self._r.delete(self._r.keys(key_glob))
def create(self, name, properties={}, collection=BaseGraphDB.DEFAULT_COLLECTION):
"""Create element (node) with given name, properties, of type `collection`."""
if isinstance(collection, self.COLLECTIONS):
collection = collection.name
properties = {
'name': name,
'properties': properties,
'collection': collection,
}
self._r.set(
self._make_collection_key(collection, name),
json.dumps(properties)
)
return properties
def create_relation(self,
source,
dest,
properties={},
type_=BaseGraphDB.DEFAULT_RELATION):
"""
Create relation (connection) of type `type_` from source to dest with
given properties.
"""
return self.create_relation_str(
source.uid, dest.uid, properties, type_=type_)
def create_relation_str(self, source, dest,
properties={}, type_=BaseGraphDB.DEFAULT_RELATION):
if isinstance(type_, self.RELATION_TYPES):
type_ = type_.name
uid = self._make_relation_uid(source, dest)
properties = {
'source': source,
'dest': dest,
'properties': properties,
'type_': type_,
}
self._r.set(
self._make_relation_key(type_, uid),
json.dumps(properties)
)
return properties
def get(self, name, collection=BaseGraphDB.DEFAULT_COLLECTION,
return_empty=False):
"""Fetch element with given name and collection type."""
try:
item = self._r.get(self._make_collection_key(collection, name))
if not item and return_empty:
return item
return json.loads(item)
except TypeError:
raise KeyError
def get_or_create(self,
name,
properties={},
collection=BaseGraphDB.DEFAULT_COLLECTION):
"""
Fetch or create element (if not exists) with given name, properties of
type `collection`.
"""
try:
return self.get(name, collection=collection)
except KeyError:
return self.create(name, properties=properties, collection=collection)
def _relations_glob(self,
source=None,
dest=None,
type_=BaseGraphDB.DEFAULT_RELATION):
if source is None:
source = '*'
else:
source = source.uid
if dest is None:
dest = '*'
else:
dest = dest.uid
return self._make_relation_key(type_, self._make_relation_uid(source, dest))
def delete_relations(self,
source=None,
dest=None,
type_=BaseGraphDB.DEFAULT_RELATION):
"""Delete all relations of type `type_` from source to dest."""
glob = self._relations_glob(source=source, dest=dest, type_=type_)
keys = self._r.keys(glob)
if keys:
self._r.delete(*keys)
def get_relations(self,
source=None,
dest=None,
type_=BaseGraphDB.DEFAULT_RELATION):
"""Fetch all relations of type `type_` from source to dest."""
glob = self._relations_glob(source=source, dest=dest, type_=type_)
for r in self._all(glob):
# Glob is primitive, we must filter stuff correctly here
if source and r['source'] != source.uid:
continue
if dest and r['dest'] != dest.uid:
continue
yield r
def get_relation(self, source, dest, type_=BaseGraphDB.DEFAULT_RELATION):
"""Fetch relations with given source, dest and type_."""
uid = self._make_relation_key(source.uid, dest.uid)
try:
return json.loads(
self._r.get(self._make_relation_key(type_, uid))
)
except TypeError:
raise KeyError
def get_or_create_relation(self,
source,
dest,
properties={},
type_=BaseGraphDB.DEFAULT_RELATION):
"""Fetch or create relation with given properties."""
try:
return self.get_relation(source, dest, type_=type_)
except KeyError:
return self.create_relation(source, dest, properties=properties, type_=type_)
def _make_collection_key(self, collection, _id):
if isinstance(collection, self.COLLECTIONS):
collection = collection.name
# NOTE: hiera-redis backend depends on this!
return '{0}:{1}'.format(collection, _id)
def _make_relation_uid(self, source, dest):
"""
There can be only one relation from source to dest, that's why
this function works.
"""
return '{0}-{1}'.format(source, dest)
def _make_relation_key(self, type_, _id):
if isinstance(type_, self.RELATION_TYPES):
type_ = type_.name
# NOTE: hiera-redis backend depends on this!
return '{0}:{1}'.format(type_, _id)
def get_ordered_hash(self, collection):
return OrderedHash(self._r, collection)
class FakeRedisGraphDB(RedisGraphDB):
REDIS_CLIENT = fakeredis.FakeStrictRedis

View File

@ -0,0 +1,399 @@
# Copyright 2015 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import inspect
import uuid
from solar import errors
from solar.core import validation
from solar.interfaces.db import base
from solar.interfaces.db import get_db
db = get_db()
class DBField(object):
is_primary = False
schema = None
schema_in_field = None
default_value = None
def __init__(self, name, value=None):
self.name = name
self.value = value
if value is None:
self.value = self.default_value
def __eq__(self, inst):
return self.name == inst.name and self.value == inst.value
def __ne__(self, inst):
return not self.__eq__(inst)
def __hash__(self):
return hash('{}:{}'.format(self.name, self.value))
def validate(self):
if self.schema is None:
return
es = validation.validate_input(self.value, schema=self.schema)
if es:
raise errors.ValidationError('"{}": {}'.format(self.name, es[0]))
def db_field(schema=None,
schema_in_field=None,
default_value=None,
is_primary=False):
"""Definition for the DB field.
schema - simple schema according to the one in solar.core.validation
schema_in_field - if you don't want to fix schema, you can specify
another field in DBObject that will represent the schema used
for validation of this field
is_primary - only one field in db object can be primary. This key is used
for creating key in the DB
"""
class DBFieldX(DBField):
pass
DBFieldX.is_primary = is_primary
DBFieldX.schema = schema
DBFieldX.schema_in_field = schema_in_field
if default_value is not None:
DBFieldX.default_value = default_value
return DBFieldX
class DBRelatedField(object):
source_db_class = None
destination_db_class = None
relation_type = None
def __init__(self, name, source_db_object):
self.name = name
self.source_db_object = source_db_object
def add(self, *destination_db_objects):
for dest in destination_db_objects:
if not isinstance(dest, self.destination_db_class):
raise errors.SolarError(
'Object {} is of incompatible type {}.'.format(
dest, self.destination_db_class
)
)
db.get_or_create_relation(
self.source_db_object._db_node,
dest._db_node,
properties={},
type_=self.relation_type
)
def remove(self, *destination_db_objects):
for dest in destination_db_objects:
db.delete_relations(
source=self.source_db_object._db_node,
dest=dest._db_node,
type_=self.relation_type
)
@property
def value(self):
"""
Return DB objects that are destinations for self.source_db_object.
"""
source_db_node = self.source_db_object._db_node
if source_db_node is None:
return set()
relations = db.get_relations(source=source_db_node,
type_=self.relation_type)
ret = set()
for rel in relations:
ret.add(
self.destination_db_class(**rel.end_node.properties)
)
return ret
def sources(self, destination_db_object):
"""
Reverse of self.value, i.e. for given destination_db_object,
return source DB objects.
"""
destination_db_node = destination_db_object._db_node
if destination_db_node is None:
return set()
relations = db.get_relations(dest=destination_db_node,
type_=self.relation_type)
ret = set()
for rel in relations:
ret.add(
self.source_db_class(**rel.start_node.properties)
)
return ret
def delete_all_incoming(self, destination_db_object):
"""
Delete all relations for which destination_db_object is an end node.
"""
db.delete_relations(
dest=destination_db_object._db_node,
type_=self.relation_type
)
def db_related_field(relation_type, destination_db_class):
class DBRelatedFieldX(DBRelatedField):
pass
DBRelatedFieldX.relation_type = relation_type
DBRelatedFieldX.destination_db_class = destination_db_class
return DBRelatedFieldX
class DBObjectMeta(type):
def __new__(cls, name, parents, dct):
collection = dct.get('_collection')
if not collection:
raise NotImplementedError('Collection is required.')
dct['_meta'] = {}
dct['_meta']['fields'] = {}
dct['_meta']['related_to'] = {}
has_primary = False
for field_name, field_klass in dct.items():
if not inspect.isclass(field_klass):
continue
if issubclass(field_klass, DBField):
dct['_meta']['fields'][field_name] = field_klass
if field_klass.is_primary:
if has_primary:
raise errors.SolarError('Object cannot have 2 primary fields.')
has_primary = True
dct['_meta']['primary'] = field_name
elif issubclass(field_klass, DBRelatedField):
dct['_meta']['related_to'][field_name] = field_klass
if not has_primary:
raise errors.SolarError('Object needs to have a primary field.')
klass = super(DBObjectMeta, cls).__new__(cls, name, parents, dct)
# Support for self-references in relations
for field_name, field_klass in klass._meta['related_to'].items():
field_klass.source_db_class = klass
if field_klass.destination_db_class == klass.__name__:
field_klass.destination_db_class = klass
return klass
class DBObject(object):
# Enum from BaseGraphDB.COLLECTIONS
_collection = None
def __init__(self, **kwargs):
wrong_fields = set(kwargs) - set(self._meta['fields'])
if wrong_fields:
raise errors.SolarError(
'Unknown fields {}'.format(wrong_fields)
)
self._fields = {}
for field_name, field_klass in self._meta['fields'].items():
value = kwargs.get(field_name, field_klass.default_value)
self._fields[field_name] = field_klass(field_name, value=value)
self._related_to = {}
for field_name, field_klass in self._meta['related_to'].items():
inst = field_klass(field_name, self)
self._related_to[field_name] = inst
self._update_values()
def __eq__(self, inst):
# NOTE: don't compare related fields
self._update_fields_values()
return self._fields == inst._fields
def __ne__(self, inst):
return not self.__eq__(inst)
def __hash__(self):
return hash(self._db_key)
def _update_fields_values(self):
"""Copy values from self to self._fields."""
for field in self._fields.values():
field.value = getattr(self, field.name)
def _update_values(self):
"""
Reverse of _update_fields_values, i.e. copy values from self._fields to
self."""
for field in self._fields.values():
setattr(self, field.name, field.value)
for field in self._related_to.values():
setattr(self, field.name, field)
@property
def _db_key(self):
"""Key for the DB document (in KV-store).
You can overwrite this with custom keys."""
if not self._primary_field.value:
setattr(self, self._primary_field.name, unicode(uuid.uuid4()))
self._update_fields_values()
return self._primary_field.value
@property
def _primary_field(self):
return self._fields[self._meta['primary']]
@property
def _db_node(self):
try:
return db.get(self._db_key, collection=self._collection)
except KeyError:
return
def validate(self):
self._update_fields_values()
for field in self._fields.values():
if field.schema_in_field is not None:
field.schema = self._fields[field.schema_in_field].value
field.validate()
def to_dict(self):
self._update_fields_values()
return {
f.name: f.value for f in self._fields.values()
}
@classmethod
def load(cls, key):
r = db.get(key, collection=cls._collection)
return cls(**r.properties)
@classmethod
def load_all(cls):
rs = db.all(collection=cls._collection)
return [cls(**r.properties) for r in rs]
def save(self):
db.create(
self._db_key,
properties=self.to_dict(),
collection=self._collection
)
class DBResourceInput(DBObject):
__metaclass__ = DBObjectMeta
_collection = base.BaseGraphDB.COLLECTIONS.input
id = db_field(schema='str!', is_primary=True)
name = db_field(schema='str!')
schema = db_field()
value = db_field(schema_in_field='schema')
is_list = db_field(schema='bool')
receivers = db_related_field(base.BaseGraphDB.RELATION_TYPES.input_to_input,
'DBResourceInput')
def backtrack_value(self):
# TODO: this is actually just fetching head element in linked list
# so this whole algorithm can be moved to the db backend probably
# TODO: cycle detection?
# TODO: write this as a Cypher query? Move to DB?
inputs = self.receivers.sources(self)
if not inputs:
return self.value
if self.is_list:
return [i.backtrack_value() for i in inputs]
return inputs.pop().backtrack_value()
class DBResource(DBObject):
__metaclass__ = DBObjectMeta
_collection = base.BaseGraphDB.COLLECTIONS.resource
id = db_field(schema='str', is_primary=True)
name = db_field(schema='str!')
actions_path = db_field(schema='str')
base_name = db_field(schema='str')
base_path = db_field(schema='str')
handler = db_field(schema='str') # one of: {'ansible_playbook', 'ansible_template', 'puppet', etc}
puppet_module = db_field(schema='str')
version = db_field(schema='str')
tags = db_field(schema=[], default_value=[])
meta_inputs = db_field(schema={}, default_value={})
inputs = db_related_field(base.BaseGraphDB.RELATION_TYPES.resource_input,
DBResourceInput)
def add_input(self, name, schema, value):
# NOTE: Inputs need to have uuid added because there can be many
# inputs with the same name
uid = '{}-{}'.format(name, uuid.uuid4())
input = DBResourceInput(id=uid,
name=name,
schema=schema,
value=value,
is_list=isinstance(schema, list))
input.save()
self.inputs.add(input)
# TODO: remove this
if __name__ == '__main__':
r = DBResource(name=1)
r.validate()

View File

@ -12,33 +12,42 @@
# License for the specific language governing permissions and limitations
# under the License.
import json
import uuid
import networkx as nx
import redis
from solar import utils
from .traversal import states
r = redis.StrictRedis(host='10.0.0.2', port=6379, db=1)
from solar.interfaces.db import get_db
db = get_db()
def save_graph(name, graph):
# maybe it is possible to store part of information in AsyncResult backend
r.set('{}:nodes'.format(name), json.dumps(graph.node.items()))
r.set('{}:edges'.format(name), json.dumps(graph.edges(data=True)))
r.set('{}:attributes'.format(name), json.dumps(graph.graph))
uid = graph.graph['uid']
db.create(uid, graph.graph, db.COLLECTIONS.plan_graph)
for n in graph:
collection = db.COLLECTIONS.plan_node.name + ':' + uid
db.create(n, properties=graph.node[n], collection=collection)
db.create_relation_str(uid, n, type_=db.RELATION_TYPES.graph_to_node)
for u, v, properties in graph.edges(data=True):
type_ = db.RELATION_TYPES.plan_edge.name + ':' + uid
db.create_relation_str(u, v, properties, type_=type_)
def get_graph(name):
dg = nx.OrderedMultiDiGraph()
nodes = json.loads(r.get('{}:nodes'.format(name)))
edges = json.loads(r.get('{}:edges'.format(name)))
dg.graph = json.loads(r.get('{}:attributes'.format(name)))
dg.add_nodes_from(nodes)
dg.add_edges_from(edges)
def get_graph(uid):
dg = nx.MultiDiGraph()
collection = db.COLLECTIONS.plan_node.name + ':' + uid
type_ = db.RELATION_TYPES.plan_edge.name + ':' + uid
dg.graph = db.get(uid, collection=db.COLLECTIONS.plan_graph).properties
dg.add_nodes_from([(n.uid, n.properties) for n in db.all(collection=collection)])
dg.add_edges_from([(i['source'], i['dest'], i['properties'])
for i in db.all_relations(type_=type_, db_convert=False)])
return dg

View File

@ -42,17 +42,9 @@ def create_diff(staged, commited):
return list(diff(commited, staged))
def _stage_changes(staged_resources, conn_graph,
commited_resources, staged_log):
def _stage_changes(staged_resources, commited_resources, staged_log):
try:
srt = nx.topological_sort(conn_graph)
except:
for cycle in nx.simple_cycles(conn_graph):
log.debug('CYCLE: %s', cycle)
raise
for res_uid in srt:
for res_uid in staged_resources.keys():
commited_data = commited_resources.get(res_uid, {})
staged_data = staged_resources.get(res_uid, {})
@ -72,17 +64,14 @@ def _stage_changes(staged_resources, conn_graph,
def stage_changes():
log = data.SL()
log.clean()
conn_graph = signals.detailed_connection_graph()
staged = {r.name: r.args_show()
for r in resource.load_all().values()}
staged = {r.name: r.args for r in resource.load_all()}
commited = data.CD()
return _stage_changes(staged, conn_graph, commited, log)
return _stage_changes(staged, commited, log)
def send_to_orchestration():
dg = nx.MultiDiGraph()
staged = {r.name: r.args_show()
for r in resource.load_all().values()}
staged = {r.name: r.args for r in resource.load_all()}
commited = data.CD()
events = {}
changed_nodes = []

View File

@ -83,8 +83,8 @@ def details(diff):
rst = []
for type_, val, change in diff:
if type_ == 'add':
for it in change:
rst.append('++ {}: {}'.format(it[0], unwrap_add(it[1])))
for key, val in change:
rst.append('++ {}: {}'.format(key ,val))
elif type_ == 'change':
rst.append('-+ {}: {} >> {}'.format(
unwrap_change_val(val), change[0], change[1]))
@ -112,7 +112,7 @@ def unwrap_change_val(val):
class Log(object):
def __init__(self, path):
self.ordered_log = db.get_set(path)
self.ordered_log = db.get_ordered_hash(path)
def append(self, logitem):
self.ordered_log.add([(logitem.log_action, logitem.to_dict())])
@ -152,21 +152,24 @@ class Data(collections.MutableMapping):
def __init__(self, path):
self.path = path
self.store = {}
r = db.read(path, collection=db.COLLECTIONS.state_data)
r = db.get(path, collection=db.COLLECTIONS.state_data,
return_empty=True, db_convert=False)
if r:
self.store = r or self.store
self.store = r.get('properties', {})
else:
self.store = {}
def __getitem__(self, key):
return self.store[key]
def __setitem__(self, key, value):
self.store[key] = value
db.save(self.path, self.store, collection=db.COLLECTIONS.state_data)
db.create(self.path, self.store, collection=db.COLLECTIONS.state_data)
def __delitem__(self, key):
self.store.pop(key)
db.save(self.path, self.store, collection=db.COLLECTIONS.state_data)
db.create(self.path, self.store, collection=db.COLLECTIONS.state_data)
def __iter__(self):
return iter(self.store)
@ -175,4 +178,4 @@ class Data(collections.MutableMapping):
return len(self.store)
def clean(self):
db.save(self.path, {}, collection=db.COLLECTIONS.state_data)
db.create(self.path, {}, collection=db.COLLECTIONS.state_data)

View File

@ -19,7 +19,6 @@ import unittest
import yaml
from solar.core.resource import virtual_resource as vr
from solar.core import signals as xs
from solar.interfaces.db import get_db
db = get_db()
@ -32,7 +31,6 @@ class BaseResourceTest(unittest.TestCase):
def tearDown(self):
shutil.rmtree(self.storage_dir)
db.clear()
xs.Connections.clear()
def make_resource_meta(self, meta_yaml):
meta = yaml.load(meta_yaml)
@ -44,5 +42,5 @@ class BaseResourceTest(unittest.TestCase):
return path
def create_resource(self, name, src, args):
return vr.create(name, src, args)[0]
def create_resource(self, name, src, args={}):
return vr.create(name, src, args=args)[0]

View File

@ -19,7 +19,12 @@ from solar.interfaces import db
def pytest_configure():
db.DB = db.mapping['fakeredis_db']()
if db.CURRENT_DB == 'redis_graph_db':
db.DB = db.get_db(backend='fakeredis_graph_db')
elif db.CURRENT_DB == 'redis_db':
db.DB = db.get_db(backend='fakeredis_db')
else:
db.DB = db.get_db(backend=db.CURRENT_DB)
@fixture(autouse=True)
@ -29,6 +34,5 @@ def cleanup(request):
from solar.core import signals
db.get_db().clear()
signals.Connections.clear()
request.addfinalizer(fin)

View File

@ -4,6 +4,7 @@ tasks:
parameters:
type: echo
args: [10]
before: [just_fail]
- uid: just_fail
parameters:
type: error

View File

@ -97,25 +97,10 @@ def resources():
'tags': []}}
return r
@fixture
def conn_graph():
edges = [
('n.1', 'r.1', {'label': 'ip:ip'}),
('n.1', 'h.1', {'label': 'ip:ip'}),
('r.1', 'h.1', {'label': 'ip:ips'})
]
mdg = nx.MultiDiGraph()
mdg.add_edges_from(edges)
return mdg
def test_stage_changes(resources, conn_graph):
def test_stage_changes(resources):
commited = {}
log = change._stage_changes(resources, conn_graph, commited, [])
log = change._stage_changes(resources, commited, [])
assert len(log) == 3
assert [l.res for l in log] == ['n.1', 'r.1', 'h.1']
def test_resource_fixture(staged):
res = wrap_resource(staged)
assert {l.res for l in log} == {'n.1', 'r.1', 'h.1'}

View File

@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
# Copyright 2015 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
@ -17,6 +18,32 @@ from pytest import fixture
from solar.events import api as evapi
from .base import BaseResourceTest
class EventAPITest(BaseResourceTest):
def test_events_load(self):
sample_meta_dir = self.make_resource_meta("""
id: sample
handler: ansible
version: 1.0.0
input:
value:
schema: int
value: 0
""")
sample1 = self.create_resource(
'sample1', sample_meta_dir, {'value': 1}
)
sample2 = self.create_resource(
'sample2', sample_meta_dir, {'value': 1}
)
evapi.Dep('sample1', 'run', 'success', 'sample2', 'run'),
loaded = evapi.all_events(sample1)
@fixture
def nova_deps():

View File

@ -0,0 +1,230 @@
# Copyright 2015 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from .base import BaseResourceTest
from solar.core import resource
from solar import errors
from solar.interfaces import orm
from solar.interfaces.db import base
class TestORM(BaseResourceTest):
def test_no_collection_defined(self):
with self.assertRaisesRegexp(NotImplementedError, 'Collection is required.'):
class TestDBObject(orm.DBObject):
__metaclass__ = orm.DBObjectMeta
def test_has_primary(self):
with self.assertRaisesRegexp(errors.SolarError, 'Object needs to have a primary field.'):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
test1 = orm.db_field(schema='str')
def test_no_multiple_primaries(self):
with self.assertRaisesRegexp(errors.SolarError, 'Object cannot have 2 primary fields.'):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
test1 = orm.db_field(schema='str', is_primary=True)
test2 = orm.db_field(schema='str', is_primary=True)
def test_primary_field(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
test1 = orm.db_field(schema='str', is_primary=True)
t = TestDBObject(test1='abc')
self.assertEqual('test1', t._primary_field.name)
self.assertEqual('abc', t._db_key)
t = TestDBObject()
self.assertIsNotNone(t._db_key)
self.assertIsNotNone(t.test1)
def test_default_value(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
test1 = orm.db_field(schema='str',
is_primary=True,
default_value='1')
t = TestDBObject()
self.assertEqual('1', t.test1)
def test_field_validation(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
t = TestDBObject(id=1)
with self.assertRaises(errors.ValidationError):
t.validate()
t = TestDBObject(id='1')
t.validate()
def test_dynamic_schema_validation(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
schema = orm.db_field()
value = orm.db_field(schema_in_field='schema')
t = TestDBObject(id='1', schema='str', value=1)
with self.assertRaises(errors.ValidationError):
t.validate()
self.assertEqual(t._fields['value'].schema, t._fields['schema'].value)
t = TestDBObject(id='1', schema='int', value=1)
t.validate()
self.assertEqual(t._fields['value'].schema, t._fields['schema'].value)
def test_unknown_fields(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
with self.assertRaisesRegexp(errors.SolarError, 'Unknown fields .*iid'):
TestDBObject(iid=1)
def test_equality(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
test = orm.db_field(schema='str')
t1 = TestDBObject(id='1', test='test')
t2 = TestDBObject(id='2', test='test')
self.assertNotEqual(t1, t2)
t2 = TestDBObject(id='1', test='test2')
self.assertNotEqual(t1, t2)
t2 = TestDBObject(id='1', test='test')
self.assertEqual(t1, t2)
class TestORMRelation(BaseResourceTest):
def test_children_value(self):
class TestDBRelatedObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.input
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
related = orm.db_related_field(
base.BaseGraphDB.RELATION_TYPES.resource_input,
TestDBRelatedObject
)
r1 = TestDBRelatedObject(id='1')
r1.save()
r2 = TestDBRelatedObject(id='2')
r2.save()
o = TestDBObject(id='a')
o.save()
self.assertSetEqual(o.related.value, set())
o.related.add(r1)
self.assertSetEqual(o.related.value, {r1})
o.related.add(r2)
self.assertSetEqual(o.related.value, {r1, r2})
o.related.remove(r2)
self.assertSetEqual(o.related.value, {r1})
o.related.add(r2)
self.assertSetEqual(o.related.value, {r1, r2})
o.related.remove(r1, r2)
self.assertSetEqual(o.related.value, set())
o.related.add(r1, r2)
self.assertSetEqual(o.related.value, {r1, r2})
with self.assertRaisesRegexp(errors.SolarError, '.*incompatible type.*'):
o.related.add(o)
def test_relation_to_self(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.input
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
related = orm.db_related_field(
base.BaseGraphDB.RELATION_TYPES.input_to_input,
'TestDBObject'
)
o1 = TestDBObject(id='1')
o1.save()
o2 = TestDBObject(id='2')
o2.save()
o3 = TestDBObject(id='2')
o3.save()
o1.related.add(o2)
o2.related.add(o3)
self.assertEqual(o1.related.value, {o2})
self.assertEqual(o2.related.value, {o3})
class TestResourceORM(BaseResourceTest):
def test_save(self):
r = orm.DBResource(id='test1', name='test1', base_path='x')
r.save()
rr = resource.load(r.id)
self.assertEqual(r, rr.db_obj)
def test_add_input(self):
r = orm.DBResource(id='test1', name='test1', base_path='x')
r.save()
r.add_input('ip', 'str!', '10.0.0.2')
self.assertEqual(len(r.inputs.value), 1)

View File

@ -12,8 +12,6 @@
# License for the specific language governing permissions and limitations
# under the License.
import unittest
import base
from solar.core import resource
@ -35,11 +33,11 @@ input:
sample1 = self.create_resource(
'sample1', sample_meta_dir, {'value': 1}
)
self.assertEqual(sample1.args['value'].value, 1)
self.assertEqual(sample1.args['value'], 1)
# test default value
sample2 = self.create_resource('sample2', sample_meta_dir, {})
self.assertEqual(sample2.args['value'].value, 0)
self.assertEqual(sample2.args['value'], 0)
def test_connections_recreated_after_load(self):
"""
@ -76,6 +74,22 @@ input:
sample1.update({'value': 2})
self.assertEqual(sample1.args['value'], sample2.args['value'])
def test_load(self):
sample_meta_dir = self.make_resource_meta("""
id: sample
handler: ansible
version: 1.0.0
input:
value:
schema: int
value: 0
""")
if __name__ == '__main__':
unittest.main()
sample = self.create_resource(
'sample', sample_meta_dir, {'value': 1}
)
sample_l = resource.load('sample')
self.assertDictEqual(sample.args, sample_l.args)
self.assertListEqual(sample.tags, sample_l.tags)

View File

@ -12,14 +12,60 @@
# License for the specific language governing permissions and limitations
# under the License.
import unittest
import base
from solar.core import signals as xs
class TestBaseInput(base.BaseResourceTest):
def test_no_self_connection(self):
sample_meta_dir = self.make_resource_meta("""
id: sample
handler: ansible
version: 1.0.0
input:
value:
schema: str!
value:
""")
sample = self.create_resource(
'sample', sample_meta_dir, {'value': 'x'}
)
with self.assertRaisesRegexp(
Exception,
'Trying to connect value-.* to itself'):
xs.connect(sample, sample, {'value'})
def test_no_cycles(self):
sample_meta_dir = self.make_resource_meta("""
id: sample
handler: ansible
version: 1.0.0
input:
value:
schema: str!
value:
""")
sample1 = self.create_resource(
'sample1', sample_meta_dir, {'value': 'x'}
)
sample2 = self.create_resource(
'sample2', sample_meta_dir, {'value': 'y'}
)
xs.connect(sample1, sample2)
with self.assertRaisesRegexp(
Exception,
'Prevented creating a cycle'):
xs.connect(sample2, sample1)
# TODO: more complex cycles
def test_input_dict_type(self):
sample_meta_dir = self.make_resource_meta("""
id: sample
@ -35,17 +81,13 @@ input:
'sample1', sample_meta_dir, {'values': {'a': 1, 'b': 2}}
)
sample2 = self.create_resource(
'sample2', sample_meta_dir, {'values': None}
'sample2', sample_meta_dir
)
xs.connect(sample1, sample2)
self.assertEqual(
sample1.args['values'],
sample2.args['values']
)
self.assertEqual(
sample2.args['values'].emitter,
sample1.args['values']
)
# Check update
sample1.update({'values': {'a': 2}})
@ -66,11 +108,11 @@ input:
sample1.args['values'],
{'a': 3}
)
self.assertEqual(
sample2.args['values'],
{'a': 2}
)
self.assertEqual(sample2.args['values'].emitter, None)
#self.assertEqual(
# sample2.args['values'],
# {'a': 2}
#)
#self.assertEqual(sample2.args['values'].emitter, None)
def test_multiple_resource_disjoint_connect(self):
sample_meta_dir = self.make_resource_meta("""
@ -79,7 +121,7 @@ handler: ansible
version: 1.0.0
input:
ip:
schema: string
schema: str
value:
port:
schema: int
@ -91,7 +133,7 @@ handler: ansible
version: 1.0.0
input:
ip:
schema: string
schema: str
value:
""")
sample_port_meta_dir = self.make_resource_meta("""
@ -111,20 +153,24 @@ input:
'sample-ip', sample_ip_meta_dir, {'ip': '10.0.0.1'}
)
sample_port = self.create_resource(
'sample-port', sample_port_meta_dir, {'port': '8000'}
'sample-port', sample_port_meta_dir, {'port': 8000}
)
self.assertNotEqual(
sample.resource_inputs()['ip'],
sample_ip.resource_inputs()['ip'],
)
xs.connect(sample_ip, sample)
xs.connect(sample_port, sample)
self.assertEqual(sample.args['ip'], sample_ip.args['ip'])
self.assertEqual(sample.args['port'], sample_port.args['port'])
self.assertEqual(
sample.args['ip'].emitter,
sample_ip.args['ip']
)
self.assertEqual(
sample.args['port'].emitter,
sample_port.args['port']
)
#self.assertEqual(
# sample.args['ip'].emitter,
# sample_ip.args['ip']
#)
#self.assertEqual(
# sample.args['port'].emitter,
# sample_port.args['port']
#)
def test_simple_observer_unsubscription(self):
sample_meta_dir = self.make_resource_meta("""
@ -133,7 +179,7 @@ handler: ansible
version: 1.0.0
input:
ip:
schema: string
schema: str
value:
""")
@ -149,17 +195,17 @@ input:
xs.connect(sample1, sample)
self.assertEqual(sample1.args['ip'], sample.args['ip'])
self.assertEqual(len(list(sample1.args['ip'].receivers)), 1)
self.assertEqual(
sample.args['ip'].emitter,
sample1.args['ip']
)
#self.assertEqual(len(list(sample1.args['ip'].receivers)), 1)
#self.assertEqual(
# sample.args['ip'].emitter,
# sample1.args['ip']
#)
xs.connect(sample2, sample)
self.assertEqual(sample2.args['ip'], sample.args['ip'])
# sample should be unsubscribed from sample1 and subscribed to sample2
self.assertEqual(len(list(sample1.args['ip'].receivers)), 0)
self.assertEqual(sample.args['ip'].emitter, sample2.args['ip'])
#self.assertEqual(len(list(sample1.args['ip'].receivers)), 0)
#self.assertEqual(sample.args['ip'].emitter, sample2.args['ip'])
sample2.update({'ip': '10.0.0.3'})
self.assertEqual(sample2.args['ip'], sample.args['ip'])
@ -220,35 +266,38 @@ input:
)
xs.connect(sample1, list_input_single, mapping={'ip': 'ips'})
self.assertEqual(
[ip['value'] for ip in list_input_single.args['ips'].value],
self.assertItemsEqual(
#[ip['value'] for ip in list_input_single.args['ips']],
list_input_single.args['ips'],
[
sample1.args['ip'],
]
)
self.assertListEqual(
[(e['emitter_attached_to'], e['emitter']) for e in list_input_single.args['ips'].value],
[(sample1.args['ip'].attached_to.name, 'ip')]
)
#self.assertListEqual(
# [(e['emitter_attached_to'], e['emitter']) for e in list_input_single.args['ips']],
# [(sample1.args['ip'].attached_to.name, 'ip')]
#)
xs.connect(sample2, list_input_single, mapping={'ip': 'ips'})
self.assertEqual(
[ip['value'] for ip in list_input_single.args['ips'].value],
self.assertItemsEqual(
#[ip['value'] for ip in list_input_single.args['ips']],
list_input_single.args['ips'],
[
sample1.args['ip'],
sample2.args['ip'],
]
)
self.assertListEqual(
[(e['emitter_attached_to'], e['emitter']) for e in list_input_single.args['ips'].value],
[(sample1.args['ip'].attached_to.name, 'ip'),
(sample2.args['ip'].attached_to.name, 'ip')]
)
#self.assertListEqual(
# [(e['emitter_attached_to'], e['emitter']) for e in list_input_single.args['ips']],
# [(sample1.args['ip'].attached_to.name, 'ip'),
# (sample2.args['ip'].attached_to.name, 'ip')]
#)
# Test update
sample2.update({'ip': '10.0.0.3'})
self.assertEqual(
[ip['value'] for ip in list_input_single.args['ips'].value],
self.assertItemsEqual(
#[ip['value'] for ip in list_input_single.args['ips']],
list_input_single.args['ips'],
[
sample1.args['ip'],
sample2.args['ip'],
@ -257,16 +306,17 @@ input:
# Test disconnect
xs.disconnect(sample2, list_input_single)
self.assertEqual(
[ip['value'] for ip in list_input_single.args['ips'].value],
self.assertItemsEqual(
#[ip['value'] for ip in list_input_single.args['ips']],
list_input_single.args['ips'],
[
sample1.args['ip'],
]
)
self.assertListEqual(
[(e['emitter_attached_to'], e['emitter']) for e in list_input_single.args['ips'].value],
[(sample1.args['ip'].attached_to.name, 'ip')]
)
#self.assertListEqual(
# [(e['emitter_attached_to'], e['emitter']) for e in list_input_single.args['ips']],
# [(sample1.args['ip'].attached_to.name, 'ip')]
#)
def test_list_input_multi(self):
sample_meta_dir = self.make_resource_meta("""
@ -295,59 +345,65 @@ input:
""")
sample1 = self.create_resource(
'sample1', sample_meta_dir, {'ip': '10.0.0.1', 'port': '1000'}
'sample1', sample_meta_dir, {'ip': '10.0.0.1', 'port': 1000}
)
sample2 = self.create_resource(
'sample2', sample_meta_dir, {'ip': '10.0.0.2', 'port': '1001'}
'sample2', sample_meta_dir, {'ip': '10.0.0.2', 'port': 1001}
)
list_input_multi = self.create_resource(
'list-input-multi', list_input_multi_meta_dir, {'ips': [], 'ports': []}
'list-input-multi', list_input_multi_meta_dir, args={'ips': [], 'ports': []}
)
xs.connect(sample1, list_input_multi, mapping={'ip': 'ips', 'port': 'ports'})
self.assertEqual(
[ip['value'] for ip in list_input_multi.args['ips'].value],
self.assertItemsEqual(
#[ip['value'] for ip in list_input_multi.args['ips']],
list_input_multi.args['ips'],
[sample1.args['ip']]
)
self.assertEqual(
[p['value'] for p in list_input_multi.args['ports'].value],
self.assertItemsEqual(
#[p['value'] for p in list_input_multi.args['ports']],
list_input_multi.args['ports'],
[sample1.args['port']]
)
xs.connect(sample2, list_input_multi, mapping={'ip': 'ips', 'port': 'ports'})
self.assertEqual(
[ip['value'] for ip in list_input_multi.args['ips'].value],
self.assertItemsEqual(
#[ip['value'] for ip in list_input_multi.args['ips']],
list_input_multi.args['ips'],
[
sample1.args['ip'],
sample2.args['ip'],
]
)
self.assertListEqual(
[(e['emitter_attached_to'], e['emitter']) for e in list_input_multi.args['ips'].value],
[(sample1.args['ip'].attached_to.name, 'ip'),
(sample2.args['ip'].attached_to.name, 'ip')]
)
self.assertEqual(
[p['value'] for p in list_input_multi.args['ports'].value],
#self.assertListEqual(
# [(e['emitter_attached_to'], e['emitter']) for e in list_input_multi.args['ips']],
# [(sample1.args['ip'].attached_to.name, 'ip'),
# (sample2.args['ip'].attached_to.name, 'ip')]
#)
self.assertItemsEqual(
#[p['value'] for p in list_input_multi.args['ports']],
list_input_multi.args['ports'],
[
sample1.args['port'],
sample2.args['port'],
]
)
self.assertListEqual(
[(e['emitter_attached_to'], e['emitter']) for e in list_input_multi.args['ports'].value],
[(sample1.args['port'].attached_to.name, 'port'),
(sample2.args['port'].attached_to.name, 'port')]
)
#self.assertListEqual(
# [(e['emitter_attached_to'], e['emitter']) for e in list_input_multi.args['ports']],
# [(sample1.args['port'].attached_to.name, 'port'),
# (sample2.args['port'].attached_to.name, 'port')]
#)
# Test disconnect
xs.disconnect(sample2, list_input_multi)
self.assertEqual(
[ip['value'] for ip in list_input_multi.args['ips'].value],
self.assertItemsEqual(
#[ip['value'] for ip in list_input_multi.args['ips']],
list_input_multi.args['ips'],
[sample1.args['ip']]
)
self.assertEqual(
[p['value'] for p in list_input_multi.args['ports'].value],
self.assertItemsEqual(
#[p['value'] for p in list_input_multi.args['ports']],
list_input_multi.args['ports'],
[sample1.args['port']]
)
@ -395,10 +451,10 @@ input:
""")
sample1 = self.create_resource(
'sample1', sample_meta_dir, {'ip': '10.0.0.1', 'port': '1000'}
'sample1', sample_meta_dir, {'ip': '10.0.0.1', 'port': 1000}
)
sample2 = self.create_resource(
'sample2', sample_meta_dir, {'ip': '10.0.0.2', 'port': '1001'}
'sample2', sample_meta_dir, {'ip': '10.0.0.2', 'port': 1001}
)
list_input = self.create_resource(
'list-input', list_input_meta_dir, {}
@ -411,23 +467,26 @@ input:
xs.connect(sample2, list_input, mapping={'ip': 'ips', 'port': 'ports'})
xs.connect(list_input, list_input_nested, mapping={'ips': 'ipss', 'ports': 'portss'})
self.assertListEqual(
[ips['value'] for ips in list_input_nested.args['ipss'].value],
[list_input.args['ips'].value]
#[ips['value'] for ips in list_input_nested.args['ipss']],
list_input_nested.args['ipss'],
[list_input.args['ips']]
)
self.assertListEqual(
[ps['value'] for ps in list_input_nested.args['portss'].value],
[list_input.args['ports'].value]
#[ps['value'] for ps in list_input_nested.args['portss']],
list_input_nested.args['portss'],
[list_input.args['ports']]
)
# Test disconnect
xs.disconnect(sample1, list_input)
self.assertListEqual(
[[ip['value'] for ip in ips['value']] for ips in list_input_nested.args['ipss'].value],
[[sample2.args['ip'].value]]
#[[ip['value'] for ip in ips['value']] for ips in list_input_nested.args['ipss']],
list_input_nested.args['ipss'],
[[sample2.args['ip']]]
)
self.assertListEqual(
[[p['value'] for p in ps['value']] for ps in list_input_nested.args['portss'].value],
[[sample2.args['port'].value]]
list_input_nested.args['portss'],
[[sample2.args['port']]]
)
@ -462,7 +521,3 @@ input:
receiver.args['server'],
)
'''
if __name__ == '__main__':
unittest.main()

View File

@ -21,21 +21,18 @@ from solar.system_log import data
def host_diff():
return [
[u'add', u'', [
[u'ip', {u'emitter': u'node1', u'value': u'10.0.0.3'}],
[u'hosts_names',
[{u'emitter_attached_to': u'riak_service1', u'emitter': u'riak_hostname', u'value': u'riak_server1.solar'},
{u'emitter_attached_to': u'riak_service2', u'emitter': u'riak_hostname', u'value': u'riak_server2.solar'},
{u'emitter_attached_to': u'riak_service3', u'emitter': u'riak_hostname', u'value': u'riak_server3.solar'}]],
[u'ssh_user', {u'emitter': None, u'value': u'vagrant'}],
[u'ssh_key', {u'emitter': u'node1', u'value': u'/vagrant/.vagrant/machines/solar-dev1/virtualbox/private_key'}],
[u'ip', u'10.0.0.3'],
[u'hosts_names', ['riak_server1.solar', 'riak_server2.solar', 'riak_server3.solar']],
[u'ssh_user', u'vagrant'],
[u'ssh_key', u'/vagrant/.vagrant/machines/solar-dev1/virtualbox/private_key'],
]]]
def test_details_for_add(host_diff):
assert data.details(host_diff) == [
'++ ip: node1::10.0.0.3',
"++ hosts_names: ['riak_hostname::riak_server1.solar', 'riak_hostname::riak_server2.solar', 'riak_hostname::riak_server3.solar']",
'++ ssh_user: vagrant', '++ ssh_key: node1::/vagrant/.vagrant/machines/solar-dev1/virtualbox/private_key']
'++ ip: 10.0.0.3',
"++ hosts_names: ['riak_server1.solar', 'riak_server2.solar', 'riak_server3.solar']",
'++ ssh_user: vagrant', '++ ssh_key: /vagrant/.vagrant/machines/solar-dev1/virtualbox/private_key']
@fixture

View File

@ -12,10 +12,9 @@
# License for the specific language governing permissions and limitations
# under the License.
import unittest
from solar.test import base
from solar import errors
from solar.core import validation as sv
@ -37,26 +36,26 @@ input:
r = self.create_resource(
'r1', sample_meta_dir, {'value': 'x', 'value-required': 'y'}
)
errors = sv.validate_resource(r)
self.assertEqual(errors, {})
es = sv.validate_resource(r)
self.assertEqual(es, {})
r = self.create_resource(
'r2', sample_meta_dir, {'value': 1, 'value-required': 'y'}
)
errors = sv.validate_resource(r)
self.assertListEqual(errors.keys(), ['value'])
es = sv.validate_resource(r)
self.assertIn('value', es)
self.assertIn('1 is not valid', es['value'][0])
r = self.create_resource(
'r3', sample_meta_dir, {'value': ''}
)
errors = sv.validate_resource(r)
self.assertListEqual(errors.keys(), ['value-required'])
es = sv.validate_resource(r)
self.assertIn('value-required', es)
self.assertIn("None is not of type 'string'", es['value-required'][0])
r = self.create_resource(
'r4', sample_meta_dir, {'value': None, 'value-required': 'y'}
)
errors = sv.validate_resource(r)
self.assertEqual(errors, {})
def test_input_int_type(self):
sample_meta_dir = self.make_resource_meta("""
@ -75,26 +74,26 @@ input:
r = self.create_resource(
'r1', sample_meta_dir, {'value': 1, 'value-required': 2}
)
errors = sv.validate_resource(r)
self.assertEqual(errors, {})
es = sv.validate_resource(r)
self.assertEqual(es, {})
r = self.create_resource(
'r2', sample_meta_dir, {'value': 'x', 'value-required': 2}
)
errors = sv.validate_resource(r)
self.assertListEqual(errors.keys(), ['value'])
es = sv.validate_resource(r)
self.assertIn('value', es)
self.assertIn("'x' is not valid", es['value'][0])
r = self.create_resource(
'r3', sample_meta_dir, {'value': 1}
)
errors = sv.validate_resource(r)
self.assertListEqual(errors.keys(), ['value-required'])
es = sv.validate_resource(r)
self.assertIn('value-required', es)
self.assertIn("None is not of type 'number'", es['value-required'][0])
r = self.create_resource(
'r4', sample_meta_dir, {'value': None, 'value-required': 2}
)
errors = sv.validate_resource(r)
self.assertEqual(errors, {})
def test_input_dict_type(self):
sample_meta_dir = self.make_resource_meta("""
@ -110,22 +109,23 @@ input:
r = self.create_resource(
'r', sample_meta_dir, {'values': {'a': 1, 'b': 2}}
)
errors = sv.validate_resource(r)
self.assertEqual(errors, {})
es = sv.validate_resource(r)
self.assertEqual(es, {})
r.update({'values': None})
errors = sv.validate_resource(r)
self.assertListEqual(errors.keys(), ['values'])
es = sv.validate_resource(r)
self.assertListEqual(es.keys(), ['values'])
r.update({'values': {'a': 1, 'c': 3}})
errors = sv.validate_resource(r)
self.assertEqual(errors, {})
es = sv.validate_resource(r)
self.assertEqual(es, {})
r = self.create_resource(
'r1', sample_meta_dir, {'values': {'b': 2}}
)
errors = sv.validate_resource(r)
self.assertListEqual(errors.keys(), ['values'])
es = sv.validate_resource(r)
self.assertIn('values', es)
self.assertIn("'a' is a required property", es['values'][0])
def test_complex_input(self):
sample_meta_dir = self.make_resource_meta("""
@ -194,6 +194,3 @@ input:
r.update({'values': {'a': 1, 'c': 3}})
errors = sv.validate_resource(r)
self.assertEqual(errors, {})
if __name__ == '__main__':
unittest.main()