Friday, 2 December 2022

Can I interrogate a PySpark DataFrame to get the list of referenced columns?

Given a PySpark DataFrame is it possible to obtain a list of source columns that are being referenced by the DataFrame?

Perhaps a more concrete example might help explain what I'm after. Say I have a DataFrame defined as:

import pyspark.sql.functions as func
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
source_df = spark.createDataFrame(
    [("pru", 23, "finance"), ("paul", 26, "HR"), ("noel", 20, "HR")],
    ["name", "age", "department"],
)
source_df.createOrReplaceTempView("people")
sqlDF = spark.sql("SELECT name, age, department FROM people")
df = sqlDF.groupBy("department").agg(func.max("age").alias("max_age"))
df.show()

which returns:

+----------+--------+                                                           
|department|max_age |
+----------+--------+
|   finance|      23|
|        HR|      26|
+----------+--------+

The columns that are referenced by df are [department, age]. Is it possible to get that list of referenced columns programatically?

Thanks to Capturing the result of explain() in pyspark I know I can extract the plan as a string:

df._sc._jvm.PythonSQLUtils.explainString(df._jdf.queryExecution(), "formatted")

which returns:

== Physical Plan ==
AdaptiveSparkPlan (6)
+- HashAggregate (5)
   +- Exchange (4)
      +- HashAggregate (3)
         +- Project (2)
            +- Scan ExistingRDD (1)


(1) Scan ExistingRDD
Output [3]: [name#0, age#1L, department#2]
Arguments: [name#0, age#1L, department#2], MapPartitionsRDD[4] at applySchemaToPythonRDD at NativeMethodAccessorImpl.java:0, ExistingRDD, UnknownPartitioning(0)

(2) Project
Output [2]: [age#1L, department#2]
Input [3]: [name#0, age#1L, department#2]

(3) HashAggregate
Input [2]: [age#1L, department#2]
Keys [1]: [department#2]
Functions [1]: [partial_max(age#1L)]
Aggregate Attributes [1]: [max#22L]
Results [2]: [department#2, max#23L]

(4) Exchange
Input [2]: [department#2, max#23L]
Arguments: hashpartitioning(department#2, 200), ENSURE_REQUIREMENTS, [plan_id=60]

(5) HashAggregate
Input [2]: [department#2, max#23L]
Keys [1]: [department#2]
Functions [1]: [max(age#1L)]
Aggregate Attributes [1]: [max(age#1L)#12L]
Results [2]: [department#2, max(age#1L)#12L AS max_age#13L]

(6) AdaptiveSparkPlan
Output [2]: [department#2, max_age#13L]
Arguments: isFinalPlan=false

which is useful, however its not what I need. I need a list of the referenced columns. Is this possible?

Perhaps another way of asking the question is... is there a way to obtain the explain plan as an object that I can iterate over/explore?



from Can I interrogate a PySpark DataFrame to get the list of referenced columns?

No comments:

Post a Comment