Skip to content

Commit efe9d40

Browse files
authored
Token scopes context (#1997)
* Move scope storage into its own context key, separately from token info. This allows us to provide scopes seperately in the remote server, where we have scopes before we do the auth. * Skip token extraction if token info already exists in context. This is to avoid redundant token extraction in remote setup where token info may have already been extracted earlier in the request lifecycle. * Check for existing scopes in context before fetching from GitHub API in scope challenge middleware * Return error type for unknown tools in inventory builder and handle it in HTTP handler
1 parent d44894e commit efe9d40

File tree

7 files changed

+79
-34
lines changed

7 files changed

+79
-34
lines changed

pkg/context/token.go

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,37 @@ import (
66
"github.com/github/github-mcp-server/pkg/utils"
77
)
88

9-
// tokenCtxKey is a context key for authentication token information
10-
type tokenCtx string
11-
12-
var tokenCtxKey tokenCtx = "tokenctx"
9+
type tokenCtxKey struct{}
1310

1411
type TokenInfo struct {
15-
Token string
16-
TokenType utils.TokenType
17-
ScopesFetched bool
18-
Scopes []string
12+
Token string
13+
TokenType utils.TokenType
1914
}
2015

2116
// WithTokenInfo adds TokenInfo to the context
2217
func WithTokenInfo(ctx context.Context, tokenInfo *TokenInfo) context.Context {
23-
return context.WithValue(ctx, tokenCtxKey, tokenInfo)
18+
return context.WithValue(ctx, tokenCtxKey{}, tokenInfo)
2419
}
2520

2621
// GetTokenInfo retrieves the authentication token from the context
2722
func GetTokenInfo(ctx context.Context) (*TokenInfo, bool) {
28-
if tokenInfo, ok := ctx.Value(tokenCtxKey).(*TokenInfo); ok {
23+
if tokenInfo, ok := ctx.Value(tokenCtxKey{}).(*TokenInfo); ok {
2924
return tokenInfo, true
3025
}
3126
return nil, false
3227
}
28+
29+
type tokenScopesKey struct{}
30+
31+
// WithTokenScopes adds token scopes to the context
32+
func WithTokenScopes(ctx context.Context, scopes []string) context.Context {
33+
return context.WithValue(ctx, tokenScopesKey{}, scopes)
34+
}
35+
36+
// GetTokenScopes retrieves token scopes from the context
37+
func GetTokenScopes(ctx context.Context) ([]string, bool) {
38+
if scopes, ok := ctx.Value(tokenScopesKey{}).([]string); ok {
39+
return scopes, true
40+
}
41+
return nil, false
42+
}

pkg/http/handler.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package http
22

33
import (
44
"context"
5+
"errors"
56
"log/slog"
67
"net/http"
78

@@ -178,6 +179,14 @@ func withInsiders(next http.Handler) http.Handler {
178179
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
179180
inv, err := h.inventoryFactoryFunc(r)
180181
if err != nil {
182+
if errors.Is(err, inventory.ErrUnknownTools) {
183+
w.WriteHeader(http.StatusBadRequest)
184+
if _, writeErr := w.Write([]byte(err.Error())); writeErr != nil {
185+
h.logger.Error("failed to write response", "error", writeErr)
186+
}
187+
return
188+
}
189+
181190
w.WriteHeader(http.StatusInternalServerError)
182191
return
183192
}
@@ -278,8 +287,10 @@ func PATScopeFilter(b *inventory.Builder, r *http.Request, fetcher scopes.Fetche
278287
// Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header.
279288
// Fine-grained PATs and other token types don't support this, so we skip filtering.
280289
if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken {
281-
if tokenInfo.ScopesFetched {
282-
return b.WithFilter(github.CreateToolScopeFilter(tokenInfo.Scopes))
290+
// Check if scopes are already in context (should be set by WithPATScopes). If not, fetch them.
291+
existingScopes, ok := ghcontext.GetTokenScopes(ctx)
292+
if ok {
293+
return b.WithFilter(github.CreateToolScopeFilter(existingScopes))
283294
}
284295

285296
scopesList, err := fetcher.FetchTokenScopes(ctx, tokenInfo.Token)

pkg/http/middleware/pat_scope.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,22 @@ func WithPATScopes(logger *slog.Logger, scopeFetcher scopes.FetcherInterface) fu
2626
// Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header.
2727
// Fine-grained PATs and other token types don't support this, so we skip filtering.
2828
if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken {
29+
existingScopes, ok := ghcontext.GetTokenScopes(ctx)
30+
if ok {
31+
logger.Debug("using existing scopes from context", "scopes", existingScopes)
32+
next.ServeHTTP(w, r)
33+
return
34+
}
35+
2936
scopesList, err := scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token)
3037
if err != nil {
3138
logger.Warn("failed to fetch PAT scopes", "error", err)
3239
next.ServeHTTP(w, r)
3340
return
3441
}
3542

36-
tokenInfo.Scopes = scopesList
37-
tokenInfo.ScopesFetched = true
38-
3943
// Store fetched scopes in context for downstream use
40-
ctx := ghcontext.WithTokenInfo(ctx, tokenInfo)
44+
ctx = ghcontext.WithTokenScopes(ctx, scopesList)
4145

4246
next.ServeHTTP(w, r.WithContext(ctx))
4347
return

pkg/http/middleware/pat_scope_test.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,13 @@ func TestWithPATScopes(t *testing.T) {
111111

112112
for _, tt := range tests {
113113
t.Run(tt.name, func(t *testing.T) {
114-
var capturedTokenInfo *ghcontext.TokenInfo
114+
var capturedScopes []string
115+
var scopesFound bool
115116
var nextHandlerCalled bool
116117

117118
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
118119
nextHandlerCalled = true
119-
capturedTokenInfo, _ = ghcontext.GetTokenInfo(r.Context())
120+
capturedScopes, scopesFound = ghcontext.GetTokenScopes(r.Context())
120121
w.WriteHeader(http.StatusOK)
121122
})
122123

@@ -141,10 +142,9 @@ func TestWithPATScopes(t *testing.T) {
141142

142143
assert.Equal(t, tt.expectNextHandlerCalled, nextHandlerCalled, "next handler called mismatch")
143144

144-
if tt.expectNextHandlerCalled && tt.tokenInfo != nil {
145-
require.NotNil(t, capturedTokenInfo, "expected token info in context")
146-
assert.Equal(t, tt.expectScopesFetched, capturedTokenInfo.ScopesFetched)
147-
assert.Equal(t, tt.expectedScopes, capturedTokenInfo.Scopes)
145+
if tt.expectNextHandlerCalled {
146+
assert.Equal(t, tt.expectScopesFetched, scopesFound, "scopes found mismatch")
147+
assert.Equal(t, tt.expectedScopes, capturedScopes)
148148
}
149149
})
150150
}
@@ -154,9 +154,12 @@ func TestWithPATScopes_PreservesExistingTokenInfo(t *testing.T) {
154154
logger := slog.Default()
155155

156156
var capturedTokenInfo *ghcontext.TokenInfo
157+
var capturedScopes []string
158+
var scopesFound bool
157159

158160
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
159161
capturedTokenInfo, _ = ghcontext.GetTokenInfo(r.Context())
162+
capturedScopes, scopesFound = ghcontext.GetTokenScopes(r.Context())
160163
w.WriteHeader(http.StatusOK)
161164
})
162165

@@ -182,6 +185,6 @@ func TestWithPATScopes_PreservesExistingTokenInfo(t *testing.T) {
182185
require.NotNil(t, capturedTokenInfo)
183186
assert.Equal(t, originalTokenInfo.Token, capturedTokenInfo.Token)
184187
assert.Equal(t, originalTokenInfo.TokenType, capturedTokenInfo.TokenType)
185-
assert.True(t, capturedTokenInfo.ScopesFetched)
186-
assert.Equal(t, []string{"repo", "user"}, capturedTokenInfo.Scopes)
188+
assert.True(t, scopesFound)
189+
assert.Equal(t, []string{"repo", "user"}, capturedScopes)
187190
}

pkg/http/middleware/scope_challenge.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,19 @@ func WithScopeChallenge(oauthCfg *oauth.Config, scopeFetcher scopes.FetcherInter
9494
return
9595
}
9696

97-
// Get OAuth scopes from GitHub API
98-
activeScopes, err := scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token)
99-
if err != nil {
100-
next.ServeHTTP(w, r)
101-
return
97+
// Get OAuth scopes for Token. First check if scopes are already in context, then fetch from GitHub if not present.
98+
// This allows Remote Server to pass scope info to avoid redundant GitHub API calls.
99+
activeScopes, ok := ghcontext.GetTokenScopes(ctx)
100+
if !ok || (len(activeScopes) == 0 && tokenInfo.Token != "") {
101+
activeScopes, err = scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token)
102+
if err != nil {
103+
next.ServeHTTP(w, r)
104+
return
105+
}
102106
}
103107

104108
// Store active scopes in context for downstream use
105-
tokenInfo.Scopes = activeScopes
106-
tokenInfo.ScopesFetched = true
107-
ctx = ghcontext.WithTokenInfo(ctx, tokenInfo)
109+
ctx = ghcontext.WithTokenScopes(ctx, activeScopes)
108110
r = r.WithContext(ctx)
109111

110112
// Check if user has the required scopes

pkg/http/middleware/token.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,16 @@ import (
1313
func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handler {
1414
return func(next http.Handler) http.Handler {
1515
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
16+
ctx := r.Context()
17+
18+
// Check if token info already exists in context, if it does, skip extraction.
19+
// In remote setup, we may have already extracted token info earlier.
20+
if _, ok := ghcontext.GetTokenInfo(ctx); ok {
21+
// Token info already exists in context, skip extraction
22+
next.ServeHTTP(w, r)
23+
return
24+
}
25+
1626
tokenType, token, err := utils.ParseAuthorizationHeader(r)
1727
if err != nil {
1828
// For missing Authorization header, return 401 with WWW-Authenticate header per MCP spec
@@ -25,7 +35,6 @@ func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handl
2535
return
2636
}
2737

28-
ctx := r.Context()
2938
ctx = ghcontext.WithTokenInfo(ctx, &ghcontext.TokenInfo{
3039
Token: token,
3140
TokenType: tokenType,

pkg/inventory/builder.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,18 @@ package inventory
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"maps"
78
"slices"
89
"strings"
910
)
1011

12+
var (
13+
// ErrUnknownTools is returned when tools specified via WithTools() are not recognized.
14+
ErrUnknownTools = errors.New("unknown tools specified in WithTools")
15+
)
16+
1117
// ToolFilter is a function that determines if a tool should be included.
1218
// Returns true if the tool should be included, false to exclude it.
1319
type ToolFilter func(ctx context.Context, tool *ServerTool) (bool, error)
@@ -219,7 +225,7 @@ func (b *Builder) Build() (*Inventory, error) {
219225

220226
// Error out if there are unrecognized tools
221227
if len(unrecognizedTools) > 0 {
222-
return nil, fmt.Errorf("unrecognized tools: %s", strings.Join(unrecognizedTools, ", "))
228+
return nil, fmt.Errorf("%w: %s", ErrUnknownTools, strings.Join(unrecognizedTools, ", "))
223229
}
224230
}
225231

0 commit comments

Comments
 (0)