Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 131 additions & 83 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,102 @@ async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None
if not check_resource_allowed(requested_resource=default_resource, configured_resource=prm_resource):
raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}")

async def _perform_oauth_discovery_and_auth(
self,
www_auth_response: httpx.Response | None = None,
) -> AsyncGenerator[httpx.Request, httpx.Response]:
"""Perform the full OAuth discovery, registration, and authorization flow.

This is extracted as a helper to allow both eager (pre-request) and
reactive (post-401) OAuth flows to share the same implementation.

Args:
www_auth_response: Optional 401 response to extract WWW-Authenticate
header from for RFC 9728 resource_metadata discovery. When None,
falls back to well-known URL discovery.
"""
www_auth_resource_metadata_url = (
extract_resource_metadata_from_www_auth(www_auth_response) if www_auth_response else None
)

# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
prm_discovery_urls = build_protected_resource_metadata_discovery_urls(
www_auth_resource_metadata_url, self.context.server_url
)

for url in prm_discovery_urls: # pragma: no branch
discovery_request = create_oauth_metadata_request(url)

discovery_response = yield discovery_request # sending request

prm = await handle_protected_resource_response(discovery_response)
if prm:
# Validate PRM resource matches server URL (RFC 8707)
await self._validate_resource_match(prm)
self.context.protected_resource_metadata = prm

# todo: try all authorization_servers to find the OASM
assert (
len(prm.authorization_servers) > 0
) # this is always true as authorization_servers has a min length of 1

self.context.auth_server_url = str(prm.authorization_servers[0])
break
else:
logger.debug(f"Protected resource metadata discovery failed: {url}")

asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
self.context.auth_server_url, self.context.server_url
)

# Step 2: Discover OAuth Authorization Server Metadata (OASM)
for url in asm_discovery_urls: # pragma: no branch
oauth_metadata_request = create_oauth_metadata_request(url)
oauth_metadata_response = yield oauth_metadata_request

ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
if not ok:
break
if ok and asm:
self.context.oauth_metadata = asm
break
else:
logger.debug(f"OAuth metadata discovery failed: {url}")

# Step 3: Apply scope selection strategy
self.context.client_metadata.scope = get_client_metadata_scopes(
extract_scope_from_www_auth(www_auth_response) if www_auth_response else None,
self.context.protected_resource_metadata,
self.context.oauth_metadata,
)

# Step 4: Register client or use URL-based client ID (CIMD)
if not self.context.client_info:
if should_use_client_metadata_url(self.context.oauth_metadata, self.context.client_metadata_url):
# Use URL-based client ID (CIMD)
logger.debug(f"Using URL-based client ID (CIMD): {self.context.client_metadata_url}")
client_information = create_client_info_from_metadata_url(
self.context.client_metadata_url, # type: ignore[arg-type]
redirect_uris=self.context.client_metadata.redirect_uris,
)
self.context.client_info = client_information
await self.context.storage.set_client_info(client_information)
else:
# Fallback to Dynamic Client Registration
registration_request = create_client_registration_request(
self.context.oauth_metadata,
self.context.client_metadata,
self.context.get_authorization_base_url(self.context.server_url),
)
registration_response = yield registration_request
client_information = await handle_registration_response(registration_response)
self.context.client_info = client_information
await self.context.storage.set_client_info(client_information)

# Step 5: Perform authorization and complete token exchange
token_response = yield await self._perform_authorization()
await self._handle_token_response(token_response)

async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
"""HTTPX auth flow integration."""
async with self.context.lock:
Expand All @@ -514,96 +610,48 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
# Refresh failed, need full re-authentication
self._initialized = False

# Eager OAuth: if we have no valid token, can't refresh, AND we have
# already been through the OAuth flow at least once (we have
# client_info and discovery metadata), re-run the discovery/auth flow
# BEFORE sending the MCP request. This avoids the unnecessary
# unauthenticated round-trip that some servers (e.g. Notion) handle
# slowly, causing ~10 s latency per request. See #1274.
#
# On the very first connection (no client_info), we skip the eager
# flow and let the reactive 401 path handle discovery, because the
# server's WWW-Authenticate header may carry routing information
# (e.g. resource_metadata URL) that pure well-known discovery lacks.
#
# If the eager flow fails, we fall through gracefully and send the
# MCP request without auth so the reactive 401 path can take over.
if not self.context.is_token_valid() and self.context.client_info:
try:
oauth_gen = self._perform_oauth_discovery_and_auth()
oauth_request = await oauth_gen.__anext__()
while True:
oauth_response = yield oauth_request
oauth_request = await oauth_gen.asend(oauth_response)
except StopAsyncIteration:
pass
except Exception:
logger.debug("Eager OAuth discovery failed, falling back to reactive 401 path", exc_info=True)

if self.context.is_token_valid():
self._add_auth_header(request)

response = yield request

if response.status_code == 401:
# Perform full OAuth flow
# Perform full OAuth flow (reactive path — uses WWW-Authenticate
# header from the 401 response for RFC 9728 discovery)
try:
# OAuth flow must be inline due to generator constraints
www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response)

# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
prm_discovery_urls = build_protected_resource_metadata_discovery_urls(
www_auth_resource_metadata_url, self.context.server_url
)

for url in prm_discovery_urls: # pragma: no branch
discovery_request = create_oauth_metadata_request(url)

discovery_response = yield discovery_request # sending request

prm = await handle_protected_resource_response(discovery_response)
if prm:
# Validate PRM resource matches server URL (RFC 8707)
await self._validate_resource_match(prm)
self.context.protected_resource_metadata = prm

# todo: try all authorization_servers to find the OASM
assert (
len(prm.authorization_servers) > 0
) # this is always true as authorization_servers has a min length of 1

self.context.auth_server_url = str(prm.authorization_servers[0])
break
else:
logger.debug(f"Protected resource metadata discovery failed: {url}")

asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
self.context.auth_server_url, self.context.server_url
)

# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
for url in asm_discovery_urls: # pragma: no branch
oauth_metadata_request = create_oauth_metadata_request(url)
oauth_metadata_response = yield oauth_metadata_request

ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
if not ok:
break
if ok and asm:
self.context.oauth_metadata = asm
break
else:
logger.debug(f"OAuth metadata discovery failed: {url}")

# Step 3: Apply scope selection strategy
self.context.client_metadata.scope = get_client_metadata_scopes(
extract_scope_from_www_auth(response),
self.context.protected_resource_metadata,
self.context.oauth_metadata,
)

# Step 4: Register client or use URL-based client ID (CIMD)
if not self.context.client_info:
if should_use_client_metadata_url(
self.context.oauth_metadata, self.context.client_metadata_url
):
# Use URL-based client ID (CIMD)
logger.debug(f"Using URL-based client ID (CIMD): {self.context.client_metadata_url}")
client_information = create_client_info_from_metadata_url(
self.context.client_metadata_url, # type: ignore[arg-type]
redirect_uris=self.context.client_metadata.redirect_uris,
)
self.context.client_info = client_information
await self.context.storage.set_client_info(client_information)
else:
# Fallback to Dynamic Client Registration
registration_request = create_client_registration_request(
self.context.oauth_metadata,
self.context.client_metadata,
self.context.get_authorization_base_url(self.context.server_url),
)
registration_response = yield registration_request
client_information = await handle_registration_response(registration_response)
self.context.client_info = client_information
await self.context.storage.set_client_info(client_information)

# Step 5: Perform authorization and complete token exchange
token_response = yield await self._perform_authorization()
await self._handle_token_response(token_response)
oauth_gen = self._perform_oauth_discovery_and_auth(www_auth_response=response)
oauth_request = await oauth_gen.__anext__()
while True:
oauth_response = yield oauth_request
oauth_request = await oauth_gen.asend(oauth_response)
except StopAsyncIteration:
pass
except Exception: # pragma: no cover
logger.exception("OAuth flow error")
raise
Expand Down
Loading
Loading