为了账号安全,请及时绑定邮箱和手机立即绑定

训练期间每类验证的准确性

训练期间每类验证的准确性

跃然一笑 2023-05-23 10:15:08
Keras 在训练时给出了整体training和validation准确率。有没有办法在培训期间获得a per-class validation accuracy?更新:来自 Pycharm 的错误日志File "C:/Users/wj96hq/PycharmProjects/PedestrianClassification/Awareness.py", line 82, in <module>shuffle=True, callbacks=callbacks)File "C:\Users\wj96hq\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py", line 66, in _method_wrapperreturn method(self, *args, **kwargs)File "C:\Users\wj96hq\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py", line 876, in fitcallbacks.on_epoch_end(epoch, epoch_logs)File "C:\Users\wj96hq\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\callbacks.py", line 365, in on_epoch_endcallback.on_epoch_end(epoch, logs)File "C:/Users/wj96hq/PycharmProjects/PedestrianClassification/Awareness.py", line 36, in on_epoch_endx_test, y_test = self.validation_data[0], self.validation_data[1]TypeError: 'NoneType' object is not subscriptable
查看完整描述

3 回答

?
慕斯709654

TA贡献1840条经验 获得超5个赞

使用它来获得每类准确性:



model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])



class Metrics(keras.callbacks.Callback):

    def on_train_begin(self, logs={}):

        self._data = []


    def on_epoch_end(self, batch, logs={}):

        x_test, y_test = self.validation_data[0], self.validation_data[1]

        y_predict = np.asarray(model.predict(x_test))


        true = np.argmax(y_test, axis=1)

        pred = np.argmax(y_predict, axis=1)

        

        cm = confusion_matrix(true, pred)

        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

        self._data.append({

            'classLevelaccuracy':cm.diagonal() ,

        })

        return


    def get_data(self):

        return self._data


metrics = Metrics()

history = model.fit(x_train, y_train, epochs=100, validation_data=(x_test, y_test), callbacks=[metrics])

metrics.get_data()

您可以在指标类中更改代码。随心所欲..并且这个工作。你只是用来metrics.get_data()获取所有信息..


查看完整回答
反对 回复 2023-05-23
?
猛跑小猪

TA贡献1858条经验 获得超8个赞

好吧,准确性是一个global指标,没有per-class accuracy. 也许你的意思是,这就是orproportion of the class correctly identified的确切定义。TPRrecall


查看完整回答
反对 回复 2023-05-23
?
倚天杖

TA贡献1828条经验 获得超3个赞

如果您想获得某个类别或一组特定类别的准确性,掩码可能是一个很好的解决方案。看这段代码:


def cus_accuracy(real, pred):


    score = accuracy(real, pred)

    mask = tf.math.greater_equal(real, 5)

    mask = tf.cast(mask, dtype=real.dtype)

    score *= mask


    mask2 = tf.math.less_equal(real, 10)

    mask2 = tf.cast(mask2, dtype=real.dtype)

    score *= mask2


return tf.reduce_mean(score)

这个指标给出了 5 到 10 类的准确度。我用它来测量 seq2seq 模型中某些单词的准确度。


查看完整回答
反对 回复 2023-05-23
  • 3 回答
  • 0 关注
  • 105 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信