unify duplicate code in replication, EC GET paths

The GetOrHeadHandler class in base.py and the ECFragGetter
class in obj.py have some methods that are very similar.
The two existing classes do also have unique code so they
will remain, but as subclasses of the common base class
called GetterBase.

Change-Id: I893c5fcb6b4f8a7dda351169f5f6b37375a34817
This commit is contained in:
indianwhocodes 2023-06-21 08:36:54 -07:00 committed by Alistair Coles
parent 2343521f75
commit e290d47c43
2 changed files with 56 additions and 128 deletions

View File

@ -1014,55 +1014,21 @@ class ByteCountEnforcer(object):
return chunk
class GetOrHeadHandler(object):
def __init__(self, app, req, server_type, node_iter, partition, path,
backend_headers, concurrency=1, policy=None,
client_chunk_size=None, newest=None, logger=None):
class GetterBase(object):
def __init__(self, app, req, node_iter, partition, policy,
path, backend_headers, logger=None):
self.app = app
self.req = req
self.node_iter = node_iter
self.server_type = server_type
self.partition = partition
self.policy = policy
self.path = path
self.backend_headers = backend_headers
self.client_chunk_size = client_chunk_size
self.logger = logger or app.logger
self.skip_bytes = 0
self.bytes_used_from_backend = 0
self.used_nodes = []
self.used_source_etag = ''
self.concurrency = concurrency
self.policy = policy
self.node = None
self.source = None
self.source_parts_iter = None
self.latest_404_timestamp = Timestamp(0)
if self.server_type == 'Object':
self.node_timeout = self.app.recoverable_node_timeout
else:
self.node_timeout = self.app.node_timeout
policy_options = self.app.get_policy_options(self.policy)
self.rebalance_missing_suppression_count = min(
policy_options.rebalance_missing_suppression_count,
node_iter.num_primary_nodes - 1)
# stuff from request
self.req_method = req.method
self.req_path = req.path
self.req_query_string = req.query_string
if newest is None:
self.newest = config_true_value(req.headers.get('x-newest', 'f'))
else:
self.newest = newest
# populated when finding source
self.statuses = []
self.reasons = []
self.bodies = []
self.source_headers = []
self.sources = []
# populated from response headers
self.start_byte = self.end_byte = self.length = None
def fast_forward(self, num_bytes):
"""
@ -1136,6 +1102,46 @@ class GetOrHeadHandler(object):
else:
self.backend_headers.pop('Range')
class GetOrHeadHandler(GetterBase):
def __init__(self, app, req, server_type, node_iter, partition, path,
backend_headers, concurrency=1, policy=None,
client_chunk_size=None, newest=None, logger=None):
super(GetOrHeadHandler, self).__init__(
app=app, req=req, node_iter=node_iter,
partition=partition, policy=policy, path=path,
backend_headers=backend_headers, logger=logger)
self.server_type = server_type
self.client_chunk_size = client_chunk_size
self.skip_bytes = 0
self.used_nodes = []
self.used_source_etag = ''
self.concurrency = concurrency
self.latest_404_timestamp = Timestamp(0)
if self.server_type == 'Object':
self.node_timeout = self.app.recoverable_node_timeout
else:
self.node_timeout = self.app.node_timeout
policy_options = self.app.get_policy_options(self.policy)
self.rebalance_missing_suppression_count = min(
policy_options.rebalance_missing_suppression_count,
node_iter.num_primary_nodes - 1)
if newest is None:
self.newest = config_true_value(req.headers.get('x-newest', 'f'))
else:
self.newest = newest
# populated when finding source
self.statuses = []
self.reasons = []
self.bodies = []
self.source_headers = []
self.sources = []
# populated from response headers
self.start_byte = self.end_byte = self.length = None
def learn_size_from_content_range(self, start, end, length):
"""
If client_chunk_size is set, makes sure we yield things starting on
@ -1403,9 +1409,9 @@ class GetOrHeadHandler(object):
with ConnectionTimeout(self.app.conn_timeout):
conn = http_connect(
ip, port, node['device'],
self.partition, self.req_method, self.path,
self.partition, self.req.method, self.path,
headers=req_headers,
query_string=self.req_query_string)
query_string=self.req.query_string)
self.app.set_node_timing(node, time.time() - start_node_timing)
with Timeout(node_timeout):
@ -1416,7 +1422,7 @@ class GetOrHeadHandler(object):
self.app.exception_occurred(
node, self.server_type,
'Trying to %(method)s %(path)s' %
{'method': self.req_method, 'path': self.req_path})
{'method': self.req.method, 'path': self.req.path})
return False
src_headers = dict(
@ -1486,7 +1492,7 @@ class GetOrHeadHandler(object):
if ts > self.latest_404_timestamp:
self.latest_404_timestamp = ts
self.app.check_response(node, self.server_type, possible_source,
self.req_method, self.path,
self.req.method, self.path,
self.bodies[-1])
return False

View File

@ -69,7 +69,7 @@ from swift.common.storage_policy import (POLICIES, REPL_POLICY, EC_POLICY,
ECDriverError, PolicyError)
from swift.proxy.controllers.base import Controller, delay_denial, \
cors_validation, update_headers, bytes_to_skip, close_swift_conn, \
ByteCountEnforcer, record_cache_op_metrics, get_cache_key
ByteCountEnforcer, record_cache_op_metrics, get_cache_key, GetterBase
from swift.common.swob import HTTPAccepted, HTTPBadRequest, HTTPNotFound, \
HTTPPreconditionFailed, HTTPRequestEntityTooLarge, HTTPRequestTimeout, \
HTTPServerError, HTTPServiceUnavailable, HTTPClientDisconnect, \
@ -2490,97 +2490,19 @@ def is_good_source(status):
return is_success(status) or is_redirection(status)
class ECFragGetter(object):
class ECFragGetter(GetterBase):
def __init__(self, app, req, node_iter, partition, policy, path,
backend_headers, header_provider, logger_thread_locals,
logger):
self.app = app
self.req = req
self.node_iter = node_iter
self.partition = partition
self.path = path
self.backend_headers = backend_headers
super(ECFragGetter, self).__init__(
app=app, req=req, node_iter=node_iter,
partition=partition, policy=policy, path=path,
backend_headers=backend_headers, logger=logger)
self.header_provider = header_provider
self.req_query_string = req.query_string
self.fragment_size = policy.fragment_size
self.skip_bytes = 0
self.bytes_used_from_backend = 0
self.source = self.node = None
self.logger_thread_locals = logger_thread_locals
self.logger = logger
def fast_forward(self, num_bytes):
"""
Will skip num_bytes into the current ranges.
:params num_bytes: the number of bytes that have already been read on
this request. This will change the Range header
so that the next req will start where it left off.
:raises HTTPRequestedRangeNotSatisfiable: if begin + num_bytes
> end of range + 1
:raises RangeAlreadyComplete: if begin + num_bytes == end of range + 1
"""
try:
req_range = Range(self.backend_headers.get('Range'))
except ValueError:
req_range = None
if req_range:
begin, end = req_range.ranges[0]
if begin is None:
# this is a -50 range req (last 50 bytes of file)
end -= num_bytes
if end == 0:
# we sent out exactly the first range's worth of bytes, so
# we're done with it
raise RangeAlreadyComplete()
if end < 0:
raise HTTPRequestedRangeNotSatisfiable()
else:
begin += num_bytes
if end is not None and begin == end + 1:
# we sent out exactly the first range's worth of bytes, so
# we're done with it
raise RangeAlreadyComplete()
if end is not None and begin > end:
raise HTTPRequestedRangeNotSatisfiable()
req_range.ranges = [(begin, end)] + req_range.ranges[1:]
self.backend_headers['Range'] = str(req_range)
else:
self.backend_headers['Range'] = 'bytes=%d-' % num_bytes
# Reset so if we need to do this more than once, we don't double-up
self.bytes_used_from_backend = 0
def pop_range(self):
"""
Remove the first byterange from our Range header.
This is used after a byterange has been completely sent to the
client; this way, should we need to resume the download from another
object server, we do not re-fetch byteranges that the client already
has.
If we have no Range header, this is a no-op.
"""
if 'Range' in self.backend_headers:
try:
req_range = Range(self.backend_headers['Range'])
except ValueError:
# there's a Range header, but it's garbage, so get rid of it
self.backend_headers.pop('Range')
return
begin, end = req_range.ranges.pop(0)
if len(req_range.ranges) > 0:
self.backend_headers['Range'] = str(req_range)
else:
self.backend_headers.pop('Range')
def learn_size_from_content_range(self, start, end, length):
"""
@ -2833,7 +2755,7 @@ class ECFragGetter(object):
ip, port, node['device'],
self.partition, 'GET', self.path,
headers=req_headers,
query_string=self.req_query_string)
query_string=self.req.query_string)
self.app.set_node_timing(node, time.time() - start_node_timing)
with Timeout(node_timeout):