ML Workflow Orchestration
Airflow for ML
4 min read
Apache Airflow is the industry standard for workflow orchestration. Originally built for data engineering, it's widely used for ML pipelines too.
Why Airflow for ML?
| Strength | Description |
|---|---|
| Mature ecosystem | 10+ years, huge community |
| Rich integrations | AWS, GCP, Kubernetes, Spark |
| Flexible scheduling | Cron, sensors, triggers |
| Battle-tested | Used by Airbnb, Uber, Netflix |
Core Concepts
DAG (Directed Acyclic Graph)
from airflow import DAG
from datetime import datetime, timedelta
default_args = {
"owner": "ml-team",
"retries": 2,
"retry_delay": timedelta(minutes=5),
}
dag = DAG(
dag_id="ml_training_pipeline",
default_args=default_args,
description="Daily ML training pipeline",
schedule="0 2 * * *", # Run at 2 AM daily
start_date=datetime(2025, 1, 1),
catchup=False,
)
Operators
Operators define what a task does:
| Operator | Use Case |
|---|---|
PythonOperator |
Run Python functions |
BashOperator |
Execute shell commands |
KubernetesPodOperator |
Run containers on K8s |
S3ToRedshiftOperator |
Data transfers |
Tasks and Dependencies
from airflow.operators.python import PythonOperator
def preprocess_data():
# Preprocessing logic
pass
def train_model():
# Training logic
pass
def evaluate_model():
# Evaluation logic
pass
# Define tasks
preprocess = PythonOperator(
task_id="preprocess_data",
python_callable=preprocess_data,
dag=dag,
)
train = PythonOperator(
task_id="train_model",
python_callable=train_model,
dag=dag,
)
evaluate = PythonOperator(
task_id="evaluate_model",
python_callable=evaluate_model,
dag=dag,
)
# Define dependencies
preprocess >> train >> evaluate
Complete ML Pipeline Example
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.amazon.aws.operators.s3 import S3CreateObjectOperator
from datetime import datetime, timedelta
import json
default_args = {
"owner": "ml-team",
"retries": 2,
"retry_delay": timedelta(minutes=5),
"email_on_failure": True,
"email": ["ml-team@company.com"],
}
def load_data(**context):
"""Load and validate data."""
import pandas as pd
df = pd.read_csv("s3://bucket/raw/data.csv")
# Validate
assert len(df) > 1000, "Not enough data"
assert df.isnull().sum().sum() < 100, "Too many nulls"
# Save processed
df.to_parquet("/tmp/processed.parquet")
# Pass metadata to next task
context["ti"].xcom_push(key="row_count", value=len(df))
def train_model(**context):
"""Train ML model."""
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
import joblib
df = pd.read_parquet("/tmp/processed.parquet")
X = df.drop("target", axis=1)
y = df["target"]
model = RandomForestClassifier(n_estimators=100)
model.fit(X, y)
joblib.dump(model, "/tmp/model.pkl")
# Get data from previous task
row_count = context["ti"].xcom_pull(
task_ids="load_data",
key="row_count"
)
print(f"Trained on {row_count} rows")
def evaluate_model(**context):
"""Evaluate model performance."""
import pandas as pd
from sklearn.metrics import accuracy_score
import joblib
df = pd.read_parquet("/tmp/processed.parquet")
model = joblib.load("/tmp/model.pkl")
X = df.drop("target", axis=1)
y = df["target"]
predictions = model.predict(X)
accuracy = accuracy_score(y, predictions)
# Store metrics
metrics = {"accuracy": accuracy, "timestamp": str(datetime.now())}
context["ti"].xcom_push(key="metrics", value=metrics)
# Fail if accuracy too low
if accuracy < 0.8:
raise ValueError(f"Accuracy {accuracy} below threshold 0.8")
def deploy_model(**context):
"""Deploy model if evaluation passed."""
import shutil
metrics = context["ti"].xcom_pull(
task_ids="evaluate_model",
key="metrics"
)
print(f"Deploying model with accuracy: {metrics['accuracy']}")
shutil.copy("/tmp/model.pkl", "/models/production/model.pkl")
with DAG(
dag_id="ml_training_pipeline",
default_args=default_args,
description="End-to-end ML training",
schedule="0 2 * * *",
start_date=datetime(2025, 1, 1),
catchup=False,
tags=["ml", "training"],
) as dag:
load_task = PythonOperator(
task_id="load_data",
python_callable=load_data,
)
train_task = PythonOperator(
task_id="train_model",
python_callable=train_model,
)
evaluate_task = PythonOperator(
task_id="evaluate_model",
python_callable=evaluate_model,
)
deploy_task = PythonOperator(
task_id="deploy_model",
python_callable=deploy_model,
)
# Pipeline: load -> train -> evaluate -> deploy
load_task >> train_task >> evaluate_task >> deploy_task
XCom: Passing Data Between Tasks
# Push data
def task_a(**context):
result = {"accuracy": 0.95, "model_path": "/tmp/model.pkl"}
context["ti"].xcom_push(key="result", value=result)
# Pull data
def task_b(**context):
result = context["ti"].xcom_pull(task_ids="task_a", key="result")
print(f"Accuracy: {result['accuracy']}")
Note: XCom is for small data (< 48KB). For large data, use external storage (S3, GCS).
Sensors: Wait for Conditions
from airflow.sensors.s3_key_sensor import S3KeySensor
wait_for_data = S3KeySensor(
task_id="wait_for_data",
bucket_name="my-bucket",
bucket_key="data/daily/{{ ds }}/data.csv",
timeout=3600, # 1 hour
poke_interval=60, # Check every minute
dag=dag,
)
wait_for_data >> load_task
TaskFlow API (Modern Syntax)
from airflow.decorators import dag, task
from datetime import datetime
@dag(
dag_id="ml_pipeline_taskflow",
schedule="@daily",
start_date=datetime(2025, 1, 1),
catchup=False,
)
def ml_pipeline():
@task
def load_data():
return {"data_path": "/tmp/data.parquet"}
@task
def train_model(data_info: dict):
return {"model_path": "/tmp/model.pkl"}
@task
def evaluate(model_info: dict):
return {"accuracy": 0.92}
# Dependencies inferred from function calls
data = load_data()
model = train_model(data)
metrics = evaluate(model)
ml_pipeline()
Best Practices
| Practice | Why |
|---|---|
| Use TaskFlow API | Cleaner code, better type hints |
| Keep tasks atomic | Easier retries and debugging |
| Externalize heavy compute | Use K8s, Spark for big jobs |
| Use pools | Control resource usage |
| Test DAGs locally | airflow dags test dag_id |
Airflow vs Other Tools
| Feature | Airflow | Kubeflow | Prefect |
|---|---|---|---|
| Best for | Data + ML | ML on K8s | Modern Python |
| Scheduling | Excellent | Basic | Good |
| K8s native | No | Yes | No |
| Learning curve | Medium | High | Low |
| UI | Good | Good | Excellent |
Common Patterns
Branching
from airflow.operators.python import BranchPythonOperator
def choose_branch(**context):
accuracy = context["ti"].xcom_pull(task_ids="evaluate")["accuracy"]
if accuracy > 0.9:
return "deploy_production"
else:
return "deploy_staging"
branch = BranchPythonOperator(
task_id="choose_deployment",
python_callable=choose_branch,
)
Dynamic Task Mapping
@task
def process_partition(partition_id: int):
# Process one partition
pass
@dag
def dynamic_pipeline():
partitions = list(range(10))
process_partition.expand(partition_id=partitions)
Key insight: Airflow excels when you need robust scheduling, monitoring, and integration with data infrastructure. For pure ML workflows, consider Kubeflow or Prefect.
Next, we'll explore Prefect as a modern alternative for Python-native workflows. :::