pytorch学习笔记(三十):RNN反向传播计算图公式推导_rnn中的计算图的求导-程序员宅基地

技术标签: 深度学习  人工智能  pytorch  神经网络  # pytorch  

前言

本节将介绍循环神经网络中梯度的计算和存储方法,即 通过时间反向传播(back-propagation through time)

正向传播在循环神经网络中比较直观,而通过时间反向传播其实是反向传播在循环神经网络中的具体应用。我们需要将循环神经网络按时间步展开,从而得到模型变量和参数之间的依赖关系,并依据链式法则应用反向传播计算并存储梯度。

1. 定义模型

简单起见,我们考虑一个无偏差项的循环神经网络,且激活函数为恒等映射( ϕ ( x ) = x \phi(x)=x ϕ(x)=x)。设时间步 t t t 的输入为单样本 x t ∈ R d \boldsymbol{x}_t \in \mathbb{R}^d xtRd,标签为 y t y_t yt,那么隐藏状态 h t ∈ R h \boldsymbol{h}_t \in \mathbb{R}^h htRh的计算表达式为

h t = W h x x t + W h h h t − 1 , \boldsymbol{h}_t = \boldsymbol{W}_{hx} \boldsymbol{x}_t + \boldsymbol{W}_{hh} \boldsymbol{h}_{t-1}, ht=Whxxt+Whhht1,

其中 W h x ∈ R h × d \boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d} WhxRh×d W h h ∈ R h × h \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h} WhhRh×h是隐藏层权重参数。设输出层权重参数 W q h ∈ R q × h \boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h} WqhRq×h,时间步 t t t的输出层变量 o t ∈ R q \boldsymbol{o}_t \in \mathbb{R}^q otRq计算为

o t = W q h h t . \boldsymbol{o}_t = \boldsymbol{W}_{qh} \boldsymbol{h}_{t}. ot=Wqhht.

设时间步 t t t的损失为 ℓ ( o t , y t ) \ell(\boldsymbol{o}_t, y_t) (ot,yt)。时间步数为 T T T的损失函数 L L L定义为

L = 1 T ∑ t = 1 T ℓ ( o t , y t ) . L = \frac{1}{T} \sum_{t=1}^T \ell (\boldsymbol{o}_t, y_t). L=T1t=1T(ot,yt).

我们将 L L L称为有关给定时间步的数据样本的目标函数,并在本节后续讨论中简称为目标函数。

2. 模型计算图

为了可视化循环神经网络中模型变量和参数在计算中的依赖关系,我们可以绘制模型计算图,如图6.3所示。例如,时间步3的隐藏状态 h 3 \boldsymbol{h}_3 h3的计算依赖模型参数 W h x \boldsymbol{W}_{hx} Whx W h h \boldsymbol{W}_{hh} Whh、上一时间步隐藏状态 h 2 \boldsymbol{h}_2 h2以及当前时间步输入 x 3 \boldsymbol{x}_3 x3

在这里插入图片描述

3. 方法

刚刚提到,图6.3中的模型的参数是 W h x \boldsymbol{W}_{hx} Whx, W h h \boldsymbol{W}_{hh} Whh W q h \boldsymbol{W}_{qh} Wqh。与3.14节(正向传播、反向传播和计算图)中的类似,训练模型通常需要模型参数的梯度 ∂ L / ∂ W h x \partial L/\partial \boldsymbol{W}_{hx} L/Whx ∂ L / ∂ W h h \partial L/\partial \boldsymbol{W}_{hh} L/Whh ∂ L / ∂ W q h \partial L/\partial \boldsymbol{W}_{qh} L/Wqh
根据图6.3中的依赖关系,我们可以按照其中箭头所指的反方向依次计算并存储梯度。为了表述方便,我们采用运算符prod表达链式法则。

首先,目标函数有关各时间步输出层变量的梯度 ∂ L / ∂ o t ∈ R q \partial L/\partial \boldsymbol{o}_t \in \mathbb{R}^q L/otRq很容易计算:

∂ L ∂ o t = ∂ ℓ ( o t , y t ) T ⋅ ∂ o t . \frac{\partial L}{\partial \boldsymbol{o}_t} = \frac{\partial \ell (\boldsymbol{o}_t, y_t)}{T \cdot \partial \boldsymbol{o}_t}. otL=Tot(ot,yt).

下面,我们可以计算目标函数有关模型参数 W q h \boldsymbol{W}_{qh} Wqh的梯度 ∂ L / ∂ W q h ∈ R q × h \partial L/\partial \boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h} L/WqhRq×h。根据图6.3, L L L通过 o 1 , … , o T \boldsymbol{o}_1, \ldots, \boldsymbol{o}_T o1,,oT依赖 W q h \boldsymbol{W}_{qh} Wqh。依据链式法则,

∂ L ∂ W q h = ∑ t = 1 T prod ( ∂ L ∂ o t , ∂ o t ∂ W q h ) = ∑ t = 1 T ∂ L ∂ o t h t ⊤ . \frac{\partial L}{\partial \boldsymbol{W}_{qh}} = \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{W}_{qh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{o}_t} \boldsymbol{h}_t^\top. WqhL=t=1Tprod(otL,Wqhot)=t=1TotLht.

其次,我们注意到隐藏状态之间也存在依赖关系。
在图6.3中, L L L只通过 o T \boldsymbol{o}_T oT依赖最终时间步 T T T的隐藏状态 h T \boldsymbol{h}_T hT。因此,我们先计算目标函数有关最终时间步隐藏状态的梯度 ∂ L / ∂ h T ∈ R h \partial L/\partial \boldsymbol{h}_T \in \mathbb{R}^h L/hTRh。依据链式法则,我们得到

∂ L ∂ h T = prod ( ∂ L ∂ o T , ∂ o T ∂ h T ) = W q h ⊤ ∂ L ∂ o T . \frac{\partial L}{\partial \boldsymbol{h}_T} = \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_T}, \frac{\partial \boldsymbol{o}_T}{\partial \boldsymbol{h}_T} \right) = \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_T}. hTL=prod(oTL,hToT)=WqhoTL.

接下来对于时间步 t < T t < T t<T, 在图6.3中, L L L通过 h t + 1 \boldsymbol{h}_{t+1} ht+1 o t \boldsymbol{o}_t ot依赖 h t \boldsymbol{h}_t ht。依据链式法则,
目标函数有关时间步 t < T t < T t<T的隐藏状态的梯度 ∂ L / ∂ h t ∈ R h \partial L/\partial \boldsymbol{h}_t \in \mathbb{R}^h L/htRh需要按照时间步从大到小依次计算:
∂ L ∂ h t = prod ( ∂ L ∂ h t + 1 , ∂ h t + 1 ∂ h t ) + prod ( ∂ L ∂ o t , ∂ o t ∂ h t ) = W h h ⊤ ∂ L ∂ h t + 1 + W q h ⊤ ∂ L ∂ o t \frac{\partial L}{\partial \boldsymbol{h}_t} = \text{prod} (\frac{\partial L}{\partial \boldsymbol{h}_{t+1}}, \frac{\partial \boldsymbol{h}_{t+1}}{\partial \boldsymbol{h}_t}) + \text{prod} (\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{h}_t} ) = \boldsymbol{W}_{hh}^\top \frac{\partial L}{\partial \boldsymbol{h}_{t+1}} + \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_t} htL=prod(ht+1L,htht+1)+prod(otL,htot)=Whhht+1L+WqhotL

将上面的递归公式展开,对任意时间步 1 ≤ t ≤ T 1 \leq t \leq T 1tT,我们可以得到目标函数有关隐藏状态梯度的通项公式

∂ L ∂ h t = ∑ i = t T ( W h h ⊤ ) T − i W q h ⊤ ∂ L ∂ o T + t − i . \frac{\partial L}{\partial \boldsymbol{h}_t} = \sum_{i=t}^T {\left(\boldsymbol{W}_{hh}^\top\right)}^{T-i} \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_{T+t-i}}. htL=i=tT(Whh)TiWqhoT+tiL.

由上式中的指数项可见,当时间步数 T T T 较大或者时间步 t t t 较小时,目标函数有关隐藏状态的梯度较容易出现 衰减爆炸。这也会影响其他包含 ∂ L / ∂ h t \partial L / \partial \boldsymbol{h}_t L/ht项的梯度,例如隐藏层中模型参数的梯度 ∂ L / ∂ W h x ∈ R h × d \partial L / \partial \boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d} L/WhxRh×d ∂ L / ∂ W h h ∈ R h × h \partial L / \partial \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h} L/WhhRh×h
在图6.3中, L L L通过 h 1 , … , h T \boldsymbol{h}_1, \ldots, \boldsymbol{h}_T h1,,hT依赖这些模型参数。
依据链式法则,我们有

∂ L ∂ W h x = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h x ) = ∑ t = 1 T ∂ L ∂ h t x t ⊤ , ∂ L ∂ W h h = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h h ) = ∑ t = 1 T ∂ L ∂ h t h t − 1 ⊤ . \begin{aligned} \frac{\partial L}{\partial \boldsymbol{W}_{hx}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hx}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{x}_t^\top,\\ \frac{\partial L}{\partial \boldsymbol{W}_{hh}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{h}_{t-1}^\top. \end{aligned} WhxLWhhL=t=1Tprod(htL,Whxht)=t=1ThtLxt,=t=1Tprod(htL,Whhht)=t=1ThtLht1.

每次迭代中,我们在依次计算完以上各个梯度后,会将它们存储起来,从而避免重复计算。例如,由于隐藏状态梯度 ∂ L / ∂ h t \partial L/\partial \boldsymbol{h}_t L/ht被计算和存储,之后的模型参数梯度 ∂ L / ∂ W h x \partial L/\partial \boldsymbol{W}_{hx} L/Whx ∂ L / ∂ W h h \partial L/\partial \boldsymbol{W}_{hh} L/Whh的计算可以直接读取 ∂ L / ∂ h t \partial L/\partial \boldsymbol{h}_t L/ht的值,而无须重复计算它们。此外,反向传播中的梯度计算可能会依赖变量的当前值。它们正是通过正向传播计算出来的。
举例来说,参数梯度 ∂ L / ∂ W h h \partial L/\partial \boldsymbol{W}_{hh} L/Whh的计算需要依赖隐藏状态在时间步 t = 0 , … , T − 1 t = 0, \ldots, T-1 t=0,,T1的当前值 h t \boldsymbol{h}_t ht h 0 \boldsymbol{h}_0 h0是初始化得到的)。这些值是通过从输入层到输出层的正向传播计算并存储得到的。

小结

  • 通过时间反向传播是反向传播在循环神经网络中的具体应用。
  • 当总的时间步数较大或者当前时间步较小时,循环神经网络的梯度较容易出现衰减或爆炸。
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_43328040/article/details/107876653

智能推荐

计算机毕业设计Java高校招生管理系统(源码+系统+mysql数据库+Lw文档)_考试招生录取系统伪代码-程序员宅基地

文章浏览阅读533次。计算机毕业设计Java高校招生管理系统(源码+系统+mysql数据库+Lw文档)前端技术:Layui、HTML、CSS、JS、JQuery等技术。JSP健身俱乐部网站设计与实现sqlserver和mysql。ssm基于javaweb开发数码产品推荐平台系统设计与实现。ssm基于HTML的“牧经校园疫情防控网站”的设计与实现。JSP计算机C语言学习网站的设计与实现sqlserver。JSP酒店餐饮管理系统的设计与实现sqlserver。springboot体育馆预定管理平台的设计与实现。_考试招生录取系统伪代码

RecyclerView GridLayoutManager 自适应宽高_gridlayoutmanager 自适应宽度-程序员宅基地

文章浏览阅读8.7k次,点赞2次,收藏2次。import android.content.Context;import android.support.v7.widget.GridLayoutManager;import android.support.v7.widget.RecyclerView;import android.util.TypedValue;/** * Created by Administrator on 2_gridlayoutmanager 自适应宽度

浅谈实时流平台Kafka的消息系统设计_kafka实时方案-程序员宅基地

文章浏览阅读1.2k次。Many users of Kafka process data in processing pipelines consisting of multiple stages, where raw input data is consumed from Kafka topics and then aggregated, enriched, or otherwise transformed into ..._kafka实时方案

计算机控制闪光灯,摄影技巧 闪灯篇 光圈控制主体 快门控制场景 闪光灯又该如何调整输出功率?...-程序员宅基地

文章浏览阅读469次。先设定相机,后设定闪光灯。在离机闪领域,M 模式是应用上的大宗。在拍摄时,我们会面临两个问题,一个是闪光灯的出力(输出功率),另一个则是相机的测光、曝光设定。基本上,我们是先决定相机的设定值,再决定闪光灯的出力问题。先参考环境光,再考虑闪光灯。思考一下:在闪光灯未触发时,相机的设定值本身不能让主体过曝!当主体已过曝,那么闪光灯进来时,结果还是过曝!但如果主体曝光不足,我们就用闪灯将他补足光线。如何..._光圈控制主体的明暗

.Net Core5.0 上传文件报错413 Request Entity Too Large_.net core 413 request entity too large-程序员宅基地

文章浏览阅读612次。开发环境:.Net Core 5.0 + MVC 进行开发.Net Core5.0 上传文件报错413 Request Entity Too Large_.net core 413 request entity too large

pythoninstaller打包 其他电脑无法运行_新手初学 py 后用 pyinstaller 打包程序后运行 exe 出现问题...-程序员宅基地

文章浏览阅读450次。用 python3.6 制作了个爬虫。在 pycham 里能正常运行,用了 requests,beautifulsoup,pandas,json,re,datetime 等第三方库但是 pyinstaller 打包完成没什么问题,运行 exe 程序就出现了错误:源码地址: https://gitee.com/rufengkj/zwu_educational_system/blob/master/s..._pyinstaller打包后别人的电脑不运行

随便推点

蒲公英 · JELLY技术周刊 Vol.06: Deno 1.0 发布前瞻,“真香定律”能否再现_在影片的不同阶段,鼠标交互有不同的效果,非常巧妙地将 web 技术与影片叙事结合到-程序员宅基地

文章浏览阅读849次。登高远眺天高地迥,觉宇宙之无穷基础技术Deno 1.0 即将发布,你需要知道的都在这里了Deno——来自 Node 之父 Ryan Dahl 的最新力作,在开源 2 年之际,终于将迎来 1.0 的正式版本。Deno 并不是 Node 的替代品,根据 Deno GitHub 官网上的介绍,Deno 是一款通用的 JavaScript/TypeScript 编程环境,它汇集了许多最出色的开源技术,并使用一个很小的可执行文件提供了全面的解决方案。如今的 Deno,基于 Rust,内置了 TypeS._在影片的不同阶段,鼠标交互有不同的效果,非常巧妙地将 web 技术与影片叙事结合到

mysql报错1708_mysql的AB及读写和集群-程序员宅基地

文章浏览阅读60次。Mysql的AB及读写第1章 Mysql的AB配置1.1 master配置1.2slave配置1.2.1 192.168.13.1901.2.2 192.168.13.1911.2.3 192.168.13.1921.2.4 192.168.13.1931.2.4 192.168.13.189第2章 读写分离2.1安装mycat2.1.1 server.xml2.1.1 schema.xml2..._hy000 1708

解决树莓派3B+:只有红灯常亮绿灯不亮_树莓派启动不了,绿灯一直不亮-程序员宅基地

文章浏览阅读1w次,点赞2次,收藏10次。一句话总结本文解决方法:可能是烧录软件有问题,使用Etcher烧录后,可正常开机;以下为解决该问题过程:最近重新给树莓派3B+安装系统,按照之前的方法往SD卡烧录系统:烧录方法:1 .SDFormatter格式化SD卡; 2.Win32DiskImager负责写入系统;接通电源后只有红灯亮,绿灯不亮,查询文章和问答,基本没有很好的解决这个问题:文章和问答常见总结:SD卡有问题,或者树莓派坏了;继续查找问题,知知乎上这篇文章如何给树莓派安装操作系统 - 知乎介绍安装SD卡的_树莓派启动不了,绿灯一直不亮

java反序列化耗时_java序列化方式性能比较-程序员宅基地

文章浏览阅读1.1k次。有一个很不错的工具http://github.com/eishay/jvm-serializers/,可以用它来评测各种流行的java序列化反序列化工具,使用上也很简单。想试试该工具的,下载源码后参考起README操作即可。而我更关心的是,是各种工具的性能对比,以作选择的一个衡量标准,也就是http://github.com/eishay/jvm-serializers/wiki的 图示和数据..._序列化反序列化耗时分析

r语言 c 函数返回值,R语言入门 输出函数 cat、print、paste等区别理解-程序员宅基地

文章浏览阅读2.8k次。一、 简介cat、print函数都是输出函数> cat("hello world")hello world>> print("hello world")[1] "hello world"print的输出有点像列表输出的未命名元素> alist=list(c(1,2,3,4,5),c('a','b','c','d','e'))> alist[[1]][1] 1 2 3 ..._r语言print

hadoop启动和运行中的error总结和处理方法-程序员宅基地

文章浏览阅读248次。错误一:2010-11-09 16:59:07,307 INFO org.apache.hadoop.ipc.Server: Error register getProtocolVersionjava.lang.IllegalArgumentException: Duplicate metricsName:getProtocolVersionat org.apa..._启动hadoop报错 got error reading edit log input stream