diff --git a/packetary/library/connections.py b/packetary/library/connections.py index 6b013a7..b6df4de 100644 --- a/packetary/library/connections.py +++ b/packetary/library/connections.py @@ -93,9 +93,23 @@ class ResumableResponse(StreamWrapper): self.stream = response.stream -class RetryHandler(urllib.BaseHandler): +class RetryHandler(urllib.HTTPRedirectHandler): """urllib Handler to add ability for retrying on server errors.""" + def redirect_request(self, req, fp, code, msg, headers, newurl): + new_req = urllib.HTTPRedirectHandler.redirect_request( + self, req, fp, code, msg, headers, newurl + ) + if new_req is not None: + # We use class assignment for casting new request to type + # RetryableRequest + new_req.__class__ = RetryableRequest + new_req.retries_left = req.retries_left + new_req.offset = req.offset + new_req.start_time = req.start_time + new_req.retry_interval = req.retry_interval + return new_req + @staticmethod def http_request(request): """Initialises http request. @@ -118,6 +132,11 @@ class RetryHandler(urllib.BaseHandler): :return: ResumableResponse if success otherwise same response """ code, msg = response.getcode(), response.msg + + if 300 <= code < 400: + # the redirect group, pass to next handler as is + return response + # the server should response partial content if range is specified if request.offset > 0 and code != http_client.PARTIAL_CONTENT: raise RangeError(msg) diff --git a/packetary/tests/test_connections.py b/packetary/tests/test_connections.py index a2621c8..c80b03d 100644 --- a/packetary/tests/test_connections.py +++ b/packetary/tests/test_connections.py @@ -268,6 +268,25 @@ class TestRetryHandler(base.TestCase): self.handler.http_response(request, response_mock) self.handler.parent.open.assert_called_once_with(request) + @mock.patch( + 'packetary.library.connections.urllib.' + 'HTTPRedirectHandler.redirect_request' + ) + def test_redirect_request(self, redirect_mock, _): + redirect_mock.return_value = connections.urllib.Request( + 'http://localhost/' + ) + req = mock.MagicMock(retries_left=10, retry_interval=5, offset=100) + new_req = self.handler.redirect_request(req, -1, 301, "", {}, "") + self.assertIsInstance(new_req, connections.RetryableRequest) + self.assertEqual(req.retries_left, new_req.retries_left) + self.assertEqual(req.retry_interval, new_req.retry_interval) + self.assertEqual(req.offset, new_req.offset) + redirect_mock.return_value = None + self.assertIsNone( + self.handler.redirect_request(req, -1, 301, "", {}, "") + ) + class TestResumeableResponse(base.TestCase): def setUp(self):