深度学习Pytorch(十一)——微调torchvision模型(1)_柚子味的羊的博客-程序员宅基地_torchvision 模型

技术标签: Pytorch  python  深度学习  pytorch  Python  

深度学习Pytorch(十一)——微调torchvision模型

一、简介

在本小节,深入探讨如何对torchvision进行微调和特征提取。所有模型都已经预先在1000类的magenet数据集上训练完成。 本节将深入介绍如何使用几个现代的CNN架构,并将直观展示如何微调任意的PyTorch模型。
本节将执行两种类型的迁移学习:

  • 微调:从预训练模型开始,更新我们新任务的所有模型参数,实质上是重新训练整个模型。
  • 特征提取:从预训练模型开始,仅更新从中导出预测的最终图层权重。它被称为特征提取,因为我们使用预训练的CNN作为固定 的特征提取器,并且仅改变输出层。

通常这两种迁移学习方法都会遵循一下步骤:

  • 初始化预训练模型
  • 重组最后一层,使其具有与新数据集类别数相同的输出数
  • 为优化算法定义想要的训练期间更新的参数
  • 运行训练步骤

二、导入相关包

from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision 
from torchvision import datasets,models,transforms
import matplotlib.pyplot as plt
import time
import os
import copy
print("Pytorch version:",torch.__version__)
print("torchvision version:",torchvision.__version__)

运行结果
在这里插入图片描述

三、数据输入

数据集——>我在这里

#%%输入
data_dir="D:\Python\Pytorch\data\hymenoptera_data"
# 从[resnet,alexnet,vgg,squeezenet,desenet,inception]
model_name='squeezenet'
# 数据集中类别数量
num_classes=2
# 训练的批量大小
batch_size=8
# 训练epoch数
num_epochs=15
# 用于特征提取的标志。为FALSE,微调整个模型,为TRUE只更新图层参数
feature_extract=True

四、辅助函数

1、模型训练和验证

  • train_model函数处理给定模型的训练和验证。作为输入,它需要PyTorch模型、数据加载器字典、损失函数、优化器、用于训练和验 证epoch数,以及当模型是初始模型时的布尔标志。
  • is_inception标志用于容纳 Inception v3 模型,因为该体系结构使用辅助输出, 并且整体模型损失涉及辅助输出和最终输出,如此处所述。 这个函数训练指定数量的epoch,并且在每个epoch之后运行完整的验证步骤。它还跟踪最佳性能的模型(从验证准确率方面),并在训练 结束时返回性能最好的模型。在每个epoch之后,打印训练和验证正确率。
#%%模型训练和验证
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False):
    since=time.time()
    val_acc_history=[]
    best_model_wts=copy.deepcopy(model.state_dict())
    best_acc=0.0
    for epoch in range(num_epochs):
        print('Epoch{}/{}'.format(epoch, num_epochs-1))
        print('-'*10)
        # 每个epoch都有一个训练和验证阶段
        for phase in['train','val']:
            if phase=='train':
                model.train()
            else:
                model.eval()
                
            running_loss=0.0
            running_corrects=0
            # 迭代数据
            for inputs,labels in dataloaders[phase]:
                inputs=inputs.to(device)
                labels=labels.to(device)
                # 梯度置零
                optimizer.zero_grad()
                # 向前传播
                with torch.set_grad_enabled(phase=='train'):
                    # 获取模型输出并计算损失,开始的特殊情况在训练中他有一个辅助输出
                    # 在训练模式下,通过将最终输出和辅助输出相加来计算损耗,在测试中值考虑最终输出
                    if is_inception and phase=='train':
                        outputs,aux_outputs=model(inputs)
                        loss1=criterion(outputs,labels)
                        loss2=criterion(aux_outputs,labels)
                        loss=loss1+0.4*loss2
                    else:
                        outputs=model(inputs)
                        loss=criterion(outputs,labels)
                        
                    _,preds=torch.max(outputs,1)
                    
                    if phase=='train':
                        loss.backward()
                        optimizer.step()
                        
                # 添加
                running_loss+=loss.item()*inputs.size(0)
                running_corrects+=torch.sum(preds==labels.data)
                
            epoch_loss=running_loss/len(dataloaders[phase].dataset)
            epoch_acc=running_corrects.double()/len(dataloaders[phase].dataset)
            
            print('{}loss : {:.4f} acc:{:.4f}'.format(phase, epoch_loss,epoch_acc))
            
            if phase=='train' and epoch_acc>best_acc:
                best_acc=epoch_acc
                best_model_wts=copy.deepcopy(model.state_dict())
            if phase=='val':
                val_acc_history.append(epoch_acc)
            
        print()

    time_elapsed=time.time()-since
    print('training complete in {:.0f}s'.format(time_elapsed//60, time_elapsed%60))
    print('best val acc:{:.4f}'.format(best_acc))
    
    model.load_state_dict(best_model_wts)
    return model,val_acc_history

2、设置模型参数的’.requires_grad属性’

当我们进行特征提取时,此辅助函数将模型中参数的 .requires_grad 属性设置为False。
默认情况下,当我们加载一个预训练模型时,所有参数都是 .requires_grad = True,如果我们从头开始训练或微调,这种设置就没问题。
但是,如果我们要运行特征提取并且只想为新初始化的层计算梯度,那么我们希望所有其他参数不需要梯度变化。

#%%设置模型参数的.require——grad属性
def set_parameter_requires_grad(model,feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.require_grad=False

靓仔今天先去跑步了,再不跑来不及了,先更这么多,后续明天继续~(感谢有人没有催更!感谢监督!希望继续监督!)

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_43368987/article/details/121170137

智能推荐

@Scope("prototype")_转角向右捡爱的博客-程序员宅基地

当我们在一个ACTION类里面写很多个方法的时候(其实是一种按功能划分模块编程的思想),每个方法的返回状态可能不一样,如果ACTION中不写@Scope("prototype"),有可能报找不到XXXACTION的错误!写上这个就表示每次请求都重新创建一个ACTION,与SINGALON对应,俗称“多例”。

A ResourcePool could not acquire a resource from its primary factory or source._有上进心的阿龙的博客-程序员宅基地

D:\install\jdk18\corretto-1.8.0_302\bin\java.exe -ea -Didea.test.cyclic.buffer.size=1048576 “-javaagent:D:\install\JetBrains\IntelliJ IDEA 2021.2.2\lib\idea_rt.jar=56331:D:\install\JetBrains\IntelliJ IDEA 2021.2.2\bin” -Dfile.encoding=UTF-8 -classpath “D:\

redis集群搭建_chengtuo1866的博客-程序员宅基地

Redis集群介绍 Redis 集群是一个提供在多个Redis间节点间共享数据的程序集。 Redis集群并不支持处理多个keys的命令,因为这需要在不同的节点间移动数据,从而达不到像Redis那样的性能,在高负载的情况下可能会导致不可预料的错误. Redis 集群通过分区来提供一定程度...

找论文,下载论文_ming***的博客-程序员宅基地

找论文搜索自己学校的校名---点进去图书馆---一般会有资源导航,里边会包含很多可以选择的网站,一般选择中国知网。点进去就到了知网的页面,然后就可以查找了。下载论文进入谷歌学术镜像---选择一个镜像---输入关键词,会查询出一些论文---选择一篇,点进去,找到文章的 DOI (可理解为论文的身份证)找到了文章的 DOI 之后,就复制下来。然后打开 SCI-HUB 网站(sci-hub proxy:sci-hub links 2020 download,chrome extension2020

Linux-目录-配置文件_adouwen3320的博客-程序员宅基地

命令:mount 设置挂载点/ 根 所有目录顶点├── bin 普通用户命令,二进制命令所在的目录 which 可用来查找命令├── boot Linux 内核及系统引导程序所需文件目录├── cgroup├── dev 设备目录├── etc 系统配文件(yum,rpm)配置文件默认路径fstab:开机自动挂载设置,实现开机要挂载的文件系统命令:blkid 来...

雪花算法生成数字id_分布式主键id生成方式之雪花算法_weixin_39932479的博客-程序员宅基地

在分布式环境下,通常都会有主键id或者单号创建的需求。要求在任何情况下生成的主键id或者单号都是不能够重复的,所以我们需要一种主键或者单号生成机制。这里有一下几种方法:1、数据库自增使用mysql数据库的主键id自增2、redis自增3、使用uuid4、雪花算法雪花算法(SnowFlake),是Twitter开源的分布式id生成算法。其核心思想就是:使用时间戳+工作机器id+序列号生成一个64 b...

随便推点

Smart Card_supergame111的博客-程序员宅基地

智能卡(SmartCard),也叫IC卡,它是一个带有微处理器和存储器等微型集成电路芯片的、具有标准规格的卡片。智能卡必须遵循一套标准,ISO7816是其中最重要的一个。下面将从以下几个方面展开,对Smart Card进行讨论:1. 电气特性2. 复位应答(ATR – Answer to Reset)3. T=0 传输协议电气特性:I

在javascript中,如何判断一个被多次encode 的url 已经被decode到原来的格式?_weixin_30493401的博客-程序员宅基地

%而不能被无限次decodeURIComponent 可以用%来进行判断转载于:https://www.cnblogs.com/zhouyideboke/p/11169705.html

你连原理都还没弄明白?java的基本单位_程序员小伊的博客-程序员宅基地

内容简介:本书一共15章,核心内容为SpringBoot、SpringCloud、Docker、RabbitMQ消息组件。其中,SpringBoot是SpringMVC技术的延伸,使用它进行程序开发会更简单,服务整合也会更容易。SpringCloud是当前微架构的核心技术方案,属于SpringBoot的技术延伸,它可以整合云服务,基于RabbitMQ和 GITHUB进行微服务管理。除此以外,本书还重点分析了OAuth统一认证服务的应用。由于笔记的内容太多,没办法全部展示出来,在此只截取部分内容展示。

minion java上传文件_Spring Boot 上传文件出错:java.io.IOException: The temporary upload location..._屁乎小铭的博客-程序员宅基地

前言,新鲜报错记录一下原因1.Springboot的应用服务在启动的时候,会生成在操作系统的/tmp目录下生成一个Tomcat.*的文件目录,用于"java.io.tmpdir"文件流操作TomcatEmbeddedServletContainerFactory2.程序对文件的操作时:会生成临时文件,暂存在临时文件中;linux系统的tmpwatch 命令会删除10天未使用的临时文件;长时间不操作...

java获取上下文名称_从Spring 应用上下文获取 Bean 的常用姿势_人亲卓玛的博客-程序员宅基地

1. 前言通常,在Spring应用程序中,当我们使用 @Bean,@Service,@Controller,@Configuration 或者其它特定的注解将 Bean 注入 Spring IoC 。然后我们可以使用 Spring 框架提供的 @Autowired 或者 JSR250、JSR330 规范注解来使用由 Spring IoC 管理的 Bean 。2. 从应用程序上下文中获取 Bean...

jsp里java代码报错_JavaEE-01 JSP动态网页基础_狐狸君raphael的博客-程序员宅基地

学习要点B/S架构的基本概念Web项目的创建和运行JSP页面元素MyEclipse创建和运行Web项目Web程序调试Web简史web前端技术演进三阶段WEB 1.0:信息广播。WEB 2.0:信息交互。 微博、博客等。WEB 5.0:移动互联网。动态网页服务器端技术演进主流web程序应用平台性能比较LAMPJavaEEASP.NET运行速度较快快一般开发速度非常快慢一般运行损耗一般较小较大难易程度...