Big Data & Streaming Systems

PySpark Interview Coding Patterns

5 min read

PySpark coding questions test your ability to solve real-world data engineering problems efficiently. This lesson covers the most common patterns you'll encounter in interviews.

Essential PySpark Imports

from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, lit, when, coalesce,
    sum, count, avg, max, min,
    row_number, rank, dense_rank, lead, lag,
    date_format, to_date, datediff, date_add,
    explode, array, struct, collect_list, collect_set,
    concat, concat_ws, split, regexp_extract,
    broadcast, monotonically_increasing_id
)
from pyspark.sql.window import Window
from pyspark.sql.types import (
    StructType, StructField, StringType, IntegerType,
    DoubleType, DateType, TimestampType, ArrayType
)

Pattern 1: Window Functions

Running Totals and Cumulative Calculations

Interview Question: "Calculate running total of sales per customer ordered by date."

from pyspark.sql.functions import sum as spark_sum

# Sample data
data = [
    ("C001", "2024-01-01", 100),
    ("C001", "2024-01-02", 150),
    ("C001", "2024-01-03", 200),
    ("C002", "2024-01-01", 300),
    ("C002", "2024-01-02", 250),
]
df = spark.createDataFrame(data, ["customer_id", "date", "amount"])

# Define window
window = Window.partitionBy("customer_id").orderBy("date")

# Running total
df_with_running = df.withColumn(
    "running_total",
    spark_sum("amount").over(window)
)

# Running total with rows between
window_3day = Window.partitionBy("customer_id")\
    .orderBy("date")\
    .rowsBetween(-2, 0)  # Current + 2 previous rows

df_with_3day = df.withColumn(
    "rolling_3day_sum",
    spark_sum("amount").over(window_3day)
)

Ranking Within Groups

Interview Question: "Find the top 3 products by revenue in each category."

from pyspark.sql.functions import row_number

# Sample data
products = [
    ("Electronics", "Laptop", 50000),
    ("Electronics", "Phone", 40000),
    ("Electronics", "Tablet", 30000),
    ("Electronics", "Watch", 20000),
    ("Clothing", "Jacket", 5000),
    ("Clothing", "Shirt", 3000),
    ("Clothing", "Pants", 4000),
]
df = spark.createDataFrame(products, ["category", "product", "revenue"])

# Rank products within each category
window = Window.partitionBy("category").orderBy(col("revenue").desc())

df_ranked = df.withColumn("rank", row_number().over(window))

# Filter top 3
top_3_per_category = df_ranked.filter(col("rank") <= 3)

Lead/Lag for Time Series Analysis

Interview Question: "Calculate day-over-day change in stock price."

from pyspark.sql.functions import lag

stock_data = [
    ("AAPL", "2024-01-01", 180.0),
    ("AAPL", "2024-01-02", 182.5),
    ("AAPL", "2024-01-03", 179.0),
    ("GOOG", "2024-01-01", 140.0),
    ("GOOG", "2024-01-02", 142.0),
]
df = spark.createDataFrame(stock_data, ["symbol", "date", "price"])

window = Window.partitionBy("symbol").orderBy("date")

df_with_change = df.withColumn(
    "prev_price", lag("price", 1).over(window)
).withColumn(
    "daily_change", col("price") - col("prev_price")
).withColumn(
    "pct_change",
    ((col("price") - col("prev_price")) / col("prev_price") * 100).cast("decimal(10,2)")
)

Pattern 2: Complex Aggregations

Multiple Aggregations with GroupBy

Interview Question: "Calculate various statistics per department: total salary, average salary, employee count, highest paid employee."

from pyspark.sql.functions import (
    sum as spark_sum, avg as spark_avg,
    count, max as spark_max, first
)

employees = [
    ("Engineering", "Alice", 150000),
    ("Engineering", "Bob", 120000),
    ("Engineering", "Charlie", 130000),
    ("Sales", "David", 100000),
    ("Sales", "Eve", 110000),
]
df = spark.createDataFrame(employees, ["dept", "name", "salary"])

# Multiple aggregations
dept_stats = df.groupBy("dept").agg(
    spark_sum("salary").alias("total_salary"),
    spark_avg("salary").alias("avg_salary"),
    count("*").alias("employee_count"),
    spark_max("salary").alias("max_salary")
)

# To get highest paid employee name, use window function
window = Window.partitionBy("dept").orderBy(col("salary").desc())

df_with_rank = df.withColumn("rank", row_number().over(window))

highest_paid = df_with_rank.filter(col("rank") == 1)\
    .select("dept", col("name").alias("highest_paid_employee"))

# Join back
result = dept_stats.join(highest_paid, "dept")

Pivot Tables

Interview Question: "Create a pivot table showing monthly revenue per product category."

from pyspark.sql.functions import month, year

sales = [
    ("2024-01-15", "Electronics", 1000),
    ("2024-01-20", "Clothing", 500),
    ("2024-02-10", "Electronics", 1200),
    ("2024-02-15", "Clothing", 600),
    ("2024-03-01", "Electronics", 800),
]
df = spark.createDataFrame(sales, ["date", "category", "revenue"])

df_with_month = df.withColumn(
    "month", date_format(to_date("date"), "yyyy-MM")
)

# Pivot: rows are categories, columns are months
pivot_df = df_with_month.groupBy("category").pivot("month").sum("revenue")

# Fill nulls with 0
pivot_df = pivot_df.fillna(0)

Pattern 3: Handling Complex Data Types

Working with Arrays

Interview Question: "Explode an array column and perform aggregations."

from pyspark.sql.functions import explode, collect_list, size

# User with multiple interests
users = [
    (1, "Alice", ["Python", "Spark", "SQL"]),
    (2, "Bob", ["Java", "Spark"]),
    (3, "Charlie", ["Python", "ML", "SQL"]),
]
df = spark.createDataFrame(users, ["id", "name", "skills"])

# Explode: one row per skill
exploded = df.select("id", "name", explode("skills").alias("skill"))

# Count skills per user
skill_counts = df.withColumn("skill_count", size("skills"))

# Find most common skills
skill_popularity = exploded.groupBy("skill")\
    .agg(count("*").alias("user_count"))\
    .orderBy(col("user_count").desc())

# Collect skills back into array after filtering
python_users = exploded.filter(col("skill") == "Python")\
    .groupBy("id", "name")\
    .agg(collect_list("skill").alias("skills"))

Working with Structs and Nested Data

Interview Question: "Parse nested JSON data and extract specific fields."

from pyspark.sql.functions import col, struct

# Nested data
json_data = """
{"id": 1, "user": {"name": "Alice", "address": {"city": "NYC", "zip": "10001"}}, "orders": [{"id": 101, "amount": 50}, {"id": 102, "amount": 75}]}
{"id": 2, "user": {"name": "Bob", "address": {"city": "LA", "zip": "90001"}}, "orders": [{"id": 103, "amount": 100}]}
"""

# Read JSON
df = spark.read.json(spark.sparkContext.parallelize(json_data.split("\n")))

# Access nested fields
df_flat = df.select(
    col("id"),
    col("user.name").alias("user_name"),
    col("user.address.city").alias("city"),
    col("user.address.zip").alias("zip_code"),
    explode("orders").alias("order")
).select(
    "*",
    col("order.id").alias("order_id"),
    col("order.amount").alias("order_amount")
).drop("order")

Pattern 4: Deduplication Strategies

Keep Most Recent Record

Interview Question: "Remove duplicates keeping only the most recent record per customer."

from pyspark.sql.functions import row_number, col

# Data with duplicates
data = [
    ("C001", "2024-01-01", "Address1"),
    ("C001", "2024-01-15", "Address2"),  # Most recent
    ("C001", "2024-01-10", "Address1.5"),
    ("C002", "2024-01-01", "AddressA"),
    ("C002", "2024-01-05", "AddressB"),  # Most recent
]
df = spark.createDataFrame(data, ["customer_id", "updated_at", "address"])

# Method 1: Window function
window = Window.partitionBy("customer_id").orderBy(col("updated_at").desc())

deduplicated = df.withColumn("rn", row_number().over(window))\
    .filter(col("rn") == 1)\
    .drop("rn")

# Method 2: Using dropDuplicates (keeps first occurrence)
# Need to sort first
deduplicated_v2 = df.orderBy(col("updated_at").desc())\
    .dropDuplicates(["customer_id"])

Handling Multiple Dedup Keys

Interview Question: "Find duplicate transactions based on customer, amount, and date within same hour."

from pyspark.sql.functions import date_trunc, count

transactions = [
    ("C001", 100.0, "2024-01-01 10:15:00"),
    ("C001", 100.0, "2024-01-01 10:45:00"),  # Potential dup
    ("C001", 100.0, "2024-01-01 11:15:00"),  # Different hour
    ("C002", 200.0, "2024-01-01 10:00:00"),
]
df = spark.createDataFrame(transactions, ["customer_id", "amount", "timestamp"])

# Add hour bucket
df_with_hour = df.withColumn(
    "hour_bucket", date_trunc("hour", col("timestamp"))
)

# Find duplicates
dup_counts = df_with_hour.groupBy("customer_id", "amount", "hour_bucket")\
    .agg(count("*").alias("count"))\
    .filter(col("count") > 1)

# Flag duplicates in original
df_with_flag = df_with_hour.join(
    dup_counts.select("customer_id", "amount", "hour_bucket", lit(True).alias("is_duplicate")),
    ["customer_id", "amount", "hour_bucket"],
    "left"
).fillna(False, ["is_duplicate"])

Pattern 5: Efficient Joins

Broadcast Join for Small Tables

Interview Question: "Optimize a join between a large fact table and small dimension table."

from pyspark.sql.functions import broadcast

# Large fact table (millions of rows)
sales = spark.read.parquet("s3://bucket/sales/")  # 100M rows

# Small dimension table (thousands of rows)
products = spark.read.parquet("s3://bucket/products/")  # 10K rows

# Broadcast join - sends small table to all executors
result = sales.join(
    broadcast(products),
    sales.product_id == products.id
).select(
    sales["*"],
    products.name,
    products.category
)

# Spark automatically broadcasts tables under spark.sql.autoBroadcastJoinThreshold (10MB default)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "50MB")

Handling Skewed Joins

Interview Question: "How would you handle a join where one key appears in 50% of the data?"

from pyspark.sql.functions import concat, lit, rand, explode, array

# Problem: customer_id = "unknown" appears in 50% of rows
# Solution: Salting technique

# Salt the skewed key
num_salts = 10

# Salt the large table
large_df = large_df.withColumn(
    "salted_key",
    when(col("customer_id") == "unknown",
         concat(col("customer_id"), lit("_"), (rand() * num_salts).cast("int")))
    .otherwise(col("customer_id"))
)

# Explode the small table for skewed keys
salt_array = array([lit(i) for i in range(num_salts)])

small_df_exploded = small_df.withColumn(
    "salted_key",
    when(col("customer_id") == "unknown",
         explode(array([concat(col("customer_id"), lit("_"), lit(i)) for i in range(num_salts)])))
    .otherwise(col("customer_id"))
)

# Now join on salted key
result = large_df.join(small_df_exploded, "salted_key")

Pattern 6: Data Quality Checks

Interview Question: "Implement data quality validations in PySpark."

from pyspark.sql.functions import (
    count, sum as spark_sum, when, isnan, isnull
)

def run_data_quality_checks(df, table_name):
    """Run comprehensive data quality checks."""

    # Row count
    row_count = df.count()
    print(f"{table_name}: {row_count} rows")

    # Null counts per column
    null_counts = df.select([
        spark_sum(when(col(c).isNull() | isnan(c), 1).otherwise(0)).alias(c)
        for c in df.columns
    ])

    print("Null counts:")
    null_counts.show()

    # Duplicate check on key columns
    total = df.count()
    distinct = df.select("id").distinct().count()
    duplicates = total - distinct
    print(f"Duplicate IDs: {duplicates}")

    # Value range checks
    numeric_stats = df.select(
        min("amount").alias("min_amount"),
        max("amount").alias("max_amount"),
        avg("amount").alias("avg_amount")
    )

    # Check for negative values (if unexpected)
    negative_count = df.filter(col("amount") < 0).count()
    if negative_count > 0:
        print(f"WARNING: {negative_count} negative amounts found")

    return {
        "row_count": row_count,
        "duplicate_count": duplicates,
        "negative_amounts": negative_count
    }

# Usage
quality_report = run_data_quality_checks(sales_df, "sales")

Pattern 7: Sessionization

Interview Question: "Group user events into sessions with 30-minute inactivity timeout."

from pyspark.sql.functions import (
    lag, unix_timestamp, sum as spark_sum, col
)
from pyspark.sql.window import Window

events = [
    ("user1", "2024-01-01 10:00:00"),
    ("user1", "2024-01-01 10:05:00"),
    ("user1", "2024-01-01 10:20:00"),
    ("user1", "2024-01-01 11:00:00"),  # New session (>30 min gap)
    ("user1", "2024-01-01 11:10:00"),
    ("user2", "2024-01-01 10:00:00"),
]
df = spark.createDataFrame(events, ["user_id", "event_time"])

# Convert to timestamp
df = df.withColumn("event_ts", to_timestamp("event_time"))

# Window for lag
window = Window.partitionBy("user_id").orderBy("event_ts")

# Calculate time since previous event
df_with_gap = df.withColumn(
    "prev_event_ts", lag("event_ts", 1).over(window)
).withColumn(
    "gap_seconds",
    unix_timestamp("event_ts") - unix_timestamp("prev_event_ts")
)

# Mark session boundaries (gap > 30 minutes or first event)
session_timeout = 30 * 60  # 30 minutes in seconds

df_with_boundary = df_with_gap.withColumn(
    "new_session",
    when(
        col("prev_event_ts").isNull() | (col("gap_seconds") > session_timeout),
        1
    ).otherwise(0)
)

# Create session IDs using cumulative sum
df_with_sessions = df_with_boundary.withColumn(
    "session_id",
    spark_sum("new_session").over(window)
)

# Calculate session metrics
session_metrics = df_with_sessions.groupBy("user_id", "session_id").agg(
    count("*").alias("event_count"),
    min("event_ts").alias("session_start"),
    max("event_ts").alias("session_end")
)

Interview Tips for PySpark Questions

  1. Always start with imports: Show you know the common functions
  2. Explain your approach: Talk through the solution before coding
  3. Consider performance: Mention broadcast joins, partitioning, avoiding shuffles
  4. Handle edge cases: Nulls, duplicates, empty DataFrames
  5. Know DataFrame vs SQL: Be comfortable with both APIs
# DataFrame API
df.filter(col("status") == "active")\
  .groupBy("region")\
  .agg(sum("revenue").alias("total"))

# SQL API (equivalent)
df.createOrReplaceTempView("sales")
spark.sql("""
    SELECT region, SUM(revenue) as total
    FROM sales
    WHERE status = 'active'
    GROUP BY region
""")

:::

Quiz

Module 5: Big Data & Streaming Systems

Take Quiz