ML Monitoring & Next Steps
Data Drift Detection
3 min read
Production ML models degrade silently. The data they encounter changes over time, causing predictions to become less accurate. Detecting this drift early is critical.
Types of Drift
| Type | What Changes | Example |
|---|---|---|
| Data Drift | Feature distributions | User demographics shift |
| Concept Drift | Input→Output relationship | Fraud patterns evolve |
| Label Drift | Target distribution | Churn rate increases |
| Prediction Drift | Model output distribution | More high-risk predictions |
Why Models Degrade
Training Time Production Time
┌─────────────────┐ ┌─────────────────┐
│ Training Data │ │ New Data │
│ │ │ │
│ Mean age: 35 │ Drift │ Mean age: 28 │
│ 70% urban │ ────► │ 85% urban │
│ Avg income: 50K │ │ Avg income: 45K │
└─────────────────┘ └─────────────────┘
│ │
▼ ▼
Model trained Model predictions
for THIS data degraded on NEW data
Detecting Drift with Evidently
Installation
pip install evidently
Basic Drift Report
import pandas as pd
from evidently.report import Report
from evidently.metric_preset import DataDriftPreset
# Reference data (training distribution)
reference_data = pd.read_csv("training_data.csv")
# Current data (production)
current_data = pd.read_csv("production_batch.csv")
# Create drift report
report = Report(metrics=[DataDriftPreset()])
report.run(
reference_data=reference_data,
current_data=current_data
)
# View in notebook
report
# Or save as HTML
report.save_html("drift_report.html")
Drift Detection Results
# Get drift detection results as dict
results = report.as_dict()
# Check if dataset has drifted
dataset_drift = results["metrics"][0]["result"]["dataset_drift"]
print(f"Dataset drift detected: {dataset_drift}")
# Check individual feature drift
for feature in results["metrics"][0]["result"]["drift_by_columns"]:
col = feature["column_name"]
drifted = feature["drift_detected"]
score = feature["drift_score"]
print(f"{col}: drift={drifted}, score={score:.3f}")
Column-Level Drift Metrics
from evidently.report import Report
from evidently.metrics import (
ColumnDriftMetric,
DataDriftTable
)
report = Report(metrics=[
DataDriftTable(),
ColumnDriftMetric(column_name="age"),
ColumnDriftMetric(column_name="income"),
ColumnDriftMetric(column_name="transaction_amount"),
])
report.run(reference_data=reference, current_data=current)
Statistical Tests for Drift
Kolmogorov-Smirnov Test
from scipy import stats
def detect_drift_ks(reference: pd.Series, current: pd.Series, threshold: float = 0.05):
"""Detect drift using KS test for numerical features."""
statistic, p_value = stats.ks_2samp(reference, current)
drift_detected = p_value < threshold
return {
"statistic": statistic,
"p_value": p_value,
"drift_detected": drift_detected
}
# Example
result = detect_drift_ks(
reference_data["transaction_amount"],
current_data["transaction_amount"]
)
print(f"Drift detected: {result['drift_detected']}, p-value: {result['p_value']:.4f}")
Population Stability Index (PSI)
import numpy as np
def calculate_psi(reference: pd.Series, current: pd.Series, bins: int = 10) -> float:
"""Calculate Population Stability Index."""
# Create bins from reference
breakpoints = np.percentile(reference, np.linspace(0, 100, bins + 1))
breakpoints[0] = -np.inf
breakpoints[-1] = np.inf
# Calculate proportions
ref_counts = np.histogram(reference, bins=breakpoints)[0] / len(reference)
cur_counts = np.histogram(current, bins=breakpoints)[0] / len(current)
# Avoid division by zero
ref_counts = np.clip(ref_counts, 0.0001, None)
cur_counts = np.clip(cur_counts, 0.0001, None)
# Calculate PSI
psi = np.sum((cur_counts - ref_counts) * np.log(cur_counts / ref_counts))
return psi
# Interpretation
# PSI < 0.1: No significant drift
# 0.1 <= PSI < 0.2: Moderate drift (investigate)
# PSI >= 0.2: Significant drift (retrain)
psi = calculate_psi(reference_data["income"], current_data["income"])
print(f"PSI: {psi:.3f}")
Automated Drift Monitoring
Scheduled Monitoring Job
# drift_monitor.py
import schedule
import time
from datetime import datetime, timedelta
from evidently.report import Report
from evidently.metric_preset import DataDriftPreset
def load_reference_data():
"""Load training data as reference."""
return pd.read_parquet("s3://ml-data/reference/training.parquet")
def load_current_data(hours: int = 24):
"""Load recent production data."""
end = datetime.now()
start = end - timedelta(hours=hours)
return pd.read_parquet(
f"s3://ml-data/predictions/{start.date()}/",
)
def check_drift():
"""Run drift detection and alert if needed."""
reference = load_reference_data()
current = load_current_data(hours=24)
report = Report(metrics=[DataDriftPreset()])
report.run(reference_data=reference, current_data=current)
results = report.as_dict()
drift_detected = results["metrics"][0]["result"]["dataset_drift"]
if drift_detected:
drifted_features = [
f["column_name"]
for f in results["metrics"][0]["result"]["drift_by_columns"]
if f["drift_detected"]
]
send_alert(f"Data drift detected in: {drifted_features}")
report.save_html(f"drift_report_{datetime.now().isoformat()}.html")
return drift_detected
def send_alert(message: str):
"""Send alert via Slack/PagerDuty/email."""
print(f"ALERT: {message}")
# slack.post(channel="#ml-alerts", message=message)
# Run every 6 hours
schedule.every(6).hours.do(check_drift)
while True:
schedule.run_pending()
time.sleep(60)
Prometheus Metrics
from prometheus_client import Gauge, start_http_server
# Define metrics
drift_score = Gauge('ml_drift_score', 'Data drift score', ['feature'])
drift_detected = Gauge('ml_drift_detected', 'Data drift detected', ['feature'])
def export_drift_metrics(results: dict):
"""Export drift metrics to Prometheus."""
for feature in results["metrics"][0]["result"]["drift_by_columns"]:
col = feature["column_name"]
drift_score.labels(feature=col).set(feature["drift_score"])
drift_detected.labels(feature=col).set(int(feature["drift_detected"]))
# Start metrics server
start_http_server(8000)
Handling Drift
| Severity | Action |
|---|---|
| Low (PSI < 0.1) | Monitor, no action |
| Medium (0.1-0.2) | Investigate, document |
| High (PSI > 0.2) | Trigger retraining pipeline |
Automated Retraining Trigger
def trigger_retraining_if_needed(psi_scores: dict, threshold: float = 0.2):
"""Trigger retraining when drift exceeds threshold."""
max_psi = max(psi_scores.values())
if max_psi >= threshold:
# Trigger Kubeflow/Airflow pipeline
trigger_training_pipeline(
reason="data_drift",
drift_scores=psi_scores
)
return True
return False
Key insight: Models don't fail loudly—they degrade silently. Continuous drift monitoring catches problems before they impact users.
Next, we'll explore model performance monitoring and alerting. :::