Pytorch dataset自定义【直播】2019 年县域农业大脑AI挑战赛---数据准备(二),Dataset定义
在我的torchvision库里介绍的博文(https://www.cnblogs.com/yjphhw/p/9773333.html)里说了对pytorch的dataset的定义方式。
本文相当于实现一个自定义的数据集,而这正是我们在做自己工程所需要的,我们总是用自己的数据嘛。
继承 from torch.utils.data import Dataset 类
然后实现 __len__(self) ,和 __getitem__(self,idx) 两个方法。以及数据增强也可以写入,数据增强想了想还是放到了Dataset里,
习惯上可能与常用的不同,但是觉得由于每种数据都有自己的增强方法所以,增强方法可以和数据集绑定到一起的。
接上一节我们通过切割,获取了2217个图像切片。
这就是我的FarmDataset
from torch.utils.data import Dataset, DataLoader
from PIL import Image,ImageEnhance
from osgeo import gdal
from torchvision import transforms
import glob
import torch as tc
import numpy as np class FarmDataset(Dataset):
def __init__(self,istrain=True,isaug=True):
self.istrain=istrain
self.trainxformat='./data/train/data1500/*.bmp'
self.trainyformat='./data/train/label1500/*.bmp'
self.testxformat='./data/test/*.png'
self.fns=glob.glob(self.trainxformat) if istrain else glob.glob(self.testxformat)
self.length=len(self.fns)
self.transforms=transforms
self.isaug=isaug def __len__(self):
#total length is 2217
return self.length
def __getitem__(self,idx):
if self.istrain: imgxname=self.fns[idx]
sampleimg = Image.open(imgxname)
imgyname=imgxname.replace('data1500','label1500')
targetimg = Image.open(imgyname).convert('L')
#sampleimg.save('original.bmp') #data augmentation
if self.isaug:
sampleimg,targetimg=self.imgtrans(sampleimg,targetimg) #check the result of dataautmentation
#sampleimg.save('sampletmp.bmp')
#targetimg.save('targettmp.bmp') sampleimg=transforms.ToTensor()(sampleimg)
#targetimg=transforms.ToTensor()(targetimg).squeeze(0).long()
targetimg=np.array(targetimg)
targetimg=tc.from_numpy(targetimg).long() #to tensor
#print(sampleimg.shape,targetimg.shape)
return sampleimg,targetimg
else:
return gdal.Open(self.fns[idx])
def imgtrans(self,x,y,outsize=1024):
'''input is a PIL image
image dataaugumentation
return also aPIL image。
'''
#rotate should consider y
degree=np.random.randint(360)
x=x.rotate(degree,resample=Image.NEAREST,fillcolor=0)
y=y.rotate(degree,resample=Image.NEAREST,fillcolor=0) #here should be carefull, in case of label damage #random do the input image augmentation
if np.random.random()>0.5:
#sharpness
factor=0.5+np.random.random()
enhancer=ImageEnhance.Sharpness(x)
x=enhancer.enhance(factor)
if np.random.random()>0.5:
#color augument
factor=0.5+np.random.random()
enhancer=ImageEnhance.Color(x)
x=enhancer.enhance(factor)
if np.random.random()>0.5:
#contrast augument
factor=0.5+np.random.random()
enhancer=ImageEnhance.Contrast(x)
x=enhancer.enhance(factor)
if np.random.random()>0.5:
#brightness
factor=0.5+np.random.random()
enhancer=ImageEnhance.Brightness(x)
x=enhancer.enhance(factor) #img flip
transtypes=[Image.FLIP_LEFT_RIGHT,Image.FLIP_TOP_BOTTOM,
Image.ROTATE_90,Image.ROTATE_180,Image.ROTATE_270]
transtype=transtypes[np.random.randint(len(transtypes))]
x = x.transpose(transtype)
y = y.transpose(transtype) #img resize between 0.8-1.2
w,h=x.size
factor=1+np.random.normal()/5
if factor>1.2: factor=1.2
if factor<0.8: factor=0.8
#print(factor,x.size)
x=x.resize((int(w*factor),int(h*factor)),Image.NEAREST)
y=y.resize((int(w*factor),int(h*factor)),Image.NEAREST) #random crop
w,h=x.size
stx=np.random.randint(w-outsize)
sty=np.random.randint(h-outsize)
#print((stx,sty,outsize,outsize))
x=x.crop((stx,sty,stx+outsize,sty+outsize)) #stx,sty,width,height
y=y.crop((stx,sty,stx+outsize,sty+outsize))
#print(x.size,y.size)
return x,y #return outsized pil image if __name__=='__main__':
d=FarmDataset(istrain=True)
x,y=d[2216]
print(x.shape)
print(y.shape)
输入的是个1500x1500的图像,输出的是增强后的1024x1024后的图像。
其实对于分割问题来看,以后这个就可以作为一个模板,修改修改就可以换到另一个数据集中。
放几张图片:
原始图像:

进行数据增强后可以得到的一系列:



经过check 发现没有的问题通过测试。

Pytorch dataset自定义【直播】2019 年县域农业大脑AI挑战赛---数据准备(二),Dataset定义的更多相关文章
- Pytorch 分割模型构建和训练【直播】2019 年县域农业大脑AI挑战赛---(四)模型构建和网络训练
对于分割网络,如果当成一个黑箱就是:输入一个3x1024x1024 输出4x1024x1024. 我没有使用二分类,直接使用了四分类. 分类网络使用了SegNet,没有加载预训练模型,参数也是默认初始 ...
- Pytorch 加载保存模型【直播】2019 年县域农业大脑AI挑战赛---(三)保存结果
在模型训练结束,结束后,通常是一个分割模型,输入 1024x1024 输出 4x1024x1024. 一种方法就是将整个图切块,然后每张预测,但是有个不好处就是可能在边界处断续. 由于这种切块再预测很 ...
- Pytorch【直播】2019 年县域农业大脑AI挑战赛---初级准备(一)切图
比赛地址:https://tianchi.aliyun.com/competition/entrance/231717/introduction 这次比赛给的图非常大5万x5万,在训练之前必须要进行数 ...
- XAF 框架中,自定义参数动作(Action),输入参数的控件可定义,用于选择组织及项目
XAF 框架中,如何生成一个自定义参数动作(Action),输入参数的控件可定义? 参考文档:https://documentation.devexpress.com/eXpressAppFramew ...
- “全栈2019”Java第八十九章:接口中能定义内部类吗?
难度 初级 学习时间 10分钟 适合人群 零基础 开发语言 Java 开发环境 JDK v11 IntelliJ IDEA v2018.3 文章原文链接 "全栈2019"Java第 ...
- 2019年全国高校计算机能力挑战赛 C语言程序设计决赛
2019年全国高校计算机能力挑战赛 C语言程序设计决赛 毕竟这个比赛是第一次举办,能理解.. 希望未来再举办时,能够再完善一下题面表述.数据范围. 话说区域赛获奖名额有点少吧.舍友花60块想混个创新创 ...
- 2019年全国高校计算机能力挑战赛初赛C语言解答
http://www.ncccu.org.cn 2019年全国高校计算机能力挑战赛分设大数据算法赛,人工智能算法赛,Office高级应用赛,程序设计赛4大赛项 C语言初赛解答 1:编程1 16.现有一 ...
- C# 将Excel里面的数据填充到DataSet中
/// <summary> /// 将Excel表里的数据填充到DataSet中 /// </summary> /// <param name="filenam ...
- [Pytorch]PyTorch Dataloader自定义数据读取
整理一下看到的自定义数据读取的方法,较好的有一下三篇文章, 其实自定义的方法就是把现有数据集的train和test分别用 含有图像路径与label的list返回就好了,所以需要根据数据集随机应变. 所 ...
随机推荐
- PB调用.NET类库详解
要维护一个老的PB系统,有些地方用PB实在不方便,好在就张三.李四几个人用,每人装个.net框架. 设置.NET类COM可见 方式一:将整个程序集设成COM可见 方式二,只公开部分类 使用.Net框架 ...
- 微信公众平台接口获取时间戳为10位,java开发需转为13位
问题1:为什么会生成13位的时间戳,13位的时间戳和10时间戳分别是怎么来的 ? java的date默认精度是毫秒,也就是说生成的时间戳就是13位的,而像c++或者php生成的时间戳默认就是10位的, ...
- spark实验(三)--Spark和Hadoop的安装(1)
一.实验目的 (1)掌握在 Linux 虚拟机中安装 Hadoop 和 Spark 的方法: (2)熟悉 HDFS 的基本使用方法: (3)掌握使用 Spark 访问本地文件和 HDFS 文件的方法. ...
- Java中引用类型、对象的创建与销毁
引用类型 在java中,除了基本数据类型之外的,就是引用数据类型了,引用指的是对象的一个引用,通过引用可以操作对象,控制对象,向对象发送消息. 简单来说,引用可以访问对象的属性,并调用对象的方法 创建 ...
- ubuntu 虚拟机添加多个站点
我们安装好lamp环境,然后开始操作,比如一个站点叫test.ubuntu1.com,一个叫test.ubuntu2.com 1.修改hosts文件,路径/etc/hosts sudo vim /et ...
- C++98常用特性介绍——mutable关键字
讲mutable前,先讲一下const函数,讲const函数前,先讲一下函数前后加const的区别 一.C++函数前后加const的区别 1)函数前加const:普通函数或非静态成员函数前均可加con ...
- QRious入门
qrious是一款基于HTML5 Canvas的纯JS二维码生成插件.通过qrious.js可以快速生成各种二维码,你可以控制二维码的尺寸颜色,还可以将生成的二维码进行Base64编码. qrious ...
- 一个简单insert 语句执行 40ms 原因剖析
背景:一个简单的带有主键的insert 语句,居然要 40ms ,开发受不了,要求降低 因此我们要关注的的 数据从插入落地的IO 中间都干了什么 一.MySQL的文件 首先简单介绍一下MySQL的数据 ...
- php 基础 语句include和require的区别是什么?为避免多次包含同一文件,可用(?)语句代替它们?
require->require是无条件包含也就是如果一个流程里加入require,无论条件成立与否都会先执行 require include->include有返回值,而require没 ...
- vmware fusion nat网络模式设置固定ip
最近想在本地用虚拟环境搭一个k8s环境,但是发现虚拟机的ip会不定时自动变化,导致mosh客户端连接经常中断.于是就想让虚拟机的ip固定住,不再变动. mac 上的 vmware fusion 设置固 ...