From 4fdcd26cd6a2f35a772264204b343808b68f6996 Mon Sep 17 00:00:00 2001 From: Patrick Malouin Date: Wed, 17 Sep 2025 14:46:01 -0400 Subject: [PATCH] fix(backchannel): expose headers on `slow_down` errors (HTTP 429s) --- auth0/exceptions.py | 4 ++-- auth0/rest.py | 1 + auth0/test/authentication/test_base.py | 6 ++++++ auth0/test/authentication/test_get_token.py | 10 +++++----- auth0/test/management/test_rest.py | 6 ++++++ 5 files changed, 20 insertions(+), 7 deletions(-) diff --git a/auth0/exceptions.py b/auth0/exceptions.py index 533e9321..03801e68 100644 --- a/auth0/exceptions.py +++ b/auth0/exceptions.py @@ -23,8 +23,8 @@ def __str__(self) -> str: class RateLimitError(Auth0Error): - def __init__(self, error_code: str, message: str, reset_at: int) -> None: - super().__init__(status_code=429, error_code=error_code, message=message) + def __init__(self, error_code: str, message: str, reset_at: int, headers: Any | None = None) -> None: + super().__init__(status_code=429, error_code=error_code, message=message, headers=headers) self.reset_at = reset_at diff --git a/auth0/rest.py b/auth0/rest.py index 74d897ce..a2d9bd9a 100644 --- a/auth0/rest.py +++ b/auth0/rest.py @@ -289,6 +289,7 @@ def content(self) -> Any: error_code=self._error_code(), message=self._error_message(), reset_at=reset_at, + headers=self._headers, ) if self._error_code() == "mfa_required": raise Auth0Error( diff --git a/auth0/test/authentication/test_base.py b/auth0/test/authentication/test_base.py index eed9d040..a4f52d83 100644 --- a/auth0/test/authentication/test_base.py +++ b/auth0/test/authentication/test_base.py @@ -158,6 +158,10 @@ def test_post_rate_limit_error(self, mock_request): self.assertEqual(context.exception.message, "desc") self.assertIsInstance(context.exception, RateLimitError) self.assertEqual(context.exception.reset_at, 9) + self.assertIsNotNone(context.exception.headers) + self.assertEqual(context.exception.headers["x-ratelimit-limit"], "3") + self.assertEqual(context.exception.headers["x-ratelimit-remaining"], "6") + self.assertEqual(context.exception.headers["x-ratelimit-reset"], "9") @mock.patch("requests.request") def test_post_rate_limit_error_without_headers(self, mock_request): @@ -177,6 +181,8 @@ def test_post_rate_limit_error_without_headers(self, mock_request): self.assertEqual(context.exception.message, "desc") self.assertIsInstance(context.exception, RateLimitError) self.assertEqual(context.exception.reset_at, -1) + self.assertIsNotNone(context.exception.headers) + self.assertEqual(context.exception.headers, {}) @mock.patch("requests.request") def test_post_error_with_code_property(self, mock_request): diff --git a/auth0/test/authentication/test_get_token.py b/auth0/test/authentication/test_get_token.py index bc6721f1..7c98d341 100644 --- a/auth0/test/authentication/test_get_token.py +++ b/auth0/test/authentication/test_get_token.py @@ -6,7 +6,7 @@ from cryptography.hazmat.primitives import asymmetric, serialization -from ... import Auth0Error +from ...exceptions import RateLimitError from ...authentication.get_token import GetToken @@ -339,22 +339,22 @@ def test_backchannel_login(self, mock_post): ) @mock.patch("requests.request") - def test_backchannel_login_headers_on_failure(self, mock_requests_request): + def test_backchannel_login_headers_on_slow_down(self, mock_requests_request): response = requests.Response() - response.status_code = 400 + response.status_code = 429 response.headers = {"Retry-After": "100"} response._content = b'{"error":"slow_down"}' mock_requests_request.return_value = response g = GetToken("my.domain.com", "cid", client_secret="csec") - with self.assertRaises(Auth0Error) as context: + with self.assertRaises(RateLimitError) as context: g.backchannel_login( auth_req_id="reqid", grant_type="urn:openid:params:grant-type:ciba", ) self.assertEqual(context.exception.headers["Retry-After"], "100") - self.assertEqual(context.exception.status_code, 400) + self.assertEqual(context.exception.status_code, 429) @mock.patch("auth0.rest.RestClient.post") def test_connection_login(self, mock_post): diff --git a/auth0/test/management/test_rest.py b/auth0/test/management/test_rest.py index 7113c446..6288daf9 100644 --- a/auth0/test/management/test_rest.py +++ b/auth0/test/management/test_rest.py @@ -278,6 +278,10 @@ def test_get_rate_limit_error(self, mock_request): self.assertEqual(context.exception.message, "message") self.assertIsInstance(context.exception, RateLimitError) self.assertEqual(context.exception.reset_at, 9) + self.assertIsNotNone(context.exception.headers) + self.assertEqual(context.exception.headers["x-ratelimit-limit"], "3") + self.assertEqual(context.exception.headers["x-ratelimit-remaining"], "6") + self.assertEqual(context.exception.headers["x-ratelimit-reset"], "9") self.assertEqual(rc._metrics["retries"], 0) @@ -300,6 +304,8 @@ def test_get_rate_limit_error_without_headers(self, mock_request): self.assertEqual(context.exception.message, "message") self.assertIsInstance(context.exception, RateLimitError) self.assertEqual(context.exception.reset_at, -1) + self.assertIsNotNone(context.exception.headers) + self.assertEqual(context.exception.headers, {}) self.assertEqual(rc._metrics["retries"], 1)