233 lines
7.3 KiB
Python
233 lines
7.3 KiB
Python
# Copyright 2016 Mirantis, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
|
# not use this file except in compliance with the License. You may obtain
|
|
# a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
# License for the specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
import logging
|
|
import logging.handlers
|
|
import multiprocessing
|
|
import os
|
|
import sys
|
|
import threading
|
|
import time
|
|
import uuid
|
|
|
|
from oslo_config import cfg
|
|
|
|
import oslo_messaging
|
|
from oslo_messaging._drivers.zmq_driver import zmq_async
|
|
from oslo_messaging.tests.functional import utils
|
|
|
|
|
|
zmq = zmq_async.import_zmq()
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
|
|
class QueueHandler(logging.Handler):
|
|
"""This is a logging handler which sends events to a multiprocessing queue.
|
|
|
|
The plan is to add it to Python 3.2, but this can be copy pasted into
|
|
user code for use with earlier Python versions.
|
|
"""
|
|
|
|
def __init__(self, queue):
|
|
"""Initialise an instance, using the passed queue."""
|
|
logging.Handler.__init__(self)
|
|
self.queue = queue
|
|
|
|
def emit(self, record):
|
|
"""Emit a record.
|
|
|
|
Writes the LogRecord to the queue.
|
|
"""
|
|
try:
|
|
ei = record.exc_info
|
|
if ei:
|
|
# just to get traceback text into record.exc_text
|
|
dummy = self.format(record) # noqa
|
|
record.exc_info = None # not needed any more
|
|
self.queue.put_nowait(record)
|
|
except (KeyboardInterrupt, SystemExit):
|
|
raise
|
|
except Exception:
|
|
self.handleError(record)
|
|
|
|
|
|
def listener_configurer(conf):
|
|
root = logging.getLogger()
|
|
h = logging.StreamHandler(sys.stdout)
|
|
f = logging.Formatter('%(asctime)s %(processName)-10s %(name)s '
|
|
'%(levelname)-8s %(message)s')
|
|
h.setFormatter(f)
|
|
root.addHandler(h)
|
|
log_path = conf.rpc_zmq_ipc_dir + "/" + "zmq_multiproc.log"
|
|
file_handler = logging.StreamHandler(open(log_path, 'w'))
|
|
file_handler.setFormatter(f)
|
|
root.addHandler(file_handler)
|
|
|
|
|
|
def server_configurer(queue):
|
|
h = QueueHandler(queue)
|
|
root = logging.getLogger()
|
|
root.addHandler(h)
|
|
root.setLevel(logging.DEBUG)
|
|
|
|
|
|
def listener_thread(queue, configurer, conf):
|
|
configurer(conf)
|
|
while True:
|
|
time.sleep(0.3)
|
|
try:
|
|
record = queue.get()
|
|
if record is None:
|
|
break
|
|
logger = logging.getLogger(record.name)
|
|
logger.handle(record)
|
|
except (KeyboardInterrupt, SystemExit):
|
|
raise
|
|
|
|
|
|
class Client(oslo_messaging.RPCClient):
|
|
|
|
def __init__(self, transport, topic):
|
|
super(Client, self).__init__(
|
|
transport=transport, target=oslo_messaging.Target(topic=topic))
|
|
self.replies = []
|
|
|
|
def call_a(self):
|
|
LOG.warning("call_a - client side")
|
|
rep = self.call({}, 'call_a')
|
|
LOG.warning("after call_a - client side")
|
|
self.replies.append(rep)
|
|
return rep
|
|
|
|
|
|
class ReplyServerEndpoint(object):
|
|
|
|
def call_a(self, *args, **kwargs):
|
|
LOG.warning("call_a - Server endpoint reached!")
|
|
return "OK"
|
|
|
|
|
|
class Server(object):
|
|
|
|
def __init__(self, conf, log_queue, transport_url, name, topic=None):
|
|
self.conf = conf
|
|
self.log_queue = log_queue
|
|
self.transport_url = transport_url
|
|
self.name = name
|
|
self.topic = topic or str(uuid.uuid4())
|
|
self.ready = multiprocessing.Value('b', False)
|
|
self._stop = multiprocessing.Event()
|
|
|
|
def start(self):
|
|
self.process = multiprocessing.Process(target=self._run_server,
|
|
name=self.name,
|
|
args=(self.conf,
|
|
self.transport_url,
|
|
self.log_queue,
|
|
self.ready))
|
|
self.process.start()
|
|
LOG.debug("Server process started: pid: %d", self.process.pid)
|
|
|
|
def _run_server(self, conf, url, log_queue, ready):
|
|
server_configurer(log_queue)
|
|
LOG.debug("Starting RPC server")
|
|
|
|
transport = oslo_messaging.get_transport(conf, url=url)
|
|
target = oslo_messaging.Target(topic=self.topic, server=self.name)
|
|
self.rpc_server = oslo_messaging.get_rpc_server(
|
|
transport=transport, target=target,
|
|
endpoints=[ReplyServerEndpoint()],
|
|
executor='eventlet')
|
|
self.rpc_server.start()
|
|
ready.value = True
|
|
LOG.debug("RPC server being started")
|
|
while not self._stop.is_set():
|
|
LOG.debug("Waiting for the stop signal ...")
|
|
time.sleep(1)
|
|
self.rpc_server.stop()
|
|
LOG.debug("Leaving process T:%s Pid:%d", str(target), os.getpid())
|
|
|
|
def cleanup(self):
|
|
LOG.debug("Stopping server")
|
|
self.shutdown()
|
|
|
|
def shutdown(self):
|
|
self._stop.set()
|
|
|
|
def restart(self, time_for_restart=1):
|
|
pass
|
|
|
|
def hang(self):
|
|
pass
|
|
|
|
def crash(self):
|
|
pass
|
|
|
|
def ping(self):
|
|
pass
|
|
|
|
|
|
class MutliprocTestCase(utils.SkipIfNoTransportURL):
|
|
|
|
def setUp(self):
|
|
super(MutliprocTestCase, self).setUp(conf=cfg.ConfigOpts())
|
|
|
|
if not self.url.startswith("zmq:"):
|
|
self.skipTest("ZeroMQ specific skipped ...")
|
|
|
|
self.transport = oslo_messaging.get_transport(self.conf, url=self.url)
|
|
|
|
LOG.debug("Start log queue")
|
|
|
|
self.log_queue = multiprocessing.Queue()
|
|
self.log_listener = threading.Thread(target=listener_thread,
|
|
args=(self.log_queue,
|
|
listener_configurer,
|
|
self.conf))
|
|
self.log_listener.start()
|
|
self.spawned = []
|
|
|
|
self.conf.prog = "test_prog"
|
|
self.conf.project = "test_project"
|
|
|
|
def tearDown(self):
|
|
super(MutliprocTestCase, self).tearDown()
|
|
for process in self.spawned:
|
|
process.cleanup()
|
|
|
|
def get_client(self, topic):
|
|
return Client(self.transport, topic)
|
|
|
|
def spawn_server(self, name, wait_for_server=False, topic=None):
|
|
srv = Server(self.conf, self.log_queue, self.url, name, topic)
|
|
LOG.debug("[SPAWN] %s (starting)...", srv.name)
|
|
srv.start()
|
|
if wait_for_server:
|
|
while not srv.ready.value:
|
|
LOG.debug("[SPAWN] %s (waiting for server ready)...",
|
|
srv.name)
|
|
time.sleep(1)
|
|
LOG.debug("[SPAWN] Server %s:%d started.", srv.name, srv.process.pid)
|
|
self.spawned.append(srv)
|
|
return srv
|
|
|
|
def spawn_servers(self, number, wait_for_server=False, random_topic=True):
|
|
common_topic = str(uuid.uuid4()) if random_topic else None
|
|
names = ["server_%i_%s" % (i, str(uuid.uuid4())[:8])
|
|
for i in range(number)]
|
|
for name in names:
|
|
server = self.spawn_server(name, wait_for_server, common_topic)
|
|
self.spawned.append(server)
|