Caffe: MNIST 数据集格式转换、用 python 读写 LMDB 数据库_minist转换-程序员宅基地

技术标签: 深度学习  MNIST  Caffe  

Preface

这两天概览了一下卜居(赵永科)的《深度学习 21天实战caffe》,进入深度学习挺长时间的了。文章也看了不少,Caffe、Theano、Torch 也都用过。其实个人认为,这本书对于已经深入这个领域已定时间的人来说,帮助不大。本书讲述的只是“术“,有点像深度学习的说明书,讲的很浅。

但是翻了一翻,还是有点收获的,这个 MNIST 手写数字识别是深度学习入门很经典的例子。基本上所有的深度学习框架,在让初学者入门使用的时候都有这个例子。

我一直对 Caffe 中使用的 LMDB、LEVELDB 数据组织比较疑惑,很多时候不明白该怎么样组织图像数据、以及其对应的标签。之前都是按照别人的代码生成的,自己其实懵懵的。所以,我想通过 MNIST 输入数据生成过程,熟悉一下 LMDB、LEVELDB 的基本使用方法。

熟悉了 C++ 版本的转 lmdb 方式后,我会解析一下 python 版本的 lmdb 转换过程

最后 Reference 部分,列出了我这里面参考的文章。


MNIST 及其转 LMDB 数据库源码 create_mnist_data

MINIST(Mixed National Institute of Stanfords and Technology)是一个大型的手写数字数据库,广泛用于机器学习领域的训练和测试,由纽约大学 Yann LeCun 教授整理。MNIST 包括 60000 个训练集和 10000 个测试集,每张图都已经进行尺寸归一化,数字居中处理,固定尺寸为 28×28 。如下图所示:

这里写图片描述


MNIST 数据格式描述

MNIST 具体的文件格式描述如下面的表所示:

MNIST 原始数据文件

MNIST 原始数据文件

训练集图片文件格式描述(train-images-idx3-ubyte)

这里写图片描述

训练集标签文件格式描述(train-labels-idx1-ubyte)

这里写图片描述

测试集图片文件格式描述(t10k-images-idx3-ubyte)

这里写图片描述

测试集标签文件格式描述(t10k-labels-idx1-ubyte)

这里写图片描述

注意:图片文件中像素按行组织,像素值 0 表示背景(白色),像素值 255 表示前景(黑色)。

转换格式、create_mnist_data.cpp 源码解析

先说一下 Caffe 为什么采用 LMDB、LEVELDB,而不是直接读取原始数据?

原因是,一方面,数据类型多种多样,有二进制文件、文本文件、编码后的图像文件(如 JPEG、PNG、网络爬取的数据等),不可能用一套代码实现所有类型的输入数据读取,转换为统一格式可以简化数据读取层的实现;

另一方面,使用 LMDB、LEVELDB 可以提高磁盘 IO 利用率。



下载到的原始数据为二进制文件,需要转换为 LEVELDB 或 LMDB 才能被 Caffe 识别。
我们 Git 得到的 Caffe 中,在 examples/mnist/ 下有一个脚本文件:create_mnist.sh ,这个就可以将原始的二进制数据,生成 LMDB 格式数据。
运行后,会生成 examples/mnist/mnist_train_lmdb/examples/mnist/mnist_test_lmdb/ 这两个目录。每个目录下都有两个文件:data.mdblock.mdb

看一下脚本文件:create_mnist.sh 里面是什么:

#!/usr/bin/env sh
# This script converts the mnist data into lmdb/leveldb format,
# depending on the value assigned to $BACKEND.

EXAMPLE=examples/mnist
DATA=data/mnist
BUILD=build/examples/mnist

BACKEND="lmdb"

echo "Creating ${BACKEND}..."

rm -rf $EXAMPLE/mnist_train_${BACKEND}
rm -rf $EXAMPLE/mnist_test_${BACKEND}

$BUILD/convert_mnist_data.bin $DATA/train-images-idx3-ubyte \
  $DATA/train-labels-idx1-ubyte $EXAMPLE/mnist_train_${BACKEND} --backend=${BACKEND}
$BUILD/convert_mnist_data.bin $DATA/t10k-images-idx3-ubyte \
  $DATA/t10k-labels-idx1-ubyte $EXAMPLE/mnist_test_${BACKEND} --backend=${BACKEND}

echo "Done."


create_mnist_data.cpp 源码解析

可以看到,上面脚本最核心的部分,就是调用 convert_mnist_data.bin 这个可执行程序,对应的源文件为 examples/mnist/convert_mnist_data.cpp,对这个源代码的解读如下,深入这段代码可以更清楚的了解 LMDB 是如何生成的。

// 这段代码将 MNIST 数据集转换为(默认的)lmdb 或者 leveldb(--backend=leveldb) 格式,用于在使用 caffe 的时候读取数据
// 使用方法:
//    convert_mnist_data [FLAGS] input_image_file input_label_file output_db_file

// gflags: 命令行参数解析头文件
#include <gflags/gflags.h> 
// glog: 记录程序日志头文件
#include <glog/logging.h>
// 解析 *.prototxt 文件
#include <google/protobuf/text_format.h>

#include <leveldb/db.h>
#include <leveldb/write_batch.h>
#include <lmdb.h>
#include <stdint.h>
#include <sys/stat.h>

#include <fstream>  // NOLINT(readability/streams)
#include <string>

// 解析caffe中proto类型文件的头文件
#include "caffe/proto/caffe.pb.h"

using namespace caffe; // NOLINT(build/namespace)
using std::string;

// GFLAGS 工具定义命令行选项 backend, 默认值为 lmdb, 即: --backend=lmdb
DEFINE_string(backend, "lmdb", "The backend for storing the result");

// 大小端转换, MNIST 原始数据文件中 32 位整型值为大端存储, C/C++ 变量为小端存储,因此需要加入转换机制
uint32_t swap_endian(uint32_t val) {
    val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
    return (val << 16) | (val >> 16);
}

// 转换数据集函数
void convert_dataset(const char* image_filename, const char* label_filename,
        const char* db_path, const string& db_backend) {

    // 用 C++ 输入文件流以二进制方式打开
    // 定义, 打开图像文件 对象: image_file(读入的文件名, 读入方式), 此处以二进制的方式
    std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
    // 定义, 打开标签文件 对象: label_file(读入的文件名, 读入方式), 此处以二进制的方式
    std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);

    // CHECK: 用于检测文件能否正常打开函数
    CHECK(image_file) << "Unable to open file " << image_filename;
    CHECK(label_file) << "Unable to open file " << label_filename;

    // 读取魔数与基本信息
    // uint32_t 用 typedef 来自定义的一种数据类型, unsigned int32, 每个int32整数占用4个字节, 这样做是为了程序的可扩展性
    uint32_t magic; // 魔数
    uint32_t num_items; // 文件包含条目总数 
    uint32_t num_labels; // 标签值
    uint32_t rows; // 行数
    uint32_t cols; // 列数

    // 读取魔数: magic
    // image_file.read( 读取内容的指针, 读取的字节数 ) , magic 是一个 int32 类型的整数,每个占 4 个字节,所以这里指定为 4
    // reinterpret_cast 为 C++ 中定义的强制转换符, 这里把 &magic, 即 magic 的地址(一个 16 进制的数), 转变成 char 类型的指针
    image_file.read(reinterpret_cast<char*>(&magic), 4);
    // 大端到小端的转换
    magic = swap_endian(magic);
    // 校验图像文件中魔数是否为 2051, 不是则报错
    CHECK_EQ(magic, 2051) << "Incorrect image file magic.";
    // 同理, 校验标签文件中的魔数是否为 2049, 不是则报错
    label_file.read(reinterpret_cast<char*>(&magic), 4);
    magic = swap_endian(magic);
    CHECK_EQ(magic, 2049) << "Incorrect label file magic.";

    // 读取图片的数量: num_items
    image_file.read(reinterpret_cast<char*>(&num_items), 4);
    num_items = swap_endian(num_items); // 大端到小端转换
    // 读取图片标签的数量
    label_file.read(reinterpret_cast<char*>(&num_labels), 4);
    num_labels = swap_endian(num_labels); // 大端到小端转换
    // 图片数量应等于其标签数量, 检查两者是否相等
    CHECK_EQ(num_items, num_labels);
    // 读取图片的行大小
    image_file.read(reinterpret_cast<char*>(&rows), 4);
    rows = swap_endian(rows); // 大端到小端转换
    // 读取图片的列大小
    image_file.read(reinterpret_cast<char*>(&cols), 4);
    cols = swap_endian(cols); // 大端到小端转换

    // lmdb 相关句柄
    MDB_env *mdb_env;
    MDB_dbi mdb_dbi;
    MDB_val mdb_key, mdb_data;
    MDB_txn *mdb_txn;
    // leveldb 相关句柄
    leveldb::DB* db;
    leveldb::Options options;
    options.error_if_exists = true;
    options.create_if_missing = true;
    options.write_buffer_size = 268435456;
    level::WriteBatch* batch = NULL;

    // 打开 db
    if (db_backend == "leveldb") { // leveldb
        LOG(INFO) << "Opening leveldb " << db_path;
        leveldb::Status status = leveldb::DB::Open(
          options, db_path, &db);
        CHECK(status.ok()) << "Failed to open leveldb " << db_path << ". Is it already existing?";
        batch = new leveldb::WriteBatch();
    }else if (db_backend == "lmdb") { // lmdb
        LOG(INFO) << "Opening lmdb " << db_path;
        CHECK_EQ(mkdir(db_path, 0744), 0) << "mkdir " << db_path << "failed";
        CHECK_EQ(mdb_env_create(&mdb_env), MDB_SUCCESS) << "mdb_env_create failed";
        CHECK_EQ(mdb_env_set_mapsize(mdb_env, 1099511627776), MDB_SUCCESS) << "mdb_env_set_mapsize failed"; // 1TB
        CHECK_EQ(mdb_env_open(mdb_env, db_path, 0, 0664), MDB_SUCCESS) << "mdb_env_open_failed";
        CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS) << "mdb_txn_begin failed";
        CHECK_EQ(mdb_open(mdb_txn, NULL, 0, &mdb_dbi), MDB_SUCCESS) << "mdb_open failed. Does the lmdb already exist?";
    } else {
        LOG(FATAL) << "Unknown db backend " << db_backend;
    }

    // 将读取数据保存至 db
    char label;
    char* pixels = new char[rows * cols];
    int count = 0;
    const int kMaxKeyLength = 10;
    char key_cstr[kMaxKeyLength];
    string value;

    // 设置datum数据对象的结构,其结构和源图像结构相同
    Datum datum;
    datum.set_channels(1);
    datum.set_height(rows);
    datum.set_width(cols);

    // 输出 Log, 输出图片总数
    LOG(INFO) << "A total of " << num_items << " items.";
    // 输出 Log, 输出图片的行、列大小
    LOG(INFO) << "Rows: " << rows << " Cols: " << cols;

    // 读取图片数据以及 label 存入 protobuf 定义好的数据结构中,
    // 序列化成字符串储存到数据库中,
    // 这里为了减少单次操作带来的带宽成本(验证数据包完整等), 
    // 每 1000 次执行一次操作
    for (int item_id = 0; item_id < num_items; ++item_id) {
        // 从数据中读取 rows * cols 个字节, 图像中一个像素值(应该是 int8 类型)用一个字节表示即可
        image_file.read(pixels, rows * cols);
        // 读取标签
        label_file.read(&label, 1);
        // set_data 函数, 把源图像值放入 datum 对象
        datum.set_data(pixels, rows*cols);
        // set_label 函数, 把标签值放入 datum
        datum.set_label(label);

        // snprintf(str1, size_t, "format", str), 把 str 按照 format 的格式以字符串的形式写入 str1, size_t 表示写入的字符个数    
        // 这里是把 item_id 转换成 8 位长度的十进制整数,然后在变成字符串复制给 key_str, 如:item_id=1500(int), 则 key_cstr = 00015000(string, \0为字符串结束标志)
        snprintf(key_cstr, kMaxKeyLength, "%08d", item_id);
        datum.SerializeToString(&value);
        // 感觉是将 datum 中的值序列化成字符串,保存在变量 value 内,通过指针来给 value 赋值
        string keystr(key_cstr);

        // 放到数据库中
        if (db_backend == "leveldb") { // leveldb
            // 通过 batch 中的子方法 Put, 把数据写入 datum 中(此时在内存中)
            batch->Put(keystr, value);
        } else if (db_backend == "lmdb") { // lmdb
            // mv 应该是 move value, 应该是和 write() 和 read() 函数文件读写的方式一样, 以固定的子节长度按照地址进行读写操作
            // 获取 value 的子节长度, 类似 sizeof() 函数
            mdb_data.mv_size = value.size()
            // 把 value 的首个字符地址转换成空类型的指针
            mdb_data.mv_data = reinterpret_cast<void*>(&value[0]);
            mdb_key.mv_size = keystr.size();
            mdb_key.mv_data = reinterpret_cast<void*>(&keystr[0]);
            // 通过 mdb_put 函数把 mdb_key 和 mdb_data 所指向的数据,  写入到 mdb_dbi
            CHECK_EQ(mdb_put(mdb_txn, mdb_dbi, &mdb_key, &mdb_data, 0), MDB_SUCCESS) << "mdb_put failed";
        } else {
            LOG(FATAL) << "Unknown db backend " << db_back_end;
        }

        // 把 db 数据写入硬盘
        // 选择 1000 个样本放入一个 batch 中,通过 batch 以批量的方式把数据写入硬盘
        // 写入硬盘通过 db.write() 函数来实现
        if (++count % 1000 == 0) {
            // 批量提交更改
            if(db_backend == "leveldb") { // leveldb
                // 把batch写入到 db 中,然后删除 batch 并重新创建
                db->Write(leveldb::WriteOptions(), batch);
                delete batch;
                batch = new leveldb::WriteBatch();
            } else if (db_backend == "lmdb") { // lmdb
                // 通过 mdb_txn_commit 函数把 mdb_txn 数据写入到硬盘
                CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS) << "mdb_txn_commit failed";
                // 重新设置 mdb_txn 的写入位置, 追加(继续)写入
                CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS) << "mdb_txn_begin failed";
            } else {
                LOG(FATAL) << "Unknown db backend " << db_backend;
            }
        } // if (++count % 1000 == 0) 

    } // for (int item_id = 0; item_id < num_items; ++item_id)

    // 写最后一个 batch 
    if (count % 1000 != 0) {
        if (db_backend == "leveldb") { // leveldb
            db->Write(leveldb::WriteOptions(), batch);
            delete batch;
            delete db; // 删除临时变量,清理内存占用
        } else if (db_backend == "lmdb") { // lmdb
            CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS) << "mdb_txn_commit failed";
            // 关闭 mdb 数据对象变量
            mdb_close(mdb_env, mdb_dbi);
            // 关闭 mdb 操作环境变量
            mdb_env_close(mdb_env);
        } else {
            LOG(FATAL) << "Unknown db backend " << db_backend;
        }
        LOG(ERROE) << "Processed " << count << " files.";
    }

    delete[] pixels;
} // void convert_dataset(const char* image_filename, const char* label_filename, const char* db_path, const string& db_backend)

int main(int argc, char** argv) {
#ifndef GFLAGS_GFLAGS_H
    namespace gflags = google;
#endif

    gflags::SetUsageMessage("This script converts the MNIST dataset to \n"
        "the lmdb/leveldb format used by Caffe to load data. \n"
        "Usage:\n"
        "    convert_mnist_data [FLAGS] input_image_file input_label_file "
        "output_db_file\n"
        "The MNIST dataset could be downloaded at\n"
        "    http://yann.lecun.com/exdb/mnist/\n"
        "You should gunzip them after downloading,"
        "or directly use the data/mnist/get_mnist.sh\n");
    gflags::ParseCommandLineFlags(&argc, &argv, true);

    // FLAGS_backend 在前面通过 DEFINE_string 定义,是字符串类型
    const string& db_backend = FLAGS_backend;

    if (argc != 4) {
        gflags::ShowUsageWithFlagsRestrict(argv[0], "examples/mnist/convert_mnist_data");
    } else {
        google::InitGoogleLogging(argv[0]);
        convert_dataset(argv[1], argv[2], argv[3], db_backend);
    }

    return 0;
}


LMDB 句柄

变量 说明
MDB_dbi mdb_dbi 环境中一个数据库的句柄
MDB_env *mdb_env 整个数据环境的句柄
MDB_val mdb_key, mdb_data 存放要输入进数据库的数据值
MDB_txn *mdb_txn 数据库事物操作的句柄


LMDB 流程图

这里写图片描述

小端存储、大端存储(Little-Endian、Big-Endian)

上面的源码中,有一个函数是进行大端存储到小端存储的转换的。这部分没有计算机汇编的基础,一开始一头雾水……参考的一篇博客:http://www.cnblogs.com/passingcloudss/archive/2011/05/03/2035273.html

不同的CPU有不同的字节序类型,这些字节序是指整数在内存中保存的顺序。最常见的有两种:
1. Little-endian:将低序字节存储在起始地址(低位编址)
2. Big-endian:将高序字节存储在起始地址(高位编址)

LE(little-endian):
最符合人的思维的字节序,地址低位存储值的低位 ,地址高位存储值的高位 。
这种存储最符合人的思维的字节序,因为从人的第一观感来说,低位值小,就应该放在内存地址小的地方,也即内存地址低位。反之,高位值就应该放在内存地址大的地方,也即内存地址高位

BE(big-endian):
最直观的字节序,地址低位存储值的高位,地址高位存储值的低位
为什么说直观,不要考虑对应关系,只需要把内存地址从左到右按照由低到高的顺序写出,把值按照通常的高位到低位的顺序写出。两者对照,一个字节一个字节的填充进去 。

注: ×86 系列的 CPU 都是 Little-Endian 的字节序。

例子1:在内存中双字 0x01020304(DWORD) 的存储方式:
  内存地址为:4000 4001 4002 4003
  小端存储: 04 03 02 01
  大端存储: 01 02 03 04
注:每个地址存 1 个字节,每个字有 4 字节。2 位 16 进制数是 1 个字节(0xFF = 11111111)。

例子2:如果我们将 0x1234abcd 写入到以 0x0000 开始的内存中,则结果为:

big-endian little-endian
0x0000 0x12 0xcd
0x0001 0x23 0xab
0x0002 0xab 0x34
0x0003 0xcd 0x12


Python 读写 LMDB 格式图像数据

我想这部分才是很多人关心的,因为我们使用 caffe,将图像数据转换为 caffe 可以识别的数据格式是第一步。同时大多数都是通过 python 接口来转换数据格式的。

LMDB 数据库

Caffe 使用 LMDB 的情况大约有两类:

  • 第一类是 DataLayer 层中 使用的 训练集、验证集、测试集;
  • 第二类 就是 ./caffe/build/tools/extract_feature.bin 这种特征提取工具提取特征后,输出的特征文件。

LMDB 的全称是 Lighting Memory-Mapped Database(闪电般的内存映射数据库) 。它文件结构简单,一个文件夹,里面一个数据文件,一个锁文件。数据随意复制,随意传输。它的访问简单,不需要运行单独的数据管理进程。只要在访问的代码里引用 LMDB 库,访问时给文件路径即可。

Caffe 中使用的数据较为很简单,就是大量的矩阵/向量平铺开来。数据之间没有什么关联,数据内没有复杂的对象结构,就是向量和矩阵。既然数据并不复杂,Caffe 就选择了 LMDB 这个简单的数据库来存放数据。

上面提到了,Caffe 使用 LMDB 数据库有两点原因:

一方面是因为数据源的格式多样性,有文本文件、二进制文件图像文件等等,不可能用一个代码完成上述所有的数据格式。因此,通过 LMDB 数据库,转化成统一的数据格式可以简化数据读取层的实现。

第二个方面就是使用 LMDB 数据库可以大大的节约磁盘 IO 的时间开销。因为读取大量小文件的时间开销是相当大的,尤其是在机械硬盘上。
数据库单文件还能减少数据集复制、传输过程的开销。因为我们都有过体会,一个具有几万个、几十万个文件的数据集,不管是直接复制,还是打开再解包,过程都巨慢无比。LMDB 只有一个文件,你的介质有多快,就能复制多快,不会因为文件多而慢的令人心碎。

Caffe 中 Datum 数据结构

Caffe 并不是把向量和矩阵直接放进数据库的,而是将数据通过 caffe.proto 里定义的一个 datum 类来封装的。数据库里存放的是一个个 datum 序列化成的字符串。Datum 的定义如下

message Datum {
  optional int32 channels = 1;
  optional int32 height = 2;
  optional int32 width = 3;
  // the actual image data, in bytes
  optional bytes data = 4;
  optional int32 label = 5;
  // Optionally, the datum could also hold float data.
  repeated float float_data = 6;
  // If true data contains an encoded image that need to be decoded
  optional bool encoded = 7 [default = false];
}

一个 Datum 有三个维度,channnelsheightwidth,可以看作是少了 num 维度的 Blob
存放数据的地方有两个:bytes datafloat_data,分别存放整数型和浮点型数据。图像数据一般是整形,放在 bytes data 中,特征向量一般是浮点型,存放在 float_data 中。
label 里存放的是类别标签,是整数型。
encoded 标识数据是否需要被解码,因为里面可能存放的是 JPEG 或者 PNG 之类经过编码的数据。

Datum 这个数据结构将数据和标签封装在一起,兼容整形和浮点型数据。经过 protobuf 编译后,可以在 Python 和 C++ 中都提供高效的访问。
同时 protobuf 还为它提供了序列化、反序列化的功能。存放进 LMDB 的就是 Datum 序列化生成的字符串。

Caffe 中将图像写入 LMDB 数据库

我上面解析的 create_mnist_data.cpp 代码对于这部分是很有用的,特别是 LMDB 流程图中的 lmdb 数据操作函数,如打开一个 lmdb 数据库,写入数据等操作,python 中的使用类似,但比 C++ 的要简洁许多 。

下面通过代码来说明吧,这段代码是一个大牛写的教程:《A Practical Introduction to Deep Learning with Caffe and Python》,写的很清晰。

import os
import glob
import random
import numpy as np

import cv2

import caffe
from caffe.proto import caffe_pb2
import lmdb

#Size of images
IMAGE_WIDTH = 227
IMAGE_HEIGHT = 227

# train_lmdb、validation_lmdb 路径
train_lmdb = '/home/chenxp/Documents/vehicleID/val/train_lmdb'
validation_lmdb = '/home/chenxp/Documents/vehicleID/val/validation_lmdb'

# 如果存在了这个文件夹, 先删除
os.system('rm -rf  ' + train_lmdb)
os.system('rm -rf  ' + validation_lmdb)

# 读取图像
train_data = [img for img in glob.glob("/home/chenxp/Documents/vehicleID/val/query/*jpg")]
test_data = [img for img in glob.glob("/home/chenxp/Documents/vehicleID/val/query/*jpg")]

# Shuffle train_data
# 打乱数据的顺序
random.shuffle(train_data)

# 图像的变换, 直方图均衡化, 以及裁剪到 IMAGE_WIDTH x IMAGE_HEIGHT 的大小
def transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT):
    #Histogram Equalization
    img[:, :, 0] = cv2.equalizeHist(img[:, :, 0])
    img[:, :, 1] = cv2.equalizeHist(img[:, :, 1])
    img[:, :, 2] = cv2.equalizeHist(img[:, :, 2])

    #Image Resizing, 三次插值
    img = cv2.resize(img, (img_width, img_height), interpolation = cv2.INTER_CUBIC)
    return img

def make_datum(img, label):
    #image is numpy.ndarray format. BGR instead of RGB
    return caffe_pb2.Datum(
        channels=3,
        width=IMAGE_WIDTH,
        height=IMAGE_HEIGHT,
        label=label,
        data=np.rollaxis(img, 2).tobytes()) # or .tostring() if numpy < 1.9

# 打开 lmdb 环境, 生成一个数据文件,定义最大空间, 1e12 = 1000000000000.0
in_db = lmdb.open(train_lmdb, map_size=int(1e12)) 
with in_db.begin(write=True) as in_txn: # 创建操作数据库句柄
    for in_idx, img_path in enumerate(train_data):
        if in_idx %  6 == 0: # 只处理 5/6 的数据作为训练集
            continue         # 留下 1/6 的数据用作验证集
        # 读取图像. 做直方图均衡化、裁剪操作
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        img = transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT)

        if 'cat' in img_path: # 组织 label, 这里是如果文件名称中有 'cat', 标签就是 0
            label = 0         # 如果图像名称中没有 'cat', 有的是 'dog', 标签则为 1
        else:                 # 这里方, label 需要自己去组织
            label = 1         # 每次情况可能不一样, 灵活点

        datum = make_datum(img, label)
        # '{:0>5d}'.format(in_idx):
        #      lmdb的每一个数据都是由键值对构成的,
        #      因此生成一个用递增顺序排列的定长唯一的key
        in_txn.put('{:0>5d}'.format(in_idx), datum.SerializeToString()) #调用句柄,写入内存
        print '{:0>5d}'.format(in_idx) + ':' + img_path

# 结束后记住释放资源,否则下次用的时候打不开。。。
in_db.close() 

# 创建验证集 lmdb 格式文件
print '\nCreating validation_lmdb'
in_db = lmdb.open(validation_lmdb, map_size=int(1e12))
with in_db.begin(write=True) as in_txn:
    for in_idx, img_path in enumerate(train_data):
        if in_idx % 6 != 0:
            continue
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        img = transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT)
        if 'cat' in img_path:
            label = 0
        else:
            label = 1
        datum = make_datum(img, label)
        in_txn.put('{:0>5d}'.format(in_idx), datum.SerializeToString())
        print '{:0>5d}'.format(in_idx) + ':' + img_path
in_db.close()
print '\nFinished processing all images'

再展示一段生成 lmdb 的代码,来源自:http://deepdish.io/2015/04/28/creating-lmdb-in-python/
这段代码并没有用真实的图像数据来生成,二是用 numpy 中的 np.zeros() 生成了图像格式的数据:

import numpy as np
import lmdb
import caffe

N = 1000

# Let's pretend this is interesting data
X = np.zeros((N, 3, 32, 32), dtype=np.uint8)
y = np.zeros(N, dtype=np.int64)

# We need to prepare the database for the size. We'll set it 10 times
# greater than what we theoretically need. There is little drawback to
# setting this too big. If you still run into problem after raising
# this, you might want to try saving fewer entries in a single
# transaction.
map_size = X.nbytes * 10

env = lmdb.open('mylmdb', map_size=map_size)

with env.begin(write=True) as txn:
    # txn is a Transaction object
    for i in range(N):
        datum = caffe.proto.caffe_pb2.Datum()
        datum.channels = X.shape[1]
        datum.height = X.shape[2]
        datum.width = X.shape[3]
        datum.data = X[i].tobytes()  # or .tostring() if numpy < 1.9
        datum.label = int(y[i])
        str_id = '{:08}'.format(i)

        # The encode is only essential in Python 3
        txn.put(str_id.encode('ascii'), datum.SerializeToString())

运行上一段代码,会生成下面两个文件:

这里写图片描述


Caffe 从 LMDB 数据库中读取数据

下面就是从生成好的 lmdb 中读取数据了:

import numpy as np
import caffe
import lmdb
import cv2

# 打开 lmdb 数据库, 指定好位置
env = lmdd.open('mylmdb', readonly=True)
with env.begin() as txn:
    raw_datum = txn.get(b'00000000')

datum = caffe.proto.caffe_pb2.Datum()
datum.ParseFromString(raw_datum)

flat_x = np.fromstring(datum.data, dtype=np.uint8)
x = flat_x.reshape(datum.channels, datum.height, datum.width)
y = datum.label

print(datum.channels)
print 'label = ' + str(y) # y 为整型, 需要转成字符串

# C x H x W 转换到 H x W x C, 才能在 cv2 中显示
img = cv2.transpose(img, (1, 2, 0)) # 或者: img = x.transpose(1, 2, 0)
cv2.imshow("Image", img)
cv2.waitKey(0)

输出为:

这里写图片描述

下图是输出的图像……别笑……那是因为上面代码用 np.zeros() 生成的太小了:
这里写图片描述

可以迭代读取 <key, value>

with env.open() as txn:
    cursor = txn.cursor()
    for key, value in cursor:
        print(key, value)

下面代码用迭代循环 txn.cursor() 读取:

import caffe
from caffe.proto import caffe_pb2

import lmdb
import cv2
import numpy as np

lmdb_env = lmdb.open('mylmdb', readonly=True) # 打开数据文件
lmdb_txn = lmdb_env.begin() # 生成处理句柄
lmdb_cursor = lmdb_txn.cursor() # 生成迭代器指针
datum = caffe_pb2.Datum() # caffe 定义的数据类型

for key, value in lmdb_cursor: # 循环获取数据
    datum.ParseFromString(value) # 从 value 中读取 datum 数据

    label = datum.label
    data = caffe.io.datum_to_array(datum)
    print data.shape
    print datum.channels
    image = data.transpose(1, 2, 0)
    cv2.imshow('cv2.png', image)
    cv2.waitKey(0)

cv2.destroyAllWindows()
lmdb_env.close()


Reference

  1. 《深度学习 21天实战 Caffe》, 卜居
  2. Caffe1——Mnist数据集创建lmdb或leveldb类型的数据
  3. caffe源码阅读(1): 数据加载
  4. 愚见caffe中的LeNet
  5. 小端格式和大端格式(Little-Endian&Big-Endian)
  6. Creating an LMDB database in Python
  7. A Practical Introduction to Deep Learning with Caffe and Python
  8. 中科院自动化所博士@beanfrog:Write/Read lmdb file for caffe with python
  9. 利用caffe与lmdb读写图像数据
  10. Caffe中LMDB的使用
  11. Caffe: Reading LMDB from Python
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/u010167269/article/details/51915512

智能推荐

稀疏编码的数学基础与理论分析-程序员宅基地

文章浏览阅读290次,点赞8次,收藏10次。1.背景介绍稀疏编码是一种用于处理稀疏数据的编码技术,其主要应用于信息传输、存储和处理等领域。稀疏数据是指数据中大部分元素为零或近似于零的数据,例如文本、图像、音频、视频等。稀疏编码的核心思想是将稀疏数据表示为非零元素和它们对应的位置信息,从而减少存储空间和计算复杂度。稀疏编码的研究起源于1990年代,随着大数据时代的到来,稀疏编码技术的应用范围和影响力不断扩大。目前,稀疏编码已经成为计算...

EasyGBS国标流媒体服务器GB28181国标方案安装使用文档-程序员宅基地

文章浏览阅读217次。EasyGBS - GB28181 国标方案安装使用文档下载安装包下载,正式使用需商业授权, 功能一致在线演示在线API架构图EasySIPCMSSIP 中心信令服务, 单节点, 自带一个 Redis Server, 随 EasySIPCMS 自启动, 不需要手动运行EasySIPSMSSIP 流媒体服务, 根..._easygbs-windows-2.6.0-23042316使用文档

【Web】记录巅峰极客2023 BabyURL题目复现——Jackson原生链_原生jackson 反序列化链子-程序员宅基地

文章浏览阅读1.2k次,点赞27次,收藏7次。2023巅峰极客 BabyURL之前AliyunCTF Bypassit I这题考查了这样一条链子:其实就是Jackson的原生反序列化利用今天复现的这题也是大同小异,一起来整一下。_原生jackson 反序列化链子

一文搞懂SpringCloud,详解干货,做好笔记_spring cloud-程序员宅基地

文章浏览阅读734次,点赞9次,收藏7次。微服务架构简单的说就是将单体应用进一步拆分,拆分成更小的服务,每个服务都是一个可以独立运行的项目。这么多小服务,如何管理他们?(服务治理 注册中心[服务注册 发现 剔除])这么多小服务,他们之间如何通讯?这么多小服务,客户端怎么访问他们?(网关)这么多小服务,一旦出现问题了,应该如何自处理?(容错)这么多小服务,一旦出现问题了,应该如何排错?(链路追踪)对于上面的问题,是任何一个微服务设计者都不能绕过去的,因此大部分的微服务产品都针对每一个问题提供了相应的组件来解决它们。_spring cloud

Js实现图片点击切换与轮播-程序员宅基地

文章浏览阅读5.9k次,点赞6次,收藏20次。Js实现图片点击切换与轮播图片点击切换<!DOCTYPE html><html> <head> <meta charset="UTF-8"> <title></title> <script type="text/ja..._点击图片进行轮播图切换

tensorflow-gpu版本安装教程(过程详细)_tensorflow gpu版本安装-程序员宅基地

文章浏览阅读10w+次,点赞245次,收藏1.5k次。在开始安装前,如果你的电脑装过tensorflow,请先把他们卸载干净,包括依赖的包(tensorflow-estimator、tensorboard、tensorflow、keras-applications、keras-preprocessing),不然后续安装了tensorflow-gpu可能会出现找不到cuda的问题。cuda、cudnn。..._tensorflow gpu版本安装

随便推点

物联网时代 权限滥用漏洞的攻击及防御-程序员宅基地

文章浏览阅读243次。0x00 简介权限滥用漏洞一般归类于逻辑问题,是指服务端功能开放过多或权限限制不严格,导致攻击者可以通过直接或间接调用的方式达到攻击效果。随着物联网时代的到来,这种漏洞已经屡见不鲜,各种漏洞组合利用也是千奇百怪、五花八门,这里总结漏洞是为了更好地应对和预防,如有不妥之处还请业内人士多多指教。0x01 背景2014年4月,在比特币飞涨的时代某网站曾经..._使用物联网漏洞的使用者

Visual Odometry and Depth Calculation--Epipolar Geometry--Direct Method--PnP_normalized plane coordinates-程序员宅基地

文章浏览阅读786次。A. Epipolar geometry and triangulationThe epipolar geometry mainly adopts the feature point method, such as SIFT, SURF and ORB, etc. to obtain the feature points corresponding to two frames of images. As shown in Figure 1, let the first image be ​ and th_normalized plane coordinates

开放信息抽取(OIE)系统(三)-- 第二代开放信息抽取系统(人工规则, rule-based, 先抽取关系)_语义角色增强的关系抽取-程序员宅基地

文章浏览阅读708次,点赞2次,收藏3次。开放信息抽取(OIE)系统(三)-- 第二代开放信息抽取系统(人工规则, rule-based, 先关系再实体)一.第二代开放信息抽取系统背景​ 第一代开放信息抽取系统(Open Information Extraction, OIE, learning-based, 自学习, 先抽取实体)通常抽取大量冗余信息,为了消除这些冗余信息,诞生了第二代开放信息抽取系统。二.第二代开放信息抽取系统历史第二代开放信息抽取系统着眼于解决第一代系统的三大问题: 大量非信息性提取(即省略关键信息的提取)、_语义角色增强的关系抽取

10个顶尖响应式HTML5网页_html欢迎页面-程序员宅基地

文章浏览阅读1.1w次,点赞6次,收藏51次。快速完成网页设计,10个顶尖响应式HTML5网页模板助你一臂之力为了寻找一个优质的网页模板,网页设计师和开发者往往可能会花上大半天的时间。不过幸运的是,现在的网页设计师和开发人员已经开始共享HTML5,Bootstrap和CSS3中的免费网页模板资源。鉴于网站模板的灵活性和强大的功能,现在广大设计师和开发者对html5网站的实际需求日益增长。为了造福大众,Mockplus的小伙伴整理了2018年最..._html欢迎页面

计算机二级 考试科目,2018全国计算机等级考试调整,一、二级都增加了考试科目...-程序员宅基地

文章浏览阅读282次。原标题:2018全国计算机等级考试调整,一、二级都增加了考试科目全国计算机等级考试将于9月15-17日举行。在备考的最后冲刺阶段,小编为大家整理了今年新公布的全国计算机等级考试调整方案,希望对备考的小伙伴有所帮助,快随小编往下看吧!从2018年3月开始,全国计算机等级考试实施2018版考试大纲,并按新体系开考各个考试级别。具体调整内容如下:一、考试级别及科目1.一级新增“网络安全素质教育”科目(代..._计算机二级增报科目什么意思

conan简单使用_apt install conan-程序员宅基地

文章浏览阅读240次。conan简单使用。_apt install conan