pytorch 语义分割loss_Focal Loss理论及PyTorch实现_第一千零一个人的博客-程序员宅基地

技术标签: pytorch 语义分割loss  

一、基本理论

采用soft - gamma: 在训练的过程中阶段性的增大gamma 可能会有更好的性能提升。

alpha 与每个类别在训练数据中的频率有关。

F.nll_loss(torch.log(F.softmax(inputs, dim=1),target)的函数功能与F.cross_entropy相同。

F.nll_loss中实现了对于target的one-hot encoding,将其编码成与input shape相同的tensor,然后与前面那一项(即F.nll_loss输入的第一项)进行 element-wise production。

基于alpha=1采用不同的gamma值进行实验的结果

focal loss解决了什么问题?

(1)不同类别不均衡

(2)难易样本不均衡

在retinanet中,除了使用呢focal loss外,还对初始化做了特殊处理,具体是怎么做的?

在retinanet中,对 classification subnet 的最后一层conv设置它的偏置b为:

b=−log((1−π)/π)

π代表先验概率,就是类别不平衡中个数少的那个类别占总数的百分比,在检测中就是代表object的anchor占所有anchor的比重,论文中设置的为0.01。

二、公式

标准的Cross Entropy 为:[图片上传失败...(image-286df1-1571884440851)]

Focal Loss 为:[图片上传失败...(image-460db1-1571884440851)]

其中,[图片上传失败...(image-d6c655-1571884440851)]

三、代码实现

一、来自Kaggle的实现(基于二分类交叉熵实现)

class FocalLoss(nn.Module):

def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):

super(FocalLoss, self).__init__()

self.alpha = alpha

self.gamma = gamma

self.logits = logits

self.reduce = reduce

def forward(self, inputs, targets):

if self.logits:

BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)

else:

BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)

pt = torch.exp(-BCE_loss)

F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

if self.reduce:

return torch.mean(F_loss)

else:

return F_loss

二、来自知乎大佬的实现:

import torch

import torch.nn as nn

import torch.nn.functional as F

from torch.autograd import Variable

class FocalLoss(nn.Module):

r"""

This criterion is a implemenation of Focal Loss, which is proposed in

Focal Loss for Dense Object Detection.

Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])

The losses are averaged across observations for each minibatch.

Args:

alpha(1D Tensor, Variable) : the scalar factor for this criterion

gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),

putting more focus on hard, misclassified examples

size_average(bool): By default, the losses are averaged over observations for each minibatch.

However, if the field size_average is set to False, the losses are

instead summed for each minibatch.

"""

def __init__(self, class_num, alpha=None, gamma=2, size_average=True):

super(FocalLoss, self).__init__()

if alpha is None:

self.alpha = Variable(torch.ones(class_num, 1))

else:

if isinstance(alpha, Variable):

self.alpha = alpha

else:

self.alpha = Variable(alpha)

self.gamma = gamma

self.class_num = class_num

self.size_average = size_average

def forward(self, inputs, targets):

N = inputs.size(0)

C = inputs.size(1)

P = F.softmax(inputs)

class_mask = inputs.data.new(N, C).fill_(0)

class_mask = Variable(class_mask)

ids = targets.view(-1, 1)

class_mask.scatter_(1, ids.data, 1.)

#print(class_mask)

if inputs.is_cuda and not self.alpha.is_cuda:

self.alpha = self.alpha.cuda()

alpha = self.alpha[ids.data.view(-1)]

probs = (P*class_mask).sum(1).view(-1,1)

log_p = probs.log()

#print('probs size= {}'.format(probs.size()))

#print(probs)

batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p

#print('-----bacth_loss------')

#print(batch_loss)

if self.size_average:

loss = batch_loss.mean()

else:

loss = batch_loss.sum()

return loss

参考

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_33213911/article/details/112034826

智能推荐

Windows Embedded Standard 7 部署须知_捷杰耶夫的博客-程序员宅基地

作者:Joseph-Growth转自:http://blog.csdn.net/joseph_happy/article/details/7722798Windows Embedded Standard 7是微软新推出的Windows 7嵌入式版本,可在PC机上运行,占用磁盘空间较少,可自行定制功能。可在微软官方网站下载并申请180天测试序列号(若不输入序列号仅能用30天)。安装界

Tapestry_kobexing933的博客-程序员宅基地

http://www.infoq.com/articles/tapestry5-intro一篇非常好的关于Tapestry 5 的介绍文章,本文的目的是展示Tapestry 5 简单,强大的特性,并非为新手教学目的,请大家一定要了解....

服务框架HSF分析之三Consumer启动和处理_iteye_14085的博客-程序员宅基地

 前两篇文章为大家带来了HSF容器启动和Porvider的分享。这篇来分析下consumer端的运行机制。一. Consumer的启动1.     服务代理在HSFSpringConsumer的启动中会返回一个HSFServiceProxy的jdk动态代理,后续调用其实都是通过这个代理类来实现的。InvocationHandler handler = newHSFServ...

关于spring、hibernate 整合错误,请大神们帮忙看看哪里出错了_sinat_34810781的博客-程序员宅基地

在网上找了个项目,准备做 spring + hibernate 整合测试的时候,update()方法出现问题了。update()是用Spring+Hibernate,程序不报错,但是数据库里面的数据没有任何变化。save()是没有任何问题的,save()是 hibernate写的。@RunWith(SpringJUnit4ClassRunner.class) @ContextConfigur

Spring MVC如何访问到静态的文件,如jpg,js,css,png,gif? _iteye_19209的博客-程序员宅基地

如果你的DispatcherServlet拦截 *.do这样的URL,就不存在访问不到静态资源的问题。如果你的DispatcherServlet拦截“/”,拦截了所有的请求,同时对*.js,*.jpg的访问也就被拦截了。目的:可以正常访问静态文件,不要找不到静态文件报404。  方案一:激活Tomcat的defaultServlet来处理静态文件 Xml代码<...

前端与后台交互的数据格式有哪些?_say个嗨呀的博客-程序员宅基地_常用前后端数据交互格式

前后端的交互数据格式有:json、xml及from表格;现在主流的数据格式是json;xml基本不怎么用了;from不常用但是要知道

随便推点

java 如何获得 UTC 1970/01/01经过的毫秒数_冯@冯的博客-程序员宅基地_java utc 毫秒

UTC 1970-01-01经过的毫秒数 //UTC 时间操作 Calendar cal = Calendar.getInstance(); System.out.println(cal); //获取时区和 GMT-0 的时间差,偏移量 int offset = cal.get(Calendar.ZONE_OFFSET); System.out.println(offset); // 获取夏令时 时差

囤题 [补档]_dk810510的博客-程序员宅基地

图论网络流相关CF 653 D题目大意给定一张\(n\)个点, \(m\)条边的有向图(\(n \le 50\), \(m \le 500\)), 每条边都有容量限制. 你要找到至少\(x\)条路径, 使得每条路径点容量都为某个定值\(F\), 且经过任意一条边点所有路径的容量之和小于等于这条边的容量. 求\(F\)的最大值.题解我们令原图的边集为\(E = \{ \left...

从图片中提取曲线坐标数据--基于MATLAB_戈 扬的博客-程序员宅基地_如何根据曲线提取坐标点

转载: https://zhuanlan.zhihu.com/p/521120120.引言在读文献的时,经常遇到这样的情况:文章里提出的方法好有趣啊,好想拿文中用的数据来试试看看能不能得到相近的结果,可是文中只有根据原始数据绘制的曲线图,没有数据。如下图所示。此时,如果能从文中把这幅图截取下来,输入到一个函数中去,最后能返回从图片中提取到的曲线的坐标数据,岂不美哉。2.MATLAB程...

二叉搜索树(二叉排序树)_爱敲代码的三毛的博客-程序员宅基地_二叉搜索树

一.概念二叉搜索树又称二叉排序树,具有以下性质:若它的左子树不为空,则左子树上所有节点的值都小于根节点的值若它的右子树不为空,则右子树上所有节点的值都大于根节点的值它的左右子树也分别为二叉搜索树注意:二叉搜索树中序遍历的结果是有序的二、基本操作1.查找元素思路:二叉搜索树的左子树永远是比根节点小的,而它的右子树则都是比根节点大的值。当前节点比要找的大就往左走,当前元素比要找的小就往右走 public Node search(int key) { if(root ==

动手学深度学习之文本预处理_哈哈哈捧场王的博客-程序员宅基地_文本预处理 深度学习

文本预处理import collectionsimport refrom d2l import torch as d2l将数据集读取到由文本行组成的列表中d2l.DATA_HUB['time_machine'] = (d2l.DATA_URL + 'timemachine.txt', '090b5e7e70c295757f55df93cb0a180b9691891a') # load一本书def read_time_machi

Python 把多个 MP4 合成一个视频_AI悦创的博客-程序员宅基地_python将多个视频合成一个视频

这两天群里有个小伙伴有一个需求, 就是把很多个视频文件 合并成一个. 期间也找了各种软件, 如格式工厂, 但是只能一次合成50个文件, 小伙伴有几千个文件需要合成, 太繁琐; 又比如会声会影, 这个剪辑是很强大, 但是软件也很大, 对电脑配置要求也高. 我只需要拼接功能, 割鸡焉用牛刀?人生苦短 我用 Python????转念一想, Python 也很擅长图形处理, 那处理视频也不在话下吧, 于是就上网搜了搜, 果然找到了简单的办法~开始安装使用主要是利用 moviepy 这个库, 里面提供了丰富的

推荐文章

热门文章

相关标签