Absolutely! Let’s walk through a complete, step-by-step tutorial to help you understand MLflow on Databricks from start to finish, using revised and working code.
We’ll cover all key components of MLflow: Tracking
Models
Model Registry
Signature + Input Example
Objective:
Train a classification model on the Iris dataset, log everything with MLflow, and register the model.
Step-by-Step MLflow Lab on Databricks
Step 1: Setup (Install Required Libraries)
Run this in a cell:
%pip install scikit-learn pandas mlflow
Step 2: Import Libraries and Load Data
import pandas as pd
from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import mlflow
import mlflow.sklearn
from mlflow.models.signature import infer_signature
Step 3: Prepare the Data
# Load Iris dataset
iris = datasets.load_iris()
X = iris.data
y = iris.target
# Split into training and testing
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
Step 4: Start MLflow Run and Train Model
# Start MLflow experiment run
with mlflow.start_run() as run:
# Train the model
model = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=42)
model.fit(X_train, y_train)
# Make predictions and calculate accuracy
predictions = model.predict(X_test)
acc = accuracy_score(y_test, predictions)
# Log parameters and metric
mlflow.log_param("n_estimators", 100)
mlflow.log_param("max_depth", 5)
mlflow.log_metric("accuracy", acc)
# Create sample input and signature
input_example = X_test[:5]
signature = infer_signature(X_train, model.predict(X_train))
# Log model with signature and input example
mlflow.sklearn.log_model(
sk_model=model,
artifact_path="iris_rf_model",
input_example=input_example,
signature=signature
)
# Save run_id for model registration
run_id = run.info.run_id
print(f"Run ID: {run_id}")
print(f"Accuracy: {acc}")
Step 5: Register the Model
Paste this in a new cell:
model_uri = f"runs:/{run_id}/iris_rf_model"
# Register the model under a name
model_details = mlflow.register_model(
model_uri=model_uri,
name="IrisClassifierModel"
)
Now go to “Models” tab in Databricks, and you’ll see
IrisClassifierModel
with versioning.
Step 6: Promote the Model (via UI)
Go to:
- Models > IrisClassifierModel
- Click on the version (e.g., Version 1)
- Click
Stage
→ ChooseStaging
orProduction
# To Promore Version1 to Production
from mlflow.tracking import MlflowClient
client = MlflowClient()
client.transition_model_version_stage(
name="IrisClassifierModel",
version=1, # or the actual version you created
stage="Production" # or "Staging", "Archived"
)
OR
# Assign an alias (like @production) to the version
from mlflow.tracking import MlflowClient
client = MlflowClient()
client.set_registered_model_alias(
name="IrisClassifierModel",
alias="production",
version=1
)
# Run this to list your registered models:
client = MlflowClient()
models = client.list_registered_models()
for m in models:
print(m.name)
Step 7: Load Model from Registry and Predict
Option A: If your model is in “Production” stage (via UI):
Now that the model is registered and staged, let’s load it and use it:
from mlflow.pyfunc import load_model
# Load model from registry using stage name
model = load_model("models:/IrisClassifierModel/Production")
# Predict
preds = model.predict(X_test)
print(preds)
Option B: If you used an alias (e.g., production
) instead of a stage:
from mlflow.pyfunc import load_model
# Load model using alias
model = load_model("models:/IrisClassifierModel@production")
# Predict
preds = model.predict(X_test)
print(preds)
You Did It!
MLflow Component | Covered |
---|---|
Tracking | Yes (log params, metrics, artifacts) |
Models | Yes (logged with signature + input example) |
Model Registry | Yes (registered and promoted via UI) |
Serving/Loading | Yes (loaded from registry and used for predictions) |
Bonus Tips:
- Want to serve the model as REST API? MLflow on Databricks supports model serving.
- Want to run this as a Project? We can convert this into an MLflow Project next.
Let me know if you want this exported as a .dbc
Databricks Notebook or want to integrate with CI/CD!
Objective: Serve Your MLflow Model via REST API on Databricks
Absolutely! Since you’ve already registered the model and assigned an alias like production
, you’re ready to serve the model using Databricks Model Serving.
Let’s go through this step-by-step, assuming you’re using a Databricks Trial or Paid workspace (not Community Edition).
Objective: Serve Your MLflow Model via REST API on Databricks
We’ll go from registered model with alias ➜ to real-time REST endpoint ➜ to making predictions from code.
Step 1: Confirm Model Registration & Alias
You’ve already done this, but here’s a quick reference for future:
from mlflow.tracking import MlflowClient
client = MlflowClient()
# Register model (already done)
model_uri = f"runs:/{run_id}/iris_rf_model"
client.register_model(model_uri=model_uri, name="IrisClassifierModel")
# Set alias
client.set_registered_model_alias(
name="IrisClassifierModel",
alias="production",
version=1
)
Step 2: Enable Model Serving from Databricks UI
- Go to Databricks Workspace.
- In the left sidebar, click “Models”.
- Click on
IrisClassifierModel
. - Click on Version 1 (or the version you aliased).
- You should see a “Serving” or “Enable Serving” button.
- Click it, then:
- Choose Real-time serving
- Click Start serving
Once serving is enabled, you’ll see the endpoint URL (copy it!).
Step 3: Use the REST Endpoint for Predictions
Here’s a full Python example to send test data and get predictions:
import requests
import json
# Replace with your actual endpoint from Databricks
url = "https://<your-databricks-instance>/serving-endpoints/IrisClassifierModel/invocations"
# If needed, generate a Personal Access Token from Databricks User Settings
token = "dapiXXXXXXXXXXXXXXXXXXXX"
# Headers
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
# Input payload (match your model’s input structure)
data = {
"dataframe_split": {
"columns": ["sepal length (cm)", "sepal width (cm)", "petal length (cm)", "petal width (cm)"],
"data": [[5.1, 3.5, 1.4, 0.2]]
}
}
# Send request
response = requests.post(url, headers=headers, json=data)
# Print response
print("Prediction:", response.json())
Step 4: Test It!
Run the above Python code in:
- Databricks notebook
- Jupyter notebook
- Any Python script
Generate a Personal Access Token (If Needed)
- Click on your profile icon in the top-right corner of Databricks.
- Go to “User Settings” > “Access Tokens”
- Click Generate New Token
- Copy it and use it in your
token
variable
Example Output:
Prediction: [0]
This means it predicted class 0
(e.g., Setosa for Iris dataset).
Summary of Steps:
Step | Action |
---|---|
Register model and set alias (production ) | |
Enable model serving in Databricks UI | |
Copy REST endpoint URL | |
Send test prediction via Python using requests |
Would you like me to generate a ready-to-run notebook (.dbc) with this entire process? Or help you test it directly with your live Databricks instance?
I’m a DevOps/SRE/DevSecOps/Cloud Expert passionate about sharing knowledge and experiences. I am working at Cotocus. I blog tech insights at DevOps School, travel stories at Holiday Landmark, stock market tips at Stocks Mantra, health and fitness guidance at My Medic Plus, product reviews at I reviewed , and SEO strategies at Wizbrand.
Please find my social handles as below;
Rajesh Kumar Personal Website
Rajesh Kumar at YOUTUBE
Rajesh Kumar at INSTAGRAM
Rajesh Kumar at X
Rajesh Kumar at FACEBOOK
Rajesh Kumar at LINKEDIN
Rajesh Kumar at PINTEREST
Rajesh Kumar at QUORA
Rajesh Kumar at WIZBRAND