From 794feeaf65a73246c54e3babaabb51891abda907 Mon Sep 17 00:00:00 2001 From: BabyChrist666 Date: Tue, 17 Feb 2026 14:49:42 -0500 Subject: [PATCH] fix: restore eager OAuth discovery to avoid slow unauthenticated roundtrip (#1274) When the client has no valid tokens, perform OAuth discovery and authorization BEFORE sending the MCP request. This restores the eager behavior from v1.11.0 that was removed in v1.12.0, eliminating the unnecessary unauthenticated roundtrip that servers like Notion handle slowly (~10s latency per operation). Both the eager (pre-request) and reactive (post-401) paths now share a single `_perform_oauth_discovery_and_auth()` helper, keeping the code DRY while preserving RFC 9728 WWW-Authenticate header support on the 401 path. Co-Authored-By: Claude Opus 4.6 --- src/mcp/client/auth/oauth2.py | 214 +++++++++++++++++----------- tests/client/test_auth.py | 176 +++++++++++++++++++++-- tests/client/test_scope_bug_1630.py | 12 +- 3 files changed, 303 insertions(+), 99 deletions(-) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index f46407754..d89c3e1df 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -496,6 +496,102 @@ async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None if not check_resource_allowed(requested_resource=default_resource, configured_resource=prm_resource): raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}") + async def _perform_oauth_discovery_and_auth( + self, + www_auth_response: httpx.Response | None = None, + ) -> AsyncGenerator[httpx.Request, httpx.Response]: + """Perform the full OAuth discovery, registration, and authorization flow. + + This is extracted as a helper to allow both eager (pre-request) and + reactive (post-401) OAuth flows to share the same implementation. + + Args: + www_auth_response: Optional 401 response to extract WWW-Authenticate + header from for RFC 9728 resource_metadata discovery. When None, + falls back to well-known URL discovery. + """ + www_auth_resource_metadata_url = ( + extract_resource_metadata_from_www_auth(www_auth_response) if www_auth_response else None + ) + + # Step 1: Discover protected resource metadata (SEP-985 with fallback support) + prm_discovery_urls = build_protected_resource_metadata_discovery_urls( + www_auth_resource_metadata_url, self.context.server_url + ) + + for url in prm_discovery_urls: # pragma: no branch + discovery_request = create_oauth_metadata_request(url) + + discovery_response = yield discovery_request # sending request + + prm = await handle_protected_resource_response(discovery_response) + if prm: + # Validate PRM resource matches server URL (RFC 8707) + await self._validate_resource_match(prm) + self.context.protected_resource_metadata = prm + + # todo: try all authorization_servers to find the OASM + assert ( + len(prm.authorization_servers) > 0 + ) # this is always true as authorization_servers has a min length of 1 + + self.context.auth_server_url = str(prm.authorization_servers[0]) + break + else: + logger.debug(f"Protected resource metadata discovery failed: {url}") + + asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( + self.context.auth_server_url, self.context.server_url + ) + + # Step 2: Discover OAuth Authorization Server Metadata (OASM) + for url in asm_discovery_urls: # pragma: no branch + oauth_metadata_request = create_oauth_metadata_request(url) + oauth_metadata_response = yield oauth_metadata_request + + ok, asm = await handle_auth_metadata_response(oauth_metadata_response) + if not ok: + break + if ok and asm: + self.context.oauth_metadata = asm + break + else: + logger.debug(f"OAuth metadata discovery failed: {url}") + + # Step 3: Apply scope selection strategy + self.context.client_metadata.scope = get_client_metadata_scopes( + extract_scope_from_www_auth(www_auth_response) if www_auth_response else None, + self.context.protected_resource_metadata, + self.context.oauth_metadata, + ) + + # Step 4: Register client or use URL-based client ID (CIMD) + if not self.context.client_info: + if should_use_client_metadata_url(self.context.oauth_metadata, self.context.client_metadata_url): + # Use URL-based client ID (CIMD) + logger.debug(f"Using URL-based client ID (CIMD): {self.context.client_metadata_url}") + client_information = create_client_info_from_metadata_url( + self.context.client_metadata_url, # type: ignore[arg-type] + redirect_uris=self.context.client_metadata.redirect_uris, + ) + self.context.client_info = client_information + await self.context.storage.set_client_info(client_information) + else: + # Fallback to Dynamic Client Registration + registration_request = create_client_registration_request( + self.context.oauth_metadata, + self.context.client_metadata, + self.context.get_authorization_base_url(self.context.server_url), + ) + registration_response = yield registration_request + client_information = await handle_registration_response(registration_response) + self.context.client_info = client_information + await self.context.storage.set_client_info(client_information) + + # Step 5: Perform authorization and complete token exchange + token_response = yield await self._perform_authorization() + await self._handle_token_response(token_response) + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: """HTTPX auth flow integration.""" async with self.context.lock: @@ -514,96 +610,48 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Refresh failed, need full re-authentication self._initialized = False + # Eager OAuth: if we have no valid token, can't refresh, AND we have + # already been through the OAuth flow at least once (we have + # client_info and discovery metadata), re-run the discovery/auth flow + # BEFORE sending the MCP request. This avoids the unnecessary + # unauthenticated round-trip that some servers (e.g. Notion) handle + # slowly, causing ~10 s latency per request. See #1274. + # + # On the very first connection (no client_info), we skip the eager + # flow and let the reactive 401 path handle discovery, because the + # server's WWW-Authenticate header may carry routing information + # (e.g. resource_metadata URL) that pure well-known discovery lacks. + # + # If the eager flow fails, we fall through gracefully and send the + # MCP request without auth so the reactive 401 path can take over. + if not self.context.is_token_valid() and self.context.client_info: + try: + oauth_gen = self._perform_oauth_discovery_and_auth() + oauth_request = await oauth_gen.__anext__() + while True: + oauth_response = yield oauth_request + oauth_request = await oauth_gen.asend(oauth_response) + except StopAsyncIteration: + pass + except Exception: + logger.debug("Eager OAuth discovery failed, falling back to reactive 401 path", exc_info=True) + if self.context.is_token_valid(): self._add_auth_header(request) response = yield request if response.status_code == 401: - # Perform full OAuth flow + # Perform full OAuth flow (reactive path — uses WWW-Authenticate + # header from the 401 response for RFC 9728 discovery) try: - # OAuth flow must be inline due to generator constraints - www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response) - - # Step 1: Discover protected resource metadata (SEP-985 with fallback support) - prm_discovery_urls = build_protected_resource_metadata_discovery_urls( - www_auth_resource_metadata_url, self.context.server_url - ) - - for url in prm_discovery_urls: # pragma: no branch - discovery_request = create_oauth_metadata_request(url) - - discovery_response = yield discovery_request # sending request - - prm = await handle_protected_resource_response(discovery_response) - if prm: - # Validate PRM resource matches server URL (RFC 8707) - await self._validate_resource_match(prm) - self.context.protected_resource_metadata = prm - - # todo: try all authorization_servers to find the OASM - assert ( - len(prm.authorization_servers) > 0 - ) # this is always true as authorization_servers has a min length of 1 - - self.context.auth_server_url = str(prm.authorization_servers[0]) - break - else: - logger.debug(f"Protected resource metadata discovery failed: {url}") - - asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( - self.context.auth_server_url, self.context.server_url - ) - - # Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers) - for url in asm_discovery_urls: # pragma: no branch - oauth_metadata_request = create_oauth_metadata_request(url) - oauth_metadata_response = yield oauth_metadata_request - - ok, asm = await handle_auth_metadata_response(oauth_metadata_response) - if not ok: - break - if ok and asm: - self.context.oauth_metadata = asm - break - else: - logger.debug(f"OAuth metadata discovery failed: {url}") - - # Step 3: Apply scope selection strategy - self.context.client_metadata.scope = get_client_metadata_scopes( - extract_scope_from_www_auth(response), - self.context.protected_resource_metadata, - self.context.oauth_metadata, - ) - - # Step 4: Register client or use URL-based client ID (CIMD) - if not self.context.client_info: - if should_use_client_metadata_url( - self.context.oauth_metadata, self.context.client_metadata_url - ): - # Use URL-based client ID (CIMD) - logger.debug(f"Using URL-based client ID (CIMD): {self.context.client_metadata_url}") - client_information = create_client_info_from_metadata_url( - self.context.client_metadata_url, # type: ignore[arg-type] - redirect_uris=self.context.client_metadata.redirect_uris, - ) - self.context.client_info = client_information - await self.context.storage.set_client_info(client_information) - else: - # Fallback to Dynamic Client Registration - registration_request = create_client_registration_request( - self.context.oauth_metadata, - self.context.client_metadata, - self.context.get_authorization_base_url(self.context.server_url), - ) - registration_response = yield registration_request - client_information = await handle_registration_response(registration_response) - self.context.client_info = client_information - await self.context.storage.set_client_info(client_information) - - # Step 5: Perform authorization and complete token exchange - token_response = yield await self._perform_authorization() - await self._handle_token_response(token_response) + oauth_gen = self._perform_oauth_discovery_and_auth(www_auth_response=response) + oauth_request = await oauth_gen.__anext__() + while True: + oauth_response = yield oauth_request + oauth_request = await oauth_gen.asend(oauth_response) + except StopAsyncIteration: + pass except Exception: # pragma: no cover logger.exception("OAuth flow error") raise diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 5aa985e36..6789db72c 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1210,6 +1210,152 @@ async def test_auth_flow_no_unnecessary_retry_after_oauth( # Verify exactly one request was yielded (no double-sending) assert request_yields == 1, f"Expected 1 request yield, got {request_yields}" + @pytest.mark.anyio + async def test_eager_oauth_flow_avoids_unauthenticated_roundtrip( + self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage + ): + """Test that when no tokens exist, OAuth discovery happens BEFORE the MCP request. + + This is the fix for issue #1274: servers like Notion can be very slow (~10s) + when handling unauthenticated requests, so we perform eager OAuth discovery + to obtain tokens first, then send the MCP request with auth already attached. + """ + # Ensure no tokens but with existing client_info — triggers the eager + # OAuth flow. The eager flow only activates when client_info is already + # present (i.e. we've been through the OAuth flow at least once before). + oauth_provider.context.current_tokens = None + oauth_provider.context.token_expiry_time = None + oauth_provider._initialized = True + + # Simulate having previously registered (required for eager flow) + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="existing_client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = oauth_provider.async_auth_flow(test_request) + + # First yielded request should be PRM discovery (path-based), NOT the MCP request. + # This is the key behavioral change for #1274. + first_request = await auth_flow.__anext__() + assert str(first_request.url) == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp" + assert first_request.method == "GET" + assert "Authorization" not in first_request.headers + + # PRM discovery returns 404 (path-based fails) + prm_response_1 = httpx.Response(404, request=first_request) + + # Should fall back to root-based PRM discovery + prm_request_2 = await auth_flow.asend(prm_response_1) + assert str(prm_request_2.url) == "https://api.example.com/.well-known/oauth-protected-resource" + + # Root PRM discovery succeeds + prm_response_2 = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', + request=prm_request_2, + ) + + # Next: OAuth authorization server metadata discovery + asm_request = await auth_flow.asend(prm_response_2) + assert asm_request.method == "GET" + assert str(asm_request.url).startswith("https://auth.example.com/") + + asm_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token", ' + b'"registration_endpoint": "https://auth.example.com/register"}' + ), + request=asm_request, + ) + + # Since client_info is already set, DCR is skipped. + # Mock authorization code grant + oauth_provider._perform_authorization_code_grant = mock.AsyncMock( + return_value=("test_auth_code", "test_code_verifier") + ) + + # Next: Token exchange (skipping DCR because client_info exists) + token_request = await auth_flow.asend(asm_response) + assert token_request.method == "POST" + assert str(token_request.url) == "https://auth.example.com/token" + + token_response = httpx.Response( + 200, + content=b'{"access_token": "eager_token", "token_type": "Bearer", "expires_in": 3600}', + request=token_request, + ) + + # NOW the MCP request is sent — with the token from the eager flow + mcp_request = await auth_flow.asend(token_response) + assert mcp_request.headers["Authorization"] == "Bearer eager_token" + assert str(mcp_request.url) == "https://api.example.com/v1/mcp" + + # Server returns 200 — no 401 round-trip needed + mcp_response = httpx.Response(200, request=mcp_request) + try: + await auth_flow.asend(mcp_response) + except StopAsyncIteration: + pass + + # Verify tokens were stored + assert oauth_provider.context.current_tokens is not None + assert oauth_provider.context.current_tokens.access_token == "eager_token" + + @pytest.mark.anyio + async def test_eager_oauth_falls_back_on_error( + self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage + ): + """Test that when the eager OAuth flow fails, we gracefully fall through + to send the MCP request without auth and let the reactive 401 path handle it. + + This covers scenarios like resource mismatch (evil PRM) or registration + failures where the eager path cannot complete. + """ + # Ensure no tokens but with existing client_info — triggers the eager flow + oauth_provider.context.current_tokens = None + oauth_provider.context.token_expiry_time = None + oauth_provider._initialized = True + + # Simulate having previously registered (required for eager flow) + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="existing_client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = oauth_provider.async_auth_flow(test_request) + + # Eager flow starts with PRM discovery + prm_request = await auth_flow.__anext__() + assert ".well-known/oauth-protected-resource" in str(prm_request.url) + + # PRM returns a mismatched resource — this triggers OAuthFlowError + # inside _validate_resource_match. The eager flow should catch this + # and fall through gracefully. + evil_prm_response = httpx.Response( + 200, + content=b'{"resource": "https://evil.example.com/mcp", "authorization_servers": ["https://evil.example.com"]}', + request=prm_request, + ) + + # The next yielded request should be the original MCP request (without auth), + # because the eager flow caught the error and fell through. + mcp_request = await auth_flow.asend(evil_prm_response) + assert str(mcp_request.url) == "https://api.example.com/v1/mcp" + assert "Authorization" not in mcp_request.headers + + # Server returns 200 (no 401) — flow completes + mcp_response = httpx.Response(200, request=mcp_request) + try: + await auth_flow.asend(mcp_response) + except StopAsyncIteration: + pass + @pytest.mark.anyio async def test_token_exchange_accepts_201_status( self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage @@ -1511,8 +1657,10 @@ async def callback_handler() -> tuple[str, str | None]: callback_handler=callback_handler, ) - provider.context.current_tokens = None - provider.context.token_expiry_time = None + # Provide a token that appears valid so the eager OAuth flow is skipped; + # the server will still return 401 to trigger the reactive path. + provider.context.current_tokens = OAuthToken(access_token="stale_token", token_type="Bearer") + provider.context.token_expiry_time = time.time() + 3600 provider._initialized = True # Mock client info to skip DCR @@ -1524,9 +1672,9 @@ async def callback_handler() -> tuple[str, str | None]: test_request = httpx.Request("GET", "https://mcp.linear.app/sse") auth_flow = provider.async_auth_flow(test_request) - # First request + # First request — now carries the stale token request = await auth_flow.__anext__() - assert "Authorization" not in request.headers + assert request.headers["Authorization"] == "Bearer stale_token" # Send 401 without WWW-Authenticate header (typical legacy server) response = httpx.Response(401, headers={}, request=test_request) @@ -1609,8 +1757,10 @@ async def callback_handler() -> tuple[str, str | None]: callback_handler=callback_handler, ) - provider.context.current_tokens = None - provider.context.token_expiry_time = None + # Provide a token that appears valid so the eager OAuth flow is skipped; + # the server will still return 401 to trigger the reactive path. + provider.context.current_tokens = OAuthToken(access_token="stale_token", token_type="Bearer") + provider.context.token_expiry_time = time.time() + 3600 provider._initialized = True provider.context.client_info = OAuthClientInformationFull( @@ -1621,7 +1771,8 @@ async def callback_handler() -> tuple[str, str | None]: test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") auth_flow = provider.async_auth_flow(test_request) - await auth_flow.__anext__() + request = await auth_flow.__anext__() + assert request.headers["Authorization"] == "Bearer stale_token" # 401 with custom WWW-Authenticate PRM URL response = httpx.Response( @@ -1749,9 +1900,10 @@ async def callback_handler() -> tuple[str, str | None]: callback_handler=callback_handler, ) - # Ensure no tokens are stored - provider.context.current_tokens = None - provider.context.token_expiry_time = None + # Provide a token that appears valid so the eager OAuth flow is skipped; + # the server will still return 401 to trigger the reactive path. + provider.context.current_tokens = OAuthToken(access_token="stale_token", token_type="Bearer") + provider.context.token_expiry_time = time.time() + 3600 provider._initialized = True # Mock client info to skip DCR @@ -1766,9 +1918,9 @@ async def callback_handler() -> tuple[str, str | None]: # Mock the auth flow auth_flow = provider.async_auth_flow(test_request) - # First request should be the original request without auth header + # First request carries the stale token request = await auth_flow.__anext__() - assert "Authorization" not in request.headers + assert request.headers["Authorization"] == "Bearer stale_token" # Send a 401 response without WWW-Authenticate header response = httpx.Response(401, headers={}, request=test_request) diff --git a/tests/client/test_scope_bug_1630.py b/tests/client/test_scope_bug_1630.py index fafa51007..d02523d8c 100644 --- a/tests/client/test_scope_bug_1630.py +++ b/tests/client/test_scope_bug_1630.py @@ -4,6 +4,7 @@ in the WWW-Authenticate header, the actual scope is used (not the resource_metadata URL). """ +import time from unittest import mock import httpx @@ -64,8 +65,10 @@ async def callback_handler() -> tuple[str, str | None]: callback_handler=callback_handler, ) - provider.context.current_tokens = None - provider.context.token_expiry_time = None + # Provide a token that appears valid so the eager OAuth flow is skipped; + # the server will still return 401 to trigger the reactive path. + provider.context.current_tokens = OAuthToken(access_token="stale_token", token_type="Bearer") + provider.context.token_expiry_time = time.time() + 3600 provider._initialized = True # Pre-set client info to skip DCR @@ -77,8 +80,9 @@ async def callback_handler() -> tuple[str, str | None]: test_request = httpx.Request("GET", "https://api.example.com/mcp") auth_flow = provider.async_auth_flow(test_request) - # First request (no auth header yet) - await auth_flow.__anext__() + # First request carries the stale token + request = await auth_flow.__anext__() + assert request.headers["Authorization"] == "Bearer stale_token" # 401 response with BOTH resource_metadata URL and scope in WWW-Authenticate # This is the key: the bug would use the URL as scope instead of "read write"