Skip to content

Seaborn - Data Visualization

Seaborn is a statistical visualization library built on to of matplotlib, and is designed to work very well with pandas dataframe objects.

Distribution Plots

We'll start with a built-n data set in the seaborn library.

import seaborn as sns
tips = sns.load_dataset('tips')

Graph Graph

We can plot a distribution (histogram) for univariate numerical data:

sns.displot(tips['total_bill'],kde=False)
sns.displot(tips['total_bill'], bins=30) # second image below

Graph Graph

Graph Graph

For bivariate numerical data we can plot the two distributions together using jointplot.

sns.jointplot(x='total_bill',y='tip',data=tips)

Graph Graph

sns.jointplot(x='total_bill',y='tip',data=tips,kind='hex')

Graph Graph

sns.jointplot(x='total_bill',y='tip',data=tips,kind='reg')

Graph Graph

To visualize pairwise relationships between numerical data across an entire data frame we can use pairplot. This will do a jointplot for every pair of numerical columns in the data frame, and arrange the plots in a (symmetric) grid. The diagonal will just be a univariate histogram. %Using the optional hue argument on a categorical column will colour the data points according to their categorical value.

sns.pairplot(tips)

Graph Graph

sns.pairplot(tips,hue='sex',palette='coolwarm')

Graph Graph

Kernel Density Estimates

sns.rugplot(tips['total_bill'])

Graph Graph

#  Don't worry about understanding this code!
# It's just for the diagram below
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats

#Create dataset
dataset = np.random.randn(25)

# Create another rugplot
sns.rugplot(dataset);

# Set up the x-axis for the plot
x_min = dataset.min() - 2
x_max = dataset.max() + 2

# 100 equally spaced points from x_min to x_max
x_axis = np.linspace(x_min,x_max,100)

# Set up the bandwidth, for info on this:
url = 'http://en.wikipedia.org/wiki/
Kernel_density_estimation#Practical_estimation_of_the_bandwidth'

bandwidth = ((4*dataset.std()**5)/(3*len(dataset)))**.2


# Create an empty kernel list
kernel_list = []

# Plot each basis function
for data_point in dataset:

    # Create a kernel for each point and append to list
    kernel = stats.norm(data_point,bandwidth).pdf(x_axis)
    kernel_list.append(kernel)

    #Scale for plotting
    kernel = kernel / kernel.max()
    kernel = kernel * .4
    plt.plot(x_axis,kernel,color = 'grey',alpha=0.5)

plt.ylim(0,1)

Graph Graph

# To get the kde plot we can sum these basis functions.

# Plot the sum of the basis function
sum_of_kde = np.sum(kernel_list,axis=0)

# Plot figure
fig = plt.plot(x_axis,sum_of_kde,color='steelblue')

# Add the initial rugplot
sns.rugplot(dataset,c = 'steelblue')

# Get rid of y-tick marks
plt.yticks([])

# Set title
plt.suptitle("Sum of the Basis Functions")

Graph Graph

Categorical Plots

import seaborn as sns
import numpy as np
tips = sns.load_dataset('tips')

Bar Plots

# displays the average total_bill for each sex
sns.barplot(x='sex', y='total_bill', data=tips)
# displays the standard deviation of total_bill for each sex
sns.barplot(x='sex', y='total_bill', data=tips,estimator=np.std)
# count occurrences per sex in data
sns.countplot(x='sex', data=tips)

Graph Graph Graph Graph Graph Graph

Box and Whisker Plots

sns.boxplot(x='day',y='total_bill',data=tips)

Graph Graph

sns.boxplot(x='day',y='total_bill',data=tips, hue ='smoker')

Graph Graph

Violin Plots

sns.violinplot(x='day',y='total_bill',data=tips)

Graph Graph

sns.violinplot(x='day',y='total_bill',data=tips, hue='smoker')

Graph Graph

sns.violinplot(x='day',y='total_bill',data=tips, hue='smoker', split=True)

Graph Graph

Strip Plots

sns.stripplot(x='day',y='total_bill',data=tips)

Graph Graph

sns.stripplot(x='day',y='total_bill',data=tips, hue='sex', dodge=True)

Graph Graph

Swarm Plots

sns.swarmplot(x='day',y='total_bill',data=tips)

Graph Graph

sns.violinplot(x='day',y='total_bill',data=tips)
sns.swarmplot(x='day',y='total_bill',data=tips, color='black')

Graph Graph

Cat Plots

catplot is the general type of plot for categorical data. All the specific commands above at just a type of catplot.

sns.catplot(x='day', y='total_bill', data=tips, kind='bar')
sns.catplot(x='day', y='total_bill', data=tips, kind='violin')
sns.catplot(x='day', y='total_bill', data=tips, kind='strip',
hue='sex', dodge=True)

Graph Graph Graph Graph Graph Graph

Matrix Plots

import seaborn as sns
tips = sns.load_dataset('tips')
flights = sns.load_dataset('flights')
flights.head()

Graph Graph

For the plots we will explore in this section we need to restructure our tables so each row and column represent a variable. In the case of the tips data set we'll look at a simple example where we construct a correlation table. Notice each row corresponds to a variable, and so does each column.

tc = tips.corr()
tc

Graph Graph

sns.heatmap(tc,annot=True, cmap='coolwarm')

Graph Graph

We use a pivot table to restructure the flights data: rows correspond to months, columns to years, and the values come from the passengers column.

fp = flights.pivot_table(index='month', columns='year', values='passengers')
fp

Graph Graph

sns.heatmap(fp, cmap='magma', linecolor = 'white', linewidth = 1)

Graph Graph

sns.clustermap(fp, cmap='coolwarm', standard_scale = 1)

Graph Graph

Grid

import seaborn as sns
import matplotlib.pyplot as plt
iris = sns.load_dataset('iris')
iris.head()

Graph Graph

iris['species'].unique()
array(['setosa', 'versicolor', 'virginica'], dtype=object)
sns.pairplot(iris, hue='species')

Graph Graph

# create an empty grid of axes to plot on, store it variable g
g = sns.PairGrid(iris)
g.map(plt.scatter)

Graph Graph

g = sns.PairGrid(iris)
g.map_diag(sns.histplot)
g.map_upper(plt.scatter)
g.map_lower(sns.kdeplot)

Graph Graph

g = sns.PairGrid(iris, hue='species')
g.map_diag(sns.histplot)
g.map_upper(plt.scatter)
g.map_lower(sns.kdeplot)

Graph Graph

tips = sns.load_dataset('tips')
tips.head()

Graph Graph

g = sns.FacetGrid(tips, col='time', row='sex')
g.map(sns.histplot,'total_bill', color = 'steelblue')

Graph Graph

g = sns.FacetGrid(tips, col='time', row='sex')
g.map(plt.scatter,'total_bill','tip', color = 'forestgreen')

Graph Graph

Regression Plots

In this section we explore the lmplot command for producing a linear model (regression) plot over a scatter plot.

import seaborn as sns
tips = sns.load_dataset('tips')
sns.lmplot(x='total_bill', y='tip', data=tips)

Graph Graph

Under the hood lmplot is calling matplotlib so we can directly interface with the parameters using kws.

sns.lmplot(x='total_bill', y='tip', data=tips, hue='sex',
    markers=['o','v'], scatter_kws={'s':60})

Graph Graph

We can produce a FacetGrid by using the col and row parameters.

sns.lmplot(x='total_bill', y='tip', data=tips, col='sex', row='time')

Graph Graph

sns.lmplot(x='total_bill', y='tip', data=tips, col='day',
    hue='sex', aspect=0.6, height=4)

Graph Graph

Style and Colour

import seaborn as sns
tips = sns.load_dataset('tips')
sns.set_style('whitegrid')
sns.countplot(x='sex', data=tips)
sns.despine()

Graph Graph

plt.figure(figsize=(3,5))
sns.countplot(x='sex', data=tips)

Graph Graph

sns.set_context('poster', font_scale=1)
sns.countplot(x='sex', data=tips)

Graph Graph


Exercises

We will be working with a famous titanic data set for these exercises. Later on in the Machine Learning section of the course, we will revisit this data, and use it to predict survival rates of passengers. For now, we'll just focus on the visualization of the data with seaborn:

import seaborn as sns
import matplotlib.pyplot as plt
sns.set_style('whitegrid')
titanic = sns.load_dataset('titanic')
titanic.head()

Recreate the plots below using the titanic dataframe. There are very few hints since most of the plots can be done with just one or two lines of code and a hint would basically give away the solution. Keep careful attention to the x and y labels for hints.

1.

Graph Graph

2.

Graph Graph

3.

Graph Graph

4.

Graph Graph

5.

Graph Graph

6.

Graph Graph

7.

Graph Graph