pytorch数据加载
一、方法一
数据组织形式
dataset_name
----train
----val from torchvision import datasets, models, transforms # Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),} data_dir = 'hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10) # Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
scheduler.step()
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode running_loss = 0.0
running_corrects = 0 # Iterate over data.
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device) # zero the parameter gradients
optimizer.zero_grad() # forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels) # backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step() # statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase] print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc)) # deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict()) print()
二、方法二
自定路径+txt内写入的路径
txt内容,前面是图片路径,后面是label类别

生成txt代码
# -*-coding:utf-8-*-
"""
@Project: googlenet_classification
@File : create_labels_files.py
@Author : panjq
@E-mail : pan_jinquan@163.com
@Date : 2018-08-11 10:15:28
""" import os
import os.path def write_txt(content, filename, mode='w'):
"""保存txt数据
:param content:需要保存的数据,type->list
:param filename:文件名
:param mode:读写模式:'w' or 'a'
:return: void
"""
with open(filename, mode) as f:
for line in content:
str_line = ""
for col, data in enumerate(line):
if not col == len(line) - 1:
# 以空格作为分隔符
str_line = str_line + str(data) + " "
else:
# 每行最后一个数据用换行符“\n”
str_line = str_line + str(data) + "\n"
f.write(str_line) def get_files_list(dir):
'''
实现遍历dir目录下,所有文件(包含子文件夹的文件)
:param dir:指定文件夹目录
:return:包含所有文件的列表->list
'''
# parent:父目录, filenames:该目录下所有文件夹,filenames:该目录下的文件名
files_list = []
for parent, dirnames, filenames in os.walk(dir):
for filename in filenames:
# print("parent is: " + parent)
# print("filename is: " + filename)
# print(os.path.join(parent, filename)) # 输出rootdir路径下所有文件(包含子文件)信息
curr_file = parent.split(os.sep)[-1]
if curr_file == '':
labels = 0
elif curr_file == '':
labels = 1
elif curr_file == '':
labels = 2
elif curr_file == '':
labels = 3
elif curr_file == '':
labels = 4
elif curr_file == '':
labels = 5
elif curr_file == '':
labels = 6
elif curr_file == '':
labels = 7
elif curr_file == '':
labels = 8
files_list.append([os.path.join(curr_file, filename), labels])
return files_list if __name__ == '__main__':
train_dir = r'F:\WU_work\guandao\data\guandao20190904_10\train'
train_txt = r'F:\WU_work\guandao\data\guandao20190904_10/train.txt'
train_data = get_files_list(train_dir)
write_txt(train_data, train_txt, mode='w') val_dir = r'F:\WU_work\guandao\data\guandao20190904_10\validation'
val_txt = r'F:\WU_work\guandao\data\guandao20190904_10/val.txt'
val_data = get_files_list(val_dir)
write_txt(val_data, val_txt, mode='w')
# 构建MyDataset实例 img_path是一种可在txt图片路径前面加入的一种机制
#img_path是训练集或验证集路径,如F:\WU_work\guandao\data\guandao20190904_10\train
train_data = MyDataset(img_path = '', txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(img_path = '', txt_path=valid_txt_path, transform=validTransform)
数据加载
# -------------------------------------------- step 1/5 : 加载数据 -------------------------------------------
train_txt_path = './Data/train.txt'
valid_txt_path = './Data/valid.txt'
# 数据预处理设置
normMean = [0.4948052, 0.48568845, 0.44682974]
normStd = [0.24580306, 0.24236229, 0.2603115]
normTransform = transforms.Normalize(normMean, normStd)
trainTransform = transforms.Compose([
transforms.Resize(224),
transforms.RandomCrop(224, padding=4),
transforms.ToTensor(),
normTransform
]) validTransform = transforms.Compose([
transforms.ToTensor(),
normTransform
]) # 构建MyDataset实例 img_path是一种可在txt图片路径前面加入的一种机制
train_data = MyDataset(img_path = '', txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(img_path = '', txt_path=valid_txt_path, transform=validTransform) # 构建DataLoder
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=16, shuffle=True)
valid_loader = torch.utils.data.DataLoader(dataset=valid_data, batch_size=16)
train_loader 是迭代器,每次返回图片和对应的label
pytorch数据加载的更多相关文章
- PyTorch数据加载处理
PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 scikit-image:用于图像的IO和变换 pandas:用于更容易地进行csv解 ...
- pytorch数据加载器
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, ...
- PyTorch 数据集类 和 数据加载类 的一些尝试
最近在学习PyTorch, 但是对里面的数据类和数据加载类比较迷糊,可能是封装的太好大部分情况下是不需要有什么自己的操作的,不过偶然遇到一些自己导入的数据时就会遇到一些问题,因此自己对此做了一些小实 ...
- [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler
[源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler 目录 [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampl ...
- [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader
[源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 目录 [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 0x00 摘要 0x01 ...
- ScrollView嵌套ListView,GridView数据加载不全问题的解决
我们大家都知道ListView,GridView加载数据项,如果数据项过多时,就会显示滚动条.ScrollView组件里面只能包含一个组件,当ScrollView里面嵌套listView,GridVi ...
- python多种格式数据加载、处理与存储
多种格式数据加载.处理与存储 实际的场景中,我们会在不同的地方遇到各种不同的数据格式(比如大家熟悉的csv与txt,比如网页HTML格式,比如XML格式),我们来一起看看python如何和这些格式的数 ...
- flask+sqlite3+echarts3+ajax 异步数据加载
结构: /www | |-- /static |....|-- jquery-3.1.1.js |....|-- echarts.js(echarts3是单文件!!) | |-- /templates ...
- Entity Framework关联查询以及数据加载(延迟加载,预加载)
数据加载分为延迟加载和预加载 EF的关联实体加载有三种方式:Lazy Loading,Eager Loading,Explicit Loading,其中Lazy Loading和Explicit Lo ...
随机推荐
- HTML5深入学习之数据存储
概述 本来,数据存储都是由 cookie 完成的,但是 cookie 不适合大量数据的存储,cookie 速度慢且效率低. 现在,HMLT5提供了两种在客户端存储数据的办法: localStorage ...
- 渐进增强(progressive enhancement)、优雅降级(graceful degradation)
渐进增强 progressive enhancement: 针对低版本浏览器进行构建页面,保证最基本的功能,然后再针对高级浏览器进行效果.交互等改进和追加功能达到更好的用户体验. 优雅降级 grace ...
- 英语caement单词caement水泥
水泥石 又称净浆硬化体.是指 硬化后的水泥浆体,称为水泥石,在英语里是cement有时写作caement [1] ,是由胶凝体.未水化的水泥颗粒内核.毛细孔等组 成的非均质体. 中文名:水泥石 外 ...
- fastjson 将json字符串转化成List<Map<String, Object>>
亲测可行,如下: JSON.parseObject(jsonstr, new TypeReference<List<Map<String, Object>>>() ...
- js修改页面标题 title
如果对你有帮助的话麻烦点个[推荐]~最好还可以follow一下我的GitHub~感谢观看! /* * *添加首页description元数据meta标签 *创建一个meta元素,sName为该meta ...
- php + h5 实现socket推送技术
在socket出现之前已经有ajax定时请求.长轮询等方案,但都不能满足需求,socket就应用而生了. socket基本函数socket 总结下常用的socket函数 服务端: socket_cre ...
- Linux Kbuild文档(转)
转载链接:http://blog.chinaunix.net/uid-10221131-id-2943265.html Linux Kbuild文档 Linux Kbuild文档 V 0.1 tang ...
- 服务器安装python3环境
服务器安装python3环境 先安装相关包 yum install zlib-devel bzip2-devel openssl-devel ncurses-devel sqlite-devel re ...
- Nginx配置文件、优化详解
上篇<编译安装nginx>已将nginx安装好,这篇写nginx配置文件和部分优化参数. 查看nginx的配置文件路径,可以使用nginx配置文件检查命令nginx -t: [root@n ...
- iView学习笔记(三):表格搜索,过滤及隐藏列操作
iView学习笔记(三):表格搜索,过滤及隐藏某列操作 1.后端准备工作 环境说明 python版本:3.6.6 Django版本:1.11.8 数据库:MariaDB 5.5.60 新建Django ...