为了账号安全,请及时绑定邮箱和手机立即绑定

在 PyTorch 中计算欧几里德距离而不是矩阵乘法

在 PyTorch 中计算欧几里德距离而不是矩阵乘法

MMMHUHU 2023-07-11 16:46:12
假设我们有 2 个矩阵:mat = torch.randn([20, 7]) * 100mat2 = torch.randn([7, 20]) * 100n, m = mat.shape最简单的常用矩阵乘法如下所示:def mat_vec_dot_product(mat, vect):    n, m = mat.shape        res = torch.zeros([n])    for i in range(n):        for j in range(m):            res[i] += mat[i][j] * vect[j]            return resres = torch.zeros([n, n])for k in range(n):    res[:, k] = mat_vec_dot_product(mat, mat2[:, k])    但是如果我需要应用 L2 范数而不是点积怎么办?代码如下:def mat_vec_l2_mult(mat, vect):    n, m = mat.shape        res = torch.zeros([n])    for i in range(n):        for j in range(m):            res[i] += (mat[i][j] - vect[j]) ** 2                res = res.sqrt()            return resfor k in range(n):    res[:, k] = mat_vec_l2_mult(mat, mat2[:, k])我们可以使用 Torch 或任何其他库以最佳方式做到这一点吗?因为简单的 O(n^3) Python 代码运行速度非常慢。
查看完整描述

2 回答

?
慕虎7371278

TA贡献1802条经验 获得超4个赞

用于torch.cdistL2 范数 - 欧氏距离

res = torch.cdist(mat, mat2.permute(1,0), p=2)

在这里,我曾经将frompermute的 dim 交换为mat27,2020,7


查看完整回答
反对 回复 2023-07-11
?
翻翻过去那场雪

TA贡献2065条经验 获得超13个赞

首先,PyTorch 中的矩阵乘法有一个内置运算符:@。因此,要将 mat 和 mat2 相乘,您只需执行以下操作:


mat @ mat2

(假设尺寸一致,应该可以工作)。


现在,要计算您似乎在第二个块中计算的平方差之和(SSD 或 L2 范数),您可以做一个简单的技巧。由于 L2 范数的平方||m_i - v||^2(其中m_i是矩阵的第 i 行M,v是向量)等于点积<m_i - v, m_i-v>- 根据您获得的点积的线性度:因此您可以通过以下方式<m_i,m_i> - 2<m_i,v> + <v,v>计算向量中每一行的 SSD:计算一次每行的 L2 范数平方、一次每行与向量之间的点积以及一次向量的 L2 范数。这可以在 中完成。然而,对于 2 个矩阵之间的 SSD,您仍然会得到MvO(n^2)O(n^3)。不过,可以通过向量化操作而不是使用循环来进行改进。这是 2 个矩阵的简单实现:


def mat_mat_l2_mult(mat,mat2):

    rows_norm = (torch.norm(mat, dim=1, p=2, keepdim=True)**2).repeat(1,mat2.shape[1])

    cols_norm = (torch.norm(mat2, dim=0, p=2, keepdim=True)**2).repeat(mat.shape[0], 1)

    rows_cols_dot_product = mat @ mat2

    ssd = rows_norm -2*rows_cols_dot_product + cols_norm

    return ssd.sqrt()


mat = torch.randn([20, 7])

mat2 = torch.randn([7,20])

print(mat_mat_l2_mult(mat, mat2))

所得矩阵的每个单元格将具有中每行和每列之间i,j差异的 L2 范数。imatjmat2


查看完整回答
反对 回复 2023-07-11
  • 2 回答
  • 0 关注
  • 136 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信