pytorch数据加载
一、方法一
数据组织形式
dataset_name
----train
----val from torchvision import datasets, models, transforms # Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),} data_dir = 'hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10) # Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
scheduler.step()
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode running_loss = 0.0
running_corrects = 0 # Iterate over data.
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device) # zero the parameter gradients
optimizer.zero_grad() # forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels) # backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step() # statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase] print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc)) # deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict()) print()
二、方法二
自定路径+txt内写入的路径
txt内容,前面是图片路径,后面是label类别

生成txt代码
# -*-coding:utf-8-*-
"""
@Project: googlenet_classification
@File : create_labels_files.py
@Author : panjq
@E-mail : pan_jinquan@163.com
@Date : 2018-08-11 10:15:28
""" import os
import os.path def write_txt(content, filename, mode='w'):
"""保存txt数据
:param content:需要保存的数据,type->list
:param filename:文件名
:param mode:读写模式:'w' or 'a'
:return: void
"""
with open(filename, mode) as f:
for line in content:
str_line = ""
for col, data in enumerate(line):
if not col == len(line) - 1:
# 以空格作为分隔符
str_line = str_line + str(data) + " "
else:
# 每行最后一个数据用换行符“\n”
str_line = str_line + str(data) + "\n"
f.write(str_line) def get_files_list(dir):
'''
实现遍历dir目录下,所有文件(包含子文件夹的文件)
:param dir:指定文件夹目录
:return:包含所有文件的列表->list
'''
# parent:父目录, filenames:该目录下所有文件夹,filenames:该目录下的文件名
files_list = []
for parent, dirnames, filenames in os.walk(dir):
for filename in filenames:
# print("parent is: " + parent)
# print("filename is: " + filename)
# print(os.path.join(parent, filename)) # 输出rootdir路径下所有文件(包含子文件)信息
curr_file = parent.split(os.sep)[-1]
if curr_file == '':
labels = 0
elif curr_file == '':
labels = 1
elif curr_file == '':
labels = 2
elif curr_file == '':
labels = 3
elif curr_file == '':
labels = 4
elif curr_file == '':
labels = 5
elif curr_file == '':
labels = 6
elif curr_file == '':
labels = 7
elif curr_file == '':
labels = 8
files_list.append([os.path.join(curr_file, filename), labels])
return files_list if __name__ == '__main__':
train_dir = r'F:\WU_work\guandao\data\guandao20190904_10\train'
train_txt = r'F:\WU_work\guandao\data\guandao20190904_10/train.txt'
train_data = get_files_list(train_dir)
write_txt(train_data, train_txt, mode='w') val_dir = r'F:\WU_work\guandao\data\guandao20190904_10\validation'
val_txt = r'F:\WU_work\guandao\data\guandao20190904_10/val.txt'
val_data = get_files_list(val_dir)
write_txt(val_data, val_txt, mode='w')
# 构建MyDataset实例 img_path是一种可在txt图片路径前面加入的一种机制
#img_path是训练集或验证集路径,如F:\WU_work\guandao\data\guandao20190904_10\train
train_data = MyDataset(img_path = '', txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(img_path = '', txt_path=valid_txt_path, transform=validTransform)
数据加载
# -------------------------------------------- step 1/5 : 加载数据 -------------------------------------------
train_txt_path = './Data/train.txt'
valid_txt_path = './Data/valid.txt'
# 数据预处理设置
normMean = [0.4948052, 0.48568845, 0.44682974]
normStd = [0.24580306, 0.24236229, 0.2603115]
normTransform = transforms.Normalize(normMean, normStd)
trainTransform = transforms.Compose([
transforms.Resize(224),
transforms.RandomCrop(224, padding=4),
transforms.ToTensor(),
normTransform
]) validTransform = transforms.Compose([
transforms.ToTensor(),
normTransform
]) # 构建MyDataset实例 img_path是一种可在txt图片路径前面加入的一种机制
train_data = MyDataset(img_path = '', txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(img_path = '', txt_path=valid_txt_path, transform=validTransform) # 构建DataLoder
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=16, shuffle=True)
valid_loader = torch.utils.data.DataLoader(dataset=valid_data, batch_size=16)
train_loader 是迭代器,每次返回图片和对应的label
pytorch数据加载的更多相关文章
- PyTorch数据加载处理
PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 scikit-image:用于图像的IO和变换 pandas:用于更容易地进行csv解 ...
- pytorch数据加载器
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, ...
- PyTorch 数据集类 和 数据加载类 的一些尝试
最近在学习PyTorch, 但是对里面的数据类和数据加载类比较迷糊,可能是封装的太好大部分情况下是不需要有什么自己的操作的,不过偶然遇到一些自己导入的数据时就会遇到一些问题,因此自己对此做了一些小实 ...
- [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler
[源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler 目录 [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampl ...
- [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader
[源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 目录 [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 0x00 摘要 0x01 ...
- ScrollView嵌套ListView,GridView数据加载不全问题的解决
我们大家都知道ListView,GridView加载数据项,如果数据项过多时,就会显示滚动条.ScrollView组件里面只能包含一个组件,当ScrollView里面嵌套listView,GridVi ...
- python多种格式数据加载、处理与存储
多种格式数据加载.处理与存储 实际的场景中,我们会在不同的地方遇到各种不同的数据格式(比如大家熟悉的csv与txt,比如网页HTML格式,比如XML格式),我们来一起看看python如何和这些格式的数 ...
- flask+sqlite3+echarts3+ajax 异步数据加载
结构: /www | |-- /static |....|-- jquery-3.1.1.js |....|-- echarts.js(echarts3是单文件!!) | |-- /templates ...
- Entity Framework关联查询以及数据加载(延迟加载,预加载)
数据加载分为延迟加载和预加载 EF的关联实体加载有三种方式:Lazy Loading,Eager Loading,Explicit Loading,其中Lazy Loading和Explicit Lo ...
随机推荐
- zynq7020开发板+ Z-turn调试计划
参加米尔zynq7020开发板试用活动. 收到米尔z-turn板子后,焊接了一个JTAG转接板,以方便调试PL部分,对于后面的调试部分,主要分三个部分走:1.调试FPGA部分,实现逻辑控制外围简单的设 ...
- Python 列表推导式、矩阵、格式化输出
列表推导式 列表推导式提供了从列表.元组创建列表的简单途径.语法: [表达式 for语句 if语句] 创建并返回一个列表.if语句可选. 示例: list1=[1,2,3,4] #使用元组也行 lis ...
- java InputStream的使用
package cn.kongxh.io3;import java.io.File ;import java.io.InputStream ;import java.io.FileInputStrea ...
- Docker 0x04: Docker 基本使用
目录 Docker 基本使用 第一步:明确要使用容器运行的应用的镜像相关 第二步:运行一个官方nginx应用 第三步:单纯下载镜像,pull 第四步:设置国内docker-hub 第五步:列出已有镜像 ...
- Httpd服务入门知识-Httpd服务常见配置案例之显示服务器版本信息
Httpd服务入门知识-Httpd服务常见配置案例之显示服务器版本信息 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.httpd配置文件的组成 1>.主要组成 Globa ...
- OneDrive,在云端
应用场景 1.一份文档下班后还没编辑好,发送到自己的QQ/微信回家后继续编辑: 2.由于来回拷贝同一份文件,导致版本太多,忘记那个是最新版本了: 3.出门在外,客户突然需要一份重要文档,这份文件放在办 ...
- nohup 、&、 2>&1 命令分析
nohup的意思是不间断的运行,&的意思是后台运行,2>&1的意思是标准输出和错误输出都重定向到同一个文件. 简单地说nohup运行时即使关掉控制台,它该运行还是运行. http ...
- python预课05 爬虫初步学习+jieba分词+词云库+哔哩哔哩弹幕爬取示例(数据分析pandas)
结巴分词 import jieba """ pip install jieba 1.精确模式 2.全模式 3.搜索引擎模式 """ txt ...
- URI和URL的区别(转)
1说明: 这段时间写android的时候用到了URL和URI,有点分不清楚,于是做了一个系统性的学习.在这里将自己的学习笔记粘贴出来,希望对大家有帮助. 1)Java类库里有两个对应的类java.ne ...
- 学习Java书籍推荐和面试网站推荐
一.Java书籍推荐: 来自http://www.importnew.com/26932.html 1. 鸟哥的Linux私房菜—基础学习篇 3. Effective Java 6. Java并发编程 ...