为了账号安全,请及时绑定邮箱和手机立即绑定

从张量中获取最大索引

从张量中获取最大索引

慕哥6287543 2021-09-28 15:15:07
我的 CNN 产生了以下内容(来自model.predict()):Tensor("input_1:0", shape=(?, 2, 26, 1), dtype=float32)[9.9952221e-01 2.3613637e-04 1.9953270e-06 1.6922619e-05 2.2012556e-04 2.4441533e-07 3.5276526e-07 7.4913805e-07 4.0657511e-07 8.7760031e-07]我想从这个 numpy 数组中获取最大值的索引。现在,我已经尝试过这样做(x即上面的数组):result = x.index(max(x))相反,这会引发一个错误,指出此数据类型不支持.index?
查看完整描述

1 回答

?
三国纷争

TA贡献1804条经验 获得超7个赞

您可以简单地使用该np.argmax功能:


import numpy as np


preds = model.predict(test_data)

pred_class = np.argmax(preds, axis=-1)


查看完整回答
反对 回复 2021-09-28
  • 1 回答
  • 0 关注
  • 242 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信