diff --git a/cmd/cli/desktop/desktop.go b/cmd/cli/desktop/desktop.go index 056cf3263..9a6e6c0d4 100644 --- a/cmd/cli/desktop/desktop.go +++ b/cmd/cli/desktop/desktop.go @@ -16,6 +16,9 @@ import ( "github.com/docker/model-runner/cmd/cli/pkg/standalone" "github.com/docker/model-runner/pkg/distribution/distribution" + "github.com/docker/model-runner/pkg/distribution/oci/authn" + "github.com/docker/model-runner/pkg/distribution/oci/reference" + "github.com/docker/model-runner/pkg/distribution/oci/remote" "github.com/docker/model-runner/pkg/inference" dmrm "github.com/docker/model-runner/pkg/inference/models" "github.com/docker/model-runner/pkg/inference/scheduling" @@ -29,6 +32,78 @@ var ( ErrServiceUnavailable = errors.New("service unavailable") ) +// resolveCredentials resolves Docker registry credentials for a model reference +// and exchanges them for a short-lived bearer token. +// Returns nil if no credentials are found (anonymous access). +func resolveCredentials(ctx context.Context, model string) *distribution.Credentials { + // Skip credential resolution for Hugging Face models (use HF_TOKEN env var instead). + if strings.HasPrefix(strings.ToLower(model), "hf.co/") { + if hfToken := os.Getenv("HF_TOKEN"); hfToken != "" { + return &distribution.Credentials{BearerToken: hfToken} + } + return nil + } + + ref, err := reference.ParseReference(model) + if err != nil { + return nil + } + + resource := authn.NewResource(ref) + auth, err := authn.DefaultKeychain.Resolve(resource) + if err != nil { + return nil + } + + authConfig, err := auth.Authorization() + if err != nil || authConfig == nil { + return nil + } + + if authConfig.RegistryToken != "" { + return &distribution.Credentials{BearerToken: authConfig.RegistryToken} + } + if authConfig.IdentityToken != "" { + return &distribution.Credentials{BearerToken: authConfig.IdentityToken} + } + + if authConfig.Username != "" && authConfig.Password != "" { + token, err := exchangeForToken(ctx, ref, auth) + if err != nil { + return &distribution.Credentials{ + Username: authConfig.Username, + Password: authConfig.Password, + } + } + return &distribution.Credentials{BearerToken: token} + } + + return nil +} + +// exchangeForToken exchanges credentials for a short-lived bearer token. +func exchangeForToken(ctx context.Context, ref reference.Reference, auth authn.Authenticator) (string, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + pr, err := remote.Ping(ctx, ref.Context().Registry, nil) + if err != nil { + return "", fmt.Errorf("pinging registry: %w", err) + } + + if pr.WWWAuthenticate.Realm == "" { + return "", fmt.Errorf("no auth required") + } + + scope := ref.Scope(remote.PushScope) + token, err := remote.Exchange(ctx, ref.Context().Registry, auth, nil, []string{scope}, pr) + if err != nil { + return "", fmt.Errorf("exchanging credentials: %w", err) + } + + return token.Token, nil +} + type otelErrorSilencer struct{} func (oes *otelErrorSilencer) Handle(error) {} @@ -98,17 +173,16 @@ func (c *Client) Status() Status { } func (c *Client) Pull(model string, printer standalone.StatusPrinter) (string, bool, error) { - // Check if this is a Hugging Face model and if HF_TOKEN is set - var hfToken string - if strings.HasPrefix(strings.ToLower(model), "hf.co/") { - hfToken = os.Getenv("HF_TOKEN") - } + creds := resolveCredentials(context.Background(), model) return c.withRetries("download", 3, printer, func(attempt int) (string, bool, error, bool) { - jsonData, err := json.Marshal(dmrm.ModelCreateRequest{ - From: model, - BearerToken: hfToken, - }) + req := dmrm.ModelCreateRequest{From: model} + if creds != nil { + req.Username = creds.Username + req.Password = creds.Password + req.BearerToken = creds.BearerToken + } + jsonData, err := json.Marshal(req) if err != nil { // Marshaling errors are not retryable return "", false, fmt.Errorf("error marshaling request: %w", err), false @@ -223,12 +297,27 @@ func (c *Client) withRetries( } func (c *Client) Push(model string, printer standalone.StatusPrinter) (string, bool, error) { + creds := resolveCredentials(context.Background(), model) + return c.withRetries("push", 3, printer, func(attempt int) (string, bool, error, bool) { + var body io.Reader + if creds != nil { + jsonData, err := json.Marshal(dmrm.ModelPushRequest{ + Username: creds.Username, + Password: creds.Password, + BearerToken: creds.BearerToken, + }) + if err != nil { + return "", false, fmt.Errorf("error marshaling request: %w", err), false + } + body = bytes.NewReader(jsonData) + } + pushPath := inference.ModelsPrefix + "/" + model + "/push" resp, err := c.doRequest( http.MethodPost, pushPath, - nil, // Assuming no body is needed for the push request + body, ) if err != nil { // Only retry on network errors, not on client errors diff --git a/cmd/cli/pkg/standalone/containers.go b/cmd/cli/pkg/standalone/containers.go index a6ed8afa2..afe773899 100644 --- a/cmd/cli/pkg/standalone/containers.go +++ b/cmd/cli/pkg/standalone/containers.go @@ -1,8 +1,6 @@ package standalone import ( - "archive/tar" - "bytes" "context" "errors" "fmt" @@ -29,66 +27,6 @@ import ( // controllerContainerName is the name to use for the controller container. const controllerContainerName = "docker-model-runner" -// copyDockerConfigToContainer copies the Docker config file from the host to the container -// and sets up proper ownership and permissions for the modelrunner user. -// It does nothing for Desktop and Cloud engine kinds. -func copyDockerConfigToContainer(ctx context.Context, dockerClient *client.Client, containerID string, engineKind types.ModelRunnerEngineKind) error { - // Do nothing for Desktop and Cloud engine kinds - if engineKind == types.ModelRunnerEngineKindDesktop || engineKind == types.ModelRunnerEngineKindCloud || - os.Getenv("_MODEL_RUNNER_TREAT_DESKTOP_AS_MOBY") == "1" { - return nil - } - - dockerConfigPath := os.ExpandEnv("$HOME/.docker/config.json") - if s, err := os.Stat(dockerConfigPath); err != nil || s.Mode()&os.ModeType != 0 { - return nil - } - - configData, err := os.ReadFile(dockerConfigPath) - if err != nil { - return fmt.Errorf("failed to read Docker config file: %w", err) - } - - var buf bytes.Buffer - tw := tar.NewWriter(&buf) - header := &tar.Header{ - Name: ".docker/config.json", - Mode: 0600, - Size: int64(len(configData)), - } - if err := tw.WriteHeader(header); err != nil { - return fmt.Errorf("failed to write tar header: %w", err) - } - if _, err := tw.Write(configData); err != nil { - return fmt.Errorf("failed to write config data to tar: %w", err) - } - if err := tw.Close(); err != nil { - return fmt.Errorf("failed to close tar writer: %w", err) - } - - // Ensure the .docker directory exists - mkdirCmd := "mkdir -p /home/modelrunner/.docker && chown modelrunner:modelrunner /home/modelrunner/.docker" - if err := execInContainer(ctx, dockerClient, containerID, mkdirCmd, false); err != nil { - return err - } - - // Copy directly into the .docker directory - err = dockerClient.CopyToContainer(ctx, containerID, "/home/modelrunner", &buf, container.CopyToContainerOptions{ - CopyUIDGID: true, - }) - if err != nil { - return fmt.Errorf("failed to copy config file to container: %w", err) - } - - // Set correct ownership and permissions - chmodCmd := "chown modelrunner:modelrunner /home/modelrunner/.docker/config.json && chmod 600 /home/modelrunner/.docker/config.json" - if err := execInContainer(ctx, dockerClient, containerID, chmodCmd, false); err != nil { - return err - } - - return nil -} - func execInContainer(ctx context.Context, dockerClient *client.Client, containerID, cmd string, asRoot bool) error { execConfig := container.ExecOptions{ Cmd: []string{"sh", "-c", cmd}, @@ -447,14 +385,6 @@ func CreateControllerContainer(ctx context.Context, dockerClient *client.Client, return fmt.Errorf("failed to start container %s: %w", controllerContainerName, err) } - // Copy Docker config file if it exists and we're the container creator. - if created && !vllmOnWSL { - if err := copyDockerConfigToContainer(ctx, dockerClient, resp.ID, engineKind); err != nil { - // Log warning but continue - don't fail container creation - printer.Printf("Warning: failed to copy Docker config: %v\n", err) - } - } - // Add proxy certificate to the system CA bundle (requires root for update-ca-certificates) if created && proxyCert != "" { printer.Printf("Updating CA certificates...\n") diff --git a/pkg/distribution/distribution/client.go b/pkg/distribution/distribution/client.go index f150adef8..aca8e6263 100644 --- a/pkg/distribution/distribution/client.go +++ b/pkg/distribution/distribution/client.go @@ -23,6 +23,16 @@ import ( "github.com/sirupsen/logrus" ) +// Credentials holds authentication credentials for registry operations. +type Credentials struct { + // Username for basic authentication. + Username string + // Password for basic authentication. + Password string + // BearerToken for token-based authentication (e.g., Hugging Face). + BearerToken string +} + // Client provides model distribution functionality type Client struct { store *store.LocalStore @@ -227,32 +237,32 @@ func (c *Client) resolveID(id string) string { } // PullModel pulls a model from a registry and returns the local file path -func (c *Client) PullModel(ctx context.Context, reference string, progressWriter io.Writer, bearerToken ...string) error { +func (c *Client) PullModel(ctx context.Context, reference string, progressWriter io.Writer, creds *Credentials) error { // Store original reference before normalization (needed for case-sensitive HuggingFace API) originalReference := reference // Normalize the model reference reference = c.normalizeModelName(reference) c.log.Infoln("Starting model pull:", utils.SanitizeForLog(reference)) - // Handle bearer token for registry authentication - var token string - if len(bearerToken) > 0 && bearerToken[0] != "" { - token = bearerToken[0] - } - // HuggingFace references always use native pull (download raw files from HF Hub) if isHuggingFaceReference(originalReference) { c.log.Infoln("Using native HuggingFace pull for:", utils.SanitizeForLog(reference)) // Pass original reference to preserve case-sensitivity for HuggingFace API + var token string + if creds != nil && creds.BearerToken != "" { + token = creds.BearerToken + } return c.pullNativeHuggingFace(ctx, originalReference, progressWriter, token) } // For non-HF references, use OCI registry registryClient := c.registry - if token != "" { - // Create a temporary registry client with bearer token authentication - auth := authn.NewBearer(token) - registryClient = registry.FromClient(c.registry, registry.WithAuth(auth)) + if creds != nil { + if creds.Username != "" && creds.Password != "" { + registryClient = registry.FromClient(c.registry, registry.WithAuthConfig(creds.Username, creds.Password)) + } else if creds.BearerToken != "" { + registryClient = registry.FromClient(c.registry, registry.WithAuth(authn.NewBearer(creds.BearerToken))) + } } // Fetch the remote model to get the manifest @@ -538,9 +548,18 @@ func (c *Client) Tag(source string, target string) error { } // PushModel pushes a tagged model from the content store to the registry. -func (c *Client) PushModel(ctx context.Context, tag string, progressWriter io.Writer) (err error) { +func (c *Client) PushModel(ctx context.Context, tag string, progressWriter io.Writer, creds *Credentials) (err error) { + registryClient := c.registry + if creds != nil { + if creds.BearerToken != "" { + registryClient = registry.FromClient(c.registry, registry.WithAuth(authn.NewBearer(creds.BearerToken))) + } else if creds.Username != "" && creds.Password != "" { + registryClient = registry.FromClient(c.registry, registry.WithAuthConfig(creds.Username, creds.Password)) + } + } + // Parse the tag - target, err := c.registry.NewTarget(tag) + target, err := registryClient.NewTarget(tag) if err != nil { return fmt.Errorf("new tag: %w", err) } diff --git a/pkg/distribution/distribution/client_test.go b/pkg/distribution/distribution/client_test.go index 7e3b65f1e..9442680c1 100644 --- a/pkg/distribution/distribution/client_test.go +++ b/pkg/distribution/distribution/client_test.go @@ -79,7 +79,7 @@ func TestClientPullModel(t *testing.T) { t.Run("pull without progress writer", func(t *testing.T) { // Pull model from registry without progress writer - err := client.PullModel(t.Context(), tag, nil) + err := client.PullModel(t.Context(), tag, nil, nil) if err != nil { t.Fatalf("Failed to pull model: %v", err) } @@ -112,7 +112,7 @@ func TestClientPullModel(t *testing.T) { var progressBuffer bytes.Buffer // Pull model from registry with progress writer - if err := client.PullModel(t.Context(), tag, &progressBuffer); err != nil { + if err := client.PullModel(t.Context(), tag, &progressBuffer, nil); err != nil { t.Fatalf("Failed to pull model: %v", err) } @@ -160,7 +160,7 @@ func TestClientPullModel(t *testing.T) { // Test with non-existent repository nonExistentRef := registryHost + "/nonexistent/model:v1.0.0" - err = testClient.PullModel(t.Context(), nonExistentRef, &progressBuffer) + err = testClient.PullModel(t.Context(), nonExistentRef, &progressBuffer, nil) if err == nil { t.Fatal("Expected error for non-existent model, got nil") } @@ -216,7 +216,7 @@ func TestClientPullModel(t *testing.T) { } // Push model to registry - if err := testClient.PushModel(t.Context(), testTag, nil); err != nil { + if err := testClient.PushModel(t.Context(), testTag, nil, nil); err != nil { t.Fatalf("Failed to pull model: %v", err) } @@ -262,7 +262,7 @@ func TestClientPullModel(t *testing.T) { var progressBuffer bytes.Buffer // Pull the model again - this should detect the incomplete file and pull again - if err := testClient.PullModel(t.Context(), testTag, &progressBuffer); err != nil { + if err := testClient.PullModel(t.Context(), testTag, &progressBuffer, nil); err != nil { t.Fatalf("Failed to pull model: %v", err) } @@ -315,7 +315,7 @@ func TestClientPullModel(t *testing.T) { } // Pull first version of model - if err := testClient.PullModel(t.Context(), testTag, nil); err != nil { + if err := testClient.PullModel(t.Context(), testTag, nil, nil); err != nil { t.Fatalf("Failed to pull first version of model: %v", err) } @@ -359,7 +359,7 @@ func TestClientPullModel(t *testing.T) { var progressBuffer bytes.Buffer // Pull model again - should get the updated version - if err := testClient.PullModel(t.Context(), testTag, &progressBuffer); err != nil { + if err := testClient.PullModel(t.Context(), testTag, &progressBuffer, nil); err != nil { t.Fatalf("Failed to pull updated model: %v", err) } @@ -405,7 +405,7 @@ func TestClientPullModel(t *testing.T) { if err := remote.Write(ref, newMdl, nil, remote.WithPlainHTTP(true)); err != nil { t.Fatalf("Failed to push model: %v", err) } - if err := client.PullModel(t.Context(), testTag, nil); err == nil || !errors.Is(err, ErrUnsupportedMediaType) { + if err := client.PullModel(t.Context(), testTag, nil, nil); err == nil || !errors.Is(err, ErrUnsupportedMediaType) { t.Fatalf("Expected artifact version error, got %v", err) } }) @@ -446,7 +446,7 @@ func TestClientPullModel(t *testing.T) { // Try to pull the safetensors model with a progress writer to capture warnings var progressBuf bytes.Buffer - err = testClient.PullModel(t.Context(), testTag, &progressBuf) + err = testClient.PullModel(t.Context(), testTag, &progressBuf, nil) // Pull should succeed on all platforms now (with a warning on non-Linux) if err != nil { @@ -478,7 +478,7 @@ func TestClientPullModel(t *testing.T) { var progressBuffer bytes.Buffer // Pull model from registry with progress writer - if err := testClient.PullModel(t.Context(), tag, &progressBuffer); err != nil { + if err := testClient.PullModel(t.Context(), tag, &progressBuffer, nil); err != nil { t.Fatalf("Failed to pull model: %v", err) } @@ -555,7 +555,7 @@ func TestClientPullModel(t *testing.T) { // Test with non-existent model nonExistentRef := registryHost + "/nonexistent/model:v1.0.0" - err = testClient.PullModel(t.Context(), nonExistentRef, &progressBuffer) + err = testClient.PullModel(t.Context(), nonExistentRef, &progressBuffer, nil) // Expect an error if err == nil { @@ -813,7 +813,7 @@ func TestNewReferenceError(t *testing.T) { // Test with invalid reference invalidRef := "invalid:reference:format" - err = client.PullModel(t.Context(), invalidRef, nil) + err = client.PullModel(t.Context(), invalidRef, nil, nil) if err == nil { t.Fatal("Expected error for invalid reference, got nil") } @@ -858,7 +858,7 @@ func TestPush(t *testing.T) { } // Push the model to the registry - if err := client.PushModel(t.Context(), tag, nil); err != nil { + if err := client.PushModel(t.Context(), tag, nil, nil); err != nil { t.Fatalf("Failed to push model: %v", err) } @@ -868,7 +868,7 @@ func TestPush(t *testing.T) { } // Test that model can be pulled successfully - if err := client.PullModel(t.Context(), tag, nil); err != nil { + if err := client.PullModel(t.Context(), tag, nil, nil); err != nil { t.Fatalf("Failed to pull model: %v", err) } @@ -929,7 +929,7 @@ func TestPushProgress(t *testing.T) { done := make(chan error, 1) go func() { defer pw.Close() - done <- client.PushModel(t.Context(), tag, pw) + done <- client.PushModel(t.Context(), tag, pw, nil) close(done) }() @@ -1051,7 +1051,7 @@ func TestClientPushModelNotFound(t *testing.T) { t.Fatalf("Failed to create client: %v", err) } - if err := client.PushModel(t.Context(), "non-existent-model:latest", nil); !errors.Is(err, ErrModelNotFound) { + if err := client.PushModel(t.Context(), "non-existent-model:latest", nil, nil); !errors.Is(err, ErrModelNotFound) { t.Fatalf("Expected ErrModelNotFound got: %v", err) } } diff --git a/pkg/distribution/distribution/ecr_test.go b/pkg/distribution/distribution/ecr_test.go index 699b8dbb6..cf9edf6d3 100644 --- a/pkg/distribution/distribution/ecr_test.go +++ b/pkg/distribution/distribution/ecr_test.go @@ -43,7 +43,7 @@ func TestECRIntegration(t *testing.T) { if err := client.store.Write(mdl, []string{ecrTag}, nil); err != nil { t.Fatalf("Failed to write model to store: %v", err) } - if err := client.PushModel(t.Context(), ecrTag, nil); err != nil { + if err := client.PushModel(t.Context(), ecrTag, nil, nil); err != nil { t.Fatalf("Failed to push model to ECR: %v", err) } if _, err := client.DeleteModel(ecrTag, false); err != nil { // cleanup @@ -53,7 +53,7 @@ func TestECRIntegration(t *testing.T) { // Test pull from ECR t.Run("Pull without progress", func(t *testing.T) { - err := client.PullModel(t.Context(), ecrTag, nil) + err := client.PullModel(t.Context(), ecrTag, nil, nil) if err != nil { t.Fatalf("Failed to pull model from ECR: %v", err) } diff --git a/pkg/distribution/distribution/gar_test.go b/pkg/distribution/distribution/gar_test.go index 669e10c12..b241eec6b 100644 --- a/pkg/distribution/distribution/gar_test.go +++ b/pkg/distribution/distribution/gar_test.go @@ -44,7 +44,7 @@ func TestGARIntegration(t *testing.T) { if err := client.store.Write(mdl, []string{garTag}, nil); err != nil { t.Fatalf("Failed to write model to store: %v", err) } - if err := client.PushModel(t.Context(), garTag, nil); err != nil { + if err := client.PushModel(t.Context(), garTag, nil, nil); err != nil { t.Fatalf("Failed to push model to ECR: %v", err) } if _, err := client.DeleteModel(garTag, false); err != nil { // cleanup @@ -54,7 +54,7 @@ func TestGARIntegration(t *testing.T) { // Test pull from GAR t.Run("Pull without progress", func(t *testing.T) { - err := client.PullModel(t.Context(), garTag, nil) + err := client.PullModel(t.Context(), garTag, nil, nil) if err != nil { t.Fatalf("Failed to pull model from GAR: %v", err) } diff --git a/pkg/inference/models/api.go b/pkg/inference/models/api.go index ffb724c12..c4ad37ea4 100644 --- a/pkg/inference/models/api.go +++ b/pkg/inference/models/api.go @@ -16,6 +16,20 @@ import ( type ModelCreateRequest struct { // From is the name of the model to pull. From string `json:"from"` + // BearerToken is an optional bearer token for authentication (e.g., for Hugging Face). + BearerToken string `json:"bearer-token,omitempty"` + // Username is an optional username for registry authentication. + Username string `json:"username,omitempty"` + // Password is an optional password for registry authentication. + Password string `json:"password,omitempty"` +} + +// ModelPushRequest represents a model push request. +type ModelPushRequest struct { + // Username is an optional username for registry authentication. + Username string `json:"username,omitempty"` + // Password is an optional password for registry authentication. + Password string `json:"password,omitempty"` // BearerToken is an optional bearer token for authentication. BearerToken string `json:"bearer-token,omitempty"` } diff --git a/pkg/inference/models/handler_test.go b/pkg/inference/models/handler_test.go index d8c0d580e..c764c2cdc 100644 --- a/pkg/inference/models/handler_test.go +++ b/pkg/inference/models/handler_test.go @@ -114,7 +114,7 @@ func TestPullModel(t *testing.T) { } w := httptest.NewRecorder() - err = handler.manager.Pull(tag, "", r, w) + err = handler.manager.Pull(tag, nil, r, w) if err != nil { t.Fatalf("Failed to pull model: %v", err) } @@ -221,7 +221,7 @@ func TestHandleGetModel(t *testing.T) { if !tt.remote && !strings.Contains(tt.modelName, "nonexistent") { r := httptest.NewRequest(http.MethodPost, "/models/create", strings.NewReader(`{"from": "`+tt.modelName+`"}`)) w := httptest.NewRecorder() - err = handler.manager.Pull(tt.modelName, "", r, w) + err = handler.manager.Pull(tt.modelName, nil, r, w) if err != nil { t.Fatalf("Failed to pull model: %v", err) } diff --git a/pkg/inference/models/http_handler.go b/pkg/inference/models/http_handler.go index 58886bfb0..901b39942 100644 --- a/pkg/inference/models/http_handler.go +++ b/pkg/inference/models/http_handler.go @@ -104,8 +104,17 @@ func (h *HTTPHandler) handleCreateModel(w http.ResponseWriter, r *http.Request) return } + var creds *distribution.Credentials + if request.Username != "" || request.Password != "" || request.BearerToken != "" { + creds = &distribution.Credentials{ + Username: request.Username, + Password: request.Password, + BearerToken: request.BearerToken, + } + } + // Pull the model - if err := h.manager.Pull(request.From, request.BearerToken, r, w); err != nil { + if err := h.manager.Pull(request.From, creds, r, w); err != nil { sanitizedFrom := utils.SanitizeForLog(request.From, -1) if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { h.log.Infof("Request canceled/timed out while pulling model %q", sanitizedFrom) @@ -417,7 +426,21 @@ func (h *HTTPHandler) handleTagModel(w http.ResponseWriter, r *http.Request, mod // handlePushModel handles POST /models/{name}/push requests. func (h *HTTPHandler) handlePushModel(w http.ResponseWriter, r *http.Request, model string) { - if err := h.manager.Push(model, r, w); err != nil { + var creds *distribution.Credentials + var request ModelPushRequest + if r.Body != nil && r.ContentLength > 0 { + if err := json.NewDecoder(r.Body).Decode(&request); err == nil { + if request.Username != "" || request.Password != "" || request.BearerToken != "" { + creds = &distribution.Credentials{ + Username: request.Username, + Password: request.Password, + BearerToken: request.BearerToken, + } + } + } + } + + if err := h.manager.Push(model, creds, r, w); err != nil { if errors.Is(err, distribution.ErrInvalidReference) { h.log.Warnf("Invalid model reference %q: %v", utils.SanitizeForLog(model, -1), err) http.Error(w, "Invalid model reference", http.StatusBadRequest) diff --git a/pkg/inference/models/manager.go b/pkg/inference/models/manager.go index 858395905..e466175b0 100644 --- a/pkg/inference/models/manager.go +++ b/pkg/inference/models/manager.go @@ -207,7 +207,7 @@ func (m *Manager) Delete(reference string, force bool) (*distribution.DeleteMode // Pull pulls a model to local storage. Any error it returns is suitable // for writing back to the client. -func (m *Manager) Pull(model string, bearerToken string, r *http.Request, w http.ResponseWriter) error { +func (m *Manager) Pull(model string, creds *distribution.Credentials, r *http.Request, w http.ResponseWriter) error { // Restrict model pull concurrency. select { case <-m.pullTokens: @@ -250,16 +250,7 @@ func (m *Manager) Pull(model string, bearerToken string, r *http.Request, w http // Pull the model using the Docker model distribution client m.log.Infoln("Pulling model:", utils.SanitizeForLog(model, -1)) - // Use bearer token if provided - var err error - if bearerToken != "" { - m.log.Infoln("Using provided bearer token for authentication") - err = m.distributionClient.PullModel(r.Context(), model, progressWriter, bearerToken) - } else { - err = m.distributionClient.PullModel(r.Context(), model, progressWriter) - } - - if err != nil { + if err := m.distributionClient.PullModel(r.Context(), model, progressWriter, creds); err != nil { return fmt.Errorf("error while pulling model: %w", err) } @@ -369,7 +360,7 @@ func (m *Manager) Tag(ref, target string) error { } // Push pushes a model from the store to the registry. -func (m *Manager) Push(model string, r *http.Request, w http.ResponseWriter) error { +func (m *Manager) Push(model string, creds *distribution.Credentials, r *http.Request, w http.ResponseWriter) error { // Set up response headers for streaming w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") @@ -398,8 +389,7 @@ func (m *Manager) Push(model string, r *http.Request, w http.ResponseWriter) err isJSON: isJSON, } - err := m.distributionClient.PushModel(r.Context(), model, progressWriter) - if err != nil { + if err := m.distributionClient.PushModel(r.Context(), model, progressWriter, creds); err != nil { return fmt.Errorf("error while pushing model: %w", err) } diff --git a/pkg/ollama/http_handler.go b/pkg/ollama/http_handler.go index da724a62e..bbc0c4469 100644 --- a/pkg/ollama/http_handler.go +++ b/pkg/ollama/http_handler.go @@ -651,7 +651,7 @@ func (h *HTTPHandler) handlePull(w http.ResponseWriter, r *http.Request) { } // Call the model manager's Pull method with the wrapped writer - if err := h.modelManager.Pull(modelName, "", r, ollamaWriter); err != nil { + if err := h.modelManager.Pull(modelName, nil, r, ollamaWriter); err != nil { h.log.Errorf("Failed to pull model: %s", utils.SanitizeForLog(err.Error(), -1)) // Send error in Ollama JSON format