From c522f5676e33dd0b9422019818a9106cde1d39f7 Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Thu, 4 Jan 2024 05:06:32 +0000 Subject: [PATCH] Add ClosingIterator class; be more explicit about closes ... in document_iters_to_http_response_body. We seemed to be relying a little too heavily upon prompt garbage collection to log client disconnects, leading to failures in test_base.py::TestGetOrHeadHandler::test_disconnected_logging under python 3.12. Closes-Bug: #2046352 Co-Authored-By: Alistair Coles Change-Id: I4479d2690f708312270eb92759789ddce7f7f930 --- swift/common/utils/__init__.py | 147 +++++++++++++----- test/unit/__init__.py | 66 +++++++- test/unit/common/test_utils.py | 188 ++++++++++++++++++++++- test/unit/proxy/controllers/test_base.py | 26 ++-- 4 files changed, 370 insertions(+), 57 deletions(-) diff --git a/swift/common/utils/__init__.py b/swift/common/utils/__init__.py index ce4e099f44..0a000b8c5d 100644 --- a/swift/common/utils/__init__.py +++ b/swift/common/utils/__init__.py @@ -3687,27 +3687,66 @@ def csv_append(csv_string, item): return item -class CloseableChain(object): +class ClosingIterator(object): + """ + Wrap another iterator and close it, if possible, on completion/exception. + + If other closeable objects are given then they will also be closed when + this iterator is closed. + + This is particularly useful for ensuring a generator properly closes its + resources, even if the generator was never started. + + This class may be subclassed to override the behavior of + ``_get_next_item``. + + :param iterable: iterator to wrap. + :param other_closeables: other resources to attempt to close. + """ + __slots__ = ('closeables', 'wrapped_iter', 'closed') + + def __init__(self, iterable, other_closeables=None): + self.closeables = [iterable] + if other_closeables: + self.closeables.extend(other_closeables) + # this is usually, but not necessarily, the same object + self.wrapped_iter = iter(iterable) + self.closed = False + + def __iter__(self): + return self + + def _get_next_item(self): + return next(self.wrapped_iter) + + def __next__(self): + try: + return self._get_next_item() + except Exception: + # note: if wrapped_iter is a generator then the exception + # already caused it to exit (without raising a GeneratorExit) + # but we still need to close any other closeables. + self.close() + raise + + next = __next__ # py2 + + def close(self): + if not self.closed: + for wrapped in self.closeables: + close_if_possible(wrapped) + self.closed = True + + +class CloseableChain(ClosingIterator): """ Like itertools.chain, but with a close method that will attempt to invoke its sub-iterators' close methods, if any. """ def __init__(self, *iterables): - self.iterables = iterables - self.chained_iter = itertools.chain(*self.iterables) - - def __iter__(self): - return self - - def __next__(self): - return next(self.chained_iter) - - next = __next__ # py2 - - def close(self): - for it in self.iterables: - close_if_possible(it) + chained_iter = itertools.chain(*iterables) + super(CloseableChain, self).__init__(chained_iter, iterables) def reiterate(iterable): @@ -4396,6 +4435,47 @@ def document_iters_to_multipart_byteranges(ranges_iter, boundary): yield terminator +class StringAlong(ClosingIterator): + """ + This iterator wraps and iterates over a first iterator until it stops, and + then iterates a second iterator, expecting it to stop immediately. This + "stringing along" of the second iterator is useful when the exit of the + second iterator must be delayed until the first iterator has stopped. For + example, when the second iterator has already yielded its item(s) but + has resources that mustn't be garbage collected until the first iterator + has stopped. + + The second iterator is expected to have no more items and raise + StopIteration when called. If this is not the case then + ``unexpected_items_func`` is called. + + :param iterable: a first iterator that is wrapped and iterated. + :param other_iter: a second iterator that is stopped once the first + iterator has stopped. + :param unexpected_items_func: a no-arg function that will be called if the + second iterator is found to have remaining items. + """ + __slots__ = ('other_iter', 'unexpected_items_func') + + def __init__(self, iterable, other_iter, unexpected_items_func): + super(StringAlong, self).__init__(iterable, [other_iter]) + self.other_iter = other_iter + self.unexpected_items_func = unexpected_items_func + + def _get_next_item(self): + try: + return super(StringAlong, self)._get_next_item() + except StopIteration: + try: + next(self.other_iter) + except StopIteration: + pass + else: + self.unexpected_items_func() + finally: + raise + + def document_iters_to_http_response_body(ranges_iter, boundary, multipart, logger): """ @@ -4445,20 +4525,11 @@ def document_iters_to_http_response_body(ranges_iter, boundary, multipart, # ranges_iter has a finally block that calls close_swift_conn, and # so if that finally block fires before we read response_body_iter, # there's nothing there. - def string_along(useful_iter, useless_iter_iter, logger): - with closing_if_possible(useful_iter): - for x in useful_iter: - yield x - - try: - next(useless_iter_iter) - except StopIteration: - pass - else: - logger.warning( - "More than one part in a single-part response?") - - return string_along(response_body_iter, ranges_iter, logger) + result = StringAlong( + response_body_iter, ranges_iter, + lambda: logger.warning( + "More than one part in a single-part response?")) + return result def multipart_byteranges_to_document_iters(input_file, boundary, @@ -6430,7 +6501,7 @@ class WatchdogTimeout(object): self.watchdog.stop(self.key) -class CooperativeIterator(object): +class CooperativeIterator(ClosingIterator): """ Wrapper to make a deliberate periodic call to ``sleep()`` while iterating over wrapped iterator, providing an opportunity to switch greenthreads. @@ -6452,24 +6523,16 @@ class CooperativeIterator(object): :param period: number of items yielded from this iterator between calls to ``sleep()``. """ - __slots__ = ('period', 'count', 'wrapped_iter') + __slots__ = ('period', 'count') def __init__(self, iterable, period=5): - self.wrapped_iter = iterable + super(CooperativeIterator, self).__init__(iterable) self.count = 0 self.period = period - def __iter__(self): - return self - - def next(self): + def _get_next_item(self): if self.count >= self.period: self.count = 0 sleep() self.count += 1 - return next(self.wrapped_iter) - - __next__ = next - - def close(self): - close_if_possible(self.wrapped_iter) + return super(CooperativeIterator, self)._get_next_item() diff --git a/test/unit/__init__.py b/test/unit/__init__.py index 6f958cb0d9..6eb1db2c45 100644 --- a/test/unit/__init__.py +++ b/test/unit/__init__.py @@ -53,7 +53,7 @@ from swift.common import storage_policy, swob, utils, exceptions from swift.common.memcached import MemcacheConnectionError from swift.common.storage_policy import (StoragePolicy, ECStoragePolicy, VALID_EC_TYPES) -from swift.common.utils import Timestamp, md5 +from swift.common.utils import Timestamp, md5, close_if_possible from test import get_config from test.debug_logger import FakeLogger from swift.common.header_key_dict import HeaderKeyDict @@ -1499,6 +1499,70 @@ class FakeSource(object): [(k, v) for k, v in self.headers.items()] +class CaptureIterator(object): + """ + Wraps an iterable, forwarding all calls to the wrapped iterable but + capturing the calls via a callback. + + This class may be used to observe garbage collection, so tests should not + have to hold a reference to instances of this class because that would + prevent them being garbage collected. Calls are therefore captured via a + callback rather than being stashed locally. + + :param wrapped: an iterable to wrap. + :param call_capture_callback: a function that will be called to capture + calls to this iterator. + """ + def __init__(self, wrapped, call_capture_callback): + self.call_capture_callback = call_capture_callback + self.wrapped_iter = wrapped + + def _capture_call(self): + # call home to capture the call + self.call_capture_callback(inspect.stack()[1][3]) + + def __iter__(self): + return self + + def next(self): + self._capture_call() + return next(self.wrapped_iter) + + __next__ = next + + def __del__(self): + self._capture_call() + + def close(self): + self._capture_call() + close_if_possible(self.wrapped_iter) + + +class CaptureIteratorFactory(object): + """ + Create instances of ``CaptureIterator`` to wrap a given iterable, and + provides a callback function for the ``CaptureIterator`` to capture its + calls. + + :param wrapped: an iterable to wrap. + """ + def __init__(self, wrapped): + self.wrapped = wrapped + self.instance_count = 0 + self.captured_calls = defaultdict(list) + + def log_call(self, instance_number, call): + self.captured_calls[instance_number].append(call) + + def __call__(self, *args, **kwargs): + # note: do not keep a reference to the CaptureIterator because that + # would prevent it being garbage collected + self.instance_count += 1 + return CaptureIterator( + self.wrapped(*args, **kwargs), + functools.partial(self.log_call, self.instance_count)) + + def get_node_error_stats(proxy_app, ring_node): node_key = proxy_app.error_limiter.node_key(ring_node) return proxy_app.error_limiter.stats.get(node_key) or {} diff --git a/test/unit/common/test_utils.py b/test/unit/common/test_utils.py index bd692c8ea1..5acc0a5324 100644 --- a/test/unit/common/test_utils.py +++ b/test/unit/common/test_utils.py @@ -6861,20 +6861,34 @@ class FakeResponse(object): class TestDocumentItersToHTTPResponseBody(unittest.TestCase): def test_no_parts(self): + logger = debug_logger() body = utils.document_iters_to_http_response_body( - iter([]), 'dontcare', - multipart=False, logger=debug_logger()) + iter([]), 'dontcare', multipart=False, logger=logger) self.assertEqual(body, '') + self.assertFalse(logger.all_log_lines()) def test_single_part(self): body = b"time flies like an arrow; fruit flies like a banana" doc_iters = [{'part_iter': iter(BytesIO(body).read, b'')}] + logger = debug_logger() resp_body = b''.join( utils.document_iters_to_http_response_body( - iter(doc_iters), b'dontcare', - multipart=False, logger=debug_logger())) + iter(doc_iters), b'dontcare', multipart=False, logger=logger)) self.assertEqual(resp_body, body) + self.assertFalse(logger.all_log_lines()) + + def test_single_part_unexpected_ranges(self): + body = b"time flies like an arrow; fruit flies like a banana" + doc_iters = [{'part_iter': iter(BytesIO(body).read, b'')}, 'junk'] + logger = debug_logger() + + resp_body = b''.join( + utils.document_iters_to_http_response_body( + iter(doc_iters), b'dontcare', multipart=False, logger=logger)) + self.assertEqual(resp_body, body) + self.assertEqual(['More than one part in a single-part response?'], + logger.get_lines_for_level('warning')) def test_multiple_parts(self): part1 = b"two peanuts were walking down a railroad track" @@ -6915,7 +6929,6 @@ class TestDocumentItersToHTTPResponseBody(unittest.TestCase): b"--boundaryboundary--")) def test_closed_part_iterator(self): - print('test') useful_iter_mock = mock.MagicMock() useful_iter_mock.__iter__.return_value = [''] body_iter = utils.document_iters_to_http_response_body( @@ -9563,6 +9576,138 @@ class TestReiterate(unittest.TestCase): self.assertIs(test_tuple, reiterated) +class TestClosingIterator(unittest.TestCase): + def _make_gen(self, items, captured_exit): + def gen(): + try: + for it in items: + if isinstance(it, Exception): + raise it + yield it + except GeneratorExit as e: + captured_exit.append(e) + raise + return gen() + + def test_close(self): + wrapped = FakeIterable([1, 2, 3]) + # note: iter(FakeIterable) is the same object + self.assertIs(wrapped, iter(wrapped)) + it = utils.ClosingIterator(wrapped) + actual = [x for x in it] + self.assertEqual([1, 2, 3], actual) + self.assertEqual(1, wrapped.close_call_count) + it.close() + self.assertEqual(1, wrapped.close_call_count) + + def test_close_others(self): + wrapped = FakeIterable([1, 2, 3]) + others = [FakeIterable([4, 5, 6]), FakeIterable([])] + self.assertIs(wrapped, iter(wrapped)) + it = utils.ClosingIterator(wrapped, others) + actual = [x for x in it] + self.assertEqual([1, 2, 3], actual) + self.assertEqual([1, 1, 1], + [i.close_call_count for i in others + [wrapped]]) + it.close() + self.assertEqual([1, 1, 1], + [i.close_call_count for i in others + [wrapped]]) + + def test_close_gen(self): + # explicitly check generator closing + captured_exit = [] + gen = self._make_gen([1, 2], captured_exit) + it = utils.ClosingIterator(gen) + self.assertFalse(captured_exit) + it.close() + self.assertFalse(captured_exit) # the generator didn't start + + captured_exit = [] + gen = self._make_gen([1, 2], captured_exit) + it = utils.ClosingIterator(gen) + self.assertFalse(captured_exit) + self.assertEqual(1, next(it)) # start the generator + it.close() + self.assertEqual(1, len(captured_exit)) + + def test_close_wrapped_is_not_same_as_iter(self): + class AltFakeIterable(FakeIterable): + def __iter__(self): + return (x for x in self.values) + + wrapped = AltFakeIterable([1, 2, 3]) + # note: iter(AltFakeIterable) is a generator, not the same object + self.assertIsNot(wrapped, iter(wrapped)) + it = utils.ClosingIterator(wrapped) + actual = [x for x in it] + self.assertEqual([1, 2, 3], actual) + self.assertEqual(1, wrapped.close_call_count) + it.close() + self.assertEqual(1, wrapped.close_call_count) + + def test_init_with_iterable(self): + wrapped = [1, 2, 3] # list is iterable but not an iterator + it = utils.ClosingIterator(wrapped) + actual = [x for x in it] + self.assertEqual([1, 2, 3], actual) + it.close() # safe to call even though list has no close + + def test_nested_iters(self): + wrapped = FakeIterable([1, 2, 3]) + it = utils.ClosingIterator(utils.ClosingIterator(wrapped)) + actual = [x for x in it] + self.assertEqual([1, 2, 3], actual) + self.assertEqual(1, wrapped.close_call_count) + it.close() + self.assertEqual(1, wrapped.close_call_count) + + def test_close_on_stop_iteration(self): + wrapped = FakeIterable([1, 2, 3]) + others = [FakeIterable([4, 5, 6]), FakeIterable([])] + self.assertIs(wrapped, iter(wrapped)) + it = utils.ClosingIterator(wrapped, others) + actual = [x for x in it] + self.assertEqual([1, 2, 3], actual) + self.assertEqual([1, 1, 1], + [i.close_call_count for i in others + [wrapped]]) + it.close() + self.assertEqual([1, 1, 1], + [i.close_call_count for i in others + [wrapped]]) + + def test_close_on_exception(self): + # sanity check: generator exits on raising exception without executing + # GeneratorExit + captured_exit = [] + gen = self._make_gen([1, ValueError(), 2], captured_exit) + self.assertEqual(1, next(gen)) + with self.assertRaises(ValueError): + next(gen) + self.assertFalse(captured_exit) + gen.close() + self.assertFalse(captured_exit) # gen already exited + + captured_exit = [] + gen = self._make_gen([1, ValueError(), 2], captured_exit) + self.assertEqual(1, next(gen)) + with self.assertRaises(ValueError): + next(gen) + self.assertFalse(captured_exit) + with self.assertRaises(StopIteration): + next(gen) # gen already exited + + # wrapped gen does the same... + captured_exit = [] + gen = self._make_gen([1, ValueError(), 2], captured_exit) + others = [FakeIterable([4, 5, 6]), FakeIterable([])] + it = utils.ClosingIterator(gen, others) + self.assertEqual(1, next(it)) + with self.assertRaises(ValueError): + next(it) + self.assertFalse(captured_exit) + # but other iters are closed :) + self.assertEqual([1, 1], [i.close_call_count for i in others]) + + class TestCloseableChain(unittest.TestCase): def test_closeable_chain_iterates(self): test_iter1 = FakeIterable([1]) @@ -9619,6 +9764,39 @@ class TestCloseableChain(unittest.TestCase): self.assertTrue(generator_closed[0]) +class TestStringAlong(unittest.TestCase): + def test_happy(self): + logger = debug_logger() + it = FakeIterable([1, 2, 3]) + other_it = FakeIterable([]) + string_along = utils.StringAlong( + it, other_it, lambda: logger.warning('boom')) + for i, x in enumerate(string_along): + self.assertEqual(i + 1, x) + self.assertEqual(0, other_it.next_call_count, x) + self.assertEqual(0, other_it.close_call_count, x) + self.assertEqual(1, other_it.next_call_count, x) + self.assertEqual(1, other_it.close_call_count, x) + lines = logger.get_lines_for_level('warning') + self.assertFalse(lines) + + def test_unhappy(self): + logger = debug_logger() + it = FakeIterable([1, 2, 3]) + other_it = FakeIterable([1]) + string_along = utils.StringAlong( + it, other_it, lambda: logger.warning('boom')) + for i, x in enumerate(string_along): + self.assertEqual(i + 1, x) + self.assertEqual(0, other_it.next_call_count, x) + self.assertEqual(0, other_it.close_call_count, x) + self.assertEqual(1, other_it.next_call_count, x) + self.assertEqual(1, other_it.close_call_count, x) + lines = logger.get_lines_for_level('warning') + self.assertEqual(1, len(lines)) + self.assertIn('boom', lines[0]) + + class TestCooperativeIterator(unittest.TestCase): def test_init(self): wrapped = itertools.count() diff --git a/test/unit/proxy/controllers/test_base.py b/test/unit/proxy/controllers/test_base.py index 44b6f804ed..9c442b8ec3 100644 --- a/test/unit/proxy/controllers/test_base.py +++ b/test/unit/proxy/controllers/test_base.py @@ -42,7 +42,7 @@ from swift.common.storage_policy import StoragePolicy, StoragePolicyCollection from test.debug_logger import debug_logger from test.unit import ( fake_http_connect, FakeRing, FakeMemcache, PatchPolicies, patch_policies, - FakeSource, StubResponse) + FakeSource, StubResponse, CaptureIteratorFactory) from swift.common.request_helpers import ( get_sys_meta_prefix, get_object_transient_sysmeta ) @@ -1706,10 +1706,14 @@ class TestGetOrHeadHandler(BaseTest): handler.source = GetterSource(self.app, source, node) return True - with mock.patch.object(handler, '_find_source', - mock_find_source): - resp = handler.get_working_response(req) - resp.app_iter.close() + factory = CaptureIteratorFactory(handler._iter_parts_from_response) + with mock.patch.object(handler, '_find_source', mock_find_source): + with mock.patch.object( + handler, '_iter_parts_from_response', factory): + resp = handler.get_working_response(req) + resp.app_iter.close() + # verify that iter exited + self.assertEqual({1: ['next', '__del__']}, factory.captured_calls) self.assertEqual(["Client disconnected on read of 'some-path'"], self.logger.get_lines_for_level('info')) @@ -1719,12 +1723,16 @@ class TestGetOrHeadHandler(BaseTest): self.app, req, 'Object', Namespace(num_primary_nodes=1), None, None, {}) - with mock.patch.object(handler, '_find_source', - mock_find_source): - resp = handler.get_working_response(req) - next(resp.app_iter) + factory = CaptureIteratorFactory(handler._iter_parts_from_response) + with mock.patch.object(handler, '_find_source', mock_find_source): + with mock.patch.object( + handler, '_iter_parts_from_response', factory): + resp = handler.get_working_response(req) + next(resp.app_iter) resp.app_iter.close() + self.assertEqual({1: ['next', '__del__']}, factory.captured_calls) self.assertEqual([], self.logger.get_lines_for_level('warning')) + self.assertEqual([], self.logger.get_lines_for_level('info')) def test_range_fast_forward(self): req = Request.blank('/')