pytorch数据加载
一、方法一
数据组织形式
dataset_name
----train
----val from torchvision import datasets, models, transforms # Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),} data_dir = 'hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10) # Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
scheduler.step()
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode running_loss = 0.0
running_corrects = 0 # Iterate over data.
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device) # zero the parameter gradients
optimizer.zero_grad() # forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels) # backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step() # statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase] print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc)) # deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict()) print()
二、方法二
自定路径+txt内写入的路径
txt内容,前面是图片路径,后面是label类别
生成txt代码
# -*-coding:utf-8-*-
"""
@Project: googlenet_classification
@File : create_labels_files.py
@Author : panjq
@E-mail : pan_jinquan@163.com
@Date : 2018-08-11 10:15:28
""" import os
import os.path def write_txt(content, filename, mode='w'):
"""保存txt数据
:param content:需要保存的数据,type->list
:param filename:文件名
:param mode:读写模式:'w' or 'a'
:return: void
"""
with open(filename, mode) as f:
for line in content:
str_line = ""
for col, data in enumerate(line):
if not col == len(line) - 1:
# 以空格作为分隔符
str_line = str_line + str(data) + " "
else:
# 每行最后一个数据用换行符“\n”
str_line = str_line + str(data) + "\n"
f.write(str_line) def get_files_list(dir):
'''
实现遍历dir目录下,所有文件(包含子文件夹的文件)
:param dir:指定文件夹目录
:return:包含所有文件的列表->list
'''
# parent:父目录, filenames:该目录下所有文件夹,filenames:该目录下的文件名
files_list = []
for parent, dirnames, filenames in os.walk(dir):
for filename in filenames:
# print("parent is: " + parent)
# print("filename is: " + filename)
# print(os.path.join(parent, filename)) # 输出rootdir路径下所有文件(包含子文件)信息
curr_file = parent.split(os.sep)[-1]
if curr_file == '':
labels = 0
elif curr_file == '':
labels = 1
elif curr_file == '':
labels = 2
elif curr_file == '':
labels = 3
elif curr_file == '':
labels = 4
elif curr_file == '':
labels = 5
elif curr_file == '':
labels = 6
elif curr_file == '':
labels = 7
elif curr_file == '':
labels = 8
files_list.append([os.path.join(curr_file, filename), labels])
return files_list if __name__ == '__main__':
train_dir = r'F:\WU_work\guandao\data\guandao20190904_10\train'
train_txt = r'F:\WU_work\guandao\data\guandao20190904_10/train.txt'
train_data = get_files_list(train_dir)
write_txt(train_data, train_txt, mode='w') val_dir = r'F:\WU_work\guandao\data\guandao20190904_10\validation'
val_txt = r'F:\WU_work\guandao\data\guandao20190904_10/val.txt'
val_data = get_files_list(val_dir)
write_txt(val_data, val_txt, mode='w')
# 构建MyDataset实例 img_path是一种可在txt图片路径前面加入的一种机制
#img_path是训练集或验证集路径,如F:\WU_work\guandao\data\guandao20190904_10\train
train_data = MyDataset(img_path = '', txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(img_path = '', txt_path=valid_txt_path, transform=validTransform)
数据加载
# -------------------------------------------- step 1/5 : 加载数据 -------------------------------------------
train_txt_path = './Data/train.txt'
valid_txt_path = './Data/valid.txt'
# 数据预处理设置
normMean = [0.4948052, 0.48568845, 0.44682974]
normStd = [0.24580306, 0.24236229, 0.2603115]
normTransform = transforms.Normalize(normMean, normStd)
trainTransform = transforms.Compose([
transforms.Resize(224),
transforms.RandomCrop(224, padding=4),
transforms.ToTensor(),
normTransform
]) validTransform = transforms.Compose([
transforms.ToTensor(),
normTransform
]) # 构建MyDataset实例 img_path是一种可在txt图片路径前面加入的一种机制
train_data = MyDataset(img_path = '', txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(img_path = '', txt_path=valid_txt_path, transform=validTransform) # 构建DataLoder
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=16, shuffle=True)
valid_loader = torch.utils.data.DataLoader(dataset=valid_data, batch_size=16)
train_loader 是迭代器,每次返回图片和对应的label
pytorch数据加载的更多相关文章
- PyTorch数据加载处理
PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 scikit-image:用于图像的IO和变换 pandas:用于更容易地进行csv解 ...
- pytorch数据加载器
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, ...
- PyTorch 数据集类 和 数据加载类 的一些尝试
最近在学习PyTorch, 但是对里面的数据类和数据加载类比较迷糊,可能是封装的太好大部分情况下是不需要有什么自己的操作的,不过偶然遇到一些自己导入的数据时就会遇到一些问题,因此自己对此做了一些小实 ...
- [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler
[源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler 目录 [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampl ...
- [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader
[源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 目录 [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 0x00 摘要 0x01 ...
- ScrollView嵌套ListView,GridView数据加载不全问题的解决
我们大家都知道ListView,GridView加载数据项,如果数据项过多时,就会显示滚动条.ScrollView组件里面只能包含一个组件,当ScrollView里面嵌套listView,GridVi ...
- python多种格式数据加载、处理与存储
多种格式数据加载.处理与存储 实际的场景中,我们会在不同的地方遇到各种不同的数据格式(比如大家熟悉的csv与txt,比如网页HTML格式,比如XML格式),我们来一起看看python如何和这些格式的数 ...
- flask+sqlite3+echarts3+ajax 异步数据加载
结构: /www | |-- /static |....|-- jquery-3.1.1.js |....|-- echarts.js(echarts3是单文件!!) | |-- /templates ...
- Entity Framework关联查询以及数据加载(延迟加载,预加载)
数据加载分为延迟加载和预加载 EF的关联实体加载有三种方式:Lazy Loading,Eager Loading,Explicit Loading,其中Lazy Loading和Explicit Lo ...
随机推荐
- windows查看某个端口被哪个进程占用
找出端口对应的PID netstat -ano | findstr 8080 帮助命令netstat -? -a 显示所有连接和侦听端口. -n 以数字形式显示地址和端口号. -o 显示拥有的与每个连 ...
- 3.Javascript实现instanceof
instanceof instanceof 用于判断某个对象是否是另一个对象(构造方法)的实例.instanceof会查找原型链,直到null如果还不是后面这个对象的实例的话就返回false,否则就返 ...
- Vue学习之组件切换及父子组件小结(八)
一.组件切换: 1.v-if与v-else方式: <!DOCTYPE html> <html lang="en"> <head> <met ...
- JS JQUERY实现滚动条自动滚到底的方法
$(function(){ var h = $(document).height()-$(window).height(); $(document).scrollTop(h); }); \ windo ...
- android中如何实现UI的实时更新---需要考虑电量和流量
1.如果不考虑电量和流量的话,只需要在对应的activity里面继承Runnable,在run方法里面写一个while死循环,调用接口返回数据,如果数据发生了变化,就立即更新UI 2.需要考虑电量的话 ...
- WPE 过滤器 滤镜 用法
过滤所有数值匹配的数据包,并修改指定的bit位 打开游戏 打开WPE 附加游戏进程 选项配置 用来配置抓取发送和接收包类型 先抓取发送包,也就是游戏中主动发给服务器的包 点击开始抓包 输入喊话内容 分 ...
- Cheat Engine 模糊数值
打开游戏 玩到换枪为止 换枪 发现子弹数量是有限的200 扫描200 这是初次扫描 开两枪 剩余子弹数量194 再次扫描194 得到地址 尝试得到的这两个地址,经验证,第二个是我们想要的地址 重新开始 ...
- k8s网络配置管理
docker容器的四种网络类型 1.桥接 2.联盟 3.主机 4.无 docker跨节点的容器通信必须通过NAT机制 宿主机上的容器一般都是私网地址 它可以通过宿主机 ...
- go test benchmark
Benchtest的简单使用 一个简单的benchtest用例 // 以BenchmarkXXX类似命名,并传入b *testing.B 参数 func BenchmarkLoopSum(b *tes ...
- kafaka可视化工具kafkatool
炒作就像动物世界的森林法则,专门攻击弱者,这种做法往往能够百发百中. ...