RSSC-LLAVA:基于XTuner用遥感数据微调LLAVA模型-程序员宅基地

技术标签: 人工智能  

RSSC-LLAVA:基于XTuner用遥感数据微调LLAVA模型

项目介绍:

基于UCM场景分类数据集构建简单的对话文本,随后利用Xtuner微调LLAVA模型,实现遥感图文问答,主要是场景分类。
代码分享:https://github.com/biscuit279/RSSC-LLAVA

实现步骤

1.数据准备

下载UCM场景分类数据集,包含21个类别,每个类别有100张图片
LAVVA模型的微调分为两步,第一步是做文本和图像的特征对齐,第二步是图文问答
第一步的时候需要准备两个json文件:

1.图像的文本描述,原作者用GPT生成的数据集,场景分类的话可以用脚本模拟一套

在这里插入图片描述
使用GPT3.5生成代码,prompt如下:

User
帮我写一段python代码:输入是一个Images文件夹,Images包含多个子文件夹,子文件夹的名称是类别名,子文件夹中装有多张图片。输出是一个json文件,格式如下:
[
  {
    "id": "004539375",
    "image": "00453/004539375.jpg",
    "conversations": [
      {
        "from": "human",
        "value": "Render a clear and concise summary of the photo.\n<image>"
      },
      {
        "from": "gpt",
        "value": "select luxury furniture 3 - inch gel memory foam mattress topper"
      }
    ]
  },
  {
    "id": "002239345",
    "image": "00223/002239345.jpg",
    "conversations": [
      {
        "from": "human",
        "value": "Write a terse but informative summary of the picture.\n<image>"
      },
      {
        "from": "gpt",
        "value": "a grey watch with an army style strap"
      }
    ]
  }
id字段是图片的名称,image字段是文件夹名称/图片名称,conversion字段的from不变,第一个value改成"What is it",第二个value改成"This is a photo of {class}"

生成的代码为

import os
import json

def process_images(input_folder):
    output_data = []
    for root, dirs, files in os.walk(input_folder):
        for file in files:
            if file.endswith('.jpg'):
                image_id = os.path.splitext(file)[0]
                image_path = os.path.join(root, file)
                class_name = os.path.basename(root)
                
                conversation_human = {
                    "from": "human",
                    "value": "Write a terse but informative summary of the picture.\n<image>"
                }
                
                conversation_gpt = {
                    "from": "gpt",
                    "value": f"This is a photo of {class_name}"
                }
                
                output_data.append({
                    "id": image_id,
                    "image": f"{class_name[:5]}/{image_id}.jpg",
                    "conversations": [conversation_human, conversation_gpt]
                })
    return output_data

def save_to_json(data, output_file):
    with open(output_file, 'w') as f:
        json.dump(data, f, indent=2)

if __name__ == "__main__":
    input_folder = "Images"
    output_file = "output.json"

    images_data = process_images(input_folder)
    save_to_json(images_data, output_file)

相关脚本已上传至github
2.meta数据集,记录图片的基本信息,主要是名称,URL,以及blip生成的caption
在这里插入图片描述
这个可能用不到,暂时不管

2.修改代码

主要修改config文件中的数据路径以及evaluation案例
随便选了一个数据集中的图片作为evaluation案例,未修改evaluattion_inputs
在这里插入图片描述
在这里插入图片描述

3.环境安装

从github下载xtuner项目并安装

git clone https://github.com/InternLM/xtuner.git
pip install -e '.[all]'

4.pretrain

修改好了以后,运行xtuner train llava_internlm_chat_7b_clip_vit_large_p14_336_e1_gpu8_pretrain --deepspeed deepspeed_zero2
会自动下载internlm7b模型
但是遇到了一个报错,似乎是网络问题
在这里插入图片描述
再次运行相同的指令,竟然无法复现这个错误,变成了一堆莫名其妙的提示
在这里插入图片描述
未完待续。。

2.4更新:
上次的报错其实就是因为那个warning,pandas包没装好,卸载重装即可解决

pip uninstall pandas
pip install pandas

注意在训练时一定要加` --deepspeed deepspeed_zero2``否则会报数据类型不匹配的错误
成功运行:
在这里插入图片描述
2100个图片文本对,一个epoch只需要5分钟左右,占31980m显存

换成56k,70个类别的数据集,用同样的脚本生成json,再次运行

5.finetune

finetune阶段每张图片需要有五组对应的问答,将生成pretrain数据的代码稍作修改,添加其他几种template即可。
总共选择11种问题的模板,9种回答的模板,每组对话的QA都是从模板中随机挑选的。
注意第一组QA中应该有"\n“

import os
import json
import random

Q_templates = ["Describe the image concisely.",
"Provide a brief description of the given image.",
"Offer a succinct explanation of the picture presented.",
"Summarize the visual content of the image.",
"Give a short and clear explanation of the subsequent image.",
"Share a concise interpretation of the image provided.",
"Present a compact description of the photo's key features.",
"Relay a brief, clear account of the picture shown.",
"Render a clear and concise summary of the photo.",
"Write a terse but informative summary of the picture.",
"Create a compact narrative representing the image presented."]

A_templates = ["This is a photo of a {}.",
"This is a satellite image of a {}. ",
"This is a land use image of a {}. ",
"This is a remote sensing image of a {}.", 
"Here is an aerial picture depicting {}.",
"Displayed is an aerial photo illustrating {}.",
"This image captures the aerial perspective of {}.",
"Presented is an aerial view of {}.",
"This picture shows {} from an aerial vantage point."]


def process_images(input_folder):
    output_data = []
    for root, dirs, files in os.walk(input_folder):
        for file in files:
            if file.endswith('.jpg'):
                image_id = os.path.splitext(file)[0]
                image_path = os.path.join(root, file)
                class_name = os.path.basename(root)
                
                conversation_human = {
                    "from": "human",
                    "value": "Write a terse but informative summary of the picture./n<image>"
                }
                
                conversation_gpt = {
                    "from": "gpt",
                    "value": f"This is an aerial image of {class_name}."
                }
                
                conversations = [conversation_human, conversation_gpt]
                Q_samples = random.sample(Q_templates, 4)
                A_samples = random.sample(A_templates, 4)

                for i in range(4):
                    conversation_human = {
                    "from": "human",
                    "value": f"{Q_samples[i]}"
                    }
                    conversation_gpt = {
                    "from": "gpt",
                    "value": f"{A_samples[i]}".format(class_name)
                    
                    }
                    # conversation_gpt['value'].replace('{class_name}', class_name)
                    # import ipdb;ipdb.set_trace()
                    conversations.append(conversation_human)
                    conversations.append(conversation_gpt)

                output_data.append({
                    "id": image_id,
                    "image": f"{image_id}.jpg",
                    "conversations": conversations
                })

    return output_data

def save_to_json(data, output_file):
    with open(output_file, 'w') as f:
        json.dump(data, f, indent=2)

得到的数据例如:

[
  {
    "id": "airplane1",
    "image": "airplane1.jpg",
    "conversations": [
      {
        "from": "human",
        "value": "Write a terse but informative summary of the picture./n<image>"
      },
      {
        "from": "gpt",
        "value": "This is an aerial image of airplane."
      },
      {
        "from": "human",
        "value": "Describe the image concisely."
      },
      {
        "from": "gpt",
        "value": "Presented is an aerial view of airplane."
      },
      {
        "from": "human",
        "value": "Provide a brief description of the given image."
      },
      {
        "from": "gpt",
        "value": "Displayed is an aerial photo illustrating airplane."
      },
      {
        "from": "human",
        "value": "Offer a succinct explanation of the picture presented."
      },
      {
        "from": "gpt",
        "value": "This is a land use image of a airplane. "
      },
      {
        "from": "human",
        "value": "Summarize the visual content of the image."
      },
      {
        "from": "gpt",
        "value": "This is a photo of a airplane."
      }
    ]
  }
  ]

将生成好的数据按以下格式组织:
在这里插入图片描述
随后修改finetune阶段的配置文件llava_internlm_chat_7b_qlora_clip_vit_large_p14_336_lora_e1_gpu8_finetune.py
仍然是需要修改数据、模型的路径和测试图片的位置
随后运行NPROC_PER_NODE=8 xtuner train llava_internlm_chat_7b_qlora_clip_vit_large_p14_336_lora_e1_gpu8_finetune --deepspeed deepspeed_zero2
将8改成实际的显卡数量
这一步使用qlora方法同时微调llm和vit,将会得到两个adapter

发现同一个工作空间,运行finetune的时候还需要重新下载一遍internlm7b模型,猜测应该是下载完之后直接加载,加载结束后就删除了本地文件。更正:默认会下载到/root/.cache/huggingface/hub/下
可以使用以下代码手动下载文件,这样下次运行就不需要重复下载了。

cd ~/RSSC
apt install git git-lfs -y
git lfs install
git lfs clone https://modelscope.cn/Shanghai_AI_Laboratory/internlm-chat-7b.git -b v1.0.3

如果是手动下载的,需要修改模型位置参数llm_name_or_path为模型存放的路径
在这里插入图片描述
但手动下载可能会报一个KeyError,可能是internlm模型的版本问题,暂时未找到合适的解决方案

仍然采用自动下载的方式。8卡A800跑56k张图片数据要一个小时
在这里插入图片描述

6.模型合并与部署

转换成huggingface格式:

xtuner convert pth_to_hf llava_internlm_chat_7b_qlora_clip_vit_large_p14_336_lora_e1_gpu8_finetune ./work_dirs/llava_internlm_chat_7b_qlora_clip_vit_large_p14_336_lora_e1_gpu8_finetune/iter_875.pth/ ./work_dirs/iter_875_hf

可以把config文件的指定成本地py文件,然后修改py文件中的模型路径,从而避免重复缓存模型
格式转换完成后,将会得到llm_adapter,projector,visual_encoder_adapter,可以分别与llm和vit合并
在这里插入图片描述

转换完成后,就可以进行对话了,分别输入llm模型,视觉模型,hf格式的llava模型,以及图片的路径,即可开始对话
xtuner chat internlm/internlm-chat-7b
–visual-encoder openai/clip-vit-large-patch14-336
–llava xtuner/llava-internlm-7b
–prompt-template internlm_chat
–image $IMAGE_PATH

发现用不同的问题,都可以输出正确的场景分类结果。
在这里插入图片描述

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_46212981/article/details/135972166

智能推荐

文本分类特征提取之Word2Vec-程序员宅基地

文章浏览阅读4.4w次,点赞11次,收藏56次。分类问题是人类所面临的一个非常重要且具有普遍意义的问题,我们生活中的很多问题归根到底都是分类问题。文本分类就是根据文本内容将其分到合适的类别,它是自然语言处理的一个十分重要的问题。文本分类主要应用于信息检索,机器翻译,自动文摘,信息过滤,邮件分类等任务。文本分类技术发展历史 1960-1970:那时主要通过人工+规则(关键词或者正则表达式)的方式,制定规则的人需要对某类目领域有足够的认知和了解。举_文本特征提取word2vec

libevent高并发网络编程 - 06_基于libevent的C++线程池实现_windows c++ 开发 客户端 libevent-程序员宅基地

文章浏览阅读1k次。本文利用libevent,实现一个C++线程池,,可自定义用户任务类,继承于任务task基类,重写任务基类的纯虚函数实现多态。比如将定义定义处理客户端的请求任务类,实现对客户端请求的并发处理。工作队列:可以理解为线程的队列,一个线程同时可以处理一个任务,空闲的线程回从任务队列取出任务执行。当工作队列空时,线程会睡眠。任务队列:用户将任务加入任务队列,然后通知工作队列,取出一个任务到线程中执行。_windows c++ 开发 客户端 libevent

工作缺点和不足及措施_【工作中存在的问题和不足及改进措施】_工作中的不足与改进_工作中不足及改进措施...-程序员宅基地

文章浏览阅读3.4w次,点赞3次,收藏11次。篇一:《工作中存在的不足及改进措施》通过近一段时间的工作,反省自身,还存在许多不足和缺点,现将近期的工作、学习中存在的不足和缺点简要总结如下:1、自身的专业业务水平不高,事故应急处理能力不强.虽然通过学习和工作经验的积累,在业务水平上有了一定的提高,但业务水平和工作经验与其它老同志比还是比较低.在日常工作中偏重于日常生产工作,也忽视了自身思想素质的提高,工作中争强当先的意识不强.2、工作上满足于正..._工作不足之处及改进措施

java读取大数据量Excel按需读取(按需加载,速度快)_java 读取大文件excel-程序员宅基地

文章浏览阅读2k次。常用的poi工具,如easy-excel,hutool读取excel是都是先将整个excel加载到内存中分析,然后再一行行遍历,当excel文件太大时读取的时间就会更长,如果我们只需要读取excel的前几行来进行预览就不能使用这种方式,应该按需读取。_java 读取大文件excel

HTML_常用标签测试_html标签检测-程序员宅基地

文章浏览阅读237次。HTML_常用标签测试_html标签检测

【优化模型】牛顿法求解非线性方程组-程序员宅基地

文章浏览阅读482次。牛顿法是一种用于求解非线性方程组的迭代优化方法。其基本原理是基于泰勒级数展开和一阶导数的近似,通过不断迭代修正初始猜测解来逼近方程组的解。Fx0其中,Fxf1​xf2​x...fn​xT是一个多元函数,xx1​x2​...xn​T是待求解的变量向量。牛顿法的基本思想是,在当前的迭代点xk​处,用一个一阶泰勒展开来近似fi​xfi​x≈fi​xk​j1∑n​∂xj​∂fi​xk。

随便推点

克里金插值法(kringing)与PHPnow集成开发环境_后端克里金插值分析-程序员宅基地

文章浏览阅读815次。文章目录摘要摘要_后端克里金插值分析

使用有道云笔记的三个技巧_有道云笔记如何建立 文档索引-程序员宅基地

文章浏览阅读3.3w次,点赞10次,收藏36次。我们在 Windows 操作系统中写文档,做笔记,通常使用 Windows 自带的记事本,可是记事本不支持插入图片,创建表格等功能,从而不得不使用 Office Word。不知道大家有没有这样的感觉,使用 Office Word 写文档,效率极低,需要一边敲字,一边使用鼠标排版,比如:在文章中给团队的名字“LSGO软件技术团队”加粗,就需要先用鼠标选中这个词语,然后点击工具栏中“B”形状的工具..._有道云笔记如何建立 文档索引

IP-guard 远程命令执行漏洞_ipg 漏洞-程序员宅基地

文章浏览阅读137次。IP-guard 远程命令执行漏洞_ipg 漏洞

IOT时代,数据安全更无侥幸-程序员宅基地

文章浏览阅读255次。2017年,全球数据泄露事件已不仅是呈翻倍的速度增长。16年的14亿条,到17年仅上半年的17亿条,这样的数据泄露规模你是否还在存在侥幸心理,就是那所谓的“怎么可能刚好落在我身上”。随着我们在工作、生活中的云化,就在今天,万物互联已经融入到我们每个人的生活中,相信在不就的将来,整个IOT时代也将会很快的到来。仔细回忆一下,今天我们所做的任何情都离不..._8,iot时代,数据安全有哪些新特征?

MySQL 详细学习教程【万字长文, 建议收藏】_mysql教程-程序员宅基地

文章浏览阅读6.7k次,点赞47次,收藏143次。存放文本时,也可以使用Text数据类型,可以将TEXT列视为VARCHAR列,注意Text不能有默认值,大小0-2^16字节;同一查询在同一事务中多次进行,由于其它提交事务所做的修改和删除,每次返回不同的结果集,则发生不可重复读;多个连接开启各自事务操作数据库中数据时,数据库系统要负责隔离操作,以保证各个连接在获取数据是的准确性;同一查询在同一个事务中多次执行,由于其它提交事务所做的插入操作,每次返回不同的结果集,此时发生幻读;同真是的表一样,视图包含列,其数据来自对应的真实表(基表)_mysql教程

GD32官方开发环境及固件库使用笔记(一)_gd32e23 开发环境-程序员宅基地

文章浏览阅读550次,点赞10次,收藏6次。GD32官方的开发环境(基于Eclipse)的使用。_gd32e23 开发环境

推荐文章

热门文章

相关标签