Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions auth0/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def __str__(self) -> str:


class RateLimitError(Auth0Error):
def __init__(self, error_code: str, message: str, reset_at: int) -> None:
super().__init__(status_code=429, error_code=error_code, message=message)
def __init__(self, error_code: str, message: str, reset_at: int, headers: Any | None = None) -> None:
super().__init__(status_code=429, error_code=error_code, message=message, headers=headers)
self.reset_at = reset_at


Expand Down
1 change: 1 addition & 0 deletions auth0/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def content(self) -> Any:
error_code=self._error_code(),
message=self._error_message(),
reset_at=reset_at,
headers=self._headers,
)
if self._error_code() == "mfa_required":
raise Auth0Error(
Expand Down
6 changes: 6 additions & 0 deletions auth0/test/authentication/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ def test_post_rate_limit_error(self, mock_request):
self.assertEqual(context.exception.message, "desc")
self.assertIsInstance(context.exception, RateLimitError)
self.assertEqual(context.exception.reset_at, 9)
self.assertIsNotNone(context.exception.headers)
self.assertEqual(context.exception.headers["x-ratelimit-limit"], "3")
self.assertEqual(context.exception.headers["x-ratelimit-remaining"], "6")
self.assertEqual(context.exception.headers["x-ratelimit-reset"], "9")

@mock.patch("requests.request")
def test_post_rate_limit_error_without_headers(self, mock_request):
Expand All @@ -177,6 +181,8 @@ def test_post_rate_limit_error_without_headers(self, mock_request):
self.assertEqual(context.exception.message, "desc")
self.assertIsInstance(context.exception, RateLimitError)
self.assertEqual(context.exception.reset_at, -1)
self.assertIsNotNone(context.exception.headers)
self.assertEqual(context.exception.headers, {})

@mock.patch("requests.request")
def test_post_error_with_code_property(self, mock_request):
Expand Down
10 changes: 5 additions & 5 deletions auth0/test/authentication/test_get_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from cryptography.hazmat.primitives import asymmetric, serialization

from ... import Auth0Error
from ...exceptions import RateLimitError
from ...authentication.get_token import GetToken


Expand Down Expand Up @@ -339,22 +339,22 @@ def test_backchannel_login(self, mock_post):
)

@mock.patch("requests.request")
def test_backchannel_login_headers_on_failure(self, mock_requests_request):
def test_backchannel_login_headers_on_slow_down(self, mock_requests_request):
response = requests.Response()
response.status_code = 400
response.status_code = 429
response.headers = {"Retry-After": "100"}
response._content = b'{"error":"slow_down"}'
mock_requests_request.return_value = response

g = GetToken("my.domain.com", "cid", client_secret="csec")

with self.assertRaises(Auth0Error) as context:
with self.assertRaises(RateLimitError) as context:
g.backchannel_login(
auth_req_id="reqid",
grant_type="urn:openid:params:grant-type:ciba",
)
self.assertEqual(context.exception.headers["Retry-After"], "100")
self.assertEqual(context.exception.status_code, 400)
self.assertEqual(context.exception.status_code, 429)

@mock.patch("auth0.rest.RestClient.post")
def test_connection_login(self, mock_post):
Expand Down
6 changes: 6 additions & 0 deletions auth0/test/management/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@ def test_get_rate_limit_error(self, mock_request):
self.assertEqual(context.exception.message, "message")
self.assertIsInstance(context.exception, RateLimitError)
self.assertEqual(context.exception.reset_at, 9)
self.assertIsNotNone(context.exception.headers)
self.assertEqual(context.exception.headers["x-ratelimit-limit"], "3")
self.assertEqual(context.exception.headers["x-ratelimit-remaining"], "6")
self.assertEqual(context.exception.headers["x-ratelimit-reset"], "9")

self.assertEqual(rc._metrics["retries"], 0)

Expand All @@ -300,6 +304,8 @@ def test_get_rate_limit_error_without_headers(self, mock_request):
self.assertEqual(context.exception.message, "message")
self.assertIsInstance(context.exception, RateLimitError)
self.assertEqual(context.exception.reset_at, -1)
self.assertIsNotNone(context.exception.headers)
self.assertEqual(context.exception.headers, {})

self.assertEqual(rc._metrics["retries"], 1)

Expand Down