Window Functions in Spark#

Window functions use values from other rows within the same group, or window, and return a value in a new column for every row. This can be in the form of aggregations (similar to a .groupBy()/group_by() but preserving the original DataFrame), ranking rows within groups, or returning values from previous rows. If you’re familiar with SQL then a window function in PySpark works in the same way.

This article explains how to use window functions in three ways: for aggregation, ranking, and referencing the previous row. An SQL example is also given.

Window Functions for Aggregations#

You can use a window function for aggregations. Rather than returning an aggregated DataFrame, the result of the aggregation will be placed in a new column.

One example of where this is useful is for deriving a total to be used as the denominator for another calculation. For instance, in the Animal Rescue data we may want to work out what percentage of animals rescued each year are dogs. We can do this by getting the total of all animals by year, then dividing each animal group count by this.

First, import the relevant packages and start a Spark session. To use window functions in PySpark, we need to import Window from pyspark.sql.window. No extra packages are needed for sparklyr, as Spark functions are referenced inside mutate().

from pyspark.sql import SparkSession, functions as F
from pyspark.sql.window import Window
import yaml

with open("../../../config.yaml") as f:
    config = yaml.safe_load(f)
    
spark = (SparkSession.builder.master("local[2]")
         .appName("window-functions")
         .getOrCreate())

rescue_path = config["rescue_path"]

In this example we will use an aggregated version of the Animal Rescue data, containing animal_group, cal_year and animal_count.

rescue_agg = (
    spark.read.parquet(rescue_path)
    .withColumn("animal_group", F.initcap(F.col("animal_group")))
    .groupBy("animal_group", "cal_year")
    .agg(F.count("animal_group").alias("animal_count")))

rescue_agg.show(5, truncate=False)
+------------------------------------------------+--------+------------+
|animal_group                                    |cal_year|animal_count|
+------------------------------------------------+--------+------------+
|Snake                                           |2018    |1           |
|Deer                                            |2015    |6           |
|Unknown - Animal Rescue From Water - Farm Animal|2014    |1           |
|Unknown - Domestic Animal Or Pet                |2017    |8           |
|Bird                                            |2012    |112         |
+------------------------------------------------+--------+------------+
only showing top 5 rows

We want to calculate the percentage of animals rescued each year that are dogs. To do this, we first need to calculate the annual totals, and can then divide the number of dogs in each year by this.

We could create a new DataFrame by grouping and aggregating and then joining back to the original DF; this would get the correct result, but a window function is much more efficient as it will reduce the number of shuffles required, as well as making the code more succinct and readable.

The syntax is quite different between PySpark and sparklyr, although the principle is identical in each, and Spark will process them in the same way. The process for using a window function for aggregation in PySpark is as follows:

  • First, use .withColumn(), as the result is stored in a new column in the DataFrame.

  • Then do the aggregation: F.sum("animal_count").

  • Then perform this over a window with .over(Window.partitionBy("cal_year")). Note that this uses .partitionBy() rather than .groupBy() (for some window functions you will also use .orderBy(), but we do not need to here).

In sparklyr:

  • Use group_by(cal_year) to partition the data.

  • Then define a new column, annual_count, as sum(animal_count)) inside mutate() (rather than summarise(), which is used for regular aggregations).

  • Finally, ungroup() to remove the grouping from the DataFrame.

rescue_annual = (rescue_agg
                 .withColumn("annual_count",
                     F.sum("animal_count").over(Window.partitionBy("cal_year"))))

Now display the DF, using Cat, Dog and Hamster between 2012 and 2014 as an example:

(rescue_annual
    .filter(
        (F.col("animal_group").isin("Cat", "Dog", "Hamster")) &
        (F.col("cal_year").between(2012, 2014)))
    .orderBy("cal_year", "animal_group")
    .show())
+------------+--------+------------+------------+
|animal_group|cal_year|animal_count|annual_count|
+------------+--------+------------+------------+
|         Cat|    2012|         305|         603|
|         Dog|    2012|         100|         603|
|         Cat|    2013|         313|         585|
|         Dog|    2013|          93|         585|
|     Hamster|    2013|           3|         585|
|         Cat|    2014|         298|         583|
|         Dog|    2014|          90|         583|
|     Hamster|    2014|           1|         583|
+------------+--------+------------+------------+

The values in annual_count are repeated in every year, as the original rows in the DF have been preserved. Had we aggregated this in the usual way we would have lost the animal_group and animal_count columns, and only returned one annual_count for each cal_year.

Once we have the annual_count column we can complete our calculation with a simple narrow transformation to get the percentage and filter on "Dog":

rescue_annual = (rescue_annual
                 .withColumn("animal_pct",
                             F.round(
                                 (F.col("animal_count") / F.col("annual_count")) * 100, 2)))

rescue_annual.filter(F.col("animal_group") == "Dog").orderBy("cal_year").show()
+------------+--------+------------+------------+----------+
|animal_group|cal_year|animal_count|annual_count|animal_pct|
+------------+--------+------------+------------+----------+
|         Dog|    2009|         132|         568|     23.24|
|         Dog|    2010|         122|         611|     19.97|
|         Dog|    2011|         103|         620|     16.61|
|         Dog|    2012|         100|         603|     16.58|
|         Dog|    2013|          93|         585|      15.9|
|         Dog|    2014|          90|         583|     15.44|
|         Dog|    2015|          88|         540|      16.3|
|         Dog|    2016|         107|         604|     17.72|
|         Dog|    2017|          81|         539|     15.03|
|         Dog|    2018|          91|         609|     14.94|
|         Dog|    2019|           1|          36|      2.78|
+------------+--------+------------+------------+----------+

This example used F.sum()/sum() but other aggregations are possible too, e.g. F.mean()/mean(), F.max()/max(). In PySpark, use multiple .withColumn() statements; in sparklyr, you can combine them in mutate(). In this example we filter on "Snake":

rescue_annual = (rescue_annual
          .withColumn("avg_count",
                     F.mean("animal_count").over(Window.partitionBy("cal_year")))
          .withColumn("max_count",
                     F.max("animal_count").over(Window.partitionBy("cal_year"))                     
                     ))
          
rescue_annual.filter(F.col("animal_group") == "Snake").show()
+------------+--------+------------+------------+----------+------------------+---------+
|animal_group|cal_year|animal_count|annual_count|animal_pct|         avg_count|max_count|
+------------+--------+------------+------------+----------+------------------+---------+
|       Snake|    2018|           1|         609|      0.16|           38.0625|      305|
|       Snake|    2013|           2|         585|      0.34|              45.0|      313|
|       Snake|    2009|           3|         568|      0.53| 37.86666666666667|      263|
|       Snake|    2016|           1|         604|      0.17|43.142857142857146|      297|
|       Snake|    2017|           1|         539|      0.19|              49.0|      258|
+------------+--------+------------+------------+----------+------------------+---------+

The alternative to window functions is creating a new grouped and aggregated DF, then joining it back to the original one. As well as being less efficient, the code will also be harder to read. For example:

rescue_counts = rescue_agg.groupBy("cal_year").agg(F.sum("animal_count").alias("annual_count"))
rescue_annual_alternative = rescue_agg.join(rescue_counts, on="cal_year", how="left")
rescue_annual_alternative.filter(F.col("animal_group") == "Dog").orderBy("cal_year").show()
+--------+------------+------------+------------+
|cal_year|animal_group|animal_count|annual_count|
+--------+------------+------------+------------+
|    2009|         Dog|         132|         568|
|    2010|         Dog|         122|         611|
|    2011|         Dog|         103|         620|
|    2012|         Dog|         100|         603|
|    2013|         Dog|          93|         585|
|    2014|         Dog|          90|         583|
|    2015|         Dog|          88|         540|
|    2016|         Dog|         107|         604|
|    2017|         Dog|          81|         539|
|    2018|         Dog|          91|         609|
|    2019|         Dog|           1|          36|
+--------+------------+------------+------------+

Using Window Functions for Ranking#

Window functions can also be ordered as well as grouped. This can be combined with F.rank()/rank() or F.row_number()/row_number() to get ranks within groups. For instance, we can get the ranking of the most commonly rescued animals by year, then filter on the top three.

The syntax is again different between PySpark and sparklyr. In PySpark, use the same method as described above for aggregations, but replace F.sum() with F.rank() (or another ordered function), and add orderBy(). In this example, use F.desc("animal_count") to sort descending. The .partitionBy() step is optional; without a .partitionBy() it will treat the whole DataFrame as one group.

In sparklyr, the method is also almost the same as using aggregations. The ordering is done directly with the rank() function. desc(animal_count) is used to sort descending.

rescue_rank = (
    rescue_agg
    .withColumn("rank",
                F.rank().over(
                    Window.partitionBy("cal_year").orderBy(F.desc("animal_count")))))

Once we have the rank column we can filter on those less than or equal to 3, to get the top 3 animals rescued by year:

rescue_rank.filter(F.col("rank") <= 3).orderBy("cal_year", "rank").show(12, truncate=False)
+------------+--------+------------+----+
|animal_group|cal_year|animal_count|rank|
+------------+--------+------------+----+
|Cat         |2009    |263         |1   |
|Dog         |2009    |132         |2   |
|Bird        |2009    |89          |3   |
|Cat         |2010    |297         |1   |
|Dog         |2010    |122         |2   |
|Bird        |2010    |99          |3   |
|Cat         |2011    |309         |1   |
|Bird        |2011    |120         |2   |
|Dog         |2011    |103         |3   |
|Cat         |2012    |305         |1   |
|Bird        |2012    |112         |2   |
|Dog         |2012    |100         |3   |
+------------+--------+------------+----+
only showing top 12 rows

Another common use case is getting just the top row from each group:

rescue_rank.filter(F.col("rank") == 1).orderBy("cal_year").show(truncate=False)
+------------+--------+------------+----+
|animal_group|cal_year|animal_count|rank|
+------------+--------+------------+----+
|Cat         |2009    |263         |1   |
|Cat         |2010    |297         |1   |
|Cat         |2011    |309         |1   |
|Cat         |2012    |305         |1   |
|Cat         |2013    |313         |1   |
|Cat         |2014    |298         |1   |
|Cat         |2015    |263         |1   |
|Cat         |2016    |297         |1   |
|Cat         |2017    |258         |1   |
|Cat         |2018    |305         |1   |
|Cat         |2019    |16          |1   |
+------------+--------+------------+----+

Comparison of ranking methods#

Note that you can have duplicate ranks within each group when using rank(); if this is not desirable then one method is to partition by more columns to break ties. There are also alternatives to rank() depending on your use case:

  • F.rank()/rank() will assign the same value to ties.

  • F.dense_rank()/dense_rank() will not skip a rank after ties.

  • F.row_number()/row_number() will give a unique number to each row within the grouping specified. Note that this can be non-deterministic if there are duplicate rows for the ordering condition specified. This can be avoided by specifying extra columns to essentially use as a tiebreaker.

We can see the difference by comparing the three methods:

rank_comparison = (rescue_agg
    .withColumn("rank",
                F.rank().over(
                    Window
                    .partitionBy("cal_year")
                    .orderBy(F.desc("animal_count"))))
    .withColumn("dense_rank",
                F.dense_rank().over(
                    Window
                    .partitionBy("cal_year")
                    .orderBy(F.desc("animal_count"))))
    .withColumn("row_number",
                F.row_number().over(
                    Window
                    .partitionBy("cal_year")
                    .orderBy(F.desc("animal_count"))))
)

(rank_comparison
    .filter(F.col("cal_year") == 2012)
    .orderBy("cal_year", "row_number")
    .show(truncate=False))
+------------------------------------------------+--------+------------+----+----------+----------+
|animal_group                                    |cal_year|animal_count|rank|dense_rank|row_number|
+------------------------------------------------+--------+------------+----+----------+----------+
|Cat                                             |2012    |305         |1   |1         |1         |
|Bird                                            |2012    |112         |2   |2         |2         |
|Dog                                             |2012    |100         |3   |3         |3         |
|Horse                                           |2012    |28          |4   |4         |4         |
|Unknown - Domestic Animal Or Pet                |2012    |18          |5   |5         |5         |
|Fox                                             |2012    |14          |6   |6         |6         |
|Deer                                            |2012    |7           |7   |7         |7         |
|Squirrel                                        |2012    |4           |8   |8         |8         |
|Unknown - Wild Animal                           |2012    |4           |8   |8         |9         |
|Unknown - Heavy Livestock Animal                |2012    |4           |8   |8         |10        |
|Cow                                             |2012    |3           |11  |9         |11        |
|Ferret                                          |2012    |1           |12  |10        |12        |
|Lamb                                            |2012    |1           |12  |10        |13        |
|Sheep                                           |2012    |1           |12  |10        |14        |
|Unknown - Animal Rescue From Water - Farm Animal|2012    |1           |12  |10        |15        |
+------------------------------------------------+--------+------------+----+----------+----------+

For all the values where animal_count is 4, rank and dense_rank have 8, whereas row_number gives a unique number from 8 to 10. As no other sorting columns were specified, these three rows could be assigned differently on subsequent runs of the code.

For animal_count less than 4, dense_rank has left no gap in the ranking sequence, whereas rank will leave gaps.

row_number has a unique value for each row, even for tied values.

Generating unique row numbers#

Spark DataFrames do not have an index in the same way as pandas or base R DataFrames as they are partitioned on the cluster. You can however use row_number() to generate a unique identifier for each row.

Whereas in the previous example we ranked within groups, here we need to treat the whole DataFrame as one group.

To do this in PySpark, use just Window.orderBy(col1, col2, ...) without the partitionBy(). In sparklyr, just use mutate() without group_by() and ungroup().

Remember to be careful as this can be non-deterministic if there are duplicate rows for the ordering condition specified.

(rescue_agg
    .withColumn("row_number",
                F.row_number().over(Window.orderBy("cal_year")))
    .show(10, truncate=False))
+--------------------------------+--------+------------+----------+
|animal_group                    |cal_year|animal_count|row_number|
+--------------------------------+--------+------------+----------+
|Hedgehog                        |2009    |1           |1         |
|Dog                             |2009    |132         |2         |
|Sheep                           |2009    |1           |3         |
|Deer                            |2009    |8           |4         |
|Lizard                          |2009    |1           |5         |
|Unknown - Heavy Livestock Animal|2009    |14          |6         |
|Unknown - Wild Animal           |2009    |6           |7         |
|Rabbit                          |2009    |1           |8         |
|Bird                            |2009    |89          |9         |
|Horse                           |2009    |19          |10        |
+--------------------------------+--------+------------+----------+
only showing top 10 rows

Reference other rows with lag() and lead()#

The window function F.lag()/lag() allows you to reference the values of previous rows within a group, and F.lead()/lead() will do the same for subsequent rows. You can specify how many previous rows you want to reference with the count argument. By default this is 1. Note that count can be negative, so lag(col, count=1) is the same as lead(col, count=-1).

The first or last row within the window partition will be null values, as they do not have a previous or subsequent row to reference. This can be changed by setting the default parameter, which by default is None.

These window functions differ from rank() and row_number() as they are referencing values, rather than returning a rank. They do however use ordering in the same way.

We can use lag() or lead() to get the number of animals rescued in the previous year, with the intention of calculating the annual change. The process for this is identical to the Using Window Functions for Ranking section, just using lag() as the function within the window.

(rescue_agg
    .withColumn("previous_count",
                F.lag("animal_count").over(
                    Window.partitionBy("animal_group").orderBy("cal_year")))
    .show(10, truncate=False))
+------------------------------------------------+--------+------------+--------------+
|animal_group                                    |cal_year|animal_count|previous_count|
+------------------------------------------------+--------+------------+--------------+
|Unknown - Animal Rescue From Water - Farm Animal|2012    |1           |null          |
|Unknown - Animal Rescue From Water - Farm Animal|2014    |1           |1             |
|Unknown - Animal Rescue From Water - Farm Animal|2019    |1           |1             |
|Cow                                             |2010    |2           |null          |
|Cow                                             |2012    |3           |2             |
|Cow                                             |2014    |1           |3             |
|Cow                                             |2016    |1           |1             |
|Horse                                           |2009    |19          |null          |
|Horse                                           |2010    |15          |19            |
|Horse                                           |2011    |22          |15            |
+------------------------------------------------+--------+------------+--------------+
only showing top 10 rows

Be careful if using lag() with incomplete data: where there were no animals rescued in a year the previous_count will not be correct.

There are several ways to resolve this; one method is using a cross join to get all the combinations of animal_group and cal_year, join the rescue_agg to this, fill the null values with 0, and then do the window calculation:

# Create a DF of all combinations of animal_group and cal_year
all_animals_years = (rescue_agg
                     .select("animal_group")
                     .distinct()
                     .crossJoin(
                         rescue_agg
                         .select("cal_year")
                         .distinct()))

# Use this DF as a base to join the rescue_agg DF to
rescue_agg_prev = (
    all_animals_years
    .join(rescue_agg, on=["animal_group", "cal_year"], how="left")
    # Replace null with 0
    .fillna(0, "animal_count")
    # lag will then reference previous year, even if 0
    .withColumn("previous_count",
                F.lag("animal_count").over(
                    Window.partitionBy("animal_group").orderBy("cal_year"))))

rescue_agg_prev.orderBy("animal_group", "cal_year").show(truncate=False)
+------------+--------+------------+--------------+
|animal_group|cal_year|animal_count|previous_count|
+------------+--------+------------+--------------+
|Bird        |2009    |89          |null          |
|Bird        |2010    |99          |89            |
|Bird        |2011    |120         |99            |
|Bird        |2012    |112         |120           |
|Bird        |2013    |85          |112           |
|Bird        |2014    |110         |85            |
|Bird        |2015    |106         |110           |
|Bird        |2016    |120         |106           |
|Bird        |2017    |124         |120           |
|Bird        |2018    |126         |124           |
|Bird        |2019    |9           |126           |
|Budgie      |2009    |0           |null          |
|Budgie      |2010    |0           |0             |
|Budgie      |2011    |1           |0             |
|Budgie      |2012    |0           |1             |
|Budgie      |2013    |0           |0             |
|Budgie      |2014    |0           |0             |
|Budgie      |2015    |0           |0             |
|Budgie      |2016    |0           |0             |
|Budgie      |2017    |0           |0             |
+------------+--------+------------+--------------+
only showing top 20 rows

Window Functions in SQL#

You can also use the regular SQL syntax for window functions when using Spark, OVER(PARTITION BY...GROUP BY). This needs an SQL wrapper to be processed in Spark, spark.sql() in PySpark and tbl(sc, sql()) in sparklyr. Remember that SQL works on tables rather than DataFrames, so register the DataFrame first.

rescue_agg.registerTempTable("rescue_agg")

sql_window = spark.sql(
    """
    SELECT
        cal_year,
        animal_group,
        animal_count,
        SUM(animal_count) OVER(PARTITION BY cal_year) AS annual_count
    FROM rescue_agg
    """
)

sql_window.filter(F.col("animal_group") == "Snake").show()
+--------+------------+------------+------------+
|cal_year|animal_group|animal_count|annual_count|
+--------+------------+------------+------------+
|    2018|       Snake|           1|         609|
|    2013|       Snake|           2|         585|
|    2009|       Snake|           3|         568|
|    2016|       Snake|           1|         604|
|    2017|       Snake|           1|         539|
+--------+------------+------------+------------+

Further Resources#

Spark at the ONS Articles:

  • Cross Joins

PySpark Documentation:

sparklyr and tidyverse Documentation:

Spark SQL Functions Documentation: