Refactor some file-like iters as utils.InputProxy subclasses

There's a few places where bespoke file-like wrapper classes have been
implemented. The common methods are now inherited from
utils.InputProxy.

Make utils.FileLikeIter tolerate size=None to mean the same as size=-1
so that it is consistent with the behavior of other input streams.

Fix docstrings in FileLikeIter.

Depends-On: https://review.opendev.org/c/openstack/requirements/+/942845
Change-Id: I20741ab58b0933390dc4679c3e6b2d888857d577
This commit is contained in:
Alistair Coles 2025-01-23 12:17:44 +00:00 committed by Tim Burke
parent 7140633925
commit e4cc228ed0
6 changed files with 304 additions and 90 deletions

View File

@ -27,7 +27,7 @@ from swift.common.request_helpers import get_object_transient_sysmeta, \
from swift.common.swob import Request, Match, HTTPException, \
HTTPUnprocessableEntity, wsgi_to_bytes, bytes_to_wsgi, normalize_etag
from swift.common.utils import get_logger, config_true_value, \
MD5_OF_EMPTY_STRING, md5
MD5_OF_EMPTY_STRING, md5, InputProxy
def encrypt_header_val(crypto, value, key):
@ -66,11 +66,11 @@ def _hmac_etag(key, etag):
return base64.b64encode(result).decode()
class EncInputWrapper(object):
class EncInputWrapper(InputProxy):
"""File-like object to be swapped in for wsgi.input."""
def __init__(self, crypto, keys, req, logger):
super().__init__(req.environ['wsgi.input'])
self.env = req.environ
self.wsgi_input = req.environ['wsgi.input']
self.path = req.path
self.crypto = crypto
self.body_crypto_ctxt = None
@ -180,15 +180,7 @@ class EncInputWrapper(object):
req.environ['swift.callback.update_footers'] = footers_callback
def read(self, *args, **kwargs):
return self.readChunk(self.wsgi_input.read, *args, **kwargs)
def readline(self, *args, **kwargs):
return self.readChunk(self.wsgi_input.readline, *args, **kwargs)
def readChunk(self, read_method, *args, **kwargs):
chunk = read_method(*args, **kwargs)
def chunk_update(self, chunk, eof, *args, **kwargs):
if chunk:
self._init_encryption_context()
self.plaintext_md5.update(chunk)

View File

@ -134,7 +134,7 @@ from swift.common.digest import get_allowed_digests, \
extract_digest_and_algorithm, DEFAULT_ALLOWED_DIGESTS
from swift.common.utils import streq_const_time, parse_content_disposition, \
parse_mime_headers, iter_multipart_mime_documents, reiterate, \
closing_if_possible, get_logger
closing_if_possible, get_logger, InputProxy
from swift.common.registry import register_swift_info
from swift.common.wsgi import WSGIContext, make_pre_authed_env
from swift.common.swob import HTTPUnauthorized, wsgi_to_str, str_to_wsgi
@ -158,7 +158,7 @@ class FormUnauthorized(Exception):
pass
class _CappedFileLikeObject(object):
class _CappedFileLikeObject(InputProxy):
"""
A file-like object wrapping another file-like object that raises
an EOFError if the amount of data read exceeds a given
@ -170,26 +170,15 @@ class _CappedFileLikeObject(object):
"""
def __init__(self, fp, max_file_size):
self.fp = fp
super().__init__(fp)
self.max_file_size = max_file_size
self.amount_read = 0
self.file_size_exceeded = False
def read(self, size=None):
ret = self.fp.read(size)
self.amount_read += len(ret)
if self.amount_read > self.max_file_size:
def chunk_update(self, chunk, eof, *args, **kwargs):
if self.bytes_received > self.max_file_size:
self.file_size_exceeded = True
raise EOFError('max_file_size exceeded')
return ret
def readline(self):
ret = self.fp.readline()
self.amount_read += len(ret)
if self.amount_read > self.max_file_size:
self.file_size_exceeded = True
raise EOFError('max_file_size exceeded')
return ret
return chunk
class FormPost(object):

View File

@ -24,8 +24,8 @@ import re
from urllib.parse import quote, unquote, parse_qsl
import string
from swift.common.utils import split_path, json, close_if_possible, md5, \
streq_const_time, get_policy_index
from swift.common.utils import split_path, json, md5, streq_const_time, \
get_policy_index, InputProxy
from swift.common.registry import get_swift_info
from swift.common import swob
from swift.common.http import HTTP_OK, HTTP_CREATED, HTTP_ACCEPTED, \
@ -133,41 +133,42 @@ class S3InputSHA256Mismatch(BaseException):
self.computed = computed
class HashingInput(object):
class HashingInput(InputProxy):
"""
wsgi.input wrapper to verify the SHA256 of the input as it's read.
"""
def __init__(self, reader, content_length, expected_hex_hash):
self._input = reader
self._to_read = content_length
def __init__(self, wsgi_input, content_length, expected_hex_hash):
super().__init__(wsgi_input)
self._expected_length = content_length
self._hasher = sha256()
self._expected = expected_hex_hash
self._expected_hash = expected_hex_hash
if content_length == 0 and \
self._hasher.hexdigest() != self._expected.lower():
self._hasher.hexdigest() != self._expected_hash.lower():
self.close()
raise XAmzContentSHA256Mismatch(
client_computed_content_s_h_a256=self._expected,
client_computed_content_s_h_a256=self._expected_hash,
s3_computed_content_s_h_a256=self._hasher.hexdigest(),
)
def read(self, size=None):
chunk = self._input.read(size)
def chunk_update(self, chunk, eof, *args, **kwargs):
self._hasher.update(chunk)
self._to_read -= len(chunk)
short_read = bool(chunk) if size is None else (len(chunk) < size)
if self._to_read < 0 or (short_read and self._to_read) or (
self._to_read == 0 and
self._hasher.hexdigest() != self._expected.lower()):
if self.bytes_received < self._expected_length:
error = eof
elif self.bytes_received == self._expected_length:
error = self._hasher.hexdigest() != self._expected_hash.lower()
else:
error = True
if error:
self.close()
# Since we don't return the last chunk, the PUT never completes
raise S3InputSHA256Mismatch(
self._expected,
self._expected_hash,
self._hasher.hexdigest())
return chunk
def close(self):
close_if_possible(self._input)
return chunk
class SigV4Mixin(object):

View File

@ -489,7 +489,9 @@ class FileLikeIter(object):
def __next__(self):
"""
next(x) -> the next value, or raise StopIteration
:raise StopIteration: if there are no more values to iterate.
:raise ValueError: if the close() method has been called.
:return: the next value.
"""
if self.closed:
raise ValueError('I/O operation on closed file')
@ -502,12 +504,14 @@ class FileLikeIter(object):
def read(self, size=-1):
"""
read([size]) -> read at most size bytes, returned as a bytes string.
If the size argument is negative or omitted, read until EOF is reached.
Notice that when in non-blocking mode, less data than what was
requested may be returned, even if no size parameter was given.
:param size: (optional) the maximum number of bytes to read. The
default value of ``-1`` means 'unlimited' i.e. read until the wrapped
iterable is exhausted.
:raise ValueError: if the close() method has been called.
:return: a bytes literal; if the wrapped iterable has been exhausted
then a zero-length bytes literal is returned.
"""
size = -1 if size is None else size
if self.closed:
raise ValueError('I/O operation on closed file')
if size < 0:
@ -529,12 +533,17 @@ class FileLikeIter(object):
def readline(self, size=-1):
"""
readline([size]) -> next line from the file, as a bytes string.
Read the next line.
Retain newline. A non-negative size argument limits the maximum
number of bytes to return (an incomplete line may be returned then).
Return an empty string at EOF.
:param size: (optional) the maximum number of bytes of the next line to
read. The default value of ``-1`` means 'unlimited' i.e. read to
the end of the line or until the wrapped iterable is exhausted,
whichever is first.
:raise ValueError: if the close() method has been called.
:return: a bytes literal; if the wrapped iterable has been exhausted
then a zero-length bytes literal is returned.
"""
size = -1 if size is None else size
if self.closed:
raise ValueError('I/O operation on closed file')
data = b''
@ -557,12 +566,16 @@ class FileLikeIter(object):
def readlines(self, sizehint=-1):
"""
readlines([size]) -> list of bytes strings, each a line from the file.
Call readline() repeatedly and return a list of the lines so read.
The optional size argument, if given, is an approximate bound on the
total number of bytes in the lines returned.
:param sizehint: (optional) an approximate bound on the total number of
bytes in the lines returned. Lines are read until ``sizehint`` has
been exceeded but complete lines are always returned, so the total
bytes read may exceed ``sizehint``.
:raise ValueError: if the close() method has been called.
:return: a list of bytes literals, each a line from the file.
"""
sizehint = -1 if sizehint is None else sizehint
if self.closed:
raise ValueError('I/O operation on closed file')
lines = []
@ -579,12 +592,10 @@ class FileLikeIter(object):
def close(self):
"""
close() -> None or (perhaps) an integer. Close the file.
Close the iter.
Sets data attribute .closed to True. A closed file cannot be used for
further I/O operations. close() may be called more than once without
error. Some kinds of file objects (for example, opened by popen())
may return an exit status upon closing.
Once close() has been called the iter cannot be used for further I/O
operations. close() may be called more than once without error.
"""
self.iterator = None
self.closed = True
@ -2515,41 +2526,79 @@ class InputProxy(object):
"""
File-like object that counts bytes read.
To be swapped in for wsgi.input for accounting purposes.
:param wsgi_input: file-like object to be wrapped
"""
def __init__(self, wsgi_input):
"""
:param wsgi_input: file-like object to wrap the functionality of
"""
self.wsgi_input = wsgi_input
#: total number of bytes read from the wrapped input
self.bytes_received = 0
#: ``True`` if an exception is raised by ``read()`` or ``readline()``,
#: ``False`` otherwise
self.client_disconnect = False
def read(self, *args, **kwargs):
def chunk_update(self, chunk, eof, *args, **kwargs):
"""
Called each time a chunk of bytes is read from the wrapped input.
:param chunk: the chunk of bytes that has been read.
:param eof: ``True`` if there are no more bytes to read from the
wrapped input, ``False`` otherwise. If ``read()`` has been called
this will be ``True`` when the size of ``chunk`` is less than the
requested size or the requested size is None. If ``readline`` has
been called this will be ``True`` when an incomplete line is read
(i.e. not ending with ``b'\\n'``) whose length is less than the
requested size or the requested size is None. If ``read()`` or
``readline()`` are called with a requested size that exactly
matches the number of bytes remaining in the wrapped input then
``eof`` will be ``False``. A subsequent call to ``read()`` or
``readline()`` with non-zero ``size`` would result in ``eof`` being
``True``. Alternatively, the end of the input could be inferred
by comparing ``bytes_received`` with the expected length of the
input.
"""
# subclasses may override this method; either the given chunk or an
# alternative chunk value should be returned
return chunk
def read(self, size=None, *args, **kwargs):
"""
Pass read request to the underlying file-like object and
add bytes read to total.
:param size: (optional) maximum number of bytes to read; the default
``None`` means unlimited.
"""
try:
chunk = self.wsgi_input.read(*args, **kwargs)
chunk = self.wsgi_input.read(size, *args, **kwargs)
except Exception:
self.client_disconnect = True
raise
self.bytes_received += len(chunk)
return chunk
eof = size is None or size < 0 or len(chunk) < size
return self.chunk_update(chunk, eof)
def readline(self, *args, **kwargs):
def readline(self, size=None, *args, **kwargs):
"""
Pass readline request to the underlying file-like object and
add bytes read to total.
:param size: (optional) maximum number of bytes to read from the
current line; the default ``None`` means unlimited.
"""
try:
line = self.wsgi_input.readline(*args, **kwargs)
line = self.wsgi_input.readline(size, *args, **kwargs)
except Exception:
self.client_disconnect = True
raise
self.bytes_received += len(line)
return line
eof = ((size is None or size < 0 or len(line) < size)
and (line[-1:] != b'\n'))
return self.chunk_update(line, eof)
def close(self):
close_if_possible(self.wsgi_input)
class LRUCache(object):

View File

@ -1468,9 +1468,17 @@ class TestHashingInput(S3ApiTestCase):
# can continue trying to read -- but it'll be empty
self.assertEqual(b'', wrapped.read(2))
self.assertFalse(wrapped._input.closed)
self.assertFalse(wrapped.wsgi_input.closed)
wrapped.close()
self.assertTrue(wrapped._input.closed)
self.assertTrue(wrapped.wsgi_input.closed)
def test_good_readline(self):
raw = b'12345\n6789'
wrapped = HashingInput(
BytesIO(raw), 10, hashlib.sha256(raw).hexdigest())
self.assertEqual(b'12345\n', wrapped.readline())
self.assertEqual(b'6789', wrapped.readline())
self.assertEqual(b'', wrapped.readline())
def test_empty(self):
wrapped = HashingInput(
@ -1478,9 +1486,9 @@ class TestHashingInput(S3ApiTestCase):
self.assertEqual(b'', wrapped.read(4))
self.assertEqual(b'', wrapped.read(2))
self.assertFalse(wrapped._input.closed)
self.assertFalse(wrapped.wsgi_input.closed)
wrapped.close()
self.assertTrue(wrapped._input.closed)
self.assertTrue(wrapped.wsgi_input.closed)
def test_too_long(self):
raw = b'123456789'
@ -1495,18 +1503,26 @@ class TestHashingInput(S3ApiTestCase):
# won't get caught by most things in a pipeline
self.assertNotIsInstance(raised.exception, Exception)
# the error causes us to close the input
self.assertTrue(wrapped._input.closed)
self.assertTrue(wrapped.wsgi_input.closed)
def test_too_short(self):
def test_too_short_read_piecemeal(self):
raw = b'123456789'
wrapped = HashingInput(
BytesIO(raw), 10, hashlib.sha256(raw).hexdigest())
self.assertEqual(b'1234', wrapped.read(4))
self.assertEqual(b'56', wrapped.read(2))
# even though the hash matches, there was more data than we expected
self.assertEqual(b'56789', wrapped.read(5))
# even though the hash matches, there was less data than we expected
with self.assertRaises(S3InputSHA256Mismatch):
wrapped.read(4)
self.assertTrue(wrapped._input.closed)
wrapped.read(1)
self.assertTrue(wrapped.wsgi_input.closed)
def test_too_short_read_all(self):
raw = b'123456789'
wrapped = HashingInput(
BytesIO(raw), 10, hashlib.sha256(raw).hexdigest())
with self.assertRaises(S3InputSHA256Mismatch):
wrapped.read()
self.assertTrue(wrapped.wsgi_input.closed)
def test_bad_hash(self):
raw = b'123456789'
@ -1516,7 +1532,7 @@ class TestHashingInput(S3ApiTestCase):
self.assertEqual(b'5678', wrapped.read(4))
with self.assertRaises(S3InputSHA256Mismatch):
wrapped.read(4)
self.assertTrue(wrapped._input.closed)
self.assertTrue(wrapped.wsgi_input.closed)
def test_empty_bad_hash(self):
_input = BytesIO(b'')
@ -1526,6 +1542,14 @@ class TestHashingInput(S3ApiTestCase):
HashingInput(_input, 0, 'nope')
self.assertTrue(_input.closed)
def test_bad_hash_readline(self):
raw = b'12345\n6789'
wrapped = HashingInput(
BytesIO(raw), 10, hashlib.sha256(raw[:-3]).hexdigest())
self.assertEqual(b'12345\n', wrapped.readline())
with self.assertRaises(S3InputSHA256Mismatch):
self.assertEqual(b'6789', wrapped.readline())
if __name__ == '__main__':
unittest.main()

View File

@ -18,6 +18,7 @@ from __future__ import print_function
import argparse
import hashlib
import io
import itertools
from swift.common.statsd_client import StatsdClient
@ -2961,8 +2962,16 @@ class TestFileLikeIter(unittest.TestCase):
def test_read(self):
in_iter = [b'abc', b'de', b'fghijk', b'l']
iter_file = utils.FileLikeIter(in_iter)
self.assertEqual(iter_file.read(), b''.join(in_iter))
expected = b''.join(in_iter)
self.assertEqual(utils.FileLikeIter(in_iter).read(), expected)
self.assertEqual(utils.FileLikeIter(in_iter).read(-1), expected)
self.assertEqual(utils.FileLikeIter(in_iter).read(None), expected)
def test_read_empty(self):
in_iter = [b'abc']
ip = utils.FileLikeIter(in_iter)
self.assertEqual(b'abc', ip.read())
self.assertEqual(b'', ip.read())
def test_read_with_size(self):
in_iter = [b'abc', b'de', b'fghijk', b'l']
@ -2995,6 +3004,15 @@ class TestFileLikeIter(unittest.TestCase):
[v if v == b'trailing.' else v + b'\n'
for v in b''.join(in_iter).split(b'\n')])
def test_readline_size_unlimited(self):
in_iter = [b'abc', b'd\nef']
self.assertEqual(
utils.FileLikeIter(in_iter).readline(-1),
b'abcd\n')
self.assertEqual(
utils.FileLikeIter(in_iter).readline(None),
b'abcd\n')
def test_readline2(self):
self.assertEqual(
utils.FileLikeIter([b'abc', b'def\n']).readline(4),
@ -3029,6 +3047,16 @@ class TestFileLikeIter(unittest.TestCase):
lines,
[v if v == b'trailing.' else v + b'\n'
for v in b''.join(in_iter).split(b'\n')])
lines = utils.FileLikeIter(in_iter).readlines(sizehint=-1)
self.assertEqual(
lines,
[v if v == b'trailing.' else v + b'\n'
for v in b''.join(in_iter).split(b'\n')])
lines = utils.FileLikeIter(in_iter).readlines(sizehint=None)
self.assertEqual(
lines,
[v if v == b'trailing.' else v + b'\n'
for v in b''.join(in_iter).split(b'\n')])
def test_readlines_with_size(self):
in_iter = [b'abc\n', b'd', b'\nef', b'g\nh', b'\nij\n\nk\n',
@ -3092,6 +3120,137 @@ class TestFileLikeIter(unittest.TestCase):
self.assertEqual(utils.get_hub(), 'selects')
class TestInputProxy(unittest.TestCase):
def test_read_all(self):
self.assertEqual(utils.InputProxy(io.BytesIO(b'abc')).read(), b'abc')
self.assertEqual(utils.InputProxy(io.BytesIO(b'abc')).read(-1), b'abc')
self.assertEqual(
utils.InputProxy(io.BytesIO(b'abc')).read(None), b'abc')
def test_read_size(self):
self.assertEqual(utils.InputProxy(io.BytesIO(b'abc')).read(0), b'')
self.assertEqual(utils.InputProxy(io.BytesIO(b'abc')).read(2), b'ab')
self.assertEqual(utils.InputProxy(io.BytesIO(b'abc')).read(4), b'abc')
def test_readline(self):
ip = utils.InputProxy(io.BytesIO(b'ab\nc'))
self.assertEqual(ip.readline(), b'ab\n')
self.assertFalse(ip.client_disconnect)
def test_bytes_received(self):
ip = utils.InputProxy(io.BytesIO(b'ab\ncdef'))
ip.readline()
self.assertEqual(3, ip.bytes_received)
ip.read(2)
self.assertEqual(5, ip.bytes_received)
ip.read(99)
self.assertEqual(7, ip.bytes_received)
def test_close(self):
utils.InputProxy(object()).close() # safe
fake = mock.MagicMock()
fake.close = mock.MagicMock()
ip = (utils.InputProxy(fake))
ip.close()
self.assertEqual([mock.call()], fake.close.call_args_list)
self.assertFalse(ip.client_disconnect)
def test_read_piecemeal_chunk_update(self):
ip = utils.InputProxy(io.BytesIO(b'abc'))
with mock.patch.object(ip, 'chunk_update') as mocked:
ip.read(1)
ip.read(2)
ip.read(1)
ip.read(1)
self.assertEqual([mock.call(b'a', False),
mock.call(b'bc', False),
mock.call(b'', True),
mock.call(b'', True)], mocked.call_args_list)
def test_read_unlimited_chunk_update(self):
ip = utils.InputProxy(io.BytesIO(b'abc'))
with mock.patch.object(ip, 'chunk_update') as mocked:
ip.read()
ip.read()
self.assertEqual([mock.call(b'abc', True),
mock.call(b'', True)], mocked.call_args_list)
ip = utils.InputProxy(io.BytesIO(b'abc'))
with mock.patch.object(ip, 'chunk_update') as mocked:
ip.read(None)
ip.read(None)
self.assertEqual([mock.call(b'abc', True),
mock.call(b'', True)], mocked.call_args_list)
ip = utils.InputProxy(io.BytesIO(b'abc'))
with mock.patch.object(ip, 'chunk_update') as mocked:
ip.read(-1)
ip.read(-1)
self.assertEqual([mock.call(b'abc', True),
mock.call(b'', True)], mocked.call_args_list)
def test_readline_piecemeal_chunk_update(self):
ip = utils.InputProxy(io.BytesIO(b'ab\nc'))
with mock.patch.object(ip, 'chunk_update') as mocked:
ip.readline(3)
ip.readline(1) # read to exact length
ip.readline(1)
self.assertEqual([mock.call(b'ab\n', False),
mock.call(b'c', False),
mock.call(b'', True)], mocked.call_args_list)
ip = utils.InputProxy(io.BytesIO(b'ab\nc'))
with mock.patch.object(ip, 'chunk_update') as mocked:
ip.readline(3)
ip.readline(2) # read beyond exact length
ip.readline(1)
self.assertEqual([mock.call(b'ab\n', False),
mock.call(b'c', True),
mock.call(b'', True)], mocked.call_args_list)
def test_readline_unlimited_chunk_update(self):
ip = utils.InputProxy(io.BytesIO(b'ab\nc'))
with mock.patch.object(ip, 'chunk_update') as mocked:
ip.readline()
ip.readline()
self.assertEqual([mock.call(b'ab\n', False),
mock.call(b'c', True)], mocked.call_args_list)
ip = utils.InputProxy(io.BytesIO(b'ab\nc'))
with mock.patch.object(ip, 'chunk_update') as mocked:
ip.readline(None)
ip.readline(None)
self.assertEqual([mock.call(b'ab\n', False),
mock.call(b'c', True)], mocked.call_args_list)
ip = utils.InputProxy(io.BytesIO(b'ab\nc'))
with mock.patch.object(ip, 'chunk_update') as mocked:
ip.readline(-1)
ip.readline(-1)
self.assertEqual([mock.call(b'ab\n', False),
mock.call(b'c', True)], mocked.call_args_list)
def test_chunk_update_modifies_chunk(self):
ip = utils.InputProxy(io.BytesIO(b'abc'))
with mock.patch.object(ip, 'chunk_update', return_value='modified'):
actual = ip.read()
self.assertEqual('modified', actual)
def test_read_client_disconnect(self):
fake = mock.MagicMock()
fake.read = mock.MagicMock(side_effect=ValueError('boom'))
ip = utils.InputProxy(fake)
with self.assertRaises(ValueError) as cm:
ip.read()
self.assertTrue(ip.client_disconnect)
self.assertEqual('boom', str(cm.exception))
def test_readline_client_disconnect(self):
fake = mock.MagicMock()
fake.readline = mock.MagicMock(side_effect=ValueError('boom'))
ip = utils.InputProxy(fake)
with self.assertRaises(ValueError) as cm:
ip.readline()
self.assertTrue(ip.client_disconnect)
self.assertEqual('boom', str(cm.exception))
class UnsafeXrange(object):
"""
Like range(limit), but with extra context switching to screw things up.