Skip to content

MLflow for Darts implementation#3022

Draft
jakubchlapek wants to merge 56 commits intounit8co:masterfrom
jakubchlapek:feat/mlflow-base
Draft

MLflow for Darts implementation#3022
jakubchlapek wants to merge 56 commits intounit8co:masterfrom
jakubchlapek:feat/mlflow-base

Conversation

@jakubchlapek
Copy link
Collaborator

Checklist before merging this PR:

  • Mentioned all issues that this PR fixes or addresses.
  • Summarized the updates of this PR under Summary.
  • Added an entry under Unreleased in the Changelog.

Addresses #2092 .

Summary

Provides a custom MLflow flavor for Darts on Darts' side. Supports autologging, logging, saving and loading of the models.
This PR focuses on the base MLflow integration, leaving serving of the models to be discussed in the future.

Included an example quickstart for the integration, however consider all of this a draft :)
Find example code in the .ipynb, however also providing a code snippet here as a quick reproducible example:

import mlflow
import tempfile
import os
from darts.metrics.metrics import smape
from darts.utils.mlflow import load_model, autolog
from darts.models import NBEATSModel, LinearRegressionModel
from darts.datasets import AirPassengersDataset
from torchmetrics import MeanAbsoluteError

# temp file setup
tmpdir = tempfile.mkdtemp()
mlflow_db = os.path.join(tmpdir, "mlflow.db")
mlflow.set_tracking_uri(f"sqlite:///{mlflow_db}")
mlflow.set_experiment("darts-forecasting")

train, val = AirPassengersDataset().load().astype("float32").split_before(0.7)

# autologging - patches .fit() on all ForecastingModel subclasses.
# for PyTorch-based models, inject_per_epoch_callbacks injects a Lightning callback
# that logs train/val loss or/and  user-specified torch metrics at the end of each epoch automatically.
autolog(
    log_models=True,
    log_params=True,
    log_training_metrics=True,
    log_validation_metrics=True,   # requires val_series in .fit()
    inject_per_epoch_callbacks=True, 
    extra_metrics=[smape],         # optional extra darts metric functions
)

with mlflow.start_run(run_name="nbeats") as run:
    model = NBEATSModel(
        input_chunk_length=24, 
        output_chunk_length=12,
        torch_metrics=MeanAbsoluteError())
    # val_series is forwarded to Lightning's val_dataloaders;
    # autolog captures per-epoch val metrics via the injected callback
    model.fit(train, val_series=val, epochs=10)
    run_id = run.info.run_id


# regression/sklearn models work identically
with mlflow.start_run(run_name="linreg"):
    model = LinearRegressionModel(lags=12)
    model.fit(train)  # logs params + in-sample metrics

# load back from MLflow
loaded = load_model(f"runs:/{run_id}/model")
preds = loaded.predict(12, series=train) # need to specify series as we save with clean=True in save_model

# import shutil
# shutil.rmtree(tmpdir)

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@jakubchlapek
Copy link
Collaborator Author

Hey @daidahao, adding this draft PR in the meantime so you and @dennisbader can have a look at what I have currently regarding the integration. There are still some decisions I am not too thrilled about and decisions to be made about the overall direction, but I'm happy to talk more about it during the meeting. Thanks for being so active for the library, really nice to be working together :)

@daidahao
Copy link
Contributor

Hi @jakubchlapek

Sorry to hear that. Thank you for all your work and I can attest that you absolutely built a solid foundation for MLflow in Darts!

@daidahao
Copy link
Contributor

daidahao commented Mar 2, 2026

@jakubchlapek @dennisbader @mizeller

I've reviewed and provided some comments here for the code (except for unit tests and notebook). Overall, I agree with most of the design choices by @jakubchlapek and there are only a few deviations from mine (e.g., using MLflow existing APIs, Darts raise_log(), logging fit() params) which I believe could benefit the current impl..

I wonder if it would be easier if @jakubchlapek could allow me to edit the code directly to address my comments here? That way, @mizeller can then continue the work here with the best of both worlds, and focus on unit tests or extension, etc. Let me know what you think.

Once again, thank you for the great work @jakubchlapek and I am glad that our impl. align so well.

@mizeller
Copy link
Contributor

mizeller commented Mar 3, 2026

@daidahao Perfect timing - I wrapped my head around @jakubchlapek's implementation & had a short meeting w/ @dennisbader last Friday about what needs to be done.

I did fork off of @jakubchlapek's branch here.

You can add your changes there and we can merge into @jakubchlapek's branch when we're done (or wait for him to give you access to his branch.) Just added you as collaborator there. Whatever you prefer :)

I'll try to squeeze this PR in this week and ping you here if something's unclear if that's okay!

@jakubchlapek
Copy link
Collaborator Author

Hey @daidahao, @mizeller. I have added you both as collaborators now, so you should be able to edit the code directly here.

daidahao added 3 commits March 5, 2026 09:34
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
try:
import pytorch_lightning as pl # noqa: F401

PL_AVAILABLE = True

This comment was marked as resolved.

pyproject.toml Outdated
]
notorch = [
"catboost>=1.0.6",
"catboost>=1.0.6,<=1.2.9",

This comment was marked as resolved.

"statsforecast>=1.4",
"xgboost>=2.1.4",
]
mlflow = ["mlflow>=2.0"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like a discussion on the new option here. My understanding is that users who would need the Darts-MLflow integration probably have MLFlow installed already and set up properly. For users who have not, MLflow itself has options for databricks, which some users might find useful. Could we instead direct users to MLflow official guide for installation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, @MichaelVerdegaal raised a suggestion for mlflow>=3.0. I have not used MLflow 2.x before but I think the minimum version should be deliberated as well.

does not apply the `with_managed_run` wrapper to the specified
`patch_function`.
"""
# Enable/disable mlflow.pytorch.autolog for per-epoch metrics on torch models.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't understand why the decorator would short-circuit here if we call mlflow.pytorch.autolog() with disable=True. Looking at XGBoost flavour, it seems they are able to call mlflow.sklearn._autolog() within mlflow.xgboost.autolog(). Is it because mlflow.sklearn._autolog() is not wrapped but mlflow.pytorch.autolog() is?


classes_to_patch = [ForecastingModel]

for subclass in get_all_subclasses(ForecastingModel):

This comment was marked as resolved.


def log_model(
model,
artifact_path: str | None = None,

This comment was marked as resolved.

log_models: bool = True,
log_params: bool = True,
log_metrics: bool = True,
inject_per_epoch_callbacks: bool = True,

This comment was marked as resolved.

A list of pip requirement strings.
"""
reqs = [_get_pinned_requirement("darts")]
if is_torch:

This comment was marked as resolved.

if code_dir_subpath is not None:
darts_flavor_conf["code"] = code_dir_subpath

default_reqs = None if pip_requirements else get_default_pip_requirements(is_torch)

This comment was marked as resolved.

bool
True if the model is a TorchForecastingModel, False otherwise.
"""
try:

This comment was marked as resolved.

daidahao added 7 commits March 5, 2026 09:47
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
daidahao and others added 13 commits March 5, 2026 14:52
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
@daidahao
Copy link
Contributor

daidahao commented Mar 7, 2026

@mizeller @jakubchlapek @dennisbader

Greetings! I've addressed most of the comments here, except for a few discussion points.

I've left a TODO note on post-fitting metrics which, IMHO, are HARD to implement at this point due to how MLflow manages active runs in autolog context. In short, we would need to keep a mapping between MLflow run ids, fitted models, model predictions, and metrics, to ensure the metrics are logged under the right run id (see mlflow.sklearn).

Sincere apologies for suggesting post-fitting metrics in the first place! I didn't realise the complexity involved.

My suggestion is to skip post-fitting metrics for now or settle for compromises such as non-terminated active runs (at the risk of cross-logging).

Other than that, I am truly proud of what we have achieved here and will hand this over to @mizeller for runups and more great work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants