【转载】 PyTorch下训练数据小文件转大文件读写(附有各种存储格式对比)
版权声明:本文为CSDN博主「Liekkas Kono」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:
https://blog.csdn.net/shiwanghualuo/article/details/120778553
=======================================================
引言
Tensorflow有着专门的数据读取模块tfrecord,可以高效地读取训练神经网络模型所用的数据,充分喂饱GPU
Caffe用lmdb来读取数据,也可以很高效地去读取
PyTorch有DataLoader读取数据,但是速度比较慢,尤其是小文件较多情况下
如何基于PyTorch,高效读取数据,充分利用GPU性能,成为一个关键问题?
TFRecord
- 是否可以将tensorflow下的tfrecord借来一用?未尝不可
- 目前已经有伙伴实现了,详情参见:tfrecord
- 同时,在Kaggle上,也有大神手动实现,详情参见:PyTorch TFRecord-Loader
tfrecord写入代码:
import cv2
import numpy as np
import tensorflow as tf
from tqdm import tqdm from data_loader import TFRecordDataLoader def read_txt(txt_path):
with open(txt_path, 'r', encoding='utf-8') as f:
data = f.readlines()
data = list(map(lambda x: x.rstrip('\n'), data))
return data def bytes_to_numpy(image_bytes):
image_np = np.frombuffer(image_bytes, dtype=np.uint8)
image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
return image_np2 def list_record_features(tfrecords_path):
"""查看tfrecords结构
https://stackoverflow.com/questions/63562691/reading-a-tfrecord-file-where-features-that-were-used-to-encode-is-not-known Args:
tfrecords_path (str): tfrecords路径 Returns:
dict: 结构信息
"""
features = {}
dataset = tf.data.TFRecordDataset([str(tfrecords_path)])
data = next(iter(dataset)) example = tf.train.Example()
example_bytes = data.numpy()
example.ParseFromString(example_bytes) for key, value in example.features.feature.items():
kind = value.WhichOneof('kind')
size = len(getattr(value, kind).value)
if key in features:
kind2, size2 = features[key]
if kind != kind2:
kind = None if size != size2:
size = None
features[key] = (kind, size)
return features class TFRecorder(object):
def __init__(self) -> None:
super().__init__()
self.feature_dict = {
'height': None,
'width': None,
'depth': None,
'label': None,
'image_raw': None
}
self.AUTO = tf.data.experimental.AUTOTUNE def image_to_feature(self, image_string, label):
height, width, channel = tf.image.decode_image(image_string).shape
self.feature_dict = {
'height': self._int64_feature(height),
'width': self._int64_feature(width),
'depth': self._int64_feature(channel),
'label': self._int64_feature(label),
'image_raw': self._bytes_feature(image_string)
}
return tf.train.Example(features=tf.train.Features(feature=self.feature_dict)) def write(self, save_path, img_label_dict):
with tf.io.TFRecordWriter(save_path) as writer:
for file_name, label in tqdm(img_label_dict.items()):
img_string = open(file_name, 'rb').read()
feature = self.image_to_feature(img_string, label)
writer.write(feature.SerializeToString()) def read(self, tfrecord_path):
reader = tf.data.TFRecordDataset(tfrecord_path)
dataset = reader.map(self._parse_image_function,
num_parallel_calls=self.AUTO)
return dataset def _parse_image_function(self, example_proto):
self.feature_dict = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'label': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string)
}
example = tf.io.parse_single_example(example_proto,
self.feature_dict)
return example @staticmethod
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
# BytesList won't unpack a string from an EagerTensor.
value = value.numpy()
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) @staticmethod
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) @staticmethod
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) if __name__ == '__main__':
tfrecorder = TFRecorder() # val.txt中存放的是图像的相对路径
img_path = read_txt('dataset/val.txt') # Path(v).parent.name: 图像的标签
img_label_dict = {v: int(Path(v).parent.name) for v in img_path} save_path = 'temp/val.tfrecords'
tfrecorder.write(save_path, img_label_dict) dataset = tfrecorder.read('dataset/val.tfrecords')
for v in dataset:
img, label = v
print('ok') # 查看未知tfrecords结构信息
list_record_features('xxxx.tfrecords')
基于PyTorch下tfrecord读取代码
import cv2
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds AUTO = tf.data.experimental.AUTOTUNE def bytes_to_numpy(image_bytes):
image_np = np.frombuffer(image_bytes, dtype=np.uint8)
image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
return image_np2 def read_labeled_tfrecord(example_proto):
feature_dict = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'label': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string)
}
example = tf.io.parse_single_example(example_proto,
feature_dict)
img = tf.io.decode_image(example['image_raw'], channels=3,
expand_animations=False)
img = tf.image.resize_with_crop_or_pad(img,
target_height=388,
target_width=270)
return img, example['label'] def get_dataset(files, batch_size=16, repeat=False,
cache=False, shuffle=False):
ds = tf.data.TFRecordDataset(files, num_parallel_reads=AUTO)
if cache:
ds = ds.cache() if repeat:
ds = ds.repeat() if shuffle:
ds = ds.shuffle(1024 * 2)
opt = tf.data.Options()
opt.experimental_deterministic = False
ds = ds.with_options(opt) ds = ds.map(read_labeled_tfrecord, num_parallel_calls=AUTO)
ds = ds.batch(batch_size)
ds = ds.prefetch(AUTO)
return tfds.as_numpy(ds) def count_data_items(file):
num_ds = tf.data.TFRecordDataset(file, num_parallel_reads=AUTO)
num_ds = num_ds.map(read_labeled_tfrecord, num_parallel_calls=AUTO)
num_ds = num_ds.repeat(1)
num_ds = num_ds.batch(1) c = 0
for _ in num_ds:
c += 1
del num_ds
return c class TFRecordDataLoader:
def __init__(self, files, batch_size=32, cache=False, train=True,
repeat=False, shuffle=False, labeled=True,
return_image_ids=True):
self.ds = get_dataset(
files,
batch_size=batch_size,
cache=cache,
repeat=repeat,
shuffle=shuffle,) if train:
self.num_examples = count_data_items(files) self.batch_size = batch_size
self.labeled = labeled
self.return_image_ids = return_image_ids
self._iterator = None def __iter__(self):
if self._iterator is None:
self._iterator = iter(self.ds)
else:
self._reset()
return self._iterator def _reset(self):
self._iterator = iter(self.ds) def __next__(self):
batch = next(self._iterator)
return batch def __len__(self):
n_batches = self.num_examples // self.batch_size
if self.num_examples % self.batch_size == 0:
return n_batches
else:
return n_batches + 1 # 使用
train_txt_path = 'dataset/minist/train.tfrecords'
train_dataloader = TFRecordDataLoader(train_txt_path,
batch_size=batch_size,
shuffle=True)
for v in train_dataloader:
pass
LMDB
- 纵观各大论坛,说到基于PyTorch下提高小文件读取速度,不得不说到LMDB(Lightning Memory-Mapped Database)了,我也做了一些尝试,最终结论将在最后给出
写入LMDB
import os
import pickle
from pathlib import Path import cv2
import lmdb
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm import tqdm import utils class SimpleDataset(Dataset):
def __init__(self, txt_path, transform=None) -> None:
self.img_paths = utils.read_txt(txt_path)
self.transform = transform def __getitem__(self, index: int):
img_path = self.img_paths[index]
label = int(Path(img_path).parent.name)
try:
img = Image.open(img_path)
img = img.convert('RGB')
except:
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = Image.fromarray(img) if self.transform:
img = self.transform(img)
img = np.array(img)
return img, label def __len__(self) -> int:
return len(self.img_paths) class LMDB_Image:
def __init__(self, image, label):
# Dimensions of image for reconstruction - not really necessary
# for this dataset, but some datasets may include images of
# varying sizes
self.channels = image.shape[2]
self.size = image.shape[:2] self.image = image.tobytes()
self.label = label def get_image(self):
""" Returns the image as a numpy array. """
image = np.frombuffer(self.image, dtype=np.uint8)
return image.reshape(*self.size, self.channels) def data2lmdb(dpath, name="train", txt_path=None,
write_frequency=10, num_workers=4):
dataset = SimpleDataset(txt_path=txt_path)
data_loader = DataLoader(dataset, num_workers=num_workers,
collate_fn=lambda x: x) lmdb_path = os.path.join(dpath, "%s.lmdb" % name)
isdir = os.path.isdir(lmdb_path) print("Generate LMDB to %s" % lmdb_path)
db = lmdb.open(lmdb_path, subdir=isdir,
map_size=1099511627776, # 单位byte
readonly=False,
meminit=False,
map_async=True) txn = db.begin(write=True)
for idx, data in enumerate(tqdm(data_loader)):
image, label = data[0]
temp = LMDB_Image(image, label)
txn.put(u'{}'.format(idx).encode('ascii'), pickle.dumps(temp)) if idx % write_frequency == 0:
print("[%d/%d]" % (idx, len(data_loader)))
txn.commit()
txn = db.begin(write=True) # finish iterating through dataset
txn.commit() keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
with db.begin(write=True) as txn:
txn.put(b'__keys__', pickle.dumps(keys))
txn.put(b'__len__', pickle.dumps(len(keys))) print("Flushing database ...")
db.sync()
db.close() if __name__ == '__main__':
save_dir = 'dataset/minist'
data2lmdb(save_dir, name='val', txt_path='dataset/minist/val.txt')
读取LMDB
class DatasetLMDB(Dataset):
def __init__(self, db_path, transform=None):
self.db_path = db_path
self.env = lmdb.open(db_path,
subdir=os.path.isdir(db_path),
readonly=True, lock=False,
readahead=False, meminit=False)
with self.env.begin() as txn:
self.length = pickle.loads(txn.get(b'__len__'))
self.keys = pickle.loads(txn.get(b'__keys__'))
self.transform = transform def __getitem__(self, index):
with self.env.begin() as txn:
byteflow = txn.get(self.keys[index]) IMAGE = pickle.loads(byteflow)
img, label = IMAGE.get_image(), IMAGE.label
return Image.fromarray(img).convert('RGB'), label def __len__(self):
return self.length # 使用
train_transforms = transforms.Compose([
transforms.Resize((388, 270)),
transforms.RandomChoice([
transforms.RandomRotation(10),
transforms.RandomHorizontalFlip(0.5),
transforms.RandomGrayscale(p=0.3),
transforms.RandomPerspective(distortion_scale=0.6, p=0.5),
transforms.ColorJitter(brightness=.5, hue=.3),
]),
transforms.ToTensor(),
normalize,
transforms.RandomErasing(),
]) train_dataset = DatasetLMDB(train_txt_path, train_transforms)
train_dataloader = DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=n_worker,
pin_memory=True)
# do other things
二进制大文件
- 直接将现有数据集按照二进制读取,存入一个
bins的大文件中,也不失为一种选择
写入bins
import cv2
import numpy as np
from tqdm import tqdm def write_bin(save_bin_path, save_index_path, data):
"""将现有基于文件的数据集写为bin大文件
写入到save_index_path中的索引位置和标签,中间以\t分割 Args:
save_bin_path (str): 保存bin的位置
save_index_path (str): 保存bin中索引和对应标签
data (str): 存放图像路径和对应标签的list,
e.g. [['xxx/1.jpg', 'cat'], ['xxx/2.jpg', 'dog']]
"""
with open(save_bin_path, 'wb') as f_w, \
open(save_index_path, 'w') as f_index:
start_index = 0 for img_path, label in tqdm(data):
with open(img_path, 'rb') as f:
img_bin = f.read() f_w.write(img_bin) len_bin = len(img_bin)
f_index.write(f'{start_index}\t{len_bin}\t{label}\n') start_index += len_bin def read_bin(bin_path, index_path):
"""读取bin大文件和对应的索引标签txt Args:
bin_path (str): bin大文件存放路径
index_path (str): 索引和标签存放txt的路径
"""
with open(bin_path, 'rb') as f_bin, open(index_path, 'r') as f_index:
index_lines = list(map(lambda x: x.strip(), f_index.readlines()))
index_lines = list(map(lambda x: x.split('\t'), index_lines)) for i, (start_index, length) in enumerate(index_lines):
start_index = int(start_index)
length = int(length.strip()) # 定位到当前指针位置到start_index
f_bin.seek(start_index) # 读取length的字节值
img_bytes = f_bin.read(length) img = np.frombuffer(img_bytes, dtype='uint8')
img = cv2.imdecode(img, -1) # -1: cv.IMREAD_UNCHANGED # 转为PIL
# img = Image.fromarray(img)
# img = img.convert('RGB') # 保存图像
# cv2.imwrite(f'temp/images/{i}.jpg', img)
Sqlite
- 采用python内置的sqlite3作为存储格式,也是一种好的选择
写入到sqlite数据库中
import sqlite3
from pathlib import Path from tqdm import tqdm def read_txt(txt_path):
with open(txt_path, 'r', encoding='utf-8-sig') as f:
data = list(map(lambda x: x.rstrip('\n'), f))
return data def img_to_bytes(img_path):
with open(img_path, 'rb') as f:
img_bytes = f.read()
return img_bytes class SQLiteWriter(object):
def __init__(self, db_path):
self.conn = sqlite3.connect(db_path)
self.cursor = self.conn.cursor() def execute(self, sql, value=None):
if value:
self.cursor.execute(sql, value)
else:
self.cursor.execute(sql) def __enter__(self):
return self def __exit__(self, exc_type, exc_val, exc_tb):
self.cursor.close()
self.conn.commit()
self.conn.close() if __name__ == '__main__':
dataset_dir = Path('datasets/minist') save_db_dir = dataset_dir / 'sqlite'
save_db_path = str(save_db_dir / 'val.db') # val.txt中 每行为:图像路径\t对应文本值 e.g. xxxx.jpg\txxxxxx
img_paths = read_txt(str(dataset_dir / 'val.txt')) with SQLiteWriter(save_db_path) as db_writer:
# 创建表
table_name = 'minist' # 注意这里的表中字段,要根据自己数据集来定义
# 具体数据库类型,可参考:https://docs.python.org/zh-cn/3/library/sqlite3.html#sqlite-and-python-types
# demo中示例所涉及到的数据集为文本识别数据集,样本为图像,标签为对应文本,
# 下面示例字段的数据类型为python下的数据类型,只需转为以下对应数据类型即可写入数据库的表中
# e.g. img_path: str(xxxx.jpg), img_data: bytes格式的图像数据, img_label: str(xxxxx)
create_table_sql = f'create table {table_name} (img_path TEXT primary key, img_data BLOB, img_label TEXT)'
db_writer.execute(create_table_sql) # 向表中插入数据,value部分采用占位符
insert_sql = f'insert into {table_name} (img_path, img_data, img_label) values(?, ?, ?)'
for img_info in tqdm(img_paths):
img_path, label = img_info.split('\t') img_full_path = str(dataset_dir / 'images' / img_path)
img_data = img_to_bytes(img_full_path) db_writer.execute(insert_sql, (img_path, img_data, label))
读取数据库
class SimpleDataset(Dataset):
def __init__(self, db_path, transform=None) -> None:
self.db_path = db_path
self.conn = None
self.establish_conn() # 数据库中表名
self.table_name = 'Synthetic_chinese_dataset' self.cursor.execute(f'select max(rowid) from {self.table_name}')
self.nums = self.cursor.fetchall()[0][0]
self.transform = transform def __getitem__(self, index: int):
self.establish_conn() # 查询
search_sql = f'select * from {self.table_name} where rowid=?'
self.cursor.execute(search_sql, (index+1, ))
img_path, img_bytes, label = self.cursor.fetchone() # 还原图像和标签
img = Image.open(BytesIO(img_bytes))
img = img.convert('RGB')
img = scale_resize_pillow(img, (320, 32)) if self.transform:
img = self.transform(img)
return img, label def __len__(self) -> int:
return self.nums def establish_conn(self):
if self.conn is None:
self.conn = sqlite3.connect(self.db_path,
check_same_thread=False,
cached_statements=1024)
self.cursor = self.conn.cursor()
return self def close_conn(self):
if self.conn is not None:
self.cursor.close()
self.conn.close() del self.conn
self.conn = None
return self # --------------------------------------------------
train_dataset = SimpleDataset(train_db_path, train_transforms)
# ✧✧使用部分,需要手动关闭数据库连接
train_dataset.close_conn()
train_dataloader = DataLoader(train_dataset,
batch_size=batch_size,
num_workers=n_worker,
pin_memory=True,
sampler=train_sampler)
最终结论
TFRecord
转换前后,数据存储大小不变,可以充分利用GPU
tfrecord不能接入到其他数据增强方式(imgaug,opencv),且数据增强方式十分有限
LMDB
转换前后,数据存储大小会变得很大(原始4.2G→转换后96G)
PyTorch多进程读取数据时,会出现图像不能还原为原始图像问题,暂时未找到解决方案
读取效率可以充分利用GPU
二进制大文件
转换前后,数据存储大小不变
同样,PyTorch多进程读取,也会出现图像不能正确还原的问题,暂时未找到解决方案
sqlite(推荐使用)
转换前后,数据存储大小不变
可以正常多进程读取
参考资料
=====================================================
引言Tensorflow有着专门的数据读取模块tfrecord,可以高效地读取训练神经网络模型所用的数据,充分喂饱GPUCaffe用lmdb来读取数据,也可以很高效地去读取PyTorch有DataLoader读取数据,但是速度比较慢,尤其是小文件较多情况下如何基于PyTorch,高效读取数据,充分利用GPU性能,成为一个关键问题?TFRecord是否可以将tensorflow下的tfrecord借来一用?未尝不可目前已经有伙伴实现了,详情参见:tfrecord同时,在Kaggle上,也有大神手动实现,详情参见:PyTorch TFRecord-Loadertfrecord写入代码:import cv2import numpy as npimport tensorflow as tffrom tqdm import tqdm from data_loader import TFRecordDataLoader def read_txt(txt_path): with open(txt_path, 'r', encoding='utf-8') as f: data = f.readlines() data = list(map(lambda x: x.rstrip('\n'), data)) return data def bytes_to_numpy(image_bytes): image_np = np.frombuffer(image_bytes, dtype=np.uint8) image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR) return image_np2 def list_record_features(tfrecords_path): """查看tfrecords结构 https://stackoverflow.com/questions/63562691/reading-a-tfrecord-file-where-features-that-were-used-to-encode-is-not-known Args: tfrecords_path (str): tfrecords路径 Returns: dict: 结构信息 """ features = {} dataset = tf.data.TFRecordDataset([str(tfrecords_path)]) data = next(iter(dataset)) example = tf.train.Example() example_bytes = data.numpy() example.ParseFromString(example_bytes) for key, value in example.features.feature.items(): kind = value.WhichOneof('kind') size = len(getattr(value, kind).value) if key in features: kind2, size2 = features[key] if kind != kind2: kind = None if size != size2: size = None features[key] = (kind, size) return features class TFRecorder(object): def __init__(self) -> None: super().__init__() self.feature_dict = { 'height': None, 'width': None, 'depth': None, 'label': None, 'image_raw': None } self.AUTO = tf.data.experimental.AUTOTUNE def image_to_feature(self, image_string, label): height, width, channel = tf.image.decode_image(image_string).shape self.feature_dict = { 'height': self._int64_feature(height), 'width': self._int64_feature(width), 'depth': self._int64_feature(channel), 'label': self._int64_feature(label), 'image_raw': self._bytes_feature(image_string) } return tf.train.Example(features=tf.train.Features(feature=self.feature_dict)) def write(self, save_path, img_label_dict): with tf.io.TFRecordWriter(save_path) as writer: for file_name, label in tqdm(img_label_dict.items()): img_string = open(file_name, 'rb').read() feature = self.image_to_feature(img_string, label) writer.write(feature.SerializeToString()) def read(self, tfrecord_path): reader = tf.data.TFRecordDataset(tfrecord_path) dataset = reader.map(self._parse_image_function, num_parallel_calls=self.AUTO) return dataset def _parse_image_function(self, example_proto): self.feature_dict = { 'height': tf.io.FixedLenFeature([], tf.int64), 'width': tf.io.FixedLenFeature([], tf.int64), 'depth': tf.io.FixedLenFeature([], tf.int64), 'label': tf.io.FixedLenFeature([], tf.int64), 'image_raw': tf.io.FixedLenFeature([], tf.string) } example = tf.io.parse_single_example(example_proto, self.feature_dict) return example @staticmethod def _bytes_feature(value): """Returns a bytes_list from a string / byte.""" if isinstance(value, type(tf.constant(0))): # BytesList won't unpack a string from an EagerTensor. value = value.numpy() return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) @staticmethod def _float_feature(value): """Returns a float_list from a float / double.""" return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) @staticmethod def _int64_feature(value): """Returns an int64_list from a bool / enum / int / uint.""" return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) if __name__ == '__main__': tfrecorder = TFRecorder() # val.txt中存放的是图像的相对路径 img_path = read_txt('dataset/val.txt') # Path(v).parent.name: 图像的标签 img_label_dict = {v: int(Path(v).parent.name) for v in img_path} save_path = 'temp/val.tfrecords' tfrecorder.write(save_path, img_label_dict) dataset = tfrecorder.read('dataset/val.tfrecords') for v in dataset: img, label = v print('ok') # 查看未知tfrecords结构信息 list_record_features('xxxx.tfrecords')123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137基于PyTorch下tfrecord读取代码import cv2import numpy as npimport tensorflow as tfimport tensorflow_datasets as tfds
AUTO = tf.data.experimental.AUTOTUNE
def bytes_to_numpy(image_bytes): image_np = np.frombuffer(image_bytes, dtype=np.uint8) image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR) return image_np2
def read_labeled_tfrecord(example_proto): feature_dict = { 'height': tf.io.FixedLenFeature([], tf.int64), 'width': tf.io.FixedLenFeature([], tf.int64), 'depth': tf.io.FixedLenFeature([], tf.int64), 'label': tf.io.FixedLenFeature([], tf.int64), 'image_raw': tf.io.FixedLenFeature([], tf.string) } example = tf.io.parse_single_example(example_proto, feature_dict) img = tf.io.decode_image(example['image_raw'], channels=3, expand_animations=False) img = tf.image.resize_with_crop_or_pad(img, target_height=388, target_width=270) return img, example['label']
def get_dataset(files, batch_size=16, repeat=False, cache=False, shuffle=False): ds = tf.data.TFRecordDataset(files, num_parallel_reads=AUTO) if cache: ds = ds.cache()
if repeat: ds = ds.repeat()
if shuffle: ds = ds.shuffle(1024 * 2) opt = tf.data.Options() opt.experimental_deterministic = False ds = ds.with_options(opt)
ds = ds.map(read_labeled_tfrecord, num_parallel_calls=AUTO) ds = ds.batch(batch_size) ds = ds.prefetch(AUTO) return tfds.as_numpy(ds)
def count_data_items(file): num_ds = tf.data.TFRecordDataset(file, num_parallel_reads=AUTO) num_ds = num_ds.map(read_labeled_tfrecord, num_parallel_calls=AUTO) num_ds = num_ds.repeat(1) num_ds = num_ds.batch(1)
c = 0 for _ in num_ds: c += 1 del num_ds return c
class TFRecordDataLoader: def __init__(self, files, batch_size=32, cache=False, train=True, repeat=False, shuffle=False, labeled=True, return_image_ids=True): self.ds = get_dataset( files, batch_size=batch_size, cache=cache, repeat=repeat, shuffle=shuffle,)
if train: self.num_examples = count_data_items(files)
self.batch_size = batch_size self.labeled = labeled self.return_image_ids = return_image_ids self._iterator = None
def __iter__(self): if self._iterator is None: self._iterator = iter(self.ds) else: self._reset() return self._iterator
def _reset(self): self._iterator = iter(self.ds)
def __next__(self): batch = next(self._iterator) return batch
def __len__(self): n_batches = self.num_examples // self.batch_size if self.num_examples % self.batch_size == 0: return n_batches else: return n_batches + 1
# 使用train_txt_path = 'dataset/minist/train.tfrecords'train_dataloader = TFRecordDataLoader(train_txt_path, batch_size=batch_size, shuffle=True)for v in train_dataloader: pass123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113LMDB纵观各大论坛,说到基于PyTorch下提高小文件读取速度,不得不说到LMDB(Lightning Memory-Mapped Database)了,我也做了一些尝试,最终结论将在最后给出写入LMDBimport osimport picklefrom pathlib import Path import cv2import lmdbimport numpy as npfrom PIL import Imagefrom torch.utils.data import DataLoader, Datasetfrom torchvision import transformsfrom tqdm import tqdm import utils class SimpleDataset(Dataset): def __init__(self, txt_path, transform=None) -> None: self.img_paths = utils.read_txt(txt_path) self.transform = transform def __getitem__(self, index: int): img_path = self.img_paths[index] label = int(Path(img_path).parent.name) try: img = Image.open(img_path) img = img.convert('RGB') except: img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = Image.fromarray(img) if self.transform: img = self.transform(img) img = np.array(img) return img, label def __len__(self) -> int: return len(self.img_paths) class LMDB_Image: def __init__(self, image, label): # Dimensions of image for reconstruction - not really necessary # for this dataset, but some datasets may include images of # varying sizes self.channels = image.shape[2] self.size = image.shape[:2] self.image = image.tobytes() self.label = label def get_image(self): """ Returns the image as a numpy array. """ image = np.frombuffer(self.image, dtype=np.uint8) return image.reshape(*self.size, self.channels) def data2lmdb(dpath, name="train", txt_path=None, write_frequency=10, num_workers=4): dataset = SimpleDataset(txt_path=txt_path) data_loader = DataLoader(dataset, num_workers=num_workers, collate_fn=lambda x: x) lmdb_path = os.path.join(dpath, "%s.lmdb" % name) isdir = os.path.isdir(lmdb_path) print("Generate LMDB to %s" % lmdb_path) db = lmdb.open(lmdb_path, subdir=isdir, map_size=1099511627776, # 单位byte readonly=False, meminit=False, map_async=True) txn = db.begin(write=True) for idx, data in enumerate(tqdm(data_loader)): image, label = data[0] temp = LMDB_Image(image, label) txn.put(u'{}'.format(idx).encode('ascii'), pickle.dumps(temp)) if idx % write_frequency == 0: print("[%d/%d]" % (idx, len(data_loader))) txn.commit() txn = db.begin(write=True) # finish iterating through dataset txn.commit() keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)] with db.begin(write=True) as txn: txn.put(b'__keys__', pickle.dumps(keys)) txn.put(b'__len__', pickle.dumps(len(keys))) print("Flushing database ...") db.sync() db.close() if __name__ == '__main__': save_dir = 'dataset/minist' data2lmdb(save_dir, name='val', txt_path='dataset/minist/val.txt')123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899读取LMDBclass DatasetLMDB(Dataset): def __init__(self, db_path, transform=None): self.db_path = db_path self.env = lmdb.open(db_path, subdir=os.path.isdir(db_path), readonly=True, lock=False, readahead=False, meminit=False) with self.env.begin() as txn: self.length = pickle.loads(txn.get(b'__len__')) self.keys = pickle.loads(txn.get(b'__keys__')) self.transform = transform def __getitem__(self, index): with self.env.begin() as txn: byteflow = txn.get(self.keys[index]) IMAGE = pickle.loads(byteflow) img, label = IMAGE.get_image(), IMAGE.label return Image.fromarray(img).convert('RGB'), label def __len__(self): return self.length # 使用train_transforms = transforms.Compose([ transforms.Resize((388, 270)), transforms.RandomChoice([ transforms.RandomRotation(10), transforms.RandomHorizontalFlip(0.5), transforms.RandomGrayscale(p=0.3), transforms.RandomPerspective(distortion_scale=0.6, p=0.5), transforms.ColorJitter(brightness=.5, hue=.3), ]), transforms.ToTensor(), normalize, transforms.RandomErasing(), ]) train_dataset = DatasetLMDB(train_txt_path, train_transforms)train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_worker, pin_memory=True)# do other things123456789101112131415161718192021222324252627282930313233343536373839404142434445二进制大文件直接将现有数据集按照二进制读取,存入一个bins的大文件中,也不失为一种选择写入binsimport cv2import numpy as npfrom tqdm import tqdm def write_bin(save_bin_path, save_index_path, data): """将现有基于文件的数据集写为bin大文件 写入到save_index_path中的索引位置和标签,中间以\t分割 Args: save_bin_path (str): 保存bin的位置 save_index_path (str): 保存bin中索引和对应标签 data (str): 存放图像路径和对应标签的list, e.g. [['xxx/1.jpg', 'cat'], ['xxx/2.jpg', 'dog']] """ with open(save_bin_path, 'wb') as f_w, \ open(save_index_path, 'w') as f_index: start_index = 0 for img_path, label in tqdm(data): with open(img_path, 'rb') as f: img_bin = f.read() f_w.write(img_bin) len_bin = len(img_bin) f_index.write(f'{start_index}\t{len_bin}\t{label}\n') start_index += len_bin def read_bin(bin_path, index_path): """读取bin大文件和对应的索引标签txt Args: bin_path (str): bin大文件存放路径 index_path (str): 索引和标签存放txt的路径 """ with open(bin_path, 'rb') as f_bin, open(index_path, 'r') as f_index: index_lines = list(map(lambda x: x.strip(), f_index.readlines())) index_lines = list(map(lambda x: x.split('\t'), index_lines)) for i, (start_index, length) in enumerate(index_lines): start_index = int(start_index) length = int(length.strip()) # 定位到当前指针位置到start_index f_bin.seek(start_index) # 读取length的字节值 img_bytes = f_bin.read(length) img = np.frombuffer(img_bytes, dtype='uint8') img = cv2.imdecode(img, -1) # -1: cv.IMREAD_UNCHANGED # 转为PIL # img = Image.fromarray(img) # img = img.convert('RGB') # 保存图像 # cv2.imwrite(f'temp/images/{i}.jpg', img)12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061读取binsfrom io import BytesIOfrom PIL import Imageimport cv2import numpy as np class SimpleDataset(Dataset): def __init__(self, txt_path, bin_path, transform=None) -> None: self.index_info = utils.read_txt(txt_path) self.index_info = list(map(lambda x: x.split('\t'), self.index_info)) self.f_bin = open(bin_path, 'rb') self.transform = transform def __getitem__(self, index: int): start_index, length, label = list(map(int, self.index_info[index])) print(start_index) self.f_bin.seek(start_index) img_bytes = self.f_bin.read(length) # 方案一: img = np.frombuffer(img_bytes, dtype='uint8') img = cv2.imdecode(img, -1) if img is None: return self.__getitem__(random.randint(0, self.__len__() - 1)) img = Image.fromarray(img) img = img.convert('RGB') # 方案二: try: img = Image.open(BytesIO(img_bytes)) img = img.convert('RGB') except: return self.__getitem__(random.randint(0, self.__len__() - 1)) if self.transform: img = self.transform(img) return img, label def __len__(self) -> int: return len(self.index_info)123456789101112131415161718192021222324252627282930313233343536373839404142Sqlite采用python内置的sqlite3作为存储格式,也是一种好的选择写入到sqlite数据库中import sqlite3from pathlib import Path from tqdm import tqdm def read_txt(txt_path): with open(txt_path, 'r', encoding='utf-8-sig') as f: data = list(map(lambda x: x.rstrip('\n'), f)) return data def img_to_bytes(img_path): with open(img_path, 'rb') as f: img_bytes = f.read() return img_bytes class SQLiteWriter(object): def __init__(self, db_path): self.conn = sqlite3.connect(db_path) self.cursor = self.conn.cursor() def execute(self, sql, value=None): if value: self.cursor.execute(sql, value) else: self.cursor.execute(sql) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.cursor.close() self.conn.commit() self.conn.close() if __name__ == '__main__': dataset_dir = Path('datasets/minist') save_db_dir = dataset_dir / 'sqlite' save_db_path = str(save_db_dir / 'val.db') # val.txt中 每行为:图像路径\t对应文本值 e.g. xxxx.jpg\txxxxxx img_paths = read_txt(str(dataset_dir / 'val.txt')) with SQLiteWriter(save_db_path) as db_writer: # 创建表 table_name = 'minist' # 注意这里的表中字段,要根据自己数据集来定义 # 具体数据库类型,可参考:https://docs.python.org/zh-cn/3/library/sqlite3.html#sqlite-and-python-types # demo中示例所涉及到的数据集为文本识别数据集,样本为图像,标签为对应文本, # 下面示例字段的数据类型为python下的数据类型,只需转为以下对应数据类型即可写入数据库的表中 # e.g. img_path: str(xxxx.jpg), img_data: bytes格式的图像数据, img_label: str(xxxxx) create_table_sql = f'create table {table_name} (img_path TEXT primary key, img_data BLOB, img_label TEXT)' db_writer.execute(create_table_sql) # 向表中插入数据,value部分采用占位符 insert_sql = f'insert into {table_name} (img_path, img_data, img_label) values(?, ?, ?)' for img_info in tqdm(img_paths): img_path, label = img_info.split('\t') img_full_path = str(dataset_dir / 'images' / img_path) img_data = img_to_bytes(img_full_path) db_writer.execute(insert_sql, (img_path, img_data, label))1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768读取数据库class SimpleDataset(Dataset): def __init__(self, db_path, transform=None) -> None: self.db_path = db_path self.conn = None self.establish_conn() # 数据库中表名 self.table_name = 'Synthetic_chinese_dataset' self.cursor.execute(f'select max(rowid) from {self.table_name}') self.nums = self.cursor.fetchall()[0][0] self.transform = transform def __getitem__(self, index: int): self.establish_conn() # 查询 search_sql = f'select * from {self.table_name} where rowid=?' self.cursor.execute(search_sql, (index+1, )) img_path, img_bytes, label = self.cursor.fetchone() # 还原图像和标签 img = Image.open(BytesIO(img_bytes)) img = img.convert('RGB') img = scale_resize_pillow(img, (320, 32)) if self.transform: img = self.transform(img) return img, label def __len__(self) -> int: return self.nums def establish_conn(self): if self.conn is None: self.conn = sqlite3.connect(self.db_path, check_same_thread=False, cached_statements=1024) self.cursor = self.conn.cursor() return self def close_conn(self): if self.conn is not None: self.cursor.close() self.conn.close() del self.conn self.conn = None return self # --------------------------------------------------train_dataset = SimpleDataset(train_db_path, train_transforms)# ✧✧使用部分,需要手动关闭数据库连接train_dataset.close_conn()train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=n_worker, pin_memory=True, sampler=train_sampler)1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859最终结论TFRecord
转换前后,数据存储大小不变,可以充分利用GPUtfrecord不能接入到其他数据增强方式(imgaug,opencv),且数据增强方式十分有限LMDB
转换前后,数据存储大小会变得很大(原始4.2G→转换后96G)PyTorch多进程读取数据时,会出现图像不能还原为原始图像问题,暂时未找到解决方案读取效率可以充分利用GPU二进制大文件
转换前后,数据存储大小不变同样,PyTorch多进程读取,也会出现图像不能正确还原的问题,暂时未找到解决方案✧ sqlite(推荐使用)
转换前后,数据存储大小不变可以正常多进程读取参考资料pytorch-sqlitesqlite_dataset————————————————版权声明:本文为CSDN博主「Liekkas Kono」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。原文链接:https://blog.csdn.net/shiwanghualuo/article/details/120778553
【转载】 PyTorch下训练数据小文件转大文件读写(附有各种存储格式对比)的更多相关文章
- C++中使用内存映射文件处理大文件
引言 文件操作是应用程序最为基本的功能之一,Win32 API和MFC均提供有支持文件处理的函数和类,常用的有Win32 API的CreateFile().WriteFile().ReadFile() ...
- 居于H5的多文件、大文件、多线程上传解决方案
文件上传在web应用中是比较常见的功能,前段时间做了一个多文件.大文件.多线程文件上传的功能,使用效果还不错,总结分享下. 一. 功能性需求与非功能性需求 要求操作便利,一次选择多个文件进行上传: 支 ...
- C#windows桌面应用小程序制作——大文件数据分段解析存储
现在的任务就是做一个大文件解析的桌面应用小程序,具体需求就是:将一个很大的文件里的数据按一定标志拆分然后分别保存到某个文件夹下面. 解析的文件内容为以下内容: windows 应用小程序界面 具体代码 ...
- Linux下的split 命令(将一个大文件根据行数平均分成若干个小文件)
将一个大文件分成若干个小文件方法 例如将一个BLM.txt文件分成前缀为 BLM_ 的1000个小文件,后缀为系数形式,且后缀为4位数字形式 先利用 wc -l BLM.txt 读出 BL ...
- 【转载】.NET/C#-uploadify视频文件or大文件上传
引言 之前使用Uploadify做了一个上传图片并预览的功能,今天在项目中,要使用该插件上传大文件.之前弄过上传图片的demo,就使用该demo进行测试.可以查看我的这篇文章: [Asp.net]Up ...
- 前端上传视频、图片、文件等大文件 组件Plupload使用指南
demo:https://blog.csdn.net/qq_30100043/article/details/78491993 Plupload上传插件中文帮助文档网址:http://www.phpi ...
- 使用BFG清除git仓库中的隐私文件或大文件
使用git时间不长,在调机械臂项目的时候,由于对TwinCAT3和vs的机制不太了解,没有添加很好的忽略文件(.gitignore).造成git仓库包含了很多没有用的文件,例如vs的sdf文件,Twi ...
- .net上传文件,大文件及下载方式汇总(转)
原文地址:http://www.360doc.com/content/19/1219/10/67993814_880731215.shtml Brettle.Web.NeatUpload.dll 文件 ...
- HTML上传文件支持大文件上传,下载
上传 1.修改配置文件web.config,在<system.webServer>下面加入 <security> <requestFiltering > <r ...
- pytorch批训练数据构造
这是对莫凡python的学习笔记. 1.创建数据 import torch import torch.utils.data as Data BATCH_SIZE = 8 x = torch.linsp ...
随机推荐
- EF,lambda 反向模糊查询
SELECT * FROM table as t WHERE "张三的偶像" LIKE t.userName; bool thisMchBelong = _mch_blackSer ...
- 太卷了,史上最简单的监控系统 catpaw 简介
指标监控的痛点 当下比较流行的监控系统,比如 Prometheus.Nightingale.VictoriaMetrics,都是基于数值型指标的监控系统,这类监控系统的痛点在于:告警的时候只能拿到异常 ...
- 说一下 JSP 的 4 种作用域?
page:代表与一个页面相关的对象和属性. request:代表与客户端发出的一个请求相关的对象和属性.一个请求可能跨越多个页面,涉及多个 Web 组件:需要在页面显示的临时数据可以置于此作用域. s ...
- 猪齿鱼数智化开发管理平台 1.3.0-alpha发布,欢迎立即体验!
2022年3月18日,数智化开发管理平台猪齿鱼 Choerodon发布 V1.3-alpha版本,多项功能新增或优化,多管齐下,全面提升团队工作效能!通过提供体系化方法论和协作.测试.DevOp ...
- 「C++」复杂模拟【壹】
建议开启目录食用 阅读本文之前建议您先看这里,如果您已经看完了,那么就可以放心大胆的学习本文了. 我认为其实本文的难度还是比较大的,今天我们题是来自山东省省选,所以建议大家谨慎阅读,如果您是专业程序员 ...
- 解决 Xshell 无法使用 zsh 的 prompt style
为了更好的阅读体验,请点击这里 先学习一下 zsh 的配置吧~ 参考资料 从 0 开始:教你如何配置 zsh powerlevel10k 如何给 Xshell 配置呢 当我安装完 oh-my-zsh. ...
- 开源一个RAG大模型本地知识库问答机器人
弹指间,2009年大学毕业到现在2024年,已经15年过去了. 前2天,看到自己14年在博客园写的一个博客,哪个时候是工作之余创业 感兴趣的朋友可以看看我10年前发的一篇博客https://www.c ...
- Android 官方AB Update说明
Android 官方AB Update说明 A/B 系统更新,也称为无缝更新,用于确保可运行的启动系统在无线 (OTA) 更新期间能够保留在磁盘上.这样可以降低更新之后设备无法启动的可能性,也就是说, ...
- mapreduce的shuffle机制
1.1 概述: mapreduce中,map阶段处理的数据如何传递给reduce阶段,是mapreduce框架中最关键的一个流程,这个流程就叫shuffle:(从map的输出到reduce的输入) s ...
- Ubuntu20.04中 ORBSLAM3的安装和测试
ORBSLAM3 安装以及测试教程(Ubuntu20.04) 1.前期准备工作 1.1安装相关依赖 sudo apt install git cmake gcc g++ mlocate 1.2下载OR ...