Redis: get rid of global CLIENTS variable

Now Connections are read from Redis on demand.
This commit is contained in:
Przemyslaw Kaminski 2015-06-10 16:02:14 +02:00
parent ac74bf73fc
commit 8aa0f6247a
6 changed files with 117 additions and 102 deletions

View File

@ -191,7 +191,7 @@ def init_cli_connections():
@connections.command() @connections.command()
def show(): def show():
print json.dumps(signals.CLIENTS, indent=2) print json.dumps(signals.Connections.read_clients(), indent=2)
# TODO: this requires graphing libraries # TODO: this requires graphing libraries
@connections.command() @connections.command()

View File

@ -28,7 +28,7 @@ class BaseObserver(object):
def receivers(self): def receivers(self):
from solar.core import resource from solar.core import resource
signals.CLIENTS = signals.Connections.read_clients() #signals.CLIENTS = signals.Connections.read_clients()
for receiver_name, receiver_input in signals.Connections.receivers( for receiver_name, receiver_input in signals.Connections.receivers(
self._attached_to_name, self._attached_to_name,
self.name self.name

View File

@ -196,7 +196,7 @@ def load(resource_name):
raw_resource = db.read(resource_name, collection=db.COLLECTIONS.resource) raw_resource = db.read(resource_name, collection=db.COLLECTIONS.resource)
if raw_resource is None: if raw_resource is None:
raise NotImplementedError( raise KeyError(
'Resource {} does not exist'.format(resource_name) 'Resource {} does not exist'.format(resource_name)
) )

View File

@ -1,42 +1,34 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import atexit
from collections import defaultdict from collections import defaultdict
import itertools import itertools
import networkx as nx import networkx as nx
import os
from solar import utils
from solar.interfaces.db import get_db from solar.interfaces.db import get_db
db = get_db() db = get_db()
CLIENTS_CONFIG_KEY = 'clients-data-file'
#CLIENTS = utils.read_config_file(CLIENTS_CONFIG_KEY)
CLIENTS = {}
class Connections(object): class Connections(object):
"""
CLIENTS structure is:
emitter_name:
emitter_input_name:
- - dst_name
- dst_input_name
while DB structure is:
emitter_name_key:
emitter: emitter_name
sources:
emitter_input_name:
- - dst_name
- dst_input_name
"""
@staticmethod @staticmethod
def read_clients(): def read_clients():
"""
Returned structure is:
emitter_name:
emitter_input_name:
- - dst_name
- dst_input_name
while DB structure is:
emitter_name_key:
emitter: emitter_name
sources:
emitter_input_name:
- - dst_name
- dst_input_name
"""
ret = {} ret = {}
for data in db.get_list(collection=db.COLLECTIONS.connection): for data in db.get_list(collection=db.COLLECTIONS.connection):
@ -45,8 +37,8 @@ class Connections(object):
return ret return ret
@staticmethod @staticmethod
def save_clients(): def save_clients(clients):
for emitter_name, sources in CLIENTS.items(): for emitter_name, sources in clients.items():
data = { data = {
'emitter': emitter_name, 'emitter': emitter_name,
'sources': sources, 'sources': sources,
@ -58,78 +50,46 @@ class Connections(object):
if src not in emitter.args: if src not in emitter.args:
return return
clients = Connections.read_clients()
# TODO: implement general circular detection, this one is simple # TODO: implement general circular detection, this one is simple
if [emitter.name, src] in CLIENTS.get(receiver.name, {}).get(dst, []): if [emitter.name, src] in clients.get(receiver.name, {}).get(dst, []):
raise Exception('Attempted to create cycle in dependencies. Not nice.') raise Exception('Attempted to create cycle in dependencies. Not nice.')
CLIENTS.setdefault(emitter.name, {}) clients.setdefault(emitter.name, {})
CLIENTS[emitter.name].setdefault(src, []) clients[emitter.name].setdefault(src, [])
if [receiver.name, dst] not in CLIENTS[emitter.name][src]: if [receiver.name, dst] not in clients[emitter.name][src]:
CLIENTS[emitter.name][src].append([receiver.name, dst]) clients[emitter.name][src].append([receiver.name, dst])
#utils.save_to_config_file(CLIENTS_CONFIG_KEY, CLIENTS) Connections.save_clients(clients)
Connections.save_clients()
@staticmethod @staticmethod
def remove(emitter, src, receiver, dst): def remove(emitter, src, receiver, dst):
CLIENTS[emitter.name][src] = [ clients = Connections.read_clients()
destination for destination in CLIENTS[emitter.name][src]
clients[emitter.name][src] = [
destination for destination in clients[emitter.name][src]
if destination != [receiver.name, dst] if destination != [receiver.name, dst]
] ]
#utils.save_to_config_file(CLIENTS_CONFIG_KEY, CLIENTS) Connections.save_clients(clients)
Connections.save_clients()
@staticmethod
def reconnect_all():
"""Reconstruct connections for resource inputs from CLIENTS.
:return:
"""
from solar.core.resource import wrap_resource
for emitter_name, dest_dict in CLIENTS.items():
emitter = wrap_resource(
db.read(emitter_name, collection=db.COLLECTIONS.resource)
)
for emitter_input, destinations in dest_dict.items():
for receiver_name, receiver_input in destinations:
receiver = wrap_resource(
db.read(receiver_name, collection=db.COLLECTIONS.resource)
)
emitter.args[emitter_input].subscribe(
receiver.args[receiver_input])
@staticmethod @staticmethod
def receivers(emitter_name, emitter_input_name): def receivers(emitter_name, emitter_input_name):
return CLIENTS.get(emitter_name, {}).get(emitter_input_name, []) return Connections.read_clients().get(emitter_name, {}).get(
emitter_input_name, []
)
@staticmethod @staticmethod
def emitter(receiver_name, receiver_input_name): def emitter(receiver_name, receiver_input_name):
for emitter_name, dest_dict in CLIENTS.items(): for emitter_name, dest_dict in Connections.read_clients().items():
for emitter_input_name, destinations in dest_dict.items(): for emitter_input_name, destinations in dest_dict.items():
if [receiver_name, receiver_input_name] in destinations: if [receiver_name, receiver_input_name] in destinations:
return [emitter_name, emitter_input_name] return [emitter_name, emitter_input_name]
@staticmethod @staticmethod
def clear(): def clear():
global CLIENTS db.clear_collection(collection=db.COLLECTIONS.connection)
CLIENTS = {}
path = utils.read_config()[CLIENTS_CONFIG_KEY]
if os.path.exists(path):
os.remove(path)
@staticmethod
def flush():
print 'FLUSHING Connections'
#utils.save_to_config_file(CLIENTS_CONFIG_KEY, CLIENTS)
Connections.save_clients()
CLIENTS = Connections.read_clients()
#atexit.register(Connections.flush)
def guess_mapping(emitter, receiver): def guess_mapping(emitter, receiver):
@ -173,9 +133,9 @@ def connect(emitter, receiver, mapping=None):
def disconnect(emitter, receiver): def disconnect(emitter, receiver):
for src, destinations in CLIENTS[emitter.name].items(): clients = Connections.read_clients()
disconnect_by_src(emitter.name, src, receiver)
for src, destinations in clients[emitter.name].items():
for destination in destinations: for destination in destinations:
receiver_input = destination[1] receiver_input = destination[1]
if receiver_input in receiver.args: if receiver_input in receiver.args:
@ -183,6 +143,8 @@ def disconnect(emitter, receiver):
print 'Removing input {} from {}'.format(receiver_input, receiver.name) print 'Removing input {} from {}'.format(receiver_input, receiver.name)
emitter.args[src].unsubscribe(receiver.args[receiver_input]) emitter.args[src].unsubscribe(receiver.args[receiver_input])
disconnect_by_src(emitter.name, src, receiver)
def disconnect_receiver_by_input(receiver, input): def disconnect_receiver_by_input(receiver, input):
"""Find receiver connection by input and disconnect it. """Find receiver connection by input and disconnect it.
@ -191,31 +153,36 @@ def disconnect_receiver_by_input(receiver, input):
:param input: :param input:
:return: :return:
""" """
for emitter_name, inputs in CLIENTS.items(): clients = Connections.read_clients()
emitter = db.read(emitter_name, collection=db.COLLECTIONS.resource)
disconnect_by_src(emitter['id'], input, receiver) for emitter_name, inputs in clients.items():
disconnect_by_src(emitter_name, input, receiver)
def disconnect_by_src(emitter_name, src, receiver): def disconnect_by_src(emitter_name, src, receiver):
if src in CLIENTS[emitter_name]: clients = Connections.read_clients()
CLIENTS[emitter_name][src] = [
destination for destination in CLIENTS[emitter_name][src] if src in clients[emitter_name]:
clients[emitter_name][src] = [
destination for destination in clients[emitter_name][src]
if destination[0] != receiver.name if destination[0] != receiver.name
] ]
#utils.save_to_config_file(CLIENTS_CONFIG_KEY, CLIENTS) Connections.save_clients(clients)
def notify(source, key, value): def notify(source, key, value):
from solar.core.resource import wrap_resource from solar.core.resource import load
CLIENTS.setdefault(source.name, {}) clients = Connections.read_clients()
print 'Notify', source.name, key, value, CLIENTS[source.name]
if key in CLIENTS[source.name]: clients.setdefault(source.name, {})
for client, r_key in CLIENTS[source.name][key]: Connections.save_clients(clients)
resource = wrap_resource(
db.read(client, collection=db.COLLECTIONS.resource) print 'Notify', source.name, key, value, clients[source.name]
) if key in clients[source.name]:
for client, r_key in clients[source.name][key]:
resource = load(client)
print 'Resource found', client print 'Resource found', client
if resource: if resource:
resource.update({r_key: value}, emitter=source) resource.update({r_key: value}, emitter=source)
@ -236,7 +203,9 @@ def assign_connections(receiver, connections):
def connection_graph(): def connection_graph():
resource_dependencies = {} resource_dependencies = {}
for source, destination_values in CLIENTS.items(): clients = Connections.read_clients()
for source, destination_values in clients.items():
resource_dependencies.setdefault(source, set()) resource_dependencies.setdefault(source, set())
for src, destinations in destination_values.items(): for src, destinations in destination_values.items():
resource_dependencies[source].update([ resource_dependencies[source].update([
@ -262,8 +231,10 @@ def connection_graph():
def detailed_connection_graph(): def detailed_connection_graph():
g = nx.MultiDiGraph() g = nx.MultiDiGraph()
for emitter_name, destination_values in CLIENTS.items(): clients = Connections.read_clients()
for emitter_input, receivers in CLIENTS[emitter_name].items():
for emitter_name, destination_values in clients.items():
for emitter_input, receivers in clients[emitter_name].items():
for receiver_name, receiver_input in receivers: for receiver_name, receiver_input in receivers:
label = emitter_input label = emitter_input
if emitter_input != receiver_input: if emitter_input != receiver_input:

View File

@ -47,5 +47,13 @@ class RedisDB(object):
def clear(self): def clear(self):
self._r.flushdb() self._r.flushdb()
def clear_collection(self, collection=COLLECTIONS.resource):
key_glob = self._make_key(collection, '*')
self._r.delete(self._r.keys(key_glob))
def delete(self, uid, collection=COLLECTIONS.resource):
self._r.delete(self._make_key(collection, uid))
def _make_key(self, collection, _id): def _make_key(self, collection, _id):
return '{0}:{1}'.format(collection, _id) return '{0}:{1}'.format(collection, _id)

View File

@ -2,7 +2,8 @@ import unittest
import base import base
from solar.core import signals as xs from solar.core import resource
from solar.core import signals
class TestResource(base.BaseResourceTest): class TestResource(base.BaseResourceTest):
@ -26,6 +27,41 @@ input:
sample2 = self.create_resource('sample2', sample_meta_dir, {}) sample2 = self.create_resource('sample2', sample_meta_dir, {})
self.assertEqual(sample2.args['value'].value, 0) self.assertEqual(sample2.args['value'].value, 0)
def test_connections_recreated_after_load(self):
"""
Create resource in some process. Then in other process load it.
All connections should remain the same.
"""
sample_meta_dir = self.make_resource_meta("""
id: sample
handler: ansible
version: 1.0.0
input:
value:
schema: int
value: 0
""")
def creating_process():
sample1 = self.create_resource(
'sample1', sample_meta_dir, {'value': 1}
)
sample2 = self.create_resource(
'sample2', sample_meta_dir, {}
)
signals.connect(sample1, sample2)
self.assertEqual(sample1.args['value'], sample2.args['value'])
creating_process()
signals.CLIENTS = {}
sample1 = resource.load('sample1')
sample2 = resource.load('sample2')
sample1.update({'value': 2})
self.assertEqual(sample1.args['value'], sample2.args['value'])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()