Python分多组求平均值的优雅操作

 
Category: Python

一个问题

最近遇到一个问题, 如何分组计算平均值, 例如, 对于随机生成的范围在0~1000的数据, 这里用下面的代码实现:

import random
import pandas as pd

N = 1000
random.seed(10)
a = [random.randint(0, N) for _ in range(N)]

对于这一组数据, 依据数据的值大小分成$N$个长度为$20$的区间, 对于每一个区间中的值计算平均值, 然后输出.

分析与解决

这个需求的话, 直观的思路当然是采用数组存每一个分组(通过if判断), 然后分别计算输出, 但是当你的数据量很大的时候(分的组数变多), 就不能这么操作了, 一个一个写if-elif简直是噩梦…

联想到pandas中一个很有名的函数.groupby(), 我决定使用这个函数来简化操作~

核心代码就是通过匿名函数的调用来进行的, 由于Python可以直接写两个比较判断的符号, 这样写起来就会方便很多

df = pd.DataFrame(a, index=a)
interval = 20
groupNum = len(a) // interval
ans = []

for i in range(groupNum):
    tmp = df.groupby(
        lambda x: i * interval < x <= (i + 1) * interval).mean().at[True, 0]
    print(tmp)
    ans.append(tmp)
print(ans)

注意, 上面的代码中, 第一行的索引设置是必要的, 因为groupb实际上就是对索引进行操作, 然后对数值进行后续操作.

核心代码是在第8行, 采用匿名函数的方法对index进行比较, 满足条件的话就归为一组, 然后计算平均值, 计算之后得到的一个DataFrame对象, 再通过取值方式取出计算出满足条件的值的均值即可.

简单推广

进一步, 如果对于同一索引下的其他值进行操作, 则只需要改变df的构造即可, 这里主要要说的是, 对于最后一个区间以及第一个区间, 如果分出的区间不能完全包含所有数据(如果存在很大的数据), 那么就不属于任何组了, 这时候我的思路是通过修改匿名函数的行为, 如下:

for i in range(groupNum):
    tmp = df.groupby(
        lambda x: i * interval < x <= (i + 1) * interval
        if i != groupNum - 1 else x > (i + 1) * interval).mean().at[True, 0]
    print(tmp)
    ans.append(tmp)

添加了第四行这个代码, 就能够满足所有的数据都在对应的分组了, 对于较小的数据也是类似的操作, 这里就不重复了.

效率对比

前面用到了groupby这个操作, 但是总感觉有资源浪费的地方, 因为计算平均值时候还对indexFalse的值进行了计算, 如果先提取出所有 分组为True的值, 然后再进行mean(), 会不会时间会缩短呢?

下面是采用get_group()方法取出所有True的分组之后再进行均值操作的代码, 通过timeit()方法的测试, 我发现这样操作反而不如先计算均值再从结果中取indexTrue的值, 这是为什么呢?

import random
import timeit
import pandas as pd


def test1():
    N = 10000
    random.seed(10)
    a = [random.randint(0, N) for _ in range(N)]
    df = pd.DataFrame(a, index=a)
    interval = 20
    groupNum = len(a) // interval
    ans = []
    for i in range(groupNum):
        tmp = df.groupby(
            lambda x: i * interval < x <= (i + 1) * interval).mean().at[True, 0]
            # lambda x: i * interval < x <= (i + 1) * interval).get_group(True).mean()[0]
        ans.append(tmp)
    return ans

print(timeit.timeit(test1, number=10), "s")

只需要解注释第17行即可. 经过我的运行, 我发现原来的操作竟然比采用get_group()之后的操作还要快大概1s左右, 可能问题就出在这个get_group()了, 以后再慢慢研究这个函数..