Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/cli/commands/push.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
func newPushCmd() *cobra.Command {
c := &cobra.Command{
Use: "push MODEL",
Short: "Push a model to Docker Hub",
Short: "Push a model to Docker Hub or Hugging Face",
Args: requireExactArgs(1, "push", "MODEL"),
RunE: func(cmd *cobra.Command, args []string) error {
return pushModel(cmd, desktopClient, args[0])
Expand Down
18 changes: 17 additions & 1 deletion cmd/cli/desktop/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,28 @@ func (c *Client) withRetries(
}

func (c *Client) Push(model string, printer standalone.StatusPrinter) (string, bool, error) {
var hfToken string
modelLower := strings.ToLower(model)
if strings.HasPrefix(modelLower, "hf.co/") || strings.HasPrefix(modelLower, "huggingface.co/") {
hfToken = os.Getenv("HF_TOKEN")
}

return c.withRetries("push", 3, printer, func(attempt int) (string, bool, error, bool) {
pushPath := inference.ModelsPrefix + "/" + model + "/push"
var body io.Reader
if hfToken != "" {
jsonData, err := json.Marshal(dmrm.ModelPushRequest{
BearerToken: hfToken,
})
if err != nil {
return "", false, fmt.Errorf("error marshaling request: %w", err), false
}
body = bytes.NewReader(jsonData)
}
resp, err := c.doRequest(
http.MethodPost,
pushPath,
nil, // Assuming no body is needed for the push request
body,
)
if err != nil {
// Only retry on network errors, not on client errors
Expand Down
4 changes: 2 additions & 2 deletions cmd/cli/docs/reference/docker_model_push.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
command: docker model push
short: Push a model to Docker Hub
long: Push a model to Docker Hub
short: Push a model to Docker Hub or Hugging Face
long: Push a model to Docker Hub or Hugging Face
usage: docker model push MODEL
pname: docker model
plink: docker_model.yaml
Expand Down
2 changes: 1 addition & 1 deletion cmd/cli/docs/reference/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Docker Model Runner
| [`ps`](model_ps.md) | List running models |
| [`pull`](model_pull.md) | Pull a model from Docker Hub or HuggingFace to your local environment |
| [`purge`](model_purge.md) | Remove all models |
| [`push`](model_push.md) | Push a model to Docker Hub |
| [`push`](model_push.md) | Push a model to Docker Hub or Hugging Face |
| [`reinstall-runner`](model_reinstall-runner.md) | Reinstall Docker Model Runner (Docker Engine only) |
| [`requests`](model_requests.md) | Fetch requests+responses from Docker Model Runner |
| [`restart-runner`](model_restart-runner.md) | Restart Docker Model Runner (Docker Engine only) |
Expand Down
2 changes: 1 addition & 1 deletion cmd/cli/docs/reference/model_push.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# docker model push

<!---MARKER_GEN_START-->
Push a model to Docker Hub
Push a model to Docker Hub or Hugging Face


<!---MARKER_GEN_END-->
Expand Down
85 changes: 79 additions & 6 deletions pkg/distribution/distribution/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import (
"fmt"
"io"
"os"
"path/filepath"
"slices"
"strings"

"github.com/docker/model-runner/pkg/distribution/huggingface"
"github.com/docker/model-runner/pkg/distribution/internal/bundle"
"github.com/docker/model-runner/pkg/distribution/internal/mutate"
"github.com/docker/model-runner/pkg/distribution/internal/progress"
"github.com/docker/model-runner/pkg/distribution/internal/store"
Expand Down Expand Up @@ -589,21 +591,34 @@ func (c *Client) Tag(source string, target string) error {
}

// PushModel pushes a tagged model from the content store to the registry.
func (c *Client) PushModel(ctx context.Context, tag string, progressWriter io.Writer) (err error) {
// Parse the tag
target, err := c.registry.NewTarget(tag)
func (c *Client) PushModel(ctx context.Context, tag string, progressWriter io.Writer, bearerToken ...string) (err error) {
originalReference := tag
normalizedRef := c.normalizeModelName(tag)

var token string
if len(bearerToken) > 0 && bearerToken[0] != "" {
token = bearerToken[0]
}

if isHuggingFaceReference(originalReference) {
return c.pushNativeHuggingFace(ctx, originalReference, normalizedRef, progressWriter, token)
}

registryClient := c.registry
if token != "" {
auth := authn.NewBearer(token)
registryClient = registry.FromClient(c.registry, registry.WithAuth(auth))
}
target, err := registryClient.NewTarget(tag)
if err != nil {
return fmt.Errorf("new tag: %w", err)
}

// Get the model from the store
normalizedRef := c.normalizeModelName(tag)
mdl, err := c.store.Read(normalizedRef)
if err != nil {
return fmt.Errorf("reading model: %w", err)
}

// Push the model
c.log.Infoln("Pushing model:", utils.SanitizeForLog(tag, -1))
if err := target.Write(ctx, mdl, progressWriter); err != nil {
c.log.Errorln("Failed to push image:", err, "reference:", tag)
Expand All @@ -621,6 +636,64 @@ func (c *Client) PushModel(ctx context.Context, tag string, progressWriter io.Wr
return nil
}

func (c *Client) pushNativeHuggingFace(ctx context.Context, reference, normalizedRef string, progressWriter io.Writer, token string) error {
repo, _, _ := parseHFReference(reference)
c.log.Infof("Pushing native HuggingFace model: repo=%s", utils.SanitizeForLog(repo))

if progressWriter != nil {
_ = progress.WriteProgress(progressWriter, "Preparing HuggingFace upload...", 0, 0, 0, "", oci.ModePush)
}

modelBundle, err := c.store.BundleForModel(normalizedRef)
if err != nil {
return fmt.Errorf("get model bundle: %w", err)
}

modelDir := filepath.Join(modelBundle.RootDir(), bundle.ModelSubdir)
files, totalSize, err := huggingface.CollectUploadFiles(modelDir)
if err != nil {
return fmt.Errorf("collect bundle files: %w", err)
}
if len(files) == 0 {
return fmt.Errorf("no model files found to upload")
}

hfOpts := []huggingface.ClientOption{
huggingface.WithUserAgent(registry.DefaultUserAgent),
}
if token != "" {
hfOpts = append(hfOpts, huggingface.WithToken(token))
}
hfClient := huggingface.NewClient(hfOpts...)

if progressWriter != nil {
msg := fmt.Sprintf("Uploading %d files (%.2f MB total)", len(files), float64(totalSize)/1024/1024)
_ = progress.WriteProgress(progressWriter, msg, uint64(totalSize), 0, 0, "", oci.ModePush)
}

if err := huggingface.UploadFiles(ctx, hfClient, repo, files, totalSize, progressWriter); err != nil {
c.log.Errorf("HuggingFace push failed: %v", err)
var authErr *huggingface.AuthError
var notFoundErr *huggingface.NotFoundError
if errors.As(err, &authErr) {
return registry.ErrUnauthorized
}
if errors.As(err, &notFoundErr) {
return registry.ErrModelNotFound
}
if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error()), oci.ModePush); writeErr != nil {
c.log.Warnf("Failed to write error message: %v", writeErr)
}
return fmt.Errorf("upload model to HuggingFace: %w", err)
}

if err := progress.WriteSuccess(progressWriter, "Model pushed successfully", oci.ModePush); err != nil {
c.log.Warnf("Failed to write success message: %v", err)
}

return nil
}

// WriteLightweightModel writes a model to the store without transferring layer data.
// This is used for config-only modifications where the layer data hasn't changed.
// The layers must already exist in the store.
Expand Down
Loading
Loading