PyTorch笔记——实现线性回归完整代码&手动或自动计算梯度代码对比_pytorch线性回归代码-程序员宅基地

技术标签: python  AI编程  机器学习  pytorch  人工智能  线性回归  

参考资料:《深度学习框架PyTorch:入门与实践》
本文对此书中线性回归部分的代码进行注释解读,并补充手动反向传播过程中求解梯度的公式。

一、生成数据集完整代码

采用“假数据”:

# 设置随机数种子,保证在不同计算机上运行时下面的输出一致
t.manual_seed(1000)

def get_fake_data(batch_size=8):
    ''' 产生随机数据:y=x*2+3,加上了一些噪声 '''
    x = t.rand(batch_size, 1) * 20
    y = x * 2 + (1 + t.randn(batch_size, 1) * 3)
    return x, y

☆二、线性回归完整代码

自动计算梯度的代码在注释中:

# 如果括号内填一个1,则报错:mat2 must be a matrix
w = t.rand(1, 1)
b = t.zeros(1, 1)

# # 如果自动计算梯度
# # 注意requires_grad默认是False,不设置为True会在loss.backward()报错
# w = t.rand(1, 1, requires_grad=True)
# b = t.zeros(1, 1, requires_grad=True)

lr = 0.0001
losses = np.zeros(500)

for ii in range(500):
    x, y = get_fake_data(batch_size=32)
    
    # 前向传播,计算loss,采用均方误差
    # torch.mul是逐元素相乘;torch.mm是矩阵相乘
    y_pred = t.mm(x, w) + b.expand_as(y)
    loss = 0.5 * (y_pred - y) ** 2
    loss = loss.sum()
    losses[ii] = loss.item()
    
    # 反向传播,手动计算梯度
    dloss = 1
    dy_pred = dloss * (y_pred - y)
    dw = t.mm(x.t(), dy_pred)
    db = dy_pred.sum() # 注意b是标量,使用的时候扩展为元素全为b的向量
    
    # 更新参数
    w.sub_(lr * dw)
    b.sub_(lr * db)
    
#     # 如果自动计算梯度
#     loss.backward()
#     w.data.sub_(lr * w.grad.data)
#     b.data.sub_(lr * b.grad.data)
#     # 注意梯度清零
#     w.grad.data.zero_()
#     b.grad.data.zero_()
    
    # 每1000次训练画一次图
    if ii % 50 == 0:
        display.clear_output(wait=True)
        # predicted
        x = t.arange(0, 20).view(-1, 1).float()
        y = t.mm(x, w) + b.expand_as(x)
        plt.plot(x.numpy(), y.numpy())
        
        # true data
        x2, y2 = get_fake_data(batch_size=20)
        plt.scatter(x2.numpy(), y2.numpy())
        
        plt.xlim(0, 5)
        plt.ylim(0, 13)
        plt.show()
        plt.pause(0.5)
        
print(w.item(), b.item())
# print(w.data[0][0], b.data[0][0]) # 和上面等价

运行结果:
在这里插入图片描述

观察loss的变化:

plt.plot(losses)
plt.ylim(50, 500)

输出结果:
在这里插入图片描述
loss是在稳步变小。

三、手动计算梯度的公式

选自《深度学习》(花书):
在这里插入图片描述
记住上述公式 G B T GB^T GBT或者 A T G A^TG ATG即可,根据这个公式,根据这个公式,loss对w的梯度为 x T d y _ p r e d x^Tdy\_pred xTdy_pred.

四、关于输出为“nan nan”的情况

print(w.item(), b.item()),如果最后w和b的值输出都为nan,那么调小学习率就行了。我将学习率定为0.001都会遇到这个情况,定为0.0001就好了。

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/umbrellalalalala/article/details/119945805

智能推荐

【OpenCV 例程200篇】64. 图像锐化——Sobel 算子_opencv sobel算子进行图像锐化-程序员宅基地

文章浏览阅读9.8k次,点赞8次,收藏48次。图像锐化的目的是增强图像的灰度跳变部分,使模糊的图像变得清晰。图像锐化也称为高通滤波,通过和增强高频,衰减和抑制低频。图像锐化常用于电子印刷、医学成像和工业检测。Sobel 算子是一种离散的微分算子,是高斯平滑和微分求导的联合运算,抗噪声能力强。Sobel 梯度算子很容易通过卷积操作 cv.filter2D 实现,OpenCV 也提供了函数 cv.Sobel 实现 Sobel 梯度算子。_opencv sobel算子进行图像锐化

【Window系统】安装FFmpeg教程_windows安装ffmpeg-程序员宅基地

文章浏览阅读4.1k次,点赞4次,收藏10次。到这里ffmpeg的配置就完成了。我们调用命令行(windows+R输入cmd)输入“ffmpeg –version”,如果出现以下结果则说明配置成功。记得点下方的确定,再关闭当前窗口再点确定,这样才能保存,千万记得不能点击取消。选择新建,把刚刚复制的bin路径粘贴进去,点击确定。_windows安装ffmpeg

The basics of swift-程序员宅基地

文章浏览阅读74次。原文出自:标哥的技术博客前言Swift是iOS、OS X和WatchOS平台新的开发语言。尽管如此,Swift有很多是与我们使用过的C和Objective-C开发经验是很像的。Swift提供了自己版本的C和Objective-C基础数据类型,包括整型Int、浮点型Double和Float、Boolean值Bool...

docker:如何将本地文件复制到docker容器内_docker拷贝文件到容器-程序员宅基地

文章浏览阅读4.1w次,点赞17次,收藏80次。如何将本地文件复制到docker容器内我们通过docker cp指令来将容器外文件传递到docker容器内1、查看容器IDdocker ps -a2、将本地文件复制到docker容器中docker cp 本地文件路径 容器ID/容器NAME:容器内路径举例:docker cp /Users/wuhanxue/Downloads/rabbitmq_delayed_message_exchange-3.9.0.ez 1faca6a70742:/opt/rabbitmq/plugins或者_docker拷贝文件到容器

网络工程师实战系统【NAT专题】-夏杰-专题视频课程-程序员宅基地

文章浏览阅读464次。通俗易懂讲解NAT技术。_网络工程师考试 夏杰 新浪

ROM开发-程序员宅基地

文章浏览阅读1.8k次。作者:X神之怒 链接:https://www.zhihu.com/question/20076944/answer/381539565 来源:知乎 著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。1、Android系统是什么?Android是Google公司于2007年发布的基于Linux的移动终端系统平台。之所以说是移动终端,是因为现如今手机、MID、Tablet等..._rom开发

随便推点

NachOS线程ID的实现、最大线程数的实现和优先级的添加_nachos线程调度调度时,线程的产生和调度须同时进行,并且要构建它们的线程家族树。-程序员宅基地

文章浏览阅读3.4k次,点赞13次,收藏48次。NachOS线程的描述和优先级1.实验目的(1)通过阅读相关源码,掌握NachOS运行原理和编译方法;(2)完善NachOS下线程描述的内容。2.实验内容(1)为NachOS线程添加线程ID,并设置系统最大线程数;(2)为NachOS线程调度添加优先级,为实现基于优先级的调度做准备。3.实验方法(实验步骤)(1)理解NachOS线程的运行与调度原理,找到需要修改的代码(注:以下所有修改代码的部分,均是由vim修改完成);(2)对thread.h进行修改:在头文件处定义线程最大数MAX_SI_nachos线程调度调度时,线程的产生和调度须同时进行,并且要构建它们的线程家族树。

Vue - 关闭项目 ESlint 校验(非 Vscode 插件)_非vscode eslint-程序员宅基地

文章浏览阅读2.1k次。如果您最初创建项目时(或别人的项目)带有ESlint代码规范校验,本文为您带来如何一行代码进行关闭。_非vscode eslint

15.mvc和分页_mvc用vue分页-程序员宅基地

文章浏览阅读644次,点赞2次,收藏5次。MVC和分页第一节 MVC模式简介1.1 MVC概念​ 首先我们需要知道MVC模式并不是javaweb项目中独有的,MVC是一种软件工程中的一种设计模式,把软件系统分为三个基本部分:模型(Model)、视图(View)和控制器(Controller),即为MVC。它是一种软件设计的典范,最早为Trygve Reenskaug提出,为施乐帕罗奥多研究中心(Xerox PARC)的Sma..._mvc用vue分页

CentOS 7 安装 Hive_centos7.5安装hive-程序员宅基地

文章浏览阅读1.4k次。操作系统:CentOS 7Hive版本:2.3.6JDK版本:1.8Mysql版本:5.7安装前准备保证 hadoop 正常运行保证 Mysql 正常运行确保JDK 正常安装yum install java-1.8.0-openjdk创建hive数据库并为其授权在msyql数据库中创建hive的元数据库create database hive;..._centos7.5安装hive

eclipse 下载并配置maven_eclipse下载weavn-程序员宅基地

文章浏览阅读2.1k次,点赞5次,收藏21次。原文链接我的个人博客maven的下载到官网:http://maven.apache.org/download.cgi 请选择最新的版本下载解压后,再新建一个仓库目录。如下图配置相应的环境变量右键“计算机”,选择“属性”,之后点击“高级系统设置”,点击“环境变量”,来设置环境变量,有以下系统变量需要配置:新建系统变量 MAVEN_HOME 变量值:C:\Program Fil..._eclipse下载weavn

进出口流程 & 报关单据-程序员宅基地

文章浏览阅读892次。出口流程一. 委托人1. 需找货运代理公司2. 向代理公司询问价格 一般为 ALL IN 价格( 空运费+燃油费+战险费 ) 总费用 = ALL IN 价格 * ( 货物公斤数 ) ALL IN 价格等级: M (最低收费)空运货物最低收费,一般不足10KGS的货物价格。 N+ (低于45KGS且大于10KGS的货物)价格; 45+ (超..._海空联运如何报关

推荐文章

热门文章

相关标签