Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@

logger = logging.getLogger(__name__)

_FORWARDED_AUTH_FLOW_HEADERS = ("User-Agent",)


class PKCEParameters(BaseModel):
"""PKCE (Proof Key for Code Exchange) parameters."""
Expand Down Expand Up @@ -477,6 +479,14 @@ def _add_auth_header(self, request: httpx.Request) -> None:
if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"

def _forward_request_headers(self, source_request: httpx.Request, outgoing_request: httpx.Request) -> httpx.Request:
"""Forward selected caller headers to OAuth flow requests."""
for header_name in _FORWARDED_AUTH_FLOW_HEADERS:
header_value = source_request.headers.get(header_name)
if header_value is not None and header_name not in outgoing_request.headers:
outgoing_request.headers[header_name] = header_value
return outgoing_request

async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
content = await response.aread()
metadata = OAuthMetadata.model_validate_json(content)
Expand Down Expand Up @@ -508,6 +518,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
if not self.context.is_token_valid() and self.context.can_refresh_token():
# Try to refresh token
refresh_request = await self._refresh_token() # pragma: no cover
refresh_request = self._forward_request_headers(request, refresh_request) # pragma: no cover
refresh_response = yield refresh_request # pragma: no cover

if not await self._handle_refresh_response(refresh_response): # pragma: no cover
Expand All @@ -532,6 +543,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.

for url in prm_discovery_urls: # pragma: no branch
discovery_request = create_oauth_metadata_request(url)
discovery_request = self._forward_request_headers(request, discovery_request)

discovery_response = yield discovery_request # sending request

Expand All @@ -558,6 +570,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
for url in asm_discovery_urls: # pragma: no branch
oauth_metadata_request = create_oauth_metadata_request(url)
oauth_metadata_request = self._forward_request_headers(request, oauth_metadata_request)
oauth_metadata_response = yield oauth_metadata_request

ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
Expand Down Expand Up @@ -596,13 +609,16 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
self.context.client_metadata,
self.context.get_authorization_base_url(self.context.server_url),
)
registration_request = self._forward_request_headers(request, registration_request)
registration_response = yield registration_request
client_information = await handle_registration_response(registration_response)
self.context.client_info = client_information
await self.context.storage.set_client_info(client_information)

# Step 5: Perform authorization and complete token exchange
token_response = yield await self._perform_authorization()
token_request = await self._perform_authorization()
token_request = self._forward_request_headers(request, token_request)
token_response = yield token_request
await self._handle_token_response(token_response)
except Exception: # pragma: no cover
logger.exception("OAuth flow error")
Expand All @@ -624,7 +640,9 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
)

# Step 2b: Perform (re-)authorization and token exchange
token_response = yield await self._perform_authorization()
token_request = await self._perform_authorization()
token_request = self._forward_request_headers(request, token_request)
token_response = yield token_request
await self._handle_token_response(token_response)
except Exception: # pragma: no cover
logger.exception("OAuth flow error")
Expand Down
85 changes: 84 additions & 1 deletion tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self):
self._client_info: OAuthClientInformationFull | None = None

async def get_tokens(self) -> OAuthToken | None:
return self._tokens # pragma: no cover
return self._tokens

async def set_tokens(self, tokens: OAuthToken) -> None:
self._tokens = tokens
Expand Down Expand Up @@ -1167,6 +1167,89 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
assert oauth_provider.context.token_expiry_time is not None

@pytest.mark.anyio
async def test_auth_flow_forwards_user_agent_to_oauth_requests(
self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage
):
oauth_provider.context.current_tokens = None
oauth_provider.context.token_expiry_time = None
oauth_provider._initialized = True

test_request = httpx.Request(
"GET", "https://api.example.com/mcp", headers={"User-Agent": "my-custom-client/1.0"}
)
auth_flow = oauth_provider.async_auth_flow(test_request)

request = await auth_flow.__anext__()
assert request.headers["User-Agent"] == "my-custom-client/1.0"

response = httpx.Response(
401,
headers={
"WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'
},
request=test_request,
)

discovery_request = await auth_flow.asend(response)
assert discovery_request.headers["User-Agent"] == "my-custom-client/1.0"

discovery_response = httpx.Response(
200,
content=b'{"resource":"https://api.example.com/v1/mcp","authorization_servers":["https://auth.example.com"]}',
request=discovery_request,
)

oauth_metadata_request = await auth_flow.asend(discovery_response)
assert oauth_metadata_request.headers["User-Agent"] == "my-custom-client/1.0"

oauth_metadata_response = httpx.Response(
200,
content=(
b'{"issuer":"https://auth.example.com",'
b'"authorization_endpoint":"https://auth.example.com/authorize",'
b'"token_endpoint":"https://auth.example.com/token",'
b'"registration_endpoint":"https://auth.example.com/register"}'
),
request=oauth_metadata_request,
)

registration_request = await auth_flow.asend(oauth_metadata_response)
assert registration_request.headers["User-Agent"] == "my-custom-client/1.0"

registration_response = httpx.Response(
201,
content=b'{"client_id":"test_client_id","client_secret":"test_client_secret","redirect_uris":["http://localhost:3030/callback"]}',
request=registration_request,
)

oauth_provider._perform_authorization_code_grant = mock.AsyncMock(
return_value=("test_auth_code", "test_code_verifier")
)

token_request = await auth_flow.asend(registration_response)
assert token_request.headers["User-Agent"] == "my-custom-client/1.0"

token_response = httpx.Response(
200,
content=(
b'{"access_token":"new_access_token","token_type":"Bearer","expires_in":3600,'
b'"refresh_token":"new_refresh_token"}'
),
request=token_request,
)

final_request = await auth_flow.asend(token_response)
assert final_request.headers["Authorization"] == "Bearer new_access_token"

final_response = httpx.Response(200, request=final_request)
with pytest.raises(StopAsyncIteration):
await auth_flow.asend(final_response)

stored_tokens = await mock_storage.get_tokens()
assert stored_tokens is not None
assert stored_tokens.access_token == "new_access_token"

@pytest.mark.anyio
async def test_auth_flow_no_unnecessary_retry_after_oauth(
self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken
Expand Down