Use ClosingMapper to ensure prompt client disconnect logging

Adds ClosingMapper class which is like map() but closes the
iterable.

Co-Authored-By: Alistair Coles <alistairncoles@gmail.com>
Change-Id: Idd0ac21b365a138b065f01d05a257af62ea88177
This commit is contained in:
Tim Burke 2024-05-02 00:17:29 +00:00
parent b4dddb7406
commit 9ec83c44fd
4 changed files with 80 additions and 5 deletions

View File

@ -2328,9 +2328,36 @@ class ClosingIterator(object):
if not self.closed:
for wrapped in self.closeables:
close_if_possible(wrapped)
# clear it out so they get GC'ed
self.closeables = []
self.wrapped_iter = iter([])
self.closed = True
class ClosingMapper(ClosingIterator):
"""
A closing iterator that yields the result of ``function`` as it is applied
to each item of ``iterable``.
Note that while this behaves similarly to the built-in ``map`` function,
``other_closeables`` does not have the same semantic as the ``iterables``
argument of ``map``.
:param function: a function that will be called with each item of
``iterable`` before yielding its result.
:param iterable: iterator to wrap.
:param other_closeables: other resources to attempt to close.
"""
__slots__ = ('func',)
def __init__(self, function, iterable, other_closeables=None):
self.func = function
super(ClosingMapper, self).__init__(iterable, other_closeables)
def _get_next_item(self):
return self.func(super(ClosingMapper, self)._get_next_item())
class CloseableChain(ClosingIterator):
"""
Like itertools.chain, but with a close method that will attempt to invoke

View File

@ -45,7 +45,7 @@ from swift.common.utils import Timestamp, WatchdogTimeout, config_true_value, \
public, split_path, list_from_csv, GreenthreadSafeIterator, \
GreenAsyncPile, quorum_size, parse_content_type, drain_and_close, \
document_iters_to_http_response_body, cache_from_env, \
CooperativeIterator, NamespaceBoundList, Namespace
CooperativeIterator, NamespaceBoundList, Namespace, ClosingMapper
from swift.common.bufferedhttp import http_connect
from swift.common import constraints
from swift.common.exceptions import ChunkReadTimeout, ChunkWriteTimeout, \
@ -1713,7 +1713,7 @@ class GetOrHeadHandler(GetterBase):
return response_part
return document_iters_to_http_response_body(
(add_content_type(pi) for pi in parts_iter),
ClosingMapper(add_content_type, parts_iter),
boundary, is_multipart, self.logger)
def get_working_response(self):

View File

@ -9645,6 +9645,51 @@ class TestClosingIterator(unittest.TestCase):
self.assertEqual([1, 1], [i.close_call_count for i in others])
class TestClosingMapper(unittest.TestCase):
def test_close(self):
calls = []
def func(args):
calls.append(args)
return sum(args)
wrapped = FakeIterable([(2, 3), (4, 5)])
other = FakeIterable([])
it = utils.ClosingMapper(func, wrapped, [other])
actual = [x for x in it]
self.assertEqual([(2, 3), (4, 5)], calls)
self.assertEqual([5, 9], actual)
self.assertEqual(1, wrapped.close_call_count)
self.assertEqual(1, other.close_call_count)
# check against result of map()
wrapped = FakeIterable([(2, 3), (4, 5)])
mapped = [x for x in map(func, wrapped)]
self.assertEqual(mapped, actual)
def test_function_raises_exception(self):
calls = []
class TestExc(Exception):
pass
def func(args):
calls.append(args)
if len(calls) > 1:
raise TestExc('boom')
else:
return sum(args)
wrapped = FakeIterable([(2, 3), (4, 5), (6, 7)])
it = utils.ClosingMapper(func, wrapped)
self.assertEqual(5, next(it))
with self.assertRaises(TestExc) as cm:
next(it)
self.assertIn('boom', str(cm.exception))
self.assertEqual(1, wrapped.close_call_count)
with self.assertRaises(StopIteration) as cm:
next(it)
class TestCloseableChain(unittest.TestCase):
def test_closeable_chain_iterates(self):
test_iter1 = FakeIterable([1])
@ -9669,7 +9714,8 @@ class TestCloseableChain(unittest.TestCase):
# close
chain = utils.CloseableChain([1, 2], [3])
chain.close()
self.assertEqual([1, 2, 3], [x for x in chain])
# read after close raises StopIteration
self.assertEqual([], [x for x in chain])
# check with generator in the chain
generator_closed = [False]

View File

@ -1724,7 +1724,8 @@ class TestGetOrHeadHandler(BaseTest):
resp = handler.get_working_response()
resp.app_iter.close()
# verify that iter exited
self.assertEqual({1: ['next', '__del__']}, factory.captured_calls)
self.assertEqual({1: ['next', 'close', '__del__']},
factory.captured_calls)
self.assertEqual(["Client disconnected on read of 'some-path'"],
self.logger.get_lines_for_level('info'))
@ -1741,7 +1742,8 @@ class TestGetOrHeadHandler(BaseTest):
resp = handler.get_working_response()
next(resp.app_iter)
resp.app_iter.close()
self.assertEqual({1: ['next', '__del__']}, factory.captured_calls)
self.assertEqual({1: ['next', 'close', '__del__']},
factory.captured_calls)
self.assertEqual([], self.logger.get_lines_for_level('warning'))
self.assertEqual([], self.logger.get_lines_for_level('info'))