ML Monitoring & Next Steps
Model Performance Monitoring
3 min read
Beyond data drift, you must monitor the model's actual predictions and performance. This catches degradation even when input data looks stable.
What to Monitor
| Category | Metrics | Why |
|---|---|---|
| Latency | p50, p95, p99 response time | User experience, SLA |
| Throughput | Requests/second | Capacity planning |
| Errors | 5xx rate, timeout rate | Reliability |
| Predictions | Distribution, confidence scores | Model behavior |
| Business | Conversion, revenue impact | Actual value |
Prometheus Metrics for ML
Custom Metrics
# metrics.py
from prometheus_client import Counter, Histogram, Gauge, start_http_server
import time
# Request metrics
prediction_requests = Counter(
'ml_prediction_requests_total',
'Total prediction requests',
['model_name', 'model_version']
)
prediction_latency = Histogram(
'ml_prediction_latency_seconds',
'Prediction latency',
['model_name'],
buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5]
)
prediction_errors = Counter(
'ml_prediction_errors_total',
'Total prediction errors',
['model_name', 'error_type']
)
# Model output metrics
prediction_confidence = Histogram(
'ml_prediction_confidence',
'Model confidence scores',
['model_name', 'predicted_class'],
buckets=[0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99]
)
prediction_class_distribution = Counter(
'ml_prediction_class_total',
'Predictions by class',
['model_name', 'predicted_class']
)
# Feature statistics
feature_value = Gauge(
'ml_feature_value',
'Feature statistics',
['feature_name', 'statistic'] # mean, std, min, max
)
Instrumented Prediction Service
# service.py
import time
from contextlib import contextmanager
class MonitoredPredictor:
def __init__(self, model, model_name: str, model_version: str):
self.model = model
self.model_name = model_name
self.model_version = model_version
@contextmanager
def track_latency(self):
start = time.perf_counter()
try:
yield
finally:
duration = time.perf_counter() - start
prediction_latency.labels(model_name=self.model_name).observe(duration)
def predict(self, features: dict) -> dict:
prediction_requests.labels(
model_name=self.model_name,
model_version=self.model_version
).inc()
try:
with self.track_latency():
prediction = self.model.predict([features])
confidence = self.model.predict_proba([features])
# Track outputs
predicted_class = str(prediction[0])
prediction_class_distribution.labels(
model_name=self.model_name,
predicted_class=predicted_class
).inc()
max_confidence = float(max(confidence[0]))
prediction_confidence.labels(
model_name=self.model_name,
predicted_class=predicted_class
).observe(max_confidence)
return {
"prediction": predicted_class,
"confidence": max_confidence
}
except Exception as e:
prediction_errors.labels(
model_name=self.model_name,
error_type=type(e).__name__
).inc()
raise
# Start metrics server
start_http_server(8000)
Grafana Dashboards
Key Queries
# Request rate by model
sum(rate(ml_prediction_requests_total[5m])) by (model_name)
# P99 latency
histogram_quantile(0.99,
sum(rate(ml_prediction_latency_seconds_bucket[5m])) by (le, model_name)
)
# Error rate
sum(rate(ml_prediction_errors_total[5m])) by (model_name)
/
sum(rate(ml_prediction_requests_total[5m])) by (model_name)
# Prediction distribution (for detecting drift)
sum(rate(ml_prediction_class_total[1h])) by (predicted_class)
# Low confidence predictions (potential issues)
histogram_quantile(0.10,
sum(rate(ml_prediction_confidence_bucket[1h])) by (le)
)
Dashboard JSON
{
"panels": [
{
"title": "Prediction Requests/sec",
"type": "graph",
"targets": [{
"expr": "sum(rate(ml_prediction_requests_total[5m])) by (model_name)"
}]
},
{
"title": "P99 Latency",
"type": "graph",
"targets": [{
"expr": "histogram_quantile(0.99, sum(rate(ml_prediction_latency_seconds_bucket[5m])) by (le))"
}]
},
{
"title": "Error Rate",
"type": "stat",
"targets": [{
"expr": "sum(rate(ml_prediction_errors_total[5m])) / sum(rate(ml_prediction_requests_total[5m]))"
}]
}
]
}
Alerting Rules
Prometheus Alert Rules
# prometheus-rules.yaml
groups:
- name: ml-model-alerts
rules:
# High latency
- alert: MLModelHighLatency
expr: |
histogram_quantile(0.99,
sum(rate(ml_prediction_latency_seconds_bucket[5m])) by (le, model_name)
) > 0.5
for: 5m
labels:
severity: warning
annotations:
summary: "Model {{ $labels.model_name }} p99 latency > 500ms"
# High error rate
- alert: MLModelHighErrorRate
expr: |
sum(rate(ml_prediction_errors_total[5m])) by (model_name)
/
sum(rate(ml_prediction_requests_total[5m])) by (model_name)
> 0.01
for: 5m
labels:
severity: critical
annotations:
summary: "Model {{ $labels.model_name }} error rate > 1%"
# Prediction skew (class imbalance changing)
- alert: MLModelPredictionSkew
expr: |
max(
sum(rate(ml_prediction_class_total[1h])) by (predicted_class)
) / sum(rate(ml_prediction_class_total[1h]))
> 0.95
for: 1h
labels:
severity: warning
annotations:
summary: "Model predictions heavily skewed to one class"
# Low confidence predictions increasing
- alert: MLModelLowConfidence
expr: |
histogram_quantile(0.25,
sum(rate(ml_prediction_confidence_bucket[1h])) by (le)
) < 0.6
for: 1h
labels:
severity: warning
annotations:
summary: "25% of predictions have confidence < 60%"
# No predictions (model down)
- alert: MLModelDown
expr: |
sum(rate(ml_prediction_requests_total[5m])) by (model_name) == 0
for: 5m
labels:
severity: critical
annotations:
summary: "Model {{ $labels.model_name }} not receiving requests"
Ground Truth Monitoring
Delayed Label Collection
# ground_truth_monitor.py
from datetime import datetime, timedelta
import pandas as pd
class GroundTruthMonitor:
def __init__(self, prediction_store, label_store):
self.prediction_store = prediction_store
self.label_store = label_store
def calculate_accuracy(self, start: datetime, end: datetime) -> dict:
"""Calculate accuracy for predictions with ground truth."""
predictions = self.prediction_store.query(start, end)
labels = self.label_store.query(start, end)
# Join on prediction_id
joined = predictions.merge(labels, on="prediction_id", how="inner")
if len(joined) == 0:
return {"accuracy": None, "sample_size": 0}
correct = (joined["predicted"] == joined["actual"]).sum()
total = len(joined)
return {
"accuracy": correct / total,
"sample_size": total,
"precision": self._precision(joined),
"recall": self._recall(joined),
"f1": self._f1(joined)
}
def monitor_rolling_accuracy(self, window_days: int = 7):
"""Monitor accuracy over rolling window."""
end = datetime.now()
start = end - timedelta(days=window_days)
metrics = self.calculate_accuracy(start, end)
# Export to Prometheus
accuracy_gauge.set(metrics["accuracy"])
sample_size_gauge.set(metrics["sample_size"])
# Alert if accuracy drops
if metrics["accuracy"] < 0.85:
send_alert(
f"Model accuracy dropped to {metrics['accuracy']:.2%} "
f"(sample: {metrics['sample_size']})"
)
return metrics
SLA Monitoring
SLO Definitions
# slo.yaml
slos:
fraud_detector:
availability:
target: 99.9
window: 30d
latency:
p99_target_ms: 100
p50_target_ms: 20
accuracy:
target: 0.95
window: 7d
SLO Burn Rate Alerts
# High burn rate = using error budget too fast
- alert: MLModelErrorBudgetBurn
expr: |
(
sum(rate(ml_prediction_errors_total[1h]))
/
sum(rate(ml_prediction_requests_total[1h]))
) > (1 - 0.999) * 14.4
for: 5m
labels:
severity: critical
annotations:
summary: "Burning through error budget 14.4x faster than sustainable"
Best Practices
| Practice | Implementation |
|---|---|
| Monitor predictions, not just inputs | Track output distribution |
| Set baselines from production | Not training data |
| Alert on trends, not noise | Use rolling windows |
| Include business metrics | Revenue, conversion |
| Automate remediation | Rollback, retrain triggers |
Key insight: Good monitoring answers "Is my model helping users?" not just "Is my model running?"
Next, we'll explore ML governance and compliance requirements. :::