diff --git a/docs/error-handling.md b/docs/error-handling.md index a2ecfe83d..bf639a7ea 100644 --- a/docs/error-handling.md +++ b/docs/error-handling.md @@ -23,10 +23,51 @@ There are two acceptable ways to construct errors in ToolHive: ToolHive provides a typed error system for common error scenarios. Each error type has an associated HTTP status code for API responses. -Error types are defined in: -- `pkg/errors/errors.go` - Core application errors (CLI, proxy, etc.) +### Creating Errors with HTTP Status Codes + +Use `errors.WithCode()` to associate HTTP status codes with errors: + +```go +import ( + "errors" + "net/http" + + thverrors "github.com/stacklok/toolhive/pkg/errors" +) + +// Define an error with a status code +var ErrWorkloadNotFound = thverrors.WithCode( + errors.New("workload not found"), + http.StatusNotFound, +) + +// Create a new error inline with a status code +return thverrors.WithCode( + fmt.Errorf("invalid request: %w", err), + http.StatusBadRequest, +) +``` + +### Extracting Status Codes + +Use `errors.Code()` to extract the HTTP status code from an error: + +```go +code := thverrors.Code(err) // Returns 500 if no code is found +``` + +### Error Definitions + +Error types with HTTP status codes are defined in: +- `pkg/errors/errors.go` - Core error utilities (`WithCode`, `Code`, `CodedError`) +- `pkg/groups/errors.go` - Group-related errors +- `pkg/container/runtime/types.go` - Runtime errors (`ErrWorkloadNotFound`) +- `pkg/workloads/types/validate.go` - Workload validation errors +- `pkg/secrets/factory.go` - Secrets provider errors +- `pkg/transport/session/errors.go` - Transport session errors - `pkg/vmcp/errors.go` - Virtual MCP Server domain errors +In general, define errors near the code that produces the error. ## Error Wrapping Guidelines @@ -54,23 +95,41 @@ which particular step is failing. Consider using `errors.WithStack` or `errors.W ## API Error Handling -### Response Format - -API errors are returned as plain text using `http.Error()`: +### Handler Pattern -Response codes are derived from unwrapping the error and this happens in a common middleware layer. +API handlers return errors instead of calling `http.Error()` directly. The `ErrorHandler` decorator in `pkg/api/errors/handler.go` converts errors to HTTP responses: -See pkg/api/errors/ for more details. -TODO: implement common middleware for error and panic handling. -TODO: integrate handler into setupDefaultRoutes. -TODO: update documentation on APIs. +```go +// Define a handler that returns an error +func (s *Routes) getWorkload(w http.ResponseWriter, r *http.Request) error { + workload, err := s.manager.GetWorkload(ctx, name) + if err != nil { + return err // ErrWorkloadNotFound already has 404 status code + } + + // For errors without a status code, wrap with WithCode + if someCondition { + return thverrors.WithCode( + fmt.Errorf("invalid input"), + http.StatusBadRequest, + ) + } + + // Success case - write response + return json.NewEncoder(w).Encode(workload) +} +// Wire up with the ErrorHandler decorator +r.Get("/{name}", apierrors.ErrorHandler(routes.getWorkload)) +``` ### Error Response Behavior -1. **First matching error code wins** - Check specific errors first, then fall back to generic internal server error. -2. **Hide internal details** - For 500 errors, log the full error but return a generic message to the user -3. **Include context for client errors** - For 400/404 errors, include helpful context in the message +1. **Status codes from errors** - The `ErrorHandler` extracts status codes using `errors.Code()`. Errors without codes default to 500. +2. **Hide internal details** - For 5xx errors, the full error is logged but only a generic message is returned to the user. +3. **Include context for client errors** - For 4xx errors, the error message is returned to the client. + +See `pkg/api/errors/handler.go` for implementation details. ## CLI Error Handling diff --git a/pkg/api/errors/handler.go b/pkg/api/errors/handler.go new file mode 100644 index 000000000..a93bc66cc --- /dev/null +++ b/pkg/api/errors/handler.go @@ -0,0 +1,49 @@ +// Package errors provides HTTP error handling utilities for the API. +package errors + +import ( + "net/http" + + "github.com/stacklok/toolhive/pkg/errors" + "github.com/stacklok/toolhive/pkg/logger" +) + +// HandlerWithError is an HTTP handler that can return an error. +// This signature allows handlers to return errors instead of manually +// writing error responses, enabling centralized error handling. +type HandlerWithError func(http.ResponseWriter, *http.Request) error + +// ErrorHandler wraps a HandlerWithError and converts returned errors +// into appropriate HTTP responses. +// +// The decorator: +// - Returns early if no error is returned (handler already wrote response) +// - Extracts HTTP status code from the error using errors.Code() +// - For 5xx errors: logs full error details, returns generic message to client +// - For 4xx errors: returns error message to client +// +// Usage: +// +// r.Get("/{name}", apierrors.ErrorHandler(routes.getWorkload)) +func ErrorHandler(fn HandlerWithError) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + err := fn(w, r) + if err == nil { + // No error returned, handler already wrote the response + return + } + + // Extract HTTP status code from the error + code := errors.Code(err) + + // For 5xx errors, log the full error but return a generic message + if code >= http.StatusInternalServerError { + logger.Errorf("Internal server error: %v", err) + http.Error(w, http.StatusText(code), code) + return + } + + // For 4xx errors, return the error message to the client + http.Error(w, err.Error(), code) + } +} diff --git a/pkg/api/errors/handler_test.go b/pkg/api/errors/handler_test.go new file mode 100644 index 000000000..124e25dfe --- /dev/null +++ b/pkg/api/errors/handler_test.go @@ -0,0 +1,168 @@ +package errors + +import ( + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + thverrors "github.com/stacklok/toolhive/pkg/errors" +) + +func TestErrorHandler(t *testing.T) { + t.Parallel() + + t.Run("passes through successful response", func(t *testing.T) { + t.Parallel() + + handler := ErrorHandler(func(w http.ResponseWriter, _ *http.Request) error { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("success")) + return nil + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "success", rec.Body.String()) + }) + + t.Run("converts 400 error to HTTP response with message", func(t *testing.T) { + t.Parallel() + + handler := ErrorHandler(func(_ http.ResponseWriter, _ *http.Request) error { + return thverrors.WithCode( + fmt.Errorf("invalid input"), + http.StatusBadRequest, + ) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "invalid input") + }) + + t.Run("converts 404 error to HTTP response with message", func(t *testing.T) { + t.Parallel() + + handler := ErrorHandler(func(_ http.ResponseWriter, _ *http.Request) error { + return thverrors.WithCode( + fmt.Errorf("resource not found"), + http.StatusNotFound, + ) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusNotFound, rec.Code) + require.Contains(t, rec.Body.String(), "resource not found") + }) + + t.Run("converts 409 error to HTTP response with message", func(t *testing.T) { + t.Parallel() + + handler := ErrorHandler(func(_ http.ResponseWriter, _ *http.Request) error { + return thverrors.WithCode( + fmt.Errorf("resource already exists"), + http.StatusConflict, + ) + }) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusConflict, rec.Code) + require.Contains(t, rec.Body.String(), "resource already exists") + }) + + t.Run("converts 500 error to generic HTTP response", func(t *testing.T) { + t.Parallel() + + handler := ErrorHandler(func(_ http.ResponseWriter, _ *http.Request) error { + return thverrors.WithCode( + fmt.Errorf("sensitive database error details"), + http.StatusInternalServerError, + ) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusInternalServerError, rec.Code) + // Should NOT contain the sensitive error details + require.False(t, strings.Contains(rec.Body.String(), "sensitive")) + // Should contain generic message + require.Contains(t, rec.Body.String(), "Internal Server Error") + }) + + t.Run("error without code defaults to 500 with generic message", func(t *testing.T) { + t.Parallel() + + handler := ErrorHandler(func(_ http.ResponseWriter, _ *http.Request) error { + return errors.New("plain error without code") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusInternalServerError, rec.Code) + // Should NOT contain the original error details + require.False(t, strings.Contains(rec.Body.String(), "plain error")) + // Should contain generic message + require.Contains(t, rec.Body.String(), "Internal Server Error") + }) + + t.Run("handles wrapped error with code", func(t *testing.T) { + t.Parallel() + + sentinelErr := thverrors.WithCode( + errors.New("not found"), + http.StatusNotFound, + ) + + handler := ErrorHandler(func(_ http.ResponseWriter, _ *http.Request) error { + return fmt.Errorf("workload lookup failed: %w", sentinelErr) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusNotFound, rec.Code) + require.Contains(t, rec.Body.String(), "workload lookup failed") + }) +} + +func TestHandlerWithError_Type(t *testing.T) { + t.Parallel() + + // Ensure HandlerWithError can be used as expected + var handler HandlerWithError = func(w http.ResponseWriter, _ *http.Request) error { + w.WriteHeader(http.StatusOK) + return nil + } + + wrapped := ErrorHandler(handler) + require.NotNil(t, wrapped) +} diff --git a/pkg/api/v1/clients.go b/pkg/api/v1/clients.go index 097236cd9..a774ec91f 100644 --- a/pkg/api/v1/clients.go +++ b/pkg/api/v1/clients.go @@ -8,9 +8,11 @@ import ( "github.com/go-chi/chi/v5" + apierrors "github.com/stacklok/toolhive/pkg/api/errors" "github.com/stacklok/toolhive/pkg/client" "github.com/stacklok/toolhive/pkg/config" "github.com/stacklok/toolhive/pkg/core" + thverrors "github.com/stacklok/toolhive/pkg/errors" "github.com/stacklok/toolhive/pkg/groups" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/workloads" @@ -36,12 +38,12 @@ func ClientRouter( } r := chi.NewRouter() - r.Get("/", routes.listClients) - r.Post("/", routes.registerClient) - r.Delete("/{name}", routes.unregisterClient) - r.Delete("/{name}/groups/{group}", routes.unregisterClientFromGroup) - r.Post("/register", routes.registerClientsBulk) - r.Post("/unregister", routes.unregisterClientsBulk) + r.Get("/", apierrors.ErrorHandler(routes.listClients)) + r.Post("/", apierrors.ErrorHandler(routes.registerClient)) + r.Delete("/{name}", apierrors.ErrorHandler(routes.unregisterClient)) + r.Delete("/{name}/groups/{group}", apierrors.ErrorHandler(routes.unregisterClientFromGroup)) + r.Post("/register", apierrors.ErrorHandler(routes.registerClientsBulk)) + r.Post("/unregister", apierrors.ErrorHandler(routes.unregisterClientsBulk)) return r } @@ -53,20 +55,17 @@ func ClientRouter( // @Produce json // @Success 200 {array} client.RegisteredClient // @Router /api/v1beta/clients [get] -func (c *ClientRoutes) listClients(w http.ResponseWriter, r *http.Request) { +func (c *ClientRoutes) listClients(w http.ResponseWriter, r *http.Request) error { clients, err := c.clientManager.ListClients(r.Context()) if err != nil { - logger.Errorf("Failed to list clients: %v", err) - http.Error(w, "Failed to list clients", http.StatusInternalServerError) - return + return fmt.Errorf("failed to list clients: %w", err) } w.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w).Encode(clients) - if err != nil { - http.Error(w, "Failed to encode client list", http.StatusInternalServerError) - return + if err := json.NewEncoder(w).Encode(clients); err != nil { + return fmt.Errorf("failed to encode client list: %w", err) } + return nil } // registerClient @@ -80,13 +79,13 @@ func (c *ClientRoutes) listClients(w http.ResponseWriter, r *http.Request) { // @Success 200 {object} createClientResponse // @Failure 400 {string} string "Invalid request" // @Router /api/v1beta/clients [post] -func (c *ClientRoutes) registerClient(w http.ResponseWriter, r *http.Request) { +func (c *ClientRoutes) registerClient(w http.ResponseWriter, r *http.Request) error { var newClient createClientRequest - err := json.NewDecoder(r.Body).Decode(&newClient) - if err != nil { - logger.Errorf("Failed to decode request body: %v", err) - http.Error(w, "Invalid request body", http.StatusBadRequest) - return + if err := json.NewDecoder(r.Body).Decode(&newClient); err != nil { + return thverrors.WithCode( + fmt.Errorf("invalid request body: %w", err), + http.StatusBadRequest, + ) } // Default groups to "default" group if it exists @@ -100,19 +99,16 @@ func (c *ClientRoutes) registerClient(w http.ResponseWriter, r *http.Request) { } } - err = c.performClientRegistration(r.Context(), []client.Client{{Name: newClient.Name}}, newClient.Groups) - if err != nil { - logger.Errorf("Failed to register client: %v", err) - http.Error(w, "Failed to register client", http.StatusInternalServerError) - return + if err := c.performClientRegistration(r.Context(), []client.Client{{Name: newClient.Name}}, newClient.Groups); err != nil { + return fmt.Errorf("failed to register client: %w", err) } w.Header().Set("Content-Type", "application/json") resp := createClientResponse(newClient) - if err = json.NewEncoder(w).Encode(resp); err != nil { - http.Error(w, "Failed to marshal server details", http.StatusInternalServerError) - return + if err := json.NewEncoder(w).Encode(resp); err != nil { + return fmt.Errorf("failed to marshal server details: %w", err) } + return nil } // unregisterClient @@ -124,21 +120,21 @@ func (c *ClientRoutes) registerClient(w http.ResponseWriter, r *http.Request) { // @Success 204 // @Failure 400 {string} string "Invalid request" // @Router /api/v1beta/clients/{name} [delete] -func (c *ClientRoutes) unregisterClient(w http.ResponseWriter, r *http.Request) { +func (c *ClientRoutes) unregisterClient(w http.ResponseWriter, r *http.Request) error { clientName := chi.URLParam(r, "name") if clientName == "" { - http.Error(w, "Client name is required", http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("client name is required"), + http.StatusBadRequest, + ) } - err := c.removeClient(r.Context(), []client.Client{{Name: client.MCPClient(clientName)}}, nil) - if err != nil { - logger.Errorf("Failed to unregister client: %v", err) - http.Error(w, "Failed to unregister client", http.StatusInternalServerError) - return + if err := c.removeClient(r.Context(), []client.Client{{Name: client.MCPClient(clientName)}}, nil); err != nil { + return fmt.Errorf("failed to unregister client: %w", err) } w.WriteHeader(http.StatusNoContent) + return nil } // unregisterClientFromGroup @@ -152,28 +148,30 @@ func (c *ClientRoutes) unregisterClient(w http.ResponseWriter, r *http.Request) // @Failure 400 {string} string "Invalid request" // @Failure 404 {string} string "Client or group not found" // @Router /api/v1beta/clients/{name}/groups/{group} [delete] -func (c *ClientRoutes) unregisterClientFromGroup(w http.ResponseWriter, r *http.Request) { +func (c *ClientRoutes) unregisterClientFromGroup(w http.ResponseWriter, r *http.Request) error { clientName := chi.URLParam(r, "name") if clientName == "" { - http.Error(w, "Client name is required", http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("client name is required"), + http.StatusBadRequest, + ) } groupName := chi.URLParam(r, "group") if groupName == "" { - http.Error(w, "Group name is required", http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("group name is required"), + http.StatusBadRequest, + ) } // Remove client from the specific group - err := c.removeClient(r.Context(), []client.Client{{Name: client.MCPClient(clientName)}}, []string{groupName}) - if err != nil { - logger.Errorf("Failed to unregister client from group: %v", err) - http.Error(w, "Failed to unregister client from group", http.StatusInternalServerError) - return + if err := c.removeClient(r.Context(), []client.Client{{Name: client.MCPClient(clientName)}}, []string{groupName}); err != nil { + return fmt.Errorf("failed to unregister client from group: %w", err) } w.WriteHeader(http.StatusNoContent) + return nil } // registerClientsBulk @@ -187,18 +185,20 @@ func (c *ClientRoutes) unregisterClientFromGroup(w http.ResponseWriter, r *http. // @Success 200 {array} createClientResponse // @Failure 400 {string} string "Invalid request" // @Router /api/v1beta/clients/register [post] -func (c *ClientRoutes) registerClientsBulk(w http.ResponseWriter, r *http.Request) { +func (c *ClientRoutes) registerClientsBulk(w http.ResponseWriter, r *http.Request) error { var req bulkClientRequest - err := json.NewDecoder(r.Body).Decode(&req) - if err != nil { - logger.Errorf("Failed to decode request body: %v", err) - http.Error(w, "Invalid request body", http.StatusBadRequest) - return + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return thverrors.WithCode( + fmt.Errorf("invalid request body: %w", err), + http.StatusBadRequest, + ) } if len(req.Names) == 0 { - http.Error(w, "At least one client name is required", http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("at least one client name is required"), + http.StatusBadRequest, + ) } clients := make([]client.Client, len(req.Names)) @@ -206,11 +206,8 @@ func (c *ClientRoutes) registerClientsBulk(w http.ResponseWriter, r *http.Reques clients[i] = client.Client{Name: name} } - err = c.performClientRegistration(r.Context(), clients, req.Groups) - if err != nil { - logger.Errorf("Failed to register clients: %v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return + if err := c.performClientRegistration(r.Context(), clients, req.Groups); err != nil { + return fmt.Errorf("failed to register clients: %w", err) } responses := make([]createClientResponse, len(req.Names)) @@ -219,10 +216,10 @@ func (c *ClientRoutes) registerClientsBulk(w http.ResponseWriter, r *http.Reques } w.Header().Set("Content-Type", "application/json") - if err = json.NewEncoder(w).Encode(responses); err != nil { - http.Error(w, "Failed to marshal response", http.StatusInternalServerError) - return + if err := json.NewEncoder(w).Encode(responses); err != nil { + return fmt.Errorf("failed to marshal response: %w", err) } + return nil } // unregisterClientsBulk @@ -235,18 +232,20 @@ func (c *ClientRoutes) registerClientsBulk(w http.ResponseWriter, r *http.Reques // @Success 204 // @Failure 400 {string} string "Invalid request" // @Router /api/v1beta/clients/unregister [post] -func (c *ClientRoutes) unregisterClientsBulk(w http.ResponseWriter, r *http.Request) { +func (c *ClientRoutes) unregisterClientsBulk(w http.ResponseWriter, r *http.Request) error { var req bulkClientRequest - err := json.NewDecoder(r.Body).Decode(&req) - if err != nil { - logger.Errorf("Failed to decode request body: %v", err) - http.Error(w, "Invalid request body", http.StatusBadRequest) - return + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return thverrors.WithCode( + fmt.Errorf("invalid request body: %w", err), + http.StatusBadRequest, + ) } if len(req.Names) == 0 { - http.Error(w, "At least one client name is required", http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("at least one client name is required"), + http.StatusBadRequest, + ) } // Convert to client.Client slice @@ -255,14 +254,12 @@ func (c *ClientRoutes) unregisterClientsBulk(w http.ResponseWriter, r *http.Requ clients[i] = client.Client{Name: name} } - err = c.removeClient(r.Context(), clients, req.Groups) - if err != nil { - logger.Errorf("Failed to unregister clients: %v", err) - http.Error(w, "Failed to unregister clients", http.StatusInternalServerError) - return + if err := c.removeClient(r.Context(), clients, req.Groups); err != nil { + return fmt.Errorf("failed to unregister clients: %w", err) } w.WriteHeader(http.StatusNoContent) + return nil } type createClientRequest struct { @@ -353,13 +350,13 @@ func (c *ClientRoutes) removeClient(ctx context.Context, clients []client.Client } if len(groupNames) > 0 { - return c.removeClientFromGroup(ctx, clients, groupNames, runningWorkloads) + return c.removeClientFromGroupInternal(ctx, clients, groupNames, runningWorkloads) } return c.removeClientGlobally(ctx, clients, runningWorkloads) } -func (c *ClientRoutes) removeClientFromGroup( +func (c *ClientRoutes) removeClientFromGroupInternal( ctx context.Context, clients []client.Client, groupNames []string, diff --git a/pkg/api/v1/groups.go b/pkg/api/v1/groups.go index 01953a04f..549ac49f0 100644 --- a/pkg/api/v1/groups.go +++ b/pkg/api/v1/groups.go @@ -3,15 +3,16 @@ package v1 import ( "context" "encoding/json" - "errors" "fmt" "net/http" "github.com/go-chi/chi/v5" + apierrors "github.com/stacklok/toolhive/pkg/api/errors" "github.com/stacklok/toolhive/pkg/client" "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/core" + thverrors "github.com/stacklok/toolhive/pkg/errors" "github.com/stacklok/toolhive/pkg/groups" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/validation" @@ -34,10 +35,10 @@ func GroupsRouter(groupManager groups.Manager, workloadManager workloads.Manager } r := chi.NewRouter() - r.Get("/", routes.listGroups) - r.Post("/", routes.createGroup) - r.Get("/{name}", routes.getGroup) - r.Delete("/{name}", routes.deleteGroup) + r.Get("/", apierrors.ErrorHandler(routes.listGroups)) + r.Post("/", apierrors.ErrorHandler(routes.createGroup)) + r.Get("/{name}", apierrors.ErrorHandler(routes.getGroup)) + r.Delete("/{name}", apierrors.ErrorHandler(routes.deleteGroup)) return r } @@ -57,22 +58,18 @@ func GroupsRouter(groupManager groups.Manager, workloadManager workloads.Manager // @Success 200 {object} groupListResponse // @Failure 500 {string} string "Internal Server Error" // @Router /api/v1beta/groups [get] -func (s *GroupsRoutes) listGroups(w http.ResponseWriter, r *http.Request) { +func (s *GroupsRoutes) listGroups(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() groupList, err := s.groupManager.List(ctx) if err != nil { - logger.Errorf("Failed to list groups: %v", err) - http.Error(w, "Failed to list groups", http.StatusInternalServerError) - return + return fmt.Errorf("failed to list groups: %w", err) } w.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w).Encode(groupListResponse{Groups: groupList}) - if err != nil { - logger.Errorf("Failed to marshal group list: %v", err) - http.Error(w, "Failed to marshal group list", http.StatusInternalServerError) - return + if err := json.NewEncoder(w).Encode(groupListResponse{Groups: groupList}); err != nil { + return fmt.Errorf("failed to marshal group list: %w", err) } + return nil } // createGroup @@ -88,42 +85,37 @@ func (s *GroupsRoutes) listGroups(w http.ResponseWriter, r *http.Request) { // @Failure 409 {string} string "Conflict" // @Failure 500 {string} string "Internal Server Error" // @Router /api/v1beta/groups [post] -func (s *GroupsRoutes) createGroup(w http.ResponseWriter, r *http.Request) { +func (s *GroupsRoutes) createGroup(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() var req createGroupRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - logger.Errorf("Failed to decode create group request: %v", err) - http.Error(w, "Invalid request body", http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("invalid request body: %w", err), + http.StatusBadRequest, + ) } // Validate group name if err := validation.ValidateGroupName(req.Name); err != nil { - logger.Errorf("Invalid group name: %v", err) - http.Error(w, fmt.Sprintf("Invalid group name: %v", err), http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("invalid group name: %w", err), + http.StatusBadRequest, + ) } err := s.groupManager.Create(ctx, req.Name) if err != nil { - logger.Errorf("Failed to create group: %v", err) - if errors.Is(err, groups.ErrGroupAlreadyExists) { - http.Error(w, err.Error(), http.StatusConflict) - } else { - http.Error(w, "Failed to create group", http.StatusInternalServerError) - } - return + return err } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) response := createGroupResponse(req) if err := json.NewEncoder(w).Encode(response); err != nil { - logger.Errorf("Failed to marshal create group response: %v", err) - http.Error(w, "Failed to marshal response", http.StatusInternalServerError) - return + return fmt.Errorf("failed to marshal create group response: %w", err) } + return nil } // getGroup @@ -137,30 +129,28 @@ func (s *GroupsRoutes) createGroup(w http.ResponseWriter, r *http.Request) { // @Failure 404 {string} string "Not Found" // @Failure 500 {string} string "Internal Server Error" // @Router /api/v1beta/groups/{name} [get] -func (s *GroupsRoutes) getGroup(w http.ResponseWriter, r *http.Request) { +func (s *GroupsRoutes) getGroup(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() name := chi.URLParam(r, "name") // Validate group name if err := validation.ValidateGroupName(name); err != nil { - logger.Errorf("Invalid group name: %v", err) - http.Error(w, fmt.Sprintf("Invalid group name: %v", err), http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("invalid group name: %w", err), + http.StatusBadRequest, + ) } group, err := s.groupManager.Get(ctx, name) if err != nil { - logger.Errorf("Failed to get group %s: %v", name, err) - http.Error(w, "Group not found", http.StatusNotFound) - return + return err } w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(group); err != nil { - logger.Errorf("Failed to marshal group: %v", err) - http.Error(w, "Failed to marshal group", http.StatusInternalServerError) - return + return fmt.Errorf("failed to marshal group: %w", err) } + return nil } // deleteGroup @@ -175,34 +165,34 @@ func (s *GroupsRoutes) getGroup(w http.ResponseWriter, r *http.Request) { // @Failure 404 {string} string "Not Found" // @Failure 500 {string} string "Internal Server Error" // @Router /api/v1beta/groups/{name} [delete] -func (s *GroupsRoutes) deleteGroup(w http.ResponseWriter, r *http.Request) { +func (s *GroupsRoutes) deleteGroup(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() name := chi.URLParam(r, "name") // Validate group name if err := validation.ValidateGroupName(name); err != nil { - logger.Errorf("Invalid group name: %v", err) - http.Error(w, fmt.Sprintf("Invalid group name: %v", err), http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("invalid group name: %w", err), + http.StatusBadRequest, + ) } // Check if this is the default group if name == groups.DefaultGroup { - http.Error(w, "Cannot delete the default group", http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("cannot delete the default group"), + http.StatusBadRequest, + ) } // Check if group exists before deleting exists, err := s.groupManager.Exists(ctx, name) if err != nil { - logger.Errorf("Failed to check if group exists %s: %v", name, err) - http.Error(w, "Failed to check group existence", http.StatusInternalServerError) - return + return fmt.Errorf("failed to check group existence: %w", err) } if !exists { - http.Error(w, "Group not found", http.StatusNotFound) - return + return groups.ErrGroupNotFound } // Get the with-workloads flag from query parameter @@ -211,36 +201,29 @@ func (s *GroupsRoutes) deleteGroup(w http.ResponseWriter, r *http.Request) { // Get all workloads and filter for the group allWorkloads, err := s.workloadManager.ListWorkloads(ctx, true) // listAll=true to include stopped workloads if err != nil { - logger.Errorf("Failed to list workloads: %v", err) - http.Error(w, "Failed to list workloads", http.StatusInternalServerError) - return + return fmt.Errorf("failed to list workloads: %w", err) } groupWorkloads, err := workloads.FilterByGroup(allWorkloads, name) if err != nil { - logger.Errorf("Failed to filter workloads by group %s: %v", name, err) - http.Error(w, "Failed to filter workloads by group", http.StatusInternalServerError) - return + return fmt.Errorf("failed to filter workloads by group: %w", err) } // Handle workloads if any exist if len(groupWorkloads) > 0 { if err := s.handleWorkloadsForGroupDeletion(ctx, name, groupWorkloads, withWorkloads); err != nil { - logger.Errorf("Failed to handle workloads for group %s: %v", name, err) - http.Error(w, "Failed to handle workloads", http.StatusInternalServerError) - return + return fmt.Errorf("failed to handle workloads: %w", err) } } // Delete the group err = s.groupManager.Delete(ctx, name) if err != nil { - logger.Errorf("Failed to delete group %s: %v", name, err) - http.Error(w, "Failed to delete group", http.StatusInternalServerError) - return + return fmt.Errorf("failed to delete group: %w", err) } w.WriteHeader(http.StatusNoContent) + return nil } // handleWorkloadsForGroupDeletion handles workloads when deleting a group diff --git a/pkg/api/v1/groups_test.go b/pkg/api/v1/groups_test.go index 6b9030e2a..9fee0372d 100644 --- a/pkg/api/v1/groups_test.go +++ b/pkg/api/v1/groups_test.go @@ -59,7 +59,7 @@ func TestGroupsRouter(t *testing.T) { gm.EXPECT().List(gomock.Any()).Return(nil, fmt.Errorf("database error")) }, expectedStatus: http.StatusInternalServerError, - expectedBody: "Failed to list groups", + expectedBody: "Internal Server Error", // 5xx errors return generic message }, { name: "create group success", @@ -103,7 +103,7 @@ func TestGroupsRouter(t *testing.T) { // No mock setup needed as JSON parsing fails first }, expectedStatus: http.StatusBadRequest, - expectedBody: "Invalid request body", + expectedBody: "invalid request body", }, { name: "get group success", @@ -121,10 +121,10 @@ func TestGroupsRouter(t *testing.T) { method: "GET", path: "/nonexistent", setupMock: func(gm *groupsmocks.MockManager, _ *workloadsmocks.MockManager) { - gm.EXPECT().Get(gomock.Any(), "nonexistent").Return(nil, fmt.Errorf("group not found")) + gm.EXPECT().Get(gomock.Any(), "nonexistent").Return(nil, groups.ErrGroupNotFound) }, expectedStatus: http.StatusNotFound, - expectedBody: "Group not found", + expectedBody: "group not found", }, { name: "delete group success", @@ -146,7 +146,7 @@ func TestGroupsRouter(t *testing.T) { gm.EXPECT().Exists(gomock.Any(), "nonexistent").Return(false, nil) }, expectedStatus: http.StatusNotFound, - expectedBody: "Group not found", + expectedBody: "group not found", }, { name: "delete default group protected", @@ -156,7 +156,7 @@ func TestGroupsRouter(t *testing.T) { // No mock setup needed as validation happens before manager call }, expectedStatus: http.StatusBadRequest, - expectedBody: "Cannot delete the default group", + expectedBody: "cannot delete the default group", }, { name: "delete group with workloads flag true", diff --git a/pkg/api/v1/secrets.go b/pkg/api/v1/secrets.go index 2149f54e9..24a1054b2 100644 --- a/pkg/api/v1/secrets.go +++ b/pkg/api/v1/secrets.go @@ -6,10 +6,13 @@ import ( "errors" "fmt" "net/http" + "strings" "github.com/go-chi/chi/v5" + apierrors "github.com/stacklok/toolhive/pkg/api/errors" "github.com/stacklok/toolhive/pkg/config" + thverrors "github.com/stacklok/toolhive/pkg/errors" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/secrets" ) @@ -48,16 +51,16 @@ func secretsRouterWithRoutes(routes *SecretsRoutes) http.Handler { r := chi.NewRouter() // Setup secrets provider - r.Post("/", routes.setupSecretsProvider) + r.Post("/", apierrors.ErrorHandler(routes.setupSecretsProvider)) // Default provider routes r.Route("/default", func(r chi.Router) { - r.Get("/", routes.getSecretsProvider) + r.Get("/", apierrors.ErrorHandler(routes.getSecretsProvider)) r.Route("/keys", func(r chi.Router) { - r.Get("/", routes.listSecrets) - r.Post("/", routes.createSecret) - r.Put("/{key}", routes.updateSecret) - r.Delete("/{key}", routes.deleteSecret) + r.Get("/", apierrors.ErrorHandler(routes.listSecrets)) + r.Post("/", apierrors.ErrorHandler(routes.createSecret)) + r.Put("/{key}", apierrors.ErrorHandler(routes.updateSecret)) + r.Delete("/{key}", apierrors.ErrorHandler(routes.deleteSecret)) }) }) @@ -78,12 +81,13 @@ func secretsRouterWithRoutes(routes *SecretsRoutes) http.Handler { // @Failure 400 {string} string "Bad Request" // @Failure 500 {string} string "Internal Server Error" // @Router /api/v1beta/secrets [post] -func (s *SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Request) { +func (s *SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Request) error { var req setupSecretsRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - logger.Errorf("Failed to decode request body: %v", err) - http.Error(w, "Invalid request body", http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("invalid request body: %w", err), + http.StatusBadRequest, + ) } // Validate provider type @@ -96,13 +100,16 @@ func (s *SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Requ case string(secrets.NoneType): providerType = secrets.NoneType case "": - http.Error(w, "Provider type cannot be empty", http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("provider type cannot be empty"), + http.StatusBadRequest, + ) default: - http.Error(w, fmt.Sprintf("Invalid secrets provider type: %s (valid types: %s, %s, %s)", - req.ProviderType, string(secrets.EncryptedType), string(secrets.OnePasswordType), string(secrets.NoneType)), - http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("invalid secrets provider type: %s (valid types: %s, %s, %s)", + req.ProviderType, string(secrets.EncryptedType), string(secrets.OnePasswordType), string(secrets.NoneType)), + http.StatusBadRequest, + ) } // Check current secrets provider configuration for appropriate messaging @@ -112,9 +119,7 @@ func (s *SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Requ if cfg.Secrets.SetupCompleted { currentProviderType, err := cfg.Secrets.GetProviderType() if err != nil { - logger.Errorf("Failed to get current provider type: %v", err) - http.Error(w, "Failed to get current provider configuration", http.StatusInternalServerError) - return + return fmt.Errorf("failed to get current provider configuration: %w", err) } // TODO Handle provider reconfiguration in a better way @@ -139,9 +144,7 @@ func (s *SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Requ // Generate a secure random password generatedPassword, err := secrets.GenerateSecurePassword() if err != nil { - logger.Errorf("Failed to generate secure password: %v", err) - http.Error(w, "Failed to generate secure password", http.StatusInternalServerError) - return + return fmt.Errorf("failed to generate secure password: %w", err) } passwordToUse = generatedPassword logger.Infof("Generated secure random password for encrypted provider setup") @@ -154,13 +157,10 @@ func (s *SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Requ ctx := context.Background() result := secrets.ValidateProviderWithPassword(ctx, providerType, passwordToUse) if !result.Success { - logger.Errorf("Provider validation failed: %v", result.Error) if errors.Is(result.Error, secrets.ErrKeyringNotAvailable) { - http.Error(w, result.Error.Error(), http.StatusBadRequest) - return + return result.Error } - http.Error(w, fmt.Sprintf("Provider validation failed: %v", result.Error), http.StatusInternalServerError) - return + return fmt.Errorf("provider validation failed: %w", result.Error) } // For encrypted provider during initial setup or reconfiguration, ensure we create the provider @@ -168,9 +168,7 @@ func (s *SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Requ if providerType == secrets.EncryptedType && (isInitialSetup || isReconfiguration) { _, err := secrets.CreateSecretProviderWithPassword(providerType, passwordToUse) if err != nil { - logger.Errorf("Failed to initialize encrypted provider: %v", err) - http.Error(w, fmt.Sprintf("Failed to initialize encrypted provider: %v", err), http.StatusInternalServerError) - return + return fmt.Errorf("failed to initialize encrypted provider: %w", err) } logger.Info("Encrypted provider initialized and password saved to keyring") } @@ -181,9 +179,7 @@ func (s *SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Requ c.Secrets.SetupCompleted = true }) if err != nil { - logger.Errorf("Failed to update configuration: %v", err) - http.Error(w, fmt.Sprintf("Failed to update configuration: %v", err), http.StatusInternalServerError) - return + return fmt.Errorf("failed to update configuration: %w", err) } // Need to force the singleton to be reloaded so that SetupComplete is updated. @@ -204,10 +200,9 @@ func (s *SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Requ Message: message, } if err := json.NewEncoder(w).Encode(resp); err != nil { - logger.Errorf("Failed to encode response: %v", err) - http.Error(w, "Failed to encode response", http.StatusInternalServerError) - return + return fmt.Errorf("failed to encode response: %w", err) } + return nil } // getSecretsProvider @@ -220,27 +215,22 @@ func (s *SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Requ // @Failure 404 {string} string "Not Found - Provider not setup" // @Failure 500 {string} string "Internal Server Error" // @Router /api/v1beta/secrets/default [get] -func (s *SecretsRoutes) getSecretsProvider(w http.ResponseWriter, _ *http.Request) { +func (s *SecretsRoutes) getSecretsProvider(w http.ResponseWriter, _ *http.Request) error { cfg := s.configProvider.GetConfig() // Check if secrets provider is setup if !cfg.Secrets.SetupCompleted { - http.Error(w, "Secrets provider not setup", http.StatusNotFound) - return + return secrets.ErrSecretsNotSetup } providerType, err := cfg.Secrets.GetProviderType() if err != nil { - logger.Errorf("Failed to get provider type: %v", err) - http.Error(w, "Failed to get provider type", http.StatusInternalServerError) - return + return fmt.Errorf("failed to get provider type: %w", err) } // Get provider capabilities provider, err := s.getSecretsManager() if err != nil { - logger.Errorf("Failed to create secrets provider: %v", err) - http.Error(w, "Failed to access secrets provider", http.StatusInternalServerError) - return + return fmt.Errorf("failed to access secrets provider: %w", err) } capabilities := provider.Capabilities() @@ -258,10 +248,9 @@ func (s *SecretsRoutes) getSecretsProvider(w http.ResponseWriter, _ *http.Reques }, } if err := json.NewEncoder(w).Encode(resp); err != nil { - logger.Errorf("Failed to encode response: %v", err) - http.Error(w, "Failed to encode response", http.StatusInternalServerError) - return + return fmt.Errorf("failed to encode response: %w", err) } + return nil } // listSecrets @@ -275,30 +264,23 @@ func (s *SecretsRoutes) getSecretsProvider(w http.ResponseWriter, _ *http.Reques // @Failure 405 {string} string "Method Not Allowed - Provider doesn't support listing" // @Failure 500 {string} string "Internal Server Error" // @Router /api/v1beta/secrets/default/keys [get] -func (s *SecretsRoutes) listSecrets(w http.ResponseWriter, r *http.Request) { - +func (s *SecretsRoutes) listSecrets(w http.ResponseWriter, r *http.Request) error { provider, err := s.getSecretsManager() if err != nil { - if errors.Is(err, secrets.ErrSecretsNotSetup) { - http.Error(w, "Secrets provider not setup", http.StatusNotFound) - return - } - logger.Errorf("Failed to get secrets manager: %v", err) - http.Error(w, "Failed to access secrets provider", http.StatusInternalServerError) - return + return err } // Check if provider supports listing if !provider.Capabilities().CanList { - http.Error(w, "Secrets provider does not support listing keys", http.StatusMethodNotAllowed) - return + return thverrors.WithCode( + fmt.Errorf("secrets provider does not support listing keys"), + http.StatusMethodNotAllowed, + ) } secretDescriptions, err := provider.ListSecrets(r.Context()) if err != nil { - logger.Errorf("Failed to list secrets: %v", err) - http.Error(w, "Failed to list secrets", http.StatusInternalServerError) - return + return fmt.Errorf("failed to list secrets: %w", err) } w.Header().Set("Content-Type", "application/json") @@ -312,10 +294,9 @@ func (s *SecretsRoutes) listSecrets(w http.ResponseWriter, r *http.Request) { } } if err := json.NewEncoder(w).Encode(resp); err != nil { - logger.Errorf("Failed to encode response: %v", err) - http.Error(w, "Failed to encode response", http.StatusInternalServerError) - return + return fmt.Errorf("failed to encode response: %w", err) } + return nil } // createSecret @@ -333,50 +314,49 @@ func (s *SecretsRoutes) listSecrets(w http.ResponseWriter, r *http.Request) { // @Failure 409 {string} string "Conflict - Secret already exists" // @Failure 500 {string} string "Internal Server Error" // @Router /api/v1beta/secrets/default/keys [post] -func (s *SecretsRoutes) createSecret(w http.ResponseWriter, r *http.Request) { +func (s *SecretsRoutes) createSecret(w http.ResponseWriter, r *http.Request) error { var req createSecretRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - logger.Errorf("Failed to decode request body: %v", err) - http.Error(w, "Invalid request body", http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("invalid request body: %w", err), + http.StatusBadRequest, + ) } if req.Key == "" || req.Value == "" { - http.Error(w, "Both 'key' and 'value' are required", http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("both 'key' and 'value' are required"), + http.StatusBadRequest, + ) } provider, err := s.getSecretsManager() if err != nil { - if errors.Is(err, secrets.ErrSecretsNotSetup) { - http.Error(w, "Secrets provider not setup", http.StatusNotFound) - return - } - logger.Errorf("Failed to get secrets manager: %v", err) - http.Error(w, "Failed to access secrets provider", http.StatusInternalServerError) - return + return err } // Check if provider supports writing if !provider.Capabilities().CanWrite { - http.Error(w, "Secrets provider does not support creating secrets", http.StatusMethodNotAllowed) - return + return thverrors.WithCode( + fmt.Errorf("secrets provider does not support creating secrets"), + http.StatusMethodNotAllowed, + ) } // Check if secret already exists (if provider supports reading) if provider.Capabilities().CanRead { _, err := provider.GetSecret(r.Context(), req.Key) if err == nil { - http.Error(w, "Secret already exists", http.StatusConflict) - return + return thverrors.WithCode( + fmt.Errorf("secret already exists"), + http.StatusConflict, + ) } } // Create the secret if err := provider.SetSecret(r.Context(), req.Key, req.Value); err != nil { - logger.Errorf("Failed to create secret: %v", err) - http.Error(w, "Failed to create secret", http.StatusInternalServerError) - return + return fmt.Errorf("failed to create secret: %w", err) } w.Header().Set("Content-Type", "application/json") @@ -386,10 +366,9 @@ func (s *SecretsRoutes) createSecret(w http.ResponseWriter, r *http.Request) { Message: "Secret created successfully", } if err := json.NewEncoder(w).Encode(resp); err != nil { - logger.Errorf("Failed to encode response: %v", err) - http.Error(w, "Failed to encode response", http.StatusInternalServerError) - return + return fmt.Errorf("failed to encode response: %w", err) } + return nil } // updateSecret @@ -407,56 +386,57 @@ func (s *SecretsRoutes) createSecret(w http.ResponseWriter, r *http.Request) { // @Failure 405 {string} string "Method Not Allowed - Provider doesn't support writing" // @Failure 500 {string} string "Internal Server Error" // @Router /api/v1beta/secrets/default/keys/{key} [put] -func (s *SecretsRoutes) updateSecret(w http.ResponseWriter, r *http.Request) { +func (s *SecretsRoutes) updateSecret(w http.ResponseWriter, r *http.Request) error { key := chi.URLParam(r, "key") if key == "" { - http.Error(w, "Secret key is required", http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("secret key is required"), + http.StatusBadRequest, + ) } var req updateSecretRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - logger.Errorf("Failed to decode request body: %v", err) - http.Error(w, "Invalid request body", http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("invalid request body: %w", err), + http.StatusBadRequest, + ) } if req.Value == "" { - http.Error(w, "Value is required", http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("value is required"), + http.StatusBadRequest, + ) } provider, err := s.getSecretsManager() if err != nil { - if errors.Is(err, secrets.ErrSecretsNotSetup) { - http.Error(w, "Secrets provider not setup", http.StatusNotFound) - return - } - logger.Errorf("Failed to get secrets manager: %v", err) - http.Error(w, "Failed to access secrets provider", http.StatusInternalServerError) - return + return err } // Check if provider supports writing if !provider.Capabilities().CanWrite { - http.Error(w, "Secrets provider does not support updating secrets", http.StatusMethodNotAllowed) - return + return thverrors.WithCode( + fmt.Errorf("secrets provider does not support updating secrets"), + http.StatusMethodNotAllowed, + ) } // Check if secret exists (if provider supports reading) if provider.Capabilities().CanRead { _, err := provider.GetSecret(r.Context(), key) if err != nil { - http.Error(w, "Secret not found", http.StatusNotFound) - return + return thverrors.WithCode( + fmt.Errorf("secret not found"), + http.StatusNotFound, + ) } } // Update the secret if err := provider.SetSecret(r.Context(), key, req.Value); err != nil { - logger.Errorf("Failed to update secret: %v", err) - http.Error(w, "Failed to update secret", http.StatusInternalServerError) - return + return fmt.Errorf("failed to update secret: %w", err) } w.Header().Set("Content-Type", "application/json") @@ -465,10 +445,9 @@ func (s *SecretsRoutes) updateSecret(w http.ResponseWriter, r *http.Request) { Message: "Secret updated successfully", } if err := json.NewEncoder(w).Encode(resp); err != nil { - logger.Errorf("Failed to encode response: %v", err) - http.Error(w, "Failed to encode response", http.StatusInternalServerError) - return + return fmt.Errorf("failed to encode response: %w", err) } + return nil } // deleteSecret @@ -482,43 +461,42 @@ func (s *SecretsRoutes) updateSecret(w http.ResponseWriter, r *http.Request) { // @Failure 405 {string} string "Method Not Allowed - Provider doesn't support deletion" // @Failure 500 {string} string "Internal Server Error" // @Router /api/v1beta/secrets/default/keys/{key} [delete] -func (s *SecretsRoutes) deleteSecret(w http.ResponseWriter, r *http.Request) { +func (s *SecretsRoutes) deleteSecret(w http.ResponseWriter, r *http.Request) error { key := chi.URLParam(r, "key") if key == "" { - http.Error(w, "Secret key is required", http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("secret key is required"), + http.StatusBadRequest, + ) } provider, err := s.getSecretsManager() if err != nil { - if errors.Is(err, secrets.ErrSecretsNotSetup) { - http.Error(w, "Secrets provider not setup", http.StatusNotFound) - return - } - logger.Errorf("Failed to get secrets manager: %v", err) - http.Error(w, "Failed to access secrets provider", http.StatusInternalServerError) - return + return err } // Check if provider supports deletion if !provider.Capabilities().CanDelete { - http.Error(w, "Secrets provider does not support deleting secrets", http.StatusMethodNotAllowed) - return + return thverrors.WithCode( + fmt.Errorf("secrets provider does not support deleting secrets"), + http.StatusMethodNotAllowed, + ) } // Delete the secret if err := provider.DeleteSecret(r.Context(), key); err != nil { - logger.Errorf("Failed to delete secret: %v", err) // Check if it's a "not found" error - if err.Error() == "cannot delete non-existent secret: "+key { - http.Error(w, "Secret not found", http.StatusNotFound) - return + if strings.Contains(err.Error(), "cannot delete non-existent secret") { + return thverrors.WithCode( + fmt.Errorf("secret not found"), + http.StatusNotFound, + ) } - http.Error(w, "Failed to delete secret", http.StatusInternalServerError) - return + return fmt.Errorf("failed to delete secret: %w", err) } w.WriteHeader(http.StatusNoContent) + return nil } // getSecretsManager is a helper function to get the secrets manager diff --git a/pkg/api/v1/secrets_test.go b/pkg/api/v1/secrets_test.go index 851e6ab15..21f1e3459 100644 --- a/pkg/api/v1/secrets_test.go +++ b/pkg/api/v1/secrets_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + apierrors "github.com/stacklok/toolhive/pkg/api/errors" "github.com/stacklok/toolhive/pkg/config" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/secrets" @@ -75,7 +76,7 @@ func TestSetupSecretsProvider_ValidRequests(t *testing.T) { w := httptest.NewRecorder() routes := NewSecretsRoutesWithProvider(configProvider) - routes.setupSecretsProvider(w, req) + apierrors.ErrorHandler(routes.setupSecretsProvider).ServeHTTP(w, req) assert.Equal(t, tt.expectedCode, w.Code) @@ -107,13 +108,13 @@ func TestSetupSecretsProvider_InvalidRequests(t *testing.T) { ProviderType: "invalid", }, expectedCode: http.StatusBadRequest, - errorMessage: "Invalid secrets provider type: invalid (valid types: encrypted, 1password, none)", + errorMessage: "invalid secrets provider type: invalid (valid types: encrypted, 1password, none)", }, { name: "invalid json body", requestBody: "invalid json", expectedCode: http.StatusBadRequest, - errorMessage: "Invalid request body", + errorMessage: "invalid request body", }, } @@ -146,7 +147,7 @@ func TestSetupSecretsProvider_InvalidRequests(t *testing.T) { w := httptest.NewRecorder() routes := NewSecretsRoutesWithProvider(configProvider) - routes.setupSecretsProvider(w, req) + apierrors.ErrorHandler(routes.setupSecretsProvider).ServeHTTP(w, req) assert.Equal(t, tt.expectedCode, w.Code) assert.Contains(t, w.Body.String(), tt.errorMessage) @@ -171,7 +172,7 @@ func TestCreateSecret_InvalidRequests(t *testing.T) { Value: "test-value", }, expectedCode: http.StatusBadRequest, - errorMessage: "Both 'key' and 'value' are required", + errorMessage: "both 'key' and 'value' are required", }, { name: "missing value", @@ -180,13 +181,13 @@ func TestCreateSecret_InvalidRequests(t *testing.T) { Value: "", }, expectedCode: http.StatusBadRequest, - errorMessage: "Both 'key' and 'value' are required", + errorMessage: "both 'key' and 'value' are required", }, { name: "invalid json body", requestBody: "invalid json", expectedCode: http.StatusBadRequest, - errorMessage: "Invalid request body", + errorMessage: "invalid request body", }, } @@ -219,7 +220,7 @@ func TestCreateSecret_InvalidRequests(t *testing.T) { w := httptest.NewRecorder() routes := NewSecretsRoutesWithProvider(configProvider) - routes.createSecret(w, req) + apierrors.ErrorHandler(routes.createSecret).ServeHTTP(w, req) assert.Equal(t, tt.expectedCode, w.Code) assert.Contains(t, w.Body.String(), tt.errorMessage) @@ -245,7 +246,7 @@ func TestUpdateSecret_InvalidRequests(t *testing.T) { Value: "new-value", }, expectedCode: http.StatusBadRequest, - errorMessage: "Secret key is required", + errorMessage: "secret key is required", }, { name: "missing value", @@ -254,14 +255,14 @@ func TestUpdateSecret_InvalidRequests(t *testing.T) { Value: "", }, expectedCode: http.StatusBadRequest, - errorMessage: "Value is required", + errorMessage: "value is required", }, { name: "invalid json body", secretKey: "test-key", requestBody: "invalid json", expectedCode: http.StatusBadRequest, - errorMessage: "Invalid request body", + errorMessage: "invalid request body", }, } @@ -301,7 +302,7 @@ func TestUpdateSecret_InvalidRequests(t *testing.T) { w := httptest.NewRecorder() routes := NewSecretsRoutesWithProvider(configProvider) - routes.updateSecret(w, req) + apierrors.ErrorHandler(routes.updateSecret).ServeHTTP(w, req) assert.Equal(t, tt.expectedCode, w.Code) assert.Contains(t, w.Body.String(), tt.errorMessage) @@ -323,7 +324,7 @@ func TestDeleteSecret_InvalidRequests(t *testing.T) { name: "empty secret key", secretKey: "", expectedCode: http.StatusBadRequest, - errorMessage: "Secret key is required", + errorMessage: "secret key is required", }, } @@ -354,7 +355,7 @@ func TestDeleteSecret_InvalidRequests(t *testing.T) { w := httptest.NewRecorder() routes := NewSecretsRoutesWithProvider(configProvider) - routes.deleteSecret(w, req) + apierrors.ErrorHandler(routes.deleteSecret).ServeHTTP(w, req) assert.Equal(t, tt.expectedCode, w.Code) assert.Contains(t, w.Body.String(), tt.errorMessage) @@ -485,10 +486,10 @@ func TestErrorHandling(t *testing.T) { w := httptest.NewRecorder() routes := NewSecretsRoutesWithProvider(configProvider) - routes.setupSecretsProvider(w, req) + apierrors.ErrorHandler(routes.setupSecretsProvider).ServeHTTP(w, req) assert.Equal(t, http.StatusBadRequest, w.Code) - assert.Contains(t, w.Body.String(), "Invalid request body") + assert.Contains(t, w.Body.String(), "invalid request body") }) t.Run("empty request body", func(t *testing.T) { @@ -510,7 +511,7 @@ func TestErrorHandling(t *testing.T) { w := httptest.NewRecorder() routes := NewSecretsRoutesWithProvider(configProvider) - routes.createSecret(w, req) + apierrors.ErrorHandler(routes.createSecret).ServeHTTP(w, req) assert.Equal(t, http.StatusBadRequest, w.Code) }) @@ -534,7 +535,7 @@ func TestErrorHandling(t *testing.T) { w := httptest.NewRecorder() routes := NewSecretsRoutesWithProvider(configProvider) - routes.setupSecretsProvider(w, req) + apierrors.ErrorHandler(routes.setupSecretsProvider).ServeHTTP(w, req) // Should still work as the handler doesn't strictly require content-type assert.Equal(t, http.StatusCreated, w.Code) diff --git a/pkg/api/v1/workloads.go b/pkg/api/v1/workloads.go index 8c7d40c92..4fe82f8d5 100644 --- a/pkg/api/v1/workloads.go +++ b/pkg/api/v1/workloads.go @@ -3,21 +3,19 @@ package v1 import ( "context" "encoding/json" - "errors" "fmt" "net/http" "github.com/go-chi/chi/v5" + apierrors "github.com/stacklok/toolhive/pkg/api/errors" "github.com/stacklok/toolhive/pkg/container/runtime" + thverrors "github.com/stacklok/toolhive/pkg/errors" "github.com/stacklok/toolhive/pkg/groups" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/runner" - "github.com/stacklok/toolhive/pkg/runner/retriever" "github.com/stacklok/toolhive/pkg/validation" "github.com/stacklok/toolhive/pkg/workloads" wt "github.com/stacklok/toolhive/pkg/workloads/types" - werr "github.com/stacklok/toolhive/pkg/workloads/types/errors" ) // WorkloadRoutes defines the routes for workload management. @@ -58,20 +56,20 @@ func WorkloadRouter( } r := chi.NewRouter() - r.Get("/", routes.listWorkloads) - r.Post("/", routes.createWorkload) - r.Post("/stop", routes.stopWorkloadsBulk) - r.Post("/restart", routes.restartWorkloadsBulk) - r.Post("/delete", routes.deleteWorkloadsBulk) - r.Get("/{name}", routes.getWorkload) - r.Post("/{name}/edit", routes.updateWorkload) - r.Post("/{name}/stop", routes.stopWorkload) - r.Post("/{name}/restart", routes.restartWorkload) - r.Get("/{name}/status", routes.getWorkloadStatus) - r.Get("/{name}/logs", routes.getLogsForWorkload) - r.Get("/{name}/proxy-logs", routes.getProxyLogsForWorkload) - r.Get("/{name}/export", routes.exportWorkload) - r.Delete("/{name}", routes.deleteWorkload) + r.Get("/", apierrors.ErrorHandler(routes.listWorkloads)) + r.Post("/", apierrors.ErrorHandler(routes.createWorkload)) + r.Post("/stop", apierrors.ErrorHandler(routes.stopWorkloadsBulk)) + r.Post("/restart", apierrors.ErrorHandler(routes.restartWorkloadsBulk)) + r.Post("/delete", apierrors.ErrorHandler(routes.deleteWorkloadsBulk)) + r.Get("/{name}", apierrors.ErrorHandler(routes.getWorkload)) + r.Post("/{name}/edit", apierrors.ErrorHandler(routes.updateWorkload)) + r.Post("/{name}/stop", apierrors.ErrorHandler(routes.stopWorkload)) + r.Post("/{name}/restart", apierrors.ErrorHandler(routes.restartWorkload)) + r.Get("/{name}/status", apierrors.ErrorHandler(routes.getWorkloadStatus)) + r.Get("/{name}/logs", apierrors.ErrorHandler(routes.getLogsForWorkload)) + r.Get("/{name}/proxy-logs", apierrors.ErrorHandler(routes.getProxyLogsForWorkload)) + r.Get("/{name}/export", apierrors.ErrorHandler(routes.exportWorkload)) + r.Delete("/{name}", apierrors.ErrorHandler(routes.deleteWorkload)) return r } @@ -86,42 +84,35 @@ func WorkloadRouter( // @Success 200 {object} workloadListResponse // @Failure 404 {string} string "Group not found" // @Router /api/v1beta/workloads [get] -func (s *WorkloadRoutes) listWorkloads(w http.ResponseWriter, r *http.Request) { +func (s *WorkloadRoutes) listWorkloads(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() listAll := r.URL.Query().Get("all") == "true" groupFilter := r.URL.Query().Get("group") workloadList, err := s.workloadManager.ListWorkloads(ctx, listAll) if err != nil { - logger.Errorf("Failed to list workloads: %v", err) - http.Error(w, "Failed to list workloads", http.StatusInternalServerError) - return + return fmt.Errorf("failed to list workloads: %w", err) } // Apply group filtering if specified if groupFilter != "" { if err := validation.ValidateGroupName(groupFilter); err != nil { - http.Error(w, "Invalid group name: "+err.Error(), http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("invalid group name: %w", err), + http.StatusBadRequest, + ) } workloadList, err = workloads.FilterByGroup(workloadList, groupFilter) if err != nil { - if errors.Is(err, groups.ErrGroupNotFound) { - http.Error(w, "Group not found", http.StatusNotFound) - } else { - logger.Errorf("Failed to filter workloads by group: %v", err) - http.Error(w, "Failed to list workloads in group", http.StatusInternalServerError) - } - return + return err // groups.ErrGroupNotFound already has 404 status code } } w.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w).Encode(workloadListResponse{Workloads: workloadList}) - if err != nil { - http.Error(w, "Failed to marshal workload list", http.StatusInternalServerError) - return + if err := json.NewEncoder(w).Encode(workloadListResponse{Workloads: workloadList}); err != nil { + return fmt.Errorf("failed to marshal workload list: %w", err) } + return nil } // getWorkload @@ -134,40 +125,32 @@ func (s *WorkloadRoutes) listWorkloads(w http.ResponseWriter, r *http.Request) { // @Success 200 {object} createRequest // @Failure 404 {string} string "Not Found" // @Router /api/v1beta/workloads/{name} [get] -func (s *WorkloadRoutes) getWorkload(w http.ResponseWriter, r *http.Request) { +func (s *WorkloadRoutes) getWorkload(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() name := chi.URLParam(r, "name") // Check if workload exists first _, err := s.workloadManager.GetWorkload(ctx, name) if err != nil { - if errors.Is(err, runtime.ErrWorkloadNotFound) { - http.Error(w, "Workload not found", http.StatusNotFound) - return - } else if errors.Is(err, wt.ErrInvalidWorkloadName) { - http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest) - return - } - logger.Errorf("Failed to get workload: %v", err) - http.Error(w, "Failed to get workload", http.StatusInternalServerError) - return + return err // ErrWorkloadNotFound (404) or ErrInvalidWorkloadName (400) already have status codes } // Load the workload configuration runConfig, err := runner.LoadState(ctx, name) if err != nil { - logger.Errorf("Failed to load workload configuration for %s: %v", name, err) - http.Error(w, "Workload configuration not found", http.StatusNotFound) - return + return thverrors.WithCode( + fmt.Errorf("workload configuration not found: %w", err), + http.StatusNotFound, + ) } config := runConfigToCreateRequest(runConfig) w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(config); err != nil { - http.Error(w, "Failed to marshal workload configuration", http.StatusInternalServerError) - return + return fmt.Errorf("failed to marshal workload configuration: %w", err) } + return nil } // stopWorkload @@ -180,22 +163,17 @@ func (s *WorkloadRoutes) getWorkload(w http.ResponseWriter, r *http.Request) { // @Failure 400 {string} string "Bad Request" // @Failure 404 {string} string "Not Found" // @Router /api/v1beta/workloads/{name}/stop [post] -func (s *WorkloadRoutes) stopWorkload(w http.ResponseWriter, r *http.Request) { +func (s *WorkloadRoutes) stopWorkload(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() name := chi.URLParam(r, "name") // Use the bulk method with a single workload _, err := s.workloadManager.StopWorkloads(ctx, []string{name}) if err != nil { - if errors.Is(err, wt.ErrInvalidWorkloadName) { - http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest) - return - } - logger.Errorf("Failed to stop workload: %v", err) - http.Error(w, "Failed to stop workload", http.StatusInternalServerError) - return + return err // ErrInvalidWorkloadName already has 400 status code } w.WriteHeader(http.StatusAccepted) + return nil } // restartWorkload @@ -208,22 +186,17 @@ func (s *WorkloadRoutes) stopWorkload(w http.ResponseWriter, r *http.Request) { // @Failure 400 {string} string "Bad Request" // @Failure 404 {string} string "Not Found" // @Router /api/v1beta/workloads/{name}/restart [post] -func (s *WorkloadRoutes) restartWorkload(w http.ResponseWriter, r *http.Request) { +func (s *WorkloadRoutes) restartWorkload(w http.ResponseWriter, r *http.Request) error { name := chi.URLParam(r, "name") // Use the bulk method with a single workload // Note: In the API, we always assume that the restart is a background operation _, err := s.workloadManager.RestartWorkloads(context.Background(), []string{name}, false) if err != nil { - if errors.Is(err, wt.ErrInvalidWorkloadName) { - http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest) - return - } - logger.Errorf("Failed to restart workload: %v", err) - http.Error(w, "Failed to restart workload", http.StatusInternalServerError) - return + return err // ErrInvalidWorkloadName already has 400 status code } w.WriteHeader(http.StatusAccepted) + return nil } // deleteWorkload @@ -236,22 +209,17 @@ func (s *WorkloadRoutes) restartWorkload(w http.ResponseWriter, r *http.Request) // @Failure 400 {string} string "Bad Request" // @Failure 404 {string} string "Not Found" // @Router /api/v1beta/workloads/{name} [delete] -func (s *WorkloadRoutes) deleteWorkload(w http.ResponseWriter, r *http.Request) { +func (s *WorkloadRoutes) deleteWorkload(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() name := chi.URLParam(r, "name") // Use the bulk method with a single workload _, err := s.workloadManager.DeleteWorkloads(ctx, []string{name}) if err != nil { - if errors.Is(err, wt.ErrInvalidWorkloadName) { - http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest) - return - } - logger.Errorf("Failed to delete workload: %v", err) - http.Error(w, "Failed to delete workload", http.StatusInternalServerError) - return + return err // ErrInvalidWorkloadName already has 400 status code } w.WriteHeader(http.StatusAccepted) + return nil } // createWorkload @@ -266,39 +234,34 @@ func (s *WorkloadRoutes) deleteWorkload(w http.ResponseWriter, r *http.Request) // @Failure 400 {string} string "Bad Request" // @Failure 409 {string} string "Conflict" // @Router /api/v1beta/workloads [post] -func (s *WorkloadRoutes) createWorkload(w http.ResponseWriter, r *http.Request) { +func (s *WorkloadRoutes) createWorkload(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() var req createRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Failed to decode request", http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("failed to decode request: %w", err), + http.StatusBadRequest, + ) } // check if the workload already exists if req.Name != "" { exists, err := s.workloadManager.DoesWorkloadExist(ctx, req.Name) if err != nil { - http.Error(w, fmt.Sprintf("Failed to check if workload exists: %v", err), http.StatusInternalServerError) - return + return fmt.Errorf("failed to check if workload exists: %w", err) } if exists { - http.Error(w, fmt.Sprintf("Workload with name %s already exists", req.Name), http.StatusConflict) - return + return thverrors.WithCode( + fmt.Errorf("workload with name %s already exists", req.Name), + http.StatusConflict, + ) } } // Create the workload using shared logic runConfig, err := s.workloadService.CreateWorkloadFromRequest(ctx, &req) if err != nil { - // Error messages already logged in createWorkloadFromRequest - if errors.Is(err, retriever.ErrImageNotFound) { - http.Error(w, err.Error(), http.StatusNotFound) - } else if errors.Is(err, retriever.ErrInvalidRunConfig) { - http.Error(w, err.Error(), http.StatusBadRequest) - } else { - http.Error(w, err.Error(), http.StatusInternalServerError) - } - return + return err // ErrImageNotFound (404) and ErrInvalidRunConfig (400) already have status codes } // Return name so that the client will get the auto-generated name. @@ -309,9 +272,9 @@ func (s *WorkloadRoutes) createWorkload(w http.ResponseWriter, r *http.Request) Port: runConfig.Port, } if err = json.NewEncoder(w).Encode(resp); err != nil { - http.Error(w, "Failed to marshal workload details", http.StatusInternalServerError) - return + return fmt.Errorf("failed to marshal workload details: %w", err) } + return nil } // updateWorkload @@ -327,23 +290,23 @@ func (s *WorkloadRoutes) createWorkload(w http.ResponseWriter, r *http.Request) // @Failure 400 {string} string "Bad Request" // @Failure 404 {string} string "Not Found" // @Router /api/v1beta/workloads/{name}/edit [post] -func (s *WorkloadRoutes) updateWorkload(w http.ResponseWriter, r *http.Request) { +func (s *WorkloadRoutes) updateWorkload(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() name := chi.URLParam(r, "name") // Parse request body var updateReq updateRequest if err := json.NewDecoder(r.Body).Decode(&updateReq); err != nil { - http.Error(w, "Invalid JSON: "+err.Error(), http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("invalid JSON: %w", err), + http.StatusBadRequest, + ) } // Check if workload exists and get its current port existingWorkload, err := s.workloadManager.GetWorkload(ctx, name) if err != nil { - logger.Errorf("Failed to get workload: %v", err) - http.Error(w, "Workload not found", http.StatusNotFound) - return + return err // ErrWorkloadNotFound (404) already has status code } // Convert updateRequest to createRequest with the existing workload name @@ -354,15 +317,7 @@ func (s *WorkloadRoutes) updateWorkload(w http.ResponseWriter, r *http.Request) runConfig, err := s.workloadService.UpdateWorkloadFromRequest(ctx, name, &createReq, existingWorkload.Port) if err != nil { - // Error messages already logged in UpdateWorkloadFromRequest - if errors.Is(err, retriever.ErrImageNotFound) { - http.Error(w, err.Error(), http.StatusNotFound) - } else if errors.Is(err, retriever.ErrInvalidRunConfig) { - http.Error(w, err.Error(), http.StatusBadRequest) - } else { - http.Error(w, err.Error(), http.StatusInternalServerError) - } - return + return err // ErrImageNotFound (404) and ErrInvalidRunConfig (400) already have status codes } // Return the same response format as create @@ -372,9 +327,9 @@ func (s *WorkloadRoutes) updateWorkload(w http.ResponseWriter, r *http.Request) Port: runConfig.Port, } if err = json.NewEncoder(w).Encode(resp); err != nil { - http.Error(w, "Failed to marshal workload details", http.StatusInternalServerError) - return + return fmt.Errorf("failed to marshal workload details: %w", err) } + return nil } // stopWorkloadsBulk @@ -387,39 +342,34 @@ func (s *WorkloadRoutes) updateWorkload(w http.ResponseWriter, r *http.Request) // @Success 202 {string} string "Accepted" // @Failure 400 {string} string "Bad Request" // @Router /api/v1beta/workloads/stop [post] -func (s *WorkloadRoutes) stopWorkloadsBulk(w http.ResponseWriter, r *http.Request) { +func (s *WorkloadRoutes) stopWorkloadsBulk(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() var req bulkOperationRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Failed to decode request", http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("failed to decode request: %w", err), + http.StatusBadRequest, + ) } if err := validateBulkOperationRequest(req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return + return thverrors.WithCode(err, http.StatusBadRequest) } workloadNames, err := s.workloadService.GetWorkloadNamesFromRequest(ctx, req) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return + return thverrors.WithCode(err, http.StatusBadRequest) } // Note that this is an asynchronous operation. // The request is not blocked on completion. _, err = s.workloadManager.StopWorkloads(ctx, workloadNames) if err != nil { - if errors.Is(err, wt.ErrInvalidWorkloadName) { - http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest) - return - } - logger.Errorf("Failed to stop workloads: %v", err) - http.Error(w, "Failed to stop workloads", http.StatusInternalServerError) - return + return err // ErrInvalidWorkloadName already has 400 status code } w.WriteHeader(http.StatusAccepted) + return nil } // restartWorkloadsBulk @@ -432,24 +382,24 @@ func (s *WorkloadRoutes) stopWorkloadsBulk(w http.ResponseWriter, r *http.Reques // @Success 202 {string} string "Accepted" // @Failure 400 {string} string "Bad Request" // @Router /api/v1beta/workloads/restart [post] -func (s *WorkloadRoutes) restartWorkloadsBulk(w http.ResponseWriter, r *http.Request) { +func (s *WorkloadRoutes) restartWorkloadsBulk(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() var req bulkOperationRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Failed to decode request", http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("failed to decode request: %w", err), + http.StatusBadRequest, + ) } if err := validateBulkOperationRequest(req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return + return thverrors.WithCode(err, http.StatusBadRequest) } workloadNames, err := s.workloadService.GetWorkloadNamesFromRequest(ctx, req) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return + return thverrors.WithCode(err, http.StatusBadRequest) } // Note that this is an asynchronous operation. @@ -457,15 +407,10 @@ func (s *WorkloadRoutes) restartWorkloadsBulk(w http.ResponseWriter, r *http.Req // Note: In the API, we always assume that the restart is a background operation. _, err = s.workloadManager.RestartWorkloads(context.Background(), workloadNames, false) if err != nil { - if errors.Is(err, wt.ErrInvalidWorkloadName) { - http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest) - return - } - logger.Errorf("Failed to restart workloads: %v", err) - http.Error(w, "Failed to restart workloads", http.StatusInternalServerError) - return + return err // ErrInvalidWorkloadName already has 400 status code } w.WriteHeader(http.StatusAccepted) + return nil } // deleteWorkloadsBulk @@ -478,39 +423,34 @@ func (s *WorkloadRoutes) restartWorkloadsBulk(w http.ResponseWriter, r *http.Req // @Success 202 {string} string "Accepted" // @Failure 400 {string} string "Bad Request" // @Router /api/v1beta/workloads/delete [post] -func (s *WorkloadRoutes) deleteWorkloadsBulk(w http.ResponseWriter, r *http.Request) { +func (s *WorkloadRoutes) deleteWorkloadsBulk(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() var req bulkOperationRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Failed to decode request", http.StatusBadRequest) - return + return thverrors.WithCode( + fmt.Errorf("failed to decode request: %w", err), + http.StatusBadRequest, + ) } if err := validateBulkOperationRequest(req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return + return thverrors.WithCode(err, http.StatusBadRequest) } workloadNames, err := s.workloadService.GetWorkloadNamesFromRequest(ctx, req) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return + return thverrors.WithCode(err, http.StatusBadRequest) } // Note that this is an asynchronous operation. // The request is not blocked on completion. _, err = s.workloadManager.DeleteWorkloads(ctx, workloadNames) if err != nil { - if errors.Is(err, wt.ErrInvalidWorkloadName) { - http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest) - return - } - logger.Errorf("Failed to delete workloads: %v", err) - http.Error(w, "Failed to delete workloads", http.StatusInternalServerError) - return + return err // ErrInvalidWorkloadName already has 400 status code } w.WriteHeader(http.StatusAccepted) + return nil } // getLogsForWorkload @@ -524,33 +464,24 @@ func (s *WorkloadRoutes) deleteWorkloadsBulk(w http.ResponseWriter, r *http.Requ // @Failure 400 {string} string "Invalid workload name" // @Failure 404 {string} string "Not Found" // @Router /api/v1beta/workloads/{name}/logs [get] -func (s *WorkloadRoutes) getLogsForWorkload(w http.ResponseWriter, r *http.Request) { +func (s *WorkloadRoutes) getLogsForWorkload(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() name := chi.URLParam(r, "name") // Validate workload name to prevent path traversal if err := wt.ValidateWorkloadName(name); err != nil { - http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest) - return + return err // ErrInvalidWorkloadName already has 400 status code } logs, err := s.workloadManager.GetLogs(ctx, name, false) if err != nil { - if errors.Is(err, runtime.ErrWorkloadNotFound) { - http.Error(w, "Workload not found", http.StatusNotFound) - return - } - logger.Errorf("Failed to get logs: %v", err) - http.Error(w, "Failed to get logs", http.StatusInternalServerError) - return + return err // ErrWorkloadNotFound (404) already has status code } w.Header().Set("Content-Type", "text/plain") - _, err = w.Write([]byte(logs)) - if err != nil { - logger.Errorf("Failed to write logs response: %v", err) - http.Error(w, "Failed to write logs response", http.StatusInternalServerError) - return + if _, err = w.Write([]byte(logs)); err != nil { + return fmt.Errorf("failed to write logs response: %w", err) } + return nil } // getProxyLogsForWorkload @@ -564,30 +495,28 @@ func (s *WorkloadRoutes) getLogsForWorkload(w http.ResponseWriter, r *http.Reque // @Failure 400 {string} string "Invalid workload name" // @Failure 404 {string} string "Proxy logs not found for workload" // @Router /api/v1beta/workloads/{name}/proxy-logs [get] -func (s *WorkloadRoutes) getProxyLogsForWorkload(w http.ResponseWriter, r *http.Request) { +func (s *WorkloadRoutes) getProxyLogsForWorkload(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() name := chi.URLParam(r, "name") // Validate workload name to prevent path traversal if err := wt.ValidateWorkloadName(name); err != nil { - http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest) - return + return err // ErrInvalidWorkloadName already has 400 status code } logs, err := s.workloadManager.GetProxyLogs(ctx, name) if err != nil { - logger.Errorf("Failed to get proxy logs: %v", err) - http.Error(w, "Proxy logs not found for workload", http.StatusNotFound) - return + return thverrors.WithCode( + fmt.Errorf("proxy logs not found for workload: %w", err), + http.StatusNotFound, + ) } w.Header().Set("Content-Type", "text/plain") - _, err = w.Write([]byte(logs)) - if err != nil { - logger.Errorf("Failed to write proxy logs response: %v", err) - http.Error(w, "Failed to write proxy logs response", http.StatusInternalServerError) - return + if _, err = w.Write([]byte(logs)); err != nil { + return fmt.Errorf("failed to write proxy logs response: %w", err) } + return nil } // getWorkloadStatus @@ -600,22 +529,13 @@ func (s *WorkloadRoutes) getProxyLogsForWorkload(w http.ResponseWriter, r *http. // @Success 200 {object} workloadStatusResponse // @Failure 404 {string} string "Not Found" // @Router /api/v1beta/workloads/{name}/status [get] -func (s *WorkloadRoutes) getWorkloadStatus(w http.ResponseWriter, r *http.Request) { +func (s *WorkloadRoutes) getWorkloadStatus(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() name := chi.URLParam(r, "name") workload, err := s.workloadManager.GetWorkload(ctx, name) if err != nil { - if errors.Is(err, runtime.ErrWorkloadNotFound) { - http.Error(w, "Workload not found", http.StatusNotFound) - return - } else if errors.Is(err, wt.ErrInvalidWorkloadName) { - http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest) - return - } - logger.Errorf("Failed to get workload: %v", err) - http.Error(w, "Failed to get workload", http.StatusInternalServerError) - return + return err // ErrWorkloadNotFound (404) or ErrInvalidWorkloadName (400) already have status codes } response := workloadStatusResponse{ @@ -624,9 +544,9 @@ func (s *WorkloadRoutes) getWorkloadStatus(w http.ResponseWriter, r *http.Reques w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { - http.Error(w, "Failed to marshal workload status", http.StatusInternalServerError) - return + return fmt.Errorf("failed to marshal workload status: %w", err) } + return nil } // exportWorkload @@ -639,27 +559,20 @@ func (s *WorkloadRoutes) getWorkloadStatus(w http.ResponseWriter, r *http.Reques // @Success 200 {object} runner.RunConfig // @Failure 404 {string} string "Not Found" // @Router /api/v1beta/workloads/{name}/export [get] -func (*WorkloadRoutes) exportWorkload(w http.ResponseWriter, r *http.Request) { +func (*WorkloadRoutes) exportWorkload(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() name := chi.URLParam(r, "name") // Load the saved run configuration runConfig, err := runner.LoadState(ctx, name) if err != nil { - if errors.Is(err, werr.ErrRunConfigNotFound) { - http.Error(w, "Workload configuration not found", http.StatusNotFound) - return - } - logger.Errorf("Failed to load workload configuration: %v", err) - http.Error(w, "Failed to load workload configuration", http.StatusInternalServerError) - return + return err // ErrRunConfigNotFound (404) already has status code } // Return the configuration as JSON w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(runConfig); err != nil { - logger.Errorf("Failed to encode workload configuration: %v", err) - http.Error(w, "Failed to encode workload configuration", http.StatusInternalServerError) - return + return fmt.Errorf("failed to encode workload configuration: %w", err) } + return nil } diff --git a/pkg/api/v1/workloads_test.go b/pkg/api/v1/workloads_test.go index 2548aa4f7..4ee84beca 100644 --- a/pkg/api/v1/workloads_test.go +++ b/pkg/api/v1/workloads_test.go @@ -13,6 +13,7 @@ import ( "go.uber.org/mock/gomock" "golang.org/x/sync/errgroup" + apierrors "github.com/stacklok/toolhive/pkg/api/errors" "github.com/stacklok/toolhive/pkg/config" "github.com/stacklok/toolhive/pkg/container/runtime" runtimemocks "github.com/stacklok/toolhive/pkg/container/runtime/mocks" @@ -46,7 +47,7 @@ func TestGetWorkload(t *testing.T) { Return(core.Workload{}, runtime.ErrWorkloadNotFound) }, expectedStatus: http.StatusNotFound, - expectedBody: "Workload not found", + expectedBody: "workload not found", }, { name: "invalid workload name", @@ -56,7 +57,7 @@ func TestGetWorkload(t *testing.T) { Return(core.Workload{}, wt.ErrInvalidWorkloadName) }, expectedStatus: http.StatusBadRequest, - expectedBody: "Invalid workload name", + expectedBody: "invalid workload name", }, } @@ -85,7 +86,7 @@ func TestGetWorkload(t *testing.T) { req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) w := httptest.NewRecorder() - routes.getWorkload(w, req) + apierrors.ErrorHandler(routes.getWorkload).ServeHTTP(w, req) assert.Equal(t, tt.expectedStatus, w.Code) assert.Contains(t, w.Body.String(), tt.expectedBody) @@ -111,7 +112,7 @@ func TestCreateWorkload(t *testing.T) { setupMock: func(_ *testing.T, _ *workloadsmocks.MockManager, _ *runtimemocks.MockRuntime, _ *groupsmocks.MockManager) { }, expectedStatus: http.StatusBadRequest, - expectedBody: "Failed to decode request", + expectedBody: "failed to decode request", }, { name: "workload already exists", @@ -120,13 +121,14 @@ func TestCreateWorkload(t *testing.T) { wm.EXPECT().DoesWorkloadExist(gomock.Any(), "existing-workload").Return(true, nil) }, expectedStatus: http.StatusConflict, - expectedBody: "Workload with name existing-workload already exists", + expectedBody: "workload with name existing-workload already exists", }, { name: "invalid proxy mode", requestBody: `{"name": "test-workload", "image": "test-image", "proxy_mode": "invalid"}`, - setupMock: func(_ *testing.T, wm *workloadsmocks.MockManager, _ *runtimemocks.MockRuntime, _ *groupsmocks.MockManager) { + setupMock: func(_ *testing.T, wm *workloadsmocks.MockManager, _ *runtimemocks.MockRuntime, gm *groupsmocks.MockManager) { wm.EXPECT().DoesWorkloadExist(gomock.Any(), "test-workload").Return(false, nil) + gm.EXPECT().Exists(gomock.Any(), "default").Return(true, nil).AnyTimes() }, expectedStatus: http.StatusBadRequest, expectedBody: "Invalid proxy_mode", @@ -231,7 +233,7 @@ func TestCreateWorkload(t *testing.T) { req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - routes.createWorkload(w, req) + apierrors.ErrorHandler(routes.createWorkload).ServeHTTP(w, req) assert.Equal(t, tt.expectedStatus, w.Code) assert.Contains(t, w.Body.String(), tt.expectedBody) @@ -259,7 +261,7 @@ func TestUpdateWorkload(t *testing.T) { setupMock: func(_ *testing.T, _ *workloadsmocks.MockManager, _ *runtimemocks.MockRuntime, _ *groupsmocks.MockManager) { }, expectedStatus: http.StatusBadRequest, - expectedBody: "Invalid JSON", + expectedBody: "invalid JSON", }, { name: "workload not found", @@ -267,10 +269,10 @@ func TestUpdateWorkload(t *testing.T) { requestBody: `{"image": "test-image"}`, setupMock: func(_ *testing.T, wm *workloadsmocks.MockManager, _ *runtimemocks.MockRuntime, _ *groupsmocks.MockManager) { wm.EXPECT().GetWorkload(gomock.Any(), "nonexistent"). - Return(core.Workload{}, fmt.Errorf("workload not found")) + Return(core.Workload{}, runtime.ErrWorkloadNotFound) }, expectedStatus: http.StatusNotFound, - expectedBody: "Workload not found", + expectedBody: "workload not found", }, { name: "stop workload fails", @@ -284,7 +286,7 @@ func TestUpdateWorkload(t *testing.T) { Return(nil, fmt.Errorf("stop failed")) }, expectedStatus: http.StatusInternalServerError, - expectedBody: "failed to update workload: stop failed", + expectedBody: "Internal Server Error", // 5xx errors return generic message }, { name: "delete workload fails", @@ -298,7 +300,7 @@ func TestUpdateWorkload(t *testing.T) { Return(nil, fmt.Errorf("delete failed")) }, expectedStatus: http.StatusInternalServerError, - expectedBody: "failed to update workload: delete failed", + expectedBody: "Internal Server Error", // 5xx errors return generic message }, { name: "with tool filters", @@ -427,7 +429,7 @@ func TestUpdateWorkload(t *testing.T) { req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) w := httptest.NewRecorder() - routes.updateWorkload(w, req) + apierrors.ErrorHandler(routes.updateWorkload).ServeHTTP(w, req) assert.Equal(t, tt.expectedStatus, w.Code) assert.Contains(t, w.Body.String(), tt.expectedBody) @@ -575,7 +577,7 @@ func TestUpdateWorkload_PortReuse(t *testing.T) { req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) w := httptest.NewRecorder() - routes.updateWorkload(w, req) + apierrors.ErrorHandler(routes.updateWorkload).ServeHTTP(w, req) assert.Equal(t, tt.expectedStatus, w.Code, tt.description) assert.Contains(t, w.Body.String(), tt.expectedBody, tt.description) diff --git a/pkg/container/docker/errors.go b/pkg/container/docker/errors.go index 455c62b6a..6e1ef2d9a 100644 --- a/pkg/container/docker/errors.go +++ b/pkg/container/docker/errors.go @@ -1,26 +1,31 @@ package docker -import "fmt" +import ( + "fmt" + "net/http" + + "github.com/stacklok/toolhive/pkg/errors" +) // Error types for container operations var ( // ErrContainerNotFound is returned when a container is not found - ErrContainerNotFound = fmt.Errorf("container not found") + ErrContainerNotFound = errors.WithCode(fmt.Errorf("container not found"), http.StatusNotFound) // ErrMultipleContainersFound is returned when multiple containers are found - ErrMultipleContainersFound = fmt.Errorf("multiple containers found with same name") + ErrMultipleContainersFound = errors.WithCode(fmt.Errorf("multiple containers found with same name"), http.StatusBadRequest) // ErrContainerNotRunning is returned when a container is not running - ErrContainerNotRunning = fmt.Errorf("container not running") + ErrContainerNotRunning = errors.WithCode(fmt.Errorf("container not running"), http.StatusBadRequest) // ErrAttachFailed is returned when attaching to a container fails - ErrAttachFailed = fmt.Errorf("failed to attach to container") + ErrAttachFailed = errors.WithCode(fmt.Errorf("failed to attach to container"), http.StatusBadRequest) // ErrContainerExited is returned when a container has exited unexpectedly - ErrContainerExited = fmt.Errorf("container exited unexpectedly") + ErrContainerExited = errors.WithCode(fmt.Errorf("container exited unexpectedly"), http.StatusBadRequest) // ErrContainerRemoved is returned when a container has been removed - ErrContainerRemoved = fmt.Errorf("container removed") + ErrContainerRemoved = errors.WithCode(fmt.Errorf("container removed"), http.StatusBadRequest) ) // ContainerError represents an error related to container operations diff --git a/pkg/container/runtime/types.go b/pkg/container/runtime/types.go index b234b0ae5..286dc45a8 100644 --- a/pkg/container/runtime/types.go +++ b/pkg/container/runtime/types.go @@ -4,12 +4,14 @@ package runtime import ( "context" - "fmt" + "errors" "io" + "net/http" "strings" "time" "github.com/stacklok/toolhive/pkg/env" + thverrors "github.com/stacklok/toolhive/pkg/errors" "github.com/stacklok/toolhive/pkg/ignore" "github.com/stacklok/toolhive/pkg/permissions" ) @@ -312,5 +314,8 @@ func IsKubernetesRuntimeWithEnv(envReader env.Reader) bool { // Common errors var ( // ErrWorkloadNotFound indicates that the specified workload was not found. - ErrWorkloadNotFound = fmt.Errorf("workload not found") + ErrWorkloadNotFound = thverrors.WithCode( + errors.New("workload not found"), + http.StatusNotFound, + ) ) diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go new file mode 100644 index 000000000..ace20c7cf --- /dev/null +++ b/pkg/errors/errors.go @@ -0,0 +1,62 @@ +// Package errors provides error types with HTTP status codes for API error handling. +package errors + +import ( + "errors" + "net/http" +) + +// CodedError wraps an error with an HTTP status code. +// This allows errors to carry their intended HTTP response code through the call stack, +// enabling centralized error handling in API handlers. +type CodedError struct { + err error + code int +} + +// Error implements the error interface. +func (e *CodedError) Error() string { + return e.err.Error() +} + +// Unwrap returns the underlying error for errors.Is() and errors.As() compatibility. +func (e *CodedError) Unwrap() error { + return e.err +} + +// HTTPCode returns the HTTP status code associated with this error. +func (e *CodedError) HTTPCode() int { + return e.code +} + +// WithCode wraps an error with an HTTP status code. +// The returned error implements Unwrap() for use with errors.Is() and errors.As(). +// If err is nil, WithCode returns nil. +func WithCode(err error, code int) error { + if err == nil { + return nil + } + return &CodedError{err: err, code: code} +} + +// Code extracts the HTTP status code from an error. +// It unwraps the error chain looking for a CodedError. +// If no CodedError is found, it returns http.StatusInternalServerError (500). +func Code(err error) int { + if err == nil { + return http.StatusOK + } + + var coded *CodedError + if errors.As(err, &coded) { + return coded.code + } + + return http.StatusInternalServerError +} + +// New creates a new error with the given message and HTTP status code. +// This is a convenience function equivalent to WithCode(errors.New(message), code). +func New(message string, code int) error { + return &CodedError{err: errors.New(message), code: code} +} diff --git a/pkg/errors/errors_test.go b/pkg/errors/errors_test.go new file mode 100644 index 000000000..c2491a55c --- /dev/null +++ b/pkg/errors/errors_test.go @@ -0,0 +1,152 @@ +package errors + +import ( + "errors" + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWithCode(t *testing.T) { + t.Parallel() + + t.Run("wraps error with code", func(t *testing.T) { + t.Parallel() + + baseErr := errors.New("test error") + err := WithCode(baseErr, http.StatusNotFound) + + require.NotNil(t, err) + + coded, ok := err.(*CodedError) + require.True(t, ok, "expected *CodedError, got %T", err) + require.Equal(t, http.StatusNotFound, coded.HTTPCode()) + require.Equal(t, "test error", coded.Error()) + }) + + t.Run("returns nil for nil error", func(t *testing.T) { + t.Parallel() + + err := WithCode(nil, http.StatusNotFound) + require.Nil(t, err) + }) +} + +func TestCode(t *testing.T) { + t.Parallel() + + t.Run("extracts code from CodedError", func(t *testing.T) { + t.Parallel() + + err := WithCode(errors.New("not found"), http.StatusNotFound) + code := Code(err) + require.Equal(t, http.StatusNotFound, code) + }) + + t.Run("returns 500 for error without code", func(t *testing.T) { + t.Parallel() + + err := errors.New("plain error") + code := Code(err) + require.Equal(t, http.StatusInternalServerError, code) + }) + + t.Run("returns 200 for nil error", func(t *testing.T) { + t.Parallel() + + code := Code(nil) + require.Equal(t, http.StatusOK, code) + }) + + t.Run("extracts code from wrapped error", func(t *testing.T) { + t.Parallel() + + baseErr := WithCode(errors.New("not found"), http.StatusNotFound) + wrappedErr := fmt.Errorf("outer context: %w", baseErr) + code := Code(wrappedErr) + require.Equal(t, http.StatusNotFound, code) + }) + + t.Run("extracts code from deeply wrapped error", func(t *testing.T) { + t.Parallel() + + baseErr := WithCode(errors.New("bad request"), http.StatusBadRequest) + wrapped1 := fmt.Errorf("layer 1: %w", baseErr) + wrapped2 := fmt.Errorf("layer 2: %w", wrapped1) + wrapped3 := fmt.Errorf("layer 3: %w", wrapped2) + code := Code(wrapped3) + require.Equal(t, http.StatusBadRequest, code) + }) +} + +func TestCodedError_Unwrap(t *testing.T) { + t.Parallel() + + t.Run("errors.Is works with wrapped error", func(t *testing.T) { + t.Parallel() + + sentinel := errors.New("sentinel") + err := WithCode(sentinel, http.StatusNotFound) + require.ErrorIs(t, err, sentinel) + }) + + t.Run("errors.Is works with double wrapped error", func(t *testing.T) { + t.Parallel() + + sentinel := errors.New("sentinel") + coded := WithCode(sentinel, http.StatusNotFound) + wrapped := fmt.Errorf("outer: %w", coded) + require.ErrorIs(t, wrapped, sentinel) + }) + + t.Run("errors.As works with CodedError", func(t *testing.T) { + t.Parallel() + + err := WithCode(errors.New("test"), http.StatusBadRequest) + wrapped := fmt.Errorf("wrapped: %w", err) + + var coded *CodedError + require.ErrorAs(t, wrapped, &coded) + require.Equal(t, http.StatusBadRequest, coded.HTTPCode()) + }) +} + +func TestNew(t *testing.T) { + t.Parallel() + + t.Run("creates error with message and code", func(t *testing.T) { + t.Parallel() + + err := New("custom error", http.StatusForbidden) + require.Equal(t, "custom error", err.Error()) + require.Equal(t, http.StatusForbidden, Code(err)) + }) +} + +func TestCodedError_HTTPCode(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + code int + expected int + }{ + {"OK", http.StatusOK, http.StatusOK}, + {"BadRequest", http.StatusBadRequest, http.StatusBadRequest}, + {"NotFound", http.StatusNotFound, http.StatusNotFound}, + {"Conflict", http.StatusConflict, http.StatusConflict}, + {"InternalServerError", http.StatusInternalServerError, http.StatusInternalServerError}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := WithCode(errors.New("test"), tt.code) + coded := err.(*CodedError) + require.Equal(t, tt.expected, coded.HTTPCode()) + }) + } +} diff --git a/pkg/groups/errors.go b/pkg/groups/errors.go index 4a8ec04d5..30e958586 100644 --- a/pkg/groups/errors.go +++ b/pkg/groups/errors.go @@ -1,14 +1,28 @@ package groups -import "errors" +import ( + "errors" + "net/http" + + thverrors "github.com/stacklok/toolhive/pkg/errors" +) var ( // ErrGroupAlreadyExists is returned when a group already exists - ErrGroupAlreadyExists = errors.New("group already exists") + ErrGroupAlreadyExists = thverrors.WithCode( + errors.New("group already exists"), + http.StatusConflict, + ) // ErrGroupNotFound is returned when a group is not found - ErrGroupNotFound = errors.New("group not found") + ErrGroupNotFound = thverrors.WithCode( + errors.New("group not found"), + http.StatusNotFound, + ) // ErrInvalidGroupName is returned when an invalid argument is provided - ErrInvalidGroupName = errors.New("invalid group name") + ErrInvalidGroupName = thverrors.WithCode( + errors.New("invalid group name"), + http.StatusBadRequest, + ) ) diff --git a/pkg/runner/retriever/retriever.go b/pkg/runner/retriever/retriever.go index fb0635797..f75741403 100644 --- a/pkg/runner/retriever/retriever.go +++ b/pkg/runner/retriever/retriever.go @@ -5,12 +5,14 @@ import ( "context" "errors" "fmt" + "net/http" nameref "github.com/google/go-containerregistry/pkg/name" "github.com/stacklok/toolhive/pkg/config" "github.com/stacklok/toolhive/pkg/container/images" "github.com/stacklok/toolhive/pkg/container/verifier" + thverrors "github.com/stacklok/toolhive/pkg/errors" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/registry" types "github.com/stacklok/toolhive/pkg/registry/registry" @@ -28,11 +30,20 @@ const ( var ( // ErrBadProtocolScheme is returned when the provided serverOrImage is not a valid protocol scheme. - ErrBadProtocolScheme = errors.New("invalid protocol scheme provided for MCP server") + ErrBadProtocolScheme = thverrors.WithCode( + errors.New("invalid protocol scheme provided for MCP server"), + http.StatusBadRequest, + ) // ErrImageNotFound is returned when the specified image is not found in the registry. - ErrImageNotFound = errors.New("image not found in registry, please check the image name or tag") + ErrImageNotFound = thverrors.WithCode( + errors.New("image not found in registry, please check the image name or tag"), + http.StatusNotFound, + ) // ErrInvalidRunConfig is returned when the run configuration built by RunConfigBuilder is invalid - ErrInvalidRunConfig = errors.New("invalid run configuration provided") + ErrInvalidRunConfig = thverrors.WithCode( + errors.New("invalid run configuration provided"), + http.StatusBadRequest, + ) ) // Retriever is a function that retrieves the MCP server definition from the registry. diff --git a/pkg/secrets/factory.go b/pkg/secrets/factory.go index 7d761a901..67397221b 100644 --- a/pkg/secrets/factory.go +++ b/pkg/secrets/factory.go @@ -7,6 +7,7 @@ import ( "encoding/base64" "errors" "fmt" + "net/http" "os" "strings" "sync" @@ -14,6 +15,7 @@ import ( "github.com/adrg/xdg" "golang.org/x/term" + thverrors "github.com/stacklok/toolhive/pkg/errors" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/process" "github.com/stacklok/toolhive/pkg/secrets/keyring" @@ -59,11 +61,17 @@ const ( ) // ErrUnknownManagerType is returned when an invalid value for ProviderType is specified. -var ErrUnknownManagerType = errors.New("unknown secret manager type") +var ErrUnknownManagerType = thverrors.WithCode( + errors.New("unknown secret manager type"), + http.StatusBadRequest, +) // ErrSecretsNotSetup is returned when secrets functionality is used before running setup. -var ErrSecretsNotSetup = errors.New("secrets provider not configured. " + - "Please run 'thv secret setup' to configure a secrets provider first") +var ErrSecretsNotSetup = thverrors.WithCode( + errors.New("secrets provider not configured. "+ + "Please run 'thv secret setup' to configure a secrets provider first"), + http.StatusNotFound, +) // SetupResult contains the result of a provider setup operation type SetupResult struct { @@ -194,10 +202,13 @@ func validateNoneProvider(result *SetupResult) *SetupResult { } // ErrKeyringNotAvailable is returned when the OS keyring is not available for the encrypted provider. -var ErrKeyringNotAvailable = errors.New("OS keyring is not available. " + - "The encrypted provider requires an OS keyring to securely store passwords. " + - "Please use a different secrets provider (e.g., 1password) " + - "or ensure your system has a keyring service available") +var ErrKeyringNotAvailable = thverrors.WithCode( + errors.New("OS keyring is not available. "+ + "The encrypted provider requires an OS keyring to securely store passwords. "+ + "Please use a different secrets provider (e.g., 1password) "+ + "or ensure your system has a keyring service available"), + http.StatusBadRequest, +) // IsKeyringAvailable tests if any keyring backend is available func IsKeyringAvailable() bool { diff --git a/pkg/state/local.go b/pkg/state/local.go index 38f03eaf7..b439b6cb1 100644 --- a/pkg/state/local.go +++ b/pkg/state/local.go @@ -4,11 +4,14 @@ import ( "context" "fmt" "io" + "net/http" "os" "path/filepath" "strings" "github.com/adrg/xdg" + + "github.com/stacklok/toolhive/pkg/errors" ) const ( @@ -63,7 +66,7 @@ func (s *LocalStore) GetReader(_ context.Context, name string) (io.ReadCloser, e file, err := os.Open(filePath) if err != nil { if os.IsNotExist(err) { - return nil, fmt.Errorf("state '%s' not found", name) + return nil, errors.WithCode(fmt.Errorf("state '%s' not found", name), http.StatusNotFound) } return nil, fmt.Errorf("failed to open state file: %w", err) } diff --git a/pkg/transport/session/errors.go b/pkg/transport/session/errors.go index 7a8aab6bb..8f70d93a5 100644 --- a/pkg/transport/session/errors.go +++ b/pkg/transport/session/errors.go @@ -1,21 +1,41 @@ package session -import "errors" +import ( + "errors" + "net/http" + + thverrors "github.com/stacklok/toolhive/pkg/errors" +) // Common session errors var ( // ErrSessionDisconnected is returned when trying to send to a disconnected session - ErrSessionDisconnected = errors.New("session is disconnected") + ErrSessionDisconnected = thverrors.WithCode( + errors.New("session is disconnected"), + http.StatusServiceUnavailable, + ) // ErrMessageChannelFull is returned when the message channel is full - ErrMessageChannelFull = errors.New("message channel is full") + ErrMessageChannelFull = thverrors.WithCode( + errors.New("message channel is full"), + http.StatusServiceUnavailable, + ) // ErrSessionNotFound is returned when a session cannot be found - ErrSessionNotFound = errors.New("session not found") + ErrSessionNotFound = thverrors.WithCode( + errors.New("session not found"), + http.StatusNotFound, + ) // ErrSessionAlreadyExists is returned when trying to create a session with an existing ID - ErrSessionAlreadyExists = errors.New("session already exists") + ErrSessionAlreadyExists = thverrors.WithCode( + errors.New("session already exists"), + http.StatusConflict, + ) // ErrInvalidSessionType is returned when an invalid session type is provided - ErrInvalidSessionType = errors.New("invalid session type") + ErrInvalidSessionType = thverrors.WithCode( + errors.New("invalid session type"), + http.StatusBadRequest, + ) ) diff --git a/pkg/workloads/types/errors/errors.go b/pkg/workloads/types/errors/errors.go index 423b3ed54..09be3a96f 100644 --- a/pkg/workloads/types/errors/errors.go +++ b/pkg/workloads/types/errors/errors.go @@ -2,7 +2,15 @@ // It is located in a separate package to side-step an import cycle package errors -import "errors" +import ( + "errors" + "net/http" + + thverrors "github.com/stacklok/toolhive/pkg/errors" +) // ErrRunConfigNotFound is returned when a run config cannot be found for a workload. -var ErrRunConfigNotFound = errors.New("run config not found") +var ErrRunConfigNotFound = thverrors.WithCode( + errors.New("run config not found"), + http.StatusNotFound, +) diff --git a/pkg/workloads/types/validate.go b/pkg/workloads/types/validate.go index c5b29bea2..93780e40c 100644 --- a/pkg/workloads/types/validate.go +++ b/pkg/workloads/types/validate.go @@ -3,14 +3,21 @@ package types import ( + "errors" "fmt" + "net/http" "path/filepath" "regexp" "strings" + + thverrors "github.com/stacklok/toolhive/pkg/errors" ) // ErrInvalidWorkloadName is returned when a workload name fails validation. -var ErrInvalidWorkloadName = fmt.Errorf("invalid workload name") +var ErrInvalidWorkloadName = thverrors.WithCode( + errors.New("invalid workload name"), + http.StatusBadRequest, +) // workloadNamePattern validates workload names to prevent path traversal attacks // and other security issues. Workload names should only contain alphanumeric