Big Data & Streaming Systems
PySpark Interview Coding Patterns
PySpark coding questions test your ability to solve real-world data engineering problems efficiently. This lesson covers the most common patterns you'll encounter in interviews.
Essential PySpark Imports
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
col, lit, when, coalesce,
sum, count, avg, max, min,
row_number, rank, dense_rank, lead, lag,
date_format, to_date, datediff, date_add,
explode, array, struct, collect_list, collect_set,
concat, concat_ws, split, regexp_extract,
broadcast, monotonically_increasing_id
)
from pyspark.sql.window import Window
from pyspark.sql.types import (
StructType, StructField, StringType, IntegerType,
DoubleType, DateType, TimestampType, ArrayType
)
Pattern 1: Window Functions
Running Totals and Cumulative Calculations
Interview Question: "Calculate running total of sales per customer ordered by date."
from pyspark.sql.functions import sum as spark_sum
# Sample data
data = [
("C001", "2024-01-01", 100),
("C001", "2024-01-02", 150),
("C001", "2024-01-03", 200),
("C002", "2024-01-01", 300),
("C002", "2024-01-02", 250),
]
df = spark.createDataFrame(data, ["customer_id", "date", "amount"])
# Define window
window = Window.partitionBy("customer_id").orderBy("date")
# Running total
df_with_running = df.withColumn(
"running_total",
spark_sum("amount").over(window)
)
# Running total with rows between
window_3day = Window.partitionBy("customer_id")\
.orderBy("date")\
.rowsBetween(-2, 0) # Current + 2 previous rows
df_with_3day = df.withColumn(
"rolling_3day_sum",
spark_sum("amount").over(window_3day)
)
Ranking Within Groups
Interview Question: "Find the top 3 products by revenue in each category."
from pyspark.sql.functions import row_number
# Sample data
products = [
("Electronics", "Laptop", 50000),
("Electronics", "Phone", 40000),
("Electronics", "Tablet", 30000),
("Electronics", "Watch", 20000),
("Clothing", "Jacket", 5000),
("Clothing", "Shirt", 3000),
("Clothing", "Pants", 4000),
]
df = spark.createDataFrame(products, ["category", "product", "revenue"])
# Rank products within each category
window = Window.partitionBy("category").orderBy(col("revenue").desc())
df_ranked = df.withColumn("rank", row_number().over(window))
# Filter top 3
top_3_per_category = df_ranked.filter(col("rank") <= 3)
Lead/Lag for Time Series Analysis
Interview Question: "Calculate day-over-day change in stock price."
from pyspark.sql.functions import lag
stock_data = [
("AAPL", "2024-01-01", 180.0),
("AAPL", "2024-01-02", 182.5),
("AAPL", "2024-01-03", 179.0),
("GOOG", "2024-01-01", 140.0),
("GOOG", "2024-01-02", 142.0),
]
df = spark.createDataFrame(stock_data, ["symbol", "date", "price"])
window = Window.partitionBy("symbol").orderBy("date")
df_with_change = df.withColumn(
"prev_price", lag("price", 1).over(window)
).withColumn(
"daily_change", col("price") - col("prev_price")
).withColumn(
"pct_change",
((col("price") - col("prev_price")) / col("prev_price") * 100).cast("decimal(10,2)")
)
Pattern 2: Complex Aggregations
Multiple Aggregations with GroupBy
Interview Question: "Calculate various statistics per department: total salary, average salary, employee count, highest paid employee."
from pyspark.sql.functions import (
sum as spark_sum, avg as spark_avg,
count, max as spark_max, first
)
employees = [
("Engineering", "Alice", 150000),
("Engineering", "Bob", 120000),
("Engineering", "Charlie", 130000),
("Sales", "David", 100000),
("Sales", "Eve", 110000),
]
df = spark.createDataFrame(employees, ["dept", "name", "salary"])
# Multiple aggregations
dept_stats = df.groupBy("dept").agg(
spark_sum("salary").alias("total_salary"),
spark_avg("salary").alias("avg_salary"),
count("*").alias("employee_count"),
spark_max("salary").alias("max_salary")
)
# To get highest paid employee name, use window function
window = Window.partitionBy("dept").orderBy(col("salary").desc())
df_with_rank = df.withColumn("rank", row_number().over(window))
highest_paid = df_with_rank.filter(col("rank") == 1)\
.select("dept", col("name").alias("highest_paid_employee"))
# Join back
result = dept_stats.join(highest_paid, "dept")
Pivot Tables
Interview Question: "Create a pivot table showing monthly revenue per product category."
from pyspark.sql.functions import month, year
sales = [
("2024-01-15", "Electronics", 1000),
("2024-01-20", "Clothing", 500),
("2024-02-10", "Electronics", 1200),
("2024-02-15", "Clothing", 600),
("2024-03-01", "Electronics", 800),
]
df = spark.createDataFrame(sales, ["date", "category", "revenue"])
df_with_month = df.withColumn(
"month", date_format(to_date("date"), "yyyy-MM")
)
# Pivot: rows are categories, columns are months
pivot_df = df_with_month.groupBy("category").pivot("month").sum("revenue")
# Fill nulls with 0
pivot_df = pivot_df.fillna(0)
Pattern 3: Handling Complex Data Types
Working with Arrays
Interview Question: "Explode an array column and perform aggregations."
from pyspark.sql.functions import explode, collect_list, size
# User with multiple interests
users = [
(1, "Alice", ["Python", "Spark", "SQL"]),
(2, "Bob", ["Java", "Spark"]),
(3, "Charlie", ["Python", "ML", "SQL"]),
]
df = spark.createDataFrame(users, ["id", "name", "skills"])
# Explode: one row per skill
exploded = df.select("id", "name", explode("skills").alias("skill"))
# Count skills per user
skill_counts = df.withColumn("skill_count", size("skills"))
# Find most common skills
skill_popularity = exploded.groupBy("skill")\
.agg(count("*").alias("user_count"))\
.orderBy(col("user_count").desc())
# Collect skills back into array after filtering
python_users = exploded.filter(col("skill") == "Python")\
.groupBy("id", "name")\
.agg(collect_list("skill").alias("skills"))
Working with Structs and Nested Data
Interview Question: "Parse nested JSON data and extract specific fields."
from pyspark.sql.functions import col, struct
# Nested data
json_data = """
{"id": 1, "user": {"name": "Alice", "address": {"city": "NYC", "zip": "10001"}}, "orders": [{"id": 101, "amount": 50}, {"id": 102, "amount": 75}]}
{"id": 2, "user": {"name": "Bob", "address": {"city": "LA", "zip": "90001"}}, "orders": [{"id": 103, "amount": 100}]}
"""
# Read JSON
df = spark.read.json(spark.sparkContext.parallelize(json_data.split("\n")))
# Access nested fields
df_flat = df.select(
col("id"),
col("user.name").alias("user_name"),
col("user.address.city").alias("city"),
col("user.address.zip").alias("zip_code"),
explode("orders").alias("order")
).select(
"*",
col("order.id").alias("order_id"),
col("order.amount").alias("order_amount")
).drop("order")
Pattern 4: Deduplication Strategies
Keep Most Recent Record
Interview Question: "Remove duplicates keeping only the most recent record per customer."
from pyspark.sql.functions import row_number, col
# Data with duplicates
data = [
("C001", "2024-01-01", "Address1"),
("C001", "2024-01-15", "Address2"), # Most recent
("C001", "2024-01-10", "Address1.5"),
("C002", "2024-01-01", "AddressA"),
("C002", "2024-01-05", "AddressB"), # Most recent
]
df = spark.createDataFrame(data, ["customer_id", "updated_at", "address"])
# Method 1: Window function
window = Window.partitionBy("customer_id").orderBy(col("updated_at").desc())
deduplicated = df.withColumn("rn", row_number().over(window))\
.filter(col("rn") == 1)\
.drop("rn")
# Method 2: Using dropDuplicates (keeps first occurrence)
# Need to sort first
deduplicated_v2 = df.orderBy(col("updated_at").desc())\
.dropDuplicates(["customer_id"])
Handling Multiple Dedup Keys
Interview Question: "Find duplicate transactions based on customer, amount, and date within same hour."
from pyspark.sql.functions import date_trunc, count
transactions = [
("C001", 100.0, "2024-01-01 10:15:00"),
("C001", 100.0, "2024-01-01 10:45:00"), # Potential dup
("C001", 100.0, "2024-01-01 11:15:00"), # Different hour
("C002", 200.0, "2024-01-01 10:00:00"),
]
df = spark.createDataFrame(transactions, ["customer_id", "amount", "timestamp"])
# Add hour bucket
df_with_hour = df.withColumn(
"hour_bucket", date_trunc("hour", col("timestamp"))
)
# Find duplicates
dup_counts = df_with_hour.groupBy("customer_id", "amount", "hour_bucket")\
.agg(count("*").alias("count"))\
.filter(col("count") > 1)
# Flag duplicates in original
df_with_flag = df_with_hour.join(
dup_counts.select("customer_id", "amount", "hour_bucket", lit(True).alias("is_duplicate")),
["customer_id", "amount", "hour_bucket"],
"left"
).fillna(False, ["is_duplicate"])
Pattern 5: Efficient Joins
Broadcast Join for Small Tables
Interview Question: "Optimize a join between a large fact table and small dimension table."
from pyspark.sql.functions import broadcast
# Large fact table (millions of rows)
sales = spark.read.parquet("s3://bucket/sales/") # 100M rows
# Small dimension table (thousands of rows)
products = spark.read.parquet("s3://bucket/products/") # 10K rows
# Broadcast join - sends small table to all executors
result = sales.join(
broadcast(products),
sales.product_id == products.id
).select(
sales["*"],
products.name,
products.category
)
# Spark automatically broadcasts tables under spark.sql.autoBroadcastJoinThreshold (10MB default)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "50MB")
Handling Skewed Joins
Interview Question: "How would you handle a join where one key appears in 50% of the data?"
from pyspark.sql.functions import concat, lit, rand, explode, array
# Problem: customer_id = "unknown" appears in 50% of rows
# Solution: Salting technique
# Salt the skewed key
num_salts = 10
# Salt the large table
large_df = large_df.withColumn(
"salted_key",
when(col("customer_id") == "unknown",
concat(col("customer_id"), lit("_"), (rand() * num_salts).cast("int")))
.otherwise(col("customer_id"))
)
# Explode the small table for skewed keys
salt_array = array([lit(i) for i in range(num_salts)])
small_df_exploded = small_df.withColumn(
"salted_key",
when(col("customer_id") == "unknown",
explode(array([concat(col("customer_id"), lit("_"), lit(i)) for i in range(num_salts)])))
.otherwise(col("customer_id"))
)
# Now join on salted key
result = large_df.join(small_df_exploded, "salted_key")
Pattern 6: Data Quality Checks
Interview Question: "Implement data quality validations in PySpark."
from pyspark.sql.functions import (
count, sum as spark_sum, when, isnan, isnull
)
def run_data_quality_checks(df, table_name):
"""Run comprehensive data quality checks."""
# Row count
row_count = df.count()
print(f"{table_name}: {row_count} rows")
# Null counts per column
null_counts = df.select([
spark_sum(when(col(c).isNull() | isnan(c), 1).otherwise(0)).alias(c)
for c in df.columns
])
print("Null counts:")
null_counts.show()
# Duplicate check on key columns
total = df.count()
distinct = df.select("id").distinct().count()
duplicates = total - distinct
print(f"Duplicate IDs: {duplicates}")
# Value range checks
numeric_stats = df.select(
min("amount").alias("min_amount"),
max("amount").alias("max_amount"),
avg("amount").alias("avg_amount")
)
# Check for negative values (if unexpected)
negative_count = df.filter(col("amount") < 0).count()
if negative_count > 0:
print(f"WARNING: {negative_count} negative amounts found")
return {
"row_count": row_count,
"duplicate_count": duplicates,
"negative_amounts": negative_count
}
# Usage
quality_report = run_data_quality_checks(sales_df, "sales")
Pattern 7: Sessionization
Interview Question: "Group user events into sessions with 30-minute inactivity timeout."
from pyspark.sql.functions import (
lag, unix_timestamp, sum as spark_sum, col
)
from pyspark.sql.window import Window
events = [
("user1", "2024-01-01 10:00:00"),
("user1", "2024-01-01 10:05:00"),
("user1", "2024-01-01 10:20:00"),
("user1", "2024-01-01 11:00:00"), # New session (>30 min gap)
("user1", "2024-01-01 11:10:00"),
("user2", "2024-01-01 10:00:00"),
]
df = spark.createDataFrame(events, ["user_id", "event_time"])
# Convert to timestamp
df = df.withColumn("event_ts", to_timestamp("event_time"))
# Window for lag
window = Window.partitionBy("user_id").orderBy("event_ts")
# Calculate time since previous event
df_with_gap = df.withColumn(
"prev_event_ts", lag("event_ts", 1).over(window)
).withColumn(
"gap_seconds",
unix_timestamp("event_ts") - unix_timestamp("prev_event_ts")
)
# Mark session boundaries (gap > 30 minutes or first event)
session_timeout = 30 * 60 # 30 minutes in seconds
df_with_boundary = df_with_gap.withColumn(
"new_session",
when(
col("prev_event_ts").isNull() | (col("gap_seconds") > session_timeout),
1
).otherwise(0)
)
# Create session IDs using cumulative sum
df_with_sessions = df_with_boundary.withColumn(
"session_id",
spark_sum("new_session").over(window)
)
# Calculate session metrics
session_metrics = df_with_sessions.groupBy("user_id", "session_id").agg(
count("*").alias("event_count"),
min("event_ts").alias("session_start"),
max("event_ts").alias("session_end")
)
Interview Tips for PySpark Questions
- Always start with imports: Show you know the common functions
- Explain your approach: Talk through the solution before coding
- Consider performance: Mention broadcast joins, partitioning, avoiding shuffles
- Handle edge cases: Nulls, duplicates, empty DataFrames
- Know DataFrame vs SQL: Be comfortable with both APIs
# DataFrame API
df.filter(col("status") == "active")\
.groupBy("region")\
.agg(sum("revenue").alias("total"))
# SQL API (equivalent)
df.createOrReplaceTempView("sales")
spark.sql("""
SELECT region, SUM(revenue) as total
FROM sales
WHERE status = 'active'
GROUP BY region
""")
:::