pytorch数据加载
一、方法一
数据组织形式
dataset_name
----train
----val from torchvision import datasets, models, transforms # Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),} data_dir = 'hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10) # Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
scheduler.step()
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode running_loss = 0.0
running_corrects = 0 # Iterate over data.
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device) # zero the parameter gradients
optimizer.zero_grad() # forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels) # backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step() # statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase] print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc)) # deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict()) print()
二、方法二
自定路径+txt内写入的路径
txt内容,前面是图片路径,后面是label类别
生成txt代码
# -*-coding:utf-8-*-
"""
@Project: googlenet_classification
@File : create_labels_files.py
@Author : panjq
@E-mail : pan_jinquan@163.com
@Date : 2018-08-11 10:15:28
""" import os
import os.path def write_txt(content, filename, mode='w'):
"""保存txt数据
:param content:需要保存的数据,type->list
:param filename:文件名
:param mode:读写模式:'w' or 'a'
:return: void
"""
with open(filename, mode) as f:
for line in content:
str_line = ""
for col, data in enumerate(line):
if not col == len(line) - 1:
# 以空格作为分隔符
str_line = str_line + str(data) + " "
else:
# 每行最后一个数据用换行符“\n”
str_line = str_line + str(data) + "\n"
f.write(str_line) def get_files_list(dir):
'''
实现遍历dir目录下,所有文件(包含子文件夹的文件)
:param dir:指定文件夹目录
:return:包含所有文件的列表->list
'''
# parent:父目录, filenames:该目录下所有文件夹,filenames:该目录下的文件名
files_list = []
for parent, dirnames, filenames in os.walk(dir):
for filename in filenames:
# print("parent is: " + parent)
# print("filename is: " + filename)
# print(os.path.join(parent, filename)) # 输出rootdir路径下所有文件(包含子文件)信息
curr_file = parent.split(os.sep)[-1]
if curr_file == '':
labels = 0
elif curr_file == '':
labels = 1
elif curr_file == '':
labels = 2
elif curr_file == '':
labels = 3
elif curr_file == '':
labels = 4
elif curr_file == '':
labels = 5
elif curr_file == '':
labels = 6
elif curr_file == '':
labels = 7
elif curr_file == '':
labels = 8
files_list.append([os.path.join(curr_file, filename), labels])
return files_list if __name__ == '__main__':
train_dir = r'F:\WU_work\guandao\data\guandao20190904_10\train'
train_txt = r'F:\WU_work\guandao\data\guandao20190904_10/train.txt'
train_data = get_files_list(train_dir)
write_txt(train_data, train_txt, mode='w') val_dir = r'F:\WU_work\guandao\data\guandao20190904_10\validation'
val_txt = r'F:\WU_work\guandao\data\guandao20190904_10/val.txt'
val_data = get_files_list(val_dir)
write_txt(val_data, val_txt, mode='w')
# 构建MyDataset实例 img_path是一种可在txt图片路径前面加入的一种机制
#img_path是训练集或验证集路径,如F:\WU_work\guandao\data\guandao20190904_10\train
train_data = MyDataset(img_path = '', txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(img_path = '', txt_path=valid_txt_path, transform=validTransform)
数据加载
# -------------------------------------------- step 1/5 : 加载数据 -------------------------------------------
train_txt_path = './Data/train.txt'
valid_txt_path = './Data/valid.txt'
# 数据预处理设置
normMean = [0.4948052, 0.48568845, 0.44682974]
normStd = [0.24580306, 0.24236229, 0.2603115]
normTransform = transforms.Normalize(normMean, normStd)
trainTransform = transforms.Compose([
transforms.Resize(224),
transforms.RandomCrop(224, padding=4),
transforms.ToTensor(),
normTransform
]) validTransform = transforms.Compose([
transforms.ToTensor(),
normTransform
]) # 构建MyDataset实例 img_path是一种可在txt图片路径前面加入的一种机制
train_data = MyDataset(img_path = '', txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(img_path = '', txt_path=valid_txt_path, transform=validTransform) # 构建DataLoder
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=16, shuffle=True)
valid_loader = torch.utils.data.DataLoader(dataset=valid_data, batch_size=16)
train_loader 是迭代器,每次返回图片和对应的label
pytorch数据加载的更多相关文章
- PyTorch数据加载处理
PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 scikit-image:用于图像的IO和变换 pandas:用于更容易地进行csv解 ...
- pytorch数据加载器
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, ...
- PyTorch 数据集类 和 数据加载类 的一些尝试
最近在学习PyTorch, 但是对里面的数据类和数据加载类比较迷糊,可能是封装的太好大部分情况下是不需要有什么自己的操作的,不过偶然遇到一些自己导入的数据时就会遇到一些问题,因此自己对此做了一些小实 ...
- [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler
[源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler 目录 [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampl ...
- [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader
[源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 目录 [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 0x00 摘要 0x01 ...
- ScrollView嵌套ListView,GridView数据加载不全问题的解决
我们大家都知道ListView,GridView加载数据项,如果数据项过多时,就会显示滚动条.ScrollView组件里面只能包含一个组件,当ScrollView里面嵌套listView,GridVi ...
- python多种格式数据加载、处理与存储
多种格式数据加载.处理与存储 实际的场景中,我们会在不同的地方遇到各种不同的数据格式(比如大家熟悉的csv与txt,比如网页HTML格式,比如XML格式),我们来一起看看python如何和这些格式的数 ...
- flask+sqlite3+echarts3+ajax 异步数据加载
结构: /www | |-- /static |....|-- jquery-3.1.1.js |....|-- echarts.js(echarts3是单文件!!) | |-- /templates ...
- Entity Framework关联查询以及数据加载(延迟加载,预加载)
数据加载分为延迟加载和预加载 EF的关联实体加载有三种方式:Lazy Loading,Eager Loading,Explicit Loading,其中Lazy Loading和Explicit Lo ...
随机推荐
- 由一个空工程改为SpringBoot工程
1.先创建一个空的工程,创建springboot 工程 必须继承spring-boot-stater-parent 2.导入依赖 <parent> <groupId>org. ...
- POST请求转换为PUT或者Delete请求、处理post请求乱码的过滤器、Get请求乱码
在web.xml中配置 <!--配置HiddenHttpMethodFilter : 将所有的POST请求转换为PUT或者Delete请求 --><filter> <fi ...
- 并发编程--一堆锁,GIL,同步异步,Event事件
目录 一堆锁 死锁现象(*****) 递归锁 RLock (了解) 信号量 (了解) GIL(*****) 什么时GIL锁 为什么需要GIL锁 Cpython解释器与GC的问题 GIL锁带来的问题 多 ...
- postgresql怎么导入数据库
1.切换到postgres用户 : sudo su - postgres 2.在shell命令行下,创建数据库exampledb,并指定所有者为dbuser : sudo -u postgres - ...
- HBase的部署与其它相关组件(Hive和Phoenix)的集成
HBase的部署与其它相关组件(Hive和Phoenix)的集成 一.HBase部署 1.1.Zookeeper正常部署 首先保证Zookeeper集群的正常部署,并启动之: /opt/module/ ...
- Linux Mysql创建新用户并允许远程连接
第一步 登陆mysql: mysql-u 数据库用户名 -h 数据库IP -p 根据提示 输入数据库密码 第二步: GRANT ALL PRIVILEGES ON *.* TO '自定义用户名'@'% ...
- Flask基础之返回值与form表单提交
目录 1.Python 现阶段三大主流Web框架 Django Tornado Flask 对比 2.Flask的安装 3.Flask的第一个简单应用 4.Flask中的render_template ...
- zabbix--基本操作
zabbix 快速上手 示例一些zabbix的最基本的配置: 添加主机群组:添加主机:创建监控项:创建触发器 添加主机群组 参考官档:https://www.zabbix.com/documentat ...
- wentaolovesmeng.club
wentaolovesmeng.club wentaostudy.club
- Vuex之辅助函数
mapState.mapGetters.mapActions 如果我们不喜欢这种在页面上使用“this.$stroe.state.count”和“this.$store.dispatch('funNa ...