Pytorch交叉熵损失(CrossEntropyLoss)函数内部运算解析_crossentropyloss(reduction="mean")-程序员宅基地

技术标签: python  深度学习  pytorch  

  对于交叉熵损失函数的来由有很多资料可以参考,这里就不再赘述。本文主要尝试对交叉熵损失函数的内部运算做深度解析。

1. 函数介绍

  Pytorch官网中对交叉熵损失函数的介绍如下:

CLASS torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=- 100,reduce=None, reduction=‘mean’, label_smoothing=0.0)

  该损失函数计算输入值(input)和目标值(target)之间的交叉熵损失。交叉熵损失函数可用于训练一个 C C C类别的分类问题。参数weight给定时,其为分配给每一个类别的权重的一维张量(Tensor)。当数据集分布不均衡时,这是很有用的。
  函数输入(input)应包含每一个类别的原始、非标准化分数。对于未批量化的输入,输入必须是大小为 ( C ) (C) C的张量, ( m i n i b a t c h , C ) (minibatch,C) minibatchC ( m i n i b a t c h , C , d 1 , d 2 , . . . , d K ) (minibatch,C,d_1 ,d_2 ,... ,d_K) minibatchCd1d2...dK,在K维情况下, K ≥ 1 K \geq1 K1
  函数目标值(target)有两种情况,本文只介绍其中较为有效的一种情况,即target为类索引
   本文以下内容均为target为类索引的情况。

  函数目标值(target)取值为在 [ 0 , C ) [0,C) [0C)之间的类索引, C C C为类别数。参数reduction设为'none'时,交叉熵损失可描述如下:
l ( x , y ) = L = { l 1 , . . . , l N } T , l n = − w y n l o g e x p ( x n , y n ) ∑ c = 1 C e x p ( x n , c ) ⋅ 1 { y n   / = i g n o r e _ i n d e x } (1) \large l(x,y) = L = \left \{ l_1,...,l_N \right \}^T, \\ \large l_n = -w_{yn}log\frac{exp(x_{n,y_n})}{\sum_{c=1}^{C}exp(x_{n,c})}\cdot 1\left \{ y_n\mathrlap{\,/}{=}ignore\_index \right \}\tag{1} l(x,y)=L={ l1,...,lN}T,ln=wynlogc=1Cexp(xn,c)exp(xn,yn)1{ yn/=ignore_index}(1)

  其中, x x x是输入, y y y是目标值, w w w是weight, C C C是类别数, N N N为batch size。在reduction不为'none'时(默认为'mean'),有:
l ( x , y ) = { ∑ n = 1 N 1 ∑ n = 1 N w y n ⋅ 1 { y n   / = i g n o r e _ i n d e x } l n , i f   r e d u c t i o n = ‘ m e a n ’ ; ∑ n = 1 N l n , i f   r e d u c t i o n = ‘ s u m ’ . (2) \large l(x,y) = \left\{\begin{matrix} \sum_{n=1}^{N}\frac{1}{\sum_{n=1}^{N}w_{yn} \cdot1\left \{ y_n\mathrlap{\,/}{=}ignore\_index \right \}}l_n, \quad if \, reduction=‘mean’; \\ \sum_{n=1}^{N}l_n, \quad if \, reduction=‘sum’ . \end{matrix}\right. \tag{2} l(x,y)=n=1Nn=1Nwyn1{ yn/=ignore_index}1ln,ifreduction=mean;n=1Nln,ifreduction=sum.(2)

 需要指出的是,在这种情况下的交叉熵损失等价于LogSoftmaxNLLLoss的组合。1

  因此,我们可以从LogSoftmaxNLLLoss来深度解析交叉熵损失函数的内部运算。

2. LogSoftmax函数

  LogSoftmax()函数2公式如下:
L o g S o f t m a x ( x i ) = l o g ( e x p ( x i ) ∑ j e x p ( x j ) ) (3) LogSoftmax(x_i) = log(\frac{exp(x_i)}{\sum_{j}exp(x_j)}) \tag{3} LogSoftmax(xi)=log(jexp(xj)exp(xi))(3)
  即,先对输入值进行Softmax归一化处理,然后对归一化值取对数。这部分对应公式(1)中的 log ⁡ e x p ( x n , y n ) ∑ c = 1 C e x p ( x n , c ) \textcolor{red}{\log\frac{exp(x_{n,y_n})}{\sum_{c=1}^{C}exp(x_{n,c})}} logc=1Cexp(xn,c)exp(xn,yn)

  代码示例如下:

>>> import torch.nn as nn
>>> SM = nn.Softmax(dim=1) #Softmax函数
>>> x = torch.tensor([[1.0,3.0,4.0],[7.0,3.0,8.0],[9.0,7.0,5.0]])
>>> x
tensor([[1., 3., 4.],
        [7., 3., 8.],
        [9., 7., 5.]])
 
>>> output_SM = SM(x) #第一步,对x进行Softmax归一化处理
>>> output_SM
#每一行元素相加之和等于1
tensor([[0.0351, 0.2595, 0.7054],
        [0.2676, 0.0049, 0.7275],
        [0.8668, 0.1173, 0.0159]]) 
>>> out_L_SM = torch.log(output_SM) #第二步,对输出取log
>>> out_L_SM
tensor([[-3.3490, -1.3490, -0.3490],
        [-1.3182, -5.3182, -0.3182],
        [-0.1429, -2.1429, -4.1429]])
        
#直接使用LogSoftmax函数,一步到位
>>> L_SM = nn.LogSoftmax(dim=1)
>>> out_L_SM_ = L_SM(x)
>>> out_L_SM_
tensor([[-3.3490, -1.3490, -0.3490],
        [-1.3182, -5.3182, -0.3182],
        [-0.1429, -2.1429, -4.1429]])

3. NLLLoss函数

  Pytorch中的NLLLoss函数3“名不副实”,虽然名为负对数似然函数,但其内部并没有进行对数计算,而只是对输入值求平均后取负(函数参数reduction为默认值'mean',参数weight为默认值'none'时)。

  官网介绍如下:

CLASS torch.nn.NLLLoss(weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction=‘mean’)

  参数reduction值为'none'时:
l ( x , y ) = L = { l 1 , . . . , l N } T ,   l n = − w y n x n , y n , w c = w e i g h t [ c ] ⋅ 1 { c   / = i g n o r e _ i n d e x } , (4) \large l(x,y) = L = \left \{ l_1,...,l_N \right \}^T,\ l_n = -w_{yn}x_{n,yn}, w_c = weight[c]\cdot1\left \{ c\mathrlap{\,/}{=}ignore\_index\right \},\tag{4} l(x,y)=L={ l1,...,lN}T, ln=wynxn,yn,wc=weight[c]1{ c/=ignore_index},(4)
  其中, x x x为输入, y y y为目标值, w w w为weight, N N N为batch size。
  参数reduction值不为'none'时(默认为'mean'),有:
l ( x , y ) = { ∑ n = 1 N 1 ∑ n = 1 N w y n l n , i f   r e d u c t i o n = ‘ m e a n ’ ; ∑ n = 1 N l n , i f   r e d u c t i o n = ‘ s u m ’ . (5) \large l(x,y) = \left\{\begin{matrix} \sum_{n=1}^{N}\frac{1}{\sum_{n=1}^{N}w_{yn}}l_n, \quad if \, reduction=‘mean’; \\ \sum_{n=1}^{N}l_n, \quad if \, reduction=‘sum’ . \end{matrix}\right. \tag{5} l(x,y)=n=1Nn=1Nwyn1ln,ifreduction=mean;n=1Nln,ifreduction=sum.(5)
  可以看出,当reduction'mean'时,即是对 l n l_n ln求加权平均值。weight参数默认为1,因此默认情况下,即是对 l n l_n ln求平均值。又 l n = − w y n x n , y n l_n = -w_{yn}x_{n,yn} ln=wynxn,yn,所以weight为默认值1时, l n = − x n , y n l_n=-x_{n,yn} ln=xn,yn。故此时,即是 x x x求平均后取负。 这部分对于公式(2)中的 ∑ n = 1 N 1 ∑ n = 1 N w y n ⋅ 1 { y n   / = i g n o r e _ i n d e x } l n \textcolor{red}{\sum_{n=1}^{N}\frac{1}{\sum_{n=1}^{N}w_{yn} \cdot1\left \{ y_n\mathrlap{\,/}{=}ignore\_index \right \}}l_n} n=1Nn=1Nwyn1{ yn/=ignore_index}1ln

  实例代码验证如下:

>>> import torch
>>> NLLLoss = torch.nn.NLLLoss() #Pytorch负对数似然损失函数
>>> input = torch.randn(3,3)
>>>input
tensor([[1.4550, 2.3858, 1.1724],
        [0.4952, 1.5870, 0.9594],
        [1.4170, 0.4525, 0.2519]])
        
>>>target = torch.tensor([1,0,2]) #类索引目标值
>>> loss = NLLLoss(input, target)
>>> loss
tensor(-1.0443)

  平均取负有: V a l u e = − 1 3 ( 2.3858 + 0.4952 + 0.2519 ) = − 1.0443 Value = -\frac{1}{3}\left ( 2.3858+0.4952+0.2519 \right ) =-1.0443 Value=31(2.3858+0.4952+0.2519)=1.0443
  显然,平均取负结果和NLLLoss运算结果相同。

注:笔者窃以为,公式(5)中上式可写为 ∑ n = 1 N l n ∑ n = 1 N w y n \frac{\sum_{n=1}^{N}l_n}{\sum_{n=1}^{N}w_{yn}} n=1Nwynn=1Nln,如此则更容易理解。公式(2)同理。

4. 小结

  本文通过将CrossEntropyLoss拆解为LogSoftmaxNLLLoss两步,对交叉熵损失内部计算做了深度的解析,以更清晰地理解交叉熵损失函数。需要指出的是,本文所介绍的内容,只是对于CrossEntropyLoss的target为类索引的情况,CrossEntropyLoss的target还可以是每个类别的概率(Probabilities for each class),这种情况有所不同。


  学习总结,以作分享,如有不妥,敬请指出。


Reference


  1. https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss

  2. https://pytorch.org/docs/stable/generated/torch.nn.LogSoftmax.html?highlight=logsoftmax#torch.nn.LogSoftmax

  3. https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html?highlight=nllloss#torch.nn.NLLLoss

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/XIAOSHUCONG/article/details/123527257

智能推荐

图像梯度-sobel算子-程序员宅基地

文章浏览阅读1w次,点赞12次,收藏61次。(1)理论部分x 水平方向的梯度, 其实也就是右边 - 左边,有的权重为1,有的为2 。若是计算出来的值很大 说明是一个边界 。y 竖直方向的梯度,其实也就是下面减上面,权重1,或2 。若是计算出来的值很大 说明是一个边界 。图像的梯度为:有时简化为:即:(2)程序部分函数:Sobelddepth 通常取 -1,但是会导致结果溢出,检测不出边缘,故使..._sobel算子

cuda10.1和cudnn7.6.5百度网盘下载链接(Linux版)_cudnn7.6网盘下载-程序员宅基地

文章浏览阅读3.6k次,点赞17次,收藏8次。cuda10.1和cudnn7.6.5百度网盘下载链接(Linux版)在官网下载不仅慢,,,主要是还总失败。。终于下载成功了,这里给出百度网盘下载链接,希望可以帮到别人百度网盘下载链接提取码: vyg5_cudnn7.6网盘下载

Python正则表达式大全-程序员宅基地

文章浏览阅读9.3w次,点赞69次,收藏427次。定义:正则表达式是对字符串(包括普通字符(例如,a 到 z 之间的字母)和特殊字符(称为“元字符”))操作的一种逻辑公式,就是用事先定义好的一些特定字符、及这些特定字符的组合,组成一个“规则字符串”,这个“规则字符串”用来表达对字符串的一种过滤逻辑。正则表达式是一种文本模式,该模式描述在搜索文本时要匹配的一个或多个字符串。上面都是官方的说明,我自己的理解是(仅供参考):通过事先规定好一些特殊字符的匹配规则,然后利用这些字符进行组合来匹配各种复杂的字符串场景。比如现在的爬虫和数据分析,字符串校验等等都需要用_python正则表达式

Vue之条件渲染_条件渲染的基本概念-程序员宅基地

文章浏览阅读973次。条件渲染就是在指定的条件下,渲染出指定的UI。比如当我们显示主页的时候,应该隐藏掉登录等一系列不相干的UI元素。即UI元素只在特定条件下进行显示。而在VUE3中,这种UI元素的显示和隐藏可以通过两个关键字,`v-if` 和`v-show`来实现。但是虽然实现的功能一样,但他们两者有着一些细微的区别。总结起来这个区别就是:v-show控制UI元素隐藏时只是将UI的显示状态变成了不可见,实际上这个UI是存在的,但是v-if隐藏UI元素时则是直接干掉了这个UI元素,使其不显示_条件渲染的基本概念

直播+录播_直播加录播是-程序员宅基地

文章浏览阅读1.2k次。什么是直播回放?简单的说就是腾讯视频【支持将已经直播结束的节目再次播放,】方便你随时观看。目前此功能试运营阶段,最多可查看过往2小时内的节目,后续将可支持最多48小时内的节目。直播+回放+看点+预订功能,全面打通直播节目的过去、现在和未来。 直播回放和暂停功能目前已经同时在PC客户端播放器(2012 Beta2以上版本)和WEB网页的Flash播放器上线,操作简单便捷,如下图_直播加录播是

VDM Alloy 20与 Alloy 926超级不锈钢的化学成分及特性_vdm alloy 36元素含量-程序员宅基地

文章浏览阅读626次。VDM Alloy 20与 Alloy 926超级不锈钢的化学成分及特性化学成分概览材料特性VDM Alloy 20的材料特性包括:• 卓越的抗硫酸和磷酸腐蚀性能• 良好的抗晶间腐蚀能力• 出色的抗氯离子引起的应力腐蚀开裂能力• 良好的抗点蚀和缝隙腐蚀能力• 在室温以及高达500°C的高温下具有良好的机械性能VDM Alloy 926的材料特性包括:出色的抗点蚀和缝隙腐蚀能力与其他奥氏体不锈钢相比,抗应力腐蚀开裂的能力有所提高与氧化性和还原性介质接触时具有良好的耐_vdm alloy 36元素含量

随便推点

ico引入方法_arco的ico怎么导入-程序员宅基地

文章浏览阅读1.2k次。打开下面的网站后,挑选要使用的,https://icomoon.io/app/#/select/image下载后 解压 ,先把fonts里面的文件复制到项目fonts文件夹中去,然后打开其中的style.css文件找到类似下面的代码@font-face {font-family: ‘icomoon’;src: url(’…/fonts/icomoon.eot?r069d6’);s..._arco的ico怎么导入

Microsoft Visual Studio 2010(VS2010)正式版 CDKEY_visual_studio_2010_professional key-程序员宅基地

文章浏览阅读1.9k次。Microsoft Visual Studio 2010(VS2010)正式版 CDKEY / SN:YCFHQ-9DWCY-DKV88-T2TMH-G7BHP企业版、旗舰版都适用推荐直接下载电驴资源的vs旗舰版然后安装,好用方便且省时!) MSDN VS2010 Ultimate 简体中文正式旗舰版破解版下载(附序列号) visual studio 2010正_visual_studio_2010_professional key

互联网医疗的定义及架构-程序员宅基地

文章浏览阅读3.2k次,点赞2次,收藏17次。导读:互联网医疗是指综合利用大数据、云计算等信息技术使得传统医疗产业与互联网、物联网、人工智能等技术应用紧密集合,形成诊前咨询、诊中诊疗、诊后康复保健、慢性病管理、健康预防等大健康生态深度..._线上医疗的定义

计算机毕业设计 基于大数据的智能家居销量数据分析系统的设计与实现 Java实战项目 附源码+文档+视频讲解_基于大数据的智能家居销售数据分析系统 开题报告-程序员宅基地

文章浏览阅读1k次,点赞8次,收藏4次。随着科技的不断发展,智能家居系统已经成为了人们生活中不可或缺的一部分。而随着智能家居销量的不断增加,如何对这些数据进行有效的分析和利用也成为了当前亟待解决的问题。因此,本文提出了一种基于大数据的智能家居销量数据分析系统的设计与实现。该系统主要分为前台和后台两个部分,用户可以通过前台进行注册登录、查看冰箱信息、获取智能家居资讯等操作,管理员则可以通过后台进行用户管理、家电信息管理、系统管理等工作。通过对智能家居销量数据的分析,可以帮助企业更好地了解市场需求,优化产品设计和生产,提高销售效益。_基于大数据的智能家居销售数据分析系统 开题报告

异常:PKIX path building failed: sun.security.provider.certpath.SunCertPathBuilderException:-程序员宅基地

文章浏览阅读3.6w次,点赞2次,收藏19次。问题java使用httpclient或者restTemplate进行https请求时,出现如下异常:javax.net.ssl.SSLHandshakeException: sun.security.validator.ValidatorException: PKIX path building failed: sun.security.provider.certpath.SunCertP..._pkix path building failed: sun.security.provider.certpath.su

c# 窗体开发2 高级控件的使用_tooltiptext c#-程序员宅基地

文章浏览阅读794次,点赞2次,收藏9次。1.单选按钮(RadioButton)同一组中其他单选按钮不能同时选定分组形式:panel GoupBox 窗体方法: 属性 说明 Appearance RadioButton 控件的显示与命令按钮相似 Checked 确定是否已选定控件 方法 ..._tooltiptext c#

推荐文章

热门文章

相关标签