用Pytorch实现一个线性回归_pytorch实现线性回归并计算cost-程序员宅基地

技术标签: 深度学习  pytorch  线性回归  

问题描述

假设学生在期末考试中,如果他们花x个小时在一门课程上,他们将得到y分。

x (hours) y (points)
1 2
2 4
3 6
4

问题是在这门课上花费4个小时时,得到的分数是多少?


很显然这是一个回归的问题。
下面结局这个问题的方法,是机器学习训练任务的基本方法论,可以基于此构建出更为复杂的模型去处理更为复杂的任务。


代码与注释

import torch

x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

class LinearModel(torch.nn.Module):	#将模型构建成一个类,并且继承自Moudle
    def __init__(self):	#构造函数
        super(LinearModel, self).__init__() #调用父类的构造,固定写法,这一步必须要有
        self.linear = torch.nn.Linear(1, 1) #Linear是一个类,类后面加括号意思是构建了一个对象,括号里面是的参数是权重和偏置
        
    def forward(self, x):	#前向传播函数
        y_pred = self.linear(x) #在一个对象的后面加括号,实现了一个可调用的对象,x送入Linear对象,执行w * x + b
        return y_pred
model = LinearModel()

# 损失函数,将向量里的损失进行求和,得到一个标量的损失值,MSELoss也继承自nn.Moudle
criterion = torch.nn.MSELoss(size_average=False)
# 优化器,不继承自Moudle,不会构建计算图,构建出的优化器知道要对哪些参数做优化,并且知道学习率
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 训练
for epoch in range(100):
    y_pred = model(x_data)	#先计算y_hat
    loss = criterion(y_pred, y_data) #计算损失
    print(epoch, loss) #loss是一个标量,一个对象,会自动调用__str__()函数,不会产生计算图,是安全的
    optimizer.zero_grad()	#梯度归0
    loss.backward()	#反向传播
    optimizer.step()	#step()用来做更新,根据预先设置的参数以及包含的梯度和学习率自动进行更新
    
# 输出 weight 和 bias
print('w = ', model.linear.weight.item()) #weight是一个矩阵,加上.item()让其只显示数值
print('b = ', model.linear.bias.item())

# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)

print('y_pred = ', y_test.data)

训练结果:
在这里插入图片描述
可以看到,通过不断的迭代,参数中的weight趋近于2,b趋近于0,最终的预测值趋近于8。
上面用的nn.Linear是一个线性模型,如下图所示:
在这里插入图片描述


总结

对于一个模型,都要将其构建成一个类,并且继承自Module,至少定义两个方法,一个是构造函数__init__(self),另一个是forward()。而backward会由Module里自动根据计算图计算。

对于整个训练搭建和训练的的步骤,总结为以下四步:

  1. Prepare dataset
  2. Design model using Class
  3. Construct loss and optimizer
  4. Training cycle

然后就是不断的前馈—反馈—更新----前馈—反馈—更新最后使Loss收敛。


参考资料

[1] https://www.bilibili.com/video/BV1Y7411d7Ys?p=5

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/hshudoudou/article/details/127060601

智能推荐

each循环-程序员宅基地

文章浏览阅读1.7k次。一、each的两种写法(1)遍历元素节点 $(node).each(function(index,element) { console.log(index); console.log(element); }) (2)遍历数组,数据格式 $.e...

java截取字符串_java根据下标获取字符-程序员宅基地

文章浏览阅读1.5k次。1.根据下标截取:String str = str.substring(开始的下标,截取的长度)2.根据字符截取://截取_之前字符串:String str = str.substring(0, str.indexOf("_"));------------------------------------------------------------------//截取_之后字符串:String str = str.substring(0, str.indexOf("_"));String_java根据下标获取字符

iscsi配置_在linux操作系统上配置iscsi的服务器targetcli不小心退出了-程序员宅基地

文章浏览阅读1.7k次。一、环境配置(安装)服务器端:yum install targetcli #管理程序systemctl start target #打开服务客户端:yum install iscsi-initiator-utils #iscsi应用程序通常这个都安装过了。二、targetcli的配置服务器端: 首先要有一块需要共享的硬盘分区,这是我要共享的..._在linux操作系统上配置iscsi的服务器targetcli不小心退出了

QTreeWidget详细使用介绍_qtreewidget用法-程序员宅基地

文章浏览阅读1.1w次,点赞10次,收藏87次。QTreeWidget继承自QTreeView,是通过树形结构来展示数据结构的控件。1.QTreeWidget和QTreeView的区别QTreeView一般和相应的QXXModel合用,形成Model/View结构.QTreeView是一个视图类,你需要手动给其指定模型类,才能够显示数据。QTreeWidget继承自QTreeView,是封闭了默认Model的QTreeView,应用了模型/视图的编程方法,将数据和显示分开了。就灵活性来讲,QTreeView要灵活些。QTreewidg._qtreewidget用法

Android CameraX和SurfaceView的基本使用_camerax surfaceview-程序员宅基地

文章浏览阅读1.9k次。Jetpack CameraX 库 的 PreviewView 可以帮助您解决这一问题。通过在各种 Android 设备上提供开发者友好、一致且稳定的 API,使得展示相机的预览变得不再困难。_camerax surfaceview

链接器的奥秘:解析计算机系统中的Linker_链接器linker-程序员宅基地

文章浏览阅读310次。链接器作为计算机系统中重要的组成部分,负责将多个目标文件和库文件合并成可执行文件。它通过符号解析、地址重定位和符号重定位等步骤,完成了程序的连接和重定位工作。此外,链接器还可以进行一些优化,以提高程序的执行效率。链接器的工作原理和功能对于软件开发人员来说至关重要,在编写高效、可靠的程序时起着重要作用。希望本文能够帮助读者更好地理解和应用链接器,提升软件开发的水平。_链接器linker

随便推点

docker如何查看容器ip_docker 容器 ip-程序员宅基地

文章浏览阅读5.6k次。docker exec -it xxx sh 进入容器终端cat /etc/hosts 显示ip所在的配置文件_docker 容器 ip

Qt之按钮(QAbstractButton)-程序员宅基地

文章浏览阅读1.5k次。简述QAbstractButton是按钮控件的抽象基类,提供了按钮所共有的功能。这个类实现了一个抽象的按钮。对这个按钮进行子类化可以处理用户行为,以及指定按钮如何绘制。QAbstractButton提供了点击和勾选按钮。QRadioButton和QCheckBox类只提供了勾选按钮,QPushButton和QToolButton提供了点击按钮,如果需要的话,它们还可以提供切换行为。任何按钮...

Springboot集成Lettuce完成Redis cluster集群的key过期监控-程序员宅基地

文章浏览阅读2.7k次。经过几天各种方案的对比及实验,终于完成了Springboot集成Lettuce完成Redis cluster集群的key过期监控的代码,主要参考了如下的文章及代码:https://my.oschina.net/u/4134799/blog/3116221/printhttps://github.com/xfearless1201/api/blob/master/src/mai...

Vuex介绍&同步取值&异步问题_vuex同步异步获取数据的区别-程序员宅基地

文章浏览阅读5.5k次,点赞3次,收藏9次。前言:我们在之前就有了子类与父类之间的传参,又有利用总线进行传参,但两者都有一定的弊端。如:总线定义组件太多容易混淆等;所以接下来我们会利用VueX进行参数传。目录一:VueX简介一:VueX简介官方解释:Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式。可以想象为一个“前端数据库”(数据仓库), 让其在各个页面上实现数据的共享包括状态,并且可操作 Vuex分成五个部分: 1.State:单一状态树 2.Getters:状态获取 3..._vuex同步异步获取数据的区别

servlet返回数据的方法_servlet 返回值-程序员宅基地

文章浏览阅读2.6w次,点赞2次,收藏17次。servlet返回数据的方法方法1. RequestDispatcher.forward()界面跳转 HttpSession session =request.getSession(); Object obj = session.getAttribute (LoginConstants.LOGIN_USER); if (null !=..._servlet 返回值

golang string 去最后一个字符_golang 取字符串最后一个元素-程序员宅基地

文章浏览阅读1.8w次。package mainimport ( "fmt" "strings")func main() { fmt.Println("Hello, 世界") var s string s = "333," strings.TrimRight(s, ",") fmt.Println(s) s = strings.TrimRight(s, ",") fmt.Println(s)..._golang 取字符串最后一个元素

推荐文章

热门文章

相关标签