Matplotlib by Example

Table of Contents
Continuing with the ‘by example’ series in data science, today I cover Matplotlib, which is the go-to visualisation library in Python, and friendly to both Numpy and Pandas, which I covered in previous ‘by example’ posts.
In this particular post, the idea is to scroll up and down, and then only pay attention to the plots that spark your interest (and their associated code).
It takes a while to achieve the kind of ‘customisation’ that one is used to in Excel, using Matplotlib, given that it requires, at times, understanding what goes on under the hood. Hopefully, these examples can be used in a more recipe style, in which it is sufficient to copy paste the code samples and change the parameters.
Essentials
Matplotlib imports and ancillary libraries.
1import matplotlib
2import matplotlib.pyplot as plt
3import matplotlib.gridspec as gridspec
4from matplotlib.ticker import MaxNLocator
5import numpy as np
6import pandas as pd
7import pandas.plotting as plotting
Plots
Dots and Lines
The plot()
function can be used to plot both single dots and lines, among others.
1# Dots
2plt.plot(1,1,".") # colour is implicit
3plt.plot(1,2,"*") # colour is implicit (note star!)
4plt.plot(2,2,".",color='red') # colour is explicit!
5plt.plot(1,2,".",color='green') # colour is explicit!
6
7# Lines specifying only Y values
8plt.plot([2,3,4])
9
10# Lines specifying X and Y values (Y is last!)
11plt.plot([2,3,4],[3,4,3])
12
13# Set axes to specific ranges and make them integers
14plt.gca().axis([0,4,0,5]) # x = 0..4, y = 0..5
15plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
16plt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))
17
18plt.show()
Applying a Mathematical Function
Rather than plotting each single point using a loop, we can precalculate the values and pass them as a list/array.
1# First generate a sequence of numbers
2
3# Low res
4numbers = np.linspace(0, 10, 8)
5wave = np.sin(numbers)
6plt.plot(numbers,wave,color='green');
7
8# High res
9numbers = np.linspace(0, 10, 100)
10wave = np.sin(numbers)
11plt.plot(numbers,wave,color='blue',alpha=0.4);
12
13
14plt.show()
Filling Areas
Sometimes we want to fill the area between two sets of values. This can be accomplished using the fill_between()
function.
1# Data sets
2v1 = (np.array([1,2,3,4,5,6,7,8])**2)
3v2 = v1*-1
4
5# Plots
6plt.plot(v1, '^',color='red')
7plt.plot(v2, 's',color='blue')
8
9# Fill area
10plt.gca().fill_between(range(len(v1)),
11 v1, v2,
12 facecolor='purple',
13 alpha=0.1)
14plt.show()
Scatter Plot
The scatter()
function provides more flexibility than plot()
when it comes to specifying the shape of each dot.
1# Values
2x = [1,2,3]
3y = [1,2,3]
4sizes = [7000.0,14000.0,25000.0]
5colours = ['red','green','blue']
6plt.scatter(x,y,s=sizes,c=colours,alpha=0.5)
7
8# Set axes to specific ranges
9plt.gca().axis([0,4,0,5]) # x = 0..4, y = 0..5
10
11plt.show()
Bar Charts
Histogram
This is a type of bar chart that is used to show the frequency distribution of a collection of values. It indicates the number of observations which fall under each bin.
1# Data set and number of bins
2x = [1,
3 2,2,
4 3,3,3,
5 4,4,4,4,
6 5,5,5,5,5,
7 6,6,6,6,6,6,
8 7,7,7,7,7,7,7,
9 8,8,8,8,8,8,8,8,
10 9,9,9,9,9,9,9,9,9]
11
12bin_n=3
13
14# hist() function
15n, bins, patches = plt.hist(x,bins=bin_n)
16
17# Use different colours for each bar
18c = 0.0
19for patch in patches:
20 if isinstance(patch,matplotlib.patches.Rectangle):
21 c = c + (1/bin_n)
22 patch.set_facecolor((c,0.0,0.0))
23
24plt.show()
Histograms are typically used to visualise normal distributions like so.
1# Hundred samples around 0.0 with a standard deviation of 1.0
2a = np.random.normal(loc=0.0,scale=1.0,size=1000)
3plt.hist(a) # 10 bins by default
4plt.show()
1# Hundred samples around 0.0 with a standard deviation of 1.0
2a = np.random.normal(loc=0.0,scale=1.0,size=1000)
3plt.hist(a,bins=50) # 50 Bins!
4plt.show()
Bar Chart
The bar()
function produces a vertical bar chart, whereas barh()
produces a horizontal one.
1# Data
2countries = ['Bolivia','India','Zimbabwe','South Africa','Switzerland']
3official_n = [37,18,16,11,4]
4colours = ['yellow','green','black','blue','red']
5x = np.arange(len(countries)) # [0..4]
6
7# Chart and Labels
8plt.xticks(x, countries)
9plt.ylabel('# Spoken languages')
10plt.title('Countries with the most official languages')
11plt.bar(x, official_n, align='center',color=colours)
12
13plt.show()
Clean Bar Chart
In this version we remove most of the noise introduced by the frame, the ticks, and the labels.
1plt.figure()
2
3# Data
4countries = ['Bolivia','India','Zimbabwe','South Africa','Switzerland']
5official_n = [37,18,16,11,4]
6colours = ['red'] + (['blue']*4) # Optionally, use bars[0].set_color()
7x = np.arange(len(countries)) # [0..4]
8
9# Chart and labels
10plt.xticks(x, countries, color='grey')
11plt.title('Countries with the most official languages',color='grey')
12bars = plt.bar(x, official_n, align='center',color=colours,linewidth=0)
13
14# Remove ticks
15plt.gca().axes.tick_params(top=False,
16 bottom=False,
17 left=False,
18 right=False,
19 labelleft=False,
20 labelbottom=True)
21
22# Hide frame
23for spine in plt.gca().spines.values():
24 spine.set_visible(False)
25
26# Place label in custom location
27for bar in bars:
28 h = bar.get_height()
29 x = bar.get_x() + bar.get_width()/2
30 y = bar.get_height() - 2
31 plt.gca().text(x,y,h,ha='center',color='w',fontsize=11,fontweight='bold')
32
33plt.show()
Box and Whisker Plots
These are plots whose aim is to show a statistical summary from a collection of values:
1 --- Q4 (100%) Maximum Value / Upper Extreme (100%)
2 |
3 |
4 --- Q3 ( 75%) Upper Quartile (75%)
5 | |
6 |-| Q2 ( 50%) Median (50%)
7 | |
8 --- Q1 ( 25%) Lower Quartile (25%)
9 |
10 |
11 --- Q0 ( 0%) Minimum Value / Lower Extreme
1# Most basic box plot where Q0 = 0, Q1 = 2.50, Q2=5, Q3=7.5 and Q4=10
2values = np.array([0,10])
3_ = plt.boxplot(values)
4plt.show()
Interquartile Range (IQR) Alert!
By default, boxplot()
will treat values that fall outside the IQR as outliers.
1# Note that -30 is an outlier!
2values = np.array([-30,0,1,2,3,4,5,6,7,8,9,10])
3_ = plt.boxplot(values)
4plt.show()
1# Add the sym="" argument to hide outliers
2values = np.array([-30,0,1,2,3,4,5,6,7,8,9,10])
3_ = plt.boxplot(values,sym="")
4plt.show()
1# Set the whis argument to a high value, or 'range' to ignore outliers
2values = np.array([-30,0,1,2,3,4,5,6,7,8,9,10])
3_ = plt.boxplot(values,whis=99999)
4plt.show()
Heatmaps
Heatmaps are similar to histograms in the sense that they show the frequency of values in a data set, grouped in bins.
The difference is that heatmaps have two dimensions, that is to say, we have two data sets rather than one: a data set for the x axis, and a data set for the y axis.
Let’s say that we wanted the heatmap to look like a cross:
1 2 *
2y 1 * * *
3 0 *
4 0 1 2
5 x
To achieve this, we need to have (x,y) pairs for each point of the star; that is to say: (0,1), (1,0), (1,1), (1,2), and (2,1).
1x = np.array([0,1,1,1,2])
2y = np.array([1,0,1,2,1])
When we visualise this as a heatmap, the frequency of each value is exactly 1.
1fig = plt.figure()
2h = plt.hist2d(x, y, bins=3)
3cbar = fig.colorbar(h[3], ax=fig.gca())
4
5# Set axes to integers
6cbar.ax.yaxis.set_major_locator(MaxNLocator(integer=True))
7plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
8plt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))
But where is the heat? The idea is that a heatmap represents the relative frequency of some values over others. Say that we wanted the center of the cross to be the hottest value. We then just need to add another (1,1) pair:
1x = np.array([0,1,1,1,1,2])
2y = np.array([1,0,1,1,2,1])
With this simple change, now the center bin represents two values, whereas the extremes represent 1.
1fig = plt.figure()
2h = plt.hist2d(x, y, bins=3)
3cbar = fig.colorbar(h[3], ax=fig.gca())
4
5# Set axes to integers
6cbar.ax.yaxis.set_major_locator(MaxNLocator(integer=True))
7plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
8plt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))
A more elaborate example shows the visualisation that results from using a gamma distribution for the x axis, and a normal distribution for the y one.
1fig = plt.figure()
2x = np.random.gamma(2, scale=1.0, size=10000)
3y = np.random.normal(loc=0.2, scale=1.0, size=10000)
4h = plt.hist2d(x, y, bins=30)
5_ = fig.colorbar(h[3], ax=fig.gca())
6_ = plt.hist2d(x,y,bins=21)
Subplots
Subplots are ‘different’ plots within the same canvas or figure.
Defining Compartments
We start working with a subplot with the plt.subplot(rows, columns, target_compartment)
function, which specifies the number of compartments within the main figure, and the target compartment to be used by the next plot.
1plt.figure()
2
3for x in range(1,7):
4 # 2 rows x 3 columns, targetting compartment x
5 plt.subplot(2, 3, x)
6 plt.plot([1,2,3,4] * x, '-o')
7
8plt.show()
Sharing Axes
Sometimes a subplot is just a means to include multiple disparate plots on a single figure. Sometimes, though, we want to contrast two sets of data against the same axes.
1# Problem. the X axes don't match
2# First figure's X axis is 0..3 while second is 0..6
3
4plt.subplot(1, 2, 1)
5plt.plot([1,2,3,4], '-o')
6plt.subplot(1, 2, 2)
7plt.plot([1,2,3,4] * 2, '-o')
8plt.show()
1# Solution
2# Add sharex (and/or sharey) argument
3
4ax = plt.subplot(1, 2, 1)
5plt.plot([1,2,3,4], '-o')
6plt.subplot(1, 2, 2, sharex = ax) # add this argument!
7plt.plot([1,2,3,4] * 2, '-o')
8plt.show()
1# Alternative approach
2# Here we plot against each compartment's axes
3
4fig, (ax1, ax2) = plt.subplots(1, 2, sharex=True, sharey=False)
5ax1.plot([1,2,3,4], '-o')
6ax2.plot([1,2,3,4] * 2, '-o')
7plt.show()
Using GridSpec
GridSpec allows finer control over the location and size of a subplot, by allowing subplots to span over two or more ‘cells’ within the grid. Let’s consider a 2 row x 3 column grid:
0 | 1 | 2 | |
---|---|---|---|
0 | a | b | c |
1 | d | e | f |
1gspec = gridspec.GridSpec(2, 3)
2a = plt.subplot(gspec[0, 0])
3b = plt.subplot(gspec[0, 1]).set_yticks([])
4c = plt.subplot(gspec[0, 2]).set_yticks([])
5d = plt.subplot(gspec[1, 0])
6e = plt.subplot(gspec[1, 1]).set_yticks([])
7f = plt.subplot(gspec[1, 2]).set_yticks([])
8plt.show()
The above can be achieved using other simpler apporaches already shown. GridSpec’s benefit is the ability to create subplots that span multiple cells.
In the below example, note that a covers both rows 0 and 1.
0 | 1 | 2 | |
---|---|---|---|
0 | a | b | c |
1 | a | e | f |
1gspec = gridspec.GridSpec(2, 3)
2a = plt.subplot(gspec[0:, 0]) # Spread from row 0 onward!
3b = plt.subplot(gspec[0, 1]).set_yticks([])
4c = plt.subplot(gspec[0, 2]).set_yticks([])
5#d = plt.subplot(gspec[1, 0]) # 'Occupied' by a!
6e = plt.subplot(gspec[1, 1]).set_yticks([])
7f = plt.subplot(gspec[1, 2]).set_yticks([])
8plt.show()
In the below example, we spread subplot d over columns 0 and 1 only.
0 | 1 | 2 | |
---|---|---|---|
0 | a | b | c |
1 | d | d | f |
1gspec = gridspec.GridSpec(2, 3)
2a = plt.subplot(gspec[0, 0])
3b = plt.subplot(gspec[0, 1]).set_yticks([])
4c = plt.subplot(gspec[0, 2]).set_yticks([])
5d = plt.subplot(gspec[1, 0:2]) # Spread from column 0 to 1!
6#e = plt.subplot(gspec[1, 1]).set_yticks([]) #'Occupied' by d
7f = plt.subplot(gspec[1, 2]).set_yticks([])
8plt.show()
In this last example, we showcase the kind of use case in which GridSpec shines.
0 | 1 | 2 | |
---|---|---|---|
0 | a | a | |
1 | b | b | c |
2 | b | b | c |
1normal_distribution = np.random.normal(loc=0.0,scale=1.0,size=1000)
2acc = 0
3cumulative = np.array([acc := acc + x for x in normal_distribution])
4
5gspec = gridspec.GridSpec(3, 3)
6a = plt.subplot(gspec[0, 0:2])
7b = plt.subplot(gspec[1:3, 0:2])
8c = plt.subplot(gspec[1:3, 2])
9
10# This is the central plot in which we plot the normal distribution
11b.plot(normal_distribution)
12b.set_xlabel("1000 samples")
13
14# Here, to the right, we plot in red the number frequency as an histogram
15c.hist(normal_distribution, orientation='horizontal',bins=50,color='r')
16#c.set_yticks([])
17c.set_ylabel("Distribution",rotation=270,labelpad=10)
18c.yaxis.set_label_position("right")
19
20# Last, on top panel a, we plot the cumulative trend for those numbers
21a.plot(cumulative, color='orange')
22a.set_xticks([])
23a.set_xlabel("Cumulative")
24a.xaxis.set_label_position('top')
25top = max([cumulative.max(),abs(cumulative.min())])
26a.set_yticks([top,0,-top])
27plt.show()
Pandas Built-in Matplotlib Support
Pandas has the ability to create DataFrames directly using Matplotlib behind the scenes.
Styles
Pandas will either the default, or a user-defined style, set via plt.style.use()
. Available styles can be obtained via plt.style.available
1plt.style.use('seaborn-bright')
2plt.style.available
['Solarize_Light2',
'_classic_test_patch',
'_mpl-gallery',
'_mpl-gallery-nogrid',
'bmh',
'classic',
'dark_background',
'fast',
'fivethirtyeight',
'ggplot',
'grayscale',
'seaborn',
'seaborn-bright',
'seaborn-colorblind',
'seaborn-dark',
'seaborn-dark-palette',
'seaborn-darkgrid',
'seaborn-deep',
'seaborn-muted',
'seaborn-notebook',
'seaborn-paper',
'seaborn-pastel',
'seaborn-poster',
'seaborn-talk',
'seaborn-ticks',
'seaborn-white',
'seaborn-whitegrid',
'tableau-colorblind10']
Plotting
Pandas produces a default line plot when invoking Dataframe.plot()
.
1np.random.seed(0)
2
3df = pd.DataFrame({'Normal': np.sort(np.random.normal(loc=8.0, scale=1.0, size=100)),
4 'Gamma': np.sort(np.random.gamma(1, scale=1.0, size=100)),
5 'Logistic': np.sort(np.random.logistic(loc=-8.0, scale=1.0, size=100))},
6 )
7df.plot();
For further customisation, we can select the kind of chart we want: line, bar, barh, hist, box, kde, density, area, pie, scatter, or hexbin.
In the case of scatter, we must specify the x and y axes from all the series contained within the DataFrame.
1df.plot('Normal','Gamma',kind='scatter');
1# An example using a colour map
2ax = df.plot.scatter('Normal', 'Logistic', c='Gamma', s=df['Gamma'], colormap='jet');
3ax.set_aspect('equal')
1# Histogram
2df.plot(kind='hist',bins=20,alpha=0.6);
1# Kernel Density Estimation
2df.plot(kind='kde');
Scatterd Matrix
This is a useful way to contrast different data sets with one another.
1plotting.scatter_matrix(df);