Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions tests/test_fetcher_ng.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,108 @@ def test_session_get_timeout(self, mock_session_get: Mock) -> None:
self.fetcher.fetch(self.url)
mock_session_get.assert_called_once()

# Test retry on ReadTimeoutError during streaming
@patch.object(urllib3.PoolManager, "request")
def test_download_bytes_retry_on_streaming_timeout(
self, mock_request: Mock
) -> None:
"""Test that download_bytes retries when ReadTimeoutError occurs during streaming."""
mock_response_fail = Mock()
mock_response_fail.status = 200
mock_response_fail.stream.side_effect = (
urllib3.exceptions.ReadTimeoutError(
urllib3.connectionpool.ConnectionPool("localhost"),
"",
"Read timed out",
)
)

mock_response_success = Mock()
mock_response_success.status = 200
mock_response_success.stream.return_value = iter(
[self.file_contents[:4], self.file_contents[4:]]
)

mock_request.side_effect = [
mock_response_fail,
mock_response_fail,
mock_response_success,
]

data = self.fetcher.download_bytes(self.url, self.file_length)
self.assertEqual(self.file_contents, data)
self.assertEqual(mock_request.call_count, 3)

# Test retry exhaustion
@patch.object(urllib3.PoolManager, "request")
def test_download_bytes_retry_exhaustion(self, mock_request: Mock) -> None:
"""Test that download_bytes fails after exhausting all retries."""
# All attempts fail
mock_response = Mock()
mock_response.status = 200
mock_response.stream.side_effect = urllib3.exceptions.ReadTimeoutError(
urllib3.connectionpool.ConnectionPool("localhost"),
"",
"Read timed out",
)
mock_request.return_value = mock_response

with self.assertRaises(exceptions.SlowRetrievalError):
self.fetcher.download_bytes(self.url, self.file_length)
# Should have been called 3 times (max_retries=3)
self.assertEqual(mock_request.call_count, 3)

# Test retry on ProtocolError during streaming
@patch.object(urllib3.PoolManager, "request")
def test_download_bytes_retry_on_protocol_error(
self, mock_request: Mock
) -> None:
"""Test that download_bytes retries when ProtocolError occurs during streaming."""
# First attempt fails with protocol error, second succeeds
mock_response_fail = Mock()
mock_response_fail.status = 200
mock_response_fail.stream.side_effect = (
urllib3.exceptions.ProtocolError("Connection broken")
)

mock_response_success = Mock()
mock_response_success.status = 200
mock_response_success.stream.return_value = iter(
[self.file_contents[:4], self.file_contents[4:]]
)

mock_request.side_effect = [
mock_response_fail,
mock_response_success,
]

data = self.fetcher.download_bytes(self.url, self.file_length)
self.assertEqual(self.file_contents, data)
self.assertEqual(mock_request.call_count, 2)

# Test that non-timeout errors are not retried
@patch.object(urllib3.PoolManager, "request")
def test_download_bytes_no_retry_on_http_error(
self, mock_request: Mock
) -> None:
"""Test that download_bytes does not retry on HTTP errors like 404."""
mock_response = Mock()
mock_response.status = 404
mock_request.return_value = mock_response

with self.assertRaises(exceptions.DownloadHTTPError):
self.fetcher.download_bytes(self.url, self.file_length)
# Should only be called once, no retries
mock_request.assert_called_once()

# Test that length mismatch errors are not retried
def test_download_bytes_no_retry_on_length_mismatch(self) -> None:
"""Test that download_bytes does not retry on length mismatch errors."""
# Try to download more data than the file contains
with self.assertRaises(exceptions.DownloadLengthMismatchError):
# File is self.file_length bytes, asking for less should fail
self.fetcher.download_bytes(self.url, self.file_length - 4)

# Simple bytes download
def test_download_bytes(self) -> None:
data = self.fetcher.download_bytes(self.url, self.file_length)
Expand Down
71 changes: 70 additions & 1 deletion tuf/ngclient/urllib3_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

# Imports
import urllib3
from urllib3.util.retry import Retry

import tuf
from tuf.api import exceptions
Expand Down Expand Up @@ -50,7 +51,21 @@ def __init__(
if app_user_agent is not None:
ua = f"{app_user_agent} {ua}"

self._proxy_env = ProxyEnvironment(headers={"User-Agent": ua})
# Configure retry strategy for connection-level retries.
# Note: This only retries at the HTTP request level (before streaming
# begins). Streaming failures are handled by the retry loop in
# download_bytes().
retry_strategy = Retry(
total=3,
read=3,
connect=3,
status_forcelist=[500, 502, 503, 504],
raise_on_status=False,
)

self._proxy_env = ProxyEnvironment(
headers={"User-Agent": ua}, retries=retry_strategy
)

def _fetch(self, url: str) -> Iterator[bytes]:
"""Fetch the contents of HTTP/HTTPS url from a remote server.
Expand Down Expand Up @@ -82,6 +97,7 @@ def _fetch(self, url: str) -> Iterator[bytes]:
except urllib3.exceptions.MaxRetryError as e:
if isinstance(e.reason, urllib3.exceptions.TimeoutError):
raise exceptions.SlowRetrievalError from e
raise

if response.status >= 400:
response.close()
Expand All @@ -106,6 +122,59 @@ def _chunks(
except urllib3.exceptions.MaxRetryError as e:
if isinstance(e.reason, urllib3.exceptions.TimeoutError):
raise exceptions.SlowRetrievalError from e
raise
except (
urllib3.exceptions.ReadTimeoutError,
urllib3.exceptions.ProtocolError,
) as e:
raise exceptions.SlowRetrievalError from e

finally:
response.release_conn()

def download_bytes(self, url: str, max_length: int) -> bytes:
"""Download bytes from given ``url`` with retry on streaming failures.

This override adds retry logic for mid-stream timeout and connection
errors that are not automatically retried by urllib3.

Args:
url: URL string that represents the location of the file.
max_length: Upper bound of data size in bytes.

Raises:
exceptions.DownloadError: An error occurred during download.
exceptions.DownloadLengthMismatchError: Downloaded bytes exceed
``max_length``.
exceptions.DownloadHTTPError: An HTTP error code was received.

Returns:
Content of the file in bytes.
"""
max_retries = 3
last_exception: Exception | None = None

for attempt in range(max_retries):
try:
return super().download_bytes(url, max_length)
except exceptions.SlowRetrievalError as e:
last_exception = e
if attempt < max_retries - 1:
logger.debug(
"Retrying download after streaming error "
"(attempt %d/%d): %s",
attempt + 1,
max_retries,
url,
)
continue
raise
except (
exceptions.DownloadHTTPError,
exceptions.DownloadLengthMismatchError,
):
raise

if last_exception:
raise last_exception
raise exceptions.DownloadError(f"Failed to download {url}")