diff --git a/tests/unit/test_worker.py b/tests/unit/test_worker.py index f78b0bc..210e6c2 100644 --- a/tests/unit/test_worker.py +++ b/tests/unit/test_worker.py @@ -181,6 +181,7 @@ class ConsumerTestCase(StacktachBaseTestCase): self.mox.StubOutWithMock(db, 'get_deployment') deployment = self.mox.CreateMockAnything() deployment.id = 1 + stats = self.mox.CreateMockAnything() db.get_deployment(deployment.id).AndReturn(deployment) self.mox.StubOutWithMock(kombu.connection, 'BrokerConnection') params = dict(hostname=config['rabbit_host'], @@ -199,11 +200,11 @@ class ConsumerTestCase(StacktachBaseTestCase): exchange = 'nova' consumer = worker.Consumer(config['name'], conn, deployment, config['durable_queue'], {}, exchange, - self._test_topics()) + self._test_topics(), stats=stats) consumer.run() worker.continue_running().AndReturn(False) self.mox.ReplayAll() - worker.run(config, deployment.id, exchange) + worker.run(config, deployment.id, exchange, stats) self.mox.VerifyAll() def test_run_queue_args(self): @@ -233,6 +234,7 @@ class ConsumerTestCase(StacktachBaseTestCase): deployment = self.mox.CreateMockAnything() deployment.id = 1 db.get_deployment(deployment.id).AndReturn(deployment) + stats = self.mox.CreateMockAnything() self.mox.StubOutWithMock(kombu.connection, 'BrokerConnection') params = dict(hostname=config['rabbit_host'], port=config['rabbit_port'], @@ -251,9 +253,9 @@ class ConsumerTestCase(StacktachBaseTestCase): consumer = worker.Consumer(config['name'], conn, deployment, config['durable_queue'], config['queue_arguments'], exchange, - self._test_topics()) + self._test_topics(), stats=stats) consumer.run() worker.continue_running().AndReturn(False) self.mox.ReplayAll() - worker.run(config, deployment.id, exchange) + worker.run(config, deployment.id, exchange, stats) self.mox.VerifyAll() diff --git a/worker/config.py b/worker/config.py index f924130..422b264 100644 --- a/worker/config.py +++ b/worker/config.py @@ -38,3 +38,11 @@ def deployments(): def topics(): return config['topics'] + + +def workers(): + if 'workers' in config: + return config['workers'] + else: + return dict() + diff --git a/worker/start_workers.py b/worker/start_workers.py index 3654cbc..4f0adf6 100644 --- a/worker/start_workers.py +++ b/worker/start_workers.py @@ -1,8 +1,10 @@ +import datetime import os import signal import sys +import time -from multiprocessing import Process +from multiprocessing import Process, Manager POSSIBLE_TOPDIR = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]), os.pardir, os.pardir)) @@ -15,45 +17,128 @@ from django.db import close_connection import worker.worker as worker from worker import config -processes = [] +processes = {} log_listener = None stacklog.set_default_logger_name('worker') +DEFAULT_PROC_TIMEOUT = 600 +RUNNING = True def _get_parent_logger(): return stacklog.get_logger('worker', is_parent=True) -def kill_time(signal, frame): - print "dying ..." - for process in processes: +def create_proc_table(manager): + for deployment in config.deployments(): + if deployment.get('enabled', True): + name = deployment['name'] + db_deployment, new = db.get_or_create_deployment(name) + for exchange in deployment.get('topics').keys(): + stats = manager.dict() + proc_info = dict(process=None, + pid=0, + deployment=deployment, + deploy_id=db_deployment.id, + exchange=exchange, + stats=stats) + processes[(name, exchange)] = proc_info + + +def is_alive(proc_info): + process = proc_info['process'] + if not proc_info['pid'] or process is None: + return False + return process.is_alive(): + + +def needs_restart(proc_info): + timeout = config.workers().get('process_timeout', DEFAULT_PROC_TIMEOUT) + process = proc_info['process'] + stats = proc_info['stats'] + age = datetime.datetime.utcnow() - stats['timestamp'] + if age > datetime.timedelta(seconds=timeout): process.terminate() - print "rose" - for process in processes: - process.join() - log_listener.end() - print "bud" - sys.exit(0) + return True + return False + + +def start_proc(proc_info): + logger = _get_parent_logger() + if is_alive(proc_info): + if needs_restart(proc_info): + logger.warning("Child process %s (%s %s) terminated due to " + "heartbeat timeout. Restarting..." % (proc_info['pid'], + proc_info['deployment']['name'], proc_info['exchange'])) + else: + return False + stats = proc_info['stats'] + stats['timestamp'] = datetime.datetime.utcnow() + stats['total_processed'] = 0 + stats['processed'] = 0 + args = (proc_info['deployment'], proc_info['deploy_id'], + proc_info['exchange'], stats) + process = Process(target=worker.run, args=args) + process.daemon = True + process.start() + proc_info['pid'] = process.pid + proc_info['process'] = process + logger.info("Started child process %s (%s %s)" % (proc_info['pid'], + proc_info['deployment']['name'], proc_info['exchange'])) + return True + + +def check_or_start_all(): + for proc_name in sorted(processes.keys()): + if RUNNING: + start_proc(processes[proc_name]) + + +def stop_all(): + procs = sorted(processes.keys()) + for pname in procs: + process = processes[pname]['process'] + if process is not None: + process.terminate() + for pname in procs: + process = processes[pname]['process'] + if process is not None: + process.join() + processes[pname]['process'] = None + processes[pname]['pid'] = 0 + + +def kill_time(signal, frame): + global RUNNING + RUNNING = False + stop_all() if __name__ == '__main__': - log_listener = stacklog.LogListener(_get_parent_logger()) + logger = _get_parent_logger() + log_listener = stacklog.LogListener(logger) log_listener.start() - for deployment in config.deployments(): - if deployment.get('enabled', True): - db_deployment, new = db.get_or_create_deployment(deployment['name']) - # NOTE (apmelton) - # Close the connection before spinning up the child process, - # otherwise the child process will attempt to use the connection - # the parent process opened up to get/create the deployment. - close_connection() - for exchange in deployment.get('topics').keys(): - process = Process(target=worker.run, args=(deployment, - db_deployment.id, - exchange,)) - process.daemon = True - process.start() - processes.append(process) + manager = Manager() + + create_proc_table(manager) + + # NOTE (apmelton) + # Close the connection before spinning up the child process, + # otherwise the child process will attempt to use the connection + # the parent process opened up to get/create the deployment. + close_connection() + signal.signal(signal.SIGINT, kill_time) signal.signal(signal.SIGTERM, kill_time) - signal.pause() + + logger.info("Starting Workers...") + while RUNNING: + check_or_start_all() + time.sleep(30) + logger.info("Workers Shutting down...") + + #make sure. + stop_all() + + log_listener.end() + sys.exit(0) + diff --git a/worker/worker.py b/worker/worker.py index 931def1..73079f5 100644 --- a/worker/worker.py +++ b/worker/worker.py @@ -51,7 +51,7 @@ def _get_child_logger(): class Consumer(kombu.mixins.ConsumerMixin): def __init__(self, name, connection, deployment, durable, queue_arguments, - exchange, topics, connect_max_retries=10): + exchange, topics, connect_max_retries=10, stats=None): self.connect_max_retries = connect_max_retries self.retry_attempts = 0 self.connection = connection @@ -65,6 +65,10 @@ class Consumer(kombu.mixins.ConsumerMixin): self.total_processed = 0 self.topics = topics self.exchange = exchange + if stats is not None: + self.stats = stats + else: + self.stats = dict() signal.signal(signal.SIGTERM, self._shutdown) def _create_exchange(self, name, type, exclusive=False, auto_delete=False): @@ -130,6 +134,9 @@ class Consumer(kombu.mixins.ConsumerMixin): "%3d/%4d msgs @ %6dk/msg" % (self.name, self.exchange, diff, idiff, self.processed, self.total_processed, per_message)) + self.stats['timestamp'] = utc + self.stats['total_processed'] = self.total_processed + self.stats['processed'] = self.processed self.last_vsz = self.pmi.vsz self.processed = 0 @@ -177,7 +184,7 @@ def exit_or_sleep(exit=False): time.sleep(5) -def run(deployment_config, deployment_id, exchange): +def run(deployment_config, deployment_id, exchange, stats=None): name = deployment_config['name'] host = deployment_config.get('rabbit_host', 'localhost') port = deployment_config.get('rabbit_port', 5672) @@ -211,7 +218,7 @@ def run(deployment_config, deployment_id, exchange): try: consumer = Consumer(name, conn, deployment, durable, queue_arguments, exchange, - topics[exchange]) + topics[exchange], stats=stats) consumer.run() except Exception as e: logger.error("!!!!Exception!!!!")