diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 20ba2711b..ba987e63c 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -109,6 +109,7 @@ var ( ContentWindowSize: viper.GetInt("content-window-size"), LockdownMode: viper.GetBool("lockdown-mode"), RepoAccessCacheTTL: &ttl, + ScopeChallenge: viper.GetBool("scope-challenge"), } return ghhttp.RunHTTPServer(httpConfig) @@ -141,6 +142,7 @@ func init() { httpCmd.Flags().Int("port", 8082, "HTTP server port") httpCmd.Flags().String("base-url", "", "Base URL where this server is publicly accessible (for OAuth resource metadata)") httpCmd.Flags().String("base-path", "", "Externally visible base path for the HTTP server (for OAuth resource metadata)") + httpCmd.Flags().Bool("scope-challenge", false, "Enable OAuth scope challenge responses and tool filtering based on token scopes") // Bind flag to viper _ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets")) @@ -159,7 +161,7 @@ func init() { _ = viper.BindPFlag("port", httpCmd.Flags().Lookup("port")) _ = viper.BindPFlag("base-url", httpCmd.Flags().Lookup("base-url")) _ = viper.BindPFlag("base-path", httpCmd.Flags().Lookup("base-path")) - + _ = viper.BindPFlag("scope-challenge", httpCmd.Flags().Lookup("scope-challenge")) // Add subcommands rootCmd.AddCommand(stdioCmd) rootCmd.AddCommand(httpCmd) diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index bb0cc277b..7ffb457ce 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -366,14 +366,7 @@ func fetchTokenScopesForHost(ctx context.Context, token, host string) ([]string, return nil, fmt.Errorf("failed to parse API host: %w", err) } - baseRestURL, err := apiHost.BaseRESTURL(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get base REST URL: %w", err) - } - - fetcher := scopes.NewFetcher(scopes.FetcherOptions{ - APIHost: baseRestURL.String(), - }) + fetcher := scopes.NewFetcher(apiHost, scopes.FetcherOptions{}) return fetcher.FetchTokenScopes(ctx, token) } diff --git a/pkg/context/mcp_info.go b/pkg/context/mcp_info.go new file mode 100644 index 000000000..ce5505682 --- /dev/null +++ b/pkg/context/mcp_info.go @@ -0,0 +1,39 @@ +package context + +import "context" + +type mcpMethodInfoCtx string + +var mcpMethodInfoCtxKey mcpMethodInfoCtx = "mcpmethodinfo" + +// MCPMethodInfo contains pre-parsed MCP method information extracted from the JSON-RPC request. +// This is populated early in the request lifecycle to enable: +// - Inventory filtering via ForMCPRequest (only register needed tools/resources/prompts) +// - Avoiding duplicate JSON parsing in middlewares (secret-scanning, scope-challenge) +// - Performance optimization for per-request server creation +type MCPMethodInfo struct { + // Method is the MCP method being called (e.g., "tools/call", "tools/list", "initialize") + Method string + // ItemName is the name of the specific item being accessed (tool name, resource URI, prompt name) + // Only populated for call/get methods (tools/call, prompts/get, resources/read) + ItemName string + // Owner is the repository owner from tool call arguments, if present + Owner string + // Repo is the repository name from tool call arguments, if present + Repo string + // Arguments contains the raw tool arguments for tools/call requests + Arguments map[string]any +} + +// WithMCPMethodInfo stores the MCPMethodInfo in the context. +func WithMCPMethodInfo(ctx context.Context, info *MCPMethodInfo) context.Context { + return context.WithValue(ctx, mcpMethodInfoCtxKey, info) +} + +// MCPMethod retrieves the MCPMethodInfo from the context. +func MCPMethod(ctx context.Context) (*MCPMethodInfo, bool) { + if info, ok := ctx.Value(mcpMethodInfoCtxKey).(*MCPMethodInfo); ok { + return info, true + } + return nil, false +} diff --git a/pkg/context/token.go b/pkg/context/token.go index dd303f029..27f276740 100644 --- a/pkg/context/token.go +++ b/pkg/context/token.go @@ -1,19 +1,39 @@ package context -import "context" +import ( + "context" + + "github.com/github/github-mcp-server/pkg/utils" +) // tokenCtxKey is a context key for authentication token information -type tokenCtxKey struct{} +type tokenCtx string + +var tokenCtxKey tokenCtx = "tokenctx" + +type TokenInfo struct { + Token string + TokenType utils.TokenType + ScopesFetched bool + Scopes []string +} // WithTokenInfo adds TokenInfo to the context -func WithTokenInfo(ctx context.Context, token string) context.Context { - return context.WithValue(ctx, tokenCtxKey{}, token) +func WithTokenInfo(ctx context.Context, tokenInfo *TokenInfo) context.Context { + return context.WithValue(ctx, tokenCtxKey, tokenInfo) +} + +func SetTokenScopes(ctx context.Context, scopes []string) { + if tokenInfo, ok := GetTokenInfo(ctx); ok { + tokenInfo.Scopes = scopes + tokenInfo.ScopesFetched = true + } } // GetTokenInfo retrieves the authentication token from the context -func GetTokenInfo(ctx context.Context) (string, bool) { - if token, ok := ctx.Value(tokenCtxKey{}).(string); ok { - return token, true +func GetTokenInfo(ctx context.Context) (*TokenInfo, bool) { + if tokenInfo, ok := ctx.Value(tokenCtxKey).(*TokenInfo); ok { + return tokenInfo, true } - return "", false + return nil, false } diff --git a/pkg/github/dependencies.go b/pkg/github/dependencies.go index 75804ad1f..8d656d0bd 100644 --- a/pkg/github/dependencies.go +++ b/pkg/github/dependencies.go @@ -282,7 +282,11 @@ func (d *RequestDeps) GetClient(ctx context.Context) (*gogithub.Client, error) { } // extract the token from the context - token, _ := ghcontext.GetTokenInfo(ctx) + tokenInfo, ok := ghcontext.GetTokenInfo(ctx) + if !ok { + return nil, fmt.Errorf("no token info in context") + } + token := tokenInfo.Token baseRestURL, err := d.apiHosts.BaseRESTURL(ctx) if err != nil { @@ -308,7 +312,11 @@ func (d *RequestDeps) GetGQLClient(ctx context.Context) (*githubv4.Client, error } // extract the token from the context - token, _ := ghcontext.GetTokenInfo(ctx) + tokenInfo, ok := ghcontext.GetTokenInfo(ctx) + if !ok { + return nil, fmt.Errorf("no token info in context") + } + token := tokenInfo.Token // Construct GraphQL client // We use NewEnterpriseClient unconditionally since we already parsed the API host diff --git a/pkg/github/server.go b/pkg/github/server.go index fddd85123..203dcabbd 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -8,6 +8,7 @@ import ( "strings" "time" + ghcontext "github.com/github/github-mcp-server/pkg/context" gherrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/octicons" @@ -73,10 +74,10 @@ type MCPServerConfig struct { type MCPServerOption func(*mcp.ServerOptions) -func NewMCPServer(ctx context.Context, cfg *MCPServerConfig, deps ToolDependencies, inventory *inventory.Inventory) (*mcp.Server, error) { +func NewMCPServer(ctx context.Context, cfg *MCPServerConfig, deps ToolDependencies, inv *inventory.Inventory) (*mcp.Server, error) { // Create the MCP server serverOpts := &mcp.ServerOptions{ - Instructions: inventory.Instructions(), + Instructions: inv.Instructions(), Logger: cfg.Logger, CompletionHandler: CompletionsHandler(deps.GetClient), } @@ -102,20 +103,25 @@ func NewMCPServer(ctx context.Context, cfg *MCPServerConfig, deps ToolDependenci ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext) ghServer.AddReceivingMiddleware(InjectDepsMiddleware(deps)) - if unrecognized := inventory.UnrecognizedToolsets(); len(unrecognized) > 0 { + if unrecognized := inv.UnrecognizedToolsets(); len(unrecognized) > 0 { cfg.Logger.Warn("Warning: unrecognized toolsets ignored", "toolsets", strings.Join(unrecognized, ", ")) } + invToUse := inv + if methodInfo, ok := ghcontext.MCPMethod(ctx); ok && methodInfo != nil { + invToUse = inv.ForMCPRequest(methodInfo.Method, methodInfo.ItemName) + } + // Register GitHub tools/resources/prompts from the inventory. // In dynamic mode with no explicit toolsets, this is a no-op since enabledToolsets // is empty - users enable toolsets at runtime via the dynamic tools below (but can // enable toolsets or tools explicitly that do need registration). - inventory.RegisterAll(ctx, ghServer, deps) + invToUse.RegisterAll(ctx, ghServer, deps) // Register dynamic toolset management tools (enable/disable) - these are separate // meta-tools that control the inventory, not part of the inventory itself if cfg.DynamicToolsets { - registerDynamicTools(ghServer, inventory, deps, cfg.Translator) + registerDynamicTools(ghServer, invToUse, deps, cfg.Translator) } return ghServer, nil diff --git a/pkg/http/handler.go b/pkg/http/handler.go index 9bb98b86b..c529f7405 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -10,7 +10,9 @@ import ( "github.com/github/github-mcp-server/pkg/http/middleware" "github.com/github/github-mcp-server/pkg/http/oauth" "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/scopes" "github.com/github/github-mcp-server/pkg/translations" + "github.com/github/github-mcp-server/pkg/utils" "github.com/go-chi/chi/v5" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -23,21 +25,30 @@ type Handler struct { config *ServerConfig deps github.ToolDependencies logger *slog.Logger + apiHosts utils.APIHostResolver t translations.TranslationHelperFunc githubMcpServerFactory GitHubMCPServerFactoryFunc inventoryFactoryFunc InventoryFactoryFunc oauthCfg *oauth.Config + scopeFetcher scopes.FetcherInterface } type HandlerOptions struct { GitHubMcpServerFactory GitHubMCPServerFactoryFunc InventoryFactory InventoryFactoryFunc OAuthConfig *oauth.Config + ScopeFetcher scopes.FetcherInterface FeatureChecker inventory.FeatureFlagChecker } type HandlerOption func(*HandlerOptions) +func WithScopeFetcher(f scopes.FetcherInterface) HandlerOption { + return func(o *HandlerOptions) { + o.ScopeFetcher = f + } +} + func WithGitHubMCPServerFactory(f GitHubMCPServerFactoryFunc) HandlerOption { return func(o *HandlerOptions) { o.GitHubMcpServerFactory = f @@ -68,6 +79,7 @@ func NewHTTPMcpHandler( deps github.ToolDependencies, t translations.TranslationHelperFunc, logger *slog.Logger, + apiHost utils.APIHostResolver, options ...HandlerOption) *Handler { opts := &HandlerOptions{} for _, o := range options { @@ -79,9 +91,14 @@ func NewHTTPMcpHandler( githubMcpServerFactory = DefaultGitHubMCPServerFactory } + scopeFetcher := opts.ScopeFetcher + if scopeFetcher == nil { + scopeFetcher = scopes.NewFetcher(apiHost, scopes.FetcherOptions{}) + } + inventoryFactory := opts.InventoryFactory if inventoryFactory == nil { - inventoryFactory = DefaultInventoryFactory(cfg, t, opts.FeatureChecker) + inventoryFactory = DefaultInventoryFactory(cfg, t, opts.FeatureChecker, scopeFetcher) } return &Handler{ @@ -89,10 +106,12 @@ func NewHTTPMcpHandler( config: cfg, deps: deps, logger: logger, + apiHosts: apiHost, t: t, githubMcpServerFactory: githubMcpServerFactory, inventoryFactoryFunc: inventoryFactory, oauthCfg: opts.OAuthConfig, + scopeFetcher: scopeFetcher, } } @@ -100,7 +119,12 @@ func (h *Handler) RegisterMiddleware(r chi.Router) { r.Use( middleware.ExtractUserToken(h.oauthCfg), middleware.WithRequestConfig, + middleware.WithMCPParse(), ) + + if h.config.ScopeChallenge { + r.Use(middleware.WithScopeChallenge(h.oauthCfg, h.scopeFetcher)) + } } // RegisterRoutes registers the routes for the MCP server @@ -177,13 +201,15 @@ func DefaultGitHubMCPServerFactory(r *http.Request, deps github.ToolDependencies } // DefaultInventoryFactory creates the default inventory factory for HTTP mode -func DefaultInventoryFactory(_ *ServerConfig, t translations.TranslationHelperFunc, featureChecker inventory.FeatureFlagChecker) InventoryFactoryFunc { +func DefaultInventoryFactory(_ *ServerConfig, t translations.TranslationHelperFunc, featureChecker inventory.FeatureFlagChecker, scopeFetcher scopes.FetcherInterface) InventoryFactoryFunc { return func(r *http.Request) (*inventory.Inventory, error) { b := github.NewInventory(t). WithDeprecatedAliases(github.DeprecatedToolAliases). WithFeatureChecker(featureChecker) b = InventoryFiltersForRequest(r, b) + b = PATScopeFilter(b, r, scopeFetcher) + b.WithServerInstructions() return b.Build() @@ -212,3 +238,29 @@ func InventoryFiltersForRequest(r *http.Request, builder *inventory.Builder) *in return builder } + +func PATScopeFilter(b *inventory.Builder, r *http.Request, fetcher scopes.FetcherInterface) *inventory.Builder { + ctx := r.Context() + + tokenInfo, ok := ghcontext.GetTokenInfo(ctx) + if !ok || tokenInfo == nil { + return b + } + + // Fetch token scopes for scope-based tool filtering (PAT tokens only) + // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. + // Fine-grained PATs and other token types don't support this, so we skip filtering. + if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken { + scopesList, err := fetcher.FetchTokenScopes(ctx, tokenInfo.Token) + if err != nil { + return b + } + + // Store fetched scopes in context for downstream use + ghcontext.SetTokenScopes(ctx, scopesList) + + return b.WithFilter(github.CreateToolScopeFilter(scopesList)) + } + + return b +} diff --git a/pkg/http/handler_test.go b/pkg/http/handler_test.go index 70258436c..c92075569 100644 --- a/pkg/http/handler_test.go +++ b/pkg/http/handler_test.go @@ -11,9 +11,10 @@ import ( ghcontext "github.com/github/github-mcp-server/pkg/context" "github.com/github/github-mcp-server/pkg/github" "github.com/github/github-mcp-server/pkg/http/headers" - "github.com/github/github-mcp-server/pkg/http/middleware" "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/scopes" "github.com/github/github-mcp-server/pkg/translations" + "github.com/github/github-mcp-server/pkg/utils" "github.com/go-chi/chi/v5" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/stretchr/testify/assert" @@ -33,6 +34,20 @@ func mockTool(name, toolsetID string, readOnly bool) inventory.ServerTool { } } +type allScopesFetcher struct{} + +func (f allScopesFetcher) FetchTokenScopes(_ context.Context, _ string) ([]string, error) { + return []string{ + string(scopes.Repo), + string(scopes.WriteOrg), + string(scopes.User), + string(scopes.Gist), + string(scopes.Notifications), + }, nil +} + +var _ scopes.FetcherInterface = allScopesFetcher{} + func mockToolWithFeatureFlag(name, toolsetID string, readOnly bool, enableFlag, disableFlag string) inventory.ServerTool { tool := mockTool(name, toolsetID, readOnly) tool.FeatureFlagEnable = enableFlag @@ -261,6 +276,9 @@ func TestHTTPHandlerRoutes(t *testing.T) { // Create feature checker that reads from context (same as production) featureChecker := createHTTPFeatureChecker() + apiHost, err := utils.NewAPIHost("https://api.github.com") + require.NoError(t, err) + // Create inventory factory that captures the built inventory inventoryFactory := func(r *http.Request) (*inventory.Inventory, error) { capturedCtx = r.Context() @@ -282,6 +300,8 @@ func TestHTTPHandlerRoutes(t *testing.T) { return mcp.NewServer(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil), nil } + allScopesFetcher := allScopesFetcher{} + // Create handler with our factories handler := NewHTTPMcpHandler( context.Background(), @@ -289,17 +309,23 @@ func TestHTTPHandlerRoutes(t *testing.T) { nil, // deps not needed for this test translations.NullTranslationHelper, slog.Default(), + apiHost, WithInventoryFactory(inventoryFactory), WithGitHubMCPServerFactory(mcpServerFactory), + WithScopeFetcher(allScopesFetcher), ) // Create router and register routes r := chi.NewRouter() - r.Use(middleware.WithRequestConfig) + handler.RegisterMiddleware(r) handler.RegisterRoutes(r) // Create request req := httptest.NewRequest(http.MethodPost, tt.path, nil) + + // Ensure we're setting Authorization header for token context + req.Header.Set(headers.AuthorizationHeader, "Bearer ghp_testtoken") + for k, v := range tt.headers { req.Header.Set(k, v) } diff --git a/pkg/http/middleware/mcp_parse.go b/pkg/http/middleware/mcp_parse.go new file mode 100644 index 000000000..c82616b27 --- /dev/null +++ b/pkg/http/middleware/mcp_parse.go @@ -0,0 +1,126 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + + ghcontext "github.com/github/github-mcp-server/pkg/context" +) + +// mcpJSONRPCRequest represents the structure of an MCP JSON-RPC request. +// We only parse the fields needed for routing and optimization. +type mcpJSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params struct { + // For tools/call + Name string `json:"name,omitempty"` + Arguments json.RawMessage `json:"arguments,omitempty"` + // For prompts/get + // Name is shared with tools/call + // For resources/read + URI string `json:"uri,omitempty"` + } `json:"params"` +} + +// WithMCPParse creates a middleware that parses MCP JSON-RPC requests early in the +// request lifecycle and stores the parsed information in the request context. +// This enables: +// - Registry filtering via ForMCPRequest (only register needed tools/resources/prompts) +// - Avoiding duplicate JSON parsing in downstream middlewares +// - Access to owner/repo for secret-scanning middleware +// +// The middleware reads the request body, parses it, restores the body for downstream +// handlers, and stores the parsed MCPMethodInfo in the request context. +func WithMCPParse() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Skip health check endpoints + if r.URL.Path == "/_ping" { + next.ServeHTTP(w, r) + return + } + + // Only parse POST requests (MCP uses JSON-RPC over POST) + if r.Method != http.MethodPost { + next.ServeHTTP(w, r) + return + } + + // Read the request body + body, err := io.ReadAll(r.Body) + if err != nil { + // Log but continue - don't block requests on parse errors + next.ServeHTTP(w, r) + return + } + + // Restore the body for downstream handlers + r.Body = io.NopCloser(bytes.NewReader(body)) + + // Skip empty bodies + if len(body) == 0 { + next.ServeHTTP(w, r) + return + } + + // Parse the JSON-RPC request + var mcpReq mcpJSONRPCRequest + err = json.Unmarshal(body, &mcpReq) + if err != nil { + // Log but continue - could be a non-MCP request or malformed JSON + next.ServeHTTP(w, r) + return + } + + // Skip if not a valid JSON-RPC 2.0 request + if mcpReq.JSONRPC != "2.0" || mcpReq.Method == "" { + next.ServeHTTP(w, r) + return + } + + // Build the MCPMethodInfo + methodInfo := &ghcontext.MCPMethodInfo{ + Method: mcpReq.Method, + } + + // Extract item name based on method type + + switch mcpReq.Method { + case "tools/call": + methodInfo.ItemName = mcpReq.Params.Name + // Parse arguments if present + if len(mcpReq.Params.Arguments) > 0 { + var args map[string]any + err := json.Unmarshal(mcpReq.Params.Arguments, &args) + if err == nil { + methodInfo.Arguments = args + // Extract owner and repo if present + if owner, ok := args["owner"].(string); ok { + methodInfo.Owner = owner + } + if repo, ok := args["repo"].(string); ok { + methodInfo.Repo = repo + } + } + } + case "prompts/get": + methodInfo.ItemName = mcpReq.Params.Name + case "resources/read": + methodInfo.ItemName = mcpReq.Params.URI + default: + // Whatever + } + + // Store the parsed info in context + ctx = ghcontext.WithMCPMethodInfo(ctx, methodInfo) + + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(fn) + } +} diff --git a/pkg/http/middleware/scope_challenge.go b/pkg/http/middleware/scope_challenge.go new file mode 100644 index 000000000..da2f06752 --- /dev/null +++ b/pkg/http/middleware/scope_challenge.go @@ -0,0 +1,140 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/http/oauth" + "github.com/github/github-mcp-server/pkg/scopes" + "github.com/github/github-mcp-server/pkg/utils" +) + +// WithScopeChallenge creates a new middleware that determines if an OAuth request contains sufficient scopes to +// complete the request and returns a scope challenge if not. +func WithScopeChallenge(oauthCfg *oauth.Config, scopeFetcher scopes.FetcherInterface) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Skip health check endpoints + if r.URL.Path == "/_ping" { + next.ServeHTTP(w, r) + return + } + + // Get user from context + tokenInfo, ok := ghcontext.GetTokenInfo(ctx) + if !ok { + next.ServeHTTP(w, r) + return + } + + // Only check OAuth tokens - scope challenge allows OAuth apps to request additional scopes + if tokenInfo.TokenType != utils.TokenTypeOAuthAccessToken { + next.ServeHTTP(w, r) + return + } + + // Try to use pre-parsed MCP method info first (performance optimization) + // This avoids re-parsing the JSON body if WithMCPParse middleware ran earlier + var toolName string + if methodInfo, ok := ghcontext.MCPMethod(ctx); ok && methodInfo != nil { + // Only check tools/call requests + if methodInfo.Method != "tools/call" { + next.ServeHTTP(w, r) + return + } + toolName = methodInfo.ItemName + } else { + // Fallback: parse the request body directly + body, err := io.ReadAll(r.Body) + if err != nil { + next.ServeHTTP(w, r) + return + } + r.Body = io.NopCloser(bytes.NewReader(body)) + + var mcpRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params struct { + Name string `json:"name,omitempty"` + Arguments map[string]any `json:"arguments,omitempty"` + } `json:"params"` + } + + err = json.Unmarshal(body, &mcpRequest) + if err != nil { + next.ServeHTTP(w, r) + return + } + + // Only check tools/call requests + if mcpRequest.Method != "tools/call" { + next.ServeHTTP(w, r) + return + } + + toolName = mcpRequest.Params.Name + } + toolScopeInfo, err := scopes.GetToolScopeInfo(toolName) + if err != nil { + next.ServeHTTP(w, r) + return + } + + // If tool not found in scope map, allow the request + if toolScopeInfo == nil { + next.ServeHTTP(w, r) + return + } + + // Get OAuth scopes from GitHub API + activeScopes, err := scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token) + if err != nil { + next.ServeHTTP(w, r) + return + } + + // Store active scopes in context for downstream use + ghcontext.SetTokenScopes(ctx, activeScopes) + + // Check if user has the required scopes + if toolScopeInfo.HasAcceptedScope(activeScopes...) { + next.ServeHTTP(w, r) + return + } + + // User lacks required scopes - get the scopes they need + requiredScopes := toolScopeInfo.GetRequiredScopesSlice() + + // Build the resource metadata URL using the shared utility + // GetEffectiveResourcePath returns the original path (e.g., /mcp or /mcp/x/all) + // which is used to construct the well-known OAuth protected resource URL + resourcePath := oauth.ResolveResourcePath(r, oauthCfg) + resourceMetadataURL := oauth.BuildResourceMetadataURL(r, oauthCfg, resourcePath) + + // Build recommended scopes: existing scopes + required scopes + recommendedScopes := make([]string, 0, len(activeScopes)+len(requiredScopes)) + recommendedScopes = append(recommendedScopes, activeScopes...) + recommendedScopes = append(recommendedScopes, requiredScopes...) + + // Build the WWW-Authenticate header value + wwwAuthenticateHeader := fmt.Sprintf(`Bearer error="insufficient_scope", scope=%q, resource_metadata=%q, error_description=%q`, + strings.Join(recommendedScopes, " "), + resourceMetadataURL, + "Additional scopes required: "+strings.Join(requiredScopes, ", "), + ) + + // Send scope challenge response with the superset of existing and required scopes + w.Header().Set("WWW-Authenticate", wwwAuthenticateHeader) + http.Error(w, "Forbidden: insufficient scopes", http.StatusForbidden) + } + return http.HandlerFunc(fn) + } +} diff --git a/pkg/http/middleware/token.go b/pkg/http/middleware/token.go index 26973a548..c362ea201 100644 --- a/pkg/http/middleware/token.go +++ b/pkg/http/middleware/token.go @@ -4,49 +4,19 @@ import ( "errors" "fmt" "net/http" - "regexp" - "strings" ghcontext "github.com/github/github-mcp-server/pkg/context" - httpheaders "github.com/github/github-mcp-server/pkg/http/headers" - "github.com/github/github-mcp-server/pkg/http/mark" "github.com/github/github-mcp-server/pkg/http/oauth" + "github.com/github/github-mcp-server/pkg/utils" ) -type authType int - -const ( - authTypeUnknown authType = iota - authTypeIDE - authTypeGhToken -) - -var ( - errMissingAuthorizationHeader = fmt.Errorf("%w: missing required Authorization header", mark.ErrBadRequest) - errBadAuthorizationHeader = fmt.Errorf("%w: Authorization header is badly formatted", mark.ErrBadRequest) - errUnsupportedAuthorizationHeader = fmt.Errorf("%w: unsupported Authorization header", mark.ErrBadRequest) -) - -var supportedThirdPartyTokenPrefixes = []string{ - "ghp_", // Personal access token (classic) - "github_pat_", // Fine-grained personal access token - "gho_", // OAuth access token - "ghu_", // User access token for a GitHub App - "ghs_", // Installation access token for a GitHub App (a.k.a. server-to-server token) -} - -// oldPatternRegexp is the regular expression for the old pattern of the token. -// Until 2021, GitHub API tokens did not have an identifiable prefix. They -// were 40 characters long and only contained the characters a-f and 0-9. -var oldPatternRegexp = regexp.MustCompile(`\A[a-f0-9]{40}\z`) - func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, token, err := parseAuthorizationHeader(r) + tokenType, token, err := utils.ParseAuthorizationHeader(r) if err != nil { // For missing Authorization header, return 401 with WWW-Authenticate header per MCP spec - if errors.Is(err, errMissingAuthorizationHeader) { + if errors.Is(err, utils.ErrMissingAuthorizationHeader) { sendAuthChallenge(w, r, oauthCfg) return } @@ -56,7 +26,10 @@ func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handl } ctx := r.Context() - ctx = ghcontext.WithTokenInfo(ctx, token) + ctx = ghcontext.WithTokenInfo(ctx, &ghcontext.TokenInfo{ + Token: token, + TokenType: tokenType, + }) r = r.WithContext(ctx) next.ServeHTTP(w, r) @@ -72,42 +45,3 @@ func sendAuthChallenge(w http.ResponseWriter, r *http.Request, oauthCfg *oauth.C w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata=%q`, resourceMetadataURL)) http.Error(w, "Unauthorized", http.StatusUnauthorized) } - -func parseAuthorizationHeader(req *http.Request) (authType authType, token string, _ error) { - authHeader := req.Header.Get(httpheaders.AuthorizationHeader) - if authHeader == "" { - return 0, "", errMissingAuthorizationHeader - } - - switch { - // decrypt dotcom token and set it as token - case strings.HasPrefix(authHeader, "GitHub-Bearer "): - return 0, "", errUnsupportedAuthorizationHeader - default: - // support both "Bearer" and "bearer" to conform to api.github.com - if len(authHeader) > 7 && strings.EqualFold(authHeader[:7], "Bearer ") { - token = authHeader[7:] - } else { - token = authHeader - } - } - - // Do a naïve check for a colon in the token - currently, only the IDE token has a colon in it. - // ex: tid=1;exp=25145314523;chat=1: - if strings.Contains(token, ":") { - return authTypeIDE, token, nil - } - - for _, prefix := range supportedThirdPartyTokenPrefixes { - if strings.HasPrefix(token, prefix) { - return authTypeGhToken, token, nil - } - } - - matchesOldTokenPattern := oldPatternRegexp.MatchString(token) - if matchesOldTokenPattern { - return authTypeGhToken, token, nil - } - - return 0, "", errBadAuthorizationHeader -} diff --git a/pkg/http/server.go b/pkg/http/server.go index c2aad4c61..7a7ab46de 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -17,6 +17,7 @@ import ( "github.com/github/github-mcp-server/pkg/http/oauth" "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/lockdown" + "github.com/github/github-mcp-server/pkg/scopes" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/go-chi/chi/v5" @@ -65,6 +66,10 @@ type ServerConfig struct { // RepoAccessCacheTTL overrides the default TTL for repository access cache entries. RepoAccessCacheTTL *time.Duration + + // ScopeChallenge indicates if we should return OAuth scope challenges, and if we should perform + // tool filtering based on token scopes. + ScopeChallenge bool } func RunHTTPServer(cfg ServerConfig) error { @@ -114,29 +119,42 @@ func RunHTTPServer(cfg ServerConfig) error { featureChecker, ) - r := chi.NewRouter() + // Initialize the global tool scope map + err = initGlobalToolScopeMap(t) + if err != nil { + return fmt.Errorf("failed to initialize tool scope map: %w", err) + } // Register OAuth protected resource metadata endpoints oauthCfg := &oauth.Config{ BaseURL: cfg.BaseURL, ResourcePath: cfg.ResourcePath, } + + serverOptions := []HandlerOption{} + if cfg.ScopeChallenge { + scopeFetcher := scopes.NewFetcher(apiHost, scopes.FetcherOptions{}) + serverOptions = append(serverOptions, WithScopeFetcher(scopeFetcher)) + } + + r := chi.NewRouter() + handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, apiHost, append(serverOptions, WithFeatureChecker(featureChecker), WithOAuthConfig(oauthCfg))...) oauthHandler, err := oauth.NewAuthHandler(oauthCfg) if err != nil { return fmt.Errorf("failed to create OAuth handler: %w", err) } - handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, WithFeatureChecker(featureChecker), WithOAuthConfig(oauthCfg)) - - // MCP routes with middleware r.Group(func(r chi.Router) { + // Register Middleware First, needs to be before route registration handler.RegisterMiddleware(r) + + // Register MCP server routes handler.RegisterRoutes(r) }) logger.Info("MCP endpoints registered", "baseURL", cfg.BaseURL) - // OAuth routes without MCP middleware r.Group(func(r chi.Router) { + // Register OAuth protected resource metadata endpoints oauthHandler.RegisterRoutes(r) }) logger.Info("OAuth protected resource endpoints registered", "baseURL", cfg.BaseURL) @@ -172,6 +190,22 @@ func RunHTTPServer(cfg ServerConfig) error { return nil } +func initGlobalToolScopeMap(t translations.TranslationHelperFunc) error { + // Build inventory with all tools to extract scope information + inv, err := inventory.NewBuilder(). + SetTools(github.AllTools(t)). + Build() + + if err != nil { + return fmt.Errorf("failed to build inventory for tool scope map: %w", err) + } + + // Initialize the global scope map + scopes.SetToolScopeMapFromInventory(inv) + + return nil +} + // createHTTPFeatureChecker creates a feature checker that reads header features from context // and validates them against the knownFeatureFlags whitelist func createHTTPFeatureChecker() inventory.FeatureFlagChecker { diff --git a/pkg/scopes/fetcher.go b/pkg/scopes/fetcher.go index 48e000179..458eaf7b7 100644 --- a/pkg/scopes/fetcher.go +++ b/pkg/scopes/fetcher.go @@ -7,6 +7,8 @@ import ( "net/url" "strings" "time" + + "github.com/github/github-mcp-server/pkg/utils" ) // OAuthScopesHeader is the HTTP response header containing the token's OAuth scopes. @@ -23,28 +25,27 @@ type FetcherOptions struct { // APIHost is the GitHub API host (e.g., "https://api.github.com"). // Defaults to "https://api.github.com" if empty. - APIHost string + APIHost utils.APIHostResolver +} + +type FetcherInterface interface { + FetchTokenScopes(ctx context.Context, token string) ([]string, error) } // Fetcher retrieves token scopes from GitHub's API. // It uses an HTTP HEAD request to minimize bandwidth since we only need headers. type Fetcher struct { client *http.Client - apiHost string + apiHost utils.APIHostResolver } // NewFetcher creates a new scope fetcher with the given options. -func NewFetcher(opts FetcherOptions) *Fetcher { +func NewFetcher(apiHost utils.APIHostResolver, opts FetcherOptions) *Fetcher { client := opts.HTTPClient if client == nil { client = &http.Client{Timeout: DefaultFetchTimeout} } - apiHost := opts.APIHost - if apiHost == "" { - apiHost = "https://api.github.com" - } - return &Fetcher{ client: client, apiHost: apiHost, @@ -61,8 +62,13 @@ func NewFetcher(opts FetcherOptions) *Fetcher { // Note: Fine-grained PATs don't return the X-OAuth-Scopes header, so an empty // slice is returned for those tokens. func (f *Fetcher) FetchTokenScopes(ctx context.Context, token string) ([]string, error) { + apiHostURL, err := f.apiHost.BaseRESTURL(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get API host URL: %w", err) + } + // Use a lightweight endpoint that requires authentication - endpoint, err := url.JoinPath(f.apiHost, "/") + endpoint, err := url.JoinPath(apiHostURL.String(), "/") if err != nil { return nil, fmt.Errorf("failed to construct API URL: %w", err) } @@ -115,11 +121,16 @@ func ParseScopeHeader(header string) []string { // FetchTokenScopes is a convenience function that creates a default fetcher // and fetches the token scopes. func FetchTokenScopes(ctx context.Context, token string) ([]string, error) { - return NewFetcher(FetcherOptions{}).FetchTokenScopes(ctx, token) + apiHost, err := utils.NewAPIHost("https://api.github.com/") + if err != nil { + return nil, fmt.Errorf("failed to create default API host: %w", err) + } + + return NewFetcher(apiHost, FetcherOptions{}).FetchTokenScopes(ctx, token) } // FetchTokenScopesWithHost is a convenience function that creates a fetcher // for a specific API host and fetches the token scopes. -func FetchTokenScopesWithHost(ctx context.Context, token, apiHost string) ([]string, error) { - return NewFetcher(FetcherOptions{APIHost: apiHost}).FetchTokenScopes(ctx, token) +func FetchTokenScopesWithHost(ctx context.Context, token string, apiHost utils.APIHostResolver) ([]string, error) { + return NewFetcher(apiHost, FetcherOptions{}).FetchTokenScopes(ctx, token) } diff --git a/pkg/scopes/fetcher_test.go b/pkg/scopes/fetcher_test.go index 13feab5b0..2d887d7a8 100644 --- a/pkg/scopes/fetcher_test.go +++ b/pkg/scopes/fetcher_test.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "net/http/httptest" + "net/url" "testing" "time" @@ -11,6 +12,23 @@ import ( "github.com/stretchr/testify/require" ) +type testAPIHostResolver struct { + baseURL string +} + +func (t testAPIHostResolver) BaseRESTURL(_ context.Context) (*url.URL, error) { + return url.Parse(t.baseURL) +} +func (t testAPIHostResolver) GraphqlURL(_ context.Context) (*url.URL, error) { + return nil, nil +} +func (t testAPIHostResolver) UploadURL(_ context.Context) (*url.URL, error) { + return nil, nil +} +func (t testAPIHostResolver) RawURL(_ context.Context) (*url.URL, error) { + return nil, nil +} + func TestParseScopeHeader(t *testing.T) { tests := []struct { name string @@ -146,10 +164,8 @@ func TestFetcher_FetchTokenScopes(t *testing.T) { t.Run(tt.name, func(t *testing.T) { server := httptest.NewServer(tt.handler) defer server.Close() - - fetcher := NewFetcher(FetcherOptions{ - APIHost: server.URL, - }) + apiHost := testAPIHostResolver{baseURL: server.URL} + fetcher := NewFetcher(apiHost, FetcherOptions{}) scopes, err := fetcher.FetchTokenScopes(context.Background(), "test-token") @@ -167,10 +183,13 @@ func TestFetcher_FetchTokenScopes(t *testing.T) { } func TestFetcher_DefaultOptions(t *testing.T) { - fetcher := NewFetcher(FetcherOptions{}) + apiHost := testAPIHostResolver{baseURL: "https://api.github.com"} + fetcher := NewFetcher(apiHost, FetcherOptions{}) // Verify default API host is set - assert.Equal(t, "https://api.github.com", fetcher.apiHost) + apiURL, err := fetcher.apiHost.BaseRESTURL(context.Background()) + require.NoError(t, err) + assert.Equal(t, "https://api.github.com", apiURL.String()) // Verify default HTTP client is set with timeout assert.NotNil(t, fetcher.client) @@ -180,7 +199,8 @@ func TestFetcher_DefaultOptions(t *testing.T) { func TestFetcher_CustomHTTPClient(t *testing.T) { customClient := &http.Client{Timeout: 5 * time.Second} - fetcher := NewFetcher(FetcherOptions{ + apiHost := testAPIHostResolver{baseURL: "https://api.github.com"} + fetcher := NewFetcher(apiHost, FetcherOptions{ HTTPClient: customClient, }) @@ -188,11 +208,12 @@ func TestFetcher_CustomHTTPClient(t *testing.T) { } func TestFetcher_CustomAPIHost(t *testing.T) { - fetcher := NewFetcher(FetcherOptions{ - APIHost: "https://api.github.enterprise.com", - }) + apiHost := testAPIHostResolver{baseURL: "https://api.github.enterprise.com"} + fetcher := NewFetcher(apiHost, FetcherOptions{}) - assert.Equal(t, "https://api.github.enterprise.com", fetcher.apiHost) + apiURL, err := fetcher.apiHost.BaseRESTURL(context.Background()) + require.NoError(t, err) + assert.Equal(t, "https://api.github.enterprise.com", apiURL.String()) } func TestFetcher_ContextCancellation(t *testing.T) { @@ -202,9 +223,8 @@ func TestFetcher_ContextCancellation(t *testing.T) { })) defer server.Close() - fetcher := NewFetcher(FetcherOptions{ - APIHost: server.URL, - }) + apiHost := testAPIHostResolver{baseURL: server.URL} + fetcher := NewFetcher(apiHost, FetcherOptions{}) ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately diff --git a/pkg/scopes/map.go b/pkg/scopes/map.go new file mode 100644 index 000000000..3c9833834 --- /dev/null +++ b/pkg/scopes/map.go @@ -0,0 +1,129 @@ +package scopes + +import "github.com/github/github-mcp-server/pkg/inventory" + +// ToolScopeMap maps tool names to their scope requirements. +type ToolScopeMap map[string]*ToolScopeInfo + +// ToolScopeInfo contains scope information for a single tool. +type ToolScopeInfo struct { + // RequiredScopes contains the scopes that are directly required by this tool. + RequiredScopes []string + + // AcceptedScopes contains all scopes that satisfy the requirements (including parent scopes). + AcceptedScopes []string +} + +// globalToolScopeMap is populated from inventory when SetToolScopeMapFromInventory is called +var globalToolScopeMap ToolScopeMap + +// SetToolScopeMapFromInventory builds and stores a tool scope map from an inventory. +// This should be called after building the inventory to make scopes available for middleware. +func SetToolScopeMapFromInventory(inv *inventory.Inventory) { + globalToolScopeMap = GetToolScopeMapFromInventory(inv) +} + +// SetGlobalToolScopeMap sets the global tool scope map directly. +// This is useful for testing when you don't have a full inventory. +func SetGlobalToolScopeMap(m ToolScopeMap) { + globalToolScopeMap = m +} + +// GetToolScopeMap returns the global tool scope map. +// Returns an empty map if SetToolScopeMapFromInventory hasn't been called yet. +func GetToolScopeMap() (ToolScopeMap, error) { + if globalToolScopeMap == nil { + return make(ToolScopeMap), nil + } + return globalToolScopeMap, nil +} + +// GetToolScopeInfo returns scope information for a specific tool from the global scope map. +func GetToolScopeInfo(toolName string) (*ToolScopeInfo, error) { + m, err := GetToolScopeMap() + if err != nil { + return nil, err + } + return m[toolName], nil +} + +// GetToolScopeMapFromInventory builds a tool scope map from an inventory. +// This extracts scope information from ServerTool.RequiredScopes and ServerTool.AcceptedScopes. +func GetToolScopeMapFromInventory(inv *inventory.Inventory) ToolScopeMap { + result := make(ToolScopeMap) + + // Get all tools from the inventory (both enabled and disabled) + // We need all tools for scope checking purposes + allTools := inv.AllTools() + for i := range allTools { + tool := &allTools[i] + if len(tool.RequiredScopes) > 0 || len(tool.AcceptedScopes) > 0 { + result[tool.Tool.Name] = &ToolScopeInfo{ + RequiredScopes: tool.RequiredScopes, + AcceptedScopes: tool.AcceptedScopes, + } + } + } + + return result +} + +// HasAcceptedScope checks if any of the provided user scopes satisfy the tool's requirements. +func (t *ToolScopeInfo) HasAcceptedScope(userScopes ...string) bool { + if t == nil || len(t.AcceptedScopes) == 0 { + return true // No scopes required + } + + userScopeSet := make(map[string]bool) + for _, scope := range userScopes { + userScopeSet[scope] = true + } + + for _, scope := range t.AcceptedScopes { + if userScopeSet[scope] { + return true + } + } + return false +} + +// MissingScopes returns the required scopes that are not present in the user's scopes. +func (t *ToolScopeInfo) MissingScopes(userScopes ...string) []string { + if t == nil || len(t.RequiredScopes) == 0 { + return nil + } + + // Create a set of user scopes for O(1) lookup + userScopeSet := make(map[string]bool, len(userScopes)) + for _, s := range userScopes { + userScopeSet[s] = true + } + + // Check if any accepted scope is present + hasAccepted := false + for _, scope := range t.AcceptedScopes { + if userScopeSet[scope] { + hasAccepted = true + break + } + } + + if hasAccepted { + return nil // User has sufficient scopes + } + + // Return required scopes as the minimum needed + missing := make([]string, len(t.RequiredScopes)) + copy(missing, t.RequiredScopes) + return missing +} + +// GetRequiredScopesSlice returns the required scopes as a slice of strings. +func (t *ToolScopeInfo) GetRequiredScopesSlice() []string { + if t == nil { + return nil + } + scopes := make([]string, len(t.RequiredScopes)) + copy(scopes, t.RequiredScopes) + return scopes +} diff --git a/pkg/scopes/map_test.go b/pkg/scopes/map_test.go new file mode 100644 index 000000000..5f33cdda2 --- /dev/null +++ b/pkg/scopes/map_test.go @@ -0,0 +1,194 @@ +package scopes + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetToolScopeMap(t *testing.T) { + // Reset and set up a test map + SetGlobalToolScopeMap(ToolScopeMap{ + "test_tool": &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + }) + + m, err := GetToolScopeMap() + require.NoError(t, err) + require.NotNil(t, m) + require.Greater(t, len(m), 0, "expected at least one tool in the scope map") + + testTool, ok := m["test_tool"] + require.True(t, ok, "expected test_tool to be in the scope map") + assert.Contains(t, testTool.RequiredScopes, "read:org") + assert.Contains(t, testTool.AcceptedScopes, "read:org") + assert.Contains(t, testTool.AcceptedScopes, "admin:org") +} + +func TestGetToolScopeInfo(t *testing.T) { + // Set up test scope map + SetGlobalToolScopeMap(ToolScopeMap{ + "search_orgs": &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + }) + + info, err := GetToolScopeInfo("search_orgs") + require.NoError(t, err) + require.NotNil(t, info) + + // Non-existent tool should return nil + info, err = GetToolScopeInfo("nonexistent_tool") + require.NoError(t, err) + assert.Nil(t, info) +} + +func TestToolScopeInfo_HasAcceptedScope(t *testing.T) { + testCases := []struct { + name string + scopeInfo *ToolScopeInfo + userScopes []string + expected bool + }{ + { + name: "has exact required scope", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + userScopes: []string{"read:org"}, + expected: true, + }, + { + name: "has parent scope (admin:org grants read:org)", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + userScopes: []string{"admin:org"}, + expected: true, + }, + { + name: "has parent scope (write:org grants read:org)", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + userScopes: []string{"write:org"}, + expected: true, + }, + { + name: "missing required scope", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + userScopes: []string{"repo"}, + expected: false, + }, + { + name: "no scope required", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{}, + AcceptedScopes: []string{}, + }, + userScopes: []string{}, + expected: true, + }, + { + name: "nil scope info", + scopeInfo: nil, + userScopes: []string{}, + expected: true, + }, + { + name: "repo scope for tool requiring repo", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"repo"}, + AcceptedScopes: []string{"repo"}, + }, + userScopes: []string{"repo"}, + expected: true, + }, + { + name: "missing repo scope", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"repo"}, + AcceptedScopes: []string{"repo"}, + }, + userScopes: []string{"public_repo"}, + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := tc.scopeInfo.HasAcceptedScope(tc.userScopes...) + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestToolScopeInfo_MissingScopes(t *testing.T) { + testCases := []struct { + name string + scopeInfo *ToolScopeInfo + userScopes []string + expectedLen int + expectedScopes []string + }{ + { + name: "has required scope - no missing", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + userScopes: []string{"read:org"}, + expectedLen: 0, + expectedScopes: nil, + }, + { + name: "missing scope", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + userScopes: []string{"repo"}, + expectedLen: 1, + expectedScopes: []string{"read:org"}, + }, + { + name: "no scope required - no missing", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{}, + AcceptedScopes: []string{}, + }, + userScopes: []string{}, + expectedLen: 0, + expectedScopes: nil, + }, + { + name: "nil scope info - no missing", + scopeInfo: nil, + userScopes: []string{}, + expectedLen: 0, + expectedScopes: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + missing := tc.scopeInfo.MissingScopes(tc.userScopes...) + assert.Len(t, missing, tc.expectedLen) + if tc.expectedScopes != nil { + for _, expected := range tc.expectedScopes { + assert.Contains(t, missing, expected) + } + } + }) + } +} diff --git a/pkg/utils/api.go b/pkg/utils/api.go index 4a33f1dd2..24abf7342 100644 --- a/pkg/utils/api.go +++ b/pkg/utils/api.go @@ -1,4 +1,4 @@ -package utils +package utils //nolint:revive //TODO: figure out a better name for this package import ( "context" diff --git a/pkg/utils/token.go b/pkg/utils/token.go new file mode 100644 index 000000000..fa3423942 --- /dev/null +++ b/pkg/utils/token.go @@ -0,0 +1,82 @@ +package utils //nolint:revive //TODO: figure out a better name for this package + +import ( + "fmt" + "net/http" + "regexp" + "strings" + + httpheaders "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/github/github-mcp-server/pkg/http/mark" +) + +type TokenType int + +const ( + TokenTypeUnknown TokenType = iota + TokenTypePersonalAccessToken + TokenTypeFineGrainedPersonalAccessToken + TokenTypeOAuthAccessToken + TokenTypeUserToServerGitHubAppToken + TokenTypeServerToServerGitHubAppToken + TokenTypeIDEToken +) + +var supportedThirdPartyTokenPrefixes = map[string]TokenType{ + "ghp_": TokenTypePersonalAccessToken, // Personal access token (classic) + "github_pat_": TokenTypeFineGrainedPersonalAccessToken, // Fine-grained personal access token + "gho_": TokenTypeOAuthAccessToken, // OAuth access token + "ghu_": TokenTypeUserToServerGitHubAppToken, // User access token for a GitHub App + "ghs_": TokenTypeServerToServerGitHubAppToken, // Installation access token for a GitHub App (a.k.a. server-to-server token) +} + +var ( + ErrMissingAuthorizationHeader = fmt.Errorf("%w: missing required Authorization header", mark.ErrBadRequest) + ErrBadAuthorizationHeader = fmt.Errorf("%w: Authorization header is badly formatted", mark.ErrBadRequest) + ErrUnsupportedAuthorizationHeader = fmt.Errorf("%w: unsupported Authorization header", mark.ErrBadRequest) +) + +// oldPatternRegexp is the regular expression for the old pattern of the token. +// Until 2021, GitHub API tokens did not have an identifiable prefix. They +// were 40 characters long and only contained the characters a-f and 0-9. +var oldPatternRegexp = regexp.MustCompile(`\A[a-f0-9]{40}\z`) + +// ParseAuthorizationHeader parses the Authorization header from the HTTP request +func ParseAuthorizationHeader(req *http.Request) (tokenType TokenType, token string, _ error) { + authHeader := req.Header.Get(httpheaders.AuthorizationHeader) + if authHeader == "" { + return 0, "", ErrMissingAuthorizationHeader + } + + switch { + // decrypt dotcom token and set it as token + case strings.HasPrefix(authHeader, "GitHub-Bearer "): + return 0, "", ErrUnsupportedAuthorizationHeader + default: + // support both "Bearer" and "bearer" to conform to api.github.com + if len(authHeader) > 7 && strings.EqualFold(authHeader[:7], "Bearer ") { + token = authHeader[7:] + } else { + token = authHeader + } + } + + // Do a naïve check for a colon in the token - currently, only the IDE token has a colon in it. + // ex: tid=1;exp=25145314523;chat=1: + if strings.Contains(token, ":") { + return TokenTypeIDEToken, token, nil + } + + for prefix, tokenType := range supportedThirdPartyTokenPrefixes { + if strings.HasPrefix(token, prefix) { + return tokenType, token, nil + } + } + + matchesOldTokenPattern := oldPatternRegexp.MatchString(token) + if matchesOldTokenPattern { + return TokenTypePersonalAccessToken, token, nil + } + + return 0, "", ErrBadAuthorizationHeader +}