[深度学习]-Dataset数据集加载
加载数据集dataloader
from torch.utils.data import DataLoader
form 自己写的dataset import Dataset
train_set = Dataset(train=True)
val_set = Dataset(train=False)
image_datasets = {
'train': train_set, 'val': val_set
}
batch_size = 4
dataloaders = {
'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2),
'val': DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)
}
dataset_sizes = {
x: len(image_datasets[x]) for x in image_datasets.keys()
}
print(dataset_sizes)
for epoch in range(num_epochs):
for phase in ['train', 'val']:
if phase == 'train':
# for param_group in optimizer.param_groups:
# print("LR", param_group['lr'])
model.train()
else:
model.eval()
以上适用于train一遍test一遍的情况
或者分别加载训练和测试:
train_dataset = Dataset('train')
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True,
num_workers=2, collate_fn=collate_fn)
test_dataset = Dataset('eval')
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=False,
num_workers=2, collate_fn=collate_fn)
自己写Dataset
from torch.utils.data import Dataset
import os
import cv2
import torch
import numpy as np
class Dataset(Dataset):
def __init__(self,train):
if train:
self.datapath = {'image': '/home/myy/code/Final_Project/data_train.txt', 'target':'/home/myy/code/Final_Project/gt_train.txt'}
else:
self.datapath = {'image': '/home/myy/code/Final_Project/data_test.txt', 'target':'/home/myy/code/Final_Project/gt_test.txt'}
# self.datapath = {'image': '/home/myy/code/Final_Project/test_small_data.txt', 'target':'/home/myy/code/Final_Project/test_small.txt'}
self.image_list, self.target_list = self.read_txt(self.datapath)
# 此处可以依据需要自己定义一些函数
# 注意调用前要加上`self.`
# 比如以下两个读取数据的函数,read_txt、read_json就是自己定义的
def read_txt(self,datapath):
im =[]
target_image = []
print(datapath)
with open(datapath['image'], 'r') as f:
image_list = f.readlines()
with open(datapath['target'], 'r') as f:
target_list = f.readlines()
return image_list, target_list
def read_json(save_path, encoding='utf8'):
jsondata = []
with open(save_path, 'r', encoding=encoding) as f:
content = f.read()
content = json.loads(content)
for key in content:
jsondata.append(content[key])
return jsondata
def __getitem__(self, item):
# 最核心的部分,经过处理,要返回输入和gt
return img, target
def __len__(self):
# 这可以根据具体情况修改,不写也行
return len(self.data)
[深度学习]-Dataset数据集加载的更多相关文章
- 什么是pytorch(4.数据集加载和处理)(翻译)
数据集加载和处理 这里主要涉及两个包:torchvision.datasets 和torch.utils.data.Dataset 和DataLoader torchvision.datasets是一 ...
- OFRecord 数据集加载
OFRecord 数据集加载 在数据输入一文中知道了使用 DataLoader 及相关算子加载数据,往往效率更高,并且学习了如何使用 DataLoader 及相关算子. 在 OFrecord 数据格式 ...
- 深入java虚拟机学习 -- 类的加载机制(续)
昨晚写 深入java虚拟机学习 -- 类的加载机制 都到1点半了,由于第二天还要工作,没有将上篇文章中的demo讲解写出来,今天抽时间补上昨晚的例子讲解. 这里我先把昨天的两份代码贴过来,重新看下: ...
- 【Java Web开发学习】Spring加载外部properties配置文件
[Java Web开发学习]Spring加载外部properties配置文件 转载:https://www.cnblogs.com/yangchongxing/p/9136505.html 1.声明属 ...
- Python3读取深度学习CIFAR-10数据集出现的若干问题解决
今天在看网上的视频学习深度学习的时候,用到了CIFAR-10数据集.当我兴高采烈的运行代码时,却发现了一些错误: # -*- coding: utf-8 -*- import pickle as p ...
- 深度学习常用数据集 API(包括 Fashion MNIST)
基准数据集 深度学习中经常会使用一些基准数据集进行一些测试.其中 MNIST, Cifar 10, cifar100, Fashion-MNIST 数据集常常被人们拿来当作练手的数据集.为了方便,诸如 ...
- Recorder︱深度学习小数据集表现、优化(Active Learning)、标注集网络获取
一.深度学习在小数据集的表现 深度学习在小数据集情况下获得好效果,可以从两个角度去解决: 1.降低偏差,图像平移等操作 2.降低方差,dropout.随机梯度下降 先来看看深度学习在小数据集上表现的具 ...
- PIE SDK 多数据源的复合数据集加载
1. 功能简介 GIS遥感图像数据复合是将多种遥感图像数据融合成一种新的图像数据的技术,是目前遥感应用分析的前沿,PIESDK通过复合数据技术可以将多幅幅影像数据集(多光谱和全色数据)组合成一幅多波段 ...
- tensorflow数据集加载
本篇涉及的内容主要有小型常用的经典数据集的加载步骤,tensorflow提供了如下接口:keras.datasets.tf.data.Dataset.from_tensor_slices(shuffl ...
随机推荐
- 入门Python数据分析最好的实战项目(一)分析篇
数据初探 首先导入要使用的科学计算包numpy,pandas,可视化matplotlib,seaborn,以及机器学习包sklearn. python学习交流群:660193417### import ...
- 关于 k 进制线性基
本质还是高斯消元,使其成为上三角矩阵.但是 \(k\) 不一定是质数. 但我们不需要保证已有数字不改变,只要维护的是一个上三角矩阵就行.所以我们可以利用更相减损让其中一个向量的最高位 \(= 0\) ...
- VScode中配置Java环境
vscode 中配置Java环境 转载说明:本篇文档原作者[@火星动力猿],文档出处来自哔哩哔哩-[教程]VScode中配置Java运行环境 转载请在开头或显眼位置标注转载信息. 1.下载VScode ...
- 牛客SQL刷题第三趴——SQL大厂面试真题
01 某音短视频 SQL156 各个视频的平均完播率 [描述]用户-视频互动表tb_user_video_log.(uid-用户ID, video_id-视频ID, start_time-开始观看时间 ...
- 2022-7-11 javascript学习 第七组 刘昀航
JavaScript是什么? 编程语言,脚本语言,依赖于某种容器来运行. JS是运行在浏览器上的,可以帮助我们去控制页面. Vue.js react.js jquery.js an ...
- 【一本通提高树链剖分】「ZJOI2008」树的统计
[ZJOI2008]树的统计 题目描述 一棵树上有 n n n 个节点,编号分别为 1 1 1 到 n n n,每个节点都有一个权值 w w w. 我们将以下面的形式来要求你对这棵树完成一些操作: I ...
- linux学习(小白篇)
当前服务器:centos 7 shell命令框:xshell 文件预览及上传:xftp (界面化软件,非常好用) 数据库连接:navicat 此文是在学习linux时做一个指令合集,方便自己查阅 进文 ...
- WPF 截图控件之绘制方框与椭圆(四) 「仿微信」
前言 接着上周写的截图控件继续更新 绘制方框与椭圆. 1.WPF实现截屏「仿微信」 2.WPF 实现截屏控件之移动(二)「仿微信」 3.WPF 截图控件之伸缩(三) 「仿微信」 正文 有开发者在B站反 ...
- Jittered采样类定义和测试
抖动采样算法测试,小图形看不出什么明显区别,还是上代码和测试图吧. 类声明: #pragma once #ifndef __JITTERED_HEADER__ #define __JITTERED_H ...
- Vue3系列11--Teleport传送组件
Teleport 是一种能够将我们的模板移动到 DOM 中 Vue app 之外的其他位置的技术,不受父级style.v-show等属性影响,但data.prop数据依旧能够共用的技术:类似于 Rea ...