diff --git a/oslo/messaging/_drivers/amqpdriver.py b/oslo/messaging/_drivers/amqpdriver.py index 3486a53a6..40c1fcc82 100644 --- a/oslo/messaging/_drivers/amqpdriver.py +++ b/oslo/messaging/_drivers/amqpdriver.py @@ -135,10 +135,11 @@ class ReplyWaiters(object): class ReplyWaiter(object): - def __init__(self, conf, reply_q, conn): + def __init__(self, conf, reply_q, conn, allowed_remote_exmods): self.conf = conf self.conn = conn self.reply_q = reply_q + self.allowed_remote_exmods = allowed_remote_exmods self.conn_lock = threading.Lock() self.incoming = [] @@ -163,8 +164,8 @@ class ReplyWaiter(object): self.msg_id_cache.check_duplicate_message(data) if data['failure']: failure = data['failure'] - result = rpc_common.deserialize_remote_exception(self.conf, - failure) + result = rpc_common.deserialize_remote_exception( + failure, self.allowed_remote_exmods) elif data.get('ending', False): ending = True else: @@ -241,8 +242,10 @@ class ReplyWaiter(object): class AMQPDriverBase(base.BaseDriver): - def __init__(self, conf, connection_pool, url=None, default_exchange=None): - super(AMQPDriverBase, self).__init__(conf, url, default_exchange) + def __init__(self, conf, connection_pool, url=None, default_exchange=None, + allowed_remote_exmods=[]): + super(AMQPDriverBase, self).__init__(conf, url, default_exchange, + allowed_remote_exmods) self._default_exchange = urls.exchange_from_url(url, default_exchange) @@ -271,7 +274,8 @@ class AMQPDriverBase(base.BaseDriver): conn = self._get_connection(pooled=False) - self._waiter = ReplyWaiter(self.conf, reply_q, conn) + self._waiter = ReplyWaiter(self.conf, reply_q, conn, + self._allowed_remote_exmods) self._reply_q = reply_q self._reply_q_conn = conn diff --git a/oslo/messaging/_drivers/base.py b/oslo/messaging/_drivers/base.py index 0868d6d1d..085b28dd6 100644 --- a/oslo/messaging/_drivers/base.py +++ b/oslo/messaging/_drivers/base.py @@ -55,10 +55,12 @@ class BaseDriver(object): __metaclass__ = abc.ABCMeta - def __init__(self, conf, url=None, default_exchange=None): + def __init__(self, conf, url=None, default_exchange=None, + allowed_remote_exmods=[]): self.conf = conf self._url = url self._default_exchange = default_exchange + self._allowed_remote_exmods = allowed_remote_exmods @abc.abstractmethod def send(self, target, ctxt, message, diff --git a/oslo/messaging/_drivers/common.py b/oslo/messaging/_drivers/common.py index b9b68d6e6..9fd169cb0 100644 --- a/oslo/messaging/_drivers/common.py +++ b/oslo/messaging/_drivers/common.py @@ -73,7 +73,6 @@ _MESSAGE_KEY = 'oslo.message' _REMOTE_POSTFIX = '_Remote' -# FIXME(markmc): add an API to replace this option _exception_opts = [ cfg.ListOpt('allowed_rpc_exception_modules', default=['oslo.messaging.exceptions', @@ -330,7 +329,7 @@ def serialize_remote_exception(failure_info, log_failure=True): return json_data -def deserialize_remote_exception(conf, data): +def deserialize_remote_exception(data, allowed_remote_exmods): failure = jsonutils.loads(str(data)) trace = failure.get('tb', []) @@ -340,8 +339,7 @@ def deserialize_remote_exception(conf, data): # NOTE(ameade): We DO NOT want to allow just any module to be imported, in # order to prevent arbitrary code execution. - conf.register_opts(_exception_opts) - if module not in conf.allowed_rpc_exception_modules: + if module != 'exceptions' and module not in allowed_remote_exmods: return messaging.RemoteError(name, failure.get('message'), trace) try: diff --git a/oslo/messaging/_drivers/impl_fake.py b/oslo/messaging/_drivers/impl_fake.py index ffce9677a..42ddd86b4 100644 --- a/oslo/messaging/_drivers/impl_fake.py +++ b/oslo/messaging/_drivers/impl_fake.py @@ -87,8 +87,10 @@ class FakeExchange(object): class FakeDriver(base.BaseDriver): - def __init__(self, conf, url=None, default_exchange=None): - super(FakeDriver, self).__init__(conf, url, default_exchange) + def __init__(self, conf, url=None, default_exchange=None, + allowed_remote_exmods=[]): + super(FakeDriver, self).__init__(conf, url, default_exchange, + allowed_remote_exmods=[]) self._default_exchange = urls.exchange_from_url(url, default_exchange) diff --git a/oslo/messaging/_drivers/impl_qpid.py b/oslo/messaging/_drivers/impl_qpid.py index cdb77924a..e159277ce 100644 --- a/oslo/messaging/_drivers/impl_qpid.py +++ b/oslo/messaging/_drivers/impl_qpid.py @@ -742,11 +742,13 @@ def cleanup(): class QpidDriver(amqpdriver.AMQPDriverBase): - def __init__(self, conf, url=None, default_exchange=None): + def __init__(self, conf, url=None, default_exchange=None, + allowed_remote_exmods=[]): conf.register_opts(qpid_opts) conf.register_opts(rpc_amqp.amqp_opts) connection_pool = rpc_amqp.get_connection_pool(conf, Connection) super(QpidDriver, self).__init__(conf, connection_pool, - url, default_exchange) + url, default_exchange, + allowed_remote_exmods) diff --git a/oslo/messaging/_drivers/impl_rabbit.py b/oslo/messaging/_drivers/impl_rabbit.py index d997d78ec..71f65b49f 100644 --- a/oslo/messaging/_drivers/impl_rabbit.py +++ b/oslo/messaging/_drivers/impl_rabbit.py @@ -873,11 +873,13 @@ def cleanup(): class RabbitDriver(amqpdriver.AMQPDriverBase): - def __init__(self, conf, url=None, default_exchange=None): + def __init__(self, conf, url=None, default_exchange=None, + allowed_remote_exmods=[]): conf.register_opts(rabbit_opts) conf.register_opts(rpc_amqp.amqp_opts) connection_pool = rpc_amqp.get_connection_pool(conf, Connection) super(RabbitDriver, self).__init__(conf, connection_pool, - url, default_exchange) + url, default_exchange, + allowed_remote_exmods) diff --git a/oslo/messaging/transport.py b/oslo/messaging/transport.py index 20083688f..87572f5ff 100644 --- a/oslo/messaging/transport.py +++ b/oslo/messaging/transport.py @@ -119,7 +119,7 @@ class DriverLoadFailure(exceptions.MessagingException): self.ex = ex -def get_transport(conf, url=None): +def get_transport(conf, url=None, allowed_remote_exmods=[]): """A factory method for Transport objects. This method will construct a Transport object from transport configuration @@ -140,6 +140,9 @@ def get_transport(conf, url=None): :type conf: cfg.ConfigOpts :param url: a transport URL :type url: str + :param allowed_remote_exmods: a list of modules which a client using this + transport will deserialize remote exceptions from + :type allowed_remote_exmods: list """ conf.register_opts(_transport_opts) @@ -151,7 +154,8 @@ def get_transport(conf, url=None): else: rpc_backend = conf.rpc_backend - kwargs = dict(default_exchange=conf.control_exchange) + kwargs = dict(default_exchange=conf.control_exchange, + allowed_remote_exmods=allowed_remote_exmods) if url is not None: kwargs['url'] = url diff --git a/tests/test_exception_serialization.py b/tests/test_exception_serialization.py index 71195b806..8884c60d3 100644 --- a/tests/test_exception_serialization.py +++ b/tests/test_exception_serialization.py @@ -150,7 +150,7 @@ SerializeRemoteExceptionTestCase.generate_scenarios() class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase): - _standard_allowed = [__name__, 'exceptions'] + _standard_allowed = [__name__] scenarios = [ ('bog_standard', @@ -203,18 +203,18 @@ class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase): remote_kwargs={})), ('not_allowed', dict(allowed=[], - clsname='Exception', - modname='exceptions', + clsname='NovaStyleException', + modname=__name__, cls=messaging.RemoteError, args=[], kwargs={}, - str=("Remote error: Exception test\n" + str=("Remote error: NovaStyleException test\n" "[u'traceback\\ntraceback\\n']."), - msg=("Remote error: Exception test\n" + msg=("Remote error: NovaStyleException test\n" "[u'traceback\\ntraceback\\n']."), remote_name='RemoteError', remote_args=(), - remote_kwargs={'exc_type': 'Exception', + remote_kwargs={'exc_type': 'NovaStyleException', 'value': 'test', 'traceback': 'traceback\ntraceback\n'})), ('unknown_module', @@ -234,7 +234,7 @@ class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase): 'value': 'test', 'traceback': 'traceback\ntraceback\n'})), ('unknown_exception', - dict(allowed=['exceptions'], + dict(allowed=[], clsname='FarcicalError', modname='exceptions', cls=messaging.RemoteError, @@ -250,7 +250,7 @@ class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase): 'value': 'test', 'traceback': 'traceback\ntraceback\n'})), ('unknown_kwarg', - dict(allowed=['exceptions'], + dict(allowed=[], clsname='Exception', modname='exceptions', cls=messaging.RemoteError, @@ -266,7 +266,7 @@ class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase): 'value': 'test', 'traceback': 'traceback\ntraceback\n'})), ('system_exit', - dict(allowed=['exceptions'], + dict(allowed=[], clsname='SystemExit', modname='exceptions', cls=messaging.RemoteError, @@ -283,13 +283,7 @@ class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase): 'traceback': 'traceback\ntraceback\n'})), ] - def setUp(self): - super(DeserializeRemoteExceptionTestCase, self).setUp() - self.conf.register_opts(exceptions._exception_opts) - def test_deserialize_remote_exception(self): - self.config(allowed_rpc_exception_modules=self.allowed) - failure = { 'class': self.clsname, 'module': self.modname, @@ -301,7 +295,7 @@ class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase): serialized = jsonutils.dumps(failure) - ex = exceptions.deserialize_remote_exception(self.conf, serialized) + ex = exceptions.deserialize_remote_exception(serialized, self.allowed) self.assertIsInstance(ex, self.cls) self.assertEqual(ex.__class__.__name__, self.remote_name) diff --git a/tests/test_transport.py b/tests/test_transport.py index 0722ac554..bbdc0b796 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -54,34 +54,46 @@ class GetTransportTestCase(test_utils.BaseTestCase): scenarios = [ ('all_none', dict(url=None, transport_url=None, rpc_backend=None, - control_exchange=None, + control_exchange=None, allowed=None, expect=dict(backend=None, exchange=None, - url=None))), + url=None, + allowed=[]))), ('rpc_backend', dict(url=None, transport_url=None, rpc_backend='testbackend', - control_exchange=None, + control_exchange=None, allowed=None, expect=dict(backend='testbackend', exchange=None, - url=None))), + url=None, + allowed=[]))), ('control_exchange', dict(url=None, transport_url=None, rpc_backend=None, - control_exchange='testexchange', + control_exchange='testexchange', allowed=None, expect=dict(backend=None, exchange='testexchange', - url=None))), + url=None, + allowed=[]))), ('transport_url', dict(url=None, transport_url='testtransport:', rpc_backend=None, - control_exchange=None, + control_exchange=None, allowed=None, expect=dict(backend='testtransport', exchange=None, - url='testtransport:'))), + url='testtransport:', + allowed=[]))), ('url_param', dict(url='testtransport:', transport_url=None, rpc_backend=None, - control_exchange=None, + control_exchange=None, allowed=None, expect=dict(backend='testtransport', exchange=None, - url='testtransport:'))), + url='testtransport:', + allowed=[]))), + ('allowed_remote_exmods', + dict(url=None, transport_url=None, rpc_backend=None, + control_exchange=None, allowed=['foo', 'bar'], + expect=dict(backend=None, + exchange=None, + url=None, + allowed=['foo', 'bar']))), ] def setUp(self): @@ -96,7 +108,8 @@ class GetTransportTestCase(test_utils.BaseTestCase): self.mox.StubOutWithMock(driver, 'DriverManager') invoke_args = [self.conf] - invoke_kwds = dict(default_exchange=self.expect['exchange']) + invoke_kwds = dict(default_exchange=self.expect['exchange'], + allowed_remote_exmods=self.expect['allowed']) if self.expect['url']: invoke_kwds['url'] = self.expect['url'] @@ -110,7 +123,10 @@ class GetTransportTestCase(test_utils.BaseTestCase): self.mox.ReplayAll() - transport = messaging.get_transport(self.conf, url=self.url) + kwargs = dict(url=self.url) + if self.allowed is not None: + kwargs['allowed_remote_exmods'] = self.allowed + transport = messaging.get_transport(self.conf, **kwargs) self.assertIsNotNone(transport) self.assertIs(transport.conf, self.conf) @@ -149,7 +165,8 @@ class GetTransportSadPathTestCase(test_utils.BaseTestCase): self.mox.StubOutWithMock(driver, 'DriverManager') invoke_args = [self.conf] - invoke_kwds = dict(default_exchange='openstack') + invoke_kwds = dict(default_exchange='openstack', + allowed_remote_exmods=[]) driver.DriverManager('oslo.messaging.drivers', self.rpc_backend,