@Version : 4.5.0
@Build : 94d077c24
By using this site, you acknowledge that you have read and understand the Cookie Policy, Privacy Policy, and the Terms. Close

Learn Spark with me

Posted Saturday, November 7th, 2020

PythonApache Spark
Learn Spark with me

Prerequisites

Here are a few things you will need for this tutorial.

  • GitHub Repo(GitHub repo) contains all datasets and notebooks
  • A computer with Spark installed. Links to install attached below.

What is Spark?

Spark general-purpose distributed computing engine that is often used as a unified analytics engine for large-scale data processing. With the emergence of planet-scale big data generated from different sectors of our human daily digital activities, actions, and footprints, tools like Spark have emerged to make the data easily loadable, transformable, and queryable among other tasks. This is how business and big tech companies transform huge data to be able to comprehend the insights hidden in the data.

Among the things you can do with Spark include Processing streaming data like logs and Machine Learning

Use cases

Spark is useful for SQL, Stream processing, and Machine Learning and Graph computations. I found this Google article that explained how Spark is useful in simple words. I recommend you read it because it contains some really good insights.

Apache Spark is written in Scala and supports different languages including Java, Python, R, Scala, etc. For this, I will use the Python PySpark library. All my examples will be in Python.

Installing Spark and PySpark

For installing PySpark, I used the tutorial here to set up Spark on Mac OSX. For other platforms like Linux, you can try this one

At the end of the installation, you should be able to issue the pyspark command in your terminal and start a Jupiter session and open a browser where you can create notebooks.

~/Code/learning/spark:- pyspark                                                                                                                                                    ─╯
[I 20:38:31.625 NotebookApp] JupyterLab extension loaded from /Users/iam/opt/anaconda3/lib/python3.8/site-packages/jupyterlab
[I 20:38:31.625 NotebookApp] JupyterLab application directory is /Users/iam/opt/anaconda3/share/jupyter/lab
[I 20:38:31.628 NotebookApp] Serving notebooks from local directory: /Volumes/Code/learning/spark
[I 20:38:31.628 NotebookApp] The Jupyter Notebook is running at:
[I 20:38:31.628 NotebookApp] http://localhost:8888/?token=ddaa7bad24bf21d32b35dcc5260df6152badc4caf698f022
[I 20:38:31.628 NotebookApp]  or http://127.0.0.1:8888/?token=ddaa7bad24bf21d32b35dcc5260df6152badc4caf698f022
[I 20:38:31.628 NotebookApp] Use Control-C to stop this server and shut down all kernels (twice to skip confirmation).

For ease of setup, you can clone the repository and start your shell at its root then start pyspark from there. This way all the notebooks here will work out of the box without modification unless you are on a future version of the PySpark stack. This will also save you time from downloading the dataset.

Data Sources

You can load data into spark from different sources. The ones we will look at here are:

  • JSON file
  • CSV file
  • JDBC Connection.

Spark Data Formats

Spark supports different data structures into which we can load our data, mainly three.

  • RDD
  • DataFrame
  • DataSet

There is a good comparison here between them with details on performance and features.

Both DataFrame and DataSet formats are built on top of the RDD format. So, for this post, I will focus on DataFrame in my examples. Usually, the choice of data format to use depends on the kind of transforms and control you want.

Sample Datasets

I have added some datasets here downloaded from Kaggle. These are the ones I will use throughout the tutorial.

Download the datasets from the GitHub repo for this tutorial. You can also use the provided Kaggle links to download the files. As for the population dataset, I made a few changes.

Loading Data Into Spark

Let us see how to load data of different formats into a Spark DataFrame that we can then interact with as we desire.

Loading CSV Data

Let us now load our sample data into spark. I will start with the Covid-19 CSV Data. To load CSV file into Spark, we just give pyspark load function the path to the file.

import pyspark
from pyspark import SQLContext
sc = SparkContext.getOrCreate();
sql = SQLContext(sc)

df = (sql.read
         .format("com.databricks.spark.csv")
         .option("header", "true")
         .load("covid_dataset/country_wise_latest.csv"))

df.printSchema()
"""
root
 |-- Country/Region: string (nullable = true)
 |-- Confirmed: string (nullable = true)
 |-- Deaths: string (nullable = true)
 |-- Recovered: string (nullable = true)
 |-- Active: string (nullable = true)
 |-- New cases: string (nullable = true)
 |-- New deaths: string (nullable = true)
 |-- New recovered: string (nullable = true)
 |-- Deaths / 100 Cases: string (nullable = true)
 |-- Recovered / 100 Cases: string (nullable = true)
 |-- Deaths / 100 Recovered: string (nullable = true)
 |-- Confirmed last week: string (nullable = true)
 |-- 1 week change: string (nullable = true)
 |-- 1 week % increase: string (nullable = true)
 |-- WHO Region: string (nullable = true)
"""
df.show()
"""
 +--------------+---------+------+---------+------+---------+----------+-------------+------------------+---------------------+----------------------+-------------------+-------------+-----------------+--------------------+
|Country/Region|Confirmed|Deaths|Recovered|Active|New cases|New deaths|New recovered|Deaths / 100 Cases|Recovered / 100 Cases|Deaths / 100 Recovered|Confirmed last week|1 week change|1 week % increase|          WHO Region|
+--------------+---------+------+---------+------+---------+----------+-------------+------------------+---------------------+----------------------+-------------------+-------------+-----------------+--------------------+
|   Afghanistan|    36263|  1269|    25198|  9796|      106|        10|           18|               3.5|                69.49|                  5.04|              35526|          737|             2.07|Eastern Mediterra...|
|       Albania|     4880|   144|     2745|  1991|      117|         6|           63|              2.95|                56.25|                  5.25|               4171|          709|             17.0|              Europe|
|       Algeria|    27973|  1163|    18837|  7973|      616|         8|          749|              4.16|                67.34|                  6.17|              23691|         4282|            18.07|              Africa|
+--------------+---------+------+---------+------+---------+----------+-------------+------------------+---------------------+----------------------+-------------------+-------------+-----------------+--------------------+
"""

Loading JSON Data

Let us now load the world population JSON data. To load a JSON file into Spark, we just give pyspark load function the path to the file. For JSON data we have to define a schema to target.

import pyspark.sql.types as Types
from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName(appName) \
    .master(master) \
    .getOrCreate()

schema = Types.StructType([
    Types.StructField('Country/Region', Types.StringType(), True),
    Types.StructField('Population', Types.StringType(), True),
    Types.StructField('Urban Pop %', Types.StringType(), True),
    Types.StructField('World Share %', Types.StringType(), True),
    Types.StructField('Med. Age', Types.StringType(), True)
])

df = spark\
    .read\
    .json("world_population/population.json", schema, multiLine=True)
"""
root
 |-- Country/Region: string (nullable = true)
 |-- Population: string (nullable = true)
 |-- Urban Pop %: string (nullable = true)
 |-- World Share %: string (nullable = true)
 |-- Med. Age: string (nullable = true)
"""
df.limit(3).show()
"""
+--------------+----------+-----------+-------------+--------+
|Country/Region|Population|Urban Pop %|World Share %|Med. Age|
+--------------+----------+-----------+-------------+--------+
|         China|1440297825|         61|        18.47|      38|
|         India|1382345085|         35|        17.70|      28|
| United States| 331341050|         83|         4.25|      38|
+--------------+----------+-----------+-------------+--------+

"""

Loading Data from a DBMS (PostgreSQL)

For this, you will need to download the sample database from this link and load into your PostgreSQL instance.

First, download a version of the JDBC driver from here that spark will use to connect to the PostgreSQL instance. Place the jar file in a suitable location then start spark with the option below that matches the driver version you have downloaded.

Note that you may have connection issues and may end up having to do some digging to get it to work.

pyspark --packages org.postgresql:postgresql:42.2.16

After this just go ahead and open a new notebook and create a new DataFrame. Here I am creating a new DF from table payment in the database named dvdrental.

from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Python Spark SQL basic example") \
    .config("spark.jars", "jars/postgresql-42.2.16.jar") \
    .getOrCreate()

df = spark.read \
    .format("jdbc") \
    .option("url", "jdbc:postgresql://localhost:5432/dvdrental") \
    .option("dbtable", "payment") \
    .option("user", "postgres") \
    .option("password", "postgres") \
    .option("driver", "org.postgresql.Driver") \
    .load()

df.printSchema()
"""
root
 |-- payment_id: integer (nullable = true)
 |-- customer_id: short (nullable = true)
 |-- staff_id: short (nullable = true)
 |-- rental_id: integer (nullable = true)
 |-- amount: decimal(5,2) (nullable = true)
 |-- payment_date: timestamp (nullable = true)
"""

On the DBMS, you can also execute raw SQL to load the data. You do so by registering a view table in the loaded spark data frame.

from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Python Spark SQL basic example") \
    .config("spark.jars", "jars/postgresql-42.2.16.jar") \
    .getOrCreate()

df = spark.read \
    .format("jdbc") \
    .option("url", "jdbc:postgresql://localhost:5432/dvdrental") \
    .option("dbtable", "payment") \
    .option("user", "postgres") \
    .option("password", "postgres") \
    .option("driver", "org.postgresql.Driver") \
    .load()

df = df.createOrReplaceTempView("payment_table_view")
df = spark.sql("Select * from payment_table_view where amount > 2.00")
df.limit(3).show()
"""
+----------+-----------+--------+---------+------+--------------------+
|payment_id|customer_id|staff_id|rental_id|amount|        payment_date|
+----------+-----------+--------+---------+------+--------------------+
|     17503|        341|       2|     1520|  7.99|2007-02-15 22:25:...|
|     17505|        341|       1|     1849|  7.99|2007-02-16 22:41:...|
|     17506|        341|       2|     2829|  2.99|2007-02-19 19:39:...|
+----------+-----------+--------+---------+------+--------------------+
"""

Select, Limit, Show

You use select the desired columns, it takes one column or a list of columns. You can also limit the number of rows to fetch. The show function prints the columns to the console. You can select a single column or a list of columns. Select can also be used to select nested columns as seen in TODO: from_json section.

import pyspark
from pyspark import SQLContext
sc = SparkContext.getOrCreate();
sql = SQLContext(sc)

df = (sql.read
         .format("com.databricks.spark.csv")
         .option("header", "true")
         .load("covid_dataset/country_wise_latest.csv"))
df.select("Country/Region").limit(3).show()
"""
+--------------+
|Country/Region|
+--------------+
|   Afghanistan|
|       Albania|
|       Algeria|
+--------------+
"""
df.select(["Country/Region", "Deaths", "Recovered"]).limit(3).show()
"""
+--------------+------+---------+
|Country/Region|Deaths|Recovered|
+--------------+------+---------+
|   Afghanistan|  1269|    25198|
|       Albania|   144|     2745|
|       Algeria|  1163|    18837|
+--------------+------+---------+
"""

Where, Filter

You can use where or filter functions to apply conditions to get a desired state of the DataFrame.

import pyspark
import pyspark.sql.functions as F
from pyspark import SQLContext
sc = SparkContext.getOrCreate();
sql = SQLContext(sc)

df = (sql.read
         .format("com.databricks.spark.csv")
         .option("header", "true")
         .load("covid_dataset/country_wise_latest.csv"))

deathsG30k = df.where(F.col("Deaths") > 30000)
deathsG30k.select(["Country/Region", "Deaths"]).show()
"""
+--------------+------+
|Country/Region|Deaths|
+--------------+------+
|        Brazil| 87618|
|        France| 30212|
|         India| 33408|
|         Italy| 35112|
|        Mexico| 44022|
|            US|148011|
|United Kingdom| 45844|
+--------------+------+
"""
deathsG30kAndInEurope = deathsG30k.filter(F.col("WHO Region") == "Europe")
deathsG30kAndInEurope.select(["Country/Region", "Deaths"]).show()
"""
+--------------+------+
|Country/Region|Deaths|
+--------------+------+
|        France| 30212|
|         Italy| 35112|
|United Kingdom| 45844|
+--------------+------+
"""

Casting, Renaming

Casting is used to change the data type of a field.

import pyspark
import pyspark.sql.functions as F
from pyspark import SQLContext
sc = SparkContext.getOrCreate();
sql = SQLContext(sc)

df = (sql.read
         .format("com.databricks.spark.csv")
         .option("header", "true")
         .load("covid_dataset/country_wise_latest.csv"))

df.printSchema()
"""
root
 |-- Country/Region: string (nullable = true)
 |-- Confirmed: string (nullable = true)
 |-- Deaths: string (nullable = true)
 |-- Recovered: string (nullable = true)
...
"""
df = df.withColumn("Deaths", F.col("Deaths").cast('int'))

df.printSchema()
"""
root
 |-- Country/Region: string (nullable = true)
 |-- Confirmed: string (nullable = true)
 |-- Deaths: integer (nullable = true)
 |-- Recovered: string (nullable = true)
...
"""
df = df.withColumnRenamed("Recovered", "Recoveries")

df.printSchema()
"""
root
 |-- Country/Region: string (nullable = true)
 |-- Confirmed: string (nullable = true)
 |-- Deaths: integer (nullable = true)
 |-- Recoveries: string (nullable = true)
...
"""

Order, Sort

You can sort a DataFrame by any column and the rules are just as they apply in SQL.

import pyspark
import pyspark.sql.types as Types
from pyspark import SQLContext
from pyspark.sql import SparkSession

sc = SparkContext.getOrCreate();
sql = SQLContext(sc)

df = (sql.read
         .format("com.databricks.spark.csv")
         .option("header", "true")
         .load("covid_dataset/country_wise_latest.csv"))


spark = SparkSession.builder \
    .appName(appName) \
    .master(master) \
    .getOrCreate()

schema = Types.StructType([
    Types.StructField('Country/Region', Types.StringType(), True),
    Types.StructField('Population', Types.StringType(), True),
    Types.StructField('Urban Pop %', Types.StringType(), True),
    Types.StructField('World Share %', Types.StringType(), True),
    Types.StructField('Med. Age', Types.StringType(), True)
])

df = spark\
    .read\
    .json("world_population/population.json", schema, multiLine=True)

df.orderBy(df['Population'].cast("int").desc()).limit(4).show()
df.sort(df['Population'].cast("int").desc()).limit(4).show()
"""
+--------------+----------+-----------+-------------+--------+
|Country/Region|Population|Urban Pop %|World Share %|Med. Age|
+--------------+----------+-----------+-------------+--------+
|         China|1440297825|         61|        18.47|      38|
|         India|1382345085|         35|        17.70|      28|
| United States| 331341050|         83|         4.25|      38|
|     Indonesia| 274021604|         56|         3.51|      30|
+--------------+----------+-----------+-------------+--------+
"""
df.orderBy(df['Population'].cast("int").asc()).limit(4).show()
df.sort(df['Population'].cast("int").asc()).limit(4).show()
"""
+----------------+----------+-----------+-------------+--------+
|  Country/Region|Population|Urban Pop %|World Share %|Med. Age|
+----------------+----------+-----------+-------------+--------+
|        Holy See|       801|       N.A.|         0.00|    N.A.|
|         Tokelau|      1360|          0|         0.00|    N.A.|
|            Niue|      1628|         46|         0.00|    N.A.|
|Falkland Islands|      3497|         66|         0.00|    N.A.|
+----------------+----------+-----------+-------------+--------+
"""

df = (sql.read
         .format("com.databricks.spark.csv")
         .option("header", "true")
         .load("world_population/worldcities.csv"))

df = df.where(df['population'].cast("int") > 10000000)

df.orderBy(df['country'].desc(), df['population'].cast("int").desc()).limit(6)\
  .select(["city", "country", "population"]).show()
"""
+-----------+-------------+----------+
|       city|      country|population|
+-----------+-------------+----------+
|   New York|United States|19354922.0|
|Los Angeles|United States|12815475.0|
|   Istanbul|       Turkey|  10061000|
|     Moscow|       Russia|  10452000|
|     Manila|  Philippines|  11100000|
|    Karachi|     Pakistan|  12130000|
+-----------+-------------+----------+
"""
df.sort(df['country'].desc(), df['population'].cast("int").asc()).limit(6)\
  .select(["city", "country", "population"]).show()
"""
+-----------+-------------+----------+
|       city|      country|population|
+-----------+-------------+----------+
|Los Angeles|United States|12815475.0|
|   New York|United States|19354922.0|
|   Istanbul|       Turkey|  10061000|
|     Moscow|       Russia|  10452000|
|     Manila|  Philippines|  11100000|
|    Karachi|     Pakistan|  12130000|
+-----------+-------------+----------+
"""

As you can see, you can sort by many fields where ties in the previous sort are sorted by the next field and so on. You can also chain sorts because both Sort and OrderBy functions both return data frames

Joins

Joining works just like SQL. Let us join the COVID data and the population data using the Country/Region column.

import pyspark
from pyspark import SQLContext
import pyspark.sql.functions as F
import pyspark.sql.types as Types
sc = SparkContext.getOrCreate();
SQL = SQLContext(sc)

covid19Df = (sql.read
         .format("com.databricks.spark.csv")
         .option("header", "true")
         .load("covid_dataset/country_wise_latest.csv"))

schema = Types.StructType([
    Types.StructField('Country/Region', Types.StringType(), True),
    Types.StructField('Population', Types.StringType(), True),
    Types.StructField('Urban Pop %', Types.StringType(), True),
    Types.StructField('World Share %', Types.StringType(), True),
    Types.StructField('Med. Age', Types.StringType(), True)
])

populationsDf = spark\
    .read\
    .json("world_population/population.json", schema, multiLine=True)\
    .withColumnRenamed("Country/Region", "CountryName")

joinedDf = populationsDf\
  .join(covid19Df, (populationsDf.CountryName == covid19Df["Country/Region"]))

joinedDf\
  .select(["CountryName" ,"Population", "Urban Pop %", "Confirmed", "Recovered", "Deaths"])\
  .sort(F.col("Confirmed").cast("int").desc())\
  .limit(5)\
  .show()
"""
+-------------+----------+-----------+---------+---------+------+
|  CountryName|Population|Urban Pop %|Confirmed|Recovered|Deaths|
+-------------+----------+-----------+---------+---------+------+
|United States| 331341050|         83|  4290259|  1325804|148011|
|       Brazil| 212821986|         88|  2442375|  1846641| 87618|
|        India|1382345085|         35|  1480073|   951166| 33408|
|       Russia| 145945524|         74|   816680|   602249| 13334|
| South Africa|  59436725|         67|   452529|   274925|  7067|
+-------------+----------+-----------+---------+---------+------+
"""

You can also join by multiple columns by having two logical operations.

Count

Just like in SQL, Spark DataFrames have a count function that returns the number of row.

import pyspark
from pyspark import SQLContext
sc = SparkContext.getOrCreate();
sql = SQLContext(sc)

df = (sql.read
         .format("com.databricks.spark.csv")
         .option("header", "true")
         .load("covid_dataset/country_wise_latest.csv"))

df.count() #187
df = (sql.read
         .format("com.databricks.spark.csv")
         .option("header", "true")
         .load("world_population/worldcities.csv"))
df.count() # 15493

Grouping and Aggregates

For grouping, we use the groupBy function. You Can group data in Spark and perform common aggregates like the ones listed below:

  • Count the rows for each group.
  • Find the mean values for each group.
  • Find the maximum values for each group.
  • Find the minimum values for each group.
  • Find the total values for each group.
  • Find the average for values for each group.

Each of the above operations is available as a function of the grouped data.

The result from the groupBy function also has the agg function that allows multiple aggregates with aliasing.

Here is a complete notebook with some groupings.

import pyspark
from pyspark import SQLContext
import pyspark.sql.functions as F
import pyspark.sql.types as Types
sc = SparkContext.getOrCreate();
sql = SQLContext(sc)

citiesPopDf = (sql.read
         .format("com.databricks.spark.csv")
         .option("header", "true")
         .load("world_population/worldcities.csv"))\
         .withColumnRenamed("Population", "city_population")\
         .where(F.col("city_population").cast("int") > 0)

schema = Types.StructType([
    Types.StructField('Country/Region', Types.StringType(), True),
    Types.StructField('Population', Types.StringType(), True),
    Types.StructField('Urban Pop %', Types.StringType(), True),
    Types.StructField('World Share %', Types.StringType(), True),
    Types.StructField('Med. Age', Types.StringType(), True)
])

countriesPopDf = spark\
    .read\
    .json("world_population/population.json", schema, multiLine=True)\
    .withColumnRenamed("Population", "country_population")

joinedDf = citiesPopDf\
   .join(countriesPopDf, (citiesPopDf.country == countriesPopDf["Country/Region"]))

joinedDf\
    .groupBy("country")\
    .count().sort(F.col("count").desc())\
    .limit(5)\
    .show()
"""
+-------------+-----+
|      country|count|
+-------------+-----+
|United States| 7328|
|       Russia|  564|
|        China|  392|
|       Brazil|  387|
|       Canada|  249|
+-------------+-----+
"""

joinedDf\
    .sort(F.col("country_population").cast("int").asc())\
    .groupBy("country")\
    .agg(\
         F.count("city").alias("Cities"),\
         F.sum(F.col("city_population").cast("int")).alias("urnab_pop"),\
         F.max(F.col("city_population").cast("int")).alias("most_pop_in_one_city"),\
         F.min(F.col("city_population").cast("int")).alias("min_pop_in_one_city"),\
         F.avg(F.col("city_population").cast("int")).alias("avg_city_pop")\
        )\
    .sort(F.col("urnab_pop").desc())\
    .limit(5)\
    .show()
"""
+-------------+------+---------+--------------------+-------------------+------------------+
|      country|Cities|urnab_pop|most_pop_in_one_city|min_pop_in_one_city|      avg_city_pop|
+-------------+------+---------+--------------------+-------------------+------------------+
|United States|  7328|390924051|            19354922|               1991|53346.622680131004|
|        China|   392|358546021|            14987000|                100| 914658.2168367347|
|        India|   212|204338075|            18978000|              10688| 963858.8443396227|
|       Brazil|   387|127259225|            18845000|                956| 328835.2067183463|
|        Japan|    69| 89712598|            35676000|              82335| 1300182.579710145|
+-------------+------+---------+--------------------+-------------------+------------------+
"""

It is also important to mention that you can still join an aggregate like the one above with another DataFrame as it is also just a DataFrame.

Window Functions

Window functions operate on a group of rows, referred to as a window, and calculate a return value for each row based on the group of rows. For example, we need to rank each city by population by country. To do this, we take these steps.

  • Group cities by country
  • Assign ranks on population in descending order
  • Select where rank is 1
import pyspark
from pyspark import SQLContext
from pyspark.sql.window import Window
import pyspark.sql.functions as F
import pyspark.sql.types as Types
sc = SparkContext.getOrCreate();
sql = SQLContext(sc)

citiesPopDf = (sql.read
         .format("com.databricks.spark.csv")
         .option("header", "true")
         .load("world_population/worldcities.csv"))\
         .withColumn("city_population", F.col("Population").cast("float"))\
         .where(F.col("city_population").cast("int") > 0)

citiesIthRanksDf = citiesPopDf\
  .withColumn("rank", \
    F.rank().over(Window.partitionBy("country")\
    .orderBy(F.col("city_population").desc())))

citiesIthRanksDf\
.sort(F.col("city_population").desc())\
.where(F.col('rank') == 1)\
.select(["city", "country", "city_population"])\
.limit(10)\
.show()

"""
+------------+-------------+---------------+
|        city|      country|city_population|
+------------+-------------+---------------+
|       Tokyo|        Japan|       3.5676E7|
|    New York|United States|    1.9354922E7|
| Mexico City|       Mexico|       1.9028E7|
|      Mumbai|        India|       1.8978E7|
|   São Paulo|       Brazil|       1.8845E7|
|    Shanghai|        China|       1.4987E7|
|       Dhaka|   Bangladesh|    1.2797394E7|
|Buenos Aires|    Argentina|       1.2795E7|
|     Karachi|     Pakistan|        1.213E7|
|       Cairo|        Egypt|       1.1893E7|
+------------+-------------+---------------+
"""

To see the second populated city by country

citiesIthRanksDf\
.sort(F.col("city_population").desc())\
.where(F.col('rank') == 2)\
.select(["city", "country", "city_population"])\
.limit(10)\
.show()

"""
+----------------+-------------+---------------+
|            city|      country|city_population|
+----------------+-------------+---------------+
|           Delhi|        India|       1.5926E7|
|     Los Angeles|United States|    1.2815475E7|
|  Rio de Janeiro|       Brazil|       1.1748E7|
|           Ōsaka|        Japan|       1.1294E7|
|         Beijing|        China|       1.1106E7|
|          Lahore|     Pakistan|      6577000.0|
|       Barcelona|        Spain|      4920000.0|
|Saint Petersburg|       Russia|      4553000.0|
|      Chittagong|   Bangladesh|      4529000.0|
|           Hanoi|      Vietnam|      4378000.0|
+----------------+-------------+---------------+
"""

Window functions are very handy. You can read more about them here.

Other Spark Functions

Spark supports other functions that are useful for doing transformations. Here is a list of all the built-in functions. While using PySpark, these Spark BIFs are accessible on the module pyspark.sql.function Below are a few examples.

concat

Here is an example of using concat and concat_ws to join string columns of a dataframe.

import pyspark
from pyspark import SQLContext
import pyspark.sql.functions as F
sc = SparkContext.getOrCreate();
sql = SQLContext(sc)

df = (sql.read
         .format("com.databricks.spark.csv")
         .option("header", "true")
         .load("covid_dataset/country_wise_latest.csv"))

df.printSchema()
df.withColumn("CountryWHORegion", F.concat(F.col("Country/Region"), F.col('WHO Region')))\
.limit(3)\
.select(['Country/Region', 'WHO Region', 'CountryWHORegion']).show()
"""
+--------------+--------------------+--------------------+
|Country/Region|          WHO Region|    CountryWHORegion|
+--------------+--------------------+--------------------+
|   Afghanistan|Eastern Mediterra...|AfghanistanEaster...|
|       Albania|              Europe|       AlbaniaEurope|
|       Algeria|              Africa|       AlgeriaAfrica|
+--------------+--------------------+--------------------+
"""
df.withColumn("Country-WHO-Region", F.concat_ws('-', F.col("Country/Region"), F.col('WHO Region')))\
.limit(3)\
.select(['Country/Region', 'WHO Region', 'Country-WHO-Region']).show()
"""
+--------------+--------------------+--------------------+
|Country/Region|          WHO Region|  Country-WHO-Region|
+--------------+--------------------+--------------------+
|   Afghanistan|Eastern Mediterra...|Afghanistan-Easte...|
|       Albania|              Europe|      Albania-Europe|
|       Algeria|              Africa|      Algeria-Africa|
+--------------+--------------------+--------------------+
"""

User Defined Functions

Sometimes you need to run a custom function on a data frame. This is where UDFs come in. You create a function, register it then you can call it.

Let's say for example in the cities' populations dataset, the population values are floats and thus do not represent the real world because there is no half or third of a person. To convert them to integer values, we can use a Custom function.

Here we register a function called toInt and register it with Spark UDF as toInteger then use it.

import pyspark
from pyspark import SQLContext
import pyspark.sql.functions as F
import pyspark.sql.types as T

from pyspark.sql.functions import udf

sc = SparkContext.getOrCreate();
sql = SQLContext(sc)

df = (sql.read
         .format("com.databricks.spark.csv")
         .option("header", "true")
         .load("world_population/worldcities.csv"))\
         .withColumn("city_population", F.col("population").cast("float"))\
         .where(F.col("city_population").cast("int") > 0)

toInteger = udf(lambda z: toInt(z), T.IntegerType())
spark.udf.register("toInteger", toInteger)

def toInt(s):
   try:
     return int(float(s))
   except:
     return 0
df.withColumn("TruncatedPopulation", toInteger('Population'))\
.limit(5)\
.select(['country', 'city', 'TruncatedPopulation']).show()
"""
+-------------+-----------+-------------------+
|      country|       city|TruncatedPopulation|
+-------------+-----------+-------------------+
|        Japan|      Tokyo|           35676000|
|United States|   New York|           19354922|
|       Mexico|Mexico City|           19028000|
|        India|     Mumbai|           18978000|
|       Brazil|  São Paulo|           18845000|
+-------------+-----------+-------------------+
"""

UDFs are known to have performance issues because for example in Python it involves constant movement of data between JVM and Python interpreter. In most cases, you should be able to achieve what you need to use a UDF by combining a series of Spark SQL functions.

Running raw SQL

Running raw SQL as we have seen above in the JDBC source section is done by registering a temporary view table that we can then execute SQL on.

For example here is me finding all cities in Germany where the population is greater than 2 million people.

import pyspark
from pyspark import SQLContext
sc = SparkContext.getOrCreate();
sql = SQLContext(sc)


df = (sql.read
         .format("com.databricks.spark.csv")
         .option("header", "true")
         .load("world_population/worldcities.csv"))
df.createOrReplaceTempView("tmp_cities_view_table")

df = spark.sql("""
                SELECT
                 city, country, population
                FROM tmp_cities_view_table
                WHERE 
                 country = 'Germany' and CAST(population AS integer) > 2000000
               """)
df.show()
"""
+---------+-------+----------+
|     city|country|population|
+---------+-------+----------+
|   Berlin|Germany|   3406000|
|Stuttgart|Germany|   2944700|
|Frankfurt|Germany|   2895000|
| Mannheim|Germany|   2362000|
+---------+-------+----------+
"""

Mapping

A map transform is applicable when you want to run each row through a function that returns a new set of rows with most likely a different schema. Pyspark DataFrame does not have a map transform function. To apply a mapping we have to convert the DataFrame to RDD data format which has a map method.

Here is an example using RDD to run a map function and Spark Row to infer a new schema from the dictionary returned by the map function.

import pyspark
from pyspark import SQLContext
import pyspark.sql.functions as F
from pyspark.sql import Row
sc = SparkContext.getOrCreate();
sql = SQLContext(sc)

df = (sql.read
         .format("com.databricks.spark.csv")
         .option("header", "true")
         .load("covid_dataset/country_wise_latest.csv"))

def map_function(x):
    rec = {}
    rec['country_name'] = x['Country/Region']
    rec['total_cases'] = x['Confirmed']
    rec['total_deaths'] = int(x['Deaths'])
    rec['total_recoveries'] = x['Recovered']
    rec['recovery_rate'] = ((int(x['Recovered'])) /int(x['Confirmed'])) * 100
    rec['death_rate'] = ((int(x['Deaths'])) /int(x['Confirmed'])) * 100
    return rec
rdd=df.rdd.map(lambda x: Row(**dict(map_function(x))))
df2=rdd.toDF()
df2.sort(F.col('total_deaths').desc()).limit(10).show(truncate=False)

"""
+--------------+-----------+------------+----------------+-------------------+------------------+
|country_name  |total_cases|total_deaths|total_recoveries|recovery_rate      |death_rate        |
+--------------+-----------+------------+----------------+-------------------+------------------+
|United States |4290259    |148011      |1325804         |30.902656459668286 |3.4499315775574386|
|Brazil        |2442375    |87618       |1846641         |75.60841394134808  |3.5874097957930293|
|United Kingdom|301708     |45844       |1437            |0.47628833176448754|15.194824134593713|
|Mexico        |395489     |44022       |303810          |76.8188242909411   |11.131030193001576|
|Italy         |246286     |35112       |198593          |80.63511527248808  |14.256595990027854|
|India         |1480073    |33408       |951166          |64.26480315497952  |2.2571859631247917|
|France        |220352     |30212       |81212           |36.855576532094105 |13.71079000871333 |
|Spain         |272421     |28432       |150376          |55.19985610507266  |10.436787178668311|
|Peru          |389717     |18418       |272547          |69.93459356404776  |4.725993477318156 |
|Iran          |293606     |15912       |255144          |86.90013146870295  |5.419507775726654 |
+--------------+-----------+------------+----------------+-------------------+------------------+
"""

Writing Data - Outputs

After doing stuff to your data, you always want to write it somewhere in a format that the next stage/reader supports. Common output formats include:

  • CSV
  • JSON
  • Parquet
  • ORC etc

Here is an example to write to these output formats.

import pyspark
from pyspark import SQLContext
import pyspark.sql.functions as F
sc = SparkContext.getOrCreate();
sql = SQLContext(sc)

df = (sql.read
         .format("com.databricks.spark.csv")
         .option("header", "true")
         .load("world_population/worldcities.csv"))\
         .withColumnRenamed("Population", "city_population")\
         .where(F.col("city_population").cast("int") > 0)

df.repartition(1).write.option("header", "true").csv('outputs/countries-population-csv')

df.repartition(1).write.json("outputs/countries-population-json")

df.repartition(1).write.parquet("outputs/countries-population-parquet")

df.repartition(1).write.orc("outputs/countries-population-orc")

At this point, you should see a few files written under the directory named outputs

Spark DataFrameWriter supports the following modes.

  • overwrite – mode is used to overwrite the existing file
  • append – To add the data to the existing file
  • ignore – Ignores write operation when the file already exists

Below is an example of overwriting files.

import pyspark
from pyspark import SQLContext
import pyspark.sql.functions as F
sc = SparkContext.getOrCreate();
sql = SQLContext(sc)

df = (sql.read
         .format("com.databricks.spark.csv")
         .option("header", "true")
         .load("world_population/worldcities.csv"))\
         .withColumnRenamed("Population", "city_population")\
         .where(F.col("city_population").cast("int") > 0)


df.repartition(1).write.mode('overwrite').option("header", "true").csv('outputs/countries-population-csv')

df.repartition(1).write.mode('overwrite').json("outputs/countries-population-json")

df.repartition(1).write.mode('overwrite').parquet("outputs/countries-population-parquet")

df.repartition(1).write.mode('overwrite').orc("outputs/countries-population-orc")

Writing partitioned data

It is also possible to write partitioned files. For example, we can write the Covid dataset to partitions by WHO region. This would produce subdirectories that have Key=Value as per the partition key. The keys will be nested when more partition keys are given.

import pyspark
from pyspark import SQLContext
import pyspark.sql.functions as F
sc = SparkContext.getOrCreate();
sql = SQLContext(sc)

df = (sql.read
         .format("com.databricks.spark.csv")
         .option("header", "true")
         .load("covid_dataset/country_wise_latest.csv"))

df.write.partitionBy('WHO Region')\
.mode('overwrite')\
.json("outputs/countries-covid-json")

Managed Spark, Provisioning Spark

There are many good managed services that gives you access to Spark at generally good costs and without the overhead of setting up and maintaining the cluster yourself. Here are a few places where you can spin up Spark instances or jobs in Cloud. Who does not like Cloud :-)

Resources.

Here is a list of some very good resources on Spark Python API.



Thank you for finding time to read my post. I hope you found this helpful and it was insightful to you. I enjoy creating content like this for knowledge sharing, my own mastery and reference.

If you want to contribute, you can do any or all of the following 😉. It will go along way! Thanks again and Cheers!