Tensorflow中注意力机制的实现:AttentionCellWrapper_attention cell wark_酒酿小圆子~的博客-程序员宅基地

技术标签: Tensorflow  机器学习 & 深度学习  

背景知识

注意力机制最早被用于机器翻译领域,其本质类似于人类在认知事物时的注意力,后因其有效性被广泛用于计算机视觉、语音识别、序列预测等领域。
常见的注意力机制通常是基于Encoder-Decoder的,模型在Decoder阶段进行解码时会考虑编码阶段Encoder的所有隐藏状态。

AttentionCellWrapper理论基础

在Tensorflow中也有现成的注意力API可以使用,即AttentionCellWrapper,具体的实现代码是在tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py文件中。

值得注意的是,Tensoflow中AttentionCellWrapper的实现并不是基于Encoder-Decoder形式的,而是受启发于https://magenta.tensorflow.org/2016/07/15/lookback-rnn-attention-rnn这篇文章中的AttentionRNN。

这篇文章提出了一种单向RNN就能使用的Attention结构(这里我们称为AttentionRNN),在处理每一步的输入时,考虑前面N步的输出,经过映射加权后把这些历史信息加到本次输入的预测中。

In our version, where we don’t have an encoder-decoder, we just always look at the outputs from the last n steps when generating the output for the current step. The way we “look at” these steps is with an attention mechanism.

具体公式如下:
在这里插入图片描述
其中:

  • 矩阵W1、W2和向量v均为可学习的参数
  • hi为前面第i步输出的隐藏状态 ct为当前时刻的细胞状态
  • ui为长度为n的相关系数向量,对于前n个step每个step对应一个相关系数。
  • ai为注意力得分,可通过对相关系数进行softmax操作得到,文章中称ai为attention mask。
  • h’t为当前时刻经过attention后的输出,可通过对前n个step的隐藏状态以及对应的注意力的分加权求和得到。

参数描述

AttentionCellWrapper源码解析

class AttentionWrapper(rnn_cell_impl.RNNCell):
  def __init__(self,
               cell,
               attention_mechanism,
               attention_layer_size=None,
               alignment_history=False,
               cell_input_fn=None,
               output_attention=True,
               initial_cell_state=None,
               name=None,
               attention_layer=None):
  • cell: 被包裹的RNNCell实例;
  • attention_mechanism: attention机制实例,例如BahdanauAttention,也可以是多个attention实例组成的列表;
  • attention_layer_size: 是数字或者数字做成的列表,如果是 None(默认),直接使用加权求和得到的上下文向量 [公式] 作为输出(详见本小节最后的_compute_attention代码),如果不是None,那么将 [公式] 和cell的输出 cell_output进行concat并做线性变换(输出维度为attention_layer_size)再输出。
    这里所说的"输出"在代码里是用"attention"表示的,见本小节最后的_compute_attention函数代码。
  • alignment_history: 即是否将之前的alignments存储到 state 中,以便于后期进行可视化展示,默认False,一般设置为True。
  • cell_input_fn: 怎样处理输入。默认会将上一步的得到的输出与的实际输入进行concat操作作为输入。

参考资料:
Tensorflow中的AttentionCellWrapper:一种更通用的Attention机制
TensorFlow AttentionWrapper源码超详细图解
AttentionCellWrapper

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/u012856866/article/details/105395920

智能推荐

windows7安装oracle 10g安装过程及注意事项_kkkjjjkj的博客-程序员宅基地

windows7安装oracle 10g安装过程及注意事项。1.因为oracle 10g暂时没有与win7兼容的版本,我们可以通过对安装软件中某些文件的修改达到安装的目地。a)打开“/Oracle 10G /stage/prereq/db”路径,找到refhost.xml文件,打开,向其中添加如下代码并保存。 b)打开“/Oracle 10G /in

android中RXJava的基本使用_android中rxjava的使用_acxingyun的博客-程序员宅基地

整个Demo学习下RXJava。 直接上代码。 在gradle中添加依赖: compile 'io.reactivex:rxjava:1.0.14' compile 'io.reactivex:rxandroid:1.0.1'不控制线程package acxingyun.cetcs.com.rxjava;import android.app.Activity;import and_android中rxjava的使用

封装axios接口请求的思路_axios封装思路_「已注销」的博客-程序员宅基地

装axios接口请求的思路1、安装axiosnpm install axios2、在src目录下新建一个http文件夹3.在http下新建两个文件:api.js(用来统一数据请求接口);request.js(封装数据请求方法);4、request.js文件/* eslint-disable quotes *//* eslint-disable dot-notation */// 引入axiosimport axios from 'axios'import { Message } f_axios封装思路

目标检测(五):Faster-RCNN_目标检测网络最多检测多少目标_MajinWakeup的博客-程序员宅基地

目标检测(五):Faster-RCNN1 RPN介绍2 训练Faster-RCNN1主要是为了解决Fast-RCNN中region proposal的耗时问题。之前所用的selective search方法的耗时往往比检测本身的耗时还要长,因此本文提出了一种Region Proposal Network(RPN)用来获取region proposal,同时RPN与Fast-RCNN中的检测网络可..._目标检测网络最多检测多少目标

linux基线合规检查工具,基于Linux系统基线安全加固工具的研究_罐装核子可乐的博客-程序员宅基地

摘要:随着GNU/Linux操作系统的服务器在企业生产环境中所占比重的增加,其安全问题也逐渐受到人们关注.目前开源软件越来越受到企业欢迎,一来,开源软件可以节约项目经费,不涉及版权问题;二者,开源软件的安全性比闭源软件相对更加安全.但是在开源软件占据主导地位的大环境下(企业生产环境),开源软件并非绝对的安全.除了大家都知晓的系统漏洞方面的风险性,操作系统基线配置也是主机安全中的一个薄弱环节.操作系..._linux台式机合规性检测工具

简单语法学习_简单的语法学习_TIMBER熊猫人的博客-程序员宅基地

Markdown语法学习在这里声明一点的是,这些是一些常用的设置。推荐文本编辑器:typora功能一:标题的设置一个#号键加上空格为一标题,两个#号键加上空格为二标题,类推下去,一共的话就有六种标题。功能二:字体的设置*斜体设置:字体两边加上一个号粗体设置:字体两边加上两个*号***斜体加粗设置:**字体两边加上三个号~~删除线设置:~~字体两边加上两个~号功能三:引用的设置一个>号加上空格,就可以看到效果了功能四:分割线的设置方法一:三个*方法二:三个-功能五:_简单的语法学习

随便推点

GBase8s 查看数据库表空间信息_gbase查看数据库中最大的表-程序员宅基地

##onstat -d 查看数据库表空间信息onstat -d命令用于检查数据库空间的使用情况[gbasedbt@node13 ~]$ onstat -dYour evaluation license will expire on 2021-08-27 00:00:00GBase Database Server Version 12.10.FC4G1TL -- On-Line (Prim) -- Up 00:16:02 -- 2023104 KbytesDbspacesaddress _gbase查看数据库中最大的表

Block学习一:Block的实质_ai_pple的博客-程序员宅基地

这篇主要让我们理解Block编译之后变成了什么我们先创建一个类ABlock 只包含简单的带有Block的代码,如下:- (void)method { void (^stackBlock)(void) = ^{ NSLog(@"this is a block"); }; stackBlock();}然后我们打..._block的实质

chia备份及迁移_chia区块数据在什么文件_撞强的博客-程序员宅基地

chia涉及3个文件夹,默认在C盘下,且占用空间越来越大。C:\Users\Administrator\.chia ,数据存储目录,存放主网、钱包、日志、密钥等。C:\Users\Administrator\AppData\Local\chia-blockchain,程序C:\Users\Administrator\AppData\Roaming\Chia Blockchain,不详。迁移和备份需要同时把这3个目录进行处理。若准备清除数据重装chia,则对C:\Users\Adm_chia区块数据在什么文件

Flutter之用户点击行为捕获,及处理点击事件_flutter 捕获成功_长风朗月碎梦的博客-程序员宅基地

1.实现的功能捕获用户的点击行为打印出点击行为捕获拖动行为并用于拖动小球更换主题同时清空打印 的点击行为2.用户手势行为主要是GestureDetector 这个widget提供apichild: GestureDetector( onTap:()=> _printMsg("点击"), ..._flutter 捕获成功

ussd代码大全_魅族ussd补电代码 | 手游网游页游攻略大全_粉色精神分析学家的博客-程序员宅基地

发布时间:2016-08-12魅族MX5在魅族系列手机中一直是非常热门的,目前还是处于抢购状态,能抢到手的用户有限,掌握一些魅族MX5抢购技巧能够增加入手魅族MX5的据.下面99安卓网小编就来分享魅族MX5抢购时间和抢购技巧,供玩家参考. MX5作为 ...标签:魅族MX5 魅族MX5抢购发布时间:2015-12-19华为P9进入工程模式之后可以设置后台,单板查询等.怎么进入工程模式呢?99安卓网..._ussd代码可以设置几个

pjsip android内核,PJSip in android_德州小王子的博客-程序员宅基地

The accepted answer isn't entirely accurate. There are many desirable features missing from the Android SIP API that you may wish to achieve via a 3rd-party library.With respect to the aforementioned ...

推荐文章

热门文章

相关标签