How to Train and Track ML Models with MLflow in Databricks (Beginner to Advanced Guide)
MLflow is the open-source standard for managing the end-to-end machine learning lifecycle, and Databricks integrates it seamlessly. Whether you’re a beginner experimenting with basic models or an advanced user deploying models in production, MLflow in Databricks offers powerful tools to streamline your workflow.
This blog walks through training and tracking ML models using MLflow in Databricks, covering everything from setup to model registry and deployment.

๐น What is MLflow?
MLflow is an open-source platform for managing the ML lifecycle, including:
- Experiment Tracking โ Log and query experiments: code, data, config, and results.
- Model Packaging โ Package models in a reusable format.
- Model Registry โ Central store for versioned models.
- Model Deployment โ Deploy ML models to various serving environments.
๐งฐ Prerequisites
Before you start:
- A Databricks workspace (Community or Enterprise)
- A notebook (preferably in Python)
- Basic knowledge of ML (Scikit-learn, PySpark ML, or TensorFlow)
๐ Step 1: Set Up MLflow in Databricks
MLflow is pre-installed in Databricks notebooks, so no installation needed. Just import it:
import mlflow
import mlflow.sklearn # or mlflow.spark / mlflow.tensorflow etc.
๐งช Step 2: Create or Select an Experiment
Create a new experiment or set an existing one:
mlflow.set_experiment("/Users/your.name@databricks.com/ML_demo_experiment")
Alternatively, you can assign it inline:
with mlflow.start_run():
# training code here
๐ง Step 3: Train a Model and Log Parameters, Metrics, and Artifacts
Hereโs a Scikit-learn example:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
import pandas as pd
# Load data
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target)
# Start MLflow run
with mlflow.start_run():
clf = RandomForestClassifier(n_estimators=100)
clf.fit(X_train, y_train)
preds = clf.predict(X_test)
acc = accuracy_score(y_test, preds)
# Log parameters, metrics, and model
mlflow.log_param("n_estimators", 100)
mlflow.log_metric("accuracy", acc)
mlflow.sklearn.log_model(clf, "model")
โ Youโll see your run in the MLflow Experiment UI under the โExperimentโ tab.
๐ Step 4: Log Additional Artifacts (Data, Images, Configs)
MLflow lets you log any file:
import matplotlib.pyplot as plt
import tempfile
# Example plot
plt.scatter(X_test[:, 0], X_test[:, 1], c=preds)
plt.title("Predictions")
tmp_path = tempfile.mktemp(suffix=".png")
plt.savefig(tmp_path)
mlflow.log_artifact(tmp_path, "plots")
You can also log:
- Confusion matrices
- Feature importance charts
- Model configs as JSON or YAML
- Training datasets
๐งญ Step 5: Explore and Compare Experiments
Navigate to the MLflow Experiments UI to:
- Compare multiple runs by accuracy, parameters
- Visualize metrics over time
- Filter runs using tags and expressions
Example:
runs_df = mlflow.search_runs(order_by=["metrics.accuracy DESC"])
display(runs_df)
๐งพ Step 6: Use Tags and Descriptions for Better Organization
mlflow.set_tag("model_type", "RandomForest")
mlflow.set_tag("feature_set", "Iris")
mlflow.set_tag("developer", "Patrick Aaron")
These help when searching and filtering experiments.
๐๏ธ Step 7: Register the Model
result = mlflow.register_model(
"runs:/{run_id}/model", "iris_random_forest_model"
)
In the Databricks UI, go to Model Registry to:
- Track model versions
- Transition stages (Staging, Production)
- Add comments and approvals
๐ฆ Step 8: Load and Use the Registered Model
from mlflow.pyfunc import load_model
model = load_model("models:/iris_random_forest_model/Production")
predictions = model.predict(X_test)
Or using UDFs with Spark:
import mlflow.pyfunc
from pyspark.sql.functions import struct
model_udf = mlflow.pyfunc.spark_udf(spark, model_uri="models:/iris_random_forest_model/Production")
df = spark.createDataFrame(pd.DataFrame(X_test))
df.withColumn("predictions", model_udf(struct(*df.columns))).show()
๐งช Step 9: Automate with MLflow Projects and Pipelines (Advanced)
MLflow Projects lets you define reproducible pipelines with MLproject YAML files.
Sample MLproject:
name: iris-classifier
conda_env: conda.yaml
entry_points:
main:
parameters:
n_estimators: {type: int, default: 100}
command: "python train.py --n_estimators {n_estimators}"
Trigger it:
mlflow run . -P n_estimators=200
๐งโ๐ผ Step 10: Deploy Model as a REST API (Advanced)
You can deploy directly in Databricks using MLflow Model Serving:
# From UI: Serve โ Enable serving
# Or use REST endpoint for real-time inference
Sample request:
curl -X POST https://<databricks-host>/model/iris_random_forest_model/production/invocations \
-H "Authorization: Bearer <token>" \
-H "Content-Type: application/json" \
-d '{"data": [[5.1, 3.5, 1.4, 0.2]]}'
๐ ๏ธ Advanced Tips
| Feature | Use Case |
|---|---|
| Autologging | Automatically log params, metrics, models with mlflow.sklearn.autolog() |
| Tracking Server | Use a central MLflow Tracking Server across teams |
| Artifacts Location | Store artifacts in S3, DBFS, or Azure Blob |
| Model Versioning | Roll back to a previous model version easily |
| CI/CD Integration | Trigger re-training/deployment from GitHub Actions or Azure DevOps |
๐ Final Thoughts
MLflow on Databricks offers a seamless, production-grade environment for tracking, versioning, and deploying machine learning models. Whether you’re experimenting or building enterprise-grade pipelines, MLflowโs native integration with Databricks is a game-changer.

Leave a Reply