In [2]:
import pandas as pd
import numpy as np

# GroupBys and Pivot Tables

This subchapter goes over groupby's and pivot table's, two incredibly useful `pandas` methods.

To start, let's load in the same dataset as the first subchapter in the `Pandas` chapter. As a reminder, this dataset has beer sales across 50 continental states in the US. It is sourced from [_Salience and Taxation: Theory and Evidence_](https://www.aeaweb.org/articles?id=10.1257/aer.99.4.1145) by Chetty, Looney, and Kroft (AER 2009), and it includes 7 columns:
- `st_name`: the state abbreviation
- `year`: the year the data was recorded
- `c_beer`: the quantity of beer consumed, in thousands of gallons
- `beer_tax`: the ad valorem tax, as a percentage
- `btax_dollars`: the excise tax, represented in dollars per case (24 cans) of beer 
- `population`: the population of the state, in thousands
- `salestax`: the sales tax percentage

In [3]:
df = pd.read_csv('data/beer_tax.csv')
df

Unnamed: 0,st_name,year,c_beer,beer_tax,btax_dollars,population,salestax
0,AL,1970,33098,72.341130,2.370,3450,4.0
1,AL,1971,37598,69.304600,2.370,3497,4.0
2,AL,1972,42719,67.149190,2.370,3539,4.0
3,AL,1973,46203,63.217026,2.370,3580,4.0
4,AL,1974,49769,56.933796,2.370,3627,4.0
...,...,...,...,...,...,...,...
1703,WY,1999,12423,0.319894,0.045,492,4.0
1704,WY,2000,12595,0.309491,0.045,494,4.0
1705,WY,2001,12808,0.300928,0.045,494,4.0
1706,WY,2002,13191,0.296244,0.045,499,4.0


## GroupBys

Let's say we're interested in how the average tax differs by state. Right now, we have a fair amount of data for every state, so it'd great if we could combine the data for each state somehow. `.groupby()` helps us do exactly that.

When we groupby on a column 'x', we create `groupby objects` for each unique value in the column. Each groupby object contains all the data corresponding to that unique value from the original DataFrame. The code below visualizes the first 5 rows from the groupby objects corresponding to the first 5 states in our dataset.

In [4]:
i = 0
for name, frame in df.groupby('st_name'):
    if i > 5:
        continue
    print(f'The first 5 rows from the subframe for {name} are: ')
    print(frame.head())
    print('________________________________________________________________________')
    i += 1

The first 5 rows from the subframe for AK are: 
   st_name  year  c_beer   beer_tax  btax_dollars  population  salestax
34      AK  1970    5372  13.735660        0.5625         304       0.0
35      AK  1971    6336  13.159102        0.5625         316       0.0
36      AK  1972    6038  12.749847        0.5625         324       0.0
37      AK  1973    6453  12.003234        0.5625         331       0.0
38      AK  1974    7598  10.810215        0.5625         341       0.0
________________________________________________________________________
The first 5 rows from the subframe for AL are: 
  st_name  year  c_beer   beer_tax  btax_dollars  population  salestax
0      AL  1970   33098  72.341130          2.37        3450       4.0
1      AL  1971   37598  69.304600          2.37        3497       4.0
2      AL  1972   42719  67.149190          2.37        3539       4.0
3      AL  1973   46203  63.217026          2.37        3580       4.0
4      AL  1974   49769  56.933796          

### Aggregate

Now, we wish to collapse all the groupby objects into a single row, and return a DataFrame that has all the single rows across all groupby objects. We can do that with the help of an aggregation function which tells `pandas` how to combine the rows. For example, the aggregation function `.mean()` tells `pandas` to take the mean of the rows in the groupby object. We do this below.


In [5]:
df.groupby('st_name').mean()

Unnamed: 0_level_0,year,c_beer,beer_tax,btax_dollars,population,salestax
st_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
AK,1986.5,12131.117647,7.208826,0.745259,503.529412,0.0
AL,1986.5,72660.411765,31.817266,2.37,4022.294118,4.0
AR,1986.5,41951.705882,7.315287,0.5449,2359.441176,3.813235
AZ,1986.5,93827.470588,3.282313,0.283235,3490.0,4.494118
CA,1986.5,585206.794118,2.189079,0.222353,27467.588235,6.320794
CO,1986.5,81713.058824,2.189892,0.171397,3282.941176,3.013235
CT,1986.5,59092.058824,3.45767,0.307591,3228.5,6.617647
DC,1986.5,16505.735294,2.320071,0.180018,633.676471,6.823529
DE,1986.5,16022.147059,2.55755,0.226297,659.970588,0.0
FL,1986.5,299864.882353,15.981612,1.290588,11903.176471,5.073529


You can visualize the entire process (although using only a subset of the data) [here](https://pandastutor.com/vis.html#code=import%20pandas%20as%20pd%0Aimport%20io%0A%0Acsv%20%3D%20'''%0A,st_name,year,c_beer,beer_tax,btax_dollars,population,salestax%0A0,AL,1970,33098,72.34113,2.37,3450,4.0%0A1,AL,1971,37598,69.3046,2.37,3497,4.0%0A2,AL,1972,42719,67.14919,2.37,3539,4.0%0A34,AK,1970,5372,13.73566,0.5625,304,0.0%0A35,AK,1971,6336,13.159101500000002,0.5625,316,0.0%0A36,AK,1972,6038,12.749847,0.5625,324,0.0%0A68,AZ,1970,38604,5.494264,0.18,1795,3.0%0A69,AZ,1971,41837,5.263641000000001,0.18,1896,3.0%0A70,AZ,1972,47949,5.0999393,0.18,2008,3.0%0A136,CA,1970,363645,2.747132,0.09,20023,5.0%0A137,CA,1971,380397,2.6318203999999996,0.09,20346,5.0%0A138,CA,1972,401928,2.5499697,0.09,20585,5.0%0A'''%0A%0Adf%20%3D%20pd.read_csv%28io.StringIO%28csv%29%29%0Adf.groupby%28'st_name'%29.mean%28%29&d=2024-01-04&lang=py&v=v1). In general, we encourage you to play around with [Pandas Tutor](https://pandastutor.com/) to visualize how your code works!

There are a lot of different `agg` functions you can use! We've given a few below, but won't be surprised if there are more!

Built-in aggregation functions:
- `.mean()`
- `.median()`
- `.sum()`
- `.count()` 
    - Note: It may appear as if `.value_counts()` and `.count()` return the same data. However, while `.value_counts()` returns a Series sorted from most common to least, the aggregation function `.count()` returns a DataFrame with the same ordering as the index.
- `.max()`
- `.min()`
- `.std()`
- `.var()`

You can also use other functions, such as those defined by `NumPy`. Examples include:
- `.agg(np.mean)`
- `.agg(np.prod)`
- `.agg(np.cumsum)`
    - returns the cumulative sum; read more [here](https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html#numpy.cumsum)


In [6]:
df.groupby('st_name').agg(np.cumsum)

Unnamed: 0,year,c_beer,beer_tax,btax_dollars,population,salestax
0,1970,33098,72.341130,2.370,3450,4.0
1,3941,70696,141.645730,4.740,6947,8.0
2,5913,113415,208.794920,7.110,10486,12.0
3,7886,159618,272.011946,9.480,14066,16.0
4,9860,209387,328.945742,11.850,17693,20.0
...,...,...,...,...,...,...
1703,55594,340347,16.654479,1.260,12807,88.0
1704,57594,352942,16.963970,1.305,13301,92.0
1705,59595,365750,17.264899,1.350,13795,96.0
1706,61597,378941,17.561143,1.395,14294,100.0


Finally, if you like, you can also define your own aggregation function! An example is given below, where we define `last_10_vs_first_10` to return the average value in the last 10 years minus the average value in the first 10 years. Remember, you aggregation function must be able to aggregate columns of data into a single value.

In [7]:
def last_10_vs_first_10(obj):
    return obj.iloc[-10:].mean() - obj.iloc[:10].mean()

df.groupby('st_name').agg(last_10_vs_first_10)

Unnamed: 0_level_0,year,c_beer,beer_tax,btax_dollars,population,salestax
st_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
AK,24.0,6799.2,-5.255919,0.387,262.1,0.0
AL,24.0,41128.7,-38.866784,0.0,740.4,0.0
AR,24.0,20500.6,-8.936081,0.0,514.8,1.715
AZ,24.0,69478.1,-1.66223,0.18,2714.2,1.68
CA,24.0,173227.7,1.103402,0.36,11827.6,1.6833
CO,24.0,40179.1,-2.181436,0.02925,1542.9,-0.03
CT,24.0,2883.6,-1.191767,0.2626,311.3,-0.25
DC,24.0,-2372.4,-2.397173,0.0392,-136.6,2.2
DE,24.0,6749.2,-0.899509,0.2068,187.7,0.0
FL,24.0,194335.7,-13.366342,0.42,7388.0,2.0


As a general tip, whenever you're trying to calculate differences across a categorical variable, consider whether groupby's could be helpful. Then, determine which column you would groupby on (normally the categorical column you're interested in). Finally, determine which aggregation function may be most helpful.

Moreover, if we only want to see the results of the `.groupby()` on a single column, it is common practice to select that column before the aggregation function to minimize computing speed. For example, if we just want to see the median beer consumption for each state, we could do `df.groupby('st_name')['c_beer'].median()`.

In [8]:
df.groupby('st_name')['c_beer'].median()

st_name
AK     13666.5
AL     74930.5
AR     44731.0
AZ     98964.5
CA    620989.5
CO     83391.0
CT     58610.5
DC     16862.5
DE     16678.0
FL    327415.0
GA    131193.0
HI     29817.0
IA     66664.0
ID     23272.5
IL    280701.5
IN    119674.5
KS     49094.0
KY     71587.5
LA    107196.0
MA    131876.5
MD     98215.5
ME     26079.0
MI    210840.0
MN     98980.5
MO    122805.0
MS     56514.0
MT     23274.0
NC    127415.0
ND     15935.5
NE     39395.5
NH     35645.5
NJ    150161.5
NM     41557.5
NV     36316.0
NY    358784.0
OH    257559.5
OK     60139.0
OR     63020.0
PA    285790.5
RI     22784.5
SC     78973.0
SD     15490.0
TN    101125.0
TX    472811.5
UT     22402.5
VA    138821.0
VT     13601.0
WA    101889.0
WI    147526.5
WV     37934.5
WY     12071.0
Name: c_beer, dtype: float64

Selecting the relevant columns like this is especially important if you may have other strings in your DataFrame and you're attempting to do an aggregation function which doesn't work with strings (for example, it makes no sense to try to take the mean of strings). If you try to use an aggregation function on a type of data it cannot work with, your code will error.



In [13]:
df

Unnamed: 0,st_name,year,c_beer,beer_tax,btax_dollars,population,salestax
0,AL,1970,33098,72.341130,2.370,3450,4.0
1,AL,1971,37598,69.304600,2.370,3497,4.0
2,AL,1972,42719,67.149190,2.370,3539,4.0
3,AL,1973,46203,63.217026,2.370,3580,4.0
4,AL,1974,49769,56.933796,2.370,3627,4.0
...,...,...,...,...,...,...,...
1703,WY,1999,12423,0.319894,0.045,492,4.0
1704,WY,2000,12595,0.309491,0.045,494,4.0
1705,WY,2001,12808,0.300928,0.045,494,4.0
1706,WY,2002,13191,0.296244,0.045,499,4.0


### Filter

Aggregate functions collapse all the groupby objects into a single row. Filter functions instead return a `True/False` for each groupby object and only keep the rows for whom the groupby object returned true. For example, let's say we want to keep all the states where the `beer_tax` was larger than 50% at least once. We could accomplish that using the following line of code.

In [17]:
df.groupby('st_name').filter(lambda x: (x['beer_tax'] >= 50).any())

Unnamed: 0,st_name,year,c_beer,beer_tax,btax_dollars,population,salestax
0,AL,1970,33098,72.341130,2.370,3450,4.0
1,AL,1971,37598,69.304600,2.370,3497,4.0
2,AL,1972,42719,67.149190,2.370,3539,4.0
3,AL,1973,46203,63.217026,2.370,3580,4.0
4,AL,1974,49769,56.933796,2.370,3627,4.0
...,...,...,...,...,...,...,...
1365,SC,1999,101782,12.283935,1.728,3975,5.0
1366,SC,2000,104116,11.884457,1.728,4023,5.0
1367,SC,2001,105525,11.555637,1.728,4060,5.0
1368,SC,2002,108000,11.375785,1.728,4104,5.0


In the above code, the [`lambda` function](https://www.freecodecamp.org/news/python-lambda-function-explained/#:~:text=They're%20commonly%20referred%20to,to%20use%20the%20function%20once.) defines a Python function that takes in a groupby object named `x`. The function then extracts the `beer_tax` Series from that groupby object, compares it to 50 and returns an array of `True/Falses` indicating whether or not the original value was greater than 50. Then, `.any()` outputs whether any (in other words: one or more) of those comparisons are true. If yes, the lambda function outputs true (meaning the filter keeps the rows from that groupby object) and if no, the lambda function outputs false (meaning the filter removes the rows from that groupby object). As you can see from the code below, only Alabama, Georgia and South Carolina have ever had `beer_tax` larger than 50%.

In [18]:
df.groupby('st_name').filter(lambda x: (x['beer_tax'] >= 50).any())['st_name'].unique()

array(['AL', 'GA', 'SC'], dtype=object)

### Multiple Columns and Aggregations

Finally, we can also groupby multiple columns. To demonstrate this, let us first create another categorical variable by adding in a `decade` column representing which decade the data is from.

In [64]:
df['decade'] = df['year'] // 10 * 10
df

Unnamed: 0,st_name,year,c_beer,beer_tax,btax_dollars,population,salestax,decade
0,AL,1970,33098,72.341130,2.370,3450,4.0,1970
1,AL,1971,37598,69.304600,2.370,3497,4.0,1970
2,AL,1972,42719,67.149190,2.370,3539,4.0,1970
3,AL,1973,46203,63.217026,2.370,3580,4.0,1970
4,AL,1974,49769,56.933796,2.370,3627,4.0,1970
...,...,...,...,...,...,...,...,...
1703,WY,1999,12423,0.319894,0.045,492,4.0,1990
1704,WY,2000,12595,0.309491,0.045,494,4.0,2000
1705,WY,2001,12808,0.300928,0.045,494,4.0,2000
1706,WY,2002,13191,0.296244,0.045,499,4.0,2000


Now, since `decade` is also a categorical variable, we can groupby both the state and the decade.

In [69]:
df.groupby(['st_name','decade']).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,year,c_beer,beer_tax,btax_dollars,population,salestax
st_name,decade,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
AK,1970,1974.5,7715.900,10.603964,0.56250,360.40,0.0
AK,1980,1984.5,13321.300,6.391171,0.70688,498.00,0.0
AK,1990,1994.5,14393.600,5.005872,0.78750,598.50,0.0
AK,2000,2001.5,14537.500,6.272505,1.19250,637.75,0.0
AL,1970,1974.5,50665.600,55.847542,2.37000,3658.50,4.0
...,...,...,...,...,...,...,...
WV,2000,2001.5,40984.500,2.653142,0.39920,1806.00,6.0
WY,1970,1975.5,11349.625,0.989311,0.04500,393.25,3.0
WY,1980,1984.5,13308.000,0.516437,0.04500,488.40,3.0
WY,1990,1994.5,11647.000,0.357562,0.04500,477.70,3.4


Additionally, we can also groupby using multiple aggregation functions.

In [70]:
df.groupby(['st_name','decade']).agg([min,max])

Unnamed: 0_level_0,Unnamed: 1_level_0,year,year,c_beer,c_beer,beer_tax,beer_tax,btax_dollars,btax_dollars,population,population,salestax,salestax
Unnamed: 0_level_1,Unnamed: 1_level_1,min,max,min,max,min,max,min,max,min,max,min,max
st_name,decade,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
AK,1970,1970,1979,5372,9848,7.340821,13.735660,0.5625,0.5625,304,405,0.0,0.0
AK,1980,1980,1989,10350,14711,5.522732,7.181146,0.5625,0.7875,405,547,0.0,0.0
AK,1990,1990,1999,13014,15419,4.478518,5.708654,0.7875,0.7875,553,625,0.0,0.0
AK,2000,2000,2003,14256,15008,4.147421,12.396732,0.7875,2.4075,628,649,0.0,0.0
AL,1970,1970,1979,33098,63999,38.661655,72.341130,2.3700,2.3700,3450,3866,4.0,4.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
WV,2000,2000,2003,39513,42075,2.569457,2.745529,0.3992,0.3992,1802,1810,6.0,6.0
WY,1970,1972,1979,7950,14246,0.734082,1.274985,0.0450,0.0450,347,454,3.0,3.0
WY,1980,1980,1989,10936,15730,0.429793,0.646776,0.0450,0.0450,458,510,3.0,3.0
WY,1990,1990,1999,11253,12423,0.319894,0.407761,0.0450,0.0450,454,492,3.0,4.0


## Pivot Tables

While we could groupby multiple columns above, the resulting dataset was honestly a bit messy. If you're interested in the relationship between two categorical variables and how they affect a third numerical variable, a pivot table is often best. For example, let us use a pivot table to visualize the average beer consumption across the decades for each state.

In [81]:
pd.pivot_table(data=df, index='st_name', columns='decade', 
               values='c_beer', aggfunc=np.mean).head(12)

decade,1970,1980,1990,2000
st_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
AK,7715.9,13321.3,14393.6,14537.5
AL,50665.6,73146.3,85399.2,94585.75
AR,30060.5,43013.7,48867.0,51736.5
AZ,57785.9,91379.6,115175.2,136681.75
CA,460041.2,630967.9,637368.5,653313.75
CO,60350.5,83809.1,90452.1,108031.75
CT,55019.5,63606.1,58895.2,58480.5
DC,17258.4,17290.4,15789.8,14452.25
DE,12163.0,16415.8,17936.4,19900.25
FL,186083.9,309504.6,364002.4,399874.25


As you can see, in the `pd.pivot_table()` method:
- The `data` parameter tells the function which DataFrame to get the data from.
- The `index` parameter says which categorical variable to use to form the index.
- The `columns` parameter says which categorical variable to use to form the columns.
- The `value` parameter says which numerical variable to aggregate.
- The `aggfunc` parameter says which function to use to perform the aggregation.
- Optionally, you can set the `fill_value` parameter to be the value you want to replace the `NaN`'s with. For example, as we don't have data on Hawaii in the 1970s, you can see that datapoint is currently a `NaN` value.
