Skip to content

Commit f20722e

Browse files
authored
Merge branch 'main' into add-project-status-updates
2 parents 3e0de48 + efe9d40 commit f20722e

File tree

7 files changed

+79
-34
lines changed

7 files changed

+79
-34
lines changed

pkg/context/token.go

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,37 @@ import (
66
"github.com/github/github-mcp-server/pkg/utils"
77
)
88

9-
// tokenCtxKey is a context key for authentication token information
10-
type tokenCtx string
11-
12-
var tokenCtxKey tokenCtx = "tokenctx"
9+
type tokenCtxKey struct{}
1310

1411
type TokenInfo struct {
15-
Token string
16-
TokenType utils.TokenType
17-
ScopesFetched bool
18-
Scopes []string
12+
Token string
13+
TokenType utils.TokenType
1914
}
2015

2116
// WithTokenInfo adds TokenInfo to the context
2217
func WithTokenInfo(ctx context.Context, tokenInfo *TokenInfo) context.Context {
23-
return context.WithValue(ctx, tokenCtxKey, tokenInfo)
18+
return context.WithValue(ctx, tokenCtxKey{}, tokenInfo)
2419
}
2520

2621
// GetTokenInfo retrieves the authentication token from the context
2722
func GetTokenInfo(ctx context.Context) (*TokenInfo, bool) {
28-
if tokenInfo, ok := ctx.Value(tokenCtxKey).(*TokenInfo); ok {
23+
if tokenInfo, ok := ctx.Value(tokenCtxKey{}).(*TokenInfo); ok {
2924
return tokenInfo, true
3025
}
3126
return nil, false
3227
}
28+
29+
type tokenScopesKey struct{}
30+
31+
// WithTokenScopes adds token scopes to the context
32+
func WithTokenScopes(ctx context.Context, scopes []string) context.Context {
33+
return context.WithValue(ctx, tokenScopesKey{}, scopes)
34+
}
35+
36+
// GetTokenScopes retrieves token scopes from the context
37+
func GetTokenScopes(ctx context.Context) ([]string, bool) {
38+
if scopes, ok := ctx.Value(tokenScopesKey{}).([]string); ok {
39+
return scopes, true
40+
}
41+
return nil, false
42+
}

pkg/http/handler.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package http
22

33
import (
44
"context"
5+
"errors"
56
"log/slog"
67
"net/http"
78

@@ -178,6 +179,14 @@ func withInsiders(next http.Handler) http.Handler {
178179
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
179180
inv, err := h.inventoryFactoryFunc(r)
180181
if err != nil {
182+
if errors.Is(err, inventory.ErrUnknownTools) {
183+
w.WriteHeader(http.StatusBadRequest)
184+
if _, writeErr := w.Write([]byte(err.Error())); writeErr != nil {
185+
h.logger.Error("failed to write response", "error", writeErr)
186+
}
187+
return
188+
}
189+
181190
w.WriteHeader(http.StatusInternalServerError)
182191
return
183192
}
@@ -278,8 +287,10 @@ func PATScopeFilter(b *inventory.Builder, r *http.Request, fetcher scopes.Fetche
278287
// Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header.
279288
// Fine-grained PATs and other token types don't support this, so we skip filtering.
280289
if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken {
281-
if tokenInfo.ScopesFetched {
282-
return b.WithFilter(github.CreateToolScopeFilter(tokenInfo.Scopes))
290+
// Check if scopes are already in context (should be set by WithPATScopes). If not, fetch them.
291+
existingScopes, ok := ghcontext.GetTokenScopes(ctx)
292+
if ok {
293+
return b.WithFilter(github.CreateToolScopeFilter(existingScopes))
283294
}
284295

285296
scopesList, err := fetcher.FetchTokenScopes(ctx, tokenInfo.Token)

pkg/http/middleware/pat_scope.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,22 @@ func WithPATScopes(logger *slog.Logger, scopeFetcher scopes.FetcherInterface) fu
2626
// Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header.
2727
// Fine-grained PATs and other token types don't support this, so we skip filtering.
2828
if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken {
29+
existingScopes, ok := ghcontext.GetTokenScopes(ctx)
30+
if ok {
31+
logger.Debug("using existing scopes from context", "scopes", existingScopes)
32+
next.ServeHTTP(w, r)
33+
return
34+
}
35+
2936
scopesList, err := scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token)
3037
if err != nil {
3138
logger.Warn("failed to fetch PAT scopes", "error", err)
3239
next.ServeHTTP(w, r)
3340
return
3441
}
3542

36-
tokenInfo.Scopes = scopesList
37-
tokenInfo.ScopesFetched = true
38-
3943
// Store fetched scopes in context for downstream use
40-
ctx := ghcontext.WithTokenInfo(ctx, tokenInfo)
44+
ctx = ghcontext.WithTokenScopes(ctx, scopesList)
4145

4246
next.ServeHTTP(w, r.WithContext(ctx))
4347
return

pkg/http/middleware/pat_scope_test.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,13 @@ func TestWithPATScopes(t *testing.T) {
111111

112112
for _, tt := range tests {
113113
t.Run(tt.name, func(t *testing.T) {
114-
var capturedTokenInfo *ghcontext.TokenInfo
114+
var capturedScopes []string
115+
var scopesFound bool
115116
var nextHandlerCalled bool
116117

117118
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
118119
nextHandlerCalled = true
119-
capturedTokenInfo, _ = ghcontext.GetTokenInfo(r.Context())
120+
capturedScopes, scopesFound = ghcontext.GetTokenScopes(r.Context())
120121
w.WriteHeader(http.StatusOK)
121122
})
122123

@@ -141,10 +142,9 @@ func TestWithPATScopes(t *testing.T) {
141142

142143
assert.Equal(t, tt.expectNextHandlerCalled, nextHandlerCalled, "next handler called mismatch")
143144

144-
if tt.expectNextHandlerCalled && tt.tokenInfo != nil {
145-
require.NotNil(t, capturedTokenInfo, "expected token info in context")
146-
assert.Equal(t, tt.expectScopesFetched, capturedTokenInfo.ScopesFetched)
147-
assert.Equal(t, tt.expectedScopes, capturedTokenInfo.Scopes)
145+
if tt.expectNextHandlerCalled {
146+
assert.Equal(t, tt.expectScopesFetched, scopesFound, "scopes found mismatch")
147+
assert.Equal(t, tt.expectedScopes, capturedScopes)
148148
}
149149
})
150150
}
@@ -154,9 +154,12 @@ func TestWithPATScopes_PreservesExistingTokenInfo(t *testing.T) {
154154
logger := slog.Default()
155155

156156
var capturedTokenInfo *ghcontext.TokenInfo
157+
var capturedScopes []string
158+
var scopesFound bool
157159

158160
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
159161
capturedTokenInfo, _ = ghcontext.GetTokenInfo(r.Context())
162+
capturedScopes, scopesFound = ghcontext.GetTokenScopes(r.Context())
160163
w.WriteHeader(http.StatusOK)
161164
})
162165

@@ -182,6 +185,6 @@ func TestWithPATScopes_PreservesExistingTokenInfo(t *testing.T) {
182185
require.NotNil(t, capturedTokenInfo)
183186
assert.Equal(t, originalTokenInfo.Token, capturedTokenInfo.Token)
184187
assert.Equal(t, originalTokenInfo.TokenType, capturedTokenInfo.TokenType)
185-
assert.True(t, capturedTokenInfo.ScopesFetched)
186-
assert.Equal(t, []string{"repo", "user"}, capturedTokenInfo.Scopes)
188+
assert.True(t, scopesFound)
189+
assert.Equal(t, []string{"repo", "user"}, capturedScopes)
187190
}

pkg/http/middleware/scope_challenge.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,19 @@ func WithScopeChallenge(oauthCfg *oauth.Config, scopeFetcher scopes.FetcherInter
9494
return
9595
}
9696

97-
// Get OAuth scopes from GitHub API
98-
activeScopes, err := scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token)
99-
if err != nil {
100-
next.ServeHTTP(w, r)
101-
return
97+
// Get OAuth scopes for Token. First check if scopes are already in context, then fetch from GitHub if not present.
98+
// This allows Remote Server to pass scope info to avoid redundant GitHub API calls.
99+
activeScopes, ok := ghcontext.GetTokenScopes(ctx)
100+
if !ok || (len(activeScopes) == 0 && tokenInfo.Token != "") {
101+
activeScopes, err = scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token)
102+
if err != nil {
103+
next.ServeHTTP(w, r)
104+
return
105+
}
102106
}
103107

104108
// Store active scopes in context for downstream use
105-
tokenInfo.Scopes = activeScopes
106-
tokenInfo.ScopesFetched = true
107-
ctx = ghcontext.WithTokenInfo(ctx, tokenInfo)
109+
ctx = ghcontext.WithTokenScopes(ctx, activeScopes)
108110
r = r.WithContext(ctx)
109111

110112
// Check if user has the required scopes

pkg/http/middleware/token.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,16 @@ import (
1313
func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handler {
1414
return func(next http.Handler) http.Handler {
1515
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
16+
ctx := r.Context()
17+
18+
// Check if token info already exists in context, if it does, skip extraction.
19+
// In remote setup, we may have already extracted token info earlier.
20+
if _, ok := ghcontext.GetTokenInfo(ctx); ok {
21+
// Token info already exists in context, skip extraction
22+
next.ServeHTTP(w, r)
23+
return
24+
}
25+
1626
tokenType, token, err := utils.ParseAuthorizationHeader(r)
1727
if err != nil {
1828
// For missing Authorization header, return 401 with WWW-Authenticate header per MCP spec
@@ -25,7 +35,6 @@ func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handl
2535
return
2636
}
2737

28-
ctx := r.Context()
2938
ctx = ghcontext.WithTokenInfo(ctx, &ghcontext.TokenInfo{
3039
Token: token,
3140
TokenType: tokenType,

pkg/inventory/builder.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,18 @@ package inventory
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"maps"
78
"slices"
89
"strings"
910
)
1011

12+
var (
13+
// ErrUnknownTools is returned when tools specified via WithTools() are not recognized.
14+
ErrUnknownTools = errors.New("unknown tools specified in WithTools")
15+
)
16+
1117
// ToolFilter is a function that determines if a tool should be included.
1218
// Returns true if the tool should be included, false to exclude it.
1319
type ToolFilter func(ctx context.Context, tool *ServerTool) (bool, error)
@@ -219,7 +225,7 @@ func (b *Builder) Build() (*Inventory, error) {
219225

220226
// Error out if there are unrecognized tools
221227
if len(unrecognizedTools) > 0 {
222-
return nil, fmt.Errorf("unrecognized tools: %s", strings.Join(unrecognizedTools, ", "))
228+
return nil, fmt.Errorf("%w: %s", ErrUnknownTools, strings.Join(unrecognizedTools, ", "))
223229
}
224230
}
225231

0 commit comments

Comments
 (0)