Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package cmd

import (
"log"
"net/http"
_ "net/http/pprof"
"os"
Expand Down Expand Up @@ -51,7 +50,7 @@ var rootCmd = &cobra.Command{
go func() {
err := http.ListenAndServe(rootConfig.PprofAddr, nil)
if err != nil {
log.Fatal(err)
logrus.Fatal(err)
}
}()
}
Expand Down
42 changes: 22 additions & 20 deletions pkg/modelprovider/mlflow/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
"github.com/databricks/databricks-sdk-go/client"
"github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/service/ml"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus"
)

type MlFlowClient struct {
Expand All @@ -39,7 +39,7 @@ func NewMlFlowRegistry(mlflowClient *client.DatabricksClient) (MlFlowClient, err

if mlflowClient != nil {
registry = ml.NewModelRegistry(mlflowClient)
log.Println("Use default mlflow client for MlFlowRegistryAPI")
logrus.Infof("mlflow: using default client for MlFlowRegistryAPI")
return MlFlowClient{registry: registry}, nil
}

Expand Down Expand Up @@ -83,7 +83,7 @@ func (mlfr *MlFlowClient) PullModelByName(

pullVersion = versions[0].Version

log.Printf("Found versions: '%v' for model '%s'\n", pullVersion, modelName)
logrus.Infof("mlflow: found versions '%v' for model '%s'", pullVersion, modelName)

} else {

Expand All @@ -109,7 +109,7 @@ func (mlfr *MlFlowClient) PullModelByName(
pullVersion = modelVersion
}
}
log.Printf("Start pull model from model registry with version %s", pullVersion)
logrus.Infof("mlflow: starting pull model from registry with version %s", pullVersion)

uri, err := mlfr.registry.GetModelVersionDownloadUri(ctx, ml.GetModelVersionDownloadUriRequest{
Name: modelName,
Expand All @@ -118,7 +118,7 @@ func (mlfr *MlFlowClient) PullModelByName(
if err != nil {
return "", errors.Join(errors.New("failed fetch download uri for model"), err)
}
log.Printf("Try pull model from uri %s", uri.ArtifactUri)
logrus.Infof("mlflow: pulling model from uri %s", uri.ArtifactUri)
parsed, err := url.Parse(uri.ArtifactUri)
if err != nil {
return "", fmt.Errorf("failed to parse artifact uri: %w", err)
Expand All @@ -141,7 +141,7 @@ func (mlfr *MlFlowClient) PullModelByName(
return "", err
}

log.Printf("✅ Model downloaded")
logrus.Infof("mlflow: model downloaded successfully")

return destSrc, nil
}
Expand All @@ -162,15 +162,15 @@ func (s3back *S3StorageBackend) DownloadModel(

bucketName := parsed.Host
s3FolderPrefix := strings.TrimPrefix(parsed.Path, "/")
log.Printf("Parsed s3 bucket %s, path %s from path", bucketName, s3FolderPrefix)
logrus.Debugf("mlflow: parsed s3 bucket %s, path %s", bucketName, s3FolderPrefix)

cfg, err := awsconfig.LoadDefaultConfig(ctx)
if err != nil {
wrap := fmt.Errorf("Error loading AWS config, try change envs or profile: %v\n", err)
return errors.Join(wrap, err)
}

log.Printf("Region - %s, endpoint - %s", cfg.Region, aws.ToString(cfg.BaseEndpoint))
logrus.Debugf("mlflow: aws region %s, endpoint %s", cfg.Region, aws.ToString(cfg.BaseEndpoint))

s3Client := s3.NewFromConfig(cfg)

Expand All @@ -184,18 +184,18 @@ func (s3back *S3StorageBackend) DownloadModel(
Prefix: aws.String(s3FolderPrefix),
})

log.Printf("Start downloading from s3 bucket %s, path %s", bucketName, s3FolderPrefix)
logrus.Infof("mlflow: starting download from s3 bucket %s, path %s", bucketName, s3FolderPrefix)

for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
log.Printf("Error listing objects: %v\n", err)
logrus.Errorf("mlflow: failed to list objects: %v", err)
return err
}

for _, object := range page.Contents {
s3Key := *object.Key
log.Printf("Downloading object: %s\n", s3Key)
logrus.Debugf("mlflow: downloading object %s", s3Key)
if strings.HasSuffix(s3Key, "/") { // Skip S3 "folder" markers
continue
}
Expand All @@ -208,18 +208,17 @@ func (s3back *S3StorageBackend) DownloadModel(
// Create local directories if they don't exist
err = os.MkdirAll(filepath.Dir(localFilePath), 0o755)
if err != nil {
log.Printf(
"Error creating local directory %s: %v\n",
filepath.Dir(localFilePath),
err,
logrus.Errorf(
"mlflow: failed to create local directory %s: %v",
filepath.Dir(localFilePath), err,
)
continue
}

// Download the object
file, err := os.Create(localFilePath)
if err != nil {
log.Printf("Error creating local file %s: %v\n", localFilePath, err)
logrus.Errorf("mlflow: failed to create local file %s: %v", localFilePath, err)
continue
}

Expand All @@ -230,18 +229,21 @@ func (s3back *S3StorageBackend) DownloadModel(
closeErr := file.Close()
if err != nil || closeErr != nil {
if err != nil {
log.Printf("Error downloading object %s: %v\n", s3Key, err)
logrus.Errorf("mlflow: failed to download object %s: %v", s3Key, err)
}
if closeErr != nil {
log.Printf("Error closing file %s: %v\n", localFilePath, closeErr)
logrus.Errorf("mlflow: failed to close file %s: %v", localFilePath, closeErr)
}
if removeErr := os.Remove(localFilePath); removeErr != nil &&
!errors.Is(removeErr, os.ErrNotExist) {
log.Printf("Error removing partial file %s: %v\n", localFilePath, removeErr)
logrus.Errorf(
"mlflow: failed to remove partial file %s: %v",
localFilePath, removeErr,
)
}
continue
}
log.Printf("Downloaded %s to %s (%d bytes)\n", s3Key, localFilePath, numBytes)
logrus.Debugf("mlflow: downloaded %s to %s (%d bytes)", s3Key, localFilePath, numBytes)
}
}

Expand Down
10 changes: 5 additions & 5 deletions pkg/modelprovider/mlflow/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
"slices"
"strings"

log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus"
)

// MlflowProvider implements the modelprovider.Provider interface for Mlflow
Expand Down Expand Up @@ -112,11 +112,11 @@ func checkMlflowAuth() error {
if isAllNonEmpty(databricksEnvs) {
return nil
} else if isAllNonEmpty(mlflowEnvs) {
log.Printf("Detected MlFlow environment variables, set DATABRICKS_* envs \n")
logrus.Infof("mlflow: detected MlFlow environment variables, setting DATABRICKS_* envs")
} else {
log.Println("Please set DATABRICKS_HOST or MLFLOW_TRACKING_URI environment variable.")
log.Println("Authentication for MLflow/Databricks is not configured.")
log.Println("See https://pkg.go.dev/github.com/databricks/databricks-sdk-go/config for more details on configuration.")
logrus.Warnf("mlflow: please set DATABRICKS_HOST or MLFLOW_TRACKING_URI environment variable")
logrus.Warnf("mlflow: authentication for MLflow/Databricks is not configured")
logrus.Warnf("mlflow: see https://pkg.go.dev/github.com/databricks/databricks-sdk-go/config for more details on configuration")

return errors.New("mlflow/databricks authentication not configured")
}
Expand Down