diff --git a/auth/enterprise_auth.go b/auth/enterprise_auth.go new file mode 100644 index 00000000..54535c40 --- /dev/null +++ b/auth/enterprise_auth.go @@ -0,0 +1,181 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements the client-side Enterprise Managed Authorization flow +// for MCP as specified in SEP-990. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "fmt" + "net/http" + + "github.com/modelcontextprotocol/go-sdk/oauthex" + "golang.org/x/oauth2" +) + +// EnterpriseAuthConfig contains configuration for Enterprise Managed Authorization +// (SEP-990). This configures both the IdP (for token exchange) and the MCP Server +// (for JWT Bearer grant). +type EnterpriseAuthConfig struct { + // IdP configuration (where the user authenticates) + IdPIssuerURL string // e.g., "https://acme.okta.com" + IdPClientID string // MCP Client's ID at the IdP + IdPClientSecret string // MCP Client's secret at the IdP + + // MCP Server configuration (the resource being accessed) + MCPAuthServerURL string // MCP Server's auth server issuer URL + MCPResourceURI string // MCP Server's resource identifier + MCPClientID string // MCP Client's ID at the MCP Server + MCPClientSecret string // MCP Client's secret at the MCP Server + MCPScopes []string // Requested scopes at the MCP Server + + // Optional HTTP client for customization + HTTPClient *http.Client +} + +// EnterpriseAuthFlow performs the complete Enterprise Managed Authorization flow: +// 1. Token Exchange: ID Token → ID-JAG at IdP +// 2. JWT Bearer: ID-JAG → Access Token at MCP Server +// +// This function takes an ID Token that was obtained via SSO (e.g., OIDC login) +// and exchanges it for an access token that can be used to call the MCP Server. +// +// There are two ways to obtain an ID Token for use with this function: +// +// Option 1: Use the OIDC login helper functions (full flow with SSO): +// +// // Step 1: Initiate OIDC login +// oidcConfig := &OIDCLoginConfig{ +// IssuerURL: "https://acme.okta.com", +// ClientID: "client-id", +// RedirectURL: "http://localhost:8080/callback", +// Scopes: []string{"openid", "profile", "email"}, +// } +// authReq, err := InitiateOIDCLogin(ctx, oidcConfig) +// if err != nil { +// log.Fatal(err) +// } +// +// // Step 2: Direct user to authReq.AuthURL for authentication +// fmt.Printf("Visit: %s\n", authReq.AuthURL) +// +// // Step 3: After redirect, complete login with authorization code +// tokens, err := CompleteOIDCLogin(ctx, oidcConfig, authCode, authReq.CodeVerifier) +// if err != nil { +// log.Fatal(err) +// } +// +// // Step 4: Use ID token for enterprise auth +// enterpriseConfig := &EnterpriseAuthConfig{ +// IdPIssuerURL: "https://acme.okta.com", +// IdPClientID: "client-id-at-idp", +// IdPClientSecret: "secret-at-idp", +// MCPAuthServerURL: "https://auth.mcpserver.example", +// MCPResourceURI: "https://mcp.mcpserver.example", +// MCPClientID: "client-id-at-mcp", +// MCPClientSecret: "secret-at-mcp", +// MCPScopes: []string{"read", "write"}, +// } +// accessToken, err := EnterpriseAuthFlow(ctx, enterpriseConfig, tokens.IDToken) +// if err != nil { +// log.Fatal(err) +// } +// +// Option 2: Bring your own ID Token (if you already have one): +// +// config := &EnterpriseAuthConfig{ +// IdPIssuerURL: "https://acme.okta.com", +// IdPClientID: "client-id-at-idp", +// IdPClientSecret: "secret-at-idp", +// MCPAuthServerURL: "https://auth.mcpserver.example", +// MCPResourceURI: "https://mcp.mcpserver.example", +// MCPClientID: "client-id-at-mcp", +// MCPClientSecret: "secret-at-mcp", +// MCPScopes: []string{"read", "write"}, +// } +// +// // If you already obtained an ID token through your own means +// accessToken, err := EnterpriseAuthFlow(ctx, config, myIDToken) +// if err != nil { +// log.Fatal(err) +// } +// +// // Use accessToken to call MCP Server APIs +func EnterpriseAuthFlow( + ctx context.Context, + config *EnterpriseAuthConfig, + idToken string, +) (*oauth2.Token, error) { + if config == nil { + return nil, fmt.Errorf("config is required") + } + if idToken == "" { + return nil, fmt.Errorf("idToken is required") + } + // Validate configuration + if config.IdPIssuerURL == "" { + return nil, fmt.Errorf("IdPIssuerURL is required") + } + if config.MCPAuthServerURL == "" { + return nil, fmt.Errorf("MCPAuthServerURL is required") + } + if config.MCPResourceURI == "" { + return nil, fmt.Errorf("MCPResourceURI is required") + } + httpClient := config.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } + + // Step 1: Discover IdP token endpoint via OIDC discovery + idpMeta, err := oauthex.GetAuthServerMeta(ctx, config.IdPIssuerURL, httpClient) + if err != nil { + return nil, fmt.Errorf("failed to discover IdP metadata: %w", err) + } + + // Step 2: Token Exchange (ID Token → ID-JAG) + tokenExchangeReq := &oauthex.TokenExchangeRequest{ + RequestedTokenType: oauthex.TokenTypeIDJAG, + Audience: config.MCPAuthServerURL, + Resource: config.MCPResourceURI, + Scope: config.MCPScopes, + SubjectToken: idToken, + SubjectTokenType: oauthex.TokenTypeIDToken, + } + + tokenExchangeResp, err := oauthex.ExchangeToken( + ctx, + idpMeta.TokenEndpoint, + tokenExchangeReq, + config.IdPClientID, + config.IdPClientSecret, + httpClient, + ) + if err != nil { + return nil, fmt.Errorf("token exchange failed: %w", err) + } + + // Step 3: JWT Bearer Grant (ID-JAG → Access Token) + mcpMeta, err := oauthex.GetAuthServerMeta(ctx, config.MCPAuthServerURL, httpClient) + if err != nil { + return nil, fmt.Errorf("failed to discover MCP auth server metadata: %w", err) + } + + accessToken, err := oauthex.ExchangeJWTBearer( + ctx, + mcpMeta.TokenEndpoint, + tokenExchangeResp.AccessToken, + config.MCPClientID, + config.MCPClientSecret, + httpClient, + ) + if err != nil { + return nil, fmt.Errorf("JWT bearer grant failed: %w", err) + } + return accessToken, nil +} diff --git a/auth/enterprise_auth_test.go b/auth/enterprise_auth_test.go new file mode 100644 index 00000000..c44e4233 --- /dev/null +++ b/auth/enterprise_auth_test.go @@ -0,0 +1,218 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/oauthex" +) + +// TestEnterpriseAuthFlow tests the complete enterprise auth flow. +func TestEnterpriseAuthFlow(t *testing.T) { + // Create test servers for IdP and MCP Server + idpServer := createMockIdPServer(t) + defer idpServer.Close() + mcpServer := createMockMCPServer(t) + defer mcpServer.Close() + // Create a test ID Token + idToken := createTestIDToken() + // Configure enterprise auth + config := &EnterpriseAuthConfig{ + IdPIssuerURL: idpServer.URL, + IdPClientID: "test-idp-client", + IdPClientSecret: "test-idp-secret", + MCPAuthServerURL: mcpServer.URL, + MCPResourceURI: "https://mcp.example.com", + MCPClientID: "test-mcp-client", + MCPClientSecret: "test-mcp-secret", + MCPScopes: []string{"read", "write"}, + HTTPClient: idpServer.Client(), + } + // Test successful flow + t.Run("successful flow", func(t *testing.T) { + token, err := EnterpriseAuthFlow(context.Background(), config, idToken) + if err != nil { + t.Fatalf("EnterpriseAuthFlow failed: %v", err) + } + if token.AccessToken != "mcp-access-token" { + t.Errorf("expected access token 'mcp-access-token', got '%s'", token.AccessToken) + } + if token.TokenType != "Bearer" { + t.Errorf("expected token type 'Bearer', got '%s'", token.TokenType) + } + }) + // Test missing config + t.Run("nil config", func(t *testing.T) { + _, err := EnterpriseAuthFlow(context.Background(), nil, idToken) + if err == nil { + t.Error("expected error for nil config, got nil") + } + }) + // Test missing ID token + t.Run("empty ID token", func(t *testing.T) { + _, err := EnterpriseAuthFlow(context.Background(), config, "") + if err == nil { + t.Error("expected error for empty ID token, got nil") + } + }) + // Test missing IdP issuer + t.Run("missing IdP issuer", func(t *testing.T) { + badConfig := *config + badConfig.IdPIssuerURL = "" + _, err := EnterpriseAuthFlow(context.Background(), &badConfig, idToken) + if err == nil { + t.Error("expected error for missing IdP issuer, got nil") + } + }) +} + +// createMockIdPServer creates a mock IdP server for testing. +func createMockIdPServer(t *testing.T) *httptest.Server { + var serverURL string + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle OIDC discovery endpoint + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": serverURL, // Use actual server URL + "token_endpoint": serverURL + "/oauth2/v1/token", + "jwks_uri": serverURL + "/.well-known/jwks.json", + "code_challenge_methods_supported": []string{"S256"}, + "grant_types_supported": []string{ + "authorization_code", + "urn:ietf:params:oauth:grant-type:token-exchange", + }, + "response_types_supported": []string{"code"}, + }) + return + } + + // Handle token exchange endpoint + if r.URL.Path != "/oauth2/v1/token" { + http.NotFound(w, r) + return + } + + if err := r.ParseForm(); err != nil { + http.Error(w, "failed to parse form", http.StatusBadRequest) + return + } + grantType := r.FormValue("grant_type") + if grantType != oauthex.GrantTypeTokenExchange { + http.Error(w, "invalid grant type", http.StatusBadRequest) + return + } + + // Return a mock ID-JAG + now := time.Now().Unix() + header := map[string]string{"typ": "oauth-id-jag+jwt", "alg": "RS256"} + claims := map[string]interface{}{ + "iss": "https://test.okta.com", + "sub": "test-user", + "aud": r.FormValue("audience"), + "resource": r.FormValue("resource"), + "client_id": r.FormValue("client_id"), + "jti": "test-jti", + "exp": now + 300, + "iat": now, + "scope": r.FormValue("scope"), + } + headerJSON, _ := json.Marshal(header) + claimsJSON, _ := json.Marshal(claims) + headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) + claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON) + mockIDJAG := fmt.Sprintf("%s.%s.mock-signature", headerB64, claimsB64) + + resp := oauthex.TokenExchangeResponse{ + IssuedTokenType: oauthex.TokenTypeIDJAG, + AccessToken: mockIDJAG, + TokenType: "N_A", + ExpiresIn: 300, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + serverURL = server.URL // Capture server URL for discovery response + return server +} + +// createMockMCPServer creates a mock MCP Server for testing. +func createMockMCPServer(t *testing.T) *httptest.Server { + var serverURL string + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle OIDC discovery endpoint + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": serverURL, // Use actual server URL + "token_endpoint": serverURL + "/v1/token", + "jwks_uri": serverURL + "/.well-known/jwks.json", + "code_challenge_methods_supported": []string{"S256"}, + "grant_types_supported": []string{ + "urn:ietf:params:oauth:grant-type:jwt-bearer", + }, + }) + return + } + + // Handle JWT Bearer endpoint + if r.URL.Path != "/v1/token" { + http.NotFound(w, r) + return + } + + if err := r.ParseForm(); err != nil { + http.Error(w, "failed to parse form", http.StatusBadRequest) + return + } + grantType := r.FormValue("grant_type") + if grantType != oauthex.GrantTypeJWTBearer { + http.Error(w, "invalid grant type", http.StatusBadRequest) + return + } + + resp := oauthex.JWTBearerResponse{ + AccessToken: "mcp-access-token", + TokenType: "Bearer", + ExpiresIn: 3600, + Scope: "read write", + RefreshToken: "mcp-refresh-token", + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + serverURL = server.URL // Capture server URL for discovery response + return server +} + +// createTestIDToken creates a mock ID Token for testing. +func createTestIDToken() string { + now := time.Now().Unix() + header := map[string]string{"typ": "JWT", "alg": "RS256"} + claims := map[string]interface{}{ + "iss": "https://test.okta.com", + "sub": "test-user", + "aud": "test-client", + "exp": now + 3600, + "iat": now, + "email": "test@example.com", + } + headerJSON, _ := json.Marshal(header) + claimsJSON, _ := json.Marshal(claims) + headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) + claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON) + + return fmt.Sprintf("%s.%s.mock-signature", headerB64, claimsB64) +} diff --git a/auth/oidc_login.go b/auth/oidc_login.go new file mode 100644 index 00000000..5c1d9106 --- /dev/null +++ b/auth/oidc_login.go @@ -0,0 +1,454 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements OIDC Authorization Code flow for obtaining ID tokens +// as part of Enterprise Managed Authorization (SEP-990). +// See https://openid.net/specs/openid-connect-core-1_0.html + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/modelcontextprotocol/go-sdk/oauthex" + "golang.org/x/oauth2" +) + +// OIDCLoginConfig configures the OIDC Authorization Code flow for obtaining +// an ID Token. This is an OPTIONAL step before calling EnterpriseAuthFlow. +// Users can alternatively obtain ID tokens through their own methods. +type OIDCLoginConfig struct { + // IssuerURL is the IdP's issuer URL (e.g., "https://acme.okta.com"). + IssuerURL string + // ClientID is the MCP Client's ID registered at the IdP. + ClientID string + // ClientSecret is the MCP Client's secret at the IdP. + // This is OPTIONAL and only used if the client is confidential. + ClientSecret string + // RedirectURL is the OAuth2 redirect URI registered with the IdP. + // This must match exactly what was registered with the IdP. + RedirectURL string + // Scopes are the OAuth2/OIDC scopes to request. + // "openid" is REQUIRED for OIDC. Common values: ["openid", "profile", "email"] + Scopes []string + // LoginHint is an OPTIONAL hint to the IdP about the user's identity. + // Some IdPs may require this (e.g., as an email address for routing to SSO providers). + // Example: "user@example.com" + LoginHint string + // HTTPClient is the HTTP client for making requests. + // If nil, http.DefaultClient is used. + HTTPClient *http.Client +} + +// OIDCAuthorizationRequest represents the result of initiating an OIDC +// authorization code flow. Users must direct the end-user to AuthURL +// to complete authentication. +type OIDCAuthorizationRequest struct { + // AuthURL is the URL the user should visit to authenticate. + // This URL includes the authorization request parameters. + AuthURL string + // State is the OAuth2 state parameter for CSRF protection. + // Users MUST validate that the state returned from the IdP matches this value. + State string + // CodeVerifier is the PKCE code verifier for secure authorization code exchange. + // This must be provided to CompleteOIDCLogin along with the authorization code. + CodeVerifier string +} + +// OIDCTokenResponse contains the tokens returned from a successful OIDC login. +type OIDCTokenResponse struct { + // IDToken is the OpenID Connect ID Token (JWT). + // This can be passed to EnterpriseAuthFlow for token exchange. + IDToken string + // AccessToken is the OAuth2 access token (if issued by IdP). + // This is typically not needed for SEP-990, but may be useful for other IdP APIs. + AccessToken string + // RefreshToken is the OAuth2 refresh token (if issued by IdP). + RefreshToken string + // TokenType is the token type (typically "Bearer"). + TokenType string + // ExpiresAt is when the ID token expires. + ExpiresAt int64 +} + +// InitiateOIDCLogin initiates an OIDC Authorization Code flow with PKCE. +// This is the first step for users who want to use SSO to obtain an ID token. +// +// The returned AuthURL should be presented to the user (e.g., opened in a browser). +// After the user authenticates, the IdP will redirect to the RedirectURL with +// an authorization code and state parameter. +// +// Example: +// +// config := &OIDCLoginConfig{ +// IssuerURL: "https://acme.okta.com", +// ClientID: "client-id", +// RedirectURL: "http://localhost:8080/callback", +// Scopes: []string{"openid", "profile", "email"}, +// } +// +// authReq, err := InitiateOIDCLogin(ctx, config) +// if err != nil { +// log.Fatal(err) +// } +// +// // Direct user to authReq.AuthURL +// fmt.Printf("Visit this URL to login: %s\n", authReq.AuthURL) +// +// // After user completes login, IdP redirects to RedirectURL with code & state +// // Extract code and state from the redirect, then call CompleteOIDCLogin +func InitiateOIDCLogin( + ctx context.Context, + config *OIDCLoginConfig, +) (*OIDCAuthorizationRequest, error) { + if config == nil { + return nil, fmt.Errorf("config is required") + } + // Validate required fields + if config.IssuerURL == "" { + return nil, fmt.Errorf("IssuerURL is required") + } + if config.ClientID == "" { + return nil, fmt.Errorf("ClientID is required") + } + if config.RedirectURL == "" { + return nil, fmt.Errorf("RedirectURL is required") + } + if len(config.Scopes) == 0 { + return nil, fmt.Errorf("Scopes is required (must include 'openid')") + } + // Validate that "openid" scope is present (required for OIDC) + hasOpenID := false + for _, scope := range config.Scopes { + if scope == "openid" { + hasOpenID = true + break + } + } + if !hasOpenID { + return nil, fmt.Errorf("Scopes must include 'openid' for OIDC") + } + // Validate URL schemes to prevent XSS attacks + if err := oauthex.CheckURLScheme(config.IssuerURL); err != nil { + return nil, fmt.Errorf("invalid IssuerURL: %w", err) + } + if err := oauthex.CheckURLScheme(config.RedirectURL); err != nil { + return nil, fmt.Errorf("invalid RedirectURL: %w", err) + } + // Discover OIDC endpoints via .well-known + httpClient := config.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } + meta, err := oauthex.GetAuthServerMeta(ctx, config.IssuerURL, httpClient) + if err != nil { + return nil, fmt.Errorf("failed to discover OIDC metadata: %w", err) + } + if meta.AuthorizationEndpoint == "" { + return nil, fmt.Errorf("authorization_endpoint not found in OIDC metadata") + } + // Generate PKCE code verifier and challenge (RFC 7636) + codeVerifier, err := generateCodeVerifier() + if err != nil { + return nil, fmt.Errorf("failed to generate PKCE verifier: %w", err) + } + codeChallenge := generateCodeChallenge(codeVerifier) + // Generate state for CSRF protection (RFC 6749 Section 10.12) + state, err := generateState() + if err != nil { + return nil, fmt.Errorf("failed to generate state: %w", err) + } + // Build authorization URL per OIDC Core Section 3.1.2.1 + authURL, err := buildAuthorizationURL( + meta.AuthorizationEndpoint, + config.ClientID, + config.RedirectURL, + config.Scopes, + state, + codeChallenge, + config.LoginHint, + ) + if err != nil { + return nil, fmt.Errorf("failed to build authorization URL: %w", err) + } + return &OIDCAuthorizationRequest{ + AuthURL: authURL, + State: state, + CodeVerifier: codeVerifier, + }, nil +} + +// CompleteOIDCLogin completes the OIDC Authorization Code flow by exchanging +// the authorization code for tokens. This is the second step after the user +// has authenticated and been redirected back to the application. +// +// The authCode and returnedState parameters should come from the redirect URL +// query parameters. The state MUST match the state from InitiateOIDCLogin +// for CSRF protection. +// +// Example: +// +// // In your redirect handler (e.g., http://localhost:8080/callback) +// authCode := r.URL.Query().Get("code") +// returnedState := r.URL.Query().Get("state") +// +// // Validate state matches what we sent +// if returnedState != authReq.State { +// log.Fatal("State mismatch - possible CSRF attack") +// } +// +// // Exchange code for tokens +// tokens, err := CompleteOIDCLogin(ctx, config, authCode, authReq.CodeVerifier) +// if err != nil { +// log.Fatal(err) +// } +// +// // Now use tokens.IDToken with EnterpriseAuthFlow +// accessToken, err := EnterpriseAuthFlow(ctx, enterpriseConfig, tokens.IDToken) +func CompleteOIDCLogin( + ctx context.Context, + config *OIDCLoginConfig, + authCode string, + codeVerifier string, +) (*OIDCTokenResponse, error) { + if config == nil { + return nil, fmt.Errorf("config is required") + } + if authCode == "" { + return nil, fmt.Errorf("authCode is required") + } + if codeVerifier == "" { + return nil, fmt.Errorf("codeVerifier is required") + } + // Validate required fields + if config.IssuerURL == "" { + return nil, fmt.Errorf("IssuerURL is required") + } + if config.ClientID == "" { + return nil, fmt.Errorf("ClientID is required") + } + if config.RedirectURL == "" { + return nil, fmt.Errorf("RedirectURL is required") + } + // Discover token endpoint + httpClient := config.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } + meta, err := oauthex.GetAuthServerMeta(ctx, config.IssuerURL, httpClient) + if err != nil { + return nil, fmt.Errorf("failed to discover OIDC metadata: %w", err) + } + if meta.TokenEndpoint == "" { + return nil, fmt.Errorf("token_endpoint not found in OIDC metadata") + } + // Build token request per OIDC Core Section 3.1.3.1 + formData := url.Values{} + formData.Set("grant_type", "authorization_code") + formData.Set("code", authCode) + formData.Set("redirect_uri", config.RedirectURL) + formData.Set("client_id", config.ClientID) + formData.Set("code_verifier", codeVerifier) + // Add client_secret if provided (confidential client) + if config.ClientSecret != "" { + formData.Set("client_secret", config.ClientSecret) + } + // Exchange authorization code for tokens + oauth2Token, err := exchangeAuthorizationCode( + ctx, + meta.TokenEndpoint, + formData, + httpClient, + ) + if err != nil { + return nil, fmt.Errorf("token exchange failed: %w", err) + } + // Extract ID Token from response + idToken, ok := oauth2Token.Extra("id_token").(string) + if !ok || idToken == "" { + return nil, fmt.Errorf("id_token not found in token response") + } + return &OIDCTokenResponse{ + IDToken: idToken, + AccessToken: oauth2Token.AccessToken, + RefreshToken: oauth2Token.RefreshToken, + TokenType: oauth2Token.TokenType, + ExpiresAt: oauth2Token.Expiry.Unix(), + }, nil +} + +// generateCodeVerifier generates a cryptographically random code verifier +// for PKCE per RFC 7636 Section 4.1. +func generateCodeVerifier() (string, error) { + // Per RFC 7636: code verifier is 43-128 characters from [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~" + // We use 32 random bytes (256 bits) base64url-encoded = 43 characters + randomBytes := make([]byte, 32) + if _, err := rand.Read(randomBytes); err != nil { + return "", fmt.Errorf("failed to generate random bytes: %w", err) + } + return base64.RawURLEncoding.EncodeToString(randomBytes), nil +} + +// generateCodeChallenge generates the PKCE code challenge from the verifier +// using SHA256 per RFC 7636 Section 4.2. +func generateCodeChallenge(verifier string) string { + hash := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(hash[:]) +} + +// generateState generates a cryptographically random state parameter +// for CSRF protection per RFC 6749 Section 10.12. +func generateState() (string, error) { + randomBytes := make([]byte, 32) + if _, err := rand.Read(randomBytes); err != nil { + return "", fmt.Errorf("failed to generate random bytes: %w", err) + } + return base64.RawURLEncoding.EncodeToString(randomBytes), nil +} + +// buildAuthorizationURL constructs the OIDC authorization URL. +func buildAuthorizationURL( + authEndpoint string, + clientID string, + redirectURL string, + scopes []string, + state string, + codeChallenge string, + loginHint string, +) (string, error) { + u, err := url.Parse(authEndpoint) + if err != nil { + return "", fmt.Errorf("invalid authorization endpoint: %w", err) + } + q := u.Query() + q.Set("response_type", "code") + q.Set("client_id", clientID) + q.Set("redirect_uri", redirectURL) + q.Set("scope", strings.Join(scopes, " ")) + q.Set("state", state) + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") + // Add login_hint if provided (optional per OIDC spec, but some IdPs may require it) + if loginHint != "" { + q.Set("login_hint", loginHint) + } + u.RawQuery = q.Encode() + return u.String(), nil +} + +// exchangeAuthorizationCode exchanges the authorization code for tokens. +func exchangeAuthorizationCode( + ctx context.Context, + tokenEndpoint string, + formData url.Values, + httpClient *http.Client, +) (*oauth2.Token, error) { + // Create HTTP request + httpReq, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + tokenEndpoint, + strings.NewReader(formData.Encode()), + ) + if err != nil { + return nil, fmt.Errorf("failed to create token request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + httpReq.Header.Set("Accept", "application/json") + + // Execute request + httpResp, err := httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("token request failed: %w", err) + } + defer httpResp.Body.Close() + + // Read response body (limit to 1MB for safety) + body, err := io.ReadAll(io.LimitReader(httpResp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("failed to read token response: %w", err) + } + + // Handle success response (200 OK) + if httpResp.StatusCode == http.StatusOK { + // Parse token response manually (following jwt_bearer.go pattern) + var tokenResp struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + IDToken string `json:"id_token,omitempty"` + Scope string `json:"scope,omitempty"` + } + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w (body: %s)", err, string(body)) + } + + // Validate required fields + if tokenResp.AccessToken == "" { + return nil, fmt.Errorf("response missing required field: access_token") + } + if tokenResp.TokenType == "" { + return nil, fmt.Errorf("response missing required field: token_type") + } + + // Convert to oauth2.Token + token := &oauth2.Token{ + AccessToken: tokenResp.AccessToken, + TokenType: tokenResp.TokenType, + RefreshToken: tokenResp.RefreshToken, + } + + // Set expiration if provided + if tokenResp.ExpiresIn > 0 { + token.Expiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + + // Add extra fields (id_token, scope) + extra := make(map[string]interface{}) + if tokenResp.IDToken != "" { + extra["id_token"] = tokenResp.IDToken + } + if tokenResp.Scope != "" { + extra["scope"] = tokenResp.Scope + } + if len(extra) > 0 { + token = token.WithExtra(extra) + } + + return token, nil + } + + // Handle error response (400 Bad Request) + if httpResp.StatusCode == http.StatusBadRequest { + var errResp struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` + ErrorURI string `json:"error_uri,omitempty"` + } + if err := json.Unmarshal(body, &errResp); err != nil { + return nil, fmt.Errorf("failed to parse error response: %w (body: %s)", err, string(body)) + } + if errResp.ErrorDescription != "" { + return nil, fmt.Errorf("token request failed: %s (%s)", errResp.Error, errResp.ErrorDescription) + } + return nil, fmt.Errorf("token request failed: %s", errResp.Error) + } + + // Handle unexpected status codes + return nil, fmt.Errorf("unexpected status code %d: %s", httpResp.StatusCode, string(body)) +} diff --git a/auth/oidc_login_test.go b/auth/oidc_login_test.go new file mode 100644 index 00000000..ca8c3609 --- /dev/null +++ b/auth/oidc_login_test.go @@ -0,0 +1,384 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" +) + +// TestInitiateOIDCLogin tests the OIDC authorization request generation. +func TestInitiateOIDCLogin(t *testing.T) { + // Create mock IdP server + idpServer := createMockOIDCServer(t) + defer idpServer.Close() + config := &OIDCLoginConfig{ + IssuerURL: idpServer.URL, + ClientID: "test-client", + RedirectURL: "http://localhost:8080/callback", + Scopes: []string{"openid", "profile", "email"}, + HTTPClient: idpServer.Client(), + } + t.Run("successful initiation", func(t *testing.T) { + authReq, err := InitiateOIDCLogin(context.Background(), config) + if err != nil { + t.Fatalf("InitiateOIDCLogin failed: %v", err) + } + // Validate AuthURL + if authReq.AuthURL == "" { + t.Error("AuthURL is empty") + } + // Parse and validate URL parameters + u, err := url.Parse(authReq.AuthURL) + if err != nil { + t.Fatalf("Failed to parse AuthURL: %v", err) + } + q := u.Query() + if q.Get("response_type") != "code" { + t.Errorf("expected response_type 'code', got '%s'", q.Get("response_type")) + } + if q.Get("client_id") != "test-client" { + t.Errorf("expected client_id 'test-client', got '%s'", q.Get("client_id")) + } + if q.Get("redirect_uri") != "http://localhost:8080/callback" { + t.Errorf("expected redirect_uri 'http://localhost:8080/callback', got '%s'", q.Get("redirect_uri")) + } + if q.Get("scope") != "openid profile email" { + t.Errorf("expected scope 'openid profile email', got '%s'", q.Get("scope")) + } + if q.Get("code_challenge_method") != "S256" { + t.Errorf("expected code_challenge_method 'S256', got '%s'", q.Get("code_challenge_method")) + } + // Validate state is generated + if authReq.State == "" { + t.Error("State is empty") + } + if q.Get("state") != authReq.State { + t.Errorf("state in URL doesn't match returned state") + } + // Validate PKCE parameters + if authReq.CodeVerifier == "" { + t.Error("CodeVerifier is empty") + } + if q.Get("code_challenge") == "" { + t.Error("code_challenge is empty") + } + }) + t.Run("with login_hint", func(t *testing.T) { + configWithHint := *config + configWithHint.LoginHint = "user@example.com" + authReq, err := InitiateOIDCLogin(context.Background(), &configWithHint) + if err != nil { + t.Fatalf("InitiateOIDCLogin failed: %v", err) + } + u, err := url.Parse(authReq.AuthURL) + if err != nil { + t.Fatalf("Failed to parse AuthURL: %v", err) + } + q := u.Query() + if q.Get("login_hint") != "user@example.com" { + t.Errorf("expected login_hint 'user@example.com', got '%s'", q.Get("login_hint")) + } + }) + t.Run("without login_hint", func(t *testing.T) { + authReq, err := InitiateOIDCLogin(context.Background(), config) + if err != nil { + t.Fatalf("InitiateOIDCLogin failed: %v", err) + } + u, err := url.Parse(authReq.AuthURL) + if err != nil { + t.Fatalf("Failed to parse AuthURL: %v", err) + } + q := u.Query() + if q.Has("login_hint") { + t.Errorf("expected no login_hint parameter, but got '%s'", q.Get("login_hint")) + } + }) + t.Run("nil config", func(t *testing.T) { + _, err := InitiateOIDCLogin(context.Background(), nil) + if err == nil { + t.Error("expected error for nil config, got nil") + } + }) + t.Run("missing openid scope", func(t *testing.T) { + badConfig := *config + badConfig.Scopes = []string{"profile", "email"} // Missing "openid" + _, err := InitiateOIDCLogin(context.Background(), &badConfig) + if err == nil { + t.Error("expected error for missing openid scope, got nil") + } + if !strings.Contains(err.Error(), "openid") { + t.Errorf("expected error about missing 'openid', got: %v", err) + } + }) + t.Run("missing required fields", func(t *testing.T) { + tests := []struct { + name string + mutate func(*OIDCLoginConfig) + expectErr string + }{ + { + name: "missing IssuerURL", + mutate: func(c *OIDCLoginConfig) { c.IssuerURL = "" }, + expectErr: "IssuerURL is required", + }, + { + name: "missing ClientID", + mutate: func(c *OIDCLoginConfig) { c.ClientID = "" }, + expectErr: "ClientID is required", + }, + { + name: "missing RedirectURL", + mutate: func(c *OIDCLoginConfig) { c.RedirectURL = "" }, + expectErr: "RedirectURL is required", + }, + { + name: "missing Scopes", + mutate: func(c *OIDCLoginConfig) { c.Scopes = nil }, + expectErr: "Scopes is required", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + badConfig := *config + tt.mutate(&badConfig) + _, err := InitiateOIDCLogin(context.Background(), &badConfig) + if err == nil { + t.Error("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.expectErr) { + t.Errorf("expected error containing '%s', got: %v", tt.expectErr, err) + } + }) + } + }) +} + +// TestCompleteOIDCLogin tests the authorization code exchange. +func TestCompleteOIDCLogin(t *testing.T) { + // Create mock IdP server + idpServer := createMockOIDCServerWithToken(t) + defer idpServer.Close() + config := &OIDCLoginConfig{ + IssuerURL: idpServer.URL, + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURL: "http://localhost:8080/callback", + Scopes: []string{"openid", "profile", "email"}, + HTTPClient: idpServer.Client(), + } + t.Run("successful code exchange", func(t *testing.T) { + tokens, err := CompleteOIDCLogin( + context.Background(), + config, + "test-auth-code", + "test-code-verifier", + ) + if err != nil { + t.Fatalf("CompleteOIDCLogin failed: %v", err) + } + // Validate tokens + if tokens.IDToken == "" { + t.Error("IDToken is empty") + } + if tokens.AccessToken == "" { + t.Error("AccessToken is empty") + } + if tokens.TokenType != "Bearer" { + t.Errorf("expected TokenType 'Bearer', got '%s'", tokens.TokenType) + } + if tokens.ExpiresAt == 0 { + t.Error("ExpiresAt is zero") + } + }) + t.Run("nil config", func(t *testing.T) { + _, err := CompleteOIDCLogin( + context.Background(), + nil, + "test-auth-code", + "test-code-verifier", + ) + if err == nil { + t.Error("expected error for nil config, got nil") + } + }) + t.Run("missing parameters", func(t *testing.T) { + tests := []struct { + name string + authCode string + codeVerifier string + expectErr string + }{ + { + name: "missing authCode", + authCode: "", + codeVerifier: "test-verifier", + expectErr: "authCode is required", + }, + { + name: "missing codeVerifier", + authCode: "test-code", + codeVerifier: "", + expectErr: "codeVerifier is required", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := CompleteOIDCLogin( + context.Background(), + config, + tt.authCode, + tt.codeVerifier, + ) + if err == nil { + t.Error("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.expectErr) { + t.Errorf("expected error containing '%s', got: %v", tt.expectErr, err) + } + }) + } + }) +} + +// TestOIDCLoginE2E tests the complete OIDC login flow end-to-end. +func TestOIDCLoginE2E(t *testing.T) { + // Create mock IdP server + idpServer := createMockOIDCServerWithToken(t) + defer idpServer.Close() + config := &OIDCLoginConfig{ + IssuerURL: idpServer.URL, + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURL: "http://localhost:8080/callback", + Scopes: []string{"openid", "profile", "email"}, + HTTPClient: idpServer.Client(), + } + // Step 1: Initiate login + authReq, err := InitiateOIDCLogin(context.Background(), config) + if err != nil { + t.Fatalf("InitiateOIDCLogin failed: %v", err) + } + // Step 2: Simulate user authentication and redirect + // (In real flow, user would visit authReq.AuthURL and IdP would redirect back) + // Here we just use a mock authorization code + mockAuthCode := "mock-authorization-code" + // Step 3: Complete login with authorization code + tokens, err := CompleteOIDCLogin( + context.Background(), + config, + mockAuthCode, + authReq.CodeVerifier, + ) + if err != nil { + t.Fatalf("CompleteOIDCLogin failed: %v", err) + } + // Validate we got an ID token + if tokens.IDToken == "" { + t.Error("Expected ID token, got empty string") + } + // Validate ID token is a JWT (has 3 parts) + parts := strings.Split(tokens.IDToken, ".") + if len(parts) != 3 { + t.Errorf("Expected JWT with 3 parts, got %d parts", len(parts)) + } +} + +// createMockOIDCServer creates a mock OIDC server for testing InitiateOIDCLogin. +func createMockOIDCServer(t *testing.T) *httptest.Server { + var serverURL string + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle OIDC discovery + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": serverURL, + "authorization_endpoint": serverURL + "/authorize", + "token_endpoint": serverURL + "/token", + "jwks_uri": serverURL + "/.well-known/jwks.json", + "response_types_supported": []string{"code"}, + "code_challenge_methods_supported": []string{"S256"}, + "grant_types_supported": []string{"authorization_code"}, + }) + return + } + http.NotFound(w, r) + })) + serverURL = server.URL + return server +} + +// createMockOIDCServerWithToken creates a mock OIDC server that also handles token exchange. +func createMockOIDCServerWithToken(t *testing.T) *httptest.Server { + var serverURL string + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle OIDC discovery + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": serverURL, + "authorization_endpoint": serverURL + "/authorize", + "token_endpoint": serverURL + "/token", + "jwks_uri": serverURL + "/.well-known/jwks.json", + "response_types_supported": []string{"code"}, + "code_challenge_methods_supported": []string{"S256"}, + "grant_types_supported": []string{"authorization_code"}, + }) + return + } + // Handle token endpoint + if r.URL.Path == "/token" { + if err := r.ParseForm(); err != nil { + http.Error(w, "failed to parse form", http.StatusBadRequest) + return + } + // Validate grant type + if r.FormValue("grant_type") != "authorization_code" { + http.Error(w, "invalid grant_type", http.StatusBadRequest) + return + } + // Create mock ID token (JWT) + now := time.Now().Unix() + idToken := fmt.Sprintf("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.%s.mock-signature", + base64EncodeClaims(map[string]interface{}{ + "iss": serverURL, + "sub": "test-user", + "aud": "test-client", + "exp": now + 3600, + "iat": now, + "email": "test@example.com", + })) + // Return token response + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "mock-access-token", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "mock-refresh-token", + "id_token": idToken, + }) + return + } + http.NotFound(w, r) + })) + serverURL = server.URL + return server +} + +// base64EncodeClaims encodes JWT claims for testing. +func base64EncodeClaims(claims map[string]interface{}) string { + claimsJSON, _ := json.Marshal(claims) + return base64.RawURLEncoding.EncodeToString(claimsJSON) +} diff --git a/docs/protocol.md b/docs/protocol.md index 16ba0bfa..2fa9d2db 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -310,6 +310,41 @@ Client-side OAuth is implemented by setting [`StreamableClientTransport.HTTPClient`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk@v0.5.0/mcp#StreamableClientTransport.HTTPClient) to a custom [`http.Client`](https://pkg.go.dev/net/http#Client) Additional support is forthcoming; see modelcontextprotocol/go-sdk#493. +#### Enterprise Authentication Flow (SEP-990) + +For enterprise SSO scenarios, the SDK provides an +[`EnterpriseAuthFlow`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#EnterpriseAuthFlow) +function that implements the complete token exchange flow: + +1. **Token Exchange** at IdP: ID Token → ID-JAG +2. **JWT Bearer Grant** at MCP Server: ID-JAG → Access Token + +This flow is typically used after obtaining an ID Token via OIDC login: + +```go +// Step 1: Obtain ID token via OIDC (see auth.InitiateOIDCLogin and auth.CompleteOIDCLogin) +idToken := "..." // from OIDC login + +// Step 2: Exchange for MCP access token +config := &auth.EnterpriseAuthConfig{ + IdPIssuerURL: "https://company.okta.com", + IdPClientID: "client-id-at-idp", + IdPClientSecret: "secret-at-idp", + MCPAuthServerURL: "https://auth.mcpserver.example", + MCPResourceURI: "https://mcp.mcpserver.example", + MCPClientID: "client-id-at-mcp", + MCPClientSecret: "secret-at-mcp", + MCPScopes: []string{"read", "write"}, +} + +accessToken, err := auth.EnterpriseAuthFlow(ctx, config, idToken) +// Use accessToken with MCP client +``` + +Helper functions are provided for OIDC login: +- [`InitiateOIDCLogin`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#InitiateOIDCLogin) - Generate authorization URL with PKCE +- [`CompleteOIDCLogin`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#CompleteOIDCLogin) - Exchange authorization code for tokens + ## Security Here we discuss the mitigations described under @@ -504,3 +539,4 @@ func Example_progress() { // frobbing widgets 2/2 } ``` + diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index ada34371..63a7e20d 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -236,6 +236,41 @@ Client-side OAuth is implemented by setting [`StreamableClientTransport.HTTPClient`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk@v0.5.0/mcp#StreamableClientTransport.HTTPClient) to a custom [`http.Client`](https://pkg.go.dev/net/http#Client) Additional support is forthcoming; see modelcontextprotocol/go-sdk#493. +#### Enterprise Authentication Flow (SEP-990) + +For enterprise SSO scenarios, the SDK provides an +[`EnterpriseAuthFlow`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#EnterpriseAuthFlow) +function that implements the complete token exchange flow: + +1. **Token Exchange** at IdP: ID Token → ID-JAG +2. **JWT Bearer Grant** at MCP Server: ID-JAG → Access Token + +This flow is typically used after obtaining an ID Token via OIDC login: + +```go +// Step 1: Obtain ID token via OIDC (see auth.InitiateOIDCLogin and auth.CompleteOIDCLogin) +idToken := "..." // from OIDC login + +// Step 2: Exchange for MCP access token +config := &auth.EnterpriseAuthConfig{ + IdPIssuerURL: "https://company.okta.com", + IdPClientID: "client-id-at-idp", + IdPClientSecret: "secret-at-idp", + MCPAuthServerURL: "https://auth.mcpserver.example", + MCPResourceURI: "https://mcp.mcpserver.example", + MCPClientID: "client-id-at-mcp", + MCPClientSecret: "secret-at-mcp", + MCPScopes: []string{"read", "write"}, +} + +accessToken, err := auth.EnterpriseAuthFlow(ctx, config, idToken) +// Use accessToken with MCP client +``` + +Helper functions are provided for OIDC login: +- [`InitiateOIDCLogin`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#InitiateOIDCLogin) - Generate authorization URL with PKCE +- [`CompleteOIDCLogin`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#CompleteOIDCLogin) - Exchange authorization code for tokens + ## Security Here we discuss the mitigations described under @@ -328,3 +363,4 @@ or Issue #460 discusses some potential ergonomic improvements to this API. %include ../../mcp/mcp_example_test.go progress - + diff --git a/oauthex/id_jag.go b/oauthex/id_jag.go new file mode 100644 index 00000000..860a36ea --- /dev/null +++ b/oauthex/id_jag.go @@ -0,0 +1,138 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements ID-JAG (Identity Assertion JWT Authorization Grant) parsing +// for Enterprise Managed Authorization (SEP-990). +// See https://github.com/modelcontextprotocol/ext-auth/blob/main/specification/draft/enterprise-managed-authorization.mdx + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "strings" + "time" +) + +// IDJAGClaims represents the claims in an Identity Assertion JWT Authorization Grant +// per SEP-990 Section 4.3. The ID-JAG is issued by the IdP during token exchange +// and describes the authorization grant for accessing an MCP Server. +type IDJAGClaims struct { + // Issuer is the IdP's issuer URL. + Issuer string `json:"iss"` + // Subject is the user identifier at the MCP Server. + Subject string `json:"sub"` + // Audience is the Issuer URL of the MCP Server's authorization server. + Audience string `json:"aud"` + // Resource is the Resource Identifier of the MCP Server. + Resource string `json:"resource"` + // ClientID is the identifier of the MCP Client that this JWT was issued to. + ClientID string `json:"client_id"` + // JTI is the unique identifier of this JWT. + JTI string `json:"jti"` + // ExpiresAt is the expiration time of this JWT (Unix timestamp). + ExpiresAt int64 `json:"exp"` + // IssuedAt is the time this JWT was issued (Unix timestamp). + IssuedAt int64 `json:"iat"` + // Scope is a space-separated list of scopes associated with the token. + Scope string `json:"scope,omitempty"` +} + +// Expiry returns the expiration time as a time.Time. +func (c *IDJAGClaims) Expiry() time.Time { + return time.Unix(c.ExpiresAt, 0) +} + +// IssuedTime returns the issued-at time as a time.Time. +func (c *IDJAGClaims) IssuedTime() time.Time { + return time.Unix(c.IssuedAt, 0) +} + +// IsExpired checks if the ID-JAG has expired. +func (c *IDJAGClaims) IsExpired() bool { + return time.Now().After(c.Expiry()) +} + +// ParseIDJAG parses an ID-JAG JWT and extracts its claims without validating +// the signature. This is useful for inspecting the contents of an ID-JAG during +// development or debugging. +// +// For production use on the server-side, use ValidateIDJAG instead, which +// performs full signature validation and claim verification. +// +// The JWT must have a "typ" header of "oauth-id-jag+jwt" per SEP-990 Section 4.3. +// +// Example: +// +// claims, err := ParseIDJAG(idJAG) +// if err != nil { +// log.Fatalf("Failed to parse ID-JAG: %v", err) +// } +// fmt.Printf("Subject: %s\n", claims.Subject) +// fmt.Printf("Expires: %v\n", claims.Expiry()) +func ParseIDJAG(jwt string) (*IDJAGClaims, error) { + if jwt == "" { + return nil, fmt.Errorf("JWT is empty") + } + // Split JWT into parts (header.payload.signature) + parts := strings.Split(jwt, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) + } + // Decode header to check typ claim + headerJSON, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + return nil, fmt.Errorf("failed to decode JWT header: %w", err) + } + var header struct { + Type string `json:"typ"` + Alg string `json:"alg"` + } + if err := json.Unmarshal(headerJSON, &header); err != nil { + return nil, fmt.Errorf("failed to parse JWT header: %w", err) + } + // Verify typ claim per SEP-990 Section 4.3 + if header.Type != "oauth-id-jag+jwt" { + return nil, fmt.Errorf("invalid JWT type: expected 'oauth-id-jag+jwt', got '%s'", header.Type) + } + // Decode payload + payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, fmt.Errorf("failed to decode JWT payload: %w", err) + } + // Parse claims + var claims IDJAGClaims + if err := json.Unmarshal(payloadJSON, &claims); err != nil { + return nil, fmt.Errorf("failed to parse JWT claims: %w", err) + } + // Validate required claims are present per SEP-990 Section 4.3 + if claims.Issuer == "" { + return nil, fmt.Errorf("missing required claim: iss") + } + if claims.Subject == "" { + return nil, fmt.Errorf("missing required claim: sub") + } + if claims.Audience == "" { + return nil, fmt.Errorf("missing required claim: aud") + } + if claims.Resource == "" { + return nil, fmt.Errorf("missing required claim: resource") + } + if claims.ClientID == "" { + return nil, fmt.Errorf("missing required claim: client_id") + } + if claims.JTI == "" { + return nil, fmt.Errorf("missing required claim: jti") + } + if claims.ExpiresAt == 0 { + return nil, fmt.Errorf("missing required claim: exp") + } + if claims.IssuedAt == 0 { + return nil, fmt.Errorf("missing required claim: iat") + } + return &claims, nil +} diff --git a/oauthex/id_jag_test.go b/oauthex/id_jag_test.go new file mode 100644 index 00000000..ff710fcc --- /dev/null +++ b/oauthex/id_jag_test.go @@ -0,0 +1,176 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "strings" + "testing" + "time" +) + +// TestParseIDJAG tests parsing of ID-JAG tokens. +func TestParseIDJAG(t *testing.T) { + // Create a test ID-JAG JWT + now := time.Now().Unix() + + header := map[string]string{ + "typ": "oauth-id-jag+jwt", + "alg": "RS256", + } + + claims := map[string]interface{}{ + "iss": "https://acme.okta.com", + "sub": "alice@acme.com", + "aud": "https://auth.mcpserver.example", + "resource": "https://mcp.mcpserver.example", + "client_id": "xyz789", + "jti": "unique-id-123", + "exp": now + 300, + "iat": now, + "scope": "read write", + } + // Encode header and payload + headerJSON, _ := json.Marshal(header) + claimsJSON, _ := json.Marshal(claims) + + headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) + claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON) + + // Create fake JWT (header.payload.signature) + fakeJWT := fmt.Sprintf("%s.%s.fake-signature", headerB64, claimsB64) + // Test successful parsing + t.Run("successful parse", func(t *testing.T) { + parsed, err := ParseIDJAG(fakeJWT) + if err != nil { + t.Fatalf("ParseIDJAG failed: %v", err) + } + if parsed.Issuer != "https://acme.okta.com" { + t.Errorf("expected issuer 'https://acme.okta.com', got '%s'", parsed.Issuer) + } + if parsed.Subject != "alice@acme.com" { + t.Errorf("expected subject 'alice@acme.com', got '%s'", parsed.Subject) + } + if parsed.Audience != "https://auth.mcpserver.example" { + t.Errorf("expected audience 'https://auth.mcpserver.example', got '%s'", parsed.Audience) + } + if parsed.Resource != "https://mcp.mcpserver.example" { + t.Errorf("expected resource 'https://mcp.mcpserver.example', got '%s'", parsed.Resource) + } + if parsed.ClientID != "xyz789" { + t.Errorf("expected client_id 'xyz789', got '%s'", parsed.ClientID) + } + if parsed.JTI != "unique-id-123" { + t.Errorf("expected jti 'unique-id-123', got '%s'", parsed.JTI) + } + if parsed.Scope != "read write" { + t.Errorf("expected scope 'read write', got '%s'", parsed.Scope) + } + if parsed.IsExpired() { + t.Error("expected ID-JAG not to be expired") + } + }) + // Test empty JWT + t.Run("empty JWT", func(t *testing.T) { + _, err := ParseIDJAG("") + if err == nil { + t.Error("expected error for empty JWT, got nil") + } + }) + // Test invalid format + t.Run("invalid format", func(t *testing.T) { + _, err := ParseIDJAG("invalid.jwt") + if err == nil { + t.Error("expected error for invalid JWT format, got nil") + } + }) + // Test wrong typ header + t.Run("wrong typ header", func(t *testing.T) { + wrongHeader := map[string]string{ + "typ": "JWT", // Should be "oauth-id-jag+jwt" + "alg": "RS256", + } + wrongHeaderJSON, _ := json.Marshal(wrongHeader) + wrongHeaderB64 := base64.RawURLEncoding.EncodeToString(wrongHeaderJSON) + wrongJWT := fmt.Sprintf("%s.%s.fake-signature", wrongHeaderB64, claimsB64) + _, err := ParseIDJAG(wrongJWT) + if err == nil { + t.Error("expected error for wrong typ header, got nil") + } + if err != nil && !strings.Contains(err.Error(), "invalid JWT type") { + t.Errorf("expected 'invalid JWT type' error, got: %v", err) + } + }) + // Test missing required claims + t.Run("missing required claims", func(t *testing.T) { + incompleteClaims := map[string]interface{}{ + "iss": "https://acme.okta.com", + // Missing other required claims + } + incompleteJSON, _ := json.Marshal(incompleteClaims) + incompleteB64 := base64.RawURLEncoding.EncodeToString(incompleteJSON) + incompleteJWT := fmt.Sprintf("%s.%s.fake-signature", headerB64, incompleteB64) + _, err := ParseIDJAG(incompleteJWT) + if err == nil { + t.Error("expected error for missing claims, got nil") + } + }) + // Test expired ID-JAG + t.Run("expired ID-JAG", func(t *testing.T) { + expiredClaims := map[string]interface{}{ + "iss": "https://acme.okta.com", + "sub": "alice@acme.com", + "aud": "https://auth.mcpserver.example", + "resource": "https://mcp.mcpserver.example", + "client_id": "xyz789", + "jti": "unique-id-123", + "exp": now - 300, // Expired 5 minutes ago + "iat": now - 600, + "scope": "read write", + } + expiredJSON, _ := json.Marshal(expiredClaims) + expiredB64 := base64.RawURLEncoding.EncodeToString(expiredJSON) + expiredJWT := fmt.Sprintf("%s.%s.fake-signature", headerB64, expiredB64) + parsed, err := ParseIDJAG(expiredJWT) + if err != nil { + t.Fatalf("ParseIDJAG failed: %v", err) + } + if !parsed.IsExpired() { + t.Error("expected ID-JAG to be expired") + } + }) +} + +// TestIDJAGClaimsMethods tests the helper methods on IDJAGClaims. +func TestIDJAGClaimsMethods(t *testing.T) { + now := time.Now() + claims := &IDJAGClaims{ + ExpiresAt: now.Add(1 * time.Hour).Unix(), + IssuedAt: now.Unix(), + } + // Test Expiry + expiry := claims.Expiry() + if expiry.Before(now) { + t.Error("expected expiry to be in the future") + } + // Test IssuedTime + issued := claims.IssuedTime() + if issued.After(now.Add(1 * time.Second)) { + t.Error("expected issued time to be in the past") + } + // Test IsExpired (should not be expired) + if claims.IsExpired() { + t.Error("expected claims not to be expired") + } + // Test IsExpired (should be expired) + claims.ExpiresAt = now.Add(-1 * time.Hour).Unix() + if !claims.IsExpired() { + t.Error("expected claims to be expired") + } +} diff --git a/oauthex/jwt_bearer.go b/oauthex/jwt_bearer.go new file mode 100644 index 00000000..03a0df5d --- /dev/null +++ b/oauthex/jwt_bearer.go @@ -0,0 +1,193 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements JWT Bearer Authorization Grant (RFC 7523) for Enterprise Managed Authorization. +// See https://datatracker.ietf.org/doc/html/rfc7523 + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "golang.org/x/oauth2" +) + +// GrantTypeJWTBearer is the grant type for RFC 7523 JWT Bearer authorization grant. +// This is used in SEP-990 to exchange an ID-JAG for an access token at the MCP Server. +const GrantTypeJWTBearer = "urn:ietf:params:oauth:grant-type:jwt-bearer" + +// JWTBearerResponse represents the response from a JWT Bearer grant request +// per RFC 7523. This uses the standard OAuth 2.0 token response format. +type JWTBearerResponse struct { + // AccessToken is the OAuth access token issued by the MCP Server's + // authorization server. + AccessToken string `json:"access_token"` + // TokenType is the type of token issued. This is typically "Bearer". + TokenType string `json:"token_type"` + // ExpiresIn is the lifetime in seconds of the access token. + ExpiresIn int `json:"expires_in,omitempty"` + // RefreshToken is the refresh token, which can be used to obtain new + // access tokens using the same authorization grant. + RefreshToken string `json:"refresh_token,omitempty"` + // Scope is the scope of the access token as described by RFC 6749 Section 3.3. + Scope string `json:"scope,omitempty"` +} + +// JWTBearerError represents an error response from a JWT Bearer grant request. +type JWTBearerError struct { + // ErrorCode is the error code as defined in RFC 6749 Section 5.2. + // The JSON field name is "error" per the RFC specification. + ErrorCode string `json:"error"` + // ErrorDescription is a human-readable description of the error. + ErrorDescription string `json:"error_description,omitempty"` + // ErrorURI is a URI identifying a human-readable web page with information + // about the error. + ErrorURI string `json:"error_uri,omitempty"` +} + +func (e *JWTBearerError) Error() string { + if e.ErrorDescription != "" { + return fmt.Sprintf("JWT bearer grant failed: %s (%s)", e.ErrorCode, e.ErrorDescription) + } + return fmt.Sprintf("JWT bearer grant failed: %s", e.ErrorCode) +} + +// ExchangeJWTBearer exchanges an Identity Assertion JWT Authorization Grant (ID-JAG) +// for an access token using JWT Bearer Grant per RFC 7523. This is the second step +// in Enterprise Managed Authorization (SEP-990) after obtaining the ID-JAG from the +// IdP via token exchange. +// +// The tokenEndpoint parameter should be the MCP Server's token endpoint (typically +// obtained from the MCP Server's authorization server metadata). +// +// The assertion parameter should be the ID-JAG JWT obtained from the token exchange +// step with the enterprise IdP. +// +// Client authentication must be performed by the caller by including appropriate +// credentials in the request (e.g., using Basic auth via the Authorization header, +// or including client_id and client_secret in the form data). +// +// Example: +// +// // First, get ID-JAG via token exchange +// idJAG := tokenExchangeResp.AccessToken +// +// // Then exchange ID-JAG for access token +// token, err := ExchangeJWTBearer( +// ctx, +// "https://auth.mcpserver.example/oauth2/token", +// idJAG, +// "mcp-client-id", +// "mcp-client-secret", +// nil, +// ) +func ExchangeJWTBearer( + ctx context.Context, + tokenEndpoint string, + assertion string, + clientID string, + clientSecret string, + httpClient *http.Client, +) (*oauth2.Token, error) { + if tokenEndpoint == "" { + return nil, fmt.Errorf("token endpoint is required") + } + if assertion == "" { + return nil, fmt.Errorf("assertion is required") + } + // Validate URL scheme to prevent XSS attacks (see #526) + if err := checkURLScheme(tokenEndpoint); err != nil { + return nil, fmt.Errorf("invalid token endpoint: %w", err) + } + // Build the JWT Bearer grant request per RFC 7523 Section 2.1 + formData := url.Values{} + formData.Set("grant_type", GrantTypeJWTBearer) + formData.Set("assertion", assertion) + // Add client authentication (following OAuth 2.0 client_secret_post method) + // Note: Per SEP-990 Section 5.1, the client_id in the assertion must match + // the authenticated client + if clientID != "" { + formData.Set("client_id", clientID) + } + if clientSecret != "" { + formData.Set("client_secret", clientSecret) + } + // Create HTTP request + httpReq, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + tokenEndpoint, + strings.NewReader(formData.Encode()), + ) + if err != nil { + return nil, fmt.Errorf("failed to create JWT bearer grant request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + httpReq.Header.Set("Accept", "application/json") + // Use provided client or default + if httpClient == nil { + httpClient = http.DefaultClient + } + // Execute the request + httpResp, err := httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("JWT bearer grant request failed: %w", err) + } + defer httpResp.Body.Close() + // Read response body (limit to 1MB for safety, following SDK pattern) + body, err := io.ReadAll(io.LimitReader(httpResp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("failed to read JWT bearer grant response: %w", err) + } + // Handle success response (200 OK per OAuth 2.0) + if httpResp.StatusCode == http.StatusOK { + var resp JWTBearerResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("failed to parse JWT bearer grant response: %w (body: %s)", err, string(body)) + } + // Validate response per OAuth 2.0 + if resp.AccessToken == "" { + return nil, fmt.Errorf("response missing required field: access_token") + } + if resp.TokenType == "" { + return nil, fmt.Errorf("response missing required field: token_type") + } + // Convert to golang.org/x/oauth2.Token + token := &oauth2.Token{ + AccessToken: resp.AccessToken, + TokenType: resp.TokenType, + RefreshToken: resp.RefreshToken, + } + // Set expiration if provided + if resp.ExpiresIn > 0 { + token.Expiry = time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second) + } + // Add scope to extra data if provided + if resp.Scope != "" { + token = token.WithExtra(map[string]interface{}{ + "scope": resp.Scope, + }) + } + return token, nil + } + // Handle error response (400 Bad Request per RFC 6749) + if httpResp.StatusCode == http.StatusBadRequest { + var errResp JWTBearerError + if err := json.Unmarshal(body, &errResp); err != nil { + return nil, fmt.Errorf("failed to parse error response: %w (body: %s)", err, string(body)) + } + return nil, &errResp + } + // Handle unexpected status codes + return nil, fmt.Errorf("unexpected status code %d: %s", httpResp.StatusCode, string(body)) +} diff --git a/oauthex/jwt_bearer_test.go b/oauthex/jwt_bearer_test.go new file mode 100644 index 00000000..3145d0bf --- /dev/null +++ b/oauthex/jwt_bearer_test.go @@ -0,0 +1,154 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +// TestExchangeJWTBearer tests the JWT Bearer grant flow. +func TestExchangeJWTBearer(t *testing.T) { + // Create a test MCP Server auth server that accepts JWT Bearer grants + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request method and content type + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + contentType := r.Header.Get("Content-Type") + if contentType != "application/x-www-form-urlencoded" { + t.Errorf("expected application/x-www-form-urlencoded, got %s", contentType) + http.Error(w, "invalid content type", http.StatusBadRequest) + return + } + // Parse form data + if err := r.ParseForm(); err != nil { + http.Error(w, "failed to parse form", http.StatusBadRequest) + return + } + // Verify grant type per RFC 7523 + grantType := r.FormValue("grant_type") + if grantType != GrantTypeJWTBearer { + t.Errorf("expected grant_type %s, got %s", GrantTypeJWTBearer, grantType) + writeJWTBearerErrorResponse(w, "unsupported_grant_type", "grant type not supported") + return + } + // Verify assertion is provided + assertion := r.FormValue("assertion") + if assertion == "" { + t.Error("assertion is required") + writeJWTBearerErrorResponse(w, "invalid_request", "missing assertion") + return + } + // Verify client authentication + clientID := r.FormValue("client_id") + clientSecret := r.FormValue("client_secret") + if clientID == "" || clientSecret == "" { + t.Error("client authentication required") + writeJWTBearerErrorResponse(w, "invalid_client", "client authentication failed") + return + } + if clientID != "mcp-client-id" || clientSecret != "mcp-client-secret" { + t.Error("invalid client credentials") + writeJWTBearerErrorResponse(w, "invalid_client", "invalid credentials") + return + } + // Return successful OAuth token response + resp := JWTBearerResponse{ + AccessToken: "mcp-access-token-123", + TokenType: "Bearer", + ExpiresIn: 3600, + Scope: "read write", + RefreshToken: "mcp-refresh-token-456", + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + // Test successful JWT Bearer grant + t.Run("successful exchange", func(t *testing.T) { + token, err := ExchangeJWTBearer( + context.Background(), + server.URL, + "fake-id-jag-jwt", + "mcp-client-id", + "mcp-client-secret", + server.Client(), + ) + if err != nil { + t.Fatalf("ExchangeJWTBearer failed: %v", err) + } + if token.AccessToken != "mcp-access-token-123" { + t.Errorf("expected access_token 'mcp-access-token-123', got %s", token.AccessToken) + } + if token.TokenType != "Bearer" { + t.Errorf("expected token_type 'Bearer', got %s", token.TokenType) + } + if token.RefreshToken != "mcp-refresh-token-456" { + t.Errorf("expected refresh_token 'mcp-refresh-token-456', got %s", token.RefreshToken) + } + // Check expiration (should be ~1 hour from now) + expectedExpiry := time.Now().Add(3600 * time.Second) + if token.Expiry.Before(time.Now()) || token.Expiry.After(expectedExpiry.Add(5*time.Second)) { + t.Errorf("unexpected expiry time: %v", token.Expiry) + } + // Check scope in extra data + scope, ok := token.Extra("scope").(string) + if !ok || scope != "read write" { + t.Errorf("expected scope 'read write', got %v", token.Extra("scope")) + } + }) + // Test missing assertion + t.Run("missing assertion", func(t *testing.T) { + _, err := ExchangeJWTBearer( + context.Background(), + server.URL, + "", // empty assertion + "mcp-client-id", + "mcp-client-secret", + server.Client(), + ) + if err == nil { + t.Error("expected error for missing assertion, got nil") + } + }) + // Test invalid URL scheme + t.Run("invalid token endpoint URL", func(t *testing.T) { + _, err := ExchangeJWTBearer( + context.Background(), + "javascript:alert(1)", + "fake-id-jag-jwt", + "mcp-client-id", + "mcp-client-secret", + server.Client(), + ) + if err == nil { + t.Error("expected error for invalid URL scheme, got nil") + } + }) +} + +// writeJWTBearerErrorResponse writes an OAuth 2.0 error response per RFC 6749 Section 5.2. +func writeJWTBearerErrorResponse(w http.ResponseWriter, errorCode, errorDescription string) { + errResp := JWTBearerError{ + ErrorCode: errorCode, + ErrorDescription: errorDescription, + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(errResp) +} diff --git a/oauthex/oauth2.go b/oauthex/oauth2.go index cdda695b..ab72f699 100644 --- a/oauthex/oauth2.go +++ b/oauthex/oauth2.go @@ -58,10 +58,9 @@ func getJSON[T any](ctx context.Context, c *http.Client, url string, limit int64 return nil, fmt.Errorf("bad status %s", res.Status) } // Specs require application/json. - ct := res.Header.Get("Content-Type") - mediaType, _, err := mime.ParseMediaType(ct) - if err != nil || mediaType != "application/json" { - return nil, fmt.Errorf("bad content type %q", ct) + ct := strings.TrimSpace(strings.SplitN(res.Header.Get("Content-Type"), ";", 2)[0]) + if ct != "application/json" { + return nil, fmt.Errorf("bad content type %q", res.Header.Get("Content-Type")) } var t T @@ -89,3 +88,9 @@ func checkURLScheme(u string) error { } return nil } + +// CheckURLScheme validates a URL scheme for security. +// This is exported for use by the auth package. +func CheckURLScheme(u string) error { + return checkURLScheme(u) +} diff --git a/oauthex/token_exchange.go b/oauthex/token_exchange.go new file mode 100644 index 00000000..fb162d0b --- /dev/null +++ b/oauthex/token_exchange.go @@ -0,0 +1,267 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements Token Exchange (RFC 8693) for Enterprise Managed Authorization. +// See https://datatracker.ietf.org/doc/html/rfc8693 + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" +) + +// Token type identifiers defined by RFC 8693 and SEP-990. +const ( + // TokenTypeIDToken is the URN for OpenID Connect ID Tokens. + TokenTypeIDToken = "urn:ietf:params:oauth:token-type:id_token" + + // TokenTypeSAML2 is the URN for SAML 2.0 assertions. + TokenTypeSAML2 = "urn:ietf:params:oauth:token-type:saml2" + + // TokenTypeIDJAG is the URN for Identity Assertion JWT Authorization Grants. + // This is the token type returned by IdP during token exchange for SEP-990. + TokenTypeIDJAG = "urn:ietf:params:oauth:token-type:id-jag" + + // GrantTypeTokenExchange is the grant type for RFC 8693 token exchange. + GrantTypeTokenExchange = "urn:ietf:params:oauth:grant-type:token-exchange" +) + +// TokenExchangeRequest represents a Token Exchange request per RFC 8693. +// This is used for Enterprise Managed Authorization (SEP-990) where an MCP Client +// exchanges an ID Token from an enterprise IdP for an ID-JAG that can be used +// to obtain an access token from an MCP Server's authorization server. +type TokenExchangeRequest struct { + // RequestedTokenType indicates the type of security token being requested. + // For SEP-990, this MUST be TokenTypeIDJAG. + RequestedTokenType string + + // Audience is the logical name of the target service where the client + // intends to use the requested token. For SEP-990, this MUST be the + // Issuer URL of the MCP Server's authorization server. + Audience string + + // Resource is the physical location or identifier of the target resource. + // For SEP-990, this MUST be the RFC9728 Resource Identifier of the MCP Server. + Resource string + + // Scope is a list of space-separated scopes for the requested token. + // This is OPTIONAL per RFC 8693 but commonly used in SEP-990. + Scope []string + + // SubjectToken is the security token that represents the identity of the + // party on behalf of whom the request is being made. For SEP-990, this is + // typically an OpenID Connect ID Token. + SubjectToken string + + // SubjectTokenType is the type of the security token in SubjectToken. + // For SEP-990 with OIDC, this MUST be TokenTypeIDToken. + SubjectTokenType string +} + +// TokenExchangeResponse represents the response from a token exchange request +// per RFC 8693 Section 2.2. +type TokenExchangeResponse struct { + // IssuedTokenType is the type of the security token in AccessToken. + // For SEP-990, this MUST be TokenTypeIDJAG. + IssuedTokenType string `json:"issued_token_type"` + + // AccessToken is the security token issued by the authorization server. + // Despite the name "access_token" (required by RFC 8693), for SEP-990 + // this contains an ID-JAG JWT, not an OAuth access token. + AccessToken string `json:"access_token"` + + // TokenType indicates the type of token returned. For SEP-990, this is "N_A" + // because the issued token is not an OAuth access token. + TokenType string `json:"token_type"` + + // Scope is the scope of the issued token, if the issued token scope is + // different from the requested scope. Per RFC 8693, this SHOULD be included + // if the scope differs from the request. + Scope string `json:"scope,omitempty"` + + // ExpiresIn is the lifetime in seconds of the issued token. + ExpiresIn int `json:"expires_in,omitempty"` +} + +// TokenExchangeError represents an error response from a token exchange request. +type TokenExchangeError struct { + // Error is the error code as defined in RFC 6749 Section 5.2. + ErrorCode string `json:"error"` + + // ErrorDescription is a human-readable description of the error. + ErrorDescription string `json:"error_description,omitempty"` + + // ErrorURI is a URI identifying a human-readable web page with information + // about the error. + ErrorURI string `json:"error_uri,omitempty"` +} + +func (e *TokenExchangeError) Error() string { + if e.ErrorDescription != "" { + return fmt.Sprintf("token exchange failed: %s (%s)", e.ErrorCode, e.ErrorDescription) + } + return fmt.Sprintf("token exchange failed: %s", e.ErrorCode) +} + +// ExchangeToken performs a token exchange request per RFC 8693 for Enterprise +// Managed Authorization (SEP-990). It exchanges an identity assertion (typically +// an ID Token) for an Identity Assertion JWT Authorization Grant (ID-JAG) that +// can be used to obtain an access token from an MCP Server. +// +// The tokenEndpoint parameter should be the IdP's token endpoint (typically +// obtained from the IdP's authorization server metadata). +// +// Client authentication must be performed by the caller by including appropriate +// credentials in the request (e.g., using Basic auth via the Authorization header, +// or including client_id and client_secret in the form data). +// +// Example: +// +// req := &TokenExchangeRequest{ +// RequestedTokenType: TokenTypeIDJAG, +// Audience: "https://auth.mcpserver.example/", +// Resource: "https://mcp.mcpserver.example/", +// Scope: []string{"read", "write"}, +// SubjectToken: idToken, +// SubjectTokenType: TokenTypeIDToken, +// } +// +// resp, err := ExchangeToken(ctx, idpTokenEndpoint, req, clientID, clientSecret, nil) +func ExchangeToken( + ctx context.Context, + tokenEndpoint string, + req *TokenExchangeRequest, + clientID string, + clientSecret string, + httpClient *http.Client, +) (*TokenExchangeResponse, error) { + if tokenEndpoint == "" { + return nil, fmt.Errorf("token endpoint is required") + } + if req == nil { + return nil, fmt.Errorf("token exchange request is required") + } + + // Validate required fields per SEP-990 Section 4 + if req.RequestedTokenType == "" { + return nil, fmt.Errorf("requested_token_type is required") + } + if req.Audience == "" { + return nil, fmt.Errorf("audience is required") + } + if req.Resource == "" { + return nil, fmt.Errorf("resource is required") + } + if req.SubjectToken == "" { + return nil, fmt.Errorf("subject_token is required") + } + if req.SubjectTokenType == "" { + return nil, fmt.Errorf("subject_token_type is required") + } + + // Validate URL schemes to prevent XSS attacks (see #526) + if err := checkURLScheme(tokenEndpoint); err != nil { + return nil, fmt.Errorf("invalid token endpoint: %w", err) + } + if err := checkURLScheme(req.Audience); err != nil { + return nil, fmt.Errorf("invalid audience: %w", err) + } + if err := checkURLScheme(req.Resource); err != nil { + return nil, fmt.Errorf("invalid resource: %w", err) + } + + // Build the token exchange request body per RFC 8693 + formData := url.Values{} + formData.Set("grant_type", GrantTypeTokenExchange) + formData.Set("requested_token_type", req.RequestedTokenType) + formData.Set("audience", req.Audience) + formData.Set("resource", req.Resource) + formData.Set("subject_token", req.SubjectToken) + formData.Set("subject_token_type", req.SubjectTokenType) + + if len(req.Scope) > 0 { + formData.Set("scope", strings.Join(req.Scope, " ")) + } + + // Add client authentication (following OAuth 2.0 client_secret_post method) + if clientID != "" { + formData.Set("client_id", clientID) + } + if clientSecret != "" { + formData.Set("client_secret", clientSecret) + } + + // Create HTTP request + httpReq, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + tokenEndpoint, + strings.NewReader(formData.Encode()), + ) + if err != nil { + return nil, fmt.Errorf("failed to create token exchange request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + httpReq.Header.Set("Accept", "application/json") + + // Use provided client or default + if httpClient == nil { + httpClient = http.DefaultClient + } + + // Execute the request + httpResp, err := httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("token exchange request failed: %w", err) + } + defer httpResp.Body.Close() + + // Read response body (limit to 1MB for safety, following SDK pattern) + body, err := io.ReadAll(io.LimitReader(httpResp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("failed to read token exchange response: %w", err) + } + + // Handle success response (200 OK per RFC 8693) + if httpResp.StatusCode == http.StatusOK { + var resp TokenExchangeResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("failed to parse token exchange response: %w (body: %s)", err, string(body)) + } + + // Validate response per SEP-990 Section 4.2 + if resp.IssuedTokenType == "" { + return nil, fmt.Errorf("response missing required field: issued_token_type") + } + if resp.AccessToken == "" { + return nil, fmt.Errorf("response missing required field: access_token") + } + if resp.TokenType == "" { + return nil, fmt.Errorf("response missing required field: token_type") + } + + return &resp, nil + } + + // Handle error response (400 Bad Request per RFC 6749) + if httpResp.StatusCode == http.StatusBadRequest { + var errResp TokenExchangeError + if err := json.Unmarshal(body, &errResp); err != nil { + return nil, fmt.Errorf("failed to parse error response: %w (body: %s)", err, string(body)) + } + return nil, &errResp + } + + // Handle unexpected status codes + return nil, fmt.Errorf("unexpected status code %d: %s", httpResp.StatusCode, string(body)) +} diff --git a/oauthex/token_exchange_test.go b/oauthex/token_exchange_test.go new file mode 100644 index 00000000..316fa8e3 --- /dev/null +++ b/oauthex/token_exchange_test.go @@ -0,0 +1,220 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +// TestExchangeToken tests the basic token exchange flow. +func TestExchangeToken(t *testing.T) { + // Create a test IdP server that implements token exchange + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request method and content type + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + contentType := r.Header.Get("Content-Type") + if contentType != "application/x-www-form-urlencoded" { + t.Errorf("expected application/x-www-form-urlencoded, got %s", contentType) + http.Error(w, "invalid content type", http.StatusBadRequest) + return + } + + // Parse form data + if err := r.ParseForm(); err != nil { + http.Error(w, "failed to parse form", http.StatusBadRequest) + return + } + + // Verify required parameters per SEP-990 Section 4 + grantType := r.FormValue("grant_type") + if grantType != GrantTypeTokenExchange { + t.Errorf("expected grant_type %s, got %s", GrantTypeTokenExchange, grantType) + writeErrorResponse(w, "invalid_grant", "invalid grant_type") + return + } + + requestedTokenType := r.FormValue("requested_token_type") + if requestedTokenType != TokenTypeIDJAG { + t.Errorf("expected requested_token_type %s, got %s", TokenTypeIDJAG, requestedTokenType) + writeErrorResponse(w, "invalid_request", "invalid requested_token_type") + return + } + + audience := r.FormValue("audience") + if audience == "" { + t.Error("audience is required") + writeErrorResponse(w, "invalid_request", "missing audience") + return + } + + resource := r.FormValue("resource") + if resource == "" { + t.Error("resource is required") + writeErrorResponse(w, "invalid_request", "missing resource") + return + } + + subjectToken := r.FormValue("subject_token") + if subjectToken == "" { + t.Error("subject_token is required") + writeErrorResponse(w, "invalid_request", "missing subject_token") + return + } + + subjectTokenType := r.FormValue("subject_token_type") + if subjectTokenType != TokenTypeIDToken { + t.Errorf("expected subject_token_type %s, got %s", TokenTypeIDToken, subjectTokenType) + writeErrorResponse(w, "invalid_request", "invalid subject_token_type") + return + } + + // Verify client authentication + clientID := r.FormValue("client_id") + clientSecret := r.FormValue("client_secret") + if clientID == "" || clientSecret == "" { + t.Error("client authentication required") + writeErrorResponse(w, "invalid_client", "client authentication failed") + return + } + + if clientID != "test-client-id" || clientSecret != "test-client-secret" { + t.Error("invalid client credentials") + writeErrorResponse(w, "invalid_client", "invalid credentials") + return + } + + // Return successful token exchange response per SEP-990 Section 4.2 + resp := TokenExchangeResponse{ + IssuedTokenType: TokenTypeIDJAG, + AccessToken: "fake-id-jag-token", + TokenType: "N_A", + Scope: r.FormValue("scope"), + ExpiresIn: 300, + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + // Test successful token exchange + t.Run("successful exchange", func(t *testing.T) { + req := &TokenExchangeRequest{ + RequestedTokenType: TokenTypeIDJAG, + Audience: "https://auth.mcpserver.example/", + Resource: "https://mcp.mcpserver.example/", + Scope: []string{"read", "write"}, + SubjectToken: "fake-id-token", + SubjectTokenType: TokenTypeIDToken, + } + + resp, err := ExchangeToken( + context.Background(), + server.URL, + req, + "test-client-id", + "test-client-secret", + server.Client(), + ) + + if err != nil { + t.Fatalf("ExchangeToken failed: %v", err) + } + + if resp.IssuedTokenType != TokenTypeIDJAG { + t.Errorf("expected issued_token_type %s, got %s", TokenTypeIDJAG, resp.IssuedTokenType) + } + + if resp.AccessToken != "fake-id-jag-token" { + t.Errorf("expected access_token 'fake-id-jag-token', got %s", resp.AccessToken) + } + + if resp.TokenType != "N_A" { + t.Errorf("expected token_type 'N_A', got %s", resp.TokenType) + } + + if resp.Scope != "read write" { + t.Errorf("expected scope 'read write', got %s", resp.Scope) + } + + if resp.ExpiresIn != 300 { + t.Errorf("expected expires_in 300, got %d", resp.ExpiresIn) + } + }) + + // Test missing required fields + t.Run("missing audience", func(t *testing.T) { + req := &TokenExchangeRequest{ + RequestedTokenType: TokenTypeIDJAG, + Resource: "https://mcp.mcpserver.example/", + SubjectToken: "fake-id-token", + SubjectTokenType: TokenTypeIDToken, + } + + _, err := ExchangeToken( + context.Background(), + server.URL, + req, + "test-client-id", + "test-client-secret", + server.Client(), + ) + + if err == nil { + t.Error("expected error for missing audience, got nil") + } + }) + + // Test invalid URL schemes + t.Run("invalid audience URL scheme", func(t *testing.T) { + req := &TokenExchangeRequest{ + RequestedTokenType: TokenTypeIDJAG, + Audience: "javascript:alert(1)", + Resource: "https://mcp.mcpserver.example/", + SubjectToken: "fake-id-token", + SubjectTokenType: TokenTypeIDToken, + } + + _, err := ExchangeToken( + context.Background(), + server.URL, + req, + "test-client-id", + "test-client-secret", + server.Client(), + ) + + if err == nil { + t.Error("expected error for invalid audience URL scheme, got nil") + } + }) +} + +// writeErrorResponse writes an OAuth 2.0 error response per RFC 6749 Section 5.2. +func writeErrorResponse(w http.ResponseWriter, errorCode, errorDescription string) { + errResp := TokenExchangeError{ + ErrorCode: errorCode, + ErrorDescription: errorDescription, + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(errResp) +}