为了账号安全,请及时绑定邮箱和手机立即绑定

torch / np einsum 内部到底是如何工作的

torch / np einsum 内部到底是如何工作的

慕码人2483693 2023-07-18 15:24:41
torch.einsum这是有关GPU内部工作的查询。我知道如何使用einsum。它是执行所有可能的矩阵乘法,然后只选择相关的矩阵乘法,还是仅执行所需的计算?例如,考虑形状 的两个张量a和,我希望找到形状的每个相应张量 的点积。使用einsum,代码为:b(N,P)ni(1,P)torch.einsum('ij,ij->i',a,b)在不使用 einsum 的情况下,获取输出的另一种方法是:torch.diag(a @ b.t())现在,第二个代码应该比第一个代码执行更多的计算(例如, if N= 2000,它执行2000更多的计算)。然而,当我尝试对这两个操作进行计时时,它们完成所需的时间大致相同,这就引出了一个问题。是否einsum执行所有组合(如第二个代码),并挑选出相关值?要测试的示例代码:import timeimport torchfor i in range(100):  a = torch.rand(50000, 256).cuda()  b = torch.rand(50000, 256).cuda()  t1 = time.time()  val = torch.diag(a @ b.t())  t2 = time.time()  val2 = torch.einsum('ij,ij->i',a,b)  t3 = time.time()  print(t2-t1,t3-t2, torch.allclose(val,val2))
查看完整描述

2 回答

?
SMILET

TA贡献1796条经验 获得超4个赞

这可能与 GPU 可以并行计算a @ b.t(). 这意味着 GPU 实际上不必等待每个行列乘法计算完成即可计算下一个乘法。如果您检查 CPU,您会发现它torch.diag(a @ b.t())torch.einsum('ij,ij->i',a,b) 大型ab.



查看完整回答
反对 回复 2023-07-18
?
莫回无

TA贡献1865条经验 获得超7个赞

我不能代表,但几年前曾在一些细节上torch合作过。np.einsum然后它根据索引字符串构造一个自定义迭代器,仅执行必要的计算。从那时起,它以各种方式进行了重新设计,显然将问题转化为@可能的情况,从而利用了 BLAS(等)库调用。


In [147]: a = np.arange(12).reshape(3,4)

In [148]: b = a


In [149]: np.einsum('ij,ij->i', a,b)

Out[149]: array([ 14, 126, 366])

我不能确定在这种情况下使用了什么方法。通过“j”求和,还可以通过以下方式完成:


In [150]: (a*b).sum(axis=1)

Out[150]: array([ 14, 126, 366])

正如您所注意到的,最简单的方法dot创建一个更大的数组,我们可以从中拉出对角线:


In [151]: (a@b.T).shape

Out[151]: (3, 3)

但这不是正确的使用方法@。 通过提供高效的“批量”处理@进行扩展。np.dot所以i维度是批次一,也是j一dot。


In [152]: a[:,None,:]@b[:,:,None]

Out[152]: 

array([[[ 14]],


       [[126]],


       [[366]]])

In [156]: (a[:,None,:]@b[:,:,None])[:,0,0]

Out[156]: array([ 14, 126, 366])

换句话说,它使用 (3,1,4) 和 (3,4,1) 生成 (3,1,1),在共享大小 4 维度上进行乘积之和。


一些采样时间:


In [162]: timeit np.einsum('ij,ij->i', a,b)

7.07 µs ± 89.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

In [163]: timeit (a*b).sum(axis=1)

9.89 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

In [164]: timeit np.diag(a@b.T)

10.6 µs ± 31.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

In [165]: timeit (a[:,None,:]@b[:,:,None])[:,0,0]

5.18 µs ± 197 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


查看完整回答
反对 回复 2023-07-18
  • 2 回答
  • 0 关注
  • 89 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信