Mohammad Gufran Jahangir August 7, 2025 0

Here’s a clear and structured explanation of salting, repartitioning, and broadcast joins in Spark β€” including how they work and when to use them β€” with simple examples.


πŸ”Ή 1. Salting β€” Avoiding Data Skew in Joins

πŸ” Problem:

When joining two large datasets, if many records share the same key, Spark will send them all to one partition, causing data skew (one partition has more data than others = slow).

βœ… Solution: Add a random suffix (salt) to the join key to spread the skewed key over multiple partitions.

πŸ§ͺ Example:

# Skewed Table A (many rows with key = 100)
dfA = spark.createDataFrame([(100, "value1")] * 1000 + [(200, "value2")], ["id", "val"])

# Small table B
dfB = spark.createDataFrame([(100, "ref1"), (200, "ref2")], ["id", "ref"])

# Salt B
from pyspark.sql.functions import explode, array, lit

# Add salt 0 to 4 to B
dfB_salted = dfB.withColumn("salt", explode(array([lit(i) for i in range(5)])))
dfB_salted = dfB_salted.withColumn("id_salted", dfB_salted["id"] * 10 + dfB_salted["salt"])

# Salt A randomly
import random
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType

@udf(IntegerType())
def random_salt():
    return random.randint(0, 4)

dfA_salted = dfA.withColumn("salt", random_salt())
dfA_salted = dfA_salted.withColumn("id_salted", dfA_salted["id"] * 10 + dfA_salted["salt"])

# Perform salted join
result = dfA_salted.join(dfB_salted, on="id_salted", how="inner")

πŸ”Ή 2. Repartition β€” Balance Data Across Partitions

πŸ” Problem:

Sometimes data is not evenly distributed, causing some executors to do more work. This is especially true after filtering or reading skewed files.

βœ… Solution: Use .repartition(n) or .repartition("column") to balance data across partitions.

πŸ§ͺ Example:

# Repartition based on a column to ensure same-key records go together
df = df.repartition("customer_id")

# Or to set a specific number of partitions
df = df.repartition(20)

πŸ”§ When to Use:

  • Before a wide transformation (like join or groupBy) on skewed data
  • After heavy filtering that removed a lot of rows

πŸ”Ή 3. Broadcast Join β€” Avoid Shuffling Big Data

πŸ” Problem:

In a large join, Spark shuffles data from both sides to bring matching keys together β€” this is expensive.

βœ… Solution: Use Broadcast Join if one side is small enough to fit in memory. Spark sends the small dataset to all executors instead of shuffling the big one.

πŸ§ͺ Example:

from pyspark.sql.functions import broadcast

# df_small is the small dimension table
result = df_big.join(broadcast(df_small), on="id")

πŸ”§ When to Use:

  • When one table is small (< 10MB or 1000s of rows)
  • Dimension tables, lookup tables, etc.

🧠 Summary Table

StrategyGoalUse WhenExample
SaltingFix skew in joinsOne key has too many rowsAdd random suffix to join key
RepartitioningBalance workloadData uneven across partitions.repartition("col")
Broadcast JoinAvoid shuffle for small tableOne side of the join is small enoughbroadcast(df_small)

How .collect() Works in Spark

In Apache Spark (including Databricks), .collect() is a transformation that retrieves the entire dataset from the distributed Spark cluster to the driver node as a local object (like a list in Python or an array in Scala).


πŸ” How .collect() Works in Spark

When you call:

df.collect()
  • Spark executes the full job, pulls all rows from all partitions, and returns it to the driver node.
  • It materializes the entire DataFrame/RDD in the driver’s memory.

⚠️ When to Use .collect() β€” and When Not To

Use Caseβœ… Safe To Use When…❌ Avoid If…
View sample rows in small datasetsDataset has a few hundred or thousand rowsData is large (millions of rows or several GBs)
Debug or test locallyYou’re testing on sample data in dev or notebookYou’re in production or working with full dataset
Logging / assertions / test casesYou’re collecting a small list to assert or printIt may OOM (Out of Memory) the driver

🧠 Example:

# BAD (if dataset is large)
all_data = df.collect()

# BETTER: use show() or take()
df.show(10)  # Displays 10 rows in the notebook
df.take(100)  # Returns first 100 rows only

πŸ› οΈ Alternatives to .collect()

MethodWhat it DoesSafer?
.show(n)Prints n rows in a tabular formatβœ… Yes
.take(n)Returns n rows as a local listβœ… Yes
.limit(n)Limits the result before action (e.g., .collect())⚠️ Use with caution
.toPandas()Converts entire DataFrame to pandas (on driver)❌ Dangerous on big data
.writeWrite to storage (e.g., .write.csv(...))βœ… Production-safe

πŸ”₯ Key Caution

❌ Never use .collect() on big data in production pipelines β€” it can crash the driver and block the job.


Category: 
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments