ONNX Runtime介绍_onnxruntime-程序员宅基地

技术标签: PyTorch  ONNX Runtime  Deep Learning  

      ONNX Runtime:由微软推出,用于优化和加速机器学习推理和训练,适用于ONNX模型,是一个跨平台推理和训练机器学习加速器(ONNX Runtime is a cross-platform inference and training machine-learning accelerator),源码地址:https://github.com/microsoft/onnxruntime,最新发布版本为v1.11.1,License为MIT:

      1.ONNX Runtime Inferencing:高性能推理引擎

      (1).可在不同的操作系统上运行,包括Windows、Linux、Mac、Android、iOS等;

      (2).可利用硬件增加性能,包括CUDA、TensorRT、DirectML、OpenVINO等;

      (3).支持PyTorch、TensorFlow等深度学习框架的模型,需先调用相应接口转换为ONNX模型;

      (4).在Python中训练,确可部署到C++/Java等应用程序中。

      2.ONNX Runtime Training:于2021年4月发布,可加快PyTorch对模型训练,可通过CUDA加速,目前多用于Linux平台。

      通过conda命令安装执行:

conda install -c conda-forge onnxruntime

      以下为测试代码:通过ResNet-50对图像进行分类

import numpy as np
import onnxruntime
import onnx
from onnx import numpy_helper
import urllib.request
import os
import tarfile
import json
import cv2

# reference: https://github.com/onnx/onnx-docker/blob/master/onnx-ecosystem/inference_demos/resnet50_modelzoo_onnxruntime_inference.ipynb
def download_onnx_model():
    labels_file_name = "imagenet-simple-labels.json"
    model_tar_name = "resnet50v2.tar.gz"
    model_directory_name = "resnet50v2"

    if os.path.exists(model_tar_name) and os.path.exists(labels_file_name):
        print("files exist, don't need to download")
    else:
        print("files don't exist, need to download ...")

        onnx_model_url = "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.tar.gz"
        imagenet_labels_url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"

        # retrieve our model from the ONNX model zoo
        urllib.request.urlretrieve(onnx_model_url, filename=model_tar_name)
        urllib.request.urlretrieve(imagenet_labels_url, filename=labels_file_name)

        print("download completed, start decompress ...")
        file = tarfile.open(model_tar_name)
        file.extractall("./")
        file.close()

    return model_directory_name, labels_file_name

def load_labels(path):
    with open(path) as f:
        data = json.load(f)
    return np.asarray(data)

def images_preprocess(images_path, images_name):
    input_data = []

    for name in images_name:
        img = cv2.imread(images_path + name)
        img = cv2.resize(img, (224, 224))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        data = np.array(img).transpose(2, 0, 1)
        #print(f"name: {name}, opencv image shape(h,w,c): {img.shape}, transpose shape(c,h,w): {data.shape}")
        # convert the input data into the float32 input
        data = data.astype('float32')

        # normalize
        mean_vec = np.array([0.485, 0.456, 0.406])
        stddev_vec = np.array([0.229, 0.224, 0.225])
        norm_data = np.zeros(data.shape).astype('float32')
        for i in range(data.shape[0]):
            norm_data[i,:,:] = (data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]

        # add batch channel
        norm_data = norm_data.reshape(1, 3, 224, 224).astype('float32')
        input_data.append(norm_data)

    return input_data

def softmax(x):
    x = x.reshape(-1)
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

def postprocess(result):
    return softmax(np.array(result)).tolist()

def inference(onnx_model, labels, input_data, images_name, images_label):
    session = onnxruntime.InferenceSession(onnx_model, None)
    # get the name of the first input of the model
    input_name = session.get_inputs()[0].name
    count = 0
    for data in input_data:
        print(f"{count+1}. image name: {images_name[count]}, actual value: {images_label[count]}")
        count += 1

        raw_result = session.run([], {input_name: data})

        res = postprocess(raw_result)

        idx = np.argmax(res)
        print(f"  result: idx: {idx}, label: {labels[idx]}, percentage: {round(res[idx]*100, 4)}%")

        sort_idx = np.flip(np.squeeze(np.argsort(res)))
        print("  top 5 labels are:", labels[sort_idx[:5]])

def main():
    model_directory_name, labels_file_name = download_onnx_model()

    labels = load_labels(labels_file_name)
    print("the number of categories is:", len(labels)) # 1000

    images_path = "../../data/image/"
    images_name = ["5.jpg", "6.jpg", "7.jpg", "8.jpg", "9.jpg", "10.jpg"]
    images_label = ["goldfish", "hen", "ostrich", "crocodile", "goose", "sheep"]
    if len(images_name) != len(images_label):
        print("Error: images count and labes'length don't match")
        return

    input_data = images_preprocess(images_path, images_name)

    onnx_model = model_directory_name + "/resnet50v2.onnx"
    inference(onnx_model, labels, input_data, images_name, images_label)

    print("test finish")

if __name__ == "__main__":
    main()

      测试图像如下所示:

      执行结果如下所示:

 

      GitHub: https://github.com/fengbingchun/PyTorch_Test

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/fengbingchun/article/details/125951896

智能推荐

874计算机科学基础综合,2018年四川大学874计算机科学专业基础综合之计算机操作系统考研仿真模拟五套题...-程序员宅基地

文章浏览阅读1.1k次。一、选择题1. 串行接口是指( )。A. 接口与系统总线之间串行传送,接口与I/0设备之间串行传送B. 接口与系统总线之间串行传送,接口与1/0设备之间并行传送C. 接口与系统总线之间并行传送,接口与I/0设备之间串行传送D. 接口与系统总线之间并行传送,接口与I/0设备之间并行传送【答案】C2. 最容易造成很多小碎片的可变分区分配算法是( )。A. 首次适应算法B. 最佳适应算法..._874 计算机科学专业基础综合题型

XShell连接失败:Could not connect to '192.168.191.128' (port 22): Connection failed._could not connect to '192.168.17.128' (port 22): c-程序员宅基地

文章浏览阅读9.7k次,点赞5次,收藏15次。连接xshell失败,报错如下图,怎么解决呢。1、通过ps -e|grep ssh命令判断是否安装ssh服务2、如果只有客户端安装了,服务器没有安装,则需要安装ssh服务器,命令:apt-get install openssh-server3、安装成功之后,启动ssh服务,命令:/etc/init.d/ssh start4、通过ps -e|grep ssh命令再次判断是否正确启动..._could not connect to '192.168.17.128' (port 22): connection failed.

杰理之KeyPage【篇】_杰理 空白芯片 烧入key文件-程序员宅基地

文章浏览阅读209次。00000000_杰理 空白芯片 烧入key文件

一文读懂ChatGPT,满足你对chatGPT的好奇心_引发对chatgpt兴趣的表述-程序员宅基地

文章浏览阅读475次。2023年初,“ChatGPT”一词在社交媒体上引起了热议,人们纷纷探讨它的本质和对社会的影响。就连央视新闻也对此进行了报道。作为新传专业的前沿人士,我们当然不能忽视这一热点。本文将全面解析ChatGPT,打开“技术黑箱”,探讨它对新闻与传播领域的影响。_引发对chatgpt兴趣的表述

中文字符频率统计python_用Python数据分析方法进行汉字声调频率统计分析-程序员宅基地

文章浏览阅读259次。用Python数据分析方法进行汉字声调频率统计分析木合塔尔·沙地克;布合力齐姑丽·瓦斯力【期刊名称】《电脑知识与技术》【年(卷),期】2017(013)035【摘要】该文首先用Python程序,自动获取基本汉字字符集中的所有汉字,然后用汉字拼音转换工具pypinyin把所有汉字转换成拼音,最后根据所有汉字的拼音声调,统计并可视化拼音声调的占比.【总页数】2页(13-14)【关键词】数据分析;数据可..._汉字声调频率统计

linux输出信息调试信息重定向-程序员宅基地

文章浏览阅读64次。最近在做一个android系统移植的项目,所使用的开发板com1是调试串口,就是说会有uboot和kernel的调试信息打印在com1上(ttySAC0)。因为后期要使用ttySAC0作为上层应用通信串口,所以要把所有的调试信息都给去掉。参考网上的几篇文章,自己做了如下修改,终于把调试信息重定向到ttySAC1上了,在这做下记录。参考文章有:http://blog.csdn.net/longt..._嵌入式rootfs 输出重定向到/dev/console

随便推点

uniapp 引入iconfont图标库彩色symbol教程_uniapp symbol图标-程序员宅基地

文章浏览阅读1.2k次,点赞4次,收藏12次。1,先去iconfont登录,然后选择图标加入购物车 2,点击又上角车车添加进入项目我的项目中就会出现选择的图标 3,点击下载至本地,然后解压文件夹,然后切换到uniapp打开终端运行注:要保证自己电脑有安装node(没有安装node可以去官网下载Node.js 中文网)npm i -g iconfont-tools(mac用户失败的话在前面加个sudo,password就是自己的开机密码吧)4,终端切换到上面解压的文件夹里面,运行iconfont-tools 这些可以默认也可以自己命名(我是自己命名的_uniapp symbol图标

C、C++ 对于char*和char[]的理解_c++ char*-程序员宅基地

文章浏览阅读1.2w次,点赞25次,收藏192次。char*和char[]都是指针,指向第一个字符所在的地址,但char*是常量的指针,char[]是指针的常量_c++ char*

Sublime Text2 使用教程-程序员宅基地

文章浏览阅读930次。代码编辑器或者文本编辑器,对于程序员来说,就像剑与战士一样,谁都想拥有一把可以随心驾驭且锋利无比的宝剑,而每一位程序员,同样会去追求最适合自己的强大、灵活的编辑器,相信你和我一样,都不会例外。我用过的编辑器不少,真不少~ 但却没有哪款让我特别心仪的,直到我遇到了 Sublime Text 2 !如果说“神器”是我能给予一款软件最高的评价,那么我很乐意为它封上这么一个称号。它小巧绿色且速度非

对10个整数进行按照从小到大的顺序排序用选择法和冒泡排序_对十个数进行大小排序java-程序员宅基地

文章浏览阅读4.1k次。一、选择法这是每一个数出来跟后面所有的进行比较。2.冒泡排序法,是两个相邻的进行对比。_对十个数进行大小排序java

物联网开发笔记——使用网络调试助手连接阿里云物联网平台(基于MQTT协议)_网络调试助手连接阿里云连不上-程序员宅基地

文章浏览阅读2.9k次。物联网开发笔记——使用网络调试助手连接阿里云物联网平台(基于MQTT协议)其实作者本意是使用4G模块来实现与阿里云物联网平台的连接过程,但是由于自己用的4G模块自身的限制,使得阿里云连接总是无法建立,已经联系客服返厂检修了,于是我在此使用网络调试助手来演示如何与阿里云物联网平台建立连接。一.准备工作1.MQTT协议说明文档(3.1.1版本)2.网络调试助手(可使用域名与服务器建立连接)PS:与阿里云建立连解释,最好使用域名来完成连接过程,而不是使用IP号。这里我跟阿里云的售后工程师咨询过,表示对应_网络调试助手连接阿里云连不上

<<<零基础C++速成>>>_无c语言基础c++期末速成-程序员宅基地

文章浏览阅读544次,点赞5次,收藏6次。运算符与表达式任何高级程序设计语言中,表达式都是最基本的组成部分,可以说C++中的大部分语句都是由表达式构成的。_无c语言基础c++期末速成