import pandas as pd
import numpy as np
GroupBys and Pivot Tables#
This subchapter goes over groupby’s and pivot table’s, two incredibly useful pandas
methods.
To start, let’s load in the same dataset as the first subchapter in the Pandas
chapter. As a reminder, this dataset has beer sales across 50 continental states in the US. It is sourced from Salience and Taxation: Theory and Evidence by Chetty, Looney, and Kroft (AER 2009), and it includes 7 columns:
st_name
: the state abbreviationyear
: the year the data was recordedc_beer
: the quantity of beer consumed, in thousands of gallonsbeer_tax
: the ad valorem tax, as a percentagebtax_dollars
: the excise tax, represented in dollars per case (24 cans) of beerpopulation
: the population of the state, in thousandssalestax
: the sales tax percentage
df = pd.read_csv('data/beer_tax.csv')
df
st_name | year | c_beer | beer_tax | btax_dollars | population | salestax | |
---|---|---|---|---|---|---|---|
0 | AL | 1970 | 33098 | 72.341130 | 2.370 | 3450 | 4.0 |
1 | AL | 1971 | 37598 | 69.304600 | 2.370 | 3497 | 4.0 |
2 | AL | 1972 | 42719 | 67.149190 | 2.370 | 3539 | 4.0 |
3 | AL | 1973 | 46203 | 63.217026 | 2.370 | 3580 | 4.0 |
4 | AL | 1974 | 49769 | 56.933796 | 2.370 | 3627 | 4.0 |
... | ... | ... | ... | ... | ... | ... | ... |
1703 | WY | 1999 | 12423 | 0.319894 | 0.045 | 492 | 4.0 |
1704 | WY | 2000 | 12595 | 0.309491 | 0.045 | 494 | 4.0 |
1705 | WY | 2001 | 12808 | 0.300928 | 0.045 | 494 | 4.0 |
1706 | WY | 2002 | 13191 | 0.296244 | 0.045 | 499 | 4.0 |
1707 | WY | 2003 | 15535 | 0.289643 | 0.045 | 501 | 4.0 |
1708 rows × 7 columns
GroupBys#
Let’s say we’re interested in how the average tax differs by state. Right now, we have a fair amount of data for every state, so it’d great if we could combine the data for each state somehow. .groupby()
helps us do exactly that.
When we groupby on a column ‘x’, we create groupby objects
for each unique value in the column. Each groupby object contains all the data corresponding to that unique value from the original DataFrame. The code below visualizes the first 5 rows from the groupby objects corresponding to the first 5 states in our dataset.
i = 0
for name, frame in df.groupby('st_name'):
if i > 5:
continue
print(f'The first 5 rows from the subframe for {name} are: ')
print(frame.head())
print('________________________________________________________________________')
i += 1
The first 5 rows from the subframe for AK are:
st_name year c_beer beer_tax btax_dollars population salestax
34 AK 1970 5372 13.735660 0.5625 304 0.0
35 AK 1971 6336 13.159102 0.5625 316 0.0
36 AK 1972 6038 12.749847 0.5625 324 0.0
37 AK 1973 6453 12.003234 0.5625 331 0.0
38 AK 1974 7598 10.810215 0.5625 341 0.0
________________________________________________________________________
The first 5 rows from the subframe for AL are:
st_name year c_beer beer_tax btax_dollars population salestax
0 AL 1970 33098 72.341130 2.37 3450 4.0
1 AL 1971 37598 69.304600 2.37 3497 4.0
2 AL 1972 42719 67.149190 2.37 3539 4.0
3 AL 1973 46203 63.217026 2.37 3580 4.0
4 AL 1974 49769 56.933796 2.37 3627 4.0
________________________________________________________________________
The first 5 rows from the subframe for AR are:
st_name year c_beer beer_tax btax_dollars population salestax
102 AR 1970 22378 16.632357 0.5449 1930 3.0
103 AR 1971 25020 15.934210 0.5449 1972 3.0
104 AR 1972 25614 15.438649 0.5449 2019 3.0
105 AR 1973 27946 14.534582 0.5449 2059 3.0
106 AR 1974 30915 13.089970 0.5449 2101 3.0
________________________________________________________________________
The first 5 rows from the subframe for AZ are:
st_name year c_beer beer_tax btax_dollars population salestax
68 AZ 1970 38604 5.494264 0.18 1795 3.0
69 AZ 1971 41837 5.263641 0.18 1896 3.0
70 AZ 1972 47949 5.099939 0.18 2008 3.0
71 AZ 1973 53380 4.801294 0.18 2124 3.0
72 AZ 1974 58188 4.324086 0.18 2223 3.0
________________________________________________________________________
The first 5 rows from the subframe for CA are:
st_name year c_beer beer_tax btax_dollars population salestax
136 CA 1970 363645 2.747132 0.09 20023 5.000
137 CA 1971 380397 2.631820 0.09 20346 5.000
138 CA 1972 401928 2.549970 0.09 20585 5.000
139 CA 1973 417463 2.400647 0.09 20869 5.167
140 CA 1974 464237 2.162043 0.09 21174 5.250
________________________________________________________________________
The first 5 rows from the subframe for CO are:
st_name year c_beer beer_tax btax_dollars population salestax
170 CO 1970 42145 4.120698 0.135 2224 3.0
171 CO 1971 45359 3.947731 0.135 2304 3.0
172 CO 1972 50444 3.824955 0.135 2405 3.0
173 CO 1973 55332 3.600970 0.135 2496 3.0
174 CO 1974 60162 3.243065 0.135 2541 3.0
________________________________________________________________________
Aggregate#
Now, we wish to collapse all the groupby objects into a single row, and return a DataFrame that has all the single rows across all groupby objects. We can do that with the help of an aggregation function which tells pandas
how to combine the rows. For example, the aggregation function .mean()
tells pandas
to take the mean of the rows in the groupby object. We do this below.
df.groupby('st_name').mean()
year | c_beer | beer_tax | btax_dollars | population | salestax | |
---|---|---|---|---|---|---|
st_name | ||||||
AK | 1986.5 | 12131.117647 | 7.208826 | 0.745259 | 503.529412 | 0.000000 |
AL | 1986.5 | 72660.411765 | 31.817266 | 2.370000 | 4022.294118 | 4.000000 |
AR | 1986.5 | 41951.705882 | 7.315287 | 0.544900 | 2359.441176 | 3.813235 |
AZ | 1986.5 | 93827.470588 | 3.282313 | 0.283235 | 3490.000000 | 4.494118 |
CA | 1986.5 | 585206.794118 | 2.189079 | 0.222353 | 27467.588235 | 6.320794 |
CO | 1986.5 | 81713.058824 | 2.189892 | 0.171397 | 3282.941176 | 3.013235 |
CT | 1986.5 | 59092.058824 | 3.457670 | 0.307591 | 3228.500000 | 6.617647 |
DC | 1986.5 | 16505.735294 | 2.320071 | 0.180018 | 633.676471 | 6.823529 |
DE | 1986.5 | 16022.147059 | 2.557550 | 0.226297 | 659.970588 | 0.000000 |
FL | 1986.5 | 299864.882353 | 15.981612 | 1.290588 | 11903.176471 | 5.073529 |
GA | 1986.5 | 125192.500000 | 30.609017 | 2.280000 | 6346.941176 | 3.411765 |
HI | 1994.5 | 30017.666667 | 14.083035 | 1.994400 | 1168.833333 | 4.000000 |
IA | 1987.5 | 67063.187500 | 4.402103 | 0.376525 | 2868.093750 | 3.968750 |
ID | 1986.5 | 22516.970588 | 4.530940 | 0.337500 | 1025.823529 | 4.117647 |
IL | 1986.5 | 271348.970588 | 2.182229 | 0.167700 | 11664.470588 | 5.595588 |
IN | 1986.5 | 113699.176471 | 3.078238 | 0.241494 | 5619.705882 | 4.382353 |
KS | 1987.5 | 49179.375000 | 4.742910 | 0.393400 | 2477.718750 | 6.625000 |
KY | 1986.5 | 69095.970588 | 2.436639 | 0.181500 | 3712.764706 | 5.382353 |
LA | 1986.5 | 101652.470588 | 11.196456 | 0.834000 | 4223.088235 | 3.529412 |
MA | 1986.5 | 132109.176471 | 3.035118 | 0.232885 | 5973.147059 | 4.588235 |
MD | 1986.5 | 98233.852941 | 2.543661 | 0.196544 | 4630.058824 | 4.764706 |
ME | 1986.5 | 25878.147059 | 8.673873 | 0.694303 | 1174.176471 | 5.235294 |
MI | 1986.5 | 211427.117647 | 6.139256 | 0.457300 | 9380.529412 | 4.558824 |
MN | 1987.5 | 98194.843750 | 3.776908 | 0.312781 | 4369.312500 | 6.984375 |
MO | 1987.5 | 120355.656250 | 1.673511 | 0.135000 | 5147.312500 | 3.796875 |
MS | 1986.5 | 54309.264706 | 12.940641 | 0.964768 | 2587.617647 | 5.897059 |
MT | 1986.5 | 23113.264706 | 3.621642 | 0.287794 | 815.294118 | 0.000000 |
NC | 1986.5 | 126038.911765 | 16.105477 | 1.199318 | 6554.441176 | 3.367647 |
ND | 1986.5 | 15923.823529 | 4.833003 | 0.360000 | 647.588235 | 5.411765 |
NE | 1986.5 | 39524.617647 | 4.398752 | 0.390994 | 1603.352941 | 3.888471 |
NH | 1986.5 | 33263.970588 | 6.042833 | 0.536862 | 1025.823529 | 0.000000 |
NJ | 1986.5 | 151812.058824 | 1.564956 | 0.149721 | 7746.117647 | 5.617647 |
NM | 1986.5 | 38953.970588 | 4.713016 | 0.472835 | 1467.323529 | 4.397059 |
NV | 1986.5 | 39440.941176 | 2.162150 | 0.175865 | 1170.235294 | 5.112132 |
NY | 1986.5 | 353700.823529 | 2.300439 | 0.221374 | 18200.617647 | 3.941176 |
OH | 1986.5 | 249413.294118 | 4.913885 | 0.371435 | 10950.147059 | 4.617647 |
OK | 1986.5 | 57819.441176 | 10.633824 | 0.831315 | 3131.647059 | 3.235294 |
OR | 1986.5 | 62440.235294 | 2.045688 | 0.169976 | 2810.235294 | 0.000000 |
PA | 1986.5 | 285304.941176 | 2.416501 | 0.180000 | 11993.882353 | 6.000000 |
RI | 1986.5 | 23093.794118 | 2.739664 | 0.217518 | 991.558824 | 6.176471 |
SC | 1986.5 | 75692.176471 | 23.198414 | 1.728000 | 3387.911765 | 4.558824 |
SD | 1986.5 | 15654.676471 | 8.112361 | 0.608353 | 708.176471 | 4.029412 |
TN | 1986.5 | 97898.352941 | 3.531423 | 0.271576 | 4866.147059 | 5.014706 |
TX | 1986.5 | 444450.558824 | 5.335410 | 0.413312 | 16405.764706 | 5.003676 |
UT | 1986.5 | 22289.352941 | 6.028931 | 0.575079 | 1688.323529 | 4.834559 |
VA | 1986.5 | 128329.147059 | 7.745603 | 0.605824 | 5949.764706 | 4.176471 |
VT | 1986.5 | 13341.676471 | 7.750310 | 0.584868 | 540.764706 | 3.970588 |
WA | 1986.5 | 99727.617647 | 2.896268 | 0.291253 | 4664.941176 | 5.814706 |
WI | 1986.5 | 147105.176471 | 1.949311 | 0.145200 | 4891.000000 | 4.617647 |
WV | 1986.5 | 35606.941176 | 5.359263 | 0.399200 | 1847.058824 | 4.735294 |
WY | 1987.5 | 12327.375000 | 0.557837 | 0.045000 | 462.343750 | 3.250000 |
You can visualize the entire process (although using only a subset of the data) here. In general, we encourage you to play around with Pandas Tutor to visualize how your code works!
There are a lot of different agg
functions you can use! We’ve given a few below, but won’t be surprised if there are more!
Built-in aggregation functions:
.mean()
.median()
.sum()
.count()
Note: It may appear as if
.value_counts()
and.count()
return the same data. However, while.value_counts()
returns a Series sorted from most common to least, the aggregation function.count()
returns a DataFrame with the same ordering as the index.
.max()
.min()
.std()
.var()
You can also use other functions, such as those defined by NumPy
. Examples include:
.agg(np.mean)
.agg(np.prod)
.agg(np.cumsum)
returns the cumulative sum; read more here
df.groupby('st_name').agg(np.cumsum)
year | c_beer | beer_tax | btax_dollars | population | salestax | |
---|---|---|---|---|---|---|
0 | 1970 | 33098 | 72.341130 | 2.370 | 3450 | 4.0 |
1 | 3941 | 70696 | 141.645730 | 4.740 | 6947 | 8.0 |
2 | 5913 | 113415 | 208.794920 | 7.110 | 10486 | 12.0 |
3 | 7886 | 159618 | 272.011946 | 9.480 | 14066 | 16.0 |
4 | 9860 | 209387 | 328.945742 | 11.850 | 17693 | 20.0 |
... | ... | ... | ... | ... | ... | ... |
1703 | 55594 | 340347 | 16.654479 | 1.260 | 12807 | 88.0 |
1704 | 57594 | 352942 | 16.963970 | 1.305 | 13301 | 92.0 |
1705 | 59595 | 365750 | 17.264899 | 1.350 | 13795 | 96.0 |
1706 | 61597 | 378941 | 17.561143 | 1.395 | 14294 | 100.0 |
1707 | 63600 | 394476 | 17.850786 | 1.440 | 14795 | 104.0 |
1708 rows × 6 columns
Finally, if you like, you can also define your own aggregation function! An example is given below, where we define last_10_vs_first_10
to return the average value in the last 10 years minus the average value in the first 10 years. Remember, you aggregation function must be able to aggregate columns of data into a single value.
def last_10_vs_first_10(obj):
return obj.iloc[-10:].mean() - obj.iloc[:10].mean()
df.groupby('st_name').agg(last_10_vs_first_10)
year | c_beer | beer_tax | btax_dollars | population | salestax | |
---|---|---|---|---|---|---|
st_name | ||||||
AK | 24.0 | 6799.2 | -5.255919 | 0.387000 | 262.1 | 0.0000 |
AL | 24.0 | 41128.7 | -38.866784 | 0.000000 | 740.4 | 0.0000 |
AR | 24.0 | 20500.6 | -8.936081 | 0.000000 | 514.8 | 1.7150 |
AZ | 24.0 | 69478.1 | -1.662230 | 0.180000 | 2714.2 | 1.6800 |
CA | 24.0 | 173227.7 | 1.103402 | 0.360000 | 11827.6 | 1.6833 |
CO | 24.0 | 40179.1 | -2.181436 | 0.029250 | 1542.9 | -0.0300 |
CT | 24.0 | 2883.6 | -1.191767 | 0.262600 | 311.3 | -0.2500 |
DC | 24.0 | -2372.4 | -2.397173 | 0.039200 | -136.6 | 2.2000 |
DE | 24.0 | 6749.2 | -0.899509 | 0.206800 | 187.7 | 0.0000 |
FL | 24.0 | 194335.7 | -13.366342 | 0.420000 | 7388.0 | 2.0000 |
GA | 24.0 | 88146.5 | -37.390833 | 0.000000 | 2928.6 | 1.0000 |
HI | 8.0 | -1523.8 | -2.805578 | 0.069790 | 90.0 | 0.0000 |
IA | 22.0 | 4125.4 | -3.340308 | 0.112500 | 10.3 | 2.0000 |
ID | 24.0 | 6536.0 | -5.534827 | 0.000000 | 438.7 | 2.0000 |
IL | 24.0 | 33124.9 | -2.352431 | 0.034680 | 1011.3 | 1.4500 |
IN | 24.0 | 25558.0 | -3.001028 | 0.051340 | 651.0 | 1.9000 |
KS | 22.0 | 7374.4 | -4.442635 | 0.037120 | 346.3 | 4.0000 |
KY | 24.0 | 17057.1 | -2.976507 | 0.000000 | 546.1 | 1.0000 |
LA | 24.0 | 35240.5 | -13.677173 | 0.000000 | 556.6 | 1.1000 |
MA | 24.0 | 1161.4 | -3.315096 | 0.022490 | 537.7 | 1.4000 |
MD | 24.0 | 5272.9 | -2.726227 | 0.020250 | 1139.4 | 0.8000 |
ME | 24.0 | 3304.3 | -7.612615 | 0.225000 | 202.7 | 0.6000 |
MI | 24.0 | -4159.0 | -7.499486 | 0.000000 | 767.3 | 1.9000 |
MN | 22.0 | 18407.3 | -3.508841 | 0.043600 | 862.3 | 4.5000 |
MO | 22.0 | 27237.9 | -1.777010 | 0.000000 | 681.2 | 1.1775 |
MS | 24.0 | 29677.5 | -15.748429 | 0.000000 | 422.0 | 2.0000 |
MT | 24.0 | 3531.8 | -3.334571 | 0.070710 | 150.7 | 0.0000 |
NC | 24.0 | 81130.0 | -19.694798 | -0.002319 | 2354.6 | 1.0500 |
ND | 24.0 | 2483.5 | -5.903816 | 0.000000 | 4.6 | 3.2000 |
NE | 24.0 | 5797.6 | -1.511346 | 0.292690 | 159.3 | 2.3250 |
NH | 24.0 | 11805.3 | -2.105625 | 0.374620 | 390.7 | 0.0000 |
NJ | 24.0 | 2076.7 | 0.169546 | 0.195100 | 1005.8 | 1.2000 |
NM | 24.0 | 19793.1 | 2.314072 | 0.735750 | 642.7 | 1.0250 |
NV | 24.0 | 40810.7 | -1.730302 | 0.067500 | 1266.5 | 3.3500 |
NY | 24.0 | -43471.5 | 0.155826 | 0.247060 | 804.5 | 0.2000 |
OH | 24.0 | 43336.6 | -5.628816 | 0.038880 | 554.4 | 1.0000 |
OK | 24.0 | 22580.2 | -10.607037 | 0.181300 | 642.7 | 2.5000 |
OR | 24.0 | 22796.4 | -1.436705 | 0.063660 | 1035.0 | 0.0000 |
PA | 24.0 | 12249.7 | -2.951908 | 0.000000 | 376.7 | 0.0000 |
RI | 24.0 | -195.2 | -2.530756 | 0.073200 | 84.7 | 1.7000 |
SC | 24.0 | 50838.6 | -28.338316 | 0.000000 | 1076.5 | 1.0000 |
SD | 24.0 | 6363.1 | -9.672514 | 0.018150 | 68.3 | 0.0000 |
TN | 24.0 | 44875.0 | -3.770087 | 0.038987 | 1326.5 | 2.4000 |
TX | 24.0 | 224717.2 | -5.557496 | 0.074200 | 7861.6 | 2.4000 |
UT | 24.0 | 10895.3 | 0.418456 | 0.573400 | 943.3 | 0.3125 |
VA | 24.0 | 50298.4 | -7.734889 | 0.102600 | 1964.4 | 0.7000 |
VT | 24.0 | 2383.1 | -8.982539 | 0.033800 | 126.3 | 2.0000 |
WA | 24.0 | 33204.6 | 1.306892 | 0.435141 | 2147.6 | 1.6200 |
WI | 24.0 | 9254.4 | -2.381206 | 0.000000 | 754.6 | 1.0000 |
WV | 24.0 | 11053.0 | -6.546676 | 0.000000 | -29.4 | 3.0000 |
WY | 22.0 | 369.9 | -0.592337 | 0.000000 | 80.0 | 0.8000 |
As a general tip, whenever you’re trying to calculate differences across a categorical variable, consider whether groupby’s could be helpful. Then, determine which column you would groupby on (normally the categorical column you’re interested in). Finally, determine which aggregation function may be most helpful.
Moreover, if we only want to see the results of the .groupby()
on a single column, it is common practice to select that column before the aggregation function to minimize computing speed. For example, if we just want to see the median beer consumption for each state, we could do df.groupby('st_name')['c_beer'].median()
.
df.groupby('st_name')['c_beer'].median()
st_name
AK 13666.5
AL 74930.5
AR 44731.0
AZ 98964.5
CA 620989.5
CO 83391.0
CT 58610.5
DC 16862.5
DE 16678.0
FL 327415.0
GA 131193.0
HI 29817.0
IA 66664.0
ID 23272.5
IL 280701.5
IN 119674.5
KS 49094.0
KY 71587.5
LA 107196.0
MA 131876.5
MD 98215.5
ME 26079.0
MI 210840.0
MN 98980.5
MO 122805.0
MS 56514.0
MT 23274.0
NC 127415.0
ND 15935.5
NE 39395.5
NH 35645.5
NJ 150161.5
NM 41557.5
NV 36316.0
NY 358784.0
OH 257559.5
OK 60139.0
OR 63020.0
PA 285790.5
RI 22784.5
SC 78973.0
SD 15490.0
TN 101125.0
TX 472811.5
UT 22402.5
VA 138821.0
VT 13601.0
WA 101889.0
WI 147526.5
WV 37934.5
WY 12071.0
Name: c_beer, dtype: float64
Selecting the relevant columns like this is especially important if you may have other strings in your DataFrame and you’re attempting to do an aggregation function which doesn’t work with strings (for example, it makes no sense to try to take the mean of strings). If you try to use an aggregation function on a type of data it cannot work with, your code will error.
df
st_name | year | c_beer | beer_tax | btax_dollars | population | salestax | |
---|---|---|---|---|---|---|---|
0 | AL | 1970 | 33098 | 72.341130 | 2.370 | 3450 | 4.0 |
1 | AL | 1971 | 37598 | 69.304600 | 2.370 | 3497 | 4.0 |
2 | AL | 1972 | 42719 | 67.149190 | 2.370 | 3539 | 4.0 |
3 | AL | 1973 | 46203 | 63.217026 | 2.370 | 3580 | 4.0 |
4 | AL | 1974 | 49769 | 56.933796 | 2.370 | 3627 | 4.0 |
... | ... | ... | ... | ... | ... | ... | ... |
1703 | WY | 1999 | 12423 | 0.319894 | 0.045 | 492 | 4.0 |
1704 | WY | 2000 | 12595 | 0.309491 | 0.045 | 494 | 4.0 |
1705 | WY | 2001 | 12808 | 0.300928 | 0.045 | 494 | 4.0 |
1706 | WY | 2002 | 13191 | 0.296244 | 0.045 | 499 | 4.0 |
1707 | WY | 2003 | 15535 | 0.289643 | 0.045 | 501 | 4.0 |
1708 rows × 7 columns
Filter#
Aggregate functions collapse all the groupby objects into a single row. Filter functions instead return a True/False
for each groupby object and only keep the rows for whom the groupby object returned true. For example, let’s say we want to keep all the states where the beer_tax
was larger than 50% at least once. We could accomplish that using the following line of code.
df.groupby('st_name').filter(lambda x: (x['beer_tax'] >= 50).any())
st_name | year | c_beer | beer_tax | btax_dollars | population | salestax | |
---|---|---|---|---|---|---|---|
0 | AL | 1970 | 33098 | 72.341130 | 2.370 | 3450 | 4.0 |
1 | AL | 1971 | 37598 | 69.304600 | 2.370 | 3497 | 4.0 |
2 | AL | 1972 | 42719 | 67.149190 | 2.370 | 3539 | 4.0 |
3 | AL | 1973 | 46203 | 63.217026 | 2.370 | 3580 | 4.0 |
4 | AL | 1974 | 49769 | 56.933796 | 2.370 | 3627 | 4.0 |
... | ... | ... | ... | ... | ... | ... | ... |
1365 | SC | 1999 | 101782 | 12.283935 | 1.728 | 3975 | 5.0 |
1366 | SC | 2000 | 104116 | 11.884457 | 1.728 | 4023 | 5.0 |
1367 | SC | 2001 | 105525 | 11.555637 | 1.728 | 4060 | 5.0 |
1368 | SC | 2002 | 108000 | 11.375785 | 1.728 | 4104 | 5.0 |
1369 | SC | 2003 | 103058 | 11.122302 | 1.728 | 4147 | 5.0 |
102 rows × 7 columns
In the above code, the lambda
function defines a Python function that takes in a groupby object named x
. The function then extracts the beer_tax
Series from that groupby object, compares it to 50 and returns an array of True/Falses
indicating whether or not the original value was greater than 50. Then, .any()
outputs whether any (in other words: one or more) of those comparisons are true. If yes, the lambda function outputs true (meaning the filter keeps the rows from that groupby object) and if no, the lambda function outputs false (meaning the filter removes the rows from that groupby object). As you can see from the code below, only Alabama, Georgia and South Carolina have ever had beer_tax
larger than 50%.
df.groupby('st_name').filter(lambda x: (x['beer_tax'] >= 50).any())['st_name'].unique()
array(['AL', 'GA', 'SC'], dtype=object)
Multiple Columns and Aggregations#
Finally, we can also groupby multiple columns. To demonstrate this, let us first create another categorical variable by adding in a decade
column representing which decade the data is from.
df['decade'] = df['year'] // 10 * 10
df
st_name | year | c_beer | beer_tax | btax_dollars | population | salestax | decade | |
---|---|---|---|---|---|---|---|---|
0 | AL | 1970 | 33098 | 72.341130 | 2.370 | 3450 | 4.0 | 1970 |
1 | AL | 1971 | 37598 | 69.304600 | 2.370 | 3497 | 4.0 | 1970 |
2 | AL | 1972 | 42719 | 67.149190 | 2.370 | 3539 | 4.0 | 1970 |
3 | AL | 1973 | 46203 | 63.217026 | 2.370 | 3580 | 4.0 | 1970 |
4 | AL | 1974 | 49769 | 56.933796 | 2.370 | 3627 | 4.0 | 1970 |
... | ... | ... | ... | ... | ... | ... | ... | ... |
1703 | WY | 1999 | 12423 | 0.319894 | 0.045 | 492 | 4.0 | 1990 |
1704 | WY | 2000 | 12595 | 0.309491 | 0.045 | 494 | 4.0 | 2000 |
1705 | WY | 2001 | 12808 | 0.300928 | 0.045 | 494 | 4.0 | 2000 |
1706 | WY | 2002 | 13191 | 0.296244 | 0.045 | 499 | 4.0 | 2000 |
1707 | WY | 2003 | 15535 | 0.289643 | 0.045 | 501 | 4.0 | 2000 |
1708 rows × 8 columns
Now, since decade
is also a categorical variable, we can groupby both the state and the decade.
df.groupby(['st_name','decade']).mean()
year | c_beer | beer_tax | btax_dollars | population | salestax | ||
---|---|---|---|---|---|---|---|
st_name | decade | ||||||
AK | 1970 | 1974.5 | 7715.900 | 10.603964 | 0.56250 | 360.40 | 0.0 |
1980 | 1984.5 | 13321.300 | 6.391171 | 0.70688 | 498.00 | 0.0 | |
1990 | 1994.5 | 14393.600 | 5.005872 | 0.78750 | 598.50 | 0.0 | |
2000 | 2001.5 | 14537.500 | 6.272505 | 1.19250 | 637.75 | 0.0 | |
AL | 1970 | 1974.5 | 50665.600 | 55.847542 | 2.37000 | 3658.50 | 4.0 |
... | ... | ... | ... | ... | ... | ... | ... |
WV | 2000 | 2001.5 | 40984.500 | 2.653142 | 0.39920 | 1806.00 | 6.0 |
WY | 1970 | 1975.5 | 11349.625 | 0.989311 | 0.04500 | 393.25 | 3.0 |
1980 | 1984.5 | 13308.000 | 0.516437 | 0.04500 | 488.40 | 3.0 | |
1990 | 1994.5 | 11647.000 | 0.357562 | 0.04500 | 477.70 | 3.4 | |
2000 | 2001.5 | 13532.250 | 0.299077 | 0.04500 | 497.00 | 4.0 |
203 rows × 6 columns
Additionally, we can also groupby using multiple aggregation functions.
df.groupby(['st_name','decade']).agg([min,max])
year | c_beer | beer_tax | btax_dollars | population | salestax | ||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
min | max | min | max | min | max | min | max | min | max | min | max | ||
st_name | decade | ||||||||||||
AK | 1970 | 1970 | 1979 | 5372 | 9848 | 7.340821 | 13.735660 | 0.5625 | 0.5625 | 304 | 405 | 0.0 | 0.0 |
1980 | 1980 | 1989 | 10350 | 14711 | 5.522732 | 7.181146 | 0.5625 | 0.7875 | 405 | 547 | 0.0 | 0.0 | |
1990 | 1990 | 1999 | 13014 | 15419 | 4.478518 | 5.708654 | 0.7875 | 0.7875 | 553 | 625 | 0.0 | 0.0 | |
2000 | 2000 | 2003 | 14256 | 15008 | 4.147421 | 12.396732 | 0.7875 | 2.4075 | 628 | 649 | 0.0 | 0.0 | |
AL | 1970 | 1970 | 1979 | 33098 | 63999 | 38.661655 | 72.341130 | 2.3700 | 2.3700 | 3450 | 3866 | 4.0 | 4.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
WV | 2000 | 2000 | 2003 | 39513 | 42075 | 2.569457 | 2.745529 | 0.3992 | 0.3992 | 1802 | 1810 | 6.0 | 6.0 |
WY | 1970 | 1972 | 1979 | 7950 | 14246 | 0.734082 | 1.274985 | 0.0450 | 0.0450 | 347 | 454 | 3.0 | 3.0 |
1980 | 1980 | 1989 | 10936 | 15730 | 0.429793 | 0.646776 | 0.0450 | 0.0450 | 458 | 510 | 3.0 | 3.0 | |
1990 | 1990 | 1999 | 11253 | 12423 | 0.319894 | 0.407761 | 0.0450 | 0.0450 | 454 | 492 | 3.0 | 4.0 | |
2000 | 2000 | 2003 | 12595 | 15535 | 0.289643 | 0.309491 | 0.0450 | 0.0450 | 494 | 501 | 4.0 | 4.0 |
203 rows × 12 columns
Pivot Tables#
While we could groupby multiple columns above, the resulting dataset was honestly a bit messy. If you’re interested in the relationship between two categorical variables and how they affect a third numerical variable, a pivot table is often best. For example, let us use a pivot table to visualize the average beer consumption across the decades for each state.
pd.pivot_table(data=df, index='st_name', columns='decade',
values='c_beer', aggfunc=np.mean).head(12)
decade | 1970 | 1980 | 1990 | 2000 |
---|---|---|---|---|
st_name | ||||
AK | 7715.9 | 13321.3 | 14393.6 | 14537.50 |
AL | 50665.6 | 73146.3 | 85399.2 | 94585.75 |
AR | 30060.5 | 43013.7 | 48867.0 | 51736.50 |
AZ | 57785.9 | 91379.6 | 115175.2 | 136681.75 |
CA | 460041.2 | 630967.9 | 637368.5 | 653313.75 |
CO | 60350.5 | 83809.1 | 90452.1 | 108031.75 |
CT | 55019.5 | 63606.1 | 58895.2 | 58480.50 |
DC | 17258.4 | 17290.4 | 15789.8 | 14452.25 |
DE | 12163.0 | 16415.8 | 17936.4 | 19900.25 |
FL | 186083.9 | 309504.6 | 364002.4 | 399874.25 |
GA | 80081.0 | 120273.0 | 153751.9 | 178871.50 |
HI | NaN | 30242.5 | 30353.2 | 28954.00 |
As you can see, in the pd.pivot_table()
method:
The
data
parameter tells the function which DataFrame to get the data from.The
index
parameter says which categorical variable to use to form the index.The
columns
parameter says which categorical variable to use to form the columns.The
value
parameter says which numerical variable to aggregate.The
aggfunc
parameter says which function to use to perform the aggregation.Optionally, you can set the
fill_value
parameter to be the value you want to replace theNaN
’s with. For example, as we don’t have data on Hawaii in the 1970s, you can see that datapoint is currently aNaN
value.