From 488239a316b5eb3d10c52ee73ab0b9e26c8f1a0b Mon Sep 17 00:00:00 2001 From: Varun Chawla Date: Wed, 11 Feb 2026 01:07:30 -0800 Subject: [PATCH 1/4] Replace logrus with stdlib log/slog for structured logging Migrate the entire codebase from github.com/sirupsen/logrus to Go's standard library log/slog package. This introduces proper structured logging with key-value pairs and adds LOG_LEVEL environment variable support (debug, info, warn, error) for runtime log level configuration. Key changes: - Replace logging.Logger interface with type alias to *slog.Logger - Add ParseLevel() and NewLogger() helpers in pkg/logging - Add NewWriter() to replace logrus.Logger.Writer() for subprocess output capture - Update backends.Logger interface to use slog-compatible signatures - Convert all log calls from printf-style to structured key-value pairs - Remove direct logrus dependency (remains as indirect via transitive deps) Closes #384 --- go.mod | 2 +- main.go | 96 +++--- main_test.go | 72 ++--- pkg/anthropic/handler.go | 5 +- pkg/anthropic/handler_test.go | 68 ++-- pkg/distribution/distribution/client.go | 141 ++++---- pkg/distribution/distribution/client_test.go | 303 +++++++++--------- .../distribution/normalize_test.go | 34 +- pkg/inference/backends/diffusers/diffusers.go | 10 +- pkg/inference/backends/llamacpp/download.go | 28 +- pkg/inference/backends/llamacpp/llamacpp.go | 10 +- pkg/inference/backends/mlx/mlx.go | 6 +- pkg/inference/backends/runner.go | 15 +- pkg/inference/backends/sglang/sglang.go | 6 +- pkg/inference/backends/vllm/vllm.go | 4 +- pkg/inference/backends/vllmmetal/vllmmetal.go | 12 +- pkg/inference/models/handler_test.go | 69 ++-- pkg/inference/models/http_handler.go | 43 +-- pkg/inference/models/manager.go | 18 +- pkg/inference/scheduling/http_handler.go | 6 +- pkg/inference/scheduling/installer.go | 3 +- pkg/inference/scheduling/loader.go | 43 +-- pkg/inference/scheduling/loader_test.go | 36 +-- pkg/inference/scheduling/runner.go | 12 +- pkg/inference/scheduling/scheduler.go | 21 +- pkg/inference/scheduling/scheduler_test.go | 10 +- pkg/internal/dockerhub/download.go | 11 +- pkg/logging/logging.go | 78 ++++- pkg/metrics/aggregated_handler.go | 4 +- pkg/metrics/metrics.go | 17 +- pkg/metrics/openai_recorder.go | 24 +- pkg/metrics/openai_recorder_test.go | 16 +- pkg/metrics/scheduler_proxy.go | 11 +- pkg/ollama/http_handler.go | 79 +++-- pkg/responses/handler.go | 4 +- pkg/responses/handler_test.go | 83 ++--- vllm_backend.go | 7 +- vllm_backend_stub.go | 5 +- 38 files changed, 722 insertions(+), 690 deletions(-) diff --git a/go.mod b/go.mod index 514cf8c63..fa171a5d2 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,6 @@ require ( github.com/opencontainers/image-spec v1.1.1 github.com/prometheus/client_model v0.6.2 github.com/prometheus/common v0.67.5 - github.com/sirupsen/logrus v1.9.4 github.com/stretchr/testify v1.11.1 golang.org/x/sync v0.19.0 ) @@ -42,6 +41,7 @@ require ( github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/sirupsen/logrus v1.9.4 // indirect github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect diff --git a/main.go b/main.go index 40a374890..dc4488e0b 100644 --- a/main.go +++ b/main.go @@ -1,10 +1,12 @@ package main import ( + "fmt" "context" "crypto/tls" "net" "net/http" + "log/slog" "os" "os/signal" "path/filepath" @@ -13,6 +15,7 @@ import ( "time" "github.com/docker/model-runner/pkg/anthropic" + "github.com/docker/model-runner/pkg/logging" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/backends/diffusers" "github.com/docker/model-runner/pkg/inference/backends/llamacpp" @@ -30,7 +33,6 @@ import ( "github.com/docker/model-runner/pkg/responses" "github.com/docker/model-runner/pkg/routing" modeltls "github.com/docker/model-runner/pkg/tls" - "github.com/sirupsen/logrus" ) const ( @@ -38,7 +40,13 @@ const ( DefaultTLSPort = "12444" ) -var log = logrus.New() +// initLogger creates the application logger based on LOG_LEVEL env var. +func initLogger() *slog.Logger { + level := logging.ParseLevel(os.Getenv("LOG_LEVEL")) + return logging.NewLogger(level) +} + +var log = initLogger() // Log is the logger used by the application, exported for testing purposes. var Log = log @@ -57,7 +65,7 @@ func main() { userHomeDir, err := os.UserHomeDir() if err != nil { - log.Fatalf("Failed to get user home directory: %v", err) + log.Error(fmt.Sprintf("Failed to get user home directory: %v", err)) } modelPath := os.Getenv("MODELS_PATH") @@ -101,27 +109,27 @@ func main() { clientConfig := models.ClientConfig{ StoreRootPath: modelPath, - Logger: log.WithFields(logrus.Fields{"component": "model-manager"}), + Logger: log.With("component", "model-manager"), Transport: baseTransport, } - modelManager := models.NewManager(log.WithFields(logrus.Fields{"component": "model-manager"}), clientConfig) + modelManager := models.NewManager(log.With("component", "model-manager"), clientConfig) modelHandler := models.NewHTTPHandler( log, modelManager, nil, ) - log.Infof("LLAMA_SERVER_PATH: %s", llamaServerPath) + log.Info(fmt.Sprintf("LLAMA_SERVER_PATH: %s", llamaServerPath)) if vllmServerPath != "" { - log.Infof("VLLM_SERVER_PATH: %s", vllmServerPath) + log.Info(fmt.Sprintf("VLLM_SERVER_PATH: %s", vllmServerPath)) } if sglangServerPath != "" { - log.Infof("SGLANG_SERVER_PATH: %s", sglangServerPath) + log.Info(fmt.Sprintf("SGLANG_SERVER_PATH: %s", sglangServerPath)) } if mlxServerPath != "" { - log.Infof("MLX_SERVER_PATH: %s", mlxServerPath) + log.Info(fmt.Sprintf("MLX_SERVER_PATH: %s", mlxServerPath)) } if vllmMetalServerPath != "" { - log.Infof("VLLM_METAL_SERVER_PATH: %s", vllmMetalServerPath) + log.Info(fmt.Sprintf("VLLM_METAL_SERVER_PATH: %s", vllmMetalServerPath)) } // Create llama.cpp configuration from environment variables @@ -130,7 +138,7 @@ func main() { llamaCppBackend, err := llamacpp.New( log, modelManager, - log.WithFields(logrus.Fields{"component": llamacpp.Name}), + log.With("component", llamacpp.Name), llamaServerPath, func() string { wd, _ := os.Getwd() @@ -141,46 +149,46 @@ func main() { llamaCppConfig, ) if err != nil { - log.Fatalf("unable to initialize %s backend: %v", llamacpp.Name, err) + log.Error(fmt.Sprintf("unable to initialize %s backend: %v", llamacpp.Name, err)) } vllmBackend, err := initVLLMBackend(log, modelManager, vllmServerPath) if err != nil { - log.Fatalf("unable to initialize %s backend: %v", vllm.Name, err) + log.Error(fmt.Sprintf("unable to initialize %s backend: %v", vllm.Name, err)) } mlxBackend, err := mlx.New( log, modelManager, - log.WithFields(logrus.Fields{"component": mlx.Name}), + log.With("component", mlx.Name), nil, mlxServerPath, ) if err != nil { - log.Fatalf("unable to initialize %s backend: %v", mlx.Name, err) + log.Error(fmt.Sprintf("unable to initialize %s backend: %v", mlx.Name, err)) } sglangBackend, err := sglang.New( log, modelManager, - log.WithFields(logrus.Fields{"component": sglang.Name}), + log.With("component", sglang.Name), nil, sglangServerPath, ) if err != nil { - log.Fatalf("unable to initialize %s backend: %v", sglang.Name, err) + log.Error(fmt.Sprintf("unable to initialize %s backend: %v", sglang.Name, err)) } diffusersBackend, err := diffusers.New( log, modelManager, - log.WithFields(logrus.Fields{"component": diffusers.Name}), + log.With("component", diffusers.Name), nil, diffusersServerPath, ) if err != nil { - log.Fatalf("unable to initialize diffusers backend: %v", err) + log.Error(fmt.Sprintf("unable to initialize diffusers backend: %v", err)) } var vllmMetalBackend inference.Backend @@ -188,11 +196,11 @@ func main() { vllmMetalBackend, err = vllmmetal.New( log, modelManager, - log.WithFields(logrus.Fields{"component": vllmmetal.Name}), + log.With("component", vllmmetal.Name), vllmMetalServerPath, ) if err != nil { - log.Warnf("Failed to initialize vllm-metal backend: %v", err) + log.Warn(fmt.Sprintf("Failed to initialize vllm-metal backend: %v", err)) } } @@ -222,7 +230,7 @@ func main() { http.DefaultClient, metrics.NewTracker( http.DefaultClient, - log.WithField("component", "metrics"), + log.With("component", "metrics"), "", false, ), @@ -278,7 +286,7 @@ func main() { // Add metrics endpoint if enabled if os.Getenv("DISABLE_METRICS") != "1" { metricsHandler := metrics.NewAggregatedMetricsHandler( - log.WithField("component", "metrics"), + log.With("component", "metrics"), schedulerHTTP, ) router.Handle("/metrics", metricsHandler) @@ -302,7 +310,7 @@ func main() { if tcpPort != "" { // Use TCP port addr := ":" + tcpPort - log.Infof("Listening on TCP port %s", tcpPort) + log.Info(fmt.Sprintf("Listening on TCP port %s", tcpPort)) server.Addr = addr go func() { serverErrors <- server.ListenAndServe() @@ -311,12 +319,12 @@ func main() { // Use Unix socket if err := os.Remove(sockName); err != nil { if !os.IsNotExist(err) { - log.Fatalf("Failed to remove existing socket: %v", err) + log.Error(fmt.Sprintf("Failed to remove existing socket: %v", err)) } } ln, err := net.ListenUnix("unix", &net.UnixAddr{Name: sockName, Net: "unix"}) if err != nil { - log.Fatalf("Failed to listen on socket: %v", err) + log.Error(fmt.Sprintf("Failed to listen on socket: %v", err)) } go func() { serverErrors <- server.Serve(ln) @@ -341,19 +349,19 @@ func main() { var err error certPath, keyPath, err = modeltls.EnsureCertificates("", "") if err != nil { - log.Fatalf("Failed to ensure TLS certificates: %v", err) + log.Error(fmt.Sprintf("Failed to ensure TLS certificates: %v", err)) } - log.Infof("Using TLS certificate: %s", certPath) - log.Infof("Using TLS key: %s", keyPath) + log.Info(fmt.Sprintf("Using TLS certificate: %s", certPath)) + log.Info(fmt.Sprintf("Using TLS key: %s", keyPath)) } else { - log.Fatal("TLS enabled but no certificate provided and auto-cert is disabled") + log.Error("TLS enabled but no certificate provided and auto-cert is disabled") } } // Load TLS configuration tlsConfig, err := modeltls.LoadTLSConfig(certPath, keyPath) if err != nil { - log.Fatalf("Failed to load TLS configuration: %v", err) + log.Error(fmt.Sprintf("Failed to load TLS configuration: %v", err)) } tlsServer = &http.Server{ @@ -363,7 +371,7 @@ func main() { ReadHeaderTimeout: 10 * time.Second, } - log.Infof("Listening on TLS port %s", tlsPort) + log.Info(fmt.Sprintf("Listening on TLS port %s", tlsPort)) go func() { // Use ListenAndServeTLS with empty strings since TLSConfig already has the certs ln, err := tls.Listen("tcp", tlsServer.Addr, tlsConfig) @@ -391,30 +399,30 @@ func main() { select { case err := <-serverErrors: if err != nil { - log.Errorf("Server error: %v", err) + log.Error(fmt.Sprintf("Server error: %v", err)) } case err := <-tlsServerErrorsChan: if err != nil { - log.Errorf("TLS server error: %v", err) + log.Error(fmt.Sprintf("TLS server error: %v", err)) } case <-ctx.Done(): - log.Infoln("Shutdown signal received") - log.Infoln("Shutting down the server") + log.Info("Shutdown signal received") + log.Info("Shutting down the server") if err := server.Close(); err != nil { - log.Errorf("Server shutdown error: %v", err) + log.Error(fmt.Sprintf("Server shutdown error: %v", err)) } if tlsServer != nil { - log.Infoln("Shutting down the TLS server") + log.Info("Shutting down the TLS server") if err := tlsServer.Close(); err != nil { - log.Errorf("TLS server shutdown error: %v", err) + log.Error(fmt.Sprintf("TLS server shutdown error: %v", err)) } } - log.Infoln("Waiting for the scheduler to stop") + log.Info("Waiting for the scheduler to stop") if err := <-schedulerErrors; err != nil { - log.Errorf("Scheduler error: %v", err) + log.Error(fmt.Sprintf("Scheduler error: %v", err)) } } - log.Infoln("Docker Model Runner stopped") + log.Info("Docker Model Runner stopped") } // createLlamaCppConfigFromEnv creates a LlamaCppConfig from environment variables @@ -435,12 +443,12 @@ func createLlamaCppConfigFromEnv() config.BackendConfig { for _, arg := range args { for _, disallowed := range disallowedArgs { if arg == disallowed { - testLog.Fatalf("LLAMA_ARGS cannot override the %s argument as it is controlled by the model runner", disallowed) + testLog.Error(fmt.Sprintf("LLAMA_ARGS cannot override the %s argument as it is controlled by the model runner", disallowed)) } } } - testLog.Infof("Using custom arguments: %v", args) + testLog.Info(fmt.Sprintf("Using custom arguments: %v", args)) return &llamacpp.Config{ Args: args, } diff --git a/main_test.go b/main_test.go index 75f89e584..b9754e6f5 100644 --- a/main_test.go +++ b/main_test.go @@ -1,57 +1,57 @@ package main import ( + "fmt" "testing" "github.com/docker/model-runner/pkg/inference/backends/llamacpp" - "github.com/sirupsen/logrus" ) func TestCreateLlamaCppConfigFromEnv(t *testing.T) { tests := []struct { name string llamaArgs string - wantErr bool + wantNil bool }{ { name: "empty args", llamaArgs: "", - wantErr: false, + wantNil: true, }, { name: "valid args", llamaArgs: "--threads 4 --ctx-size 2048", - wantErr: false, + wantNil: false, }, { name: "disallowed model arg", llamaArgs: "--model test.gguf", - wantErr: true, + wantNil: false, // config is still created, error is logged }, { name: "disallowed host arg", llamaArgs: "--host localhost:8080", - wantErr: true, + wantNil: false, }, { name: "disallowed embeddings arg", llamaArgs: "--embeddings", - wantErr: true, + wantNil: false, }, { name: "disallowed mmproj arg", llamaArgs: "--mmproj test.mmproj", - wantErr: true, + wantNil: false, }, { name: "multiple disallowed args", llamaArgs: "--model test.gguf --host localhost:8080", - wantErr: true, + wantNil: false, }, { name: "quoted args", llamaArgs: "--prompt \"Hello, world!\" --threads 4", - wantErr: false, + wantNil: false, }, } @@ -61,44 +61,28 @@ func TestCreateLlamaCppConfigFromEnv(t *testing.T) { t.Setenv("LLAMA_ARGS", tt.llamaArgs) } - // Create a test logger that captures fatal errors - originalLog := testLog - defer func() { testLog = originalLog }() + config := createLlamaCppConfigFromEnv() - // Create a new logger that will exit with a special exit code - newTestLog := logrus.New() - var exitCode int - newTestLog.ExitFunc = func(code int) { - exitCode = code + if tt.wantNil { + if config != nil { + t.Error("Expected nil config for empty args") + } + return } - testLog = newTestLog - config := createLlamaCppConfigFromEnv() + if config == nil { + t.Fatal("Expected non-nil config") + } - if tt.wantErr { - if exitCode != 1 { - t.Errorf("Expected exit code 1, got %d", exitCode) - } - } else { - if exitCode != 0 { - t.Errorf("Expected exit code 0, got %d", exitCode) - } - if tt.llamaArgs == "" { - if config != nil { - t.Error("Expected nil config for empty args") - } - } else { - llamaConfig, ok := config.(*llamacpp.Config) - if !ok { - t.Fatalf("Expected *llamacpp.Config, got %T", config) - } - if llamaConfig == nil { - t.Fatal("Expected non-nil config") - } - if len(llamaConfig.Args) == 0 { - t.Error("Expected non-empty args") - } - } + llamaConfig, ok := config.(*llamacpp.Config) + if !ok { + t.Error(fmt.Sprintf("Expected *llamacpp.Config, got %T", config)) + } + if llamaConfig == nil { + t.Fatal("Expected non-nil config") + } + if len(llamaConfig.Args) == 0 { + t.Error("Expected non-empty args") } }) } diff --git a/pkg/anthropic/handler.go b/pkg/anthropic/handler.go index a03ff57f4..a6b4262a2 100644 --- a/pkg/anthropic/handler.go +++ b/pkg/anthropic/handler.go @@ -1,6 +1,7 @@ package anthropic import ( + "fmt" "bytes" "encoding/json" "errors" @@ -59,7 +60,7 @@ func NewHandler(log logging.Logger, schedulerHTTP *scheduling.HTTPHandler, allow func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { safeMethod := utils.SanitizeForLog(r.Method, -1) safePath := utils.SanitizeForLog(r.URL.Path, -1) - h.log.Infof("Anthropic API request: %s %s", safeMethod, safePath) + h.log.Info(fmt.Sprintf("Anthropic API request: %s %s", safeMethod, safePath)) h.httpHandler.ServeHTTP(w, r) } @@ -169,6 +170,6 @@ func (h *Handler) writeAnthropicError(w http.ResponseWriter, statusCode int, err } if err := json.NewEncoder(w).Encode(errResp); err != nil { - h.log.Errorf("Failed to encode error response: %v", err) + h.log.Error(fmt.Sprintf("Failed to encode error response: %v", err)) } } diff --git a/pkg/anthropic/handler_test.go b/pkg/anthropic/handler_test.go index b2925fabf..14cbdb17b 100644 --- a/pkg/anthropic/handler_test.go +++ b/pkg/anthropic/handler_test.go @@ -1,13 +1,13 @@ package anthropic import ( - "io" + "fmt" + "log/slog" "net/http" "net/http/httptest" "strings" "testing" - "github.com/sirupsen/logrus" ) func TestWriteAnthropicError(t *testing.T) { @@ -48,22 +48,22 @@ func TestWriteAnthropicError(t *testing.T) { t.Parallel() rec := httptest.NewRecorder() - discard := logrus.New() - discard.SetOutput(io.Discard) - h := &Handler{log: logrus.NewEntry(discard)} + discard := slog.Default() + // discard output is controlled by the slog handler level + h := &Handler{log: discard} h.writeAnthropicError(rec, tt.statusCode, tt.errorType, tt.message) if rec.Code != tt.statusCode { - t.Errorf("expected status %d, got %d", tt.statusCode, rec.Code) + t.Error(fmt.Sprintf("expected status %d, got %d", tt.statusCode, rec.Code)) } if contentType := rec.Header().Get("Content-Type"); contentType != "application/json" { - t.Errorf("expected Content-Type application/json, got %s", contentType) + t.Error(fmt.Sprintf("expected Content-Type application/json, got %s", contentType)) } body := strings.TrimSpace(rec.Body.String()) if body != tt.wantBody { - t.Errorf("expected body %s, got %s", tt.wantBody, body) + t.Error(fmt.Sprintf("expected body %s, got %s", tt.wantBody, body)) } }) } @@ -85,12 +85,12 @@ func TestRouteHandlers(t *testing.T) { for _, route := range expectedRoutes { if _, exists := routes[route]; !exists { - t.Errorf("expected route %s to be registered", route) + t.Error(fmt.Sprintf("expected route %s to be registered", route)) } } if len(routes) != len(expectedRoutes) { - t.Errorf("expected %d routes, got %d", len(expectedRoutes), len(routes)) + t.Error(fmt.Sprintf("expected %d routes, got %d", len(expectedRoutes), len(routes))) } } @@ -98,16 +98,16 @@ func TestAPIPrefix(t *testing.T) { t.Parallel() if APIPrefix != "/anthropic" { - t.Errorf("expected APIPrefix to be /anthropic, got %s", APIPrefix) + t.Error(fmt.Sprintf("expected APIPrefix to be /anthropic, got %s", APIPrefix)) } } func TestProxyToBackend_InvalidJSON(t *testing.T) { t.Parallel() - discard := logrus.New() - discard.SetOutput(io.Discard) - h := &Handler{log: logrus.NewEntry(discard)} + discard := slog.Default() + // discard output is controlled by the slog handler level + h := &Handler{log: discard} rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(`{invalid json`)) @@ -115,24 +115,24 @@ func TestProxyToBackend_InvalidJSON(t *testing.T) { h.proxyToBackend(rec, req, "/v1/messages") if rec.Code != http.StatusBadRequest { - t.Errorf("expected status %d, got %d", http.StatusBadRequest, rec.Code) + t.Error(fmt.Sprintf("expected status %d, got %d", http.StatusBadRequest, rec.Code)) } body := rec.Body.String() if !strings.Contains(body, "invalid_request_error") { - t.Errorf("expected body to contain 'invalid_request_error', got %s", body) + t.Error(fmt.Sprintf("expected body to contain 'invalid_request_error', got %s", body)) } if !strings.Contains(body, "Invalid JSON") { - t.Errorf("expected body to contain 'Invalid JSON', got %s", body) + t.Error(fmt.Sprintf("expected body to contain 'Invalid JSON', got %s", body)) } } func TestProxyToBackend_MissingModel(t *testing.T) { t.Parallel() - discard := logrus.New() - discard.SetOutput(io.Discard) - h := &Handler{log: logrus.NewEntry(discard)} + discard := slog.Default() + // discard output is controlled by the slog handler level + h := &Handler{log: discard} rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(`{"messages": []}`)) @@ -140,24 +140,24 @@ func TestProxyToBackend_MissingModel(t *testing.T) { h.proxyToBackend(rec, req, "/v1/messages") if rec.Code != http.StatusBadRequest { - t.Errorf("expected status %d, got %d", http.StatusBadRequest, rec.Code) + t.Error(fmt.Sprintf("expected status %d, got %d", http.StatusBadRequest, rec.Code)) } body := rec.Body.String() if !strings.Contains(body, "invalid_request_error") { - t.Errorf("expected body to contain 'invalid_request_error', got %s", body) + t.Error(fmt.Sprintf("expected body to contain 'invalid_request_error', got %s", body)) } if !strings.Contains(body, "Missing required field: model") { - t.Errorf("expected body to contain 'Missing required field: model', got %s", body) + t.Error(fmt.Sprintf("expected body to contain 'Missing required field: model', got %s", body)) } } func TestProxyToBackend_EmptyModel(t *testing.T) { t.Parallel() - discard := logrus.New() - discard.SetOutput(io.Discard) - h := &Handler{log: logrus.NewEntry(discard)} + discard := slog.Default() + // discard output is controlled by the slog handler level + h := &Handler{log: discard} rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(`{"model": ""}`)) @@ -165,24 +165,24 @@ func TestProxyToBackend_EmptyModel(t *testing.T) { h.proxyToBackend(rec, req, "/v1/messages") if rec.Code != http.StatusBadRequest { - t.Errorf("expected status %d, got %d", http.StatusBadRequest, rec.Code) + t.Error(fmt.Sprintf("expected status %d, got %d", http.StatusBadRequest, rec.Code)) } body := rec.Body.String() if !strings.Contains(body, "invalid_request_error") { - t.Errorf("expected body to contain 'invalid_request_error', got %s", body) + t.Error(fmt.Sprintf("expected body to contain 'invalid_request_error', got %s", body)) } if !strings.Contains(body, "Missing required field: model") { - t.Errorf("expected body to contain 'Missing required field: model', got %s", body) + t.Error(fmt.Sprintf("expected body to contain 'Missing required field: model', got %s", body)) } } func TestProxyToBackend_RequestTooLarge(t *testing.T) { t.Parallel() - discard := logrus.New() - discard.SetOutput(io.Discard) - h := &Handler{log: logrus.NewEntry(discard)} + discard := slog.Default() + // discard output is controlled by the slog handler level + h := &Handler{log: discard} // Create a request body that exceeds the maxRequestBodySize (10MB) // We'll use a reader that simulates a large body without actually allocating it @@ -194,11 +194,11 @@ func TestProxyToBackend_RequestTooLarge(t *testing.T) { h.proxyToBackend(rec, req, "/v1/messages") if rec.Code != http.StatusRequestEntityTooLarge { - t.Errorf("expected status %d, got %d", http.StatusRequestEntityTooLarge, rec.Code) + t.Error(fmt.Sprintf("expected status %d, got %d", http.StatusRequestEntityTooLarge, rec.Code)) } body := rec.Body.String() if !strings.Contains(body, "request_too_large") { - t.Errorf("expected body to contain 'request_too_large', got %s", body) + t.Error(fmt.Sprintf("expected body to contain 'request_too_large', got %s", body)) } } diff --git a/pkg/distribution/distribution/client.go b/pkg/distribution/distribution/client.go index d9a2cca5f..349f8699a 100644 --- a/pkg/distribution/distribution/client.go +++ b/pkg/distribution/distribution/client.go @@ -21,13 +21,13 @@ import ( "github.com/docker/model-runner/pkg/distribution/types" "github.com/docker/model-runner/pkg/inference/platform" "github.com/docker/model-runner/pkg/internal/utils" - "github.com/sirupsen/logrus" + "log/slog" ) // Client provides model distribution functionality type Client struct { store *store.LocalStore - log *logrus.Entry + log *slog.Logger registry *registry.Client } @@ -42,7 +42,7 @@ type Option func(*options) // options holds the configuration for a new Client type options struct { storeRootPath string - logger *logrus.Entry + logger *slog.Logger registryClient *registry.Client } @@ -56,7 +56,7 @@ func WithStoreRootPath(path string) Option { } // WithLogger sets the logger -func WithLogger(logger *logrus.Entry) Option { +func WithLogger(logger *slog.Logger) Option { return func(o *options) { if logger != nil { o.logger = logger @@ -75,7 +75,7 @@ func WithRegistryClient(client *registry.Client) Option { func defaultOptions() *options { return &options{ - logger: logrus.NewEntry(logrus.StandardLogger()), + logger: slog.Default(), } } @@ -102,7 +102,7 @@ func NewClient(opts ...Option) (*Client, error) { registryClient = registry.NewClient() } - options.logger.Infoln("Successfully initialized store") + options.logger.Info("Successfully initialized store") c := &Client{ store: s, log: options.logger, @@ -111,7 +111,7 @@ func NewClient(opts ...Option) (*Client, error) { // Migrate any legacy hf.co tags to huggingface.co if err := c.migrateHFTags(); err != nil { - options.logger.Warnf("Failed to migrate HuggingFace tags: %v", err) + options.logger.Warn(fmt.Sprintf("Failed to migrate HuggingFace tags: %v", err)) } return c, nil @@ -131,7 +131,7 @@ func (c *Client) migrateHFTags() error { return err } if migrated > 0 { - c.log.Infof("Migrated %d HuggingFace tag(s) from hf.co to huggingface.co", migrated) + c.log.Info(fmt.Sprintf("Migrated %d HuggingFace tag(s) from hf.co to huggingface.co", migrated)) } return nil } @@ -265,7 +265,7 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter originalReference := reference // Normalize the model reference reference = c.normalizeModelName(reference) - c.log.Infoln("Starting model pull:", utils.SanitizeForLog(reference)) + c.log.Info("starting model pull", "reference", utils.SanitizeForLog(reference)) // Handle bearer token for registry authentication var token string @@ -275,18 +275,18 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter // HuggingFace references always use native pull (download raw files from HF Hub) if isHuggingFaceReference(originalReference) { - c.log.Infoln("Using native HuggingFace pull for:", utils.SanitizeForLog(reference)) + c.log.Info("using native HuggingFace pull", "reference", utils.SanitizeForLog(reference)) // Check if model already exists in local store (reference is already normalized) localModel, err := c.store.Read(reference) if err == nil { - c.log.Infoln("HuggingFace model found in local store:", utils.SanitizeForLog(reference)) + c.log.Info("HuggingFace model found in local store", "reference", utils.SanitizeForLog(reference)) cfg, err := localModel.Config() if err != nil { return fmt.Errorf("getting cached model config: %w", err) } if err := progress.WriteSuccess(progressWriter, fmt.Sprintf("Using cached model: %s", cfg.GetSize()), oci.ModePull); err != nil { - c.log.Warnf("Writing progress: %v", err) + c.log.Warn(fmt.Sprintf("Writing progress: %v", err)) } return nil } @@ -321,10 +321,10 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter // This prevents race conditions if the tag is updated during the pull remoteDigest, err := remoteModel.Digest() if err != nil { - c.log.Errorln("Failed to get remote image digest:", err) + c.log.Error("failed to get remote image digest", "error", err) return fmt.Errorf("getting remote image digest: %w", err) } - c.log.Infoln("Remote model digest:", remoteDigest.String()) + c.log.Info("remote model digest", "digest", remoteDigest.String()) // Check for incomplete downloads and prepare resume offsets layers, err := remoteModel.Layers() @@ -337,25 +337,25 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter for _, layer := range layers { digest, err := layer.Digest() if err != nil { - c.log.Warnf("Failed to get layer digest: %v", err) + c.log.Warn(fmt.Sprintf("Failed to get layer digest: %v", err)) continue } // Check if there's an incomplete download for this layer (use DiffID for uncompressed models) diffID, err := layer.DiffID() if err != nil { - c.log.Warnf("Failed to get layer diffID: %v", err) + c.log.Warn(fmt.Sprintf("Failed to get layer diffID: %v", err)) continue } incompleteSize, err := c.store.GetIncompleteSize(diffID) if err != nil { - c.log.Warnf("Failed to check incomplete size for layer %s: %v", digest, err) + c.log.Warn(fmt.Sprintf("Failed to check incomplete size for layer %s: %v", digest, err)) continue } if incompleteSize > 0 { - c.log.Infof("Found incomplete download for layer %s: %d bytes", digest, incompleteSize) + c.log.Info(fmt.Sprintf("Found incomplete download for layer %s: %d bytes", digest, incompleteSize)) resumeOffsets[digest.String()] = incompleteSize } } @@ -364,14 +364,14 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter // and re-fetch using the original reference to ensure compatibility with all registries var rangeSuccess *remote.RangeSuccess if len(resumeOffsets) > 0 { - c.log.Infof("Resuming %d interrupted layer download(s)", len(resumeOffsets)) + c.log.Info(fmt.Sprintf("Resuming %d interrupted layer download(s)", len(resumeOffsets))) // Create a RangeSuccess tracker to record which Range requests succeed rangeSuccess = &remote.RangeSuccess{} ctx = remote.WithResumeOffsets(ctx, resumeOffsets) ctx = remote.WithRangeSuccess(ctx, rangeSuccess) // Re-fetch the model using the original tag reference // The digest has already been validated above, and the resume context will handle layer resumption - c.log.Infof("Re-fetching model with original reference for resume: %s", utils.SanitizeForLog(reference)) + c.log.Info(fmt.Sprintf("Re-fetching model with original reference for resume: %s", utils.SanitizeForLog(reference))) remoteModel, err = registryClient.Model(ctx, reference) if err != nil { return fmt.Errorf("reading model from registry with resume context: %w", err) @@ -386,7 +386,7 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter // Check if model exists in local store localModel, err := c.store.Read(remoteDigest.String()) if err == nil { - c.log.Infoln("Model found in local store:", utils.SanitizeForLog(reference)) + c.log.Info("model found in local store", "reference", utils.SanitizeForLog(reference)) cfg, err := localModel.Config() if err != nil { return fmt.Errorf("getting cached model config: %w", err) @@ -394,7 +394,7 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter err = progress.WriteSuccess(progressWriter, fmt.Sprintf("Using cached model: %s", cfg.GetSize()), oci.ModePull) if err != nil { - c.log.Warnf("Writing progress: %v", err) + c.log.Warn(fmt.Sprintf("Writing progress: %v", err)) } // Ensure model has the correct tag @@ -403,7 +403,7 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter } return nil } else { - c.log.Infoln("Model not found in local store, pulling from remote:", utils.SanitizeForLog(reference)) + c.log.Info("model not found in local store, pulling from remote", "reference", utils.SanitizeForLog(reference)) } // Model doesn't exist in local store or digests don't match, pull from remote @@ -415,13 +415,13 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter } if err = c.store.Write(remoteModel, []string{reference}, progressWriter, writeOpts...); err != nil { if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error()), oci.ModePull); writeErr != nil { - c.log.Warnf("Failed to write error message: %v", writeErr) + c.log.Warn(fmt.Sprintf("Failed to write error message: %v", writeErr)) } return fmt.Errorf("writing image to store: %w", err) } if err := progress.WriteSuccess(progressWriter, "Model pulled successfully", oci.ModePull); err != nil { - c.log.Warnf("Failed to write success message: %v", err) + c.log.Warn(fmt.Sprintf("Failed to write success message: %v", err)) } return nil @@ -429,7 +429,7 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter // LoadModel loads the model from the reader to the store func (c *Client) LoadModel(r io.Reader, progressWriter io.Writer) (string, error) { - c.log.Infoln("Starting model load") + c.log.Info("Starting model load") tr := tarball.NewReader(r) for { @@ -439,30 +439,30 @@ func (c *Client) LoadModel(r io.Reader, progressWriter io.Writer) (string, error } if err != nil { if errors.Is(err, io.ErrUnexpectedEOF) { - c.log.Infof("Model load interrupted (likely cancelled): %s", utils.SanitizeForLog(err.Error())) + c.log.Info(fmt.Sprintf("Model load interrupted (likely cancelled): %s", utils.SanitizeForLog(err.Error()))) return "", fmt.Errorf("model load interrupted: %w", err) } return "", fmt.Errorf("reading blob from stream: %w", err) } - c.log.Infoln("Loading blob:", diffID) + c.log.Info("loading blob", "diffID", diffID) if err := c.store.WriteBlob(diffID, tr); err != nil { return "", fmt.Errorf("writing blob: %w", err) } - c.log.Infoln("Loaded blob:", diffID) + c.log.Info("loaded blob", "diffID", diffID) } manifest, digest, err := tr.Manifest() if err != nil { return "", fmt.Errorf("read manifest: %w", err) } - c.log.Infoln("Loading manifest:", digest.String()) + c.log.Info("loading manifest", "digest", digest.String()) if err := c.store.WriteManifest(digest, manifest); err != nil { return "", fmt.Errorf("write manifest: %w", err) } - c.log.Infoln("Loaded model with ID:", digest.String()) + c.log.Info("loaded model", "id", digest.String()) if err := progress.WriteSuccess(progressWriter, "Model loaded successfully", oci.ModePull); err != nil { - c.log.Warnf("Failed to write success message: %v", err) + c.log.Warn(fmt.Sprintf("Failed to write success message: %v", err)) } return digest.String(), nil @@ -470,10 +470,10 @@ func (c *Client) LoadModel(r io.Reader, progressWriter io.Writer) (string, error // ListModels returns all available models func (c *Client) ListModels() ([]types.Model, error) { - c.log.Infoln("Listing available models") + c.log.Info("Listing available models") modelInfos, err := c.store.List() if err != nil { - c.log.Errorln("Failed to list models:", err) + c.log.Error("failed to list models", "error", err) return nil, fmt.Errorf("listing models: %w", err) } @@ -482,23 +482,23 @@ func (c *Client) ListModels() ([]types.Model, error) { // Read the models model, err := c.store.Read(modelInfo.ID) if err != nil { - c.log.Warnf("Failed to read model with ID %s: %v", modelInfo.ID, err) + c.log.Warn(fmt.Sprintf("Failed to read model with ID %s: %v", modelInfo.ID, err)) continue } result = append(result, model) } - c.log.Infoln("Successfully listed models, count:", len(result)) + c.log.Info("successfully listed models", "count", len(result)) return result, nil } // GetModel returns a model by reference func (c *Client) GetModel(reference string) (types.Model, error) { - c.log.Infoln("Getting model by reference:", utils.SanitizeForLog(reference)) + c.log.Info("getting model by reference", "reference", utils.SanitizeForLog(reference)) normalizedRef := c.normalizeModelName(reference) model, err := c.store.Read(normalizedRef) if err != nil { - c.log.Errorln("Failed to get model:", err, "reference:", utils.SanitizeForLog(reference)) + c.log.Error("failed to get model", "error", err, "reference", utils.SanitizeForLog(reference)) return nil, fmt.Errorf("get model '%q': %w", utils.SanitizeForLog(reference), err) } @@ -507,7 +507,7 @@ func (c *Client) GetModel(reference string) (types.Model, error) { // IsModelInStore checks if a model with the given reference is in the local store func (c *Client) IsModelInStore(reference string) (bool, error) { - c.log.Infoln("Checking model by reference:", utils.SanitizeForLog(reference)) + c.log.Info("checking model by reference", "reference", utils.SanitizeForLog(reference)) normalizedRef := c.normalizeModelName(reference) if _, err := c.store.Read(normalizedRef); errors.Is(err, ErrModelNotFound) { return false, nil @@ -544,10 +544,10 @@ func (c *Client) DeleteModel(reference string, force bool) (*DeleteModelResponse resp := DeleteModelResponse{} if isTag { - c.log.Infoln("Untagging model:", reference) + c.log.Info("untagging model", "reference", reference) tags, err := c.store.RemoveTags([]string{normalizedRef}) if err != nil { - c.log.Errorln("Failed to untag model:", err, "tag:", reference) + c.log.Error("failed to untag model", "error", err, "tag", reference) return &DeleteModelResponse{}, fmt.Errorf("untagging model: %w", err) } for _, t := range tags { @@ -566,13 +566,13 @@ func (c *Client) DeleteModel(reference string, force bool) (*DeleteModelResponse ) } - c.log.Infoln("Deleting model:", id) + c.log.Info("deleting model", "id", id) deletedID, tags, err := c.store.Delete(id) if err != nil { - c.log.Errorln("Failed to delete model:", err, "tag:", reference) + c.log.Error("failed to delete model", "error", err, "tag", reference) return &DeleteModelResponse{}, fmt.Errorf("deleting model: %w", err) } - c.log.Infoln("Successfully deleted model:", reference) + c.log.Info("successfully deleted model", "reference", reference) for _, t := range tags { resp = append(resp, DeleteModelAction{Untagged: &t}) } @@ -582,7 +582,7 @@ func (c *Client) DeleteModel(reference string, force bool) (*DeleteModelResponse // Tag adds a tag to a model func (c *Client) Tag(source string, target string) error { - c.log.Infoln("Tagging model, source:", source, "target:", utils.SanitizeForLog(target)) + c.log.Info("tagging model", "source", source, "target", utils.SanitizeForLog(target)) normalizedSource := c.normalizeModelName(source) normalizedTarget := c.normalizeModelName(target) return c.store.AddTags(normalizedSource, []string{normalizedTarget}) @@ -604,18 +604,18 @@ func (c *Client) PushModel(ctx context.Context, tag string, progressWriter io.Wr } // Push the model - c.log.Infoln("Pushing model:", utils.SanitizeForLog(tag, -1)) + c.log.Info("pushing model", "tag", utils.SanitizeForLog(tag, -1)) if err := target.Write(ctx, mdl, progressWriter); err != nil { - c.log.Errorln("Failed to push image:", err, "reference:", tag) + c.log.Error("failed to push image", "error", err, "reference", tag) if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error()), oci.ModePush); writeErr != nil { - c.log.Warnf("Failed to write error message: %v", writeErr) + c.log.Warn(fmt.Sprintf("Failed to write error message: %v", writeErr)) } return fmt.Errorf("pushing image: %w", err) } - c.log.Infoln("Successfully pushed model:", tag) + c.log.Info("successfully pushed model", "tag", tag) if err := progress.WriteSuccess(progressWriter, "Model pushed successfully", oci.ModePush); err != nil { - c.log.Warnf("Failed to write success message: %v", err) + c.log.Warn(fmt.Sprintf("Failed to write success message: %v", err)) } return nil @@ -625,7 +625,7 @@ func (c *Client) PushModel(ctx context.Context, tag string, progressWriter io.Wr // This is used for config-only modifications where the layer data hasn't changed. // The layers must already exist in the store. func (c *Client) WriteLightweightModel(mdl types.ModelArtifact, tags []string) error { - c.log.Infoln("Writing lightweight model variant") + c.log.Info("Writing lightweight model variant") normalizedTags := make([]string, len(tags)) for i, tag := range tags { normalizedTags[i] = c.normalizeModelName(tag) @@ -634,20 +634,20 @@ func (c *Client) WriteLightweightModel(mdl types.ModelArtifact, tags []string) e } func (c *Client) ResetStore() error { - c.log.Infoln("Resetting store") + c.log.Info("Resetting store") if err := c.store.Reset(); err != nil { - c.log.Errorln("Failed to reset store:", err) + c.log.Error("failed to reset store", "error", err) return fmt.Errorf("resetting store: %w", err) } return nil } func (c *Client) ExportModel(reference string, w io.Writer) error { - c.log.Infoln("Exporting model:", utils.SanitizeForLog(reference)) + c.log.Info("exporting model", "reference", utils.SanitizeForLog(reference)) normalizedRef := c.normalizeModelName(reference) mdl, err := c.store.Read(normalizedRef) if err != nil { - c.log.Errorln("Failed to get model for export:", err, "reference:", utils.SanitizeForLog(reference)) + c.log.Error("failed to get model for export", "error", err, "reference", utils.SanitizeForLog(reference)) return fmt.Errorf("get model '%q': %w", utils.SanitizeForLog(reference), err) } @@ -657,11 +657,11 @@ func (c *Client) ExportModel(reference string, w io.Writer) error { } if err := target.Write(context.Background(), mdl, nil); err != nil { - c.log.Errorln("Failed to export model:", err, "reference:", utils.SanitizeForLog(reference)) + c.log.Error("failed to export model", "error", err, "reference", utils.SanitizeForLog(reference)) return fmt.Errorf("export model: %w", err) } - c.log.Infoln("Successfully exported model:", utils.SanitizeForLog(reference)) + c.log.Info("successfully exported model", "reference", utils.SanitizeForLog(reference)) return nil } @@ -670,14 +670,14 @@ type RepackageOptions struct { } func (c *Client) RepackageModel(sourceRef string, targetRef string, opts RepackageOptions) error { - c.log.Infoln("Repackaging model:", utils.SanitizeForLog(sourceRef), "->", utils.SanitizeForLog(targetRef)) + c.log.Info("repackaging model", "source", utils.SanitizeForLog(sourceRef), "target", utils.SanitizeForLog(targetRef)) normalizedSource := c.normalizeModelName(sourceRef) normalizedTarget := c.normalizeModelName(targetRef) mdl, err := c.store.Read(normalizedSource) if err != nil { - c.log.Errorln("Failed to get model for repackaging:", err, "reference:", utils.SanitizeForLog(sourceRef)) + c.log.Error("failed to get model for repackaging", "error", err, "reference", utils.SanitizeForLog(sourceRef)) return fmt.Errorf("get model '%q': %w", utils.SanitizeForLog(sourceRef), err) } @@ -687,11 +687,11 @@ func (c *Client) RepackageModel(sourceRef string, targetRef string, opts Repacka } if err := c.store.WriteLightweight(modifiedModel, []string{normalizedTarget}); err != nil { - c.log.Errorln("Failed to write repackaged model:", err, "target:", utils.SanitizeForLog(targetRef)) + c.log.Error("failed to write repackaged model", "error", err, "target", utils.SanitizeForLog(targetRef)) return fmt.Errorf("write repackaged model: %w", err) } - c.log.Infoln("Successfully repackaged model:", utils.SanitizeForLog(sourceRef), "->", utils.SanitizeForLog(targetRef)) + c.log.Info("successfully repackaged model", "source", utils.SanitizeForLog(sourceRef), "target", utils.SanitizeForLog(targetRef)) return nil } @@ -708,7 +708,7 @@ func GetSupportedFormats() []types.Format { return []types.Format{types.FormatGGUF, types.FormatDiffusers} } -func checkCompat(image types.ModelArtifact, log *logrus.Entry, reference string, progressWriter io.Writer) error { +func checkCompat(image types.ModelArtifact, log *slog.Logger, reference string, progressWriter io.Writer) error { manifest, err := image.Manifest() if err != nil { return err @@ -724,13 +724,12 @@ func checkCompat(image types.ModelArtifact, log *logrus.Entry, reference string, } if config.GetFormat() == "" { - log.Warnf("Model format field is empty for %s, unable to verify format compatibility", - utils.SanitizeForLog(reference)) + log.Warn(fmt.Sprintf("Model format field is empty for %s, unable to verify format compatibility", utils.SanitizeForLog(reference))) } else if !slices.Contains(GetSupportedFormats(), config.GetFormat()) { // Write warning but continue with pull - log.Warnln(warnUnsupportedFormat) + log.Warn(warnUnsupportedFormat) if err := progress.WriteWarning(progressWriter, warnUnsupportedFormat, oci.ModePull); err != nil { - log.Warnf("Failed to write warning message: %v", err) + log.Warn(fmt.Sprintf("Failed to write warning message: %v", err)) } // Don't return an error - allow the pull to continue } @@ -775,7 +774,7 @@ func parseHFReference(reference string) (repo, revision, tag string) { // This is used when the model is stored as raw files (safetensors) on HuggingFace Hub func (c *Client) pullNativeHuggingFace(ctx context.Context, reference string, progressWriter io.Writer, token string) error { repo, revision, tag := parseHFReference(reference) - c.log.Infof("Pulling native HuggingFace model: repo=%s, revision=%s, tag=%s", utils.SanitizeForLog(repo), utils.SanitizeForLog(revision), utils.SanitizeForLog(tag)) + c.log.Info(fmt.Sprintf("Pulling native HuggingFace model: repo=%s, revision=%s, tag=%s", utils.SanitizeForLog(repo), utils.SanitizeForLog(revision), utils.SanitizeForLog(tag))) // Create HuggingFace client hfOpts := []huggingface.ClientOption{ @@ -807,23 +806,23 @@ func (c *Client) pullNativeHuggingFace(ctx context.Context, reference string, pr return registry.ErrModelNotFound } if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error()), oci.ModePull); writeErr != nil { - c.log.Warnf("Failed to write error message: %v", writeErr) + c.log.Warn(fmt.Sprintf("Failed to write error message: %v", writeErr)) } return fmt.Errorf("build model from HuggingFace: %w", err) } // Write model to store with normalized tag storageTag := c.normalizeModelName(reference) - c.log.Infof("Writing model to store with tag: %s", utils.SanitizeForLog(storageTag)) + c.log.Info(fmt.Sprintf("Writing model to store with tag: %s", utils.SanitizeForLog(storageTag))) if err := c.store.Write(model, []string{storageTag}, progressWriter); err != nil { if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error()), oci.ModePull); writeErr != nil { - c.log.Warnf("Failed to write error message: %v", writeErr) + c.log.Warn(fmt.Sprintf("Failed to write error message: %v", writeErr)) } return fmt.Errorf("writing model to store: %w", err) } if err := progress.WriteSuccess(progressWriter, "Model pulled successfully", oci.ModePull); err != nil { - c.log.Warnf("Failed to write success message: %v", err) + c.log.Warn(fmt.Sprintf("Failed to write success message: %v", err)) } return nil diff --git a/pkg/distribution/distribution/client_test.go b/pkg/distribution/distribution/client_test.go index 31629ff79..c31addf98 100644 --- a/pkg/distribution/distribution/client_test.go +++ b/pkg/distribution/distribution/client_test.go @@ -13,6 +13,7 @@ import ( "os" "path/filepath" "strings" + "log/slog" "testing" "github.com/docker/model-runner/pkg/distribution/internal/mutate" @@ -24,7 +25,6 @@ import ( mdregistry "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/distribution/registry/testregistry" "github.com/docker/model-runner/pkg/inference/platform" - "github.com/sirupsen/logrus" ) var ( @@ -45,7 +45,7 @@ func TestClientPullModel(t *testing.T) { defer server.Close() registryURL, err := url.Parse(server.URL) if err != nil { - t.Fatalf("Failed to parse registry URL: %v", err) + t.Error(fmt.Sprintf("Failed to parse registry URL: %v", err)) } registryHost := registryURL.Host @@ -54,52 +54,52 @@ func TestClientPullModel(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Read model content for verification later modelContent, err := os.ReadFile(testGGUFFile) if err != nil { - t.Fatalf("Failed to read test model file: %v", err) + t.Error(fmt.Sprintf("Failed to read test model file: %v", err)) } model := testutil.BuildModelFromPath(t, testGGUFFile) tag := registryHost + "/testmodel:v1.0.0" ref, err := reference.ParseReference(tag) if err != nil { - t.Fatalf("Failed to parse reference: %v", err) + t.Error(fmt.Sprintf("Failed to parse reference: %v", err)) } if err := remote.Write(ref, model, nil, remote.WithPlainHTTP(true)); err != nil { - t.Fatalf("Failed to push model: %v", err) + t.Error(fmt.Sprintf("Failed to push model: %v", err)) } t.Run("pull without progress writer", func(t *testing.T) { // Pull model from registry without progress writer err := client.PullModel(t.Context(), tag, nil) if err != nil { - t.Fatalf("Failed to pull model: %v", err) + t.Error(fmt.Sprintf("Failed to pull model: %v", err)) } model, err := client.GetModel(tag) if err != nil { - t.Fatalf("Failed to get model: %v", err) + t.Error(fmt.Sprintf("Failed to get model: %v", err)) } modelPaths, err := model.GGUFPaths() if err != nil { - t.Fatalf("Failed to get model path: %v", err) + t.Error(fmt.Sprintf("Failed to get model path: %v", err)) } if len(modelPaths) != 1 { - t.Fatalf("Unexpected number of model files: %d", len(modelPaths)) + t.Error(fmt.Sprintf("Unexpected number of model files: %d", len(modelPaths))) } // Verify model content pulledContent, err := os.ReadFile(modelPaths[0]) if err != nil { - t.Fatalf("Failed to read pulled model: %v", err) + t.Error(fmt.Sprintf("Failed to read pulled model: %v", err)) } if string(pulledContent) != string(modelContent) { - t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent) + t.Error(fmt.Sprintf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent)) } }) @@ -109,36 +109,36 @@ func TestClientPullModel(t *testing.T) { // Pull model from registry with progress writer if err := client.PullModel(t.Context(), tag, &progressBuffer); err != nil { - t.Fatalf("Failed to pull model: %v", err) + t.Error(fmt.Sprintf("Failed to pull model: %v", err)) } // Verify progress output progressOutput := progressBuffer.String() if !strings.Contains(progressOutput, "Using cached model") && !strings.Contains(progressOutput, "Downloading") { - t.Errorf("Progress output doesn't contain expected text: got %q", progressOutput) + t.Error(fmt.Sprintf("Progress output doesn't contain expected text: got %q", progressOutput)) } model, err := client.GetModel(tag) if err != nil { - t.Fatalf("Failed to get model: %v", err) + t.Error(fmt.Sprintf("Failed to get model: %v", err)) } modelPaths, err := model.GGUFPaths() if err != nil { - t.Fatalf("Failed to get model path: %v", err) + t.Error(fmt.Sprintf("Failed to get model path: %v", err)) } if len(modelPaths) != 1 { - t.Fatalf("Unexpected number of model files: %d", len(modelPaths)) + t.Error(fmt.Sprintf("Unexpected number of model files: %d", len(modelPaths))) } // Verify model content pulledContent, err := os.ReadFile(modelPaths[0]) if err != nil { - t.Fatalf("Failed to read pulled model: %v", err) + t.Error(fmt.Sprintf("Failed to read pulled model: %v", err)) } if string(pulledContent) != string(modelContent) { - t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent) + t.Error(fmt.Sprintf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent)) } }) @@ -148,7 +148,7 @@ func TestClientPullModel(t *testing.T) { // Create client with plainHTTP for test registry testClient, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Create a buffer to capture progress output @@ -165,25 +165,25 @@ func TestClientPullModel(t *testing.T) { var pullErr *mdregistry.Error ok := errors.As(err, &pullErr) if !ok { - t.Fatalf("Expected registry.Error, got %T: %v", err, err) + t.Error(fmt.Sprintf("Expected registry.Error, got %T: %v", err, err)) } // Verify it matches registry.ErrModelNotFound for API compatibility if !errors.Is(err, mdregistry.ErrModelNotFound) { - t.Fatalf("Expected registry.ErrModelNotFound, got %T", err) + t.Error(fmt.Sprintf("Expected registry.ErrModelNotFound, got %T", err)) } // Verify error fields if pullErr.Reference != nonExistentRef { - t.Errorf("Expected reference %q, got %q", nonExistentRef, pullErr.Reference) + t.Error(fmt.Sprintf("Expected reference %q, got %q", nonExistentRef, pullErr.Reference)) } // The error code can be NAME_UNKNOWN, MANIFEST_UNKNOWN, or UNKNOWN depending on the resolver implementation if pullErr.Code != "NAME_UNKNOWN" && pullErr.Code != "MANIFEST_UNKNOWN" && pullErr.Code != "UNKNOWN" { - t.Errorf("Expected error code NAME_UNKNOWN, MANIFEST_UNKNOWN, or UNKNOWN, got %q", pullErr.Code) + t.Error(fmt.Sprintf("Expected error code NAME_UNKNOWN, MANIFEST_UNKNOWN, or UNKNOWN, got %q", pullErr.Code)) } // The error message varies by resolver implementation if !strings.Contains(strings.ToLower(pullErr.Message), "not found") { - t.Errorf("Expected message to contain 'not found', got %q", pullErr.Message) + t.Error(fmt.Sprintf("Expected message to contain 'not found', got %q", pullErr.Message)) } if pullErr.Err == nil { t.Error("Expected underlying error to be non-nil") @@ -196,7 +196,7 @@ func TestClientPullModel(t *testing.T) { // Create client with plainHTTP for test registry testClient, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Use the dummy.gguf file from assets directory @@ -205,26 +205,26 @@ func TestClientPullModel(t *testing.T) { // Push model to local store testTag := registryHost + "/incomplete-test/model:v1.0.0" if err := testClient.store.Write(mdl, []string{testTag}, nil); err != nil { - t.Fatalf("Failed to push model to store: %v", err) + t.Error(fmt.Sprintf("Failed to push model to store: %v", err)) } // Push model to registry if err := testClient.PushModel(t.Context(), testTag, nil); err != nil { - t.Fatalf("Failed to pull model: %v", err) + t.Error(fmt.Sprintf("Failed to pull model: %v", err)) } // Get the model to find the GGUF path model, err := testClient.GetModel(testTag) if err != nil { - t.Fatalf("Failed to get model: %v", err) + t.Error(fmt.Sprintf("Failed to get model: %v", err)) } ggufPaths, err := model.GGUFPaths() if err != nil { - t.Fatalf("Failed to get GGUF path: %v", err) + t.Error(fmt.Sprintf("Failed to get GGUF path: %v", err)) } if len(ggufPaths) != 1 { - t.Fatalf("Unexpected number of model files: %d", len(ggufPaths)) + t.Error(fmt.Sprintf("Unexpected number of model files: %d", len(ggufPaths))) } // Create an incomplete file by copying the GGUF file and adding .incomplete suffix @@ -232,23 +232,23 @@ func TestClientPullModel(t *testing.T) { incompletePath := ggufPath + ".incomplete" originalContent, err := os.ReadFile(ggufPath) if err != nil { - t.Fatalf("Failed to read GGUF file: %v", err) + t.Error(fmt.Sprintf("Failed to read GGUF file: %v", err)) } // Write partial content to simulate an incomplete download partialContent := originalContent[:len(originalContent)/2] if err := os.WriteFile(incompletePath, partialContent, 0644); err != nil { - t.Fatalf("Failed to create incomplete file: %v", err) + t.Error(fmt.Sprintf("Failed to create incomplete file: %v", err)) } // Verify the incomplete file exists if _, err := os.Stat(incompletePath); os.IsNotExist(err) { - t.Fatalf("Failed to create incomplete file: %v", err) + t.Error(fmt.Sprintf("Failed to create incomplete file: %v", err)) } // Delete the local model to force a pull if _, err := testClient.DeleteModel(testTag, false); err != nil { - t.Fatalf("Failed to delete model: %v", err) + t.Error(fmt.Sprintf("Failed to delete model: %v", err)) } // Create a buffer to capture progress output @@ -256,33 +256,33 @@ func TestClientPullModel(t *testing.T) { // Pull the model again - this should detect the incomplete file and pull again if err := testClient.PullModel(t.Context(), testTag, &progressBuffer); err != nil { - t.Fatalf("Failed to pull model: %v", err) + t.Error(fmt.Sprintf("Failed to pull model: %v", err)) } // Verify progress output indicates a new download, not using cached model progressOutput := progressBuffer.String() if strings.Contains(progressOutput, "Using cached model") { - t.Errorf("Expected to pull model again due to incomplete file, but used cached model") + t.Error("Expected to pull model again due to incomplete file, but used cached model") } // Verify the incomplete file no longer exists if _, err := os.Stat(incompletePath); !os.IsNotExist(err) { - t.Errorf("Incomplete file still exists after successful pull: %s", incompletePath) + t.Error(fmt.Sprintf("Incomplete file still exists after successful pull: %s", incompletePath)) } // Verify the complete file exists if _, err := os.Stat(ggufPath); os.IsNotExist(err) { - t.Errorf("GGUF file doesn't exist after pull: %s", ggufPath) + t.Error(fmt.Sprintf("GGUF file doesn't exist after pull: %s", ggufPath)) } // Verify the content of the pulled file matches the original pulledContent, err := os.ReadFile(ggufPath) if err != nil { - t.Fatalf("Failed to read pulled GGUF file: %v", err) + t.Error(fmt.Sprintf("Failed to read pulled GGUF file: %v", err)) } if !bytes.Equal(pulledContent, originalContent) { - t.Errorf("Pulled content doesn't match original content") + t.Error("Pulled content doesn't match original content") } }) @@ -292,13 +292,13 @@ func TestClientPullModel(t *testing.T) { // Create client with plainHTTP for test registry testClient, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Read model content for verification later testModelContent, err := os.ReadFile(testGGUFFile) if err != nil { - t.Fatalf("Failed to read test model file: %v", err) + t.Error(fmt.Sprintf("Failed to read test model file: %v", err)) } // Push first version of model to registry @@ -309,38 +309,38 @@ func TestClientPullModel(t *testing.T) { // Pull first version of model if err := testClient.PullModel(t.Context(), testTag, nil); err != nil { - t.Fatalf("Failed to pull first version of model: %v", err) + t.Error(fmt.Sprintf("Failed to pull first version of model: %v", err)) } // Verify first version is in local store model, err := testClient.GetModel(testTag) if err != nil { - t.Fatalf("Failed to get first version of model: %v", err) + t.Error(fmt.Sprintf("Failed to get first version of model: %v", err)) } modelPath, err := model.GGUFPaths() if err != nil { - t.Fatalf("Failed to get model path: %v", err) + t.Error(fmt.Sprintf("Failed to get model path: %v", err)) } if len(modelPath) != 1 { - t.Fatalf("Unexpected number of model files: %d", len(modelPath)) + t.Error(fmt.Sprintf("Unexpected number of model files: %d", len(modelPath))) } // Verify first version content pulledContent, err := os.ReadFile(modelPath[0]) if err != nil { - t.Fatalf("Failed to read pulled model: %v", err) + t.Error(fmt.Sprintf("Failed to read pulled model: %v", err)) } if string(pulledContent) != string(testModelContent) { - t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, testModelContent) + t.Error(fmt.Sprintf("Pulled model content doesn't match original: got %q, want %q", pulledContent, testModelContent)) } // Create a modified version of the model updatedModelFile := filepath.Join(tempDir, "updated-dummy.gguf") updatedContent := append(testModelContent, []byte("UPDATED CONTENT")...) if err := os.WriteFile(updatedModelFile, updatedContent, 0644); err != nil { - t.Fatalf("Failed to create updated model file: %v", err) + t.Error(fmt.Sprintf("Failed to create updated model file: %v", err)) } // Push updated model with same tag @@ -353,37 +353,37 @@ func TestClientPullModel(t *testing.T) { // Pull model again - should get the updated version if err := testClient.PullModel(t.Context(), testTag, &progressBuffer); err != nil { - t.Fatalf("Failed to pull updated model: %v", err) + t.Error(fmt.Sprintf("Failed to pull updated model: %v", err)) } // Verify progress output indicates a new download, not using cached model progressOutput := progressBuffer.String() if strings.Contains(progressOutput, "Using cached model") { - t.Errorf("Expected to pull updated model, but used cached model") + t.Error("Expected to pull updated model, but used cached model") } // Get the model again to verify it's the updated version updatedModel, err := testClient.GetModel(testTag) if err != nil { - t.Fatalf("Failed to get updated model: %v", err) + t.Error(fmt.Sprintf("Failed to get updated model: %v", err)) } updatedModelPaths, err := updatedModel.GGUFPaths() if err != nil { - t.Fatalf("Failed to get updated model path: %v", err) + t.Error(fmt.Sprintf("Failed to get updated model path: %v", err)) } if len(updatedModelPaths) != 1 { - t.Fatalf("Unexpected number of model files: %d", len(modelPath)) + t.Error(fmt.Sprintf("Unexpected number of model files: %d", len(modelPath))) } // Verify updated content updatedPulledContent, err := os.ReadFile(updatedModelPaths[0]) if err != nil { - t.Fatalf("Failed to read updated pulled model: %v", err) + t.Error(fmt.Sprintf("Failed to read updated pulled model: %v", err)) } if string(updatedPulledContent) != string(updatedContent) { - t.Errorf("Updated pulled model content doesn't match: got %q, want %q", updatedPulledContent, updatedContent) + t.Error(fmt.Sprintf("Updated pulled model content doesn't match: got %q, want %q", updatedPulledContent, updatedContent)) } }) @@ -393,13 +393,13 @@ func TestClientPullModel(t *testing.T) { testTag := registryHost + "/unsupported-test/model:v1.0.0" ref, err := reference.ParseReference(testTag) if err != nil { - t.Fatalf("Failed to parse reference: %v", err) + t.Error(fmt.Sprintf("Failed to parse reference: %v", err)) } if err := remote.Write(ref, newMdl, nil, remote.WithPlainHTTP(true)); err != nil { - t.Fatalf("Failed to push model: %v", err) + t.Error(fmt.Sprintf("Failed to push model: %v", err)) } if err := client.PullModel(t.Context(), testTag, nil); err == nil || !errors.Is(err, ErrUnsupportedMediaType) { - t.Fatalf("Expected artifact version error, got %v", err) + t.Error(fmt.Sprintf("Expected artifact version error, got %v", err)) } }) @@ -410,7 +410,7 @@ func TestClientPullModel(t *testing.T) { safetensorsPath := filepath.Join(safetensorsTempDir, "model.safetensors") safetensorsContent := []byte("fake safetensors content for testing") if err := os.WriteFile(safetensorsPath, safetensorsContent, 0644); err != nil { - t.Fatalf("Failed to create safetensors file: %v", err) + t.Error(fmt.Sprintf("Failed to create safetensors file: %v", err)) } // Create a safetensors model @@ -420,10 +420,10 @@ func TestClientPullModel(t *testing.T) { testTag := registryHost + "/safetensors-test/model:v1.0.0" ref, err := reference.ParseReference(testTag) if err != nil { - t.Fatalf("Failed to parse reference: %v", err) + t.Error(fmt.Sprintf("Failed to parse reference: %v", err)) } if err := remote.Write(ref, safetensorsModel, nil, remote.WithPlainHTTP(true)); err != nil { - t.Fatalf("Failed to push safetensors model to registry: %v", err) + t.Error(fmt.Sprintf("Failed to push safetensors model to registry: %v", err)) } // Create a new client with a separate temp store @@ -431,7 +431,7 @@ func TestClientPullModel(t *testing.T) { testClient, err := newTestClient(clientTempDir) if err != nil { - t.Fatalf("Failed to create test client: %v", err) + t.Error(fmt.Sprintf("Failed to create test client: %v", err)) } // Try to pull the safetensors model with a progress writer to capture warnings @@ -440,17 +440,17 @@ func TestClientPullModel(t *testing.T) { // Pull should succeed on all platforms now (with a warning on non-Linux) if err != nil { - t.Fatalf("Expected no error, got: %v", err) + t.Error(fmt.Sprintf("Expected no error, got: %v", err)) } if !platform.SupportsVLLM() { // On non-Linux, verify that a warning was written progressOutput := progressBuf.String() if !strings.Contains(progressOutput, `"type":"warning"`) { - t.Fatalf("Expected warning message on non-Linux platforms, got output: %s", progressOutput) + t.Error(fmt.Sprintf("Expected warning message on non-Linux platforms, got output: %s", progressOutput)) } if !strings.Contains(progressOutput, warnUnsupportedFormat) { - t.Fatalf("Expected warning about safetensors format, got output: %s", progressOutput) + t.Error(fmt.Sprintf("Expected warning about safetensors format, got output: %s", progressOutput)) } } }) @@ -461,7 +461,7 @@ func TestClientPullModel(t *testing.T) { // Create client with plainHTTP for test registry testClient, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Create a buffer to capture progress output @@ -469,7 +469,7 @@ func TestClientPullModel(t *testing.T) { // Pull model from registry with progress writer if err := testClient.PullModel(t.Context(), tag, &progressBuffer); err != nil { - t.Fatalf("Failed to pull model: %v", err) + t.Error(fmt.Sprintf("Failed to pull model: %v", err)) } // Parse progress output as JSON @@ -479,13 +479,13 @@ func TestClientPullModel(t *testing.T) { line := scanner.Text() var msg oci.ProgressMessage if err := json.Unmarshal([]byte(line), &msg); err != nil { - t.Fatalf("Failed to parse JSON progress message: %v, line: %s", err, line) + t.Error(fmt.Sprintf("Failed to parse JSON progress message: %v, line: %s", err, line)) } messages = append(messages, msg) } if err := scanner.Err(); err != nil { - t.Fatalf("Error reading progress output: %v", err) + t.Error(fmt.Sprintf("Error reading progress output: %v", err)) } // Verify we got some messages @@ -496,38 +496,38 @@ func TestClientPullModel(t *testing.T) { // Verify all messages have the correct mode for i, msg := range messages { if msg.Mode != oci.ModePull { - t.Errorf("message %d: expected mode %q, got %q", i, oci.ModePull, msg.Mode) + t.Error(fmt.Sprintf("message %d: expected mode %q, got %q", i, oci.ModePull, msg.Mode)) } } // Check the last message is a success message lastMsg := messages[len(messages)-1] if lastMsg.Type != oci.TypeSuccess { - t.Errorf("Expected last message to be success, got type: %q, message: %s", lastMsg.Type, lastMsg.Message) + t.Error(fmt.Sprintf("Expected last message to be success, got type: %q, message: %s", lastMsg.Type, lastMsg.Message)) } // Verify model was pulled correctly model, err := testClient.GetModel(tag) if err != nil { - t.Fatalf("Failed to get model: %v", err) + t.Error(fmt.Sprintf("Failed to get model: %v", err)) } modelPaths, err := model.GGUFPaths() if err != nil { - t.Fatalf("Failed to get model path: %v", err) + t.Error(fmt.Sprintf("Failed to get model path: %v", err)) } if len(modelPaths) != 1 { - t.Fatalf("Unexpected number of model files: %d", len(modelPaths)) + t.Error(fmt.Sprintf("Unexpected number of model files: %d", len(modelPaths))) } // Verify model content pulledContent, err := os.ReadFile(modelPaths[0]) if err != nil { - t.Fatalf("Failed to read pulled model: %v", err) + t.Error(fmt.Sprintf("Failed to read pulled model: %v", err)) } if string(pulledContent) != string(modelContent) { - t.Errorf("Pulled model content doesn't match original") + t.Error("Pulled model content doesn't match original") } }) @@ -537,7 +537,7 @@ func TestClientPullModel(t *testing.T) { // Create client with plainHTTP for test registry testClient, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Create a buffer to capture progress output @@ -554,7 +554,7 @@ func TestClientPullModel(t *testing.T) { // Verify it matches registry.ErrModelNotFound if !errors.Is(err, mdregistry.ErrModelNotFound) { - t.Fatalf("Expected registry.ErrModelNotFound, got %T", err) + t.Error(fmt.Sprintf("Expected registry.ErrModelNotFound, got %T", err)) } // No JSON messages should be in the buffer for this error case @@ -568,7 +568,7 @@ func TestClientGetModel(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Create model from test GGUF file @@ -578,18 +578,18 @@ func TestClientGetModel(t *testing.T) { tag := "test/model:v1.0.0" normalizedTag := "docker.io/test/model:v1.0.0" // Reference package normalizes to include registry if err := client.store.Write(model, []string{tag}, nil); err != nil { - t.Fatalf("Failed to push model to store: %v", err) + t.Error(fmt.Sprintf("Failed to push model to store: %v", err)) } // Get model mi, err := client.GetModel(tag) if err != nil { - t.Fatalf("Failed to get model: %v", err) + t.Error(fmt.Sprintf("Failed to get model: %v", err)) } // Verify model - tags are normalized to include the default registry if len(mi.Tags()) == 0 || mi.Tags()[0] != normalizedTag { - t.Errorf("Model tags don't match: got %v, want [%s]", mi.Tags(), normalizedTag) + t.Error(fmt.Sprintf("Model tags don't match: got %v, want [%s]", mi.Tags(), normalizedTag)) } } @@ -599,13 +599,13 @@ func TestClientGetModelNotFound(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Get non-existent model _, err = client.GetModel("nonexistent/model:v1.0.0") if !errors.Is(err, ErrModelNotFound) { - t.Errorf("Expected ErrModelNotFound, got %v", err) + t.Error(fmt.Sprintf("Expected ErrModelNotFound, got %v", err)) } } @@ -615,14 +615,14 @@ func TestClientListModels(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Create test model file modelContent := []byte("test model content") modelFile := filepath.Join(tempDir, "test-model.gguf") if err := os.WriteFile(modelFile, modelContent, 0644); err != nil { - t.Fatalf("Failed to write test model file: %v", err) + t.Error(fmt.Sprintf("Failed to write test model file: %v", err)) } mdl := testutil.BuildModelFromPath(t, modelFile) @@ -631,21 +631,21 @@ func TestClientListModels(t *testing.T) { // First model tag1 := "test/model1:v1.0.0" if err := client.store.Write(mdl, []string{tag1}, nil); err != nil { - t.Fatalf("Failed to push model to store: %v", err) + t.Error(fmt.Sprintf("Failed to push model to store: %v", err)) } // Create a slightly different model file for the second model modelContent2 := []byte("test model content 2") modelFile2 := filepath.Join(tempDir, "test-model2.gguf") if err := os.WriteFile(modelFile2, modelContent2, 0644); err != nil { - t.Fatalf("Failed to write test model file: %v", err) + t.Error(fmt.Sprintf("Failed to write test model file: %v", err)) } mdl2 := testutil.BuildModelFromPath(t, modelFile2) // Second model tag2 := "test/model2:v1.0.0" if err := client.store.Write(mdl2, []string{tag2}, nil); err != nil { - t.Fatalf("Failed to push model to store: %v", err) + t.Error(fmt.Sprintf("Failed to push model to store: %v", err)) } // Normalized tags for verification (reference package normalizes to include default registry) @@ -656,12 +656,12 @@ func TestClientListModels(t *testing.T) { // List models models, err := client.ListModels() if err != nil { - t.Fatalf("Failed to list models: %v", err) + t.Error(fmt.Sprintf("Failed to list models: %v", err)) } // Verify models if len(models) != len(tags) { - t.Errorf("Expected %d models, got %d", len(tags), len(models)) + t.Error(fmt.Sprintf("Expected %d models, got %d", len(tags), len(models))) } // Check if all tags are present @@ -674,7 +674,7 @@ func TestClientListModels(t *testing.T) { for _, tag := range tags { if !tagMap[tag] { - t.Errorf("Tag %s not found in models", tag) + t.Error(fmt.Sprintf("Tag %s not found in models", tag)) } } } @@ -685,7 +685,7 @@ func TestClientGetStorePath(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Get store path @@ -693,12 +693,12 @@ func TestClientGetStorePath(t *testing.T) { // Verify store path matches the temp directory if storePath != tempDir { - t.Errorf("Store path doesn't match: got %s, want %s", storePath, tempDir) + t.Error(fmt.Sprintf("Store path doesn't match: got %s, want %s", storePath, tempDir)) } // Verify the store directory exists if _, err := os.Stat(storePath); os.IsNotExist(err) { - t.Errorf("Store directory does not exist: %s", storePath) + t.Error(fmt.Sprintf("Store directory does not exist: %s", storePath)) } } @@ -708,7 +708,7 @@ func TestClientDefaultLogger(t *testing.T) { // Create client without specifying logger client, err := NewClient(WithStoreRootPath(tempDir)) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Verify that logger is not nil @@ -717,13 +717,13 @@ func TestClientDefaultLogger(t *testing.T) { } // Create client with custom logger - customLogger := logrus.NewEntry(logrus.New()) + customLogger := slog.Default() client, err = NewClient( WithStoreRootPath(tempDir), WithLogger(customLogger), ) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Verify that custom logger is used @@ -746,8 +746,7 @@ func TestWithFunctionsNilChecks(t *testing.T) { // Verify the path wasn't changed to empty if opts.storeRootPath != tempDir { - t.Errorf("WithStoreRootPath with empty string changed the path: got %q, want %q", - opts.storeRootPath, tempDir) + t.Error(fmt.Sprintf("WithStoreRootPath with empty string changed the path: got %q, want %q", opts.storeRootPath, tempDir)) } }) @@ -789,7 +788,7 @@ func TestNewReferenceError(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Test with invalid reference @@ -800,7 +799,7 @@ func TestNewReferenceError(t *testing.T) { } if !errors.Is(err, ErrInvalidReference) { - t.Fatalf("Expected error to match sentinel invalid reference error, got %v", err) + t.Error(fmt.Sprintf("Expected error to match sentinel invalid reference error, got %v", err)) } } @@ -810,7 +809,7 @@ func TestPush(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Create a test registry @@ -820,7 +819,7 @@ func TestPush(t *testing.T) { // Create a tag for the model uri, err := url.Parse(server.URL) if err != nil { - t.Fatalf("Failed to parse registry URL: %v", err) + t.Error(fmt.Sprintf("Failed to parse registry URL: %v", err)) } tag := uri.Host + "/incomplete-test/model:v1.0.0" @@ -828,39 +827,39 @@ func TestPush(t *testing.T) { mdl := testutil.BuildModelFromPath(t, testGGUFFile) digest, err := mdl.ID() if err != nil { - t.Fatalf("Failed to get digest of original model: %v", err) + t.Error(fmt.Sprintf("Failed to get digest of original model: %v", err)) } if err := client.store.Write(mdl, []string{tag}, nil); err != nil { - t.Fatalf("Failed to push model to store: %v", err) + t.Error(fmt.Sprintf("Failed to push model to store: %v", err)) } // Push the model to the registry if err := client.PushModel(t.Context(), tag, nil); err != nil { - t.Fatalf("Failed to push model: %v", err) + t.Error(fmt.Sprintf("Failed to push model: %v", err)) } // Delete local copy (so we can test pulling) if _, err := client.DeleteModel(tag, false); err != nil { - t.Fatalf("Failed to delete model: %v", err) + t.Error(fmt.Sprintf("Failed to delete model: %v", err)) } // Test that model can be pulled successfully if err := client.PullModel(t.Context(), tag, nil); err != nil { - t.Fatalf("Failed to pull model: %v", err) + t.Error(fmt.Sprintf("Failed to pull model: %v", err)) } // Test that model the pulled model is the same as the original (matching digests) mdl2, err := client.GetModel(tag) if err != nil { - t.Fatalf("Failed to get pulled model: %v", err) + t.Error(fmt.Sprintf("Failed to get pulled model: %v", err)) } digest2, err := mdl2.ID() if err != nil { - t.Fatalf("Failed to get digest of the pulled model: %v", err) + t.Error(fmt.Sprintf("Failed to get digest of the pulled model: %v", err)) } if digest != digest2 { - t.Fatalf("Digests don't match: got %s, want %s", digest2, digest) + t.Error(fmt.Sprintf("Digests don't match: got %s, want %s", digest2, digest)) } } @@ -870,7 +869,7 @@ func TestPushProgress(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Create a test registry @@ -880,7 +879,7 @@ func TestPushProgress(t *testing.T) { // Create a tag for the model uri, err := url.Parse(server.URL) if err != nil { - t.Fatalf("Failed to parse registry URL: %v", err) + t.Error(fmt.Sprintf("Failed to parse registry URL: %v", err)) } tag := uri.Host + "/some/model/repo:some-tag" @@ -889,14 +888,14 @@ func TestPushProgress(t *testing.T) { sz := int64(progress.MinBytesForUpdate * 2) path, err := randomFile(sz) if err != nil { - t.Fatalf("Failed to create temp file: %v", err) + t.Error(fmt.Sprintf("Failed to create temp file: %v", err)) } defer os.Remove(path) mdl := testutil.BuildModelFromPath(t, path) if err := client.store.Write(mdl, []string{tag}, nil); err != nil { - t.Fatalf("Failed to write model to store: %v", err) + t.Error(fmt.Sprintf("Failed to write model to store: %v", err)) } // Create a buffer to capture progress output @@ -918,13 +917,13 @@ func TestPushProgress(t *testing.T) { // Wait for the push to complete if err := <-done; err != nil { - t.Fatalf("Failed to push model: %v", err) + t.Error(fmt.Sprintf("Failed to push model: %v", err)) } // Verify we got at least 2 messages (1 progress + 1 success) // With fast local uploads, we may only get one progress update per layer if len(lines) < 2 { - t.Fatalf("Expected at least 2 progress messages, got %d", len(lines)) + t.Error(fmt.Sprintf("Expected at least 2 progress messages, got %d", len(lines))) } // Verify we got at least one progress message and the success message @@ -939,10 +938,10 @@ func TestPushProgress(t *testing.T) { } } if !hasProgress { - t.Fatalf("Expected at least one progress message containing 'Uploaded:', got %v", lines) + t.Error(fmt.Sprintf("Expected at least one progress message containing 'Uploaded:', got %v", lines)) } if !hasSuccess { - t.Fatalf("Expected a success message, got %v", lines) + t.Error(fmt.Sprintf("Expected a success message, got %v", lines)) } } @@ -952,14 +951,14 @@ func TestTag(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Create a test model model := testutil.BuildModelFromPath(t, testGGUFFile) id, err := model.ID() if err != nil { - t.Fatalf("Failed to get model ID: %v", err) + t.Error(fmt.Sprintf("Failed to get model ID: %v", err)) } // Normalize the model name before writing @@ -967,35 +966,35 @@ func TestTag(t *testing.T) { // Push the model to the store if err := client.store.Write(model, []string{normalized}, nil); err != nil { - t.Fatalf("Failed to push model to store: %v", err) + t.Error(fmt.Sprintf("Failed to push model to store: %v", err)) } // Tag the model by ID if err := client.Tag(id, "other-repo:tag1"); err != nil { - t.Fatalf("Failed to tag model %q: %v", id, err) + t.Error(fmt.Sprintf("Failed to tag model %q: %v", id, err)) } // Tag the model by tag if err := client.Tag(id, "other-repo:tag2"); err != nil { - t.Fatalf("Failed to tag model %q: %v", id, err) + t.Error(fmt.Sprintf("Failed to tag model %q: %v", id, err)) } // Verify the model has all 3 tags modelInfo, err := client.GetModel("some-repo:some-tag") if err != nil { - t.Fatalf("Failed to get model: %v", err) + t.Error(fmt.Sprintf("Failed to get model: %v", err)) } if len(modelInfo.Tags()) != 3 { - t.Fatalf("Expected 3 tags, got %d", len(modelInfo.Tags())) + t.Error(fmt.Sprintf("Expected 3 tags, got %d", len(modelInfo.Tags()))) } // Verify the model can be accessed by new tags if _, err := client.GetModel("other-repo:tag1"); err != nil { - t.Fatalf("Failed to get model by tag: %v", err) + t.Error(fmt.Sprintf("Failed to get model by tag: %v", err)) } if _, err := client.GetModel("other-repo:tag2"); err != nil { - t.Fatalf("Failed to get model by tag: %v", err) + t.Error(fmt.Sprintf("Failed to get model by tag: %v", err)) } } @@ -1005,12 +1004,12 @@ func TestTagNotFound(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Tag the model by ID if err := client.Tag("non-existent-model:latest", "other-repo:tag1"); !errors.Is(err, ErrModelNotFound) { - t.Fatalf("Expected ErrModelNotFound, got: %v", err) + t.Error(fmt.Sprintf("Expected ErrModelNotFound, got: %v", err)) } } @@ -1020,11 +1019,11 @@ func TestClientPushModelNotFound(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } if err := client.PushModel(t.Context(), "non-existent-model:latest", nil); !errors.Is(err, ErrModelNotFound) { - t.Fatalf("Expected ErrModelNotFound got: %v", err) + t.Error(fmt.Sprintf("Expected ErrModelNotFound got: %v", err)) } } @@ -1034,13 +1033,13 @@ func TestIsModelInStoreNotFound(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } if inStore, err := client.IsModelInStore("non-existent-model:latest"); err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Error(fmt.Sprintf("Unexpected error: %v", err)) } else if inStore { - t.Fatalf("Expected model not to be found") + t.Error("Expected model not to be found") } } @@ -1050,7 +1049,7 @@ func TestIsModelInStoreFound(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Create a test model @@ -1061,13 +1060,13 @@ func TestIsModelInStoreFound(t *testing.T) { // Push the model to the store if err := client.store.Write(model, []string{normalized}, nil); err != nil { - t.Fatalf("Failed to push model to store: %v", err) + t.Error(fmt.Sprintf("Failed to push model to store: %v", err)) } if inStore, err := client.IsModelInStore("some-repo:some-tag"); err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Error(fmt.Sprintf("Unexpected error: %v", err)) } else if !inStore { - t.Fatalf("Expected model to be found") + t.Error("Expected model to be found") } } @@ -1142,26 +1141,26 @@ func TestMigrateHFTagsOnClientInit(t *testing.T) { // Step 1: Create a client and write a model with the legacy tag setupClient, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create setup client: %v", err) + t.Error(fmt.Sprintf("Failed to create setup client: %v", err)) } model := testutil.BuildModelFromPath(t, testGGUFFile) if err := setupClient.store.Write(model, []string{tc.storedTag}, nil); err != nil { - t.Fatalf("Failed to write model to store: %v", err) + t.Error(fmt.Sprintf("Failed to write model to store: %v", err)) } // Step 2: Create a NEW client (simulating restart) - migration should happen client, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Step 3: Verify the model can be found using the reference // (normalizeModelName converts hf.co -> huggingface.co, and migration should have updated the store) foundModel, err := client.GetModel(tc.lookupRef) if err != nil { - t.Fatalf("Failed to get model after migration: %v", err) + t.Error(fmt.Sprintf("Failed to get model after migration: %v", err)) } if foundModel == nil { @@ -1183,10 +1182,10 @@ func TestMigrateHFTagsOnClientInit(t *testing.T) { } } if hasOldTag { - t.Errorf("Model still has old hf.co tag after migration: %v", tags) + t.Error(fmt.Sprintf("Model still has old hf.co tag after migration: %v", tags)) } if !hasNewTag { - t.Errorf("Model doesn't have huggingface.co tag after migration: %v", tags) + t.Error(fmt.Sprintf("Model doesn't have huggingface.co tag after migration: %v", tags)) } } }) @@ -1215,7 +1214,7 @@ func TestPullHuggingFaceModelFromCache(t *testing.T) { // Create client client, err := newTestClient(tempDir) if err != nil { - t.Fatalf("Failed to create client: %v", err) + t.Error(fmt.Sprintf("Failed to create client: %v", err)) } // Create a test model and write it to the store with a normalized HuggingFace tag @@ -1224,20 +1223,20 @@ func TestPullHuggingFaceModelFromCache(t *testing.T) { // Store with normalized tag (huggingface.co) hfTag := "huggingface.co/testorg/testmodel:latest" if err := client.store.Write(model, []string{hfTag}, nil); err != nil { - t.Fatalf("Failed to write model to store: %v", err) + t.Error(fmt.Sprintf("Failed to write model to store: %v", err)) } // Now try to pull using the test case's reference - it should use the cache var progressBuffer bytes.Buffer err = client.PullModel(t.Context(), tc.pullRef, &progressBuffer) if err != nil { - t.Fatalf("Failed to pull model from cache: %v", err) + t.Error(fmt.Sprintf("Failed to pull model from cache: %v", err)) } // Verify that progress shows it was cached progressOutput := progressBuffer.String() if !strings.Contains(progressOutput, "Using cached model") { - t.Errorf("Expected progress to indicate cached model, got: %s", progressOutput) + t.Error(fmt.Sprintf("Expected progress to indicate cached model, got: %s", progressOutput)) } }) } diff --git a/pkg/distribution/distribution/normalize_test.go b/pkg/distribution/distribution/normalize_test.go index 26dd55673..861241bfd 100644 --- a/pkg/distribution/distribution/normalize_test.go +++ b/pkg/distribution/distribution/normalize_test.go @@ -1,14 +1,16 @@ package distribution import ( + "fmt" "io" "path/filepath" "strings" "testing" + "log/slog" + "github.com/docker/model-runner/pkg/distribution/builder" "github.com/docker/model-runner/pkg/distribution/tarball" - "github.com/sirupsen/logrus" ) func TestNormalizeModelName(t *testing.T) { @@ -151,7 +153,7 @@ func TestNormalizeModelName(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result := client.normalizeModelName(tt.input) if result != tt.expected { - t.Errorf("normalizeModelName(%q) = %q, want %q", tt.input, result, tt.expected) + t.Error(fmt.Sprintf("normalizeModelName(%q) = %q, want %q", tt.input, result, tt.expected)) } }) } @@ -213,7 +215,7 @@ func TestLooksLikeID(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result := client.looksLikeID(tt.input) if result != tt.expected { - t.Errorf("looksLikeID(%q) = %v, want %v", tt.input, result, tt.expected) + t.Error(fmt.Sprintf("looksLikeID(%q) = %v, want %v", tt.input, result, tt.expected)) } }) } @@ -275,7 +277,7 @@ func TestLooksLikeDigest(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result := client.looksLikeDigest(tt.input) if result != tt.expected { - t.Errorf("looksLikeDigest(%q) = %v, want %v", tt.input, result, tt.expected) + t.Error(fmt.Sprintf("looksLikeDigest(%q) = %v, want %v", tt.input, result, tt.expected)) } }) } @@ -292,7 +294,7 @@ func TestNormalizeModelNameWithIDResolution(t *testing.T) { // Extract the short ID (12 hex chars after "sha256:") if !strings.HasPrefix(modelID, "sha256:") { - t.Fatalf("Expected model ID to start with 'sha256:', got: %s", modelID) + t.Error(fmt.Sprintf("Expected model ID to start with 'sha256:', got: %s", modelID)) } shortID := modelID[7:19] // Extract 12 chars after "sha256:" fullHex := strings.TrimPrefix(modelID, "sha256:") @@ -323,7 +325,7 @@ func TestNormalizeModelNameWithIDResolution(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result := client.normalizeModelName(tt.input) if result != tt.expected { - t.Errorf("normalizeModelName(%q) = %q, want %q", tt.input, result, tt.expected) + t.Error(fmt.Sprintf("normalizeModelName(%q) = %q, want %q", tt.input, result, tt.expected)) } }) } @@ -339,10 +341,10 @@ func createTestClient(t *testing.T) (*Client, func()) { // Create client with minimal config client, err := NewClient( WithStoreRootPath(tempDir), - WithLogger(logrus.NewEntry(logrus.StandardLogger())), + WithLogger(slog.Default()), ) if err != nil { - t.Fatalf("Failed to create test client: %v", err) + t.Error(fmt.Sprintf("Failed to create test client: %v", err)) } cleanup := func() { @@ -373,7 +375,7 @@ func TestIsHuggingFaceReference(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result := isHuggingFaceReference(tt.input) if result != tt.expected { - t.Errorf("isHuggingFaceReference(%q) = %v, want %v", tt.input, result, tt.expected) + t.Error(fmt.Sprintf("isHuggingFaceReference(%q) = %v, want %v", tt.input, result, tt.expected)) } }) } @@ -435,13 +437,13 @@ func TestParseHFReference(t *testing.T) { t.Run(tt.name, func(t *testing.T) { repo, rev, tag := parseHFReference(tt.input) if repo != tt.expectedRepo { - t.Errorf("parseHFReference(%q) repo = %q, want %q", tt.input, repo, tt.expectedRepo) + t.Error(fmt.Sprintf("parseHFReference(%q) repo = %q, want %q", tt.input, repo, tt.expectedRepo)) } if rev != tt.expectedRev { - t.Errorf("parseHFReference(%q) rev = %q, want %q", tt.input, rev, tt.expectedRev) + t.Error(fmt.Sprintf("parseHFReference(%q) rev = %q, want %q", tt.input, rev, tt.expectedRev)) } if tag != tt.expectedTag { - t.Errorf("parseHFReference(%q) tag = %q, want %q", tt.input, tag, tt.expectedTag) + t.Error(fmt.Sprintf("parseHFReference(%q) tag = %q, want %q", tt.input, tag, tt.expectedTag)) } }) } @@ -455,7 +457,7 @@ func loadTestModel(t *testing.T, client *Client, ggufPath string) string { pr, pw := io.Pipe() target, err := tarball.NewTarget(pw) if err != nil { - t.Fatalf("Failed to create target: %v", err) + t.Error(fmt.Sprintf("Failed to create target: %v", err)) } done := make(chan error) @@ -468,15 +470,15 @@ func loadTestModel(t *testing.T, client *Client, ggufPath string) string { bldr, err := builder.FromPath(ggufPath) if err != nil { - t.Fatalf("Failed to create builder from GGUF: %v", err) + t.Error(fmt.Sprintf("Failed to create builder from GGUF: %v", err)) } if err := bldr.Build(t.Context(), target, nil); err != nil { - t.Fatalf("Failed to build model: %v", err) + t.Error(fmt.Sprintf("Failed to build model: %v", err)) } if err := <-done; err != nil { - t.Fatalf("Failed to load model: %v", err) + t.Error(fmt.Sprintf("Failed to load model: %v", err)) } if id == "" { diff --git a/pkg/inference/backends/diffusers/diffusers.go b/pkg/inference/backends/diffusers/diffusers.go index a966c6678..e6af495b1 100644 --- a/pkg/inference/backends/diffusers/diffusers.go +++ b/pkg/inference/backends/diffusers/diffusers.go @@ -116,14 +116,14 @@ func (d *diffusers) Install(_ context.Context, _ *http.Client) error { // Check if diffusers is installed if err := d.pythonCmd("-c", "import diffusers").Run(); err != nil { d.status = "diffusers package not installed" - d.log.Warnf("diffusers package not found. Install with: uv pip install diffusers torch") + d.log.Warn("diffusers package not found. Install with: uv pip install diffusers torch") return ErrDiffusersNotFound } // Get version output, err := d.pythonCmd("-c", "import diffusers; print(diffusers.__version__)").Output() if err != nil { - d.log.Warnf("could not get diffusers version: %v", err) + d.log.Warn(fmt.Sprintf("could not get diffusers version: %v", err)) d.status = "running diffusers version: unknown" } else { d.status = fmt.Sprintf("running diffusers version: %s", strings.TrimSpace(string(output))) @@ -156,7 +156,7 @@ func (d *diffusers) Run(ctx context.Context, socket, model string, modelRef stri return fmt.Errorf("%w: model %s", ErrNoDDUFFile, model) } - d.log.Infof("Loading DDUF file from: %s", ddufPath) + d.log.Info(fmt.Sprintf("Loading DDUF file from: %s", ddufPath)) args, err := d.config.GetArgs(ddufPath, socket, mode, backendConfig) if err != nil { @@ -168,7 +168,7 @@ func (d *diffusers) Run(ctx context.Context, socket, model string, modelRef stri args = append(args, "--served-model-name", modelRef) } - d.log.Infof("Diffusers args: %v", utils.SanitizeForLog(strings.Join(args, " "))) + d.log.Info(fmt.Sprintf("Diffusers args: %v", utils.SanitizeForLog(strings.Join(args, " ")))) if d.pythonPath == "" { return fmt.Errorf("diffusers: python runtime not configured; did you forget to call Install") @@ -187,7 +187,7 @@ func (d *diffusers) Run(ctx context.Context, socket, model string, modelRef stri SandboxConfig: "", Args: args, Logger: d.log, - ServerLogWriter: d.serverLog.Writer(), + ServerLogWriter: logging.NewWriter(d.serverLog), ErrorTransformer: ExtractPythonError, }) } diff --git a/pkg/inference/backends/llamacpp/download.go b/pkg/inference/backends/llamacpp/download.go index 4db132711..abc7e73b7 100644 --- a/pkg/inference/backends/llamacpp/download.go +++ b/pkg/inference/backends/llamacpp/download.go @@ -56,11 +56,11 @@ func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logge shouldUpdateServer := ShouldUpdateServer ShouldUpdateServerLock.Unlock() if !shouldUpdateServer { - log.Infof("downloadLatestLlamaCpp: update disabled") + log.Info("downloadLatestLlamaCpp: update disabled") return errLlamaCppUpdateDisabled } - log.Infof("downloadLatestLlamaCpp: %s, %s, %s, %s", desiredVersion, desiredVariant, vendoredServerStoragePath, llamaCppPath) + log.Info(fmt.Sprintf("downloadLatestLlamaCpp: %s, %s, %s, %s", desiredVersion, desiredVariant, vendoredServerStoragePath, llamaCppPath)) desiredTag := desiredVersion + "-" + desiredVariant url := fmt.Sprintf("https://hub.docker.com/v2/namespaces/%s/repositories/%s/tags/%s", hubNamespace, hubRepo, desiredTag) resp, err := httpClient.Get(url) @@ -89,7 +89,7 @@ func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logge latest = response.Digest } if latest == "" { - log.Warnf("could not fing the %s tag, hub response: %s", desiredTag, body) + log.Warn(fmt.Sprintf("could not fing the %s tag, hub response: %s", desiredTag, body)) return fmt.Errorf("could not find the %s tag", desiredTag) } @@ -107,18 +107,18 @@ func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logge data, err = os.ReadFile(currentVersionFile) if err != nil { - log.Warnf("failed to read current llama.cpp version: %v", err) - log.Warnf("proceeding to update llama.cpp binary") + log.Warn(fmt.Sprintf("failed to read current llama.cpp version: %v", err)) + log.Warn("proceeding to update llama.cpp binary") } else if strings.TrimSpace(string(data)) == latest { - log.Infoln("current llama.cpp version is already up to date") + log.Info("current llama.cpp version is already up to date") if _, statErr := os.Stat(llamaCppPath); statErr == nil { l.status = fmt.Sprintf("running llama.cpp %s (%s) version: %s", desiredTag, latest, getLlamaCppVersion(log, llamaCppPath)) return nil } - log.Infoln("llama.cpp binary must be updated, proceeding to update it") + log.Info("llama.cpp binary must be updated, proceeding to update it") } else { - log.Infof("current llama.cpp version is outdated: %s vs %s, proceeding to update it", strings.TrimSpace(string(data)), latest) + log.Info(fmt.Sprintf("current llama.cpp version is outdated: %s vs %s, proceeding to update it", strings.TrimSpace(string(data)), latest)) } image := fmt.Sprintf("registry-1.docker.io/%s/%s@%s", hubNamespace, hubRepo, latest) @@ -163,12 +163,12 @@ func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logge } } - log.Infoln("successfully updated llama.cpp binary") + log.Info("successfully updated llama.cpp binary") l.status = fmt.Sprintf("running llama.cpp %s (%s) version: %s", desiredTag, latest, getLlamaCppVersion(log, llamaCppPath)) - log.Infoln(l.status) + log.Info(l.status) if err := os.WriteFile(currentVersionFile, []byte(latest), 0o644); err != nil { - log.Warnf("failed to save llama.cpp version: %v", err) + log.Warn(fmt.Sprintf("failed to save llama.cpp version: %v", err)) } return nil @@ -176,7 +176,7 @@ func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logge //nolint:unused // Used in platform-specific files (download_darwin.go, download_windows.go) func extractFromImage(ctx context.Context, log logging.Logger, image, requiredOs, requiredArch, destination string) error { - log.Infof("Extracting image %q to %q", image, destination) + log.Info(fmt.Sprintf("Extracting image %q to %q", image, destination)) tmpDir, err := os.MkdirTemp("", "docker-tar-extract") if err != nil { return err @@ -191,7 +191,7 @@ func extractFromImage(ctx context.Context, log logging.Logger, image, requiredOs func getLlamaCppVersion(log logging.Logger, llamaCpp string) string { output, err := exec.Command(llamaCpp, "--version").CombinedOutput() if err != nil { - log.Warnf("could not get llama.cpp version: %v", err) + log.Warn(fmt.Sprintf("could not get llama.cpp version: %v", err)) return "unknown" } re := regexp.MustCompile(`version: \d+ \((\w+)\)`) @@ -199,6 +199,6 @@ func getLlamaCppVersion(log logging.Logger, llamaCpp string) string { if len(matches) == 2 { return matches[1] } - log.Warnf("failed to parse llama.cpp version from output:\n%s", strings.TrimSpace(string(output))) + log.Warn(fmt.Sprintf("failed to parse llama.cpp version from output:\n%s", strings.TrimSpace(string(output)))) return "unknown" } diff --git a/pkg/inference/backends/llamacpp/llamacpp.go b/pkg/inference/backends/llamacpp/llamacpp.go index a116480dc..592e6d078 100644 --- a/pkg/inference/backends/llamacpp/llamacpp.go +++ b/pkg/inference/backends/llamacpp/llamacpp.go @@ -117,7 +117,7 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error { // digest to be equal to the one on Docker Hub. llamaCppPath := filepath.Join(l.updatedServerStoragePath, llamaServerBin) if err := l.ensureLatestLlamaCpp(ctx, l.log, httpClient, llamaCppPath, l.vendoredServerStoragePath); err != nil { - l.log.Infof("failed to ensure latest llama.cpp: %v\n", err) + l.log.Info(fmt.Sprintf("failed to ensure latest llama.cpp: %v\n", err)) if !errors.Is(err, errLlamaCppUpToDate) && !errors.Is(err, errLlamaCppUpdateDisabled) { l.status = fmt.Sprintf("failed to install llama.cpp: %v", err) } @@ -129,7 +129,7 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error { } l.gpuSupported = l.checkGPUSupport(ctx) - l.log.Infof("installed llama-server with gpuSupport=%t", l.gpuSupported) + l.log.Info(fmt.Sprintf("installed llama-server with gpuSupport=%t", l.gpuSupported)) return nil } @@ -180,7 +180,7 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, _ string, mode SandboxConfig: sandbox.ConfigurationLlamaCpp, Args: args, Logger: l.log, - ServerLogWriter: l.serverLog.Writer(), + ServerLogWriter: logging.NewWriter(l.serverLog), }) } @@ -351,12 +351,12 @@ func (l *llamaCpp) checkGPUSupport(ctx context.Context) bool { "--list-devices", ) if err != nil { - l.log.Warnf("Failed to start sandboxed llama.cpp process to probe GPU support: %v", err) + l.log.Warn(fmt.Sprintf("Failed to start sandboxed llama.cpp process to probe GPU support: %v", err)) return false } defer llamaCppSandbox.Close() if err := llamaCppSandbox.Command().Wait(); err != nil { - l.log.Warnf("Failed to determine if llama-server is built with GPU support: %v", err) + l.log.Warn(fmt.Sprintf("Failed to determine if llama-server is built with GPU support: %v", err)) return false } sc := bufio.NewScanner(strings.NewReader(output.String())) diff --git a/pkg/inference/backends/mlx/mlx.go b/pkg/inference/backends/mlx/mlx.go index 6ba3683ed..9907ee4f7 100644 --- a/pkg/inference/backends/mlx/mlx.go +++ b/pkg/inference/backends/mlx/mlx.go @@ -102,7 +102,7 @@ func (m *mlx) Install(ctx context.Context, httpClient *http.Client) error { cmd := exec.CommandContext(ctx, pythonPath, "-c", "import mlx_lm") if runErr := cmd.Run(); runErr != nil { m.status = "mlx-lm package not installed" - m.log.Warnf("mlx-lm package not found. Install with: uv pip install mlx-lm") + m.log.Warn("mlx-lm package not found. Install with: uv pip install mlx-lm") return fmt.Errorf("mlx-lm package not installed: %w", runErr) } @@ -110,7 +110,7 @@ func (m *mlx) Install(ctx context.Context, httpClient *http.Client) error { cmd = exec.CommandContext(ctx, pythonPath, "-c", "import mlx; print(mlx.__version__)") output, outputErr := cmd.Output() if outputErr != nil { - m.log.Warnf("could not get MLX version: %v", outputErr) + m.log.Warn(fmt.Sprintf("could not get MLX version: %v", outputErr)) m.status = "running MLX version: unknown" } else { m.status = fmt.Sprintf("running MLX version: %s", strings.TrimSpace(string(output))) @@ -142,7 +142,7 @@ func (m *mlx) Run(ctx context.Context, socket, model string, modelRef string, mo SandboxConfig: "", Args: args, Logger: m.log, - ServerLogWriter: m.serverLog.Writer(), + ServerLogWriter: logging.NewWriter(m.serverLog), }) } diff --git a/pkg/inference/backends/runner.go b/pkg/inference/backends/runner.go index 186857340..9cab22c56 100644 --- a/pkg/inference/backends/runner.go +++ b/pkg/inference/backends/runner.go @@ -46,9 +46,8 @@ type RunnerConfig struct { // Logger interface for backend logging type Logger interface { - Infof(format string, args ...interface{}) - Warnf(format string, args ...interface{}) - Warnln(args ...interface{}) + Info(msg string, args ...any) + Warn(msg string, args ...any) } // RunBackend runs a backend process with common error handling and logging. @@ -61,8 +60,8 @@ type Logger interface { func RunBackend(ctx context.Context, config RunnerConfig) error { // Remove old socket file if err := os.RemoveAll(config.Socket); err != nil && !errors.Is(err, fs.ErrNotExist) { - config.Logger.Warnf("failed to remove socket file %s: %v\n", config.Socket, err) - config.Logger.Warnln(config.BackendName + " may not be able to start") + config.Logger.Warn("failed to remove socket file", "socket", config.Socket, "error", err) + config.Logger.Warn(config.BackendName + " may not be able to start") } // Sanitize args for safe logging @@ -70,7 +69,7 @@ func RunBackend(ctx context.Context, config RunnerConfig) error { for i, arg := range config.Args { sanitizedArgs[i] = utils.SanitizeForLog(arg, 0) } - config.Logger.Infof("%s args: %v", config.BackendName, sanitizedArgs) + config.Logger.Info("backend args", "backend", config.BackendName, "args", sanitizedArgs) // Create tail buffer for error output tailBuf := tailbuffer.NewTailBuffer(1024) @@ -107,7 +106,7 @@ func RunBackend(ctx context.Context, config RunnerConfig) error { errOutput := new(strings.Builder) if _, err := io.Copy(errOutput, tailBuf); err != nil { - config.Logger.Warnf("failed to read server output tail: %v", err) + config.Logger.Warn("failed to read server output tail", "error", err) } if errOutput.String() != "" { @@ -124,7 +123,7 @@ func RunBackend(ctx context.Context, config RunnerConfig) error { backendErrors <- backendErr close(backendErrors) if err := os.Remove(config.Socket); err != nil && !errors.Is(err, fs.ErrNotExist) { - config.Logger.Warnf("failed to remove socket file %s on exit: %v\n", config.Socket, err) + config.Logger.Warn("failed to remove socket file on exit", "socket", config.Socket, "error", err) } }() defer func() { diff --git a/pkg/inference/backends/sglang/sglang.go b/pkg/inference/backends/sglang/sglang.go index 802c48802..3d00e7eb5 100644 --- a/pkg/inference/backends/sglang/sglang.go +++ b/pkg/inference/backends/sglang/sglang.go @@ -111,14 +111,14 @@ func (s *sglang) Install(_ context.Context, _ *http.Client) error { // Check if sglang is installed if err := s.pythonCmd("-c", "import sglang").Run(); err != nil { s.status = "sglang package not installed" - s.log.Warnf("sglang package not found. Install with: uv pip install sglang") + s.log.Warn("sglang package not found. Install with: uv pip install sglang") return ErrSGLangNotFound } // Get version output, err := s.pythonCmd("-c", "import sglang; print(sglang.__version__)").Output() if err != nil { - s.log.Warnf("could not get sglang version: %v", err) + s.log.Warn(fmt.Sprintf("could not get sglang version: %v", err)) s.status = "running sglang version: unknown" } else { s.status = fmt.Sprintf("running sglang version: %s", strings.TrimSpace(string(output))) @@ -171,7 +171,7 @@ func (s *sglang) Run(ctx context.Context, socket, model string, modelRef string, SandboxConfig: "", Args: args, Logger: s.log, - ServerLogWriter: s.serverLog.Writer(), + ServerLogWriter: logging.NewWriter(s.serverLog), }) } diff --git a/pkg/inference/backends/vllm/vllm.go b/pkg/inference/backends/vllm/vllm.go index 3ec8bf975..603ecdf8a 100644 --- a/pkg/inference/backends/vllm/vllm.go +++ b/pkg/inference/backends/vllm/vllm.go @@ -94,7 +94,7 @@ func (v *vLLM) Install(_ context.Context, _ *http.Client) error { versionPath := filepath.Join(filepath.Dir(vllmDir), "version") versionBytes, err := os.ReadFile(versionPath) if err != nil { - v.log.Warnf("could not get vllm version: %v", err) + v.log.Warn(fmt.Sprintf("could not get vllm version: %v", err)) v.status = "running vllm version: unknown" } else { v.status = fmt.Sprintf("running vllm version: %s", strings.TrimSpace(string(versionBytes))) @@ -158,7 +158,7 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, m SandboxConfig: "", Args: args, Logger: v.log, - ServerLogWriter: v.serverLog.Writer(), + ServerLogWriter: logging.NewWriter(v.serverLog), }) } diff --git a/pkg/inference/backends/vllmmetal/vllmmetal.go b/pkg/inference/backends/vllmmetal/vllmmetal.go index 81ddec6e7..37e9003e1 100644 --- a/pkg/inference/backends/vllmmetal/vllmmetal.go +++ b/pkg/inference/backends/vllmmetal/vllmmetal.go @@ -109,7 +109,7 @@ func (v *vllmMetal) Install(ctx context.Context, httpClient *http.Client) error v.pythonPath = pythonPath return v.verifyInstallation(ctx) } - v.log.Infof("vllm-metal version mismatch: installed %s, want %s", installed, vllmMetalVersion) + v.log.Info(fmt.Sprintf("vllm-metal version mismatch: installed %s, want %s", installed, vllmMetalVersion)) } } @@ -120,7 +120,7 @@ func (v *vllmMetal) Install(ctx context.Context, httpClient *http.Client) error // Save version file if err := os.WriteFile(versionFile, []byte(vllmMetalVersion), 0644); err != nil { - v.log.Warnf("failed to write version file: %v", err) + v.log.Warn(fmt.Sprintf("failed to write version file: %v", err)) } v.pythonPath = pythonPath @@ -130,7 +130,7 @@ func (v *vllmMetal) Install(ctx context.Context, httpClient *http.Client) error // downloadAndExtract downloads the vllm-metal image from Docker Hub and extracts it. // The image contains a self-contained Python installation with all packages pre-installed. func (v *vllmMetal) downloadAndExtract(ctx context.Context, _ *http.Client) error { - v.log.Infof("Downloading vllm-metal %s from Docker Hub...", vllmMetalVersion) + v.log.Info(fmt.Sprintf("Downloading vllm-metal %s from Docker Hub...", vllmMetalVersion)) // Create temp directory for download downloadDir, err := os.MkdirTemp("", "vllm-metal-install") @@ -160,7 +160,7 @@ func (v *vllmMetal) downloadAndExtract(ctx context.Context, _ *http.Client) erro return fmt.Errorf("failed to remove existing install dir: %w", err) } - v.log.Infof("Extracting self-contained Python environment...") + v.log.Info("Extracting self-contained Python environment...") // Copy the extracted self-contained Python installation directly to install dir // (the image contains /vllm-metal/ with bin/, lib/, etc.) @@ -175,7 +175,7 @@ func (v *vllmMetal) downloadAndExtract(ctx context.Context, _ *http.Client) erro return fmt.Errorf("failed to make python3 executable: %w", err) } - v.log.Infof("vllm-metal %s installed successfully", vllmMetalVersion) + v.log.Info(fmt.Sprintf("vllm-metal %s installed successfully", vllmMetalVersion)) return nil } @@ -266,7 +266,7 @@ func (v *vllmMetal) Run(ctx context.Context, socket, model string, modelRef stri SandboxConfig: "", Args: args, Logger: v.log, - ServerLogWriter: v.serverLog.Writer(), + ServerLogWriter: logging.NewWriter(v.serverLog), }) } diff --git a/pkg/inference/models/handler_test.go b/pkg/inference/models/handler_test.go index d8c0d580e..1a58ea241 100644 --- a/pkg/inference/models/handler_test.go +++ b/pkg/inference/models/handler_test.go @@ -1,8 +1,8 @@ package models import ( + "fmt" "encoding/json" - "io" "net/http" "net/http/httptest" "net/url" @@ -11,11 +11,12 @@ import ( "strings" "testing" + "log/slog" + "github.com/docker/model-runner/pkg/distribution/builder" reg "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/distribution/registry/testregistry" "github.com/docker/model-runner/pkg/inference" - "github.com/sirupsen/logrus" ) // getProjectRoot returns the absolute path to the project root directory @@ -23,7 +24,7 @@ func getProjectRoot(t *testing.T) string { // Start from the current test file's directory dir, err := os.Getwd() if err != nil { - t.Fatalf("Failed to get current directory: %v", err) + t.Error(fmt.Sprintf("Failed to get current directory: %v", err)) } // Walk up the directory tree until we find the go.mod file @@ -49,7 +50,7 @@ func TestPullModel(t *testing.T) { // Create a tag for the model uri, err := url.Parse(server.URL) if err != nil { - t.Fatalf("Failed to parse registry URL: %v", err) + t.Error(fmt.Sprintf("Failed to parse registry URL: %v", err)) } tag := uri.Host + "/ai/model:v1.0.0" @@ -57,23 +58,23 @@ func TestPullModel(t *testing.T) { projectRoot := getProjectRoot(t) model, err := builder.FromPath(filepath.Join(projectRoot, "assets", "dummy.gguf")) if err != nil { - t.Fatalf("Failed to create model builder: %v", err) + t.Error(fmt.Sprintf("Failed to create model builder: %v", err)) } license, err := model.WithLicense(filepath.Join(projectRoot, "assets", "license.txt")) if err != nil { - t.Fatalf("Failed to add license to model: %v", err) + t.Error(fmt.Sprintf("Failed to add license to model: %v", err)) } // Build the OCI model artifact + push it (use plainHTTP for test registry) client := reg.NewClient(reg.WithPlainHTTP(true)) target, err := client.NewTarget(tag) if err != nil { - t.Fatalf("Failed to create model target: %v", err) + t.Error(fmt.Sprintf("Failed to create model target: %v", err)) } err = license.Build(t.Context(), target, os.Stdout) if err != nil { - t.Fatalf("Failed to build model: %v", err) + t.Error(fmt.Sprintf("Failed to build model: %v", err)) } tests := []struct { @@ -100,10 +101,10 @@ func TestPullModel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - log := logrus.NewEntry(logrus.StandardLogger()) - manager := NewManager(log.WithFields(logrus.Fields{"component": "model-manager"}), ClientConfig{ + log := slog.Default() + manager := NewManager(log.With("component", "model-manager"), ClientConfig{ StoreRootPath: tempDir, - Logger: log.WithFields(logrus.Fields{"component": "model-manager"}), + Logger: log.With("component", "model-manager"), PlainHTTP: true, }) handler := NewHTTPHandler(log, manager, nil) @@ -116,19 +117,19 @@ func TestPullModel(t *testing.T) { w := httptest.NewRecorder() err = handler.manager.Pull(tag, "", r, w) if err != nil { - t.Fatalf("Failed to pull model: %v", err) + t.Error(fmt.Sprintf("Failed to pull model: %v", err)) } if tt.expectedCT != w.Header().Get("Content-Type") { - t.Fatalf("Expected content type %s, got %s", tt.expectedCT, w.Header().Get("Content-Type")) + t.Error(fmt.Sprintf("Expected content type %s, got %s", tt.expectedCT, w.Header().Get("Content-Type"))) } // Clean tempDir after each test if err := os.RemoveAll(tempDir); err != nil { - t.Fatalf("Failed to clean temp directory: %v", err) + t.Error(fmt.Sprintf("Failed to clean temp directory: %v", err)) } if err := os.MkdirAll(tempDir, 0755); err != nil { - t.Fatalf("Failed to recreate temp directory: %v", err) + t.Error(fmt.Sprintf("Failed to recreate temp directory: %v", err)) } }) } @@ -143,19 +144,19 @@ func TestHandleGetModel(t *testing.T) { uri, err := url.Parse(server.URL) if err != nil { - t.Fatalf("Failed to parse registry URL: %v", err) + t.Error(fmt.Sprintf("Failed to parse registry URL: %v", err)) } // Prepare the OCI model artifact projectRoot := getProjectRoot(t) model, err := builder.FromPath(filepath.Join(projectRoot, "assets", "dummy.gguf")) if err != nil { - t.Fatalf("Failed to create model builder: %v", err) + t.Error(fmt.Sprintf("Failed to create model builder: %v", err)) } license, err := model.WithLicense(filepath.Join(projectRoot, "assets", "license.txt")) if err != nil { - t.Fatalf("Failed to add license to model: %v", err) + t.Error(fmt.Sprintf("Failed to add license to model: %v", err)) } // Build the OCI model artifact + push it (use plainHTTP for test registry) @@ -163,11 +164,11 @@ func TestHandleGetModel(t *testing.T) { client := reg.NewClient(reg.WithPlainHTTP(true)) target, err := client.NewTarget(tag) if err != nil { - t.Fatalf("Failed to create model target: %v", err) + t.Error(fmt.Sprintf("Failed to create model target: %v", err)) } err = license.Build(t.Context(), target, os.Stdout) if err != nil { - t.Fatalf("Failed to build model: %v", err) + t.Error(fmt.Sprintf("Failed to build model: %v", err)) } tests := []struct { @@ -207,10 +208,10 @@ func TestHandleGetModel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - log := logrus.NewEntry(logrus.StandardLogger()) - manager := NewManager(log.WithFields(logrus.Fields{"component": "model-manager"}), ClientConfig{ + log := slog.Default() + manager := NewManager(log.With("component", "model-manager"), ClientConfig{ StoreRootPath: tempDir, - Logger: log.WithFields(logrus.Fields{"component": "model-manager"}), + Logger: log.With("component", "model-manager"), Transport: http.DefaultTransport, UserAgent: "test-agent", PlainHTTP: true, @@ -223,7 +224,7 @@ func TestHandleGetModel(t *testing.T) { w := httptest.NewRecorder() err = handler.manager.Pull(tt.modelName, "", r, w) if err != nil { - t.Fatalf("Failed to pull model: %v", err) + t.Error(fmt.Sprintf("Failed to pull model: %v", err)) } } @@ -243,12 +244,12 @@ func TestHandleGetModel(t *testing.T) { // Check response if w.Code != tt.expectedCode { - t.Errorf("Expected status code %d, got %d", tt.expectedCode, w.Code) + t.Error(fmt.Sprintf("Expected status code %d, got %d", tt.expectedCode, w.Code)) } if tt.expectedError != "" { if !strings.Contains(w.Body.String(), tt.expectedError) { - t.Errorf("Expected error containing %q, got %q", tt.expectedError, w.Body.String()) + t.Error(fmt.Sprintf("Expected error containing %q, got %q", tt.expectedError, w.Body.String())) } } else { // For successful responses, verify we got a valid JSON response @@ -260,16 +261,16 @@ func TestHandleGetModel(t *testing.T) { Config json.RawMessage `json:"config"` } if err := json.NewDecoder(w.Body).Decode(&response); err != nil { - t.Errorf("Failed to decode response body: %v", err) + t.Error(fmt.Sprintf("Failed to decode response body: %v", err)) } } // Clean tempDir after each test if err := os.RemoveAll(tempDir); err != nil { - t.Fatalf("Failed to clean temp directory: %v", err) + t.Error(fmt.Sprintf("Failed to clean temp directory: %v", err)) } if err := os.MkdirAll(tempDir, 0755); err != nil { - t.Fatalf("Failed to recreate temp directory: %v", err) + t.Error(fmt.Sprintf("Failed to recreate temp directory: %v", err)) } }) } @@ -297,12 +298,10 @@ func TestCors(t *testing.T) { for _, tt := range tests { t.Run(tt.path, func(t *testing.T) { t.Parallel() - discard := logrus.New() - discard.SetOutput(io.Discard) - log := logrus.NewEntry(discard) - manager := NewManager(log.WithFields(logrus.Fields{"component": "model-manager"}), ClientConfig{ + log := slog.Default() + manager := NewManager(log.With("component", "model-manager"), ClientConfig{ StoreRootPath: tempDir, - Logger: log.WithFields(logrus.Fields{"component": "model-manager"}), + Logger: log.With("component", "model-manager"), }) m := NewHTTPHandler(log, manager, []string{"*"}) req := httptest.NewRequest(http.MethodOptions, "http://model-runner.docker.internal"+tt.path, http.NoBody) @@ -311,7 +310,7 @@ func TestCors(t *testing.T) { m.ServeHTTP(w, req) if w.Code != http.StatusNoContent { - t.Errorf("Expected status code 204 for OPTIONS request, got %d", w.Code) + t.Error(fmt.Sprintf("Expected status code 204 for OPTIONS request, got %d", w.Code)) } }) } diff --git a/pkg/inference/models/http_handler.go b/pkg/inference/models/http_handler.go index a72136735..0e34c3036 100644 --- a/pkg/inference/models/http_handler.go +++ b/pkg/inference/models/http_handler.go @@ -12,13 +12,14 @@ import ( "strings" "sync" + "log/slog" + "github.com/docker/model-runner/pkg/distribution/distribution" "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/internal/utils" "github.com/docker/model-runner/pkg/logging" "github.com/docker/model-runner/pkg/middleware" - "github.com/sirupsen/logrus" ) // HTTPHandler manages inference model pulls and storage. @@ -40,7 +41,7 @@ type ClientConfig struct { // StoreRootPath is the root path for the model store. StoreRootPath string // Logger is the logger to use. - Logger *logrus.Entry + Logger *slog.Logger // Transport is the HTTP transport to use. Transport http.RoundTripper // UserAgent is the user agent to use. @@ -108,21 +109,21 @@ func (h *HTTPHandler) handleCreateModel(w http.ResponseWriter, r *http.Request) if err := h.manager.Pull(request.From, request.BearerToken, r, w); err != nil { sanitizedFrom := utils.SanitizeForLog(request.From, -1) if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - h.log.Infof("Request canceled/timed out while pulling model %q", sanitizedFrom) + h.log.Info(fmt.Sprintf("Request canceled/timed out while pulling model %q", sanitizedFrom)) return } if errors.Is(err, registry.ErrInvalidReference) { - h.log.Warnf("Invalid model reference %q: %v", sanitizedFrom, err) + h.log.Warn(fmt.Sprintf("Invalid model reference %q: %v", sanitizedFrom, err)) http.Error(w, "Invalid model reference", http.StatusBadRequest) return } if errors.Is(err, registry.ErrUnauthorized) { - h.log.Warnf("Unauthorized to pull model %q: %v", sanitizedFrom, err) + h.log.Warn(fmt.Sprintf("Unauthorized to pull model %q: %v", sanitizedFrom, err)) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } if errors.Is(err, registry.ErrModelNotFound) { - h.log.Warnf("Failed to pull model %q: %v", sanitizedFrom, err) + h.log.Warn(fmt.Sprintf("Failed to pull model %q: %v", sanitizedFrom, err)) http.Error(w, "Model not found", http.StatusNotFound) return } @@ -165,7 +166,7 @@ func (h *HTTPHandler) handleExportModel(w http.ResponseWriter, r *http.Request, http.Error(w, err.Error(), http.StatusNotFound) return } - h.log.Warnln("Error while exporting model:", err) + h.log.Warn("error while exporting model", "error", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -182,7 +183,7 @@ func (h *HTTPHandler) handleGetModels(w http.ResponseWriter, r *http.Request) { // Write the response. w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(apiModels); err != nil { - h.log.Warnln("Error while encoding model listing response:", err) + h.log.Warn("error while encoding model listing response", "error", err) } } @@ -198,7 +199,7 @@ func (h *HTTPHandler) handleGetModelByRef(w http.ResponseWriter, r *http.Request if r.URL.Query().Has("remote") { val, err := strconv.ParseBool(r.URL.Query().Get("remote")) if err != nil { - h.log.Warnln("Error while parsing remote query parameter:", err) + h.log.Warn("error while parsing remote query parameter", "error", err) } else { remote = val } @@ -223,7 +224,7 @@ func (h *HTTPHandler) handleGetModelByRef(w http.ResponseWriter, r *http.Request // Write the response. w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(apiModel); err != nil { - h.log.Warnln("Error while encoding model response:", err) + h.log.Warn("error while encoding model response", "error", err) } } @@ -310,7 +311,7 @@ func (h *HTTPHandler) handleDeleteModel(w http.ResponseWriter, r *http.Request) var force bool if r.URL.Query().Has("force") { if val, err := strconv.ParseBool(r.URL.Query().Get("force")); err != nil { - h.log.Warnln("Error while parsing force query parameter:", err) + h.log.Warn("error while parsing force query parameter", "error", err) } else { force = val } @@ -327,7 +328,7 @@ func (h *HTTPHandler) handleDeleteModel(w http.ResponseWriter, r *http.Request) http.Error(w, err.Error(), http.StatusConflict) return } - h.log.Warnln("Error while deleting model:", err) + h.log.Warn("error while deleting model", "error", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -357,7 +358,7 @@ func (h *HTTPHandler) handleOpenAIGetModels(w http.ResponseWriter, r *http.Reque // Write the response. w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(models); err != nil { - h.log.Warnln("Error while encoding OpenAI model listing response:", err) + h.log.Warn("error while encoding OpenAI model listing response", "error", err) } } @@ -383,7 +384,7 @@ func (h *HTTPHandler) handleOpenAIGetModel(w http.ResponseWriter, r *http.Reques return } if err := json.NewEncoder(w).Encode(openaiModel); err != nil { - h.log.Warnln("Error while encoding OpenAI model response:", err) + h.log.Warn("error while encoding OpenAI model response", "error", err) } } @@ -443,7 +444,7 @@ func (h *HTTPHandler) handleTagModel(w http.ResponseWriter, r *http.Request, mod "target": target, } if err := json.NewEncoder(w).Encode(response); err != nil { - h.log.Warnln("Error while encoding tag response:", err) + h.log.Warn("error while encoding tag response", "error", err) } } @@ -451,17 +452,17 @@ func (h *HTTPHandler) handleTagModel(w http.ResponseWriter, r *http.Request, mod func (h *HTTPHandler) handlePushModel(w http.ResponseWriter, r *http.Request, model string) { if err := h.manager.Push(model, r, w); err != nil { if errors.Is(err, distribution.ErrInvalidReference) { - h.log.Warnf("Invalid model reference %q: %v", utils.SanitizeForLog(model, -1), err) + h.log.Warn(fmt.Sprintf("Invalid model reference %q: %v", utils.SanitizeForLog(model, -1), err)) http.Error(w, "Invalid model reference", http.StatusBadRequest) return } if errors.Is(err, distribution.ErrModelNotFound) { - h.log.Warnf("Failed to push model %q: %v", utils.SanitizeForLog(model, -1), err) + h.log.Warn(fmt.Sprintf("Failed to push model %q: %v", utils.SanitizeForLog(model, -1), err)) http.Error(w, "Model not found", http.StatusNotFound) return } if errors.Is(err, registry.ErrUnauthorized) { - h.log.Warnf("Unauthorized to push model %q: %v", utils.SanitizeForLog(model, -1), err) + h.log.Warn(fmt.Sprintf("Unauthorized to push model %q: %v", utils.SanitizeForLog(model, -1), err)) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } @@ -496,7 +497,7 @@ func (h *HTTPHandler) handleRepackageModel(w http.ResponseWriter, r *http.Reques http.Error(w, err.Error(), http.StatusNotFound) return } - h.log.Warnf("Failed to repackage model %q: %v", utils.SanitizeForLog(model, -1), err) + h.log.Warn(fmt.Sprintf("Failed to repackage model %q: %v", utils.SanitizeForLog(model, -1), err)) http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -509,7 +510,7 @@ func (h *HTTPHandler) handleRepackageModel(w http.ResponseWriter, r *http.Reques "target": req.Target, } if err := json.NewEncoder(w).Encode(response); err != nil { - h.log.Warnln("Error while encoding repackage response:", err) + h.log.Warn("error while encoding repackage response", "error", err) } } @@ -517,7 +518,7 @@ func (h *HTTPHandler) handleRepackageModel(w http.ResponseWriter, r *http.Reques func (h *HTTPHandler) handlePurge(w http.ResponseWriter, _ *http.Request) { err := h.manager.Purge() if err != nil { - h.log.Warnf("Failed to purge models: %v", err) + h.log.Warn(fmt.Sprintf("Failed to purge models: %v", err)) http.Error(w, err.Error(), http.StatusInternalServerError) return } diff --git a/pkg/inference/models/manager.go b/pkg/inference/models/manager.go index 7d5001bdb..f13519a3f 100644 --- a/pkg/inference/models/manager.go +++ b/pkg/inference/models/manager.go @@ -52,7 +52,7 @@ func NewManager(log logging.Logger, c ClientConfig) *Manager { distribution.WithRegistryClient(registryClient), ) if err != nil { - log.Errorf("Failed to create distribution client: %v", err) + log.Error(fmt.Sprintf("Failed to create distribution client: %v", err)) // Continue without distribution client. The model manager will still // respond to requests, but may return errors if the client is required. } @@ -92,13 +92,13 @@ func (m *Manager) ResolveID(modelRef string) string { sanitizedModelRef := utils.SanitizeForLog(modelRef, -1) model, err := m.GetLocal(sanitizedModelRef) if err != nil { - m.log.Warnf("Failed to resolve model ref %s to ID: %v", sanitizedModelRef, err) + m.log.Warn(fmt.Sprintf("Failed to resolve model ref %s to ID: %v", sanitizedModelRef, err)) return sanitizedModelRef } modelID, err := model.ID() if err != nil { - m.log.Warnf("Failed to get model ID for ref %s: %v", sanitizedModelRef, err) + m.log.Warn(fmt.Sprintf("Failed to get model ID for ref %s: %v", sanitizedModelRef, err)) return sanitizedModelRef } @@ -172,7 +172,7 @@ func (m *Manager) List() ([]*Model, error) { for _, model := range models { apiModel, err := ToModel(model) if err != nil { - m.log.Warnf("error while converting model, skipping: %v", err) + m.log.Warn(fmt.Sprintf("error while converting model, skipping: %v", err)) continue } apiModels = append(apiModels, apiModel) @@ -248,12 +248,12 @@ func (m *Manager) Pull(model string, bearerToken string, r *http.Request, w http } // Pull the model using the Docker model distribution client - m.log.Infoln("Pulling model:", utils.SanitizeForLog(model, -1)) + m.log.Info("pulling model", "model", utils.SanitizeForLog(model, -1)) // Use bearer token if provided var err error if bearerToken != "" { - m.log.Infoln("Using provided bearer token for authentication") + m.log.Info("Using provided bearer token for authentication") err = m.distributionClient.PullModel(r.Context(), model, progressWriter, bearerToken) } else { err = m.distributionClient.PullModel(r.Context(), model, progressWriter) @@ -300,7 +300,7 @@ func (m *Manager) Tag(ref, target string) error { for _, mModel := range models { modelID, idErr := mModel.ID() if idErr != nil { - m.log.Warnf("Failed to get model ID: %v", idErr) + m.log.Warn(fmt.Sprintf("Failed to get model ID: %v", idErr)) continue } @@ -359,7 +359,7 @@ func (m *Manager) Tag(ref, target string) error { // Now tag using the found model reference (the matching tag) if tagErr := m.distributionClient.Tag(foundModelRef, target); tagErr != nil { - m.log.Warnf("Failed to apply tag %q to resolved model %q: %v", utils.SanitizeForLog(target, -1), utils.SanitizeForLog(foundModelRef, -1), tagErr) + m.log.Warn(fmt.Sprintf("Failed to apply tag %q to resolved model %q: %v", utils.SanitizeForLog(target, -1), utils.SanitizeForLog(foundModelRef, -1), tagErr)) return fmt.Errorf("error while tagging model: %w", tagErr) } } else if err != nil { @@ -411,7 +411,7 @@ func (m *Manager) Purge() error { return fmt.Errorf("model distribution service unavailable") } if err := m.distributionClient.ResetStore(); err != nil { - m.log.Warnf("Failed to purge models: %v", err) + m.log.Warn(fmt.Sprintf("Failed to purge models: %v", err)) return fmt.Errorf("error while purging models: %w", err) } return nil diff --git a/pkg/inference/scheduling/http_handler.go b/pkg/inference/scheduling/http_handler.go index 90050cac8..06df35e08 100644 --- a/pkg/inference/scheduling/http_handler.go +++ b/pkg/inference/scheduling/http_handler.go @@ -448,7 +448,7 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) { go func() { preloadBody, err := json.Marshal(OpenAIInferenceRequest{Model: configureRequest.Model}) if err != nil { - h.scheduler.log.Warnf("failed to marshal preload request body: %v", err) + h.scheduler.log.Warn(fmt.Sprintf("failed to marshal preload request body: %v", err)) return } ctx, cancel := context.WithTimeout(context.Background(), time.Minute) @@ -460,7 +460,7 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) { bytes.NewReader(preloadBody), ) if err != nil { - h.scheduler.log.Warnf("failed to create preload request: %v", err) + h.scheduler.log.Warn(fmt.Sprintf("failed to create preload request: %v", err)) return } preloadReq.Header.Set("User-Agent", r.UserAgent()) @@ -470,7 +470,7 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) { recorder := httptest.NewRecorder() h.handleOpenAIInference(recorder, preloadReq) if recorder.Code != http.StatusOK { - h.scheduler.log.Warnf("background model preload failed with status %d: %s", recorder.Code, recorder.Body.String()) + h.scheduler.log.Warn(fmt.Sprintf("background model preload failed with status %d: %s", recorder.Code, recorder.Body.String())) } }() diff --git a/pkg/inference/scheduling/installer.go b/pkg/inference/scheduling/installer.go index 8703ad031..6c7ce90ea 100644 --- a/pkg/inference/scheduling/installer.go +++ b/pkg/inference/scheduling/installer.go @@ -1,6 +1,7 @@ package scheduling import ( + "fmt" "context" "errors" "net/http" @@ -141,7 +142,7 @@ func (i *installer) run(ctx context.Context) { continue } if err := backend.Install(ctx, i.httpClient); err != nil { - i.log.Warnf("Backend installation failed for %s: %v", name, err) + i.log.Warn(fmt.Sprintf("Backend installation failed for %s: %v", name, err)) select { case <-ctx.Done(): status.err = errors.Join(errInstallerShuttingDown, ctx.Err()) diff --git a/pkg/inference/scheduling/loader.go b/pkg/inference/scheduling/loader.go index 6a69b0ba2..9acd94caf 100644 --- a/pkg/inference/scheduling/loader.go +++ b/pkg/inference/scheduling/loader.go @@ -233,21 +233,17 @@ func (l *loader) evict(idleOnly bool) int { default: } if unused && (!idleOnly || idle || defunct) && (!idleOnly || !neverEvict || defunct) { - l.log.Infof("Evicting %s backend runner with model %s (%s) in %s mode", - r.backend, r.modelID, runnerInfo.modelRef, r.mode, - ) + l.log.Info(fmt.Sprintf("Evicting %s backend runner with model %s (%s) in %s mode", r.backend, r.modelID, runnerInfo.modelRef, r.mode)) l.freeRunnerSlot(runnerInfo.slot, r) evictedCount++ } else if unused { - l.log.Debugf("Runner %s (%s) is unused but not evictable: idleOnly=%v, idle=%v, defunct=%v, neverEvict=%v", - r.modelID, runnerInfo.modelRef, idleOnly, idle, defunct, neverEvict) + l.log.Debug(fmt.Sprintf("Runner %s (%s) is unused but not evictable: idleOnly=%v, idle=%v, defunct=%v, neverEvict=%v", r.modelID, runnerInfo.modelRef, idleOnly, idle, defunct, neverEvict)) } else { - l.log.Debugf("Runner %s (%s) is in use with %d references, cannot evict", - r.modelID, runnerInfo.modelRef, l.references[runnerInfo.slot]) + l.log.Debug(fmt.Sprintf("Runner %s (%s) is in use with %d references, cannot evict", r.modelID, runnerInfo.modelRef, l.references[runnerInfo.slot])) } } if evictedCount > 0 { - l.log.Infof("Evicted %d runner(s)", evictedCount) + l.log.Info(fmt.Sprintf("Evicted %d runner(s)", evictedCount)) } return len(l.runners) } @@ -260,16 +256,13 @@ func (l *loader) evictRunner(backend, model string, mode inference.BackendMode) for r, runnerInfo := range l.runners { unused := l.references[runnerInfo.slot] == 0 if unused && (allBackends || r.backend == backend) && r.modelID == model && r.mode == mode { - l.log.Infof("Evicting %s backend runner with model %s (%s) in %s mode", - r.backend, r.modelID, runnerInfo.modelRef, r.mode, - ) + l.log.Info(fmt.Sprintf("Evicting %s backend runner with model %s (%s) in %s mode", r.backend, r.modelID, runnerInfo.modelRef, r.mode)) l.freeRunnerSlot(runnerInfo.slot, r) found = true } } if !found { - l.log.Warnf("No unused runner found for backend=%s, model=%s, mode=%s", - utils.SanitizeForLog(backend), utils.SanitizeForLog(model), utils.SanitizeForLog(string(mode))) + l.log.Warn(fmt.Sprintf("No unused runner found for backend=%s, model=%s, mode=%s", utils.SanitizeForLog(backend), utils.SanitizeForLog(model), utils.SanitizeForLog(string(mode)))) } return len(l.runners) } @@ -452,7 +445,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string defaultConfig := inference.BackendConfiguration{} if l.modelManager != nil { if bundle, err := l.modelManager.GetBundle(modelID); err != nil { - l.log.Warnf("Failed to get bundle for model %s to determine default context size: %v", modelID, err) + l.log.Warn(fmt.Sprintf("Failed to get bundle for model %s to determine default context size: %v", modelID, err)) } else if runtimeConfig := bundle.RuntimeConfig(); runtimeConfig != nil { if ctxSize := runtimeConfig.GetContextSize(); ctxSize != nil { defaultConfig.ContextSize = ctxSize @@ -462,7 +455,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string runnerConfig = &defaultConfig } - l.log.Infof("Loading %s backend runner with model %s in %s mode", backendName, modelID, mode) + l.log.Info(fmt.Sprintf("Loading %s backend runner with model %s in %s mode", backendName, modelID, mode)) // Acquire the loader lock and defer its release. if !l.lock(ctx) { @@ -492,7 +485,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string if ok { select { case <-l.slots[existing.slot].done: - l.log.Warnf("%s runner for %s is defunct. Waiting for it to be evicted.", backendName, existing.modelRef) + l.log.Warn(fmt.Sprintf("%s runner for %s is defunct. Waiting for it to be evicted.", backendName, existing.modelRef)) if l.references[existing.slot] == 0 { l.evictRunner(backendName, modelID, mode) // Continue the loop to retry loading after evicting the defunct runner @@ -509,8 +502,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string // If all slots are full, try evicting unused runners. if len(l.runners) == len(l.slots) { - l.log.Infof("Evicting to make room: %d/%d slots used", - len(l.runners), len(l.slots)) + l.log.Info(fmt.Sprintf("Evicting to make room: %d/%d slots used", len(l.runners), len(l.slots))) runnerCountAtLoopStart := len(l.runners) remainingRunners := l.evict(false) // Restart the loop if eviction happened @@ -530,8 +522,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string } if slot < 0 { - l.log.Debugf("Cannot load model yet: %d/%d slots used", - len(l.runners), len(l.slots)) + l.log.Debug(fmt.Sprintf("Cannot load model yet: %d/%d slots used", len(l.runners), len(l.slots))) } // If we've identified a slot, then we're ready to start a runner. @@ -539,9 +530,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string // Create the runner. runner, err := run(l.log, backend, modelID, modelRef, mode, slot, runnerConfig, l.openAIRecorder) if err != nil { - l.log.Warnf("Unable to start %s backend runner with model %s in %s mode: %v", - backendName, modelID, mode, err, - ) + l.log.Warn(fmt.Sprintf("Unable to start %s backend runner with model %s in %s mode: %v", backendName, modelID, mode, err)) return nil, fmt.Errorf("unable to start runner: %w", err) } @@ -553,9 +542,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string // deduplication of runners and keep slot / memory reservations. if err := runner.wait(ctx); err != nil { runner.terminate() - l.log.Warnf("Initialization for %s backend runner with model %s in %s mode failed: %v", - backendName, modelID, mode, err, - ) + l.log.Warn(fmt.Sprintf("Initialization for %s backend runner with model %s in %s mode failed: %v", backendName, modelID, mode, err)) return nil, fmt.Errorf("error waiting for runner to be ready: %w", err) } @@ -628,7 +615,7 @@ func (l *loader) setRunnerConfig(ctx context.Context, backendName, modelID strin // If the configuration hasn't changed, then just return. if existingConfig, ok := l.runnerConfigs[configKey]; ok && reflect.DeepEqual(runnerConfig, existingConfig) { - l.log.Infof("Configuration for %s runner for modelID %s unchanged", backendName, modelID) + l.log.Info(fmt.Sprintf("Configuration for %s runner for modelID %s unchanged", backendName, modelID)) return nil } @@ -651,7 +638,7 @@ func (l *loader) setRunnerConfig(ctx context.Context, backendName, modelID strin return errRunnerAlreadyActive } - l.log.Infof("Configuring %s runner for %s", backendName, modelID) + l.log.Info(fmt.Sprintf("Configuring %s runner for %s", backendName, modelID)) l.runnerConfigs[configKey] = runnerConfig return nil } diff --git a/pkg/inference/scheduling/loader_test.go b/pkg/inference/scheduling/loader_test.go index fe79744bd..d84622251 100644 --- a/pkg/inference/scheduling/loader_test.go +++ b/pkg/inference/scheduling/loader_test.go @@ -1,6 +1,7 @@ package scheduling import ( + "fmt" "context" "errors" "io" @@ -8,8 +9,9 @@ import ( "testing" "time" + "log/slog" + "github.com/docker/model-runner/pkg/inference" - "github.com/sirupsen/logrus" ) // mockBackend is a minimal backend implementation for testing @@ -55,10 +57,8 @@ func (b *fastFailBackend) Run(ctx context.Context, socket, model string, modelRe } // createTestLogger creates a logger for testing -func createTestLogger() *logrus.Entry { - log := logrus.New() - log.SetOutput(io.Discard) - return logrus.NewEntry(log) +func createTestLogger() *slog.Logger { + return slog.Default() } // Test memory size constants @@ -68,7 +68,7 @@ const ( // createDefunctMockRunner creates a mock runner with a closed done channel, // simulating a defunct (crashed/terminated) runner for testing -func createDefunctMockRunner(ctx context.Context, log *logrus.Entry, backend inference.Backend) *runner { +func createDefunctMockRunner(ctx context.Context, log *slog.Logger, backend inference.Backend) *runner { defunctRunnerDone := make(chan struct{}) _, defunctRunnerCancel := context.WithCancel(ctx) @@ -97,7 +97,7 @@ func createDefunctMockRunner(ctx context.Context, log *logrus.Entry, backend inf // createAliveTerminableMockRunner creates a mock runner with an open done channel // (i.e., not defunct) that will close when cancel is invoked, so terminate() returns. -func createAliveTerminableMockRunner(ctx context.Context, log *logrus.Entry, backend inference.Backend) *runner { +func createAliveTerminableMockRunner(ctx context.Context, log *slog.Logger, backend inference.Backend) *runner { runCtx, cancel := context.WithCancel(ctx) done := make(chan struct{}) @@ -162,16 +162,16 @@ func TestMakeRunnerKey(t *testing.T) { key := makeRunnerKey(tt.backend, tt.modelID, tt.draftModelID, tt.mode) if key.backend != tt.backend { - t.Errorf("Expected backend %q, got %q", tt.backend, key.backend) + t.Error(fmt.Sprintf("Expected backend %q, got %q", tt.backend, key.backend)) } if key.modelID != tt.modelID { - t.Errorf("Expected modelID %q, got %q", tt.modelID, key.modelID) + t.Error(fmt.Sprintf("Expected modelID %q, got %q", tt.modelID, key.modelID)) } if key.draftModelID != tt.draftModelID { - t.Errorf("Expected draftModelID %q, got %q", tt.draftModelID, key.draftModelID) + t.Error(fmt.Sprintf("Expected draftModelID %q, got %q", tt.draftModelID, key.draftModelID)) } if key.mode != tt.mode { - t.Errorf("Expected mode %v, got %v", tt.mode, key.mode) + t.Error(fmt.Sprintf("Expected mode %v, got %v", tt.mode, key.mode)) } }) } @@ -186,16 +186,16 @@ func TestMakeConfigKey(t *testing.T) { key := makeConfigKey(backend, modelID, mode) if key.backend != backend { - t.Errorf("Expected backend %q, got %q", backend, key.backend) + t.Error(fmt.Sprintf("Expected backend %q, got %q", backend, key.backend)) } if key.modelID != modelID { - t.Errorf("Expected modelID %q, got %q", modelID, key.modelID) + t.Error(fmt.Sprintf("Expected modelID %q, got %q", modelID, key.modelID)) } if key.draftModelID != "" { - t.Errorf("Expected empty draftModelID for config key, got %q", key.draftModelID) + t.Error(fmt.Sprintf("Expected empty draftModelID for config key, got %q", key.draftModelID)) } if key.mode != mode { - t.Errorf("Expected mode %v, got %v", mode, key.mode) + t.Error(fmt.Sprintf("Expected mode %v, got %v", mode, key.mode)) } } @@ -326,7 +326,7 @@ func TestPerModelKeepAliveEviction(t *testing.T) { // Runner with short keep_alive should be evicted, never-evict should remain if remaining != 1 { - t.Errorf("Expected 1 remaining runner after eviction, got %d", remaining) + t.Error(fmt.Sprintf("Expected 1 remaining runner after eviction, got %d", remaining)) } // Verify that model-never is still present @@ -382,10 +382,10 @@ func TestIdleCheckDurationWithPerModelKeepAlive(t *testing.T) { // Should be based on the short keep_alive runner (around 100ms + 100ms buffer) // The never-evict runner should be skipped if duration < 0 { - t.Errorf("Expected positive duration, got %v", duration) + t.Error(fmt.Sprintf("Expected positive duration, got %v", duration)) } if duration > 500*time.Millisecond { - t.Errorf("Expected duration around 200ms, got %v", duration) + t.Error(fmt.Sprintf("Expected duration around 200ms, got %v", duration)) } loader.unlock() diff --git a/pkg/inference/scheduling/runner.go b/pkg/inference/scheduling/runner.go index 710a538f8..9f06dc23b 100644 --- a/pkg/inference/scheduling/runner.go +++ b/pkg/inference/scheduling/runner.go @@ -140,7 +140,7 @@ func run( return nil } proxy.Transport = transport - proxyLog := log.Writer() + proxyLog := logging.NewWriter(log) proxy.ErrorLog = logpkg.New(proxyLog, "", 0) // Create a cancellable context to regulate the runner's backend run loop @@ -192,15 +192,13 @@ func run( if r.openAIRecorder != nil { r.openAIRecorder.SetConfigForModel(modelID, runnerConfig) } else { - r.log.Warnf("OpenAI recorder is nil for model %s", modelID) + r.log.Warn(fmt.Sprintf("OpenAI recorder is nil for model %s", modelID)) } // Start the backend run loop. go func() { if err := backend.Run(runCtx, socket, modelID, modelRef, mode, runnerConfig); err != nil { - log.Warnf("Backend %s running model %s exited with error: %v", - backend.Name(), utils.SanitizeForLog(modelRef), err, - ) + log.Warn(fmt.Sprintf("Backend %s running model %s exited with error: %v", backend.Name(), utils.SanitizeForLog(modelRef), err)) r.err = err } close(runDone) @@ -266,13 +264,13 @@ func (r *runner) terminate() { // Close the proxy's log. if err := r.proxyLog.Close(); err != nil { - r.log.Warnf("Unable to close reverse proxy log writer: %v", err) + r.log.Warn(fmt.Sprintf("Unable to close reverse proxy log writer: %v", err)) } if r.openAIRecorder != nil { r.openAIRecorder.RemoveModel(r.model) } else { - r.log.Warnf("OpenAI recorder is nil for model %s", r.model) + r.log.Warn(fmt.Sprintf("OpenAI recorder is nil for model %s", r.model)) } } diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index dd918478b..575c2c33e 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -59,7 +59,7 @@ func NewScheduler( tracker *metrics.Tracker, deferredBackends []string, ) *Scheduler { - openAIRecorder := metrics.NewOpenAIRecorder(log.WithField("component", "openai-recorder"), modelManager) + openAIRecorder := metrics.NewOpenAIRecorder(log.With("component", "openai-recorder"), modelManager) // Create the scheduler. s := &Scheduler{ @@ -107,7 +107,7 @@ func (s *Scheduler) Run(ctx context.Context) error { func (s *Scheduler) selectBackendForModel(model types.Model, backend inference.Backend, modelRef string) inference.Backend { config, err := model.Config() if err != nil { - s.log.Warnln("failed to fetch model config:", err) + s.log.Warn("failed to fetch model config", "error", err) return backend } @@ -118,8 +118,8 @@ func (s *Scheduler) selectBackendForModel(model types.Model, backend inference.B if s.installer.isInstalled(vllmmetal.Name) { return vllmMetalBackend } - s.log.Infof("vllm-metal backend is available but not installed. "+ - "To install, run: docker model install-runner --backend %s", vllmmetal.Name) + s.log.Info("vllm-metal backend is available but not installed", + "backend", vllmmetal.Name) return vllmMetalBackend } // Fall back to MLX on macOS @@ -134,9 +134,8 @@ func (s *Scheduler) selectBackendForModel(model types.Model, backend inference.B if sglangBackend, ok := s.backends[sglang.Name]; ok && sglangBackend != nil { return sglangBackend } - s.log.Warnf("Model %s is in safetensors format but no compatible backend is available. "+ - "Backend %s may not support this format and could fail at runtime.", - utils.SanitizeForLog(modelRef), backend.Name()) + s.log.Warn(fmt.Sprintf("Model %s is in safetensors format but no compatible backend is available. "+ + "Backend %s may not support this format and could fail at runtime.", utils.SanitizeForLog(modelRef), backend.Name())) } return backend @@ -205,7 +204,7 @@ func (s *Scheduler) GetAllActiveRunners() []metrics.ActiveRunner { for _, backend := range runningBackends { mode, ok := inference.ParseBackendMode(backend.Mode) if !ok { - s.log.Warnf("Unknown backend mode %q, defaulting to completion.", backend.Mode) + s.log.Warn(fmt.Sprintf("Unknown backend mode %q, defaulting to completion.", backend.Mode)) } // Find the runner slot for this backend/model combination // We iterate through all runners since we don't know the draftModelID @@ -213,7 +212,7 @@ func (s *Scheduler) GetAllActiveRunners() []metrics.ActiveRunner { if key.backend == backend.BackendName && key.modelID == backend.ModelName && key.mode == mode { socket, err := RunnerSocketPath(runnerInfo.slot) if err != nil { - s.log.Warnf("Failed to get socket path for runner %s/%s (%s): %v", backend.BackendName, backend.ModelName, key.modelID, err) + s.log.Warn(fmt.Sprintf("Failed to get socket path for runner %s/%s (%s): %v", backend.BackendName, backend.ModelName, key.modelID, err)) continue } @@ -245,7 +244,7 @@ func (s *Scheduler) GetLlamaCppSocket() (string, error) { if backend.BackendName == llamacpp.Name { mode, ok := inference.ParseBackendMode(backend.Mode) if !ok { - s.log.Warnf("Unknown backend mode %q, defaulting to completion.", backend.Mode) + s.log.Warn(fmt.Sprintf("Unknown backend mode %q, defaulting to completion.", backend.Mode)) } // Find the runner slot for this backend/model combination // We iterate through all runners since we don't know the draftModelID @@ -335,7 +334,7 @@ func (s *Scheduler) ConfigureRunner(ctx context.Context, backend inference.Backe // Set the runner configuration if err := s.loader.setRunnerConfig(ctx, backend.Name(), modelID, mode, runnerConfig); err != nil { - s.log.Warnf("Failed to configure %s runner for %s (%s): %s", backend.Name(), utils.SanitizeForLog(req.Model, -1), modelID, err) + s.log.Warn(fmt.Sprintf("Failed to configure %s runner for %s (%s): %s", backend.Name(), utils.SanitizeForLog(req.Model, -1), modelID, err)) return nil, err } diff --git a/pkg/inference/scheduling/scheduler_test.go b/pkg/inference/scheduling/scheduler_test.go index 0c1d4a71b..e13bf5b70 100644 --- a/pkg/inference/scheduling/scheduler_test.go +++ b/pkg/inference/scheduling/scheduler_test.go @@ -1,12 +1,12 @@ package scheduling import ( - "io" + "fmt" + "log/slog" "net/http" "net/http/httptest" "testing" - "github.com/sirupsen/logrus" ) func TestCors(t *testing.T) { @@ -30,9 +30,7 @@ func TestCors(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - discard := logrus.New() - discard.SetOutput(io.Discard) - log := logrus.NewEntry(discard) + log := slog.Default() s := NewScheduler(log, nil, nil, nil, nil, nil, nil) httpHandler := NewHTTPHandler(s, nil, []string{"*"}) req := httptest.NewRequest(http.MethodOptions, "http://model-runner.docker.internal"+tt.path, http.NoBody) @@ -41,7 +39,7 @@ func TestCors(t *testing.T) { httpHandler.ServeHTTP(w, req) if w.Code != http.StatusNoContent { - t.Errorf("Expected status code 204 for OPTIONS request, got %d", w.Code) + t.Error(fmt.Sprintf("Expected status code 204 for OPTIONS request, got %d", w.Code)) } }) } diff --git a/pkg/internal/dockerhub/download.go b/pkg/internal/dockerhub/download.go index 800b1c6bb..d48091d8c 100644 --- a/pkg/internal/dockerhub/download.go +++ b/pkg/internal/dockerhub/download.go @@ -5,7 +5,7 @@ import ( "encoding/base64" "errors" "fmt" - "log" + "log/slog" "os" "path/filepath" "strings" @@ -20,7 +20,6 @@ import ( "github.com/containerd/platforms" "github.com/docker/model-runner/pkg/internal/jsonutil" v1 "github.com/opencontainers/image-spec/specs-go/v1" - "github.com/sirupsen/logrus" ) func PullPlatform(ctx context.Context, image, destination, requiredOs, requiredArch string) error { @@ -52,7 +51,7 @@ func retry(ctx context.Context, attempts int, sleep time.Duration, f func() (*v1 var result *v1.Descriptor for i := 0; i < attempts; i++ { if i > 0 { - log.Printf("retry %d after error: %v\n", i, err) + slog.Info("retrying after error", "attempt", i, "error", err) select { case <-ctx.Done(): return nil, ctx.Err() @@ -100,7 +99,7 @@ func dockerCredentials(host string) (string, string, error) { if hubUsername != "" && hubPassword != "" { return hubUsername, hubPassword, nil } - logrus.WithField("host", host).Debug("checking for registry auth config") + slog.Debug("checking for registry auth config", "host", host) home, err := os.UserHomeDir() if err != nil { return "", "", err @@ -125,10 +124,10 @@ func dockerCredentials(host string) (string, string, error) { } parts := strings.SplitN(string(creds), ":", 2) if len(parts) != 2 { - logrus.Debugf("skipping not user/password auth for registry %s: %s", host, parts[0]) + slog.Debug("skipping non-user/password auth for registry", "host", host, "auth_type", parts[0]) return "", "", nil } - logrus.Debugf("using auth for registry %s: user=%s", host, parts[0]) + slog.Debug("using auth for registry", "host", host, "user", parts[0]) return parts[0], parts[1], nil } } diff --git a/pkg/logging/logging.go b/pkg/logging/logging.go index aacc31bac..399628ec6 100644 --- a/pkg/logging/logging.go +++ b/pkg/logging/logging.go @@ -1,14 +1,78 @@ package logging import ( + "bufio" + "context" "io" - - "github.com/sirupsen/logrus" + "log/slog" + "os" + "strings" ) -// Logger is a bridging interface between logrus and Docker Desktop's internal -// logging types. -type Logger interface { - logrus.FieldLogger - Writer() *io.PipeWriter +// Logger is the application logger type, backed by slog. +type Logger = *slog.Logger + +// ParseLevel parses a log level string into slog.Level. +// Supported values: debug, info, warn, error (case-insensitive). +// Defaults to info if the value is unrecognized. +func ParseLevel(s string) slog.Level { + switch strings.ToLower(strings.TrimSpace(s)) { + case "debug": + return slog.LevelDebug + case "info", "": + return slog.LevelInfo + case "warn", "warning": + return slog.LevelWarn + case "error": + return slog.LevelError + default: + return slog.LevelInfo + } +} + +// NewLogger creates a new slog.Logger with a text handler at the given level. +func NewLogger(level slog.Level) *slog.Logger { + return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: level, + })) +} + +// slogWriter is an io.WriteCloser that forwards each line to a slog.Logger. +type slogWriter struct { + logger *slog.Logger + pr *io.PipeReader + pw *io.PipeWriter + done chan struct{} +} + +// NewWriter returns an io.WriteCloser that logs each line written to it +// using the provided slog.Logger at Info level. +func NewWriter(logger *slog.Logger) io.WriteCloser { + pr, pw := io.Pipe() + sw := &slogWriter{ + logger: logger, + pr: pr, + pw: pw, + done: make(chan struct{}), + } + go sw.scan() + return sw +} + +func (sw *slogWriter) scan() { + defer close(sw.done) + scanner := bufio.NewScanner(sw.pr) + for scanner.Scan() { + sw.logger.Log(context.Background(), slog.LevelInfo, scanner.Text()) + } +} + +func (sw *slogWriter) Write(p []byte) (int, error) { + return sw.pw.Write(p) +} + +func (sw *slogWriter) Close() error { + err := sw.pw.Close() + <-sw.done + return err } diff --git a/pkg/metrics/aggregated_handler.go b/pkg/metrics/aggregated_handler.go index 8f24b1960..06993d211 100644 --- a/pkg/metrics/aggregated_handler.go +++ b/pkg/metrics/aggregated_handler.go @@ -65,7 +65,7 @@ func (h *AggregatedMetricsHandler) collectAndAggregateMetrics(ctx context.Contex families, err := h.fetchRunnerMetrics(ctx, runner) if err != nil { - h.log.Warnf("Failed to fetch metrics from runner %s/%s: %v", runner.BackendName, runner.ModelName, err) + h.log.Warn(fmt.Sprintf("Failed to fetch metrics from runner %s/%s: %v", runner.BackendName, runner.ModelName, err)) return } @@ -165,7 +165,7 @@ func (h *AggregatedMetricsHandler) writeAggregatedMetrics(w http.ResponseWriter, encoder := expfmt.NewEncoder(w, expfmt.NewFormat(expfmt.TypeTextPlain)) for _, family := range families { if err := encoder.Encode(family); err != nil { - h.log.Errorf("Failed to encode metric family %s: %v", *family.Name, err) + h.log.Error(fmt.Sprintf("Failed to encode metric family %s: %v", *family.Name, err)) continue } } diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index f892cd278..05cceda4c 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -14,7 +14,6 @@ import ( "github.com/docker/model-runner/pkg/distribution/types" "github.com/docker/model-runner/pkg/internal/utils" "github.com/docker/model-runner/pkg/logging" - "github.com/sirupsen/logrus" ) type Tracker struct { @@ -49,13 +48,7 @@ func NewTracker(httpClient *http.Client, log logging.Logger, userAgent string, d userAgent = userAgent + " docker-model-runner" } - if os.Getenv("DEBUG") == "1" { - if logger, ok := log.(*logrus.Logger); ok { - logger.SetLevel(logrus.DebugLevel) - } else if entry, ok := log.(*logrus.Entry); ok { - entry.Logger.SetLevel(logrus.DebugLevel) - } - } + // Debug level is now configured via LOG_LEVEL environment variable return &Tracker{ doNotTrack: os.Getenv("DO_NOT_TRACK") == "1" || doNotTrack, @@ -75,7 +68,7 @@ func (t *Tracker) TrackModel(model types.Model, userAgent, action string) { func (t *Tracker) trackModel(model types.Model, userAgent, action string) { tags := model.Tags() - t.log.Debugln("Tracking model:", tags) + t.log.Debug("tracking model", "tags", tags) if len(tags) == 0 { return } @@ -90,14 +83,14 @@ func (t *Tracker) trackModel(model types.Model, userAgent, action string) { for _, tag := range tags { ref, err := reference.ParseReference(tag, registry.GetDefaultRegistryOptions()...) if err != nil { - t.log.Errorf("Error parsing reference: %v\n", err) + t.log.Error("error parsing reference", "error", err) return } if err = t.headManifest(ref, ua); err != nil { - t.log.Debugf("Manifest does not exist or error occurred: %v\n", err) + t.log.Debug("manifest does not exist or error occurred", "error", err) continue } - t.log.Debugln("Tracked", utils.SanitizeForLog(ref.Name(), -1), utils.SanitizeForLog(ref.Identifier(), -1), "with user agent:", utils.SanitizeForLog(ua, -1)) + t.log.Debug("tracked", "name", utils.SanitizeForLog(ref.Name(), -1), "identifier", utils.SanitizeForLog(ref.Identifier(), -1), "userAgent", utils.SanitizeForLog(ua, -1)) } } diff --git a/pkg/metrics/openai_recorder.go b/pkg/metrics/openai_recorder.go index b1a31dee3..e2acb179f 100644 --- a/pkg/metrics/openai_recorder.go +++ b/pkg/metrics/openai_recorder.go @@ -205,7 +205,7 @@ func (r *OpenAIRecorder) truncateBase64Data(data string) string { func (r *OpenAIRecorder) SetConfigForModel(model string, config *inference.BackendConfiguration) { if config == nil { - r.log.Warnf("SetConfigForModel called with nil config for model %s", model) + r.log.Warn(fmt.Sprintf("SetConfigForModel called with nil config for model %s", model)) return } @@ -399,9 +399,9 @@ func (r *OpenAIRecorder) RecordResponse(id, model string, rw http.ResponseWriter return } } - r.log.Errorf("Matching request (id=%s) not found for model %s - %d\n%s", id, modelID, statusCode, response) + r.log.Error(fmt.Sprintf("Matching request (id=%s) not found for model %s - %d\n%s", id, modelID, statusCode, response)) } else { - r.log.Errorf("Model %s not found in records - %d\n%s", modelID, statusCode, response) + r.log.Error(fmt.Sprintf("Model %s not found in records - %d\n%s", modelID, statusCode, response)) } } @@ -717,7 +717,7 @@ func (r *OpenAIRecorder) handleStreamingRequests(w http.ResponseWriter, req *htt // Send heartbeat to establish connection. if _, err := fmt.Fprintf(w, "event: connected\ndata: {\"status\": \"connected\"}\n\n"); err != nil { - r.log.Errorf("Failed to write connected event to response: %v", err) + r.log.Error(fmt.Sprintf("Failed to write connected event to response: %v", err)) } flusher.Flush() @@ -738,17 +738,17 @@ func (r *OpenAIRecorder) handleStreamingRequests(w http.ResponseWriter, req *htt // Send as SSE event. jsonData, err := json.Marshal(modelRecords) if err != nil { - r.log.Errorf("Failed to marshal record for streaming: %v", err) + r.log.Error(fmt.Sprintf("Failed to marshal record for streaming: %v", err)) errorMsg := fmt.Sprintf(`{"error": "Failed to marshal record: %v"}`, err) if _, writeErr := fmt.Fprintf(w, "event: error\ndata: %s\n\n", errorMsg); writeErr != nil { - r.log.Errorf("Failed to write error event to response: %v", writeErr) + r.log.Error(fmt.Sprintf("Failed to write error event to response: %v", writeErr)) } flusher.Flush() continue } if _, err := fmt.Fprintf(w, "event: new_request\ndata: %s\n\n", jsonData); err != nil { - r.log.Errorf("Failed to write new_request event to response: %v", err) + r.log.Error(fmt.Sprintf("Failed to write new_request event to response: %v", err)) } flusher.Flush() @@ -841,14 +841,14 @@ func (r *OpenAIRecorder) sendExistingRecords(w http.ResponseWriter, model string }} jsonData, err := json.Marshal(singleRecord) if err != nil { - r.log.Errorf("Failed to marshal existing record for streaming: %v", err) + r.log.Error(fmt.Sprintf("Failed to marshal existing record for streaming: %v", err)) errorMsg := fmt.Sprintf(`{"error": "Failed to marshal existing record: %v"}`, err) if _, writeErr := fmt.Fprintf(w, "event: error\ndata: %s\n\n", errorMsg); writeErr != nil { - r.log.Errorf("Failed to write error event to response: %v", writeErr) + r.log.Error(fmt.Sprintf("Failed to write error event to response: %v", writeErr)) } } else { if _, writeErr := fmt.Fprintf(w, "event: existing_request\ndata: %s\n\n", jsonData); writeErr != nil { - r.log.Errorf("Failed to write existing_request event to response: %v", writeErr) + r.log.Error(fmt.Sprintf("Failed to write existing_request event to response: %v", writeErr)) } } } @@ -863,8 +863,8 @@ func (r *OpenAIRecorder) RemoveModel(model string) { if _, exists := r.records[modelID]; exists { delete(r.records, modelID) - r.log.Infof("Removed records for model: %s", modelID) + r.log.Info(fmt.Sprintf("Removed records for model: %s", modelID)) } else { - r.log.Warnf("No records found for model: %s", modelID) + r.log.Warn(fmt.Sprintf("No records found for model: %s", modelID)) } } diff --git a/pkg/metrics/openai_recorder_test.go b/pkg/metrics/openai_recorder_test.go index e7ff5f8f0..00b6d167c 100644 --- a/pkg/metrics/openai_recorder_test.go +++ b/pkg/metrics/openai_recorder_test.go @@ -1,16 +1,18 @@ package metrics import ( + "fmt" "encoding/json" "testing" + "log/slog" + "github.com/docker/model-runner/pkg/inference/models" - "github.com/sirupsen/logrus" ) func TestTruncateMediaFields(t *testing.T) { // Create a mock logger and model manager - logger := logrus.New() + logger := slog.Default() modelManager := &models.Manager{} recorder := NewOpenAIRecorder(logger, modelManager) @@ -140,18 +142,18 @@ func TestTruncateMediaFields(t *testing.T) { if inputErr != nil { // For invalid JSON inputs, verify it's returned unchanged if resultStr != tt.expected { - t.Errorf("Invalid JSON should be returned unchanged. Expected %q, got %q", tt.expected, resultStr) + t.Error(fmt.Sprintf("Invalid JSON should be returned unchanged. Expected %q, got %q", tt.expected, resultStr)) } } else { // For valid JSON inputs, verify output is still valid JSON var resultJSON interface{} if err := json.Unmarshal(result, &resultJSON); err != nil { - t.Errorf("Result should be valid JSON, but got error: %v", err) + t.Error(fmt.Sprintf("Result should be valid JSON, but got error: %v", err)) } // Also check the content matches expected if resultStr != tt.expected { - t.Errorf("Expected result %q, but got %q", tt.expected, resultStr) + t.Error(fmt.Sprintf("Expected result %q, but got %q", tt.expected, resultStr)) } } }) @@ -159,7 +161,7 @@ func TestTruncateMediaFields(t *testing.T) { } func TestTruncateBase64Data(t *testing.T) { - logger := logrus.New() + logger := slog.Default() modelManager := &models.Manager{} recorder := NewOpenAIRecorder(logger, modelManager) @@ -199,7 +201,7 @@ func TestTruncateBase64Data(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result := recorder.truncateBase64Data(tt.input) if result != tt.expected { - t.Errorf("Expected %q, got %q", tt.expected, result) + t.Error(fmt.Sprintf("Expected %q, got %q", tt.expected, result)) } }) } diff --git a/pkg/metrics/scheduler_proxy.go b/pkg/metrics/scheduler_proxy.go index ca7183885..a52a730fe 100644 --- a/pkg/metrics/scheduler_proxy.go +++ b/pkg/metrics/scheduler_proxy.go @@ -1,6 +1,7 @@ package metrics import ( + "fmt" "io" "net" "net/http" @@ -40,7 +41,7 @@ func (h *SchedulerMetricsHandler) ServeHTTP(w http.ResponseWriter, r *http.Reque // Get the socket path for the active llama.cpp runner socket, err := h.scheduler.GetLlamaCppSocket() if err != nil { - h.log.Errorf("Failed to get llama.cpp socket: %v", err) + h.log.Error(fmt.Sprintf("Failed to get llama.cpp socket: %v", err)) http.Error(w, "Metrics endpoint not available", http.StatusServiceUnavailable) return } @@ -58,7 +59,7 @@ func (h *SchedulerMetricsHandler) ServeHTTP(w http.ResponseWriter, r *http.Reque // Create request to the backend metrics endpoint req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, "http://unix/metrics", http.NoBody) if err != nil { - h.log.Errorf("Failed to create metrics request: %v", err) + h.log.Error(fmt.Sprintf("Failed to create metrics request: %v", err)) http.Error(w, "Failed to create metrics request", http.StatusInternalServerError) return } @@ -73,7 +74,7 @@ func (h *SchedulerMetricsHandler) ServeHTTP(w http.ResponseWriter, r *http.Reque // Make the request to the backend resp, err := client.Do(req) if err != nil { - h.log.Errorf("Failed to fetch metrics from backend: %v", err) + h.log.Error(fmt.Sprintf("Failed to fetch metrics from backend: %v", err)) http.Error(w, "Backend metrics unavailable", http.StatusServiceUnavailable) return } @@ -91,9 +92,9 @@ func (h *SchedulerMetricsHandler) ServeHTTP(w http.ResponseWriter, r *http.Reque // Copy response body if _, err := io.Copy(w, resp.Body); err != nil { - h.log.Errorf("Failed to copy metrics response: %v", err) + h.log.Error(fmt.Sprintf("Failed to copy metrics response: %v", err)) return } - h.log.Debugf("Successfully proxied metrics request") + h.log.Debug("Successfully proxied metrics request") } diff --git a/pkg/ollama/http_handler.go b/pkg/ollama/http_handler.go index 445996f40..2f324ea0b 100644 --- a/pkg/ollama/http_handler.go +++ b/pkg/ollama/http_handler.go @@ -62,7 +62,7 @@ func NewHTTPHandler(log logging.Logger, scheduler *scheduling.Scheduler, schedul func (h *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { safeMethod := utils.SanitizeForLog(r.Method, -1) safePath := utils.SanitizeForLog(r.URL.Path, -1) - h.log.Infof("Ollama API request: %s %s", safeMethod, safePath) + h.log.Info(fmt.Sprintf("Ollama API request: %s %s", safeMethod, safePath)) h.httpHandler.ServeHTTP(w, r) } @@ -145,14 +145,14 @@ func (w *ollamaProgressWriter) Write(p []byte) (n int, err error) { return w.writer.Write(p) } // Unrecognized type, pass through to avoid losing information - w.log.Warnf("Unknown progress message type: %s", msg.Type) + w.log.Warn(fmt.Sprintf("Unknown progress message type: %s", msg.Type)) return w.writer.Write(p) } // Marshal and write ollama format data, err := json.Marshal(ollamaMsg) if err != nil { - w.log.Warnf("Failed to marshal ollama progress: %v", err) + w.log.Warn(fmt.Sprintf("Failed to marshal ollama progress: %v", err)) return w.writer.Write(p) } @@ -187,7 +187,7 @@ func (h *HTTPHandler) handleVersion(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { - h.log.Errorf("Failed to encode response: %v", err) + h.log.Error(fmt.Sprintf("Failed to encode response: %v", err)) } } @@ -196,7 +196,7 @@ func (h *HTTPHandler) handleListModels(w http.ResponseWriter, r *http.Request) { // Get models from the model manager modelsList, err := h.modelManager.List() if err != nil { - h.log.Errorf("Failed to list models: %v", err) + h.log.Error(fmt.Sprintf("Failed to list models: %v", err)) http.Error(w, "Failed to list models", http.StatusInternalServerError) return } @@ -243,7 +243,7 @@ func (h *HTTPHandler) handleListModels(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { - h.log.Errorf("Failed to encode response: %v", err) + h.log.Error(fmt.Sprintf("Failed to encode response: %v", err)) } } @@ -260,7 +260,7 @@ func (h *HTTPHandler) handlePS(w http.ResponseWriter, r *http.Request) { // Get model details to populate additional fields model, err := h.modelManager.GetLocal(backend.ModelName) if err != nil { - h.log.Warnf("Failed to get model details for %s: %v", backend.ModelName, err) + h.log.Warn(fmt.Sprintf("Failed to get model details for %s: %v", backend.ModelName, err)) // Still add the model with basic info models = append(models, PSModel{ Name: backend.ModelName, @@ -303,7 +303,7 @@ func (h *HTTPHandler) handlePS(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { - h.log.Errorf("Failed to encode response: %v", err) + h.log.Error(fmt.Sprintf("Failed to encode response: %v", err)) } } @@ -324,7 +324,7 @@ func (h *HTTPHandler) handleShowModel(w http.ResponseWriter, r *http.Request) { // Get model details model, err := h.modelManager.GetLocal(modelName) if err != nil { - h.log.Errorf("Failed to get model: %v", err) + h.log.Error(fmt.Sprintf("Failed to get model: %v", err)) http.Error(w, fmt.Sprintf("Model not found: %v", err), http.StatusNotFound) return } @@ -332,7 +332,7 @@ func (h *HTTPHandler) handleShowModel(w http.ResponseWriter, r *http.Request) { // Get config config, err := model.Config() if err != nil { - h.log.Errorf("Failed to get model config: %v", err) + h.log.Error(fmt.Sprintf("Failed to get model config: %v", err)) http.Error(w, fmt.Sprintf("Failed to get model config: %v", err), http.StatusInternalServerError) return } @@ -350,7 +350,7 @@ func (h *HTTPHandler) handleShowModel(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { - h.log.Errorf("Failed to encode response: %v", err) + h.log.Error(fmt.Sprintf("Failed to encode response: %v", err)) } } @@ -417,7 +417,7 @@ func (h *HTTPHandler) configureModel(ctx context.Context, modelName string, opti if hasContextSize || reasoningBudget != nil || hasKeepAlive { sanitizedModelName := utils.SanitizeForLog(modelName, -1) - h.log.Infof("configureModel: configuring model %s", sanitizedModelName) + h.log.Info(fmt.Sprintf("configureModel: configuring model %s", sanitizedModelName)) configureRequest := scheduling.ConfigureRequest{ Model: modelName, } @@ -434,12 +434,12 @@ func (h *HTTPHandler) configureModel(ctx context.Context, modelName string, opti if err == nil { configureRequest.KeepAlive = &ka } else { - h.log.Warnf("configureModel: invalid keep_alive %q: %v", keepAlive, err) + h.log.Warn(fmt.Sprintf("configureModel: invalid keep_alive %q: %v", keepAlive, err)) } } _, err := h.scheduler.ConfigureRunner(ctx, nil, configureRequest, userAgent) if err != nil { - h.log.Warnf("configureModel: failed to configure model %s: %v", sanitizedModelName, err) + h.log.Warn(fmt.Sprintf("configureModel: failed to configure model %s: %v", sanitizedModelName, err)) } } } @@ -456,7 +456,7 @@ func (h *HTTPHandler) handleGenerate(w http.ResponseWriter, r *http.Request) { var req GenerateRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - h.log.Errorf("handleGenerate: failed to decode request: %v", err) + h.log.Error(fmt.Sprintf("handleGenerate: failed to decode request: %v", err)) http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest) return } @@ -506,7 +506,7 @@ func (h *HTTPHandler) handleGenerate(w http.ResponseWriter, r *http.Request) { func (h *HTTPHandler) unloadModel(ctx context.Context, w http.ResponseWriter, modelName string) { // Sanitize user input before logging to prevent log injection sanitizedModelName := utils.SanitizeForLog(modelName, -1) - h.log.Infof("unloadModel: unloading model %s", sanitizedModelName) + h.log.Info(fmt.Sprintf("unloadModel: unloading model %s", sanitizedModelName)) // Create an unload request for the scheduler unloadReq := map[string]interface{}{ @@ -516,19 +516,19 @@ func (h *HTTPHandler) unloadModel(ctx context.Context, w http.ResponseWriter, mo // Marshal the unload request reqBody, err := json.Marshal(unloadReq) if err != nil { - h.log.Errorf("unloadModel: failed to marshal request: %v", err) + h.log.Error(fmt.Sprintf("unloadModel: failed to marshal request: %v", err)) http.Error(w, fmt.Sprintf("Failed to marshal request: %v", err), http.StatusInternalServerError) return } // Sanitize the user-provided request body before logging to avoid log injection safeReqBody := utils.SanitizeForLog(string(reqBody), -1) - h.log.Infof("unloadModel: sending POST /engines/unload with body: %s", safeReqBody) + h.log.Info(fmt.Sprintf("unloadModel: sending POST /engines/unload with body: %s", safeReqBody)) // Create a new request to the scheduler newReq, err := http.NewRequestWithContext(ctx, http.MethodPost, "/engines/unload", strings.NewReader(string(reqBody))) if err != nil { - h.log.Errorf("unloadModel: failed to create request: %v", err) + h.log.Error(fmt.Sprintf("unloadModel: failed to create request: %v", err)) http.Error(w, fmt.Sprintf("Failed to create request: %v", err), http.StatusInternalServerError) return } @@ -544,7 +544,7 @@ func (h *HTTPHandler) unloadModel(ctx context.Context, w http.ResponseWriter, mo // Forward to scheduler HTTP handler h.schedulerHTTP.ServeHTTP(respRecorder, newReq) - h.log.Infof("unloadModel: scheduler response status=%d, body=%s", respRecorder.statusCode, respRecorder.body.String()) + h.log.Info(fmt.Sprintf("unloadModel: scheduler response status=%d, body=%s", respRecorder.statusCode, respRecorder.body.String())) // Return the response status w.WriteHeader(respRecorder.statusCode) @@ -574,7 +574,7 @@ func (h *HTTPHandler) handleDelete(w http.ResponseWriter, r *http.Request) { } sanitizedModelName := utils.SanitizeForLog(modelName, -1) - h.log.Infof("handleDelete: deleting model %s", sanitizedModelName) + h.log.Info(fmt.Sprintf("handleDelete: deleting model %s", sanitizedModelName)) // First, unload the model from memory unloadReq := map[string]interface{}{ @@ -583,14 +583,14 @@ func (h *HTTPHandler) handleDelete(w http.ResponseWriter, r *http.Request) { reqBody, err := json.Marshal(unloadReq) if err != nil { - h.log.Errorf("handleDelete: failed to marshal unload request: %v", err) + h.log.Error(fmt.Sprintf("handleDelete: failed to marshal unload request: %v", err)) http.Error(w, fmt.Sprintf("Failed to marshal request: %v", err), http.StatusInternalServerError) return } newReq, err := http.NewRequestWithContext(ctx, http.MethodPost, "/engines/unload", strings.NewReader(string(reqBody))) if err != nil { - h.log.Errorf("handleDelete: failed to create unload request: %v", err) + h.log.Error(fmt.Sprintf("handleDelete: failed to create unload request: %v", err)) http.Error(w, fmt.Sprintf("Failed to create request: %v", err), http.StatusInternalServerError) return } @@ -603,17 +603,12 @@ func (h *HTTPHandler) handleDelete(w http.ResponseWriter, r *http.Request) { } h.schedulerHTTP.ServeHTTP(respRecorder, newReq) - h.log.Infof("handleDelete: unload response status=%d", respRecorder.statusCode) + h.log.Info(fmt.Sprintf("handleDelete: unload response status=%d", respRecorder.statusCode)) // Check if unload succeeded before deleting from storage if respRecorder.statusCode < 200 || respRecorder.statusCode >= 300 { sanitizedBody := utils.SanitizeForLog(respRecorder.body.String(), -1) - h.log.Errorf( - "handleDelete: unload failed for model %s with status=%d, body=%q", - sanitizedModelName, - respRecorder.statusCode, - sanitizedBody, - ) + h.log.Error(fmt.Sprintf("handleDelete: unload failed for model %s with status=%d, body=%q", sanitizedModelName, respRecorder.statusCode, sanitizedBody)) http.Error( w, fmt.Sprintf("Failed to unload model: scheduler returned status %d", respRecorder.statusCode), @@ -625,12 +620,12 @@ func (h *HTTPHandler) handleDelete(w http.ResponseWriter, r *http.Request) { // Then delete the model from storage if _, err := h.modelManager.Delete(modelName, false); err != nil { sanitizedErr := utils.SanitizeForLog(err.Error(), -1) - h.log.Errorf("handleDelete: failed to delete model %s: %v", sanitizedModelName, sanitizedErr) + h.log.Error(fmt.Sprintf("handleDelete: failed to delete model %s: %v", sanitizedModelName, sanitizedErr)) http.Error(w, fmt.Sprintf("Failed to delete model: %v", sanitizedErr), http.StatusInternalServerError) return } - h.log.Infof("handleDelete: successfully deleted model %s", sanitizedModelName) + h.log.Info(fmt.Sprintf("handleDelete: successfully deleted model %s", sanitizedModelName)) // Return success response in Ollama format (empty JSON object) w.Header().Set("Content-Type", "application/json") @@ -664,7 +659,7 @@ func (h *HTTPHandler) handlePull(w http.ResponseWriter, r *http.Request) { // Call the model manager's Pull method with the wrapped writer if err := h.modelManager.Pull(modelName, "", r, ollamaWriter); err != nil { - h.log.Errorf("Failed to pull model: %s", utils.SanitizeForLog(err.Error(), -1)) + h.log.Error(fmt.Sprintf("Failed to pull model: %s", utils.SanitizeForLog(err.Error(), -1))) // Send error in Ollama JSON format errorResponse := ollamaPullStatus{ @@ -676,7 +671,7 @@ func (h *HTTPHandler) handlePull(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusInternalServerError) if err := json.NewEncoder(w).Encode(errorResponse); err != nil { - h.log.Errorf("failed to encode response: %v", err) + h.log.Error(fmt.Sprintf("failed to encode response: %v", err)) } } else { // Headers already sent - write error as JSON line @@ -1053,7 +1048,7 @@ func (s *streamingChatResponseWriter) Write(data []byte) (int, error) { // Parse OpenAI chunk using proper struct var chunk openAIChatStreamChunk if err := json.Unmarshal([]byte(dataStr), &chunk); err != nil { - s.log.Warnf("Failed to parse OpenAI chat stream chunk: %v", err) + s.log.Warn(fmt.Sprintf("Failed to parse OpenAI chat stream chunk: %v", err)) continue } @@ -1175,7 +1170,7 @@ func (s *streamingGenerateResponseWriter) Write(data []byte) (int, error) { // Parse OpenAI chunk using proper struct var chunk openAIChatStreamChunk if err := json.Unmarshal([]byte(dataStr), &chunk); err != nil { - s.log.Warnf("Failed to parse OpenAI chat stream chunk: %v", err) + s.log.Warn(fmt.Sprintf("Failed to parse OpenAI chat stream chunk: %v", err)) continue } @@ -1222,7 +1217,7 @@ func (h *HTTPHandler) convertChatResponse(w http.ResponseWriter, respRecorder *r // Convert to Ollama error format (simple string) w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(map[string]string{"error": openAIErr.Error.Message}); err != nil { - h.log.Errorf("failed to encode response: %v", err) + h.log.Error(fmt.Sprintf("failed to encode response: %v", err)) } } else { // Fallback: return raw error body @@ -1234,7 +1229,7 @@ func (h *HTTPHandler) convertChatResponse(w http.ResponseWriter, respRecorder *r // Parse OpenAI response using proper struct var openAIResp openAIChatResponse if err := json.Unmarshal([]byte(respRecorder.body.String()), &openAIResp); err != nil { - h.log.Errorf("Failed to parse OpenAI response: %v", err) + h.log.Error(fmt.Sprintf("Failed to parse OpenAI response: %v", err)) http.Error(w, "Failed to parse response", http.StatusInternalServerError) return } @@ -1264,7 +1259,7 @@ func (h *HTTPHandler) convertChatResponse(w http.ResponseWriter, respRecorder *r w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { - h.log.Errorf("Failed to encode response: %v", err) + h.log.Error(fmt.Sprintf("Failed to encode response: %v", err)) } } @@ -1309,7 +1304,7 @@ func (h *HTTPHandler) convertGenerateResponse(w http.ResponseWriter, respRecorde // Convert to Ollama error format (simple string) w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(map[string]string{"error": openAIErr.Error.Message}); err != nil { - h.log.Errorf("failed to encode response: %v", err) + h.log.Error(fmt.Sprintf("failed to encode response: %v", err)) } } else { // Fallback: return raw error body @@ -1321,7 +1316,7 @@ func (h *HTTPHandler) convertGenerateResponse(w http.ResponseWriter, respRecorde // Parse OpenAI chat response (since we're now using chat completions endpoint) var openAIResp openAIChatResponse if err := json.Unmarshal([]byte(respRecorder.body.String()), &openAIResp); err != nil { - h.log.Errorf("Failed to parse OpenAI chat response: %v", err) + h.log.Error(fmt.Sprintf("Failed to parse OpenAI chat response: %v", err)) http.Error(w, "Failed to parse response", http.StatusInternalServerError) return } @@ -1345,6 +1340,6 @@ func (h *HTTPHandler) convertGenerateResponse(w http.ResponseWriter, respRecorde w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { - h.log.Errorf("Failed to encode response: %v", err) + h.log.Error(fmt.Sprintf("Failed to encode response: %v", err)) } } diff --git a/pkg/responses/handler.go b/pkg/responses/handler.go index 125efe6c4..bb4c50cd8 100644 --- a/pkg/responses/handler.go +++ b/pkg/responses/handler.go @@ -55,7 +55,7 @@ func NewHTTPHandler(log logging.Logger, schedulerHTTP http.Handler, allowedOrigi func (h *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { cleanPath := strings.ReplaceAll(r.URL.Path, "\n", "") cleanPath = strings.ReplaceAll(cleanPath, "\r", "") - h.log.Infof("Responses API request: %s %s", r.Method, cleanPath) + h.log.Info(fmt.Sprintf("Responses API request: %s %s", r.Method, cleanPath)) h.httpHandler.ServeHTTP(w, r) } @@ -305,7 +305,7 @@ func (h *HTTPHandler) sendJSON(w http.ResponseWriter, statusCode int, data inter w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) if err := json.NewEncoder(w).Encode(data); err != nil { - h.log.Errorf("Failed to encode JSON response: %v", err) + h.log.Error(fmt.Sprintf("Failed to encode JSON response: %v", err)) } } diff --git a/pkg/responses/handler_test.go b/pkg/responses/handler_test.go index 0968b3ac3..7e586f39d 100644 --- a/pkg/responses/handler_test.go +++ b/pkg/responses/handler_test.go @@ -1,6 +1,8 @@ package responses import ( + "fmt" + "log/slog" "bytes" "encoding/json" "io" @@ -9,7 +11,6 @@ import ( "strings" "testing" - "github.com/sirupsen/logrus" ) // mockSchedulerHTTP is a mock scheduler that returns predefined responses. @@ -39,8 +40,8 @@ func (m *mockSchedulerHTTP) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func newTestHandler(mock *mockSchedulerHTTP) *HTTPHandler { - log := logrus.New() - log.SetOutput(io.Discard) + log := slog.Default() + // log output is controlled by the slog handler level return NewHTTPHandler(log, mock, nil) } @@ -85,28 +86,28 @@ func TestHandler_CreateResponse_NonStreaming(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusOK { - t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusOK)) } var result Response if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - t.Fatalf("failed to decode response: %v", err) + t.Error(fmt.Sprintf("failed to decode response: %v", err)) } if result.Object != "response" { - t.Errorf("object = %s, want response", result.Object) + t.Error(fmt.Sprintf("object = %s, want response", result.Object)) } if result.Model != "gpt-4" { - t.Errorf("model = %s, want gpt-4", result.Model) + t.Error(fmt.Sprintf("model = %s, want gpt-4", result.Model)) } if result.Status != StatusCompleted { - t.Errorf("status = %s, want %s", result.Status, StatusCompleted) + t.Error(fmt.Sprintf("status = %s, want %s", result.Status, StatusCompleted)) } if result.OutputText != "Hello! How can I help you?" { - t.Errorf("output_text = %s, want Hello! How can I help you?", result.OutputText) + t.Error(fmt.Sprintf("output_text = %s, want Hello! How can I help you?", result.OutputText)) } if !strings.HasPrefix(result.ID, "resp_") { - t.Errorf("id should start with resp_, got %s", result.ID) + t.Error(fmt.Sprintf("id should start with resp_, got %s", result.ID)) } } @@ -124,7 +125,7 @@ func TestHandler_CreateResponse_MissingModel(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusBadRequest { - t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusBadRequest) + t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusBadRequest)) } var errResp map[string]interface{} @@ -147,7 +148,7 @@ func TestHandler_CreateResponse_InvalidJSON(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusBadRequest { - t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusBadRequest) + t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusBadRequest)) } } @@ -170,19 +171,19 @@ func TestHandler_GetResponse(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusOK { - t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusOK)) } var result Response if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - t.Fatalf("failed to decode response: %v", err) + t.Error(fmt.Sprintf("failed to decode response: %v", err)) } if result.ID != "resp_test123" { - t.Errorf("id = %s, want resp_test123", result.ID) + t.Error(fmt.Sprintf("id = %s, want resp_test123", result.ID)) } if result.OutputText != "Test output" { - t.Errorf("output_text = %s, want Test output", result.OutputText) + t.Error(fmt.Sprintf("output_text = %s, want Test output", result.OutputText)) } } @@ -198,7 +199,7 @@ func TestHandler_GetResponse_NotFound(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusNotFound { - t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusNotFound) + t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusNotFound)) } } @@ -219,7 +220,7 @@ func TestHandler_DeleteResponse(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusOK { - t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusOK)) } // Verify it's deleted @@ -241,7 +242,7 @@ func TestHandler_DeleteResponse_NotFound(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusNotFound { - t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusNotFound) + t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusNotFound)) } } @@ -299,14 +300,14 @@ func TestHandler_CreateResponse_WithPreviousResponse(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - t.Fatalf("status = %d, want %d, body: %s", resp.StatusCode, http.StatusOK, body) + t.Error(fmt.Sprintf("status = %d, want %d, body: %s", resp.StatusCode, http.StatusOK, body)) } var result Response json.NewDecoder(resp.Body).Decode(&result) if result.PreviousResponseID == nil || *result.PreviousResponseID != "resp_prev123" { - t.Errorf("previous_response_id = %v, want resp_prev123", result.PreviousResponseID) + t.Error(fmt.Sprintf("previous_response_id = %v, want resp_prev123", result.PreviousResponseID)) } } @@ -337,14 +338,14 @@ func TestHandler_CreateResponse_UpstreamError(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusInternalServerError { - t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusInternalServerError) + t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusInternalServerError)) } var result Response json.NewDecoder(resp.Body).Decode(&result) if result.Status != StatusFailed { - t.Errorf("status = %s, want %s", result.Status, StatusFailed) + t.Error(fmt.Sprintf("status = %s, want %s", result.Status, StatusFailed)) } if result.Error == nil { t.Error("expected error to be set") @@ -373,7 +374,7 @@ func TestHandler_CreateResponse_UpstreamError_NonJSONBody(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusInternalServerError { - t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusInternalServerError) + t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusInternalServerError)) } var result Response @@ -381,19 +382,19 @@ func TestHandler_CreateResponse_UpstreamError_NonJSONBody(t *testing.T) { // Assert: non-streaming error handling falls back correctly if result.Status != StatusFailed { - t.Errorf("status = %s, want %s", result.Status, StatusFailed) + t.Error(fmt.Sprintf("status = %s, want %s", result.Status, StatusFailed)) } if result.Error == nil { - t.Fatalf("expected error, got nil") + t.Error("expected error, got nil") } if result.Error.Code != "upstream_error" { - t.Errorf("error.code = %v, want upstream_error", result.Error.Code) + t.Error(fmt.Sprintf("error.code = %v, want upstream_error", result.Error.Code)) } if !strings.Contains(result.Error.Message, "upstream exploded in a non-json way") { - t.Errorf("error.message = %q, want to contain raw upstream body", result.Error.Message) + t.Error(fmt.Sprintf("error.message = %q, want to contain raw upstream body", result.Error.Message)) } } @@ -426,19 +427,19 @@ func TestHandler_CreateResponse_Streaming(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusOK { - t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusOK)) } // Check content type is SSE contentType := resp.Header.Get("Content-Type") if !strings.Contains(contentType, "text/event-stream") { - t.Errorf("Content-Type = %s, want text/event-stream", contentType) + t.Error(fmt.Sprintf("Content-Type = %s, want text/event-stream", contentType)) } // Read all body body, err := io.ReadAll(resp.Body) if err != nil { - t.Fatalf("failed to read body: %v", err) + t.Error(fmt.Sprintf("failed to read body: %v", err)) } bodyStr := string(body) @@ -517,7 +518,7 @@ func TestHandler_CreateResponse_WithTools(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - t.Fatalf("status = %d, want %d, body: %s", resp.StatusCode, http.StatusOK, body) + t.Error(fmt.Sprintf("status = %d, want %d, body: %s", resp.StatusCode, http.StatusOK, body)) } var result Response @@ -541,10 +542,10 @@ func TestHandler_CreateResponse_WithTools(t *testing.T) { } if funcCall.Name != "get_weather" { - t.Errorf("function name = %s, want get_weather", funcCall.Name) + t.Error(fmt.Sprintf("function name = %s, want get_weather", funcCall.Name)) } if funcCall.CallID != "call_abc123" { - t.Errorf("call_id = %s, want call_abc123", funcCall.CallID) + t.Error(fmt.Sprintf("call_id = %s, want call_abc123", funcCall.CallID)) } } @@ -589,7 +590,7 @@ func TestHandler_ResponsePersistence(t *testing.T) { json.NewDecoder(w2.Result().Body).Decode(&getResult) if getResult.ID != createResult.ID { - t.Errorf("IDs don't match: %s vs %s", getResult.ID, createResult.ID) + t.Error(fmt.Sprintf("IDs don't match: %s vs %s", getResult.ID, createResult.ID)) } } @@ -623,20 +624,20 @@ func TestHandler_CreateResponse_Streaming_Persistence(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusOK { - t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusOK)) } // Verify that the StreamingResponseWriter persisted a coherent Response in the store memStore := handler.store if memStore.Count() != 1 { - t.Fatalf("expected exactly one response in store, got %d", memStore.Count()) + t.Error(fmt.Sprintf("expected exactly one response in store, got %d", memStore.Count())) } // Get the response ID from the store responseIDs := memStore.GetResponseIDs() if len(responseIDs) != 1 { - t.Fatalf("expected exactly one response ID in store, got %d", len(responseIDs)) + t.Error(fmt.Sprintf("expected exactly one response ID in store, got %d", len(responseIDs))) } // Retrieve the response using the public API @@ -647,12 +648,12 @@ func TestHandler_CreateResponse_Streaming_Persistence(t *testing.T) { // Status should be completed after streaming finishes if persistedResp.Status != StatusCompleted { - t.Errorf("persisted response status = %s, want %s", persistedResp.Status, StatusCompleted) + t.Error(fmt.Sprintf("persisted response status = %s, want %s", persistedResp.Status, StatusCompleted)) } // OutputText should match concatenated streamed chunks: "Hello" + "!" => "Hello!" if persistedResp.OutputText != "Hello!" { - t.Errorf("persisted response OutputText = %q, want %q", persistedResp.OutputText, "Hello!") + t.Error(fmt.Sprintf("persisted response OutputText = %q, want %q", persistedResp.OutputText, "Hello!")) } // There should be at least one OutputItem whose message content matches "Hello!" @@ -673,7 +674,7 @@ func TestHandler_CreateResponse_Streaming_Persistence(t *testing.T) { } } if !found { - t.Errorf("expected an OutputItem message with text %q in persisted response", "Hello!") + t.Error(fmt.Sprintf("expected an OutputItem message with text %q in persisted response", "Hello!")) } } diff --git a/vllm_backend.go b/vllm_backend.go index 5697cb77c..404233582 100644 --- a/vllm_backend.go +++ b/vllm_backend.go @@ -3,17 +3,18 @@ package main import ( + "log/slog" + "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/backends/vllm" "github.com/docker/model-runner/pkg/inference/models" - "github.com/sirupsen/logrus" ) -func initVLLMBackend(log *logrus.Logger, modelManager *models.Manager, customBinaryPath string) (inference.Backend, error) { +func initVLLMBackend(log *slog.Logger, modelManager *models.Manager, customBinaryPath string) (inference.Backend, error) { return vllm.New( log, modelManager, - log.WithFields(logrus.Fields{"component": vllm.Name}), + log.With("component", vllm.Name), nil, customBinaryPath, ) diff --git a/vllm_backend_stub.go b/vllm_backend_stub.go index 64937cc16..86bafd230 100644 --- a/vllm_backend_stub.go +++ b/vllm_backend_stub.go @@ -3,12 +3,13 @@ package main import ( + "log/slog" + "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/models" - "github.com/sirupsen/logrus" ) -func initVLLMBackend(log *logrus.Logger, modelManager *models.Manager, customBinaryPath string) (inference.Backend, error) { +func initVLLMBackend(log *slog.Logger, modelManager *models.Manager, customBinaryPath string) (inference.Backend, error) { return nil, nil } From bdb90f5ff9da2cbd79f545256c13dc2af64534e4 Mon Sep 17 00:00:00 2001 From: Varun Chawla Date: Fri, 13 Feb 2026 19:42:05 -0800 Subject: [PATCH 2/4] Convert all fmt.Sprintf log calls to slog structured key-value pairs Address review feedback: replace fmt.Sprintf anti-pattern across the entire codebase with proper slog structured key-value args so log aggregation tools can parse/filter/query individual fields. Also fixes: - t.Error(fmt.Sprintf(...)) -> t.Errorf(...) (staticcheck S1038) - Import ordering (gci: standard first, then third-party) - Race condition in testregistry (concurrent map access in handleBlobUpload) - Removed unused fmt imports --- main.go | 4 +- pkg/anthropic/handler.go | 5 +- pkg/anthropic/handler_test.go | 36 ++- pkg/distribution/distribution/client.go | 50 +-- pkg/distribution/distribution/client_test.go | 288 +++++++++--------- .../distribution/normalize_test.go | 32 +- .../registry/testregistry/registry.go | 2 + pkg/inference/backends/diffusers/diffusers.go | 6 +- pkg/inference/backends/llamacpp/download.go | 16 +- pkg/inference/backends/llamacpp/llamacpp.go | 8 +- pkg/inference/backends/mlx/mlx.go | 2 +- pkg/inference/backends/sglang/sglang.go | 2 +- pkg/inference/backends/vllm/vllm.go | 2 +- pkg/inference/backends/vllmmetal/vllmmetal.go | 8 +- pkg/inference/models/handler_test.go | 48 ++- pkg/inference/models/http_handler.go | 21 +- pkg/inference/models/manager.go | 14 +- pkg/inference/scheduling/http_handler.go | 6 +- pkg/inference/scheduling/installer.go | 3 +- pkg/inference/scheduling/loader.go | 30 +- pkg/inference/scheduling/loader_test.go | 26 +- pkg/inference/scheduling/runner.go | 8 +- pkg/inference/scheduling/scheduler.go | 12 +- pkg/metrics/aggregated_handler.go | 4 +- pkg/metrics/openai_recorder.go | 24 +- pkg/metrics/openai_recorder_test.go | 12 +- pkg/metrics/scheduler_proxy.go | 9 +- pkg/ollama/http_handler.go | 74 ++--- pkg/responses/handler.go | 4 +- pkg/responses/handler_test.go | 78 +++-- 30 files changed, 410 insertions(+), 424 deletions(-) diff --git a/main.go b/main.go index dc4488e0b..546d44ea2 100644 --- a/main.go +++ b/main.go @@ -4,9 +4,9 @@ import ( "fmt" "context" "crypto/tls" + "log/slog" "net" "net/http" - "log/slog" "os" "os/signal" "path/filepath" @@ -15,7 +15,6 @@ import ( "time" "github.com/docker/model-runner/pkg/anthropic" - "github.com/docker/model-runner/pkg/logging" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/backends/diffusers" "github.com/docker/model-runner/pkg/inference/backends/llamacpp" @@ -27,6 +26,7 @@ import ( "github.com/docker/model-runner/pkg/inference/models" "github.com/docker/model-runner/pkg/inference/platform" "github.com/docker/model-runner/pkg/inference/scheduling" + "github.com/docker/model-runner/pkg/logging" "github.com/docker/model-runner/pkg/metrics" "github.com/docker/model-runner/pkg/middleware" "github.com/docker/model-runner/pkg/ollama" diff --git a/pkg/anthropic/handler.go b/pkg/anthropic/handler.go index a6b4262a2..323bec76e 100644 --- a/pkg/anthropic/handler.go +++ b/pkg/anthropic/handler.go @@ -1,7 +1,6 @@ package anthropic import ( - "fmt" "bytes" "encoding/json" "errors" @@ -60,7 +59,7 @@ func NewHandler(log logging.Logger, schedulerHTTP *scheduling.HTTPHandler, allow func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { safeMethod := utils.SanitizeForLog(r.Method, -1) safePath := utils.SanitizeForLog(r.URL.Path, -1) - h.log.Info(fmt.Sprintf("Anthropic API request: %s %s", safeMethod, safePath)) + h.log.Info("Anthropic API request", "method", safeMethod, "path", safePath) h.httpHandler.ServeHTTP(w, r) } @@ -170,6 +169,6 @@ func (h *Handler) writeAnthropicError(w http.ResponseWriter, statusCode int, err } if err := json.NewEncoder(w).Encode(errResp); err != nil { - h.log.Error(fmt.Sprintf("Failed to encode error response: %v", err)) + h.log.Error("Failed to encode error response", "error", err) } } diff --git a/pkg/anthropic/handler_test.go b/pkg/anthropic/handler_test.go index 14cbdb17b..39a575142 100644 --- a/pkg/anthropic/handler_test.go +++ b/pkg/anthropic/handler_test.go @@ -1,13 +1,11 @@ package anthropic import ( - "fmt" "log/slog" "net/http" "net/http/httptest" "strings" "testing" - ) func TestWriteAnthropicError(t *testing.T) { @@ -54,16 +52,16 @@ func TestWriteAnthropicError(t *testing.T) { h.writeAnthropicError(rec, tt.statusCode, tt.errorType, tt.message) if rec.Code != tt.statusCode { - t.Error(fmt.Sprintf("expected status %d, got %d", tt.statusCode, rec.Code)) + t.Errorf("expected status %d, got %d", tt.statusCode, rec.Code) } if contentType := rec.Header().Get("Content-Type"); contentType != "application/json" { - t.Error(fmt.Sprintf("expected Content-Type application/json, got %s", contentType)) + t.Errorf("expected Content-Type application/json, got %s", contentType) } body := strings.TrimSpace(rec.Body.String()) if body != tt.wantBody { - t.Error(fmt.Sprintf("expected body %s, got %s", tt.wantBody, body)) + t.Errorf("expected body %s, got %s", tt.wantBody, body) } }) } @@ -85,12 +83,12 @@ func TestRouteHandlers(t *testing.T) { for _, route := range expectedRoutes { if _, exists := routes[route]; !exists { - t.Error(fmt.Sprintf("expected route %s to be registered", route)) + t.Errorf("expected route %s to be registered", route) } } if len(routes) != len(expectedRoutes) { - t.Error(fmt.Sprintf("expected %d routes, got %d", len(expectedRoutes), len(routes))) + t.Errorf("expected %d routes, got %d", len(expectedRoutes), len(routes)) } } @@ -98,7 +96,7 @@ func TestAPIPrefix(t *testing.T) { t.Parallel() if APIPrefix != "/anthropic" { - t.Error(fmt.Sprintf("expected APIPrefix to be /anthropic, got %s", APIPrefix)) + t.Errorf("expected APIPrefix to be /anthropic, got %s", APIPrefix) } } @@ -115,15 +113,15 @@ func TestProxyToBackend_InvalidJSON(t *testing.T) { h.proxyToBackend(rec, req, "/v1/messages") if rec.Code != http.StatusBadRequest { - t.Error(fmt.Sprintf("expected status %d, got %d", http.StatusBadRequest, rec.Code)) + t.Errorf("expected status %d, got %d", http.StatusBadRequest, rec.Code) } body := rec.Body.String() if !strings.Contains(body, "invalid_request_error") { - t.Error(fmt.Sprintf("expected body to contain 'invalid_request_error', got %s", body)) + t.Errorf("expected body to contain 'invalid_request_error', got %s", body) } if !strings.Contains(body, "Invalid JSON") { - t.Error(fmt.Sprintf("expected body to contain 'Invalid JSON', got %s", body)) + t.Errorf("expected body to contain 'Invalid JSON', got %s", body) } } @@ -140,15 +138,15 @@ func TestProxyToBackend_MissingModel(t *testing.T) { h.proxyToBackend(rec, req, "/v1/messages") if rec.Code != http.StatusBadRequest { - t.Error(fmt.Sprintf("expected status %d, got %d", http.StatusBadRequest, rec.Code)) + t.Errorf("expected status %d, got %d", http.StatusBadRequest, rec.Code) } body := rec.Body.String() if !strings.Contains(body, "invalid_request_error") { - t.Error(fmt.Sprintf("expected body to contain 'invalid_request_error', got %s", body)) + t.Errorf("expected body to contain 'invalid_request_error', got %s", body) } if !strings.Contains(body, "Missing required field: model") { - t.Error(fmt.Sprintf("expected body to contain 'Missing required field: model', got %s", body)) + t.Errorf("expected body to contain 'Missing required field: model', got %s", body) } } @@ -165,15 +163,15 @@ func TestProxyToBackend_EmptyModel(t *testing.T) { h.proxyToBackend(rec, req, "/v1/messages") if rec.Code != http.StatusBadRequest { - t.Error(fmt.Sprintf("expected status %d, got %d", http.StatusBadRequest, rec.Code)) + t.Errorf("expected status %d, got %d", http.StatusBadRequest, rec.Code) } body := rec.Body.String() if !strings.Contains(body, "invalid_request_error") { - t.Error(fmt.Sprintf("expected body to contain 'invalid_request_error', got %s", body)) + t.Errorf("expected body to contain 'invalid_request_error', got %s", body) } if !strings.Contains(body, "Missing required field: model") { - t.Error(fmt.Sprintf("expected body to contain 'Missing required field: model', got %s", body)) + t.Errorf("expected body to contain 'Missing required field: model', got %s", body) } } @@ -194,11 +192,11 @@ func TestProxyToBackend_RequestTooLarge(t *testing.T) { h.proxyToBackend(rec, req, "/v1/messages") if rec.Code != http.StatusRequestEntityTooLarge { - t.Error(fmt.Sprintf("expected status %d, got %d", http.StatusRequestEntityTooLarge, rec.Code)) + t.Errorf("expected status %d, got %d", http.StatusRequestEntityTooLarge, rec.Code) } body := rec.Body.String() if !strings.Contains(body, "request_too_large") { - t.Error(fmt.Sprintf("expected body to contain 'request_too_large', got %s", body)) + t.Errorf("expected body to contain 'request_too_large', got %s", body) } } diff --git a/pkg/distribution/distribution/client.go b/pkg/distribution/distribution/client.go index 349f8699a..21da86d3e 100644 --- a/pkg/distribution/distribution/client.go +++ b/pkg/distribution/distribution/client.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "log/slog" "os" "slices" "strings" @@ -21,7 +22,6 @@ import ( "github.com/docker/model-runner/pkg/distribution/types" "github.com/docker/model-runner/pkg/inference/platform" "github.com/docker/model-runner/pkg/internal/utils" - "log/slog" ) // Client provides model distribution functionality @@ -111,7 +111,7 @@ func NewClient(opts ...Option) (*Client, error) { // Migrate any legacy hf.co tags to huggingface.co if err := c.migrateHFTags(); err != nil { - options.logger.Warn(fmt.Sprintf("Failed to migrate HuggingFace tags: %v", err)) + options.logger.Warn("Failed to migrate HuggingFace tags", "error", err) } return c, nil @@ -131,7 +131,7 @@ func (c *Client) migrateHFTags() error { return err } if migrated > 0 { - c.log.Info(fmt.Sprintf("Migrated %d HuggingFace tag(s) from hf.co to huggingface.co", migrated)) + c.log.Info("Migrated HuggingFace tag(s) from hf.co to huggingface.co", "count", migrated) } return nil } @@ -286,7 +286,7 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter return fmt.Errorf("getting cached model config: %w", err) } if err := progress.WriteSuccess(progressWriter, fmt.Sprintf("Using cached model: %s", cfg.GetSize()), oci.ModePull); err != nil { - c.log.Warn(fmt.Sprintf("Writing progress: %v", err)) + c.log.Warn("Writing progress", "error", err) } return nil } @@ -337,25 +337,25 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter for _, layer := range layers { digest, err := layer.Digest() if err != nil { - c.log.Warn(fmt.Sprintf("Failed to get layer digest: %v", err)) + c.log.Warn("Failed to get layer digest", "error", err) continue } // Check if there's an incomplete download for this layer (use DiffID for uncompressed models) diffID, err := layer.DiffID() if err != nil { - c.log.Warn(fmt.Sprintf("Failed to get layer diffID: %v", err)) + c.log.Warn("Failed to get layer diffID", "error", err) continue } incompleteSize, err := c.store.GetIncompleteSize(diffID) if err != nil { - c.log.Warn(fmt.Sprintf("Failed to check incomplete size for layer %s: %v", digest, err)) + c.log.Warn("Failed to check incomplete size for layer", "digest", digest, "error", err) continue } if incompleteSize > 0 { - c.log.Info(fmt.Sprintf("Found incomplete download for layer %s: %d bytes", digest, incompleteSize)) + c.log.Info("Found incomplete download for layer", "digest", digest, "bytes", incompleteSize) resumeOffsets[digest.String()] = incompleteSize } } @@ -364,14 +364,14 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter // and re-fetch using the original reference to ensure compatibility with all registries var rangeSuccess *remote.RangeSuccess if len(resumeOffsets) > 0 { - c.log.Info(fmt.Sprintf("Resuming %d interrupted layer download(s)", len(resumeOffsets))) + c.log.Info("Resuming interrupted layer download(s)", "count", len(resumeOffsets)) // Create a RangeSuccess tracker to record which Range requests succeed rangeSuccess = &remote.RangeSuccess{} ctx = remote.WithResumeOffsets(ctx, resumeOffsets) ctx = remote.WithRangeSuccess(ctx, rangeSuccess) // Re-fetch the model using the original tag reference // The digest has already been validated above, and the resume context will handle layer resumption - c.log.Info(fmt.Sprintf("Re-fetching model with original reference for resume: %s", utils.SanitizeForLog(reference))) + c.log.Info("Re-fetching model with original reference for resume", "model", utils.SanitizeForLog(reference)) remoteModel, err = registryClient.Model(ctx, reference) if err != nil { return fmt.Errorf("reading model from registry with resume context: %w", err) @@ -394,7 +394,7 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter err = progress.WriteSuccess(progressWriter, fmt.Sprintf("Using cached model: %s", cfg.GetSize()), oci.ModePull) if err != nil { - c.log.Warn(fmt.Sprintf("Writing progress: %v", err)) + c.log.Warn("Writing progress", "error", err) } // Ensure model has the correct tag @@ -415,13 +415,13 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter } if err = c.store.Write(remoteModel, []string{reference}, progressWriter, writeOpts...); err != nil { if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error()), oci.ModePull); writeErr != nil { - c.log.Warn(fmt.Sprintf("Failed to write error message: %v", writeErr)) + c.log.Warn("Failed to write error message", "error", writeErr) } return fmt.Errorf("writing image to store: %w", err) } if err := progress.WriteSuccess(progressWriter, "Model pulled successfully", oci.ModePull); err != nil { - c.log.Warn(fmt.Sprintf("Failed to write success message: %v", err)) + c.log.Warn("Failed to write success message", "error", err) } return nil @@ -439,7 +439,7 @@ func (c *Client) LoadModel(r io.Reader, progressWriter io.Writer) (string, error } if err != nil { if errors.Is(err, io.ErrUnexpectedEOF) { - c.log.Info(fmt.Sprintf("Model load interrupted (likely cancelled): %s", utils.SanitizeForLog(err.Error()))) + c.log.Info("Model load interrupted (likely cancelled)", "error", utils.SanitizeForLog(err.Error())) return "", fmt.Errorf("model load interrupted: %w", err) } return "", fmt.Errorf("reading blob from stream: %w", err) @@ -462,7 +462,7 @@ func (c *Client) LoadModel(r io.Reader, progressWriter io.Writer) (string, error c.log.Info("loaded model", "id", digest.String()) if err := progress.WriteSuccess(progressWriter, "Model loaded successfully", oci.ModePull); err != nil { - c.log.Warn(fmt.Sprintf("Failed to write success message: %v", err)) + c.log.Warn("Failed to write success message", "error", err) } return digest.String(), nil @@ -482,7 +482,7 @@ func (c *Client) ListModels() ([]types.Model, error) { // Read the models model, err := c.store.Read(modelInfo.ID) if err != nil { - c.log.Warn(fmt.Sprintf("Failed to read model with ID %s: %v", modelInfo.ID, err)) + c.log.Warn("Failed to read model with ID", "model", modelInfo.ID, "error", err) continue } result = append(result, model) @@ -608,14 +608,14 @@ func (c *Client) PushModel(ctx context.Context, tag string, progressWriter io.Wr if err := target.Write(ctx, mdl, progressWriter); err != nil { c.log.Error("failed to push image", "error", err, "reference", tag) if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error()), oci.ModePush); writeErr != nil { - c.log.Warn(fmt.Sprintf("Failed to write error message: %v", writeErr)) + c.log.Warn("Failed to write error message", "error", writeErr) } return fmt.Errorf("pushing image: %w", err) } c.log.Info("successfully pushed model", "tag", tag) if err := progress.WriteSuccess(progressWriter, "Model pushed successfully", oci.ModePush); err != nil { - c.log.Warn(fmt.Sprintf("Failed to write success message: %v", err)) + c.log.Warn("Failed to write success message", "error", err) } return nil @@ -724,12 +724,12 @@ func checkCompat(image types.ModelArtifact, log *slog.Logger, reference string, } if config.GetFormat() == "" { - log.Warn(fmt.Sprintf("Model format field is empty for %s, unable to verify format compatibility", utils.SanitizeForLog(reference))) + log.Warn("Model format field is empty for , unable to verify format compatibility", "model", utils.SanitizeForLog(reference)) } else if !slices.Contains(GetSupportedFormats(), config.GetFormat()) { // Write warning but continue with pull log.Warn(warnUnsupportedFormat) if err := progress.WriteWarning(progressWriter, warnUnsupportedFormat, oci.ModePull); err != nil { - log.Warn(fmt.Sprintf("Failed to write warning message: %v", err)) + log.Warn("Failed to write warning message", "error", err) } // Don't return an error - allow the pull to continue } @@ -774,7 +774,7 @@ func parseHFReference(reference string) (repo, revision, tag string) { // This is used when the model is stored as raw files (safetensors) on HuggingFace Hub func (c *Client) pullNativeHuggingFace(ctx context.Context, reference string, progressWriter io.Writer, token string) error { repo, revision, tag := parseHFReference(reference) - c.log.Info(fmt.Sprintf("Pulling native HuggingFace model: repo=%s, revision=%s, tag=%s", utils.SanitizeForLog(repo), utils.SanitizeForLog(revision), utils.SanitizeForLog(tag))) + c.log.Info("Pulling native HuggingFace model", "repo", utils.SanitizeForLog(repo), "revision", utils.SanitizeForLog(revision), "tag", utils.SanitizeForLog(tag)) // Create HuggingFace client hfOpts := []huggingface.ClientOption{ @@ -806,23 +806,23 @@ func (c *Client) pullNativeHuggingFace(ctx context.Context, reference string, pr return registry.ErrModelNotFound } if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error()), oci.ModePull); writeErr != nil { - c.log.Warn(fmt.Sprintf("Failed to write error message: %v", writeErr)) + c.log.Warn("Failed to write error message", "error", writeErr) } return fmt.Errorf("build model from HuggingFace: %w", err) } // Write model to store with normalized tag storageTag := c.normalizeModelName(reference) - c.log.Info(fmt.Sprintf("Writing model to store with tag: %s", utils.SanitizeForLog(storageTag))) + c.log.Info("Writing model to store with tag", "model", utils.SanitizeForLog(storageTag)) if err := c.store.Write(model, []string{storageTag}, progressWriter); err != nil { if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error()), oci.ModePull); writeErr != nil { - c.log.Warn(fmt.Sprintf("Failed to write error message: %v", writeErr)) + c.log.Warn("Failed to write error message", "error", writeErr) } return fmt.Errorf("writing model to store: %w", err) } if err := progress.WriteSuccess(progressWriter, "Model pulled successfully", oci.ModePull); err != nil { - c.log.Warn(fmt.Sprintf("Failed to write success message: %v", err)) + c.log.Warn("Failed to write success message", "error", err) } return nil diff --git a/pkg/distribution/distribution/client_test.go b/pkg/distribution/distribution/client_test.go index c31addf98..3195a01b1 100644 --- a/pkg/distribution/distribution/client_test.go +++ b/pkg/distribution/distribution/client_test.go @@ -8,12 +8,12 @@ import ( "errors" "fmt" "io" + "log/slog" "net/http/httptest" "net/url" "os" "path/filepath" "strings" - "log/slog" "testing" "github.com/docker/model-runner/pkg/distribution/internal/mutate" @@ -45,7 +45,7 @@ func TestClientPullModel(t *testing.T) { defer server.Close() registryURL, err := url.Parse(server.URL) if err != nil { - t.Error(fmt.Sprintf("Failed to parse registry URL: %v", err)) + t.Errorf("Failed to parse registry URL: %v", err) } registryHost := registryURL.Host @@ -54,52 +54,52 @@ func TestClientPullModel(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Read model content for verification later modelContent, err := os.ReadFile(testGGUFFile) if err != nil { - t.Error(fmt.Sprintf("Failed to read test model file: %v", err)) + t.Errorf("Failed to read test model file: %v", err) } model := testutil.BuildModelFromPath(t, testGGUFFile) tag := registryHost + "/testmodel:v1.0.0" ref, err := reference.ParseReference(tag) if err != nil { - t.Error(fmt.Sprintf("Failed to parse reference: %v", err)) + t.Errorf("Failed to parse reference: %v", err) } if err := remote.Write(ref, model, nil, remote.WithPlainHTTP(true)); err != nil { - t.Error(fmt.Sprintf("Failed to push model: %v", err)) + t.Errorf("Failed to push model: %v", err) } t.Run("pull without progress writer", func(t *testing.T) { // Pull model from registry without progress writer err := client.PullModel(t.Context(), tag, nil) if err != nil { - t.Error(fmt.Sprintf("Failed to pull model: %v", err)) + t.Errorf("Failed to pull model: %v", err) } model, err := client.GetModel(tag) if err != nil { - t.Error(fmt.Sprintf("Failed to get model: %v", err)) + t.Errorf("Failed to get model: %v", err) } modelPaths, err := model.GGUFPaths() if err != nil { - t.Error(fmt.Sprintf("Failed to get model path: %v", err)) + t.Errorf("Failed to get model path: %v", err) } if len(modelPaths) != 1 { - t.Error(fmt.Sprintf("Unexpected number of model files: %d", len(modelPaths))) + t.Errorf("Unexpected number of model files: %d", len(modelPaths)) } // Verify model content pulledContent, err := os.ReadFile(modelPaths[0]) if err != nil { - t.Error(fmt.Sprintf("Failed to read pulled model: %v", err)) + t.Errorf("Failed to read pulled model: %v", err) } if string(pulledContent) != string(modelContent) { - t.Error(fmt.Sprintf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent)) + t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent) } }) @@ -109,36 +109,36 @@ func TestClientPullModel(t *testing.T) { // Pull model from registry with progress writer if err := client.PullModel(t.Context(), tag, &progressBuffer); err != nil { - t.Error(fmt.Sprintf("Failed to pull model: %v", err)) + t.Errorf("Failed to pull model: %v", err) } // Verify progress output progressOutput := progressBuffer.String() if !strings.Contains(progressOutput, "Using cached model") && !strings.Contains(progressOutput, "Downloading") { - t.Error(fmt.Sprintf("Progress output doesn't contain expected text: got %q", progressOutput)) + t.Errorf("Progress output doesn't contain expected text: got %q", progressOutput) } model, err := client.GetModel(tag) if err != nil { - t.Error(fmt.Sprintf("Failed to get model: %v", err)) + t.Errorf("Failed to get model: %v", err) } modelPaths, err := model.GGUFPaths() if err != nil { - t.Error(fmt.Sprintf("Failed to get model path: %v", err)) + t.Errorf("Failed to get model path: %v", err) } if len(modelPaths) != 1 { - t.Error(fmt.Sprintf("Unexpected number of model files: %d", len(modelPaths))) + t.Errorf("Unexpected number of model files: %d", len(modelPaths)) } // Verify model content pulledContent, err := os.ReadFile(modelPaths[0]) if err != nil { - t.Error(fmt.Sprintf("Failed to read pulled model: %v", err)) + t.Errorf("Failed to read pulled model: %v", err) } if string(pulledContent) != string(modelContent) { - t.Error(fmt.Sprintf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent)) + t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent) } }) @@ -148,7 +148,7 @@ func TestClientPullModel(t *testing.T) { // Create client with plainHTTP for test registry testClient, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Create a buffer to capture progress output @@ -165,25 +165,25 @@ func TestClientPullModel(t *testing.T) { var pullErr *mdregistry.Error ok := errors.As(err, &pullErr) if !ok { - t.Error(fmt.Sprintf("Expected registry.Error, got %T: %v", err, err)) + t.Errorf("Expected registry.Error, got %T: %v", err, err) } // Verify it matches registry.ErrModelNotFound for API compatibility if !errors.Is(err, mdregistry.ErrModelNotFound) { - t.Error(fmt.Sprintf("Expected registry.ErrModelNotFound, got %T", err)) + t.Errorf("Expected registry.ErrModelNotFound, got %T", err) } // Verify error fields if pullErr.Reference != nonExistentRef { - t.Error(fmt.Sprintf("Expected reference %q, got %q", nonExistentRef, pullErr.Reference)) + t.Errorf("Expected reference %q, got %q", nonExistentRef, pullErr.Reference) } // The error code can be NAME_UNKNOWN, MANIFEST_UNKNOWN, or UNKNOWN depending on the resolver implementation if pullErr.Code != "NAME_UNKNOWN" && pullErr.Code != "MANIFEST_UNKNOWN" && pullErr.Code != "UNKNOWN" { - t.Error(fmt.Sprintf("Expected error code NAME_UNKNOWN, MANIFEST_UNKNOWN, or UNKNOWN, got %q", pullErr.Code)) + t.Errorf("Expected error code NAME_UNKNOWN, MANIFEST_UNKNOWN, or UNKNOWN, got %q", pullErr.Code) } // The error message varies by resolver implementation if !strings.Contains(strings.ToLower(pullErr.Message), "not found") { - t.Error(fmt.Sprintf("Expected message to contain 'not found', got %q", pullErr.Message)) + t.Errorf("Expected message to contain 'not found', got %q", pullErr.Message) } if pullErr.Err == nil { t.Error("Expected underlying error to be non-nil") @@ -196,7 +196,7 @@ func TestClientPullModel(t *testing.T) { // Create client with plainHTTP for test registry testClient, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Use the dummy.gguf file from assets directory @@ -205,26 +205,26 @@ func TestClientPullModel(t *testing.T) { // Push model to local store testTag := registryHost + "/incomplete-test/model:v1.0.0" if err := testClient.store.Write(mdl, []string{testTag}, nil); err != nil { - t.Error(fmt.Sprintf("Failed to push model to store: %v", err)) + t.Errorf("Failed to push model to store: %v", err) } // Push model to registry if err := testClient.PushModel(t.Context(), testTag, nil); err != nil { - t.Error(fmt.Sprintf("Failed to pull model: %v", err)) + t.Errorf("Failed to pull model: %v", err) } // Get the model to find the GGUF path model, err := testClient.GetModel(testTag) if err != nil { - t.Error(fmt.Sprintf("Failed to get model: %v", err)) + t.Errorf("Failed to get model: %v", err) } ggufPaths, err := model.GGUFPaths() if err != nil { - t.Error(fmt.Sprintf("Failed to get GGUF path: %v", err)) + t.Errorf("Failed to get GGUF path: %v", err) } if len(ggufPaths) != 1 { - t.Error(fmt.Sprintf("Unexpected number of model files: %d", len(ggufPaths))) + t.Errorf("Unexpected number of model files: %d", len(ggufPaths)) } // Create an incomplete file by copying the GGUF file and adding .incomplete suffix @@ -232,23 +232,23 @@ func TestClientPullModel(t *testing.T) { incompletePath := ggufPath + ".incomplete" originalContent, err := os.ReadFile(ggufPath) if err != nil { - t.Error(fmt.Sprintf("Failed to read GGUF file: %v", err)) + t.Errorf("Failed to read GGUF file: %v", err) } // Write partial content to simulate an incomplete download partialContent := originalContent[:len(originalContent)/2] if err := os.WriteFile(incompletePath, partialContent, 0644); err != nil { - t.Error(fmt.Sprintf("Failed to create incomplete file: %v", err)) + t.Errorf("Failed to create incomplete file: %v", err) } // Verify the incomplete file exists if _, err := os.Stat(incompletePath); os.IsNotExist(err) { - t.Error(fmt.Sprintf("Failed to create incomplete file: %v", err)) + t.Errorf("Failed to create incomplete file: %v", err) } // Delete the local model to force a pull if _, err := testClient.DeleteModel(testTag, false); err != nil { - t.Error(fmt.Sprintf("Failed to delete model: %v", err)) + t.Errorf("Failed to delete model: %v", err) } // Create a buffer to capture progress output @@ -256,7 +256,7 @@ func TestClientPullModel(t *testing.T) { // Pull the model again - this should detect the incomplete file and pull again if err := testClient.PullModel(t.Context(), testTag, &progressBuffer); err != nil { - t.Error(fmt.Sprintf("Failed to pull model: %v", err)) + t.Errorf("Failed to pull model: %v", err) } // Verify progress output indicates a new download, not using cached model @@ -267,18 +267,18 @@ func TestClientPullModel(t *testing.T) { // Verify the incomplete file no longer exists if _, err := os.Stat(incompletePath); !os.IsNotExist(err) { - t.Error(fmt.Sprintf("Incomplete file still exists after successful pull: %s", incompletePath)) + t.Errorf("Incomplete file still exists after successful pull: %s", incompletePath) } // Verify the complete file exists if _, err := os.Stat(ggufPath); os.IsNotExist(err) { - t.Error(fmt.Sprintf("GGUF file doesn't exist after pull: %s", ggufPath)) + t.Errorf("GGUF file doesn't exist after pull: %s", ggufPath) } // Verify the content of the pulled file matches the original pulledContent, err := os.ReadFile(ggufPath) if err != nil { - t.Error(fmt.Sprintf("Failed to read pulled GGUF file: %v", err)) + t.Errorf("Failed to read pulled GGUF file: %v", err) } if !bytes.Equal(pulledContent, originalContent) { @@ -292,13 +292,13 @@ func TestClientPullModel(t *testing.T) { // Create client with plainHTTP for test registry testClient, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Read model content for verification later testModelContent, err := os.ReadFile(testGGUFFile) if err != nil { - t.Error(fmt.Sprintf("Failed to read test model file: %v", err)) + t.Errorf("Failed to read test model file: %v", err) } // Push first version of model to registry @@ -309,38 +309,38 @@ func TestClientPullModel(t *testing.T) { // Pull first version of model if err := testClient.PullModel(t.Context(), testTag, nil); err != nil { - t.Error(fmt.Sprintf("Failed to pull first version of model: %v", err)) + t.Errorf("Failed to pull first version of model: %v", err) } // Verify first version is in local store model, err := testClient.GetModel(testTag) if err != nil { - t.Error(fmt.Sprintf("Failed to get first version of model: %v", err)) + t.Errorf("Failed to get first version of model: %v", err) } modelPath, err := model.GGUFPaths() if err != nil { - t.Error(fmt.Sprintf("Failed to get model path: %v", err)) + t.Errorf("Failed to get model path: %v", err) } if len(modelPath) != 1 { - t.Error(fmt.Sprintf("Unexpected number of model files: %d", len(modelPath))) + t.Errorf("Unexpected number of model files: %d", len(modelPath)) } // Verify first version content pulledContent, err := os.ReadFile(modelPath[0]) if err != nil { - t.Error(fmt.Sprintf("Failed to read pulled model: %v", err)) + t.Errorf("Failed to read pulled model: %v", err) } if string(pulledContent) != string(testModelContent) { - t.Error(fmt.Sprintf("Pulled model content doesn't match original: got %q, want %q", pulledContent, testModelContent)) + t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, testModelContent) } // Create a modified version of the model updatedModelFile := filepath.Join(tempDir, "updated-dummy.gguf") updatedContent := append(testModelContent, []byte("UPDATED CONTENT")...) if err := os.WriteFile(updatedModelFile, updatedContent, 0644); err != nil { - t.Error(fmt.Sprintf("Failed to create updated model file: %v", err)) + t.Errorf("Failed to create updated model file: %v", err) } // Push updated model with same tag @@ -353,7 +353,7 @@ func TestClientPullModel(t *testing.T) { // Pull model again - should get the updated version if err := testClient.PullModel(t.Context(), testTag, &progressBuffer); err != nil { - t.Error(fmt.Sprintf("Failed to pull updated model: %v", err)) + t.Errorf("Failed to pull updated model: %v", err) } // Verify progress output indicates a new download, not using cached model @@ -365,25 +365,25 @@ func TestClientPullModel(t *testing.T) { // Get the model again to verify it's the updated version updatedModel, err := testClient.GetModel(testTag) if err != nil { - t.Error(fmt.Sprintf("Failed to get updated model: %v", err)) + t.Errorf("Failed to get updated model: %v", err) } updatedModelPaths, err := updatedModel.GGUFPaths() if err != nil { - t.Error(fmt.Sprintf("Failed to get updated model path: %v", err)) + t.Errorf("Failed to get updated model path: %v", err) } if len(updatedModelPaths) != 1 { - t.Error(fmt.Sprintf("Unexpected number of model files: %d", len(modelPath))) + t.Errorf("Unexpected number of model files: %d", len(modelPath)) } // Verify updated content updatedPulledContent, err := os.ReadFile(updatedModelPaths[0]) if err != nil { - t.Error(fmt.Sprintf("Failed to read updated pulled model: %v", err)) + t.Errorf("Failed to read updated pulled model: %v", err) } if string(updatedPulledContent) != string(updatedContent) { - t.Error(fmt.Sprintf("Updated pulled model content doesn't match: got %q, want %q", updatedPulledContent, updatedContent)) + t.Errorf("Updated pulled model content doesn't match: got %q, want %q", updatedPulledContent, updatedContent) } }) @@ -393,13 +393,13 @@ func TestClientPullModel(t *testing.T) { testTag := registryHost + "/unsupported-test/model:v1.0.0" ref, err := reference.ParseReference(testTag) if err != nil { - t.Error(fmt.Sprintf("Failed to parse reference: %v", err)) + t.Errorf("Failed to parse reference: %v", err) } if err := remote.Write(ref, newMdl, nil, remote.WithPlainHTTP(true)); err != nil { - t.Error(fmt.Sprintf("Failed to push model: %v", err)) + t.Errorf("Failed to push model: %v", err) } if err := client.PullModel(t.Context(), testTag, nil); err == nil || !errors.Is(err, ErrUnsupportedMediaType) { - t.Error(fmt.Sprintf("Expected artifact version error, got %v", err)) + t.Errorf("Expected artifact version error, got %v", err) } }) @@ -410,7 +410,7 @@ func TestClientPullModel(t *testing.T) { safetensorsPath := filepath.Join(safetensorsTempDir, "model.safetensors") safetensorsContent := []byte("fake safetensors content for testing") if err := os.WriteFile(safetensorsPath, safetensorsContent, 0644); err != nil { - t.Error(fmt.Sprintf("Failed to create safetensors file: %v", err)) + t.Errorf("Failed to create safetensors file: %v", err) } // Create a safetensors model @@ -420,10 +420,10 @@ func TestClientPullModel(t *testing.T) { testTag := registryHost + "/safetensors-test/model:v1.0.0" ref, err := reference.ParseReference(testTag) if err != nil { - t.Error(fmt.Sprintf("Failed to parse reference: %v", err)) + t.Errorf("Failed to parse reference: %v", err) } if err := remote.Write(ref, safetensorsModel, nil, remote.WithPlainHTTP(true)); err != nil { - t.Error(fmt.Sprintf("Failed to push safetensors model to registry: %v", err)) + t.Errorf("Failed to push safetensors model to registry: %v", err) } // Create a new client with a separate temp store @@ -431,7 +431,7 @@ func TestClientPullModel(t *testing.T) { testClient, err := newTestClient(clientTempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create test client: %v", err)) + t.Errorf("Failed to create test client: %v", err) } // Try to pull the safetensors model with a progress writer to capture warnings @@ -440,17 +440,17 @@ func TestClientPullModel(t *testing.T) { // Pull should succeed on all platforms now (with a warning on non-Linux) if err != nil { - t.Error(fmt.Sprintf("Expected no error, got: %v", err)) + t.Errorf("Expected no error, got: %v", err) } if !platform.SupportsVLLM() { // On non-Linux, verify that a warning was written progressOutput := progressBuf.String() if !strings.Contains(progressOutput, `"type":"warning"`) { - t.Error(fmt.Sprintf("Expected warning message on non-Linux platforms, got output: %s", progressOutput)) + t.Errorf("Expected warning message on non-Linux platforms, got output: %s", progressOutput) } if !strings.Contains(progressOutput, warnUnsupportedFormat) { - t.Error(fmt.Sprintf("Expected warning about safetensors format, got output: %s", progressOutput)) + t.Errorf("Expected warning about safetensors format, got output: %s", progressOutput) } } }) @@ -461,7 +461,7 @@ func TestClientPullModel(t *testing.T) { // Create client with plainHTTP for test registry testClient, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Create a buffer to capture progress output @@ -469,7 +469,7 @@ func TestClientPullModel(t *testing.T) { // Pull model from registry with progress writer if err := testClient.PullModel(t.Context(), tag, &progressBuffer); err != nil { - t.Error(fmt.Sprintf("Failed to pull model: %v", err)) + t.Errorf("Failed to pull model: %v", err) } // Parse progress output as JSON @@ -479,13 +479,13 @@ func TestClientPullModel(t *testing.T) { line := scanner.Text() var msg oci.ProgressMessage if err := json.Unmarshal([]byte(line), &msg); err != nil { - t.Error(fmt.Sprintf("Failed to parse JSON progress message: %v, line: %s", err, line)) + t.Errorf("Failed to parse JSON progress message: %v, line: %s", err, line) } messages = append(messages, msg) } if err := scanner.Err(); err != nil { - t.Error(fmt.Sprintf("Error reading progress output: %v", err)) + t.Errorf("Error reading progress output: %v", err) } // Verify we got some messages @@ -496,34 +496,34 @@ func TestClientPullModel(t *testing.T) { // Verify all messages have the correct mode for i, msg := range messages { if msg.Mode != oci.ModePull { - t.Error(fmt.Sprintf("message %d: expected mode %q, got %q", i, oci.ModePull, msg.Mode)) + t.Errorf("message %d: expected mode %q, got %q", i, oci.ModePull, msg.Mode) } } // Check the last message is a success message lastMsg := messages[len(messages)-1] if lastMsg.Type != oci.TypeSuccess { - t.Error(fmt.Sprintf("Expected last message to be success, got type: %q, message: %s", lastMsg.Type, lastMsg.Message)) + t.Errorf("Expected last message to be success, got type: %q, message: %s", lastMsg.Type, lastMsg.Message) } // Verify model was pulled correctly model, err := testClient.GetModel(tag) if err != nil { - t.Error(fmt.Sprintf("Failed to get model: %v", err)) + t.Errorf("Failed to get model: %v", err) } modelPaths, err := model.GGUFPaths() if err != nil { - t.Error(fmt.Sprintf("Failed to get model path: %v", err)) + t.Errorf("Failed to get model path: %v", err) } if len(modelPaths) != 1 { - t.Error(fmt.Sprintf("Unexpected number of model files: %d", len(modelPaths))) + t.Errorf("Unexpected number of model files: %d", len(modelPaths)) } // Verify model content pulledContent, err := os.ReadFile(modelPaths[0]) if err != nil { - t.Error(fmt.Sprintf("Failed to read pulled model: %v", err)) + t.Errorf("Failed to read pulled model: %v", err) } if string(pulledContent) != string(modelContent) { @@ -537,7 +537,7 @@ func TestClientPullModel(t *testing.T) { // Create client with plainHTTP for test registry testClient, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Create a buffer to capture progress output @@ -554,7 +554,7 @@ func TestClientPullModel(t *testing.T) { // Verify it matches registry.ErrModelNotFound if !errors.Is(err, mdregistry.ErrModelNotFound) { - t.Error(fmt.Sprintf("Expected registry.ErrModelNotFound, got %T", err)) + t.Errorf("Expected registry.ErrModelNotFound, got %T", err) } // No JSON messages should be in the buffer for this error case @@ -568,7 +568,7 @@ func TestClientGetModel(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Create model from test GGUF file @@ -578,18 +578,18 @@ func TestClientGetModel(t *testing.T) { tag := "test/model:v1.0.0" normalizedTag := "docker.io/test/model:v1.0.0" // Reference package normalizes to include registry if err := client.store.Write(model, []string{tag}, nil); err != nil { - t.Error(fmt.Sprintf("Failed to push model to store: %v", err)) + t.Errorf("Failed to push model to store: %v", err) } // Get model mi, err := client.GetModel(tag) if err != nil { - t.Error(fmt.Sprintf("Failed to get model: %v", err)) + t.Errorf("Failed to get model: %v", err) } // Verify model - tags are normalized to include the default registry if len(mi.Tags()) == 0 || mi.Tags()[0] != normalizedTag { - t.Error(fmt.Sprintf("Model tags don't match: got %v, want [%s]", mi.Tags(), normalizedTag)) + t.Errorf("Model tags don't match: got %v, want [%s]", mi.Tags(), normalizedTag) } } @@ -599,13 +599,13 @@ func TestClientGetModelNotFound(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Get non-existent model _, err = client.GetModel("nonexistent/model:v1.0.0") if !errors.Is(err, ErrModelNotFound) { - t.Error(fmt.Sprintf("Expected ErrModelNotFound, got %v", err)) + t.Errorf("Expected ErrModelNotFound, got %v", err) } } @@ -615,14 +615,14 @@ func TestClientListModels(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Create test model file modelContent := []byte("test model content") modelFile := filepath.Join(tempDir, "test-model.gguf") if err := os.WriteFile(modelFile, modelContent, 0644); err != nil { - t.Error(fmt.Sprintf("Failed to write test model file: %v", err)) + t.Errorf("Failed to write test model file: %v", err) } mdl := testutil.BuildModelFromPath(t, modelFile) @@ -631,21 +631,21 @@ func TestClientListModels(t *testing.T) { // First model tag1 := "test/model1:v1.0.0" if err := client.store.Write(mdl, []string{tag1}, nil); err != nil { - t.Error(fmt.Sprintf("Failed to push model to store: %v", err)) + t.Errorf("Failed to push model to store: %v", err) } // Create a slightly different model file for the second model modelContent2 := []byte("test model content 2") modelFile2 := filepath.Join(tempDir, "test-model2.gguf") if err := os.WriteFile(modelFile2, modelContent2, 0644); err != nil { - t.Error(fmt.Sprintf("Failed to write test model file: %v", err)) + t.Errorf("Failed to write test model file: %v", err) } mdl2 := testutil.BuildModelFromPath(t, modelFile2) // Second model tag2 := "test/model2:v1.0.0" if err := client.store.Write(mdl2, []string{tag2}, nil); err != nil { - t.Error(fmt.Sprintf("Failed to push model to store: %v", err)) + t.Errorf("Failed to push model to store: %v", err) } // Normalized tags for verification (reference package normalizes to include default registry) @@ -656,12 +656,12 @@ func TestClientListModels(t *testing.T) { // List models models, err := client.ListModels() if err != nil { - t.Error(fmt.Sprintf("Failed to list models: %v", err)) + t.Errorf("Failed to list models: %v", err) } // Verify models if len(models) != len(tags) { - t.Error(fmt.Sprintf("Expected %d models, got %d", len(tags), len(models))) + t.Errorf("Expected %d models, got %d", len(tags), len(models)) } // Check if all tags are present @@ -674,7 +674,7 @@ func TestClientListModels(t *testing.T) { for _, tag := range tags { if !tagMap[tag] { - t.Error(fmt.Sprintf("Tag %s not found in models", tag)) + t.Errorf("Tag %s not found in models", tag) } } } @@ -685,7 +685,7 @@ func TestClientGetStorePath(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Get store path @@ -693,12 +693,12 @@ func TestClientGetStorePath(t *testing.T) { // Verify store path matches the temp directory if storePath != tempDir { - t.Error(fmt.Sprintf("Store path doesn't match: got %s, want %s", storePath, tempDir)) + t.Errorf("Store path doesn't match: got %s, want %s", storePath, tempDir) } // Verify the store directory exists if _, err := os.Stat(storePath); os.IsNotExist(err) { - t.Error(fmt.Sprintf("Store directory does not exist: %s", storePath)) + t.Errorf("Store directory does not exist: %s", storePath) } } @@ -708,7 +708,7 @@ func TestClientDefaultLogger(t *testing.T) { // Create client without specifying logger client, err := NewClient(WithStoreRootPath(tempDir)) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Verify that logger is not nil @@ -723,7 +723,7 @@ func TestClientDefaultLogger(t *testing.T) { WithLogger(customLogger), ) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Verify that custom logger is used @@ -746,7 +746,7 @@ func TestWithFunctionsNilChecks(t *testing.T) { // Verify the path wasn't changed to empty if opts.storeRootPath != tempDir { - t.Error(fmt.Sprintf("WithStoreRootPath with empty string changed the path: got %q, want %q", opts.storeRootPath, tempDir)) + t.Errorf("WithStoreRootPath with empty string changed the path: got %q, want %q", opts.storeRootPath, tempDir) } }) @@ -788,7 +788,7 @@ func TestNewReferenceError(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Test with invalid reference @@ -799,7 +799,7 @@ func TestNewReferenceError(t *testing.T) { } if !errors.Is(err, ErrInvalidReference) { - t.Error(fmt.Sprintf("Expected error to match sentinel invalid reference error, got %v", err)) + t.Errorf("Expected error to match sentinel invalid reference error, got %v", err) } } @@ -809,7 +809,7 @@ func TestPush(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Create a test registry @@ -819,7 +819,7 @@ func TestPush(t *testing.T) { // Create a tag for the model uri, err := url.Parse(server.URL) if err != nil { - t.Error(fmt.Sprintf("Failed to parse registry URL: %v", err)) + t.Errorf("Failed to parse registry URL: %v", err) } tag := uri.Host + "/incomplete-test/model:v1.0.0" @@ -827,39 +827,39 @@ func TestPush(t *testing.T) { mdl := testutil.BuildModelFromPath(t, testGGUFFile) digest, err := mdl.ID() if err != nil { - t.Error(fmt.Sprintf("Failed to get digest of original model: %v", err)) + t.Errorf("Failed to get digest of original model: %v", err) } if err := client.store.Write(mdl, []string{tag}, nil); err != nil { - t.Error(fmt.Sprintf("Failed to push model to store: %v", err)) + t.Errorf("Failed to push model to store: %v", err) } // Push the model to the registry if err := client.PushModel(t.Context(), tag, nil); err != nil { - t.Error(fmt.Sprintf("Failed to push model: %v", err)) + t.Errorf("Failed to push model: %v", err) } // Delete local copy (so we can test pulling) if _, err := client.DeleteModel(tag, false); err != nil { - t.Error(fmt.Sprintf("Failed to delete model: %v", err)) + t.Errorf("Failed to delete model: %v", err) } // Test that model can be pulled successfully if err := client.PullModel(t.Context(), tag, nil); err != nil { - t.Error(fmt.Sprintf("Failed to pull model: %v", err)) + t.Errorf("Failed to pull model: %v", err) } // Test that model the pulled model is the same as the original (matching digests) mdl2, err := client.GetModel(tag) if err != nil { - t.Error(fmt.Sprintf("Failed to get pulled model: %v", err)) + t.Errorf("Failed to get pulled model: %v", err) } digest2, err := mdl2.ID() if err != nil { - t.Error(fmt.Sprintf("Failed to get digest of the pulled model: %v", err)) + t.Errorf("Failed to get digest of the pulled model: %v", err) } if digest != digest2 { - t.Error(fmt.Sprintf("Digests don't match: got %s, want %s", digest2, digest)) + t.Errorf("Digests don't match: got %s, want %s", digest2, digest) } } @@ -869,7 +869,7 @@ func TestPushProgress(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Create a test registry @@ -879,7 +879,7 @@ func TestPushProgress(t *testing.T) { // Create a tag for the model uri, err := url.Parse(server.URL) if err != nil { - t.Error(fmt.Sprintf("Failed to parse registry URL: %v", err)) + t.Errorf("Failed to parse registry URL: %v", err) } tag := uri.Host + "/some/model/repo:some-tag" @@ -888,14 +888,14 @@ func TestPushProgress(t *testing.T) { sz := int64(progress.MinBytesForUpdate * 2) path, err := randomFile(sz) if err != nil { - t.Error(fmt.Sprintf("Failed to create temp file: %v", err)) + t.Errorf("Failed to create temp file: %v", err) } defer os.Remove(path) mdl := testutil.BuildModelFromPath(t, path) if err := client.store.Write(mdl, []string{tag}, nil); err != nil { - t.Error(fmt.Sprintf("Failed to write model to store: %v", err)) + t.Errorf("Failed to write model to store: %v", err) } // Create a buffer to capture progress output @@ -917,13 +917,13 @@ func TestPushProgress(t *testing.T) { // Wait for the push to complete if err := <-done; err != nil { - t.Error(fmt.Sprintf("Failed to push model: %v", err)) + t.Errorf("Failed to push model: %v", err) } // Verify we got at least 2 messages (1 progress + 1 success) // With fast local uploads, we may only get one progress update per layer if len(lines) < 2 { - t.Error(fmt.Sprintf("Expected at least 2 progress messages, got %d", len(lines))) + t.Errorf("Expected at least 2 progress messages, got %d", len(lines)) } // Verify we got at least one progress message and the success message @@ -938,10 +938,10 @@ func TestPushProgress(t *testing.T) { } } if !hasProgress { - t.Error(fmt.Sprintf("Expected at least one progress message containing 'Uploaded:', got %v", lines)) + t.Errorf("Expected at least one progress message containing 'Uploaded:', got %v", lines) } if !hasSuccess { - t.Error(fmt.Sprintf("Expected a success message, got %v", lines)) + t.Errorf("Expected a success message, got %v", lines) } } @@ -951,14 +951,14 @@ func TestTag(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Create a test model model := testutil.BuildModelFromPath(t, testGGUFFile) id, err := model.ID() if err != nil { - t.Error(fmt.Sprintf("Failed to get model ID: %v", err)) + t.Errorf("Failed to get model ID: %v", err) } // Normalize the model name before writing @@ -966,35 +966,35 @@ func TestTag(t *testing.T) { // Push the model to the store if err := client.store.Write(model, []string{normalized}, nil); err != nil { - t.Error(fmt.Sprintf("Failed to push model to store: %v", err)) + t.Errorf("Failed to push model to store: %v", err) } // Tag the model by ID if err := client.Tag(id, "other-repo:tag1"); err != nil { - t.Error(fmt.Sprintf("Failed to tag model %q: %v", id, err)) + t.Errorf("Failed to tag model %q: %v", id, err) } // Tag the model by tag if err := client.Tag(id, "other-repo:tag2"); err != nil { - t.Error(fmt.Sprintf("Failed to tag model %q: %v", id, err)) + t.Errorf("Failed to tag model %q: %v", id, err) } // Verify the model has all 3 tags modelInfo, err := client.GetModel("some-repo:some-tag") if err != nil { - t.Error(fmt.Sprintf("Failed to get model: %v", err)) + t.Errorf("Failed to get model: %v", err) } if len(modelInfo.Tags()) != 3 { - t.Error(fmt.Sprintf("Expected 3 tags, got %d", len(modelInfo.Tags()))) + t.Errorf("Expected 3 tags, got %d", len(modelInfo.Tags())) } // Verify the model can be accessed by new tags if _, err := client.GetModel("other-repo:tag1"); err != nil { - t.Error(fmt.Sprintf("Failed to get model by tag: %v", err)) + t.Errorf("Failed to get model by tag: %v", err) } if _, err := client.GetModel("other-repo:tag2"); err != nil { - t.Error(fmt.Sprintf("Failed to get model by tag: %v", err)) + t.Errorf("Failed to get model by tag: %v", err) } } @@ -1004,12 +1004,12 @@ func TestTagNotFound(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Tag the model by ID if err := client.Tag("non-existent-model:latest", "other-repo:tag1"); !errors.Is(err, ErrModelNotFound) { - t.Error(fmt.Sprintf("Expected ErrModelNotFound, got: %v", err)) + t.Errorf("Expected ErrModelNotFound, got: %v", err) } } @@ -1019,11 +1019,11 @@ func TestClientPushModelNotFound(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } if err := client.PushModel(t.Context(), "non-existent-model:latest", nil); !errors.Is(err, ErrModelNotFound) { - t.Error(fmt.Sprintf("Expected ErrModelNotFound got: %v", err)) + t.Errorf("Expected ErrModelNotFound got: %v", err) } } @@ -1033,11 +1033,11 @@ func TestIsModelInStoreNotFound(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } if inStore, err := client.IsModelInStore("non-existent-model:latest"); err != nil { - t.Error(fmt.Sprintf("Unexpected error: %v", err)) + t.Errorf("Unexpected error: %v", err) } else if inStore { t.Error("Expected model not to be found") } @@ -1049,7 +1049,7 @@ func TestIsModelInStoreFound(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Create a test model @@ -1060,11 +1060,11 @@ func TestIsModelInStoreFound(t *testing.T) { // Push the model to the store if err := client.store.Write(model, []string{normalized}, nil); err != nil { - t.Error(fmt.Sprintf("Failed to push model to store: %v", err)) + t.Errorf("Failed to push model to store: %v", err) } if inStore, err := client.IsModelInStore("some-repo:some-tag"); err != nil { - t.Error(fmt.Sprintf("Unexpected error: %v", err)) + t.Errorf("Unexpected error: %v", err) } else if !inStore { t.Error("Expected model to be found") } @@ -1141,26 +1141,26 @@ func TestMigrateHFTagsOnClientInit(t *testing.T) { // Step 1: Create a client and write a model with the legacy tag setupClient, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create setup client: %v", err)) + t.Errorf("Failed to create setup client: %v", err) } model := testutil.BuildModelFromPath(t, testGGUFFile) if err := setupClient.store.Write(model, []string{tc.storedTag}, nil); err != nil { - t.Error(fmt.Sprintf("Failed to write model to store: %v", err)) + t.Errorf("Failed to write model to store: %v", err) } // Step 2: Create a NEW client (simulating restart) - migration should happen client, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Step 3: Verify the model can be found using the reference // (normalizeModelName converts hf.co -> huggingface.co, and migration should have updated the store) foundModel, err := client.GetModel(tc.lookupRef) if err != nil { - t.Error(fmt.Sprintf("Failed to get model after migration: %v", err)) + t.Errorf("Failed to get model after migration: %v", err) } if foundModel == nil { @@ -1182,10 +1182,10 @@ func TestMigrateHFTagsOnClientInit(t *testing.T) { } } if hasOldTag { - t.Error(fmt.Sprintf("Model still has old hf.co tag after migration: %v", tags)) + t.Errorf("Model still has old hf.co tag after migration: %v", tags) } if !hasNewTag { - t.Error(fmt.Sprintf("Model doesn't have huggingface.co tag after migration: %v", tags)) + t.Errorf("Model doesn't have huggingface.co tag after migration: %v", tags) } } }) @@ -1214,7 +1214,7 @@ func TestPullHuggingFaceModelFromCache(t *testing.T) { // Create client client, err := newTestClient(tempDir) if err != nil { - t.Error(fmt.Sprintf("Failed to create client: %v", err)) + t.Errorf("Failed to create client: %v", err) } // Create a test model and write it to the store with a normalized HuggingFace tag @@ -1223,20 +1223,20 @@ func TestPullHuggingFaceModelFromCache(t *testing.T) { // Store with normalized tag (huggingface.co) hfTag := "huggingface.co/testorg/testmodel:latest" if err := client.store.Write(model, []string{hfTag}, nil); err != nil { - t.Error(fmt.Sprintf("Failed to write model to store: %v", err)) + t.Errorf("Failed to write model to store: %v", err) } // Now try to pull using the test case's reference - it should use the cache var progressBuffer bytes.Buffer err = client.PullModel(t.Context(), tc.pullRef, &progressBuffer) if err != nil { - t.Error(fmt.Sprintf("Failed to pull model from cache: %v", err)) + t.Errorf("Failed to pull model from cache: %v", err) } // Verify that progress shows it was cached progressOutput := progressBuffer.String() if !strings.Contains(progressOutput, "Using cached model") { - t.Error(fmt.Sprintf("Expected progress to indicate cached model, got: %s", progressOutput)) + t.Errorf("Expected progress to indicate cached model, got: %s", progressOutput) } }) } diff --git a/pkg/distribution/distribution/normalize_test.go b/pkg/distribution/distribution/normalize_test.go index 861241bfd..16c32be08 100644 --- a/pkg/distribution/distribution/normalize_test.go +++ b/pkg/distribution/distribution/normalize_test.go @@ -1,14 +1,12 @@ package distribution import ( - "fmt" "io" + "log/slog" "path/filepath" "strings" "testing" - "log/slog" - "github.com/docker/model-runner/pkg/distribution/builder" "github.com/docker/model-runner/pkg/distribution/tarball" ) @@ -153,7 +151,7 @@ func TestNormalizeModelName(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result := client.normalizeModelName(tt.input) if result != tt.expected { - t.Error(fmt.Sprintf("normalizeModelName(%q) = %q, want %q", tt.input, result, tt.expected)) + t.Errorf("normalizeModelName(%q) = %q, want %q", tt.input, result, tt.expected) } }) } @@ -215,7 +213,7 @@ func TestLooksLikeID(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result := client.looksLikeID(tt.input) if result != tt.expected { - t.Error(fmt.Sprintf("looksLikeID(%q) = %v, want %v", tt.input, result, tt.expected)) + t.Errorf("looksLikeID(%q) = %v, want %v", tt.input, result, tt.expected) } }) } @@ -277,7 +275,7 @@ func TestLooksLikeDigest(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result := client.looksLikeDigest(tt.input) if result != tt.expected { - t.Error(fmt.Sprintf("looksLikeDigest(%q) = %v, want %v", tt.input, result, tt.expected)) + t.Errorf("looksLikeDigest(%q) = %v, want %v", tt.input, result, tt.expected) } }) } @@ -294,7 +292,7 @@ func TestNormalizeModelNameWithIDResolution(t *testing.T) { // Extract the short ID (12 hex chars after "sha256:") if !strings.HasPrefix(modelID, "sha256:") { - t.Error(fmt.Sprintf("Expected model ID to start with 'sha256:', got: %s", modelID)) + t.Errorf("Expected model ID to start with 'sha256:', got: %s", modelID) } shortID := modelID[7:19] // Extract 12 chars after "sha256:" fullHex := strings.TrimPrefix(modelID, "sha256:") @@ -325,7 +323,7 @@ func TestNormalizeModelNameWithIDResolution(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result := client.normalizeModelName(tt.input) if result != tt.expected { - t.Error(fmt.Sprintf("normalizeModelName(%q) = %q, want %q", tt.input, result, tt.expected)) + t.Errorf("normalizeModelName(%q) = %q, want %q", tt.input, result, tt.expected) } }) } @@ -344,7 +342,7 @@ func createTestClient(t *testing.T) (*Client, func()) { WithLogger(slog.Default()), ) if err != nil { - t.Error(fmt.Sprintf("Failed to create test client: %v", err)) + t.Errorf("Failed to create test client: %v", err) } cleanup := func() { @@ -375,7 +373,7 @@ func TestIsHuggingFaceReference(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result := isHuggingFaceReference(tt.input) if result != tt.expected { - t.Error(fmt.Sprintf("isHuggingFaceReference(%q) = %v, want %v", tt.input, result, tt.expected)) + t.Errorf("isHuggingFaceReference(%q) = %v, want %v", tt.input, result, tt.expected) } }) } @@ -437,13 +435,13 @@ func TestParseHFReference(t *testing.T) { t.Run(tt.name, func(t *testing.T) { repo, rev, tag := parseHFReference(tt.input) if repo != tt.expectedRepo { - t.Error(fmt.Sprintf("parseHFReference(%q) repo = %q, want %q", tt.input, repo, tt.expectedRepo)) + t.Errorf("parseHFReference(%q) repo = %q, want %q", tt.input, repo, tt.expectedRepo) } if rev != tt.expectedRev { - t.Error(fmt.Sprintf("parseHFReference(%q) rev = %q, want %q", tt.input, rev, tt.expectedRev)) + t.Errorf("parseHFReference(%q) rev = %q, want %q", tt.input, rev, tt.expectedRev) } if tag != tt.expectedTag { - t.Error(fmt.Sprintf("parseHFReference(%q) tag = %q, want %q", tt.input, tag, tt.expectedTag)) + t.Errorf("parseHFReference(%q) tag = %q, want %q", tt.input, tag, tt.expectedTag) } }) } @@ -457,7 +455,7 @@ func loadTestModel(t *testing.T, client *Client, ggufPath string) string { pr, pw := io.Pipe() target, err := tarball.NewTarget(pw) if err != nil { - t.Error(fmt.Sprintf("Failed to create target: %v", err)) + t.Errorf("Failed to create target: %v", err) } done := make(chan error) @@ -470,15 +468,15 @@ func loadTestModel(t *testing.T, client *Client, ggufPath string) string { bldr, err := builder.FromPath(ggufPath) if err != nil { - t.Error(fmt.Sprintf("Failed to create builder from GGUF: %v", err)) + t.Errorf("Failed to create builder from GGUF: %v", err) } if err := bldr.Build(t.Context(), target, nil); err != nil { - t.Error(fmt.Sprintf("Failed to build model: %v", err)) + t.Errorf("Failed to build model: %v", err) } if err := <-done; err != nil { - t.Error(fmt.Sprintf("Failed to load model: %v", err)) + t.Errorf("Failed to load model: %v", err) } if id == "" { diff --git a/pkg/distribution/registry/testregistry/registry.go b/pkg/distribution/registry/testregistry/registry.go index b8555266b..e82aa233e 100644 --- a/pkg/distribution/registry/testregistry/registry.go +++ b/pkg/distribution/registry/testregistry/registry.go @@ -69,7 +69,9 @@ func (r *Registry) handleBlobUpload(w http.ResponseWriter, req *http.Request, pa switch req.Method { case http.MethodPost: // Start upload + r.mu.RLock() uploadID := fmt.Sprintf("upload-%d", len(r.blobs)) + r.mu.RUnlock() location := fmt.Sprintf("/v2/%s/blobs/uploads/%s", repo, uploadID) w.Header().Set("Location", location) w.Header().Set("Docker-Upload-UUID", uploadID) diff --git a/pkg/inference/backends/diffusers/diffusers.go b/pkg/inference/backends/diffusers/diffusers.go index e6af495b1..3a991d1eb 100644 --- a/pkg/inference/backends/diffusers/diffusers.go +++ b/pkg/inference/backends/diffusers/diffusers.go @@ -123,7 +123,7 @@ func (d *diffusers) Install(_ context.Context, _ *http.Client) error { // Get version output, err := d.pythonCmd("-c", "import diffusers; print(diffusers.__version__)").Output() if err != nil { - d.log.Warn(fmt.Sprintf("could not get diffusers version: %v", err)) + d.log.Warn("could not get diffusers version", "error", err) d.status = "running diffusers version: unknown" } else { d.status = fmt.Sprintf("running diffusers version: %s", strings.TrimSpace(string(output))) @@ -156,7 +156,7 @@ func (d *diffusers) Run(ctx context.Context, socket, model string, modelRef stri return fmt.Errorf("%w: model %s", ErrNoDDUFFile, model) } - d.log.Info(fmt.Sprintf("Loading DDUF file from: %s", ddufPath)) + d.log.Info("Loading DDUF file from", "path", ddufPath) args, err := d.config.GetArgs(ddufPath, socket, mode, backendConfig) if err != nil { @@ -168,7 +168,7 @@ func (d *diffusers) Run(ctx context.Context, socket, model string, modelRef stri args = append(args, "--served-model-name", modelRef) } - d.log.Info(fmt.Sprintf("Diffusers args: %v", utils.SanitizeForLog(strings.Join(args, " ")))) + d.log.Info("Diffusers args", "args", utils.SanitizeForLog(strings.Join(args, " "))) if d.pythonPath == "" { return fmt.Errorf("diffusers: python runtime not configured; did you forget to call Install") diff --git a/pkg/inference/backends/llamacpp/download.go b/pkg/inference/backends/llamacpp/download.go index abc7e73b7..245592731 100644 --- a/pkg/inference/backends/llamacpp/download.go +++ b/pkg/inference/backends/llamacpp/download.go @@ -60,7 +60,7 @@ func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logge return errLlamaCppUpdateDisabled } - log.Info(fmt.Sprintf("downloadLatestLlamaCpp: %s, %s, %s, %s", desiredVersion, desiredVariant, vendoredServerStoragePath, llamaCppPath)) + log.Info("downloadLatestLlamaCpp", "version", desiredVersion, "variant", desiredVariant, "storagePath", vendoredServerStoragePath, "llamaCppPath", llamaCppPath) desiredTag := desiredVersion + "-" + desiredVariant url := fmt.Sprintf("https://hub.docker.com/v2/namespaces/%s/repositories/%s/tags/%s", hubNamespace, hubRepo, desiredTag) resp, err := httpClient.Get(url) @@ -89,7 +89,7 @@ func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logge latest = response.Digest } if latest == "" { - log.Warn(fmt.Sprintf("could not fing the %s tag, hub response: %s", desiredTag, body)) + log.Warn("could not find the tag", "tag", desiredTag, "response", body) return fmt.Errorf("could not find the %s tag", desiredTag) } @@ -107,7 +107,7 @@ func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logge data, err = os.ReadFile(currentVersionFile) if err != nil { - log.Warn(fmt.Sprintf("failed to read current llama.cpp version: %v", err)) + log.Warn("failed to read current llama.cpp version", "error", err) log.Warn("proceeding to update llama.cpp binary") } else if strings.TrimSpace(string(data)) == latest { log.Info("current llama.cpp version is already up to date") @@ -118,7 +118,7 @@ func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logge } log.Info("llama.cpp binary must be updated, proceeding to update it") } else { - log.Info(fmt.Sprintf("current llama.cpp version is outdated: %s vs %s, proceeding to update it", strings.TrimSpace(string(data)), latest)) + log.Info("current llama.cpp version is outdated, proceeding to update", "current", strings.TrimSpace(string(data)), "latest", latest) } image := fmt.Sprintf("registry-1.docker.io/%s/%s@%s", hubNamespace, hubRepo, latest) @@ -168,7 +168,7 @@ func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logge log.Info(l.status) if err := os.WriteFile(currentVersionFile, []byte(latest), 0o644); err != nil { - log.Warn(fmt.Sprintf("failed to save llama.cpp version: %v", err)) + log.Warn("failed to save llama.cpp version", "error", err) } return nil @@ -176,7 +176,7 @@ func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logge //nolint:unused // Used in platform-specific files (download_darwin.go, download_windows.go) func extractFromImage(ctx context.Context, log logging.Logger, image, requiredOs, requiredArch, destination string) error { - log.Info(fmt.Sprintf("Extracting image %q to %q", image, destination)) + log.Info("Extracting image", "image", image, "destination", destination) tmpDir, err := os.MkdirTemp("", "docker-tar-extract") if err != nil { return err @@ -191,7 +191,7 @@ func extractFromImage(ctx context.Context, log logging.Logger, image, requiredOs func getLlamaCppVersion(log logging.Logger, llamaCpp string) string { output, err := exec.Command(llamaCpp, "--version").CombinedOutput() if err != nil { - log.Warn(fmt.Sprintf("could not get llama.cpp version: %v", err)) + log.Warn("could not get llama.cpp version", "error", err) return "unknown" } re := regexp.MustCompile(`version: \d+ \((\w+)\)`) @@ -199,6 +199,6 @@ func getLlamaCppVersion(log logging.Logger, llamaCpp string) string { if len(matches) == 2 { return matches[1] } - log.Warn(fmt.Sprintf("failed to parse llama.cpp version from output:\n%s", strings.TrimSpace(string(output)))) + log.Warn("failed to parse llama.cpp version from output:\n", "error", strings.TrimSpace(string(output))) return "unknown" } diff --git a/pkg/inference/backends/llamacpp/llamacpp.go b/pkg/inference/backends/llamacpp/llamacpp.go index 592e6d078..12a1011cd 100644 --- a/pkg/inference/backends/llamacpp/llamacpp.go +++ b/pkg/inference/backends/llamacpp/llamacpp.go @@ -117,7 +117,7 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error { // digest to be equal to the one on Docker Hub. llamaCppPath := filepath.Join(l.updatedServerStoragePath, llamaServerBin) if err := l.ensureLatestLlamaCpp(ctx, l.log, httpClient, llamaCppPath, l.vendoredServerStoragePath); err != nil { - l.log.Info(fmt.Sprintf("failed to ensure latest llama.cpp: %v\n", err)) + l.log.Info("failed to ensure latest llama.cpp \n", "error", err) if !errors.Is(err, errLlamaCppUpToDate) && !errors.Is(err, errLlamaCppUpdateDisabled) { l.status = fmt.Sprintf("failed to install llama.cpp: %v", err) } @@ -129,7 +129,7 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error { } l.gpuSupported = l.checkGPUSupport(ctx) - l.log.Info(fmt.Sprintf("installed llama-server with gpuSupport=%t", l.gpuSupported)) + l.log.Info("installed llama-server", "gpuSupport", l.gpuSupported) return nil } @@ -351,12 +351,12 @@ func (l *llamaCpp) checkGPUSupport(ctx context.Context) bool { "--list-devices", ) if err != nil { - l.log.Warn(fmt.Sprintf("Failed to start sandboxed llama.cpp process to probe GPU support: %v", err)) + l.log.Warn("Failed to start sandboxed llama.cpp process to probe GPU support", "error", err) return false } defer llamaCppSandbox.Close() if err := llamaCppSandbox.Command().Wait(); err != nil { - l.log.Warn(fmt.Sprintf("Failed to determine if llama-server is built with GPU support: %v", err)) + l.log.Warn("Failed to determine if llama-server is built with GPU support", "error", err) return false } sc := bufio.NewScanner(strings.NewReader(output.String())) diff --git a/pkg/inference/backends/mlx/mlx.go b/pkg/inference/backends/mlx/mlx.go index 9907ee4f7..012347381 100644 --- a/pkg/inference/backends/mlx/mlx.go +++ b/pkg/inference/backends/mlx/mlx.go @@ -110,7 +110,7 @@ func (m *mlx) Install(ctx context.Context, httpClient *http.Client) error { cmd = exec.CommandContext(ctx, pythonPath, "-c", "import mlx; print(mlx.__version__)") output, outputErr := cmd.Output() if outputErr != nil { - m.log.Warn(fmt.Sprintf("could not get MLX version: %v", outputErr)) + m.log.Warn("could not get MLX version", "error", outputErr) m.status = "running MLX version: unknown" } else { m.status = fmt.Sprintf("running MLX version: %s", strings.TrimSpace(string(output))) diff --git a/pkg/inference/backends/sglang/sglang.go b/pkg/inference/backends/sglang/sglang.go index 3d00e7eb5..9e08d5e33 100644 --- a/pkg/inference/backends/sglang/sglang.go +++ b/pkg/inference/backends/sglang/sglang.go @@ -118,7 +118,7 @@ func (s *sglang) Install(_ context.Context, _ *http.Client) error { // Get version output, err := s.pythonCmd("-c", "import sglang; print(sglang.__version__)").Output() if err != nil { - s.log.Warn(fmt.Sprintf("could not get sglang version: %v", err)) + s.log.Warn("could not get sglang version", "error", err) s.status = "running sglang version: unknown" } else { s.status = fmt.Sprintf("running sglang version: %s", strings.TrimSpace(string(output))) diff --git a/pkg/inference/backends/vllm/vllm.go b/pkg/inference/backends/vllm/vllm.go index 603ecdf8a..f16adcc27 100644 --- a/pkg/inference/backends/vllm/vllm.go +++ b/pkg/inference/backends/vllm/vllm.go @@ -94,7 +94,7 @@ func (v *vLLM) Install(_ context.Context, _ *http.Client) error { versionPath := filepath.Join(filepath.Dir(vllmDir), "version") versionBytes, err := os.ReadFile(versionPath) if err != nil { - v.log.Warn(fmt.Sprintf("could not get vllm version: %v", err)) + v.log.Warn("could not get vllm version", "error", err) v.status = "running vllm version: unknown" } else { v.status = fmt.Sprintf("running vllm version: %s", strings.TrimSpace(string(versionBytes))) diff --git a/pkg/inference/backends/vllmmetal/vllmmetal.go b/pkg/inference/backends/vllmmetal/vllmmetal.go index 37e9003e1..6d3ffb81c 100644 --- a/pkg/inference/backends/vllmmetal/vllmmetal.go +++ b/pkg/inference/backends/vllmmetal/vllmmetal.go @@ -109,7 +109,7 @@ func (v *vllmMetal) Install(ctx context.Context, httpClient *http.Client) error v.pythonPath = pythonPath return v.verifyInstallation(ctx) } - v.log.Info(fmt.Sprintf("vllm-metal version mismatch: installed %s, want %s", installed, vllmMetalVersion)) + v.log.Info("vllm-metal version mismatch", "installed", installed, "want", vllmMetalVersion) } } @@ -120,7 +120,7 @@ func (v *vllmMetal) Install(ctx context.Context, httpClient *http.Client) error // Save version file if err := os.WriteFile(versionFile, []byte(vllmMetalVersion), 0644); err != nil { - v.log.Warn(fmt.Sprintf("failed to write version file: %v", err)) + v.log.Warn("failed to write version file", "error", err) } v.pythonPath = pythonPath @@ -130,7 +130,7 @@ func (v *vllmMetal) Install(ctx context.Context, httpClient *http.Client) error // downloadAndExtract downloads the vllm-metal image from Docker Hub and extracts it. // The image contains a self-contained Python installation with all packages pre-installed. func (v *vllmMetal) downloadAndExtract(ctx context.Context, _ *http.Client) error { - v.log.Info(fmt.Sprintf("Downloading vllm-metal %s from Docker Hub...", vllmMetalVersion)) + v.log.Info("Downloading vllm-metal from Docker Hub...", "version", vllmMetalVersion) // Create temp directory for download downloadDir, err := os.MkdirTemp("", "vllm-metal-install") @@ -175,7 +175,7 @@ func (v *vllmMetal) downloadAndExtract(ctx context.Context, _ *http.Client) erro return fmt.Errorf("failed to make python3 executable: %w", err) } - v.log.Info(fmt.Sprintf("vllm-metal %s installed successfully", vllmMetalVersion)) + v.log.Info("vllm-metal installed successfully", "version", vllmMetalVersion) return nil } diff --git a/pkg/inference/models/handler_test.go b/pkg/inference/models/handler_test.go index 1a58ea241..7194ef995 100644 --- a/pkg/inference/models/handler_test.go +++ b/pkg/inference/models/handler_test.go @@ -1,8 +1,8 @@ package models import ( - "fmt" "encoding/json" + "log/slog" "net/http" "net/http/httptest" "net/url" @@ -11,8 +11,6 @@ import ( "strings" "testing" - "log/slog" - "github.com/docker/model-runner/pkg/distribution/builder" reg "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/distribution/registry/testregistry" @@ -24,7 +22,7 @@ func getProjectRoot(t *testing.T) string { // Start from the current test file's directory dir, err := os.Getwd() if err != nil { - t.Error(fmt.Sprintf("Failed to get current directory: %v", err)) + t.Errorf("Failed to get current directory: %v", err) } // Walk up the directory tree until we find the go.mod file @@ -50,7 +48,7 @@ func TestPullModel(t *testing.T) { // Create a tag for the model uri, err := url.Parse(server.URL) if err != nil { - t.Error(fmt.Sprintf("Failed to parse registry URL: %v", err)) + t.Errorf("Failed to parse registry URL: %v", err) } tag := uri.Host + "/ai/model:v1.0.0" @@ -58,23 +56,23 @@ func TestPullModel(t *testing.T) { projectRoot := getProjectRoot(t) model, err := builder.FromPath(filepath.Join(projectRoot, "assets", "dummy.gguf")) if err != nil { - t.Error(fmt.Sprintf("Failed to create model builder: %v", err)) + t.Errorf("Failed to create model builder: %v", err) } license, err := model.WithLicense(filepath.Join(projectRoot, "assets", "license.txt")) if err != nil { - t.Error(fmt.Sprintf("Failed to add license to model: %v", err)) + t.Errorf("Failed to add license to model: %v", err) } // Build the OCI model artifact + push it (use plainHTTP for test registry) client := reg.NewClient(reg.WithPlainHTTP(true)) target, err := client.NewTarget(tag) if err != nil { - t.Error(fmt.Sprintf("Failed to create model target: %v", err)) + t.Errorf("Failed to create model target: %v", err) } err = license.Build(t.Context(), target, os.Stdout) if err != nil { - t.Error(fmt.Sprintf("Failed to build model: %v", err)) + t.Errorf("Failed to build model: %v", err) } tests := []struct { @@ -117,19 +115,19 @@ func TestPullModel(t *testing.T) { w := httptest.NewRecorder() err = handler.manager.Pull(tag, "", r, w) if err != nil { - t.Error(fmt.Sprintf("Failed to pull model: %v", err)) + t.Errorf("Failed to pull model: %v", err) } if tt.expectedCT != w.Header().Get("Content-Type") { - t.Error(fmt.Sprintf("Expected content type %s, got %s", tt.expectedCT, w.Header().Get("Content-Type"))) + t.Errorf("Expected content type %s, got %s", tt.expectedCT, w.Header().Get("Content-Type")) } // Clean tempDir after each test if err := os.RemoveAll(tempDir); err != nil { - t.Error(fmt.Sprintf("Failed to clean temp directory: %v", err)) + t.Errorf("Failed to clean temp directory: %v", err) } if err := os.MkdirAll(tempDir, 0755); err != nil { - t.Error(fmt.Sprintf("Failed to recreate temp directory: %v", err)) + t.Errorf("Failed to recreate temp directory: %v", err) } }) } @@ -144,19 +142,19 @@ func TestHandleGetModel(t *testing.T) { uri, err := url.Parse(server.URL) if err != nil { - t.Error(fmt.Sprintf("Failed to parse registry URL: %v", err)) + t.Errorf("Failed to parse registry URL: %v", err) } // Prepare the OCI model artifact projectRoot := getProjectRoot(t) model, err := builder.FromPath(filepath.Join(projectRoot, "assets", "dummy.gguf")) if err != nil { - t.Error(fmt.Sprintf("Failed to create model builder: %v", err)) + t.Errorf("Failed to create model builder: %v", err) } license, err := model.WithLicense(filepath.Join(projectRoot, "assets", "license.txt")) if err != nil { - t.Error(fmt.Sprintf("Failed to add license to model: %v", err)) + t.Errorf("Failed to add license to model: %v", err) } // Build the OCI model artifact + push it (use plainHTTP for test registry) @@ -164,11 +162,11 @@ func TestHandleGetModel(t *testing.T) { client := reg.NewClient(reg.WithPlainHTTP(true)) target, err := client.NewTarget(tag) if err != nil { - t.Error(fmt.Sprintf("Failed to create model target: %v", err)) + t.Errorf("Failed to create model target: %v", err) } err = license.Build(t.Context(), target, os.Stdout) if err != nil { - t.Error(fmt.Sprintf("Failed to build model: %v", err)) + t.Errorf("Failed to build model: %v", err) } tests := []struct { @@ -224,7 +222,7 @@ func TestHandleGetModel(t *testing.T) { w := httptest.NewRecorder() err = handler.manager.Pull(tt.modelName, "", r, w) if err != nil { - t.Error(fmt.Sprintf("Failed to pull model: %v", err)) + t.Errorf("Failed to pull model: %v", err) } } @@ -244,12 +242,12 @@ func TestHandleGetModel(t *testing.T) { // Check response if w.Code != tt.expectedCode { - t.Error(fmt.Sprintf("Expected status code %d, got %d", tt.expectedCode, w.Code)) + t.Errorf("Expected status code %d, got %d", tt.expectedCode, w.Code) } if tt.expectedError != "" { if !strings.Contains(w.Body.String(), tt.expectedError) { - t.Error(fmt.Sprintf("Expected error containing %q, got %q", tt.expectedError, w.Body.String())) + t.Errorf("Expected error containing %q, got %q", tt.expectedError, w.Body.String()) } } else { // For successful responses, verify we got a valid JSON response @@ -261,16 +259,16 @@ func TestHandleGetModel(t *testing.T) { Config json.RawMessage `json:"config"` } if err := json.NewDecoder(w.Body).Decode(&response); err != nil { - t.Error(fmt.Sprintf("Failed to decode response body: %v", err)) + t.Errorf("Failed to decode response body: %v", err) } } // Clean tempDir after each test if err := os.RemoveAll(tempDir); err != nil { - t.Error(fmt.Sprintf("Failed to clean temp directory: %v", err)) + t.Errorf("Failed to clean temp directory: %v", err) } if err := os.MkdirAll(tempDir, 0755); err != nil { - t.Error(fmt.Sprintf("Failed to recreate temp directory: %v", err)) + t.Errorf("Failed to recreate temp directory: %v", err) } }) } @@ -310,7 +308,7 @@ func TestCors(t *testing.T) { m.ServeHTTP(w, req) if w.Code != http.StatusNoContent { - t.Error(fmt.Sprintf("Expected status code 204 for OPTIONS request, got %d", w.Code)) + t.Errorf("Expected status code 204 for OPTIONS request, got %d", w.Code) } }) } diff --git a/pkg/inference/models/http_handler.go b/pkg/inference/models/http_handler.go index 0e34c3036..28e465c9d 100644 --- a/pkg/inference/models/http_handler.go +++ b/pkg/inference/models/http_handler.go @@ -6,14 +6,13 @@ import ( "errors" "fmt" "html" + "log/slog" "net/http" "path" "strconv" "strings" "sync" - "log/slog" - "github.com/docker/model-runner/pkg/distribution/distribution" "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/inference" @@ -109,21 +108,21 @@ func (h *HTTPHandler) handleCreateModel(w http.ResponseWriter, r *http.Request) if err := h.manager.Pull(request.From, request.BearerToken, r, w); err != nil { sanitizedFrom := utils.SanitizeForLog(request.From, -1) if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - h.log.Info(fmt.Sprintf("Request canceled/timed out while pulling model %q", sanitizedFrom)) + h.log.Info("Request canceled/timed out while pulling model", "model", sanitizedFrom) return } if errors.Is(err, registry.ErrInvalidReference) { - h.log.Warn(fmt.Sprintf("Invalid model reference %q: %v", sanitizedFrom, err)) + h.log.Warn("Invalid model reference", "model", sanitizedFrom, "error", err) http.Error(w, "Invalid model reference", http.StatusBadRequest) return } if errors.Is(err, registry.ErrUnauthorized) { - h.log.Warn(fmt.Sprintf("Unauthorized to pull model %q: %v", sanitizedFrom, err)) + h.log.Warn("Unauthorized to pull model", "model", sanitizedFrom, "error", err) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } if errors.Is(err, registry.ErrModelNotFound) { - h.log.Warn(fmt.Sprintf("Failed to pull model %q: %v", sanitizedFrom, err)) + h.log.Warn("Failed to pull model", "model", sanitizedFrom, "error", err) http.Error(w, "Model not found", http.StatusNotFound) return } @@ -452,17 +451,17 @@ func (h *HTTPHandler) handleTagModel(w http.ResponseWriter, r *http.Request, mod func (h *HTTPHandler) handlePushModel(w http.ResponseWriter, r *http.Request, model string) { if err := h.manager.Push(model, r, w); err != nil { if errors.Is(err, distribution.ErrInvalidReference) { - h.log.Warn(fmt.Sprintf("Invalid model reference %q: %v", utils.SanitizeForLog(model, -1), err)) + h.log.Warn("Invalid model reference", "model", utils.SanitizeForLog(model, -1), "error", err) http.Error(w, "Invalid model reference", http.StatusBadRequest) return } if errors.Is(err, distribution.ErrModelNotFound) { - h.log.Warn(fmt.Sprintf("Failed to push model %q: %v", utils.SanitizeForLog(model, -1), err)) + h.log.Warn("Failed to push model", "model", utils.SanitizeForLog(model, -1), "error", err) http.Error(w, "Model not found", http.StatusNotFound) return } if errors.Is(err, registry.ErrUnauthorized) { - h.log.Warn(fmt.Sprintf("Unauthorized to push model %q: %v", utils.SanitizeForLog(model, -1), err)) + h.log.Warn("Unauthorized to push model", "model", utils.SanitizeForLog(model, -1), "error", err) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } @@ -497,7 +496,7 @@ func (h *HTTPHandler) handleRepackageModel(w http.ResponseWriter, r *http.Reques http.Error(w, err.Error(), http.StatusNotFound) return } - h.log.Warn(fmt.Sprintf("Failed to repackage model %q: %v", utils.SanitizeForLog(model, -1), err)) + h.log.Warn("Failed to repackage model", "model", utils.SanitizeForLog(model, -1), "error", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -518,7 +517,7 @@ func (h *HTTPHandler) handleRepackageModel(w http.ResponseWriter, r *http.Reques func (h *HTTPHandler) handlePurge(w http.ResponseWriter, _ *http.Request) { err := h.manager.Purge() if err != nil { - h.log.Warn(fmt.Sprintf("Failed to purge models: %v", err)) + h.log.Warn("Failed to purge models", "error", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } diff --git a/pkg/inference/models/manager.go b/pkg/inference/models/manager.go index f13519a3f..66e09f7ff 100644 --- a/pkg/inference/models/manager.go +++ b/pkg/inference/models/manager.go @@ -52,7 +52,7 @@ func NewManager(log logging.Logger, c ClientConfig) *Manager { distribution.WithRegistryClient(registryClient), ) if err != nil { - log.Error(fmt.Sprintf("Failed to create distribution client: %v", err)) + log.Error("Failed to create distribution client", "error", err) // Continue without distribution client. The model manager will still // respond to requests, but may return errors if the client is required. } @@ -92,13 +92,13 @@ func (m *Manager) ResolveID(modelRef string) string { sanitizedModelRef := utils.SanitizeForLog(modelRef, -1) model, err := m.GetLocal(sanitizedModelRef) if err != nil { - m.log.Warn(fmt.Sprintf("Failed to resolve model ref %s to ID: %v", sanitizedModelRef, err)) + m.log.Warn("Failed to resolve model ref to ID", "model", sanitizedModelRef, "error", err) return sanitizedModelRef } modelID, err := model.ID() if err != nil { - m.log.Warn(fmt.Sprintf("Failed to get model ID for ref %s: %v", sanitizedModelRef, err)) + m.log.Warn("Failed to get model ID for ref", "model", sanitizedModelRef, "error", err) return sanitizedModelRef } @@ -172,7 +172,7 @@ func (m *Manager) List() ([]*Model, error) { for _, model := range models { apiModel, err := ToModel(model) if err != nil { - m.log.Warn(fmt.Sprintf("error while converting model, skipping: %v", err)) + m.log.Warn("error while converting model, skipping", "error", err) continue } apiModels = append(apiModels, apiModel) @@ -300,7 +300,7 @@ func (m *Manager) Tag(ref, target string) error { for _, mModel := range models { modelID, idErr := mModel.ID() if idErr != nil { - m.log.Warn(fmt.Sprintf("Failed to get model ID: %v", idErr)) + m.log.Warn("Failed to get model ID", "error", idErr) continue } @@ -359,7 +359,7 @@ func (m *Manager) Tag(ref, target string) error { // Now tag using the found model reference (the matching tag) if tagErr := m.distributionClient.Tag(foundModelRef, target); tagErr != nil { - m.log.Warn(fmt.Sprintf("Failed to apply tag %q to resolved model %q: %v", utils.SanitizeForLog(target, -1), utils.SanitizeForLog(foundModelRef, -1), tagErr)) + m.log.Warn("Failed to apply tag to resolved model", "target", utils.SanitizeForLog(target, -1), "model", utils.SanitizeForLog(foundModelRef, -1), "error", tagErr) return fmt.Errorf("error while tagging model: %w", tagErr) } } else if err != nil { @@ -411,7 +411,7 @@ func (m *Manager) Purge() error { return fmt.Errorf("model distribution service unavailable") } if err := m.distributionClient.ResetStore(); err != nil { - m.log.Warn(fmt.Sprintf("Failed to purge models: %v", err)) + m.log.Warn("Failed to purge models", "error", err) return fmt.Errorf("error while purging models: %w", err) } return nil diff --git a/pkg/inference/scheduling/http_handler.go b/pkg/inference/scheduling/http_handler.go index 06df35e08..a9f3077b9 100644 --- a/pkg/inference/scheduling/http_handler.go +++ b/pkg/inference/scheduling/http_handler.go @@ -448,7 +448,7 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) { go func() { preloadBody, err := json.Marshal(OpenAIInferenceRequest{Model: configureRequest.Model}) if err != nil { - h.scheduler.log.Warn(fmt.Sprintf("failed to marshal preload request body: %v", err)) + h.scheduler.log.Warn("failed to marshal preload request body", "error", err) return } ctx, cancel := context.WithTimeout(context.Background(), time.Minute) @@ -460,7 +460,7 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) { bytes.NewReader(preloadBody), ) if err != nil { - h.scheduler.log.Warn(fmt.Sprintf("failed to create preload request: %v", err)) + h.scheduler.log.Warn("failed to create preload request", "error", err) return } preloadReq.Header.Set("User-Agent", r.UserAgent()) @@ -470,7 +470,7 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) { recorder := httptest.NewRecorder() h.handleOpenAIInference(recorder, preloadReq) if recorder.Code != http.StatusOK { - h.scheduler.log.Warn(fmt.Sprintf("background model preload failed with status %d: %s", recorder.Code, recorder.Body.String())) + h.scheduler.log.Warn("background model preload failed", "status", recorder.Code, "body", recorder.Body.String()) } }() diff --git a/pkg/inference/scheduling/installer.go b/pkg/inference/scheduling/installer.go index 6c7ce90ea..a41448831 100644 --- a/pkg/inference/scheduling/installer.go +++ b/pkg/inference/scheduling/installer.go @@ -1,7 +1,6 @@ package scheduling import ( - "fmt" "context" "errors" "net/http" @@ -142,7 +141,7 @@ func (i *installer) run(ctx context.Context) { continue } if err := backend.Install(ctx, i.httpClient); err != nil { - i.log.Warn(fmt.Sprintf("Backend installation failed for %s: %v", name, err)) + i.log.Warn("Backend installation failed for", "backend", name, "error", err) select { case <-ctx.Done(): status.err = errors.Join(errInstallerShuttingDown, ctx.Err()) diff --git a/pkg/inference/scheduling/loader.go b/pkg/inference/scheduling/loader.go index 9acd94caf..9817abfdf 100644 --- a/pkg/inference/scheduling/loader.go +++ b/pkg/inference/scheduling/loader.go @@ -233,17 +233,17 @@ func (l *loader) evict(idleOnly bool) int { default: } if unused && (!idleOnly || idle || defunct) && (!idleOnly || !neverEvict || defunct) { - l.log.Info(fmt.Sprintf("Evicting %s backend runner with model %s (%s) in %s mode", r.backend, r.modelID, runnerInfo.modelRef, r.mode)) + l.log.Info("Evicting backend runner with model ( ) in mode", "backend", r.backend, "backend", r.modelID, "model", runnerInfo.modelRef, "mode", r.mode) l.freeRunnerSlot(runnerInfo.slot, r) evictedCount++ } else if unused { - l.log.Debug(fmt.Sprintf("Runner %s (%s) is unused but not evictable: idleOnly=%v, idle=%v, defunct=%v, neverEvict=%v", r.modelID, runnerInfo.modelRef, idleOnly, idle, defunct, neverEvict)) + l.log.Debug("Runner is unused but not evictable", "modelID", r.modelID, "modelRef", runnerInfo.modelRef, "idleOnly", idleOnly, "idle", idle, "defunct", defunct, "neverEvict", neverEvict) } else { - l.log.Debug(fmt.Sprintf("Runner %s (%s) is in use with %d references, cannot evict", r.modelID, runnerInfo.modelRef, l.references[runnerInfo.slot])) + l.log.Debug("Runner is in use with references, cannot evict", "modelID", r.modelID, "modelRef", runnerInfo.modelRef, "references", l.references[runnerInfo.slot]) } } if evictedCount > 0 { - l.log.Info(fmt.Sprintf("Evicted %d runner(s)", evictedCount)) + l.log.Info("Evicted runner(s)", "count", evictedCount) } return len(l.runners) } @@ -256,13 +256,13 @@ func (l *loader) evictRunner(backend, model string, mode inference.BackendMode) for r, runnerInfo := range l.runners { unused := l.references[runnerInfo.slot] == 0 if unused && (allBackends || r.backend == backend) && r.modelID == model && r.mode == mode { - l.log.Info(fmt.Sprintf("Evicting %s backend runner with model %s (%s) in %s mode", r.backend, r.modelID, runnerInfo.modelRef, r.mode)) + l.log.Info("Evicting backend runner with model ( ) in mode", "backend", r.backend, "backend", r.modelID, "model", runnerInfo.modelRef, "mode", r.mode) l.freeRunnerSlot(runnerInfo.slot, r) found = true } } if !found { - l.log.Warn(fmt.Sprintf("No unused runner found for backend=%s, model=%s, mode=%s", utils.SanitizeForLog(backend), utils.SanitizeForLog(model), utils.SanitizeForLog(string(mode)))) + l.log.Warn("No unused runner found", "backend", utils.SanitizeForLog(backend), "model", utils.SanitizeForLog(model), "mode", utils.SanitizeForLog(string(mode))) } return len(l.runners) } @@ -445,7 +445,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string defaultConfig := inference.BackendConfiguration{} if l.modelManager != nil { if bundle, err := l.modelManager.GetBundle(modelID); err != nil { - l.log.Warn(fmt.Sprintf("Failed to get bundle for model %s to determine default context size: %v", modelID, err)) + l.log.Warn("Failed to get bundle for model to determine default context size", "model", modelID, "error", err) } else if runtimeConfig := bundle.RuntimeConfig(); runtimeConfig != nil { if ctxSize := runtimeConfig.GetContextSize(); ctxSize != nil { defaultConfig.ContextSize = ctxSize @@ -455,7 +455,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string runnerConfig = &defaultConfig } - l.log.Info(fmt.Sprintf("Loading %s backend runner with model %s in %s mode", backendName, modelID, mode)) + l.log.Info("Loading backend runner with model in mode", "backend", backendName, "backend", modelID, "mode", mode) // Acquire the loader lock and defer its release. if !l.lock(ctx) { @@ -485,7 +485,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string if ok { select { case <-l.slots[existing.slot].done: - l.log.Warn(fmt.Sprintf("%s runner for %s is defunct. Waiting for it to be evicted.", backendName, existing.modelRef)) + l.log.Warn("runner for is defunct. Waiting for it to be evicted.", "backend", backendName, "model", existing.modelRef) if l.references[existing.slot] == 0 { l.evictRunner(backendName, modelID, mode) // Continue the loop to retry loading after evicting the defunct runner @@ -502,7 +502,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string // If all slots are full, try evicting unused runners. if len(l.runners) == len(l.slots) { - l.log.Info(fmt.Sprintf("Evicting to make room: %d/%d slots used", len(l.runners), len(l.slots))) + l.log.Info("Evicting to make room", "runners", len(l.runners), "slots", len(l.slots)) runnerCountAtLoopStart := len(l.runners) remainingRunners := l.evict(false) // Restart the loop if eviction happened @@ -522,7 +522,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string } if slot < 0 { - l.log.Debug(fmt.Sprintf("Cannot load model yet: %d/%d slots used", len(l.runners), len(l.slots))) + l.log.Debug("Cannot load model yet", "runners", len(l.runners), "slots", len(l.slots)) } // If we've identified a slot, then we're ready to start a runner. @@ -530,7 +530,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string // Create the runner. runner, err := run(l.log, backend, modelID, modelRef, mode, slot, runnerConfig, l.openAIRecorder) if err != nil { - l.log.Warn(fmt.Sprintf("Unable to start %s backend runner with model %s in %s mode: %v", backendName, modelID, mode, err)) + l.log.Warn("Unable to start backend runner with model in mode", "backend", backendName, "backend", modelID, "mode", mode, "error", err) return nil, fmt.Errorf("unable to start runner: %w", err) } @@ -542,7 +542,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string // deduplication of runners and keep slot / memory reservations. if err := runner.wait(ctx); err != nil { runner.terminate() - l.log.Warn(fmt.Sprintf("Initialization for %s backend runner with model %s in %s mode failed: %v", backendName, modelID, mode, err)) + l.log.Warn("Initialization for backend runner with model in mode failed", "backend", backendName, "backend", modelID, "mode", mode, "error", err) return nil, fmt.Errorf("error waiting for runner to be ready: %w", err) } @@ -615,7 +615,7 @@ func (l *loader) setRunnerConfig(ctx context.Context, backendName, modelID strin // If the configuration hasn't changed, then just return. if existingConfig, ok := l.runnerConfigs[configKey]; ok && reflect.DeepEqual(runnerConfig, existingConfig) { - l.log.Info(fmt.Sprintf("Configuration for %s runner for modelID %s unchanged", backendName, modelID)) + l.log.Info("Configuration for runner for modelID unchanged", "backend", backendName, "model", modelID) return nil } @@ -638,7 +638,7 @@ func (l *loader) setRunnerConfig(ctx context.Context, backendName, modelID strin return errRunnerAlreadyActive } - l.log.Info(fmt.Sprintf("Configuring %s runner for %s", backendName, modelID)) + l.log.Info("Configuring runner for", "backend", backendName, "model", modelID) l.runnerConfigs[configKey] = runnerConfig return nil } diff --git a/pkg/inference/scheduling/loader_test.go b/pkg/inference/scheduling/loader_test.go index d84622251..9120fba6e 100644 --- a/pkg/inference/scheduling/loader_test.go +++ b/pkg/inference/scheduling/loader_test.go @@ -1,16 +1,14 @@ package scheduling import ( - "fmt" "context" "errors" "io" + "log/slog" "net/http" "testing" "time" - "log/slog" - "github.com/docker/model-runner/pkg/inference" ) @@ -162,16 +160,16 @@ func TestMakeRunnerKey(t *testing.T) { key := makeRunnerKey(tt.backend, tt.modelID, tt.draftModelID, tt.mode) if key.backend != tt.backend { - t.Error(fmt.Sprintf("Expected backend %q, got %q", tt.backend, key.backend)) + t.Errorf("Expected backend %q, got %q", tt.backend, key.backend) } if key.modelID != tt.modelID { - t.Error(fmt.Sprintf("Expected modelID %q, got %q", tt.modelID, key.modelID)) + t.Errorf("Expected modelID %q, got %q", tt.modelID, key.modelID) } if key.draftModelID != tt.draftModelID { - t.Error(fmt.Sprintf("Expected draftModelID %q, got %q", tt.draftModelID, key.draftModelID)) + t.Errorf("Expected draftModelID %q, got %q", tt.draftModelID, key.draftModelID) } if key.mode != tt.mode { - t.Error(fmt.Sprintf("Expected mode %v, got %v", tt.mode, key.mode)) + t.Errorf("Expected mode %v, got %v", tt.mode, key.mode) } }) } @@ -186,16 +184,16 @@ func TestMakeConfigKey(t *testing.T) { key := makeConfigKey(backend, modelID, mode) if key.backend != backend { - t.Error(fmt.Sprintf("Expected backend %q, got %q", backend, key.backend)) + t.Errorf("Expected backend %q, got %q", backend, key.backend) } if key.modelID != modelID { - t.Error(fmt.Sprintf("Expected modelID %q, got %q", modelID, key.modelID)) + t.Errorf("Expected modelID %q, got %q", modelID, key.modelID) } if key.draftModelID != "" { - t.Error(fmt.Sprintf("Expected empty draftModelID for config key, got %q", key.draftModelID)) + t.Errorf("Expected empty draftModelID for config key, got %q", key.draftModelID) } if key.mode != mode { - t.Error(fmt.Sprintf("Expected mode %v, got %v", mode, key.mode)) + t.Errorf("Expected mode %v, got %v", mode, key.mode) } } @@ -326,7 +324,7 @@ func TestPerModelKeepAliveEviction(t *testing.T) { // Runner with short keep_alive should be evicted, never-evict should remain if remaining != 1 { - t.Error(fmt.Sprintf("Expected 1 remaining runner after eviction, got %d", remaining)) + t.Errorf("Expected 1 remaining runner after eviction, got %d", remaining) } // Verify that model-never is still present @@ -382,10 +380,10 @@ func TestIdleCheckDurationWithPerModelKeepAlive(t *testing.T) { // Should be based on the short keep_alive runner (around 100ms + 100ms buffer) // The never-evict runner should be skipped if duration < 0 { - t.Error(fmt.Sprintf("Expected positive duration, got %v", duration)) + t.Errorf("Expected positive duration, got %v", duration) } if duration > 500*time.Millisecond { - t.Error(fmt.Sprintf("Expected duration around 200ms, got %v", duration)) + t.Errorf("Expected duration around 200ms, got %v", duration) } loader.unlock() diff --git a/pkg/inference/scheduling/runner.go b/pkg/inference/scheduling/runner.go index 9f06dc23b..962aba66a 100644 --- a/pkg/inference/scheduling/runner.go +++ b/pkg/inference/scheduling/runner.go @@ -192,13 +192,13 @@ func run( if r.openAIRecorder != nil { r.openAIRecorder.SetConfigForModel(modelID, runnerConfig) } else { - r.log.Warn(fmt.Sprintf("OpenAI recorder is nil for model %s", modelID)) + r.log.Warn("OpenAI recorder is nil for model", "model", modelID) } // Start the backend run loop. go func() { if err := backend.Run(runCtx, socket, modelID, modelRef, mode, runnerConfig); err != nil { - log.Warn(fmt.Sprintf("Backend %s running model %s exited with error: %v", backend.Name(), utils.SanitizeForLog(modelRef), err)) + log.Warn("Backend running model exited with error", "backend", backend.Name(), "model", utils.SanitizeForLog(modelRef), "error", err) r.err = err } close(runDone) @@ -264,13 +264,13 @@ func (r *runner) terminate() { // Close the proxy's log. if err := r.proxyLog.Close(); err != nil { - r.log.Warn(fmt.Sprintf("Unable to close reverse proxy log writer: %v", err)) + r.log.Warn("Unable to close reverse proxy log writer", "error", err) } if r.openAIRecorder != nil { r.openAIRecorder.RemoveModel(r.model) } else { - r.log.Warn(fmt.Sprintf("OpenAI recorder is nil for model %s", r.model)) + r.log.Warn("OpenAI recorder is nil for model", "model", r.model) } } diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index 575c2c33e..783520f09 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -134,8 +134,8 @@ func (s *Scheduler) selectBackendForModel(model types.Model, backend inference.B if sglangBackend, ok := s.backends[sglang.Name]; ok && sglangBackend != nil { return sglangBackend } - s.log.Warn(fmt.Sprintf("Model %s is in safetensors format but no compatible backend is available. "+ - "Backend %s may not support this format and could fail at runtime.", utils.SanitizeForLog(modelRef), backend.Name())) + s.log.Warn("Model is in safetensors format but no compatible backend is available", + "model", utils.SanitizeForLog(modelRef), "backend", backend.Name()) } return backend @@ -204,7 +204,7 @@ func (s *Scheduler) GetAllActiveRunners() []metrics.ActiveRunner { for _, backend := range runningBackends { mode, ok := inference.ParseBackendMode(backend.Mode) if !ok { - s.log.Warn(fmt.Sprintf("Unknown backend mode %q, defaulting to completion.", backend.Mode)) + s.log.Warn("Unknown backend mode, defaulting to completion", "mode", backend.Mode) } // Find the runner slot for this backend/model combination // We iterate through all runners since we don't know the draftModelID @@ -212,7 +212,7 @@ func (s *Scheduler) GetAllActiveRunners() []metrics.ActiveRunner { if key.backend == backend.BackendName && key.modelID == backend.ModelName && key.mode == mode { socket, err := RunnerSocketPath(runnerInfo.slot) if err != nil { - s.log.Warn(fmt.Sprintf("Failed to get socket path for runner %s/%s (%s): %v", backend.BackendName, backend.ModelName, key.modelID, err)) + s.log.Warn("Failed to get socket path for runner / ( )", "backend", backend.BackendName, "backend", backend.ModelName, "model", key.modelID, "error", err) continue } @@ -244,7 +244,7 @@ func (s *Scheduler) GetLlamaCppSocket() (string, error) { if backend.BackendName == llamacpp.Name { mode, ok := inference.ParseBackendMode(backend.Mode) if !ok { - s.log.Warn(fmt.Sprintf("Unknown backend mode %q, defaulting to completion.", backend.Mode)) + s.log.Warn("Unknown backend mode, defaulting to completion", "mode", backend.Mode) } // Find the runner slot for this backend/model combination // We iterate through all runners since we don't know the draftModelID @@ -334,7 +334,7 @@ func (s *Scheduler) ConfigureRunner(ctx context.Context, backend inference.Backe // Set the runner configuration if err := s.loader.setRunnerConfig(ctx, backend.Name(), modelID, mode, runnerConfig); err != nil { - s.log.Warn(fmt.Sprintf("Failed to configure %s runner for %s (%s): %s", backend.Name(), utils.SanitizeForLog(req.Model, -1), modelID, err)) + s.log.Warn("Failed to configure runner for ( )", "backend", backend.Name(), "model", utils.SanitizeForLog(req.Model, -1), "model", modelID, "error", err) return nil, err } diff --git a/pkg/metrics/aggregated_handler.go b/pkg/metrics/aggregated_handler.go index 06993d211..442c447c0 100644 --- a/pkg/metrics/aggregated_handler.go +++ b/pkg/metrics/aggregated_handler.go @@ -65,7 +65,7 @@ func (h *AggregatedMetricsHandler) collectAndAggregateMetrics(ctx context.Contex families, err := h.fetchRunnerMetrics(ctx, runner) if err != nil { - h.log.Warn(fmt.Sprintf("Failed to fetch metrics from runner %s/%s: %v", runner.BackendName, runner.ModelName, err)) + h.log.Warn("Failed to fetch metrics from runner /", "backend", runner.BackendName, "model", runner.ModelName, "error", err) return } @@ -165,7 +165,7 @@ func (h *AggregatedMetricsHandler) writeAggregatedMetrics(w http.ResponseWriter, encoder := expfmt.NewEncoder(w, expfmt.NewFormat(expfmt.TypeTextPlain)) for _, family := range families { if err := encoder.Encode(family); err != nil { - h.log.Error(fmt.Sprintf("Failed to encode metric family %s: %v", *family.Name, err)) + h.log.Error("Failed to encode metric family", "family", *family.Name, "error", err) continue } } diff --git a/pkg/metrics/openai_recorder.go b/pkg/metrics/openai_recorder.go index e2acb179f..de6c146a4 100644 --- a/pkg/metrics/openai_recorder.go +++ b/pkg/metrics/openai_recorder.go @@ -205,7 +205,7 @@ func (r *OpenAIRecorder) truncateBase64Data(data string) string { func (r *OpenAIRecorder) SetConfigForModel(model string, config *inference.BackendConfiguration) { if config == nil { - r.log.Warn(fmt.Sprintf("SetConfigForModel called with nil config for model %s", model)) + r.log.Warn("SetConfigForModel called with nil config for model", "model", model) return } @@ -399,9 +399,9 @@ func (r *OpenAIRecorder) RecordResponse(id, model string, rw http.ResponseWriter return } } - r.log.Error(fmt.Sprintf("Matching request (id=%s) not found for model %s - %d\n%s", id, modelID, statusCode, response)) + r.log.Error("Matching request not found for model", "id", id, "model", modelID, "statusCode", statusCode, "response", response) } else { - r.log.Error(fmt.Sprintf("Model %s not found in records - %d\n%s", modelID, statusCode, response)) + r.log.Error("Model not found in records", "model", modelID, "statusCode", statusCode, "response", response) } } @@ -717,7 +717,7 @@ func (r *OpenAIRecorder) handleStreamingRequests(w http.ResponseWriter, req *htt // Send heartbeat to establish connection. if _, err := fmt.Fprintf(w, "event: connected\ndata: {\"status\": \"connected\"}\n\n"); err != nil { - r.log.Error(fmt.Sprintf("Failed to write connected event to response: %v", err)) + r.log.Error("Failed to write connected event to response", "error", err) } flusher.Flush() @@ -738,17 +738,17 @@ func (r *OpenAIRecorder) handleStreamingRequests(w http.ResponseWriter, req *htt // Send as SSE event. jsonData, err := json.Marshal(modelRecords) if err != nil { - r.log.Error(fmt.Sprintf("Failed to marshal record for streaming: %v", err)) + r.log.Error("Failed to marshal record for streaming", "error", err) errorMsg := fmt.Sprintf(`{"error": "Failed to marshal record: %v"}`, err) if _, writeErr := fmt.Fprintf(w, "event: error\ndata: %s\n\n", errorMsg); writeErr != nil { - r.log.Error(fmt.Sprintf("Failed to write error event to response: %v", writeErr)) + r.log.Error("Failed to write error event to response", "error", writeErr) } flusher.Flush() continue } if _, err := fmt.Fprintf(w, "event: new_request\ndata: %s\n\n", jsonData); err != nil { - r.log.Error(fmt.Sprintf("Failed to write new_request event to response: %v", err)) + r.log.Error("Failed to write new_request event to response", "error", err) } flusher.Flush() @@ -841,14 +841,14 @@ func (r *OpenAIRecorder) sendExistingRecords(w http.ResponseWriter, model string }} jsonData, err := json.Marshal(singleRecord) if err != nil { - r.log.Error(fmt.Sprintf("Failed to marshal existing record for streaming: %v", err)) + r.log.Error("Failed to marshal existing record for streaming", "error", err) errorMsg := fmt.Sprintf(`{"error": "Failed to marshal existing record: %v"}`, err) if _, writeErr := fmt.Fprintf(w, "event: error\ndata: %s\n\n", errorMsg); writeErr != nil { - r.log.Error(fmt.Sprintf("Failed to write error event to response: %v", writeErr)) + r.log.Error("Failed to write error event to response", "error", writeErr) } } else { if _, writeErr := fmt.Fprintf(w, "event: existing_request\ndata: %s\n\n", jsonData); writeErr != nil { - r.log.Error(fmt.Sprintf("Failed to write existing_request event to response: %v", writeErr)) + r.log.Error("Failed to write existing_request event to response", "error", writeErr) } } } @@ -863,8 +863,8 @@ func (r *OpenAIRecorder) RemoveModel(model string) { if _, exists := r.records[modelID]; exists { delete(r.records, modelID) - r.log.Info(fmt.Sprintf("Removed records for model: %s", modelID)) + r.log.Info("Removed records for model", "model", modelID) } else { - r.log.Warn(fmt.Sprintf("No records found for model: %s", modelID)) + r.log.Warn("No records found for model", "model", modelID) } } diff --git a/pkg/metrics/openai_recorder_test.go b/pkg/metrics/openai_recorder_test.go index 00b6d167c..0f8e4f10f 100644 --- a/pkg/metrics/openai_recorder_test.go +++ b/pkg/metrics/openai_recorder_test.go @@ -1,11 +1,9 @@ package metrics import ( - "fmt" "encoding/json" - "testing" - "log/slog" + "testing" "github.com/docker/model-runner/pkg/inference/models" ) @@ -142,18 +140,18 @@ func TestTruncateMediaFields(t *testing.T) { if inputErr != nil { // For invalid JSON inputs, verify it's returned unchanged if resultStr != tt.expected { - t.Error(fmt.Sprintf("Invalid JSON should be returned unchanged. Expected %q, got %q", tt.expected, resultStr)) + t.Errorf("Invalid JSON should be returned unchanged. Expected %q, got %q", tt.expected, resultStr) } } else { // For valid JSON inputs, verify output is still valid JSON var resultJSON interface{} if err := json.Unmarshal(result, &resultJSON); err != nil { - t.Error(fmt.Sprintf("Result should be valid JSON, but got error: %v", err)) + t.Errorf("Result should be valid JSON, but got error: %v", err) } // Also check the content matches expected if resultStr != tt.expected { - t.Error(fmt.Sprintf("Expected result %q, but got %q", tt.expected, resultStr)) + t.Errorf("Expected result %q, but got %q", tt.expected, resultStr) } } }) @@ -201,7 +199,7 @@ func TestTruncateBase64Data(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result := recorder.truncateBase64Data(tt.input) if result != tt.expected { - t.Error(fmt.Sprintf("Expected %q, got %q", tt.expected, result)) + t.Errorf("Expected %q, got %q", tt.expected, result) } }) } diff --git a/pkg/metrics/scheduler_proxy.go b/pkg/metrics/scheduler_proxy.go index a52a730fe..df86029b5 100644 --- a/pkg/metrics/scheduler_proxy.go +++ b/pkg/metrics/scheduler_proxy.go @@ -1,7 +1,6 @@ package metrics import ( - "fmt" "io" "net" "net/http" @@ -41,7 +40,7 @@ func (h *SchedulerMetricsHandler) ServeHTTP(w http.ResponseWriter, r *http.Reque // Get the socket path for the active llama.cpp runner socket, err := h.scheduler.GetLlamaCppSocket() if err != nil { - h.log.Error(fmt.Sprintf("Failed to get llama.cpp socket: %v", err)) + h.log.Error("Failed to get llama.cpp socket", "error", err) http.Error(w, "Metrics endpoint not available", http.StatusServiceUnavailable) return } @@ -59,7 +58,7 @@ func (h *SchedulerMetricsHandler) ServeHTTP(w http.ResponseWriter, r *http.Reque // Create request to the backend metrics endpoint req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, "http://unix/metrics", http.NoBody) if err != nil { - h.log.Error(fmt.Sprintf("Failed to create metrics request: %v", err)) + h.log.Error("Failed to create metrics request", "error", err) http.Error(w, "Failed to create metrics request", http.StatusInternalServerError) return } @@ -74,7 +73,7 @@ func (h *SchedulerMetricsHandler) ServeHTTP(w http.ResponseWriter, r *http.Reque // Make the request to the backend resp, err := client.Do(req) if err != nil { - h.log.Error(fmt.Sprintf("Failed to fetch metrics from backend: %v", err)) + h.log.Error("Failed to fetch metrics from backend", "error", err) http.Error(w, "Backend metrics unavailable", http.StatusServiceUnavailable) return } @@ -92,7 +91,7 @@ func (h *SchedulerMetricsHandler) ServeHTTP(w http.ResponseWriter, r *http.Reque // Copy response body if _, err := io.Copy(w, resp.Body); err != nil { - h.log.Error(fmt.Sprintf("Failed to copy metrics response: %v", err)) + h.log.Error("Failed to copy metrics response", "error", err) return } diff --git a/pkg/ollama/http_handler.go b/pkg/ollama/http_handler.go index 2f324ea0b..45ff4084c 100644 --- a/pkg/ollama/http_handler.go +++ b/pkg/ollama/http_handler.go @@ -62,7 +62,7 @@ func NewHTTPHandler(log logging.Logger, scheduler *scheduling.Scheduler, schedul func (h *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { safeMethod := utils.SanitizeForLog(r.Method, -1) safePath := utils.SanitizeForLog(r.URL.Path, -1) - h.log.Info(fmt.Sprintf("Ollama API request: %s %s", safeMethod, safePath)) + h.log.Info("Ollama API request", "method", safeMethod, "path", safePath) h.httpHandler.ServeHTTP(w, r) } @@ -145,14 +145,14 @@ func (w *ollamaProgressWriter) Write(p []byte) (n int, err error) { return w.writer.Write(p) } // Unrecognized type, pass through to avoid losing information - w.log.Warn(fmt.Sprintf("Unknown progress message type: %s", msg.Type)) + w.log.Warn("Unknown progress message type", "message", msg.Type) return w.writer.Write(p) } // Marshal and write ollama format data, err := json.Marshal(ollamaMsg) if err != nil { - w.log.Warn(fmt.Sprintf("Failed to marshal ollama progress: %v", err)) + w.log.Warn("Failed to marshal ollama progress", "error", err) return w.writer.Write(p) } @@ -187,7 +187,7 @@ func (h *HTTPHandler) handleVersion(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { - h.log.Error(fmt.Sprintf("Failed to encode response: %v", err)) + h.log.Error("Failed to encode response", "error", err) } } @@ -196,7 +196,7 @@ func (h *HTTPHandler) handleListModels(w http.ResponseWriter, r *http.Request) { // Get models from the model manager modelsList, err := h.modelManager.List() if err != nil { - h.log.Error(fmt.Sprintf("Failed to list models: %v", err)) + h.log.Error("Failed to list models", "error", err) http.Error(w, "Failed to list models", http.StatusInternalServerError) return } @@ -243,7 +243,7 @@ func (h *HTTPHandler) handleListModels(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { - h.log.Error(fmt.Sprintf("Failed to encode response: %v", err)) + h.log.Error("Failed to encode response", "error", err) } } @@ -260,7 +260,7 @@ func (h *HTTPHandler) handlePS(w http.ResponseWriter, r *http.Request) { // Get model details to populate additional fields model, err := h.modelManager.GetLocal(backend.ModelName) if err != nil { - h.log.Warn(fmt.Sprintf("Failed to get model details for %s: %v", backend.ModelName, err)) + h.log.Warn("Failed to get model details for", "backend", backend.ModelName, "error", err) // Still add the model with basic info models = append(models, PSModel{ Name: backend.ModelName, @@ -303,7 +303,7 @@ func (h *HTTPHandler) handlePS(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { - h.log.Error(fmt.Sprintf("Failed to encode response: %v", err)) + h.log.Error("Failed to encode response", "error", err) } } @@ -324,7 +324,7 @@ func (h *HTTPHandler) handleShowModel(w http.ResponseWriter, r *http.Request) { // Get model details model, err := h.modelManager.GetLocal(modelName) if err != nil { - h.log.Error(fmt.Sprintf("Failed to get model: %v", err)) + h.log.Error("Failed to get model", "error", err) http.Error(w, fmt.Sprintf("Model not found: %v", err), http.StatusNotFound) return } @@ -332,7 +332,7 @@ func (h *HTTPHandler) handleShowModel(w http.ResponseWriter, r *http.Request) { // Get config config, err := model.Config() if err != nil { - h.log.Error(fmt.Sprintf("Failed to get model config: %v", err)) + h.log.Error("Failed to get model config", "error", err) http.Error(w, fmt.Sprintf("Failed to get model config: %v", err), http.StatusInternalServerError) return } @@ -350,7 +350,7 @@ func (h *HTTPHandler) handleShowModel(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { - h.log.Error(fmt.Sprintf("Failed to encode response: %v", err)) + h.log.Error("Failed to encode response", "error", err) } } @@ -417,7 +417,7 @@ func (h *HTTPHandler) configureModel(ctx context.Context, modelName string, opti if hasContextSize || reasoningBudget != nil || hasKeepAlive { sanitizedModelName := utils.SanitizeForLog(modelName, -1) - h.log.Info(fmt.Sprintf("configureModel: configuring model %s", sanitizedModelName)) + h.log.Info("configureModel: configuring model", "model", sanitizedModelName) configureRequest := scheduling.ConfigureRequest{ Model: modelName, } @@ -434,12 +434,12 @@ func (h *HTTPHandler) configureModel(ctx context.Context, modelName string, opti if err == nil { configureRequest.KeepAlive = &ka } else { - h.log.Warn(fmt.Sprintf("configureModel: invalid keep_alive %q: %v", keepAlive, err)) + h.log.Warn("configureModel: invalid keep_alive", "model", keepAlive, "error", err) } } _, err := h.scheduler.ConfigureRunner(ctx, nil, configureRequest, userAgent) if err != nil { - h.log.Warn(fmt.Sprintf("configureModel: failed to configure model %s: %v", sanitizedModelName, err)) + h.log.Warn("configureModel: failed to configure model", "model", sanitizedModelName, "error", err) } } } @@ -456,7 +456,7 @@ func (h *HTTPHandler) handleGenerate(w http.ResponseWriter, r *http.Request) { var req GenerateRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - h.log.Error(fmt.Sprintf("handleGenerate: failed to decode request: %v", err)) + h.log.Error("handleGenerate: failed to decode request", "error", err) http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest) return } @@ -506,7 +506,7 @@ func (h *HTTPHandler) handleGenerate(w http.ResponseWriter, r *http.Request) { func (h *HTTPHandler) unloadModel(ctx context.Context, w http.ResponseWriter, modelName string) { // Sanitize user input before logging to prevent log injection sanitizedModelName := utils.SanitizeForLog(modelName, -1) - h.log.Info(fmt.Sprintf("unloadModel: unloading model %s", sanitizedModelName)) + h.log.Info("unloadModel: unloading model", "model", sanitizedModelName) // Create an unload request for the scheduler unloadReq := map[string]interface{}{ @@ -516,19 +516,19 @@ func (h *HTTPHandler) unloadModel(ctx context.Context, w http.ResponseWriter, mo // Marshal the unload request reqBody, err := json.Marshal(unloadReq) if err != nil { - h.log.Error(fmt.Sprintf("unloadModel: failed to marshal request: %v", err)) + h.log.Error("unloadModel: failed to marshal request", "error", err) http.Error(w, fmt.Sprintf("Failed to marshal request: %v", err), http.StatusInternalServerError) return } // Sanitize the user-provided request body before logging to avoid log injection safeReqBody := utils.SanitizeForLog(string(reqBody), -1) - h.log.Info(fmt.Sprintf("unloadModel: sending POST /engines/unload with body: %s", safeReqBody)) + h.log.Info("unloadModel: sending POST /engines/unload with body", "model", safeReqBody) // Create a new request to the scheduler newReq, err := http.NewRequestWithContext(ctx, http.MethodPost, "/engines/unload", strings.NewReader(string(reqBody))) if err != nil { - h.log.Error(fmt.Sprintf("unloadModel: failed to create request: %v", err)) + h.log.Error("unloadModel: failed to create request", "error", err) http.Error(w, fmt.Sprintf("Failed to create request: %v", err), http.StatusInternalServerError) return } @@ -544,7 +544,7 @@ func (h *HTTPHandler) unloadModel(ctx context.Context, w http.ResponseWriter, mo // Forward to scheduler HTTP handler h.schedulerHTTP.ServeHTTP(respRecorder, newReq) - h.log.Info(fmt.Sprintf("unloadModel: scheduler response status=%d, body=%s", respRecorder.statusCode, respRecorder.body.String())) + h.log.Info("unloadModel: scheduler response", "status", respRecorder.statusCode, "body", respRecorder.body.String()) // Return the response status w.WriteHeader(respRecorder.statusCode) @@ -574,7 +574,7 @@ func (h *HTTPHandler) handleDelete(w http.ResponseWriter, r *http.Request) { } sanitizedModelName := utils.SanitizeForLog(modelName, -1) - h.log.Info(fmt.Sprintf("handleDelete: deleting model %s", sanitizedModelName)) + h.log.Info("handleDelete: deleting model", "model", sanitizedModelName) // First, unload the model from memory unloadReq := map[string]interface{}{ @@ -583,14 +583,14 @@ func (h *HTTPHandler) handleDelete(w http.ResponseWriter, r *http.Request) { reqBody, err := json.Marshal(unloadReq) if err != nil { - h.log.Error(fmt.Sprintf("handleDelete: failed to marshal unload request: %v", err)) + h.log.Error("handleDelete: failed to marshal unload request", "error", err) http.Error(w, fmt.Sprintf("Failed to marshal request: %v", err), http.StatusInternalServerError) return } newReq, err := http.NewRequestWithContext(ctx, http.MethodPost, "/engines/unload", strings.NewReader(string(reqBody))) if err != nil { - h.log.Error(fmt.Sprintf("handleDelete: failed to create unload request: %v", err)) + h.log.Error("handleDelete: failed to create unload request", "error", err) http.Error(w, fmt.Sprintf("Failed to create request: %v", err), http.StatusInternalServerError) return } @@ -603,12 +603,12 @@ func (h *HTTPHandler) handleDelete(w http.ResponseWriter, r *http.Request) { } h.schedulerHTTP.ServeHTTP(respRecorder, newReq) - h.log.Info(fmt.Sprintf("handleDelete: unload response status=%d", respRecorder.statusCode)) + h.log.Info("handleDelete: unload response", "status", respRecorder.statusCode) // Check if unload succeeded before deleting from storage if respRecorder.statusCode < 200 || respRecorder.statusCode >= 300 { sanitizedBody := utils.SanitizeForLog(respRecorder.body.String(), -1) - h.log.Error(fmt.Sprintf("handleDelete: unload failed for model %s with status=%d, body=%q", sanitizedModelName, respRecorder.statusCode, sanitizedBody)) + h.log.Error("handleDelete: unload failed for model", "model", sanitizedModelName, "status", respRecorder.statusCode, "body", sanitizedBody) http.Error( w, fmt.Sprintf("Failed to unload model: scheduler returned status %d", respRecorder.statusCode), @@ -620,12 +620,12 @@ func (h *HTTPHandler) handleDelete(w http.ResponseWriter, r *http.Request) { // Then delete the model from storage if _, err := h.modelManager.Delete(modelName, false); err != nil { sanitizedErr := utils.SanitizeForLog(err.Error(), -1) - h.log.Error(fmt.Sprintf("handleDelete: failed to delete model %s: %v", sanitizedModelName, sanitizedErr)) + h.log.Error("handleDelete: failed to delete model", "model", sanitizedModelName, "error", sanitizedErr) http.Error(w, fmt.Sprintf("Failed to delete model: %v", sanitizedErr), http.StatusInternalServerError) return } - h.log.Info(fmt.Sprintf("handleDelete: successfully deleted model %s", sanitizedModelName)) + h.log.Info("handleDelete: successfully deleted model", "model", sanitizedModelName) // Return success response in Ollama format (empty JSON object) w.Header().Set("Content-Type", "application/json") @@ -659,7 +659,7 @@ func (h *HTTPHandler) handlePull(w http.ResponseWriter, r *http.Request) { // Call the model manager's Pull method with the wrapped writer if err := h.modelManager.Pull(modelName, "", r, ollamaWriter); err != nil { - h.log.Error(fmt.Sprintf("Failed to pull model: %s", utils.SanitizeForLog(err.Error(), -1))) + h.log.Error("Failed to pull model", "error", utils.SanitizeForLog(err.Error(), -1)) // Send error in Ollama JSON format errorResponse := ollamaPullStatus{ @@ -671,7 +671,7 @@ func (h *HTTPHandler) handlePull(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusInternalServerError) if err := json.NewEncoder(w).Encode(errorResponse); err != nil { - h.log.Error(fmt.Sprintf("failed to encode response: %v", err)) + h.log.Error("failed to encode response", "error", err) } } else { // Headers already sent - write error as JSON line @@ -1048,7 +1048,7 @@ func (s *streamingChatResponseWriter) Write(data []byte) (int, error) { // Parse OpenAI chunk using proper struct var chunk openAIChatStreamChunk if err := json.Unmarshal([]byte(dataStr), &chunk); err != nil { - s.log.Warn(fmt.Sprintf("Failed to parse OpenAI chat stream chunk: %v", err)) + s.log.Warn("Failed to parse OpenAI chat stream chunk", "error", err) continue } @@ -1170,7 +1170,7 @@ func (s *streamingGenerateResponseWriter) Write(data []byte) (int, error) { // Parse OpenAI chunk using proper struct var chunk openAIChatStreamChunk if err := json.Unmarshal([]byte(dataStr), &chunk); err != nil { - s.log.Warn(fmt.Sprintf("Failed to parse OpenAI chat stream chunk: %v", err)) + s.log.Warn("Failed to parse OpenAI chat stream chunk", "error", err) continue } @@ -1217,7 +1217,7 @@ func (h *HTTPHandler) convertChatResponse(w http.ResponseWriter, respRecorder *r // Convert to Ollama error format (simple string) w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(map[string]string{"error": openAIErr.Error.Message}); err != nil { - h.log.Error(fmt.Sprintf("failed to encode response: %v", err)) + h.log.Error("failed to encode response", "error", err) } } else { // Fallback: return raw error body @@ -1229,7 +1229,7 @@ func (h *HTTPHandler) convertChatResponse(w http.ResponseWriter, respRecorder *r // Parse OpenAI response using proper struct var openAIResp openAIChatResponse if err := json.Unmarshal([]byte(respRecorder.body.String()), &openAIResp); err != nil { - h.log.Error(fmt.Sprintf("Failed to parse OpenAI response: %v", err)) + h.log.Error("Failed to parse OpenAI response", "error", err) http.Error(w, "Failed to parse response", http.StatusInternalServerError) return } @@ -1259,7 +1259,7 @@ func (h *HTTPHandler) convertChatResponse(w http.ResponseWriter, respRecorder *r w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { - h.log.Error(fmt.Sprintf("Failed to encode response: %v", err)) + h.log.Error("Failed to encode response", "error", err) } } @@ -1304,7 +1304,7 @@ func (h *HTTPHandler) convertGenerateResponse(w http.ResponseWriter, respRecorde // Convert to Ollama error format (simple string) w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(map[string]string{"error": openAIErr.Error.Message}); err != nil { - h.log.Error(fmt.Sprintf("failed to encode response: %v", err)) + h.log.Error("failed to encode response", "error", err) } } else { // Fallback: return raw error body @@ -1316,7 +1316,7 @@ func (h *HTTPHandler) convertGenerateResponse(w http.ResponseWriter, respRecorde // Parse OpenAI chat response (since we're now using chat completions endpoint) var openAIResp openAIChatResponse if err := json.Unmarshal([]byte(respRecorder.body.String()), &openAIResp); err != nil { - h.log.Error(fmt.Sprintf("Failed to parse OpenAI chat response: %v", err)) + h.log.Error("Failed to parse OpenAI chat response", "error", err) http.Error(w, "Failed to parse response", http.StatusInternalServerError) return } @@ -1340,6 +1340,6 @@ func (h *HTTPHandler) convertGenerateResponse(w http.ResponseWriter, respRecorde w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { - h.log.Error(fmt.Sprintf("Failed to encode response: %v", err)) + h.log.Error("Failed to encode response", "error", err) } } diff --git a/pkg/responses/handler.go b/pkg/responses/handler.go index bb4c50cd8..298e98735 100644 --- a/pkg/responses/handler.go +++ b/pkg/responses/handler.go @@ -55,7 +55,7 @@ func NewHTTPHandler(log logging.Logger, schedulerHTTP http.Handler, allowedOrigi func (h *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { cleanPath := strings.ReplaceAll(r.URL.Path, "\n", "") cleanPath = strings.ReplaceAll(cleanPath, "\r", "") - h.log.Info(fmt.Sprintf("Responses API request: %s %s", r.Method, cleanPath)) + h.log.Info("Responses API request", "method", r.Method, "path", cleanPath) h.httpHandler.ServeHTTP(w, r) } @@ -305,7 +305,7 @@ func (h *HTTPHandler) sendJSON(w http.ResponseWriter, statusCode int, data inter w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) if err := json.NewEncoder(w).Encode(data); err != nil { - h.log.Error(fmt.Sprintf("Failed to encode JSON response: %v", err)) + h.log.Error("Failed to encode JSON response", "error", err) } } diff --git a/pkg/responses/handler_test.go b/pkg/responses/handler_test.go index 7e586f39d..1e460dd78 100644 --- a/pkg/responses/handler_test.go +++ b/pkg/responses/handler_test.go @@ -1,16 +1,14 @@ package responses import ( - "fmt" - "log/slog" "bytes" "encoding/json" "io" + "log/slog" "net/http" "net/http/httptest" "strings" "testing" - ) // mockSchedulerHTTP is a mock scheduler that returns predefined responses. @@ -86,28 +84,28 @@ func TestHandler_CreateResponse_NonStreaming(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusOK { - t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusOK)) + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) } var result Response if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - t.Error(fmt.Sprintf("failed to decode response: %v", err)) + t.Errorf("failed to decode response: %v", err) } if result.Object != "response" { - t.Error(fmt.Sprintf("object = %s, want response", result.Object)) + t.Errorf("object = %s, want response", result.Object) } if result.Model != "gpt-4" { - t.Error(fmt.Sprintf("model = %s, want gpt-4", result.Model)) + t.Errorf("model = %s, want gpt-4", result.Model) } if result.Status != StatusCompleted { - t.Error(fmt.Sprintf("status = %s, want %s", result.Status, StatusCompleted)) + t.Errorf("status = %s, want %s", result.Status, StatusCompleted) } if result.OutputText != "Hello! How can I help you?" { - t.Error(fmt.Sprintf("output_text = %s, want Hello! How can I help you?", result.OutputText)) + t.Errorf("output_text = %s, want Hello! How can I help you?", result.OutputText) } if !strings.HasPrefix(result.ID, "resp_") { - t.Error(fmt.Sprintf("id should start with resp_, got %s", result.ID)) + t.Errorf("id should start with resp_, got %s", result.ID) } } @@ -125,7 +123,7 @@ func TestHandler_CreateResponse_MissingModel(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusBadRequest { - t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusBadRequest)) + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusBadRequest) } var errResp map[string]interface{} @@ -148,7 +146,7 @@ func TestHandler_CreateResponse_InvalidJSON(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusBadRequest { - t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusBadRequest)) + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusBadRequest) } } @@ -171,19 +169,19 @@ func TestHandler_GetResponse(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusOK { - t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusOK)) + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) } var result Response if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - t.Error(fmt.Sprintf("failed to decode response: %v", err)) + t.Errorf("failed to decode response: %v", err) } if result.ID != "resp_test123" { - t.Error(fmt.Sprintf("id = %s, want resp_test123", result.ID)) + t.Errorf("id = %s, want resp_test123", result.ID) } if result.OutputText != "Test output" { - t.Error(fmt.Sprintf("output_text = %s, want Test output", result.OutputText)) + t.Errorf("output_text = %s, want Test output", result.OutputText) } } @@ -199,7 +197,7 @@ func TestHandler_GetResponse_NotFound(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusNotFound { - t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusNotFound)) + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusNotFound) } } @@ -220,7 +218,7 @@ func TestHandler_DeleteResponse(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusOK { - t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusOK)) + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) } // Verify it's deleted @@ -242,7 +240,7 @@ func TestHandler_DeleteResponse_NotFound(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusNotFound { - t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusNotFound)) + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusNotFound) } } @@ -300,14 +298,14 @@ func TestHandler_CreateResponse_WithPreviousResponse(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - t.Error(fmt.Sprintf("status = %d, want %d, body: %s", resp.StatusCode, http.StatusOK, body)) + t.Errorf("status = %d, want %d, body: %s", resp.StatusCode, http.StatusOK, body) } var result Response json.NewDecoder(resp.Body).Decode(&result) if result.PreviousResponseID == nil || *result.PreviousResponseID != "resp_prev123" { - t.Error(fmt.Sprintf("previous_response_id = %v, want resp_prev123", result.PreviousResponseID)) + t.Errorf("previous_response_id = %v, want resp_prev123", result.PreviousResponseID) } } @@ -338,14 +336,14 @@ func TestHandler_CreateResponse_UpstreamError(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusInternalServerError { - t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusInternalServerError)) + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusInternalServerError) } var result Response json.NewDecoder(resp.Body).Decode(&result) if result.Status != StatusFailed { - t.Error(fmt.Sprintf("status = %s, want %s", result.Status, StatusFailed)) + t.Errorf("status = %s, want %s", result.Status, StatusFailed) } if result.Error == nil { t.Error("expected error to be set") @@ -374,7 +372,7 @@ func TestHandler_CreateResponse_UpstreamError_NonJSONBody(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusInternalServerError { - t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusInternalServerError)) + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusInternalServerError) } var result Response @@ -382,7 +380,7 @@ func TestHandler_CreateResponse_UpstreamError_NonJSONBody(t *testing.T) { // Assert: non-streaming error handling falls back correctly if result.Status != StatusFailed { - t.Error(fmt.Sprintf("status = %s, want %s", result.Status, StatusFailed)) + t.Errorf("status = %s, want %s", result.Status, StatusFailed) } if result.Error == nil { @@ -390,11 +388,11 @@ func TestHandler_CreateResponse_UpstreamError_NonJSONBody(t *testing.T) { } if result.Error.Code != "upstream_error" { - t.Error(fmt.Sprintf("error.code = %v, want upstream_error", result.Error.Code)) + t.Errorf("error.code = %v, want upstream_error", result.Error.Code) } if !strings.Contains(result.Error.Message, "upstream exploded in a non-json way") { - t.Error(fmt.Sprintf("error.message = %q, want to contain raw upstream body", result.Error.Message)) + t.Errorf("error.message = %q, want to contain raw upstream body", result.Error.Message) } } @@ -427,19 +425,19 @@ func TestHandler_CreateResponse_Streaming(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusOK { - t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusOK)) + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) } // Check content type is SSE contentType := resp.Header.Get("Content-Type") if !strings.Contains(contentType, "text/event-stream") { - t.Error(fmt.Sprintf("Content-Type = %s, want text/event-stream", contentType)) + t.Errorf("Content-Type = %s, want text/event-stream", contentType) } // Read all body body, err := io.ReadAll(resp.Body) if err != nil { - t.Error(fmt.Sprintf("failed to read body: %v", err)) + t.Errorf("failed to read body: %v", err) } bodyStr := string(body) @@ -518,7 +516,7 @@ func TestHandler_CreateResponse_WithTools(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - t.Error(fmt.Sprintf("status = %d, want %d, body: %s", resp.StatusCode, http.StatusOK, body)) + t.Errorf("status = %d, want %d, body: %s", resp.StatusCode, http.StatusOK, body) } var result Response @@ -542,10 +540,10 @@ func TestHandler_CreateResponse_WithTools(t *testing.T) { } if funcCall.Name != "get_weather" { - t.Error(fmt.Sprintf("function name = %s, want get_weather", funcCall.Name)) + t.Errorf("function name = %s, want get_weather", funcCall.Name) } if funcCall.CallID != "call_abc123" { - t.Error(fmt.Sprintf("call_id = %s, want call_abc123", funcCall.CallID)) + t.Errorf("call_id = %s, want call_abc123", funcCall.CallID) } } @@ -590,7 +588,7 @@ func TestHandler_ResponsePersistence(t *testing.T) { json.NewDecoder(w2.Result().Body).Decode(&getResult) if getResult.ID != createResult.ID { - t.Error(fmt.Sprintf("IDs don't match: %s vs %s", getResult.ID, createResult.ID)) + t.Errorf("IDs don't match: %s vs %s", getResult.ID, createResult.ID) } } @@ -624,20 +622,20 @@ func TestHandler_CreateResponse_Streaming_Persistence(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusOK { - t.Error(fmt.Sprintf("status = %d, want %d", resp.StatusCode, http.StatusOK)) + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) } // Verify that the StreamingResponseWriter persisted a coherent Response in the store memStore := handler.store if memStore.Count() != 1 { - t.Error(fmt.Sprintf("expected exactly one response in store, got %d", memStore.Count())) + t.Errorf("expected exactly one response in store, got %d", memStore.Count()) } // Get the response ID from the store responseIDs := memStore.GetResponseIDs() if len(responseIDs) != 1 { - t.Error(fmt.Sprintf("expected exactly one response ID in store, got %d", len(responseIDs))) + t.Errorf("expected exactly one response ID in store, got %d", len(responseIDs)) } // Retrieve the response using the public API @@ -648,12 +646,12 @@ func TestHandler_CreateResponse_Streaming_Persistence(t *testing.T) { // Status should be completed after streaming finishes if persistedResp.Status != StatusCompleted { - t.Error(fmt.Sprintf("persisted response status = %s, want %s", persistedResp.Status, StatusCompleted)) + t.Errorf("persisted response status = %s, want %s", persistedResp.Status, StatusCompleted) } // OutputText should match concatenated streamed chunks: "Hello" + "!" => "Hello!" if persistedResp.OutputText != "Hello!" { - t.Error(fmt.Sprintf("persisted response OutputText = %q, want %q", persistedResp.OutputText, "Hello!")) + t.Errorf("persisted response OutputText = %q, want %q", persistedResp.OutputText, "Hello!") } // There should be at least one OutputItem whose message content matches "Hello!" @@ -674,7 +672,7 @@ func TestHandler_CreateResponse_Streaming_Persistence(t *testing.T) { } } if !found { - t.Error(fmt.Sprintf("expected an OutputItem message with text %q in persisted response", "Hello!")) + t.Errorf("expected an OutputItem message with text %q in persisted response", "Hello!") } } From b7731483f2b2b30d76f01792892a77bb5b5f3ad7 Mon Sep 17 00:00:00 2001 From: Varun Chawla Date: Sat, 14 Feb 2026 23:05:45 -0800 Subject: [PATCH 3/4] Fix lint: gci import ordering and staticcheck S1038 Sort stdlib imports alphabetically in main.go, remove unused fmt import and trailing blank line in scheduler_test.go, and replace t.Error(fmt.Sprintf(...)) with t.Errorf(...) in test files. --- main.go | 2 +- main_test.go | 3 +-- pkg/inference/scheduling/scheduler_test.go | 4 +--- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/main.go b/main.go index 546d44ea2..29be6e19f 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,9 @@ package main import ( - "fmt" "context" "crypto/tls" + "fmt" "log/slog" "net" "net/http" diff --git a/main_test.go b/main_test.go index b9754e6f5..ac7a6e94a 100644 --- a/main_test.go +++ b/main_test.go @@ -1,7 +1,6 @@ package main import ( - "fmt" "testing" "github.com/docker/model-runner/pkg/inference/backends/llamacpp" @@ -76,7 +75,7 @@ func TestCreateLlamaCppConfigFromEnv(t *testing.T) { llamaConfig, ok := config.(*llamacpp.Config) if !ok { - t.Error(fmt.Sprintf("Expected *llamacpp.Config, got %T", config)) + t.Errorf("Expected *llamacpp.Config, got %T", config) } if llamaConfig == nil { t.Fatal("Expected non-nil config") diff --git a/pkg/inference/scheduling/scheduler_test.go b/pkg/inference/scheduling/scheduler_test.go index e13bf5b70..58520278d 100644 --- a/pkg/inference/scheduling/scheduler_test.go +++ b/pkg/inference/scheduling/scheduler_test.go @@ -1,12 +1,10 @@ package scheduling import ( - "fmt" "log/slog" "net/http" "net/http/httptest" "testing" - ) func TestCors(t *testing.T) { @@ -39,7 +37,7 @@ func TestCors(t *testing.T) { httpHandler.ServeHTTP(w, req) if w.Code != http.StatusNoContent { - t.Error(fmt.Sprintf("Expected status code 204 for OPTIONS request, got %d", w.Code)) + t.Errorf("Expected status code 204 for OPTIONS request, got %d", w.Code) } }) } From b4a3247def91fa862e9e601448925b101688e252 Mon Sep 17 00:00:00 2001 From: Varun Chawla Date: Mon, 16 Feb 2026 13:17:38 -0800 Subject: [PATCH 4/4] Address reviewer feedback: restore Fatal exit behavior, fix structured logging - Add os.Exit(1) after log.Error where log.Fatal was originally used (prevents app from continuing with nil backends/listeners) - Convert all remaining fmt.Sprintf log calls to proper slog key-value pairs for structured logging benefits - Restore createLlamaCppConfigFromEnv Fatal exit behavior via exitFunc - Fix test assertions: restore t.Fatalf for setup errors, use t.Errorf for assertion checks - Fix gci import ordering (log/slog in correct alphabetical position) - Remove trailing newlines from log messages --- main.go | 70 +++-- main_test.go | 67 +++-- pkg/distribution/distribution/client_test.go | 257 +++++++++--------- .../distribution/normalize_test.go | 12 +- pkg/inference/backends/llamacpp/download.go | 2 +- pkg/inference/backends/llamacpp/llamacpp.go | 2 +- pkg/inference/models/handler_test.go | 36 +-- pkg/inference/scheduling/loader.go | 16 +- pkg/inference/scheduling/scheduler.go | 4 +- pkg/responses/handler_test.go | 19 +- 10 files changed, 256 insertions(+), 229 deletions(-) diff --git a/main.go b/main.go index 29be6e19f..e98e1944a 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,6 @@ package main import ( "context" "crypto/tls" - "fmt" "log/slog" "net" "net/http" @@ -54,6 +53,9 @@ var Log = log // testLog is a test-override logger used by createLlamaCppConfigFromEnv. var testLog = log +// exitFunc is the function called for fatal errors. Overridable in tests. +var exitFunc = os.Exit + func main() { ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer cancel() @@ -65,7 +67,8 @@ func main() { userHomeDir, err := os.UserHomeDir() if err != nil { - log.Error(fmt.Sprintf("Failed to get user home directory: %v", err)) + log.Error("Failed to get user home directory", "error", err) + os.Exit(1) } modelPath := os.Getenv("MODELS_PATH") @@ -118,18 +121,18 @@ func main() { modelManager, nil, ) - log.Info(fmt.Sprintf("LLAMA_SERVER_PATH: %s", llamaServerPath)) + log.Info("LLAMA_SERVER_PATH", "path", llamaServerPath) if vllmServerPath != "" { - log.Info(fmt.Sprintf("VLLM_SERVER_PATH: %s", vllmServerPath)) + log.Info("VLLM_SERVER_PATH", "path", vllmServerPath) } if sglangServerPath != "" { - log.Info(fmt.Sprintf("SGLANG_SERVER_PATH: %s", sglangServerPath)) + log.Info("SGLANG_SERVER_PATH", "path", sglangServerPath) } if mlxServerPath != "" { - log.Info(fmt.Sprintf("MLX_SERVER_PATH: %s", mlxServerPath)) + log.Info("MLX_SERVER_PATH", "path", mlxServerPath) } if vllmMetalServerPath != "" { - log.Info(fmt.Sprintf("VLLM_METAL_SERVER_PATH: %s", vllmMetalServerPath)) + log.Info("VLLM_METAL_SERVER_PATH", "path", vllmMetalServerPath) } // Create llama.cpp configuration from environment variables @@ -149,12 +152,14 @@ func main() { llamaCppConfig, ) if err != nil { - log.Error(fmt.Sprintf("unable to initialize %s backend: %v", llamacpp.Name, err)) + log.Error("Unable to initialize backend", "backend", llamacpp.Name, "error", err) + os.Exit(1) } vllmBackend, err := initVLLMBackend(log, modelManager, vllmServerPath) if err != nil { - log.Error(fmt.Sprintf("unable to initialize %s backend: %v", vllm.Name, err)) + log.Error("Unable to initialize backend", "backend", vllm.Name, "error", err) + os.Exit(1) } mlxBackend, err := mlx.New( @@ -165,7 +170,8 @@ func main() { mlxServerPath, ) if err != nil { - log.Error(fmt.Sprintf("unable to initialize %s backend: %v", mlx.Name, err)) + log.Error("Unable to initialize backend", "backend", mlx.Name, "error", err) + os.Exit(1) } sglangBackend, err := sglang.New( @@ -176,7 +182,8 @@ func main() { sglangServerPath, ) if err != nil { - log.Error(fmt.Sprintf("unable to initialize %s backend: %v", sglang.Name, err)) + log.Error("Unable to initialize backend", "backend", sglang.Name, "error", err) + os.Exit(1) } diffusersBackend, err := diffusers.New( @@ -188,7 +195,8 @@ func main() { ) if err != nil { - log.Error(fmt.Sprintf("unable to initialize diffusers backend: %v", err)) + log.Error("Unable to initialize backend", "backend", diffusers.Name, "error", err) + os.Exit(1) } var vllmMetalBackend inference.Backend @@ -200,7 +208,7 @@ func main() { vllmMetalServerPath, ) if err != nil { - log.Warn(fmt.Sprintf("Failed to initialize vllm-metal backend: %v", err)) + log.Warn("Failed to initialize vllm-metal backend", "error", err) } } @@ -310,7 +318,7 @@ func main() { if tcpPort != "" { // Use TCP port addr := ":" + tcpPort - log.Info(fmt.Sprintf("Listening on TCP port %s", tcpPort)) + log.Info("Listening on TCP port", "port", tcpPort) server.Addr = addr go func() { serverErrors <- server.ListenAndServe() @@ -319,12 +327,14 @@ func main() { // Use Unix socket if err := os.Remove(sockName); err != nil { if !os.IsNotExist(err) { - log.Error(fmt.Sprintf("Failed to remove existing socket: %v", err)) + log.Error("Failed to remove existing socket", "error", err) + os.Exit(1) } } ln, err := net.ListenUnix("unix", &net.UnixAddr{Name: sockName, Net: "unix"}) if err != nil { - log.Error(fmt.Sprintf("Failed to listen on socket: %v", err)) + log.Error("Failed to listen on socket", "error", err) + os.Exit(1) } go func() { serverErrors <- server.Serve(ln) @@ -349,19 +359,22 @@ func main() { var err error certPath, keyPath, err = modeltls.EnsureCertificates("", "") if err != nil { - log.Error(fmt.Sprintf("Failed to ensure TLS certificates: %v", err)) + log.Error("Failed to ensure TLS certificates", "error", err) + os.Exit(1) } - log.Info(fmt.Sprintf("Using TLS certificate: %s", certPath)) - log.Info(fmt.Sprintf("Using TLS key: %s", keyPath)) + log.Info("Using TLS certificate", "path", certPath) + log.Info("Using TLS key", "path", keyPath) } else { log.Error("TLS enabled but no certificate provided and auto-cert is disabled") + os.Exit(1) } } // Load TLS configuration tlsConfig, err := modeltls.LoadTLSConfig(certPath, keyPath) if err != nil { - log.Error(fmt.Sprintf("Failed to load TLS configuration: %v", err)) + log.Error("Failed to load TLS configuration", "error", err) + os.Exit(1) } tlsServer = &http.Server{ @@ -371,7 +384,7 @@ func main() { ReadHeaderTimeout: 10 * time.Second, } - log.Info(fmt.Sprintf("Listening on TLS port %s", tlsPort)) + log.Info("Listening on TLS port", "port", tlsPort) go func() { // Use ListenAndServeTLS with empty strings since TLSConfig already has the certs ln, err := tls.Listen("tcp", tlsServer.Addr, tlsConfig) @@ -399,27 +412,27 @@ func main() { select { case err := <-serverErrors: if err != nil { - log.Error(fmt.Sprintf("Server error: %v", err)) + log.Error("Server error", "error", err) } case err := <-tlsServerErrorsChan: if err != nil { - log.Error(fmt.Sprintf("TLS server error: %v", err)) + log.Error("TLS server error", "error", err) } case <-ctx.Done(): log.Info("Shutdown signal received") log.Info("Shutting down the server") if err := server.Close(); err != nil { - log.Error(fmt.Sprintf("Server shutdown error: %v", err)) + log.Error("Server shutdown error", "error", err) } if tlsServer != nil { log.Info("Shutting down the TLS server") if err := tlsServer.Close(); err != nil { - log.Error(fmt.Sprintf("TLS server shutdown error: %v", err)) + log.Error("TLS server shutdown error", "error", err) } } log.Info("Waiting for the scheduler to stop") if err := <-schedulerErrors; err != nil { - log.Error(fmt.Sprintf("Scheduler error: %v", err)) + log.Error("Scheduler error", "error", err) } } log.Info("Docker Model Runner stopped") @@ -443,12 +456,13 @@ func createLlamaCppConfigFromEnv() config.BackendConfig { for _, arg := range args { for _, disallowed := range disallowedArgs { if arg == disallowed { - testLog.Error(fmt.Sprintf("LLAMA_ARGS cannot override the %s argument as it is controlled by the model runner", disallowed)) + testLog.Error("LLAMA_ARGS cannot override argument", "arg", disallowed) + exitFunc(1) } } } - testLog.Info(fmt.Sprintf("Using custom arguments: %v", args)) + testLog.Info("Using custom arguments", "args", args) return &llamacpp.Config{ Args: args, } diff --git a/main_test.go b/main_test.go index ac7a6e94a..d2967235b 100644 --- a/main_test.go +++ b/main_test.go @@ -10,47 +10,47 @@ func TestCreateLlamaCppConfigFromEnv(t *testing.T) { tests := []struct { name string llamaArgs string - wantNil bool + wantErr bool }{ { name: "empty args", llamaArgs: "", - wantNil: true, + wantErr: false, }, { name: "valid args", llamaArgs: "--threads 4 --ctx-size 2048", - wantNil: false, + wantErr: false, }, { name: "disallowed model arg", llamaArgs: "--model test.gguf", - wantNil: false, // config is still created, error is logged + wantErr: true, }, { name: "disallowed host arg", llamaArgs: "--host localhost:8080", - wantNil: false, + wantErr: true, }, { name: "disallowed embeddings arg", llamaArgs: "--embeddings", - wantNil: false, + wantErr: true, }, { name: "disallowed mmproj arg", llamaArgs: "--mmproj test.mmproj", - wantNil: false, + wantErr: true, }, { name: "multiple disallowed args", llamaArgs: "--model test.gguf --host localhost:8080", - wantNil: false, + wantErr: true, }, { name: "quoted args", llamaArgs: "--prompt \"Hello, world!\" --threads 4", - wantNil: false, + wantErr: false, }, } @@ -60,28 +60,41 @@ func TestCreateLlamaCppConfigFromEnv(t *testing.T) { t.Setenv("LLAMA_ARGS", tt.llamaArgs) } - config := createLlamaCppConfigFromEnv() + // Override exitFunc to capture exit calls instead of actually exiting + originalExitFunc := exitFunc + defer func() { exitFunc = originalExitFunc }() - if tt.wantNil { - if config != nil { - t.Error("Expected nil config for empty args") - } - return + var exitCode int + exitFunc = func(code int) { + exitCode = code } - if config == nil { - t.Fatal("Expected non-nil config") - } + config := createLlamaCppConfigFromEnv() - llamaConfig, ok := config.(*llamacpp.Config) - if !ok { - t.Errorf("Expected *llamacpp.Config, got %T", config) - } - if llamaConfig == nil { - t.Fatal("Expected non-nil config") - } - if len(llamaConfig.Args) == 0 { - t.Error("Expected non-empty args") + if tt.wantErr { + if exitCode != 1 { + t.Errorf("Expected exit code 1, got %d", exitCode) + } + } else { + if exitCode != 0 { + t.Errorf("Expected exit code 0, got %d", exitCode) + } + if tt.llamaArgs == "" { + if config != nil { + t.Error("Expected nil config for empty args") + } + } else { + llamaConfig, ok := config.(*llamacpp.Config) + if !ok { + t.Fatalf("Expected *llamacpp.Config, got %T", config) + } + if llamaConfig == nil { + t.Fatal("Expected non-nil config") + } + if len(llamaConfig.Args) == 0 { + t.Error("Expected non-empty args") + } + } } }) } diff --git a/pkg/distribution/distribution/client_test.go b/pkg/distribution/distribution/client_test.go index 3195a01b1..ae5754349 100644 --- a/pkg/distribution/distribution/client_test.go +++ b/pkg/distribution/distribution/client_test.go @@ -45,7 +45,7 @@ func TestClientPullModel(t *testing.T) { defer server.Close() registryURL, err := url.Parse(server.URL) if err != nil { - t.Errorf("Failed to parse registry URL: %v", err) + t.Fatalf("Failed to parse registry URL: %v", err) } registryHost := registryURL.Host @@ -54,48 +54,48 @@ func TestClientPullModel(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Read model content for verification later modelContent, err := os.ReadFile(testGGUFFile) if err != nil { - t.Errorf("Failed to read test model file: %v", err) + t.Fatalf("Failed to read test model file: %v", err) } model := testutil.BuildModelFromPath(t, testGGUFFile) tag := registryHost + "/testmodel:v1.0.0" ref, err := reference.ParseReference(tag) if err != nil { - t.Errorf("Failed to parse reference: %v", err) + t.Fatalf("Failed to parse reference: %v", err) } if err := remote.Write(ref, model, nil, remote.WithPlainHTTP(true)); err != nil { - t.Errorf("Failed to push model: %v", err) + t.Fatalf("Failed to push model: %v", err) } t.Run("pull without progress writer", func(t *testing.T) { // Pull model from registry without progress writer err := client.PullModel(t.Context(), tag, nil) if err != nil { - t.Errorf("Failed to pull model: %v", err) + t.Fatalf("Failed to pull model: %v", err) } model, err := client.GetModel(tag) if err != nil { - t.Errorf("Failed to get model: %v", err) + t.Fatalf("Failed to get model: %v", err) } modelPaths, err := model.GGUFPaths() if err != nil { - t.Errorf("Failed to get model path: %v", err) + t.Fatalf("Failed to get model path: %v", err) } if len(modelPaths) != 1 { - t.Errorf("Unexpected number of model files: %d", len(modelPaths)) + t.Fatalf("Unexpected number of model files: %d", len(modelPaths)) } // Verify model content pulledContent, err := os.ReadFile(modelPaths[0]) if err != nil { - t.Errorf("Failed to read pulled model: %v", err) + t.Fatalf("Failed to read pulled model: %v", err) } if string(pulledContent) != string(modelContent) { @@ -109,7 +109,7 @@ func TestClientPullModel(t *testing.T) { // Pull model from registry with progress writer if err := client.PullModel(t.Context(), tag, &progressBuffer); err != nil { - t.Errorf("Failed to pull model: %v", err) + t.Fatalf("Failed to pull model: %v", err) } // Verify progress output @@ -120,21 +120,21 @@ func TestClientPullModel(t *testing.T) { model, err := client.GetModel(tag) if err != nil { - t.Errorf("Failed to get model: %v", err) + t.Fatalf("Failed to get model: %v", err) } modelPaths, err := model.GGUFPaths() if err != nil { - t.Errorf("Failed to get model path: %v", err) + t.Fatalf("Failed to get model path: %v", err) } if len(modelPaths) != 1 { - t.Errorf("Unexpected number of model files: %d", len(modelPaths)) + t.Fatalf("Unexpected number of model files: %d", len(modelPaths)) } // Verify model content pulledContent, err := os.ReadFile(modelPaths[0]) if err != nil { - t.Errorf("Failed to read pulled model: %v", err) + t.Fatalf("Failed to read pulled model: %v", err) } if string(pulledContent) != string(modelContent) { @@ -148,7 +148,7 @@ func TestClientPullModel(t *testing.T) { // Create client with plainHTTP for test registry testClient, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Create a buffer to capture progress output @@ -165,12 +165,12 @@ func TestClientPullModel(t *testing.T) { var pullErr *mdregistry.Error ok := errors.As(err, &pullErr) if !ok { - t.Errorf("Expected registry.Error, got %T: %v", err, err) + t.Fatalf("Expected registry.Error, got %T: %v", err, err) } // Verify it matches registry.ErrModelNotFound for API compatibility if !errors.Is(err, mdregistry.ErrModelNotFound) { - t.Errorf("Expected registry.ErrModelNotFound, got %T", err) + t.Fatalf("Expected registry.ErrModelNotFound, got %T", err) } // Verify error fields @@ -196,7 +196,7 @@ func TestClientPullModel(t *testing.T) { // Create client with plainHTTP for test registry testClient, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Use the dummy.gguf file from assets directory @@ -205,26 +205,26 @@ func TestClientPullModel(t *testing.T) { // Push model to local store testTag := registryHost + "/incomplete-test/model:v1.0.0" if err := testClient.store.Write(mdl, []string{testTag}, nil); err != nil { - t.Errorf("Failed to push model to store: %v", err) + t.Fatalf("Failed to push model to store: %v", err) } // Push model to registry if err := testClient.PushModel(t.Context(), testTag, nil); err != nil { - t.Errorf("Failed to pull model: %v", err) + t.Fatalf("Failed to pull model: %v", err) } // Get the model to find the GGUF path model, err := testClient.GetModel(testTag) if err != nil { - t.Errorf("Failed to get model: %v", err) + t.Fatalf("Failed to get model: %v", err) } ggufPaths, err := model.GGUFPaths() if err != nil { - t.Errorf("Failed to get GGUF path: %v", err) + t.Fatalf("Failed to get GGUF path: %v", err) } if len(ggufPaths) != 1 { - t.Errorf("Unexpected number of model files: %d", len(ggufPaths)) + t.Fatalf("Unexpected number of model files: %d", len(ggufPaths)) } // Create an incomplete file by copying the GGUF file and adding .incomplete suffix @@ -232,23 +232,23 @@ func TestClientPullModel(t *testing.T) { incompletePath := ggufPath + ".incomplete" originalContent, err := os.ReadFile(ggufPath) if err != nil { - t.Errorf("Failed to read GGUF file: %v", err) + t.Fatalf("Failed to read GGUF file: %v", err) } // Write partial content to simulate an incomplete download partialContent := originalContent[:len(originalContent)/2] if err := os.WriteFile(incompletePath, partialContent, 0644); err != nil { - t.Errorf("Failed to create incomplete file: %v", err) + t.Fatalf("Failed to create incomplete file: %v", err) } // Verify the incomplete file exists if _, err := os.Stat(incompletePath); os.IsNotExist(err) { - t.Errorf("Failed to create incomplete file: %v", err) + t.Fatalf("Failed to create incomplete file: %v", err) } // Delete the local model to force a pull if _, err := testClient.DeleteModel(testTag, false); err != nil { - t.Errorf("Failed to delete model: %v", err) + t.Fatalf("Failed to delete model: %v", err) } // Create a buffer to capture progress output @@ -256,13 +256,13 @@ func TestClientPullModel(t *testing.T) { // Pull the model again - this should detect the incomplete file and pull again if err := testClient.PullModel(t.Context(), testTag, &progressBuffer); err != nil { - t.Errorf("Failed to pull model: %v", err) + t.Fatalf("Failed to pull model: %v", err) } // Verify progress output indicates a new download, not using cached model progressOutput := progressBuffer.String() if strings.Contains(progressOutput, "Using cached model") { - t.Error("Expected to pull model again due to incomplete file, but used cached model") + t.Errorf("Expected to pull model again due to incomplete file, but used cached model") } // Verify the incomplete file no longer exists @@ -278,11 +278,11 @@ func TestClientPullModel(t *testing.T) { // Verify the content of the pulled file matches the original pulledContent, err := os.ReadFile(ggufPath) if err != nil { - t.Errorf("Failed to read pulled GGUF file: %v", err) + t.Fatalf("Failed to read pulled GGUF file: %v", err) } if !bytes.Equal(pulledContent, originalContent) { - t.Error("Pulled content doesn't match original content") + t.Errorf("Pulled content doesn't match original content") } }) @@ -292,13 +292,13 @@ func TestClientPullModel(t *testing.T) { // Create client with plainHTTP for test registry testClient, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Read model content for verification later testModelContent, err := os.ReadFile(testGGUFFile) if err != nil { - t.Errorf("Failed to read test model file: %v", err) + t.Fatalf("Failed to read test model file: %v", err) } // Push first version of model to registry @@ -309,27 +309,27 @@ func TestClientPullModel(t *testing.T) { // Pull first version of model if err := testClient.PullModel(t.Context(), testTag, nil); err != nil { - t.Errorf("Failed to pull first version of model: %v", err) + t.Fatalf("Failed to pull first version of model: %v", err) } // Verify first version is in local store model, err := testClient.GetModel(testTag) if err != nil { - t.Errorf("Failed to get first version of model: %v", err) + t.Fatalf("Failed to get first version of model: %v", err) } modelPath, err := model.GGUFPaths() if err != nil { - t.Errorf("Failed to get model path: %v", err) + t.Fatalf("Failed to get model path: %v", err) } if len(modelPath) != 1 { - t.Errorf("Unexpected number of model files: %d", len(modelPath)) + t.Fatalf("Unexpected number of model files: %d", len(modelPath)) } // Verify first version content pulledContent, err := os.ReadFile(modelPath[0]) if err != nil { - t.Errorf("Failed to read pulled model: %v", err) + t.Fatalf("Failed to read pulled model: %v", err) } if string(pulledContent) != string(testModelContent) { @@ -340,7 +340,7 @@ func TestClientPullModel(t *testing.T) { updatedModelFile := filepath.Join(tempDir, "updated-dummy.gguf") updatedContent := append(testModelContent, []byte("UPDATED CONTENT")...) if err := os.WriteFile(updatedModelFile, updatedContent, 0644); err != nil { - t.Errorf("Failed to create updated model file: %v", err) + t.Fatalf("Failed to create updated model file: %v", err) } // Push updated model with same tag @@ -353,33 +353,33 @@ func TestClientPullModel(t *testing.T) { // Pull model again - should get the updated version if err := testClient.PullModel(t.Context(), testTag, &progressBuffer); err != nil { - t.Errorf("Failed to pull updated model: %v", err) + t.Fatalf("Failed to pull updated model: %v", err) } // Verify progress output indicates a new download, not using cached model progressOutput := progressBuffer.String() if strings.Contains(progressOutput, "Using cached model") { - t.Error("Expected to pull updated model, but used cached model") + t.Errorf("Expected to pull updated model, but used cached model") } // Get the model again to verify it's the updated version updatedModel, err := testClient.GetModel(testTag) if err != nil { - t.Errorf("Failed to get updated model: %v", err) + t.Fatalf("Failed to get updated model: %v", err) } updatedModelPaths, err := updatedModel.GGUFPaths() if err != nil { - t.Errorf("Failed to get updated model path: %v", err) + t.Fatalf("Failed to get updated model path: %v", err) } if len(updatedModelPaths) != 1 { - t.Errorf("Unexpected number of model files: %d", len(modelPath)) + t.Fatalf("Unexpected number of model files: %d", len(modelPath)) } // Verify updated content updatedPulledContent, err := os.ReadFile(updatedModelPaths[0]) if err != nil { - t.Errorf("Failed to read updated pulled model: %v", err) + t.Fatalf("Failed to read updated pulled model: %v", err) } if string(updatedPulledContent) != string(updatedContent) { @@ -393,13 +393,13 @@ func TestClientPullModel(t *testing.T) { testTag := registryHost + "/unsupported-test/model:v1.0.0" ref, err := reference.ParseReference(testTag) if err != nil { - t.Errorf("Failed to parse reference: %v", err) + t.Fatalf("Failed to parse reference: %v", err) } if err := remote.Write(ref, newMdl, nil, remote.WithPlainHTTP(true)); err != nil { - t.Errorf("Failed to push model: %v", err) + t.Fatalf("Failed to push model: %v", err) } if err := client.PullModel(t.Context(), testTag, nil); err == nil || !errors.Is(err, ErrUnsupportedMediaType) { - t.Errorf("Expected artifact version error, got %v", err) + t.Fatalf("Expected artifact version error, got %v", err) } }) @@ -410,7 +410,7 @@ func TestClientPullModel(t *testing.T) { safetensorsPath := filepath.Join(safetensorsTempDir, "model.safetensors") safetensorsContent := []byte("fake safetensors content for testing") if err := os.WriteFile(safetensorsPath, safetensorsContent, 0644); err != nil { - t.Errorf("Failed to create safetensors file: %v", err) + t.Fatalf("Failed to create safetensors file: %v", err) } // Create a safetensors model @@ -420,10 +420,10 @@ func TestClientPullModel(t *testing.T) { testTag := registryHost + "/safetensors-test/model:v1.0.0" ref, err := reference.ParseReference(testTag) if err != nil { - t.Errorf("Failed to parse reference: %v", err) + t.Fatalf("Failed to parse reference: %v", err) } if err := remote.Write(ref, safetensorsModel, nil, remote.WithPlainHTTP(true)); err != nil { - t.Errorf("Failed to push safetensors model to registry: %v", err) + t.Fatalf("Failed to push safetensors model to registry: %v", err) } // Create a new client with a separate temp store @@ -431,7 +431,7 @@ func TestClientPullModel(t *testing.T) { testClient, err := newTestClient(clientTempDir) if err != nil { - t.Errorf("Failed to create test client: %v", err) + t.Fatalf("Failed to create test client: %v", err) } // Try to pull the safetensors model with a progress writer to capture warnings @@ -440,17 +440,17 @@ func TestClientPullModel(t *testing.T) { // Pull should succeed on all platforms now (with a warning on non-Linux) if err != nil { - t.Errorf("Expected no error, got: %v", err) + t.Fatalf("Expected no error, got: %v", err) } if !platform.SupportsVLLM() { // On non-Linux, verify that a warning was written progressOutput := progressBuf.String() if !strings.Contains(progressOutput, `"type":"warning"`) { - t.Errorf("Expected warning message on non-Linux platforms, got output: %s", progressOutput) + t.Fatalf("Expected warning message on non-Linux platforms, got output: %s", progressOutput) } if !strings.Contains(progressOutput, warnUnsupportedFormat) { - t.Errorf("Expected warning about safetensors format, got output: %s", progressOutput) + t.Fatalf("Expected warning about safetensors format, got output: %s", progressOutput) } } }) @@ -461,7 +461,7 @@ func TestClientPullModel(t *testing.T) { // Create client with plainHTTP for test registry testClient, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Create a buffer to capture progress output @@ -469,7 +469,7 @@ func TestClientPullModel(t *testing.T) { // Pull model from registry with progress writer if err := testClient.PullModel(t.Context(), tag, &progressBuffer); err != nil { - t.Errorf("Failed to pull model: %v", err) + t.Fatalf("Failed to pull model: %v", err) } // Parse progress output as JSON @@ -479,13 +479,13 @@ func TestClientPullModel(t *testing.T) { line := scanner.Text() var msg oci.ProgressMessage if err := json.Unmarshal([]byte(line), &msg); err != nil { - t.Errorf("Failed to parse JSON progress message: %v, line: %s", err, line) + t.Fatalf("Failed to parse JSON progress message: %v, line: %s", err, line) } messages = append(messages, msg) } if err := scanner.Err(); err != nil { - t.Errorf("Error reading progress output: %v", err) + t.Fatalf("Error reading progress output: %v", err) } // Verify we got some messages @@ -509,25 +509,25 @@ func TestClientPullModel(t *testing.T) { // Verify model was pulled correctly model, err := testClient.GetModel(tag) if err != nil { - t.Errorf("Failed to get model: %v", err) + t.Fatalf("Failed to get model: %v", err) } modelPaths, err := model.GGUFPaths() if err != nil { - t.Errorf("Failed to get model path: %v", err) + t.Fatalf("Failed to get model path: %v", err) } if len(modelPaths) != 1 { - t.Errorf("Unexpected number of model files: %d", len(modelPaths)) + t.Fatalf("Unexpected number of model files: %d", len(modelPaths)) } // Verify model content pulledContent, err := os.ReadFile(modelPaths[0]) if err != nil { - t.Errorf("Failed to read pulled model: %v", err) + t.Fatalf("Failed to read pulled model: %v", err) } if string(pulledContent) != string(modelContent) { - t.Error("Pulled model content doesn't match original") + t.Errorf("Pulled model content doesn't match original") } }) @@ -537,7 +537,7 @@ func TestClientPullModel(t *testing.T) { // Create client with plainHTTP for test registry testClient, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Create a buffer to capture progress output @@ -554,7 +554,7 @@ func TestClientPullModel(t *testing.T) { // Verify it matches registry.ErrModelNotFound if !errors.Is(err, mdregistry.ErrModelNotFound) { - t.Errorf("Expected registry.ErrModelNotFound, got %T", err) + t.Fatalf("Expected registry.ErrModelNotFound, got %T", err) } // No JSON messages should be in the buffer for this error case @@ -568,7 +568,7 @@ func TestClientGetModel(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Create model from test GGUF file @@ -578,13 +578,13 @@ func TestClientGetModel(t *testing.T) { tag := "test/model:v1.0.0" normalizedTag := "docker.io/test/model:v1.0.0" // Reference package normalizes to include registry if err := client.store.Write(model, []string{tag}, nil); err != nil { - t.Errorf("Failed to push model to store: %v", err) + t.Fatalf("Failed to push model to store: %v", err) } // Get model mi, err := client.GetModel(tag) if err != nil { - t.Errorf("Failed to get model: %v", err) + t.Fatalf("Failed to get model: %v", err) } // Verify model - tags are normalized to include the default registry @@ -599,7 +599,7 @@ func TestClientGetModelNotFound(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Get non-existent model @@ -615,14 +615,14 @@ func TestClientListModels(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Create test model file modelContent := []byte("test model content") modelFile := filepath.Join(tempDir, "test-model.gguf") if err := os.WriteFile(modelFile, modelContent, 0644); err != nil { - t.Errorf("Failed to write test model file: %v", err) + t.Fatalf("Failed to write test model file: %v", err) } mdl := testutil.BuildModelFromPath(t, modelFile) @@ -631,21 +631,21 @@ func TestClientListModels(t *testing.T) { // First model tag1 := "test/model1:v1.0.0" if err := client.store.Write(mdl, []string{tag1}, nil); err != nil { - t.Errorf("Failed to push model to store: %v", err) + t.Fatalf("Failed to push model to store: %v", err) } // Create a slightly different model file for the second model modelContent2 := []byte("test model content 2") modelFile2 := filepath.Join(tempDir, "test-model2.gguf") if err := os.WriteFile(modelFile2, modelContent2, 0644); err != nil { - t.Errorf("Failed to write test model file: %v", err) + t.Fatalf("Failed to write test model file: %v", err) } mdl2 := testutil.BuildModelFromPath(t, modelFile2) // Second model tag2 := "test/model2:v1.0.0" if err := client.store.Write(mdl2, []string{tag2}, nil); err != nil { - t.Errorf("Failed to push model to store: %v", err) + t.Fatalf("Failed to push model to store: %v", err) } // Normalized tags for verification (reference package normalizes to include default registry) @@ -656,7 +656,7 @@ func TestClientListModels(t *testing.T) { // List models models, err := client.ListModels() if err != nil { - t.Errorf("Failed to list models: %v", err) + t.Fatalf("Failed to list models: %v", err) } // Verify models @@ -685,7 +685,7 @@ func TestClientGetStorePath(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Get store path @@ -708,7 +708,7 @@ func TestClientDefaultLogger(t *testing.T) { // Create client without specifying logger client, err := NewClient(WithStoreRootPath(tempDir)) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Verify that logger is not nil @@ -723,7 +723,7 @@ func TestClientDefaultLogger(t *testing.T) { WithLogger(customLogger), ) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Verify that custom logger is used @@ -746,7 +746,8 @@ func TestWithFunctionsNilChecks(t *testing.T) { // Verify the path wasn't changed to empty if opts.storeRootPath != tempDir { - t.Errorf("WithStoreRootPath with empty string changed the path: got %q, want %q", opts.storeRootPath, tempDir) + t.Errorf("WithStoreRootPath with empty string changed the path: got %q, want %q", + opts.storeRootPath, tempDir) } }) @@ -788,7 +789,7 @@ func TestNewReferenceError(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Test with invalid reference @@ -799,7 +800,7 @@ func TestNewReferenceError(t *testing.T) { } if !errors.Is(err, ErrInvalidReference) { - t.Errorf("Expected error to match sentinel invalid reference error, got %v", err) + t.Fatalf("Expected error to match sentinel invalid reference error, got %v", err) } } @@ -809,7 +810,7 @@ func TestPush(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Create a test registry @@ -819,7 +820,7 @@ func TestPush(t *testing.T) { // Create a tag for the model uri, err := url.Parse(server.URL) if err != nil { - t.Errorf("Failed to parse registry URL: %v", err) + t.Fatalf("Failed to parse registry URL: %v", err) } tag := uri.Host + "/incomplete-test/model:v1.0.0" @@ -827,39 +828,39 @@ func TestPush(t *testing.T) { mdl := testutil.BuildModelFromPath(t, testGGUFFile) digest, err := mdl.ID() if err != nil { - t.Errorf("Failed to get digest of original model: %v", err) + t.Fatalf("Failed to get digest of original model: %v", err) } if err := client.store.Write(mdl, []string{tag}, nil); err != nil { - t.Errorf("Failed to push model to store: %v", err) + t.Fatalf("Failed to push model to store: %v", err) } // Push the model to the registry if err := client.PushModel(t.Context(), tag, nil); err != nil { - t.Errorf("Failed to push model: %v", err) + t.Fatalf("Failed to push model: %v", err) } // Delete local copy (so we can test pulling) if _, err := client.DeleteModel(tag, false); err != nil { - t.Errorf("Failed to delete model: %v", err) + t.Fatalf("Failed to delete model: %v", err) } // Test that model can be pulled successfully if err := client.PullModel(t.Context(), tag, nil); err != nil { - t.Errorf("Failed to pull model: %v", err) + t.Fatalf("Failed to pull model: %v", err) } // Test that model the pulled model is the same as the original (matching digests) mdl2, err := client.GetModel(tag) if err != nil { - t.Errorf("Failed to get pulled model: %v", err) + t.Fatalf("Failed to get pulled model: %v", err) } digest2, err := mdl2.ID() if err != nil { - t.Errorf("Failed to get digest of the pulled model: %v", err) + t.Fatalf("Failed to get digest of the pulled model: %v", err) } if digest != digest2 { - t.Errorf("Digests don't match: got %s, want %s", digest2, digest) + t.Fatalf("Digests don't match: got %s, want %s", digest2, digest) } } @@ -869,7 +870,7 @@ func TestPushProgress(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Create a test registry @@ -879,7 +880,7 @@ func TestPushProgress(t *testing.T) { // Create a tag for the model uri, err := url.Parse(server.URL) if err != nil { - t.Errorf("Failed to parse registry URL: %v", err) + t.Fatalf("Failed to parse registry URL: %v", err) } tag := uri.Host + "/some/model/repo:some-tag" @@ -888,14 +889,14 @@ func TestPushProgress(t *testing.T) { sz := int64(progress.MinBytesForUpdate * 2) path, err := randomFile(sz) if err != nil { - t.Errorf("Failed to create temp file: %v", err) + t.Fatalf("Failed to create temp file: %v", err) } defer os.Remove(path) mdl := testutil.BuildModelFromPath(t, path) if err := client.store.Write(mdl, []string{tag}, nil); err != nil { - t.Errorf("Failed to write model to store: %v", err) + t.Fatalf("Failed to write model to store: %v", err) } // Create a buffer to capture progress output @@ -917,13 +918,13 @@ func TestPushProgress(t *testing.T) { // Wait for the push to complete if err := <-done; err != nil { - t.Errorf("Failed to push model: %v", err) + t.Fatalf("Failed to push model: %v", err) } // Verify we got at least 2 messages (1 progress + 1 success) // With fast local uploads, we may only get one progress update per layer if len(lines) < 2 { - t.Errorf("Expected at least 2 progress messages, got %d", len(lines)) + t.Fatalf("Expected at least 2 progress messages, got %d", len(lines)) } // Verify we got at least one progress message and the success message @@ -938,10 +939,10 @@ func TestPushProgress(t *testing.T) { } } if !hasProgress { - t.Errorf("Expected at least one progress message containing 'Uploaded:', got %v", lines) + t.Fatalf("Expected at least one progress message containing 'Uploaded:', got %v", lines) } if !hasSuccess { - t.Errorf("Expected a success message, got %v", lines) + t.Fatalf("Expected a success message, got %v", lines) } } @@ -951,14 +952,14 @@ func TestTag(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Create a test model model := testutil.BuildModelFromPath(t, testGGUFFile) id, err := model.ID() if err != nil { - t.Errorf("Failed to get model ID: %v", err) + t.Fatalf("Failed to get model ID: %v", err) } // Normalize the model name before writing @@ -966,35 +967,35 @@ func TestTag(t *testing.T) { // Push the model to the store if err := client.store.Write(model, []string{normalized}, nil); err != nil { - t.Errorf("Failed to push model to store: %v", err) + t.Fatalf("Failed to push model to store: %v", err) } // Tag the model by ID if err := client.Tag(id, "other-repo:tag1"); err != nil { - t.Errorf("Failed to tag model %q: %v", id, err) + t.Fatalf("Failed to tag model %q: %v", id, err) } // Tag the model by tag if err := client.Tag(id, "other-repo:tag2"); err != nil { - t.Errorf("Failed to tag model %q: %v", id, err) + t.Fatalf("Failed to tag model %q: %v", id, err) } // Verify the model has all 3 tags modelInfo, err := client.GetModel("some-repo:some-tag") if err != nil { - t.Errorf("Failed to get model: %v", err) + t.Fatalf("Failed to get model: %v", err) } if len(modelInfo.Tags()) != 3 { - t.Errorf("Expected 3 tags, got %d", len(modelInfo.Tags())) + t.Fatalf("Expected 3 tags, got %d", len(modelInfo.Tags())) } // Verify the model can be accessed by new tags if _, err := client.GetModel("other-repo:tag1"); err != nil { - t.Errorf("Failed to get model by tag: %v", err) + t.Fatalf("Failed to get model by tag: %v", err) } if _, err := client.GetModel("other-repo:tag2"); err != nil { - t.Errorf("Failed to get model by tag: %v", err) + t.Fatalf("Failed to get model by tag: %v", err) } } @@ -1004,12 +1005,12 @@ func TestTagNotFound(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Tag the model by ID if err := client.Tag("non-existent-model:latest", "other-repo:tag1"); !errors.Is(err, ErrModelNotFound) { - t.Errorf("Expected ErrModelNotFound, got: %v", err) + t.Fatalf("Expected ErrModelNotFound, got: %v", err) } } @@ -1019,11 +1020,11 @@ func TestClientPushModelNotFound(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } if err := client.PushModel(t.Context(), "non-existent-model:latest", nil); !errors.Is(err, ErrModelNotFound) { - t.Errorf("Expected ErrModelNotFound got: %v", err) + t.Fatalf("Expected ErrModelNotFound got: %v", err) } } @@ -1033,13 +1034,13 @@ func TestIsModelInStoreNotFound(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } if inStore, err := client.IsModelInStore("non-existent-model:latest"); err != nil { - t.Errorf("Unexpected error: %v", err) + t.Fatalf("Unexpected error: %v", err) } else if inStore { - t.Error("Expected model not to be found") + t.Fatalf("Expected model not to be found") } } @@ -1049,7 +1050,7 @@ func TestIsModelInStoreFound(t *testing.T) { // Create client with plainHTTP for test registry client, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Create a test model @@ -1060,13 +1061,13 @@ func TestIsModelInStoreFound(t *testing.T) { // Push the model to the store if err := client.store.Write(model, []string{normalized}, nil); err != nil { - t.Errorf("Failed to push model to store: %v", err) + t.Fatalf("Failed to push model to store: %v", err) } if inStore, err := client.IsModelInStore("some-repo:some-tag"); err != nil { - t.Errorf("Unexpected error: %v", err) + t.Fatalf("Unexpected error: %v", err) } else if !inStore { - t.Error("Expected model to be found") + t.Fatalf("Expected model to be found") } } @@ -1141,26 +1142,26 @@ func TestMigrateHFTagsOnClientInit(t *testing.T) { // Step 1: Create a client and write a model with the legacy tag setupClient, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create setup client: %v", err) + t.Fatalf("Failed to create setup client: %v", err) } model := testutil.BuildModelFromPath(t, testGGUFFile) if err := setupClient.store.Write(model, []string{tc.storedTag}, nil); err != nil { - t.Errorf("Failed to write model to store: %v", err) + t.Fatalf("Failed to write model to store: %v", err) } // Step 2: Create a NEW client (simulating restart) - migration should happen client, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Step 3: Verify the model can be found using the reference // (normalizeModelName converts hf.co -> huggingface.co, and migration should have updated the store) foundModel, err := client.GetModel(tc.lookupRef) if err != nil { - t.Errorf("Failed to get model after migration: %v", err) + t.Fatalf("Failed to get model after migration: %v", err) } if foundModel == nil { @@ -1214,7 +1215,7 @@ func TestPullHuggingFaceModelFromCache(t *testing.T) { // Create client client, err := newTestClient(tempDir) if err != nil { - t.Errorf("Failed to create client: %v", err) + t.Fatalf("Failed to create client: %v", err) } // Create a test model and write it to the store with a normalized HuggingFace tag @@ -1223,14 +1224,14 @@ func TestPullHuggingFaceModelFromCache(t *testing.T) { // Store with normalized tag (huggingface.co) hfTag := "huggingface.co/testorg/testmodel:latest" if err := client.store.Write(model, []string{hfTag}, nil); err != nil { - t.Errorf("Failed to write model to store: %v", err) + t.Fatalf("Failed to write model to store: %v", err) } // Now try to pull using the test case's reference - it should use the cache var progressBuffer bytes.Buffer err = client.PullModel(t.Context(), tc.pullRef, &progressBuffer) if err != nil { - t.Errorf("Failed to pull model from cache: %v", err) + t.Fatalf("Failed to pull model from cache: %v", err) } // Verify that progress shows it was cached diff --git a/pkg/distribution/distribution/normalize_test.go b/pkg/distribution/distribution/normalize_test.go index 16c32be08..4e791eeff 100644 --- a/pkg/distribution/distribution/normalize_test.go +++ b/pkg/distribution/distribution/normalize_test.go @@ -292,7 +292,7 @@ func TestNormalizeModelNameWithIDResolution(t *testing.T) { // Extract the short ID (12 hex chars after "sha256:") if !strings.HasPrefix(modelID, "sha256:") { - t.Errorf("Expected model ID to start with 'sha256:', got: %s", modelID) + t.Fatalf("Expected model ID to start with 'sha256:', got: %s", modelID) } shortID := modelID[7:19] // Extract 12 chars after "sha256:" fullHex := strings.TrimPrefix(modelID, "sha256:") @@ -342,7 +342,7 @@ func createTestClient(t *testing.T) (*Client, func()) { WithLogger(slog.Default()), ) if err != nil { - t.Errorf("Failed to create test client: %v", err) + t.Fatalf("Failed to create test client: %v", err) } cleanup := func() { @@ -455,7 +455,7 @@ func loadTestModel(t *testing.T, client *Client, ggufPath string) string { pr, pw := io.Pipe() target, err := tarball.NewTarget(pw) if err != nil { - t.Errorf("Failed to create target: %v", err) + t.Fatalf("Failed to create target: %v", err) } done := make(chan error) @@ -468,15 +468,15 @@ func loadTestModel(t *testing.T, client *Client, ggufPath string) string { bldr, err := builder.FromPath(ggufPath) if err != nil { - t.Errorf("Failed to create builder from GGUF: %v", err) + t.Fatalf("Failed to create builder from GGUF: %v", err) } if err := bldr.Build(t.Context(), target, nil); err != nil { - t.Errorf("Failed to build model: %v", err) + t.Fatalf("Failed to build model: %v", err) } if err := <-done; err != nil { - t.Errorf("Failed to load model: %v", err) + t.Fatalf("Failed to load model: %v", err) } if id == "" { diff --git a/pkg/inference/backends/llamacpp/download.go b/pkg/inference/backends/llamacpp/download.go index 245592731..fb1e24b01 100644 --- a/pkg/inference/backends/llamacpp/download.go +++ b/pkg/inference/backends/llamacpp/download.go @@ -199,6 +199,6 @@ func getLlamaCppVersion(log logging.Logger, llamaCpp string) string { if len(matches) == 2 { return matches[1] } - log.Warn("failed to parse llama.cpp version from output:\n", "error", strings.TrimSpace(string(output))) + log.Warn("Failed to parse llama.cpp version from output", "output", strings.TrimSpace(string(output))) return "unknown" } diff --git a/pkg/inference/backends/llamacpp/llamacpp.go b/pkg/inference/backends/llamacpp/llamacpp.go index 12a1011cd..ee4db9f7d 100644 --- a/pkg/inference/backends/llamacpp/llamacpp.go +++ b/pkg/inference/backends/llamacpp/llamacpp.go @@ -117,7 +117,7 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error { // digest to be equal to the one on Docker Hub. llamaCppPath := filepath.Join(l.updatedServerStoragePath, llamaServerBin) if err := l.ensureLatestLlamaCpp(ctx, l.log, httpClient, llamaCppPath, l.vendoredServerStoragePath); err != nil { - l.log.Info("failed to ensure latest llama.cpp \n", "error", err) + l.log.Info("Failed to ensure latest llama.cpp", "error", err) if !errors.Is(err, errLlamaCppUpToDate) && !errors.Is(err, errLlamaCppUpdateDisabled) { l.status = fmt.Sprintf("failed to install llama.cpp: %v", err) } diff --git a/pkg/inference/models/handler_test.go b/pkg/inference/models/handler_test.go index 7194ef995..b7f0c34b2 100644 --- a/pkg/inference/models/handler_test.go +++ b/pkg/inference/models/handler_test.go @@ -22,7 +22,7 @@ func getProjectRoot(t *testing.T) string { // Start from the current test file's directory dir, err := os.Getwd() if err != nil { - t.Errorf("Failed to get current directory: %v", err) + t.Fatalf("Failed to get current directory: %v", err) } // Walk up the directory tree until we find the go.mod file @@ -48,7 +48,7 @@ func TestPullModel(t *testing.T) { // Create a tag for the model uri, err := url.Parse(server.URL) if err != nil { - t.Errorf("Failed to parse registry URL: %v", err) + t.Fatalf("Failed to parse registry URL: %v", err) } tag := uri.Host + "/ai/model:v1.0.0" @@ -56,23 +56,23 @@ func TestPullModel(t *testing.T) { projectRoot := getProjectRoot(t) model, err := builder.FromPath(filepath.Join(projectRoot, "assets", "dummy.gguf")) if err != nil { - t.Errorf("Failed to create model builder: %v", err) + t.Fatalf("Failed to create model builder: %v", err) } license, err := model.WithLicense(filepath.Join(projectRoot, "assets", "license.txt")) if err != nil { - t.Errorf("Failed to add license to model: %v", err) + t.Fatalf("Failed to add license to model: %v", err) } // Build the OCI model artifact + push it (use plainHTTP for test registry) client := reg.NewClient(reg.WithPlainHTTP(true)) target, err := client.NewTarget(tag) if err != nil { - t.Errorf("Failed to create model target: %v", err) + t.Fatalf("Failed to create model target: %v", err) } err = license.Build(t.Context(), target, os.Stdout) if err != nil { - t.Errorf("Failed to build model: %v", err) + t.Fatalf("Failed to build model: %v", err) } tests := []struct { @@ -115,19 +115,19 @@ func TestPullModel(t *testing.T) { w := httptest.NewRecorder() err = handler.manager.Pull(tag, "", r, w) if err != nil { - t.Errorf("Failed to pull model: %v", err) + t.Fatalf("Failed to pull model: %v", err) } if tt.expectedCT != w.Header().Get("Content-Type") { - t.Errorf("Expected content type %s, got %s", tt.expectedCT, w.Header().Get("Content-Type")) + t.Fatalf("Expected content type %s, got %s", tt.expectedCT, w.Header().Get("Content-Type")) } // Clean tempDir after each test if err := os.RemoveAll(tempDir); err != nil { - t.Errorf("Failed to clean temp directory: %v", err) + t.Fatalf("Failed to clean temp directory: %v", err) } if err := os.MkdirAll(tempDir, 0755); err != nil { - t.Errorf("Failed to recreate temp directory: %v", err) + t.Fatalf("Failed to recreate temp directory: %v", err) } }) } @@ -142,19 +142,19 @@ func TestHandleGetModel(t *testing.T) { uri, err := url.Parse(server.URL) if err != nil { - t.Errorf("Failed to parse registry URL: %v", err) + t.Fatalf("Failed to parse registry URL: %v", err) } // Prepare the OCI model artifact projectRoot := getProjectRoot(t) model, err := builder.FromPath(filepath.Join(projectRoot, "assets", "dummy.gguf")) if err != nil { - t.Errorf("Failed to create model builder: %v", err) + t.Fatalf("Failed to create model builder: %v", err) } license, err := model.WithLicense(filepath.Join(projectRoot, "assets", "license.txt")) if err != nil { - t.Errorf("Failed to add license to model: %v", err) + t.Fatalf("Failed to add license to model: %v", err) } // Build the OCI model artifact + push it (use plainHTTP for test registry) @@ -162,11 +162,11 @@ func TestHandleGetModel(t *testing.T) { client := reg.NewClient(reg.WithPlainHTTP(true)) target, err := client.NewTarget(tag) if err != nil { - t.Errorf("Failed to create model target: %v", err) + t.Fatalf("Failed to create model target: %v", err) } err = license.Build(t.Context(), target, os.Stdout) if err != nil { - t.Errorf("Failed to build model: %v", err) + t.Fatalf("Failed to build model: %v", err) } tests := []struct { @@ -222,7 +222,7 @@ func TestHandleGetModel(t *testing.T) { w := httptest.NewRecorder() err = handler.manager.Pull(tt.modelName, "", r, w) if err != nil { - t.Errorf("Failed to pull model: %v", err) + t.Fatalf("Failed to pull model: %v", err) } } @@ -265,10 +265,10 @@ func TestHandleGetModel(t *testing.T) { // Clean tempDir after each test if err := os.RemoveAll(tempDir); err != nil { - t.Errorf("Failed to clean temp directory: %v", err) + t.Fatalf("Failed to clean temp directory: %v", err) } if err := os.MkdirAll(tempDir, 0755); err != nil { - t.Errorf("Failed to recreate temp directory: %v", err) + t.Fatalf("Failed to recreate temp directory: %v", err) } }) } diff --git a/pkg/inference/scheduling/loader.go b/pkg/inference/scheduling/loader.go index 9817abfdf..d08aaa23e 100644 --- a/pkg/inference/scheduling/loader.go +++ b/pkg/inference/scheduling/loader.go @@ -233,7 +233,7 @@ func (l *loader) evict(idleOnly bool) int { default: } if unused && (!idleOnly || idle || defunct) && (!idleOnly || !neverEvict || defunct) { - l.log.Info("Evicting backend runner with model ( ) in mode", "backend", r.backend, "backend", r.modelID, "model", runnerInfo.modelRef, "mode", r.mode) + l.log.Info("Evicting backend runner", "backend", r.backend, "model", r.modelID, "modelRef", runnerInfo.modelRef, "mode", r.mode) l.freeRunnerSlot(runnerInfo.slot, r) evictedCount++ } else if unused { @@ -256,7 +256,7 @@ func (l *loader) evictRunner(backend, model string, mode inference.BackendMode) for r, runnerInfo := range l.runners { unused := l.references[runnerInfo.slot] == 0 if unused && (allBackends || r.backend == backend) && r.modelID == model && r.mode == mode { - l.log.Info("Evicting backend runner with model ( ) in mode", "backend", r.backend, "backend", r.modelID, "model", runnerInfo.modelRef, "mode", r.mode) + l.log.Info("Evicting backend runner", "backend", r.backend, "model", r.modelID, "modelRef", runnerInfo.modelRef, "mode", r.mode) l.freeRunnerSlot(runnerInfo.slot, r) found = true } @@ -455,7 +455,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string runnerConfig = &defaultConfig } - l.log.Info("Loading backend runner with model in mode", "backend", backendName, "backend", modelID, "mode", mode) + l.log.Info("Loading backend runner", "backend", backendName, "model", modelID, "mode", mode) // Acquire the loader lock and defer its release. if !l.lock(ctx) { @@ -485,7 +485,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string if ok { select { case <-l.slots[existing.slot].done: - l.log.Warn("runner for is defunct. Waiting for it to be evicted.", "backend", backendName, "model", existing.modelRef) + l.log.Warn("Runner is defunct, waiting for eviction", "backend", backendName, "model", existing.modelRef) if l.references[existing.slot] == 0 { l.evictRunner(backendName, modelID, mode) // Continue the loop to retry loading after evicting the defunct runner @@ -530,7 +530,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string // Create the runner. runner, err := run(l.log, backend, modelID, modelRef, mode, slot, runnerConfig, l.openAIRecorder) if err != nil { - l.log.Warn("Unable to start backend runner with model in mode", "backend", backendName, "backend", modelID, "mode", mode, "error", err) + l.log.Warn("Unable to start backend runner", "backend", backendName, "model", modelID, "mode", mode, "error", err) return nil, fmt.Errorf("unable to start runner: %w", err) } @@ -542,7 +542,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string // deduplication of runners and keep slot / memory reservations. if err := runner.wait(ctx); err != nil { runner.terminate() - l.log.Warn("Initialization for backend runner with model in mode failed", "backend", backendName, "backend", modelID, "mode", mode, "error", err) + l.log.Warn("Backend runner initialization failed", "backend", backendName, "model", modelID, "mode", mode, "error", err) return nil, fmt.Errorf("error waiting for runner to be ready: %w", err) } @@ -615,7 +615,7 @@ func (l *loader) setRunnerConfig(ctx context.Context, backendName, modelID strin // If the configuration hasn't changed, then just return. if existingConfig, ok := l.runnerConfigs[configKey]; ok && reflect.DeepEqual(runnerConfig, existingConfig) { - l.log.Info("Configuration for runner for modelID unchanged", "backend", backendName, "model", modelID) + l.log.Info("Runner configuration unchanged", "backend", backendName, "model", modelID) return nil } @@ -638,7 +638,7 @@ func (l *loader) setRunnerConfig(ctx context.Context, backendName, modelID strin return errRunnerAlreadyActive } - l.log.Info("Configuring runner for", "backend", backendName, "model", modelID) + l.log.Info("Configuring runner", "backend", backendName, "model", modelID) l.runnerConfigs[configKey] = runnerConfig return nil } diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index 783520f09..cef6b6b9b 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -212,7 +212,7 @@ func (s *Scheduler) GetAllActiveRunners() []metrics.ActiveRunner { if key.backend == backend.BackendName && key.modelID == backend.ModelName && key.mode == mode { socket, err := RunnerSocketPath(runnerInfo.slot) if err != nil { - s.log.Warn("Failed to get socket path for runner / ( )", "backend", backend.BackendName, "backend", backend.ModelName, "model", key.modelID, "error", err) + s.log.Warn("Failed to get socket path for runner", "backend", backend.BackendName, "model", backend.ModelName, "modelID", key.modelID, "error", err) continue } @@ -334,7 +334,7 @@ func (s *Scheduler) ConfigureRunner(ctx context.Context, backend inference.Backe // Set the runner configuration if err := s.loader.setRunnerConfig(ctx, backend.Name(), modelID, mode, runnerConfig); err != nil { - s.log.Warn("Failed to configure runner for ( )", "backend", backend.Name(), "model", utils.SanitizeForLog(req.Model, -1), "model", modelID, "error", err) + s.log.Warn("Failed to configure runner", "backend", backend.Name(), "model", utils.SanitizeForLog(req.Model, -1), "modelID", modelID, "error", err) return nil, err } diff --git a/pkg/responses/handler_test.go b/pkg/responses/handler_test.go index 1e460dd78..fb71e452b 100644 --- a/pkg/responses/handler_test.go +++ b/pkg/responses/handler_test.go @@ -38,8 +38,7 @@ func (m *mockSchedulerHTTP) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func newTestHandler(mock *mockSchedulerHTTP) *HTTPHandler { - log := slog.Default() - // log output is controlled by the slog handler level + log := slog.New(slog.DiscardHandler) return NewHTTPHandler(log, mock, nil) } @@ -89,7 +88,7 @@ func TestHandler_CreateResponse_NonStreaming(t *testing.T) { var result Response if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - t.Errorf("failed to decode response: %v", err) + t.Fatalf("failed to decode response: %v", err) } if result.Object != "response" { @@ -174,7 +173,7 @@ func TestHandler_GetResponse(t *testing.T) { var result Response if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - t.Errorf("failed to decode response: %v", err) + t.Fatalf("failed to decode response: %v", err) } if result.ID != "resp_test123" { @@ -298,7 +297,7 @@ func TestHandler_CreateResponse_WithPreviousResponse(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - t.Errorf("status = %d, want %d, body: %s", resp.StatusCode, http.StatusOK, body) + t.Fatalf("status = %d, want %d, body: %s", resp.StatusCode, http.StatusOK, body) } var result Response @@ -384,7 +383,7 @@ func TestHandler_CreateResponse_UpstreamError_NonJSONBody(t *testing.T) { } if result.Error == nil { - t.Error("expected error, got nil") + t.Fatalf("expected error, got nil") } if result.Error.Code != "upstream_error" { @@ -437,7 +436,7 @@ func TestHandler_CreateResponse_Streaming(t *testing.T) { // Read all body body, err := io.ReadAll(resp.Body) if err != nil { - t.Errorf("failed to read body: %v", err) + t.Fatalf("failed to read body: %v", err) } bodyStr := string(body) @@ -516,7 +515,7 @@ func TestHandler_CreateResponse_WithTools(t *testing.T) { resp := w.Result() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - t.Errorf("status = %d, want %d, body: %s", resp.StatusCode, http.StatusOK, body) + t.Fatalf("status = %d, want %d, body: %s", resp.StatusCode, http.StatusOK, body) } var result Response @@ -629,13 +628,13 @@ func TestHandler_CreateResponse_Streaming_Persistence(t *testing.T) { memStore := handler.store if memStore.Count() != 1 { - t.Errorf("expected exactly one response in store, got %d", memStore.Count()) + t.Fatalf("expected exactly one response in store, got %d", memStore.Count()) } // Get the response ID from the store responseIDs := memStore.GetResponseIDs() if len(responseIDs) != 1 { - t.Errorf("expected exactly one response ID in store, got %d", len(responseIDs)) + t.Fatalf("expected exactly one response ID in store, got %d", len(responseIDs)) } // Retrieve the response using the public API