Cleanup for iterators in SegmentedIterable

We had a pair of large, complicated iterators to handle fetching all
the segment data, and they were hard to read and think about. I tried
to break them out into some simpler pieces:

 * one to handle coalescing multiple requests to the same segment

 * one to handle fetching the bytes from each segment

 * one to check that the download isn't taking too long

 * one to count the bytes and make sure we sent the right number

 * one to catch errors and handle cleanup

It's more nesting, but each level now does just one thing.

Change-Id: If6f5cbd79edeff6ecb81350792449ce767919bcc
This commit is contained in:
Samuel Merritt 2018-01-26 16:51:03 -08:00
parent d6e911c623
commit 98d185905a
2 changed files with 106 additions and 106 deletions

View File

@ -353,7 +353,6 @@ class SegmentedIterable(object):
self.current_resp = None
def _coalesce_requests(self):
start_time = time.time()
pending_req = pending_etag = pending_size = None
try:
for seg_dict in self.listing_iter:
@ -376,11 +375,6 @@ class SegmentedIterable(object):
first_byte = first_byte or 0
go_to_end = last_byte is None or (
seg_size is not None and last_byte == seg_size - 1)
if time.time() - start_time > self.max_get_time:
raise SegmentError(
'While processing manifest %s, '
'max LO GET time of %ds exceeded' %
(self.name, self.max_get_time))
# The "multipart-manifest=get" query param ensures that the
# segment is a plain old object, not some flavor of large
# object; therefore, its etag is its MD5sum and hence we can
@ -433,37 +427,22 @@ class SegmentedIterable(object):
except ListingIterError:
e_type, e_value, e_traceback = sys.exc_info()
if time.time() - start_time > self.max_get_time:
raise SegmentError(
'While processing manifest %s, '
'max LO GET time of %ds exceeded' %
(self.name, self.max_get_time))
if pending_req:
yield pending_req, pending_etag, pending_size
six.reraise(e_type, e_value, e_traceback)
if time.time() - start_time > self.max_get_time:
raise SegmentError(
'While processing manifest %s, '
'max LO GET time of %ds exceeded' %
(self.name, self.max_get_time))
if pending_req:
yield pending_req, pending_etag, pending_size
def _internal_iter(self):
bytes_left = self.response_body_length
try:
def _requests_to_bytes_iter(self):
# Take the requests out of self._coalesce_requests, actually make
# the requests, and generate the bytes from the responses.
#
# Yields 2-tuples (segment-name, byte-chunk). The segment name is
# used for logging.
for data_or_req, seg_etag, seg_size in self._coalesce_requests():
if isinstance(data_or_req, bytes):
chunk = data_or_req # ugly, awful overloading
if bytes_left is None:
yield chunk
elif bytes_left >= len(chunk):
yield chunk
bytes_left -= len(chunk)
else:
yield chunk[:bytes_left]
if isinstance(data_or_req, bytes): # ugly, awful overloading
yield ('data segment', data_or_req)
continue
seg_req = data_or_req
seg_resp = seg_req.get_response(self.app)
@ -508,21 +487,7 @@ class SegmentedIterable(object):
for chunk in itertools.chain.from_iterable(document_iters):
if seg_hash:
seg_hash.update(chunk)
if bytes_left is None:
yield chunk
elif bytes_left >= len(chunk):
yield chunk
bytes_left -= len(chunk)
else:
yield chunk[:bytes_left]
bytes_left -= len(chunk)
close_if_possible(seg_resp.app_iter)
raise SegmentError(
'Too many bytes for %(name)s; truncating in '
'%(seg)s with %(left)d bytes left' %
{'name': self.name, 'seg': seg_req.path,
'left': bytes_left})
yield (seg_req.path, chunk)
close_if_possible(seg_resp.app_iter)
if seg_hash and seg_hash.hexdigest() != seg_resp.etag:
@ -532,9 +497,49 @@ class SegmentedIterable(object):
{'seg': seg_req.path, 'etag': seg_resp.etag,
'name': self.name, 'actual': seg_hash.hexdigest()})
def _byte_counting_iter(self):
# Checks that we give the client the right number of bytes. Raises
# SegmentError if the number of bytes is wrong.
bytes_left = self.response_body_length
for seg_name, chunk in self._requests_to_bytes_iter():
if bytes_left is None:
yield chunk
elif bytes_left >= len(chunk):
yield chunk
bytes_left -= len(chunk)
else:
yield chunk[:bytes_left]
bytes_left -= len(chunk)
raise SegmentError(
'Too many bytes for %(name)s; truncating in '
'%(seg)s with %(left)d bytes left' %
{'name': self.name, 'seg': seg_name,
'left': bytes_left})
if bytes_left:
raise SegmentError(
'Not enough bytes for %s; closing connection' % self.name)
def _time_limited_iter(self):
# Makes sure a GET response doesn't take more than self.max_get_time
# seconds to process. Raises an exception if things take too long.
start_time = time.time()
for chunk in self._byte_counting_iter():
now = time.time()
yield chunk
if now - start_time > self.max_get_time:
raise SegmentError(
'While processing manifest %s, '
'max LO GET time of %ds exceeded' %
(self.name, self.max_get_time))
def _internal_iter(self):
# Top level of our iterator stack: pass bytes through; catch and
# handle exceptions.
try:
for chunk in self._time_limited_iter():
yield chunk
except (ListingIterError, SegmentError) as err:
self.logger.error(err)
if not self.validated_first_segment:

View File

@ -3040,14 +3040,9 @@ class TestSloGetManifest(SloTestCase):
def test_download_takes_too_long(self, mock_time):
mock_time.time.side_effect = [
0, # start time
1, # just building the first segment request; purely local
2, # build the second segment request object, too, so we know we
# can't coalesce and should instead go fetch the first segment
7 * 3600, # that takes a while, but gets serviced; we build the
# third request and service the second
21 * 3600, # which takes *even longer* (ostensibly something to
# do with submanifests), but we build the fourth...
28 * 3600, # and before we go to service it we time out
10 * 3600, # a_5
20 * 3600, # b_10
30 * 3600, # c_15, but then we time out
]
req = Request.blank(
'/v1/AUTH_test/gettest/manifest-abcd',