Pytorch_学习_cnn_多尺度 cnn pytorch-程序员宅基地

技术标签: pytorch  

直接代码:

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt


# Hyper parameters
torch.manual_seed(1)
EPOCH = 1  #训练一批次
BATCH_SICE = 50  # 一次50个
LR = 0.001  # 每次的学习率
DOWNLOAD_MNSIT = False  # 需要下载为True, 下载好了可以设置为False

# download data
train_data = torchvision.datasets.MNIST(
    root='./mnist/',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD_MNSIT
)

test_data = torchvision.datasets.MNIST(
    root='./mnist/',
    train=False
)

# train_loader
train_loader = Data.DataLoader(
    dataset=train_data,
    batch_size=BATCH_SICE,
    shuffle=True
)

# ready_for_test 前20000个
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.Tensor)[:2000]/255.
test_y = test_data.test_labels[:2000]
# print(test_x.size())
# print(test_y.size())
# print(test_x[1])
# print(test_y[1])

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1  = nn.Sequential(     # (1, 28, 28)
            nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2
            ),             # output shape : N = (W - F + 2P)/S + 1
            # W :图片大小 W*W 28, F:filter大小 F*F,5, S:步长, 1  P:padding 2
            # N = (28 - 5 + 2*2)/1 + 1 = 28 因为有16个filter
            # (16, 28, 28)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
            # Maxpooling 2*2 向下采样取最大(想想如果4*4图片,采用2*2maxpooling,就变成了2*2图片)
            # 所以这里图片大小为 (16, 14, 14)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, 1, 2),  # (32, 14, 14)
            nn.ReLU(),
            nn.MaxPool2d(2)  # (32, 7, 7)
        )
        self.out = nn.Linear(32 * 7 * 7, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)  # 展平多维的卷积图成 (batch_size, 32 * 7 * 7)
        output = self.out(x)
        return output

cnn = CNN()
print(cnn)



# training
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()


# training and testing
for epoch in range(EPOCH):
    for step, (b_x, b_y) in enumerate(train_loader):   # 分配 batch data, normalize x when iterate train_loader
        output = cnn(b_x)               # cnn output
        loss = loss_func(output, b_y)   # cross entropy loss
        optimizer.zero_grad()           # clear gradients for this training step
        loss.backward()                 # backpropagation, compute gradients
        optimizer.step()                # apply gradients

        if step % 50 == 0:
            print('train loss = ', loss.data.numpy())

 

 

结果

CNN(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (out): Linear(in_features=1568, out_features=10, bias=True)
)
train loss =  2.3104544
train loss =  0.61840093
train loss =  0.12703253
train loss =  0.23725168
train loss =  0.4044273
train loss =  0.08451731
train loss =  0.19528298
train loss =  0.109065436
train loss =  0.123532385
train loss =  0.06730688
train loss =  0.2240683
train loss =  0.2114728
train loss =  0.024014007
train loss =  0.08469809
train loss =  0.21586336
train loss =  0.10181876
train loss =  0.043114547
train loss =  0.09106462
train loss =  0.055737924
train loss =  0.10089029
train loss =  0.032855053
train loss =  0.021929108
train loss =  0.025414439
train loss =  0.117736794

Process finished with exit code 0

 

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/sinat_15355869/article/details/86552617

智能推荐

class和struct的区别-程序员宅基地

文章浏览阅读101次。4.class可以有⽆参的构造函数,struct不可以,必须是有参的构造函数,⽽且在有参的构造函数必须初始。2.Struct适⽤于作为经常使⽤的⼀些数据组合成的新类型,表示诸如点、矩形等主要⽤来存储数据的轻量。1.Class⽐较适合⼤的和复杂的数据,表现抽象和多级别的对象层次时。2.class允许继承、被继承,struct不允许,只能继承接⼝。3.Struct有性能优势,Class有⾯向对象的扩展优势。3.class可以初始化变量,struct不可以。1.class是引⽤类型,struct是值类型。

android使用json后闪退,应用闪退问题:从json信息的解析开始就会闪退-程序员宅基地

文章浏览阅读586次。想实现的功能是点击顶部按钮之后按关键字进行搜索,已经可以从服务器收到反馈的json信息,但从json信息的解析开始就会闪退,加载listview也不知道行不行public abstract class loadlistview{public ListView plv;public String js;public int listlength;public int listvisit;public..._rton转json为什么会闪退

如何使用wordnet词典,得到英文句子的同义句_get_synonyms wordnet-程序员宅基地

文章浏览阅读219次。如何使用wordnet词典,得到英文句子的同义句_get_synonyms wordnet

系统项目报表导出功能开发_积木报表 多线程-程序员宅基地

文章浏览阅读521次。系统项目报表导出 导出任务队列表 + 定时扫描 + 多线程_积木报表 多线程

ajax 如何从服务器上获取数据?_ajax 获取http数据-程序员宅基地

文章浏览阅读1.1k次,点赞9次,收藏9次。使用AJAX技术的好处之一是它能够提供更好的用户体验,因为它允许在不重新加载整个页面的情况下更新网页的某一部分。另外,AJAX还使得开发人员能够创建更复杂、更动态的Web应用程序,因为它们可以在后台与服务器进行通信,而不需要打断用户的浏览体验。在Web开发中,AJAX(Asynchronous JavaScript and XML)是一种常用的技术,用于在不重新加载整个页面的情况下,从服务器获取数据并更新网页的某一部分。使用AJAX,你可以创建异步请求,从而提供更快的响应和更好的用户体验。_ajax 获取http数据

Linux图形终端与字符终端-程序员宅基地

文章浏览阅读2.8k次。登录退出、修改密码、关机重启_字符终端

随便推点

Python与Arduino绘制超声波雷达扫描_超声波扫描建模 python库-程序员宅基地

文章浏览阅读3.8k次,点赞3次,收藏51次。前段时间看到一位发烧友制作的超声波雷达扫描神器,用到了Arduino和Processing,可惜啊,我不会Processing更看不懂人家的程序,咋办呢?嘿嘿,所以我就换了个思路解决,因为我会一点Python啊,那就动手吧!在做这个案例之前先要搞明白一个问题:怎么将Arduino通过超声波检测到的距离反馈到Python端?这个嘛,我首先想到了串行通信接口。没错!就是串口。只要Arduino将数据发送给COM口,然后Python能从COM口读取到这个数据就可以啦!我先写了一个测试程序试了一下,OK!搞定_超声波扫描建模 python库

凯撒加密方法介绍及实例说明-程序员宅基地

文章浏览阅读4.2k次。端—端加密指信息由发送端自动加密,并且由TCP/IP进行数据包封装,然后作为不可阅读和不可识别的数据穿过互联网,当这些信息到达目的地,将被自动重组、解密,而成为可读的数据。不可逆加密算法的特征是加密过程中不需要使用密钥,输入明文后由系统直接经过加密算法处理成密文,这种加密后的数据是无法被解密的,只有重新输入明文,并再次经过同样不可逆的加密算法处理,得到相同的加密密文并被系统重新识别后,才能真正解密。2.使用时,加密者查找明文字母表中需要加密的消息中的每一个字母所在位置,并且写下密文字母表中对应的字母。_凯撒加密

工控协议--cip--协议解析基本记录_cip协议embedded_service_error-程序员宅基地

文章浏览阅读5.7k次。CIP报文解析常用到的几个字段:普通类型服务类型:[0x00], CIP对象:[0x02 Message Router], ioi segments:[XX]PCCC(带cmd和func)服务类型:[0x00], CIP对象:[0x02 Message Router], cmd:[0x101], fnc:[0x101]..._cip协议embedded_service_error

如何在vs2019及以后版本(如vs2022)上添加 添加ActiveX控件中的MFC类_vs添加mfc库-程序员宅基地

文章浏览阅读2.4k次,点赞9次,收藏13次。有时候我们在MFC项目开发过程中,需要用到一些微软已经提供的功能,如VC++使用EXCEL功能,这时候我们就能直接通过VS2019到如EXCEL.EXE方式,生成对应的OLE头文件,然后直接使用功能,那么,我们上篇文章中介绍了vs2017及以前的版本如何来添加。但由于微软某些方面考虑,这种方式已被放弃。从上图中可以看出,这一功能,在从vs2017版本15.9开始,后续版本已经删除了此功能。那么我们如果仍需要此功能,我们如何在新版本中添加呢。_vs添加mfc库

frame_size (1536) was not respected for a non-last frame_frame_size (1024) was not respected for a non-last-程序员宅基地

文章浏览阅读785次。用ac3编码,执行编码函数时报错入如下:[ac3 @ 0x7fed7800f200] frame_size (1536) was not respected for anon-last frame (avcodec_encode_audio2)用ac3编码时每次送入编码器的音频采样数应该是1536个采样,不然就会报上述错误。这个数字并非刻意固定,而是跟ac3内部的编码算法原理相关。全网找不到,国内音视频之路还有很长的路,音视频人一起加油吧~......_frame_size (1024) was not respected for a non-last frame

Android移动应用开发入门_在安卓移动应用开发中要在活动类文件中声迷你一个复选框变量-程序员宅基地

文章浏览阅读230次,点赞2次,收藏2次。创建Android应用程序一个项目里面可以有很多模块,而每一个模块就对应了一个应用程序。项目结构介绍_在安卓移动应用开发中要在活动类文件中声迷你一个复选框变量

推荐文章

热门文章

相关标签