一、方法一
数据组织形式
dataset_name
----train
----val from torchvision import datasets, models, transforms # Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),} data_dir = 'hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
    for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10) # Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
scheduler.step()
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode running_loss = 0.0
running_corrects = 0 # Iterate over data.
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device) # zero the parameter gradients
optimizer.zero_grad() # forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels) # backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step() # statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase] print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc)) # deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict()) print()

二、方法二

自定路径+txt内写入的路径

txt内容,前面是图片路径,后面是label类别

生成txt代码

# -*-coding:utf-8-*-
"""
@Project: googlenet_classification
@File : create_labels_files.py
@Author : panjq
@E-mail : pan_jinquan@163.com
@Date : 2018-08-11 10:15:28
""" import os
import os.path def write_txt(content, filename, mode='w'):
"""保存txt数据
:param content:需要保存的数据,type->list
:param filename:文件名
:param mode:读写模式:'w' or 'a'
:return: void
"""
with open(filename, mode) as f:
for line in content:
str_line = ""
for col, data in enumerate(line):
if not col == len(line) - 1:
# 以空格作为分隔符
str_line = str_line + str(data) + " "
else:
# 每行最后一个数据用换行符“\n”
str_line = str_line + str(data) + "\n"
f.write(str_line) def get_files_list(dir):
'''
实现遍历dir目录下,所有文件(包含子文件夹的文件)
:param dir:指定文件夹目录
:return:包含所有文件的列表->list
'''
# parent:父目录, filenames:该目录下所有文件夹,filenames:该目录下的文件名
files_list = []
for parent, dirnames, filenames in os.walk(dir):
for filename in filenames:
# print("parent is: " + parent)
# print("filename is: " + filename)
# print(os.path.join(parent, filename)) # 输出rootdir路径下所有文件(包含子文件)信息
curr_file = parent.split(os.sep)[-1]
if curr_file == '':
labels = 0
elif curr_file == '':
labels = 1
elif curr_file == '':
labels = 2
elif curr_file == '':
labels = 3
elif curr_file == '':
labels = 4
elif curr_file == '':
labels = 5
elif curr_file == '':
labels = 6
elif curr_file == '':
labels = 7
elif curr_file == '':
labels = 8
files_list.append([os.path.join(curr_file, filename), labels])
return files_list if __name__ == '__main__':
train_dir = r'F:\WU_work\guandao\data\guandao20190904_10\train'
train_txt = r'F:\WU_work\guandao\data\guandao20190904_10/train.txt'
train_data = get_files_list(train_dir)
write_txt(train_data, train_txt, mode='w') val_dir = r'F:\WU_work\guandao\data\guandao20190904_10\validation'
val_txt = r'F:\WU_work\guandao\data\guandao20190904_10/val.txt'
val_data = get_files_list(val_dir)
write_txt(val_data, val_txt, mode='w')
    # 构建MyDataset实例 img_path是一种可在txt图片路径前面加入的一种机制
  #img_path是训练集或验证集路径,如F:\WU_work\guandao\data\guandao20190904_10\train
train_data = MyDataset(img_path = '', txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(img_path = '', txt_path=valid_txt_path, transform=validTransform)

数据加载

# -------------------------------------------- step 1/5 : 加载数据 -------------------------------------------
train_txt_path = './Data/train.txt'
valid_txt_path = './Data/valid.txt'
# 数据预处理设置
normMean = [0.4948052, 0.48568845, 0.44682974]
normStd = [0.24580306, 0.24236229, 0.2603115]
normTransform = transforms.Normalize(normMean, normStd)
trainTransform = transforms.Compose([
transforms.Resize(224),
transforms.RandomCrop(224, padding=4),
transforms.ToTensor(),
normTransform
]) validTransform = transforms.Compose([
transforms.ToTensor(),
normTransform
]) # 构建MyDataset实例 img_path是一种可在txt图片路径前面加入的一种机制
train_data = MyDataset(img_path = '', txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(img_path = '', txt_path=valid_txt_path, transform=validTransform) # 构建DataLoder
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=16, shuffle=True)
valid_loader = torch.utils.data.DataLoader(dataset=valid_data, batch_size=16)
train_loader 是迭代器,每次返回图片和对应的label

pytorch数据加载的更多相关文章

  1. PyTorch数据加载处理

    PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 scikit-image:用于图像的IO和变换 pandas:用于更容易地进行csv解 ...

  2. pytorch数据加载器

    class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, ...

  3. PyTorch 数据集类 和 数据加载类 的一些尝试

    最近在学习PyTorch,  但是对里面的数据类和数据加载类比较迷糊,可能是封装的太好大部分情况下是不需要有什么自己的操作的,不过偶然遇到一些自己导入的数据时就会遇到一些问题,因此自己对此做了一些小实 ...

  4. [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler

    [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler 目录 [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampl ...

  5. [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader

    [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 目录 [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 0x00 摘要 0x01 ...

  6. ScrollView嵌套ListView,GridView数据加载不全问题的解决

    我们大家都知道ListView,GridView加载数据项,如果数据项过多时,就会显示滚动条.ScrollView组件里面只能包含一个组件,当ScrollView里面嵌套listView,GridVi ...

  7. python多种格式数据加载、处理与存储

    多种格式数据加载.处理与存储 实际的场景中,我们会在不同的地方遇到各种不同的数据格式(比如大家熟悉的csv与txt,比如网页HTML格式,比如XML格式),我们来一起看看python如何和这些格式的数 ...

  8. flask+sqlite3+echarts3+ajax 异步数据加载

    结构: /www | |-- /static |....|-- jquery-3.1.1.js |....|-- echarts.js(echarts3是单文件!!) | |-- /templates ...

  9. Entity Framework关联查询以及数据加载(延迟加载,预加载)

    数据加载分为延迟加载和预加载 EF的关联实体加载有三种方式:Lazy Loading,Eager Loading,Explicit Loading,其中Lazy Loading和Explicit Lo ...

随机推荐

  1. windows查看某个端口被哪个进程占用

    找出端口对应的PID netstat -ano | findstr 8080 帮助命令netstat -? -a 显示所有连接和侦听端口. -n 以数字形式显示地址和端口号. -o 显示拥有的与每个连 ...

  2. 3.Javascript实现instanceof

    instanceof instanceof 用于判断某个对象是否是另一个对象(构造方法)的实例.instanceof会查找原型链,直到null如果还不是后面这个对象的实例的话就返回false,否则就返 ...

  3. Vue学习之组件切换及父子组件小结(八)

    一.组件切换: 1.v-if与v-else方式: <!DOCTYPE html> <html lang="en"> <head> <met ...

  4. JS JQUERY实现滚动条自动滚到底的方法

    $(function(){ var h = $(document).height()-$(window).height(); $(document).scrollTop(h); }); \ windo ...

  5. android中如何实现UI的实时更新---需要考虑电量和流量

    1.如果不考虑电量和流量的话,只需要在对应的activity里面继承Runnable,在run方法里面写一个while死循环,调用接口返回数据,如果数据发生了变化,就立即更新UI 2.需要考虑电量的话 ...

  6. WPE 过滤器 滤镜 用法

    过滤所有数值匹配的数据包,并修改指定的bit位 打开游戏 打开WPE 附加游戏进程 选项配置 用来配置抓取发送和接收包类型 先抓取发送包,也就是游戏中主动发给服务器的包 点击开始抓包 输入喊话内容 分 ...

  7. Cheat Engine 模糊数值

    打开游戏 玩到换枪为止 换枪 发现子弹数量是有限的200 扫描200 这是初次扫描 开两枪 剩余子弹数量194 再次扫描194 得到地址 尝试得到的这两个地址,经验证,第二个是我们想要的地址 重新开始 ...

  8. k8s网络配置管理

    docker容器的四种网络类型 1.桥接           2.联盟    3.主机    4.无 docker跨节点的容器通信必须通过NAT机制  宿主机上的容器一般都是私网地址 它可以通过宿主机 ...

  9. go test benchmark

    Benchtest的简单使用 一个简单的benchtest用例 // 以BenchmarkXXX类似命名,并传入b *testing.B 参数 func BenchmarkLoopSum(b *tes ...

  10. kafaka可视化工具kafkatool

    炒作就像动物世界的森林法则,专门攻击弱者,这种做法往往能够百发百中.                                                                   ...