diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index f46407754..d89c3e1df 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -496,6 +496,102 @@ async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None if not check_resource_allowed(requested_resource=default_resource, configured_resource=prm_resource): raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}") + async def _perform_oauth_discovery_and_auth( + self, + www_auth_response: httpx.Response | None = None, + ) -> AsyncGenerator[httpx.Request, httpx.Response]: + """Perform the full OAuth discovery, registration, and authorization flow. + + This is extracted as a helper to allow both eager (pre-request) and + reactive (post-401) OAuth flows to share the same implementation. + + Args: + www_auth_response: Optional 401 response to extract WWW-Authenticate + header from for RFC 9728 resource_metadata discovery. When None, + falls back to well-known URL discovery. + """ + www_auth_resource_metadata_url = ( + extract_resource_metadata_from_www_auth(www_auth_response) if www_auth_response else None + ) + + # Step 1: Discover protected resource metadata (SEP-985 with fallback support) + prm_discovery_urls = build_protected_resource_metadata_discovery_urls( + www_auth_resource_metadata_url, self.context.server_url + ) + + for url in prm_discovery_urls: # pragma: no branch + discovery_request = create_oauth_metadata_request(url) + + discovery_response = yield discovery_request # sending request + + prm = await handle_protected_resource_response(discovery_response) + if prm: + # Validate PRM resource matches server URL (RFC 8707) + await self._validate_resource_match(prm) + self.context.protected_resource_metadata = prm + + # todo: try all authorization_servers to find the OASM + assert ( + len(prm.authorization_servers) > 0 + ) # this is always true as authorization_servers has a min length of 1 + + self.context.auth_server_url = str(prm.authorization_servers[0]) + break + else: + logger.debug(f"Protected resource metadata discovery failed: {url}") + + asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( + self.context.auth_server_url, self.context.server_url + ) + + # Step 2: Discover OAuth Authorization Server Metadata (OASM) + for url in asm_discovery_urls: # pragma: no branch + oauth_metadata_request = create_oauth_metadata_request(url) + oauth_metadata_response = yield oauth_metadata_request + + ok, asm = await handle_auth_metadata_response(oauth_metadata_response) + if not ok: + break + if ok and asm: + self.context.oauth_metadata = asm + break + else: + logger.debug(f"OAuth metadata discovery failed: {url}") + + # Step 3: Apply scope selection strategy + self.context.client_metadata.scope = get_client_metadata_scopes( + extract_scope_from_www_auth(www_auth_response) if www_auth_response else None, + self.context.protected_resource_metadata, + self.context.oauth_metadata, + ) + + # Step 4: Register client or use URL-based client ID (CIMD) + if not self.context.client_info: + if should_use_client_metadata_url(self.context.oauth_metadata, self.context.client_metadata_url): + # Use URL-based client ID (CIMD) + logger.debug(f"Using URL-based client ID (CIMD): {self.context.client_metadata_url}") + client_information = create_client_info_from_metadata_url( + self.context.client_metadata_url, # type: ignore[arg-type] + redirect_uris=self.context.client_metadata.redirect_uris, + ) + self.context.client_info = client_information + await self.context.storage.set_client_info(client_information) + else: + # Fallback to Dynamic Client Registration + registration_request = create_client_registration_request( + self.context.oauth_metadata, + self.context.client_metadata, + self.context.get_authorization_base_url(self.context.server_url), + ) + registration_response = yield registration_request + client_information = await handle_registration_response(registration_response) + self.context.client_info = client_information + await self.context.storage.set_client_info(client_information) + + # Step 5: Perform authorization and complete token exchange + token_response = yield await self._perform_authorization() + await self._handle_token_response(token_response) + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: """HTTPX auth flow integration.""" async with self.context.lock: @@ -514,96 +610,48 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Refresh failed, need full re-authentication self._initialized = False + # Eager OAuth: if we have no valid token, can't refresh, AND we have + # already been through the OAuth flow at least once (we have + # client_info and discovery metadata), re-run the discovery/auth flow + # BEFORE sending the MCP request. This avoids the unnecessary + # unauthenticated round-trip that some servers (e.g. Notion) handle + # slowly, causing ~10 s latency per request. See #1274. + # + # On the very first connection (no client_info), we skip the eager + # flow and let the reactive 401 path handle discovery, because the + # server's WWW-Authenticate header may carry routing information + # (e.g. resource_metadata URL) that pure well-known discovery lacks. + # + # If the eager flow fails, we fall through gracefully and send the + # MCP request without auth so the reactive 401 path can take over. + if not self.context.is_token_valid() and self.context.client_info: + try: + oauth_gen = self._perform_oauth_discovery_and_auth() + oauth_request = await oauth_gen.__anext__() + while True: + oauth_response = yield oauth_request + oauth_request = await oauth_gen.asend(oauth_response) + except StopAsyncIteration: + pass + except Exception: + logger.debug("Eager OAuth discovery failed, falling back to reactive 401 path", exc_info=True) + if self.context.is_token_valid(): self._add_auth_header(request) response = yield request if response.status_code == 401: - # Perform full OAuth flow + # Perform full OAuth flow (reactive path — uses WWW-Authenticate + # header from the 401 response for RFC 9728 discovery) try: - # OAuth flow must be inline due to generator constraints - www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response) - - # Step 1: Discover protected resource metadata (SEP-985 with fallback support) - prm_discovery_urls = build_protected_resource_metadata_discovery_urls( - www_auth_resource_metadata_url, self.context.server_url - ) - - for url in prm_discovery_urls: # pragma: no branch - discovery_request = create_oauth_metadata_request(url) - - discovery_response = yield discovery_request # sending request - - prm = await handle_protected_resource_response(discovery_response) - if prm: - # Validate PRM resource matches server URL (RFC 8707) - await self._validate_resource_match(prm) - self.context.protected_resource_metadata = prm - - # todo: try all authorization_servers to find the OASM - assert ( - len(prm.authorization_servers) > 0 - ) # this is always true as authorization_servers has a min length of 1 - - self.context.auth_server_url = str(prm.authorization_servers[0]) - break - else: - logger.debug(f"Protected resource metadata discovery failed: {url}") - - asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( - self.context.auth_server_url, self.context.server_url - ) - - # Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers) - for url in asm_discovery_urls: # pragma: no branch - oauth_metadata_request = create_oauth_metadata_request(url) - oauth_metadata_response = yield oauth_metadata_request - - ok, asm = await handle_auth_metadata_response(oauth_metadata_response) - if not ok: - break - if ok and asm: - self.context.oauth_metadata = asm - break - else: - logger.debug(f"OAuth metadata discovery failed: {url}") - - # Step 3: Apply scope selection strategy - self.context.client_metadata.scope = get_client_metadata_scopes( - extract_scope_from_www_auth(response), - self.context.protected_resource_metadata, - self.context.oauth_metadata, - ) - - # Step 4: Register client or use URL-based client ID (CIMD) - if not self.context.client_info: - if should_use_client_metadata_url( - self.context.oauth_metadata, self.context.client_metadata_url - ): - # Use URL-based client ID (CIMD) - logger.debug(f"Using URL-based client ID (CIMD): {self.context.client_metadata_url}") - client_information = create_client_info_from_metadata_url( - self.context.client_metadata_url, # type: ignore[arg-type] - redirect_uris=self.context.client_metadata.redirect_uris, - ) - self.context.client_info = client_information - await self.context.storage.set_client_info(client_information) - else: - # Fallback to Dynamic Client Registration - registration_request = create_client_registration_request( - self.context.oauth_metadata, - self.context.client_metadata, - self.context.get_authorization_base_url(self.context.server_url), - ) - registration_response = yield registration_request - client_information = await handle_registration_response(registration_response) - self.context.client_info = client_information - await self.context.storage.set_client_info(client_information) - - # Step 5: Perform authorization and complete token exchange - token_response = yield await self._perform_authorization() - await self._handle_token_response(token_response) + oauth_gen = self._perform_oauth_discovery_and_auth(www_auth_response=response) + oauth_request = await oauth_gen.__anext__() + while True: + oauth_response = yield oauth_request + oauth_request = await oauth_gen.asend(oauth_response) + except StopAsyncIteration: + pass except Exception: # pragma: no cover logger.exception("OAuth flow error") raise diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 5aa985e36..6789db72c 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1210,6 +1210,152 @@ async def test_auth_flow_no_unnecessary_retry_after_oauth( # Verify exactly one request was yielded (no double-sending) assert request_yields == 1, f"Expected 1 request yield, got {request_yields}" + @pytest.mark.anyio + async def test_eager_oauth_flow_avoids_unauthenticated_roundtrip( + self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage + ): + """Test that when no tokens exist, OAuth discovery happens BEFORE the MCP request. + + This is the fix for issue #1274: servers like Notion can be very slow (~10s) + when handling unauthenticated requests, so we perform eager OAuth discovery + to obtain tokens first, then send the MCP request with auth already attached. + """ + # Ensure no tokens but with existing client_info — triggers the eager + # OAuth flow. The eager flow only activates when client_info is already + # present (i.e. we've been through the OAuth flow at least once before). + oauth_provider.context.current_tokens = None + oauth_provider.context.token_expiry_time = None + oauth_provider._initialized = True + + # Simulate having previously registered (required for eager flow) + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="existing_client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = oauth_provider.async_auth_flow(test_request) + + # First yielded request should be PRM discovery (path-based), NOT the MCP request. + # This is the key behavioral change for #1274. + first_request = await auth_flow.__anext__() + assert str(first_request.url) == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp" + assert first_request.method == "GET" + assert "Authorization" not in first_request.headers + + # PRM discovery returns 404 (path-based fails) + prm_response_1 = httpx.Response(404, request=first_request) + + # Should fall back to root-based PRM discovery + prm_request_2 = await auth_flow.asend(prm_response_1) + assert str(prm_request_2.url) == "https://api.example.com/.well-known/oauth-protected-resource" + + # Root PRM discovery succeeds + prm_response_2 = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', + request=prm_request_2, + ) + + # Next: OAuth authorization server metadata discovery + asm_request = await auth_flow.asend(prm_response_2) + assert asm_request.method == "GET" + assert str(asm_request.url).startswith("https://auth.example.com/") + + asm_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token", ' + b'"registration_endpoint": "https://auth.example.com/register"}' + ), + request=asm_request, + ) + + # Since client_info is already set, DCR is skipped. + # Mock authorization code grant + oauth_provider._perform_authorization_code_grant = mock.AsyncMock( + return_value=("test_auth_code", "test_code_verifier") + ) + + # Next: Token exchange (skipping DCR because client_info exists) + token_request = await auth_flow.asend(asm_response) + assert token_request.method == "POST" + assert str(token_request.url) == "https://auth.example.com/token" + + token_response = httpx.Response( + 200, + content=b'{"access_token": "eager_token", "token_type": "Bearer", "expires_in": 3600}', + request=token_request, + ) + + # NOW the MCP request is sent — with the token from the eager flow + mcp_request = await auth_flow.asend(token_response) + assert mcp_request.headers["Authorization"] == "Bearer eager_token" + assert str(mcp_request.url) == "https://api.example.com/v1/mcp" + + # Server returns 200 — no 401 round-trip needed + mcp_response = httpx.Response(200, request=mcp_request) + try: + await auth_flow.asend(mcp_response) + except StopAsyncIteration: + pass + + # Verify tokens were stored + assert oauth_provider.context.current_tokens is not None + assert oauth_provider.context.current_tokens.access_token == "eager_token" + + @pytest.mark.anyio + async def test_eager_oauth_falls_back_on_error( + self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage + ): + """Test that when the eager OAuth flow fails, we gracefully fall through + to send the MCP request without auth and let the reactive 401 path handle it. + + This covers scenarios like resource mismatch (evil PRM) or registration + failures where the eager path cannot complete. + """ + # Ensure no tokens but with existing client_info — triggers the eager flow + oauth_provider.context.current_tokens = None + oauth_provider.context.token_expiry_time = None + oauth_provider._initialized = True + + # Simulate having previously registered (required for eager flow) + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="existing_client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = oauth_provider.async_auth_flow(test_request) + + # Eager flow starts with PRM discovery + prm_request = await auth_flow.__anext__() + assert ".well-known/oauth-protected-resource" in str(prm_request.url) + + # PRM returns a mismatched resource — this triggers OAuthFlowError + # inside _validate_resource_match. The eager flow should catch this + # and fall through gracefully. + evil_prm_response = httpx.Response( + 200, + content=b'{"resource": "https://evil.example.com/mcp", "authorization_servers": ["https://evil.example.com"]}', + request=prm_request, + ) + + # The next yielded request should be the original MCP request (without auth), + # because the eager flow caught the error and fell through. + mcp_request = await auth_flow.asend(evil_prm_response) + assert str(mcp_request.url) == "https://api.example.com/v1/mcp" + assert "Authorization" not in mcp_request.headers + + # Server returns 200 (no 401) — flow completes + mcp_response = httpx.Response(200, request=mcp_request) + try: + await auth_flow.asend(mcp_response) + except StopAsyncIteration: + pass + @pytest.mark.anyio async def test_token_exchange_accepts_201_status( self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage @@ -1511,8 +1657,10 @@ async def callback_handler() -> tuple[str, str | None]: callback_handler=callback_handler, ) - provider.context.current_tokens = None - provider.context.token_expiry_time = None + # Provide a token that appears valid so the eager OAuth flow is skipped; + # the server will still return 401 to trigger the reactive path. + provider.context.current_tokens = OAuthToken(access_token="stale_token", token_type="Bearer") + provider.context.token_expiry_time = time.time() + 3600 provider._initialized = True # Mock client info to skip DCR @@ -1524,9 +1672,9 @@ async def callback_handler() -> tuple[str, str | None]: test_request = httpx.Request("GET", "https://mcp.linear.app/sse") auth_flow = provider.async_auth_flow(test_request) - # First request + # First request — now carries the stale token request = await auth_flow.__anext__() - assert "Authorization" not in request.headers + assert request.headers["Authorization"] == "Bearer stale_token" # Send 401 without WWW-Authenticate header (typical legacy server) response = httpx.Response(401, headers={}, request=test_request) @@ -1609,8 +1757,10 @@ async def callback_handler() -> tuple[str, str | None]: callback_handler=callback_handler, ) - provider.context.current_tokens = None - provider.context.token_expiry_time = None + # Provide a token that appears valid so the eager OAuth flow is skipped; + # the server will still return 401 to trigger the reactive path. + provider.context.current_tokens = OAuthToken(access_token="stale_token", token_type="Bearer") + provider.context.token_expiry_time = time.time() + 3600 provider._initialized = True provider.context.client_info = OAuthClientInformationFull( @@ -1621,7 +1771,8 @@ async def callback_handler() -> tuple[str, str | None]: test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") auth_flow = provider.async_auth_flow(test_request) - await auth_flow.__anext__() + request = await auth_flow.__anext__() + assert request.headers["Authorization"] == "Bearer stale_token" # 401 with custom WWW-Authenticate PRM URL response = httpx.Response( @@ -1749,9 +1900,10 @@ async def callback_handler() -> tuple[str, str | None]: callback_handler=callback_handler, ) - # Ensure no tokens are stored - provider.context.current_tokens = None - provider.context.token_expiry_time = None + # Provide a token that appears valid so the eager OAuth flow is skipped; + # the server will still return 401 to trigger the reactive path. + provider.context.current_tokens = OAuthToken(access_token="stale_token", token_type="Bearer") + provider.context.token_expiry_time = time.time() + 3600 provider._initialized = True # Mock client info to skip DCR @@ -1766,9 +1918,9 @@ async def callback_handler() -> tuple[str, str | None]: # Mock the auth flow auth_flow = provider.async_auth_flow(test_request) - # First request should be the original request without auth header + # First request carries the stale token request = await auth_flow.__anext__() - assert "Authorization" not in request.headers + assert request.headers["Authorization"] == "Bearer stale_token" # Send a 401 response without WWW-Authenticate header response = httpx.Response(401, headers={}, request=test_request) diff --git a/tests/client/test_scope_bug_1630.py b/tests/client/test_scope_bug_1630.py index fafa51007..d02523d8c 100644 --- a/tests/client/test_scope_bug_1630.py +++ b/tests/client/test_scope_bug_1630.py @@ -4,6 +4,7 @@ in the WWW-Authenticate header, the actual scope is used (not the resource_metadata URL). """ +import time from unittest import mock import httpx @@ -64,8 +65,10 @@ async def callback_handler() -> tuple[str, str | None]: callback_handler=callback_handler, ) - provider.context.current_tokens = None - provider.context.token_expiry_time = None + # Provide a token that appears valid so the eager OAuth flow is skipped; + # the server will still return 401 to trigger the reactive path. + provider.context.current_tokens = OAuthToken(access_token="stale_token", token_type="Bearer") + provider.context.token_expiry_time = time.time() + 3600 provider._initialized = True # Pre-set client info to skip DCR @@ -77,8 +80,9 @@ async def callback_handler() -> tuple[str, str | None]: test_request = httpx.Request("GET", "https://api.example.com/mcp") auth_flow = provider.async_auth_flow(test_request) - # First request (no auth header yet) - await auth_flow.__anext__() + # First request carries the stale token + request = await auth_flow.__anext__() + assert request.headers["Authorization"] == "Bearer stale_token" # 401 response with BOTH resource_metadata URL and scope in WWW-Authenticate # This is the key: the bug would use the URL as scope instead of "read write"