Cleanup for iterators in SegmentedIterable

We had a pair of large, complicated iterators to handle fetching all
the segment data, and they were hard to read and think about. I tried
to break them out into some simpler pieces:

 * one to handle coalescing multiple requests to the same segment

 * one to handle fetching the bytes from each segment

 * one to check that the download isn't taking too long

 * one to count the bytes and make sure we sent the right number

 * one to catch errors and handle cleanup

It's more nesting, but each level now does just one thing.

Change-Id: If6f5cbd79edeff6ecb81350792449ce767919bcc
This commit is contained in:
Samuel Merritt 2018-01-26 16:51:03 -08:00
parent d6e911c623
commit 98d185905a
2 changed files with 106 additions and 106 deletions

View File

@ -353,7 +353,6 @@ class SegmentedIterable(object):
self.current_resp = None self.current_resp = None
def _coalesce_requests(self): def _coalesce_requests(self):
start_time = time.time()
pending_req = pending_etag = pending_size = None pending_req = pending_etag = pending_size = None
try: try:
for seg_dict in self.listing_iter: for seg_dict in self.listing_iter:
@ -376,11 +375,6 @@ class SegmentedIterable(object):
first_byte = first_byte or 0 first_byte = first_byte or 0
go_to_end = last_byte is None or ( go_to_end = last_byte is None or (
seg_size is not None and last_byte == seg_size - 1) seg_size is not None and last_byte == seg_size - 1)
if time.time() - start_time > self.max_get_time:
raise SegmentError(
'While processing manifest %s, '
'max LO GET time of %ds exceeded' %
(self.name, self.max_get_time))
# The "multipart-manifest=get" query param ensures that the # The "multipart-manifest=get" query param ensures that the
# segment is a plain old object, not some flavor of large # segment is a plain old object, not some flavor of large
# object; therefore, its etag is its MD5sum and hence we can # object; therefore, its etag is its MD5sum and hence we can
@ -433,108 +427,119 @@ class SegmentedIterable(object):
except ListingIterError: except ListingIterError:
e_type, e_value, e_traceback = sys.exc_info() e_type, e_value, e_traceback = sys.exc_info()
if time.time() - start_time > self.max_get_time:
raise SegmentError(
'While processing manifest %s, '
'max LO GET time of %ds exceeded' %
(self.name, self.max_get_time))
if pending_req: if pending_req:
yield pending_req, pending_etag, pending_size yield pending_req, pending_etag, pending_size
six.reraise(e_type, e_value, e_traceback) six.reraise(e_type, e_value, e_traceback)
if time.time() - start_time > self.max_get_time:
raise SegmentError(
'While processing manifest %s, '
'max LO GET time of %ds exceeded' %
(self.name, self.max_get_time))
if pending_req: if pending_req:
yield pending_req, pending_etag, pending_size yield pending_req, pending_etag, pending_size
def _internal_iter(self): def _requests_to_bytes_iter(self):
# Take the requests out of self._coalesce_requests, actually make
# the requests, and generate the bytes from the responses.
#
# Yields 2-tuples (segment-name, byte-chunk). The segment name is
# used for logging.
for data_or_req, seg_etag, seg_size in self._coalesce_requests():
if isinstance(data_or_req, bytes): # ugly, awful overloading
yield ('data segment', data_or_req)
continue
seg_req = data_or_req
seg_resp = seg_req.get_response(self.app)
if not is_success(seg_resp.status_int):
close_if_possible(seg_resp.app_iter)
raise SegmentError(
'While processing manifest %s, '
'got %d while retrieving %s' %
(self.name, seg_resp.status_int, seg_req.path))
elif ((seg_etag and (seg_resp.etag != seg_etag)) or
(seg_size and (seg_resp.content_length != seg_size) and
not seg_req.range)):
# The content-length check is for security reasons. Seems
# possible that an attacker could upload a >1mb object and
# then replace it with a much smaller object with same
# etag. Then create a big nested SLO that calls that
# object many times which would hammer our obj servers. If
# this is a range request, don't check content-length
# because it won't match.
close_if_possible(seg_resp.app_iter)
raise SegmentError(
'Object segment no longer valid: '
'%(path)s etag: %(r_etag)s != %(s_etag)s or '
'%(r_size)s != %(s_size)s.' %
{'path': seg_req.path, 'r_etag': seg_resp.etag,
'r_size': seg_resp.content_length,
's_etag': seg_etag,
's_size': seg_size})
else:
self.current_resp = seg_resp
seg_hash = None
if seg_resp.etag and not seg_req.headers.get('Range'):
# Only calculate the MD5 if it we can use it to validate
seg_hash = hashlib.md5()
document_iters = maybe_multipart_byteranges_to_document_iters(
seg_resp.app_iter,
seg_resp.headers['Content-Type'])
for chunk in itertools.chain.from_iterable(document_iters):
if seg_hash:
seg_hash.update(chunk)
yield (seg_req.path, chunk)
close_if_possible(seg_resp.app_iter)
if seg_hash and seg_hash.hexdigest() != seg_resp.etag:
raise SegmentError(
"Bad MD5 checksum in %(name)s for %(seg)s: headers had"
" %(etag)s, but object MD5 was actually %(actual)s" %
{'seg': seg_req.path, 'etag': seg_resp.etag,
'name': self.name, 'actual': seg_hash.hexdigest()})
def _byte_counting_iter(self):
# Checks that we give the client the right number of bytes. Raises
# SegmentError if the number of bytes is wrong.
bytes_left = self.response_body_length bytes_left = self.response_body_length
try: for seg_name, chunk in self._requests_to_bytes_iter():
for data_or_req, seg_etag, seg_size in self._coalesce_requests(): if bytes_left is None:
if isinstance(data_or_req, bytes): yield chunk
chunk = data_or_req # ugly, awful overloading elif bytes_left >= len(chunk):
if bytes_left is None: yield chunk
yield chunk bytes_left -= len(chunk)
elif bytes_left >= len(chunk): else:
yield chunk yield chunk[:bytes_left]
bytes_left -= len(chunk) bytes_left -= len(chunk)
else:
yield chunk[:bytes_left]
continue
seg_req = data_or_req
seg_resp = seg_req.get_response(self.app)
if not is_success(seg_resp.status_int):
close_if_possible(seg_resp.app_iter)
raise SegmentError(
'While processing manifest %s, '
'got %d while retrieving %s' %
(self.name, seg_resp.status_int, seg_req.path))
elif ((seg_etag and (seg_resp.etag != seg_etag)) or
(seg_size and (seg_resp.content_length != seg_size) and
not seg_req.range)):
# The content-length check is for security reasons. Seems
# possible that an attacker could upload a >1mb object and
# then replace it with a much smaller object with same
# etag. Then create a big nested SLO that calls that
# object many times which would hammer our obj servers. If
# this is a range request, don't check content-length
# because it won't match.
close_if_possible(seg_resp.app_iter)
raise SegmentError(
'Object segment no longer valid: '
'%(path)s etag: %(r_etag)s != %(s_etag)s or '
'%(r_size)s != %(s_size)s.' %
{'path': seg_req.path, 'r_etag': seg_resp.etag,
'r_size': seg_resp.content_length,
's_etag': seg_etag,
's_size': seg_size})
else:
self.current_resp = seg_resp
seg_hash = None
if seg_resp.etag and not seg_req.headers.get('Range'):
# Only calculate the MD5 if it we can use it to validate
seg_hash = hashlib.md5()
document_iters = maybe_multipart_byteranges_to_document_iters(
seg_resp.app_iter,
seg_resp.headers['Content-Type'])
for chunk in itertools.chain.from_iterable(document_iters):
if seg_hash:
seg_hash.update(chunk)
if bytes_left is None:
yield chunk
elif bytes_left >= len(chunk):
yield chunk
bytes_left -= len(chunk)
else:
yield chunk[:bytes_left]
bytes_left -= len(chunk)
close_if_possible(seg_resp.app_iter)
raise SegmentError(
'Too many bytes for %(name)s; truncating in '
'%(seg)s with %(left)d bytes left' %
{'name': self.name, 'seg': seg_req.path,
'left': bytes_left})
close_if_possible(seg_resp.app_iter)
if seg_hash and seg_hash.hexdigest() != seg_resp.etag:
raise SegmentError(
"Bad MD5 checksum in %(name)s for %(seg)s: headers had"
" %(etag)s, but object MD5 was actually %(actual)s" %
{'seg': seg_req.path, 'etag': seg_resp.etag,
'name': self.name, 'actual': seg_hash.hexdigest()})
if bytes_left:
raise SegmentError( raise SegmentError(
'Not enough bytes for %s; closing connection' % self.name) 'Too many bytes for %(name)s; truncating in '
'%(seg)s with %(left)d bytes left' %
{'name': self.name, 'seg': seg_name,
'left': bytes_left})
if bytes_left:
raise SegmentError(
'Not enough bytes for %s; closing connection' % self.name)
def _time_limited_iter(self):
# Makes sure a GET response doesn't take more than self.max_get_time
# seconds to process. Raises an exception if things take too long.
start_time = time.time()
for chunk in self._byte_counting_iter():
now = time.time()
yield chunk
if now - start_time > self.max_get_time:
raise SegmentError(
'While processing manifest %s, '
'max LO GET time of %ds exceeded' %
(self.name, self.max_get_time))
def _internal_iter(self):
# Top level of our iterator stack: pass bytes through; catch and
# handle exceptions.
try:
for chunk in self._time_limited_iter():
yield chunk
except (ListingIterError, SegmentError) as err: except (ListingIterError, SegmentError) as err:
self.logger.error(err) self.logger.error(err)
if not self.validated_first_segment: if not self.validated_first_segment:

View File

@ -3040,14 +3040,9 @@ class TestSloGetManifest(SloTestCase):
def test_download_takes_too_long(self, mock_time): def test_download_takes_too_long(self, mock_time):
mock_time.time.side_effect = [ mock_time.time.side_effect = [
0, # start time 0, # start time
1, # just building the first segment request; purely local 10 * 3600, # a_5
2, # build the second segment request object, too, so we know we 20 * 3600, # b_10
# can't coalesce and should instead go fetch the first segment 30 * 3600, # c_15, but then we time out
7 * 3600, # that takes a while, but gets serviced; we build the
# third request and service the second
21 * 3600, # which takes *even longer* (ostensibly something to
# do with submanifests), but we build the fourth...
28 * 3600, # and before we go to service it we time out
] ]
req = Request.blank( req = Request.blank(
'/v1/AUTH_test/gettest/manifest-abcd', '/v1/AUTH_test/gettest/manifest-abcd',