ML Workflow Orchestration

Airflow for ML

4 min read

Apache Airflow is the industry standard for workflow orchestration. Originally built for data engineering, it's widely used for ML pipelines too.

Why Airflow for ML?

Strength Description
Mature ecosystem 10+ years, huge community
Rich integrations AWS, GCP, Kubernetes, Spark
Flexible scheduling Cron, sensors, triggers
Battle-tested Used by Airbnb, Uber, Netflix

Core Concepts

DAG (Directed Acyclic Graph)

from airflow import DAG
from datetime import datetime, timedelta

default_args = {
    "owner": "ml-team",
    "retries": 2,
    "retry_delay": timedelta(minutes=5),
}

dag = DAG(
    dag_id="ml_training_pipeline",
    default_args=default_args,
    description="Daily ML training pipeline",
    schedule="0 2 * * *",  # Run at 2 AM daily
    start_date=datetime(2025, 1, 1),
    catchup=False,
)

Operators

Operators define what a task does:

Operator Use Case
PythonOperator Run Python functions
BashOperator Execute shell commands
KubernetesPodOperator Run containers on K8s
S3ToRedshiftOperator Data transfers

Tasks and Dependencies

from airflow.operators.python import PythonOperator

def preprocess_data():
    # Preprocessing logic
    pass

def train_model():
    # Training logic
    pass

def evaluate_model():
    # Evaluation logic
    pass

# Define tasks
preprocess = PythonOperator(
    task_id="preprocess_data",
    python_callable=preprocess_data,
    dag=dag,
)

train = PythonOperator(
    task_id="train_model",
    python_callable=train_model,
    dag=dag,
)

evaluate = PythonOperator(
    task_id="evaluate_model",
    python_callable=evaluate_model,
    dag=dag,
)

# Define dependencies
preprocess >> train >> evaluate

Complete ML Pipeline Example

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.amazon.aws.operators.s3 import S3CreateObjectOperator
from datetime import datetime, timedelta
import json

default_args = {
    "owner": "ml-team",
    "retries": 2,
    "retry_delay": timedelta(minutes=5),
    "email_on_failure": True,
    "email": ["ml-team@company.com"],
}


def load_data(**context):
    """Load and validate data."""
    import pandas as pd

    df = pd.read_csv("s3://bucket/raw/data.csv")

    # Validate
    assert len(df) > 1000, "Not enough data"
    assert df.isnull().sum().sum() < 100, "Too many nulls"

    # Save processed
    df.to_parquet("/tmp/processed.parquet")

    # Pass metadata to next task
    context["ti"].xcom_push(key="row_count", value=len(df))


def train_model(**context):
    """Train ML model."""
    import pandas as pd
    from sklearn.ensemble import RandomForestClassifier
    import joblib

    df = pd.read_parquet("/tmp/processed.parquet")
    X = df.drop("target", axis=1)
    y = df["target"]

    model = RandomForestClassifier(n_estimators=100)
    model.fit(X, y)

    joblib.dump(model, "/tmp/model.pkl")

    # Get data from previous task
    row_count = context["ti"].xcom_pull(
        task_ids="load_data",
        key="row_count"
    )
    print(f"Trained on {row_count} rows")


def evaluate_model(**context):
    """Evaluate model performance."""
    import pandas as pd
    from sklearn.metrics import accuracy_score
    import joblib

    df = pd.read_parquet("/tmp/processed.parquet")
    model = joblib.load("/tmp/model.pkl")

    X = df.drop("target", axis=1)
    y = df["target"]
    predictions = model.predict(X)

    accuracy = accuracy_score(y, predictions)

    # Store metrics
    metrics = {"accuracy": accuracy, "timestamp": str(datetime.now())}
    context["ti"].xcom_push(key="metrics", value=metrics)

    # Fail if accuracy too low
    if accuracy < 0.8:
        raise ValueError(f"Accuracy {accuracy} below threshold 0.8")


def deploy_model(**context):
    """Deploy model if evaluation passed."""
    import shutil

    metrics = context["ti"].xcom_pull(
        task_ids="evaluate_model",
        key="metrics"
    )

    print(f"Deploying model with accuracy: {metrics['accuracy']}")
    shutil.copy("/tmp/model.pkl", "/models/production/model.pkl")


with DAG(
    dag_id="ml_training_pipeline",
    default_args=default_args,
    description="End-to-end ML training",
    schedule="0 2 * * *",
    start_date=datetime(2025, 1, 1),
    catchup=False,
    tags=["ml", "training"],
) as dag:

    load_task = PythonOperator(
        task_id="load_data",
        python_callable=load_data,
    )

    train_task = PythonOperator(
        task_id="train_model",
        python_callable=train_model,
    )

    evaluate_task = PythonOperator(
        task_id="evaluate_model",
        python_callable=evaluate_model,
    )

    deploy_task = PythonOperator(
        task_id="deploy_model",
        python_callable=deploy_model,
    )

    # Pipeline: load -> train -> evaluate -> deploy
    load_task >> train_task >> evaluate_task >> deploy_task

XCom: Passing Data Between Tasks

# Push data
def task_a(**context):
    result = {"accuracy": 0.95, "model_path": "/tmp/model.pkl"}
    context["ti"].xcom_push(key="result", value=result)

# Pull data
def task_b(**context):
    result = context["ti"].xcom_pull(task_ids="task_a", key="result")
    print(f"Accuracy: {result['accuracy']}")

Note: XCom is for small data (< 48KB). For large data, use external storage (S3, GCS).

Sensors: Wait for Conditions

from airflow.sensors.s3_key_sensor import S3KeySensor

wait_for_data = S3KeySensor(
    task_id="wait_for_data",
    bucket_name="my-bucket",
    bucket_key="data/daily/{{ ds }}/data.csv",
    timeout=3600,  # 1 hour
    poke_interval=60,  # Check every minute
    dag=dag,
)

wait_for_data >> load_task

TaskFlow API (Modern Syntax)

from airflow.decorators import dag, task
from datetime import datetime

@dag(
    dag_id="ml_pipeline_taskflow",
    schedule="@daily",
    start_date=datetime(2025, 1, 1),
    catchup=False,
)
def ml_pipeline():

    @task
    def load_data():
        return {"data_path": "/tmp/data.parquet"}

    @task
    def train_model(data_info: dict):
        return {"model_path": "/tmp/model.pkl"}

    @task
    def evaluate(model_info: dict):
        return {"accuracy": 0.92}

    # Dependencies inferred from function calls
    data = load_data()
    model = train_model(data)
    metrics = evaluate(model)

ml_pipeline()

Best Practices

Practice Why
Use TaskFlow API Cleaner code, better type hints
Keep tasks atomic Easier retries and debugging
Externalize heavy compute Use K8s, Spark for big jobs
Use pools Control resource usage
Test DAGs locally airflow dags test dag_id

Airflow vs Other Tools

Feature Airflow Kubeflow Prefect
Best for Data + ML ML on K8s Modern Python
Scheduling Excellent Basic Good
K8s native No Yes No
Learning curve Medium High Low
UI Good Good Excellent

Common Patterns

Branching

from airflow.operators.python import BranchPythonOperator

def choose_branch(**context):
    accuracy = context["ti"].xcom_pull(task_ids="evaluate")["accuracy"]
    if accuracy > 0.9:
        return "deploy_production"
    else:
        return "deploy_staging"

branch = BranchPythonOperator(
    task_id="choose_deployment",
    python_callable=choose_branch,
)

Dynamic Task Mapping

@task
def process_partition(partition_id: int):
    # Process one partition
    pass

@dag
def dynamic_pipeline():
    partitions = list(range(10))
    process_partition.expand(partition_id=partitions)

Key insight: Airflow excels when you need robust scheduling, monitoring, and integration with data infrastructure. For pure ML workflows, consider Kubeflow or Prefect.

Next, we'll explore Prefect as a modern alternative for Python-native workflows. :::

Quiz

Module 3: ML Workflow Orchestration

Take Quiz