2018百度之星开发者大赛-paddlepaddle学习(二)将数据保存为recordio文件并读取
paddlepaddle将数据保存为recordio文件并读取
因为有时候一次性将数据加载到内存中有可能太大,所以我们可以选择将数据转换成标准格式recordio文件并读取供我们的网络利用,接下来记录一下如何保存数据为recordio,并读取。
将数据保存为RecordIO文件
import paddle.fluid as fluid
import numpy
def reader_creator():
def __impl__():
for i in range(1000):
yield [
numpy.random.random(size=[3,224,224], dtype="float32"),
numpy.random.random(size=[1], dtype="int64")
]
return __impl__
img = fluid.layers.data(name="image", shape=[3, 224, 224])
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
feeder = fluid.DataFeeder(feed_list=[img, label], place=fluid.CPUPlace())
BATCH_SIZE = 32
reader = paddle.batch(reader_creator(), batch_size=BATCH_SIZE)
fluid.recordio_writer.convert_reader_to_recordio_file(
"train.recordio", feeder=feeder, reader_creator=reader)
乍看也能看懂,非常合理,但是我这么保存以后就出现问题,在后面的读取数据的时候,一般我们会把[3,244,244]的图片送入卷积层,但是会报错,提示维数至少为4维,这点跟tensorflow一样,第一维是维数,那么应该怎么办呢?我的方法是:
def reader_creator():
def __impl__():
for i in range(len(src_im_test)):
yield [
src_im_test[i],#shape=[3,244,244]
test_desmap[i],#shape=[3,244,244]
test_num[i] #shape=[1]
]
return __impl__
img = fluid.layers.data(name="image", shape=[-1,3, 244, 244])#注意这里要加-1
label = fluid.layers.data(name="label", shape=[-1,1,244, 244])
num = fluid.layers.data(name="num", shape=[1], dtype='int64')
feeder = fluid.DataFeeder(feed_list=[img, label, num], place=fluid.CPUPlace())
reader = paddle.batch(reader_creator(), batch_size=1)
fluid.recordio_writer.convert_reader_to_recordio_file(
"train.recordio", feeder=feeder, reader_creator=reader)
这里把batch_size 设为1,后面读取的时候可以自由组batch_size。
从RecordIO读取数据传入网络
这里是官网代码:
import paddle.fluid as fluid
file_obj = fluid.layers.open_files(
filenames=["train.recordio"],
shape=[[3, 224, 224], [1]],
lod_levels=[0, 0],
dtypes=["float32", "int64"],
pass_num=100
)
image, label = fluid.layers.read_file(file_obj)
对应于上面的官网例子,但是前面说过,这样子有问题(在我这里是不能送入卷积层),下面给出我的读取方法,和上面我的代码相对应:
import paddle.fluid as fluid
file_obj = fluid.layers.open_files(
filenames=["train.recordio"],
shapes = [[-1,3, 244, 244], [-1,1,244, 244],[-1, 1]],
dtypes=['float32','float32','int64'],
lod_levels=[0, 0, 0],
)
file_obj = fluid.layers.batch(file_obj, batch_size=9)
img, des_im, total_num = fluid.layers.read_file(file_obj)#这里的数据可以直接送入网络
conv1 = fluid.layers.conv2d(img, 1, 1)#如果前面保存时不指定-1,这里就会报错
loss =fluid.layers.reduce_mean(fluid.layers.square_error_cost(input=conv1,label=des_im))
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
loss_v, = exe.run(fetch_list=[loss])
print "loss is {}".format(loss_v)
这里需要说明的是,从文件里面读取的数据已经是网络可以识别的数据格式了,有兴趣的话可以fluid.layers.data生成的变量一起print出来看一下,是一样的类型。
结果:
loss is [190564.94]
注意,返回的是一个numpy array,这里可以修改exe.run里面的参数return_numpy=False来决定。
我们再来看一下:
num = exe.run(fetch_list=[total_num])
print "num is {}".format(num)
结果:
num is [[17]
[13]
[61]
[9]
[8]
[9]
[17]
[29]
[9]]
同样返回的也是numpy array,可以看出来是怎么组成batch的。
2018百度之星开发者大赛-paddlepaddle学习(二)将数据保存为recordio文件并读取的更多相关文章
- 2018百度之星开发者大赛-paddlepaddle学习
前言 本次比赛赛题是进行人流密度的估计,因为之前看过很多人体姿态估计和目标检测的论文,隐约感觉到可以用到这次比赛上来,所以趁着现在时间比较多,赶紧报名参加了一下比赛,比赛规定用paddlepaddle ...
- HDU6383 2018 “百度之星”程序设计大赛 - 初赛(B) 1004-p1m2 (二分)
原题地址 p1m2 Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 131072/131072 K (Java/Others)Total ...
- HDU6380 2018 “百度之星”程序设计大赛 - 初赛(B) A-degree (无环图=树)
原题地址 degree Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 131072/131072 K (Java/Others)Tot ...
- 2018"百度之星"程序设计大赛 - 资格赛 - 题集
1001 $ 1 \leq m \leq 10 $ 像是状压的复杂度. 于是我们(用二进制)枚举留下的问题集合 然后把这个集合和问卷们的答案集合 $ & $ 一下 就可以只留下被选中的问题的答 ...
- 2018"百度之星"程序设计大赛 - 资格赛hdu6349三原色(最小生成树)
题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=6349 题目: 三原色图 Time Limit: 1500/1000 MS (Java/Others) ...
- 2018 “百度之星”程序设计大赛 - 初赛(A)
第二题还算手稳+手快?最后勉强挤进前五百(期间看着自己从两百多掉到494名) 1001 度度熊拼三角 (hdoj 6374) 链接:http://acm.hdu.edu.cn/showprob ...
- 2018"百度之星"程序设计大赛 - 资格赛 A/B/E/F
调查问卷 Accepts: 505 Submissions: 2436 Time Limit: 6500/6000 MS (Java/Others) Memory Limit: 262144/ ...
- 2018"百度之星"程序设计大赛 - 资格赛 1002 子串查询
题面又是万能的毒毒熊... 实在不想写了,就只写了这题 记26个前缀和查询枚举最小值直接算 实在是氵的死 而且我忘记输出Case #%d 想了很久 >_< #include<bits ...
- 2018 “百度之星”程序设计大赛 - 初赛(A)度度熊学队列 list rope
c++ list使用 #include <cstdio> #include <cstdlib> #include <cmath> #include <cstr ...
随机推荐
- php json格式化输出
1.json格式是适用于多种语言的数据格式,通用性高 2.在php中将array格式的数据转化为json格式 3.默认情况下转化后的json格式为一个串,需要将这个串格式化成相应的样式输出 主要的函数 ...
- disconf实践(二)基于XML的分布式配置文件管理,不会自动reload
上一篇博文介绍了disconf web的搭建流程,这一篇就介绍disconf client通过配置xml文件来获取disconf管理端的配置信息. 1. 登录管理端,并新建APP,然后上传配置文件 2 ...
- Intellij IDEA 2018.2.2 SpringBoot热启动 (Maven)
一.IDEA 工具配置 1. 打开IDEA 设置界面,选择编译,按图打勾. 2 . 然后 Shift+Ctrl+Alt+/,选择Registry 3 . compiler.automake.allow ...
- JS与OC交互,JS中调用OC方法(获取JSContext的方式)
最近用到JS和OC原生方法调用的问题,查了许多资料都语焉不详,自己记录一下吧,如果有误欢迎联系我指出. JS中调用OC方法有三种方式: 1.通过获取JSContext的方式直接调用OC方法 2.通过继 ...
- ERP系统和MES系统的区别
公司说最近要上一套erp系统,说让我比较一下,erp系统哪个好,还有mes系统,我们适合上哪个系统,其实我还真的不太懂,刚接触erp跟mes的时候,对于两者的概念总是傻傻分不清楚,总是觉得既然都是为企 ...
- 19-3-15Python中闭包,迭代器,递归
函数名的使用 函数名可以当作值赋值给变量 函数名可以当作元素放到容器里 闭包 一个嵌套函数 在嵌套函数内的函数使用外部(非全局的变量) 满足以上两条就是闭包 python中闭包,会进行内存驻留,普通函 ...
- Mariadb相关
centos下Mariadb相关 MariaDB 是一个采用 Maria 存储引擎的MySQL分支版本,是由原来 MySQL 的作者Michael Widenius创办的公司所开发的免费开源的数据库服 ...
- 【Java web 容器resin的安装】
#resin的安装 #启动resin #访问resin监听的java web容器端口 resin修改端口监听号
- es6之扩展运算符 三个点(...)
对象的扩展运算符理解对象的扩展运算符其实很简单,只要记住一句话就可以: 对象中的扩展运算符(...)用于取出参数对象中的所有可遍历属性,拷贝到当前对象之中 let bar = { a: 1, b: 2 ...
- CDH升级 5.7.5 --> 5.13.3(tar包方式)
博客园首发,转载请注明出处:https://www.cnblogs.com/tzxxh/p/9123231.html 一.准备 1.关闭cdh中的服务 hdfs.yarn等所有服务:关闭 cm-ser ...