ML Monitoring & Next Steps

Model Performance Monitoring

3 min read

Beyond data drift, you must monitor the model's actual predictions and performance. This catches degradation even when input data looks stable.

What to Monitor

Category Metrics Why
Latency p50, p95, p99 response time User experience, SLA
Throughput Requests/second Capacity planning
Errors 5xx rate, timeout rate Reliability
Predictions Distribution, confidence scores Model behavior
Business Conversion, revenue impact Actual value

Prometheus Metrics for ML

Custom Metrics

# metrics.py
from prometheus_client import Counter, Histogram, Gauge, start_http_server
import time

# Request metrics
prediction_requests = Counter(
    'ml_prediction_requests_total',
    'Total prediction requests',
    ['model_name', 'model_version']
)

prediction_latency = Histogram(
    'ml_prediction_latency_seconds',
    'Prediction latency',
    ['model_name'],
    buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5]
)

prediction_errors = Counter(
    'ml_prediction_errors_total',
    'Total prediction errors',
    ['model_name', 'error_type']
)

# Model output metrics
prediction_confidence = Histogram(
    'ml_prediction_confidence',
    'Model confidence scores',
    ['model_name', 'predicted_class'],
    buckets=[0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99]
)

prediction_class_distribution = Counter(
    'ml_prediction_class_total',
    'Predictions by class',
    ['model_name', 'predicted_class']
)

# Feature statistics
feature_value = Gauge(
    'ml_feature_value',
    'Feature statistics',
    ['feature_name', 'statistic']  # mean, std, min, max
)

Instrumented Prediction Service

# service.py
import time
from contextlib import contextmanager

class MonitoredPredictor:
    def __init__(self, model, model_name: str, model_version: str):
        self.model = model
        self.model_name = model_name
        self.model_version = model_version

    @contextmanager
    def track_latency(self):
        start = time.perf_counter()
        try:
            yield
        finally:
            duration = time.perf_counter() - start
            prediction_latency.labels(model_name=self.model_name).observe(duration)

    def predict(self, features: dict) -> dict:
        prediction_requests.labels(
            model_name=self.model_name,
            model_version=self.model_version
        ).inc()

        try:
            with self.track_latency():
                prediction = self.model.predict([features])
                confidence = self.model.predict_proba([features])

            # Track outputs
            predicted_class = str(prediction[0])
            prediction_class_distribution.labels(
                model_name=self.model_name,
                predicted_class=predicted_class
            ).inc()

            max_confidence = float(max(confidence[0]))
            prediction_confidence.labels(
                model_name=self.model_name,
                predicted_class=predicted_class
            ).observe(max_confidence)

            return {
                "prediction": predicted_class,
                "confidence": max_confidence
            }

        except Exception as e:
            prediction_errors.labels(
                model_name=self.model_name,
                error_type=type(e).__name__
            ).inc()
            raise

# Start metrics server
start_http_server(8000)

Grafana Dashboards

Key Queries

# Request rate by model
sum(rate(ml_prediction_requests_total[5m])) by (model_name)

# P99 latency
histogram_quantile(0.99,
  sum(rate(ml_prediction_latency_seconds_bucket[5m])) by (le, model_name)
)

# Error rate
sum(rate(ml_prediction_errors_total[5m])) by (model_name)
/
sum(rate(ml_prediction_requests_total[5m])) by (model_name)

# Prediction distribution (for detecting drift)
sum(rate(ml_prediction_class_total[1h])) by (predicted_class)

# Low confidence predictions (potential issues)
histogram_quantile(0.10,
  sum(rate(ml_prediction_confidence_bucket[1h])) by (le)
)

Dashboard JSON

{
  "panels": [
    {
      "title": "Prediction Requests/sec",
      "type": "graph",
      "targets": [{
        "expr": "sum(rate(ml_prediction_requests_total[5m])) by (model_name)"
      }]
    },
    {
      "title": "P99 Latency",
      "type": "graph",
      "targets": [{
        "expr": "histogram_quantile(0.99, sum(rate(ml_prediction_latency_seconds_bucket[5m])) by (le))"
      }]
    },
    {
      "title": "Error Rate",
      "type": "stat",
      "targets": [{
        "expr": "sum(rate(ml_prediction_errors_total[5m])) / sum(rate(ml_prediction_requests_total[5m]))"
      }]
    }
  ]
}

Alerting Rules

Prometheus Alert Rules

# prometheus-rules.yaml
groups:
- name: ml-model-alerts
  rules:
  # High latency
  - alert: MLModelHighLatency
    expr: |
      histogram_quantile(0.99,
        sum(rate(ml_prediction_latency_seconds_bucket[5m])) by (le, model_name)
      ) > 0.5
    for: 5m
    labels:
      severity: warning
    annotations:
      summary: "Model {{ $labels.model_name }} p99 latency > 500ms"

  # High error rate
  - alert: MLModelHighErrorRate
    expr: |
      sum(rate(ml_prediction_errors_total[5m])) by (model_name)
      /
      sum(rate(ml_prediction_requests_total[5m])) by (model_name)
      > 0.01
    for: 5m
    labels:
      severity: critical
    annotations:
      summary: "Model {{ $labels.model_name }} error rate > 1%"

  # Prediction skew (class imbalance changing)
  - alert: MLModelPredictionSkew
    expr: |
      max(
        sum(rate(ml_prediction_class_total[1h])) by (predicted_class)
      ) / sum(rate(ml_prediction_class_total[1h]))
      > 0.95
    for: 1h
    labels:
      severity: warning
    annotations:
      summary: "Model predictions heavily skewed to one class"

  # Low confidence predictions increasing
  - alert: MLModelLowConfidence
    expr: |
      histogram_quantile(0.25,
        sum(rate(ml_prediction_confidence_bucket[1h])) by (le)
      ) < 0.6
    for: 1h
    labels:
      severity: warning
    annotations:
      summary: "25% of predictions have confidence < 60%"

  # No predictions (model down)
  - alert: MLModelDown
    expr: |
      sum(rate(ml_prediction_requests_total[5m])) by (model_name) == 0
    for: 5m
    labels:
      severity: critical
    annotations:
      summary: "Model {{ $labels.model_name }} not receiving requests"

Ground Truth Monitoring

Delayed Label Collection

# ground_truth_monitor.py
from datetime import datetime, timedelta
import pandas as pd

class GroundTruthMonitor:
    def __init__(self, prediction_store, label_store):
        self.prediction_store = prediction_store
        self.label_store = label_store

    def calculate_accuracy(self, start: datetime, end: datetime) -> dict:
        """Calculate accuracy for predictions with ground truth."""
        predictions = self.prediction_store.query(start, end)
        labels = self.label_store.query(start, end)

        # Join on prediction_id
        joined = predictions.merge(labels, on="prediction_id", how="inner")

        if len(joined) == 0:
            return {"accuracy": None, "sample_size": 0}

        correct = (joined["predicted"] == joined["actual"]).sum()
        total = len(joined)

        return {
            "accuracy": correct / total,
            "sample_size": total,
            "precision": self._precision(joined),
            "recall": self._recall(joined),
            "f1": self._f1(joined)
        }

    def monitor_rolling_accuracy(self, window_days: int = 7):
        """Monitor accuracy over rolling window."""
        end = datetime.now()
        start = end - timedelta(days=window_days)

        metrics = self.calculate_accuracy(start, end)

        # Export to Prometheus
        accuracy_gauge.set(metrics["accuracy"])
        sample_size_gauge.set(metrics["sample_size"])

        # Alert if accuracy drops
        if metrics["accuracy"] < 0.85:
            send_alert(
                f"Model accuracy dropped to {metrics['accuracy']:.2%} "
                f"(sample: {metrics['sample_size']})"
            )

        return metrics

SLA Monitoring

SLO Definitions

# slo.yaml
slos:
  fraud_detector:
    availability:
      target: 99.9
      window: 30d
    latency:
      p99_target_ms: 100
      p50_target_ms: 20
    accuracy:
      target: 0.95
      window: 7d

SLO Burn Rate Alerts

# High burn rate = using error budget too fast
- alert: MLModelErrorBudgetBurn
  expr: |
    (
      sum(rate(ml_prediction_errors_total[1h]))
      /
      sum(rate(ml_prediction_requests_total[1h]))
    ) > (1 - 0.999) * 14.4
  for: 5m
  labels:
    severity: critical
  annotations:
    summary: "Burning through error budget 14.4x faster than sustainable"

Best Practices

Practice Implementation
Monitor predictions, not just inputs Track output distribution
Set baselines from production Not training data
Alert on trends, not noise Use rolling windows
Include business metrics Revenue, conversion
Automate remediation Rollback, retrain triggers

Key insight: Good monitoring answers "Is my model helping users?" not just "Is my model running?"

Next, we'll explore ML governance and compliance requirements. :::

Quiz

Module 6: ML Monitoring & Next Steps

Take Quiz