Merge pull request #155 from CGenie/cgenie/graph-db-hash-type

Add support for hash inputs, enable multiple inputs to connect into o…
This commit is contained in:
Jędrzej Nowak 2015-09-16 11:47:12 +02:00
commit 2cd5d948c2
5 changed files with 256 additions and 34 deletions

View File

@ -38,8 +38,9 @@ class Resource(object):
_metadata = {} _metadata = {}
# Create # Create
@dispatch(str, str, dict) @dispatch(str, str)
def __init__(self, name, base_path, args, tags=None, virtual_resource=None): def __init__(self, name, base_path, args=None, tags=None, virtual_resource=None):
args = args or {}
self.name = name self.name = name
if base_path: if base_path:
metadata = read_meta(base_path) metadata = read_meta(base_path)
@ -86,7 +87,8 @@ class Resource(object):
k: v for k, v in ret.items() if os.path.isfile(v) k: v for k, v in ret.items() if os.path.isfile(v)
} }
def create_inputs(self, args): def create_inputs(self, args=None):
args = args or {}
for name, v in self.db_obj.meta_inputs.items(): for name, v in self.db_obj.meta_inputs.items():
value = args.get(name, v.get('value')) value = args.get(name, v.get('value'))

View File

@ -24,7 +24,8 @@ from solar.core import resource
from solar.core import signals from solar.core import signals
def create(name, base_path, args={}, virtual_resource=None): def create(name, base_path, args=None, virtual_resource=None):
args = args or {}
if isinstance(base_path, provider.BaseProvider): if isinstance(base_path, provider.BaseProvider):
base_path = base_path.directory base_path = base_path.directory
@ -47,12 +48,13 @@ def create(name, base_path, args={}, virtual_resource=None):
return rs return rs
def create_resource(name, base_path, args={}, virtual_resource=None): def create_resource(name, base_path, args=None, virtual_resource=None):
args = args or {}
if isinstance(base_path, provider.BaseProvider): if isinstance(base_path, provider.BaseProvider):
base_path = base_path.directory base_path = base_path.directory
r = resource.Resource( r = resource.Resource(
name, base_path, args, tags=[], virtual_resource=virtual_resource name, base_path, args=args, tags=[], virtual_resource=virtual_resource
) )
return r return r

View File

@ -73,6 +73,9 @@ def connect(emitter, receiver, mapping={}, events=None):
def connect_single(emitter, src, receiver, dst): def connect_single(emitter, src, receiver, dst):
if ':' in dst:
return connect_multi(emitter, src, receiver, dst)
# Disconnect all receiver inputs # Disconnect all receiver inputs
# Check if receiver input is of list type first # Check if receiver input is of list type first
emitter_input = emitter.resource_inputs()[src] emitter_input = emitter.resource_inputs()[src]
@ -98,6 +101,25 @@ def connect_single(emitter, src, receiver, dst):
emitter_input.receivers.add(receiver_input) emitter_input.receivers.add(receiver_input)
def connect_multi(emitter, src, receiver, dst):
receiver_input_name, receiver_input_key = dst.split(':')
emitter_input = emitter.resource_inputs()[src]
receiver_input = receiver.resource_inputs()[receiver_input_name]
# NOTE: make sure that receiver.args[receiver_input] is of dict type
if not receiver_input.is_hash:
raise Exception(
'Receiver input {} must be a hash or a list of hashes'.format(receiver_input_name)
)
log.debug('Connecting {}::{} -> {}::{}[{}]'.format(
emitter.name, emitter_input.name, receiver.name, receiver_input.name,
receiver_input_key
))
emitter_input.receivers.add_hash(receiver_input, receiver_input_key)
def disconnect_receiver_by_input(receiver, input_name): def disconnect_receiver_by_input(receiver, input_name):
input_node = receiver.resource_inputs()[input_name] input_node = receiver.resource_inputs()[input_name]

View File

@ -89,6 +89,24 @@ class DBRelatedField(object):
self.name = name self.name = name
self.source_db_object = source_db_object self.source_db_object = source_db_object
def all(self):
source_db_node = self.source_db_object._db_node
if source_db_node is None:
return []
return db.get_relations(source=source_db_node,
type_=self.relation_type)
def all_by_dest(self, destination_db_object):
destination_db_node = destination_db_object._db_node
if destination_db_node is None:
return set()
return db.get_relations(dest=destination_db_node,
type_=self.relation_type)
def add(self, *destination_db_objects): def add(self, *destination_db_objects):
for dest in destination_db_objects: for dest in destination_db_objects:
if not isinstance(dest, self.destination_db_class): if not isinstance(dest, self.destination_db_class):
@ -105,6 +123,21 @@ class DBRelatedField(object):
type_=self.relation_type type_=self.relation_type
) )
def add_hash(self, destination_db_object, destination_key):
if not isinstance(destination_db_object, self.destination_db_class):
raise errors.SolarError(
'Object {} is of incompatible type {}.'.format(
destination_db_object, self.destination_db_class
)
)
db.get_or_create_relation(
self.source_db_object._db_node,
destination_db_object._db_node,
properties={'destination_key': destination_key},
type_=self.relation_type
)
def remove(self, *destination_db_objects): def remove(self, *destination_db_objects):
for dest in destination_db_objects: for dest in destination_db_objects:
db.delete_relations( db.delete_relations(
@ -119,13 +152,7 @@ class DBRelatedField(object):
Return DB objects that are destinations for self.source_db_object. Return DB objects that are destinations for self.source_db_object.
""" """
source_db_node = self.source_db_object._db_node relations = self.all()
if source_db_node is None:
return set()
relations = db.get_relations(source=source_db_node,
type_=self.relation_type)
ret = set() ret = set()
@ -142,13 +169,7 @@ class DBRelatedField(object):
return source DB objects. return source DB objects.
""" """
destination_db_node = destination_db_object._db_node relations = self.all_by_dest(destination_db_object)
if destination_db_node is None:
return set()
relations = db.get_relations(dest=destination_db_node,
type_=self.relation_type)
ret = set() ret = set()
@ -339,23 +360,64 @@ class DBResourceInput(DBObject):
name = db_field(schema='str!') name = db_field(schema='str!')
schema = db_field() schema = db_field()
value = db_field(schema_in_field='schema') value = db_field(schema_in_field='schema')
is_list = db_field(schema='bool') is_list = db_field(schema='bool!', default_value=False)
is_hash = db_field(schema='bool!', default_value=False)
receivers = db_related_field(base.BaseGraphDB.RELATION_TYPES.input_to_input, receivers = db_related_field(base.BaseGraphDB.RELATION_TYPES.input_to_input,
'DBResourceInput') 'DBResourceInput')
@property
def resource(self):
return DBResource(
**db.get_relations(
dest=self._db_node,
type_=base.BaseGraphDB.RELATION_TYPES.resource_input
)[0].start_node.properties
)
def backtrack_value(self): def backtrack_value(self):
# TODO: this is actually just fetching head element in linked list # TODO: this is actually just fetching head element in linked list
# so this whole algorithm can be moved to the db backend probably # so this whole algorithm can be moved to the db backend probably
# TODO: cycle detection? # TODO: cycle detection?
# TODO: write this as a Cypher query? Move to DB? # TODO: write this as a Cypher query? Move to DB?
inputs = self.receivers.sources(self) inputs = self.receivers.sources(self)
relations = self.receivers.all_by_dest(self)
source_class = self.receivers.source_db_class
if not inputs: if not inputs:
return self.value return self.value
if self.is_list: if self.is_list:
return [i.backtrack_value() for i in inputs] if not self.is_hash:
return [i.backtrack_value() for i in inputs]
# NOTE: we return a list of values, but we need to group them
# by resource name, hence this dict here
ret = {}
for r in relations:
source = source_class(**r.start_node.properties)
ret.setdefault(source.resource.name, {})
key = r.properties['destination_key']
value = source.backtrack_value()
ret[source.resource.name].update({key: value})
return ret.values()
elif self.is_hash:
ret = self.value or {}
for r in relations:
source = source_class(
**r.start_node.properties
)
# NOTE: hard way to do this, what if there are more relations
# and some of them do have destination_key while others
# don't?
if 'destination_key' not in r.properties:
return source.backtrack_value()
key = r.properties['destination_key']
ret[key] = source.backtrack_value()
return ret
return inputs.pop().backtrack_value() return inputs.pop().backtrack_value()
@ -387,7 +449,8 @@ class DBResource(DBObject):
name=name, name=name,
schema=schema, schema=schema,
value=value, value=value,
is_list=isinstance(schema, list)) is_list=isinstance(schema, list),
is_hash=isinstance(schema, dict) or (isinstance(schema, list) and len(schema) > 0 and isinstance(schema[0], dict)))
input.save() input.save()
self.inputs.add(input) self.inputs.add(input)

View File

@ -457,10 +457,10 @@ input:
'sample2', sample_meta_dir, {'ip': '10.0.0.2', 'port': 1001} 'sample2', sample_meta_dir, {'ip': '10.0.0.2', 'port': 1001}
) )
list_input = self.create_resource( list_input = self.create_resource(
'list-input', list_input_meta_dir, {} 'list-input', list_input_meta_dir,
) )
list_input_nested = self.create_resource( list_input_nested = self.create_resource(
'list-input-nested', list_input_nested_meta_dir, {} 'list-input-nested', list_input_nested_meta_dir,
) )
xs.connect(sample1, list_input, mapping={'ip': 'ips', 'port': 'ports'}) xs.connect(sample1, list_input, mapping={'ip': 'ips', 'port': 'ports'})
@ -490,16 +490,19 @@ input:
) )
''' class TestHashInput(base.BaseResourceTest):
class TestMultiInput(base.BaseResourceTest): def test_hash_input_basic(self):
def test_multi_input(self):
sample_meta_dir = self.make_resource_meta(""" sample_meta_dir = self.make_resource_meta("""
id: sample id: sample
handler: ansible handler: ansible
version: 1.0.0 version: 1.0.0
input: input:
ip: ip:
schema: str!
value:
port: port:
schema: int!
value:
""") """)
receiver_meta_dir = self.make_resource_meta(""" receiver_meta_dir = self.make_resource_meta("""
id: receiver id: receiver
@ -507,17 +510,147 @@ handler: ansible
version: 1.0.0 version: 1.0.0
input: input:
server: server:
schema: {ip: str!, port: int!}
""") """)
sample = self.create_resource( sample = self.create_resource(
'sample', sample_meta_dir, {'ip': '10.0.0.1', 'port': '5000'} 'sample', sample_meta_dir, args={'ip': '10.0.0.1', 'port': 5000}
) )
receiver = self.create_resource( receiver = self.create_resource(
'receiver', receiver_meta_dir, {'server': None} 'receiver', receiver_meta_dir
) )
xs.connect(sample, receiver, mapping={'ip, port': 'server'}) xs.connect(sample, receiver, mapping={'ip': 'server:ip', 'port': 'server:port'})
self.assertItemsEqual( self.assertDictEqual(
(sample.args['ip'], sample.args['port']), {'ip': sample.args['ip'], 'port': sample.args['port']},
receiver.args['server'],
)
sample.update({'ip': '10.0.0.2'})
self.assertDictEqual(
{'ip': sample.args['ip'], 'port': sample.args['port']},
receiver.args['server'],
)
def test_hash_input_mixed(self):
sample_meta_dir = self.make_resource_meta("""
id: sample
handler: ansible
version: 1.0.0
input:
ip:
schema: str!
value:
port:
schema: int!
value:
""")
receiver_meta_dir = self.make_resource_meta("""
id: receiver
handler: ansible
version: 1.0.0
input:
server:
schema: {ip: str!, port: int!}
""")
sample = self.create_resource(
'sample', sample_meta_dir, args={'ip': '10.0.0.1', 'port': 5000}
)
receiver = self.create_resource(
'receiver', receiver_meta_dir, args={'server': {'port': 5001}}
)
xs.connect(sample, receiver, mapping={'ip': 'server:ip'})
self.assertDictEqual(
{'ip': sample.args['ip'], 'port': 5001},
receiver.args['server'],
)
sample.update({'ip': '10.0.0.2'})
self.assertDictEqual(
{'ip': sample.args['ip'], 'port': 5001},
receiver.args['server'],
)
def test_hash_input_with_list(self):
sample_meta_dir = self.make_resource_meta("""
id: sample
handler: ansible
version: 1.0.0
input:
ip:
schema: str!
value:
port:
schema: int!
value:
""")
receiver_meta_dir = self.make_resource_meta("""
id: receiver
handler: ansible
version: 1.0.0
input:
server:
schema: [{ip: str!, port: int!}]
""")
sample1 = self.create_resource(
'sample1', sample_meta_dir, args={'ip': '10.0.0.1', 'port': 5000}
)
receiver = self.create_resource(
'receiver', receiver_meta_dir
)
xs.connect(sample1, receiver, mapping={'ip': 'server:ip', 'port': 'server:port'})
self.assertItemsEqual(
[{'ip': sample1.args['ip'], 'port': sample1.args['port']}],
receiver.args['server'],
)
sample2 = self.create_resource(
'sample2', sample_meta_dir, args={'ip': '10.0.0.2', 'port': 5001}
)
xs.connect(sample2, receiver, mapping={'ip': 'server:ip', 'port': 'server:port'})
self.assertItemsEqual(
[{'ip': sample1.args['ip'], 'port': sample1.args['port']},
{'ip': sample2.args['ip'], 'port': sample2.args['port']}],
receiver.args['server'],
)
xs.disconnect(sample1, receiver)
self.assertItemsEqual(
[{'ip': sample2.args['ip'], 'port': sample2.args['port']}],
receiver.args['server'],
)
def test_hash_input_with_multiple_connections(self):
sample_meta_dir = self.make_resource_meta("""
id: sample
handler: ansible
version: 1.0.0
input:
ip:
schema: str!
value:
""")
receiver_meta_dir = self.make_resource_meta("""
id: receiver
handler: ansible
version: 1.0.0
input:
ip:
schema: str!
value:
server:
schema: {ip: str!}
""")
sample = self.create_resource(
'sample', sample_meta_dir, args={'ip': '10.0.0.1'}
)
receiver = self.create_resource(
'receiver', receiver_meta_dir
)
xs.connect(sample, receiver, mapping={'ip': ['ip', 'server:ip']})
self.assertEqual(
sample.args['ip'],
receiver.args['ip']
)
self.assertDictEqual(
{'ip': sample.args['ip']},
receiver.args['server'], receiver.args['server'],
) )
'''