diff --git a/swift/common/exceptions.py b/swift/common/exceptions.py index 8eb6879d72..d7ea759d66 100644 --- a/swift/common/exceptions.py +++ b/swift/common/exceptions.py @@ -111,6 +111,10 @@ class LockTimeout(MessageTimeout): pass +class ThreadPoolDead(SwiftException): + pass + + class RingBuilderError(SwiftException): pass diff --git a/swift/common/utils.py b/swift/common/utils.py index 2d54dadf74..f04860a341 100644 --- a/swift/common/utils.py +++ b/swift/common/utils.py @@ -2737,6 +2737,7 @@ class ThreadPool(object): self._run_queue = Queue() self._result_queue = Queue() self._threads = [] + self._alive = True if nthreads <= 0: return @@ -2784,6 +2785,8 @@ class ThreadPool(object): """ while True: item = work_queue.get() + if item is None: + break ev, func, args, kwargs = item try: result = func(*args, **kwargs) @@ -2838,6 +2841,9 @@ class ThreadPool(object): :returns: result of calling func :raises: whatever func raises """ + if not self._alive: + raise swift.common.exceptions.ThreadPoolDead() + if self.nthreads <= 0: result = func(*args, **kwargs) sleep() @@ -2882,11 +2888,38 @@ class ThreadPool(object): :returns: result of calling func :raises: whatever func raises """ + if not self._alive: + raise swift.common.exceptions.ThreadPoolDead() + if self.nthreads <= 0: return self._run_in_eventlet_tpool(func, *args, **kwargs) else: return self.run_in_thread(func, *args, **kwargs) + def terminate(self): + """ + Releases the threadpool's resources (OS threads, greenthreads, pipes, + etc.) and renders it unusable. + + Don't call run_in_thread() or force_run_in_thread() after calling + terminate(). + """ + self._alive = False + if self.nthreads <= 0: + return + + for _junk in range(self.nthreads): + self._run_queue.put(None) + for thr in self._threads: + thr.join() + self._threads = [] + self.nthreads = 0 + + greenthread.kill(self._consumer_coro) + + self.rpipe.close() + os.close(self.wpipe) + def ismount(path): """ diff --git a/test/unit/common/test_utils.py b/test/unit/common/test_utils.py index 87a1362702..dce0d76aba 100644 --- a/test/unit/common/test_utils.py +++ b/test/unit/common/test_utils.py @@ -28,6 +28,7 @@ import mock import random import re import socket +import stat import sys import json import math @@ -55,7 +56,7 @@ from mock import MagicMock, patch from swift.common.exceptions import (Timeout, MessageTimeout, ConnectionTimeout, LockTimeout, ReplicationLockTimeout, - MimeInvalid) + MimeInvalid, ThreadPoolDead) from swift.common import utils from swift.common.container_sync_realms import ContainerSyncRealms from swift.common.swob import Request, Response @@ -3747,7 +3748,28 @@ class TestStatsdLoggingDelegation(unittest.TestCase): self.assertEquals(called, [12345]) -class TestThreadpool(unittest.TestCase): +class TestThreadPool(unittest.TestCase): + + def setUp(self): + self.tp = None + + def tearDown(self): + if self.tp: + self.tp.terminate() + + def _pipe_count(self): + # Counts the number of pipes that this process owns. + fd_dir = "/proc/%d/fd" % os.getpid() + + def is_pipe(path): + try: + stat_result = os.stat(path) + return stat.S_ISFIFO(stat_result.st_mode) + except OSError: + return False + + return len([fd for fd in os.listdir(fd_dir) + if is_pipe(os.path.join(fd_dir, fd))]) def _thread_id(self): return threading.current_thread().ident @@ -3759,7 +3781,7 @@ class TestThreadpool(unittest.TestCase): return int('fishcakes') def test_run_in_thread_with_threads(self): - tp = utils.ThreadPool(1) + tp = self.tp = utils.ThreadPool(1) my_id = self._thread_id() other_id = tp.run_in_thread(self._thread_id) @@ -3778,7 +3800,7 @@ class TestThreadpool(unittest.TestCase): def test_force_run_in_thread_with_threads(self): # with nthreads > 0, force_run_in_thread looks just like run_in_thread - tp = utils.ThreadPool(1) + tp = self.tp = utils.ThreadPool(1) my_id = self._thread_id() other_id = tp.force_run_in_thread(self._thread_id) @@ -3828,7 +3850,7 @@ class TestThreadpool(unittest.TestCase): def alpha(): return beta() - tp = utils.ThreadPool(1) + tp = self.tp = utils.ThreadPool(1) try: tp.run_in_thread(alpha) except ZeroDivisionError: @@ -3846,6 +3868,44 @@ class TestThreadpool(unittest.TestCase): self.assertEqual(tb_func[1], "run_in_thread") self.assertEqual(tb_func[0], "test_preserving_stack_trace_from_thread") + def test_terminate(self): + initial_thread_count = threading.activeCount() + initial_pipe_count = self._pipe_count() + + tp = utils.ThreadPool(4) + # do some work to ensure any lazy initialization happens + tp.run_in_thread(os.path.join, 'foo', 'bar') + tp.run_in_thread(os.path.join, 'baz', 'quux') + + # 4 threads in the ThreadPool, plus one pipe for IPC; this also + # serves as a sanity check that we're actually allocating some + # resources to free later + self.assertEqual(initial_thread_count, threading.activeCount() - 4) + self.assertEqual(initial_pipe_count, self._pipe_count() - 2) + + tp.terminate() + self.assertEqual(initial_thread_count, threading.activeCount()) + self.assertEqual(initial_pipe_count, self._pipe_count()) + + def test_cant_run_after_terminate(self): + tp = utils.ThreadPool(0) + tp.terminate() + self.assertRaises(ThreadPoolDead, tp.run_in_thread, lambda: 1) + self.assertRaises(ThreadPoolDead, tp.force_run_in_thread, lambda: 1) + + def test_double_terminate_doesnt_crash(self): + tp = utils.ThreadPool(0) + tp.terminate() + tp.terminate() + + tp = utils.ThreadPool(1) + tp.terminate() + tp.terminate() + + def test_terminate_no_threads_doesnt_crash(self): + tp = utils.ThreadPool(0) + tp.terminate() + class TestAuditLocationGenerator(unittest.TestCase):