Add a per-transport allow_remote_exmods API

Currently we have a allowed_rpc_exception_modules configuration variable
which we use to configure a per-project list of modules which we will
allow exceptions to be instantiated from when deserializing remote
errors.

It makes no sense for this to be user configurable, instead the list of
modules should be set when you create a transport.

Closes-Bug: #1031719
Change-Id: Ib40e92cb920996ec5e8f63d6f2cbd88fd01a90f2
This commit is contained in:
Mark McLoughlin 2013-08-07 13:07:05 +01:00
parent 66f597f30d
commit ac2176cde3
9 changed files with 73 additions and 48 deletions

View File

@ -135,10 +135,11 @@ class ReplyWaiters(object):
class ReplyWaiter(object):
def __init__(self, conf, reply_q, conn):
def __init__(self, conf, reply_q, conn, allowed_remote_exmods):
self.conf = conf
self.conn = conn
self.reply_q = reply_q
self.allowed_remote_exmods = allowed_remote_exmods
self.conn_lock = threading.Lock()
self.incoming = []
@ -163,8 +164,8 @@ class ReplyWaiter(object):
self.msg_id_cache.check_duplicate_message(data)
if data['failure']:
failure = data['failure']
result = rpc_common.deserialize_remote_exception(self.conf,
failure)
result = rpc_common.deserialize_remote_exception(
failure, self.allowed_remote_exmods)
elif data.get('ending', False):
ending = True
else:
@ -241,8 +242,10 @@ class ReplyWaiter(object):
class AMQPDriverBase(base.BaseDriver):
def __init__(self, conf, connection_pool, url=None, default_exchange=None):
super(AMQPDriverBase, self).__init__(conf, url, default_exchange)
def __init__(self, conf, connection_pool, url=None, default_exchange=None,
allowed_remote_exmods=[]):
super(AMQPDriverBase, self).__init__(conf, url, default_exchange,
allowed_remote_exmods)
self._default_exchange = urls.exchange_from_url(url, default_exchange)
@ -271,7 +274,8 @@ class AMQPDriverBase(base.BaseDriver):
conn = self._get_connection(pooled=False)
self._waiter = ReplyWaiter(self.conf, reply_q, conn)
self._waiter = ReplyWaiter(self.conf, reply_q, conn,
self._allowed_remote_exmods)
self._reply_q = reply_q
self._reply_q_conn = conn

View File

@ -55,10 +55,12 @@ class BaseDriver(object):
__metaclass__ = abc.ABCMeta
def __init__(self, conf, url=None, default_exchange=None):
def __init__(self, conf, url=None, default_exchange=None,
allowed_remote_exmods=[]):
self.conf = conf
self._url = url
self._default_exchange = default_exchange
self._allowed_remote_exmods = allowed_remote_exmods
@abc.abstractmethod
def send(self, target, ctxt, message,

View File

@ -73,7 +73,6 @@ _MESSAGE_KEY = 'oslo.message'
_REMOTE_POSTFIX = '_Remote'
# FIXME(markmc): add an API to replace this option
_exception_opts = [
cfg.ListOpt('allowed_rpc_exception_modules',
default=['oslo.messaging.exceptions',
@ -330,7 +329,7 @@ def serialize_remote_exception(failure_info, log_failure=True):
return json_data
def deserialize_remote_exception(conf, data):
def deserialize_remote_exception(data, allowed_remote_exmods):
failure = jsonutils.loads(str(data))
trace = failure.get('tb', [])
@ -340,8 +339,7 @@ def deserialize_remote_exception(conf, data):
# NOTE(ameade): We DO NOT want to allow just any module to be imported, in
# order to prevent arbitrary code execution.
conf.register_opts(_exception_opts)
if module not in conf.allowed_rpc_exception_modules:
if module != 'exceptions' and module not in allowed_remote_exmods:
return messaging.RemoteError(name, failure.get('message'), trace)
try:

View File

@ -87,8 +87,10 @@ class FakeExchange(object):
class FakeDriver(base.BaseDriver):
def __init__(self, conf, url=None, default_exchange=None):
super(FakeDriver, self).__init__(conf, url, default_exchange)
def __init__(self, conf, url=None, default_exchange=None,
allowed_remote_exmods=[]):
super(FakeDriver, self).__init__(conf, url, default_exchange,
allowed_remote_exmods=[])
self._default_exchange = urls.exchange_from_url(url, default_exchange)

View File

@ -742,11 +742,13 @@ def cleanup():
class QpidDriver(amqpdriver.AMQPDriverBase):
def __init__(self, conf, url=None, default_exchange=None):
def __init__(self, conf, url=None, default_exchange=None,
allowed_remote_exmods=[]):
conf.register_opts(qpid_opts)
conf.register_opts(rpc_amqp.amqp_opts)
connection_pool = rpc_amqp.get_connection_pool(conf, Connection)
super(QpidDriver, self).__init__(conf, connection_pool,
url, default_exchange)
url, default_exchange,
allowed_remote_exmods)

View File

@ -873,11 +873,13 @@ def cleanup():
class RabbitDriver(amqpdriver.AMQPDriverBase):
def __init__(self, conf, url=None, default_exchange=None):
def __init__(self, conf, url=None, default_exchange=None,
allowed_remote_exmods=[]):
conf.register_opts(rabbit_opts)
conf.register_opts(rpc_amqp.amqp_opts)
connection_pool = rpc_amqp.get_connection_pool(conf, Connection)
super(RabbitDriver, self).__init__(conf, connection_pool,
url, default_exchange)
url, default_exchange,
allowed_remote_exmods)

View File

@ -119,7 +119,7 @@ class DriverLoadFailure(exceptions.MessagingException):
self.ex = ex
def get_transport(conf, url=None):
def get_transport(conf, url=None, allowed_remote_exmods=[]):
"""A factory method for Transport objects.
This method will construct a Transport object from transport configuration
@ -140,6 +140,9 @@ def get_transport(conf, url=None):
:type conf: cfg.ConfigOpts
:param url: a transport URL
:type url: str
:param allowed_remote_exmods: a list of modules which a client using this
transport will deserialize remote exceptions from
:type allowed_remote_exmods: list
"""
conf.register_opts(_transport_opts)
@ -151,7 +154,8 @@ def get_transport(conf, url=None):
else:
rpc_backend = conf.rpc_backend
kwargs = dict(default_exchange=conf.control_exchange)
kwargs = dict(default_exchange=conf.control_exchange,
allowed_remote_exmods=allowed_remote_exmods)
if url is not None:
kwargs['url'] = url

View File

@ -150,7 +150,7 @@ SerializeRemoteExceptionTestCase.generate_scenarios()
class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase):
_standard_allowed = [__name__, 'exceptions']
_standard_allowed = [__name__]
scenarios = [
('bog_standard',
@ -203,18 +203,18 @@ class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase):
remote_kwargs={})),
('not_allowed',
dict(allowed=[],
clsname='Exception',
modname='exceptions',
clsname='NovaStyleException',
modname=__name__,
cls=messaging.RemoteError,
args=[],
kwargs={},
str=("Remote error: Exception test\n"
str=("Remote error: NovaStyleException test\n"
"[u'traceback\\ntraceback\\n']."),
msg=("Remote error: Exception test\n"
msg=("Remote error: NovaStyleException test\n"
"[u'traceback\\ntraceback\\n']."),
remote_name='RemoteError',
remote_args=(),
remote_kwargs={'exc_type': 'Exception',
remote_kwargs={'exc_type': 'NovaStyleException',
'value': 'test',
'traceback': 'traceback\ntraceback\n'})),
('unknown_module',
@ -234,7 +234,7 @@ class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase):
'value': 'test',
'traceback': 'traceback\ntraceback\n'})),
('unknown_exception',
dict(allowed=['exceptions'],
dict(allowed=[],
clsname='FarcicalError',
modname='exceptions',
cls=messaging.RemoteError,
@ -250,7 +250,7 @@ class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase):
'value': 'test',
'traceback': 'traceback\ntraceback\n'})),
('unknown_kwarg',
dict(allowed=['exceptions'],
dict(allowed=[],
clsname='Exception',
modname='exceptions',
cls=messaging.RemoteError,
@ -266,7 +266,7 @@ class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase):
'value': 'test',
'traceback': 'traceback\ntraceback\n'})),
('system_exit',
dict(allowed=['exceptions'],
dict(allowed=[],
clsname='SystemExit',
modname='exceptions',
cls=messaging.RemoteError,
@ -283,13 +283,7 @@ class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase):
'traceback': 'traceback\ntraceback\n'})),
]
def setUp(self):
super(DeserializeRemoteExceptionTestCase, self).setUp()
self.conf.register_opts(exceptions._exception_opts)
def test_deserialize_remote_exception(self):
self.config(allowed_rpc_exception_modules=self.allowed)
failure = {
'class': self.clsname,
'module': self.modname,
@ -301,7 +295,7 @@ class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase):
serialized = jsonutils.dumps(failure)
ex = exceptions.deserialize_remote_exception(self.conf, serialized)
ex = exceptions.deserialize_remote_exception(serialized, self.allowed)
self.assertIsInstance(ex, self.cls)
self.assertEqual(ex.__class__.__name__, self.remote_name)

View File

@ -54,34 +54,46 @@ class GetTransportTestCase(test_utils.BaseTestCase):
scenarios = [
('all_none',
dict(url=None, transport_url=None, rpc_backend=None,
control_exchange=None,
control_exchange=None, allowed=None,
expect=dict(backend=None,
exchange=None,
url=None))),
url=None,
allowed=[]))),
('rpc_backend',
dict(url=None, transport_url=None, rpc_backend='testbackend',
control_exchange=None,
control_exchange=None, allowed=None,
expect=dict(backend='testbackend',
exchange=None,
url=None))),
url=None,
allowed=[]))),
('control_exchange',
dict(url=None, transport_url=None, rpc_backend=None,
control_exchange='testexchange',
control_exchange='testexchange', allowed=None,
expect=dict(backend=None,
exchange='testexchange',
url=None))),
url=None,
allowed=[]))),
('transport_url',
dict(url=None, transport_url='testtransport:', rpc_backend=None,
control_exchange=None,
control_exchange=None, allowed=None,
expect=dict(backend='testtransport',
exchange=None,
url='testtransport:'))),
url='testtransport:',
allowed=[]))),
('url_param',
dict(url='testtransport:', transport_url=None, rpc_backend=None,
control_exchange=None,
control_exchange=None, allowed=None,
expect=dict(backend='testtransport',
exchange=None,
url='testtransport:'))),
url='testtransport:',
allowed=[]))),
('allowed_remote_exmods',
dict(url=None, transport_url=None, rpc_backend=None,
control_exchange=None, allowed=['foo', 'bar'],
expect=dict(backend=None,
exchange=None,
url=None,
allowed=['foo', 'bar']))),
]
def setUp(self):
@ -96,7 +108,8 @@ class GetTransportTestCase(test_utils.BaseTestCase):
self.mox.StubOutWithMock(driver, 'DriverManager')
invoke_args = [self.conf]
invoke_kwds = dict(default_exchange=self.expect['exchange'])
invoke_kwds = dict(default_exchange=self.expect['exchange'],
allowed_remote_exmods=self.expect['allowed'])
if self.expect['url']:
invoke_kwds['url'] = self.expect['url']
@ -110,7 +123,10 @@ class GetTransportTestCase(test_utils.BaseTestCase):
self.mox.ReplayAll()
transport = messaging.get_transport(self.conf, url=self.url)
kwargs = dict(url=self.url)
if self.allowed is not None:
kwargs['allowed_remote_exmods'] = self.allowed
transport = messaging.get_transport(self.conf, **kwargs)
self.assertIsNotNone(transport)
self.assertIs(transport.conf, self.conf)
@ -149,7 +165,8 @@ class GetTransportSadPathTestCase(test_utils.BaseTestCase):
self.mox.StubOutWithMock(driver, 'DriverManager')
invoke_args = [self.conf]
invoke_kwds = dict(default_exchange='openstack')
invoke_kwds = dict(default_exchange='openstack',
allowed_remote_exmods=[])
driver.DriverManager('oslo.messaging.drivers',
self.rpc_backend,