Skip to content

Commit 8f7cfb0

Browse files
author
Mrutunjay Kinagi
committed
fix(auth): forward user-agent to oauth flow requests
1 parent be5bb7c commit 8f7cfb0

File tree

2 files changed

+103
-2
lines changed

2 files changed

+103
-2
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252

5353
logger = logging.getLogger(__name__)
5454

55+
_FORWARDED_AUTH_FLOW_HEADERS = ("User-Agent",)
56+
5557

5658
class PKCEParameters(BaseModel):
5759
"""PKCE (Proof Key for Code Exchange) parameters."""
@@ -477,6 +479,14 @@ def _add_auth_header(self, request: httpx.Request) -> None:
477479
if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch
478480
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
479481

482+
def _forward_request_headers(self, source_request: httpx.Request, outgoing_request: httpx.Request) -> httpx.Request:
483+
"""Forward selected caller headers to OAuth flow requests."""
484+
for header_name in _FORWARDED_AUTH_FLOW_HEADERS:
485+
header_value = source_request.headers.get(header_name)
486+
if header_value is not None and header_name not in outgoing_request.headers:
487+
outgoing_request.headers[header_name] = header_value
488+
return outgoing_request
489+
480490
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
481491
content = await response.aread()
482492
metadata = OAuthMetadata.model_validate_json(content)
@@ -508,6 +518,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
508518
if not self.context.is_token_valid() and self.context.can_refresh_token():
509519
# Try to refresh token
510520
refresh_request = await self._refresh_token() # pragma: no cover
521+
refresh_request = self._forward_request_headers(request, refresh_request) # pragma: no cover
511522
refresh_response = yield refresh_request # pragma: no cover
512523

513524
if not await self._handle_refresh_response(refresh_response): # pragma: no cover
@@ -532,6 +543,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
532543

533544
for url in prm_discovery_urls: # pragma: no branch
534545
discovery_request = create_oauth_metadata_request(url)
546+
discovery_request = self._forward_request_headers(request, discovery_request)
535547

536548
discovery_response = yield discovery_request # sending request
537549

@@ -558,6 +570,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
558570
# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
559571
for url in asm_discovery_urls: # pragma: no branch
560572
oauth_metadata_request = create_oauth_metadata_request(url)
573+
oauth_metadata_request = self._forward_request_headers(request, oauth_metadata_request)
561574
oauth_metadata_response = yield oauth_metadata_request
562575

563576
ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
@@ -596,13 +609,16 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
596609
self.context.client_metadata,
597610
self.context.get_authorization_base_url(self.context.server_url),
598611
)
612+
registration_request = self._forward_request_headers(request, registration_request)
599613
registration_response = yield registration_request
600614
client_information = await handle_registration_response(registration_response)
601615
self.context.client_info = client_information
602616
await self.context.storage.set_client_info(client_information)
603617

604618
# Step 5: Perform authorization and complete token exchange
605-
token_response = yield await self._perform_authorization()
619+
token_request = await self._perform_authorization()
620+
token_request = self._forward_request_headers(request, token_request)
621+
token_response = yield token_request
606622
await self._handle_token_response(token_response)
607623
except Exception: # pragma: no cover
608624
logger.exception("OAuth flow error")
@@ -624,7 +640,9 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
624640
)
625641

626642
# Step 2b: Perform (re-)authorization and token exchange
627-
token_response = yield await self._perform_authorization()
643+
token_request = await self._perform_authorization()
644+
token_request = self._forward_request_headers(request, token_request)
645+
token_response = yield token_request
628646
await self._handle_token_response(token_response)
629647
except Exception: # pragma: no cover
630648
logger.exception("OAuth flow error")

tests/client/test_auth.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,89 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide
11671167
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
11681168
assert oauth_provider.context.token_expiry_time is not None
11691169

1170+
@pytest.mark.anyio
1171+
async def test_auth_flow_forwards_user_agent_to_oauth_requests(
1172+
self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage
1173+
):
1174+
oauth_provider.context.current_tokens = None
1175+
oauth_provider.context.token_expiry_time = None
1176+
oauth_provider._initialized = True
1177+
1178+
test_request = httpx.Request(
1179+
"GET", "https://api.example.com/mcp", headers={"User-Agent": "my-custom-client/1.0"}
1180+
)
1181+
auth_flow = oauth_provider.async_auth_flow(test_request)
1182+
1183+
request = await auth_flow.__anext__()
1184+
assert request.headers["User-Agent"] == "my-custom-client/1.0"
1185+
1186+
response = httpx.Response(
1187+
401,
1188+
headers={
1189+
"WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'
1190+
},
1191+
request=test_request,
1192+
)
1193+
1194+
discovery_request = await auth_flow.asend(response)
1195+
assert discovery_request.headers["User-Agent"] == "my-custom-client/1.0"
1196+
1197+
discovery_response = httpx.Response(
1198+
200,
1199+
content=b'{"resource":"https://api.example.com/v1/mcp","authorization_servers":["https://auth.example.com"]}',
1200+
request=discovery_request,
1201+
)
1202+
1203+
oauth_metadata_request = await auth_flow.asend(discovery_response)
1204+
assert oauth_metadata_request.headers["User-Agent"] == "my-custom-client/1.0"
1205+
1206+
oauth_metadata_response = httpx.Response(
1207+
200,
1208+
content=(
1209+
b'{"issuer":"https://auth.example.com",'
1210+
b'"authorization_endpoint":"https://auth.example.com/authorize",'
1211+
b'"token_endpoint":"https://auth.example.com/token",'
1212+
b'"registration_endpoint":"https://auth.example.com/register"}'
1213+
),
1214+
request=oauth_metadata_request,
1215+
)
1216+
1217+
registration_request = await auth_flow.asend(oauth_metadata_response)
1218+
assert registration_request.headers["User-Agent"] == "my-custom-client/1.0"
1219+
1220+
registration_response = httpx.Response(
1221+
201,
1222+
content=b'{"client_id":"test_client_id","client_secret":"test_client_secret","redirect_uris":["http://localhost:3030/callback"]}',
1223+
request=registration_request,
1224+
)
1225+
1226+
oauth_provider._perform_authorization_code_grant = mock.AsyncMock(
1227+
return_value=("test_auth_code", "test_code_verifier")
1228+
)
1229+
1230+
token_request = await auth_flow.asend(registration_response)
1231+
assert token_request.headers["User-Agent"] == "my-custom-client/1.0"
1232+
1233+
token_response = httpx.Response(
1234+
200,
1235+
content=(
1236+
b'{"access_token":"new_access_token","token_type":"Bearer","expires_in":3600,'
1237+
b'"refresh_token":"new_refresh_token"}'
1238+
),
1239+
request=token_request,
1240+
)
1241+
1242+
final_request = await auth_flow.asend(token_response)
1243+
assert final_request.headers["Authorization"] == "Bearer new_access_token"
1244+
1245+
final_response = httpx.Response(200, request=final_request)
1246+
with pytest.raises(StopAsyncIteration):
1247+
await auth_flow.asend(final_response)
1248+
1249+
stored_tokens = await mock_storage.get_tokens()
1250+
assert stored_tokens is not None
1251+
assert stored_tokens.access_token == "new_access_token"
1252+
11701253
@pytest.mark.anyio
11711254
async def test_auth_flow_no_unnecessary_retry_after_oauth(
11721255
self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken

0 commit comments

Comments
 (0)