为了账号安全,请及时绑定邮箱和手机立即绑定

如何根据pytorch中另一个张量的值将张量的某个值更改为零?

如何根据pytorch中另一个张量的值将张量的某个值更改为零?

四季花海 2022-12-06 15:12:02
我有两个张量:张量 a 和张量 b。如何根据张量 b 的值更改张量 a 的某些值?我知道下面的代码是正确的,但是当张量很大时它运行起来很慢。还有其他方法吗?import torcha = torch.rand(10).cuda()b = torch.rand(10).cuda()a[b > 0.5] = 0.
查看完整描述

2 回答

?
MMMHUHU

TA贡献1834条经验 获得超8个赞

对于这个确切的用例,还要考虑


a * (b <= 0.5)

这似乎是以下最快的


In [1]: import torch

   ...: a = torch.rand(3**10)

   ...: b = torch.rand(3**10)


In [2]: %timeit a[b > 0.5] = 0.

553 µs ± 17.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [3]: a = torch.rand(3**10)


In [4]: %timeit temp = torch.where(b > 0.5, torch.tensor(0.), a)

   ...:

49 µs ± 391 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [5]: a = torch.rand(3**10)


In [6]: %timeit temp = (a * (b <= 0.5))

44 µs ± 381 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [7]: %timeit a.masked_fill_(b > 0.5, 0.)

244 µs ± 3.48 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


查看完整回答
反对 回复 2022-12-06
?
撒科打诨

TA贡献1934条经验 获得超2个赞

我想torch.where会更快我在 CPU 中的测量是结果。


import torch

a = torch.rand(3**10)

b = torch.rand(3**10)

%timeit a[b > 0.5] = 0.

852 µs ± 30.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit temp = torch.where(b > 0.5, torch.tensor(0.), a)

294 µs ± 4.51 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


查看完整回答
反对 回复 2022-12-06
  • 2 回答
  • 0 关注
  • 159 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信