1 回答
TA贡献1799条经验 获得超6个赞
tf.scatter_nd_update 此示例(从此处的 tf 文档扩展)应该有所帮助。
您想首先将您的 row_indices 和 column_indices 组合成一个二维索引列表,这indices是tf.scatter_nd_update. 然后你输入了一个期望值列表,即updates.
ref = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))
indices = tf.constant([[0, 2], [2, 2]])
updates = tf.constant([1.0, 2.0])
update = tf.scatter_nd_update(ref, indices, updates)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print sess.run(update)
Result:
[[ 0. 0. 1. 0.]
[ 0. 0. 0. 0.]
[ 0. 0. 2. 0.]
[ 0. 0. 0. 0.]
[ 0. 0. 0. 0.]
[ 0. 0. 0. 0.]
[ 0. 0. 0. 0.]
[ 0. 0. 0. 0.]]
专门针对您的数据,
ref = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))
changed_tensor = [[8.3356, 0., 8.457685 ],
[0., 6.103182, 8.602337 ],
[8.8974, 7.330564, 0. ],
[0., 3.8914037, 5.826657 ],
[8.8974, 0., 8.283971 ],
[6.103182, 3.0614321, 5.826657 ],
[7.330564, 0., 8.283971 ],
[6.103182, 3.8914037, 0. ]]
updates = tf.reshape(changed_tensor, shape=[-1])
sparse_indices = tf.constant(
[[1, 1],
[2, 1],
[5, 1],
[1, 2],
[2, 2],
[5, 2],
[1, 3],
[2, 3],
[5, 3],
[1, 2],
[4, 2],
[6, 2],
[1, 3],
[4, 3],
[6, 3],
[2, 2],
[3, 2],
[6, 2],
[2, 3],
[3, 3],
[6, 3],
[2, 2],
[4, 2],
[4, 2]])
update = tf.scatter_nd_update(ref, sparse_indices, updates)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print sess.run(update)
Result:
[[ 0. 0. 0. 0. ]
[ 0. 8.3355999 0. 8.8973999 ]
[ 0. 0. 6.10318184 7.33056402]
[ 0. 0. 3.06143212 0. ]
[ 0. 0. 0. 0. ]
[ 0. 8.45768547 8.60233688 0. ]
[ 0. 0. 5.82665682 8.28397083]
[ 0. 0. 0. 0. ]]
添加回答
举报
