yolov5损失函数的几点理解_yolov5 为什么sigmod()*2-程序员宅基地

技术标签: 重要  

yolov5损失函数的几点理解

所用代码:https://github.com/ultralytics/yolov5
参考文献:https://www.cnblogs.com/pprp/p/12590801.html
感谢知乎网友:Ancy贝贝

重要的代码块在build_targets内。

def build_targets(p, targets, model):
    # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
    det = model.module.model[-1] if is_parallel(model) else model.model[-1]  # Detect() module
    na, nt = det.na, targets.shape[0]  # number of anchors, targets
    tcls, tbox, indices, anch = [], [], [], []
    gain = torch.ones(7, device=targets.device)  # normalized to gridspace gain
    ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt)  # same as .repeat_interleave(nt)
    #将targets复制3份,每份分配一个anchor编号,如0,1,2. 也就是每个anchor分配一份targets。
    targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2)  # append anchor indices

    g = 0.5  # bias
    # 这里off表示了5个偏移,原点不动,往右、往下、往左、往上。
    # 其中坐标原点在图像的左上角,x轴往右(列),y轴往下(行)。
    off = torch.tensor([[0, 0],
                        [1, 0], [0, 1], [-1, 0], [0, -1],  # j,k,l,m
                        # [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm
                        ], device=targets.device).float() * g  # offsets

    for i in range(det.nl):
        #det.anchors在导入model的时候就除以了步长,因此此时anchor大小不是相对于原图,而是相对于对应特征层的尺寸
        anchors = det.anchors[i]
        gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]]  # xyxy gain

        # Match targets to anchors
        #这里主要是将gt的cx,cy,w,h换算到当前特征层对应的尺寸,以便和该层的anchor大小相对应
        t = targets * gain
        if nt:
            # Matches
            #这个部分是计算gt和anchor的匹配程度
            #即w_gt/w_anchor  h_gt/h_anchor
            r = t[:, :, 4:6] / anchors[:, None]  # wh ratio
            #这里判断了r和1/r与model.hyp['anchor_t']的大小关系,即只有不大于这个数,也就是说gt与anchor的宽高差距不过大的时候,才认为匹配。代码中 model.hyp['anchor_t']=4
            j = torch.max(r, 1. / r).max(2)[0] < model.hyp['anchor_t']  # compare       
            # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t']  # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
			#将满足条件的targets筛选出来。          
            t = t[j]  # filter
            
            # Offsets
            #这个部分就是扩充targets的数量,将比较targets附近的4个点,选取最近的2个点作为新targets中心,新targets的w、h使用与原targets一致,只是中心点坐标的不同。
            gxy = t[:, 2:4]  # grid xy
            gxi = gain[[2, 3]] - gxy  # inverse
            j, k = ((gxy % 1. < g) & (gxy > 1.)).T
            l, m = ((gxi % 1. < g) & (gxi > 1.)).T
            j = torch.stack((torch.ones_like(j), j, k, l, m))
            t = t.repeat((5, 1, 1))[j] #筛选后t的数量是原来t的3倍。
            offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
        else:
            t = targets[0]
            offsets = 0

        # Define
        b, c = t[:, :2].long().T  # image, class
        gxy = t[:, 2:4]  # grid xy
        gwh = t[:, 4:6]  # grid wh
        gij = (gxy - offsets) #自己加的代码,方便查看gij的分布。
         plot_gxy(gxy=gij, scale_i=i, size=gain, flag='gij') #自己编的代码,用于查看gij的分布。
        gij = (gxy - offsets).long() #将所有targets中心点坐标进行偏移。

        gi, gj = gij.T  # grid xy indices

        # Append
        a = t[:, 6].long()  # anchor indices
        indices.append((b, a, gj, gi))  # image, anchor, grid indices
        tbox.append(torch.cat((gxy - gij, gwh), 1))  # box
        anch.append(anchors[a])  # anchors
        tcls.append(c)  # class

    return tcls, tbox, indices, anch

下图是20x20的特征图上的gij的分布示意图,从图中可以看出每个targets都扩充了2个临近的targets。关于为什么扩充,我还没理解,有知道的网友欢迎留言。另外,知乎网友Ancy贝贝的理解是:之前通过筛选,去掉了一些匹配不上anchor的gt,本来正样本就比负样本少很多,经过筛选,少得更多了,所以每个gt扩充2个出来,增加正样本比例。
在这里插入图片描述

    # Regression
    pxy = ps[:, :2].sigmoid() * 2. - 0.5
    pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
    pbox = torch.cat((pxy, pwh), 1).to(device)  # predicted box
    giou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True)  # giou(prediction, target)
    lbox += (1.0 - giou).mean()  # giou loss

在这里插入图片描述
代码中的pxy对应bxy,ps[:, :2]对应txy。由此可知bxy的取值范围是[-0.5,1.5]。因此有可能偏移到临近的单元格内,但偏移不多,不知道作者是什么考虑的。

在这里插入图片描述
代码中的pwh对应bwh,anchors[i]对应Pwh。因此可知bwh的范围是[0,4]*Pwh。这和前面
j = torch.max(r, 1. / r).max(2)[0] < model.hyp[‘anchor_t’] # model.hyp[‘anchor_t’]=4 是一致的。

Objectness

tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().clamp(0).type(tobj.dtype) # giou ratio
此处 tobj[b, a, gj, gi]用giou(真实的是ciou)取代1,代表该点对应置信度。为什么要用giou来代替,我也没想明白,有知道的网友欢迎留言。

其余的部分比较好理解,在此不再赘述。

附:
plot_gxy的代码:

def plot_gxy(gxy, scale_i, size, flag):
    s = int(size[2].cpu().numpy())
    ax = plt.subplot(111)
    ax.axis([0, s, 0, s])
    lxx = np.arange(0, s + 1, 1)
    lxx = np.repeat(lxx, s + 1, axis=0)
    lxx = lxx.reshape(s + 1, s + 1)
    lyy = np.arange(0, s + 1, 1)
    lyy = np.repeat(lyy, s + 1, axis=0)
    lyy = lyy.reshape(s + 1, s + 1)
    lyy = lyy.T


    for i in range(len(lxx)):
        plt.plot(lxx[i], lyy[i], color='k', linewidth=0.05, linestyle='-')
        plt.plot(lyy[i], lxx[i], color='k', linewidth=0.05, linestyle='-')

    for i in range(len(gxy)):
        x1, y1 = gxy.cpu().numpy().T
        plt.scatter(x1, y1, s=0.02, color='k')

    ax = plt.gca()  # 获取到当前坐标轴信息
    ax.xaxis.set_ticks_position('top')  # 将X坐标轴移到上面
    ax.invert_yaxis()
    plt.savefig("gxy_{}_{}.png".format(scale_i, flag))
    plt.close()
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/tpz789/article/details/108844004

智能推荐

生活垃圾数据集(YOLO版)_垃圾回收数据集-程序员宅基地

文章浏览阅读1.6k次,点赞5次,收藏20次。【有害垃圾】:电池(1 号、2 号、5 号)、过期药品或内包装等;【可回收垃圾】:易拉罐、小号矿泉水瓶;【厨余垃圾】:小土豆、切过的白萝卜、胡萝卜,尺寸为电池大小;【其他垃圾】:瓷片、鹅卵石(小土豆大小)、砖块等。文件结构|----classes.txt # 标签种类|----data-txt\ # 数据集文件集合|----images\ # 数据集图片|----labels\ # yolo标签。_垃圾回收数据集

天气系统3------微服务_cityid=101280803-程序员宅基地

文章浏览阅读272次。之前写到 通过封装的API 已经可以做到使用redis进行缓存天气信息但是这一操作每次都由客户使用时才进行更新 不友好 所以应该自己实现半小时的定时存入redis 使用quartz框架 首先添加依赖build.gradle中// Quartz compile('org.springframework.boot:spring-boot-starter-quartz'..._cityid=101280803

python wxpython 不同Frame 之间的参数传递_wxpython frame.bind-程序员宅基地

文章浏览阅读1.8k次,点赞2次,收藏8次。对于使用触发事件来反应的按钮传递参数如下:可以通过lambda对function的参数传递:t.Bind(wx.EVT_BUTTON, lambda x, textctrl=t: self.input_fun(event=x, textctrl=textctrl))前提需要self.input_fun(self,event,t):传入参数而同时两个Frame之间的参数传..._wxpython frame.bind

cocos小游戏开发总结-程序员宅基地

文章浏览阅读1.9k次。最近接到一个任务要开发消消乐小游戏,当然首先就想到乐cocosCreator来作为开发工具。开发本身倒没有多少难点。消消乐的开发官网发行的书上有专门讲到。下面主要总结一下开发中遇到的问题以及解决方法屏幕适配由于设计尺寸是750*1336,如果适应高度,则在iphonX下,内容会超出屏幕宽度。按宽适应,iphon4下内容会超出屏幕高度。所以就需要根据屏幕比例来动态设置适配策略。 onLoad..._750*1336

ssm435银行贷款管理系统+vue_vue3重构信贷管理系统-程序员宅基地

文章浏览阅读745次,点赞21次,收藏21次。web项目的框架,通常更简单的数据源。21世纪的今天,随着社会的不断发展与进步,人们对于信息科学化的认识,已由低层次向高层次发展,由原来的感性认识向理性认识提高,管理工作的重要性已逐渐被人们所认识,科学化的管理,使信息存储达到准确、快速、完善,并能提高工作管理效率,促进其发展。论文主要是对银行贷款管理系统进行了介绍,包括研究的现状,还有涉及的开发背景,然后还对系统的设计目标进行了论述,还有系统的需求,以及整个的设计方案,对系统的设计以及实现,也都论述的比较细致,最后对银行贷款管理系统进行了一些具体测试。_vue3重构信贷管理系统

乌龟棋 题解-程序员宅基地

文章浏览阅读774次。题目描述原题目戳这里小明过生日的时候,爸爸送给他一副乌龟棋当作礼物。乌龟棋的棋盘是一行 NNN 个格子,每个格子上一个分数(非负整数)。棋盘第 111 格是唯一的起点,第 NNN 格是终点,游戏要求玩家控制一个乌龟棋子从起点出发走到终点。乌龟棋中 MMM 张爬行卡片,分成 444 种不同的类型( MMM 张卡片中不一定包含所有 444 种类型的卡片,见样例),每种类型的卡片上分别标有 1,2,3,41, 2, 3, 41,2,3,4 四个数字之一,表示使用这种卡片后,乌龟棋子将向前爬行相应的格子数

随便推点

python内存泄露的原因_Python服务端内存泄露的处理过程-程序员宅基地

文章浏览阅读1.5k次。吐槽内存泄露 ? 内存暴涨 ? OOM ?首先提一下我自己曾经历过多次内存泄露,到底有几次? 我自己心里悲伤的回想了下,造成线上影响的内存泄露事件有将近5次了,没上线就查出内存暴涨次数可能更多。这次不是最惨,相信也不会是最后的内存的泄露。有人说,内存泄露对于程序员来说,是个好事,也是个坏事。 怎么说? 好事在于,技术又有所长进,经验有所心得…. 毕竟不是所有程序员都写过OOM的服务…. 坏事..._python内存泄露

Sensor (draft)_draft sensor-程序员宅基地

文章浏览阅读747次。1.sensor typeTYPE_ACCELEROMETER=1 TYPE_MAGNETIC_FIELD=2 (what's value mean at x and z axis)TYPE_ORIENTATION=3TYPE_GYROSCOPE=4 TYPE_LIGHT=5(in )TYPE_PRESSURE=6TYPE_TEMPERATURE=7TYPE_PRO_draft sensor

【刘庆源码共享】稀疏线性系统求解算法MGMRES(m) 之 矩阵类定义三(C++)_gmres不构造矩阵-程序员宅基地

文章浏览阅读581次。/* * Copyright (c) 2009 湖南师范大学数计院 一心飞翔项目组 * All Right Reserved * * 文件名:matrix.cpp 定义Point、Node、Matrix类的各个方法 * 摘 要:定义矩阵类,包括矩阵的相关信息和方法 * * 作 者:刘 庆 * 修改日期:2009年7月19日21:15:12 **/

三分钟带你看完HTML5增强的【iframe元素】_iframe allow-top-navigation-程序员宅基地

文章浏览阅读1.7w次,点赞6次,收藏20次。HTML不再推荐页面中使用框架集,因此HTML5删除了&lt;frameset&gt;、&lt;frame&gt;和&lt;noframes&gt;这三个元素。不过HTML5还保留了&lt;iframe&gt;元素,该元素可以在普通的HTML页面中使用,生成一个行内框架,可以直接放在HTML页面的任意位置。除了指定id、class和style之外,还可以指定如下属性:src 指定一个UR..._iframe allow-top-navigation

Java之 Spring Cloud 微服务的链路追踪 Sleuth 和 Zipkin(第三个阶段)【三】【SpringBoot项目实现商品服务器端是调用】-程序员宅基地

文章浏览阅读785次,点赞29次,收藏12次。Zipkin 是 Twitter 的一个开源项目,它基于 Google Dapper 实现,它致力于收集服务的定时数据,以解决微服务架构中的延迟问题,包括数据的收集、存储、查找和展现。我们可以使用它来收集各个服务器上请求链路的跟踪数据,并通过它提供的 REST API 接口来辅助我们查询跟踪数据以实现对分布式系统的监控程序,从而及时地发现系统中出现的延迟升高问题并找出系统性能瓶颈的根源。除了面向开发的 API 接口之外,它也提供了方便的 UI 组件来帮助我们直观的搜索跟踪信息和分析请求链路明细,

烁博科技|浅谈视频安全监控行业发展_2018年8月由于某知名视频监控厂商多款摄像机存在安全漏洞-程序员宅基地

文章浏览阅读358次。“随着天网工程的建设,中国已经建成世界上规模最大的视频监控网,摄像头总 数超过2000万个,成为世界上最安全的国家。视频图像及配套数据已经应用在反恐维稳、治安防控、侦查破案、交通行政管理、服务民生等各行业各领域。烁博科技视频安全核心能力:精准智能数据采集能力:在建设之初即以应用需求为导向,开展点位选择、设备选型等布建工作,实现前端采集设备的精细化部署。随需而动的AI数据挖掘能力:让AI所需要的算力、算法、数据、服务都在应用需求的牵引下实现合理的调度,实现解析能力的最大化。完善的数据治理能力:面_2018年8月由于某知名视频监控厂商多款摄像机存在安全漏洞