为了账号安全,请及时绑定邮箱和手机立即绑定

python:计算向量到矩阵每一行的欧氏距离的最快方法?

python:计算向量到矩阵每一行的欧氏距离的最快方法?

冉冉说 2023-06-13 16:27:47
考虑这个 python 代码,我在其中尝试计算向量到矩阵每一行的欧几里距离。与我能找到的使用 Tullio.jl 的最佳 Julia 版本相比,它非常慢。python 版本需要30s而 Julia 版本只需要75ms。我确信我在 Python 方面没有做得最好。有更快的解决方案吗?欢迎使用 Numba 和 numpy 解决方案。import numpy as np# generatea = np.random.rand(4000000, 128)b = np.random.rand(128)print(a.shape)print(b.shape)def lin_norm_ever(a, b):    return np.apply_along_axis(lambda x: np.linalg.norm(x - b), 1, a)import timet = time.time()res = lin_norm_ever(a, b)print(res.shape)elapsed = time.time() - tprint(elapsed)朱莉娅版本using Tulliofunction comp_tullio(a, c)    dist = zeros(Float32, size(a, 2))    @tullio dist[i] = (c[j] - a[j,i])^2    distend@time comp_tullio(a, c)@benchmark comp_tullio(a, c) # 75ms on my computer
查看完整描述

1 回答

?
慕沐林林

TA贡献2016条经验 获得超9个赞

为了获得最佳性能,我将在此示例中使用 Numba。我还添加了来自 Divakars 链接答案的 2 种方法以进行比较。


代码


import numpy as np

import numba as nb

from scipy.spatial.distance import cdist


@nb.njit(fastmath=True,parallel=True,cache=True)

def dist_1(mat,vec):

    res=np.empty(mat.shape[0],dtype=mat.dtype)

    for i in nb.prange(mat.shape[0]):

        acc=0

        for j in range(mat.shape[1]):

            acc+=(mat[i,j]-vec[j])**2

        res[i]=np.sqrt(acc)

    return res


#from https://stackoverflow.com/a/52364284/4045774

def dist_2(mat,vec):

    return cdist(mat, np.atleast_2d(vec)).ravel()


#from https://stackoverflow.com/a/52364284/4045774

def dist_3(mat,vec):

    M = mat.dot(vec)

    d = np.einsum('ij,ij->i',mat,mat) + np.inner(vec,vec) -2*M

    return np.sqrt(d)

时序


#Float64

a = np.random.rand(4000000, 128)

b = np.random.rand(128)

%timeit dist_1(a,b)

#122 ms ± 3.86 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit dist_2(a,b)

#484 ms ± 3.02 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit dist_3(a,b)

#432 ms ± 14.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


#Float32

a = np.random.rand(4000000, 128).astype(np.float32)

b = np.random.rand(128).astype(np.float32)

%timeit dist_1(a,b)

#68.6 ms ± 414 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit dist_2(a,b)

#2.2 s ± 32.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

#looks like there is a costly type-casting to float64

%timeit dist_3(a,b)

#228 ms ± 8.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


查看完整回答
反对 回复 2023-06-13
  • 1 回答
  • 0 关注
  • 77 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信