proxy-server: de-duplicate _get_next_response_part method

Both GetOrHeadHandler (used for replicated policy GETs) and
ECFragGetter (used for EC policy GETs) have _get_next_response_part
methods that are very similar. This patch replaces them with a single
method in the common GetterBase superclass.

Both classes are modified to use *only* the Request instance passed to
their constructors. Previously their entry methods
(GetOrHeadHandler.get_working_response and
ECFragGetter.response_parts_iter) accepted a Request instance as an
arg and the class then variably referred to that or the Request
instance passed to the constructor. Both instances must be the same
and it is therefore safer to only allow the Request to be passed to
the constructor.

The 'newest' keyword arg is dropped from the GetOrHeadHandler
constructor because it is never used.

This refactoring patch makes no intentional behavioral changes, apart
from the text of some error log messages which have been changed to
differentiate replicated object GETs from EC fragment GETs.

Change-Id: I148e158ab046929d188289796abfbbce97dc8d90
This commit is contained in:
Alistair Coles 2023-10-06 15:19:14 +01:00
parent 50336c5098
commit 8061dfb1c3
4 changed files with 155 additions and 114 deletions

View File

@ -1139,6 +1139,14 @@ class ByteCountEnforcer(object):
class GetterSource(object):
"""
Encapsulates properties of a source from which a GET response is read.
:param app: a proxy app.
:param resp: an instance of ``HTTPResponse``.
:param node: a dict describing the node from which the response was
returned.
"""
__slots__ = ('app', 'resp', 'node', '_parts_iter')
def __init__(self, app, resp, node):
@ -1175,8 +1183,26 @@ class GetterSource(object):
class GetterBase(object):
"""
This base class provides helper methods for handling GET requests to
backend servers.
:param app: a proxy app.
:param req: an instance of ``swob.Request``.
:param node_iter: an iterator yielding nodes.
:param partition: partition.
:param policy: the policy instance, or None if Account or Container.
:param path: path for the request.
:param backend_headers: a dict of headers to be sent with backend requests.
:param node_timeout: the timeout value for backend requests.
:param resource_type: a string description of the type of resource being
accessed; ``resource type`` is used in logs and isn't necessarily the
server type.
:param logger: a logger instance.
"""
def __init__(self, app, req, node_iter, partition, policy,
path, backend_headers, logger=None):
path, backend_headers, node_timeout, resource_type,
logger=None):
self.app = app
self.req = req
self.node_iter = node_iter
@ -1184,6 +1210,9 @@ class GetterBase(object):
self.policy = policy
self.path = path
self.backend_headers = backend_headers
# resource type is used in logs and isn't necessarily the server type
self.resource_type = resource_type
self.node_timeout = node_timeout
self.logger = logger or app.logger
self.bytes_used_from_backend = 0
self.source = None
@ -1206,6 +1235,35 @@ class GetterBase(object):
self.source.close()
return self._find_source()
def _get_next_response_part(self):
# return the next part of the response body; there may only be one part
# unless it's a multipart/byteranges response
while True:
# the loop here is to resume if trying to parse
# multipart/byteranges response raises a ChunkReadTimeout
# and resets the source_parts_iter
try:
with WatchdogTimeout(self.app.watchdog, self.node_timeout,
ChunkReadTimeout):
# If we don't have a multipart/byteranges response,
# but just a 200 or a single-range 206, then this
# performs no IO, and either just returns source or
# raises StopIteration.
# Otherwise, this call to next() performs IO when
# we have a multipart/byteranges response, as it
# will read the MIME boundary and part headers. In this
# case, ChunkReadTimeout may also be raised.
# If StopIteration is raised, it escapes and is
# handled elsewhere.
start_byte, end_byte, length, headers, part = next(
self.source.parts_iter)
return (start_byte, end_byte, length, headers, part)
except ChunkReadTimeout:
if not self._replace_source(
'Trying to read next part of %s multi-part GET '
'(retrying)' % self.resource_type):
raise
def fast_forward(self, num_bytes):
"""
Will skip num_bytes into the current ranges.
@ -1311,31 +1369,41 @@ class GetterBase(object):
class GetOrHeadHandler(GetterBase):
"""
Handles GET requests to backend servers.
:param app: a proxy app.
:param req: an instance of ``swob.Request``.
:param server_type: server type used in logging
:param node_iter: an iterator yielding nodes.
:param partition: partition.
:param path: path for the request.
:param backend_headers: a dict of headers to be sent with backend requests.
:param concurrency: number of requests to run concurrently.
:param policy: the policy instance, or None if Account or Container.
:param logger: a logger instance.
"""
def __init__(self, app, req, server_type, node_iter, partition, path,
backend_headers, concurrency=1, policy=None,
newest=None, logger=None):
backend_headers, concurrency=1, policy=None, logger=None):
if server_type == 'Object':
node_timeout = app.recoverable_node_timeout
else:
node_timeout = app.node_timeout
super(GetOrHeadHandler, self).__init__(
app=app, req=req, node_iter=node_iter,
partition=partition, policy=policy, path=path,
backend_headers=backend_headers, logger=logger)
app=app, req=req, node_iter=node_iter, partition=partition,
policy=policy, path=path, backend_headers=backend_headers,
node_timeout=node_timeout, resource_type=server_type.lower(),
logger=logger)
self.server_type = server_type
self.used_nodes = []
self.used_source_etag = None
self.concurrency = concurrency
self.latest_404_timestamp = Timestamp(0)
if self.server_type == 'Object':
self.node_timeout = self.app.recoverable_node_timeout
else:
self.node_timeout = self.app.node_timeout
policy_options = self.app.get_policy_options(self.policy)
self.rebalance_missing_suppression_count = min(
policy_options.rebalance_missing_suppression_count,
node_iter.num_primary_nodes - 1)
if newest is None:
self.newest = config_true_value(req.headers.get('x-newest', 'f'))
else:
self.newest = newest
self.newest = config_true_value(req.headers.get('x-newest', 'f'))
# populated when finding source
self.statuses = []
@ -1347,31 +1415,6 @@ class GetOrHeadHandler(GetterBase):
# populated from response headers
self.start_byte = self.end_byte = self.length = None
def _get_next_response_part(self):
# return the next part of the response body; there may only be one part
# unless it's a multipart/byteranges response
while True:
try:
# This call to next() performs IO when we have a
# multipart/byteranges response; it reads the MIME
# boundary and part headers.
#
# If we don't have a multipart/byteranges response,
# but just a 200 or a single-range 206, then this
# performs no IO, and either just returns source or
# raises StopIteration.
with WatchdogTimeout(self.app.watchdog, self.node_timeout,
ChunkReadTimeout):
# if StopIteration is raised, it escapes and is
# handled elsewhere
start_byte, end_byte, length, headers, part = next(
self.source.parts_iter)
return (start_byte, end_byte, length, headers, part)
except ChunkReadTimeout:
if not self._replace_source(
'Trying to read object during GET (retrying)'):
raise
def _iter_bytes_from_response_part(self, part_file, nbytes):
# yield chunks of bytes from a single response part; if an error
# occurs, try to resume yielding bytes from a different source
@ -1416,7 +1459,7 @@ class GetOrHeadHandler(GetterBase):
self.bytes_used_from_backend += len(chunk)
yield chunk
def _iter_parts_from_response(self, req):
def _iter_parts_from_response(self):
# iterate over potentially multiple response body parts; for each
# part, yield an iterator over the part's bytes
try:
@ -1441,7 +1484,7 @@ class GetOrHeadHandler(GetterBase):
'part_iter': part_iter}
self.pop_range()
except StopIteration:
req.environ['swift.non_client_disconnect'] = True
self.req.environ['swift.non_client_disconnect'] = True
finally:
if part_iter:
part_iter.close()
@ -1462,7 +1505,8 @@ class GetOrHeadHandler(GetterBase):
if end is not None and begin is not None:
if end - begin + 1 == self.bytes_used_from_backend:
warn = False
if not req.environ.get('swift.non_client_disconnect') and warn:
if (warn and
not self.req.environ.get('swift.non_client_disconnect')):
self.logger.info('Client disconnected on read of %r',
self.path)
raise
@ -1641,13 +1685,12 @@ class GetOrHeadHandler(GetterBase):
return True
return False
def _make_app_iter(self, req):
def _make_app_iter(self):
"""
Returns an iterator over the contents of the source (via its read
func). There is also quite a bit of cleanup to ensure garbage
collection works and the underlying socket of the source is closed.
:param req: incoming request object
:return: an iterator that yields chunks of response body bytes
"""
@ -1664,7 +1707,7 @@ class GetOrHeadHandler(GetterBase):
# furnished one for us, so we'll just re-use it
boundary = dict(content_type_attrs)["boundary"]
parts_iter = self._iter_parts_from_response(req)
parts_iter = self._iter_parts_from_response()
def add_content_type(response_part):
response_part["content_type"] = \
@ -1675,15 +1718,15 @@ class GetOrHeadHandler(GetterBase):
(add_content_type(pi) for pi in parts_iter),
boundary, is_multipart, self.logger)
def get_working_response(self, req):
def get_working_response(self):
res = None
if self._replace_source():
res = Response(request=req)
res = Response(request=self.req)
res.status = self.source.resp.status
update_headers(res, self.source.resp.getheaders())
if req.method == 'GET' and \
if self.req.method == 'GET' and \
self.source.resp.status in (HTTP_OK, HTTP_PARTIAL_CONTENT):
res.app_iter = self._make_app_iter(req)
res.app_iter = self._make_app_iter()
# See NOTE: swift_conn at top of file about this.
res.swift_conn = self.source.resp.swift_conn
if not res.environ:
@ -2281,7 +2324,7 @@ class Controller(object):
partition, path, backend_headers,
concurrency, policy=policy,
logger=self.logger)
res = handler.get_working_response(req)
res = handler.get_working_response()
if not res:
res = self.best_response(

View File

@ -2468,9 +2468,10 @@ class ECFragGetter(GetterBase):
backend_headers, header_provider, logger_thread_locals,
logger):
super(ECFragGetter, self).__init__(
app=app, req=req, node_iter=node_iter,
partition=partition, policy=policy, path=path,
backend_headers=backend_headers, logger=logger)
app=app, req=req, node_iter=node_iter, partition=partition,
policy=policy, path=path, backend_headers=backend_headers,
node_timeout=app.recoverable_node_timeout,
resource_type='EC fragment', logger=logger)
self.header_provider = header_provider
self.fragment_size = policy.fragment_size
self.skip_bytes = 0
@ -2478,39 +2479,13 @@ class ECFragGetter(GetterBase):
self.status = self.reason = self.body = self.source_headers = None
self._source_iter = None
def _get_next_response_part(self):
node_timeout = self.app.recoverable_node_timeout
while True:
# the loop here is to resume if trying to parse
# multipart/byteranges response raises a ChunkReadTimeout
# and resets the source_parts_iter
try:
with WatchdogTimeout(self.app.watchdog, node_timeout,
ChunkReadTimeout):
# If we don't have a multipart/byteranges response,
# but just a 200 or a single-range 206, then this
# performs no IO, and just returns source (or
# raises StopIteration).
# Otherwise, this call to next() performs IO when
# we have a multipart/byteranges response; as it
# will read the MIME boundary and part headers.
start_byte, end_byte, length, headers, part = next(
self.source.parts_iter)
return (start_byte, end_byte, length, headers, part)
except ChunkReadTimeout:
if not self._replace_source(
'Trying to read next part of EC multi-part GET '
'(retrying)'):
raise
def _iter_bytes_from_response_part(self, part_file, nbytes):
buf = b''
part_file = ByteCountEnforcer(part_file, nbytes)
while True:
try:
with WatchdogTimeout(self.app.watchdog,
self.app.recoverable_node_timeout,
self.node_timeout,
ChunkReadTimeout):
chunk = part_file.read(self.app.object_chunk_size)
# NB: this append must be *inside* the context
@ -2564,7 +2539,7 @@ class ECFragGetter(GetterBase):
if not chunk:
break
def _iter_parts_from_response(self, req):
def _iter_parts_from_response(self):
try:
part_iter = None
try:
@ -2575,7 +2550,7 @@ class ECFragGetter(GetterBase):
except StopIteration:
# it seems this is the only way out of the loop; not
# sure why the req.environ update is always needed
req.environ['swift.non_client_disconnect'] = True
self.req.environ['swift.non_client_disconnect'] = True
break
# skip_bytes compensates for the backend request range
# expansion done in _convert_range
@ -2619,7 +2594,8 @@ class ECFragGetter(GetterBase):
if end is not None and begin is not None:
if end - begin + 1 == self.bytes_used_from_backend:
warn = False
if not req.environ.get('swift.non_client_disconnect') and warn:
if (warn and
not self.req.environ.get('swift.non_client_disconnect')):
self.logger.warning(
'Client disconnected on read of EC frag %r', self.path)
raise
@ -2640,7 +2616,7 @@ class ECFragGetter(GetterBase):
else:
return HeaderKeyDict()
def _make_node_request(self, node, node_timeout):
def _make_node_request(self, node):
# make a backend request; return a response if it has an acceptable
# status code, otherwise None
self.logger.thread_locals = self.logger_thread_locals
@ -2657,7 +2633,7 @@ class ECFragGetter(GetterBase):
query_string=self.req.query_string)
self.app.set_node_timing(node, time.time() - start_node_timing)
with Timeout(node_timeout):
with Timeout(self.node_timeout):
possible_source = conn.getresponse()
# See NOTE: swift_conn at top of file about this.
possible_source.swift_conn = conn
@ -2713,9 +2689,7 @@ class ECFragGetter(GetterBase):
def _source_gen(self):
self.status = self.reason = self.body = self.source_headers = None
for node in self.node_iter:
source = self._make_node_request(
node, self.app.recoverable_node_timeout)
source = self._make_node_request(node)
if source:
yield GetterSource(self.app, source, node)
else:
@ -2739,11 +2713,10 @@ class ECFragGetter(GetterBase):
return True
return False
def response_parts_iter(self, req):
def response_parts_iter(self):
"""
Create an iterator over a single fragment response body.
:param req: a ``swob.Request``.
:return: an interator that yields chunks of bytes from a fragment
response body.
"""
@ -2755,7 +2728,7 @@ class ECFragGetter(GetterBase):
else:
if source:
self.source = source
it = self._iter_parts_from_response(req)
it = self._iter_parts_from_response()
return it
@ -2775,7 +2748,7 @@ class ECObjectController(BaseObjectController):
policy, req.swift_entity_path, backend_headers,
header_provider, logger_thread_locals,
self.logger)
return (getter, getter.response_parts_iter(req))
return getter, getter.response_parts_iter()
def _convert_range(self, req, policy):
"""

View File

@ -1673,24 +1673,35 @@ class TestGetterSource(unittest.TestCase):
@patch_policies([StoragePolicy(0, 'zero', True, object_ring=FakeRing())])
class TestGetOrHeadHandler(BaseTest):
def test_init_node_timeout(self):
conf = {'node_timeout': 2, 'recoverable_node_timeout': 3}
conf = {'node_timeout': 5, 'recoverable_node_timeout': 3}
app = proxy_server.Application(conf,
logger=self.logger,
account_ring=self.account_ring,
container_ring=self.container_ring)
# x-newest set
req = Request.blank('/v1/a/c/o', headers={'X-Newest': 'true'})
node_iter = Namespace(num_primary_nodes=3)
# app.recoverable_node_timeout
getter = GetOrHeadHandler(
app, req, 'Object', node_iter, None, None, {})
self.assertEqual(3, getter.node_timeout)
# x-newest not set
req = Request.blank('/v1/a/c/o')
node_iter = Namespace(num_primary_nodes=3)
# app.recoverable_node_timeout
getter = GetOrHeadHandler(
app, req, 'Object', node_iter, None, None, {})
self.assertEqual(3, getter.node_timeout)
# app.node_timeout
getter = GetOrHeadHandler(
app, req, 'Account', node_iter, None, None, {})
self.assertEqual(2, getter.node_timeout)
self.assertEqual(5, getter.node_timeout)
getter = GetOrHeadHandler(
app, req, 'Container', node_iter, None, None, {})
self.assertEqual(2, getter.node_timeout)
self.assertEqual(5, getter.node_timeout)
def test_disconnected_logging(self):
req = Request.blank('/v1/a/c/o')
@ -1710,7 +1721,7 @@ class TestGetOrHeadHandler(BaseTest):
with mock.patch.object(handler, '_find_source', mock_find_source):
with mock.patch.object(
handler, '_iter_parts_from_response', factory):
resp = handler.get_working_response(req)
resp = handler.get_working_response()
resp.app_iter.close()
# verify that iter exited
self.assertEqual({1: ['next', '__del__']}, factory.captured_calls)
@ -1727,7 +1738,7 @@ class TestGetOrHeadHandler(BaseTest):
with mock.patch.object(handler, '_find_source', mock_find_source):
with mock.patch.object(
handler, '_iter_parts_from_response', factory):
resp = handler.get_working_response(req)
resp = handler.get_working_response()
next(resp.app_iter)
resp.app_iter.close()
self.assertEqual({1: ['next', '__del__']}, factory.captured_calls)

View File

@ -20,6 +20,7 @@ import math
import random
import time
import unittest
import argparse
from collections import defaultdict
from contextlib import contextmanager
import json
@ -1799,15 +1800,16 @@ class TestReplicatedObjController(CommonObjectControllerMixin,
node_error_counts(self.app, self.obj_ring.devs))
# note: client response uses boundary from first backend response
self.assertEqual(resp_body1, actual_body)
error_lines = self.app.logger.get_lines_for_level('error')
self.assertEqual(1, len(error_lines))
self.assertIn('Trying to read object during GET ', error_lines[0])
return req_range_hdrs
def test_GET_with_multirange_slow_body_resumes(self):
req_range_hdrs = self._do_test_GET_with_multirange_slow_body_resumes(
slowdown_after=0)
self.assertEqual(['bytes=0-49,100-104'] * 2, req_range_hdrs)
error_lines = self.app.logger.get_lines_for_level('error')
self.assertEqual(1, len(error_lines))
self.assertIn('Trying to read next part of object multi-part GET '
'(retrying)', error_lines[0])
def test_GET_with_multirange_slow_body_resumes_before_body_started(self):
# First response times out while first part boundary/headers are being
@ -1816,6 +1818,10 @@ class TestReplicatedObjController(CommonObjectControllerMixin,
req_range_hdrs = self._do_test_GET_with_multirange_slow_body_resumes(
slowdown_after=40, resume_bytes=0)
self.assertEqual(['bytes=0-49,100-104'] * 2, req_range_hdrs)
error_lines = self.app.logger.get_lines_for_level('error')
self.assertEqual(1, len(error_lines))
self.assertIn('Trying to read next part of object multi-part GET '
'(retrying)', error_lines[0])
def test_GET_with_multirange_slow_body_resumes_after_body_started(self):
# First response times out after first part boundary/headers have been
@ -1829,6 +1835,10 @@ class TestReplicatedObjController(CommonObjectControllerMixin,
slowdown_after=140, resume_bytes=20)
self.assertEqual(['bytes=0-49,100-104', 'bytes=20-49,100-104'],
req_range_hdrs)
error_lines = self.app.logger.get_lines_for_level('error')
self.assertEqual(1, len(error_lines))
self.assertIn('Trying to read object during GET (retrying) ',
error_lines[0])
def test_GET_with_multirange_slow_body_unable_to_resume(self):
self.app.recoverable_node_timeout = 0.01
@ -1882,7 +1892,8 @@ class TestReplicatedObjController(CommonObjectControllerMixin,
error_lines = self.app.logger.get_lines_for_level('error')
self.assertEqual(3, len(error_lines))
for line in error_lines:
self.assertIn('Trying to read object during GET ', line)
self.assertIn('Trying to read next part of object multi-part GET '
'(retrying)', line)
def test_GET_unable_to_resume(self):
self.app.recoverable_node_timeout = 0.01
@ -4904,7 +4915,7 @@ class TestECObjController(ECObjectControllerMixin, unittest.TestCase):
self.assertEqual(len(log), self.policy.ec_n_unique_fragments * 2)
log_lines = self.app.logger.get_lines_for_level('error')
self.assertEqual(3, len(log_lines), log_lines)
self.assertIn('Trying to read next part of EC multi-part GET',
self.assertIn('Trying to read next part of EC fragment multi-part GET',
log_lines[0])
self.assertIn('Trying to read during GET: ChunkReadTimeout',
log_lines[1])
@ -4985,7 +4996,7 @@ class TestECObjController(ECObjectControllerMixin, unittest.TestCase):
self.assertEqual(len(log), self.policy.ec_n_unique_fragments * 2)
log_lines = self.app.logger.get_lines_for_level('error')
self.assertEqual(2, len(log_lines), log_lines)
self.assertIn('Trying to read next part of EC multi-part GET',
self.assertIn('Trying to read next part of EC fragment multi-part GET',
log_lines[0])
self.assertIn('Trying to read during GET: ChunkReadTimeout',
log_lines[1])
@ -5061,7 +5072,7 @@ class TestECObjController(ECObjectControllerMixin, unittest.TestCase):
self.assertEqual(resp.status_int, 206)
self.assertEqual(len(log), self.policy.ec_n_unique_fragments * 2)
log_lines = self.app.logger.get_lines_for_level('error')
self.assertIn("Trying to read next part of EC multi-part "
self.assertIn("Trying to read next part of EC fragment multi-part "
"GET (retrying)", log_lines[0])
# not the most graceful ending
self.assertIn("Exception fetching fragments for '/a/c/o'",
@ -7317,12 +7328,19 @@ class TestNumContainerUpdates(unittest.TestCase):
class TestECFragGetter(BaseObjectControllerMixin, unittest.TestCase):
def setUp(self):
super(TestECFragGetter, self).setUp()
req = Request.blank(path='/a/c/o')
req = Request.blank(path='/v1/a/c/o')
self.getter = obj.ECFragGetter(
self.app, req, None, None, self.policy, 'a/c/o',
{}, None, self.logger.thread_locals,
self.logger)
def test_init_node_timeout(self):
app = argparse.Namespace(node_timeout=2, recoverable_node_timeout=3)
getter = obj.ECFragGetter(
app, None, None, None, self.policy, 'a/c/o',
{}, None, None, self.logger)
self.assertEqual(3, getter.node_timeout)
def test_iter_bytes_from_response_part(self):
part = FileLikeIter([b'some', b'thing'])
it = self.getter._iter_bytes_from_response_part(part, nbytes=None)
@ -7364,15 +7382,13 @@ class TestECFragGetter(BaseObjectControllerMixin, unittest.TestCase):
def test_fragment_size(self):
source = FakeSource((
b'abcd', b'1234', b'abc', b'd1', b'234abcd1234abcd1', b'2'))
req = Request.blank('/v1/a/c/o')
def mock_source_gen():
yield GetterSource(self.app, source, {})
self.getter.fragment_size = 8
with mock.patch.object(self.getter, '_source_gen',
mock_source_gen):
it = self.getter.response_parts_iter(req)
with mock.patch.object(self.getter, '_source_gen', mock_source_gen):
it = self.getter.response_parts_iter()
fragments = list(next(it)['part_iter'])
self.assertEqual(fragments, [
@ -7386,7 +7402,6 @@ class TestECFragGetter(BaseObjectControllerMixin, unittest.TestCase):
# incomplete reads of fragment_size will be re-fetched
source2 = FakeSource([b'efgh', b'5678', b'lots', None])
source3 = FakeSource([b'lots', b'more', b'data'])
req = Request.blank('/v1/a/c/o')
range_headers = []
sources = [GetterSource(self.app, src, node)
for src in (source1, source2, source3)]
@ -7399,7 +7414,7 @@ class TestECFragGetter(BaseObjectControllerMixin, unittest.TestCase):
self.getter.fragment_size = 8
with mock.patch.object(self.getter, '_source_gen',
mock_source_gen):
it = self.getter.response_parts_iter(req)
it = self.getter.response_parts_iter()
fragments = list(next(it)['part_iter'])
self.assertEqual(fragments, [
@ -7415,7 +7430,6 @@ class TestECFragGetter(BaseObjectControllerMixin, unittest.TestCase):
range_headers = []
sources = [GetterSource(self.app, src, node)
for src in (source1, source2)]
req = Request.blank('/v1/a/c/o')
def mock_source_gen():
for source in sources:
@ -7425,7 +7439,7 @@ class TestECFragGetter(BaseObjectControllerMixin, unittest.TestCase):
self.getter.fragment_size = 8
with mock.patch.object(self.getter, '_source_gen',
mock_source_gen):
it = self.getter.response_parts_iter(req)
it = self.getter.response_parts_iter()
fragments = list(next(it)['part_iter'])
self.assertEqual(fragments, [b'abcd1234', b'efgh5678'])
self.assertEqual(range_headers, [None, 'bytes=8-'])