Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

GroupBys and Pivot Tables

import pandas as pd
import numpy as np

This subchapter goes over groupby’s and pivot table’s, two incredibly useful pandas methods.

To start, let’s load in the same dataset as the first subchapter in the Pandas chapter. As a reminder, this dataset has beer sales across 50 continental states in the US. It is sourced from Salience and Taxation: Theory and Evidence by Chetty, Looney, and Kroft (AER 2009), and it includes 7 columns:

  • st_name: the state abbreviation

  • year: the year the data was recorded

  • c_beer: the quantity of beer consumed, in thousands of gallons

  • beer_tax: the ad valorem tax, as a percentage

  • btax_dollars: the excise tax, represented in dollars per case (24 cans) of beer

  • population: the population of the state, in thousands

  • salestax: the sales tax percentage

df = pd.read_csv('data/beer_tax.csv')
df
Loading...

GroupBys

Let’s say we’re interested in how the average tax differs by state. Right now, we have a fair amount of data for every state, so it’d great if we could combine the data for each state somehow. .groupby() helps us do exactly that.

When we groupby on a column ‘x’, we create groupby objects for each unique value in the column. Each groupby object contains all the data corresponding to that unique value from the original DataFrame. The code below visualizes the first 5 rows from the groupby objects corresponding to the first 5 states in our dataset.

i = 0
for name, frame in df.groupby('st_name'):
    if i > 5:
        continue
    print(f'The first 5 rows from the subframe for {name} are: ')
    print(frame.head())
    print('________________________________________________________________________')
    i += 1
The first 5 rows from the subframe for AK are: 
   st_name  year  c_beer   beer_tax  btax_dollars  population  salestax
34      AK  1970    5372  13.735660        0.5625         304       0.0
35      AK  1971    6336  13.159102        0.5625         316       0.0
36      AK  1972    6038  12.749847        0.5625         324       0.0
37      AK  1973    6453  12.003234        0.5625         331       0.0
38      AK  1974    7598  10.810215        0.5625         341       0.0
________________________________________________________________________
The first 5 rows from the subframe for AL are: 
  st_name  year  c_beer   beer_tax  btax_dollars  population  salestax
0      AL  1970   33098  72.341130          2.37        3450       4.0
1      AL  1971   37598  69.304600          2.37        3497       4.0
2      AL  1972   42719  67.149190          2.37        3539       4.0
3      AL  1973   46203  63.217026          2.37        3580       4.0
4      AL  1974   49769  56.933796          2.37        3627       4.0
________________________________________________________________________
The first 5 rows from the subframe for AR are: 
    st_name  year  c_beer   beer_tax  btax_dollars  population  salestax
102      AR  1970   22378  16.632357        0.5449        1930       3.0
103      AR  1971   25020  15.934210        0.5449        1972       3.0
104      AR  1972   25614  15.438649        0.5449        2019       3.0
105      AR  1973   27946  14.534582        0.5449        2059       3.0
106      AR  1974   30915  13.089970        0.5449        2101       3.0
________________________________________________________________________
The first 5 rows from the subframe for AZ are: 
   st_name  year  c_beer  beer_tax  btax_dollars  population  salestax
68      AZ  1970   38604  5.494264          0.18        1795       3.0
69      AZ  1971   41837  5.263641          0.18        1896       3.0
70      AZ  1972   47949  5.099939          0.18        2008       3.0
71      AZ  1973   53380  4.801294          0.18        2124       3.0
72      AZ  1974   58188  4.324086          0.18        2223       3.0
________________________________________________________________________
The first 5 rows from the subframe for CA are: 
    st_name  year  c_beer  beer_tax  btax_dollars  population  salestax
136      CA  1970  363645  2.747132          0.09       20023     5.000
137      CA  1971  380397  2.631820          0.09       20346     5.000
138      CA  1972  401928  2.549970          0.09       20585     5.000
139      CA  1973  417463  2.400647          0.09       20869     5.167
140      CA  1974  464237  2.162043          0.09       21174     5.250
________________________________________________________________________
The first 5 rows from the subframe for CO are: 
    st_name  year  c_beer  beer_tax  btax_dollars  population  salestax
170      CO  1970   42145  4.120698         0.135        2224       3.0
171      CO  1971   45359  3.947731         0.135        2304       3.0
172      CO  1972   50444  3.824955         0.135        2405       3.0
173      CO  1973   55332  3.600970         0.135        2496       3.0
174      CO  1974   60162  3.243065         0.135        2541       3.0
________________________________________________________________________

Aggregate

Now, we wish to collapse all the groupby objects into a single row, and return a DataFrame that has all the single rows across all groupby objects. We can do that with the help of an aggregation function which tells pandas how to combine the rows. For example, the aggregation function .mean() tells pandas to take the mean of the rows in the groupby object. We do this below.

df.groupby('st_name').mean()
Loading...

You can visualize the entire process (although using only a subset of the data) here. In general, we encourage you to play around with Pandas Tutor to visualize how your code works!

There are a lot of different agg functions you can use! We’ve given a few below, but won’t be surprised if there are more!

Built-in aggregation functions:

  • .mean()

  • .median()

  • .sum()

  • .count()

    • Note: It may appear as if .value_counts() and .count() return the same data. However, while .value_counts() returns a Series sorted from most common to least, the aggregation function .count() returns a DataFrame with the same ordering as the index.

  • .max()

  • .min()

  • .std()

  • .var()

You can also use other functions, such as those defined by NumPy. Examples include:

  • .agg(np.mean)

  • .agg(np.prod)

  • .agg(np.cumsum)

    • returns the cumulative sum; read more here

df.groupby('st_name').agg(np.cumsum)
Loading...

Finally, if you like, you can also define your own aggregation function! An example is given below, where we define last_10_vs_first_10 to return the average value in the last 10 years minus the average value in the first 10 years. Remember, you aggregation function must be able to aggregate columns of data into a single value.

def last_10_vs_first_10(obj):
    return obj.iloc[-10:].mean() - obj.iloc[:10].mean()

df.groupby('st_name').agg(last_10_vs_first_10)
Loading...

As a general tip, whenever you’re trying to calculate differences across a categorical variable, consider whether groupby’s could be helpful. Then, determine which column you would groupby on (normally the categorical column you’re interested in). Finally, determine which aggregation function may be most helpful.

Moreover, if we only want to see the results of the .groupby() on a single column, it is common practice to select that column before the aggregation function to minimize computing speed. For example, if we just want to see the median beer consumption for each state, we could do df.groupby('st_name')['c_beer'].median().

df.groupby('st_name')['c_beer'].median()
st_name AK 13666.5 AL 74930.5 AR 44731.0 AZ 98964.5 CA 620989.5 CO 83391.0 CT 58610.5 DC 16862.5 DE 16678.0 FL 327415.0 GA 131193.0 HI 29817.0 IA 66664.0 ID 23272.5 IL 280701.5 IN 119674.5 KS 49094.0 KY 71587.5 LA 107196.0 MA 131876.5 MD 98215.5 ME 26079.0 MI 210840.0 MN 98980.5 MO 122805.0 MS 56514.0 MT 23274.0 NC 127415.0 ND 15935.5 NE 39395.5 NH 35645.5 NJ 150161.5 NM 41557.5 NV 36316.0 NY 358784.0 OH 257559.5 OK 60139.0 OR 63020.0 PA 285790.5 RI 22784.5 SC 78973.0 SD 15490.0 TN 101125.0 TX 472811.5 UT 22402.5 VA 138821.0 VT 13601.0 WA 101889.0 WI 147526.5 WV 37934.5 WY 12071.0 Name: c_beer, dtype: float64

Selecting the relevant columns like this is especially important if you may have other strings in your DataFrame and you’re attempting to do an aggregation function which doesn’t work with strings (for example, it makes no sense to try to take the mean of strings). If you try to use an aggregation function on a type of data it cannot work with, your code will error.

df
Loading...

Filter

Aggregate functions collapse all the groupby objects into a single row. Filter functions instead return a True/False for each groupby object and only keep the rows for whom the groupby object returned true. For example, let’s say we want to keep all the states where the beer_tax was larger than 50% at least once. We could accomplish that using the following line of code.

df.groupby('st_name').filter(lambda x: (x['beer_tax'] >= 50).any())
Loading...

In the above code, the lambda function defines a Python function that takes in a groupby object named x. The function then extracts the beer_tax Series from that groupby object, compares it to 50 and returns an array of True/Falses indicating whether or not the original value was greater than 50. Then, .any() outputs whether any (in other words: one or more) of those comparisons are true. If yes, the lambda function outputs true (meaning the filter keeps the rows from that groupby object) and if no, the lambda function outputs false (meaning the filter removes the rows from that groupby object). As you can see from the code below, only Alabama, Georgia and South Carolina have ever had beer_tax larger than 50%.

df.groupby('st_name').filter(lambda x: (x['beer_tax'] >= 50).any())['st_name'].unique()
array(['AL', 'GA', 'SC'], dtype=object)

Multiple Columns and Aggregations

Finally, we can also groupby multiple columns. To demonstrate this, let us first create another categorical variable by adding in a decade column representing which decade the data is from.

df['decade'] = df['year'] // 10 * 10
df
Loading...

Now, since decade is also a categorical variable, we can groupby both the state and the decade.

df.groupby(['st_name','decade']).mean()
Loading...

Additionally, we can also groupby using multiple aggregation functions.

df.groupby(['st_name','decade']).agg([min,max])
Loading...

Pivot Tables

While we could groupby multiple columns above, the resulting dataset was honestly a bit messy. If you’re interested in the relationship between two categorical variables and how they affect a third numerical variable, a pivot table is often best. For example, let us use a pivot table to visualize the average beer consumption across the decades for each state.

pd.pivot_table(data=df, index='st_name', columns='decade', 
               values='c_beer', aggfunc=np.mean).head(12)
Loading...

As you can see, in the pd.pivot_table() method:

  • The data parameter tells the function which DataFrame to get the data from.

  • The index parameter says which categorical variable to use to form the index.

  • The columns parameter says which categorical variable to use to form the columns.

  • The value parameter says which numerical variable to aggregate.

  • The aggfunc parameter says which function to use to perform the aggregation.

  • Optionally, you can set the fill_value parameter to be the value you want to replace the NaN’s with. For example, as we don’t have data on Hawaii in the 1970s, you can see that datapoint is currently a NaN value.

References
  1. Chetty, R., Looney, A., & Kroft, K. (2009). Salience and Taxation: Theory and Evidence. American Economic Review, 99(4), 1145–1177. 10.1257/aer.99.4.1145