Returning Data from Cluster to Driver#
This article explores different ways of moving small amounts of data from a PySpark DataFrame, which is lazily evaluated on the Spark cluster, into the driver. Remember that the Spark cluster will have more memory than the driver, so be careful about the amount of data that you are returning.
There are lots of ways to do this; most users will use .show()
, the .limit()
.toPandas()
combination or use eager evaluation. Several other methods are included for completeness, including the original method, .collect()
:
.show()
.toPandas()
Eager Evaluation
.collect()
and Variations.take()
.first()
.head()
Lazy Evaluation and .printSchema()
#
Before looking at each of the methods in turn, it is worth revisiting the concept of lazy evaluation. Spark DataFrames are not evaluated by default. An action has to be called in order for the cluster to process the Spark plan. .show()
, .toPandas()
and .collect()
are all examples of actions.
Note that if you implicitly try and print the DataFrame it will just return the schema, rather than evaluating the DataFrame. First, start a Spark session and load the Animal Rescue data:
from pyspark.sql import SparkSession, functions as F
import pandas as pd
import yaml
with open("../../config.yaml") as f:
config = yaml.safe_load(f)
spark = (SparkSession.builder.master("local[2]")
.appName("returning-data")
.getOrCreate())
rescue_path = config["rescue_path"]
rescue = spark.read.parquet(rescue_path)
Then type the name of the DataFrame; this will print the column names and types but will not evaluate the DataFrame:
rescue
DataFrame[incident_number: string, date_time_of_call: string, cal_year: int, fin_year: string, type_of_incident: string, engine_count: double, job_hours: double, hourly_cost: int, total_cost: double, description: string, animal_group: string, origin_of_call: string, property_type: string, property_category: string, special_service_type_category: string, special_service_type: string, ward_code: string, ward: string, borough_code: string, borough: string, stn_ground_name: string, postcode_district: string, easting_m: double, northing_m: double, easting_rounded: int, northing_rounded: int]
A better way of getting the schema is with .printSchema()
, which prints out in a much more readable manner. This is not an action as the DataFrame does not need to be evaluated to get this information.
rescue.printSchema()
root
|-- incident_number: string (nullable = true)
|-- date_time_of_call: string (nullable = true)
|-- cal_year: integer (nullable = true)
|-- fin_year: string (nullable = true)
|-- type_of_incident: string (nullable = true)
|-- engine_count: double (nullable = true)
|-- job_hours: double (nullable = true)
|-- hourly_cost: integer (nullable = true)
|-- total_cost: double (nullable = true)
|-- description: string (nullable = true)
|-- animal_group: string (nullable = true)
|-- origin_of_call: string (nullable = true)
|-- property_type: string (nullable = true)
|-- property_category: string (nullable = true)
|-- special_service_type_category: string (nullable = true)
|-- special_service_type: string (nullable = true)
|-- ward_code: string (nullable = true)
|-- ward: string (nullable = true)
|-- borough_code: string (nullable = true)
|-- borough: string (nullable = true)
|-- stn_ground_name: string (nullable = true)
|-- postcode_district: string (nullable = true)
|-- easting_m: double (nullable = true)
|-- northing_m: double (nullable = true)
|-- easting_rounded: integer (nullable = true)
|-- northing_rounded: integer (nullable = true)
.show()
#
Using .show()
is the easiest way to preview a PySpark DataFrame. By default it will print out \(20\) rows of the DataFrame.
You cannot assign the results of .show()
to a variable, so it is purely for information.
By default Spark DataFrames are not ordered, due to the fact they are distributed on the cluster. This means that calling .show()
on the same DataFrame several times can return different results.
This can look ugly if there are many columns in the DataFrame, so often you will want to use .select()
to only return the columns you are interested in.
(rescue
.select("incident_number", "animal_group", "total_cost", "description")
.show())
+---------------+--------------------+----------+--------------------+
|incident_number| animal_group|total_cost| description|
+---------------+--------------------+----------+--------------------+
| 80771131| Cat| 290.0|CAT TRAPPED IN BA...|
| 141817141| Horse| 590.0|HORSE TRAPPED IN ...|
|143166-22102016| Bird| 326.0|PIGEON WITH WING ...|
| 43051141| Cat| 295.0|ASSIST RSPCA WITH...|
| 9393131| Dog| 260.0|DOG FALLEN INTO T...|
| 44345121| Deer| 520.0|DEER STUCK IN RAI...|
| 58835101| Deer| 260.0|DEER TRAPPED IN F...|
|126246-03092018| Cat| 333.0|KITTEN TRAPPED BE...|
| 98474151| Fox| 298.0|ASSIST RSPCA WAS ...|
| 17398141| Cat| 1160.0|CAT STUCK IN CAR ...|
| 26486141| Bird| 290.0|PEREGRINE FALCON ...|
| 144750111| Dog| 260.0|DOG IN PRECARIOUS...|
|129971-26092017| Cat| 328.0|KITTEN TRAPPED UN...|
| 113396111| Cat| 260.0|ASSIST RSPCA WITH...|
| 105429101| Horse| 260.0|HORSE STUCK IN DITCH|
|165278-09122017|Unknown - Domesti...| 328.0|CALLERS SMALL PET...|
| 143598091| Cat| 260.0| CAT UP TREE|
| 38468151| Bird| 298.0|BIRD TRAPPED IN C...|
|052371-03052016| Cat| 326.0|ASSIST RSPCA INSP...|
| 156017101| Deer| 520.0|DEER TRAPPED IN R...|
+---------------+--------------------+----------+--------------------+
only showing top 20 rows
.show()
has three arguments. n
is the number of rows to return (default 20
), truncate
will truncate long string entries (default True
) and vertical
will return one row per line (default False
).
(rescue
.select("incident_number", "animal_group", "total_cost", "description")
.show(n=3, truncate=False, vertical=True))
-RECORD 0---------------------------------------------------------------------------
incident_number | 80771131
animal_group | Cat
total_cost | 290.0
description | CAT TRAPPED IN BASEMENT
-RECORD 1---------------------------------------------------------------------------
incident_number | 141817141
animal_group | Horse
total_cost | 590.0
description | HORSE TRAPPED IN GATE
-RECORD 2---------------------------------------------------------------------------
incident_number | 143166-22102016
animal_group | Bird
total_cost | 326.0
description | PIGEON WITH WING IMAPLED ON SHARP IMPLEMENT UNDER A BRIDGE NEAR
only showing top 3 rows
.toPandas()
#
.toPandas()
converts a PySpark DataFrame into a pandas DataFrame. You can then use all the usual methods on a pandas DataFrame and it will have the standard pandas properties of being mutable, with a fixed row order.
Be careful with toPandas()
as it will convert the whole DF, which is an issue when you are dealing with large data on the Spark cluster. Calling limit()
before .toPandas()
is a good way to ensure that you do not overload the driver.
You can either print out the results of toPandas()
immediately, or assign it to a variable:
rescue_pandas = (rescue
.select("incident_number", "animal_group", "total_cost", "description")
.limit(5)
.toPandas())
rescue
DataFrame[incident_number: string, date_time_of_call: string, cal_year: int, fin_year: string, type_of_incident: string, engine_count: double, job_hours: double, hourly_cost: int, total_cost: double, description: string, animal_group: string, origin_of_call: string, property_type: string, property_category: string, special_service_type_category: string, special_service_type: string, ward_code: string, ward: string, borough_code: string, borough: string, stn_ground_name: string, postcode_district: string, easting_m: double, northing_m: double, easting_rounded: int, northing_rounded: int]
(rescue
.select("incident_number", "animal_group", "total_cost", "description")
.limit(5)
.toPandas())
incident_number | animal_group | total_cost | description | |
---|---|---|---|---|
0 | 80771131 | Cat | 290.0 | CAT TRAPPED IN BASEMENT |
1 | 141817141 | Horse | 590.0 | HORSE TRAPPED IN GATE |
2 | 143166-22102016 | Bird | 326.0 | PIGEON WITH WING IMAPLED ON SHARP IMPLEMENT U... |
3 | 43051141 | Cat | 295.0 | ASSIST RSPCA WITH CAT STUCK ON CHIMNEY |
4 | 9393131 | Dog | 260.0 | DOG FALLEN INTO THE CANAL |
An alternative to using .limit()
is to first check the size of the DataFrame. If it is small there is no need to call .limit()
:
lizards = (rescue
.select("incident_number", "animal_group", "total_cost", "description")
.filter(F.col("animal_group") == "Lizard"))
lizards.count()
3
lizards.toPandas()
incident_number | animal_group | total_cost | description | |
---|---|---|---|---|
0 | 117580091 | Lizard | 260.0 | PET LIZARD TRAPPED BEHIND RADIATOR |
1 | 74480101 | Lizard | 260.0 | IGUANA TRAPPED BEHIND HEATING PIPES |
2 | 070849-04062018 | Lizard | 333.0 | ASSIST RSPCA WITH IGUANA IN TREE (30 FEET HIGH) |
Eager Evaluation#
By default, Spark does not display the contents of your DataFrame when you type the name of your DataFrame to implicitly print, instead returning the column name and type; we saw the output from printing rescue
earlier.
There is a way to change this behaviour so that typing the name of your DataFrame returns a nicely formatted preview of the results: add .config("spark.sql.repl.eagerEval.enabled", 'true')
to the Spark session builder. This will make the behaviour of typing the DF similar to .show()
. The number of rows returned will be limited automatically so there is no need to apply the .limit()
command to your DataFrame. You can set the number of rows to be returned with .config("spark.sql.repl.eagerEval.maxNumRows", number_of_rows)
; by default this is 20
. Columns are truncated to 20
characters by default; this can be changed with spark.sql.repl.eagerEval.truncate
. More information is in the Spark Configuration documentation.
Before setting this property we need to stop the existing Spark session with spark.stop()
.
If this is something that you find desirable you may want to add these settings to a spark-defaults.conf
file, so they are applied automatically when starting a Spark session.
spark.stop()
spark = (SparkSession.builder.master("local[2]")
.appName("returning-data")
# Enable eager evaluation of PySpark DFs
.config("spark.sql.repl.eagerEval.enabled", 'true')
# Maximum rows to return from the DF preview
.config("spark.sql.repl.eagerEval.maxNumRows", 10)
# Set number of characters to return per column
.config("spark.sql.repl.eagerEval.truncate", 100)
.getOrCreate())
rescue = spark.read.parquet(rescue_path)
rescue.select("incident_number", "animal_group", "total_cost", "description")
incident_number | animal_group | total_cost | description |
---|---|---|---|
80771131 | Cat | 290.0 | CAT TRAPPED IN BASEMENT |
141817141 | Horse | 590.0 | HORSE TRAPPED IN GATE |
143166-22102016 | Bird | 326.0 | PIGEON WITH WING IMAPLED ON SHARP IMPLEMENT UNDER A BRIDGE NEAR |
43051141 | Cat | 295.0 | ASSIST RSPCA WITH CAT STUCK ON CHIMNEY |
9393131 | Dog | 260.0 | DOG FALLEN INTO THE CANAL |
44345121 | Deer | 520.0 | DEER STUCK IN RAILINGS |
58835101 | Deer | 260.0 | DEER TRAPPED IN FENCE |
126246-03092018 | Cat | 333.0 | KITTEN TRAPPED BEHIND BATH |
98474151 | Fox | 298.0 | ASSIST RSPCA WAS FOX TRAPPED IN BARBED WIRE |
17398141 | Cat | 1160.0 | CAT STUCK IN CAR ENGINE |
.collect()
and variations#
.collect()
is the original way of converting Spark DataFrames from the cluster to the driver and can also be used directly on RDDs (Resilient Distributed Datasets) as well as DataFrames. Like with .toPandas()
this will bring back the whole DataFrame to the driver, so either combine with .limit()
or verify that the DF is not too large with .count()
first.
Manipulating the output of .collect()
is quite awkward, as it returns a list of Row
objects. It is generally better to just use .toPandas()
, as pandas DataFrames are easy to manipulate and have built in methods to convert to other data structures if desired.
The output of .collect()
can be assigned to variable or printed out directly:
rescue_collected = (rescue
.select("incident_number", "animal_group", "total_cost", "description")
.limit(5)
.collect())
rescue_collected
[Row(incident_number='80771131', animal_group='Cat', total_cost=290.0, description='CAT TRAPPED IN BASEMENT'),
Row(incident_number='141817141', animal_group='Horse', total_cost=590.0, description='HORSE TRAPPED IN GATE'),
Row(incident_number='143166-22102016', animal_group='Bird', total_cost=326.0, description='PIGEON WITH WING IMAPLED ON SHARP IMPLEMENT UNDER A BRIDGE NEAR'),
Row(incident_number='43051141', animal_group='Cat', total_cost=295.0, description='ASSIST RSPCA WITH CAT STUCK ON CHIMNEY'),
Row(incident_number='9393131', animal_group='Dog', total_cost=260.0, description='DOG FALLEN INTO THE CANAL')]
(rescue
.select("incident_number", "animal_group", "total_cost", "description")
.limit(5)
.collect())
[Row(incident_number='80771131', animal_group='Cat', total_cost=290.0, description='CAT TRAPPED IN BASEMENT'),
Row(incident_number='141817141', animal_group='Horse', total_cost=590.0, description='HORSE TRAPPED IN GATE'),
Row(incident_number='143166-22102016', animal_group='Bird', total_cost=326.0, description='PIGEON WITH WING IMAPLED ON SHARP IMPLEMENT UNDER A BRIDGE NEAR'),
Row(incident_number='43051141', animal_group='Cat', total_cost=295.0, description='ASSIST RSPCA WITH CAT STUCK ON CHIMNEY'),
Row(incident_number='9393131', animal_group='Dog', total_cost=260.0, description='DOG FALLEN INTO THE CANAL')]
Even just collecting one column still returns a list of Row
objects. You can use an identity .flatMap()
to return a normal Python list. Note that .flatMap()
is an RDD method and so you need to convert the DF to an RDD first with .rdd
:
rescue.select("incident_number").limit(5).rdd.flatMap(lambda x: x).collect()
['80771131', '141817141', '143166-22102016', '43051141', '9393131']
Although it is much easier in pandas by converting the one column DataFrame into a series with .squeeze()
(or by selecting the column name for implicit conversion) then using .tolist()
:
rescue.select("incident_number").limit(5).toPandas().squeeze().tolist()
['80771131', '141817141', '143166-22102016', '43051141', '9393131']
rescue.select("incident_number").limit(5).toPandas()["incident_number"].tolist()
['80771131', '141817141', '143166-22102016', '43051141', '9393131']
The main use case for .collect()
is returning a value from a one row DataFrame as a scalar value and then assigning to a variable:
sum_cost = (rescue
.agg(F.sum("total_cost"))
.collect()[0][0])
sum_cost
2012231.0
.take()
: combine .limit()
and .collect()
#
An alternative to .collect()
is to use .take(num)
which will only return a maximum of num
rows (or the whole DataFrame if the number of rows is less than num
). This is equivalent to .limit(num).collect()
. Remember that using .collect()
will bring the entire DataFrame into the driver, so if it is large you need to reduce the size first. Using .take()
removes this problem.
.take()
has no default value, so you must supply num
(unlike .show()
, which has a default of 20
).
Just like .collect()
you can assign the output of .take()
to a variable or print out directly:
rescue_take = rescue.select("incident_number", "animal_group", "total_cost", "description").take(5)
rescue_take
[Row(incident_number='80771131', animal_group='Cat', total_cost=290.0, description='CAT TRAPPED IN BASEMENT'),
Row(incident_number='141817141', animal_group='Horse', total_cost=590.0, description='HORSE TRAPPED IN GATE'),
Row(incident_number='143166-22102016', animal_group='Bird', total_cost=326.0, description='PIGEON WITH WING IMAPLED ON SHARP IMPLEMENT UNDER A BRIDGE NEAR'),
Row(incident_number='43051141', animal_group='Cat', total_cost=295.0, description='ASSIST RSPCA WITH CAT STUCK ON CHIMNEY'),
Row(incident_number='9393131', animal_group='Dog', total_cost=260.0, description='DOG FALLEN INTO THE CANAL')]
.first()
: return one row#
If you only want one row, use .first()
. The advantage of .first()
over .take(1)
or .limit(1).collect()
is that a single Row
object is returned, rather than a Row
object within a list of length 1. This means the code looks neater if you only want to extract one value.
As PySpark DataFrames are not ordered by default .first()
is non-deterministic, and so it is generally combined with .orderBy()
:
rescue_first = (rescue
.select("incident_number", "animal_group", "total_cost", "description")
.orderBy(F.desc("total_cost"))
.first())
rescue_first
Row(incident_number='098141-28072016', animal_group='Cat', total_cost=3912.0, description='CAT STUCK WITHIN WALL SPACE RSPCA IN ATTENDANCE')
This Row
object is not in a list, and so values can be referenced directly:
top_cost = rescue_first["total_cost"]
top_cost
3912.0
Note that if the DataFrame is empty this will return None
:
first_none = (rescue
.filter(F.col("animal_group") == "Dragon")
.first())
print(first_none)
None
Whereas .collect()
and .take()
will return an empty list:
collect_empty = (rescue
.filter(F.col("animal_group") == "Dragon")
.collect())
collect_empty
[]
.head()
: a confusing function#
Another method included for completeness is .head()
. This works in the same way as .take()
if you specify the number of rows (returns a list of Row
objects, even if it only has one row), or .first()
if not (returns a single Row
object). This could potentially get confusing so it is better to just use .take()
or .first()
instead.
# Same as .first()
rescue_blank = rescue.select("incident_number", "animal_group", "total_cost", "description").head()
# Returns a Row object
rescue_blank
Row(incident_number='80771131', animal_group='Cat', total_cost=290.0, description='CAT TRAPPED IN BASEMENT')
# Same as .take(1)
rescue_blank = rescue.select("incident_number", "animal_group", "total_cost", "description").head(1)
# Returns a list of Row objects (only one item in this case)
rescue_blank
[Row(incident_number='80771131', animal_group='Cat', total_cost=290.0, description='CAT TRAPPED IN BASEMENT')]
Source Code#
For these functions it is interesting to take a look at the PySpark source code; normally all you need is in the documentation, but here you can see how the methods actually work. For instance, df.take(10)
is the same as df.limit(10).collect()
.
Further Resources#
Spark at the ONS Articles:
RDDs (Resilient Distributed Datasets)
PySpark Documentation:
Spark Configuration: details of
spark.sql.repl.eagerEval.enabled
,spark.sql.repl.eagerEval.maxNumRows
andspark.sql.repl.eagerEval.truncate
PySpark Source Code: