How can I use salting to perform a cumulative sum window operation? While a tiny sample, my id column is heavily skewed and I need to perform effectively this operation on it:
window_unsalted = Window.partitionBy("id").orderBy("timestamp")
# exected value
df = df.withColumn("Expected", F.sum('value').over(window_unsalted))
However, I want to try salting because at the scale of my data, I cannot compute it otherwise.
Consider this MWE. How can I replicate the expected value, 20, using salting techniques?
from pyspark.sql import functions as F
from pyspark.sql.window import Window
data = [
(7329, 1636617182, 1.0),
(7329, 1636142065, 1.0),
(7329, 1636142003, 1.0),
(7329, 1680400388, 1.0),
(7329, 1636142400, 1.0),
(7329, 1636397030, 1.0),
(7329, 1636142926, 1.0),
(7329, 1635970969, 1.0),
(7329, 1636122419, 1.0),
(7329, 1636142195, 1.0),
(7329, 1636142654, 1.0),
(7329, 1636142484, 1.0),
(7329, 1636119628, 1.0),
(7329, 1636404275, 1.0),
(7329, 1680827925, 1.0),
(7329, 1636413478, 1.0),
(7329, 1636143578, 1.0),
(7329, 1636413800, 1.0),
(7329, 1636124556, 1.0),
(7329, 1636143614, 1.0),
(7329, 1636617778, -1.0),
(7329, 1636142155, -1.0),
(7329, 1636142061, -1.0),
(7329, 1680400415, -1.0),
(7329, 1636142480, -1.0),
(7329, 1636400183, -1.0),
(7329, 1636143444, -1.0),
(7329, 1635977251, -1.0),
(7329, 1636122624, -1.0),
(7329, 1636142298, -1.0),
(7329, 1636142720, -1.0),
(7329, 1636142584, -1.0),
(7329, 1636122147, -1.0),
(7329, 1636413382, -1.0),
(7329, 1680827958, -1.0),
(7329, 1636413538, -1.0),
(7329, 1636143610, -1.0),
(7329, 1636414011, -1.0),
(7329, 1636141936, -1.0),
(7329, 1636146843, -1.0)
]
df = spark.createDataFrame(data, ["id", "timestamp", "value"])
# Define the number of salt buckets
num_buckets = 100
# Add a salted_id column to the dataframe
df = df.withColumn("salted_id", (F.concat(F.col("id"),
(F.rand(seed=42)*num_buckets).cast("int")).cast("string")))
# Define a window partitioned by the salted_id, and ordered by timestamp
window = Window.partitionBy("salted_id").orderBy("timestamp")
# Add a cumulative sum column
df = df.withColumn("cumulative_sum", F.sum("value").over(window))
# Define a window partitioned by the original id, and ordered by timestamp
window_unsalted = Window.partitionBy("id").orderBy("timestamp")
# Compute the final cumulative sum by adding up the cumulative sums within each original id
df = df.withColumn("final_cumulative_sum",
F.sum("cumulative_sum").over(window_unsalted))
# exected value
df = df.withColumn("Expected", F.sum('value').over(window_unsalted))
# incorrect trial
df.agg(F.sum('final_cumulative_sum')).show()
# expected value
df.agg(F.sum('Expected')).show()
from PySpark: CumSum with Salting over Window w/ Skew
No comments:
Post a Comment