度量学习——总结_不说话装高手H的博客-程序员宅基地_基于度量学习

技术标签: 深度学习  pytorch  人工智能  

传统方法

User guide: contents — metric-learn 0.6.2 documentation

深度学习

基于深度学习的度量学习方法大都由两个部分组成:特征提取模块和距离度量模块。距离度量模块的任务是使同一类样本间的距离更靠近,而不同类样本间的距离更远离。这一模块更多的实现方法是改进损失函数,对模型的学习更加“赏罚分明”。

基于正负样本对的方法

也可以称为基于对比学习的方法,抽出正负样本对学习。对比学习的方法现在正广泛的应用于学习更好的特征提取模块,即用自监督学习的方法来学习更好的特征表达,更强大的 backbone,如 MoCo、SimCLR 等。Contrastive Representation Learning | Lil'Log

而为了学习更好的距离度量模块,越来越多基于样本对的损失函数被提出:

Contrastive Loss

是最简单也最直观的损失函数:

直观上分析这个公式,当两个样本的标签相同时,模型的损失函数值为这两个样本在特征空间内的距离,这时梯度回传是为了使这两个样本更“靠近”。而当这两个样本的标签不同时,这里的 α 是 margin,当这两个样本对的在特征空间的距离大于 margin 时,就使损失为0,即不更新网络维持现状,当小于 margin 时,我们就惩罚模型,使这两个样本的距离不断逼近 margin。同时 margin 更重要的作用是避免模型欺骗损失函数,即将所有的样本都映射到特征空间的同一个点,学到一个“捷径”使损失不断接近0。

Triplet Loss

上面的方法只有一个样本对,Triplet Loss 则引入了正负样本对的概念。 

xa 称为 anchor 样本,xp 为正样本,xn 为负样本。这里的 margin 同样有避免模型将所有样本映射到特征空间中同一个点的作用。同时 Triplet Loss 的一个关键点是负例挖掘(Negative Samples Mining),将 anchor 样本与正样本间的距离尽可能限制到0附近,同时将 anchor 样本与负样本的距离推开至 margin 左右,

The only requirement is that given two positive examples of the same class and one negative example, the negative should be farther away than the positive by some margin.

Triplet Loss 实现:Triplet Loss and Online Triplet Mining in TensorFlow | Olivier Moindrot blog

Quadruplet Loss

是对 Triplet Loss 的更进一步,其中 xa、xp 和 xs 都属于同一类:

Structured Loss

Triplet Loss 中只考虑了一个负例,而忽略了其他负例。与 Triplet Loss 不同,Structured Loss 考虑的是 batch 中所有的正样本对,以及距离正样本对中两个点最近的负样本:

N-Pair Loss 也是考虑到 Triplet Loss 中的这个缺点而提出的,只不过处理方法与 Structured Loss 不同。

基于交叉熵的方法

上述基于对比方法的正负样本采样问题在多卡训练的情况下变得尤其复杂,而且不能保证具有相似标签的样本被很好地分开,为了解决这两个问题,越来越多基于交叉熵的方法被提出。这些方法都是基于最基本的交叉熵损失函数(也可以称为 softmax loss):

其中 wz+b 是输出分类结果前的那一个全连接层。将模型的输出,标签为 i 的特征向量 z,投影到类别 i 的权重 wi 上,从几何上来说这个结果就是 vector center ,即这个向量 z 在特征空间中的映射点。将在 MNIST 数据集上用 softmax loss 训练的网络中特征的分布可视化(左边为 train,右边为 val):

可以看到还有一些部分不能被很好地分开,即使是在这样简单的一个数据集上。

Center Loss

在 Softmax Loss 的基础上加上正则项,将不同类样本在特征空间内“推开”。

其中 z 是全连接层的输入,c 是可学习的向量,可以理解为每一个类中特征向量的移动平均值(moving mean vector)。

可以看到 Center Loss 将每一类聚类到类别中心,在特征空间内更好地将类别分开。

SphereFace

Center Loss 的问题是,我们不能预先知道数据集中各类别聚类中心在特征空间内是相互远离的,如果他们很靠近的话,这个正则项就起不了多大的作用。但是如果我们将每个类别的聚类中心放到距离特征空间“原点”相同的距离,即将聚类中心都映射到一个“圆”上,只要这个“圆”的半径够大,理论上我们就可以把各类别的聚类中心“推”的更加分散。这个“圆”其实就是一个超平面。这也就是 SphereFace 的做法,它是由 Softmax Loss 一步步进化而来:

由向量积的公式我们知道:

所以上述公式中的做法就是将各类别的权重 w 正则化,正则化后我们完成了将聚类中心“放置”在超平面上的操作,之后再将全连接层的偏置置0,为了可以更简单的分析。公式中的 θ 就是 z 和 wi 之间的夹角,它是大于0小于 pai 的。

在 Softmax Loss 推理时,我们将特征向量 z 通过全连接层,也就是将 z 映射到各个类别的权重上,哪一个结果大,那么它就属于哪一类。反应到下图中,将 z(图中的 x)在特征空间的映射向每一个类别的权重(W1 和 W2)做垂直平分线,这个交点到特征空间“原点”的距离即为分类依据,这个距离就是 Cosθ,也就是全连接层的输出

在 SphereFace 中,由于我们对 w 和 b 进行了一些操作,并且 z 进行了正则化是一个常数。所以当 z 和各类别的权重 w 之间的夹角 θ 更小,那么它就属于哪一类。这样的“决策边界”依然不能保证十分正确的分类,因为我们没有对 z 在特征空间的映射到各聚类中心点在超平面上的距离施加正则或者惩罚项。这也是 SphereFace 第二个创新点,margin μ。SphereFace 公式如下:

所以在推理时只有当 z 与一个类别 w 的夹角 μθ 大于与其他类别的夹角 θ 时,模型才会判定 z 属于这个类别。也就是说 θ 被限定到了如下区间:

通过损失函数影响模型,让模型将 z 映射到特征空间内更小的角度,这样在推理的时候可以更好地判别。下图可以看出 SphereFace 的效果,图中两个 w 之间的红线即为决策边界。

图1

SphereFace 开创了用角度距离来完成分类的先河,接下来的几种方法都是基于此提出。

CosFace

指出 SphereFace 用计算出的角度经过 Cos 函数输出特征向量的调整结果(或者说调整特征向量在特征空间内的映射),但是 Cos 函数不是单调的,所以给优化带来了困难。同时只通过角度的余弦值来判断属于哪一类的话,会导致类别间的距离有的大有的小,降低了区分能力。

CosFace 的做法是将特征向量 z 也进行正则化,同时加上两个超参数 s 和 m:

其中 s 是放缩系数,而 m 是 margin。直观理解是:将垂直平分线交点往特征空间“原点”拉近,类别中各样本在特征空间的映射就会更加靠近“原点”,如图1中每个类别画出的“圆弧”会更加小,及增大了决策边界的角度。

其中 s 和 m 的选择颇为讲究:

其中 K 是特征维度,C 是数据集中类别数,PW 是 expected minimum posterior probability of class center。随着类别数的增加,类别之间 cosine margin 的上限相应地减少。

ArcFace

论文:https://arxiv.org/pdf/1801.07698.pdf

对全连接层的分类输出进行调整,再计算交叉熵损失。

从 softmax loss 的看到这里,这张图就不难理解了。

对全连接层的输入和权重进行正则化后结果为 cosθ,再将其乘上一个超参数 s:

将 cos(θyi) 用 cos(θyi+m) 代替,这部分是 ArcFace 的核心,其背后的意义是是直接在角度空间(angular space)中最大化分类界限。而 CosFace 是将类别映射的更紧凑以期望来达到最大化分类界限的目的,与 ArcFace 在公式上的区别就是增加 margin 的位置。m 为超参数 margin:

下面代码算出的是 L3 中的指数函数 e 的输入,用 ArcFace 进行调整后输入到交叉熵损失中输出损失:

class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin
            cos(theta + m)
        """
    def __init__(self, in_features, out_features, s=30.0,
                 m=0.30, easy_margin=False, ls_eps=0.0, device="cpu"):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.ls_eps = ls_eps  # label smoothing
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

        self.device = device

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight)) # F.normalize(input)、F.normalize(self.weight) 是公式中对输入和权重的正则化
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m    # 三角公式
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        # one_hot = torch.zeros(cosine.size(), device=self.device)
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
        # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output

AdaCos

论文:https://openaccess.thecvf.com/content_CVPR_2019/papers/Zhang_AdaCos_Adaptively_Scaling_Cosine_Logits_for_Effectively_Learning_Deep_Face_CVPR_2019_paper.pdf

这篇文章中对 ArcFace 中的两个超参数 s 和 m 进行了消融实验,P 是 softmax 后归一化的后验概率:

对 s:

当 s 过小时,模型的分类概率达不到 1,这样模型无法做出“自信”的判断,就导致损失函数惩罚了正例;而当 S 过大时,模型过于自信,这时损失函数无法正确地惩罚负例。

对 m:

当 m 过大,当 θ 变得稍大时,模型就不会将其判断为这一类,可以证明加上了 margin 的损失函数比不加 margin 的损失函数使模型的预测的细粒度更小。

Sub-center ArcFace

论文:https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123560715.pdf

解决了 Sphere Face、CosFace 和 ArcFace 对噪声数据敏感的问题,免去了数据清洗的工作更贴近日常生活中的数据。为每一类设定 K 个子中心而不是像之前的做法每一类中只有一个。这样样本的大部分都会靠近 dominant centers,而那些 noisy / hard sample 则会被推向其他的 undominant centers。

但这样做破坏了类内的紧致性,对此文章中的做法是当网络具有足够的识别能力后,直接去掉那些 undominant centers。同时引入了一个恒定的角度阈值来降低高置信噪声数据,在此之后在自动清理的数据集上从头开始重新训练模型。

其中

对应下列代码只是公式中 cosine, _ = torch.max(cosine_all, dim=2) 这个操作:

class ArcMarginProduct_subcenter(nn.Module):
    def __init__(self, in_features, out_features, k):
        super().__init__()
        self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features))
        self.reset_parameters()
        self.k = k
        self.out_features = out_features

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, features):
        cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
        cosine_all = cosine_all.view(-1, self.out_features, self.k)
        cosine, _ = torch.max(cosine_all, dim=2)
        return cosine

下面大部分都是 ArcFace 中的操作,其中 phi = torch.where(cosine > th.view(-1, 1), phi, cosine - mm.view(-1, 1)) 理解为引入了一个恒定的角度阈值来降低高置信噪声数据:

class ArcFaceLossAdaptiveMargin(nn.modules.Module):
    def __init__(self, margins, out_dim, s):
        super().__init__()
        self.crit = DenseCrossEntropy()
        self.s = s
        self.register_buffer('margins', torch.tensor(margins, device="cuda:0"))
        self.out_dim = out_dim

    def forward(self, logits, labels):
        # ms = []
        # ms = self.margins[labels.cpu().numpy()]
        ms = self.margins[labels]
        cos_m = torch.cos(ms)  # torch.from_numpy(np.cos(ms)).float().cuda()
        sin_m = torch.sin(ms)  # torch.from_numpy(np.sin(ms)).float().cuda()
        th = torch.cos(math.pi - ms)  # torch.from_numpy(np.cos(math.pi - ms)).float().cuda()
        mm = torch.sin(math.pi - ms) * ms  # torch.from_numpy(np.sin(math.pi - ms) * ms).float().cuda()
        labels = F.one_hot(labels, self.out_dim)
        labels = labels.half() if CFG.MIXED_PRECISION else labels.float()
        cosine = logits
        sine = torch.sqrt(1.0 - cosine * cosine)
        phi = cosine * cos_m.view(-1, 1) - sine * sin_m.view(-1, 1)
        phi = torch.where(cosine > th.view(-1, 1), phi, cosine - mm.view(-1, 1))
        output = (labels * phi) + ((1.0 - labels) * cosine)
        output *= self.s
        loss = self.crit(output, labels)
        return loss

ArcFace with Dynamic Margin

为了应对严重的类别不均衡而提出,样本更少的类应该具有更大的 margin,以期望更好地与其他类分开。每一类的 margin:

其中 a 和 b 控制着 margin 的上下界,n 为各类别的样本数,λ 控制着这个函数的形状。

参考

Deep Metric Learning: a (Long) Survey – Chan Kha Vu

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/AmbitionalH/article/details/123444279

智能推荐

gmpy2的安装(python)_oooooooohhhhh的博客-程序员宅基地_gmpy2

Gmpy2 安装:(三个方法)直接再pycharm里面安装——file——setting——python interpreter—“+”搜索gmpy2安装(这个步骤可能有点点小问题,我凭记忆来的,安装模板的位置应该都能找到,就不要纠结这个小错误了)。但是我用这个安装报错,使用的下面一个方法安装的 Win+R 打开windows命令窗口,使用pip install gmpy2安装成功 若使用这个pip直接安装不成功,https://www.lfd.uci.edu/~gohlke/pythonlibs/

cpu低端计算机配置清单,i3 4160/GTX750Ti剑灵/英雄联盟中低端组装机配置清单_婧在机器学习中的博客-程序员宅基地

现在的游戏对于电脑配置有着较高的要求,不然配置都带不动游戏,今天给大家推荐的是华硕B85 PRO GAMER主板、Intel酷睿i3-4160处理器以及影驰GTX750Ti大将版独立显卡搭配的组装电脑配置清单,下面华海电脑详细的介绍一下这款游戏中低端DIY电脑配置单,具体组装机配置清单如下:i3 4160/GTX750Ti剑灵/英雄联盟中低端组装机配置清单配件名称品牌型号参考价格处理器Intel酷...

EasyPlayer播放器浏览器ActiveX/OCX插件RTSP播放/抓拍/录像功能调用说明_xiejiashu的博客-程序员宅基地

EasyPlayerPro与EasyPlayer-RTSP新增ocx多窗口播放功能这里以EasyPlayerPro为例,使用方法如下:打开播放器文件夹,进入Bin/C++目录,可以看到reg.bat这个文件,以管理员身份运行 成功运行程序后,找到ocx mutiplayer.html文件,右键选择打开方式,使用ie浏览器打开ie浏览器会弹窗阻止javascript程序运行,看不到视...

html5+ barcode_scan-二维码扫描_weixin_41961749的博客-程序员宅基地

Hello H5+ barcodeBarcode模块管理条码扫描,提供常见的条码(二维码及一维码)的扫描识别功能,可调用设备的摄像头对条码图片扫描进行数据输入。通过plus.barcode可获取条码码管理对象。常量:QR: 条码类型常量,QR二维码,数值为0EAN13: 条码类型常量,EAN一维条形码码标准版,数值为1EAN8: 条码类型常量,ENA一维条形码简版,数值为2A...

Toolbar使用详解_昊帅的博客-程序员宅基地

使用ToolBar需要设置三点: 1.添加依赖库:compile 'com.android.support:appcompat-v7:26.0.0-alpha1'2.activity需要继承AppCompatActivity 3.在AndroidManifest.xml文件中,设置元素使用appcompat中的某个NoActionBar主题,从而来去除使用ActionBar来提供操作栏。<app

kali linux安装后常用软件安装介绍和部分错误解决_5t4rk的博客-程序员宅基地

kali linux 安装之后整理《kaili-linux 虚拟机详细安装简本中文版》http://goo.gl/XjLs1G【普及帖】kali linux 安装之后整理By sIyuAn ps 网上收集、自己整理一些,想安装啥自己节选也不定要全部都弄1.更新软件源:root权限:leafpad /etc/apt/sources.list 用 # 注释官方源,然后添加国

随便推点

厉害,我带的实习生仅用四步就整合好SpringSecurity+JWT实现登录认证_沉默王二的博客-程序员宅基地_暂未登录或token已经过期

小二是新来的实习生,作为技术 leader,我还是很负责任的,有什么锅都想甩给他,啊,不,一不小心怎么把心里话全说出来了呢?重来!小二是新来的实习生,作为技术 leader,我还是很负责任的,有什么好事都想着他,这不,我就安排了一个整合SpringSecurity+JWT实现登录认证的小任务交,没想到,他仅用四步就搞定了,这让我感觉倍有面。一、关于 SpringSecurity在 Spring Boot 出现之前,SpringSecurity 的使用场景是被另外一个安全管理框架 Shiro 牢牢霸占

yum 命令无法使用_枫_林的博客-程序员宅基地_yum命令无法使用

一、yum 命令无法使用[[email protected] ~]# yum install gitLoaded plugins: fastestmirror, langpacksLoading mirror speeds from cached hostfileCould not retrieve mirrorlist http://mirrorlist.centos.org/?release...

js获取 本周,本月,本季度,本年,上月,上周,上季度,去年_雨泽的博客-程序员宅基地

@author  YHC转载必须在文章第一行注明来源地址:  http://blog.csdn.net/yhc13429826359/article/details/8085641/** * 针对Ext的工具类 */var MrYangUtil=function(){ /*** * 获得当前时间 */ this.getCurrentDate=function(){ retur

2018VR眼镜,今天做了一个测评!_A556636928的博客-程序员宅基地

Vr眼镜是什么?近几年炒的非常火热的vr虚拟现实备受大家关注,各大科技公司也开始投产研发,推出新一代的vr眼镜,它到底是个什么东西?对我们的现实生活影响大不大?虚拟现实技术是一种可以创建和体验虚拟世界的计算机仿真系统它利用计算机生成一种模拟环境是一种多源信息融合的交互式的三维动态视景和实体行为的系统仿真使用户沉浸到该环境中。VR眼镜是什么?vr眼镜的分类外接式头戴设备用户体验较好,具备...