EfficientNet迁移学习(一) —— 主程序(train.py)_python 预训练efficientnet迁移学习-程序员宅基地

技术标签: 深度学习  EfficientNet  图像分类迁移学习  分类网络  

项目背景

EfficientNet系列博客与ResNet迁移学习的任务一样,均是进行二分类的迁移学习,唯一差别在于网络结构的不同。其它相关的内容,参考下面链接:ResNet迁移学习(一)—— 主程序结构,部分内容如下图所示:

这里是引用

代码框架

下面我们主要介绍train.py的程序框架和功能模块接口。程序框架是构建工程的首要步骤,而功能模块是完善框架的内容。这样在构建工程的时候,思路会比较清晰,不至于无从下手。

1. 数据接口

 tf.set_random_seed(-1)

# ****************************************************************** #
#                   1. Python多线程数据读取与数据预处理                 #
# ****************************************************************** # 
train_queue = train_set_queue() 
valid_queue = valid_set_queue()

2. 网络结构接口

    # ****************************************************************** #
    #                   2. 构建静态图                                     #
    # ****************************************************************** #
    model = Model()
    print('variables: ', model.first_stage_trainable_var_list)

3. 开启会话,执行训练

第一步:如果要加载预训练模型,则需要首先恢复其相应的权重

# ****************************************************************** #
#                   3. 恢复权重, 开启会话, 执行静态图                   #
# ****************************************************************** #
    init = tf.global_variables_initializer()
    config_proto = tf.ConfigProto()
    config_proto.gpu_options.allow_growth = True
    with tf.Session(config=config_proto) as sess:
        sess.run(init)
        try:
            print('=> Restoring weights from: %s ... ' % model.pre_trained_weight)
            model.loader.restore(sess, model.pre_trained_weight)
            print('成功加载模型')
        except:
            print('Load Pretrained Weight Failed !!!\n')
            print('=> %s does not exist !!!' % model.pre_trained_weight)
            print('=> Now it starts to train from scratch ...')
            cfg.Train.First_Stage_Epochs = 0  # 重头开始训练网络, 不需要预训练权重
        else:
            print('\nLoad Model Success !!!\n')

第二步:设置训练过程中的保存文件目录,将要收集的变量保存文件,可以用Tensorboard查看

        start_epoch = 0

        # 训练日志保存目录
        save_path = cfg.train.log
        if os.path.exists(save_path):
            print('目录非空, 清空目录')
            shutil.rmtree(save_path)
        train_writer = tf.summary.FileWriter(save_path + 'train', sess.graph)
        valid_writer = tf.summary.FileWriter(save_path + 'valid')

第三步:根据设置的Epoch,循环训练网络

        for epoch in range(start_epoch + 1, 1 + model.first_stage_epoch + model.second_stage_epoch):
            if epoch < model.first_stage_epoch:
                train_op = model.train_op_with_frozen_variables
            else:
                train_op = model.train_op_with_all_variables

            # Train Process
            for step in range(0, cfg.train.train_num // model.batch_size):
                # 每次从队列中取出一个batch的数据
                _, _, image, label = train_queue.get()

                # 统计每个batch中标签的二分类分布情况
                # result = Counter(label)
                train_params = [train_op,
                                model.loss,
                                model.accuracy,
                                model.merged,
                                model.learn_rate,
                                model.global_step]

                train_dict = {
     model.inputs: image,
                              model.label_c: label,
                              model.trainable: True}

                _, train_loss, train_accuracy, merge_train, lr_, global_step_ = sess.run(train_params,
                                                                                         feed_dict=train_dict)

                train_writer.add_summary(merge_train, global_step_)
                print('iter:%2d/%d || train loss:%.4f || train accuracy:%.4f || lr:%g ' % (step,
                                                                                           epoch,
                                                                                           train_loss,
                                                                                           train_accuracy,
                                                                                           lr_))
                # print('阴性样本:%d || 阳性样本:%d ' % (result[0], result[1]))

            # Valid Process
            valid = []
            label_ = []
            predict_ = []
            for valid_step in range(cfg.train.valid_num // model.batch_size):
                name, _, valid_image, valid_label = valid_queue.get()

                valid_params = [model.accuracy,
                                model.loss,
                                model.merged,
                                tf.argmax(model.logits, 1),
                                model.label_c]

                valid_dict = {
     model.inputs: valid_image,
                              model.label_c: valid_label,
                              model.trainable: False}

                valid_accuracy, valid_loss, merge_valid, predict, label_c = sess.run(valid_params, feed_dict=valid_dict)

                valid_writer.add_summary(merge_valid, global_step_)
                print('==================================================================================')
                print('step=%2d/%d || valid accuracy=%g || loss=%.4f ' % (valid_step,
                                                                          epoch,
                                                                          valid_accuracy,
                                                                          valid_loss))
                print('==================================================================================')
                print('标签值:', label_c)
                print('预测值:', predict)

                label_.append(label_c)
                predict_.append(predict)
                valid.append(valid_accuracy)

            # 计算所有验证集的混淆矩阵
            print('mean valid accuracy: ', np.mean(valid))
            confusion_matrix = tf.confusion_matrix(np.hstack(label_), np.hstack(predict_), num_classes=2)
            confusion_matrix_ = sess.run(confusion_matrix)

            TN = confusion_matrix_[0][0]
            FP = confusion_matrix_[0][1]
            FN = confusion_matrix_[1][0]
            TP = confusion_matrix_[1][1]

            acc = (TP + TN) / (TP + TN + FP + FN)
            sensitive = TP / (TP + FN)
            specify = TN / (TN + FP)

            print('混淆矩阵\n', confusion_matrix_)
            print('准确度, 灵敏度, 特异度: ', acc, sensitive, specify)

            # 保存每个epoch的模型
            if epoch >= 79:
                exit()
            model.saver.save(sess, cfg.train.save_model, global_step=epoch)

完整代码

from Model import *
from DataLoader import *
from config import cfg
from collections import Counter
import shutil
os.environ['CUDA_VISIBLE_DEVICES'] = '1'


def train():
    tf.set_random_seed(-1)

    # ****************************************************************** #
    #                   1. Python多线程数据读取与数据预处理                 #
    # ****************************************************************** #
    train_queue = train_set_queue()
    valid_queue = valid_set_queue()
    # for i in range(1000000):
    #     name1, _, image, label = train_queue.get()
    #     name2, _, valid_image, valid_label = valid_queue.get()
    #     print(name1)
    #     print(name2)
    #     print(len(name1))
    #     print(len(name2))
    #     print('==============================================')

    # ****************************************************************** #
    #                   2. 构建静态图                                     #
    # ****************************************************************** #
    model = Model()
    print('variables: ', model.first_stage_trainable_var_list)

    # ****************************************************************** #
    #                   3. 恢复权重, 开启会话, 执行静态图                   #
    # ****************************************************************** #
    init = tf.global_variables_initializer()
    config_proto = tf.ConfigProto()
    config_proto.gpu_options.allow_growth = True
    with tf.Session(config=config_proto) as sess:
        sess.run(init)
        try:
            print('=> Restoring weights from: %s ... ' % model.pre_trained_weight)
            model.loader.restore(sess, model.pre_trained_weight)
            print('成功加载模型')
        except:
            print('Load Pretrained Weight Failed !!!\n')
            print('=> %s does not exist !!!' % model.pre_trained_weight)
            print('=> Now it starts to train from scratch ...')
            cfg.Train.First_Stage_Epochs = 0  # 重头开始训练网络, 不需要预训练权重
        else:
            print('\nLoad Model Success !!!\n')

        start_epoch = 0

        # 训练日志保存目录
        save_path = cfg.train.log
        if os.path.exists(save_path):
            print('目录非空, 清空目录')
            shutil.rmtree(save_path)
        train_writer = tf.summary.FileWriter(save_path + 'train', sess.graph)
        valid_writer = tf.summary.FileWriter(save_path + 'valid')

        for epoch in range(start_epoch + 1, 1 + model.first_stage_epoch + model.second_stage_epoch):
            if epoch < model.first_stage_epoch:
                train_op = model.train_op_with_frozen_variables
            else:
                train_op = model.train_op_with_all_variables

            # Train Process
            for step in range(0, cfg.train.train_num // model.batch_size):
                # 每次从队列中取出一个batch的数据
                _, _, image, label = train_queue.get()

                # 统计每个batch中标签的二分类分布情况
                # result = Counter(label)
                train_params = [train_op,
                                model.loss,
                                model.accuracy,
                                model.merged,
                                model.learn_rate,
                                model.global_step]

                train_dict = {
    model.inputs: image,
                              model.label_c: label,
                              model.trainable: True}

                _, train_loss, train_accuracy, merge_train, lr_, global_step_ = sess.run(train_params,
                                                                                         feed_dict=train_dict)

                train_writer.add_summary(merge_train, global_step_)
                print('iter:%2d/%d || train loss:%.4f || train accuracy:%.4f || lr:%g ' % (step,
                                                                                           epoch,
                                                                                           train_loss,
                                                                                           train_accuracy,
                                                                                           lr_))
                # print('阴性样本:%d || 阳性样本:%d ' % (result[0], result[1]))

            # Valid Process
            valid = []
            label_ = []
            predict_ = []
            for valid_step in range(cfg.train.valid_num // model.batch_size):
                name, _, valid_image, valid_label = valid_queue.get()

                valid_params = [model.accuracy,
                                model.loss,
                                model.merged,
                                tf.argmax(model.logits, 1),
                                model.label_c]

                valid_dict = {
    model.inputs: valid_image,
                              model.label_c: valid_label,
                              model.trainable: False}

                valid_accuracy, valid_loss, merge_valid, predict, label_c = sess.run(valid_params, feed_dict=valid_dict)

                valid_writer.add_summary(merge_valid, global_step_)
                print('==================================================================================')
                print('step=%2d/%d || valid accuracy=%g || loss=%.4f ' % (valid_step,
                                                                          epoch,
                                                                          valid_accuracy,
                                                                          valid_loss))
                print('==================================================================================')
                print('标签值:', label_c)
                print('预测值:', predict)

                label_.append(label_c)
                predict_.append(predict)
                valid.append(valid_accuracy)

            # 计算所有验证集的混淆矩阵
            print('mean valid accuracy: ', np.mean(valid))
            confusion_matrix = tf.confusion_matrix(np.hstack(label_), np.hstack(predict_), num_classes=2)
            confusion_matrix_ = sess.run(confusion_matrix)

            TN = confusion_matrix_[0][0]
            FP = confusion_matrix_[0][1]
            FN = confusion_matrix_[1][0]
            TP = confusion_matrix_[1][1]

            acc = (TP + TN) / (TP + TN + FP + FN)
            sensitive = TP / (TP + FN)
            specify = TN / (TN + FP)

            print('混淆矩阵\n', confusion_matrix_)
            print('准确度, 灵敏度, 特异度: ', acc, sensitive, specify)

            # 保存每个epoch的模型
            if epoch >= 79:
                exit()
            model.saver.save(sess, cfg.train.save_model, global_step=epoch)


if __name__ == '__main__':
    train()

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/kxh123456/article/details/109546778

智能推荐

class和struct的区别-程序员宅基地

文章浏览阅读101次。4.class可以有⽆参的构造函数,struct不可以,必须是有参的构造函数,⽽且在有参的构造函数必须初始。2.Struct适⽤于作为经常使⽤的⼀些数据组合成的新类型,表示诸如点、矩形等主要⽤来存储数据的轻量。1.Class⽐较适合⼤的和复杂的数据,表现抽象和多级别的对象层次时。2.class允许继承、被继承,struct不允许,只能继承接⼝。3.Struct有性能优势,Class有⾯向对象的扩展优势。3.class可以初始化变量,struct不可以。1.class是引⽤类型,struct是值类型。

android使用json后闪退,应用闪退问题:从json信息的解析开始就会闪退-程序员宅基地

文章浏览阅读586次。想实现的功能是点击顶部按钮之后按关键字进行搜索,已经可以从服务器收到反馈的json信息,但从json信息的解析开始就会闪退,加载listview也不知道行不行public abstract class loadlistview{public ListView plv;public String js;public int listlength;public int listvisit;public..._rton转json为什么会闪退

如何使用wordnet词典,得到英文句子的同义句_get_synonyms wordnet-程序员宅基地

文章浏览阅读219次。如何使用wordnet词典,得到英文句子的同义句_get_synonyms wordnet

系统项目报表导出功能开发_积木报表 多线程-程序员宅基地

文章浏览阅读521次。系统项目报表导出 导出任务队列表 + 定时扫描 + 多线程_积木报表 多线程

ajax 如何从服务器上获取数据?_ajax 获取http数据-程序员宅基地

文章浏览阅读1.1k次,点赞9次,收藏9次。使用AJAX技术的好处之一是它能够提供更好的用户体验,因为它允许在不重新加载整个页面的情况下更新网页的某一部分。另外,AJAX还使得开发人员能够创建更复杂、更动态的Web应用程序,因为它们可以在后台与服务器进行通信,而不需要打断用户的浏览体验。在Web开发中,AJAX(Asynchronous JavaScript and XML)是一种常用的技术,用于在不重新加载整个页面的情况下,从服务器获取数据并更新网页的某一部分。使用AJAX,你可以创建异步请求,从而提供更快的响应和更好的用户体验。_ajax 获取http数据

Linux图形终端与字符终端-程序员宅基地

文章浏览阅读2.8k次。登录退出、修改密码、关机重启_字符终端

随便推点

Python与Arduino绘制超声波雷达扫描_超声波扫描建模 python库-程序员宅基地

文章浏览阅读3.8k次,点赞3次,收藏51次。前段时间看到一位发烧友制作的超声波雷达扫描神器,用到了Arduino和Processing,可惜啊,我不会Processing更看不懂人家的程序,咋办呢?嘿嘿,所以我就换了个思路解决,因为我会一点Python啊,那就动手吧!在做这个案例之前先要搞明白一个问题:怎么将Arduino通过超声波检测到的距离反馈到Python端?这个嘛,我首先想到了串行通信接口。没错!就是串口。只要Arduino将数据发送给COM口,然后Python能从COM口读取到这个数据就可以啦!我先写了一个测试程序试了一下,OK!搞定_超声波扫描建模 python库

凯撒加密方法介绍及实例说明-程序员宅基地

文章浏览阅读4.2k次。端—端加密指信息由发送端自动加密,并且由TCP/IP进行数据包封装,然后作为不可阅读和不可识别的数据穿过互联网,当这些信息到达目的地,将被自动重组、解密,而成为可读的数据。不可逆加密算法的特征是加密过程中不需要使用密钥,输入明文后由系统直接经过加密算法处理成密文,这种加密后的数据是无法被解密的,只有重新输入明文,并再次经过同样不可逆的加密算法处理,得到相同的加密密文并被系统重新识别后,才能真正解密。2.使用时,加密者查找明文字母表中需要加密的消息中的每一个字母所在位置,并且写下密文字母表中对应的字母。_凯撒加密

工控协议--cip--协议解析基本记录_cip协议embedded_service_error-程序员宅基地

文章浏览阅读5.7k次。CIP报文解析常用到的几个字段:普通类型服务类型:[0x00], CIP对象:[0x02 Message Router], ioi segments:[XX]PCCC(带cmd和func)服务类型:[0x00], CIP对象:[0x02 Message Router], cmd:[0x101], fnc:[0x101]..._cip协议embedded_service_error

如何在vs2019及以后版本(如vs2022)上添加 添加ActiveX控件中的MFC类_vs添加mfc库-程序员宅基地

文章浏览阅读2.4k次,点赞9次,收藏13次。有时候我们在MFC项目开发过程中,需要用到一些微软已经提供的功能,如VC++使用EXCEL功能,这时候我们就能直接通过VS2019到如EXCEL.EXE方式,生成对应的OLE头文件,然后直接使用功能,那么,我们上篇文章中介绍了vs2017及以前的版本如何来添加。但由于微软某些方面考虑,这种方式已被放弃。从上图中可以看出,这一功能,在从vs2017版本15.9开始,后续版本已经删除了此功能。那么我们如果仍需要此功能,我们如何在新版本中添加呢。_vs添加mfc库

frame_size (1536) was not respected for a non-last frame_frame_size (1024) was not respected for a non-last-程序员宅基地

文章浏览阅读785次。用ac3编码,执行编码函数时报错入如下:[ac3 @ 0x7fed7800f200] frame_size (1536) was not respected for anon-last frame (avcodec_encode_audio2)用ac3编码时每次送入编码器的音频采样数应该是1536个采样,不然就会报上述错误。这个数字并非刻意固定,而是跟ac3内部的编码算法原理相关。全网找不到,国内音视频之路还有很长的路,音视频人一起加油吧~......_frame_size (1024) was not respected for a non-last frame

Android移动应用开发入门_在安卓移动应用开发中要在活动类文件中声迷你一个复选框变量-程序员宅基地

文章浏览阅读230次,点赞2次,收藏2次。创建Android应用程序一个项目里面可以有很多模块,而每一个模块就对应了一个应用程序。项目结构介绍_在安卓移动应用开发中要在活动类文件中声迷你一个复选框变量

推荐文章

热门文章

相关标签