2 回答

TA贡献1796条经验 获得超4个赞
第 1 部分:Pandas 和(也许)Numpy
比较您的method1b和method2:
method1b生成一个DataFrame,这可能是你想要的,
method2生成一个Numpy 数组,因此要获得完全可比较的结果,您应该随后从中生成一个DataFrame。
所以我将您的方法2更改为:
def method2():
masking = array_condition != 1
array1_new = array1[masking]
array2_new = array2[masking]
array3_new = array3[masking]
array_condition_new = array_condition[masking]
df_new = pd.DataFrame({ 'array_condition': array_condition[masking],
'array1': array1_new, 'array2': array2_new, 'array3': array3_new})
然后比较执行时间(使用%timeit)。
结果是我的method2 (扩展)版本的执行时间 比method1b长约5%(请自行检查)。
所以我的观点是,只要是单一的操作,可能还是和Pandas在一起比较好。
但是,如果您想在源 DataFrame 上按顺序执行几个操作和/或您对Numpy数组的结果感到满意,那么值得:
调用
arr = df.values
以获取底层Numpy数组。使用Numpy方法对其执行所有必需的操作。
(可选)从最终结果创建一个 DataFrame。
我尝试了method1b的Numpy版本:
def method3(): a = df.values arr = a[a[:,0] != 1]
但执行时间要长约40%。
原因可能是Numpy数组具有相同类型的所有元素,因此array_condition列被强制浮动,然后创建整个Numpy数组,这需要一些时间。
第 2 部分:Numpy 和 Numba
要考虑的替代方法是使用Numba包 - 一种即时 Python 编译器。
我做了这样的测试:
创建了一个Numpy数组(作为初步步骤):
a = df.values
原因是 JIT 编译的方法能够使用Numpy方法和类型,但不能使用Pandas的方法和类型。
为了执行测试,我使用了与上面几乎相同的方法,但使用了@njit注释(需要来自 numba import njit):
@njit def method4(): arr = a[a[:,0] != 1]
这次:
执行时间约为method1b时间的 45% 。
但由于
a = df.values
已经在测试循环之前执行过,因此这个结果是否与之前的测试有可比性存在疑问。
无论如何,自己尝试Numba,也许这对您来说是一个有趣的选择。

TA贡献1824条经验 获得超6个赞
您可能会发现在这里使用numpy.where很有用。它将布尔掩码转换为数组索引,使生活更便宜。将其与 numpy.vstack 结合可以实现一些内存便宜的操作:
def method3():
wh = np.where(array_condition == 1)
return np.vstack(tuple(col[wh] for col in (array1, array2, array3)))
这给出了以下时间:
>>> %timeit method2()
180 ms ± 6.66 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit method3()
96.9 ms ± 2.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
元组解包允许该操作在内存上相当轻,因为当对象被 vstack-ed 重新组合在一起时,它会更小。如果您需要直接从 DataFrame 中获取列,则以下代码段可能有用:
def method3b():
wh = np.where(array_condition == 1)
col_names = ['array1','array2','array3']
return np.vstack(tuple(col[wh] for col in tuple(df[col_name].to_numpy()
for col_name in col_names)))
这允许人们从 DataFrame 中按名称获取列,然后在运行中对这些列进行元组解包。速度差不多:
>>> %timeit method3b()
96.6 ms ± 3.09 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
添加回答
举报