1 回答
TA贡献2039条经验 获得超8个赞
有一些改进的空间
但永远不要期望 logsumexp 和标准求和一样快,因为这exp是一项相当昂贵的操作。
例子
import numpy as np
#from version 0.43 until 0.47 this has to be set before importing numba
#Bug: https://github.com/numba/numba/issues/4689
from llvmlite import binding
binding.set_option('SVML', '-vector-library=SVML')
import numba as nb
@nb.njit(fastmath=True,parallel=False)
def logsum_exp_reduceat(arr, indices):
res = np.empty(indices.shape[0],dtype=arr.dtype)
for i in nb.prange(indices.shape[0]-1):
r = 0.
for j in range(indices[i],indices[i+1]):
r += np.exp(arr[j])
res[i]=np.log(r)
r = 0.
for j in range(indices[-1],arr.shape[0]):
r += np.exp(arr[j])
res[-1]=np.log(r)
return res
计时
#small example where parallelization doesn't make sense
arr = np.random.uniform(0,0.1, 10_000)
log_arr = np.log(arr)
#use arrays if possible
indices = np.sort(np.random.randint(0, 10_000, 100))
%timeit logsum_exp_reduceat(arr, indices)
#without parallelzation 22 µs ± 173 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
#with parallelization 84.7 µs ± 32.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.add.reduceat(arr, indices)
#4.46 µs ± 61.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
#large example where parallelization makes sense
arr = np.random.uniform(0,0.1, 1000_000)
log_arr = np.log(arr)
indices = np.sort(np.random.randint(0, 1000_000, 100))
%timeit logsum_exp_reduceat(arr, indices)
#without parallelzation 1.57 ms ± 14.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
#with parallelization 409 µs ± 14.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit np.add.reduceat(arr, indices)
#340 µs ± 11.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
添加回答
举报
