Here’s a clear and structured explanation of salting, repartitioning, and broadcast joins in Spark — including how they work and when to use them — with simple examples.
🔹 1. Salting — Avoiding Data Skew in Joins
🔍 Problem:
When joining two large datasets, if many records share the same key, Spark will send them all to one partition, causing data skew (one partition has more data than others = slow).
✅ Solution: Add a random suffix (salt) to the join key to spread the skewed key over multiple partitions.
🧪 Example:
# Skewed Table A (many rows with key = 100)
dfA = spark.createDataFrame([(100, "value1")] * 1000 + [(200, "value2")], ["id", "val"])
# Small table B
dfB = spark.createDataFrame([(100, "ref1"), (200, "ref2")], ["id", "ref"])
# Salt B
from pyspark.sql.functions import explode, array, lit
# Add salt 0 to 4 to B
dfB_salted = dfB.withColumn("salt", explode(array([lit(i) for i in range(5)])))
dfB_salted = dfB_salted.withColumn("id_salted", dfB_salted["id"] * 10 + dfB_salted["salt"])
# Salt A randomly
import random
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
@udf(IntegerType())
def random_salt():
return random.randint(0, 4)
dfA_salted = dfA.withColumn("salt", random_salt())
dfA_salted = dfA_salted.withColumn("id_salted", dfA_salted["id"] * 10 + dfA_salted["salt"])
# Perform salted join
result = dfA_salted.join(dfB_salted, on="id_salted", how="inner")
🔹 2. Repartition — Balance Data Across Partitions
🔍 Problem:
Sometimes data is not evenly distributed, causing some executors to do more work. This is especially true after filtering or reading skewed files.
✅ Solution: Use .repartition(n) or .repartition("column") to balance data across partitions.
🧪 Example:
# Repartition based on a column to ensure same-key records go together
df = df.repartition("customer_id")
# Or to set a specific number of partitions
df = df.repartition(20)
🔧 When to Use:
- Before a wide transformation (like join or groupBy) on skewed data
- After heavy filtering that removed a lot of rows
🔹 3. Broadcast Join — Avoid Shuffling Big Data
🔍 Problem:
In a large join, Spark shuffles data from both sides to bring matching keys together — this is expensive.
✅ Solution: Use Broadcast Join if one side is small enough to fit in memory. Spark sends the small dataset to all executors instead of shuffling the big one.
🧪 Example:
from pyspark.sql.functions import broadcast
# df_small is the small dimension table
result = df_big.join(broadcast(df_small), on="id")
🔧 When to Use:
- When one table is small (< 10MB or 1000s of rows)
- Dimension tables, lookup tables, etc.
🧠 Summary Table
| Strategy | Goal | Use When | Example |
|---|---|---|---|
| Salting | Fix skew in joins | One key has too many rows | Add random suffix to join key |
| Repartitioning | Balance workload | Data uneven across partitions | .repartition("col") |
| Broadcast Join | Avoid shuffle for small table | One side of the join is small enough | broadcast(df_small) |
How .collect() Works in Spark
In Apache Spark (including Databricks), .collect() is a transformation that retrieves the entire dataset from the distributed Spark cluster to the driver node as a local object (like a list in Python or an array in Scala).
🔍 How .collect() Works in Spark
When you call:
df.collect()
- Spark executes the full job, pulls all rows from all partitions, and returns it to the driver node.
- It materializes the entire DataFrame/RDD in the driver’s memory.
⚠️ When to Use .collect() — and When Not To
| Use Case | ✅ Safe To Use When… | ❌ Avoid If… |
|---|---|---|
| View sample rows in small datasets | Dataset has a few hundred or thousand rows | Data is large (millions of rows or several GBs) |
| Debug or test locally | You’re testing on sample data in dev or notebook | You’re in production or working with full dataset |
| Logging / assertions / test cases | You’re collecting a small list to assert or print | It may OOM (Out of Memory) the driver |
🧠 Example:
# BAD (if dataset is large)
all_data = df.collect()
# BETTER: use show() or take()
df.show(10) # Displays 10 rows in the notebook
df.take(100) # Returns first 100 rows only
🛠️ Alternatives to .collect()
| Method | What it Does | Safer? |
|---|---|---|
.show(n) | Prints n rows in a tabular format | ✅ Yes |
.take(n) | Returns n rows as a local list | ✅ Yes |
.limit(n) | Limits the result before action (e.g., .collect()) | ⚠️ Use with caution |
.toPandas() | Converts entire DataFrame to pandas (on driver) | ❌ Dangerous on big data |
.write | Write to storage (e.g., .write.csv(...)) | ✅ Production-safe |
🔥 Key Caution
❌ Never use
.collect()on big data in production pipelines — it can crash the driver and block the job.