diff --git a/oslo_messaging/_drivers/impl_kafka.py b/oslo_messaging/_drivers/impl_kafka.py index 52e793327..6c90d30cb 100644 --- a/oslo_messaging/_drivers/impl_kafka.py +++ b/oslo_messaging/_drivers/impl_kafka.py @@ -115,35 +115,6 @@ def with_reconnect(retries=None): return decorator -class Producer(object): - _producer = None - _servers = None - _lock = threading.Lock() - - @staticmethod - @with_reconnect() - def connect(servers, **kwargs): - return kafka.KafkaProducer( - bootstrap_servers=servers, - selector=KAFKA_SELECTOR, - **kwargs) - - @classmethod - def producer(cls, servers, **kwargs): - with cls._lock: - if not cls._producer or cls._servers != servers: - cls._servers = servers - cls._producer = cls.connect(servers, **kwargs) - return cls._producer - - @classmethod - def cleanup(cls): - with cls._lock: - if cls._producer: - cls._producer.close() - cls._producer = None - - class Connection(object): def __init__(self, conf, url, purpose): @@ -154,6 +125,7 @@ class Connection(object): self.linger_ms = driver_conf.producer_batch_timeout * 1000 self.conf = conf self.producer = None + self.producer_lock = threading.Lock() self.consumer = None self.consumer_timeout = float(driver_conf.kafka_consumer_timeout) self.max_fetch_bytes = driver_conf.kafka_max_fetch_bytes @@ -189,25 +161,24 @@ class Connection(object): :param msg: messages for publishing :param retry: the number of retry """ - - message = pack_message(ctxt, msg) - self._ensure_connection() - self._send_and_retry(message, topic, retry) - - def _send_and_retry(self, message, topic, retry): - if not isinstance(message, str): - message = jsonutils.dumps(message) retry = retry if retry >= 0 else None + message = pack_message(ctxt, msg) + message = jsonutils.dumps(message) @with_reconnect(retries=retry) - def _send(topic, message): + def wrapped_with_reconnect(): + self._ensure_producer() + # NOTE(sileht): This returns a future, we can use get() + # if we want to block like other driver self.producer.send(topic, message) try: - _send(topic, message) + wrapped_with_reconnect() except Exception: - Producer.cleanup() - LOG.exception(_LE("Failed to send message")) + # NOTE(sileht): if something goes wrong close the producer + # connection + self._close_producer() + raise @with_reconnect() def _poll_messages(self, timeout): @@ -239,12 +210,10 @@ class Connection(object): pass def close(self): - if self.producer: - self.producer.close() - self.producer = None + self._close_producer() if self.consumer: self.consumer.close() - self.consumer = None + self.consumer = None def commit(self): """Commit is used by subscribers belonging to the same group. @@ -257,14 +226,23 @@ class Connection(object): """ self.consumer.commit() - def _ensure_connection(self): - try: - self.producer = Producer.producer(self.hostaddrs, - linger_ms=self.linger_ms, - batch_size=self.batch_size) - except kafka.errors.KafkaError as e: - LOG.exception(_LE("KafkaProducer could not be initialized: %s"), e) - raise + def _close_producer(self): + with self.producer_lock: + if self.producer: + self.producer.close() + self.producer = None + + def _ensure_producer(self): + if self.producer: + return + with self.producer_lock: + if self.producer: + return + self.producer = kafka.KafkaProducer( + bootstrap_servers=self.hostaddrs, + linger_ms=self.linger_ms, + batch_size=self.batch_size, + selector=KAFKA_SELECTOR) @with_reconnect() def declare_topic_consumer(self, topics, group=None): diff --git a/oslo_messaging/tests/drivers/test_impl_kafka.py b/oslo_messaging/tests/drivers/test_impl_kafka.py index 6262aab4f..8d76cdc13 100644 --- a/oslo_messaging/tests/drivers/test_impl_kafka.py +++ b/oslo_messaging/tests/drivers/test_impl_kafka.py @@ -74,7 +74,6 @@ class TestKafkaDriver(test_utils.BaseTestCase): self.messaging_conf.transport_driver = 'kafka' transport = oslo_messaging.get_transport(self.conf) self.driver = transport._driver - self.addCleanup(kafka_driver.Producer.cleanup) def test_send(self): target = oslo_messaging.Target(topic="topic_test") @@ -87,8 +86,10 @@ class TestKafkaDriver(test_utils.BaseTestCase): with mock.patch("kafka.KafkaProducer") as fake_producer_class: fake_producer = fake_producer_class.return_value fake_producer.send.side_effect = kafka.errors.NoBrokersAvailable - self.driver.send_notification(target, {}, {"payload": ["test_1"]}, - None, retry=3) + self.assertRaises(kafka.errors.NoBrokersAvailable, + self.driver.send_notification, + target, {}, {"payload": ["test_1"]}, + None, retry=3) self.assertEqual(3, fake_producer.send.call_count) def test_listen(self): @@ -127,10 +128,11 @@ class TestKafkaConnection(test_utils.BaseTestCase): transport = oslo_messaging.get_transport(self.conf) self.driver = transport._driver - @mock.patch.object(kafka_driver.Connection, '_ensure_connection') - @mock.patch.object(kafka_driver.Connection, '_send_and_retry') - def test_notify(self, fake_send, fake_ensure_connection): + def test_notify(self): conn = self.driver._get_connection(common_driver.PURPOSE_SEND) - conn.notify_send("fake_topic", {"fake_ctxt": "fake_param"}, - {"fake_text": "fake_message_1"}, 10) - self.assertEqual(1, len(fake_send.mock_calls)) + + with mock.patch("kafka.KafkaProducer") as fake_producer_class: + fake_producer = fake_producer_class.return_value + conn.notify_send("fake_topic", {"fake_ctxt": "fake_param"}, + {"fake_text": "fake_message_1"}, 10) + self.assertEqual(1, len(fake_producer.send.mock_calls))