Removed old db implementation

This commit is contained in:
Jedrzej Nowak 2015-11-17 13:01:21 +01:00
parent 7fa43cc6bc
commit 1807219376
24 changed files with 44 additions and 2343 deletions

19
examples/bootstrap/example-bootstrap.py Normal file → Executable file
View File

@ -10,11 +10,7 @@ from solar.core import signals
from solar.core import validation
from solar.core.resource import virtual_resource as vr
from solar import errors
from solar.interfaces.db import get_db
db = get_db()
from solar.dblayer.model import ModelMeta
@click.group()
@ -23,9 +19,7 @@ def main():
def setup_resources():
db.clear()
signals.Connections.clear()
ModelMeta.remove_all()
node2 = vr.create('node2', 'resources/ro_node/', {
'ip': '10.0.0.4',
@ -61,7 +55,7 @@ def deploy():
setup_resources()
# run
resources = map(resource.wrap_resource, db.get_list(collection=db.COLLECTIONS.resource))
resources = resource.load_all()
resources = {r.name: r for r in resources}
for name in resources_to_run:
@ -76,7 +70,7 @@ def deploy():
@click.command()
def undeploy():
resources = map(resource.wrap_resource, db.get_list(collection=db.COLLECTIONS.resource))
resources = resource.load_all()
resources = {r.name: r for r in resources}
for name in reversed(resources_to_run):
@ -85,10 +79,7 @@ def undeploy():
except errors.SolarError as e:
print 'WARNING: %s' % str(e)
db.clear()
signals.Connections.clear()
ModelMeta.remove_all()
main.add_command(deploy)
main.add_command(undeploy)

View File

@ -19,11 +19,9 @@ from solar.core import actions
from solar.core.resource import virtual_resource as vr
from solar.core import resource
from solar.core import signals
from solar.interfaces.db import get_db
from solar.dblayer.model import ModelMeta
from solar.core.resource_provider import GitProvider, RemoteZipProvider
import resources_compiled
@ -34,9 +32,7 @@ def main():
@click.command()
def deploy():
db = get_db()
db.clear()
ModelMeta.remove_all()
signals.Connections.clear()
node1 = resources_compiled.RoNodeResource('node1', None, {})
@ -75,18 +71,16 @@ def deploy():
@click.command()
def undeploy():
db = get_db()
ModelMeta.remove_all()
resources = map(resource.wrap_resource, db.get_list(collection=db.COLLECTIONS.resource))
resources = resource.load_all()
resources = {r.name: r for r in resources}
actions.resource_action(resources['openstack_rabbitmq_user'], 'remove')
actions.resource_action(resources['openstack_vhost'], 'remove')
actions.resource_action(resources['rabbitmq_service1'], 'remove')
db.clear()
signals.Connections.clear()
ModelMeta.remove_all()
main.add_command(deploy)

View File

@ -4,15 +4,11 @@ import time
from solar.core import signals
from solar.core.resource import virtual_resource as vr
from solar.interfaces.db import get_db
db = get_db()
from solar.dblayer.model import ModelMeta
def run():
db.clear()
ModelMeta.remove_all()
resources = vr.create('nodes', 'templates/nodes_with_transports.yaml', {'count': 2})
nodes = [x for x in resources if x.name.startswith('node')]

View File

@ -1,10 +1,8 @@
from solar.core.resource import virtual_resource as vr
from solar.interfaces.db import get_db
from solar.dblayer.model import ModelMeta
import yaml
db = get_db()
STORAGE = {'objects_ceph': True,
'osd_pool_size': 2,
@ -34,7 +32,7 @@ NETWORK_METADATA = yaml.load("""
def deploy():
db.clear()
ModelMeta.remove_all()
resources = vr.create('nodes', 'templates/nodes.yaml', {'count': 2})
first_node, second_node = [x for x in resources if x.name.startswith('node')]
first_transp = next(x for x in resources if x.name.startswith('transport'))

8
examples/lxc/example-lxc.py Normal file → Executable file
View File

@ -12,10 +12,10 @@ import click
from solar.core import signals
from solar.core.resource import virtual_resource as vr
from solar.interfaces.db import get_db
from solar.system_log import change
from solar.cli import orch
from solar.dblayer.model import ModelMeta
@click.group()
def main():
@ -43,9 +43,7 @@ def lxc_template(idx):
@click.command()
def deploy():
db = get_db()
db.clear()
signals.Connections.clear()
ModelMeta.remove_all()
node1 = vr.create('nodes', 'templates/nodes.yaml', {})[0]
seed = vr.create('nodes', 'templates/seed_node.yaml', {})[0]

View File

@ -8,9 +8,7 @@ from solar.core import signals
from solar.core import validation
from solar.core.resource import virtual_resource as vr
from solar import events as evapi
from solar.interfaces.db import get_db
from solar.dblayer.model import ModelMeta
PROFILE = False
#PROFILE = True
@ -35,8 +33,6 @@ if PROFILE:
# Official puppet manifests, not fuel-library
db = get_db()
@click.group()
def main():
@ -247,7 +243,7 @@ def setup_neutron(node, librarian, rabbitmq_service, openstack_rabbitmq_user, op
return {'neutron_puppet': neutron_puppet}
def setup_neutron_api(node, mariadb_service, admin_user, keystone_puppet, services_tenant, neutron_puppet):
# NEUTRON PLUGIN AND NEUTRON API (SERVER)
# NEUTRON PLUGIN AND NEUTRON API (SERVER)
neutron_plugins_ml2 = vr.create('neutron_plugins_ml2', 'resources/neutron_plugins_ml2_puppet', {})[0]
node.connect(neutron_plugins_ml2)
@ -830,7 +826,7 @@ def create_compute(node):
@click.command()
def create_all():
db.clear()
ModelMeta.remove_all()
r = prepare_nodes(2)
r.update(create_controller('node0'))
r.update(create_compute('node1'))
@ -856,7 +852,7 @@ def add_controller(node):
@click.command()
def clear():
db.clear()
ModelMeta.remove_all()
if __name__ == '__main__':

7
examples/riak/riaks-template.py Normal file → Executable file
View File

@ -8,16 +8,13 @@ import click
import sys
from solar.core import resource
from solar.interfaces.db import get_db
from solar import template
db = get_db()
from solar.dblayer.model import ModelMeta
def setup_riak():
db.clear()
ModelMeta.remove_all()
nodes = template.nodes_from('templates/riak_nodes.yaml')
riak_services = nodes.on_each(

View File

@ -22,18 +22,13 @@ from solar import errors
from solar.dblayer.model import ModelMeta
from solar.interfaces.db import get_db
from solar.events.controls import React, Dep
from solar.events.api import add_event
from solar.dblayer.solar_models import Resource
# db = get_db()
def setup_riak():
# db.clear()
ModelMeta.remove_all()
resources = vr.create('nodes', 'templates/nodes.yaml', {'count': 3})

View File

@ -5,16 +5,11 @@ import time
from solar.core import resource
from solar.core import signals
from solar.core.resource import virtual_resource as vr
from solar.interfaces.db import get_db
db = get_db()
from solar.dblayer.model import ModelMeta
def run():
db.clear()
ModelMeta.remove_all()
node = vr.create('node', 'resources/ro_node', {'name': 'first' + str(time.time()),
'ip': '10.0.0.3',

View File

@ -2,15 +2,11 @@ import time
from solar.core.resource import virtual_resource as vr
from solar import errors
from solar.interfaces.db import get_db
db = get_db()
from solar.dblayer.model import ModelMeta
def run():
db.clear()
ModelMeta.remove_all()
node = vr.create('node', 'resources/ro_node', {'name': 'first' + str(time.time()),
'ip': '10.0.0.3',

View File

@ -34,7 +34,6 @@ from solar.core.tags_set_parser import Expression
from solar.core.resource import virtual_resource as vr
from solar.core.log import log
from solar import errors
from solar.interfaces import orm
from solar import utils
from solar.cli import base
@ -78,25 +77,26 @@ def init_actions():
@click.option('-d', '--dry-run', default=False, is_flag=True)
@click.option('-m', '--dry-run-mapping', default='{}')
def run(dry_run_mapping, dry_run, action, tags):
if dry_run:
dry_run_executor = executors.DryRunExecutor(mapping=json.loads(dry_run_mapping))
raise NotImplementedError("Not yet implemented")
# if dry_run:
# dry_run_executor = executors.DryRunExecutor(mapping=json.loads(dry_run_mapping))
resources = filter(
lambda r: Expression(tags, r.tags).evaluate(),
orm.DBResource.all()
)
# resources = filter(
# lambda r: Expression(tags, r.tags).evaluate(),
# orm.DBResource.all()
# )
for r in resources:
resource_obj = sresource.load(r['id'])
actions.resource_action(resource_obj, action)
# for r in resources:
# resource_obj = sresource.load(r['id'])
# actions.resource_action(resource_obj, action)
if dry_run:
click.echo('EXECUTED:')
for key in dry_run_executor.executed:
click.echo('{}: {}'.format(
click.style(dry_run_executor.compute_hash(key), fg='green'),
str(key)
))
# if dry_run:
# click.echo('EXECUTED:')
# for key in dry_run_executor.executed:
# click.echo('{}: {}'.format(
# click.style(dry_run_executor.compute_hash(key), fg='green'),
# str(key)
# ))
def init_cli_connect():

View File

@ -25,7 +25,6 @@ from solar.core import resource as sresource
from solar.core.resource import virtual_resource as vr
from solar.core.log import log
from solar import errors
from solar.interfaces import orm
from solar import utils
from solar.cli import executors
@ -120,7 +119,6 @@ def clear_all():
click.echo('Clearing all resources and connections')
ModelMeta.remove_all()
# orm.db.clear()
@resource.command()
@click.argument('name')

View File

@ -22,7 +22,6 @@ import os
from solar import utils
from solar.core import validation
from solar.interfaces import orm
from solar.core import signals
from solar.events import api

View File

@ -16,7 +16,6 @@
import networkx
from solar.core.log import log
# from solar.interfaces import orm
from solar.dblayer.solar_models import Resource as DBResource
@ -143,91 +142,6 @@ def get_mapping(emitter, receiver, mapping=None):
def connect(emitter, receiver, mapping=None):
emitter.connect(receiver, mapping)
# def connect(emitter, receiver, mapping=None):
# if mapping is None:
# mapping = guess_mapping(emitter, receiver)
# # XXX: we didn't agree on that "reverse" there
# location_and_transports(emitter, receiver, mapping)
# if isinstance(mapping, set):
# mapping = {src: src for src in mapping}
# for src, dst in mapping.items():
# if not isinstance(dst, list):
# dst = [dst]
# for d in dst:
# connect_single(emitter, src, receiver, d)
# def connect_single(emitter, src, receiver, dst):
# if ':' in dst:
# return connect_multi(emitter, src, receiver, dst)
# # Disconnect all receiver inputs
# # Check if receiver input is of list type first
# emitter_input = emitter.resource_inputs()[src]
# receiver_input = receiver.resource_inputs()[dst]
# if emitter_input.id == receiver_input.id:
# raise Exception(
# 'Trying to connect {} to itself, this is not possible'.format(
# emitter_input.id)
# )
# if not receiver_input.is_list:
# receiver_input.receivers.delete_all_incoming(receiver_input)
# # Check for cycles
# # TODO: change to get_paths after it is implemented in drivers
# if emitter_input in receiver_input.receivers.as_set():
# raise Exception('Prevented creating a cycle on %s::%s' % (emitter.name,
# emitter_input.name))
# log.debug('Connecting {}::{} -> {}::{}'.format(
# emitter.name, emitter_input.name, receiver.name, receiver_input.name
# ))
# emitter_input.receivers.add(receiver_input)
# def connect_multi(emitter, src, receiver, dst):
# receiver_input_name, receiver_input_key = dst.split(':')
# if '|' in receiver_input_key:
# receiver_input_key, receiver_input_tag = receiver_input_key.split('|')
# else:
# receiver_input_tag = None
# emitter_input = emitter.resource_inputs()[src]
# receiver_input = receiver.resource_inputs()[receiver_input_name]
# if not receiver_input.is_list or receiver_input_tag:
# receiver_input.receivers.delete_all_incoming(
# receiver_input,
# destination_key=receiver_input_key,
# tag=receiver_input_tag
# )
# # We can add default tag now
# receiver_input_tag = receiver_input_tag or emitter.name
# # NOTE: make sure that receiver.args[receiver_input] is of dict type
# if not receiver_input.is_hash:
# raise Exception(
# 'Receiver input {} must be a hash or a list of hashes'.format(receiver_input_name)
# )
# log.debug('Connecting {}::{} -> {}::{}[{}], tag={}'.format(
# emitter.name, emitter_input.name, receiver.name, receiver_input.name,
# receiver_input_key,
# receiver_input_tag
# ))
# emitter_input.receivers.add_hash(
# receiver_input,
# receiver_input_key,
# tag=receiver_input_tag
# )
def disconnect_receiver_by_input(receiver, input_name):
# input_node = receiver.resource_inputs()[input_name]
@ -236,12 +150,6 @@ def disconnect_receiver_by_input(receiver, input_name):
receiver.db_obj.inputs.disconnect(input_name)
# def disconnect(emitter, receiver):
# for emitter_input in emitter.resource_inputs().values():
# for receiver_input in receiver.resource_inputs().values():
# emitter_input.receivers.remove(receiver_input)
def detailed_connection_graph(start_with=None, end_with=None, details=False):
from solar.core.resource import Resource, load_all

View File

@ -18,7 +18,6 @@ __all__ = ['add_dep', 'add_react', 'Dep', 'React', 'add_event']
import networkx as nx
from solar.core.log import log
from solar.interfaces import orm
from solar.events.controls import Dep, React, StateChange
from solar.dblayer.solar_models import Resource

View File

@ -1,38 +0,0 @@
# Copyright 2015 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import importlib
db_backends = {
'neo4j_db': ('solar.interfaces.db.neo4j', 'Neo4jDB'),
'redis_db': ('solar.interfaces.db.redis_db', 'RedisDB'),
'fakeredis_db': ('solar.interfaces.db.redis_db', 'FakeRedisDB'),
'redis_graph_db': ('solar.interfaces.db.redis_graph_db', 'RedisGraphDB'),
'fakeredis_graph_db': ('solar.interfaces.db.redis_graph_db', 'FakeRedisGraphDB'),
}
CURRENT_DB = 'redis_graph_db'
#CURRENT_DB = 'neo4j_db'
DB = None
def get_db(backend=CURRENT_DB):
# Should be retrieved from config
global DB
if DB is None:
import_path, klass = db_backends[backend]
module = importlib.import_module(import_path)
DB = getattr(module, klass)()
return DB

View File

@ -1,236 +0,0 @@
import abc
from enum import Enum
from functools import partial
class Node(object):
def __init__(self, db, uid, labels, properties):
self.db = db
self.uid = uid
self.labels = labels
self.properties = properties
@property
def collection(self):
return getattr(
BaseGraphDB.COLLECTIONS,
list(self.labels)[0]
)
class Relation(object):
def __init__(self, db, start_node, end_node, properties):
self.db = db
self.start_node = start_node
self.end_node = end_node
self.properties = properties
class DBObjectMeta(abc.ABCMeta):
# Tuples of: function name, is-multi (i.e. returns a list)
node_db_read_methods = [
('all', True),
('create', False),
('get', False),
('get_or_create', False),
]
relation_db_read_methods = [
('all_relations', True),
('create_relation', False),
('get_relations', True),
('get_relation', False),
('get_or_create_relation', False),
]
def __new__(cls, name, parents, dct):
def from_db_list_decorator(converting_func, method):
def wrapper(self, *args, **kwargs):
db_convert = kwargs.pop('db_convert', True)
result = method(self, *args, **kwargs)
if db_convert:
return map(partial(converting_func, self), result)
return result
return wrapper
def from_db_decorator(converting_func, method):
def wrapper(self, *args, **kwargs):
db_convert = kwargs.pop('db_convert', True)
result = method(self, *args, **kwargs)
if result is None:
return
if db_convert:
return converting_func(self, result)
return result
return wrapper
node_db_to_object = cls.find_method(
'node_db_to_object', name, parents, dct
)
relation_db_to_object = cls.find_method(
'relation_db_to_object', name, parents, dct
)
# Node conversions
for method_name, is_list in cls.node_db_read_methods:
method = cls.find_method(method_name, name, parents, dct)
if is_list:
func = from_db_list_decorator
else:
func = from_db_decorator
# Handle subclasses
if not getattr(method, '_wrapped', None):
dct[method_name] = func(node_db_to_object, method)
setattr(dct[method_name], '_wrapped', True)
# Relation conversions
for method_name, is_list in cls.relation_db_read_methods:
method = cls.find_method(method_name, name, parents, dct)
if is_list:
func = from_db_list_decorator
else:
func = from_db_decorator
# Handle subclasses
if not getattr(method, '_wrapped', None):
dct[method_name] = func(relation_db_to_object, method)
setattr(dct[method_name], '_wrapped', True)
return super(DBObjectMeta, cls).__new__(cls, name, parents, dct)
@classmethod
def find_method(cls, method_name, class_name, parents, dict):
if method_name in dict:
return dict[method_name]
for parent in parents:
method = getattr(parent, method_name)
if method:
return method
raise NotImplementedError(
'"{}" method not implemented in class {}'.format(
method_name, class_name
)
)
class BaseGraphDB(object):
__metaclass__ = DBObjectMeta
COLLECTIONS = Enum(
'Collections',
'input resource state_data state_log plan_node plan_graph events stage_log commit_log resource_events'
)
DEFAULT_COLLECTION=COLLECTIONS.resource
RELATION_TYPES = Enum(
'RelationTypes',
'input_to_input resource_input plan_edge graph_to_node resource_event commited'
)
DEFAULT_RELATION=RELATION_TYPES.resource_input
@staticmethod
def node_db_to_object(node_db):
"""Convert node DB object to Node object."""
@staticmethod
def object_to_node_db(node_obj):
"""Convert Node object to node DB object."""
@staticmethod
def relation_db_to_object(relation_db):
"""Convert relation DB object to Relation object."""
@staticmethod
def object_to_relation_db(relation_obj):
"""Convert Relation object to relation DB object."""
@abc.abstractmethod
def all(self, collection=DEFAULT_COLLECTION):
"""Return all elements (nodes) of type `collection`."""
@abc.abstractmethod
def all_relations(self, type_=DEFAULT_RELATION):
"""Return all relations of type `type_`."""
@abc.abstractmethod
def clear(self):
"""Clear the whole DB."""
@abc.abstractmethod
def clear_collection(self, collection=DEFAULT_COLLECTION):
"""Clear all elements (nodes) of type `collection`."""
@abc.abstractmethod
def create(self, name, properties={}, collection=DEFAULT_COLLECTION):
"""Create element (node) with given name, args, of type `collection`."""
@abc.abstractmethod
def delete(self, name, collection=DEFAULT_COLLECTION):
"""Delete element with given name. of type `collection`."""
@abc.abstractmethod
def create_relation(self,
source,
dest,
properties={},
type_=DEFAULT_RELATION):
"""
Create relation (connection) of type `type_` from source to dest with
given args.
"""
@abc.abstractmethod
def get(self, name, collection=DEFAULT_COLLECTION):
"""Fetch element with given name and collection type."""
@abc.abstractmethod
def get_or_create(self,
name,
properties={},
collection=DEFAULT_COLLECTION):
"""
Fetch or create element (if not exists) with given name, args of type
`collection`.
"""
@abc.abstractmethod
def delete_relations(self,
source=None,
dest=None,
type_=DEFAULT_RELATION,
has_properties=None):
"""Delete all relations of type `type_` from source to dest."""
@abc.abstractmethod
def get_relations(self,
source=None,
dest=None,
type_=DEFAULT_RELATION,
has_properties=None):
"""Fetch all relations of type `type_` from source to dest.
NOTE that this function must return only direct relations (edges)
between vertices `source` and `dest` of type `type_`.
If you want all PATHS between `source` and `dest`, write another
method for this (`get_paths`)."""
@abc.abstractmethod
def get_relation(self, source, dest, type_=DEFAULT_RELATION):
"""Fetch relations with given source, dest and type_."""
@abc.abstractmethod
def get_or_create_relation(self,
source,
dest,
properties={},
type_=DEFAULT_RELATION):
"""Fetch or create relation with given args."""

View File

@ -1,205 +0,0 @@
import json
from copy import deepcopy
import py2neo
from solar.core import log
from .base import BaseGraphDB, Node, Relation
class Neo4jDB(BaseGraphDB):
DB = {
'host': 'localhost',
'port': 7474,
}
NEO4J_CLIENT = py2neo.Graph
def __init__(self):
self._r = self.NEO4J_CLIENT('http://{host}:{port}/db/data/'.format(
**self.DB
))
def node_db_to_object(self, node_db):
return Node(
self,
node_db.properties['name'],
node_db.labels,
# Neo4j Node.properties is some strange PropertySet, use dict instead
dict(**node_db.properties)
)
def relation_db_to_object(self, relation_db):
return Relation(
self,
self.node_db_to_object(relation_db.start_node),
self.node_db_to_object(relation_db.end_node),
relation_db.properties
)
def all(self, collection=BaseGraphDB.DEFAULT_COLLECTION):
return [
r.n for r in self._r.cypher.execute(
'MATCH (n:%(collection)s) RETURN n' % {
'collection': collection.name,
}
)
]
def all_relations(self, type_=BaseGraphDB.DEFAULT_RELATION):
return [
r.r for r in self._r.cypher.execute(
*self._relations_query(
source=None, dest=None, type_=type_
)
)
]
def clear(self):
log.log.debug('Clearing whole DB')
self._r.delete_all()
def clear_collection(self, collection=BaseGraphDB.DEFAULT_COLLECTION):
log.log.debug('Clearing collection %s', collection.name)
# TODO: make single DELETE query
self._r.delete([r.n for r in self.all(collection=collection)])
def create(self, name, properties={}, collection=BaseGraphDB.DEFAULT_COLLECTION):
log.log.debug(
'Creating %s, name %s with properties %s',
collection.name,
name,
properties
)
properties = deepcopy(properties)
properties['name'] = name
n = py2neo.Node(collection.name, **properties)
return self._r.create(n)[0]
def create_relation(self,
source,
dest,
properties={},
type_=BaseGraphDB.DEFAULT_RELATION):
log.log.debug(
'Creating %s from %s to %s with properties %s',
type_.name,
source.properties['name'],
dest.properties['name'],
properties
)
s = self.get(
source.properties['name'],
collection=source.collection,
db_convert=False
)
d = self.get(
dest.properties['name'],
collection=dest.collection,
db_convert=False
)
r = py2neo.Relationship(s, type_.name, d, **properties)
self._r.create(r)
return r
def _get_query(self, name, collection=BaseGraphDB.DEFAULT_COLLECTION):
return 'MATCH (n:%(collection)s {name:{name}}) RETURN n' % {
'collection': collection.name,
}, {
'name': name,
}
def get(self, name, collection=BaseGraphDB.DEFAULT_COLLECTION):
query, kwargs = self._get_query(name, collection=collection)
res = self._r.cypher.execute(query, kwargs)
if res:
return res[0].n
def get_or_create(self,
name,
properties={},
collection=BaseGraphDB.DEFAULT_COLLECTION):
n = self.get(name, collection=collection, db_convert=False)
if n:
if properties != n.properties:
n.properties.update(properties)
n.push()
return n
return self.create(name, properties=properties, collection=collection)
def _relations_query(self,
source=None,
dest=None,
type_=BaseGraphDB.DEFAULT_RELATION,
query_type='RETURN'):
kwargs = {}
source_query = '(n)'
if source:
source_query = '(n {name:{source_name}})'
kwargs['source_name'] = source.properties['name']
dest_query = '(m)'
if dest:
dest_query = '(m {name:{dest_name}})'
kwargs['dest_name'] = dest.properties['name']
rel_query = '[r:%(type_)s]' % {'type_': type_.name}
query = ('MATCH %(source_query)s-%(rel_query)s->'
'%(dest_query)s %(query_type)s r' % {
'dest_query': dest_query,
'query_type': query_type,
'rel_query': rel_query,
'source_query': source_query,
})
return query, kwargs
def delete_relations(self,
source=None,
dest=None,
type_=BaseGraphDB.DEFAULT_RELATION):
query, kwargs = self._relations_query(
source=source, dest=dest, type_=type_, query_type='DELETE'
)
self._r.cypher.execute(query, kwargs)
def get_relations(self,
source=None,
dest=None,
type_=BaseGraphDB.DEFAULT_RELATION):
query, kwargs = self._relations_query(
source=source, dest=dest, type_=type_
)
res = self._r.cypher.execute(query, kwargs)
return [r.r for r in res]
def get_relation(self, source, dest, type_=BaseGraphDB.DEFAULT_RELATION):
rel = self.get_relations(source=source, dest=dest, type_=type_)
if rel:
return rel[0]
def get_or_create_relation(self,
source,
dest,
properties={},
type_=BaseGraphDB.DEFAULT_RELATION):
rel = self.get_relations(source=source, dest=dest, type_=type_)
if rel:
r = rel[0]
if properties != r.properties:
r.properties.update(properties)
r.push()
return r
return self.create_relation(source, dest, properties=properties, type_=type_)

View File

@ -1,156 +0,0 @@
# Copyright 2015 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from enum import Enum
try:
import ujson as json
except ImportError:
import json
import redis
import fakeredis
class RedisDB(object):
COLLECTIONS = Enum(
'Collections',
'connection resource state_data state_log events'
)
DB = {
'host': 'localhost',
'port': 6379,
}
REDIS_CLIENT = redis.StrictRedis
def __init__(self):
self._r = self.REDIS_CLIENT(**self.DB)
self.entities = {}
def read(self, uid, collection=COLLECTIONS.resource):
try:
return json.loads(
self._r.get(self._make_key(collection, uid))
)
except TypeError:
return None
def get_list(self, collection=COLLECTIONS.resource):
key_glob = self._make_key(collection, '*')
keys = self._r.keys(key_glob)
with self._r.pipeline() as pipe:
pipe.multi()
values = [self._r.get(key) for key in keys]
pipe.execute()
for value in values:
yield json.loads(value)
def save(self, uid, data, collection=COLLECTIONS.resource):
ret = self._r.set(
self._make_key(collection, uid),
json.dumps(data)
)
return ret
def save_list(self, lst, collection=COLLECTIONS.resource):
with self._r.pipeline() as pipe:
pipe.multi()
for uid, data in lst:
key = self._make_key(collection, uid)
pipe.set(key, json.dumps(data))
pipe.execute()
def clear(self):
self._r.flushdb()
def get_ordered_hash(self, collection):
return OrderedHash(self._r, collection)
def clear_collection(self, collection=COLLECTIONS.resource):
key_glob = self._make_key(collection, '*')
self._r.delete(self._r.keys(key_glob))
def delete(self, uid, collection=COLLECTIONS.resource):
self._r.delete(self._make_key(collection, uid))
def _make_key(self, collection, _id):
if isinstance(collection, self.COLLECTIONS):
collection = collection.name
# NOTE: hiera-redis backend depends on this!
return '{0}:{1}'.format(collection, _id)
class OrderedHash(object):
def __init__(self, client, collection):
self.r = client
self.collection = collection
self.order_counter = '{}:incr'.format(collection)
self.order = '{}:order'.format(collection)
def add(self, items):
pipe = self.r.pipeline()
for key, value in items:
count = self.r.incr(self.order_counter)
pipe.zadd(self.order, count, key)
pipe.hset(self.collection, key, json.dumps(value))
pipe.execute()
def rem(self, keys):
pipe = self.r.pipeline()
for key in keys:
pipe.zrem(self.order, key)
pipe.hdel(self.collection, key)
pipe.execute()
def get(self, key):
value = self.r.hget(self.collection, key)
if value:
return json.loads(value)
return None
def update(self, key, value):
self.r.hset(self.collection, key, json.dumps(value))
def clean(self):
self.rem(self.r.zrange(self.order, 0, -1))
def rem_left(self, n=1):
self.rem(self.r.zrevrange(self.order, 0, n-1))
def reverse(self, n=1):
result = []
for key in self.r.zrevrange(self.order, 0, n-1):
result.append(self.get(key))
return result
def list(self, n=0):
result = []
for key in self.r.zrange(self.order, 0, n-1):
result.append(self.get(key))
return result
class FakeRedisDB(RedisDB):
REDIS_CLIENT = fakeredis.FakeStrictRedis

View File

@ -1,300 +0,0 @@
try:
import ujson as json
except ImportError:
import json
import redis
import fakeredis
from .base import BaseGraphDB, Node, Relation
from .redis_db import OrderedHash
class RedisGraphDB(BaseGraphDB):
DB = {
'host': 'localhost',
'port': 6379,
}
REDIS_CLIENT = redis.StrictRedis
def __init__(self):
self._r = self.REDIS_CLIENT(**self.DB)
self.entities = {}
def node_db_to_object(self, node_db):
if isinstance(node_db, Node):
return node_db
return Node(
self,
node_db['name'],
[node_db['collection']],
node_db['properties']
)
def relation_db_to_object(self, relation_db):
if isinstance(relation_db, Relation):
return relation_db
if relation_db['type_'] == BaseGraphDB.RELATION_TYPES.input_to_input.name:
source_collection = BaseGraphDB.COLLECTIONS.input
dest_collection = BaseGraphDB.COLLECTIONS.input
elif relation_db['type_'] == BaseGraphDB.RELATION_TYPES.resource_input.name:
source_collection = BaseGraphDB.COLLECTIONS.resource
dest_collection = BaseGraphDB.COLLECTIONS.input
elif relation_db['type_'] == BaseGraphDB.RELATION_TYPES.resource_event.name:
source_collection = BaseGraphDB.COLLECTIONS.resource_events
dest_collection = BaseGraphDB.COLLECTIONS.events
source = self.get(relation_db['source'], collection=source_collection)
dest = self.get(relation_db['dest'], collection=dest_collection)
return Relation(
self,
source,
dest,
relation_db['properties']
)
def all(self, collection=BaseGraphDB.DEFAULT_COLLECTION):
"""Return all elements (nodes) of type `collection`."""
key_glob = self._make_collection_key(collection, '*')
for result in self._all(key_glob):
yield result
def all_relations(self, type_=BaseGraphDB.DEFAULT_RELATION):
"""Return all relations of type `type_`."""
key_glob = self._make_relation_key(type_, '*')
for result in self._all(key_glob):
yield result
def _all(self, key_glob):
keys = self._r.keys(key_glob)
with self._r.pipeline() as pipe:
pipe.multi()
values = [self._r.get(key) for key in keys]
pipe.execute()
for value in values:
yield json.loads(value)
def clear(self):
"""Clear the whole DB."""
self._r.flushdb()
def clear_collection(self, collection=BaseGraphDB.DEFAULT_COLLECTION):
"""Clear all elements (nodes) of type `collection`."""
key_glob = self._make_collection_key(collection, '*')
self._r.delete(self._r.keys(key_glob))
def create(self, name, properties={}, collection=BaseGraphDB.DEFAULT_COLLECTION):
"""Create element (node) with given name, properties, of type `collection`."""
if isinstance(collection, self.COLLECTIONS):
collection = collection.name
properties = {
'name': name,
'properties': properties,
'collection': collection,
}
self._r.set(
self._make_collection_key(collection, name),
json.dumps(properties)
)
return properties
def create_relation(self,
source,
dest,
properties={},
type_=BaseGraphDB.DEFAULT_RELATION):
"""
Create relation (connection) of type `type_` from source to dest with
given properties.
"""
return self.create_relation_str(
source.uid, dest.uid, properties, type_=type_)
def create_relation_str(self, source, dest,
properties={}, type_=BaseGraphDB.DEFAULT_RELATION):
if isinstance(type_, self.RELATION_TYPES):
type_ = type_.name
uid = self._make_relation_uid(source, dest)
properties = {
'source': source,
'dest': dest,
'properties': properties,
'type_': type_,
}
self._r.set(
self._make_relation_key(type_, uid),
json.dumps(properties)
)
return properties
def get(self, name, collection=BaseGraphDB.DEFAULT_COLLECTION,
return_empty=False):
"""Fetch element with given name and collection type."""
try:
collection_key = self._make_collection_key(collection, name)
item = self._r.get(collection_key)
if not item and return_empty:
return item
return json.loads(item)
except TypeError:
raise KeyError(collection_key)
def delete(self, name, collection=BaseGraphDB.DEFAULT_COLLECTION):
keys = self._r.keys(self._make_collection_key(collection, name))
if keys:
self._r.delete(*keys)
def get_or_create(self,
name,
properties={},
collection=BaseGraphDB.DEFAULT_COLLECTION):
"""
Fetch or create element (if not exists) with given name, properties of
type `collection`.
"""
try:
return self.get(name, collection=collection)
except KeyError:
return self.create(name, properties=properties, collection=collection)
def _relations_glob(self,
source=None,
dest=None,
type_=BaseGraphDB.DEFAULT_RELATION):
if source is None:
source = '*'
else:
source = source.uid
if dest is None:
dest = '*'
else:
dest = dest.uid
return self._make_relation_key(type_, self._make_relation_uid(source, dest))
def delete_relations(self,
source=None,
dest=None,
type_=BaseGraphDB.DEFAULT_RELATION,
has_properties=None):
"""Delete all relations of type `type_` from source to dest."""
glob = self._relations_glob(source=source, dest=dest, type_=type_)
keys = self._r.keys(glob)
if not keys:
return
if not has_properties:
self._r.delete(*keys)
rels = self.get_relations(
source=source, dest=dest, type_=type_, has_properties=has_properties
)
for r in rels:
self.delete_relations(
source=r.start_node,
dest=r.end_node,
type_=type_
)
def get_relations(self,
source=None,
dest=None,
type_=BaseGraphDB.DEFAULT_RELATION,
has_properties=None):
"""Fetch all relations of type `type_` from source to dest."""
glob = self._relations_glob(source=source, dest=dest, type_=type_)
def check_has_properties(r):
if has_properties:
for k, v in has_properties.items():
if not r['properties'].get(k) == v:
return False
return True
for r in self._all(glob):
# Glob is primitive, we must filter stuff correctly here
if source and r['source'] != source.uid:
continue
if dest and r['dest'] != dest.uid:
continue
if not check_has_properties(r):
continue
yield r
def get_relation(self, source, dest, type_=BaseGraphDB.DEFAULT_RELATION):
"""Fetch relations with given source, dest and type_."""
uid = self._make_relation_key(source.uid, dest.uid)
try:
return json.loads(
self._r.get(self._make_relation_key(type_, uid))
)
except TypeError:
raise KeyError
def get_or_create_relation(self,
source,
dest,
properties=None,
type_=BaseGraphDB.DEFAULT_RELATION):
"""Fetch or create relation with given properties."""
properties = properties or {}
try:
return self.get_relation(source, dest, type_=type_)
except KeyError:
return self.create_relation(source, dest, properties=properties, type_=type_)
def _make_collection_key(self, collection, _id):
if isinstance(collection, self.COLLECTIONS):
collection = collection.name
# NOTE: hiera-redis backend depends on this!
return '{0}:{1}'.format(collection, _id)
def _make_relation_uid(self, source, dest):
"""
There can be only one relation from source to dest, that's why
this function works.
"""
return '{0}-{1}'.format(source, dest)
def _make_relation_key(self, type_, _id):
if isinstance(type_, self.RELATION_TYPES):
type_ = type_.name
# NOTE: hiera-redis backend depends on this!
return '{0}:{1}'.format(type_, _id)
def get_ordered_hash(self, collection):
return OrderedHash(self._r, collection)
class FakeRedisGraphDB(RedisGraphDB):
REDIS_CLIENT = fakeredis.FakeStrictRedis

View File

@ -1,735 +0,0 @@
# Copyright 2015 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import inspect
import networkx
import uuid
from solar import errors
from solar.core import validation
from solar.interfaces.db import base
from solar.interfaces.db import get_db
import os
# USE_CACHE could be set only from CLI
USE_CACHE = int(os.getenv("USE_CACHE", 0))
db = get_db()
from functools import wraps
def _delete_from(store):
def _wrp(key):
try:
del store[key]
except KeyError:
pass
return _wrp
def cache_me(store):
def _inner(f):
# attaching to functions even when no cache enabled for consistency
f._cache_store = store
f._cache_del = _delete_from(store)
@wraps(f)
def _inner2(obj, *args, **kwargs):
try:
return store[obj.id]
except KeyError:
pass
val = f(obj, *args, **kwargs)
if obj.id.startswith('location_id'):
if not val.value:
return val
if obj.id.startswith('transports_id'):
if not val.value:
return val
if isinstance(val, list):
return val
else:
if not val.value:
return val
store[obj.id] = val
return val
if USE_CACHE:
return _inner2
else:
return f
return _inner
class DBField(object):
is_primary = False
schema = None
schema_in_field = None
default_value = None
def __init__(self, name, value=None):
self.name = name
self.value = value
if value is None:
self.value = self.default_value
def __eq__(self, inst):
return self.name == inst.name and self.value == inst.value
def __ne__(self, inst):
return not self.__eq__(inst)
def __hash__(self):
return hash('{}:{}'.format(self.name, self.value))
def validate(self):
if self.schema is None:
return
es = validation.validate_input(self.value, schema=self.schema)
if es:
raise errors.ValidationError('"{}": {}'.format(self.name, es[0]))
def db_field(schema=None,
schema_in_field=None,
default_value=None,
is_primary=False):
"""Definition for the DB field.
schema - simple schema according to the one in solar.core.validation
schema_in_field - if you don't want to fix schema, you can specify
another field in DBObject that will represent the schema used
for validation of this field
is_primary - only one field in db object can be primary. This key is used
for creating key in the DB
"""
class DBFieldX(DBField):
pass
DBFieldX.is_primary = is_primary
DBFieldX.schema = schema
DBFieldX.schema_in_field = schema_in_field
if default_value is not None:
DBFieldX.default_value = default_value
return DBFieldX
class DBRelatedField(object):
source_db_class = None
destination_db_class = None
relation_type = None
def __init__(self, name, source_db_object):
self.name = name
self.source_db_object = source_db_object
@classmethod
def graph(self):
relations = db.get_relations(type_=self.relation_type)
g = networkx.MultiDiGraph()
for r in relations:
source = self.source_db_class(**r.start_node.properties)
dest = self.destination_db_class(**r.end_node.properties)
properties = r.properties.copy()
g.add_edge(
source,
dest,
attr_dict=properties
)
return g
def all(self):
source_db_node = self.source_db_object._db_node
if source_db_node is None:
return []
return db.get_relations(source=source_db_node,
type_=self.relation_type)
def all_by_dest(self, destination_db_object):
destination_db_node = destination_db_object._db_node
if destination_db_node is None:
return set()
return db.get_relations(dest=destination_db_node,
type_=self.relation_type)
def add(self, *destination_db_objects):
for dest in destination_db_objects:
if not isinstance(dest, self.destination_db_class):
raise errors.SolarError(
'Object {} is of incompatible type {}.'.format(
dest, self.destination_db_class
)
)
db.get_or_create_relation(
self.source_db_object._db_node,
dest._db_node,
properties={},
type_=self.relation_type
)
def add_hash(self, destination_db_object, destination_key, tag=None):
if not isinstance(destination_db_object, self.destination_db_class):
raise errors.SolarError(
'Object {} is of incompatible type {}.'.format(
destination_db_object, self.destination_db_class
)
)
db.get_or_create_relation(
self.source_db_object._db_node,
destination_db_object._db_node,
properties={'destination_key': destination_key, 'tag': tag},
type_=self.relation_type
)
def remove(self, *destination_db_objects):
for dest in destination_db_objects:
db.delete_relations(
source=self.source_db_object._db_node,
dest=dest._db_node,
type_=self.relation_type
)
def as_set(self):
"""
Return DB objects that are destinations for self.source_db_object.
"""
relations = self.all()
ret = set()
for rel in relations:
ret.add(
self.destination_db_class(**rel.end_node.properties)
)
return ret
def as_list(self):
relations = self.all()
ret = []
for rel in relations:
ret.append(
self.destination_db_class(**rel.end_node.properties)
)
return ret
def sources(self, destination_db_object):
"""
Reverse of self.as_set, i.e. for given destination_db_object,
return source DB objects.
"""
relations = self.all_by_dest(destination_db_object)
ret = set()
for rel in relations:
ret.add(
self.source_db_class(**rel.start_node.properties)
)
return ret
def delete_all_incoming(self,
destination_db_object,
destination_key=None,
tag=None):
"""
Delete all relations for which destination_db_object is an end node.
If object is a hash, you can additionally specify the dst_key argument.
Then only connections that are destinations of dst_key will be deleted.
Same with tag.
"""
properties = {}
if destination_key is not None:
properties['destination_key'] = destination_key
if tag is not None:
properties['tag'] = tag
db.delete_relations(
dest=destination_db_object._db_node,
type_=self.relation_type,
has_properties=properties or None
)
def db_related_field(relation_type, destination_db_class):
class DBRelatedFieldX(DBRelatedField):
pass
DBRelatedFieldX.relation_type = relation_type
DBRelatedFieldX.destination_db_class = destination_db_class
return DBRelatedFieldX
class DBObjectMeta(type):
def __new__(cls, name, parents, dct):
collection = dct.get('_collection')
if not collection:
raise NotImplementedError('Collection is required.')
dct['_meta'] = {}
dct['_meta']['fields'] = {}
dct['_meta']['related_to'] = {}
has_primary = False
for field_name, field_klass in dct.items():
if not inspect.isclass(field_klass):
continue
if issubclass(field_klass, DBField):
dct['_meta']['fields'][field_name] = field_klass
if field_klass.is_primary:
if has_primary:
raise errors.SolarError('Object cannot have 2 primary fields.')
has_primary = True
dct['_meta']['primary'] = field_name
elif issubclass(field_klass, DBRelatedField):
dct['_meta']['related_to'][field_name] = field_klass
if not has_primary:
raise errors.SolarError('Object needs to have a primary field.')
klass = super(DBObjectMeta, cls).__new__(cls, name, parents, dct)
# Support for self-references in relations
for field_name, field_klass in klass._meta['related_to'].items():
field_klass.source_db_class = klass
if field_klass.destination_db_class == klass.__name__:
field_klass.destination_db_class = klass
return klass
class DBObject(object):
# Enum from BaseGraphDB.COLLECTIONS
_collection = None
def __init__(self, **kwargs):
wrong_fields = set(kwargs) - set(self._meta['fields'])
if wrong_fields:
raise errors.SolarError(
'Unknown fields {}'.format(wrong_fields)
)
self._fields = {}
for field_name, field_klass in self._meta['fields'].items():
value = kwargs.get(field_name, field_klass.default_value)
self._fields[field_name] = field_klass(field_name, value=value)
self._related_to = {}
for field_name, field_klass in self._meta['related_to'].items():
inst = field_klass(field_name, self)
self._related_to[field_name] = inst
self._update_values()
def __eq__(self, inst):
# NOTE: don't compare related fields
self._update_fields_values()
return self._fields == inst._fields
def __ne__(self, inst):
return not self.__eq__(inst)
def __hash__(self):
return hash(self._db_key)
def _update_fields_values(self):
"""Copy values from self to self._fields."""
for field in self._fields.values():
field.value = getattr(self, field.name)
def _update_values(self):
"""
Reverse of _update_fields_values, i.e. copy values from self._fields to
self."""
for field in self._fields.values():
setattr(self, field.name, field.value)
for field in self._related_to.values():
setattr(self, field.name, field)
@property
def _db_key(self):
"""Key for the DB document (in KV-store).
You can overwrite this with custom keys."""
if not self._primary_field.value:
setattr(self, self._primary_field.name, unicode(uuid.uuid4()))
self._update_fields_values()
return self._primary_field.value
@property
def _primary_field(self):
return self._fields[self._meta['primary']]
@property
def _db_node(self):
try:
return db.get(self._db_key, collection=self._collection)
except KeyError:
return
def validate(self):
self._update_fields_values()
for field in self._fields.values():
if field.schema_in_field is not None:
field.schema = self._fields[field.schema_in_field].value
field.validate()
def to_dict(self):
self._update_fields_values()
return {
f.name: f.value for f in self._fields.values()
}
@classmethod
def load(cls, key):
r = db.get(key, collection=cls._collection)
return cls(**r.properties)
@classmethod
def load_all(cls):
rs = db.all(collection=cls._collection)
return [cls(**r.properties) for r in rs]
def save(self):
db.create(
self._db_key,
properties=self.to_dict(),
collection=self._collection
)
def delete(self):
db.delete(
self._db_key,
collection=self._collection
)
class DBResourceInput(DBObject):
__metaclass__ = DBObjectMeta
_collection = base.BaseGraphDB.COLLECTIONS.input
id = db_field(schema='str!', is_primary=True)
name = db_field(schema='str!')
schema = db_field()
value = db_field(schema_in_field='schema')
is_list = db_field(schema='bool!', default_value=False)
is_hash = db_field(schema='bool!', default_value=False)
receivers = db_related_field(base.BaseGraphDB.RELATION_TYPES.input_to_input,
'DBResourceInput')
@property
def resource(self):
return DBResource(
**db.get_relations(
dest=self._db_node,
type_=base.BaseGraphDB.RELATION_TYPES.resource_input
)[0].start_node.properties
)
def save(self):
self.backtrack_value_emitter._cache_del(self.id)
return super(DBResourceInput, self).save()
def delete(self):
db.delete_relations(
source=self._db_node,
type_=base.BaseGraphDB.RELATION_TYPES.input_to_input
)
db.delete_relations(
dest=self._db_node,
type_=base.BaseGraphDB.RELATION_TYPES.input_to_input
)
self.backtrack_value_emitter._cache_del(self.id)
super(DBResourceInput, self).delete()
def edges(self):
out = db.get_relations(
source=self._db_node,
type_=base.BaseGraphDB.RELATION_TYPES.input_to_input)
incoming = db.get_relations(
dest=self._db_node,
type_=base.BaseGraphDB.RELATION_TYPES.input_to_input)
for relation in out + incoming:
meta = relation.properties
source = DBResourceInput(**relation.start_node.properties)
dest = DBResourceInput(**relation.end_node.properties)
yield source, dest, meta
def check_other_val(self, other_val=None):
if not other_val:
return self
res = self.resource
# TODO: needs to be refactored a lot to be more effective.
# We don't have way of getting single input / value for given resource.
inps = {i.name: i for i in res.inputs.as_set()}
correct_input = inps[other_val]
return correct_input.backtrack_value()
@cache_me({})
def backtrack_value_emitter(self, level=None, other_val=None):
# TODO: this is actually just fetching head element in linked list
# so this whole algorithm can be moved to the db backend probably
# TODO: cycle detection?
# TODO: write this as a Cypher query? Move to DB?
if level is not None and other_val is not None:
raise Exception("Not supported yet")
if level == 0:
return self
def backtrack_func(i):
if level is None:
return i.backtrack_value_emitter(other_val=other_val)
return i.backtrack_value_emitter(level=level - 1, other_val=other_val)
inputs = self.receivers.sources(self)
relations = self.receivers.all_by_dest(self)
source_class = self.receivers.source_db_class
if not inputs:
return self.check_other_val(other_val)
# if lazy_val is None:
# return self.value
# print self.resource.name
# print [x.name for x in self.resource.inputs.as_set()]
# _input = next(x for x in self.resource.inputs.as_set() if x.name == lazy_val)
# return _input.backtrack_value()
# # return self.value
if self.is_list:
if not self.is_hash:
return [backtrack_func(i) for i in inputs]
# NOTE: we return a list of values, but we need to group them
# hence this dict here
# NOTE: grouping is done by resource.name by default, but this
# can be overwritten by the 'tag' property in relation
ret = {}
for r in relations:
source = source_class(**r.start_node.properties)
tag = r.properties['tag']
ret.setdefault(tag, {})
key = r.properties['destination_key']
value = backtrack_func(source)
ret[tag].update({key: value})
return ret.values()
elif self.is_hash:
ret = self.value or {}
for r in relations:
source = source_class(
**r.start_node.properties
)
# NOTE: hard way to do this, what if there are more relations
# and some of them do have destination_key while others
# don't?
if 'destination_key' not in r.properties:
return backtrack_func(source)
key = r.properties['destination_key']
ret[key] = backtrack_func(source)
return ret
return backtrack_func(inputs.pop())
def parse_backtracked_value(self, v):
if isinstance(v, DBResourceInput):
return v.value
if isinstance(v, list):
return [self.parse_backtracked_value(vv) for vv in v]
if isinstance(v, dict):
return {
k: self.parse_backtracked_value(vv) for k, vv in v.items()
}
return v
def backtrack_value(self, other_val=None):
return self.parse_backtracked_value(self.backtrack_value_emitter(other_val=other_val))
class DBEvent(DBObject):
__metaclass__ = DBObjectMeta
_collection = base.BaseGraphDB.COLLECTIONS.events
id = db_field(is_primary=True)
parent = db_field(schema='str!')
parent_action = db_field(schema='str!')
etype = db_field('str!')
state = db_field('str')
child = db_field('str')
child_action = db_field('str')
def delete(self):
db.delete_relations(
dest=self._db_node,
type_=base.BaseGraphDB.RELATION_TYPES.resource_event
)
super(DBEvent, self).delete()
class DBResourceEvents(DBObject):
__metaclass__ = DBObjectMeta
_collection = base.BaseGraphDB.COLLECTIONS.resource_events
id = db_field(schema='str!', is_primary=True)
events = db_related_field(base.BaseGraphDB.RELATION_TYPES.resource_event,
DBEvent)
@classmethod
def get_or_create(cls, name):
r = db.get_or_create(
name,
properties={'id': name},
collection=cls._collection)
return cls(**r.properties)
class DBCommitedState(DBObject):
__metaclass__ = DBObjectMeta
_collection = base.BaseGraphDB.COLLECTIONS.state_data
id = db_field(schema='str!', is_primary=True)
inputs = db_field(schema={}, default_value={})
connections = db_field(schema=[], default_value=[])
base_path = db_field(schema='str')
tags = db_field(schema=[], default_value=[])
state = db_field(schema='str', default_value='removed')
@classmethod
def get_or_create(cls, name):
r = db.get_or_create(
name,
properties={'id': name},
collection=cls._collection)
return cls(**r.properties)
class DBResource(DBObject):
__metaclass__ = DBObjectMeta
_collection = base.BaseGraphDB.COLLECTIONS.resource
id = db_field(schema='str', is_primary=True)
name = db_field(schema='str!')
actions_path = db_field(schema='str')
actions = db_field(schema={}, default_value={})
base_name = db_field(schema='str')
base_path = db_field(schema='str')
handler = db_field(schema='str') # one of: {'ansible_playbook', 'ansible_template', 'puppet', etc}
puppet_module = db_field(schema='str')
version = db_field(schema='str')
tags = db_field(schema=[], default_value=[])
meta_inputs = db_field(schema={}, default_value={})
state = db_field(schema='str')
inputs = db_related_field(base.BaseGraphDB.RELATION_TYPES.resource_input,
DBResourceInput)
def add_input(self, name, schema, value):
# NOTE: Inputs need to have uuid added because there can be many
# inputs with the same name
uid = '{}-{}'.format(name, uuid.uuid4())
input = DBResourceInput(id=uid,
name=name,
schema=schema,
value=value,
is_list=isinstance(schema, list),
is_hash=isinstance(schema, dict) or (isinstance(schema, list) and len(schema) > 0 and isinstance(schema[0], dict)))
input.save()
self.inputs.add(input)
def add_event(self, action, state, etype, child, child_action):
event = DBEvent(
parent=self.name,
parent_action=action,
state=state,
etype=etype,
child=child,
child_action=child_action
)
event.save()
self.events.add(event)
def delete(self):
for input in self.inputs.as_set():
self.inputs.remove(input)
input.delete()
super(DBResource, self).delete()
def graph(self):
mdg = networkx.MultiDiGraph()
for input in self.inputs.as_list():
mdg.add_edges_from(input.edges())
return mdg
def add_tags(self, *tags):
self.tags = list(set(self.tags) | set(tags))
self.save()
def remove_tags(self, *tags):
self.tags = list(set(self.tags) - set(tags))
self.save()
# TODO: remove this
if __name__ == '__main__':
r = DBResource(name=1)
r.validate()

View File

@ -15,7 +15,6 @@
from solar.system_log import data
from solar.dblayer.solar_models import CommitedResource
from dictdiffer import patch
from solar.interfaces import orm
from solar.core.resource import resource
from .consts import CHANGES

View File

@ -1,488 +0,0 @@
# Copyright 2015 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from .base import BaseResourceTest
from solar.core import resource
from solar.core import signals
from solar import errors
from solar.interfaces import orm
from solar.interfaces.db import base
class TestORM(BaseResourceTest):
def test_no_collection_defined(self):
with self.assertRaisesRegexp(NotImplementedError, 'Collection is required.'):
class TestDBObject(orm.DBObject):
__metaclass__ = orm.DBObjectMeta
def test_has_primary(self):
with self.assertRaisesRegexp(errors.SolarError, 'Object needs to have a primary field.'):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
test1 = orm.db_field(schema='str')
def test_no_multiple_primaries(self):
with self.assertRaisesRegexp(errors.SolarError, 'Object cannot have 2 primary fields.'):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
test1 = orm.db_field(schema='str', is_primary=True)
test2 = orm.db_field(schema='str', is_primary=True)
def test_primary_field(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
test1 = orm.db_field(schema='str', is_primary=True)
t = TestDBObject(test1='abc')
self.assertEqual('test1', t._primary_field.name)
self.assertEqual('abc', t._db_key)
t = TestDBObject()
self.assertIsNotNone(t._db_key)
self.assertIsNotNone(t.test1)
def test_default_value(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
test1 = orm.db_field(schema='str',
is_primary=True,
default_value='1')
t = TestDBObject()
self.assertEqual('1', t.test1)
def test_field_validation(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
t = TestDBObject(id=1)
with self.assertRaises(errors.ValidationError):
t.validate()
t = TestDBObject(id='1')
t.validate()
def test_dynamic_schema_validation(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
schema = orm.db_field()
value = orm.db_field(schema_in_field='schema')
t = TestDBObject(id='1', schema='str', value=1)
with self.assertRaises(errors.ValidationError):
t.validate()
self.assertEqual(t._fields['value'].schema, t._fields['schema'].value)
t = TestDBObject(id='1', schema='int', value=1)
t.validate()
self.assertEqual(t._fields['value'].schema, t._fields['schema'].value)
def test_unknown_fields(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
with self.assertRaisesRegexp(errors.SolarError, 'Unknown fields .*iid'):
TestDBObject(iid=1)
def test_equality(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
test = orm.db_field(schema='str')
t1 = TestDBObject(id='1', test='test')
t2 = TestDBObject(id='2', test='test')
self.assertNotEqual(t1, t2)
t2 = TestDBObject(id='1', test='test2')
self.assertNotEqual(t1, t2)
t2 = TestDBObject(id='1', test='test')
self.assertEqual(t1, t2)
class TestORMRelation(BaseResourceTest):
def test_children_value(self):
class TestDBRelatedObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.input
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
related = orm.db_related_field(
base.BaseGraphDB.RELATION_TYPES.resource_input,
TestDBRelatedObject
)
r1 = TestDBRelatedObject(id='1')
r1.save()
r2 = TestDBRelatedObject(id='2')
r2.save()
o = TestDBObject(id='a')
o.save()
self.assertSetEqual(o.related.as_set(), set())
o.related.add(r1)
self.assertSetEqual(o.related.as_set(), {r1})
o.related.add(r2)
self.assertSetEqual(o.related.as_set(), {r1, r2})
o.related.remove(r2)
self.assertSetEqual(o.related.as_set(), {r1})
o.related.add(r2)
self.assertSetEqual(o.related.as_set(), {r1, r2})
o.related.remove(r1, r2)
self.assertSetEqual(o.related.as_set(), set())
o.related.add(r1, r2)
self.assertSetEqual(o.related.as_set(), {r1, r2})
with self.assertRaisesRegexp(errors.SolarError, '.*incompatible type.*'):
o.related.add(o)
def test_relation_to_self(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.input
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
related = orm.db_related_field(
base.BaseGraphDB.RELATION_TYPES.input_to_input,
'TestDBObject'
)
o1 = TestDBObject(id='1')
o1.save()
o2 = TestDBObject(id='2')
o2.save()
o3 = TestDBObject(id='2')
o3.save()
o1.related.add(o2)
o2.related.add(o3)
self.assertEqual(o1.related.as_set(), {o2})
self.assertEqual(o2.related.as_set(), {o3})
class TestResourceORM(BaseResourceTest):
def test_save(self):
r = orm.DBResource(id='test1', name='test1', base_path='x')
r.save()
rr = resource.load(r.id)
self.assertEqual(r, rr.db_obj)
def test_add_input(self):
r = orm.DBResource(id='test1', name='test1', base_path='x')
r.save()
r.add_input('ip', 'str!', '10.0.0.2')
self.assertEqual(len(r.inputs.as_set()), 1)
def test_delete_resource(self):
r = orm.DBResource(id='test1', name='test1', base_path='x')
r.save()
r.add_input('ip', 'str!', '10.0.0.2')
class TestResourceInputORM(BaseResourceTest):
def test_backtrack_simple(self):
sample_meta_dir = self.make_resource_meta("""
id: sample
handler: ansible
version: 1.0.0
input:
value:
schema: str!
value:
""")
sample1 = self.create_resource(
'sample1', sample_meta_dir, {'value': 'x'}
)
sample2 = self.create_resource(
'sample2', sample_meta_dir, {'value': 'y'}
)
sample3 = self.create_resource(
'sample3', sample_meta_dir, {'value': 'z'}
)
vi = sample2.resource_inputs()['value']
self.assertEqual(vi.backtrack_value_emitter(), vi)
# sample1 -> sample2
signals.connect(sample1, sample2)
self.assertEqual(vi.backtrack_value_emitter(),
sample1.resource_inputs()['value'])
# sample3 -> sample1 -> sample2
signals.connect(sample3, sample1)
self.assertEqual(vi.backtrack_value_emitter(),
sample3.resource_inputs()['value'])
# sample2 disconnected
signals.disconnect(sample1, sample2)
self.assertEqual(vi.backtrack_value_emitter(), vi)
def test_backtrack_list(self):
sample_meta_dir = self.make_resource_meta("""
id: sample
handler: ansible
version: 1.0.0
input:
value:
schema: str!
value:
""")
sample_list_meta_dir = self.make_resource_meta("""
id: sample_list
handler: ansible
version: 1.0.0
input:
values:
schema: [str!]
value:
""")
sample_list = self.create_resource(
'sample_list', sample_list_meta_dir
)
vi = sample_list.resource_inputs()['values']
sample1 = self.create_resource(
'sample1', sample_meta_dir, {'value': 'x'}
)
sample2 = self.create_resource(
'sample2', sample_meta_dir, {'value': 'y'}
)
sample3 = self.create_resource(
'sample3', sample_meta_dir, {'value': 'z'}
)
self.assertEqual(vi.backtrack_value_emitter(), vi)
# [sample1] -> sample_list
signals.connect(sample1, sample_list, {'value': 'values'})
self.assertEqual(vi.backtrack_value_emitter(),
[sample1.resource_inputs()['value']])
# [sample3, sample1] -> sample_list
signals.connect(sample3, sample_list, {'value': 'values'})
self.assertSetEqual(set(vi.backtrack_value_emitter()),
set([sample1.resource_inputs()['value'],
sample3.resource_inputs()['value']]))
# sample2 disconnected
signals.disconnect(sample1, sample_list)
self.assertEqual(vi.backtrack_value_emitter(),
[sample3.resource_inputs()['value']])
def test_backtrack_dict(self):
sample_meta_dir = self.make_resource_meta("""
id: sample
handler: ansible
version: 1.0.0
input:
value:
schema: str!
value:
""")
sample_dict_meta_dir = self.make_resource_meta("""
id: sample_dict
handler: ansible
version: 1.0.0
input:
value:
schema: {a: str!}
value:
""")
sample_dict = self.create_resource(
'sample_dict', sample_dict_meta_dir
)
vi = sample_dict.resource_inputs()['value']
sample1 = self.create_resource(
'sample1', sample_meta_dir, {'value': 'x'}
)
sample2 = self.create_resource(
'sample2', sample_meta_dir, {'value': 'z'}
)
self.assertEqual(vi.backtrack_value_emitter(), vi)
# {a: sample1} -> sample_dict
signals.connect(sample1, sample_dict, {'value': 'value:a'})
self.assertDictEqual(vi.backtrack_value_emitter(),
{'a': sample1.resource_inputs()['value']})
# {a: sample2} -> sample_dict
signals.connect(sample2, sample_dict, {'value': 'value:a'})
self.assertDictEqual(vi.backtrack_value_emitter(),
{'a': sample2.resource_inputs()['value']})
# sample2 disconnected
signals.disconnect(sample2, sample_dict)
self.assertEqual(vi.backtrack_value_emitter(), vi)
def test_backtrack_dict_list(self):
sample_meta_dir = self.make_resource_meta("""
id: sample
handler: ansible
version: 1.0.0
input:
value:
schema: str!
value:
""")
sample_dict_list_meta_dir = self.make_resource_meta("""
id: sample_dict_list
handler: ansible
version: 1.0.0
input:
value:
schema: [{a: str!}]
value:
""")
sample_dict_list = self.create_resource(
'sample_dict', sample_dict_list_meta_dir
)
vi = sample_dict_list.resource_inputs()['value']
sample1 = self.create_resource(
'sample1', sample_meta_dir, {'value': 'x'}
)
sample2 = self.create_resource(
'sample2', sample_meta_dir, {'value': 'y'}
)
sample3 = self.create_resource(
'sample3', sample_meta_dir, {'value': 'z'}
)
self.assertEqual(vi.backtrack_value_emitter(), vi)
# [{a: sample1}] -> sample_dict_list
signals.connect(sample1, sample_dict_list, {'value': 'value:a'})
self.assertListEqual(vi.backtrack_value_emitter(),
[{'a': sample1.resource_inputs()['value']}])
# [{a: sample1}, {a: sample3}] -> sample_dict_list
signals.connect(sample3, sample_dict_list, {'value': 'value:a'})
self.assertItemsEqual(vi.backtrack_value_emitter(),
[{'a': sample1.resource_inputs()['value']},
{'a': sample3.resource_inputs()['value']}])
# [{a: sample1}, {a: sample2}] -> sample_dict_list
signals.connect(sample2, sample_dict_list, {'value': 'value:a|sample3'})
self.assertItemsEqual(vi.backtrack_value_emitter(),
[{'a': sample1.resource_inputs()['value']},
{'a': sample2.resource_inputs()['value']}])
# sample2 disconnected
signals.disconnect(sample2, sample_dict_list)
self.assertEqual(vi.backtrack_value_emitter(),
[{'a': sample1.resource_inputs()['value']}])
class TestEventORM(BaseResourceTest):
def test_return_emtpy_set(self):
r = orm.DBResourceEvents(id='test1')
r.save()
self.assertEqual(r.events.as_set(), set())
def test_save_and_load_by_parent(self):
ev = orm.DBEvent(
parent='n1',
parent_action='run',
state='success',
child_action='run',
child='n2',
etype='dependency')
ev.save()
rst = orm.DBEvent.load(ev.id)
self.assertEqual(rst, ev)
def test_save_several(self):
ev = orm.DBEvent(
parent='n1',
parent_action='run',
state='success',
child_action='run',
child='n2',
etype='dependency')
ev.save()
ev1 = orm.DBEvent(
parent='n1',
parent_action='run',
state='success',
child_action='run',
child='n3',
etype='dependency')
ev1.save()
self.assertEqual(len(orm.DBEvent.load_all()), 2)
def test_removal_of_event(self):
r = orm.DBResourceEvents(id='test1')
r.save()
ev = orm.DBEvent(
parent='test1',
parent_action='run',
state='success',
child_action='run',
child='n2',
etype='dependency')
ev.save()
r.events.add(ev)
self.assertEqual(r.events.as_set(), {ev})
ev.delete()
r = orm.DBResourceEvents.load('test1')
self.assertEqual(r.events.as_set(), set())