Return timer after adding it to internal list

In some cases, its useful to selectively stop
some timers and not all. This fix allows a way
for client applications to maintain a handle
on timers and be able to stop some of them
selectively.

This is needed by
https://review.openstack.org/#/c/190842/

Change-Id: Ib258c07d6dd12ef7cc9f4a6c8cbde885048ceef8
Closes-Bug: 1466661
This commit is contained in:
Rohit Jaiswal 2015-06-23 17:04:30 +00:00
parent 59ed2ec386
commit 02242929f1
2 changed files with 36 additions and 0 deletions

View File

@ -79,3 +79,34 @@ class ThreadGroupTestCase(test_base.BaseTestCase):
self.assertEqual(1, len(self.tg.timers))
self.tg.stop_timers()
self.assertEqual(0, len(self.tg.timers))
def test_add_and_remove_timer(self):
def foo(*args, **kwargs):
pass
timer = self.tg.add_timer('1234', foo)
self.assertEqual(1, len(self.tg.timers))
timer.stop()
self.assertEqual(1, len(self.tg.timers))
self.tg.timer_done(timer)
self.assertEqual(0, len(self.tg.timers))
def test_add_and_remove_dynamic_timer(self):
def foo(*args, **kwargs):
pass
initial_delay = 1
periodic_interval_max = 2
timer = self.tg.add_dynamic_timer(foo, initial_delay,
periodic_interval_max)
self.assertEqual(1, len(self.tg.timers))
self.assertTrue(timer._running)
timer.stop()
self.assertEqual(1, len(self.tg.timers))
self.tg.timer_done(timer)
self.assertEqual(0, len(self.tg.timers))

View File

@ -69,6 +69,7 @@ class ThreadGroup(object):
timer.start(initial_delay=initial_delay,
periodic_interval_max=periodic_interval_max)
self.timers.append(timer)
return timer
def add_timer(self, interval, callback, initial_delay=None,
*args, **kwargs):
@ -76,6 +77,7 @@ class ThreadGroup(object):
pulse.start(interval=interval,
initial_delay=initial_delay)
self.timers.append(pulse)
return pulse
def add_thread(self, callback, *args, **kwargs):
gt = self.pool.spawn(callback, *args, **kwargs)
@ -86,6 +88,9 @@ class ThreadGroup(object):
def thread_done(self, thread):
self.threads.remove(thread)
def timer_done(self, timer):
self.timers.remove(timer)
def _stop_threads(self):
current = threading.current_thread()