diff --git a/test/unit/common/test_utils.py b/test/unit/common/test_utils.py index c59172098a..33858fd708 100644 --- a/test/unit/common/test_utils.py +++ b/test/unit/common/test_utils.py @@ -8836,3 +8836,32 @@ class TestCloseableChain(unittest.TestCase): chain = utils.CloseableChain([1, 2], [3]) chain.close() self.assertEqual([1, 2, 3], [x for x in chain]) + + # check with generator in the chain + generator_closed = [False] + + def gen(): + try: + yield 2 + yield 3 + except GeneratorExit: + generator_closed[0] = True + raise + + test_iter1 = FakeIterable([1]) + chain = utils.CloseableChain(test_iter1, gen()) + self.assertEqual(0, test_iter1.close_call_count) + self.assertFalse(generator_closed[0]) + chain.close() + self.assertEqual(1, test_iter1.close_call_count) + # Generator never kicked off, so there's no GeneratorExit + self.assertFalse(generator_closed[0]) + + test_iter1 = FakeIterable([1]) + chain = utils.CloseableChain(gen(), test_iter1) + self.assertEqual(2, next(chain)) # Kick off the generator + self.assertEqual(0, test_iter1.close_call_count) + self.assertFalse(generator_closed[0]) + chain.close() + self.assertEqual(1, test_iter1.close_call_count) + self.assertTrue(generator_closed[0])