Thursday 23 November 2023

PySpark: CumSum with Salting over Window w/ Skew

How can I use salting to perform a cumulative sum window operation? While a tiny sample, my id column is heavily skewed and I need to perform effectively this operation on it:

window_unsalted = Window.partitionBy("id").orderBy("timestamp")  

# exected value
df = df.withColumn("Expected", F.sum('value').over(window_unsalted))

However, I want to try salting because at the scale of my data, I cannot compute it otherwise.

Consider this MWE. How can I replicate the expected value, 20, using salting techniques?

from pyspark.sql import functions as F  
from pyspark.sql.window import Window  

data = [  
    (7329, 1636617182, 1.0),  
    (7329, 1636142065, 1.0),  
    (7329, 1636142003, 1.0),  
    (7329, 1680400388, 1.0),  
    (7329, 1636142400, 1.0),  
    (7329, 1636397030, 1.0),  
    (7329, 1636142926, 1.0),  
    (7329, 1635970969, 1.0),  
    (7329, 1636122419, 1.0),  
    (7329, 1636142195, 1.0),  
    (7329, 1636142654, 1.0),  
    (7329, 1636142484, 1.0),  
    (7329, 1636119628, 1.0),  
    (7329, 1636404275, 1.0),  
    (7329, 1680827925, 1.0),  
    (7329, 1636413478, 1.0),  
    (7329, 1636143578, 1.0),  
    (7329, 1636413800, 1.0),  
    (7329, 1636124556, 1.0),  
    (7329, 1636143614, 1.0),  
    (7329, 1636617778, -1.0),  
    (7329, 1636142155, -1.0),  
    (7329, 1636142061, -1.0),  
    (7329, 1680400415, -1.0),  
    (7329, 1636142480, -1.0),  
    (7329, 1636400183, -1.0),  
    (7329, 1636143444, -1.0),  
    (7329, 1635977251, -1.0),  
    (7329, 1636122624, -1.0),  
    (7329, 1636142298, -1.0),  
    (7329, 1636142720, -1.0),  
    (7329, 1636142584, -1.0),  
    (7329, 1636122147, -1.0),  
    (7329, 1636413382, -1.0),  
    (7329, 1680827958, -1.0),  
    (7329, 1636413538, -1.0),  
    (7329, 1636143610, -1.0),  
    (7329, 1636414011, -1.0),  
    (7329, 1636141936, -1.0),  
    (7329, 1636146843, -1.0)  
]  
  
df = spark.createDataFrame(data, ["id", "timestamp", "value"])  
  
# Define the number of salt buckets  
num_buckets = 100  
  
# Add a salted_id column to the dataframe  
df = df.withColumn("salted_id", (F.concat(F.col("id"),   
                (F.rand(seed=42)*num_buckets).cast("int")).cast("string")))  
  
# Define a window partitioned by the salted_id, and ordered by timestamp  
window = Window.partitionBy("salted_id").orderBy("timestamp")  
  
# Add a cumulative sum column  
df = df.withColumn("cumulative_sum", F.sum("value").over(window))  
  
# Define a window partitioned by the original id, and ordered by timestamp  
window_unsalted = Window.partitionBy("id").orderBy("timestamp")  
  
# Compute the final cumulative sum by adding up the cumulative sums within each original id  
df = df.withColumn("final_cumulative_sum",   
                   F.sum("cumulative_sum").over(window_unsalted))  

# exected value
df = df.withColumn("Expected", F.sum('value').over(window_unsalted))

# incorrect trial
df.agg(F.sum('final_cumulative_sum')).show()

# expected value
df.agg(F.sum('Expected')).show()


from PySpark: CumSum with Salting over Window w/ Skew

No comments:

Post a Comment