diff --git a/oslo/messaging/_drivers/amqpdriver.py b/oslo/messaging/_drivers/amqpdriver.py index 4549fd6d8..fb001f10b 100644 --- a/oslo/messaging/_drivers/amqpdriver.py +++ b/oslo/messaging/_drivers/amqpdriver.py @@ -307,7 +307,6 @@ class AMQPDriverBase(base.BaseDriver): try: with self._get_connection() as conn: - # FIXME(markmc): check that target.topic is set if target.fanout: conn.fanout_send(target.topic, msg) else: @@ -326,8 +325,6 @@ class AMQPDriverBase(base.BaseDriver): self._waiter.unlisten(msg_id) def listen(self, target): - # FIXME(markmc): check that topic.target and topic.server is set - conn = self._get_connection(pooled=False) listener = AMQPListener(self, target, conn) diff --git a/oslo/messaging/_drivers/impl_fake.py b/oslo/messaging/_drivers/impl_fake.py index 4be1d8d6c..6afbfad37 100644 --- a/oslo/messaging/_drivers/impl_fake.py +++ b/oslo/messaging/_drivers/impl_fake.py @@ -25,14 +25,6 @@ from oslo.messaging._drivers import base from oslo.messaging import _urls as urls -class InvalidTarget(base.TransportDriverError, ValueError): - - def __init__(self, msg, target): - msg = msg + ":" + str(target) - super(InvalidTarget, self).__init__(msg) - self.target = target - - class FakeIncomingMessage(base.IncomingMessage): def __init__(self, listener, ctxt, message, reply_q): @@ -122,13 +114,6 @@ class FakeDriver(base.BaseDriver): def send(self, target, ctxt, message, wait_for_reply=None, timeout=None, envelope=False): - if not target.topic: - raise InvalidTarget('A topic is required to send', target) - - # FIXME(markmc): preconditions to enforce: - # - timeout and not wait_for_reply - # - target.fanout and (wait_for_reply or timeout) - self._check_serialize(message) exchange = self._get_exchange(target.exchange or @@ -153,10 +138,6 @@ class FakeDriver(base.BaseDriver): return None def listen(self, target): - if not (target.topic and target.server): - raise InvalidTarget('Topic and server are required to listen', - target) - exchange = self._get_exchange(target.exchange or self._default_exchange) diff --git a/oslo/messaging/exceptions.py b/oslo/messaging/exceptions.py index 32373c4df..b6bbb74be 100644 --- a/oslo/messaging/exceptions.py +++ b/oslo/messaging/exceptions.py @@ -13,7 +13,7 @@ # License for the specific language governing permissions and limitations # under the License. -__all__ = ['MessagingException', 'MessagingTimeout'] +__all__ = ['MessagingException', 'MessagingTimeout', 'InvalidTarget'] class MessagingException(Exception): @@ -28,3 +28,12 @@ class MessagingException(Exception): class MessagingTimeout(MessagingException): """Raised if message sending times out.""" + + +class InvalidTarget(MessagingException, ValueError): + """Raised if a target does not meet certain pre-conditions.""" + + def __init__(self, msg, target): + msg = msg + ":" + str(target) + super(InvalidTarget, self).__init__(msg) + self.target = target diff --git a/oslo/messaging/transport.py b/oslo/messaging/transport.py index 390899262..f4667c83b 100644 --- a/oslo/messaging/transport.py +++ b/oslo/messaging/transport.py @@ -77,12 +77,19 @@ class Transport(object): def _send(self, target, ctxt, message, wait_for_reply=None, timeout=None, envelope=False): + if not target.topic: + raise exceptions.InvalidTarget('A topic is required to send', + target) return self._driver.send(target, ctxt, message, wait_for_reply=wait_for_reply, timeout=timeout, envelope=envelope) def _listen(self, target): + if not (target.topic and target.server): + raise exceptions.InvalidTarget('A server\'s target must have ' + 'topic and server names specified', + target) return self._driver.listen(target) def cleanup(self): diff --git a/tests/test_rpc_server.py b/tests/test_rpc_server.py index 28731f6cd..71e3c5644 100644 --- a/tests/test_rpc_server.py +++ b/tests/test_rpc_server.py @@ -114,7 +114,7 @@ class TestRPCServer(test_utils.BaseTestCase, ServerSetupMixin): try: server.start() except Exception as ex: - self.assertIsInstance(ex, messaging.ServerListenError, ex) + self.assertIsInstance(ex, messaging.InvalidTarget, ex) self.assertEqual(ex.target.topic, 'testtopic') else: self.assertTrue(False) @@ -126,7 +126,7 @@ class TestRPCServer(test_utils.BaseTestCase, ServerSetupMixin): try: server.start() except Exception as ex: - self.assertIsInstance(ex, messaging.ServerListenError, ex) + self.assertIsInstance(ex, messaging.InvalidTarget, ex) self.assertEqual(ex.target.server, 'testserver') else: self.assertTrue(False) @@ -141,7 +141,7 @@ class TestRPCServer(test_utils.BaseTestCase, ServerSetupMixin): try: method({}, 'ping', arg='foo') except Exception as ex: - self.assertIsInstance(ex, messaging.ClientSendError, ex) + self.assertIsInstance(ex, messaging.InvalidTarget, ex) self.assertIsNotNone(ex.target) else: self.assertTrue(False) diff --git a/tests/test_transport.py b/tests/test_transport.py index 9240f964b..55e581c37 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -231,29 +231,31 @@ class TestSetDefaults(test_utils.BaseTestCase): class TestTransportMethodArgs(test_utils.BaseTestCase): + _target = messaging.Target(topic='topic', server='server') + def test_send_defaults(self): t = transport.Transport(_FakeDriver(cfg.CONF)) self.mox.StubOutWithMock(t._driver, 'send') - t._driver.send('target', 'ctxt', 'message', + t._driver.send(self._target, 'ctxt', 'message', wait_for_reply=None, timeout=None, envelope=False) self.mox.ReplayAll() - t._send('target', 'ctxt', 'message') + t._send(self._target, 'ctxt', 'message') def test_send_all_args(self): t = transport.Transport(_FakeDriver(cfg.CONF)) self.mox.StubOutWithMock(t._driver, 'send') - t._driver.send('target', 'ctxt', 'message', + t._driver.send(self._target, 'ctxt', 'message', wait_for_reply='wait_for_reply', timeout='timeout', envelope='envelope') self.mox.ReplayAll() - t._send('target', 'ctxt', 'message', + t._send(self._target, 'ctxt', 'message', wait_for_reply='wait_for_reply', timeout='timeout', envelope='envelope') @@ -262,7 +264,7 @@ class TestTransportMethodArgs(test_utils.BaseTestCase): t = transport.Transport(_FakeDriver(cfg.CONF)) self.mox.StubOutWithMock(t._driver, 'listen') - t._driver.listen('target') + t._driver.listen(self._target) self.mox.ReplayAll() - t._listen('target') + t._listen(self._target)