mlflow: Add version tags for registered models#31
Conversation
Add the following model version tags when logging a model to MLflow: * model_uri: The URI of the model artifact * model_type: The type of the model (e.g. 'medcat_snomed') * validation_status: The validation status of the model (e.g. 'pending') Signed-off-by: Phoevos Kalemkeris <phoevos.kalemkeris@ucl.ac.uk>
Signed-off-by: Phoevos Kalemkeris <phoevos.kalemkeris@ucl.ac.uk>
app/management/tracker_client.py
Outdated
| model_name: str, | ||
| model_manager: ModelManager, | ||
| validation_status: str = "pending", | ||
| model_type: Optional[str] = None, |
There was a problem hiding this comment.
All CMS models have the ModelType, hence there's no need to make the argument optional and mlflow.set_tag() will not set a None value.
|
@baixiac so |
Signed-off-by: Phoevos Kalemkeris <phoevos.kalemkeris@ucl.ac.uk>
Signed-off-by: Phoevos Kalemkeris <phoevos.kalemkeris@ucl.ac.uk>
There was a problem hiding this comment.
Pull request overview
This PR adds version tags to registered models in MLflow to improve model tracking and discovery. The changes introduce three new tags: model_uri, model_type, and validation_status that are attached to model versions when logging models.
Key changes:
- Added
_set_model_version_tagshelper method to standardize tag setting across model registration flows - Updated
save_modelandsave_pretrained_modelmethods to acceptmodel_typeparameter and set version tags - Modified all trainer implementations to pass model type information when saving models
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| app/management/tracker_client.py | Added _set_model_version_tags static method and updated save_model/save_pretrained_model to set version tags including model_uri, model_type, and validation_status |
| app/trainers/metacat_trainer.py | Updated save_model call to include model type from model service |
| app/trainers/medcat_trainer.py | Updated save_model calls (2 locations) to include model type from model service |
| app/trainers/medcat_deid_trainer.py | Updated save_model call to include model type from model service |
| app/trainers/huggingface_ner_trainer.py | Updated save_model calls (2 locations) to include model type from model service |
| app/trainers/huggingface_llm_trainer.py | Updated save_model call to include model type from model service |
| app/cli/cli.py | Updated save_pretrained_model call to pass model_type parameter |
| tests/app/monitoring/test_tracker_client.py | Enhanced tests to verify version tags are set correctly; added mock setup for pretrained model test |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| mlflow.set_tag.has_calls( | ||
| [ | ||
| call("training.output.package", "file.zip"), | ||
| call("training.output.model_uri", artifact_uri), | ||
| call("training.output.model_type", "model_type"), | ||
| ], | ||
| any_order=False, | ||
| ) |
There was a problem hiding this comment.
Incorrect assertion method. Should be assert_has_calls instead of has_calls. The current code will not actually perform the assertion, allowing the test to pass even if the calls were not made.
Add the following model version tags when logging a model to MLflow: