Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ require (
github.com/opencontainers/image-spec v1.1.1
github.com/prometheus/client_model v0.6.2
github.com/prometheus/common v0.67.5
github.com/sirupsen/logrus v1.9.4
github.com/stretchr/testify v1.11.1
golang.org/x/sync v0.19.0
)
Expand All @@ -42,6 +41,7 @@ require (
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/sirupsen/logrus v1.9.4 // indirect
github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
Expand Down
96 changes: 52 additions & 44 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package main
import (
"context"
"crypto/tls"
"fmt"
"log/slog"
"net"
"net/http"
"os"
Expand All @@ -24,21 +26,27 @@ import (
"github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/inference/platform"
"github.com/docker/model-runner/pkg/inference/scheduling"
"github.com/docker/model-runner/pkg/logging"
"github.com/docker/model-runner/pkg/metrics"
"github.com/docker/model-runner/pkg/middleware"
"github.com/docker/model-runner/pkg/ollama"
"github.com/docker/model-runner/pkg/responses"
"github.com/docker/model-runner/pkg/routing"
modeltls "github.com/docker/model-runner/pkg/tls"
"github.com/sirupsen/logrus"
)

const (
// DefaultTLSPort is the default TLS port for Moby
DefaultTLSPort = "12444"
)

var log = logrus.New()
// initLogger creates the application logger based on LOG_LEVEL env var.
func initLogger() *slog.Logger {
level := logging.ParseLevel(os.Getenv("LOG_LEVEL"))
return logging.NewLogger(level)
}

var log = initLogger()

// Log is the logger used by the application, exported for testing purposes.
var Log = log
Expand All @@ -57,7 +65,7 @@ func main() {

userHomeDir, err := os.UserHomeDir()
if err != nil {
log.Fatalf("Failed to get user home directory: %v", err)
log.Error(fmt.Sprintf("Failed to get user home directory: %v", err))
}

modelPath := os.Getenv("MODELS_PATH")
Expand Down Expand Up @@ -101,27 +109,27 @@ func main() {

clientConfig := models.ClientConfig{
StoreRootPath: modelPath,
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
Logger: log.With("component", "model-manager"),
Transport: baseTransport,
}
modelManager := models.NewManager(log.WithFields(logrus.Fields{"component": "model-manager"}), clientConfig)
modelManager := models.NewManager(log.With("component", "model-manager"), clientConfig)
modelHandler := models.NewHTTPHandler(
log,
modelManager,
nil,
)
log.Infof("LLAMA_SERVER_PATH: %s", llamaServerPath)
log.Info(fmt.Sprintf("LLAMA_SERVER_PATH: %s", llamaServerPath))
if vllmServerPath != "" {
log.Infof("VLLM_SERVER_PATH: %s", vllmServerPath)
log.Info(fmt.Sprintf("VLLM_SERVER_PATH: %s", vllmServerPath))
}
if sglangServerPath != "" {
log.Infof("SGLANG_SERVER_PATH: %s", sglangServerPath)
log.Info(fmt.Sprintf("SGLANG_SERVER_PATH: %s", sglangServerPath))
}
if mlxServerPath != "" {
log.Infof("MLX_SERVER_PATH: %s", mlxServerPath)
log.Info(fmt.Sprintf("MLX_SERVER_PATH: %s", mlxServerPath))
}
if vllmMetalServerPath != "" {
log.Infof("VLLM_METAL_SERVER_PATH: %s", vllmMetalServerPath)
log.Info(fmt.Sprintf("VLLM_METAL_SERVER_PATH: %s", vllmMetalServerPath))
}

// Create llama.cpp configuration from environment variables
Expand All @@ -130,7 +138,7 @@ func main() {
llamaCppBackend, err := llamacpp.New(
log,
modelManager,
log.WithFields(logrus.Fields{"component": llamacpp.Name}),
log.With("component", llamacpp.Name),
llamaServerPath,
func() string {
wd, _ := os.Getwd()
Expand All @@ -141,58 +149,58 @@ func main() {
llamaCppConfig,
)
if err != nil {
log.Fatalf("unable to initialize %s backend: %v", llamacpp.Name, err)
log.Error(fmt.Sprintf("unable to initialize %s backend: %v", llamacpp.Name, err))
}

vllmBackend, err := initVLLMBackend(log, modelManager, vllmServerPath)
if err != nil {
log.Fatalf("unable to initialize %s backend: %v", vllm.Name, err)
log.Error(fmt.Sprintf("unable to initialize %s backend: %v", vllm.Name, err))
}

mlxBackend, err := mlx.New(
log,
modelManager,
log.WithFields(logrus.Fields{"component": mlx.Name}),
log.With("component", mlx.Name),
nil,
mlxServerPath,
)
if err != nil {
log.Fatalf("unable to initialize %s backend: %v", mlx.Name, err)
log.Error(fmt.Sprintf("unable to initialize %s backend: %v", mlx.Name, err))
}

sglangBackend, err := sglang.New(
log,
modelManager,
log.WithFields(logrus.Fields{"component": sglang.Name}),
log.With("component", sglang.Name),
nil,
sglangServerPath,
)
if err != nil {
log.Fatalf("unable to initialize %s backend: %v", sglang.Name, err)
log.Error(fmt.Sprintf("unable to initialize %s backend: %v", sglang.Name, err))
}

diffusersBackend, err := diffusers.New(
log,
modelManager,
log.WithFields(logrus.Fields{"component": diffusers.Name}),
log.With("component", diffusers.Name),
nil,
diffusersServerPath,
)

if err != nil {
log.Fatalf("unable to initialize diffusers backend: %v", err)
log.Error(fmt.Sprintf("unable to initialize diffusers backend: %v", err))
}

var vllmMetalBackend inference.Backend
if platform.SupportsVLLMMetal() {
vllmMetalBackend, err = vllmmetal.New(
log,
modelManager,
log.WithFields(logrus.Fields{"component": vllmmetal.Name}),
log.With("component", vllmmetal.Name),
vllmMetalServerPath,
)
if err != nil {
log.Warnf("Failed to initialize vllm-metal backend: %v", err)
log.Warn(fmt.Sprintf("Failed to initialize vllm-metal backend: %v", err))
}
}

Expand Down Expand Up @@ -222,7 +230,7 @@ func main() {
http.DefaultClient,
metrics.NewTracker(
http.DefaultClient,
log.WithField("component", "metrics"),
log.With("component", "metrics"),
"",
false,
),
Expand Down Expand Up @@ -278,7 +286,7 @@ func main() {
// Add metrics endpoint if enabled
if os.Getenv("DISABLE_METRICS") != "1" {
metricsHandler := metrics.NewAggregatedMetricsHandler(
log.WithField("component", "metrics"),
log.With("component", "metrics"),
schedulerHTTP,
)
router.Handle("/metrics", metricsHandler)
Expand All @@ -302,7 +310,7 @@ func main() {
if tcpPort != "" {
// Use TCP port
addr := ":" + tcpPort
log.Infof("Listening on TCP port %s", tcpPort)
log.Info(fmt.Sprintf("Listening on TCP port %s", tcpPort))
server.Addr = addr
go func() {
serverErrors <- server.ListenAndServe()
Expand All @@ -311,12 +319,12 @@ func main() {
// Use Unix socket
if err := os.Remove(sockName); err != nil {
if !os.IsNotExist(err) {
log.Fatalf("Failed to remove existing socket: %v", err)
log.Error(fmt.Sprintf("Failed to remove existing socket: %v", err))
}
}
ln, err := net.ListenUnix("unix", &net.UnixAddr{Name: sockName, Net: "unix"})
if err != nil {
log.Fatalf("Failed to listen on socket: %v", err)
log.Error(fmt.Sprintf("Failed to listen on socket: %v", err))
}
go func() {
serverErrors <- server.Serve(ln)
Expand All @@ -341,19 +349,19 @@ func main() {
var err error
certPath, keyPath, err = modeltls.EnsureCertificates("", "")
if err != nil {
log.Fatalf("Failed to ensure TLS certificates: %v", err)
log.Error(fmt.Sprintf("Failed to ensure TLS certificates: %v", err))
}
log.Infof("Using TLS certificate: %s", certPath)
log.Infof("Using TLS key: %s", keyPath)
log.Info(fmt.Sprintf("Using TLS certificate: %s", certPath))
log.Info(fmt.Sprintf("Using TLS key: %s", keyPath))
} else {
log.Fatal("TLS enabled but no certificate provided and auto-cert is disabled")
log.Error("TLS enabled but no certificate provided and auto-cert is disabled")
}
}

// Load TLS configuration
tlsConfig, err := modeltls.LoadTLSConfig(certPath, keyPath)
if err != nil {
log.Fatalf("Failed to load TLS configuration: %v", err)
log.Error(fmt.Sprintf("Failed to load TLS configuration: %v", err))
}

tlsServer = &http.Server{
Expand All @@ -363,7 +371,7 @@ func main() {
ReadHeaderTimeout: 10 * time.Second,
}

log.Infof("Listening on TLS port %s", tlsPort)
log.Info(fmt.Sprintf("Listening on TLS port %s", tlsPort))
go func() {
// Use ListenAndServeTLS with empty strings since TLSConfig already has the certs
ln, err := tls.Listen("tcp", tlsServer.Addr, tlsConfig)
Expand Down Expand Up @@ -391,30 +399,30 @@ func main() {
select {
case err := <-serverErrors:
if err != nil {
log.Errorf("Server error: %v", err)
log.Error(fmt.Sprintf("Server error: %v", err))
}
case err := <-tlsServerErrorsChan:
if err != nil {
log.Errorf("TLS server error: %v", err)
log.Error(fmt.Sprintf("TLS server error: %v", err))
}
case <-ctx.Done():
log.Infoln("Shutdown signal received")
log.Infoln("Shutting down the server")
log.Info("Shutdown signal received")
log.Info("Shutting down the server")
if err := server.Close(); err != nil {
log.Errorf("Server shutdown error: %v", err)
log.Error(fmt.Sprintf("Server shutdown error: %v", err))
}
if tlsServer != nil {
log.Infoln("Shutting down the TLS server")
log.Info("Shutting down the TLS server")
if err := tlsServer.Close(); err != nil {
log.Errorf("TLS server shutdown error: %v", err)
log.Error(fmt.Sprintf("TLS server shutdown error: %v", err))
}
}
log.Infoln("Waiting for the scheduler to stop")
log.Info("Waiting for the scheduler to stop")
if err := <-schedulerErrors; err != nil {
log.Errorf("Scheduler error: %v", err)
log.Error(fmt.Sprintf("Scheduler error: %v", err))
}
}
log.Infoln("Docker Model Runner stopped")
log.Info("Docker Model Runner stopped")
}

// createLlamaCppConfigFromEnv creates a LlamaCppConfig from environment variables
Expand All @@ -435,12 +443,12 @@ func createLlamaCppConfigFromEnv() config.BackendConfig {
for _, arg := range args {
for _, disallowed := range disallowedArgs {
if arg == disallowed {
testLog.Fatalf("LLAMA_ARGS cannot override the %s argument as it is controlled by the model runner", disallowed)
testLog.Error(fmt.Sprintf("LLAMA_ARGS cannot override the %s argument as it is controlled by the model runner", disallowed))
}
}
}

testLog.Infof("Using custom arguments: %v", args)
testLog.Info(fmt.Sprintf("Using custom arguments: %v", args))
return &llamacpp.Config{
Args: args,
}
Expand Down
Loading