Hereβs a clear and structured explanation of salting, repartitioning, and broadcast joins in Spark β including how they work and when to use them β with simple examples.
πΉ 1. Salting β Avoiding Data Skew in Joins
π Problem:
When joining two large datasets, if many records share the same key, Spark will send them all to one partition, causing data skew (one partition has more data than others = slow).
β Solution: Add a random suffix (salt) to the join key to spread the skewed key over multiple partitions.
π§ͺ Example:
# Skewed Table A (many rows with key = 100)
dfA = spark.createDataFrame([(100, "value1")] * 1000 + [(200, "value2")], ["id", "val"])
# Small table B
dfB = spark.createDataFrame([(100, "ref1"), (200, "ref2")], ["id", "ref"])
# Salt B
from pyspark.sql.functions import explode, array, lit
# Add salt 0 to 4 to B
dfB_salted = dfB.withColumn("salt", explode(array([lit(i) for i in range(5)])))
dfB_salted = dfB_salted.withColumn("id_salted", dfB_salted["id"] * 10 + dfB_salted["salt"])
# Salt A randomly
import random
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
@udf(IntegerType())
def random_salt():
return random.randint(0, 4)
dfA_salted = dfA.withColumn("salt", random_salt())
dfA_salted = dfA_salted.withColumn("id_salted", dfA_salted["id"] * 10 + dfA_salted["salt"])
# Perform salted join
result = dfA_salted.join(dfB_salted, on="id_salted", how="inner")
πΉ 2. Repartition β Balance Data Across Partitions
π Problem:
Sometimes data is not evenly distributed, causing some executors to do more work. This is especially true after filtering or reading skewed files.
β
Solution: Use .repartition(n) or .repartition("column") to balance data across partitions.
π§ͺ Example:
# Repartition based on a column to ensure same-key records go together
df = df.repartition("customer_id")
# Or to set a specific number of partitions
df = df.repartition(20)
π§ When to Use:
- Before a wide transformation (like join or groupBy) on skewed data
- After heavy filtering that removed a lot of rows
πΉ 3. Broadcast Join β Avoid Shuffling Big Data
π Problem:
In a large join, Spark shuffles data from both sides to bring matching keys together β this is expensive.
β Solution: Use Broadcast Join if one side is small enough to fit in memory. Spark sends the small dataset to all executors instead of shuffling the big one.
π§ͺ Example:
from pyspark.sql.functions import broadcast
# df_small is the small dimension table
result = df_big.join(broadcast(df_small), on="id")
π§ When to Use:
- When one table is small (< 10MB or 1000s of rows)
- Dimension tables, lookup tables, etc.
π§ Summary Table
| Strategy | Goal | Use When | Example |
|---|---|---|---|
| Salting | Fix skew in joins | One key has too many rows | Add random suffix to join key |
| Repartitioning | Balance workload | Data uneven across partitions | .repartition("col") |
| Broadcast Join | Avoid shuffle for small table | One side of the join is small enough | broadcast(df_small) |
How .collect() Works in Spark
In Apache Spark (including Databricks), .collect() is a transformation that retrieves the entire dataset from the distributed Spark cluster to the driver node as a local object (like a list in Python or an array in Scala).
π How .collect() Works in Spark
When you call:
df.collect()
- Spark executes the full job, pulls all rows from all partitions, and returns it to the driver node.
- It materializes the entire DataFrame/RDD in the driver’s memory.
β οΈ When to Use .collect() β and When Not To
| Use Case | β Safe To Use When… | β Avoid If… |
|---|---|---|
| View sample rows in small datasets | Dataset has a few hundred or thousand rows | Data is large (millions of rows or several GBs) |
| Debug or test locally | Youβre testing on sample data in dev or notebook | You’re in production or working with full dataset |
| Logging / assertions / test cases | Youβre collecting a small list to assert or print | It may OOM (Out of Memory) the driver |
π§ Example:
# BAD (if dataset is large)
all_data = df.collect()
# BETTER: use show() or take()
df.show(10) # Displays 10 rows in the notebook
df.take(100) # Returns first 100 rows only
π οΈ Alternatives to .collect()
| Method | What it Does | Safer? |
|---|---|---|
.show(n) | Prints n rows in a tabular format | β Yes |
.take(n) | Returns n rows as a local list | β Yes |
.limit(n) | Limits the result before action (e.g., .collect()) | β οΈ Use with caution |
.toPandas() | Converts entire DataFrame to pandas (on driver) | β Dangerous on big data |
.write | Write to storage (e.g., .write.csv(...)) | β Production-safe |
π₯ Key Caution
β Never use
.collect()on big data in production pipelines β it can crash the driver and block the job.