为了账号安全,请及时绑定邮箱和手机立即绑定

TensorFlow利用卷积神经网络在谷歌inception_v3模型基础上解决花朵分类问题

标签:
人工智能

本篇更多的是在代码实战方向,不会涉及太多的理论。本文主要针对TensorFlow和卷积神经网络有一定基础的同学,并对图像处理有一定的了解。

阅读本文你大概需要以下知识:

1.TensorFlow基础
2.TensorFlow实现卷积神经网络的前向传播过程
3.TFRecord数据格式
4.Dataset的使用
5.Slim的使用

好了废话不多说,下面开始。

一.数据准备

首先我们需要有一个让我们训练的数据集,这里谷歌已经帮我们做好了。这里要把数据集下载下来,打开命令行,执行如下命令:

wget http://download.tensorflow.org/example_image/flower_photo.tgz//解压tar xzf flower_photos.tgz

这里需要注意的是,文件最好是下载到你的工程目录下方便你的读取。什么?你还不会搭建TensorFlow程序?请移步https://www.tensorflow.org/install/
选择自己的操作系统,在这里我的是macOS。我使用的是Virtualenv来搭建TensorFlow运行环境。
数据集下载并解压后,我们可以看到大概是这个样子

webp

每一个文件夹里都是一个种类的花的图片,这里总共有五种花。
好了,数据有了?接下来该怎么办呢?当然是把数据进行预处理拉,你不会觉得我们的TensorFlow可以直接识别这些图片进行训练吧,hhhhhh。


二.数据预处理

接下来我们在目录下新建pre_data.python文件。TensorFlow对图片做处理一般是生成TFRecord文件。什么是TFRecord?后面我们会讲到。

首先我们要引入我们需要的库。

# glob模块的主要方法就是glob,该方法返回所有匹配的文件路径列表(list)import glob#os.path生成路径方便glob获取import os.path#这里主要用到随机数import numpy as np#引入tensorflow框架import tensorflow as tf#引入gflie对图片做处理from tensorflow.python.platform import gfile

相关库在我们这个程序中的功能都作了简单介绍,下面用到的时候我们会更加详细的说明。

大家都知道我们的数据集一般分训练,测试和验证数据集。观察上面的数据集,谷歌只是给出了每一种花的图片,并没有给去哪些我训练,哪些是测试,哪些是验证数据集。所以在这里我们要进行划分。

#输入图片地址INPUT_DATA = '../../flower_photos'#训练数据集OUTPUT_FILE = './path/to/output.tfrecords'#测试数据集OUTPUT_TEST_FILE = './path/to/output_test.tfrecords'#验证数据集OUTPUT_VALIDATION_FILE = './path/to/output_validation.tfrecords'#测试数据和验证数据的比例VALIDATION_PERCENTAGE = 10
TEST_PERCENTAGE = 10

关于VALIDATION_PERCENTAGE和TEST_PERCENTAGE这两个常量,我们在后面的例子会给出。

下面我们就来定义处理数据的方法:

def create_image_lists(sess,testing_percentage,validation_percentage):
    #拿到INPUT_DATA文件夹下的所有目录(包括root)
    sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]    #如果是root_dir不需要做处理
    is_root_dir = True
    #定义图片对应的标签,从0-4分别代表不同的花
    current_label = 0
    #写入TFRecord的数据需要首先定义writer
    #这里定义三个writer分别存储训练,测试和验证数据
    writer = tf.python_io.TFRecordWriter(OUTPUT_FILE)
    writer_test = tf.python_io.TFRecordWriter(OUTPUT_TEST_FILE)
    writer_validation = tf.python_io.TFRecordWriter(OUTPUT_VALIDATION_FILE)    #循环目录
    for sub_dir in sub_dirs:        if is_root_dir:            #跳过根目录
            is_root_dir = False
            continue
        #定义空数组来装图片路径
        file_list = []        #生成查找路径
        dir_name = os.path.basename(sub_dir)
        file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + "jpg")        # extend合并两个数组
        # glob模块的主要方法就是glob,该方法返回所有匹配的文件路径列表(list)
        # 比如:glob.glob(r’c:*.txt’) 这里就是获得C盘下的所有txt文件
        file_list.extend(glob.glob(file_glob))        #路径下没有文件就跳过,不继续操作
        if not file_list: continue
        #这里我定义index来打印当前进度
        index = 0
        #file_list此时是图片路径列表
        for file_name in file_list:            #使用gfile从路径中读取图片
            image_raw_data = gfile.FastGFile(file_name, 'rb').read()            #对图像解码,解码结果为一个张量
            image = tf.image.decode_jpeg(image_raw_data)            #对图像矩阵进行归一化处理
            #因为为了将图片数据能够保存到 TFRecord 结构体中
            #所以需要将其图片矩阵转换成 string
            #所以为了在使用时能够转换回来
            #这里确定下数据格式为 tf.float32  
            if image.dtype != tf.float32:
                image = tf.image.convert_image_dtype(image, dtype=tf.float32)            # 将图片转化成299*299方便模型处理
            image = tf.image.resize_images(image, [299, 299])            #为了拿到图片的真实数据这里我们要运行一个session op
            image_value = sess.run(image)
           
            pixels = image_value.shape[1]            #存储在TFrecord里面的不能是array的形式
            #所以我们需要利用tostring()将上面的矩阵
            #转化成字符串
            #再通过tf.train.BytesList转化成可以存储的形式
            image_raw = image_value.tostring()            #存到features
            #随机划分测试集和训练集
            #这里存入TFRecord三个数据,图像的pixels像素
            #图像原张量,这里我们需要转成string
            #以及当前图像对应的标签
            example = tf.train.Example(features=tf.train.Features(feature={                'pixels': _int64_feature(pixels),                'label': _int64_feature(current_label),                'image_raw': _bytes_feature(image_raw)
            }))
            chance = np.random.randint(100)            #随机划分数据集
            if chance < validation_percentage:
                writer_validation.write(example.SerializeToString())            elif chance < (testing_percentage+validation_percentage):
                writer_test.write(example.SerializeToString())            else:
                writer.write(example.SerializeToString())            # print('example',index)
            index = index + 1

        #每一个文件夹下的所有图片都是一个类别
        #所以这里每遍历完一个文件夹,标签就增加1
        current_label += 1

    writer.close()
    writer_validation.close()
    writer_test.close()

运行上述程序需要一定时间,我的电脑比较烂,大概跑了三十分钟左右。这时候在你的./path/to目录下可以看到output.tfrecords,output_test.tfrecords,output_validation.tfrecords三个文件,分别存放了训练,测试和验证数据集。上述代码将所有图片划分成训练、验证和测试数据集。并且把图片从原始的jpg格式转换成inception-v3模型需要的299 * 299 * 3的数字矩阵。在数据处理完毕之后,通过以下命令可以下载谷歌提供好的Inception_v3模型。

wget http://download.tensorflow.org/models/inception_v3_2016_08_26.tar.gz//解压之后可以得到训练好的模型文件inception_v3.ckpttar xzf inception_v3_2016_08

二.训练

当新的数据集和已经训练好的模型都准备好之后,我们来写代码在谷歌inception_v3的基础上训练新数据集。

首先同样我们导入相关的库并且定义相关常量。在这里我们通过slim工具来直接加载模型,而不用自己再定义前向传播过程。

import numpy as npimport tensorflow as tfimport tensorflow.contrib.slim as slim# 加载通过TensorFlow-Silm定义好的 inception_v3模型import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3# 输入数据文件INPUT_DATA = './path/to/output.tfrecords'# 验证数据集VALIDATION_DATA = './path/to/output_validation.tfrecords'# 保存训练好的模型的路径ls = './path/to/save_model'# 谷歌提供的训练好的模型文件地址CKPT_FILE = './path/to/inception_v3.ckpt'TRAIN_FILE = './path/to/save_model'# 定义训练中使用的参数LEARNING_RATE = 0.01#组合batch的大小BATCH = 32#用于one_hot函数输出概率分布N_CLASSES = 5#打乱顺序,并设置出队和入队中元素最少的个数,这里是10000个shuffle_buffer = 10000# 不需要从谷歌模型中加载的参数,这里就是最后的全连接层。因为输出类别不一样,所以最后全连接层的参数也不一样CHECKPOINT_EXCLUDE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'# 需要训练的网络层参数 这里就是最后的全连接层TRAINABLE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'

接下来我们定义几个辅助方法。首先因为我们的数据存在TFRecord里,需要定义方法从TFRecord解析数据。

def parse(record):
    features = tf.parse_single_example(
        record,
        features={            'image_raw': tf.FixedLenFeature([], tf.string),            'label': tf.FixedLenFeature([], tf.int64),            'pixels': tf.FixedLenFeature([], tf.int64)
        }
    )    #decode_raw用于解析TFRecord里面的字符串
    decoded_image = tf.decode_raw(features['image_raw'], tf.uint8)
    label = features['label']    #要注意这里的decoded_image并不能直接进行reshape操作
    #之前我们在存储的时候,把图片进行了tostring()操作
    #这会导致图片的长度在原来基础上*8
    #后面我们要用到numpy的fromstring来处理
    return decoded_image, label

接下来定义两个方法。因为我们已经下载了谷歌训练好的inception_v3模型的参数,下面我们需要定义两个方法从里面加载参数。

#直接从inception_v3.ckpt中读取的参数def get_tuned_variables():
    #strip删除头尾字符,默认为空格
    exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(",")]
    variables_to_restore = []    #这里给出了所有slim模型下的参数
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:            if var.op.name.startswith(exclusion):
                excluded = True
                break
            if not excluded:
                variables_to_restore.append(var)        return variables_to_restore#需要重新训练的参数def get_trainable_variables():
    #strip删除头尾字符,默认为空格
    scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(",")]
    variables_to_train = []    # 枚举所有需要训练的参数前缀,并通过这些前缀找到所有的参数。
    for scope in scopes:      #从TRAINABLE_VARIABLES集合中获取名为scope的变量
      #也就是我们需要重新训练的参数
        variables = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, scope)
        variables_to_train.extend(variables)    return variables_to_train

这里我们就写完了所需要的工具函数,接下来我们定义主函数。主函数主要完成数据读取,模型定义,通过模型得出前向传播结果,通过损失函数计算损失,最后把损失交给优化器做处理。首先我们先来完成数据读取的代码,这里我们使用的是TensorFlow高层API Dataset。不清楚的可以去看一下Dataset的用法。

这里我们在训练的同时也对模型做了验证。所以我们需要加载训练和验证数据

#读取测试数据
    #利用TFRecordDataset读取TFRecord文件
    dataset = tf.data.TFRecordDataset([INPUT_DATA])    #解析TFRecord
    dataset = dataset.map(parse)    #把数据打乱顺序并组装成batch
    dataset = dataset.shuffle(shuffle_buffer).batch(BATCH)    #定义数据重复的次数
    NUM_EPOCHS = 10
    dataset = dataset.repeat(NUM_EPOCHS)    #定义迭代器来获取处理后的数据
    iterator = dataset.make_one_shot_iterator()    #迭代器开始迭代
    img, label = iterator.get_next()    #读取验证数据(同上)
    valida_dataset = tf.data.TFRecordDataset([VALIDATION_DATA])
    valida_dataset = valida_dataset.map(parse)
    valida_dataset = valida_dataset.batch(BATCH)
    valida_iterator = valida_dataset.make_one_shot_iterator()
    valida_img,valida_label = valida_iterator.get_next()    #定义inception-v3的输入,images为输入图片,label为每一张图片对应的标签
    #再解释下每一个维度 None为batch的大小,299为图片大小,3为通道
    images = tf.placeholder(tf.float32,[None,299,299,3],name='input_images')
    labels = tf.placeholder(tf.int64,[None],name='labels')

要注意上述定义的只是tensorflow的张量,保存的只是计算过程并没有具体的数据。只有运行session之后才会拿到具体的数据。

下面我们来通过slim加载inception-v3模型

 #定义inception-v3模型结构 inception_v3.ckpt里只有参数的取值
    with slim.arg_scope(inception_v3.inception_v3_arg_scope()):        #logits  inception_v3前向传播得到的结果
        logits,_ = inception_v3.inception_v3(images,num_classes=N_CLASSES)        #获取需要训练的变量
        trainable_variables = get_trainable_variables()        #这里用交叉熵作为损失函数,注意一下tf.losses.softmax_cross_entropy的参数
        # tf.losses.softmax_cross_entropy(
        #     onehot_labels,  # 注意此处参数名就叫 onehot_labels
        #     logits,
        #     weights=1.0,
        #     label_smoothing=0,
        #     scope=None,
        #     loss_collection=tf.GraphKeys.LOSSES,
        #     reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
        # )
        #这里要把labels转成one_hot类型,logits就是神经网络的输出        
        tf.losses.softmax_cross_entropy(tf.one_hot(labels,N_CLASSES),logits,weights=1.0)        #把计算的损失交给优化器处理
        train_step = tf.train.RMSPropOptimizer(LEARNING_RATE).minimize(tf.losses.get_total_loss())        #计算正确率。
        with tf.name_scope('evaluation'):
            correct_prediction = tf.equal(tf.argmax(logits,1),labels)
            evaluation_step = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))        #定义加载模型的函数
        load_fn = slim.assign_from_checkpoint_fn(CKPT_FILE,get_tuned_variables(),ignore_missing_vars=True)        #定义保存新的训练好的模型的函数
        saver = tf.train.Saver()        with tf.Session() as sess:            #初始化所有变量
            init = tf.global_variables_initializer()
            sess.run(init)
            print('Loading tuned variables from %s'%CKPT_FILE)            #加载谷歌已经训练好的模型
            load_fn(sess)
            step = 0;            #在这里我们用一个while来循环训练,直到dataset里没有数据就结束循环
            while True:                try:                    if step % 30  == 0 or step + 1 == STEPS:                      #每30轮输出一次正确率
                        if step != 0:                            #每30轮保存一次当前模型的参数,以便中途训练中断可以继续
                            saver.save(sess,TRAIN_FILE,global_step=step)                       #运行session拿到真实图片的数据
                        valida_img_batch,valida_label_batch = sess.run([valida_img,valida_label])                        #上面有提到TFRecord里图片数据被转成了string,在这里转回来
                        valida_img_batch = np.fromstring(valida_img_batch, dtype=np.float32)                        #把图片张量拉成新的维度
                        valida_img_batch = tf.reshape(valida_img_batch, [32, 299, 299, 3])                        #用session运行上述操作,得到处理后的图片张量
                        valida_img_batch = sess.run(valida_img_batch)                        #把图片张量传到feed_dict算出正确率并显示
                        validation_accuracy = sess.run(evaluation_step,feed_dict={
                            images:valida_img_batch,
                            labels:valida_label_batch
                        })
                        print('Step %d: Validation accurary = %.1f%%'%(step,validation_accuracy*100.0))                    #下面是对训练数据的操作,同上
                    img_batch,label_batch = sess.run([img,label])
                    img_batch = np.fromstring(img_batch, dtype=np.float32)
                    img_batch = tf.reshape(img_batch, [32,299, 299, 3])
                    img_batch = sess.run(img_batch)

                    sess.run(train_step,feed_dict={
                        images:img_batch,
                        labels:label_batch
                    })                    #step仅仅用于记录
                    step = step + 1
                except tf.errors.OutOfRangeError:                    break

运行上述程序开始训练。在这里我暂时是使用cpu进行训练,训练过程大约3小时,可以得到类型下面的结果。

step 0:Validation accuracy = 12.5%
step 30:Validation accuracy = 22.2%
step 60:Validation accuracy = 63.2%
step 90:Validation accuracy = 79.8%
step 120:Validation accuracy = 86.4%
step 150:Validation accuracy = 88.5%
.....



作者:sidiWang
链接:https://www.jianshu.com/p/fc77879d3591


点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消