diff --git a/swift/common/middleware/slo.py b/swift/common/middleware/slo.py index 47e9a660ff..d376d3b659 100644 --- a/swift/common/middleware/slo.py +++ b/swift/common/middleware/slo.py @@ -213,12 +213,11 @@ from swift.common.exceptions import ListingIterError, SegmentError from swift.common.swob import Request, HTTPBadRequest, HTTPServerError, \ HTTPMethodNotAllowed, HTTPRequestEntityTooLarge, HTTPLengthRequired, \ HTTPOk, HTTPPreconditionFailed, HTTPException, HTTPNotFound, \ - HTTPUnauthorized, HTTPConflict, HTTPRequestedRangeNotSatisfiable,\ - Response, Range + HTTPUnauthorized, HTTPConflict, Response, Range from swift.common.utils import get_logger, config_true_value, \ get_valid_utf8_str, override_bytes_from_content_type, split_path, \ register_swift_info, RateLimitedIterator, quote, close_if_possible, \ - closing_if_possible + closing_if_possible, LRUCache from swift.common.request_helpers import SegmentedIterable from swift.common.constraints import check_utf8, MAX_BUFFERED_SLO_SEGMENTS from swift.common.http import HTTP_NOT_FOUND, HTTP_UNAUTHORIZED, is_success @@ -386,8 +385,6 @@ class SloGetContext(WSGIContext): def __init__(self, slo): self.slo = slo - self.first_byte = None - self.last_byte = None super(SloGetContext, self).__init__(slo.app) def _fetch_sub_slo_segments(self, req, version, acc, con, obj): @@ -434,7 +431,7 @@ class SloGetContext(WSGIContext): return int(seg_dict['bytes']) def _segment_listing_iterator(self, req, version, account, segments, - recursion_depth=1): + byteranges): for seg_dict in segments: if config_true_value(seg_dict.get('sub_slo')): override_bytes_from_content_type(seg_dict, @@ -448,23 +445,46 @@ class SloGetContext(WSGIContext): # If we were to make SegmentedIterable handle all the range # calculations, we would be unable to make this optimization. total_length = sum(self._segment_length(seg) for seg in segments) - if self.first_byte is None: - self.first_byte = 0 - if self.last_byte is None: - self.last_byte = total_length - 1 + if not byteranges: + byteranges = [(0, total_length - 1)] + # Cache segments from sub-SLOs in case more than one byterange + # includes data from a particular sub-SLO. We only cache a few sets + # of segments so that a malicious user cannot build a giant SLO tree + # and then GET it to run the proxy out of memory. + # + # LRUCache is a little awkward to use this way, but it beats doing + # things manually. + # + # 20 is sort of an arbitrary choice; it's twice our max recursion + # depth, so we know this won't expand memory requirements by too + # much. + cached_fetch_sub_slo_segments = \ + LRUCache(maxsize=20)(self._fetch_sub_slo_segments) + + for first_byte, last_byte in byteranges: + byterange_listing_iter = self._byterange_listing_iterator( + req, version, account, segments, first_byte, last_byte, + cached_fetch_sub_slo_segments) + for seg_info in byterange_listing_iter: + yield seg_info + + def _byterange_listing_iterator(self, req, version, account, segments, + first_byte, last_byte, + cached_fetch_sub_slo_segments, + recursion_depth=1): last_sub_path = None for seg_dict in segments: seg_length = self._segment_length(seg_dict) - if self.first_byte >= seg_length: + if first_byte >= seg_length: # don't need any bytes from this segment - self.first_byte -= seg_length - self.last_byte -= seg_length + first_byte -= seg_length + last_byte -= seg_length continue - if self.last_byte < 0: + if last_byte < 0: # no bytes are needed from this or any future segment - break + return seg_range = seg_dict.get('range') if seg_range is None: @@ -483,33 +503,30 @@ class SloGetContext(WSGIContext): sub_path = get_valid_utf8_str(seg_dict['name']) sub_cont, sub_obj = split_path(sub_path, 2, 2, True) if last_sub_path != sub_path: - sub_segments = self._fetch_sub_slo_segments( + sub_segments = cached_fetch_sub_slo_segments( req, version, account, sub_cont, sub_obj) last_sub_path = sub_path # Use the existing machinery to slice into the sub-SLO. - # This requires that we save off our current state, and - # restore at the other end. - orig_start, orig_end = self.first_byte, self.last_byte - self.first_byte = range_start + max(0, self.first_byte) - self.last_byte = min(range_end, range_start + self.last_byte) - - for sub_seg_dict, sb, eb in self._segment_listing_iterator( + for sub_seg_dict, sb, eb in self._byterange_listing_iterator( req, version, account, sub_segments, + # This adjusts first_byte and last_byte to be + # relative to the sub-SLO. + range_start + max(0, first_byte), + min(range_end, range_start + last_byte), + + cached_fetch_sub_slo_segments, recursion_depth=recursion_depth + 1): yield sub_seg_dict, sb, eb - - # Restore the first/last state - self.first_byte, self.last_byte = orig_start, orig_end else: if isinstance(seg_dict['name'], six.text_type): seg_dict['name'] = seg_dict['name'].encode("utf-8") yield (seg_dict, - max(0, self.first_byte) + range_start, - min(range_end, range_start + self.last_byte)) + max(0, first_byte) + range_start, + min(range_end, range_start + last_byte)) - self.first_byte -= seg_length - self.last_byte -= seg_length + first_byte -= seg_length + last_byte -= seg_length def _need_to_refetch_manifest(self, req): """ @@ -692,22 +709,18 @@ class SloGetContext(WSGIContext): def _manifest_get_response(self, req, content_length, response_headers, segments): - self.first_byte, self.last_byte = None, None if req.range: - byteranges = req.range.ranges_for_length(content_length) - if len(byteranges) == 0: - return HTTPRequestedRangeNotSatisfiable(request=req) - elif len(byteranges) == 1: - self.first_byte, self.last_byte = byteranges[0] + byteranges = [ # For some reason, swob.Range.ranges_for_length adds 1 to the # last byte's position. - self.last_byte -= 1 - else: - req.range = None + (start, end - 1) for start, end + in req.range.ranges_for_length(content_length)] + else: + byteranges = [] ver, account, _junk = req.split_path(3, 3, rest_with_last=True) plain_listing_iter = self._segment_listing_iterator( - req, ver, account, segments) + req, ver, account, segments, byteranges) def is_small_segment((seg_dict, start_byte, end_byte)): start = 0 if start_byte is None else start_byte diff --git a/swift/common/request_helpers.py b/swift/common/request_helpers.py index 9d231900ca..308f15d3ac 100644 --- a/swift/common/request_helpers.py +++ b/swift/common/request_helpers.py @@ -35,11 +35,11 @@ from swift.common.constraints import FORMAT2CONTENT_TYPE from swift.common.exceptions import ListingIterError, SegmentError from swift.common.http import is_success from swift.common.swob import HTTPBadRequest, HTTPNotAcceptable, \ - HTTPServiceUnavailable, Range, is_chunked + HTTPServiceUnavailable, Range, is_chunked, multi_range_iterator from swift.common.utils import split_path, validate_device_partition, \ close_if_possible, maybe_multipart_byteranges_to_document_iters, \ multipart_byteranges_to_document_iters, parse_content_type, \ - parse_content_range, csv_append, list_from_csv + parse_content_range, csv_append, list_from_csv, Spliterator from swift.common.wsgi import make_subrequest @@ -520,6 +520,25 @@ class SegmentedIterable(object): """ return self + def app_iter_ranges(self, ranges, content_type, boundary, content_size): + """ + This method assumes that iter(self) yields all the data bytes that + go into the response, but none of the MIME stuff. For example, if + the response will contain three MIME docs with data "abcd", "efgh", + and "ijkl", then iter(self) will give out the bytes "abcdefghijkl". + + This method inserts the MIME stuff around the data bytes. + """ + si = Spliterator(self) + mri = multi_range_iterator( + ranges, content_type, boundary, content_size, + lambda start, end_plus_one: si.take(end_plus_one - start)) + try: + for x in mri: + yield x + finally: + self.close() + def validate_first_segment(self): """ Start fetching object data to ensure that the first segment (if any) is diff --git a/swift/common/swob.py b/swift/common/swob.py index 1a8c6ea630..df66664a9d 100644 --- a/swift/common/swob.py +++ b/swift/common/swob.py @@ -1245,9 +1245,9 @@ class Response(object): ranges = self.request.range.ranges_for_length(self.content_length) if ranges == []: self.status = 416 - self.content_length = 0 close_if_possible(app_iter) - return [''] + body = None + app_iter = None elif ranges: range_size = len(ranges) if range_size > 0: diff --git a/swift/common/utils.py b/swift/common/utils.py index 57e0e9be3f..24ec0f5faa 100644 --- a/swift/common/utils.py +++ b/swift/common/utils.py @@ -3292,6 +3292,80 @@ class LRUCache(object): return LRUCacheWrapped() +class Spliterator(object): + """ + Takes an iterator yielding sliceable things (e.g. strings or lists) and + yields subiterators, each yielding up to the requested number of items + from the source. + + >>> si = Spliterator(["abcde", "fg", "hijkl"]) + >>> ''.join(si.take(4)) + "abcd" + >>> ''.join(si.take(3)) + "efg" + >>> ''.join(si.take(1)) + "h" + >>> ''.join(si.take(3)) + "ijk" + >>> ''.join(si.take(3)) + "l" # shorter than requested; this can happen with the last iterator + + """ + def __init__(self, source_iterable): + self.input_iterator = iter(source_iterable) + self.leftovers = None + self.leftovers_index = 0 + self._iterator_in_progress = False + + def take(self, n): + if self._iterator_in_progress: + raise ValueError("cannot call take() again until the first" + " iterator is exhausted") + self._iterator_in_progress = True + + try: + if self.leftovers: + # All this string slicing is a little awkward, but it's for + # a good reason. Consider a length N string that someone is + # taking k bytes at a time. + # + # With this implementation, we create one new string of + # length k (copying the bytes) on each call to take(). Once + # the whole input has been consumed, each byte has been + # copied exactly once, giving O(N) bytes copied. + # + # If, instead of this, we were to set leftovers = + # leftovers[k:] and omit leftovers_index, then each call to + # take() would copy k bytes to create the desired substring, + # then copy all the remaining bytes to reset leftovers, + # resulting in an overall O(N^2) bytes copied. + llen = len(self.leftovers) - self.leftovers_index + if llen <= n: + n -= llen + yield self.leftovers[self.leftovers_index:] + self.leftovers = None + self.leftovers_index = 0 + else: + yield self.leftovers[ + self.leftovers_index:(self.leftovers_index + n)] + self.leftovers_index += n + n = 0 + + while n > 0: + chunk = next(self.input_iterator) + cl = len(chunk) + if cl <= n: + n -= cl + yield chunk + else: + yield chunk[:n] + self.leftovers = chunk + self.leftovers_index = n + n = 0 + finally: + self._iterator_in_progress = False + + def tpool_reraise(func, *args, **kwargs): """ Hack to work around Eventlet's tpool not catching and reraising Timeouts. diff --git a/test/functional/tests.py b/test/functional/tests.py index 82b91d9a26..664b9509db 100644 --- a/test/functional/tests.py +++ b/test/functional/tests.py @@ -3097,6 +3097,34 @@ class TestSlo(Base): self.assertEqual('b', file_contents[-2]) self.assertEqual('c', file_contents[-1]) + def test_slo_multi_ranged_get(self): + file_item = self.env.container.file('manifest-abcde') + file_contents = file_item.read( + hdrs={"Range": "bytes=1048571-1048580,2097147-2097156"}) + + # See testMultiRangeGets for explanation + parser = email.parser.FeedParser() + parser.feed("Content-Type: %s\r\n\r\n" % file_item.content_type) + parser.feed(file_contents) + + root_message = parser.close() + self.assertTrue(root_message.is_multipart()) # sanity check + + byteranges = root_message.get_payload() + self.assertEqual(len(byteranges), 2) + + self.assertEqual(byteranges[0]['Content-Type'], + "application/octet-stream") + self.assertEqual( + byteranges[0]['Content-Range'], "bytes 1048571-1048580/4194305") + self.assertEqual(byteranges[0].get_payload(), "aaaaabbbbb") + + self.assertEqual(byteranges[1]['Content-Type'], + "application/octet-stream") + self.assertEqual( + byteranges[1]['Content-Range'], "bytes 2097147-2097156/4194305") + self.assertEqual(byteranges[1].get_payload(), "bbbbbccccc") + def test_slo_ranged_submanifest(self): file_item = self.env.container.file('manifest-abcde-submanifest') file_contents = file_item.read(size=1024 * 1024 + 2, diff --git a/test/unit/common/middleware/test_slo.py b/test/unit/common/middleware/test_slo.py index 3898c9db6b..50c86c7e7d 100644 --- a/test/unit/common/middleware/test_slo.py +++ b/test/unit/common/middleware/test_slo.py @@ -22,12 +22,14 @@ import time import unittest from mock import patch from hashlib import md5 +from StringIO import StringIO from swift.common import swob, utils from swift.common.exceptions import ListingIterError, SegmentError from swift.common.header_key_dict import HeaderKeyDict from swift.common.middleware import slo from swift.common.swob import Request, HTTPException -from swift.common.utils import quote, closing_if_possible, close_if_possible +from swift.common.utils import quote, closing_if_possible, close_if_possible, \ + parse_content_type, iter_multipart_mime_documents, parse_mime_headers from test.unit.common.middleware.helpers import FakeSwift @@ -1735,6 +1737,116 @@ class TestSloGetManifest(SloTestCase): self.assertEqual(self.app.swift_sources[1:], ['SLO'] * (len(self.app.swift_sources) - 1)) + def test_multiple_ranges_get_manifest(self): + req = Request.blank( + '/v1/AUTH_test/gettest/manifest-abcd', + environ={'REQUEST_METHOD': 'GET'}, + headers={'Range': 'bytes=3-17,20-24,35-999999'}) + status, headers, body = self.call_slo(req) + headers = HeaderKeyDict(headers) + + self.assertEqual(status, '206 Partial Content') + + ct, params = parse_content_type(headers['Content-Type']) + params = dict(params) + self.assertEqual(ct, 'multipart/byteranges') + boundary = params.get('boundary') + self.assertTrue(boundary is not None) + + self.assertEqual(len(body), int(headers['Content-Length'])) + + got_mime_docs = [] + for mime_doc_fh in iter_multipart_mime_documents( + StringIO(body), boundary): + headers = parse_mime_headers(mime_doc_fh) + body = mime_doc_fh.read() + got_mime_docs.append((headers, body)) + self.assertEqual(len(got_mime_docs), 3) + + first_range_headers = got_mime_docs[0][0] + first_range_body = got_mime_docs[0][1] + self.assertEqual(first_range_headers['Content-Range'], + 'bytes 3-17/50') + self.assertEqual(first_range_headers['Content-Type'], + 'application/json') + self.assertEqual(first_range_body, 'aabbbbbbbbbbccc') + + second_range_headers = got_mime_docs[1][0] + second_range_body = got_mime_docs[1][1] + self.assertEqual(second_range_headers['Content-Range'], + 'bytes 20-24/50') + self.assertEqual(second_range_headers['Content-Type'], + 'application/json') + self.assertEqual(second_range_body, 'ccccc') + + third_range_headers = got_mime_docs[2][0] + third_range_body = got_mime_docs[2][1] + self.assertEqual(third_range_headers['Content-Range'], + 'bytes 35-49/50') + self.assertEqual(third_range_headers['Content-Type'], + 'application/json') + self.assertEqual(third_range_body, 'ddddddddddddddd') + + self.assertEqual( + self.app.calls, + [('GET', '/v1/AUTH_test/gettest/manifest-abcd'), + ('GET', '/v1/AUTH_test/gettest/manifest-abcd'), + ('GET', '/v1/AUTH_test/gettest/manifest-bc'), + ('GET', '/v1/AUTH_test/gettest/a_5?multipart-manifest=get'), + ('GET', '/v1/AUTH_test/gettest/b_10?multipart-manifest=get'), + ('GET', '/v1/AUTH_test/gettest/c_15?multipart-manifest=get'), + ('GET', '/v1/AUTH_test/gettest/d_20?multipart-manifest=get')]) + + ranges = [c[2].get('Range') for c in self.app.calls_with_headers] + self.assertEqual(ranges, [ + 'bytes=3-17,20-24,35-999999', # initial GET + None, # re-fetch top-level manifest + None, # fetch manifest-bc as sub-slo + 'bytes=3-', # a_5 + None, # b_10 + 'bytes=0-2,5-9', # c_15 + 'bytes=5-']) # d_20 + # we set swift.source for everything but the first request + self.assertIsNone(self.app.swift_sources[0]) + self.assertEqual(self.app.swift_sources[1:], + ['SLO'] * (len(self.app.swift_sources) - 1)) + + def test_multiple_ranges_including_suffix_get_manifest(self): + req = Request.blank( + '/v1/AUTH_test/gettest/manifest-abcd', + environ={'REQUEST_METHOD': 'GET'}, + headers={'Range': 'bytes=3-17,-21'}) + status, headers, body = self.call_slo(req) + headers = HeaderKeyDict(headers) + + self.assertEqual(status, '206 Partial Content') + + ct, params = parse_content_type(headers['Content-Type']) + params = dict(params) + self.assertEqual(ct, 'multipart/byteranges') + boundary = params.get('boundary') + self.assertTrue(boundary is not None) + + got_mime_docs = [] + for mime_doc_fh in iter_multipart_mime_documents( + StringIO(body), boundary): + headers = parse_mime_headers(mime_doc_fh) + body = mime_doc_fh.read() + got_mime_docs.append((headers, body)) + self.assertEqual(len(got_mime_docs), 2) + + first_range_headers = got_mime_docs[0][0] + first_range_body = got_mime_docs[0][1] + self.assertEqual(first_range_headers['Content-Range'], + 'bytes 3-17/50') + self.assertEqual(first_range_body, 'aabbbbbbbbbbccc') + + second_range_headers = got_mime_docs[1][0] + second_range_body = got_mime_docs[1][1] + self.assertEqual(second_range_headers['Content-Range'], + 'bytes 29-49/50') + self.assertEqual(second_range_body, 'cdddddddddddddddddddd') + def test_range_get_includes_whole_manifest(self): # If the first range GET results in retrieval of the entire manifest # body (which we can detect by looking at Content-Range), then we @@ -1924,21 +2036,6 @@ class TestSloGetManifest(SloTestCase): status, headers, body = self.call_slo(req) self.assertEqual(status, '416 Requested Range Not Satisfiable') - def test_multi_range_get_manifest(self): - # SLO doesn't support multi-range GETs. The way that you express - # "unsupported" in HTTP is to return a 200 and the whole entity. - req = Request.blank( - '/v1/AUTH_test/gettest/manifest-abcd', - environ={'REQUEST_METHOD': 'GET'}, - headers={'Range': 'bytes=0-0,2-2'}) - status, headers, body = self.call_slo(req) - headers = HeaderKeyDict(headers) - - self.assertEqual(status, '200 OK') - self.assertEqual(headers['Content-Length'], '50') - self.assertEqual( - body, 'aaaaabbbbbbbbbbcccccccccccccccdddddddddddddddddddd') - def test_get_segment_with_non_ascii_path(self): segment_body = u"a møøse once bit my sister".encode("utf-8") self.app.register( @@ -2027,11 +2124,9 @@ class TestSloGetManifest(SloTestCase): ('GET', '/v1/AUTH_test/gettest/manifest-bc-ranges'), ('GET', '/v1/AUTH_test/gettest/a_5?multipart-manifest=get'), ('GET', '/v1/AUTH_test/gettest/c_15?multipart-manifest=get'), - ('GET', '/v1/AUTH_test/gettest/manifest-bc-ranges'), ('GET', '/v1/AUTH_test/gettest/d_20?multipart-manifest=get'), ('GET', '/v1/AUTH_test/gettest/c_15?multipart-manifest=get'), ('GET', '/v1/AUTH_test/gettest/b_10?multipart-manifest=get'), - ('GET', '/v1/AUTH_test/gettest/manifest-bc-ranges'), ('GET', '/v1/AUTH_test/gettest/a_5?multipart-manifest=get'), ('GET', '/v1/AUTH_test/gettest/b_10?multipart-manifest=get'), ('GET', '/v1/AUTH_test/gettest/d_20?multipart-manifest=get')]) @@ -2043,11 +2138,9 @@ class TestSloGetManifest(SloTestCase): None, 'bytes=3-', 'bytes=0-2', - None, 'bytes=11-11', 'bytes=13-', 'bytes=4-6', - None, 'bytes=0-0', 'bytes=4-5', 'bytes=0-2']) @@ -2114,11 +2207,9 @@ class TestSloGetManifest(SloTestCase): ('GET', '/v1/AUTH_test/gettest/manifest-abcd-ranges'), ('GET', '/v1/AUTH_test/gettest/manifest-bc-ranges'), ('GET', '/v1/AUTH_test/gettest/c_15?multipart-manifest=get'), - ('GET', '/v1/AUTH_test/gettest/manifest-bc-ranges'), ('GET', '/v1/AUTH_test/gettest/d_20?multipart-manifest=get'), ('GET', '/v1/AUTH_test/gettest/c_15?multipart-manifest=get'), ('GET', '/v1/AUTH_test/gettest/b_10?multipart-manifest=get'), - ('GET', '/v1/AUTH_test/gettest/manifest-bc-ranges'), ('GET', '/v1/AUTH_test/gettest/a_5?multipart-manifest=get'), ('GET', '/v1/AUTH_test/gettest/b_10?multipart-manifest=get')]) @@ -2129,11 +2220,9 @@ class TestSloGetManifest(SloTestCase): None, None, 'bytes=2-2', - None, 'bytes=11-11', 'bytes=13-', 'bytes=4-6', - None, 'bytes=0-0', 'bytes=4-4']) # we set swift.source for everything but the first request @@ -2180,23 +2269,6 @@ class TestSloGetManifest(SloTestCase): self.assertEqual(self.app.swift_sources[1:], ['SLO'] * (len(self.app.swift_sources) - 1)) - def test_multi_range_get_range_manifest(self): - # SLO doesn't support multi-range GETs. The way that you express - # "unsupported" in HTTP is to return a 200 and the whole entity. - req = Request.blank( - '/v1/AUTH_test/gettest/manifest-abcd-ranges', - environ={'REQUEST_METHOD': 'GET'}, - headers={'Range': 'bytes=0-0,2-2'}) - status, headers, body = self.call_slo(req) - headers = HeaderKeyDict(headers) - - self.assertEqual(status, '200 OK') - self.assertEqual(headers['Content-Type'], 'application/json') - self.assertEqual(body, 'aaaaaaaaccccccccbbbbbbbbdddddddd') - self.assertNotIn('Transfer-Encoding', headers) - self.assertNotIn('Content-Range', headers) - self.assertEqual(headers['Content-Length'], '32') - def test_get_bogus_manifest(self): req = Request.blank( '/v1/AUTH_test/gettest/manifest-badjson', diff --git a/test/unit/common/test_swob.py b/test/unit/common/test_swob.py index f1a11e1fcb..b785bf289c 100644 --- a/test/unit/common/test_swob.py +++ b/test/unit/common/test_swob.py @@ -1298,16 +1298,16 @@ class TestResponse(unittest.TestCase): resp = req.get_response(test_app) resp.conditional_response = True body = ''.join(resp([], start_response)) - self.assertEqual(body, '') - self.assertEqual(resp.content_length, 0) + self.assertIn('The Range requested is not available', body) + self.assertEqual(resp.content_length, len(body)) self.assertEqual(resp.status, '416 Requested Range Not Satisfiable') resp = swift.common.swob.Response( body='1234567890', request=req, conditional_response=True) body = ''.join(resp([], start_response)) - self.assertEqual(body, '') - self.assertEqual(resp.content_length, 0) + self.assertIn('The Range requested is not available', body) + self.assertEqual(resp.content_length, len(body)) self.assertEqual(resp.status, '416 Requested Range Not Satisfiable') # Syntactically-invalid Range headers "MUST" be ignored diff --git a/test/unit/common/test_utils.py b/test/unit/common/test_utils.py index 707a293d6e..85c3e2ed50 100644 --- a/test/unit/common/test_utils.py +++ b/test/unit/common/test_utils.py @@ -5335,6 +5335,68 @@ class TestLRUCache(unittest.TestCase): self.assertEqual(f.size(), 4) +class TestSpliterator(unittest.TestCase): + def test_string(self): + input_chunks = ["coun", "ter-", "b", "ra", "nch-mater", + "nit", "y-fungusy", "-nummular"] + si = utils.Spliterator(input_chunks) + + self.assertEqual(''.join(si.take(8)), "counter-") + self.assertEqual(''.join(si.take(7)), "branch-") + self.assertEqual(''.join(si.take(10)), "maternity-") + self.assertEqual(''.join(si.take(8)), "fungusy-") + self.assertEqual(''.join(si.take(8)), "nummular") + + def test_big_input_string(self): + input_chunks = ["iridium"] + si = utils.Spliterator(input_chunks) + + self.assertEqual(''.join(si.take(2)), "ir") + self.assertEqual(''.join(si.take(1)), "i") + self.assertEqual(''.join(si.take(2)), "di") + self.assertEqual(''.join(si.take(1)), "u") + self.assertEqual(''.join(si.take(1)), "m") + + def test_chunk_boundaries(self): + input_chunks = ["soylent", "green", "is", "people"] + si = utils.Spliterator(input_chunks) + + self.assertEqual(''.join(si.take(7)), "soylent") + self.assertEqual(''.join(si.take(5)), "green") + self.assertEqual(''.join(si.take(2)), "is") + self.assertEqual(''.join(si.take(6)), "people") + + def test_no_empty_strings(self): + input_chunks = ["soylent", "green", "is", "people"] + si = utils.Spliterator(input_chunks) + + outputs = (list(si.take(7)) # starts and ends on chunk boundary + + list(si.take(2)) # spans two chunks + + list(si.take(3)) # begins but does not end chunk + + list(si.take(2)) # ends but does not begin chunk + + list(si.take(6))) # whole chunk + EOF + self.assertNotIn('', outputs) + + def test_running_out(self): + input_chunks = ["not much"] + si = utils.Spliterator(input_chunks) + + self.assertEqual(''.join(si.take(4)), "not ") + self.assertEqual(''.join(si.take(99)), "much") # short + self.assertEqual(''.join(si.take(4)), "") + self.assertEqual(''.join(si.take(4)), "") + + def test_overlap(self): + input_chunks = ["one fish", "two fish", "red fish", "blue fish"] + + si = utils.Spliterator(input_chunks) + t1 = si.take(20) # longer than first chunk + self.assertLess(len(next(t1)), 20) # it's not exhausted + + t2 = si.take(20) + self.assertRaises(ValueError, next, t2) + + class TestParseContentRange(unittest.TestCase): def test_good(self): start, end, total = utils.parse_content_range("bytes 100-200/300")