Removed old db implementation

This commit is contained in:
Jedrzej Nowak 2015-11-17 13:01:21 +01:00
parent 7fa43cc6bc
commit 1807219376
24 changed files with 44 additions and 2343 deletions

19
examples/bootstrap/example-bootstrap.py Normal file → Executable file
View File

@ -10,11 +10,7 @@ from solar.core import signals
from solar.core import validation from solar.core import validation
from solar.core.resource import virtual_resource as vr from solar.core.resource import virtual_resource as vr
from solar import errors from solar import errors
from solar.dblayer.model import ModelMeta
from solar.interfaces.db import get_db
db = get_db()
@click.group() @click.group()
@ -23,9 +19,7 @@ def main():
def setup_resources(): def setup_resources():
db.clear() ModelMeta.remove_all()
signals.Connections.clear()
node2 = vr.create('node2', 'resources/ro_node/', { node2 = vr.create('node2', 'resources/ro_node/', {
'ip': '10.0.0.4', 'ip': '10.0.0.4',
@ -61,7 +55,7 @@ def deploy():
setup_resources() setup_resources()
# run # run
resources = map(resource.wrap_resource, db.get_list(collection=db.COLLECTIONS.resource)) resources = resource.load_all()
resources = {r.name: r for r in resources} resources = {r.name: r for r in resources}
for name in resources_to_run: for name in resources_to_run:
@ -76,7 +70,7 @@ def deploy():
@click.command() @click.command()
def undeploy(): def undeploy():
resources = map(resource.wrap_resource, db.get_list(collection=db.COLLECTIONS.resource)) resources = resource.load_all()
resources = {r.name: r for r in resources} resources = {r.name: r for r in resources}
for name in reversed(resources_to_run): for name in reversed(resources_to_run):
@ -85,10 +79,7 @@ def undeploy():
except errors.SolarError as e: except errors.SolarError as e:
print 'WARNING: %s' % str(e) print 'WARNING: %s' % str(e)
db.clear() ModelMeta.remove_all()
signals.Connections.clear()
main.add_command(deploy) main.add_command(deploy)
main.add_command(undeploy) main.add_command(undeploy)

View File

@ -19,11 +19,9 @@ from solar.core import actions
from solar.core.resource import virtual_resource as vr from solar.core.resource import virtual_resource as vr
from solar.core import resource from solar.core import resource
from solar.core import signals from solar.core import signals
from solar.dblayer.model import ModelMeta
from solar.interfaces.db import get_db
from solar.core.resource_provider import GitProvider, RemoteZipProvider from solar.core.resource_provider import GitProvider, RemoteZipProvider
import resources_compiled import resources_compiled
@ -34,9 +32,7 @@ def main():
@click.command() @click.command()
def deploy(): def deploy():
db = get_db() ModelMeta.remove_all()
db.clear()
signals.Connections.clear() signals.Connections.clear()
node1 = resources_compiled.RoNodeResource('node1', None, {}) node1 = resources_compiled.RoNodeResource('node1', None, {})
@ -75,18 +71,16 @@ def deploy():
@click.command() @click.command()
def undeploy(): def undeploy():
db = get_db() ModelMeta.remove_all()
resources = map(resource.wrap_resource, db.get_list(collection=db.COLLECTIONS.resource)) resources = resource.load_all()
resources = {r.name: r for r in resources} resources = {r.name: r for r in resources}
actions.resource_action(resources['openstack_rabbitmq_user'], 'remove') actions.resource_action(resources['openstack_rabbitmq_user'], 'remove')
actions.resource_action(resources['openstack_vhost'], 'remove') actions.resource_action(resources['openstack_vhost'], 'remove')
actions.resource_action(resources['rabbitmq_service1'], 'remove') actions.resource_action(resources['rabbitmq_service1'], 'remove')
db.clear() ModelMeta.remove_all()
signals.Connections.clear()
main.add_command(deploy) main.add_command(deploy)

View File

@ -4,15 +4,11 @@ import time
from solar.core import signals from solar.core import signals
from solar.core.resource import virtual_resource as vr from solar.core.resource import virtual_resource as vr
from solar.dblayer.model import ModelMeta
from solar.interfaces.db import get_db
db = get_db()
def run(): def run():
db.clear() ModelMeta.remove_all()
resources = vr.create('nodes', 'templates/nodes_with_transports.yaml', {'count': 2}) resources = vr.create('nodes', 'templates/nodes_with_transports.yaml', {'count': 2})
nodes = [x for x in resources if x.name.startswith('node')] nodes = [x for x in resources if x.name.startswith('node')]

View File

@ -1,10 +1,8 @@
from solar.core.resource import virtual_resource as vr from solar.core.resource import virtual_resource as vr
from solar.interfaces.db import get_db from solar.dblayer.model import ModelMeta
import yaml import yaml
db = get_db()
STORAGE = {'objects_ceph': True, STORAGE = {'objects_ceph': True,
'osd_pool_size': 2, 'osd_pool_size': 2,
@ -34,7 +32,7 @@ NETWORK_METADATA = yaml.load("""
def deploy(): def deploy():
db.clear() ModelMeta.remove_all()
resources = vr.create('nodes', 'templates/nodes.yaml', {'count': 2}) resources = vr.create('nodes', 'templates/nodes.yaml', {'count': 2})
first_node, second_node = [x for x in resources if x.name.startswith('node')] first_node, second_node = [x for x in resources if x.name.startswith('node')]
first_transp = next(x for x in resources if x.name.startswith('transport')) first_transp = next(x for x in resources if x.name.startswith('transport'))

8
examples/lxc/example-lxc.py Normal file → Executable file
View File

@ -12,10 +12,10 @@ import click
from solar.core import signals from solar.core import signals
from solar.core.resource import virtual_resource as vr from solar.core.resource import virtual_resource as vr
from solar.interfaces.db import get_db
from solar.system_log import change from solar.system_log import change
from solar.cli import orch from solar.cli import orch
from solar.dblayer.model import ModelMeta
@click.group() @click.group()
def main(): def main():
@ -43,9 +43,7 @@ def lxc_template(idx):
@click.command() @click.command()
def deploy(): def deploy():
db = get_db() ModelMeta.remove_all()
db.clear()
signals.Connections.clear()
node1 = vr.create('nodes', 'templates/nodes.yaml', {})[0] node1 = vr.create('nodes', 'templates/nodes.yaml', {})[0]
seed = vr.create('nodes', 'templates/seed_node.yaml', {})[0] seed = vr.create('nodes', 'templates/seed_node.yaml', {})[0]

View File

@ -8,9 +8,7 @@ from solar.core import signals
from solar.core import validation from solar.core import validation
from solar.core.resource import virtual_resource as vr from solar.core.resource import virtual_resource as vr
from solar import events as evapi from solar import events as evapi
from solar.dblayer.model import ModelMeta
from solar.interfaces.db import get_db
PROFILE = False PROFILE = False
#PROFILE = True #PROFILE = True
@ -35,8 +33,6 @@ if PROFILE:
# Official puppet manifests, not fuel-library # Official puppet manifests, not fuel-library
db = get_db()
@click.group() @click.group()
def main(): def main():
@ -830,7 +826,7 @@ def create_compute(node):
@click.command() @click.command()
def create_all(): def create_all():
db.clear() ModelMeta.remove_all()
r = prepare_nodes(2) r = prepare_nodes(2)
r.update(create_controller('node0')) r.update(create_controller('node0'))
r.update(create_compute('node1')) r.update(create_compute('node1'))
@ -856,7 +852,7 @@ def add_controller(node):
@click.command() @click.command()
def clear(): def clear():
db.clear() ModelMeta.remove_all()
if __name__ == '__main__': if __name__ == '__main__':

7
examples/riak/riaks-template.py Normal file → Executable file
View File

@ -8,16 +8,13 @@ import click
import sys import sys
from solar.core import resource from solar.core import resource
from solar.interfaces.db import get_db
from solar import template from solar import template
from solar.dblayer.model import ModelMeta
db = get_db()
def setup_riak(): def setup_riak():
db.clear()
ModelMeta.remove_all()
nodes = template.nodes_from('templates/riak_nodes.yaml') nodes = template.nodes_from('templates/riak_nodes.yaml')
riak_services = nodes.on_each( riak_services = nodes.on_each(

View File

@ -22,18 +22,13 @@ from solar import errors
from solar.dblayer.model import ModelMeta from solar.dblayer.model import ModelMeta
from solar.interfaces.db import get_db
from solar.events.controls import React, Dep from solar.events.controls import React, Dep
from solar.events.api import add_event from solar.events.api import add_event
from solar.dblayer.solar_models import Resource from solar.dblayer.solar_models import Resource
# db = get_db()
def setup_riak(): def setup_riak():
# db.clear()
ModelMeta.remove_all() ModelMeta.remove_all()
resources = vr.create('nodes', 'templates/nodes.yaml', {'count': 3}) resources = vr.create('nodes', 'templates/nodes.yaml', {'count': 3})

View File

@ -5,16 +5,11 @@ import time
from solar.core import resource from solar.core import resource
from solar.core import signals from solar.core import signals
from solar.core.resource import virtual_resource as vr from solar.core.resource import virtual_resource as vr
from solar.dblayer.model import ModelMeta
from solar.interfaces.db import get_db
db = get_db()
def run(): def run():
db.clear() ModelMeta.remove_all()
node = vr.create('node', 'resources/ro_node', {'name': 'first' + str(time.time()), node = vr.create('node', 'resources/ro_node', {'name': 'first' + str(time.time()),
'ip': '10.0.0.3', 'ip': '10.0.0.3',

View File

@ -2,15 +2,11 @@ import time
from solar.core.resource import virtual_resource as vr from solar.core.resource import virtual_resource as vr
from solar import errors from solar import errors
from solar.dblayer.model import ModelMeta
from solar.interfaces.db import get_db
db = get_db()
def run(): def run():
db.clear() ModelMeta.remove_all()
node = vr.create('node', 'resources/ro_node', {'name': 'first' + str(time.time()), node = vr.create('node', 'resources/ro_node', {'name': 'first' + str(time.time()),
'ip': '10.0.0.3', 'ip': '10.0.0.3',

View File

@ -34,7 +34,6 @@ from solar.core.tags_set_parser import Expression
from solar.core.resource import virtual_resource as vr from solar.core.resource import virtual_resource as vr
from solar.core.log import log from solar.core.log import log
from solar import errors from solar import errors
from solar.interfaces import orm
from solar import utils from solar import utils
from solar.cli import base from solar.cli import base
@ -78,25 +77,26 @@ def init_actions():
@click.option('-d', '--dry-run', default=False, is_flag=True) @click.option('-d', '--dry-run', default=False, is_flag=True)
@click.option('-m', '--dry-run-mapping', default='{}') @click.option('-m', '--dry-run-mapping', default='{}')
def run(dry_run_mapping, dry_run, action, tags): def run(dry_run_mapping, dry_run, action, tags):
if dry_run: raise NotImplementedError("Not yet implemented")
dry_run_executor = executors.DryRunExecutor(mapping=json.loads(dry_run_mapping)) # if dry_run:
# dry_run_executor = executors.DryRunExecutor(mapping=json.loads(dry_run_mapping))
resources = filter( # resources = filter(
lambda r: Expression(tags, r.tags).evaluate(), # lambda r: Expression(tags, r.tags).evaluate(),
orm.DBResource.all() # orm.DBResource.all()
) # )
for r in resources: # for r in resources:
resource_obj = sresource.load(r['id']) # resource_obj = sresource.load(r['id'])
actions.resource_action(resource_obj, action) # actions.resource_action(resource_obj, action)
if dry_run: # if dry_run:
click.echo('EXECUTED:') # click.echo('EXECUTED:')
for key in dry_run_executor.executed: # for key in dry_run_executor.executed:
click.echo('{}: {}'.format( # click.echo('{}: {}'.format(
click.style(dry_run_executor.compute_hash(key), fg='green'), # click.style(dry_run_executor.compute_hash(key), fg='green'),
str(key) # str(key)
)) # ))
def init_cli_connect(): def init_cli_connect():

View File

@ -25,7 +25,6 @@ from solar.core import resource as sresource
from solar.core.resource import virtual_resource as vr from solar.core.resource import virtual_resource as vr
from solar.core.log import log from solar.core.log import log
from solar import errors from solar import errors
from solar.interfaces import orm
from solar import utils from solar import utils
from solar.cli import executors from solar.cli import executors
@ -120,7 +119,6 @@ def clear_all():
click.echo('Clearing all resources and connections') click.echo('Clearing all resources and connections')
ModelMeta.remove_all() ModelMeta.remove_all()
# orm.db.clear()
@resource.command() @resource.command()
@click.argument('name') @click.argument('name')

View File

@ -22,7 +22,6 @@ import os
from solar import utils from solar import utils
from solar.core import validation from solar.core import validation
from solar.interfaces import orm
from solar.core import signals from solar.core import signals
from solar.events import api from solar.events import api

View File

@ -16,7 +16,6 @@
import networkx import networkx
from solar.core.log import log from solar.core.log import log
# from solar.interfaces import orm
from solar.dblayer.solar_models import Resource as DBResource from solar.dblayer.solar_models import Resource as DBResource
@ -143,91 +142,6 @@ def get_mapping(emitter, receiver, mapping=None):
def connect(emitter, receiver, mapping=None): def connect(emitter, receiver, mapping=None):
emitter.connect(receiver, mapping) emitter.connect(receiver, mapping)
# def connect(emitter, receiver, mapping=None):
# if mapping is None:
# mapping = guess_mapping(emitter, receiver)
# # XXX: we didn't agree on that "reverse" there
# location_and_transports(emitter, receiver, mapping)
# if isinstance(mapping, set):
# mapping = {src: src for src in mapping}
# for src, dst in mapping.items():
# if not isinstance(dst, list):
# dst = [dst]
# for d in dst:
# connect_single(emitter, src, receiver, d)
# def connect_single(emitter, src, receiver, dst):
# if ':' in dst:
# return connect_multi(emitter, src, receiver, dst)
# # Disconnect all receiver inputs
# # Check if receiver input is of list type first
# emitter_input = emitter.resource_inputs()[src]
# receiver_input = receiver.resource_inputs()[dst]
# if emitter_input.id == receiver_input.id:
# raise Exception(
# 'Trying to connect {} to itself, this is not possible'.format(
# emitter_input.id)
# )
# if not receiver_input.is_list:
# receiver_input.receivers.delete_all_incoming(receiver_input)
# # Check for cycles
# # TODO: change to get_paths after it is implemented in drivers
# if emitter_input in receiver_input.receivers.as_set():
# raise Exception('Prevented creating a cycle on %s::%s' % (emitter.name,
# emitter_input.name))
# log.debug('Connecting {}::{} -> {}::{}'.format(
# emitter.name, emitter_input.name, receiver.name, receiver_input.name
# ))
# emitter_input.receivers.add(receiver_input)
# def connect_multi(emitter, src, receiver, dst):
# receiver_input_name, receiver_input_key = dst.split(':')
# if '|' in receiver_input_key:
# receiver_input_key, receiver_input_tag = receiver_input_key.split('|')
# else:
# receiver_input_tag = None
# emitter_input = emitter.resource_inputs()[src]
# receiver_input = receiver.resource_inputs()[receiver_input_name]
# if not receiver_input.is_list or receiver_input_tag:
# receiver_input.receivers.delete_all_incoming(
# receiver_input,
# destination_key=receiver_input_key,
# tag=receiver_input_tag
# )
# # We can add default tag now
# receiver_input_tag = receiver_input_tag or emitter.name
# # NOTE: make sure that receiver.args[receiver_input] is of dict type
# if not receiver_input.is_hash:
# raise Exception(
# 'Receiver input {} must be a hash or a list of hashes'.format(receiver_input_name)
# )
# log.debug('Connecting {}::{} -> {}::{}[{}], tag={}'.format(
# emitter.name, emitter_input.name, receiver.name, receiver_input.name,
# receiver_input_key,
# receiver_input_tag
# ))
# emitter_input.receivers.add_hash(
# receiver_input,
# receiver_input_key,
# tag=receiver_input_tag
# )
def disconnect_receiver_by_input(receiver, input_name): def disconnect_receiver_by_input(receiver, input_name):
# input_node = receiver.resource_inputs()[input_name] # input_node = receiver.resource_inputs()[input_name]
@ -236,12 +150,6 @@ def disconnect_receiver_by_input(receiver, input_name):
receiver.db_obj.inputs.disconnect(input_name) receiver.db_obj.inputs.disconnect(input_name)
# def disconnect(emitter, receiver):
# for emitter_input in emitter.resource_inputs().values():
# for receiver_input in receiver.resource_inputs().values():
# emitter_input.receivers.remove(receiver_input)
def detailed_connection_graph(start_with=None, end_with=None, details=False): def detailed_connection_graph(start_with=None, end_with=None, details=False):
from solar.core.resource import Resource, load_all from solar.core.resource import Resource, load_all

View File

@ -18,7 +18,6 @@ __all__ = ['add_dep', 'add_react', 'Dep', 'React', 'add_event']
import networkx as nx import networkx as nx
from solar.core.log import log from solar.core.log import log
from solar.interfaces import orm
from solar.events.controls import Dep, React, StateChange from solar.events.controls import Dep, React, StateChange
from solar.dblayer.solar_models import Resource from solar.dblayer.solar_models import Resource

View File

@ -1,38 +0,0 @@
# Copyright 2015 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import importlib
db_backends = {
'neo4j_db': ('solar.interfaces.db.neo4j', 'Neo4jDB'),
'redis_db': ('solar.interfaces.db.redis_db', 'RedisDB'),
'fakeredis_db': ('solar.interfaces.db.redis_db', 'FakeRedisDB'),
'redis_graph_db': ('solar.interfaces.db.redis_graph_db', 'RedisGraphDB'),
'fakeredis_graph_db': ('solar.interfaces.db.redis_graph_db', 'FakeRedisGraphDB'),
}
CURRENT_DB = 'redis_graph_db'
#CURRENT_DB = 'neo4j_db'
DB = None
def get_db(backend=CURRENT_DB):
# Should be retrieved from config
global DB
if DB is None:
import_path, klass = db_backends[backend]
module = importlib.import_module(import_path)
DB = getattr(module, klass)()
return DB

View File

@ -1,236 +0,0 @@
import abc
from enum import Enum
from functools import partial
class Node(object):
def __init__(self, db, uid, labels, properties):
self.db = db
self.uid = uid
self.labels = labels
self.properties = properties
@property
def collection(self):
return getattr(
BaseGraphDB.COLLECTIONS,
list(self.labels)[0]
)
class Relation(object):
def __init__(self, db, start_node, end_node, properties):
self.db = db
self.start_node = start_node
self.end_node = end_node
self.properties = properties
class DBObjectMeta(abc.ABCMeta):
# Tuples of: function name, is-multi (i.e. returns a list)
node_db_read_methods = [
('all', True),
('create', False),
('get', False),
('get_or_create', False),
]
relation_db_read_methods = [
('all_relations', True),
('create_relation', False),
('get_relations', True),
('get_relation', False),
('get_or_create_relation', False),
]
def __new__(cls, name, parents, dct):
def from_db_list_decorator(converting_func, method):
def wrapper(self, *args, **kwargs):
db_convert = kwargs.pop('db_convert', True)
result = method(self, *args, **kwargs)
if db_convert:
return map(partial(converting_func, self), result)
return result
return wrapper
def from_db_decorator(converting_func, method):
def wrapper(self, *args, **kwargs):
db_convert = kwargs.pop('db_convert', True)
result = method(self, *args, **kwargs)
if result is None:
return
if db_convert:
return converting_func(self, result)
return result
return wrapper
node_db_to_object = cls.find_method(
'node_db_to_object', name, parents, dct
)
relation_db_to_object = cls.find_method(
'relation_db_to_object', name, parents, dct
)
# Node conversions
for method_name, is_list in cls.node_db_read_methods:
method = cls.find_method(method_name, name, parents, dct)
if is_list:
func = from_db_list_decorator
else:
func = from_db_decorator
# Handle subclasses
if not getattr(method, '_wrapped', None):
dct[method_name] = func(node_db_to_object, method)
setattr(dct[method_name], '_wrapped', True)
# Relation conversions
for method_name, is_list in cls.relation_db_read_methods:
method = cls.find_method(method_name, name, parents, dct)
if is_list:
func = from_db_list_decorator
else:
func = from_db_decorator
# Handle subclasses
if not getattr(method, '_wrapped', None):
dct[method_name] = func(relation_db_to_object, method)
setattr(dct[method_name], '_wrapped', True)
return super(DBObjectMeta, cls).__new__(cls, name, parents, dct)
@classmethod
def find_method(cls, method_name, class_name, parents, dict):
if method_name in dict:
return dict[method_name]
for parent in parents:
method = getattr(parent, method_name)
if method:
return method
raise NotImplementedError(
'"{}" method not implemented in class {}'.format(
method_name, class_name
)
)
class BaseGraphDB(object):
__metaclass__ = DBObjectMeta
COLLECTIONS = Enum(
'Collections',
'input resource state_data state_log plan_node plan_graph events stage_log commit_log resource_events'
)
DEFAULT_COLLECTION=COLLECTIONS.resource
RELATION_TYPES = Enum(
'RelationTypes',
'input_to_input resource_input plan_edge graph_to_node resource_event commited'
)
DEFAULT_RELATION=RELATION_TYPES.resource_input
@staticmethod
def node_db_to_object(node_db):
"""Convert node DB object to Node object."""
@staticmethod
def object_to_node_db(node_obj):
"""Convert Node object to node DB object."""
@staticmethod
def relation_db_to_object(relation_db):
"""Convert relation DB object to Relation object."""
@staticmethod
def object_to_relation_db(relation_obj):
"""Convert Relation object to relation DB object."""
@abc.abstractmethod
def all(self, collection=DEFAULT_COLLECTION):
"""Return all elements (nodes) of type `collection`."""
@abc.abstractmethod
def all_relations(self, type_=DEFAULT_RELATION):
"""Return all relations of type `type_`."""
@abc.abstractmethod
def clear(self):
"""Clear the whole DB."""
@abc.abstractmethod
def clear_collection(self, collection=DEFAULT_COLLECTION):
"""Clear all elements (nodes) of type `collection`."""
@abc.abstractmethod
def create(self, name, properties={}, collection=DEFAULT_COLLECTION):
"""Create element (node) with given name, args, of type `collection`."""
@abc.abstractmethod
def delete(self, name, collection=DEFAULT_COLLECTION):
"""Delete element with given name. of type `collection`."""
@abc.abstractmethod
def create_relation(self,
source,
dest,
properties={},
type_=DEFAULT_RELATION):
"""
Create relation (connection) of type `type_` from source to dest with
given args.
"""
@abc.abstractmethod
def get(self, name, collection=DEFAULT_COLLECTION):
"""Fetch element with given name and collection type."""
@abc.abstractmethod
def get_or_create(self,
name,
properties={},
collection=DEFAULT_COLLECTION):
"""
Fetch or create element (if not exists) with given name, args of type
`collection`.
"""
@abc.abstractmethod
def delete_relations(self,
source=None,
dest=None,
type_=DEFAULT_RELATION,
has_properties=None):
"""Delete all relations of type `type_` from source to dest."""
@abc.abstractmethod
def get_relations(self,
source=None,
dest=None,
type_=DEFAULT_RELATION,
has_properties=None):
"""Fetch all relations of type `type_` from source to dest.
NOTE that this function must return only direct relations (edges)
between vertices `source` and `dest` of type `type_`.
If you want all PATHS between `source` and `dest`, write another
method for this (`get_paths`)."""
@abc.abstractmethod
def get_relation(self, source, dest, type_=DEFAULT_RELATION):
"""Fetch relations with given source, dest and type_."""
@abc.abstractmethod
def get_or_create_relation(self,
source,
dest,
properties={},
type_=DEFAULT_RELATION):
"""Fetch or create relation with given args."""

View File

@ -1,205 +0,0 @@
import json
from copy import deepcopy
import py2neo
from solar.core import log
from .base import BaseGraphDB, Node, Relation
class Neo4jDB(BaseGraphDB):
DB = {
'host': 'localhost',
'port': 7474,
}
NEO4J_CLIENT = py2neo.Graph
def __init__(self):
self._r = self.NEO4J_CLIENT('http://{host}:{port}/db/data/'.format(
**self.DB
))
def node_db_to_object(self, node_db):
return Node(
self,
node_db.properties['name'],
node_db.labels,
# Neo4j Node.properties is some strange PropertySet, use dict instead
dict(**node_db.properties)
)
def relation_db_to_object(self, relation_db):
return Relation(
self,
self.node_db_to_object(relation_db.start_node),
self.node_db_to_object(relation_db.end_node),
relation_db.properties
)
def all(self, collection=BaseGraphDB.DEFAULT_COLLECTION):
return [
r.n for r in self._r.cypher.execute(
'MATCH (n:%(collection)s) RETURN n' % {
'collection': collection.name,
}
)
]
def all_relations(self, type_=BaseGraphDB.DEFAULT_RELATION):
return [
r.r for r in self._r.cypher.execute(
*self._relations_query(
source=None, dest=None, type_=type_
)
)
]
def clear(self):
log.log.debug('Clearing whole DB')
self._r.delete_all()
def clear_collection(self, collection=BaseGraphDB.DEFAULT_COLLECTION):
log.log.debug('Clearing collection %s', collection.name)
# TODO: make single DELETE query
self._r.delete([r.n for r in self.all(collection=collection)])
def create(self, name, properties={}, collection=BaseGraphDB.DEFAULT_COLLECTION):
log.log.debug(
'Creating %s, name %s with properties %s',
collection.name,
name,
properties
)
properties = deepcopy(properties)
properties['name'] = name
n = py2neo.Node(collection.name, **properties)
return self._r.create(n)[0]
def create_relation(self,
source,
dest,
properties={},
type_=BaseGraphDB.DEFAULT_RELATION):
log.log.debug(
'Creating %s from %s to %s with properties %s',
type_.name,
source.properties['name'],
dest.properties['name'],
properties
)
s = self.get(
source.properties['name'],
collection=source.collection,
db_convert=False
)
d = self.get(
dest.properties['name'],
collection=dest.collection,
db_convert=False
)
r = py2neo.Relationship(s, type_.name, d, **properties)
self._r.create(r)
return r
def _get_query(self, name, collection=BaseGraphDB.DEFAULT_COLLECTION):
return 'MATCH (n:%(collection)s {name:{name}}) RETURN n' % {
'collection': collection.name,
}, {
'name': name,
}
def get(self, name, collection=BaseGraphDB.DEFAULT_COLLECTION):
query, kwargs = self._get_query(name, collection=collection)
res = self._r.cypher.execute(query, kwargs)
if res:
return res[0].n
def get_or_create(self,
name,
properties={},
collection=BaseGraphDB.DEFAULT_COLLECTION):
n = self.get(name, collection=collection, db_convert=False)
if n:
if properties != n.properties:
n.properties.update(properties)
n.push()
return n
return self.create(name, properties=properties, collection=collection)
def _relations_query(self,
source=None,
dest=None,
type_=BaseGraphDB.DEFAULT_RELATION,
query_type='RETURN'):
kwargs = {}
source_query = '(n)'
if source:
source_query = '(n {name:{source_name}})'
kwargs['source_name'] = source.properties['name']
dest_query = '(m)'
if dest:
dest_query = '(m {name:{dest_name}})'
kwargs['dest_name'] = dest.properties['name']
rel_query = '[r:%(type_)s]' % {'type_': type_.name}
query = ('MATCH %(source_query)s-%(rel_query)s->'
'%(dest_query)s %(query_type)s r' % {
'dest_query': dest_query,
'query_type': query_type,
'rel_query': rel_query,
'source_query': source_query,
})
return query, kwargs
def delete_relations(self,
source=None,
dest=None,
type_=BaseGraphDB.DEFAULT_RELATION):
query, kwargs = self._relations_query(
source=source, dest=dest, type_=type_, query_type='DELETE'
)
self._r.cypher.execute(query, kwargs)
def get_relations(self,
source=None,
dest=None,
type_=BaseGraphDB.DEFAULT_RELATION):
query, kwargs = self._relations_query(
source=source, dest=dest, type_=type_
)
res = self._r.cypher.execute(query, kwargs)
return [r.r for r in res]
def get_relation(self, source, dest, type_=BaseGraphDB.DEFAULT_RELATION):
rel = self.get_relations(source=source, dest=dest, type_=type_)
if rel:
return rel[0]
def get_or_create_relation(self,
source,
dest,
properties={},
type_=BaseGraphDB.DEFAULT_RELATION):
rel = self.get_relations(source=source, dest=dest, type_=type_)
if rel:
r = rel[0]
if properties != r.properties:
r.properties.update(properties)
r.push()
return r
return self.create_relation(source, dest, properties=properties, type_=type_)

View File

@ -1,156 +0,0 @@
# Copyright 2015 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from enum import Enum
try:
import ujson as json
except ImportError:
import json
import redis
import fakeredis
class RedisDB(object):
COLLECTIONS = Enum(
'Collections',
'connection resource state_data state_log events'
)
DB = {
'host': 'localhost',
'port': 6379,
}
REDIS_CLIENT = redis.StrictRedis
def __init__(self):
self._r = self.REDIS_CLIENT(**self.DB)
self.entities = {}
def read(self, uid, collection=COLLECTIONS.resource):
try:
return json.loads(
self._r.get(self._make_key(collection, uid))
)
except TypeError:
return None
def get_list(self, collection=COLLECTIONS.resource):
key_glob = self._make_key(collection, '*')
keys = self._r.keys(key_glob)
with self._r.pipeline() as pipe:
pipe.multi()
values = [self._r.get(key) for key in keys]
pipe.execute()
for value in values:
yield json.loads(value)
def save(self, uid, data, collection=COLLECTIONS.resource):
ret = self._r.set(
self._make_key(collection, uid),
json.dumps(data)
)
return ret
def save_list(self, lst, collection=COLLECTIONS.resource):
with self._r.pipeline() as pipe:
pipe.multi()
for uid, data in lst:
key = self._make_key(collection, uid)
pipe.set(key, json.dumps(data))
pipe.execute()
def clear(self):
self._r.flushdb()
def get_ordered_hash(self, collection):
return OrderedHash(self._r, collection)
def clear_collection(self, collection=COLLECTIONS.resource):
key_glob = self._make_key(collection, '*')
self._r.delete(self._r.keys(key_glob))
def delete(self, uid, collection=COLLECTIONS.resource):
self._r.delete(self._make_key(collection, uid))
def _make_key(self, collection, _id):
if isinstance(collection, self.COLLECTIONS):
collection = collection.name
# NOTE: hiera-redis backend depends on this!
return '{0}:{1}'.format(collection, _id)
class OrderedHash(object):
def __init__(self, client, collection):
self.r = client
self.collection = collection
self.order_counter = '{}:incr'.format(collection)
self.order = '{}:order'.format(collection)
def add(self, items):
pipe = self.r.pipeline()
for key, value in items:
count = self.r.incr(self.order_counter)
pipe.zadd(self.order, count, key)
pipe.hset(self.collection, key, json.dumps(value))
pipe.execute()
def rem(self, keys):
pipe = self.r.pipeline()
for key in keys:
pipe.zrem(self.order, key)
pipe.hdel(self.collection, key)
pipe.execute()
def get(self, key):
value = self.r.hget(self.collection, key)
if value:
return json.loads(value)
return None
def update(self, key, value):
self.r.hset(self.collection, key, json.dumps(value))
def clean(self):
self.rem(self.r.zrange(self.order, 0, -1))
def rem_left(self, n=1):
self.rem(self.r.zrevrange(self.order, 0, n-1))
def reverse(self, n=1):
result = []
for key in self.r.zrevrange(self.order, 0, n-1):
result.append(self.get(key))
return result
def list(self, n=0):
result = []
for key in self.r.zrange(self.order, 0, n-1):
result.append(self.get(key))
return result
class FakeRedisDB(RedisDB):
REDIS_CLIENT = fakeredis.FakeStrictRedis

View File

@ -1,300 +0,0 @@
try:
import ujson as json
except ImportError:
import json
import redis
import fakeredis
from .base import BaseGraphDB, Node, Relation
from .redis_db import OrderedHash
class RedisGraphDB(BaseGraphDB):
DB = {
'host': 'localhost',
'port': 6379,
}
REDIS_CLIENT = redis.StrictRedis
def __init__(self):
self._r = self.REDIS_CLIENT(**self.DB)
self.entities = {}
def node_db_to_object(self, node_db):
if isinstance(node_db, Node):
return node_db
return Node(
self,
node_db['name'],
[node_db['collection']],
node_db['properties']
)
def relation_db_to_object(self, relation_db):
if isinstance(relation_db, Relation):
return relation_db
if relation_db['type_'] == BaseGraphDB.RELATION_TYPES.input_to_input.name:
source_collection = BaseGraphDB.COLLECTIONS.input
dest_collection = BaseGraphDB.COLLECTIONS.input
elif relation_db['type_'] == BaseGraphDB.RELATION_TYPES.resource_input.name:
source_collection = BaseGraphDB.COLLECTIONS.resource
dest_collection = BaseGraphDB.COLLECTIONS.input
elif relation_db['type_'] == BaseGraphDB.RELATION_TYPES.resource_event.name:
source_collection = BaseGraphDB.COLLECTIONS.resource_events
dest_collection = BaseGraphDB.COLLECTIONS.events
source = self.get(relation_db['source'], collection=source_collection)
dest = self.get(relation_db['dest'], collection=dest_collection)
return Relation(
self,
source,
dest,
relation_db['properties']
)
def all(self, collection=BaseGraphDB.DEFAULT_COLLECTION):
"""Return all elements (nodes) of type `collection`."""
key_glob = self._make_collection_key(collection, '*')
for result in self._all(key_glob):
yield result
def all_relations(self, type_=BaseGraphDB.DEFAULT_RELATION):
"""Return all relations of type `type_`."""
key_glob = self._make_relation_key(type_, '*')
for result in self._all(key_glob):
yield result
def _all(self, key_glob):
keys = self._r.keys(key_glob)
with self._r.pipeline() as pipe:
pipe.multi()
values = [self._r.get(key) for key in keys]
pipe.execute()
for value in values:
yield json.loads(value)
def clear(self):
"""Clear the whole DB."""
self._r.flushdb()
def clear_collection(self, collection=BaseGraphDB.DEFAULT_COLLECTION):
"""Clear all elements (nodes) of type `collection`."""
key_glob = self._make_collection_key(collection, '*')
self._r.delete(self._r.keys(key_glob))
def create(self, name, properties={}, collection=BaseGraphDB.DEFAULT_COLLECTION):
"""Create element (node) with given name, properties, of type `collection`."""
if isinstance(collection, self.COLLECTIONS):
collection = collection.name
properties = {
'name': name,
'properties': properties,
'collection': collection,
}
self._r.set(
self._make_collection_key(collection, name),
json.dumps(properties)
)
return properties
def create_relation(self,
source,
dest,
properties={},
type_=BaseGraphDB.DEFAULT_RELATION):
"""
Create relation (connection) of type `type_` from source to dest with
given properties.
"""
return self.create_relation_str(
source.uid, dest.uid, properties, type_=type_)
def create_relation_str(self, source, dest,
properties={}, type_=BaseGraphDB.DEFAULT_RELATION):
if isinstance(type_, self.RELATION_TYPES):
type_ = type_.name
uid = self._make_relation_uid(source, dest)
properties = {
'source': source,
'dest': dest,
'properties': properties,
'type_': type_,
}
self._r.set(
self._make_relation_key(type_, uid),
json.dumps(properties)
)
return properties
def get(self, name, collection=BaseGraphDB.DEFAULT_COLLECTION,
return_empty=False):
"""Fetch element with given name and collection type."""
try:
collection_key = self._make_collection_key(collection, name)
item = self._r.get(collection_key)
if not item and return_empty:
return item
return json.loads(item)
except TypeError:
raise KeyError(collection_key)
def delete(self, name, collection=BaseGraphDB.DEFAULT_COLLECTION):
keys = self._r.keys(self._make_collection_key(collection, name))
if keys:
self._r.delete(*keys)
def get_or_create(self,
name,
properties={},
collection=BaseGraphDB.DEFAULT_COLLECTION):
"""
Fetch or create element (if not exists) with given name, properties of
type `collection`.
"""
try:
return self.get(name, collection=collection)
except KeyError:
return self.create(name, properties=properties, collection=collection)
def _relations_glob(self,
source=None,
dest=None,
type_=BaseGraphDB.DEFAULT_RELATION):
if source is None:
source = '*'
else:
source = source.uid
if dest is None:
dest = '*'
else:
dest = dest.uid
return self._make_relation_key(type_, self._make_relation_uid(source, dest))
def delete_relations(self,
source=None,
dest=None,
type_=BaseGraphDB.DEFAULT_RELATION,
has_properties=None):
"""Delete all relations of type `type_` from source to dest."""
glob = self._relations_glob(source=source, dest=dest, type_=type_)
keys = self._r.keys(glob)
if not keys:
return
if not has_properties:
self._r.delete(*keys)
rels = self.get_relations(
source=source, dest=dest, type_=type_, has_properties=has_properties
)
for r in rels:
self.delete_relations(
source=r.start_node,
dest=r.end_node,
type_=type_
)
def get_relations(self,
source=None,
dest=None,
type_=BaseGraphDB.DEFAULT_RELATION,
has_properties=None):
"""Fetch all relations of type `type_` from source to dest."""
glob = self._relations_glob(source=source, dest=dest, type_=type_)
def check_has_properties(r):
if has_properties:
for k, v in has_properties.items():
if not r['properties'].get(k) == v:
return False
return True
for r in self._all(glob):
# Glob is primitive, we must filter stuff correctly here
if source and r['source'] != source.uid:
continue
if dest and r['dest'] != dest.uid:
continue
if not check_has_properties(r):
continue
yield r
def get_relation(self, source, dest, type_=BaseGraphDB.DEFAULT_RELATION):
"""Fetch relations with given source, dest and type_."""
uid = self._make_relation_key(source.uid, dest.uid)
try:
return json.loads(
self._r.get(self._make_relation_key(type_, uid))
)
except TypeError:
raise KeyError
def get_or_create_relation(self,
source,
dest,
properties=None,
type_=BaseGraphDB.DEFAULT_RELATION):
"""Fetch or create relation with given properties."""
properties = properties or {}
try:
return self.get_relation(source, dest, type_=type_)
except KeyError:
return self.create_relation(source, dest, properties=properties, type_=type_)
def _make_collection_key(self, collection, _id):
if isinstance(collection, self.COLLECTIONS):
collection = collection.name
# NOTE: hiera-redis backend depends on this!
return '{0}:{1}'.format(collection, _id)
def _make_relation_uid(self, source, dest):
"""
There can be only one relation from source to dest, that's why
this function works.
"""
return '{0}-{1}'.format(source, dest)
def _make_relation_key(self, type_, _id):
if isinstance(type_, self.RELATION_TYPES):
type_ = type_.name
# NOTE: hiera-redis backend depends on this!
return '{0}:{1}'.format(type_, _id)
def get_ordered_hash(self, collection):
return OrderedHash(self._r, collection)
class FakeRedisGraphDB(RedisGraphDB):
REDIS_CLIENT = fakeredis.FakeStrictRedis

View File

@ -1,735 +0,0 @@
# Copyright 2015 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import inspect
import networkx
import uuid
from solar import errors
from solar.core import validation
from solar.interfaces.db import base
from solar.interfaces.db import get_db
import os
# USE_CACHE could be set only from CLI
USE_CACHE = int(os.getenv("USE_CACHE", 0))
db = get_db()
from functools import wraps
def _delete_from(store):
def _wrp(key):
try:
del store[key]
except KeyError:
pass
return _wrp
def cache_me(store):
def _inner(f):
# attaching to functions even when no cache enabled for consistency
f._cache_store = store
f._cache_del = _delete_from(store)
@wraps(f)
def _inner2(obj, *args, **kwargs):
try:
return store[obj.id]
except KeyError:
pass
val = f(obj, *args, **kwargs)
if obj.id.startswith('location_id'):
if not val.value:
return val
if obj.id.startswith('transports_id'):
if not val.value:
return val
if isinstance(val, list):
return val
else:
if not val.value:
return val
store[obj.id] = val
return val
if USE_CACHE:
return _inner2
else:
return f
return _inner
class DBField(object):
is_primary = False
schema = None
schema_in_field = None
default_value = None
def __init__(self, name, value=None):
self.name = name
self.value = value
if value is None:
self.value = self.default_value
def __eq__(self, inst):
return self.name == inst.name and self.value == inst.value
def __ne__(self, inst):
return not self.__eq__(inst)
def __hash__(self):
return hash('{}:{}'.format(self.name, self.value))
def validate(self):
if self.schema is None:
return
es = validation.validate_input(self.value, schema=self.schema)
if es:
raise errors.ValidationError('"{}": {}'.format(self.name, es[0]))
def db_field(schema=None,
schema_in_field=None,
default_value=None,
is_primary=False):
"""Definition for the DB field.
schema - simple schema according to the one in solar.core.validation
schema_in_field - if you don't want to fix schema, you can specify
another field in DBObject that will represent the schema used
for validation of this field
is_primary - only one field in db object can be primary. This key is used
for creating key in the DB
"""
class DBFieldX(DBField):
pass
DBFieldX.is_primary = is_primary
DBFieldX.schema = schema
DBFieldX.schema_in_field = schema_in_field
if default_value is not None:
DBFieldX.default_value = default_value
return DBFieldX
class DBRelatedField(object):
source_db_class = None
destination_db_class = None
relation_type = None
def __init__(self, name, source_db_object):
self.name = name
self.source_db_object = source_db_object
@classmethod
def graph(self):
relations = db.get_relations(type_=self.relation_type)
g = networkx.MultiDiGraph()
for r in relations:
source = self.source_db_class(**r.start_node.properties)
dest = self.destination_db_class(**r.end_node.properties)
properties = r.properties.copy()
g.add_edge(
source,
dest,
attr_dict=properties
)
return g
def all(self):
source_db_node = self.source_db_object._db_node
if source_db_node is None:
return []
return db.get_relations(source=source_db_node,
type_=self.relation_type)
def all_by_dest(self, destination_db_object):
destination_db_node = destination_db_object._db_node
if destination_db_node is None:
return set()
return db.get_relations(dest=destination_db_node,
type_=self.relation_type)
def add(self, *destination_db_objects):
for dest in destination_db_objects:
if not isinstance(dest, self.destination_db_class):
raise errors.SolarError(
'Object {} is of incompatible type {}.'.format(
dest, self.destination_db_class
)
)
db.get_or_create_relation(
self.source_db_object._db_node,
dest._db_node,
properties={},
type_=self.relation_type
)
def add_hash(self, destination_db_object, destination_key, tag=None):
if not isinstance(destination_db_object, self.destination_db_class):
raise errors.SolarError(
'Object {} is of incompatible type {}.'.format(
destination_db_object, self.destination_db_class
)
)
db.get_or_create_relation(
self.source_db_object._db_node,
destination_db_object._db_node,
properties={'destination_key': destination_key, 'tag': tag},
type_=self.relation_type
)
def remove(self, *destination_db_objects):
for dest in destination_db_objects:
db.delete_relations(
source=self.source_db_object._db_node,
dest=dest._db_node,
type_=self.relation_type
)
def as_set(self):
"""
Return DB objects that are destinations for self.source_db_object.
"""
relations = self.all()
ret = set()
for rel in relations:
ret.add(
self.destination_db_class(**rel.end_node.properties)
)
return ret
def as_list(self):
relations = self.all()
ret = []
for rel in relations:
ret.append(
self.destination_db_class(**rel.end_node.properties)
)
return ret
def sources(self, destination_db_object):
"""
Reverse of self.as_set, i.e. for given destination_db_object,
return source DB objects.
"""
relations = self.all_by_dest(destination_db_object)
ret = set()
for rel in relations:
ret.add(
self.source_db_class(**rel.start_node.properties)
)
return ret
def delete_all_incoming(self,
destination_db_object,
destination_key=None,
tag=None):
"""
Delete all relations for which destination_db_object is an end node.
If object is a hash, you can additionally specify the dst_key argument.
Then only connections that are destinations of dst_key will be deleted.
Same with tag.
"""
properties = {}
if destination_key is not None:
properties['destination_key'] = destination_key
if tag is not None:
properties['tag'] = tag
db.delete_relations(
dest=destination_db_object._db_node,
type_=self.relation_type,
has_properties=properties or None
)
def db_related_field(relation_type, destination_db_class):
class DBRelatedFieldX(DBRelatedField):
pass
DBRelatedFieldX.relation_type = relation_type
DBRelatedFieldX.destination_db_class = destination_db_class
return DBRelatedFieldX
class DBObjectMeta(type):
def __new__(cls, name, parents, dct):
collection = dct.get('_collection')
if not collection:
raise NotImplementedError('Collection is required.')
dct['_meta'] = {}
dct['_meta']['fields'] = {}
dct['_meta']['related_to'] = {}
has_primary = False
for field_name, field_klass in dct.items():
if not inspect.isclass(field_klass):
continue
if issubclass(field_klass, DBField):
dct['_meta']['fields'][field_name] = field_klass
if field_klass.is_primary:
if has_primary:
raise errors.SolarError('Object cannot have 2 primary fields.')
has_primary = True
dct['_meta']['primary'] = field_name
elif issubclass(field_klass, DBRelatedField):
dct['_meta']['related_to'][field_name] = field_klass
if not has_primary:
raise errors.SolarError('Object needs to have a primary field.')
klass = super(DBObjectMeta, cls).__new__(cls, name, parents, dct)
# Support for self-references in relations
for field_name, field_klass in klass._meta['related_to'].items():
field_klass.source_db_class = klass
if field_klass.destination_db_class == klass.__name__:
field_klass.destination_db_class = klass
return klass
class DBObject(object):
# Enum from BaseGraphDB.COLLECTIONS
_collection = None
def __init__(self, **kwargs):
wrong_fields = set(kwargs) - set(self._meta['fields'])
if wrong_fields:
raise errors.SolarError(
'Unknown fields {}'.format(wrong_fields)
)
self._fields = {}
for field_name, field_klass in self._meta['fields'].items():
value = kwargs.get(field_name, field_klass.default_value)
self._fields[field_name] = field_klass(field_name, value=value)
self._related_to = {}
for field_name, field_klass in self._meta['related_to'].items():
inst = field_klass(field_name, self)
self._related_to[field_name] = inst
self._update_values()
def __eq__(self, inst):
# NOTE: don't compare related fields
self._update_fields_values()
return self._fields == inst._fields
def __ne__(self, inst):
return not self.__eq__(inst)
def __hash__(self):
return hash(self._db_key)
def _update_fields_values(self):
"""Copy values from self to self._fields."""
for field in self._fields.values():
field.value = getattr(self, field.name)
def _update_values(self):
"""
Reverse of _update_fields_values, i.e. copy values from self._fields to
self."""
for field in self._fields.values():
setattr(self, field.name, field.value)
for field in self._related_to.values():
setattr(self, field.name, field)
@property
def _db_key(self):
"""Key for the DB document (in KV-store).
You can overwrite this with custom keys."""
if not self._primary_field.value:
setattr(self, self._primary_field.name, unicode(uuid.uuid4()))
self._update_fields_values()
return self._primary_field.value
@property
def _primary_field(self):
return self._fields[self._meta['primary']]
@property
def _db_node(self):
try:
return db.get(self._db_key, collection=self._collection)
except KeyError:
return
def validate(self):
self._update_fields_values()
for field in self._fields.values():
if field.schema_in_field is not None:
field.schema = self._fields[field.schema_in_field].value
field.validate()
def to_dict(self):
self._update_fields_values()
return {
f.name: f.value for f in self._fields.values()
}
@classmethod
def load(cls, key):
r = db.get(key, collection=cls._collection)
return cls(**r.properties)
@classmethod
def load_all(cls):
rs = db.all(collection=cls._collection)
return [cls(**r.properties) for r in rs]
def save(self):
db.create(
self._db_key,
properties=self.to_dict(),
collection=self._collection
)
def delete(self):
db.delete(
self._db_key,
collection=self._collection
)
class DBResourceInput(DBObject):
__metaclass__ = DBObjectMeta
_collection = base.BaseGraphDB.COLLECTIONS.input
id = db_field(schema='str!', is_primary=True)
name = db_field(schema='str!')
schema = db_field()
value = db_field(schema_in_field='schema')
is_list = db_field(schema='bool!', default_value=False)
is_hash = db_field(schema='bool!', default_value=False)
receivers = db_related_field(base.BaseGraphDB.RELATION_TYPES.input_to_input,
'DBResourceInput')
@property
def resource(self):
return DBResource(
**db.get_relations(
dest=self._db_node,
type_=base.BaseGraphDB.RELATION_TYPES.resource_input
)[0].start_node.properties
)
def save(self):
self.backtrack_value_emitter._cache_del(self.id)
return super(DBResourceInput, self).save()
def delete(self):
db.delete_relations(
source=self._db_node,
type_=base.BaseGraphDB.RELATION_TYPES.input_to_input
)
db.delete_relations(
dest=self._db_node,
type_=base.BaseGraphDB.RELATION_TYPES.input_to_input
)
self.backtrack_value_emitter._cache_del(self.id)
super(DBResourceInput, self).delete()
def edges(self):
out = db.get_relations(
source=self._db_node,
type_=base.BaseGraphDB.RELATION_TYPES.input_to_input)
incoming = db.get_relations(
dest=self._db_node,
type_=base.BaseGraphDB.RELATION_TYPES.input_to_input)
for relation in out + incoming:
meta = relation.properties
source = DBResourceInput(**relation.start_node.properties)
dest = DBResourceInput(**relation.end_node.properties)
yield source, dest, meta
def check_other_val(self, other_val=None):
if not other_val:
return self
res = self.resource
# TODO: needs to be refactored a lot to be more effective.
# We don't have way of getting single input / value for given resource.
inps = {i.name: i for i in res.inputs.as_set()}
correct_input = inps[other_val]
return correct_input.backtrack_value()
@cache_me({})
def backtrack_value_emitter(self, level=None, other_val=None):
# TODO: this is actually just fetching head element in linked list
# so this whole algorithm can be moved to the db backend probably
# TODO: cycle detection?
# TODO: write this as a Cypher query? Move to DB?
if level is not None and other_val is not None:
raise Exception("Not supported yet")
if level == 0:
return self
def backtrack_func(i):
if level is None:
return i.backtrack_value_emitter(other_val=other_val)
return i.backtrack_value_emitter(level=level - 1, other_val=other_val)
inputs = self.receivers.sources(self)
relations = self.receivers.all_by_dest(self)
source_class = self.receivers.source_db_class
if not inputs:
return self.check_other_val(other_val)
# if lazy_val is None:
# return self.value
# print self.resource.name
# print [x.name for x in self.resource.inputs.as_set()]
# _input = next(x for x in self.resource.inputs.as_set() if x.name == lazy_val)
# return _input.backtrack_value()
# # return self.value
if self.is_list:
if not self.is_hash:
return [backtrack_func(i) for i in inputs]
# NOTE: we return a list of values, but we need to group them
# hence this dict here
# NOTE: grouping is done by resource.name by default, but this
# can be overwritten by the 'tag' property in relation
ret = {}
for r in relations:
source = source_class(**r.start_node.properties)
tag = r.properties['tag']
ret.setdefault(tag, {})
key = r.properties['destination_key']
value = backtrack_func(source)
ret[tag].update({key: value})
return ret.values()
elif self.is_hash:
ret = self.value or {}
for r in relations:
source = source_class(
**r.start_node.properties
)
# NOTE: hard way to do this, what if there are more relations
# and some of them do have destination_key while others
# don't?
if 'destination_key' not in r.properties:
return backtrack_func(source)
key = r.properties['destination_key']
ret[key] = backtrack_func(source)
return ret
return backtrack_func(inputs.pop())
def parse_backtracked_value(self, v):
if isinstance(v, DBResourceInput):
return v.value
if isinstance(v, list):
return [self.parse_backtracked_value(vv) for vv in v]
if isinstance(v, dict):
return {
k: self.parse_backtracked_value(vv) for k, vv in v.items()
}
return v
def backtrack_value(self, other_val=None):
return self.parse_backtracked_value(self.backtrack_value_emitter(other_val=other_val))
class DBEvent(DBObject):
__metaclass__ = DBObjectMeta
_collection = base.BaseGraphDB.COLLECTIONS.events
id = db_field(is_primary=True)
parent = db_field(schema='str!')
parent_action = db_field(schema='str!')
etype = db_field('str!')
state = db_field('str')
child = db_field('str')
child_action = db_field('str')
def delete(self):
db.delete_relations(
dest=self._db_node,
type_=base.BaseGraphDB.RELATION_TYPES.resource_event
)
super(DBEvent, self).delete()
class DBResourceEvents(DBObject):
__metaclass__ = DBObjectMeta
_collection = base.BaseGraphDB.COLLECTIONS.resource_events
id = db_field(schema='str!', is_primary=True)
events = db_related_field(base.BaseGraphDB.RELATION_TYPES.resource_event,
DBEvent)
@classmethod
def get_or_create(cls, name):
r = db.get_or_create(
name,
properties={'id': name},
collection=cls._collection)
return cls(**r.properties)
class DBCommitedState(DBObject):
__metaclass__ = DBObjectMeta
_collection = base.BaseGraphDB.COLLECTIONS.state_data
id = db_field(schema='str!', is_primary=True)
inputs = db_field(schema={}, default_value={})
connections = db_field(schema=[], default_value=[])
base_path = db_field(schema='str')
tags = db_field(schema=[], default_value=[])
state = db_field(schema='str', default_value='removed')
@classmethod
def get_or_create(cls, name):
r = db.get_or_create(
name,
properties={'id': name},
collection=cls._collection)
return cls(**r.properties)
class DBResource(DBObject):
__metaclass__ = DBObjectMeta
_collection = base.BaseGraphDB.COLLECTIONS.resource
id = db_field(schema='str', is_primary=True)
name = db_field(schema='str!')
actions_path = db_field(schema='str')
actions = db_field(schema={}, default_value={})
base_name = db_field(schema='str')
base_path = db_field(schema='str')
handler = db_field(schema='str') # one of: {'ansible_playbook', 'ansible_template', 'puppet', etc}
puppet_module = db_field(schema='str')
version = db_field(schema='str')
tags = db_field(schema=[], default_value=[])
meta_inputs = db_field(schema={}, default_value={})
state = db_field(schema='str')
inputs = db_related_field(base.BaseGraphDB.RELATION_TYPES.resource_input,
DBResourceInput)
def add_input(self, name, schema, value):
# NOTE: Inputs need to have uuid added because there can be many
# inputs with the same name
uid = '{}-{}'.format(name, uuid.uuid4())
input = DBResourceInput(id=uid,
name=name,
schema=schema,
value=value,
is_list=isinstance(schema, list),
is_hash=isinstance(schema, dict) or (isinstance(schema, list) and len(schema) > 0 and isinstance(schema[0], dict)))
input.save()
self.inputs.add(input)
def add_event(self, action, state, etype, child, child_action):
event = DBEvent(
parent=self.name,
parent_action=action,
state=state,
etype=etype,
child=child,
child_action=child_action
)
event.save()
self.events.add(event)
def delete(self):
for input in self.inputs.as_set():
self.inputs.remove(input)
input.delete()
super(DBResource, self).delete()
def graph(self):
mdg = networkx.MultiDiGraph()
for input in self.inputs.as_list():
mdg.add_edges_from(input.edges())
return mdg
def add_tags(self, *tags):
self.tags = list(set(self.tags) | set(tags))
self.save()
def remove_tags(self, *tags):
self.tags = list(set(self.tags) - set(tags))
self.save()
# TODO: remove this
if __name__ == '__main__':
r = DBResource(name=1)
r.validate()

View File

@ -15,7 +15,6 @@
from solar.system_log import data from solar.system_log import data
from solar.dblayer.solar_models import CommitedResource from solar.dblayer.solar_models import CommitedResource
from dictdiffer import patch from dictdiffer import patch
from solar.interfaces import orm
from solar.core.resource import resource from solar.core.resource import resource
from .consts import CHANGES from .consts import CHANGES

View File

@ -1,488 +0,0 @@
# Copyright 2015 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from .base import BaseResourceTest
from solar.core import resource
from solar.core import signals
from solar import errors
from solar.interfaces import orm
from solar.interfaces.db import base
class TestORM(BaseResourceTest):
def test_no_collection_defined(self):
with self.assertRaisesRegexp(NotImplementedError, 'Collection is required.'):
class TestDBObject(orm.DBObject):
__metaclass__ = orm.DBObjectMeta
def test_has_primary(self):
with self.assertRaisesRegexp(errors.SolarError, 'Object needs to have a primary field.'):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
test1 = orm.db_field(schema='str')
def test_no_multiple_primaries(self):
with self.assertRaisesRegexp(errors.SolarError, 'Object cannot have 2 primary fields.'):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
test1 = orm.db_field(schema='str', is_primary=True)
test2 = orm.db_field(schema='str', is_primary=True)
def test_primary_field(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
test1 = orm.db_field(schema='str', is_primary=True)
t = TestDBObject(test1='abc')
self.assertEqual('test1', t._primary_field.name)
self.assertEqual('abc', t._db_key)
t = TestDBObject()
self.assertIsNotNone(t._db_key)
self.assertIsNotNone(t.test1)
def test_default_value(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
test1 = orm.db_field(schema='str',
is_primary=True,
default_value='1')
t = TestDBObject()
self.assertEqual('1', t.test1)
def test_field_validation(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
t = TestDBObject(id=1)
with self.assertRaises(errors.ValidationError):
t.validate()
t = TestDBObject(id='1')
t.validate()
def test_dynamic_schema_validation(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
schema = orm.db_field()
value = orm.db_field(schema_in_field='schema')
t = TestDBObject(id='1', schema='str', value=1)
with self.assertRaises(errors.ValidationError):
t.validate()
self.assertEqual(t._fields['value'].schema, t._fields['schema'].value)
t = TestDBObject(id='1', schema='int', value=1)
t.validate()
self.assertEqual(t._fields['value'].schema, t._fields['schema'].value)
def test_unknown_fields(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
with self.assertRaisesRegexp(errors.SolarError, 'Unknown fields .*iid'):
TestDBObject(iid=1)
def test_equality(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
test = orm.db_field(schema='str')
t1 = TestDBObject(id='1', test='test')
t2 = TestDBObject(id='2', test='test')
self.assertNotEqual(t1, t2)
t2 = TestDBObject(id='1', test='test2')
self.assertNotEqual(t1, t2)
t2 = TestDBObject(id='1', test='test')
self.assertEqual(t1, t2)
class TestORMRelation(BaseResourceTest):
def test_children_value(self):
class TestDBRelatedObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.input
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.resource
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
related = orm.db_related_field(
base.BaseGraphDB.RELATION_TYPES.resource_input,
TestDBRelatedObject
)
r1 = TestDBRelatedObject(id='1')
r1.save()
r2 = TestDBRelatedObject(id='2')
r2.save()
o = TestDBObject(id='a')
o.save()
self.assertSetEqual(o.related.as_set(), set())
o.related.add(r1)
self.assertSetEqual(o.related.as_set(), {r1})
o.related.add(r2)
self.assertSetEqual(o.related.as_set(), {r1, r2})
o.related.remove(r2)
self.assertSetEqual(o.related.as_set(), {r1})
o.related.add(r2)
self.assertSetEqual(o.related.as_set(), {r1, r2})
o.related.remove(r1, r2)
self.assertSetEqual(o.related.as_set(), set())
o.related.add(r1, r2)
self.assertSetEqual(o.related.as_set(), {r1, r2})
with self.assertRaisesRegexp(errors.SolarError, '.*incompatible type.*'):
o.related.add(o)
def test_relation_to_self(self):
class TestDBObject(orm.DBObject):
_collection = base.BaseGraphDB.COLLECTIONS.input
__metaclass__ = orm.DBObjectMeta
id = orm.db_field(schema='str', is_primary=True)
related = orm.db_related_field(
base.BaseGraphDB.RELATION_TYPES.input_to_input,
'TestDBObject'
)
o1 = TestDBObject(id='1')
o1.save()
o2 = TestDBObject(id='2')
o2.save()
o3 = TestDBObject(id='2')
o3.save()
o1.related.add(o2)
o2.related.add(o3)
self.assertEqual(o1.related.as_set(), {o2})
self.assertEqual(o2.related.as_set(), {o3})
class TestResourceORM(BaseResourceTest):
def test_save(self):
r = orm.DBResource(id='test1', name='test1', base_path='x')
r.save()
rr = resource.load(r.id)
self.assertEqual(r, rr.db_obj)
def test_add_input(self):
r = orm.DBResource(id='test1', name='test1', base_path='x')
r.save()
r.add_input('ip', 'str!', '10.0.0.2')
self.assertEqual(len(r.inputs.as_set()), 1)
def test_delete_resource(self):
r = orm.DBResource(id='test1', name='test1', base_path='x')
r.save()
r.add_input('ip', 'str!', '10.0.0.2')
class TestResourceInputORM(BaseResourceTest):
def test_backtrack_simple(self):
sample_meta_dir = self.make_resource_meta("""
id: sample
handler: ansible
version: 1.0.0
input:
value:
schema: str!
value:
""")
sample1 = self.create_resource(
'sample1', sample_meta_dir, {'value': 'x'}
)
sample2 = self.create_resource(
'sample2', sample_meta_dir, {'value': 'y'}
)
sample3 = self.create_resource(
'sample3', sample_meta_dir, {'value': 'z'}
)
vi = sample2.resource_inputs()['value']
self.assertEqual(vi.backtrack_value_emitter(), vi)
# sample1 -> sample2
signals.connect(sample1, sample2)
self.assertEqual(vi.backtrack_value_emitter(),
sample1.resource_inputs()['value'])
# sample3 -> sample1 -> sample2
signals.connect(sample3, sample1)
self.assertEqual(vi.backtrack_value_emitter(),
sample3.resource_inputs()['value'])
# sample2 disconnected
signals.disconnect(sample1, sample2)
self.assertEqual(vi.backtrack_value_emitter(), vi)
def test_backtrack_list(self):
sample_meta_dir = self.make_resource_meta("""
id: sample
handler: ansible
version: 1.0.0
input:
value:
schema: str!
value:
""")
sample_list_meta_dir = self.make_resource_meta("""
id: sample_list
handler: ansible
version: 1.0.0
input:
values:
schema: [str!]
value:
""")
sample_list = self.create_resource(
'sample_list', sample_list_meta_dir
)
vi = sample_list.resource_inputs()['values']
sample1 = self.create_resource(
'sample1', sample_meta_dir, {'value': 'x'}
)
sample2 = self.create_resource(
'sample2', sample_meta_dir, {'value': 'y'}
)
sample3 = self.create_resource(
'sample3', sample_meta_dir, {'value': 'z'}
)
self.assertEqual(vi.backtrack_value_emitter(), vi)
# [sample1] -> sample_list
signals.connect(sample1, sample_list, {'value': 'values'})
self.assertEqual(vi.backtrack_value_emitter(),
[sample1.resource_inputs()['value']])
# [sample3, sample1] -> sample_list
signals.connect(sample3, sample_list, {'value': 'values'})
self.assertSetEqual(set(vi.backtrack_value_emitter()),
set([sample1.resource_inputs()['value'],
sample3.resource_inputs()['value']]))
# sample2 disconnected
signals.disconnect(sample1, sample_list)
self.assertEqual(vi.backtrack_value_emitter(),
[sample3.resource_inputs()['value']])
def test_backtrack_dict(self):
sample_meta_dir = self.make_resource_meta("""
id: sample
handler: ansible
version: 1.0.0
input:
value:
schema: str!
value:
""")
sample_dict_meta_dir = self.make_resource_meta("""
id: sample_dict
handler: ansible
version: 1.0.0
input:
value:
schema: {a: str!}
value:
""")
sample_dict = self.create_resource(
'sample_dict', sample_dict_meta_dir
)
vi = sample_dict.resource_inputs()['value']
sample1 = self.create_resource(
'sample1', sample_meta_dir, {'value': 'x'}
)
sample2 = self.create_resource(
'sample2', sample_meta_dir, {'value': 'z'}
)
self.assertEqual(vi.backtrack_value_emitter(), vi)
# {a: sample1} -> sample_dict
signals.connect(sample1, sample_dict, {'value': 'value:a'})
self.assertDictEqual(vi.backtrack_value_emitter(),
{'a': sample1.resource_inputs()['value']})
# {a: sample2} -> sample_dict
signals.connect(sample2, sample_dict, {'value': 'value:a'})
self.assertDictEqual(vi.backtrack_value_emitter(),
{'a': sample2.resource_inputs()['value']})
# sample2 disconnected
signals.disconnect(sample2, sample_dict)
self.assertEqual(vi.backtrack_value_emitter(), vi)
def test_backtrack_dict_list(self):
sample_meta_dir = self.make_resource_meta("""
id: sample
handler: ansible
version: 1.0.0
input:
value:
schema: str!
value:
""")
sample_dict_list_meta_dir = self.make_resource_meta("""
id: sample_dict_list
handler: ansible
version: 1.0.0
input:
value:
schema: [{a: str!}]
value:
""")
sample_dict_list = self.create_resource(
'sample_dict', sample_dict_list_meta_dir
)
vi = sample_dict_list.resource_inputs()['value']
sample1 = self.create_resource(
'sample1', sample_meta_dir, {'value': 'x'}
)
sample2 = self.create_resource(
'sample2', sample_meta_dir, {'value': 'y'}
)
sample3 = self.create_resource(
'sample3', sample_meta_dir, {'value': 'z'}
)
self.assertEqual(vi.backtrack_value_emitter(), vi)
# [{a: sample1}] -> sample_dict_list
signals.connect(sample1, sample_dict_list, {'value': 'value:a'})
self.assertListEqual(vi.backtrack_value_emitter(),
[{'a': sample1.resource_inputs()['value']}])
# [{a: sample1}, {a: sample3}] -> sample_dict_list
signals.connect(sample3, sample_dict_list, {'value': 'value:a'})
self.assertItemsEqual(vi.backtrack_value_emitter(),
[{'a': sample1.resource_inputs()['value']},
{'a': sample3.resource_inputs()['value']}])
# [{a: sample1}, {a: sample2}] -> sample_dict_list
signals.connect(sample2, sample_dict_list, {'value': 'value:a|sample3'})
self.assertItemsEqual(vi.backtrack_value_emitter(),
[{'a': sample1.resource_inputs()['value']},
{'a': sample2.resource_inputs()['value']}])
# sample2 disconnected
signals.disconnect(sample2, sample_dict_list)
self.assertEqual(vi.backtrack_value_emitter(),
[{'a': sample1.resource_inputs()['value']}])
class TestEventORM(BaseResourceTest):
def test_return_emtpy_set(self):
r = orm.DBResourceEvents(id='test1')
r.save()
self.assertEqual(r.events.as_set(), set())
def test_save_and_load_by_parent(self):
ev = orm.DBEvent(
parent='n1',
parent_action='run',
state='success',
child_action='run',
child='n2',
etype='dependency')
ev.save()
rst = orm.DBEvent.load(ev.id)
self.assertEqual(rst, ev)
def test_save_several(self):
ev = orm.DBEvent(
parent='n1',
parent_action='run',
state='success',
child_action='run',
child='n2',
etype='dependency')
ev.save()
ev1 = orm.DBEvent(
parent='n1',
parent_action='run',
state='success',
child_action='run',
child='n3',
etype='dependency')
ev1.save()
self.assertEqual(len(orm.DBEvent.load_all()), 2)
def test_removal_of_event(self):
r = orm.DBResourceEvents(id='test1')
r.save()
ev = orm.DBEvent(
parent='test1',
parent_action='run',
state='success',
child_action='run',
child='n2',
etype='dependency')
ev.save()
r.events.add(ev)
self.assertEqual(r.events.as_set(), {ev})
ev.delete()
r = orm.DBResourceEvents.load('test1')
self.assertEqual(r.events.as_set(), set())