Merge pull request #155 from CGenie/cgenie/graph-db-hash-type
Add support for hash inputs, enable multiple inputs to connect into o…
This commit is contained in:
commit
2cd5d948c2
@ -38,8 +38,9 @@ class Resource(object):
|
||||
_metadata = {}
|
||||
|
||||
# Create
|
||||
@dispatch(str, str, dict)
|
||||
def __init__(self, name, base_path, args, tags=None, virtual_resource=None):
|
||||
@dispatch(str, str)
|
||||
def __init__(self, name, base_path, args=None, tags=None, virtual_resource=None):
|
||||
args = args or {}
|
||||
self.name = name
|
||||
if base_path:
|
||||
metadata = read_meta(base_path)
|
||||
@ -86,7 +87,8 @@ class Resource(object):
|
||||
k: v for k, v in ret.items() if os.path.isfile(v)
|
||||
}
|
||||
|
||||
def create_inputs(self, args):
|
||||
def create_inputs(self, args=None):
|
||||
args = args or {}
|
||||
for name, v in self.db_obj.meta_inputs.items():
|
||||
value = args.get(name, v.get('value'))
|
||||
|
||||
|
@ -24,7 +24,8 @@ from solar.core import resource
|
||||
from solar.core import signals
|
||||
|
||||
|
||||
def create(name, base_path, args={}, virtual_resource=None):
|
||||
def create(name, base_path, args=None, virtual_resource=None):
|
||||
args = args or {}
|
||||
if isinstance(base_path, provider.BaseProvider):
|
||||
base_path = base_path.directory
|
||||
|
||||
@ -47,12 +48,13 @@ def create(name, base_path, args={}, virtual_resource=None):
|
||||
return rs
|
||||
|
||||
|
||||
def create_resource(name, base_path, args={}, virtual_resource=None):
|
||||
def create_resource(name, base_path, args=None, virtual_resource=None):
|
||||
args = args or {}
|
||||
if isinstance(base_path, provider.BaseProvider):
|
||||
base_path = base_path.directory
|
||||
|
||||
r = resource.Resource(
|
||||
name, base_path, args, tags=[], virtual_resource=virtual_resource
|
||||
name, base_path, args=args, tags=[], virtual_resource=virtual_resource
|
||||
)
|
||||
return r
|
||||
|
||||
|
@ -73,6 +73,9 @@ def connect(emitter, receiver, mapping={}, events=None):
|
||||
|
||||
|
||||
def connect_single(emitter, src, receiver, dst):
|
||||
if ':' in dst:
|
||||
return connect_multi(emitter, src, receiver, dst)
|
||||
|
||||
# Disconnect all receiver inputs
|
||||
# Check if receiver input is of list type first
|
||||
emitter_input = emitter.resource_inputs()[src]
|
||||
@ -98,6 +101,25 @@ def connect_single(emitter, src, receiver, dst):
|
||||
emitter_input.receivers.add(receiver_input)
|
||||
|
||||
|
||||
def connect_multi(emitter, src, receiver, dst):
|
||||
receiver_input_name, receiver_input_key = dst.split(':')
|
||||
|
||||
emitter_input = emitter.resource_inputs()[src]
|
||||
receiver_input = receiver.resource_inputs()[receiver_input_name]
|
||||
|
||||
# NOTE: make sure that receiver.args[receiver_input] is of dict type
|
||||
if not receiver_input.is_hash:
|
||||
raise Exception(
|
||||
'Receiver input {} must be a hash or a list of hashes'.format(receiver_input_name)
|
||||
)
|
||||
|
||||
log.debug('Connecting {}::{} -> {}::{}[{}]'.format(
|
||||
emitter.name, emitter_input.name, receiver.name, receiver_input.name,
|
||||
receiver_input_key
|
||||
))
|
||||
emitter_input.receivers.add_hash(receiver_input, receiver_input_key)
|
||||
|
||||
|
||||
def disconnect_receiver_by_input(receiver, input_name):
|
||||
input_node = receiver.resource_inputs()[input_name]
|
||||
|
||||
|
@ -89,6 +89,24 @@ class DBRelatedField(object):
|
||||
self.name = name
|
||||
self.source_db_object = source_db_object
|
||||
|
||||
def all(self):
|
||||
source_db_node = self.source_db_object._db_node
|
||||
|
||||
if source_db_node is None:
|
||||
return []
|
||||
|
||||
return db.get_relations(source=source_db_node,
|
||||
type_=self.relation_type)
|
||||
|
||||
def all_by_dest(self, destination_db_object):
|
||||
destination_db_node = destination_db_object._db_node
|
||||
|
||||
if destination_db_node is None:
|
||||
return set()
|
||||
|
||||
return db.get_relations(dest=destination_db_node,
|
||||
type_=self.relation_type)
|
||||
|
||||
def add(self, *destination_db_objects):
|
||||
for dest in destination_db_objects:
|
||||
if not isinstance(dest, self.destination_db_class):
|
||||
@ -105,6 +123,21 @@ class DBRelatedField(object):
|
||||
type_=self.relation_type
|
||||
)
|
||||
|
||||
def add_hash(self, destination_db_object, destination_key):
|
||||
if not isinstance(destination_db_object, self.destination_db_class):
|
||||
raise errors.SolarError(
|
||||
'Object {} is of incompatible type {}.'.format(
|
||||
destination_db_object, self.destination_db_class
|
||||
)
|
||||
)
|
||||
|
||||
db.get_or_create_relation(
|
||||
self.source_db_object._db_node,
|
||||
destination_db_object._db_node,
|
||||
properties={'destination_key': destination_key},
|
||||
type_=self.relation_type
|
||||
)
|
||||
|
||||
def remove(self, *destination_db_objects):
|
||||
for dest in destination_db_objects:
|
||||
db.delete_relations(
|
||||
@ -119,13 +152,7 @@ class DBRelatedField(object):
|
||||
Return DB objects that are destinations for self.source_db_object.
|
||||
"""
|
||||
|
||||
source_db_node = self.source_db_object._db_node
|
||||
|
||||
if source_db_node is None:
|
||||
return set()
|
||||
|
||||
relations = db.get_relations(source=source_db_node,
|
||||
type_=self.relation_type)
|
||||
relations = self.all()
|
||||
|
||||
ret = set()
|
||||
|
||||
@ -142,13 +169,7 @@ class DBRelatedField(object):
|
||||
return source DB objects.
|
||||
"""
|
||||
|
||||
destination_db_node = destination_db_object._db_node
|
||||
|
||||
if destination_db_node is None:
|
||||
return set()
|
||||
|
||||
relations = db.get_relations(dest=destination_db_node,
|
||||
type_=self.relation_type)
|
||||
relations = self.all_by_dest(destination_db_object)
|
||||
|
||||
ret = set()
|
||||
|
||||
@ -339,23 +360,64 @@ class DBResourceInput(DBObject):
|
||||
name = db_field(schema='str!')
|
||||
schema = db_field()
|
||||
value = db_field(schema_in_field='schema')
|
||||
is_list = db_field(schema='bool')
|
||||
is_list = db_field(schema='bool!', default_value=False)
|
||||
is_hash = db_field(schema='bool!', default_value=False)
|
||||
|
||||
receivers = db_related_field(base.BaseGraphDB.RELATION_TYPES.input_to_input,
|
||||
'DBResourceInput')
|
||||
|
||||
@property
|
||||
def resource(self):
|
||||
return DBResource(
|
||||
**db.get_relations(
|
||||
dest=self._db_node,
|
||||
type_=base.BaseGraphDB.RELATION_TYPES.resource_input
|
||||
)[0].start_node.properties
|
||||
)
|
||||
|
||||
def backtrack_value(self):
|
||||
# TODO: this is actually just fetching head element in linked list
|
||||
# so this whole algorithm can be moved to the db backend probably
|
||||
# TODO: cycle detection?
|
||||
# TODO: write this as a Cypher query? Move to DB?
|
||||
inputs = self.receivers.sources(self)
|
||||
relations = self.receivers.all_by_dest(self)
|
||||
source_class = self.receivers.source_db_class
|
||||
|
||||
if not inputs:
|
||||
return self.value
|
||||
|
||||
if self.is_list:
|
||||
return [i.backtrack_value() for i in inputs]
|
||||
if not self.is_hash:
|
||||
return [i.backtrack_value() for i in inputs]
|
||||
|
||||
# NOTE: we return a list of values, but we need to group them
|
||||
# by resource name, hence this dict here
|
||||
ret = {}
|
||||
|
||||
for r in relations:
|
||||
source = source_class(**r.start_node.properties)
|
||||
ret.setdefault(source.resource.name, {})
|
||||
key = r.properties['destination_key']
|
||||
value = source.backtrack_value()
|
||||
|
||||
ret[source.resource.name].update({key: value})
|
||||
|
||||
return ret.values()
|
||||
elif self.is_hash:
|
||||
ret = self.value or {}
|
||||
for r in relations:
|
||||
source = source_class(
|
||||
**r.start_node.properties
|
||||
)
|
||||
# NOTE: hard way to do this, what if there are more relations
|
||||
# and some of them do have destination_key while others
|
||||
# don't?
|
||||
if 'destination_key' not in r.properties:
|
||||
return source.backtrack_value()
|
||||
key = r.properties['destination_key']
|
||||
ret[key] = source.backtrack_value()
|
||||
return ret
|
||||
|
||||
return inputs.pop().backtrack_value()
|
||||
|
||||
@ -387,7 +449,8 @@ class DBResource(DBObject):
|
||||
name=name,
|
||||
schema=schema,
|
||||
value=value,
|
||||
is_list=isinstance(schema, list))
|
||||
is_list=isinstance(schema, list),
|
||||
is_hash=isinstance(schema, dict) or (isinstance(schema, list) and len(schema) > 0 and isinstance(schema[0], dict)))
|
||||
input.save()
|
||||
|
||||
self.inputs.add(input)
|
||||
|
@ -457,10 +457,10 @@ input:
|
||||
'sample2', sample_meta_dir, {'ip': '10.0.0.2', 'port': 1001}
|
||||
)
|
||||
list_input = self.create_resource(
|
||||
'list-input', list_input_meta_dir, {}
|
||||
'list-input', list_input_meta_dir,
|
||||
)
|
||||
list_input_nested = self.create_resource(
|
||||
'list-input-nested', list_input_nested_meta_dir, {}
|
||||
'list-input-nested', list_input_nested_meta_dir,
|
||||
)
|
||||
|
||||
xs.connect(sample1, list_input, mapping={'ip': 'ips', 'port': 'ports'})
|
||||
@ -490,16 +490,19 @@ input:
|
||||
)
|
||||
|
||||
|
||||
'''
|
||||
class TestMultiInput(base.BaseResourceTest):
|
||||
def test_multi_input(self):
|
||||
class TestHashInput(base.BaseResourceTest):
|
||||
def test_hash_input_basic(self):
|
||||
sample_meta_dir = self.make_resource_meta("""
|
||||
id: sample
|
||||
handler: ansible
|
||||
version: 1.0.0
|
||||
input:
|
||||
ip:
|
||||
schema: str!
|
||||
value:
|
||||
port:
|
||||
schema: int!
|
||||
value:
|
||||
""")
|
||||
receiver_meta_dir = self.make_resource_meta("""
|
||||
id: receiver
|
||||
@ -507,17 +510,147 @@ handler: ansible
|
||||
version: 1.0.0
|
||||
input:
|
||||
server:
|
||||
schema: {ip: str!, port: int!}
|
||||
""")
|
||||
|
||||
sample = self.create_resource(
|
||||
'sample', sample_meta_dir, {'ip': '10.0.0.1', 'port': '5000'}
|
||||
'sample', sample_meta_dir, args={'ip': '10.0.0.1', 'port': 5000}
|
||||
)
|
||||
receiver = self.create_resource(
|
||||
'receiver', receiver_meta_dir, {'server': None}
|
||||
'receiver', receiver_meta_dir
|
||||
)
|
||||
xs.connect(sample, receiver, mapping={'ip, port': 'server'})
|
||||
self.assertItemsEqual(
|
||||
(sample.args['ip'], sample.args['port']),
|
||||
xs.connect(sample, receiver, mapping={'ip': 'server:ip', 'port': 'server:port'})
|
||||
self.assertDictEqual(
|
||||
{'ip': sample.args['ip'], 'port': sample.args['port']},
|
||||
receiver.args['server'],
|
||||
)
|
||||
sample.update({'ip': '10.0.0.2'})
|
||||
self.assertDictEqual(
|
||||
{'ip': sample.args['ip'], 'port': sample.args['port']},
|
||||
receiver.args['server'],
|
||||
)
|
||||
|
||||
def test_hash_input_mixed(self):
|
||||
sample_meta_dir = self.make_resource_meta("""
|
||||
id: sample
|
||||
handler: ansible
|
||||
version: 1.0.0
|
||||
input:
|
||||
ip:
|
||||
schema: str!
|
||||
value:
|
||||
port:
|
||||
schema: int!
|
||||
value:
|
||||
""")
|
||||
receiver_meta_dir = self.make_resource_meta("""
|
||||
id: receiver
|
||||
handler: ansible
|
||||
version: 1.0.0
|
||||
input:
|
||||
server:
|
||||
schema: {ip: str!, port: int!}
|
||||
""")
|
||||
|
||||
sample = self.create_resource(
|
||||
'sample', sample_meta_dir, args={'ip': '10.0.0.1', 'port': 5000}
|
||||
)
|
||||
receiver = self.create_resource(
|
||||
'receiver', receiver_meta_dir, args={'server': {'port': 5001}}
|
||||
)
|
||||
xs.connect(sample, receiver, mapping={'ip': 'server:ip'})
|
||||
self.assertDictEqual(
|
||||
{'ip': sample.args['ip'], 'port': 5001},
|
||||
receiver.args['server'],
|
||||
)
|
||||
sample.update({'ip': '10.0.0.2'})
|
||||
self.assertDictEqual(
|
||||
{'ip': sample.args['ip'], 'port': 5001},
|
||||
receiver.args['server'],
|
||||
)
|
||||
|
||||
def test_hash_input_with_list(self):
|
||||
sample_meta_dir = self.make_resource_meta("""
|
||||
id: sample
|
||||
handler: ansible
|
||||
version: 1.0.0
|
||||
input:
|
||||
ip:
|
||||
schema: str!
|
||||
value:
|
||||
port:
|
||||
schema: int!
|
||||
value:
|
||||
""")
|
||||
receiver_meta_dir = self.make_resource_meta("""
|
||||
id: receiver
|
||||
handler: ansible
|
||||
version: 1.0.0
|
||||
input:
|
||||
server:
|
||||
schema: [{ip: str!, port: int!}]
|
||||
""")
|
||||
|
||||
sample1 = self.create_resource(
|
||||
'sample1', sample_meta_dir, args={'ip': '10.0.0.1', 'port': 5000}
|
||||
)
|
||||
receiver = self.create_resource(
|
||||
'receiver', receiver_meta_dir
|
||||
)
|
||||
xs.connect(sample1, receiver, mapping={'ip': 'server:ip', 'port': 'server:port'})
|
||||
self.assertItemsEqual(
|
||||
[{'ip': sample1.args['ip'], 'port': sample1.args['port']}],
|
||||
receiver.args['server'],
|
||||
)
|
||||
sample2 = self.create_resource(
|
||||
'sample2', sample_meta_dir, args={'ip': '10.0.0.2', 'port': 5001}
|
||||
)
|
||||
xs.connect(sample2, receiver, mapping={'ip': 'server:ip', 'port': 'server:port'})
|
||||
self.assertItemsEqual(
|
||||
[{'ip': sample1.args['ip'], 'port': sample1.args['port']},
|
||||
{'ip': sample2.args['ip'], 'port': sample2.args['port']}],
|
||||
receiver.args['server'],
|
||||
)
|
||||
xs.disconnect(sample1, receiver)
|
||||
self.assertItemsEqual(
|
||||
[{'ip': sample2.args['ip'], 'port': sample2.args['port']}],
|
||||
receiver.args['server'],
|
||||
)
|
||||
|
||||
def test_hash_input_with_multiple_connections(self):
|
||||
sample_meta_dir = self.make_resource_meta("""
|
||||
id: sample
|
||||
handler: ansible
|
||||
version: 1.0.0
|
||||
input:
|
||||
ip:
|
||||
schema: str!
|
||||
value:
|
||||
""")
|
||||
receiver_meta_dir = self.make_resource_meta("""
|
||||
id: receiver
|
||||
handler: ansible
|
||||
version: 1.0.0
|
||||
input:
|
||||
ip:
|
||||
schema: str!
|
||||
value:
|
||||
server:
|
||||
schema: {ip: str!}
|
||||
""")
|
||||
|
||||
sample = self.create_resource(
|
||||
'sample', sample_meta_dir, args={'ip': '10.0.0.1'}
|
||||
)
|
||||
receiver = self.create_resource(
|
||||
'receiver', receiver_meta_dir
|
||||
)
|
||||
xs.connect(sample, receiver, mapping={'ip': ['ip', 'server:ip']})
|
||||
self.assertEqual(
|
||||
sample.args['ip'],
|
||||
receiver.args['ip']
|
||||
)
|
||||
self.assertDictEqual(
|
||||
{'ip': sample.args['ip']},
|
||||
receiver.args['server'],
|
||||
)
|
||||
'''
|
||||
|
Loading…
Reference in New Issue
Block a user