Pytorch交叉熵损失(CrossEntropyLoss)函数内部运算解析_crossentropyloss(reduction="mean")-程序员宅基地

技术标签: python  深度学习  pytorch  

  对于交叉熵损失函数的来由有很多资料可以参考,这里就不再赘述。本文主要尝试对交叉熵损失函数的内部运算做深度解析。

1. 函数介绍

  Pytorch官网中对交叉熵损失函数的介绍如下:

CLASS torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=- 100,reduce=None, reduction=‘mean’, label_smoothing=0.0)

  该损失函数计算输入值(input)和目标值(target)之间的交叉熵损失。交叉熵损失函数可用于训练一个 C C C类别的分类问题。参数weight给定时,其为分配给每一个类别的权重的一维张量(Tensor)。当数据集分布不均衡时,这是很有用的。
  函数输入(input)应包含每一个类别的原始、非标准化分数。对于未批量化的输入,输入必须是大小为 ( C ) (C) C的张量, ( m i n i b a t c h , C ) (minibatch,C) minibatchC ( m i n i b a t c h , C , d 1 , d 2 , . . . , d K ) (minibatch,C,d_1 ,d_2 ,... ,d_K) minibatchCd1d2...dK,在K维情况下, K ≥ 1 K \geq1 K1
  函数目标值(target)有两种情况,本文只介绍其中较为有效的一种情况,即target为类索引
   本文以下内容均为target为类索引的情况。

  函数目标值(target)取值为在 [ 0 , C ) [0,C) [0C)之间的类索引, C C C为类别数。参数reduction设为'none'时,交叉熵损失可描述如下:
l ( x , y ) = L = { l 1 , . . . , l N } T , l n = − w y n l o g e x p ( x n , y n ) ∑ c = 1 C e x p ( x n , c ) ⋅ 1 { y n   / = i g n o r e _ i n d e x } (1) \large l(x,y) = L = \left \{ l_1,...,l_N \right \}^T, \\ \large l_n = -w_{yn}log\frac{exp(x_{n,y_n})}{\sum_{c=1}^{C}exp(x_{n,c})}\cdot 1\left \{ y_n\mathrlap{\,/}{=}ignore\_index \right \}\tag{1} l(x,y)=L={ l1,...,lN}T,ln=wynlogc=1Cexp(xn,c)exp(xn,yn)1{ yn/=ignore_index}(1)

  其中, x x x是输入, y y y是目标值, w w w是weight, C C C是类别数, N N N为batch size。在reduction不为'none'时(默认为'mean'),有:
l ( x , y ) = { ∑ n = 1 N 1 ∑ n = 1 N w y n ⋅ 1 { y n   / = i g n o r e _ i n d e x } l n , i f   r e d u c t i o n = ‘ m e a n ’ ; ∑ n = 1 N l n , i f   r e d u c t i o n = ‘ s u m ’ . (2) \large l(x,y) = \left\{\begin{matrix} \sum_{n=1}^{N}\frac{1}{\sum_{n=1}^{N}w_{yn} \cdot1\left \{ y_n\mathrlap{\,/}{=}ignore\_index \right \}}l_n, \quad if \, reduction=‘mean’; \\ \sum_{n=1}^{N}l_n, \quad if \, reduction=‘sum’ . \end{matrix}\right. \tag{2} l(x,y)=n=1Nn=1Nwyn1{ yn/=ignore_index}1ln,ifreduction=mean;n=1Nln,ifreduction=sum.(2)

 需要指出的是,在这种情况下的交叉熵损失等价于LogSoftmaxNLLLoss的组合。1

  因此,我们可以从LogSoftmaxNLLLoss来深度解析交叉熵损失函数的内部运算。

2. LogSoftmax函数

  LogSoftmax()函数2公式如下:
L o g S o f t m a x ( x i ) = l o g ( e x p ( x i ) ∑ j e x p ( x j ) ) (3) LogSoftmax(x_i) = log(\frac{exp(x_i)}{\sum_{j}exp(x_j)}) \tag{3} LogSoftmax(xi)=log(jexp(xj)exp(xi))(3)
  即,先对输入值进行Softmax归一化处理,然后对归一化值取对数。这部分对应公式(1)中的 log ⁡ e x p ( x n , y n ) ∑ c = 1 C e x p ( x n , c ) \textcolor{red}{\log\frac{exp(x_{n,y_n})}{\sum_{c=1}^{C}exp(x_{n,c})}} logc=1Cexp(xn,c)exp(xn,yn)

  代码示例如下:

>>> import torch.nn as nn
>>> SM = nn.Softmax(dim=1) #Softmax函数
>>> x = torch.tensor([[1.0,3.0,4.0],[7.0,3.0,8.0],[9.0,7.0,5.0]])
>>> x
tensor([[1., 3., 4.],
        [7., 3., 8.],
        [9., 7., 5.]])
 
>>> output_SM = SM(x) #第一步,对x进行Softmax归一化处理
>>> output_SM
#每一行元素相加之和等于1
tensor([[0.0351, 0.2595, 0.7054],
        [0.2676, 0.0049, 0.7275],
        [0.8668, 0.1173, 0.0159]]) 
>>> out_L_SM = torch.log(output_SM) #第二步,对输出取log
>>> out_L_SM
tensor([[-3.3490, -1.3490, -0.3490],
        [-1.3182, -5.3182, -0.3182],
        [-0.1429, -2.1429, -4.1429]])
        
#直接使用LogSoftmax函数,一步到位
>>> L_SM = nn.LogSoftmax(dim=1)
>>> out_L_SM_ = L_SM(x)
>>> out_L_SM_
tensor([[-3.3490, -1.3490, -0.3490],
        [-1.3182, -5.3182, -0.3182],
        [-0.1429, -2.1429, -4.1429]])

3. NLLLoss函数

  Pytorch中的NLLLoss函数3“名不副实”,虽然名为负对数似然函数,但其内部并没有进行对数计算,而只是对输入值求平均后取负(函数参数reduction为默认值'mean',参数weight为默认值'none'时)。

  官网介绍如下:

CLASS torch.nn.NLLLoss(weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction=‘mean’)

  参数reduction值为'none'时:
l ( x , y ) = L = { l 1 , . . . , l N } T ,   l n = − w y n x n , y n , w c = w e i g h t [ c ] ⋅ 1 { c   / = i g n o r e _ i n d e x } , (4) \large l(x,y) = L = \left \{ l_1,...,l_N \right \}^T,\ l_n = -w_{yn}x_{n,yn}, w_c = weight[c]\cdot1\left \{ c\mathrlap{\,/}{=}ignore\_index\right \},\tag{4} l(x,y)=L={ l1,...,lN}T, ln=wynxn,yn,wc=weight[c]1{ c/=ignore_index},(4)
  其中, x x x为输入, y y y为目标值, w w w为weight, N N N为batch size。
  参数reduction值不为'none'时(默认为'mean'),有:
l ( x , y ) = { ∑ n = 1 N 1 ∑ n = 1 N w y n l n , i f   r e d u c t i o n = ‘ m e a n ’ ; ∑ n = 1 N l n , i f   r e d u c t i o n = ‘ s u m ’ . (5) \large l(x,y) = \left\{\begin{matrix} \sum_{n=1}^{N}\frac{1}{\sum_{n=1}^{N}w_{yn}}l_n, \quad if \, reduction=‘mean’; \\ \sum_{n=1}^{N}l_n, \quad if \, reduction=‘sum’ . \end{matrix}\right. \tag{5} l(x,y)=n=1Nn=1Nwyn1ln,ifreduction=mean;n=1Nln,ifreduction=sum.(5)
  可以看出,当reduction'mean'时,即是对 l n l_n ln求加权平均值。weight参数默认为1,因此默认情况下,即是对 l n l_n ln求平均值。又 l n = − w y n x n , y n l_n = -w_{yn}x_{n,yn} ln=wynxn,yn,所以weight为默认值1时, l n = − x n , y n l_n=-x_{n,yn} ln=xn,yn。故此时,即是 x x x求平均后取负。 这部分对于公式(2)中的 ∑ n = 1 N 1 ∑ n = 1 N w y n ⋅ 1 { y n   / = i g n o r e _ i n d e x } l n \textcolor{red}{\sum_{n=1}^{N}\frac{1}{\sum_{n=1}^{N}w_{yn} \cdot1\left \{ y_n\mathrlap{\,/}{=}ignore\_index \right \}}l_n} n=1Nn=1Nwyn1{ yn/=ignore_index}1ln

  实例代码验证如下:

>>> import torch
>>> NLLLoss = torch.nn.NLLLoss() #Pytorch负对数似然损失函数
>>> input = torch.randn(3,3)
>>>input
tensor([[1.4550, 2.3858, 1.1724],
        [0.4952, 1.5870, 0.9594],
        [1.4170, 0.4525, 0.2519]])
        
>>>target = torch.tensor([1,0,2]) #类索引目标值
>>> loss = NLLLoss(input, target)
>>> loss
tensor(-1.0443)

  平均取负有: V a l u e = − 1 3 ( 2.3858 + 0.4952 + 0.2519 ) = − 1.0443 Value = -\frac{1}{3}\left ( 2.3858+0.4952+0.2519 \right ) =-1.0443 Value=31(2.3858+0.4952+0.2519)=1.0443
  显然,平均取负结果和NLLLoss运算结果相同。

注:笔者窃以为,公式(5)中上式可写为 ∑ n = 1 N l n ∑ n = 1 N w y n \frac{\sum_{n=1}^{N}l_n}{\sum_{n=1}^{N}w_{yn}} n=1Nwynn=1Nln,如此则更容易理解。公式(2)同理。

4. 小结

  本文通过将CrossEntropyLoss拆解为LogSoftmaxNLLLoss两步,对交叉熵损失内部计算做了深度的解析,以更清晰地理解交叉熵损失函数。需要指出的是,本文所介绍的内容,只是对于CrossEntropyLoss的target为类索引的情况,CrossEntropyLoss的target还可以是每个类别的概率(Probabilities for each class),这种情况有所不同。


  学习总结,以作分享,如有不妥,敬请指出。


Reference


  1. https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss

  2. https://pytorch.org/docs/stable/generated/torch.nn.LogSoftmax.html?highlight=logsoftmax#torch.nn.LogSoftmax

  3. https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html?highlight=nllloss#torch.nn.NLLLoss

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/XIAOSHUCONG/article/details/123527257

智能推荐

hive使用适用场景_大数据入门:Hive应用场景-程序员宅基地

文章浏览阅读5.8k次。在大数据的发展当中,大数据技术生态的组件,也在不断地拓展开来,而其中的Hive组件,作为Hadoop的数据仓库工具,可以实现对Hadoop集群当中的大规模数据进行相应的数据处理。今天我们的大数据入门分享,就主要来讲讲,Hive应用场景。关于Hive,首先需要明确的一点就是,Hive并非数据库,Hive所提供的数据存储、查询和分析功能,本质上来说,并非传统数据库所提供的存储、查询、分析功能。Hive..._hive应用场景

zblog采集-织梦全自动采集插件-织梦免费采集插件_zblog 网页采集插件-程序员宅基地

文章浏览阅读496次。Zblog是由Zblog开发团队开发的一款小巧而强大的基于Asp和PHP平台的开源程序,但是插件市场上的Zblog采集插件,没有一款能打的,要么就是没有SEO文章内容处理,要么就是功能单一。很少有适合SEO站长的Zblog采集。人们都知道Zblog采集接口都是对Zblog采集不熟悉的人做的,很多人采取模拟登陆的方法进行发布文章,也有很多人直接操作数据库发布文章,然而这些都或多或少的产生各种问题,发布速度慢、文章内容未经严格过滤,导致安全性问题、不能发Tag、不能自动创建分类等。但是使用Zblog采._zblog 网页采集插件

Flink学习四:提交Flink运行job_flink定时运行job-程序员宅基地

文章浏览阅读2.4k次,点赞2次,收藏2次。restUI页面提交1.1 添加上传jar包1.2 提交任务job1.3 查看提交的任务2. 命令行提交./flink-1.9.3/bin/flink run -c com.qu.wc.StreamWordCount -p 2 FlinkTutorial-1.0-SNAPSHOT.jar3. 命令行查看正在运行的job./flink-1.9.3/bin/flink list4. 命令行查看所有job./flink-1.9.3/bin/flink list --all._flink定时运行job

STM32-LED闪烁项目总结_嵌入式stm32闪烁led实验总结-程序员宅基地

文章浏览阅读1k次,点赞2次,收藏6次。这个项目是基于STM32的LED闪烁项目,主要目的是让学习者熟悉STM32的基本操作和编程方法。在这个项目中,我们将使用STM32作为控制器,通过对GPIO口的控制实现LED灯的闪烁。这个STM32 LED闪烁的项目是一个非常简单的入门项目,但它可以帮助学习者熟悉STM32的编程方法和GPIO口的使用。在这个项目中,我们通过对GPIO口的控制实现了LED灯的闪烁。LED闪烁是STM32入门课程的基础操作之一,它旨在教学生如何使用STM32开发板控制LED灯的闪烁。_嵌入式stm32闪烁led实验总结

Debezium安装部署和将服务托管到systemctl-程序员宅基地

文章浏览阅读63次。本文介绍了安装和部署Debezium的详细步骤,并演示了如何将Debezium服务托管到systemctl以进行方便的管理。本文将详细介绍如何安装和部署Debezium,并将其服务托管到systemctl。解压缩后,将得到一个名为"debezium"的目录,其中包含Debezium的二进制文件和其他必要的资源。注意替换"ExecStart"中的"/path/to/debezium"为实际的Debezium目录路径。接下来,需要下载Debezium的压缩包,并将其解压到所需的目录。

Android 控制屏幕唤醒常亮或熄灭_android实现拿起手机亮屏-程序员宅基地

文章浏览阅读4.4k次。需求:在诗词曲文项目中,诗词整篇朗读的时候,文章没有读完会因为屏幕熄灭停止朗读。要求:在文章没有朗读完毕之前屏幕常亮,读完以后屏幕常亮关闭;1.权限配置:设置电源管理的权限。

随便推点

目标检测简介-程序员宅基地

文章浏览阅读2.3k次。目标检测简介、评估标准、经典算法_目标检测

记SQL server安装后无法连接127.0.0.1解决方法_sqlserver 127 0 01 无法连接-程序员宅基地

文章浏览阅读6.3k次,点赞4次,收藏9次。实训时需要安装SQL server2008 R所以我上网上找了一个.exe 的安装包链接:https://pan.baidu.com/s/1_FkhB8XJy3Js_rFADhdtmA提取码:ztki注:解压后1.04G安装时Microsoft需下载.NET,更新安装后会自动安装如下:点击第一个傻瓜式安装,唯一注意的是在修改路径的时候如下不可修改:到安装实例的时候就可以修改啦数据..._sqlserver 127 0 01 无法连接

js 获取对象的所有key值,用来遍历_js 遍历对象的key-程序员宅基地

文章浏览阅读7.4k次。1. Object.keys(item); 获取到了key之后就可以遍历的时候直接使用这个进行遍历所有的key跟valuevar infoItem={ name:'xiaowu', age:'18',}//的出来的keys就是[name,age]var keys=Object.keys(infoItem);2. 通常用于以下实力中 <div *ngFor="let item of keys"> <div>{{item}}.._js 遍历对象的key

粒子群算法(PSO)求解路径规划_粒子群算法路径规划-程序员宅基地

文章浏览阅读2.2w次,点赞51次,收藏310次。粒子群算法求解路径规划路径规划问题描述    给定环境信息,如果该环境内有障碍物,寻求起始点到目标点的最短路径, 并且路径不能与障碍物相交,如图 1.1.1 所示。1.2 粒子群算法求解1.2.1 求解思路    粒子群优化算法(PSO),粒子群中的每一个粒子都代表一个问题的可能解, 通过粒子个体的简单行为,群体内的信息交互实现问题求解的智能性。    在路径规划中,我们将每一条路径规划为一个粒子,每个粒子群群有 n 个粒 子,即有 n 条路径,同时,每个粒子又有 m 个染色体,即中间过渡点的_粒子群算法路径规划

量化评价:稳健的业绩评价指标_rar 海龟-程序员宅基地

文章浏览阅读353次。所谓稳健的评估指标,是指在评估的过程中数据的轻微变化并不会显著的影响一个统计指标。而不稳健的评估指标则相反,在对交易系统进行回测时,参数值的轻微变化会带来不稳健指标的大幅变化。对于不稳健的评估指标,任何对数据有影响的因素都会对测试结果产生过大的影响,这很容易导致数据过拟合。_rar 海龟

IAP在ARM Cortex-M3微控制器实现原理_value line devices connectivity line devices-程序员宅基地

文章浏览阅读607次,点赞2次,收藏7次。–基于STM32F103ZET6的UART通讯实现一、什么是IAP,为什么要IAPIAP即为In Application Programming(在应用中编程),一般情况下,以STM32F10x系列芯片为主控制器的设备在出厂时就已经使用J-Link仿真器将应用代码烧录了,如果在设备使用过程中需要进行应用代码的更换、升级等操作的话,则可能需要将设备返回原厂并拆解出来再使用J-Link重新烧录代码,这就增加了很多不必要的麻烦。站在用户的角度来说,就是能让用户自己来更换设备里边的代码程序而厂家这边只需要提供给_value line devices connectivity line devices