How to Fix Data Skew in Spark

This is one of the most commonly asked questions in data engineering interviews, especially when you're expected to be comfortable with Spark in real-world production environments.

Interview Question

"In one of your Spark jobs, you notice that a few tasks are taking significantly longer than others, and the job is running slower than expected. How would you investigate and handle this skew?"

This question isn’t just theoretical - it tests your hands-on understanding of Spark internals, your debugging approach, and your ability to design efficient data pipelines.

Let’s dive deep into the answer that an experienced data engineer should be able to explain with clarity and structure.

Understanding the Problem: What Is Data Skew?

When you run a distributed Spark job, your data is divided into multiple tasks. Ideally, all tasks should take roughly equal time to finish. But sometimes, you’ll notice that while most tasks finish quickly, one or two tasks keep running for a long time, delaying the whole job.

This is often because of data skew, when the data is not evenly distributed across partitions.

A Simple Example

Suppose you're processing a log file and grouping records by user_id. If 80% of the data belongs to just 2 users, then the task responsible for processing those 2 users will get overloaded while others finish early. This is exactly what data skew looks like.

Step 1: How to Detect Data Skew

Before solving the issue, you first need to detect it properly. Here's how to do it in a real-world environment.

A. Use Spark UI

Go to your Spark Application's Web UI (usually accessible from your job orchestration platform like Databricks, EMR, or standalone Spark cluster).

  • Click on the "Stages" tab.
  • Look for stages where most tasks complete quickly, but a few are taking much longer.
  • Check the task execution time, shuffle read size, and GC (Garbage Collection) time. A large difference is your first hint.

For example, if 198 out of 200 tasks finish in 10 seconds, but 2 take 2 minutes, there is skew.

B. Profile the Data

Use Spark to manually check if some keys have much more data than others.

1df.groupBy("join_key").count().orderBy("count", ascending=False).show()

This will show you if a small number of keys are responsible for a large chunk of the data.

Step 2: Why Does This Happen?

Understanding why skew happens is important. It usually occurs in operations that require data shuffling like:

  • groupByKey()
  • reduceByKey()
  • join()
  • distinct()
  • aggregateByKey()

These operations rely on hash partitioning or range partitioning based on a key. If a key appears too often (i.e., is "hot"), all of its records are assigned to a single task. This creates an uneven workload.

Step 3: How to Fix or Handle Skew

Once you've confirmed that skew is happening, here are several ways to handle it, based on your situation.

1. Salting the Skewed Keys

Salting means adding a random suffix or prefix to your keys to distribute records more evenly across partitions.

When to Use

When you’re doing groupBy or join on a highly skewed key like "India" or "Unknown" that has millions of records.

How to Do It

Step 1: Add a random salt to your key column.

1from pyspark.sql.functions import col, concat_ws, rand 
2
3df = df.withColumn("salt", (rand() * 10).cast("int")) 
4
5df = df.withColumn("new_key", concat_ws("_", col("key"), col("salt")))

If we are performing a join, then we need to explode the other DataFrame

1from pyspark.sql.functions import explode, array, lit
2
3# Create a column with an array of all possible salt values
4df2 = df2.withColumn("salt", explode(array([lit(i) for i in range(10)])))
5
6# Construct the new join key
7df2 = df2.withColumn("new_key", concat_ws("_", col("key"), col("salt")))

Step 2: Do your groupBy or join using this new_key.

Step 3: After processing, remove the salt if needed and recombine.

This approach splits a heavy key into multiple "virtual" keys like India_0, India_1, etc., distributing the load.

2. Use Broadcast Join

If you're doing a join, and one of the tables is small enough to fit in memory (less than 100MB), broadcast it.

1from pyspark.sql.functions import broadcast 
2result = large_df.join(broadcast(small_df), on="id")

This avoids shuffling the larger table and sends the smaller one to each executor.

When to Use

  • When one side of the join is small and the skew is caused by the large table joining with a few hot keys.

3. Enable Adaptive Query Execution (Spark 3.x +)

If you're using Spark version 3.0 or later, Adaptive Query Execution (AQE) helps Spark automatically detect and handle skewed partitions.

You just need to enable it in your Spark configuration:

1spark.conf.set("spark.sql.adaptive.enabled", "true")
2spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")

Spark will then:

  • Identify skewed partitions at runtime
  • Split them into smaller tasks
  • Optimize the join strategy on the fly

When to Use

  • Always, if you are using Spark 3.0+
  • Especially helpful if you don’t want to manually write salting logic

4. Process Skewed Keys Separately

If only 1 or 2 keys are causing skew (e.g., "India"), isolate those records and process them separately.

1skewed = df.filter("key = 'India'") 
2
3non_skewed = df.filter("key != 'India'") # Process them individually # Then union the results

When to Use

  • When only a small number of hot keys are causing problems
  • When you can handle those keys differently

5. Repartitioning Based on Key

Repartition your data explicitly before shuffle-heavy operations like join or groupBy.

1df = df.repartition(100, "key")

But use this carefully. Increasing partitions too much may create too many small tasks and hurt performance instead of helping.

Real-World Example

Let’s say you are aggregating e-commerce data and grouping by country.

1sales_df.groupBy("country").agg({"amount": "sum"})

You notice that one task takes forever, and Spark UI shows that most records are for "India".

Here’s how you’d handle it:

  1. Add a salt column:
1from pyspark.sql.functions import rand, col, concat 
2salted = sales_df.withColumn("salt", (rand() * 10).cast("int")) 
3salted = salted.withColumn("country_salt", concat(col("country"), col("salt")))
  1. Do the groupBy on country_salt.
  2. After aggregation, remove the salt:
1from pyspark.sql.functions import split, sum 
2result = salted.groupBy("country_salt").agg(sum("amount").alias("partial_total")) 
3final = result.withColumn("country", split("country_salt", "\\d+$")[0]) \ .groupBy("country").agg(sum("partial_total").alias("total_amount"))

This balances the load and solves the skew problem.

Final Thoughts for Candidates

This is not just an interview question - it’s a real-world challenge you’ll face in data pipelines that process millions or billions of rows. Showing that you understand how to detect and solve skew tells interviewers that you’re production-ready.

  • Use Spark UI to diagnose
  • Use data profiling to find skewed keys
  • Apply techniques like salting, broadcast joins, or AQE
  • Bonus: Talk about a real-world example if you have one