Checkpoints and Staging Tables#

Persisting to disk#

Spark uses lazy evaluation. As we build up many transformations Spark creates an execution plan for the DataFrame and the plan is executed when an action is called. This execution plan represents the DataFrame’s lineage.

Sometimes the DataFrame’s lineage can grow long and complex, which will slow down the processing and maybe even return an error. However, we can get around this by breaking the lineage.

There is more than one way of breaking the lineage of a DataFrame, which is discussed in more detail in the Persisting section.

Checkpoint#

In this article, we cover a simple method of persisting to disk called checkpointing, which is essentially an out of the box shortcut to a write/read operation.

Experiment#

To demonstrate the benefit of checkpointing we’ll time how long it takes to create a DataFrame using an iterative calculation. We will run the process without persisting, then again using a checkpoint.

We’ll create a new Spark session each time just in case there’s an advantage when processing the DataFrame a second time in the same session. We will also use the Python module time to measure the time taken to create the DataFrame.

We’re going to create a new DataFrame with an id column and a column called col_0 that will consist of random numbers. We’ll then create a loop to add new columns where the values depend on a previous column. The contents of the columns isn’t important here. What is important is that Spark is creating an execution plan that it getting longer with each iteration of the loop.

In general, we try to avoid using loops with Spark and this example shows why. A better solution to this problem using Spark would be to add new rows with each iteration as opposed to columns.

We will set a seed_num when creating the random numbers to make the results repeatable. The DataFrame will have num_rows amount of rows, which we will set to a thousand and the loop will iterate 11 times to create col_1 to col_11.

import os
from pyspark.sql import SparkSession, functions as F
from time import time
import yaml

spark = (SparkSession.builder.master("local[2]")
         .appName("checkpoint")
         .getOrCreate())

new_cols = 12
seed_num = 42
num_rows = 10**3
library(sparklyr)
library(dplyr)
library(DBI)

sc <- sparklyr::spark_connect(
    master = "local[2]",
    app_name = "checkpoint",
    config = sparklyr::spark_config())


set.seed(42)
new_cols <- 12
num_rows <- 10^3

Without persisting#

start_time = time()

df = spark.range(num_rows)
df = df.withColumn("col_0", F.ceil(F.rand(seed_num) * new_cols))

for i in range(1, new_cols):
    df = (df.withColumn("col_"+str(i), 
                        F.when(F.col("col_"+str(i-1)) > i, 
                               F.col("col_"+str(i-1)))
                        .otherwise(0)))

df.show(10)

time_taken = time() - start_time
print(f"Time taken to create the DataFrame:  {time_taken}")
+---+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+------+
| id|col_0|col_1|col_2|col_3|col_4|col_5|col_6|col_7|col_8|col_9|col_10|col_11|
+---+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+------+
|  0|    8|    8|    8|    8|    8|    8|    8|    8|    0|    0|     0|     0|
|  1|   11|   11|   11|   11|   11|   11|   11|   11|   11|   11|    11|     0|
|  2|   11|   11|   11|   11|   11|   11|   11|   11|   11|   11|    11|     0|
|  3|   11|   11|   11|   11|   11|   11|   11|   11|   11|   11|    11|     0|
|  4|    6|    6|    6|    6|    6|    6|    0|    0|    0|    0|     0|     0|
|  5|    7|    7|    7|    7|    7|    7|    7|    0|    0|    0|     0|     0|
|  6|    1|    0|    0|    0|    0|    0|    0|    0|    0|    0|     0|     0|
|  7|    2|    2|    0|    0|    0|    0|    0|    0|    0|    0|     0|     0|
|  8|    4|    4|    4|    4|    0|    0|    0|    0|    0|    0|     0|     0|
|  9|    9|    9|    9|    9|    9|    9|    9|    9|    9|    0|     0|     0|
+---+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+------+
only showing top 10 rows

Time taken to create the DataFrame:  6.800409317016602
start_time <- Sys.time()

df = sparklyr::sdf_seq(sc, 1, num_rows) %>%
    sparklyr::mutate(col_0 = ceiling(rand()*new_cols))

for (i in 1: new_cols)
{
  column_name <- paste0('col_', i)
  prev_column <- paste0('col_', i-1)
  df <- df %>%
    sparklyr::mutate(
      !!column_name := case_when(
        !!as.symbol(prev_column) > i ~ !!as.symbol(prev_column),
        TRUE ~ 0 ))
  
}

df %>%
    head(10) %>%
    sparklyr::collect() %>%
    print()

end_time <- Sys.time()
time_taken <- end_time - start_time

cat("Time taken to create DataFrame", time_taken)

The result above shows how long Spark took to create the plan and execute it to show the top 10 rows.

With checkpoints#

Next we will stop the Spark session and start a new one to repeat the operation using checkpoints.

To perform a checkpoint we need to set up a checkpoint directory on the file system, which is where the checkpointed DataFrames will be stored. It’s important to practice good housekeeping with this directory because new files are created with every checkpoint, but they are not automatically deleted.

spark.stop()

spark = (SparkSession.builder.master("local[2]")
         .appName("checkpoint")
         .getOrCreate())

with open("../../../config.yaml") as f:
    config = yaml.safe_load(f)
    
checkpoint_path = config["checkpoint_path"]
spark.sparkContext.setCheckpointDir(checkpoint_path)
sparklyr::spark_disconnect(sc)

sc <- sparklyr::spark_connect(
    master = "local[2]",
    app_name = "checkpoint",
    config = sparklyr::spark_config())


config <- yaml::yaml.load_file("ons-spark/config.yaml")

sparklyr::spark_set_checkpoint_dir(sc, config$checkpoint_path)

We will checkpoint the DataFrame every 3 iterations of the loop so that the lineage doesn’t grow as long. Again, we will time how long it takes for Spark to complete the operation.

start_time = time()

df = spark.range(num_rows)
df = df.withColumn("col_0", F.ceil(F.rand(seed_num) * new_cols))

for i in range(1, new_cols):
    df = (df.withColumn("col_"+str(i), 
                       F.when(F.col("col_"+str(i-1)) > i, 
                              F.col("col_"+str(i-1)))
                       .otherwise(0)))
    if i % 3 == 0: # this means if i is divisable by three then...
        df = df.checkpoint() # here is the checkpoint
        
df.show(10)

time_taken = time() - start_time
print(f"Time taken to create the DataFrame:  {time_taken}")
+---+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+------+
| id|col_0|col_1|col_2|col_3|col_4|col_5|col_6|col_7|col_8|col_9|col_10|col_11|
+---+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+------+
|  0|    8|    8|    8|    8|    8|    8|    8|    8|    0|    0|     0|     0|
|  1|   11|   11|   11|   11|   11|   11|   11|   11|   11|   11|    11|     0|
|  2|   11|   11|   11|   11|   11|   11|   11|   11|   11|   11|    11|     0|
|  3|   11|   11|   11|   11|   11|   11|   11|   11|   11|   11|    11|     0|
|  4|    6|    6|    6|    6|    6|    6|    0|    0|    0|    0|     0|     0|
|  5|    7|    7|    7|    7|    7|    7|    7|    0|    0|    0|     0|     0|
|  6|    1|    0|    0|    0|    0|    0|    0|    0|    0|    0|     0|     0|
|  7|    2|    2|    0|    0|    0|    0|    0|    0|    0|    0|     0|     0|
|  8|    4|    4|    4|    4|    0|    0|    0|    0|    0|    0|     0|     0|
|  9|    9|    9|    9|    9|    9|    9|    9|    9|    9|    0|     0|     0|
+---+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+------+
only showing top 10 rows

Time taken to create the DataFrame:  1.0024805068969727
start_time <- Sys.time()

df1 <- sparklyr::sdf_seq(sc, 1, num_rows) %>%
    sparklyr::mutate(col_0 = ceiling(rand()*new_cols))

for (i in 1: new_cols)
{
  column_name <- paste0('col_', i)
  prev_column <- paste0('col_', i-1)
  df1 <- df1 %>%
    sparklyr::mutate(
    !!column_name := case_when(
        !!as.symbol(prev_column) > i ~ !!as.symbol(prev_column),
        TRUE ~ 0 ))
  
  
  if (i %% 3 == 0) 
  {
    df1 <- sparklyr::sdf_checkpoint(df1, eager= TRUE)
  }
}

df1 %>%
    head(10) %>%
    sparklyr::collect() %>%
    print()

end_time <- Sys.time()
time_taken <- end_time - start_time


cat("Time taken to create DataFrame: ", time_taken)

The exact times will vary with each run of this notebook, but hopefully you will see that using the .checkpoint() was more efficient.

As mentioned earlier, the checkpoint files are not deleted on HDFS automatically. The files are not intended to be used after you stop the Spark session, so make sure you delete these files after a session.

Often the easiest way to delete files is through some GUI, but the cell below is handy to have at the end of your scripts when using checkpoints to make sure you don’t forget to empty the checkpoint folder.

import subprocess
cmd = f'hdfs dfs -rm -r -skipTrash {checkpoint_path}' 
p = subprocess.run(cmd, shell=True)

spark.stop()
cmd <- paste0("hdfs dfs -rm -r -skipTrash ", config$checkpoint_path)
p <- system(command = cmd)

sparklyr::spark_disconnect(sc)

How often should I checkpoint?#

How did we come up with the number 3 for number of iterations to checkpoint? Trial and error. Unfortunately, you may not have the luxury of trying to find the optimum number, but have a go at checkpointing and see if you can get any improvements in performance.

More frequent checkpointing means more writing and reading data, which does take some time, but the aim is to save some time by simplifying the execution plan.

As mentioned above, the use of loops shown here is not considered good practice with Spark, but it was a convenient example of using checkpoints. Of course, checkpointing can also be used outside loops, see the Persisting article for more information on the different forms of persisting data in Spark and their applications.

Staging Tables#

Staging tables are an alternative way of checkpointing data in Spark, in which the data is written out as a named Hive table in a database, rather than to the checkpointing location.

Staging tables: the concept#

You can write a staging table to HDFS with df.write.mode("overwrite").saveAsTable(table_name, format="parquet") or df.write.insertInto(table_name, overwrite=True)(of course, if using .insertInto() you will need to create the table first). You can then read the table back in with spark.read.table(). Like with checkpointing, this will break the lineage of the DataFrame, and therefore they can be useful in large, complex pipelines, or those that involve processes in a loop. As Spark is more efficient at reading in tables than CSV files, another use case is staging CSV files as tables at the start of your code before doing any complex calculations.

Staging has some advantages over checkpointing:

  • The same table can be overwritten, meaning there is no need to clean up old checkpointed files

  • It is stored in a location that is easier to access, rather than the checkpointing folder, which can help with debugging and testing changes to the code

  • They can be re-used elsewhere

  • If .insertInto() is used, you can take advantage of the table schema, as an exception will be raised if the DataFrame and table schemas do not match

  • It is more efficient for Spark to read Hive tables than CSV files as the underlying format is Parquet, so if your data are delivered as CSV files you may want to stage them as Hive tables first.

There are also some disadvantages:

  • Takes longer to write the code

  • More difficult to maintain, especially if .insertInto() is used, as you will have to alter the table if the DataFrame structure changes

  • Ensure that you are not using them unnecessarily (the same is true with any method of persisting data)

The examples here use PySpark, but the same principles apply to R users who are using sparklyr in DAP.

Example#

Our example will be very simple, and show how to read a CSV file, perform some basic data cleansing, then stage as a Hive table, and then read it back in as a DataFrame.

Often staging tables are most useful in large, complex pipelines; for obvious reasons our example will instead be simple!

First, import the relevant modules and create a Spark session:

import os
from pyspark.sql import SparkSession
import pyspark.sql.functions as F 

spark = (SparkSession.builder.master("local[2]")
         .appName("staging-tables")
         .getOrCreate())
library(sparklyr)
library(dplyr)

sc <- sparklyr::spark_connect(
    master = "local[2]",
    app_name = "staging_tables",
    config = sparklyr::spark_config())

Now read in the CSV:

rescue_path = config['rescue_path_csv']
df = spark.read.csv(rescue_path, header = True)
animal_rescue_csv <- config$rescue_path_csv

df <- sparklyr::spark_read_csv(sc,
                              path=animal_rescue_csv,
                              header=TRUE,
                              infer_schema=TRUE)

Then do some preparation: drop and rename some columns, change the format, then sort.

Note that if saving as a Hive table there are some stricter rules, including:

  • Some characters aren’t allowed in column names, including £

  • The table won’t load in the browser in HUE if you use date, but will accept a timestamp

We then preview the DataFrame with .toPandas() (remember to use .limit() when looking at data in this way):

df = (df.
    drop(
        "WardCode", 
        "BoroughCode", 
        "Easting_m", 
        "Northing_m", 
        "Easting_rounded", 
        "Northing_rounded")
    .withColumnRenamed("PumpCount", "EngineCount")
    .withColumnRenamed("FinalDescription", "Description")
    .withColumnRenamed("HourlyNotionalCost(£)", "HourlyCost")
    .withColumnRenamed("IncidentNotionalCost(£)", "TotalCost")
    .withColumnRenamed("OriginofCall", "OriginOfCall")
    .withColumnRenamed("PumpHoursTotal", "JobHours")
    .withColumnRenamed("AnimalGroupParent", "AnimalGroup")
    .withColumn(
        "DateTimeOfCall", F.to_timestamp(F.col("DateTimeOfCall"), "dd/MM/yyyy"))
    .orderBy("IncidentNumber")
    )

df.limit(3).toPandas()
IncidentNumber DateTimeOfCall CalYear FinYear TypeOfIncident EngineCount JobHours HourlyCost TotalCost Description AnimalGroup OriginOfCall PropertyType PropertyCategory SpecialServiceTypeCategory SpecialServiceType Ward Borough StnGroundName PostcodeDistrict
0 000014-03092018M 2018-09-03 2018 2018/19 Special Service 2.0 3.0 333 999.0 None Unknown - Heavy Livestock Animal Other FRS Animal harm outdoors Outdoor Other animal assistance Animal harm involving livestock CARSHALTON SOUTH AND CLOCKHOUSE SUTTON Wallington CR8
1 000099-01012017 2017-01-01 2017 2016/17 Special Service 1.0 2.0 326 652.0 DOG WITH HEAD STUCK IN RAILINGS CALLED BY OWNER Dog Person (mobile) Railings Outdoor Structure Other animal assistance Assist trapped domestic animal BROMLEY TOWN BROMLEY Bromley BR2
2 000260-01012017 2017-01-01 2017 2016/17 Special Service 1.0 1.0 326 326.0 BIRD TRAPPED IN NETTING BY THE 02 SHOP AND NEA... Bird Person (land line) Single shop Non Residential Animal rescue from height Animal rescue from height - Bird Fairfield CROYDON Croydon CR0
df <- df %>%
    sparklyr::select(-c( 
             "WardCode", 
             "BoroughCode", 
             "Easting_m", 
             "Northing_m", 
             "Easting_rounded", 
             "Northing_rounded"), 
            "EngineCount" = "PumpCount",
            "Description" = "FinalDescription",
            "HourlyCost" = "HourlyNotionalCostGBP",
            "TotalCost" = "IncidentNotionalCostGBP",
            "OriginOfCall" = "OriginofCall",
            "JobHours" = "PumpHoursTotal",
            "AnimalGroup" = "AnimalGroupParent") %>%

    sparklyr::mutate(DateTimeOfCall = to_date(DateTimeOfCall, "dd/MM/yyyy")) %>%
    dplyr::arrange(desc(IncidentNumber)) 

df %>%
    head(3) %>%
    sparklyr::collect() %>%
    print() 

Let’s look at the plan with df.explain(). This displays what precisely Spark will do once an action is called (lazy evaluation). This is a simple example but in long pipelines this plan can get complicated. Using a staging table can split this process, referred to as cutting the lineage.

df.explain()
== Physical Plan ==
*(2) Sort [IncidentNumber#322 ASC NULLS FIRST], true, 0
+- Exchange rangepartitioning(IncidentNumber#322 ASC NULLS FIRST, 200)
   +- *(1) Project [IncidentNumber#322, cast(unix_timestamp(DateTimeOfCall#323, dd/MM/yyyy, Some(Europe/London)) as timestamp) AS DateTimeOfCall#541, CalYear#324, FinYear#325, TypeOfIncident#326, PumpCount#327 AS EngineCount#394, PumpHoursTotal#328 AS JobHours#499, HourlyNotionalCost(£)#329 AS HourlyCost#436, IncidentNotionalCost(£)#330 AS TotalCost#457, FinalDescription#331 AS Description#415, AnimalGroupParent#332 AS AnimalGroup#520, OriginofCall#333 AS OriginOfCall#478, PropertyType#334, PropertyCategory#335, SpecialServiceTypeCategory#336, SpecialServiceType#337, Ward#339, Borough#341, StnGroundName#342, PostcodeDistrict#343]
      +- *(1) FileScan csv [IncidentNumber#322,DateTimeOfCall#323,CalYear#324,FinYear#325,TypeOfIncident#326,PumpCount#327,PumpHoursTotal#328,HourlyNotionalCost(£)#329,IncidentNotionalCost(£)#330,FinalDescription#331,AnimalGroupParent#332,OriginofCall#333,PropertyType#334,PropertyCategory#335,SpecialServiceTypeCategory#336,SpecialServiceType#337,Ward#339,Borough#341,StnGroundName#342,PostcodeDistrict#343] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/home/cdsw/ons-spark/ons-spark/data/animal_rescue.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<IncidentNumber:string,DateTimeOfCall:string,CalYear:string,FinYear:string,TypeOfIncident:s...
explain(df)

Now save the DataFrame as table, using mode("overwrite"), which overwrites the existing table if there is one. The first time you create a staging table this option will be redundant, but on subsequent runs on the code you will get an error without this as the table will already exist.

Note that we specify the database we want to save the table in. In this instance, we want to save the table in the training database. The format for saving within a specified database is database.table_name.

username = os.getenv('HADOOP_USER_NAME') 

table_name_plain = config['staging_table_example']
table_name = table_name_plain+username
database = "training"

df.write.mode("overwrite").saveAsTable(f"{database}.{table_name}", format="parquet")
username <- Sys.getenv('HADOOP_USER_NAME')
df <- sparklyr::sdf_register(df, 'df')

database <- config$database

table_name_plain <- config$staging_table_example
table_name <- paste0(table_name_plain, username)

sql <- paste0('DROP TABLE IF EXISTS ', database, '.', table_name)
invisible(DBI::dbExecute(sc, sql))

tbl_change_db(sc, database)
sparklyr::spark_write_table(df, name = table_name)

Now read the data in again and preview:

df = spark.read.table(table_name)
df.limit(3).toPandas()
df <- sparklyr::spark_read_table(sc, table_name, repartition = 0)

df %>%
    head(3) %>%
    sparklyr::collect() %>%
    print()

The DataFrame has the same structure as previously, but when we look at the plan with df.explain() we can see that less is being done. This is an example of cutting the lineage and can be useful when you have complex plans.

df.explain()
explain(df)

Using .insertInto()#

Another method is to create an empty table and then use .insertInto(); here we will just use a small number of columns as an example:

small_table = f"train_tmp.staging_small_{username}"

spark.sql(f"""
    CREATE TABLE {small_table} (
        IncidentNumber STRING,
        CalYear INT,
        EngineCount INT,
        AnimalGroup STRING
    )
    STORED AS PARQUET
    """)

Note that the columns will be inserted by position, not name, so it’s a good idea to re-select the column order to match that of the table before inserting in:

col_order = spark.read.table(small_table).columns
df.select(col_order).write.insertInto(small_table, overwrite=True)

This can then be read in as before:

df = spark.read.table(small_table)
df.show(5)

Finally we will drop the tables used in this example, which we can do with the DROP SQL statement. This is much easier than deleting a checkpointed file.

Of course, with staging tables you generally want to keep the table, but just overwrite the data each time, so this step often won’t be needed.

Always be very careful when using DROP as this will delete the table without warning!

spark.sql(f"DROP TABLE {table_name}")
spark.sql(f"DROP TABLE {small_table}")

Further Resources#

Spark at the ONS Articles:

PySpark Documentation:

SparklyR Documentation:

Python Documentation:

R Documentation:

Other material: