torch.embedding and EmbeddingBag 详解-程序员宅基地

技术标签: Pytorch  python  

Embedding

torch.embedding 实际上是一个查找表,一般用来存储词嵌入并通过indices从embedding中恢复词嵌入。

位置:

torch.nn.Embedding

参数及官方解释为:

  • num_embeddings (int): size of the dictionary of embeddings
  • embedding_dim (int): the size of each embedding vector
  • padding_idx (int, optional) :If given, pads the output with the embedding vector at padding_idx (initialized to zeros) whenever it encounters the index.
  • max_norm ((float, optional)):If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm.
  • norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
  • scale_grad_by_freq: If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default False.
  • sparse (bool, optional) :If True, gradient w.r.t. weight matrix will be a sparse tensor. See Notes for more details regarding sparse gradients.

Attributes:

  • weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) initialized from :math:\mathcal{N}(0, 1)

Shape:

  • Input: :math:(*), LongTensor of arbitrary shape containing the indices to extract
  • Output: :math:(*, H), where * is the input shape and :math:H=\text{embedding\_dim}

Examples::

>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902,  0.7172],
         [-0.6431,  0.0748,  0.6969],
         [ 1.4970,  1.3448, -0.9685],
         [-0.3677, -2.7265, -0.1685]],

         [[ 1.4970,  1.3448, -0.9685],
         [ 0.4362, -0.4004,  0.9400],
         [-0.6431,  0.0748,  0.6969],
         [ 0.9124, -2.3616,  1.1151]]])

可以看到当index 相同的时候输出的 embedding是相同的。例如第一个sample的index=2 和第二个sample的index=2,
也就是说对于同一个embedding, 输入的index相同,对应的tensor相同。

with padding

>>> # example with padding_idx
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
>>> input = torch.LongTensor([[0,2,0,5]])
>>> embedding(input)
tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.1535, -2.0309,  0.9315],
         [ 0.0000,  0.0000,  0.0000],
         [-0.1655,  0.9897,  0.0635]]])

当有padding的时候,例如设置padding_idx = 0,也就是当第一个index 和第三个index = 0时,输出的tensor 自动padding为0,而index=2和index=5没有设置padding,所以输出没有被0 padding。

其中的一个classmethod

@classmethod
   def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,
                       max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False):
	r"""Creates Embedding instance from given 2-dimensional FloatTensor.

                Args:
                    embeddings (Tensor): FloatTensor containing weights for the Embedding.
                        First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``.
                    freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
                        Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True``
                    padding_idx (int, optional): See module initialization documentation.
                    max_norm (float, optional): See module initialization documentation.
                    norm_type (float, optional): See module initialization documentation. Default ``2``.
                    scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
                    sparse (bool, optional): See module initialization documentation.

                Examples::

                    >>> # FloatTensor containing pretrained weights
                    >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
                    >>> embedding = nn.Embedding.from_pretrained(weight)
                    >>> # Get embeddings for index 1
                    >>> input = torch.LongTensor([1])
                    >>> embedding(input)
                    tensor([[ 4.0000,  5.1000,  6.3000]])
                """

EmbeddingBag

Computes sums or means of ‘bags’ of embeddings, without instantiating the intermediate embeddings.

支持的三种mode

  • sum:is equivalent to ~torch.nn.Embedding followed by torch.sum(dim=0)
  • mean:is equivalent to ~torch.nn.Embedding followed by torch.mean(dim=0)
  • max:is equivalent to ~torch.nn.Embedding followed by torch.max(dim=0)
    但是用embeddingbag 的效率会更高。
    pytorch支持在forward pass 中增加 per-sample weights,但只在 mode == sum时支持。如果这个参数为0,在计算 weighted sum的时候所有的weight = 1,如果不为0,则按照设置的weight来计算weighted sum。
    其他参数和 embedding差不多
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_46559271/article/details/106356155

智能推荐

如何配置filezilla服务端和客户端_filezilla server for windows (32bit x86)-程序员宅基地

文章浏览阅读7.8k次,点赞3次,收藏9次。如何配置filezilla服务端和客户端百度‘filezilla server’下载最新版。注意点:下载的版本如果是32位的适用xp和win2003,百度首页的是适用于win7或更高的win系统。32和64内容无异。安装过程也是一样的。一、这里的filezilla包括服务端和客户端。我们先来用filezilla server 架设ftp服务端。看步骤。1选择标准版的就可以了。 _filezilla server for windows (32bit x86)

深度学习图像处理01:图像的本质-程序员宅基地

文章浏览阅读724次,点赞18次,收藏8次。深度学习作为一种强大的机器学习技术,已经成为图像处理领域的核心技术之一。通过模拟人脑处理信息的方式,深度学习能够从图像数据中学习到复杂的模式和特征,从而实现从简单的图像分类到复杂的场景理解等多种功能。要充分发挥深度学习在图像处理中的潜力,我们首先需要理解图像的本质。本文旨在深入探讨深度学习图像处理的基础概念,为初学者铺平通往高级理解的道路。我们将从最基础的问题开始:图像是什么?我们如何通过计算机来理解和处理图像?

数据探索阶段——对样本数据集的结构和规律进行分析_数据分析 规律集-程序员宅基地

文章浏览阅读62次。在收集到初步的样本数据之后,接下来该考虑的问题有:(1)样本数据集的数量和质量是否满足模型构建的要求。(2)是否出现从未设想过的数据状态。(3)是否有明显的规律和趋势。(4)各因素之间有什么样的关联性。解决方案:检验数据集的数据质量、绘制图表、计算某些特征量等,对样本数据集的结构和规律进行分析。从数据质量分析和数据特征分析两个角度出发。_数据分析 规律集

上传计算机桌面文件图标不见,关于桌面上图标都不见了这类问题的解决方法-程序员宅基地

文章浏览阅读8.9k次。关于桌面上图标都不见了这类问题的解决方法1、在桌面空白处右击鼠标-->排列图标-->勾选显示桌面图标。2、如果问题还没解决,那么打开任务管理器(同时按“Ctrl+Alt+Del”即可打开),点击“文件”→“新建任务”,在打开的“创建新任务”对话框中输入“explorer”,单击“确定”按钮后,稍等一下就可以见到桌面图标了。3、问题还没解决,按Windows键+R(或者点开始-->..._上传文件时候怎么找不到桌面图标

LINUX 虚拟网卡tun例子——修改_怎么设置tun的接收缓冲-程序员宅基地

文章浏览阅读1.5k次。参考:http://blog.csdn.net/zahuopuboss/article/details/9259283 #include #include #include #include #include #include #include #include #include #include #include #include _怎么设置tun的接收缓冲

UITextView 评论输入框 高度自适应-程序员宅基地

文章浏览阅读741次。创建一个inputView继承于UIView- (instancetype)initWithFrame:(CGRect)frame{ self = [superinitWithFrame:frame]; if (self) { self.backgroundColor = [UIColorcolorWithRed:0.13gre

随便推点

字符串基础面试题_java字符串相关面试题-程序员宅基地

文章浏览阅读594次。字符串面试题(2022)_java字符串相关面试题

VSCODE 实现远程GUI,显示plt.plot, 设置x11端口转发_vscode远程ssh连接服务器 python 显示plt-程序员宅基地

文章浏览阅读1.4w次,点赞12次,收藏21次。VSCODE 实现远程GUI,显示plt.plot, 设置x11端口转发问题服务器 linux ubuntu16.04本地 windows 10很多小伙伴发现VSCode不能显示figure,只有用自带的jupyter才能勉强个截图、或者转战远程桌面,这对数据分析极为不方便。在命令行键入xeyes(一个显示图像的命令)会failed,而桌面下会出现:但是Xshell能实现X11转发图像,有交互功能,但只能用Xshell输入命令plot,实在不方便。其实VScode有X11转发插件!!方法_vscode远程ssh连接服务器 python 显示plt

element-ui switch开关打开和关闭时的文字设置样式-程序员宅基地

文章浏览阅读3.3k次,点赞2次,收藏2次。element switch开关文字显示element中switch开关把on-text 和 off-text 属性改为 active-text 和 inactive-text 属性.怎么把文字描述显示在开关上?下面就是实现方法: 1 <el-table-column label="状态"> 2 <template slot-scope="scope">..._el-switch 不同状态显示不同字

HttpRequestUtil方法get、post、JsonToPost_httprequestutil.httpget-程序员宅基地

文章浏览阅读785次。java后台发起请求使用的工具类package com.cennavi.utils;import org.apache.http.Header;import org.apache.http.HttpResponse;import org.apache.http.HttpStatus;import org.apache.http.client.HttpClient;import org.apache.http.client.methods.HttpPost;import org.apach_httprequestutil.httpget

App-V轻量级应用程序虚拟化之三客户端测试-程序员宅基地

文章浏览阅读137次。在前两节我们部署了App-V Server并且序列化了相应的软件,现在可谓是万事俱备,只欠东风。在这篇博客里面主要介绍一下如何部署客户端并实现应用程序的虚拟化。在这里先简要的说一下应用虚拟化的工作原理吧!App-V Streaming 就是利用templateServer序列化出一个软件运行的虚拟环境,然后上传到app-v Server上,最后客户..._app-v 客户端