【深度学习系列】——Fashion-MNIST数据集简介-程序员宅基地

技术标签: Fashion-MNIST  深度学习  # 深度学习  

1、数据集简介

\quad \quad 不同于MNIST手写数据集,Fashion-MNIST数据集包含了10个类别的图像,分别是:t-shirt(T恤),trouser(牛仔裤),pullover(套衫),dress(裙子),coat(外套),sandal(凉鞋),shirt(衬衫),sneaker(运动鞋),bag(包),ankle boot(短靴)。

2、获取数据集

1、torchvision包

\quad \quad torchvision包,是服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。

\quad \quad torchvision主要由以下几部分构成:

1、torchvision.datasets:一些加载数据的函数及常用的数据集接口;

2、torchvision.models:包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;

3、torchvision.transforms:常用的图片转换,例如裁剪、旋转等;

4、torchvision.utils:其他的一些有用的方法


2、下载数据集

\quad \quad 下面,我们通过torchvision的torchvision.datasets包来下载这个数据集。第一次调用时会自动从网上获取数据。我们通过参数train来指定获取训练数据集或测试数据集(testing data set)。测试数据集也叫测试集(testing set),只用来评价模型的表现,并不用来训练模型。

\quad \quad 另外,我们还指定了参数transform=transform.ToTensor()使所有数据转换为Tensor,如果不进行转换则返回的是PIL图片。transform=transform.ToTensor()transform.ToTensor()将尺寸为( H ∗ W ∗ C H*W*C HWC)且数据位于(0,255)的PIL图片或者数据类型为np.uint8的Numpy数组转换为尺寸为( C ∗ H ∗ W C*H*W CHW)且数据类型为torch.float32且位于(0.0,1.0)的Tensor。其中,C为通道个数(如果图像是灰色的,则C=1;如果图像是彩色的,则C=3);H为图像高度;W为图像宽度。

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
from IPython import display
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

3、查看数据集

\quad \quad 上面的mnist_trainmnist_test都是torch.utils.datadatasets的子类,所以我们可以用len()来获取该数据集的大小,还可以用下标来获取具体的一个样本。训练集和测试机中的每一个类别的图像数分别为6000和1000。因为有10个类别,所以训练集和测试集的样本数分别为60000和10000。

#数据集大小
print(type(mnist_train))
print(len(mnist_train), len(mnist_test))

<class ‘torchvision.datasets.mnist.FashionMNIST’>
60000 10000

\quad \quad 我们可以通过方括号[]来访问任意一个样本,下面获取第一个样本的图像和标签。

feature, label = mnist_train[0]
print(feature.shape, feature.dtype)  # Channel x Height X Width
print(label)

torch.Size([1, 28, 28]) torch.float32
9

\quad \quad 变量feature对应高和宽均为28像素的图像。每个像素的数值为0到255之间8位无符号整数(uint8)。它使用三维的NDArray存储。其中的最后一维是通道数。因为数据集中是灰度图像,所以通道数为1。为了表述简洁,我们将高和宽分别为 h h h w w w像素的图像的形状记为 h × w h \times w h×w(h,w)

4、图像可视化显示

1、首先定义一个函数,根据数值标签获取字符串标签

\quad \quad Fashion-MNIST中一共包括了10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。以下函数可以将数值标签转成相应的文本标签。

# 本函数已保存在d2lzh_pytorch包中方便以后使用
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

2、定义函数,一个可以在一行里画出多张图像和对应标签的函数:

# 本函数已保存在d2lzh_pytorch包中方便以后使用
def show_fashion_mnist(images, labels):
    display.set_matplotlib_formats('svg')#用矢量图进行展示
    # 这里的_表示我们忽略(不使用)的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()

3、应用

现在,我们看一下训练数据集中前9个样本的图像内容和文本标签。

X, y = [], []
for i in range(10):
    X.append(mnist_train[i][0])
    y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))

在这里插入图片描述

5、读取小批量数据

\quad \quad 在以后的深度学习模型中,优化参数时经常会用到小批量梯度下降(深度学习常称为随机梯度下降)算法,因此我们需每次读取小批量数据来进行每次迭代与更新。

\quad \quad 我们将在训练数据集上训练模型,并将训练好的模型在测试数据集上评价模型的表现。前面说过,上面的mnist_train都是torch.utils.datadatasets的子类,所以我们可以将其传入torch.utils.data.DataLoader来创建一个读取小批量数据样本的DataLoader实例。

\quad \quad 在实践中,数据读取经常是训练的性能瓶颈,特别当模型较简单或者计算硬件性能较高时。PyTorch的DataLoader中一个很方便的功能是允许使用多进程来加速数据读取。这里我们通过参数num_workers来设置4个进程读取数据。

import sys
batch_size = 256#自己设定值大小
if sys.platform.startswith('win'):
    num_workers = 0  # 0表示不用额外的进程来加速读取数据
else:
    num_workers = 4

train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size, shuffle=True,num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size, shuffle=False,num_workers=num_workers)
                             

\quad \quad 我们将获取并读取Fashion-MNIST数据集的逻辑封装在d2lzh_pytorch.load_data_fashion_mnist函数中供后面章节调用。该函数将返回train_itertest_iter两个变量。随着本书内容的不断深入,我们会进一步改进该函数。它的完整实现将在“深度卷积神经网络(AlexNet)”一节中描述。

最后我们查看读取一遍训练数据需要的时间。

start = time.time()
for X, y in train_iter:
    continue
'%.2f sec' % (time.time() - start)

‘7.92 sec’

【完整代码】:

#导入所需包
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
from IPython import display
#1、下载数据集
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())
# 2、查看数据集大小及单个样本,方便设定参数
print(type(mnist_train))
print(len(mnist_train), len(mnist_test))
feature, label = mnist_train[0]
print(feature.shape, feature.dtype)  # Channel x Height X Width
print(label)

#3、将数据集的数值标签转换为文本标签
# 本函数已保存在d2lzh_pytorch包中方便以后使用
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]
# 4、同一行画出图片和标签,可视化
# 本函数已保存在d2lzh_pytorch包中方便以后使用
def show_fashion_mnist(images, labels):
    display.set_matplotlib_formats('svg')#用矢量图进行展示
    # 这里的_表示我们忽略(不使用)的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()

# 5、查看部分样本并可视化
X, y = [], []
for i in range(10):
    X.append(mnist_train[i][0])
    y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))

# 小批量读取数据
batch_size = 256

if sys.platform.startswith('win'):
    num_workers = 0  # 0表示不用额外的进程来加速读取数据
else:
    num_workers = 4

train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size, shuffle=True,num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size, shuffle=False,num_workers=num_workers)
                             

参考资料:

动手学深度学习

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_45666566/article/details/107812603

智能推荐

学python的就业还是学c加加的就业广_学习Python就业有哪些方向?-程序员宅基地

文章浏览阅读159次。展开全部Python的就业方向主要分为五大块,分别是:发展方向一:Linux运维发展方向二:e69da5e6ba9062616964757a686964616f31333431356634Python Web网站工程师发展方向三:Python自动化测试发展方向四:数据分析发展方向五:人工智能Python具体会涉及到的职业岗位主要有:0、WEB开发Python拥有很多免费数据函数库、免费web网页模..._机算机一直学python不学c++将来能就业吗

#4.2混沌数学与混沌理论-程序员宅基地

文章浏览阅读3.9k次,点赞3次,收藏7次。一个序列试验数字,随着试验数据的增多,试验误差不是保持相对一致,而是被急剧放大,导致失去序列特征,进而进入混沌状态,这种现象被数学家进一步研究,产生混沌数学。“法国数学家昂利·庞加莱(Henri Poincare)发明了相空间概念,这是一个虚构的数学空间,表示给定动力学系统所有可能的运动。庞加莱这一大创新所带来的结果,是动力学可借助被称为吸引子(attractor)的几何形状来加以直观化。趋向于定态的系统,它具有的吸引子是一个点。趋向于周期性地重复同样行为的系统,它具有的吸引子是一个闭环。也就是说,闭环吸_混沌数学

jetcache fastjson 泛型复杂对象JSON序列 ,反序列化-程序员宅基地

文章浏览阅读268次,点赞4次,收藏3次。默认的虚拟化不能转换List 中的泛型数据类型, 似的从缓存拿取的list集合对象数据全部都转换成了JSONObject。增加JSON对象类型存储。增加JSON对象类型解析。

新手如何参加信息学竞赛NOIP,怎么入门(常见问题解答)?_如何参加noip-程序员宅基地

文章浏览阅读9.6k次,点赞12次,收藏75次。新手如何参加信息学竞赛NOIP,怎么入门(常见问题解答)?新手学信息学竞赛如何入门,知乎上有一个话题讨论,我们也引用一些比较优质的回答给各位同学和家长一些参考,结合一些常见学习问题作出总结。本篇文章摘自清北学堂noipnoi订阅号;2019NOIP夏令营报名正在进行中,可前往订阅号报名和咨询。问:高一新生如何准备信息竞赛?答1:下面七点按难度从低到高排序:1.跟着学校的步伐。关注任何动态。..._如何参加noip

APACHE-ATLAS-2.1.0简介(一)-程序员宅基地

文章浏览阅读1.8k次。ATLAS是Hadoop生态的数据治理和元数据框架,是一组可扩展的核心基础治理服务,使企业能够有效,高效地满足Hadoop生态中的合规性要求,并允许与整个企业数据生态系统集成。Apache Atlas为组织提供了开放的元数据管理和治理功能,以建立其数据资产的目录,对这些资产进行分类和治理,并为数据科学家,分析师和数据治理团队提供围绕这些数据资产的协作功能。...............

Java 解决同步数据时 Read timed out; nested exception is java.net.SocketTimeoutException: Read timed out_read timed out; nested exception is java.net.socke-程序员宅基地

文章浏览阅读1.3k次,点赞12次,收藏6次。Java 解决同步数据时 Read timed out; nested exception is java.net.SocketTimeoutException: Read timed out_read timed out; nested exception is java.net.sockettimeoutexception: read ti

随便推点

gitlab仓库完整迁移(代码,分支,提交记录)_gitlab new directory-程序员宅基地

文章浏览阅读2.7k次。背景代码仓库所在服务器因为异常断电关机,无法启动,需要进行gitlab工程代码迁移命令git clone --mirror <URL to my OLD repo location>cd <New directory where your OLD repo was cloned>git remote set-url origin <URL to my NEW repo location>git push -f origin..._gitlab new directory

完美解决丨 - [SyntaxError: invalid syntax](#SyntaxError-invalid-syntax)_invalid create index syntax, use `create index for-程序员宅基地

文章浏览阅读3.6k次,点赞76次,收藏5次。「SQL面试题库」是由不吃西红柿发起,全员免费参与的SQL学习活动。我每天发布1道SQL面试真题,从简单到困难,涵盖所有SQL知识点,我敢保证只要做完这100道题,不仅能轻松搞定面试,代码能力和工作效率也会有明显提升。_invalid create index syntax, use `create index for ...` instead. (line 1, co

[网络安全自学篇] 三十五.恶意代码攻击检测及恶意样本分析_如何在数据包中分析那些是带有恶意攻击的和正常语句-程序员宅基地

文章浏览阅读1.7w次,点赞14次,收藏78次。本文主要结合作者的《系统安全前沿》作业,相关论文及绿盟李东宏老师的博客,从产业界和学术界分别详细讲解恶意代码攻击溯源的相关知识。在学术界方面,用类似于综述来介绍攻击追踪溯源的不同方法;在产业界方面,主要参考李东宏老师从企业恶意样本分析的角度介绍溯源工作。关于攻击溯源的博客和论文都比较少,希望这篇文章对您有所帮助,如果文章中存在错误或理解不到位的地方,还请告知作者与海涵~_如何在数据包中分析那些是带有恶意攻击的和正常语句

如何带好一个团队?团队管理的要点有哪些?_如何管理好一个团队-程序员宅基地

文章浏览阅读2k次。以目标为基准,以结果为导向,一个目标明确的团队,项目成员的个人目标也会更加明确,从而发挥最大的效率。合理运用自己的权限是管理者必修的一门功课,因为你的一个决策会影响到员工的工作态度、发展等要素,所以一定要深思熟虑,对团队和团队中的每一个人负责。在项目管理中,使用项目管理软件可以实现全面,可视化管理,在软件上制定计划,统一分配任务,不仅能够提高效率,还能够实现更好的管理。作为管理者,要学会适度的放权,有的放矢,管理者手中的权限是为了团队更好的发展,肩负着重要的责任。实现有效沟通,避免信息孤岛的出现。_如何管理好一个团队

暗影精灵9休眠时间间歇性风扇转动解决方法_暗影精灵9睡眠风扇突然转-程序员宅基地

文章浏览阅读1.3k次。HP最近一次更新后本人的暗影精灵9在休眠的时候风扇经常性地突然高速转动,带来了不小的噪音困扰,查阅惠普社区后先是安装了最新的BIOS文件F.11版本,仍不能解决问题。目前方法:惠普管家上找到之前版本的BIOS文件(本人使用F.08)并下载,按向导安装并重启电脑,即可解决问题。即回退BIOS版本文件。这两天电脑不再出现休眠后风扇时不时地疯狂转动的现象,应该是个有效的方法。_暗影精灵9睡眠风扇突然转

常用Web漏洞扫描工具汇总(持续更新中)_网站扫描工具-程序员宅基地

文章浏览阅读1.1k次。常用Web漏洞扫描工具汇总_网站扫描工具

推荐文章

热门文章

相关标签