2 回答

TA贡献1802条经验 获得超5个赞
好的,在对不同类型的索引进行了大量实验之后,我让它工作了。事实证明,答案就在 Advanced Indexing 中。不幸的是,PyTorch 文档没有详细介绍高级索引。这是 Numpy 文档中的链接。
对于上述问题,此命令起到了作用:
>>> k_lst = torch.zeros([4,4,5])
>>> k_lst[torch.arange(4).unsqueeze(1), torch.arange(4), inp_list[:,:,1]] = inp_list[:,:,0].float()
>>> k_lst
tensor([[[ 0., 0., 1., 0., 0.],
[ 0., 0., 0., 0., 3.],
[-1., 0., 0., 0., 0.],
[ 0., 45., 0., 0., 0.]],
[[ 0., 0., 1., 0., 0.],
[ 0., 0., 0., 0., 3.],
[-1., 0., 0., 0., 0.],
[ 0., 45., 0., 0., 0.]],
[[ 0., 0., 1., 0., 0.],
[ 0., 0., 0., 0., 3.],
[-1., 0., 0., 0., 0.],
[ 0., 45., 0., 0., 0.]],
[[ 0., 0., 1., 0., 0.],
[ 0., 0., 0., 0., 3.],
[-1., 0., 0., 0., 0.],
[ 0., 45., 0., 0., 0.]]])
这正是我想要的。
我在搜索这个时学到了很多东西,我想与任何偶然发现这个问题的人分享这个。那么,为什么会这样呢?答案在于广播的工作方式。如果您查看所涉及的不同索引张量的形状,您会发现它们(必然)是可广播的。
>>> torch.arange(4).unsqueeze(1).shape, torch.arange(4).shape, inp_list[:,:,1].shape
(torch.Size([4, 1]), torch.Size([4]), torch.Size([4, 4]))
显然,要访问此处的 k_lst 等 3-D 张量的元素,我们需要 3 个索引 - 每个维度一个。如果你给算子3个相同形状的张量[],它可以通过从3个张量中匹配对应的元素得到一堆合法的索引。
如果这 3 个张量具有不同的形状,但可以广播(就像这里的情况),它会复制缺少张量的相关行/列所需的次数以获得具有相同形状的张量。
最终,就我而言,如果我们研究如何分配不同的值,这相当于做
k_lst[0,0,inp_list[0,0,1]] = inp_list[0,0,0].float()
k_lst[0,1,inp_list[0,1,1]] = inp_list[0,1,0].float()
k_lst[0,2,inp_list[0,2,1]] = inp_list[0,2,0].float()
k_lst[0,3,inp_list[0,3,1]] = inp_list[0,3,0].float()
k_lst[1,0,inp_list[1,0,1]] = inp_list[1,0,0].float()
k_lst[1,1,inp_list[1,1,1]] = inp_list[1,1,0].float()
.
.
.
k_lst[3,3,inp_list[3,3,1]] = inp_list[3,3,0].float()
这个格式让我想起了torch.Tensor.scatter(),但是如果能用它来解决这个问题,我还没想好怎么办。
添加回答
举报