Returning Data from Cluster to Driver#

This article explores different ways of moving small amounts of data from a PySpark DataFrame, which is lazily evaluated on the Spark cluster, into the driver. Remember that the Spark cluster will have more memory than the driver, so be careful about the amount of data that you are returning.

There are lots of ways to do this; most users will use .show(), the .limit() .toPandas() combination or use eager evaluation. Several other methods are included for completeness, including the original method, .collect():

  • .show()

  • .toPandas()

  • Eager Evaluation

  • .collect() and Variations

    • .take()

    • .first()

    • .head()

Lazy Evaluation and .printSchema()#

Before looking at each of the methods in turn, it is worth revisiting the concept of lazy evaluation. Spark DataFrames are not evaluated by default. An action has to be called in order for the cluster to process the Spark plan. .show(), .toPandas() and .collect() are all examples of actions.

Note that if you implicitly try and print the DataFrame it will just return the schema, rather than evaluating the DataFrame. First, start a Spark session and load the Animal Rescue data:

from pyspark.sql import SparkSession, functions as F

import pandas as pd
import yaml

with open("../../config.yaml") as f:
    config = yaml.safe_load(f)
    
spark = (SparkSession.builder.master("local[2]")
         .appName("returning-data")
         .getOrCreate())

rescue_path = config["rescue_path"]

rescue = spark.read.parquet(rescue_path)

Then type the name of the DataFrame; this will print the column names and types but will not evaluate the DataFrame:

rescue
DataFrame[incident_number: string, date_time_of_call: string, cal_year: int, fin_year: string, type_of_incident: string, engine_count: double, job_hours: double, hourly_cost: int, total_cost: double, description: string, animal_group: string, origin_of_call: string, property_type: string, property_category: string, special_service_type_category: string, special_service_type: string, ward_code: string, ward: string, borough_code: string, borough: string, stn_ground_name: string, postcode_district: string, easting_m: double, northing_m: double, easting_rounded: int, northing_rounded: int]

A better way of getting the schema is with .printSchema(), which prints out in a much more readable manner. This is not an action as the DataFrame does not need to be evaluated to get this information.

rescue.printSchema()
root
 |-- incident_number: string (nullable = true)
 |-- date_time_of_call: string (nullable = true)
 |-- cal_year: integer (nullable = true)
 |-- fin_year: string (nullable = true)
 |-- type_of_incident: string (nullable = true)
 |-- engine_count: double (nullable = true)
 |-- job_hours: double (nullable = true)
 |-- hourly_cost: integer (nullable = true)
 |-- total_cost: double (nullable = true)
 |-- description: string (nullable = true)
 |-- animal_group: string (nullable = true)
 |-- origin_of_call: string (nullable = true)
 |-- property_type: string (nullable = true)
 |-- property_category: string (nullable = true)
 |-- special_service_type_category: string (nullable = true)
 |-- special_service_type: string (nullable = true)
 |-- ward_code: string (nullable = true)
 |-- ward: string (nullable = true)
 |-- borough_code: string (nullable = true)
 |-- borough: string (nullable = true)
 |-- stn_ground_name: string (nullable = true)
 |-- postcode_district: string (nullable = true)
 |-- easting_m: double (nullable = true)
 |-- northing_m: double (nullable = true)
 |-- easting_rounded: integer (nullable = true)
 |-- northing_rounded: integer (nullable = true)

.show()#

Using .show() is the easiest way to preview a PySpark DataFrame. By default it will print out \(20\) rows of the DataFrame.

You cannot assign the results of .show() to a variable, so it is purely for information.

By default Spark DataFrames are not ordered, due to the fact they are distributed on the cluster. This means that calling .show() on the same DataFrame several times can return different results.

This can look ugly if there are many columns in the DataFrame, so often you will want to use .select() to only return the columns you are interested in.

(rescue
    .select("incident_number", "animal_group", "total_cost", "description")
    .show())
+---------------+--------------------+----------+--------------------+
|incident_number|        animal_group|total_cost|         description|
+---------------+--------------------+----------+--------------------+
|       80771131|                 Cat|     290.0|CAT TRAPPED IN BA...|
|      141817141|               Horse|     590.0|HORSE TRAPPED IN ...|
|143166-22102016|                Bird|     326.0|PIGEON WITH WING ...|
|       43051141|                 Cat|     295.0|ASSIST RSPCA WITH...|
|        9393131|                 Dog|     260.0|DOG FALLEN INTO T...|
|       44345121|                Deer|     520.0|DEER STUCK IN RAI...|
|       58835101|                Deer|     260.0|DEER TRAPPED IN F...|
|126246-03092018|                 Cat|     333.0|KITTEN TRAPPED BE...|
|       98474151|                 Fox|     298.0|ASSIST RSPCA WAS ...|
|       17398141|                 Cat|    1160.0|CAT STUCK IN CAR ...|
|       26486141|                Bird|     290.0|PEREGRINE FALCON ...|
|      144750111|                 Dog|     260.0|DOG IN PRECARIOUS...|
|129971-26092017|                 Cat|     328.0|KITTEN TRAPPED UN...|
|      113396111|                 Cat|     260.0|ASSIST RSPCA WITH...|
|      105429101|               Horse|     260.0|HORSE STUCK IN DITCH|
|165278-09122017|Unknown - Domesti...|     328.0|CALLERS SMALL PET...|
|      143598091|                 Cat|     260.0|         CAT UP TREE|
|       38468151|                Bird|     298.0|BIRD TRAPPED IN C...|
|052371-03052016|                 Cat|     326.0|ASSIST RSPCA INSP...|
|      156017101|                Deer|     520.0|DEER TRAPPED IN R...|
+---------------+--------------------+----------+--------------------+
only showing top 20 rows

.show() has three arguments. n is the number of rows to return (default 20), truncate will truncate long string entries (default True) and vertical will return one row per line (default False).

(rescue
    .select("incident_number", "animal_group", "total_cost", "description")
    .show(n=3, truncate=False, vertical=True))
-RECORD 0---------------------------------------------------------------------------
 incident_number | 80771131                                                         
 animal_group    | Cat                                                              
 total_cost      | 290.0                                                            
 description     | CAT TRAPPED IN BASEMENT                                          
-RECORD 1---------------------------------------------------------------------------
 incident_number | 141817141                                                        
 animal_group    | Horse                                                            
 total_cost      | 590.0                                                            
 description     | HORSE TRAPPED IN GATE                                            
-RECORD 2---------------------------------------------------------------------------
 incident_number | 143166-22102016                                                  
 animal_group    | Bird                                                             
 total_cost      | 326.0                                                            
 description     | PIGEON WITH WING IMAPLED ON SHARP IMPLEMENT  UNDER A BRIDGE NEAR 
only showing top 3 rows

.toPandas()#

.toPandas() converts a PySpark DataFrame into a pandas DataFrame. You can then use all the usual methods on a pandas DataFrame and it will have the standard pandas properties of being mutable, with a fixed row order.

Be careful with toPandas() as it will convert the whole DF, which is an issue when you are dealing with large data on the Spark cluster. Calling limit() before .toPandas() is a good way to ensure that you do not overload the driver.

You can either print out the results of toPandas() immediately, or assign it to a variable:

rescue_pandas = (rescue
                 .select("incident_number", "animal_group", "total_cost", "description")
                 .limit(5)
                 .toPandas())
rescue
DataFrame[incident_number: string, date_time_of_call: string, cal_year: int, fin_year: string, type_of_incident: string, engine_count: double, job_hours: double, hourly_cost: int, total_cost: double, description: string, animal_group: string, origin_of_call: string, property_type: string, property_category: string, special_service_type_category: string, special_service_type: string, ward_code: string, ward: string, borough_code: string, borough: string, stn_ground_name: string, postcode_district: string, easting_m: double, northing_m: double, easting_rounded: int, northing_rounded: int]
(rescue
    .select("incident_number", "animal_group", "total_cost", "description")
    .limit(5)
    .toPandas())
incident_number animal_group total_cost description
0 80771131 Cat 290.0 CAT TRAPPED IN BASEMENT
1 141817141 Horse 590.0 HORSE TRAPPED IN GATE
2 143166-22102016 Bird 326.0 PIGEON WITH WING IMAPLED ON SHARP IMPLEMENT U...
3 43051141 Cat 295.0 ASSIST RSPCA WITH CAT STUCK ON CHIMNEY
4 9393131 Dog 260.0 DOG FALLEN INTO THE CANAL

An alternative to using .limit() is to first check the size of the DataFrame. If it is small there is no need to call .limit():

lizards = (rescue
           .select("incident_number", "animal_group", "total_cost", "description")
           .filter(F.col("animal_group") == "Lizard"))
lizards.count()
3
lizards.toPandas()
incident_number animal_group total_cost description
0 117580091 Lizard 260.0 PET LIZARD TRAPPED BEHIND RADIATOR
1 74480101 Lizard 260.0 IGUANA TRAPPED BEHIND HEATING PIPES
2 070849-04062018 Lizard 333.0 ASSIST RSPCA WITH IGUANA IN TREE (30 FEET HIGH)

Eager Evaluation#

By default, Spark does not display the contents of your DataFrame when you type the name of your DataFrame to implicitly print, instead returning the column name and type; we saw the output from printing rescue earlier.

There is a way to change this behaviour so that typing the name of your DataFrame returns a nicely formatted preview of the results: add .config("spark.sql.repl.eagerEval.enabled", 'true') to the Spark session builder. This will make the behaviour of typing the DF similar to .show(). The number of rows returned will be limited automatically so there is no need to apply the .limit() command to your DataFrame. You can set the number of rows to be returned with .config("spark.sql.repl.eagerEval.maxNumRows", number_of_rows); by default this is 20. Columns are truncated to 20 characters by default; this can be changed with spark.sql.repl.eagerEval.truncate. More information is in the Spark Configuration documentation.

Before setting this property we need to stop the existing Spark session with spark.stop().

If this is something that you find desirable you may want to add these settings to a spark-defaults.conf file, so they are applied automatically when starting a Spark session.

spark.stop()

spark = (SparkSession.builder.master("local[2]")
         .appName("returning-data")
         # Enable eager evaluation of PySpark DFs
         .config("spark.sql.repl.eagerEval.enabled", 'true')
         # Maximum rows to return from the DF preview
         .config("spark.sql.repl.eagerEval.maxNumRows", 10)
         # Set number of characters to return per column
         .config("spark.sql.repl.eagerEval.truncate", 100)
         .getOrCreate())

rescue = spark.read.parquet(rescue_path)

rescue.select("incident_number", "animal_group", "total_cost", "description")
incident_numberanimal_grouptotal_costdescription
80771131Cat290.0CAT TRAPPED IN BASEMENT
141817141Horse590.0HORSE TRAPPED IN GATE
143166-22102016Bird326.0PIGEON WITH WING IMAPLED ON SHARP IMPLEMENT UNDER A BRIDGE NEAR
43051141Cat295.0ASSIST RSPCA WITH CAT STUCK ON CHIMNEY
9393131Dog260.0DOG FALLEN INTO THE CANAL
44345121Deer520.0DEER STUCK IN RAILINGS
58835101Deer260.0DEER TRAPPED IN FENCE
126246-03092018Cat333.0KITTEN TRAPPED BEHIND BATH
98474151Fox298.0ASSIST RSPCA WAS FOX TRAPPED IN BARBED WIRE
17398141Cat1160.0CAT STUCK IN CAR ENGINE
only showing top 10 rows

.collect() and variations#

.collect() is the original way of converting Spark DataFrames from the cluster to the driver and can also be used directly on RDDs (Resilient Distributed Datasets) as well as DataFrames. Like with .toPandas() this will bring back the whole DataFrame to the driver, so either combine with .limit() or verify that the DF is not too large with .count() first.

Manipulating the output of .collect() is quite awkward, as it returns a list of Row objects. It is generally better to just use .toPandas(), as pandas DataFrames are easy to manipulate and have built in methods to convert to other data structures if desired.

The output of .collect() can be assigned to variable or printed out directly:

rescue_collected = (rescue
                    .select("incident_number", "animal_group", "total_cost", "description")
                    .limit(5)
                    .collect())
rescue_collected
[Row(incident_number='80771131', animal_group='Cat', total_cost=290.0, description='CAT TRAPPED IN BASEMENT'),
 Row(incident_number='141817141', animal_group='Horse', total_cost=590.0, description='HORSE TRAPPED IN GATE'),
 Row(incident_number='143166-22102016', animal_group='Bird', total_cost=326.0, description='PIGEON WITH WING IMAPLED ON SHARP IMPLEMENT  UNDER A BRIDGE NEAR'),
 Row(incident_number='43051141', animal_group='Cat', total_cost=295.0, description='ASSIST RSPCA WITH CAT STUCK ON CHIMNEY'),
 Row(incident_number='9393131', animal_group='Dog', total_cost=260.0, description='DOG FALLEN INTO THE CANAL')]
(rescue
    .select("incident_number", "animal_group", "total_cost", "description")
    .limit(5)
    .collect())
[Row(incident_number='80771131', animal_group='Cat', total_cost=290.0, description='CAT TRAPPED IN BASEMENT'),
 Row(incident_number='141817141', animal_group='Horse', total_cost=590.0, description='HORSE TRAPPED IN GATE'),
 Row(incident_number='143166-22102016', animal_group='Bird', total_cost=326.0, description='PIGEON WITH WING IMAPLED ON SHARP IMPLEMENT  UNDER A BRIDGE NEAR'),
 Row(incident_number='43051141', animal_group='Cat', total_cost=295.0, description='ASSIST RSPCA WITH CAT STUCK ON CHIMNEY'),
 Row(incident_number='9393131', animal_group='Dog', total_cost=260.0, description='DOG FALLEN INTO THE CANAL')]

Even just collecting one column still returns a list of Row objects. You can use an identity .flatMap() to return a normal Python list. Note that .flatMap() is an RDD method and so you need to convert the DF to an RDD first with .rdd:

rescue.select("incident_number").limit(5).rdd.flatMap(lambda x: x).collect()
['80771131', '141817141', '143166-22102016', '43051141', '9393131']

Although it is much easier in pandas by converting the one column DataFrame into a series with .squeeze() (or by selecting the column name for implicit conversion) then using .tolist():

rescue.select("incident_number").limit(5).toPandas().squeeze().tolist()
['80771131', '141817141', '143166-22102016', '43051141', '9393131']
rescue.select("incident_number").limit(5).toPandas()["incident_number"].tolist()
['80771131', '141817141', '143166-22102016', '43051141', '9393131']

The main use case for .collect() is returning a value from a one row DataFrame as a scalar value and then assigning to a variable:

sum_cost = (rescue
            .agg(F.sum("total_cost"))
            .collect()[0][0])
sum_cost
2012231.0

.take(): combine .limit() and .collect()#

An alternative to .collect() is to use .take(num) which will only return a maximum of num rows (or the whole DataFrame if the number of rows is less than num). This is equivalent to .limit(num).collect(). Remember that using .collect() will bring the entire DataFrame into the driver, so if it is large you need to reduce the size first. Using .take() removes this problem.

.take() has no default value, so you must supply num (unlike .show(), which has a default of 20).

Just like .collect() you can assign the output of .take() to a variable or print out directly:

rescue_take = rescue.select("incident_number", "animal_group", "total_cost", "description").take(5)
rescue_take
[Row(incident_number='80771131', animal_group='Cat', total_cost=290.0, description='CAT TRAPPED IN BASEMENT'),
 Row(incident_number='141817141', animal_group='Horse', total_cost=590.0, description='HORSE TRAPPED IN GATE'),
 Row(incident_number='143166-22102016', animal_group='Bird', total_cost=326.0, description='PIGEON WITH WING IMAPLED ON SHARP IMPLEMENT  UNDER A BRIDGE NEAR'),
 Row(incident_number='43051141', animal_group='Cat', total_cost=295.0, description='ASSIST RSPCA WITH CAT STUCK ON CHIMNEY'),
 Row(incident_number='9393131', animal_group='Dog', total_cost=260.0, description='DOG FALLEN INTO THE CANAL')]

.first(): return one row#

If you only want one row, use .first(). The advantage of .first() over .take(1) or .limit(1).collect() is that a single Row object is returned, rather than a Row object within a list of length 1. This means the code looks neater if you only want to extract one value.

As PySpark DataFrames are not ordered by default .first() is non-deterministic, and so it is generally combined with .orderBy():

rescue_first = (rescue
                .select("incident_number", "animal_group", "total_cost", "description")
                .orderBy(F.desc("total_cost"))
                .first())
rescue_first
Row(incident_number='098141-28072016', animal_group='Cat', total_cost=3912.0, description='CAT STUCK WITHIN WALL SPACE  RSPCA IN ATTENDANCE')

This Row object is not in a list, and so values can be referenced directly:

top_cost = rescue_first["total_cost"]
top_cost
3912.0

Note that if the DataFrame is empty this will return None:

first_none = (rescue
              .filter(F.col("animal_group") == "Dragon")
              .first())
print(first_none)
None

Whereas .collect() and .take() will return an empty list:

collect_empty = (rescue
                 .filter(F.col("animal_group") == "Dragon")
                 .collect())
collect_empty
[]

.head(): a confusing function#

Another method included for completeness is .head(). This works in the same way as .take() if you specify the number of rows (returns a list of Row objects, even if it only has one row), or .first() if not (returns a single Row object). This could potentially get confusing so it is better to just use .take() or .first() instead.

# Same as .first()
rescue_blank = rescue.select("incident_number", "animal_group", "total_cost", "description").head()

# Returns a Row object
rescue_blank
Row(incident_number='80771131', animal_group='Cat', total_cost=290.0, description='CAT TRAPPED IN BASEMENT')
# Same as .take(1)
rescue_blank = rescue.select("incident_number", "animal_group", "total_cost", "description").head(1)

# Returns a list of Row objects (only one item in this case)
rescue_blank
[Row(incident_number='80771131', animal_group='Cat', total_cost=290.0, description='CAT TRAPPED IN BASEMENT')]

Source Code#

For these functions it is interesting to take a look at the PySpark source code; normally all you need is in the documentation, but here you can see how the methods actually work. For instance, df.take(10) is the same as df.limit(10).collect().

Further Resources#

Spark at the ONS Articles:

PySpark Documentation:

PySpark Source Code: