How to Get Top N Rows with in Each Group in Pandas?

In this post we will see how to get top N rows from a data frame such that the top values of a specific variable in each group defined by another variable. Note this is not the same as top N rows according to one variable in the whole dataframe.

Let us say we have gapminder data frame that has life expectancy values for countries in five continents. And we would like to see top 3 countries (rows) with large life expectancies (descending order) in each continent. Basically, we need top N rows in each group.

We earlier wrote a post on getting top N rows in a data frame, but this one has a slight twist 🙂 See the blogpost,”How to Select Top N Rows with the Largest Values in a Column(s) in Pandas?

top N rows in each group: Step by Step

Getting top N rows with in each group involves multiple steps. First, let us see how to get top N rows within each group step by step and later we can combine some of the steps.

Let us first load gapminder data frame from Carpentries site and filter the data frame to contain data for the year 2007.

data_url = 'http://bit.ly/2cLzoxH'
gapminder = pd.read_csv(data_url)
gapminder_2007 = gapminder[gapminder.year==2007]

Let us load Pandas.

# load pandas
import pandas as pd

Since we want to find top N countries with highest life expectancy in each continent group, let us group our dataframe by “continent” using Pandas’s groupby function. We save the resulting grouped dataframe into a new variable.

df1 = gapminder_2007.groupby(["continent"])

Remember, the resulting grouped dataframe has all the data, but for each group (here continent) separately. Next, we take the grouped dataframe and use the function apply in Pandas to sort each group within the grouped data frame. Since we want top countries with highest life expectancy, we sort by the variable “lifeExp”. Inside apply function, we use lambda function to perform sorting by “lifeExp”.

df2= df1.apply(lambda x: x.sort_values(["lifeExp"]

We have saved the resulting grouped and sorted dataframe into another variable. If you take a look at the content of this grouped and sorted dataframe, we can see that it has multi-index one for continent and the other index for row numbers.

df2

                country  year         pop continent  lifeExp     gdpPercap
continent                                                                 
Africa    1271  Reunion  2007    798094.0    Africa   76.442   7670.122558
          911     Libya  2007   6036914.0    Africa   73.952  12057.499280
          1571  Tunisia  2007  10276158.0    Africa   73.923   7092.923025

We can reset the index using “reset_index” function. Here we use “drop=True” option to avoid index being added as column to the resulting data frame.

>df3=df2.reset_index(drop=True)
>df3.head()

   country  year         pop continent  lifeExp     gdpPercap
0  Reunion  2007    798094.0    Africa   76.442   7670.122558
1    Libya  2007   6036914.0    Africa   73.952  12057.499280
2  Tunisia  2007  10276158.0    Africa   73.923   7092.923025

So, now we have a sorted dataframe. If you examine the rows, we can see that first we have countries from Afria with lifeExp in descending order and the next are other continents sorted by lifeExp in descending order.
Now we are ready to select N rows from each group, in this example “continent”.

We can use groupby function with “continent” as argument and use head() function to select the first N rows. Since the rows within each continent is sorted by lifeExp, we will get top N rows with high lifeExp for each continent.

Let us see the top most country with high lifeExp in each continent.

df3.groupby('continent').head(1)

Here is the topmost row in each group.

       country  year          pop continent  lifeExp     gdpPercap
0      Reunion  2007     798094.0    Africa   76.442   7670.122558
52      Canada  2007   33390141.0  Americas   80.653  36319.235010
77       Japan  2007  127467972.0      Asia   82.603  31656.068060
110    Iceland  2007     301931.0    Europe   81.757  36180.789190
140  Australia  2007   20434176.0   Oceania   81.235  34435.367440

We got the top N rows within each group in multiple steps. We can combine these steps by chaining the commands. Here we have chained the steps so that there are just two steps to get top N rows within each group. And we will get the same answer as above.

# get dataframe sorted by life Expectancy in each continent 
g = gapminder_2007.groupby(["continent"]).apply(lambda x: x.sort_values(["lifeExp"], ascending = False)).reset_index(drop=True)
# select top N rows within each continent
g.groupby('continent').head(1)