Pivot tables in Spark#

A pivot table is a way of displaying the result of grouped and aggregated data as a two dimensional table, rather than in the list form that you get from regular grouping and aggregating. You might be familiar with them from Excel.

The principles are the same in PySpark and sparklyr, although unlike some Spark functions that are used in both PySpark and sparklyr the syntax is very different.

Python Explanation

You can create pivot tables in PySpark by using .pivot() with .groupBy(). If you group your data by two or more columns then you may find it easier to view the data in this way.

.pivot() has two arguments. pivot_col is the column used to create the output columns, and has to be a single column; it cannot accept a list of multiple columns. The second argument, values, is optional but recommended. You can specify the exact columns that you want returned. If left blank, Spark will automatically use all possible values as output columns; calculating this can be inefficient and the output will look untidy if there are a large number of columns.

R Explanation

You can create pivot tables in sparklyr with sdf_pivot(). This is a sparklyr specific function and so it cannot be used on base R DataFrames or tibbles. An example of pivoting on a tibble is given at the end for comparison.

sdf_pivot(x, formula, fun.aggregate) has three arguments. The first, x is the sparklyr DataFrame, the second, formula is an R formula with grouped columns on the left and pivot column on the right, separated by a tilde (e.g. col1 + col2 ~ pivot_col), and the third, fun.aggregate, is the functions used for aggregation; by default it will count the rows if left blank. Be careful with pivoting data where your pivot column has a large number of distinct values; it will return a very wide DataFrame that will be untidy to view. It is recommended to filter() the data first to only include the values you want in the output columns. The second example uses filter().

Example 1: Group by one column and count#

Create a new Spark session and read the Animal Rescue data. To make the example easier to read, just filter on a few animal groups:

import yaml
from pyspark.sql import SparkSession, functions as F

spark = (SparkSession.builder.master("local[2]")
         .appName("pivot")
         .getOrCreate())

with open("../../../config.yaml") as f:
    config = yaml.safe_load(f)
    
rescue_path = config["rescue_path"]
rescue = (spark.read.parquet(rescue_path)
          .select("incident_number", "animal_group", "cal_year", "total_cost", "origin_of_call")
          .filter(F.col("animal_group").isin("Cat", "Dog", "Hamster", "Sheep")))

The minimal example is grouping by just one column, pivoting on another, just counting the rows, rather than an aggregating values in another column.

Python Example

In PySpark, use .groupBy() and .count() as you normally would when grouping and getting the row count, but add .pivot() between the two functions.

rescue_pivot = (rescue
                .groupBy("animal_group")
                .pivot("cal_year")
                .count())

rescue_pivot.show()
+------------+----+----+----+----+----+----+----+----+----+----+----+
|animal_group|2009|2010|2011|2012|2013|2014|2015|2016|2017|2018|2019|
+------------+----+----+----+----+----+----+----+----+----+----+----+
|     Hamster|null|   3|   3|null|   3|   1|null|   4|null|null|null|
|         Cat| 262| 294| 309| 302| 312| 295| 262| 296| 257| 304|  16|
|         Dog| 132| 122| 103| 100|  93|  90|  88| 107|  81|  91|   1|
|       Sheep|   1|null|null|   1|null|null|   1|   1|null|null|null|
+------------+----+----+----+----+----+----+----+----+----+----+----+
R Example

In sparklyr, use sdf_pivot(). As the pipe (%>%) is being used to apply the function to the DataFrame, this minimal example takes just one argument, formula, which is a tilde expression. The left hand side is the grouping column, animal_group, and the right hand side is the pivot column, cal_year. The default aggregation is to get the row count, so there is no need to specify the other argument, fun.aggregate.

Note that the R output will spill over to multiple rows. The second example resolves this by filtering on what will become the pivot columns.

rescue_pivot <- rescue %>%
    sparklyr::sdf_pivot(animal_group ~ cal_year)

rescue_pivot %>%
    sparklyr::collect() %>%
    print()
# A tibble: 4 × 12
  animal_group `2009` `2010` `2011` `2012` `2013` `2014` `2015` `2016` `2017`
  <chr>         <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>
1 Dog             132    122    103    100     93     90     88    107     81
2 Cat             262    294    309    302    312    295    262    296    257
3 Hamster          NA      3      3     NA      3      1     NA      4     NA
4 Sheep             1     NA     NA      1     NA     NA      1      1     NA
  `2018` `2019`
   <dbl>  <dbl>
1     91      1
2    304     16
3     NA     NA
4     NA     NA

Another way of viewing the same information would be to just use a regular grouping expression, but it is harder to compare between years when displaying the DataFrame in this way:

Python Example
rescue_grouped = (rescue
                  .groupBy("animal_group", "cal_year")
                  .count()
                  .orderBy("animal_group", "cal_year"))

rescue_grouped.show(40)
+------------+--------+-----+
|animal_group|cal_year|count|
+------------+--------+-----+
|         Cat|    2009|  262|
|         Cat|    2010|  294|
|         Cat|    2011|  309|
|         Cat|    2012|  302|
|         Cat|    2013|  312|
|         Cat|    2014|  295|
|         Cat|    2015|  262|
|         Cat|    2016|  296|
|         Cat|    2017|  257|
|         Cat|    2018|  304|
|         Cat|    2019|   16|
|         Dog|    2009|  132|
|         Dog|    2010|  122|
|         Dog|    2011|  103|
|         Dog|    2012|  100|
|         Dog|    2013|   93|
|         Dog|    2014|   90|
|         Dog|    2015|   88|
|         Dog|    2016|  107|
|         Dog|    2017|   81|
|         Dog|    2018|   91|
|         Dog|    2019|    1|
|     Hamster|    2010|    3|
|     Hamster|    2011|    3|
|     Hamster|    2013|    3|
|     Hamster|    2014|    1|
|     Hamster|    2016|    4|
|       Sheep|    2009|    1|
|       Sheep|    2012|    1|
|       Sheep|    2015|    1|
|       Sheep|    2016|    1|
+------------+--------+-----+
R Example
rescue_grouped <- rescue %>%
    dplyr::group_by(animal_group, cal_year) %>%
    dplyr::summarise(n()) %>%
    sparklyr::sdf_sort(c("animal_group", "cal_year"))

rescue_grouped %>%
    sparklyr::collect() %>%
    print()
# A tibble: 31 × 3
   animal_group cal_year `n()`
   <chr>           <int> <dbl>
 1 Cat              2009   262
 2 Cat              2010   294
 3 Cat              2011   309
 4 Cat              2012   302
 5 Cat              2013   312
 6 Cat              2014   295
 7 Cat              2015   262
 8 Cat              2016   296
 9 Cat              2017   257
10 Cat              2018   304
11 Cat              2019    16
12 Dog              2009   132
13 Dog              2010   122
14 Dog              2011   103
15 Dog              2012   100
16 Dog              2013    93
17 Dog              2014    90
18 Dog              2015    88
19 Dog              2016   107
20 Dog              2017    81
21 Dog              2018    91
22 Dog              2019     1
23 Hamster          2010     3
24 Hamster          2011     3
25 Hamster          2013     3
26 Hamster          2014     1
27 Hamster          2016     4
28 Sheep            2009     1
29 Sheep            2012     1
30 Sheep            2015     1
31 Sheep            2016     1

Example 2: Aggregate by another column and specify values#

Python Example

You can use .agg() with .pivot() in the same way as you do with .groupBy(). This example will sum the total_cost.

The documentation explains why it is more efficient to manually provide the values argument; as an example, we just look at three years.

rescue_pivot = (rescue
                .groupBy("animal_group")
                .pivot("cal_year", values=["2009", "2010", "2011"])
                .agg(F.sum("total_cost")))

rescue_pivot.show()
+------------+-------+-------+-------+
|animal_group|   2009|   2010|   2011|
+------------+-------+-------+-------+
|     Hamster|   null|  780.0|  780.0|
|         Cat|76685.0|88140.0|89440.0|
|         Dog|39295.0|38480.0|31200.0|
|       Sheep|  255.0|   null|   null|
+------------+-------+-------+-------+
R Example

To group by several columns express this on the left side of the formula argument, concatenating them with +, in this example AnimalGroup + OriginOfCall ~ CalYear.

To only look at a certain subset of the pivot column you can just use filter() before pivoting. This is a good idea if your pivot column has a large number of distinct values. As an example, we just look at three years.

rescue_pivot <- rescue %>%
    sparklyr::filter(cal_year %in% c("2009", "2010", "2011")) %>%
    sparklyr::sdf_pivot(
        animal_group ~ cal_year,
        fun.aggregate = list(total_cost = "sum"))

rescue_pivot %>%
    sparklyr::collect() %>%
    print()
# A tibble: 4 × 4
  animal_group `2009` `2010` `2011`
  <chr>         <dbl>  <dbl>  <dbl>
1 Cat           76685  88140  89440
2 Dog           39295  38480  31200
3 Hamster          NA    780    780
4 Sheep           255     NA     NA

Example 3: Multiple groupings and aggregations, fill nulls and sort#

Python Example

You can only supply one column to .pivot(), but you can have multiple aggregations. Adding an .alias() makes the result easier to read.

Any missing combinations of the grouping and pivot will be returned as null, e.g. there are no incidents with Hamster, Person (land line) and 2009. To set this to zero, use .fillna().

If grouping by multiple columns you may also want to add .orderBy().

rescue_pivot = (rescue
           .groupBy("animal_group", "origin_of_call")
           .pivot("cal_year", values = ["2009", "2010", "2011"])
           .agg(F.sum("total_cost").alias("sum"), F.max("total_cost").alias("max"))
           .fillna(0)
           .orderBy("animal_group", "origin_of_call"))

rescue_pivot.show()
+------------+--------------------+--------+--------+--------+--------+--------+--------+
|animal_group|      origin_of_call|2009_sum|2009_max|2010_sum|2010_max|2011_sum|2011_max|
+------------+--------------------+--------+--------+--------+--------+--------+--------+
|         Cat|           Ambulance|     0.0|     0.0|     0.0|     0.0|     0.0|     0.0|
|         Cat|           Other FRS|   260.0|   260.0|   520.0|   260.0|  1040.0|   520.0|
|         Cat|  Person (land line)| 45365.0|   780.0| 53040.0|  1040.0| 53040.0|  1040.0|
|         Cat|     Person (mobile)| 30545.0|  1820.0| 33800.0|  2080.0| 34580.0|  1040.0|
|         Cat|Person (running c...|     0.0|     0.0|   260.0|   260.0|     0.0|     0.0|
|         Cat|              Police|   515.0|   260.0|   520.0|   260.0|   780.0|   260.0|
|         Dog|           Ambulance|   255.0|   255.0|     0.0|     0.0|     0.0|     0.0|
|         Dog|           Other FRS|  1540.0|   765.0|  1040.0|   520.0|     0.0|     0.0|
|         Dog|  Person (land line)| 13460.0|   780.0|  9880.0|  1040.0|  9100.0|   520.0|
|         Dog|     Person (mobile)| 20675.0|   780.0| 24180.0|  1040.0| 21320.0|  1040.0|
|         Dog|              Police|  3365.0|   765.0|  3380.0|  1300.0|   780.0|   260.0|
|     Hamster|  Person (land line)|     0.0|     0.0|   260.0|   260.0|   520.0|   260.0|
|     Hamster|     Person (mobile)|     0.0|     0.0|   520.0|   260.0|   260.0|   260.0|
|       Sheep|           Other FRS|     0.0|     0.0|     0.0|     0.0|     0.0|     0.0|
|       Sheep|  Person (land line)|   255.0|   255.0|     0.0|     0.0|     0.0|     0.0|
|       Sheep|     Person (mobile)|     0.0|     0.0|     0.0|     0.0|     0.0|     0.0|
+------------+--------------------+--------+--------+--------+--------+--------+--------+
R Example

sdf_pivot() is quite awkward with multiple aggregations on the same column. fun.aggregate can take a named list, but only one aggregation can be applied to each column. As we want to get the sum and max of total_cost, we can create another column, total_cost_copy, and aggregate on this. To rename the result columns dynamically, use rename_with().

Any missing combinations of the grouping and pivot will be returned as NA, e.g. there are no incidents with Hamster, Person (land line) and 2009. To set this to zero, use na.replace().

If grouping by multiple columns you may also want to add sdf_sort().

rescue_pivot <- rescue %>%
    sparklyr::filter(cal_year %in% c("2009", "2010", "2011")) %>%
    sparklyr::mutate(total_cost_copy = total_cost) %>%
    sparklyr::sdf_pivot(
        animal_group + origin_of_call ~ cal_year,
        fun.aggregate = list(
            total_cost_copy = "sum",
            total_cost = "max"
        )) %>%
    dplyr::rename_with(~substr(., 1, 8), contains(c("_max", "_sum"))) %>%
    sparklyr::sdf_sort(c("animal_group", "origin_of_call")) %>%
    sparklyr::na.replace(0)

rescue_pivot %>%
    sparklyr::collect() %>%
    print()
# A tibble: 13 × 8
   animal_group origin_of_call        `2009_max` `2009_sum` `2010_max`
   <chr>        <chr>                      <dbl>      <dbl>      <dbl>
 1 Cat          Other FRS                    260        260        260
 2 Cat          Person (land line)           780      45365       1040
 3 Cat          Person (mobile)             1820      30545       2080
 4 Cat          Person (running call)          0          0        260
 5 Cat          Police                       260        515        260
 6 Dog          Ambulance                    255        255          0
 7 Dog          Other FRS                    765       1540        520
 8 Dog          Person (land line)           780      13460       1040
 9 Dog          Person (mobile)              780      20675       1040
10 Dog          Police                       765       3365       1300
11 Hamster      Person (land line)             0          0        260
12 Hamster      Person (mobile)                0          0        260
13 Sheep        Person (land line)           255        255          0
   `2010_sum` `2011_max` `2011_sum`
        <dbl>      <dbl>      <dbl>
 1        520        520       1040
 2      53040       1040      53040
 3      33800       1040      34580
 4        260          0          0
 5        520        260        780
 6          0          0          0
 7       1040          0          0
 8       9880        520       9100
 9      24180       1040      21320
10       3380        260        780
11        260        260        520
12        520        260        260
13          0          0          0

Comparison with pivot_wider()#

This section is just for those interested in R and dplyr.

R Explanation

sdf_pivot() can only be used on sparklyr DataFrames. If you have a base R DataFrame or tibble you can use tidyr::pivot_wider(). The documentation for sdf_pivot() explains that it was based on reshape2::dcast(), but it is now recommended to use the tidyr package rather than reshape2. The syntax is different to sdf_pivot() and so it is worth looking at an example for comparison.

First, filter the sparklyr DataFrame and convert to a tibble. Be careful when collecting data from the Spark cluster to the driver; in this example the rescue DataFrame is small, but it will not work if your DataFrame is large:

rescue_tibble <- rescue %>%
    sparklyr::filter(cal_year %in% c("2009", "2010", "2011")) %>%
    sparklyr::collect()

# Check that this is a tibble
class(rescue_tibble)
[1] "tbl_df"     "tbl"        "data.frame"

Now use pivot_wider(); note that rather than a formula with ~ it used names_from and names_to, and it groups by all columns not given in these arguments:

tibble_pivot <- rescue_tibble %>%
    sparklyr::select(animal_group, origin_of_call, cal_year, total_cost) %>%
    tidyr::pivot_wider(
        names_from = cal_year,
        values_from = total_cost,
        values_fn = list(total_cost = sum)) %>%
    dplyr::arrange(animal_group, origin_of_call)
    
tibble_pivot %>%
    print()
# A tibble: 13 × 5
   animal_group origin_of_call        `2011` `2009` `2010`
   <chr>        <chr>                  <dbl>  <dbl>  <dbl>
 1 Cat          Other FRS               1040    260    520
 2 Cat          Person (land line)     53040     NA  53040
 3 Cat          Person (mobile)        34580     NA  33800
 4 Cat          Person (running call)     NA     NA    260
 5 Cat          Police                   780    515    520
 6 Dog          Ambulance                 NA    255     NA
 7 Dog          Other FRS                 NA   1540   1040
 8 Dog          Person (land line)        NA  13460   9880
 9 Dog          Person (mobile)        21320  20675  24180
10 Dog          Police                   780   3365   3380
11 Hamster      Person (land line)       520     NA    260
12 Hamster      Person (mobile)          260     NA    520
13 Sheep        Person (land line)        NA    255     NA

Further Resources#

PySpark Documentation:

sparklyr and tidyverse Documentation: