Fix GetOrHeadHandler unit tests to use public interface

Modify some unit tests to call the get_working_response() method of
GetOrHeadHandler rather than a private method.

Extract and re-use a TestSource helper class.

Change-Id: Ie8e0560f90207eede1992c8665e4d8571a505580
This commit is contained in:
Alistair Coles 2023-05-23 13:36:41 +01:00
parent 3e08fcd7b1
commit 149b617c28

View File

@ -183,6 +183,39 @@ class FakeCache(FakeMemcache):
return self.stub or self.store.get(key)
class TestSource(object):
def __init__(self, chunks, headers=None, body=b''):
self.chunks = list(chunks)
self.headers = headers or {}
self.status = 200
self.swift_conn = None
self.body = body
def read(self, _read_size):
if self.chunks:
chunk = self.chunks.pop(0)
if chunk is None:
raise exceptions.ChunkReadTimeout()
else:
return chunk
else:
return self.body
def getheader(self, header):
# content-length for the whole object is generated dynamically
# by summing non-None chunks
if header.lower() == "content-length":
if self.chunks:
return str(sum(len(c) for c in self.chunks
if c is not None))
return len(self.read(-1))
return self.headers.get(header.lower())
def getheaders(self):
return [('content-length', self.getheader('content-length'))] + \
[(k, v) for k, v in self.headers.items()]
class BaseTest(unittest.TestCase):
def setUp(self):
@ -1272,67 +1305,21 @@ class TestFuncs(BaseTest):
self.assertEqual('', dst_headers['Referer'])
def test_client_chunk_size(self):
class TestSource(object):
def __init__(self, chunks):
self.chunks = list(chunks)
self.status = 200
def read(self, _read_size):
if self.chunks:
return self.chunks.pop(0)
else:
return b''
def getheader(self, header):
if header.lower() == "content-length":
return str(sum(len(c) for c in self.chunks))
def getheaders(self):
return [('content-length', self.getheader('content-length'))]
source = TestSource((
b'abcd', b'1234', b'abc', b'd1', b'234abcd1234abcd1', b'2'))
req = Request.blank('/v1/a/c/o')
node = {}
handler = GetOrHeadHandler(
self.app, req, None, Namespace(num_primary_nodes=3), None, None,
{}, client_chunk_size=8)
handler.source = source
handler.node = node
app_iter = handler._make_app_iter(req)
client_chunks = list(app_iter)
with mock.patch.object(handler, '_get_source_and_node',
return_value=(source, {})):
resp = handler.get_working_response(req)
client_chunks = list(resp.app_iter)
self.assertEqual(client_chunks, [
b'abcd1234', b'abcd1234', b'abcd1234', b'abcd12'])
def test_client_chunk_size_resuming(self):
class TestSource(object):
def __init__(self, chunks):
self.chunks = list(chunks)
self.status = 200
def read(self, _read_size):
if self.chunks:
chunk = self.chunks.pop(0)
if chunk is None:
raise exceptions.ChunkReadTimeout()
else:
return chunk
else:
return b''
def getheader(self, header):
# content-length for the whole object is generated dynamically
# by summing non-None chunks initialized as source1
if header.lower() == "content-length":
return str(sum(len(c) for c in self.chunks
if c is not None))
def getheaders(self):
return [('content-length', self.getheader('content-length'))]
node = {'ip': '1.2.3.4', 'port': 6200, 'device': 'sda'}
source1 = TestSource([b'abcd', b'1234', None,
@ -1346,93 +1333,55 @@ class TestFuncs(BaseTest):
None, {}, client_chunk_size=8)
range_headers = []
sources = [(source2, node), (source3, node)]
sources = [(source1, node), (source2, node), (source3, node)]
def mock_get_source_and_node():
range_headers.append(handler.backend_headers['Range'])
range_headers.append(handler.backend_headers.get('Range'))
return sources.pop(0)
handler.source = source1
handler.node = node
app_iter = handler._make_app_iter(req)
with mock.patch.object(handler, '_get_source_and_node',
side_effect=mock_get_source_and_node):
client_chunks = list(app_iter)
self.assertEqual(range_headers, ['bytes=8-27', 'bytes=16-27'])
mock_get_source_and_node):
resp = handler.get_working_response(req)
client_chunks = list(resp.app_iter)
self.assertEqual(range_headers, [None, 'bytes=8-27', 'bytes=16-27'])
self.assertEqual(client_chunks, [
b'abcd1234', b'efgh5678', b'lotsmore', b'data'])
def test_client_chunk_size_resuming_chunked(self):
class TestChunkedSource(object):
def __init__(self, chunks):
self.chunks = list(chunks)
self.status = 200
self.headers = {'transfer-encoding': 'chunked',
'content-type': 'text/plain'}
def read(self, _read_size):
if self.chunks:
chunk = self.chunks.pop(0)
if chunk is None:
raise exceptions.ChunkReadTimeout()
else:
return chunk
else:
return b''
def getheader(self, header):
return self.headers.get(header.lower())
def getheaders(self):
return self.headers
node = {'ip': '1.2.3.4', 'port': 6200, 'device': 'sda'}
source1 = TestChunkedSource([b'abcd', b'1234', b'abc', None])
source2 = TestChunkedSource([b'efgh5678'])
headers = {'transfer-encoding': 'chunked',
'content-type': 'text/plain'}
source1 = TestSource([b'abcd', b'1234', b'abc', None], headers=headers)
source2 = TestSource([b'efgh5678'], headers=headers)
sources = [(source1, node), (source2, node)]
req = Request.blank('/v1/a/c/o')
handler = GetOrHeadHandler(
self.app, req, 'Object', Namespace(num_primary_nodes=1), None,
None, {}, client_chunk_size=8)
handler.source = source1
handler.node = node
app_iter = handler._make_app_iter(req)
def mock_get_source_and_node():
return sources.pop(0)
with mock.patch.object(handler, '_get_source_and_node',
lambda: (source2, node)):
client_chunks = list(app_iter)
mock_get_source_and_node):
resp = handler.get_working_response(req)
client_chunks = list(resp.app_iter)
self.assertEqual(client_chunks, [b'abcd1234', b'efgh5678'])
def test_disconnected_logging(self):
self.app.logger = mock.Mock()
req = Request.blank('/v1/a/c/o')
class TestSource(object):
def __init__(self):
self.headers = {'content-type': 'text/plain',
'content-length': len(self.read(-1))}
self.status = 200
def read(self, _read_size):
return b'the cake is a lie'
def getheader(self, header):
return self.headers.get(header.lower())
def getheaders(self):
return self.headers
source = TestSource()
headers = {'content-type': 'text/plain'}
source = TestSource([], headers=headers, body=b'the cake is a lie')
node = {'ip': '1.2.3.4', 'port': 6200, 'device': 'sda'}
handler = GetOrHeadHandler(
self.app, req, 'Object', Namespace(num_primary_nodes=1), None,
'some-path', {})
handler.source = source
handler.node = node
app_iter = handler._make_app_iter(req)
app_iter.close()
with mock.patch.object(handler, '_get_source_and_node',
return_value=(source, node)):
resp = handler.get_working_response(req)
resp.app_iter.close()
self.app.logger.info.assert_called_once_with(
'Client disconnected on read of %r', 'some-path')
@ -1441,11 +1390,12 @@ class TestFuncs(BaseTest):
handler = GetOrHeadHandler(
self.app, req, 'Object', Namespace(num_primary_nodes=1), None,
None, {})
handler.source = source
handler.node = node
app_iter = handler._make_app_iter(req)
next(app_iter)
app_iter.close()
with mock.patch.object(handler, '_get_source_and_node',
return_value=(source, node)):
resp = handler.get_working_response(req)
next(resp.app_iter)
resp.app_iter.close()
self.app.logger.warning.assert_not_called()
def test_bytes_to_skip(self):