Plotting data with matplotlib

Plotting of data is pandas is handled by an external Python module called matplotlib. Like pandas it is a large library and has a venerable history (first released in 2003) and so we couldn't hope to cover all its functionality in this course. To see the wide range of possibilities you have with matplotlib see its example gallery.

Here we will cover the basic uses of it and how it integrates with pandas. While working through these examples you will likely find it very useful to refer to the matplotlib documentation.

First we import pandas and numpy in the same way as we did previously.

In [1]:
import numpy as np
import pandas as pd
from pandas import Series, DataFrame

Some matplotlib functionality is provided directly through pandas (such as the plot() method as we will see) but for much of it you need to import the matplotlib interface itself.

The most common interface to matplotlib is its pyplot module which provides a way to affect the current state of matplotlib directly. By convention this is imported as plt.

We also set the figure format to be SVG so that the plots look a little nicer in our Jupyter notebook.

In [2]:
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'svg'

Once we have imported matplotlib we can start calling its functions. Any functions called on the plt object will affect all of matplotlib from that point on in the script.

We first need to import some data to plot. Let's start with the data from the pandas section (available from cetml1659on.dat) and import it into a DataFrame:

In [3]:
df = pd.read_csv(
    'cetml1659on.dat',  # file name
    skiprows=6,  # skip header
    sep='\s+',  # whitespace separated
    na_values=['-99.9', '-99.99'],  # NaNs
)
df.head()
Out[3]:
JAN FEB MAR APR MAY JUN JUL AUG SEP OCT NOV DEC YEAR
1659 3.0 4.0 6.0 7.0 11.0 13.0 16.0 16.0 13.0 10.0 5.0 2.0 8.87
1660 0.0 4.0 6.0 9.0 11.0 14.0 15.0 16.0 13.0 10.0 6.0 5.0 9.10
1661 5.0 5.0 6.0 8.0 11.0 14.0 15.0 15.0 13.0 11.0 8.0 6.0 9.78
1662 5.0 6.0 6.0 8.0 11.0 15.0 15.0 15.0 13.0 11.0 6.0 3.0 9.52
1663 1.0 1.0 5.0 7.0 10.0 14.0 15.0 15.0 13.0 10.0 7.0 5.0 8.63

Pandas integrates matplotlib directly into itself so any dataframe can be plotted easily simply by calling the plot() method on one of the columns. This creates a plot object which you can then edit and alter, for example by setting the axis labels using the plt.ylabel() function before displaying it with plt.show().

Matplotlib operates on a single global state and calling any function on plt will alter that state. Calling df.plot() sets the currently operating plot. plt.ylabel() then alters that state and plt.show() displays it.

In [4]:
df['JAN'].plot()

plt.ylabel(r'Temperature ($^\circ$C)')

plt.show()

Exercise 3

  • Make sure you can reproduce the plot above. Try tweaking the labels or which column is plotted.
  • Try putting in two plot() calls with different months (January and July for example) before calling show().
In [5]:
# Answer

df['JUL'].plot()
df['JAN'].plot()

plt.ylabel(r'Temperature ($^\circ$C)')

plt.show()

Making it prettier

While it's useful to be able to quickly plot any data we have in front of us, matplotlib's power comes from its configurability. Let's experiment with a dataset and see how much we can change the plot.

We'll start with a simple DataFrame contianing two columns, one with the values of a cosine, the other with the values of a sine.

In [6]:
X = np.linspace(-np.pi, np.pi, 256, endpoint=True)
data = {'cos': np.cos(X), 'sin': np.sin(X)}
trig = DataFrame(index=X, data=data)

trig.plot()
plt.show()

You can see that it has plotted the sine and cosine curves between $\pi$ and $-\pi$. Now, let's go through and see how we can affect the display of this plot.

Changing colours and line widths

First step, we want to have the cosine in blue and the sine in red and a slighty thicker line for both of them.

In [7]:
trig.cos.plot(color="blue", linewidth=2.5, linestyle="-")
trig.sin.plot(color="red", linewidth=2.5, linestyle="-")

plt.show()

Exercise 4

  • Using the temperature dataset, set the colours of the July and January lines to a warm colour and a cool colour.
  • Add in the yearly average column to the plot with a dashed line style.
In [8]:
# Answer

df['JUL'].plot(color='orange')
df['JAN'].plot(color='blue')
df['YEAR'].plot(linestyle=':')

plt.ylabel(r'Temperature ($^\circ$C)')

plt.legend(loc='upper left')

plt.show()

Setting limits

Current limits of the figure are a bit too tight and we want to make some space in order to clearly see all data points.

In [9]:
trig.cos.plot(color="blue", linewidth=2.5, linestyle="-")
trig.sin.plot(color="red", linewidth=2.5, linestyle="-")

### New code
plt.xlim(trig.index.min() * 1.1, trig.index.max() * 1.1)
plt.ylim(trig.cos.min() * 1.1, trig.cos.max() * 1.1)
### End of new code

plt.show()

Setting ticks

Current ticks are not ideal because they do not show the interesting values ($\pm\pi$,$\pm\frac{\pi}{2}$) for sine and cosine. We’ll change them such that they show only these values.

In [10]:
trig.cos.plot(color="blue", linewidth=2.5, linestyle="-")
trig.sin.plot(color="red", linewidth=2.5, linestyle="-")

plt.xlim(trig.index.min() * 1.1, trig.index.max() * 1.1)
plt.ylim(trig.cos.min() * 1.1, trig.cos.max() * 1.1)

### New code
plt.xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
plt.yticks([-1, 0, +1])
### End of new code

plt.show()

Setting tick labels

Ticks are now properly placed but their label is not very explicit. We could guess that 3.142 is $\pi$ but it would be better to make it explicit. When we set tick values, we can also provide a corresponding label in the second argument list. Note that we’ll use LaTeX to allow for nice rendering of the label.

In [11]:
trig.cos.plot(color="blue", linewidth=2.5, linestyle="-")
trig.sin.plot(color="red", linewidth=2.5, linestyle="-")

plt.xlim(trig.index.min() * 1.1, trig.index.max() * 1.1)
plt.ylim(trig.cos.min() * 1.1, trig.cos.max() * 1.1)

### New code
plt.xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi],
           [r'$-\pi$', r'$-\pi/2$', r'$0$', r'$+\pi/2$', r'$+\pi$'])

plt.yticks([-1, 0, +1],
           [r'$-1$', r'$0$', r'$+1$'])
### End of new code

plt.show()

Moving spines

Spines are the lines connecting the axis tick marks and noting the boundaries of the data area. They can be placed at arbitrary positions and until now, they were on the border of the axis. We’ll change that since we want to have them in the middle. Since there are four of them (top/bottom/left/right), we’ll discard the top and right by setting their color to none and we’ll move the bottom and left ones to coordinate 0 in data space coordinates.

In [12]:
trig.cos.plot(color="blue", linewidth=2.5, linestyle="-")
trig.sin.plot(color="red", linewidth=2.5, linestyle="-")

plt.xlim(trig.index.min() * 1.1, trig.index.max() * 1.1)
plt.ylim(trig.cos.min() * 1.1, trig.cos.max() * 1.1)

plt.xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
plt.yticks([-1, 0, +1])

plt.xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi],
           [r'$-\pi$', r'$-\pi/2$', r'$0$', r'$+\pi/2$', r'$+\pi$'])

plt.yticks([-1, 0, +1],
           [r'$-1$', r'$0$', r'$+1$'])

### New code
ax = plt.gca()  # gca stands for 'get current axis'
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
ax.xaxis.set_ticks_position('bottom')
ax.spines['bottom'].set_position(('data',0))
ax.yaxis.set_ticks_position('left')
ax.spines['left'].set_position(('data',0))
### End of new code

plt.show()

Adding a legend

Let’s add a legend in the upper left corner. This only requires adding the keyword argument label (that will be used in the legend box) to the plot commands.

In [13]:
trig.cos.plot(color="blue", linewidth=2.5, linestyle="-")
trig.sin.plot(color="red", linewidth=2.5, linestyle="-")

plt.xlim(trig.index.min() * 1.1, trig.index.max() * 1.1)
plt.ylim(trig.cos.min() * 1.1, trig.cos.max() * 1.1)

plt.xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
plt.yticks([-1, 0, +1])

plt.xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi],
           [r'$-\pi$', r'$-\pi/2$', r'$0$', r'$+\pi/2$', r'$+\pi$'])

plt.yticks([-1, 0, +1],
           [r'$-1$', r'$0$', r'$+1$'])

ax = plt.gca()  # gca stands for 'get current axis'
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
ax.xaxis.set_ticks_position('bottom')
ax.spines['bottom'].set_position(('data',0))
ax.yaxis.set_ticks_position('left')
ax.spines['left'].set_position(('data',0))

### New code
plt.legend(loc='upper left')
### End of new code

plt.show()

Annotate some points

Let’s annotate some interesting points using the annotate command. We chose the $\frac{2}{3}\pi$ value and we want to annotate both the sine and the cosine. We’ll first draw a marker on the curve as well as a straight dotted line. Then, we’ll use the annotate command to display some text with an arrow.

In [27]:
trig.cos.plot(color="blue", linewidth=2.5, linestyle="-")
trig.sin.plot(color="red", linewidth=2.5, linestyle="-")

plt.xlim(trig.index.min() * 1.1, trig.index.max() * 1.1)
plt.ylim(trig.cos.min() * 1.1, trig.cos.max() * 1.1)

plt.xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
plt.yticks([-1, 0, +1])

plt.xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi],
           [r'$-\pi$', r'$-\pi/2$', r'$0$', r'$+\pi/2$', r'$+\pi$'])

plt.yticks([-1, 0, +1],
           [r'$-1$', r'$0$', r'$+1$'])

ax = plt.gca()  # gca stands for 'get current axis'
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
ax.xaxis.set_ticks_position('bottom')
ax.spines['bottom'].set_position(('data',0))
ax.yaxis.set_ticks_position('left')
ax.spines['left'].set_position(('data',0))

plt.legend(loc='upper left')

### New code
t = 2 * np.pi / 3
plt.plot([t, t], [0, np.cos(t)], color='blue', linewidth=2.5, linestyle="--")
plt.scatter([t, ], [np.cos(t), ], 50, color='blue')

plt.annotate(r'$cos(\frac{2\pi}{3})=-\frac{1}{2}$',
             xy=(t, np.cos(t)), xycoords='data',
             xytext=(-90, -50), textcoords='offset points', fontsize=16,
             arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))

plt.plot([t, t],[0, np.sin(t)], color='red', linewidth=2.5, linestyle="--")
plt.scatter([t, ],[np.sin(t), ], 50, color='red')

plt.annotate(r'$sin(\frac{2\pi}{3})=\frac{\sqrt{3}}{2}$',
             xy=(t, np.sin(t)), xycoords='data',
             xytext=(+10, +30), textcoords='offset points', fontsize=16,
             arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
### End of new code

plt.show()

Now you know how to make different modifications to your plots we can make some of these changes to our temerature data.

Saving plot to a file

You can take any plot you've created within Jupyter and save it to a file on disk using the plt.savefig() function. You give the function the name of the file to create and it will use whatever format is specified by the name.

In [28]:
trig.plot()

plt.show()

plt.savefig('my_fig.svg')

You can then display the figure in Jupyter with ![](my_fig.svg)

Exercise 5

  • Add in a legend for the data.
  • Add an annotation to one of the spikes in the data. Make sure the label is placed nicely.
    • Tip: you can get the year and temperature for a spike using:
      warm_winter_year = df['JAN'].idxmax()
      warm_winter_temp = df['JAN'].max()
      
  • Save the figure to a file and display it in your Jupyter notebook.
In [16]:
# Answer

df['JUL'].plot(color='orange')
df['JAN'].plot(color='blue')
df['YEAR'].plot(linestyle=':')

plt.ylabel(r'Temperature ($^\circ$C)')

plt.legend(loc='upper left')

warm_winter_year = df['JAN'].idxmax()
warm_winter_temp = df['JAN'].max()

plt.annotate(r'A warm winter',
             xy=(warm_winter_year, warm_winter_temp),
             xytext=(-30, +30), textcoords='offset points', fontsize=14,
             arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))

plt.show()

Bar charts

Of course, Matplotlib can plot more than just line graphs. One of the other most common plot types is a bar chart. Let's work towards plotting a bar chart of the average temperature per decade.

Let's start by adding a new column to the data frame which represents the decade. We create it by taking the index (which is a list of years), converting each element to a string and then replacing the fourth character with a '0'.

In [17]:
years = Series(df.index, index=df.index).apply(str)
decade = years.apply(lambda x: x[:3]+'0')

df['decade'] = decade
df.head()
Out[17]:
JAN FEB MAR APR MAY JUN JUL AUG SEP OCT NOV DEC YEAR decade
1659 3.0 4.0 6.0 7.0 11.0 13.0 16.0 16.0 13.0 10.0 5.0 2.0 8.87 1650
1660 0.0 4.0 6.0 9.0 11.0 14.0 15.0 16.0 13.0 10.0 6.0 5.0 9.10 1660
1661 5.0 5.0 6.0 8.0 11.0 14.0 15.0 15.0 13.0 11.0 8.0 6.0 9.78 1660
1662 5.0 6.0 6.0 8.0 11.0 15.0 15.0 15.0 13.0 11.0 6.0 3.0 9.52 1660
1663 1.0 1.0 5.0 7.0 10.0 14.0 15.0 15.0 13.0 10.0 7.0 5.0 8.63 1660

Once we have our decade column, we can use Pandas groupby() function to gather our data by decade and then aggregate it by taking the mean of each decade.

In [18]:
by_decade = df.groupby('decade')
agg = by_decade.aggregate(np.mean)

agg.head()
Out[18]:
JAN FEB MAR APR MAY JUN JUL AUG SEP OCT NOV DEC YEAR
decade
1650 3.00 4.00 6.00 7.00 11.00 13.00 16.00 16.00 13.00 10.00 5.00 2.00 8.870
1660 2.60 4.00 5.10 7.70 10.60 14.50 16.00 15.70 13.30 10.00 6.30 3.80 9.157
1670 3.25 2.35 4.50 7.25 11.05 14.40 15.80 15.25 12.40 8.95 5.20 2.45 8.607
1680 2.50 2.80 4.80 7.40 11.45 14.00 15.45 14.90 12.70 9.55 5.45 4.05 8.785
1690 1.89 2.49 3.99 6.79 9.60 13.44 15.27 14.65 11.93 8.64 5.26 3.31 8.134

At this point, agg is a standard Pandas DataFrame so we can plot it like any other, by putting .bar after the plot call:

In [19]:
agg.YEAR.plot.bar()

plt.ylabel(r'Temperature ($^\circ$C)')

plt.show()

Exercise 6

  1. Plot a bar chart of the average temperature per century.

    • Set the limits of the y-axis to zoom in on the data.
  2. Plot a histogram of the average annual temperature

    • Make sure that the x-axis is labelled correctly.
    • Tip: Look in the documentation for the right command to run
  3. Plot a scatter plot of each year's February temperature plotted against that year's January temperature. Is there an obvious correlation?

In [20]:
# Answer

years = Series(df.index, index=df.index).apply(str)
century = years.apply(lambda x: x[:2]+'00')

df['century'] = century

by_century = df.groupby('century')
century_avg = by_century.agg(np.mean)


century_avg.YEAR.plot.bar()

plt.xlabel(r'Century')
plt.ylabel(r'Average yearly temperature ($^\circ$C)')
plt.ylim(8, 10.5)

plt.show()
In [21]:
# Answer

df.YEAR.plot.hist()

plt.xlabel(r'Temperature ($^\circ$C)')

plt.show()
In [22]:
# Answer

df.plot.scatter(x='JAN', y='FEB')

plt.xlabel(r'Temperature in January($^\circ$C)')
plt.ylabel(r'Temperature in February($^\circ$C)')

plt.show()