为了账号安全,请及时绑定邮箱和手机立即绑定

机器学习笔记-多类逻辑回归-使用gluon

标签:
机器学习

这次使用gluon让代码更精减,代码来自:https://zh.gluon.ai/chapter_supervised-learning/mlp-gluon.html

from mxnet import gluon

from mxnet import ndarray as nd

import matplotlib.pyplot as plt

import mxnet as mx

from mxnet import autograd

   

def transform(data, label):

    return data.astype('float32')/255, label.astype('float32')

   

mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform)

mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform)

   

def show_images(images):

    n = images.shape[0]

    _, figs = plt.subplots(1, n, figsize=(15, 15))

    for i in range(n):

        figs[i].imshow(images[i].reshape((28, 28)).asnumpy())

        figs[i].axes.get_xaxis().set_visible(False)

        figs[i].axes.get_yaxis().set_visible(False)

    plt.show()

 

def get_text_labels(label):

    text_labels = [

        'T 恤', '长 裤', '套头衫', '裙 子', '外 套',

        '凉 鞋', '衬 衣', '运动鞋', '包 包', '短 靴'

    ]

    return [text_labels[int(i)] for i in label]

   

data, label = mnist_train[0:10]

   

print('example shape: ', data.shape, 'label:', label)

show_images(data)

print(get_text_labels(label))

   

batch_size = 256

train_data = gluon.data.DataLoader(mnist_train, batch_size, shuffle=True)

test_data = gluon.data.DataLoader(mnist_test, batch_size, shuffle=False)

   

#计算模型

net = gluon.nn.Sequential()

with net.name_scope():

    net.add(gluon.nn.Flatten())

    net.add(gluon.nn.Dense(256, activation="relu"))

    net.add(gluon.nn.Dense(10))

net.initialize()

   

softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()

 

#定义训练器

trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.5})

  

def accuracy(output, label):

    return nd.mean(output.argmax(axis=1) == label).asscalar()

   

def _get_batch(batch):

    if isinstance(batch, mx.io.DataBatch):

        data = batch.data[0]

        label = batch.label[0]

    else:

        data, label = batch

    return data, label

   

def evaluate_accuracy(data_iterator, net):

    acc = 0.

    if isinstance(data_iterator, mx.io.MXDataIter):

        data_iterator.reset()

    for i, batch in enumerate(data_iterator):

        data, label = _get_batch(batch)

        output = net(data)

        acc += accuracy(output, label)

    return acc / (i+1)

   

for epoch in range(5):

    train_loss = 0.

    train_acc = 0.

    for data, label in train_data:

        with autograd.record():

            output = net(data)

            loss = softmax_cross_entropy(output, label)

        loss.backward()

        trainer.step(batch_size) #使用训练器,向"前"走一步

 

        train_loss += nd.mean(loss).asscalar()

        train_acc += accuracy(output, label)

 

    test_acc = evaluate_accuracy(test_data, net)

    print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (

        epoch, train_loss/len(train_data), train_acc/len(train_data), test_acc))

 

data, label = mnist_test[0:10]

show_images(data)

print('true labels')

print(get_text_labels(label))

   

predicted_labels = net(data).argmax(axis=1)

print('predicted labels')

print(get_text_labels(predicted_labels.asnumpy()))

 有变化的地方,已经加上了注释。运行效果,跟一篇完全相同,就不重复贴图了

点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消