tf.nn.sparse_softmax_cross_entropy_with_logits 函数简介-程序员宅基地

技术标签: python  

欢迎关注微信公众号:python科技园

目录

1. 函数功能

2. 语法结构

3. 实现步骤

    第一步:计算 Softmax

    第二步:计算 Cross Entropy

4. 示例 sparse_softmax_cross_entropy_with_logits

5. 示例 softmax_cross_entropy_with_logits

6. 补充说明

 

函数讲解

1. 函数功能

多分类交叉熵计算函数。它适用于每个类别相互独立且排斥的情况,例如一幅图只能属于一类,而不能同时包含一条狗和一头大象。

 

2. 语法结构

tf.nn.sparse_softmax_cross_entropy_with_logits(labels=None, logits=None, name=None)

 

3. 实现步骤

函数如其名,具体分为两个步骤:Softmax 和 Cross Entropy。

 

第一步:计算 Softmax

在进行文本分类或图像识别等任务时,神经网络输出层的神经元个数通常就是我们要分类的类别数量。Softmax函数的作用就是将每个类别所对应的输出分量归一化,使分量的和为1,可以理解为output vector的输出分量值就是input data分类为每个类别的概率。

 

假设上面这个图中的 z_1, z_2, z_3  为一个三分类模型的output vector,为[3, 1, -3],分别代表类别1、类别2、类别3所对应的分量。经过Softmax函数作用后,将其转化为了 [0.88, 0.12, 0],代表输入的该样本被分到类别1的概率为0.88,分到类别2的概率为0.12,分到类别3的概率几乎为0。这就是 Softmax 函数的作用,Softmax 函数的表达式如下所示:

s_i = \frac{e^{v_i}}{\sum_{j=1}^{N}e^{v_j}}

 

第二步:计算 Cross Entropy

神经网络的输出层经过Softmax函数作用后,接下来就要计算loss了,多分类使用Cross Entropy作为loss function。tf.nn.sparse_softmax_cross_entropy_with_logits() 函数输入的 label 格式是一维向量。Cross Entropy 的表达式如下所示:

H_y(y') = - \sum_i y_i \log(y_i')

其中:y_i 为label中的第 i 个值,y_i' 为经 Softmax归一化输出的vector中的对应分量。当分类越准确时,y_i'  所对应的分量就会越接近于1,从而  H_y(y') 的值也就会越小。

 

PS:如果 label 是 one-hot 格式,则可以使用 tf.nn.softmax_cross_entropy_with_logits() 函数来进行Softmax和loss的计算。

 

4. 示例 sparse_softmax_cross_entropy_with_logits

import tensorflow as tf

labels_sparse = [0, 2, 1]
# 索引,即真实的类别
# 0表示第一个样本的类别属于第1类;
# 2表示第二个样本的类别属于第3类;
# 1表示第三个样本的类别属于第2类;

logits = tf.constant(value=[[3, 1, -3], [1, 4, 3], [2, 7, 5]],
          dtype=tf.float32, shape=[3, 3])


loss_sparse = tf.nn.sparse_softmax_cross_entropy_with_logits(
    labels=labels_sparse,
    logits=logits)


with tf.compat.v1.Session() as sess:
    print("loss_sparse: \n", sess.run(loss_sparse))

结果如下所示:

loss_sparse:

[0.12910892, 1.3490121, 0.13284527]

 

根据计算步骤,使用第一个样本数据简单验证一下:

label_sparse = [0]
logit = tf.constant([3, 1, -3], dtype=tf.float64)

softmax_logit = tf.nn.softmax(logit)
softmax_cross_entropy_logit = -(tf.math.log(softmax_logit[label_sparse]))

with tf.compat.v1.Session() as sess:
    print("softmax_cross_entropy_logit: \n", sess.run(softmax_cross_entropy_logit))

结果如下所示:

softmax_cross_entropy_logit:

0.1291089088298506

可知该值和 loss_sparse 的第一个值一致。

 

5. 示例 softmax_cross_entropy_with_logits

import tensorflow as tf

labels = [[1, 0, 0], [0, 0, 1], [0, 1, 0]]
# 索引,即真实的类别
# 0表示第一个样本的类别属于第1类;
# 2表示第二个样本的类别属于第3类;
# 1表示第三个样本的类别属于第2类;

logits = tf.constant(value=[[3, 1, -3], [1, 4, 3], [2, 7, 5]],
          dtype=tf.float32, shape=[3, 3])


loss = tf.nn.softmax_cross_entropy_with_logits(
    labels=labels,
    logits=logits)


with tf.compat.v1.Session() as sess:
    print("loss: \n", sess.run(loss))

 

结果如下所示:

loss: 

[0.12910892, 1.3490121, 0.13284527]

根据结果可知,sparse_softmax_cross_entropy_with_logits 和 softmax_cross_entropy_with_logits 两种方式的计算结果是相同的。

 

6. 补充说明

使用 tf.keras.models.Model 构建好模型后,在 compile 的时候:

(1)如果:

loss='sparse_categorical_crossentropy'

则构建的样本的label是数字编码,同sparse_softmax_cross_entropy_with_logits中的 labels_sparse = [0, 2, 1] 

 

(2)如果:

loss='categorical_crossentropy'

则构建的样本的label是one-hot编码,同 softmax_cross_entropy_with_logits中的 labels = [[1, 0, 0], [0, 0, 1], [0, 1, 0]]

 

 

参考:

1. https://tensorflow.google.cn/api_docs/python/tf/nn/sparse_softmax_cross_entropy_with_logits?hl=zh-cn

2. https://keras-cn-docs.readthedocs.io/zh_CN/latest/other/metrices/

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/wdh315172/article/details/106140608

智能推荐

更便捷更高效地生产影像地图瓦片_瓦片数据库生产-程序员宅基地

文章浏览阅读1k次。针对海量影像地图,采用地理处理建模构建切片业务流程,并应用“基于金字塔算法生成瓦片”的方案实施更高效的影像地图瓦片生产。_瓦片数据库生产

00942 ora 表存在_"ORA-00942: 表或视图不存在" 的原因和解决方法-程序员宅基地

文章浏览阅读1.3w次。1、问题产生的原因Oracle 是大小写敏感的,我们创自己写Sql脚本创建表的时候Oracle会自动将我们的表名,字段名转成大写。2、问题分析但是 Oracle 同样支持"" 语法,将表名或字段名加上""后,Oracle不会将其转换成大写。如果加上了"",那么我们采用一般的SQL语句查询则会产生“ORA-00942: 表或视图不存在 ”,因此SQL脚本中需要将表名也加上""。例如:select*..._ora00942表或视图不存在,但明明存在

ES2007的tomcat报错的原因_tomcat报错esservicecontraller-程序员宅基地

文章浏览阅读392次。方正ES2007,在能成功连接数据库的前提下,如果tomcat报错,那么你就可以看看你的ip地址是否仍然是自动获取ip,如果是请更改过来,使用具体的ip地址(类似192.168.1.11....),再跑一遍就Ok了。_tomcat报错esservicecontraller

es修改排序_java 如何实现ElasticSearch自定义排序-程序员宅基地

文章浏览阅读1k次。1、es版本用的是5.1由于需要使用es的script的inline功能,需要修改es yml的配置文件,增加如下配置使其支持inlinescript.inline: onscript.stored: onscript.file: onscript.engine.groovy.inline.aggs: on增加完成上述配置需要重启es 注:不同的es版本配置不同2、java代码# 定义传入scri..._es addsort

计算机主机故障有哪些,电脑硬件常见故障有哪些-程序员宅基地

文章浏览阅读1.7k次。电脑硬件常见故障有哪些计算机硬件是指计算机系统中由电子,机械和光电元件等组成的各种物理装置的总称。下面小编来给大家介绍电脑硬件常见故障,希望对大家有帮助!1、CPU温度是CPU常见的'一个问题,CPU温度过高的时候会出现电脑频繁重启的现象,而且是每次开机还未进入系统就重启了,每次重启的时间也越来越短,这个时候就很有可能是CPU温度过高。2、主板电脑有时候开机屏幕上什么也不显示,并且没有出现报错声时..._电脑主机故障的八种表现和对应的现象是什么

java 之 静态泛型方法_java static 泛型 方法-程序员宅基地

文章浏览阅读585次。java 之 静态泛型方法_java static 泛型 方法

随便推点

IJPay支付开源 让你的代码飞起来_ijpay 如何运行-程序员宅基地

文章浏览阅读631次。今天我要给大家推荐一个非常棒的支付开源项目,最近大家在不忙的时候可以学习一下,项目在GitHub有3.3k ,GITEE已有6.5k之多~特别说明:不依赖任何第三方 MVC 框架,仅仅作为工具使用简单快速完成支付模块的开发,可轻松嵌入到任何系统里。微信支付支持多商户多应用,普通商户模式与服务商商模式当然也支持境外商户、同时支持 Api-v3 与 Api-v2 版本的接口。支付宝支付支持多商户多应用,签名同时支持普通公钥方式与公钥证书方式目前封装好的SDK仅支持安卓 ,IOS还在开发中~~~~IJP_ijpay 如何运行

Vue 使用 Apache Echarts 绘制地图(省市、地区、自定义)_vue 地图-程序员宅基地

文章浏览阅读2.3w次,点赞77次,收藏260次。使用Apache Echarts绘制中国、省市级、自定义地图的方法_vue 地图

软件开发工具【十四】 之 常用建模工具_建模工具开发学习-程序员宅基地

文章浏览阅读5.4k次,点赞4次,收藏11次。感谢内容提供者:金牛区吴迪软件开发工作室接上一篇:软件开发工具【十三】 之 Eclipse插件的使用与开发文章目录一、UML建模介绍1.面向对象方法的出现和发展2.面向对象的一些概念3.面向对象方法的基本过程4.组件思想二、RATIONAL ROSE建模工具介绍1.RATIONAL 公司简介2.面向对象的分析设计和Rational Rose3.Rational Rose可视化建模的特点三、使用RATIONAL ROSE 建模1.UML建模的三大部分2.需求分析之用例图与活动图3.系统分析与设计四、E._建模工具开发学习

如何通过gdb查看反汇编代码_pwngdb 显示汇编代码-程序员宅基地

文章浏览阅读2.4w次,点赞15次,收藏77次。0x00 程序源码C代码如下:#include <stdio.h>int addme(int a, int b){ int c ; c = a+ b; return c;}int main(int argc, char const *argv[]){ int ret= 0; ret = addme(10,20); pri..._pwngdb 显示汇编代码

View的简介_view在编程中什么意思-程序员宅基地

文章浏览阅读1.6w次,点赞17次,收藏34次。认识一个新的事物,首先我们从概念上讲,我们需要知道,这个事物 是什么,这个事物有什么用途?对应到View 上,我们要搞明白 View 的定义以及工作原理。 1.View是什么? View是屏幕上的一块矩形区域,它负责用来显示一个区域,并且响应这个区域内的事件。可以说,手机屏幕上的任意一部分看的见得地方都是View,它很常见,比如 TextView 、ImageView 、Button_view在编程中什么意思

ES 内存使用和GC指标_es gc count多少算异常-程序员宅基地

文章浏览阅读3k次。摘录自:http://blog.csdn.net/yangwenbo214/article/details/74000458ES 内存使用和GC指标——主节点每30秒会去检查其他节点的状态,如果任何节点的垃圾回收时间超过30秒(Garbage collection duration),则会导致主节点任务该节点脱离集群。内存使用和GC指标在运行Elasticsearch时,内存是您..._es gc count多少算异常

推荐文章

热门文章

相关标签