EfficientDet 训练自己的数据集_efficientdet训练自己的数据-程序员宅基地

技术标签: python  深度学习  efficientDet  pytorch  json  

EfficientDet训练自己的数据集

项目安装

参考代码:https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch
安装及环境配置可参考作者介绍或者其他博客

数据准备

训练时需要将数据集转换为coco格式的数据集,本人使用的数据集为visdrone数据集,转换过程如下:txt->XML->coco.json

txt->XML

import os
from PIL import Image

# 把下面的路径改成你自己的路径即可
root_dir = "./VisDrone2019-DET-train/"
annotations_dir = root_dir+"annotations/"
image_dir = root_dir + "images/"
xml_dir = root_dir+"Annotations_XML/"
# 下面的类别也换成你自己数据类别,也可适用于其他的数据集转换
class_name = ['ignored regions','pedestrian','people','bicycle','car','van','truck','tricycle','awning-tricycle','bus','motor','others']

for filename in os.listdir(annotations_dir):
    fin = open(annotations_dir+filename, 'r')
    image_name = filename.split('.')[0]
    img = Image.open(image_dir+image_name+".jpg") # 若图像数据是“png”转换成“.png”即可
    xml_name = xml_dir+image_name+'.xml'
    with open(xml_name, 'w') as fout:
        fout.write('<annotation>'+'\n')
        
        fout.write('\t'+'<folder>VOC2007</folder>'+'\n')
        fout.write('\t'+'<filename>'+image_name+'.jpg'+'</filename>'+'\n')
        
        fout.write('\t'+'<source>'+'\n')
        fout.write('\t\t'+'<database>'+'VisDrone2018 Database'+'</database>'+'\n')
        fout.write('\t\t'+'<annotation>'+'VisDrone2018'+'</annotation>'+'\n')
        fout.write('\t\t'+'<image>'+'flickr'+'</image>'+'\n')
        fout.write('\t\t'+'<flickrid>'+'Unspecified'+'</flickrid>'+'\n')
        fout.write('\t'+'</source>'+'\n')
        
        fout.write('\t'+'<owner>'+'\n')
        fout.write('\t\t'+'<flickrid>'+'Haipeng Zhang'+'</flickrid>'+'\n')
        fout.write('\t\t'+'<name>'+'Haipeng Zhang'+'</name>'+'\n')
        fout.write('\t'+'</owner>'+'\n')
        
        fout.write('\t'+'<size>'+'\n')
        fout.write('\t\t'+'<width>'+str(img.size[0])+'</width>'+'\n')
        fout.write('\t\t'+'<height>'+str(img.size[1])+'</height>'+'\n')
        fout.write('\t\t'+'<depth>'+'3'+'</depth>'+'\n')
        fout.write('\t'+'</size>'+'\n')
        
        fout.write('\t'+'<segmented>'+'0'+'</segmented>'+'\n')

        for line in fin.readlines():
            line = line.split(',')
            fout.write('\t'+'<object>'+'\n')
            fout.write('\t\t'+'<name>'+class_name[int(line[5])]+'</name>'+'\n')
            fout.write('\t\t'+'<pose>'+'Unspecified'+'</pose>'+'\n')
            fout.write('\t\t'+'<truncated>'+line[6]+'</truncated>'+'\n')
            fout.write('\t\t'+'<difficult>'+str(int(line[7]))+'</difficult>'+'\n')
            fout.write('\t\t'+'<bndbox>'+'\n')
            fout.write('\t\t\t'+'<xmin>'+line[0]+'</xmin>'+'\n')
            fout.write('\t\t\t'+'<ymin>'+line[1]+'</ymin>'+'\n')
            # pay attention to this point!(0-based)
            fout.write('\t\t\t'+'<xmax>'+str(int(line[0])+int(line[2])-1)+'</xmax>'+'\n')
            fout.write('\t\t\t'+'<ymax>'+str(int(line[1])+int(line[3])-1)+'</ymax>'+'\n')
            fout.write('\t\t'+'</bndbox>'+'\n')
            fout.write('\t'+'</object>'+'\n')
             
        fin.close()
        fout.write('</annotation>')

XML->coco.json

    # coding=utf-8
import xml.etree.ElementTree as ET
import os
import json


voc_clses = ['aeroplane', 'bicycle', 'bird', 'boat',
    'bottle', 'bus', 'car', 'cat', 'chair',
    'cow', 'diningtable', 'dog', 'horse',
    'motorbike', 'person', 'pottedplant',
    'sheep', 'sofa', 'train', 'tvmonitor']


categories = []
for iind, cat in enumerate(voc_clses):
    cate = {
    }
    cate['supercategory'] = cat
    cate['name'] = cat
    cate['id'] = iind
    categories.append(cate)

def getimages(xmlname, id):
    sig_xml_box = []
    tree = ET.parse(xmlname)
    root = tree.getroot()
    images = {
    }
    for i in root:  # 遍历一级节点
        if i.tag == 'filename':
            file_name = i.text  # 0001.jpg
            # print('image name: ', file_name)
            images['file_name'] = file_name
        if i.tag == 'size':
            for j in i:
                if j.tag == 'width':
                    width = j.text
                    images['width'] = width
                if j.tag == 'height':
                    height = j.text
                    images['height'] = height
        if i.tag == 'object':
            for j in i:
                if j.tag == 'name':
                    cls_name = j.text
                cat_id = voc_clses.index(cls_name) + 1
                if j.tag == 'bndbox':
                    bbox = []
                    xmin = 0
                    ymin = 0
                    xmax = 0
                    ymax = 0
                    for r in j:
                        if r.tag == 'xmin':
                            xmin = eval(r.text)
                        if r.tag == 'ymin':
                            ymin = eval(r.text)
                        if r.tag == 'xmax':
                            xmax = eval(r.text)
                        if r.tag == 'ymax':
                            ymax = eval(r.text)
                    bbox.append(xmin)
                    bbox.append(ymin)
                    bbox.append(xmax - xmin)
                    bbox.append(ymax - ymin)
                    bbox.append(id)   # 保存当前box对应的image_id
                    bbox.append(cat_id)
                    # anno area
                    bbox.append((xmax - xmin) * (ymax - ymin) - 10.0)   # bbox的ares
                    # coco中的ares数值是 < w*h 的, 因为它其实是按segmentation的面积算的,所以我-10.0一下...
                    sig_xml_box.append(bbox)
                    # print('bbox', xmin, ymin, xmax - xmin, ymax - ymin, 'id', id, 'cls_id', cat_id)
    images['id'] = id
    # print ('sig_img_box', sig_xml_box)
    return images, sig_xml_box



def txt2list(txtfile):
    f = open(txtfile)
    l = []
    for line in f:
        l.append(line[:-1])
    return l


# voc2007xmls = 'anns'
voc2007xmls = '/data2/chenjia/data/VOCdevkit/VOC2007/Annotations'
# test_txt = 'voc2007/test.txt'
test_txt = '/data2/chenjia/data/VOCdevkit/VOC2007/ImageSets/Main/test.txt'
xml_names = txt2list(test_txt)
xmls = []
bboxes = []
ann_js = {
    }
for ind, xml_name in enumerate(xml_names):
    xmls.append(os.path.join(voc2007xmls, xml_name + '.xml'))
json_name = 'annotations/instances_voc2007val.json'
images = []
for i_index, xml_file in enumerate(xmls):
    image, sig_xml_bbox = getimages(xml_file, i_index)
    images.append(image)
    bboxes.extend(sig_xml_bbox)
ann_js['images'] = images
ann_js['categories'] = categories
annotations = []
for box_ind, box in enumerate(bboxes):
    anno = {
    }
    anno['image_id'] =  box[-3]
    anno['category_id'] = box[-2]
    anno['bbox'] = box[:-3]
    anno['id'] = box_ind
    anno['area'] = box[-1]
    anno['iscrowd'] = 0
    annotations.append(anno)
ann_js['annotations'] = annotations
json.dump(ann_js, open(json_name, 'w'), indent=4)  # indent=4 更加美观显示               

将生成的json及图片按照一下结构放置,注意修改json文件名称:

  • dadasets
    • visdrone2019
      • train2019
      • val2019
      • annotations
        • instances_train2019.json
        • instances_val2019.json

修改projects下coco.yml内容,按照自己的数据库情况修改

project_name: visdrone2019  # also the folder name of the dataset that under data_path folder
train_set: train2019
val_set: val2019
num_gpus: 1

# mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco.
mean: [0.373, 0.378, 0.364]
std: [0.191, 0.182, 0.194]

# this is coco anchors, change it if necessary
anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]'
anchors_ratios: '[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]'

# must match your dataset's category_id.
# category_id is one_indexed,
# for example, index of 'car' here is 2, while category_id of is 3
obj_list: ["pedestrian","people","bicycle","car","van","truck","tricycle","awning-tricycle","bus","motor"]

训练

python train.py -c 2 --batch_size 8 --lr 1e-5 --num_epochs 10
–load_weights /path/to/your/weights/efficientdet-d2.pth

提前下载model文件,放置在文件夹中,建议d0,d1,d2(大了显存会溢出),如出现显存溢出情况,调整batch_size大小。
在这里插入图片描述

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/gu891221/article/details/105783719

智能推荐

Lua编程时遇到的一个错误:attempt to index upvalue-程序员宅基地

文章浏览阅读4k次。最近用OpenResty开发一个产品。新学Lua语言,写了不少脚本。前几天遇到这么一个运行时错误:2022/01/21 18:57:01 [error] 581744#0: *74 lua entry thread aborted: runtime error: /opt/lua/blacklist.lua:98: attempt to index upvalue 'actions' (a number value)简化一下,blacklist.lua代码如下:local actions =_attempt to index upvalue

【Unity API】3---GameObject_unity new gameobject()参数-程序员宅基地

文章浏览阅读425次。1.创建游戏物体的三种方法 public GameObject prefab; // Use this for initialization void Start () { //1.第一种创建方法 GameObject go = new GameObject("Cube"); //2.第二种 ,可以实例化特效或者角色等等 ..._unity new gameobject()参数

python知识图谱问答系统代码_医疗知识图谱问答系统探究(一)-程序员宅基地

文章浏览阅读522次。这是 阿拉灯神丁Vicky 的第 23 篇文章1、项目背景为通过项目实战增加对知识图谱的认识,几乎找了所有网上的开源项目及视频实战教程。果然,功夫不负有心人,找到了中科院软件所刘焕勇老师在github上的开源项目,基于知识图谱的医药领域问答项目QABasedOnMedicaKnowledgeGraph。用了两个晚上搭建了两套,Mac版与Windows版,哈哈,运行成功!!!从无到有搭建一个以疾病为..._chat_graph.py

hdu 3986 Harry Potter and the Final Battle(最短路+枚举删边)_3986 harry potter and the final battle 枚举+最短路(删掉任意-程序员宅基地

文章浏览阅读899次。Harry Potter and the Final BattleTime Limit: 5000/3000 MS (Java/Others) Memory Limit: 65536/65536 K (Java/Others)Total Submission(s): 1741 Accepted Submission(s): 487Problem Descript_3986 harry potter and the final battle 枚举+最短路(删掉任意一条边的最长最短

python开发节目程序_python获取央视节目信息-程序员宅基地

文章浏览阅读342次。# -*- coding: utf-8 -*-#---------------------------------------# 程序:cctv节目表抓取# 作者:lqf# 日期:2013-08-09# 语言:Python 2.7# 功能:抓取央视的节目列表信息#---------------------------------------import stringimport..._python获取电视直播节目单

如何用C语言实现OOP-程序员宅基地

文章浏览阅读1.6k次,点赞5次,收藏18次。我们知道面向对象的三大特性分别是:封装、继承、多态。很多语言例如:C++ 和 Java 等都是面向对象的编程语言,而我们通常说 C 是面向过程的语言,那么是否可以用 C 实现简单的面向对象..._c语言如何实现oop编程

随便推点

xamarin ios_使用Xamarin iOS和Android应用程序将数据存储在Google表格中-程序员宅基地

文章浏览阅读269次。xamarin ios 为您的应用程序打开机遇之门 (OPENS A WORLD OF OPPORTUNITIES FOR YOUR APPS)There have been questions on StackOverflow asking how to use Google Sheets to store data using Xamarin, as a quick and simple d..._xamarin里ios下载保存excel文件

Bean creation exception on FactoryBean type check: org.springframework.beans.factory.UnsatisfiedDepe_bean creation exception on non-lazy factorybean ty-程序员宅基地

文章浏览阅读1.9w次。Bean creation exception on FactoryBean type check: org.springframework.beans.factory.UnsatisfiedDependencyException: Error creating bean with name xxxspring bean创建失败我的是ssm项目,项目已启动就报错,mapper接口和mapper的..._bean creation exception on non-lazy factorybean type check: org.springframew

Dan Abramov - [Just JavaScript] 01 Mental Models(心智模型) 随便翻译一下_justjavascrpit-程序员宅基地

文章浏览阅读779次。是翻译的订阅邮件,非原创,下方有英文原文。看一下这段代码:let a = 10;let b = a;a = 0;运行后a和b的值是多少?在进一步阅读之前,先理解它。如果你已经写了一段时间的js,你可能会存疑:“我每天写的代码比这有难度多了,重点是啥?”本练习的目的不是要想你介绍这些变量,相反,假设你已经十分熟悉这些,本练习的目的是为了然你构建起相应的心智模型。什么是心智模型..._justjavascrpit

嵌入式系统的事件驱动型编程技术_[论文阅读笔记]区块链系统中智能合约技术综述...-程序员宅基地

文章浏览阅读276次。区块链系统中智能合约技术综述范吉立 李晓华 聂铁铮 于戈《计算机科学》2019年8月14页,56个参考文献框架1 引言2 区块链中的智能合约语言2.1 智能合约语言2.2 比特币脚本语言图2.3 以太坊灵完备型语言2.3.1 Solidity语言2.3.2 Serpent语言2.4 可验证型语言Pact2.5 超级账本智能合约语言2.6 开发语言的对比3 区块链中智能合约的实现技术3.1 嵌..._嵌入式事件驱动编程

python嵌入式开发实战_Python和PyQT来开发嵌入式ARM界面如何实现-程序员宅基地

文章浏览阅读386次。Python是一种跨平台的计算机程序设计语言。是一种面向对象的动态类型语言,最初被设计用于编写自动化脚本(shell),随着版本的不断更新和语言新功能的添加,越来越多被用于独立的、大型项目的开发1)。 简介随着Python在互联网人工智能领域的流行,大家也慢慢感受到Python开发的便利,本文就基于嵌入式ARM平台,介绍使用Python配合PyQT5模块来开发图形化应用程序。本文所演示的ARM平台..._qt for python可以写入嵌入式设备吗

python rabbitmq 多任务类型_rabbitmq常用的三种exchange类型和python库pika接入rabbitmq-程序员宅基地

文章浏览阅读108次。现在很多开源软件都提供了对应的web管理界面,rabbitmq也不例外,rabbitmq提供了一个web插件。当rabbit-server启动之后,即在浏览器中通过http://localhost:15672/地址访问页面,提供一个比命令rabbitmqctl更友好的学习rabbitmq的方式。可以简单方便的通过配置rabbitmq,并可以向exchange和queue中发消息来验证自己的理解。如..._python rabbitmq exchange_bind