# 机器学习笔记-多类逻辑回归-使用gluon

from mxnet import gluon

from mxnet import ndarray as nd

import matplotlib.pyplot as plt

import mxnet as mx

def transform(data, label):

return data.astype('float32')/255, label.astype('float32')

mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform)

mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform)

def show_images(images):

n = images.shape[0]

_, figs = plt.subplots(1, n, figsize=(15, 15))

for i in range(n):

figs[i].imshow(images[i].reshape((28, 28)).asnumpy())

figs[i].axes.get_xaxis().set_visible(False)

figs[i].axes.get_yaxis().set_visible(False)

plt.show()

def get_text_labels(label):

text_labels = [

'T 恤', '长 裤', '套头衫', '裙 子', '外 套',

'凉 鞋', '衬 衣', '运动鞋', '包 包', '短 靴'

]

return [text_labels[int(i)] for i in label]

data, label = mnist_train[0:10]

print('example shape: ', data.shape, 'label:', label)

show_images(data)

print(get_text_labels(label))

batch_size = 256

#计算模型

net = gluon.nn.Sequential()

with net.name_scope():

net.initialize()

softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()

#定义训练器

trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.5})

def accuracy(output, label):

return nd.mean(output.argmax(axis=1) == label).asscalar()

def _get_batch(batch):

if isinstance(batch, mx.io.DataBatch):

data = batch.data[0]

label = batch.label[0]

else:

data, label = batch

return data, label

def evaluate_accuracy(data_iterator, net):

acc = 0.

if isinstance(data_iterator, mx.io.MXDataIter):

data_iterator.reset()

for i, batch in enumerate(data_iterator):

data, label = _get_batch(batch)

output = net(data)

acc += accuracy(output, label)

return acc / (i+1)

for epoch in range(5):

train_loss = 0.

train_acc = 0.

for data, label in train_data:

output = net(data)

loss = softmax_cross_entropy(output, label)

loss.backward()

trainer.step(batch_size) #使用训练器，向"前"走一步

train_loss += nd.mean(loss).asscalar()

train_acc += accuracy(output, label)

test_acc = evaluate_accuracy(test_data, net)

print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (

epoch, train_loss/len(train_data), train_acc/len(train_data), test_acc))

data, label = mnist_test[0:10]

show_images(data)

print('true labels')

print(get_text_labels(label))

predicted_labels = net(data).argmax(axis=1)

print('predicted labels')

print(get_text_labels(predicted_labels.asnumpy()))

有变化的地方，已经加上了注释。运行效果，跟一篇完全相同，就不重复贴图了

