How to Train and Track ML Models with MLflow in Databricks (Beginner to Advanced Guide)

Posted by


How to Train and Track ML Models with MLflow in Databricks (Beginner to Advanced Guide)

MLflow is the open-source standard for managing the end-to-end machine learning lifecycle, and Databricks integrates it seamlessly. Whether you’re a beginner experimenting with basic models or an advanced user deploying models in production, MLflow in Databricks offers powerful tools to streamline your workflow.

This blog walks through training and tracking ML models using MLflow in Databricks, covering everything from setup to model registry and deployment.


๐Ÿ”น What is MLflow?

MLflow is an open-source platform for managing the ML lifecycle, including:

  • Experiment Tracking โ€“ Log and query experiments: code, data, config, and results.
  • Model Packaging โ€“ Package models in a reusable format.
  • Model Registry โ€“ Central store for versioned models.
  • Model Deployment โ€“ Deploy ML models to various serving environments.

๐Ÿงฐ Prerequisites

Before you start:

  • A Databricks workspace (Community or Enterprise)
  • A notebook (preferably in Python)
  • Basic knowledge of ML (Scikit-learn, PySpark ML, or TensorFlow)

๐Ÿš€ Step 1: Set Up MLflow in Databricks

MLflow is pre-installed in Databricks notebooks, so no installation needed. Just import it:

import mlflow
import mlflow.sklearn  # or mlflow.spark / mlflow.tensorflow etc.

๐Ÿงช Step 2: Create or Select an Experiment

Create a new experiment or set an existing one:

mlflow.set_experiment("/Users/your.name@databricks.com/ML_demo_experiment")

Alternatively, you can assign it inline:

with mlflow.start_run():
    # training code here

๐Ÿง  Step 3: Train a Model and Log Parameters, Metrics, and Artifacts

Hereโ€™s a Scikit-learn example:

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
import pandas as pd

# Load data
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target)

# Start MLflow run
with mlflow.start_run():
    clf = RandomForestClassifier(n_estimators=100)
    clf.fit(X_train, y_train)
    preds = clf.predict(X_test)
    acc = accuracy_score(y_test, preds)

    # Log parameters, metrics, and model
    mlflow.log_param("n_estimators", 100)
    mlflow.log_metric("accuracy", acc)
    mlflow.sklearn.log_model(clf, "model")

โœ… Youโ€™ll see your run in the MLflow Experiment UI under the โ€œExperimentโ€ tab.


๐Ÿ“Š Step 4: Log Additional Artifacts (Data, Images, Configs)

MLflow lets you log any file:

import matplotlib.pyplot as plt
import tempfile

# Example plot
plt.scatter(X_test[:, 0], X_test[:, 1], c=preds)
plt.title("Predictions")
tmp_path = tempfile.mktemp(suffix=".png")
plt.savefig(tmp_path)
mlflow.log_artifact(tmp_path, "plots")

You can also log:

  • Confusion matrices
  • Feature importance charts
  • Model configs as JSON or YAML
  • Training datasets

๐Ÿงญ Step 5: Explore and Compare Experiments

Navigate to the MLflow Experiments UI to:

  • Compare multiple runs by accuracy, parameters
  • Visualize metrics over time
  • Filter runs using tags and expressions

Example:

runs_df = mlflow.search_runs(order_by=["metrics.accuracy DESC"])
display(runs_df)

๐Ÿงพ Step 6: Use Tags and Descriptions for Better Organization

mlflow.set_tag("model_type", "RandomForest")
mlflow.set_tag("feature_set", "Iris")
mlflow.set_tag("developer", "Patrick Aaron")

These help when searching and filtering experiments.


๐Ÿ—‚๏ธ Step 7: Register the Model

result = mlflow.register_model(
    "runs:/{run_id}/model", "iris_random_forest_model"
)

In the Databricks UI, go to Model Registry to:

  • Track model versions
  • Transition stages (Staging, Production)
  • Add comments and approvals

๐Ÿ“ฆ Step 8: Load and Use the Registered Model

from mlflow.pyfunc import load_model

model = load_model("models:/iris_random_forest_model/Production")
predictions = model.predict(X_test)

Or using UDFs with Spark:

import mlflow.pyfunc
from pyspark.sql.functions import struct

model_udf = mlflow.pyfunc.spark_udf(spark, model_uri="models:/iris_random_forest_model/Production")
df = spark.createDataFrame(pd.DataFrame(X_test))
df.withColumn("predictions", model_udf(struct(*df.columns))).show()

๐Ÿงช Step 9: Automate with MLflow Projects and Pipelines (Advanced)

MLflow Projects lets you define reproducible pipelines with MLproject YAML files.

Sample MLproject:

name: iris-classifier
conda_env: conda.yaml
entry_points:
  main:
    parameters:
      n_estimators: {type: int, default: 100}
    command: "python train.py --n_estimators {n_estimators}"

Trigger it:

mlflow run . -P n_estimators=200

๐Ÿง‘โ€๐Ÿ’ผ Step 10: Deploy Model as a REST API (Advanced)

You can deploy directly in Databricks using MLflow Model Serving:

# From UI: Serve โ†’ Enable serving
# Or use REST endpoint for real-time inference

Sample request:

curl -X POST https://<databricks-host>/model/iris_random_forest_model/production/invocations \
-H "Authorization: Bearer <token>" \
-H "Content-Type: application/json" \
-d '{"data": [[5.1, 3.5, 1.4, 0.2]]}'

๐Ÿ› ๏ธ Advanced Tips

FeatureUse Case
AutologgingAutomatically log params, metrics, models with mlflow.sklearn.autolog()
Tracking ServerUse a central MLflow Tracking Server across teams
Artifacts LocationStore artifacts in S3, DBFS, or Azure Blob
Model VersioningRoll back to a previous model version easily
CI/CD IntegrationTrigger re-training/deployment from GitHub Actions or Azure DevOps

๐Ÿ”š Final Thoughts

MLflow on Databricks offers a seamless, production-grade environment for tracking, versioning, and deploying machine learning models. Whether you’re experimenting or building enterprise-grade pipelines, MLflowโ€™s native integration with Databricks is a game-changer.


๐Ÿ“š Resources


Leave a Reply

Your email address will not be published. Required fields are marked *

0
Would love your thoughts, please comment.x
()
x