From 0f95870c51c696b076b3c9b266e1d7cde52a30d4 Mon Sep 17 00:00:00 2001 From: Alistair Coles Date: Tue, 25 Apr 2023 21:18:01 +0100 Subject: [PATCH] ECFragGetter: simplify iter_bytes_from_response_part Refactor and add some targeted unit tests. No behavioral change. Change-Id: I153528b8a1709f3756c261cf3eb2acfd5de10f9c --- swift/proxy/controllers/obj.py | 40 +++++--------- test/unit/proxy/controllers/test_obj.py | 69 ++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 30 deletions(-) diff --git a/swift/proxy/controllers/obj.py b/swift/proxy/controllers/obj.py index 057c77ffa3..fa9169e474 100644 --- a/swift/proxy/controllers/obj.py +++ b/swift/proxy/controllers/obj.py @@ -2500,7 +2500,7 @@ class ECFragGetter(object): self.client_chunk_size = policy.fragment_size self.skip_bytes = 0 self.bytes_used_from_backend = 0 - self.source = None + self.source = self.node = None self.logger_thread_locals = logger_thread_locals self.logger = logger @@ -2660,14 +2660,13 @@ class ECFragGetter(object): read_chunk_size=self.app.object_chunk_size) def iter_bytes_from_response_part(self, part_file, nbytes): - client_chunk_size = self.client_chunk_size - node_timeout = self.app.recoverable_node_timeout nchunks = 0 buf = b'' part_file = ByteCountEnforcer(part_file, nbytes) while True: try: - with WatchdogTimeout(self.app.watchdog, node_timeout, + with WatchdogTimeout(self.app.watchdog, + self.app.recoverable_node_timeout, ChunkReadTimeout): chunk = part_file.read(self.app.object_chunk_size) nchunks += 1 @@ -2726,33 +2725,18 @@ class ECFragGetter(object): self.bytes_used_from_backend += len(buf) buf = b'' - if not chunk: - if buf: - with WatchdogTimeout(self.app.watchdog, - self.app.client_timeout, - ChunkWriteTimeout): - self.bytes_used_from_backend += len(buf) - yield buf - buf = b'' - break - - if client_chunk_size is not None: - while len(buf) >= client_chunk_size: - client_chunk = buf[:client_chunk_size] - buf = buf[client_chunk_size:] - with WatchdogTimeout(self.app.watchdog, - self.app.client_timeout, - ChunkWriteTimeout): - self.bytes_used_from_backend += \ - len(client_chunk) - yield client_chunk - else: + client_chunk_size = self.client_chunk_size or len(buf) + while buf and (len(buf) >= client_chunk_size or not chunk): + client_chunk = buf[:client_chunk_size] + buf = buf[client_chunk_size:] with WatchdogTimeout(self.app.watchdog, self.app.client_timeout, ChunkWriteTimeout): - self.bytes_used_from_backend += len(buf) - yield buf - buf = b'' + self.bytes_used_from_backend += len(client_chunk) + yield client_chunk + + if not chunk: + break # This is for fairness; if the network is outpacing # the CPU, we'll always be able to read and write diff --git a/test/unit/proxy/controllers/test_obj.py b/test/unit/proxy/controllers/test_obj.py index 79753e00b5..8224a2f76b 100644 --- a/test/unit/proxy/controllers/test_obj.py +++ b/test/unit/proxy/controllers/test_obj.py @@ -39,8 +39,9 @@ else: import swift from swift.common import utils, swob, exceptions -from swift.common.exceptions import ChunkWriteTimeout -from swift.common.utils import Timestamp, list_from_csv, md5 +from swift.common.exceptions import ChunkWriteTimeout, ShortReadError, \ + ChunkReadTimeout +from swift.common.utils import Timestamp, list_from_csv, md5, FileLikeIter from swift.proxy import server as proxy_server from swift.proxy.controllers import obj from swift.proxy.controllers.base import \ @@ -6676,5 +6677,69 @@ class TestNumContainerUpdates(unittest.TestCase): c_replica, o_replica, o_quorum)) +@patch_policies(with_ec_default=True) +class TestECFragGetter(BaseObjectControllerMixin, unittest.TestCase): + def setUp(self): + super(TestECFragGetter, self).setUp() + req = Request.blank(path='/a/c/o') + self.getter = obj.ECFragGetter( + self.app, req, None, None, self.policy, 'a/c/o', + {}, None, self.logger.thread_locals, + self.logger) + + def test_iter_bytes_from_response_part(self): + part = FileLikeIter([b'some', b'thing']) + it = self.getter.iter_bytes_from_response_part(part, nbytes=None) + self.assertEqual(b'something', b''.join(it)) + + def test_iter_bytes_from_response_part_insufficient_bytes(self): + part = FileLikeIter([b'some', b'thing']) + it = self.getter.iter_bytes_from_response_part(part, nbytes=100) + with mock.patch.object(self.getter, '_dig_for_source_and_node', + return_value=(None, None)): + with self.assertRaises(ShortReadError) as cm: + b''.join(it) + self.assertEqual('Too few bytes; read 9, expecting 100', + str(cm.exception)) + + def test_iter_bytes_from_response_part_read_timeout(self): + part = FileLikeIter([b'some', b'thing']) + self.app.recoverable_node_timeout = 0.05 + self.app.client_timeout = 0.8 + it = self.getter.iter_bytes_from_response_part(part, nbytes=9) + with mock.patch.object(self.getter, '_dig_for_source_and_node', + return_value=(None, None)): + with mock.patch.object(part, 'read', + side_effect=[b'some', ChunkReadTimeout(9)]): + with self.assertRaises(ChunkReadTimeout) as cm: + b''.join(it) + self.assertEqual('9 seconds', str(cm.exception)) + + def test_iter_bytes_from_response_part_null_chunk_size(self): + # we don't expect a policy to have fragment_size None or zero but + # verify that the getter is defensive + self.getter.client_chunk_size = None + part = FileLikeIter([b'some', b'thing', b'']) + it = self.getter.iter_bytes_from_response_part(part, nbytes=None) + self.assertEqual(b'something', b''.join(it)) + + self.getter.client_chunk_size = 0 + part = FileLikeIter([b'some', b'thing', b'']) + it = self.getter.iter_bytes_from_response_part(part, nbytes=None) + self.assertEqual(b'something', b''.join(it)) + + def test_iter_bytes_from_response_part_small_chunk_size(self): + # we don't expect a policy to have fragment_size None or zero but + # verify that the getter is defensive + self.getter.client_chunk_size = 4 + part = FileLikeIter([b'some', b'thing', b'']) + it = self.getter.iter_bytes_from_response_part(part, nbytes=None) + self.assertEqual([b'some', b'thin', b'g'], [ch for ch in it]) + self.getter.client_chunk_size = 1 + part = FileLikeIter([b'some', b'thing', b'']) + it = self.getter.iter_bytes_from_response_part(part, nbytes=None) + self.assertEqual([c.encode() for c in 'something'], [ch for ch in it]) + + if __name__ == '__main__': unittest.main()