数据读入需求

我们在训练模型参数时想要从训练数据集中一次取出一小批数据(比如50条、100条)做梯度下降,不断地分批取出数据直到损失函数基本不再减小并且在训练集上的正确率足够高,取出的n条数据还要是预处理过的,一次取出的要包含输入数据和对应的lable,并且希望在达到训练效果之前可以不断地取出数据而不会因数据集取空了提前结束训练,最好取出的数据还是乱序的。

基于上面的要求,我们可以利用TensorFlow的dataset模块创建我们所需的数据集。

Dataset简介

TensorFlow程序数据导入的方法有多种。一是通过 feed_dict 传入具体值。二是利用tf的Queues创建数据队列,一次取出batch个数据进行训练,队列可以用多线程读数据,速度比较快,但是队列模块的用法比较复杂,要修改程序的时候就感觉很乱。

Dataset与队列相比就简单多了,Dataset(数据集) API 在 TensorFlow 1.4版本中已经从tf.contrib.data迁移到了tf.data之中,增加了对于Python的生成器的支持,官方强烈建议使用Dataset API 为 TensorFlow模型创建输入管道。

dataset用法

import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))

创建了一个dataset,这个dataset中含有5个元素1….,5,为了将5个元素取出,方法是从Dataset中示例化一个iterator,然后对iterator进行迭代。

iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
for i in range(5):
print(sess.run(one_element))

语句iterator = dataset.make_one_shot_iterator()从dataset中实例化了一个Iterator,这个Iterator是一个“one shot iterator”,即只能从头到尾读取一次。one_element = iterator.get_next()表示从iterator里取出一个元素。这里取5次后dataset里的元素就空了,再取的话就就会抛出tf.errors.OutOfRangeError异常。

除了one-hot iterator,tf还支持其他三种iterator

  • initializable
  • reinitializable
  • feedable

这三个迭代器比one-hot复杂,这里就不介绍他们了。

dataset元素变换

dataset数据集API还有一些操作元素的函数来满足我们的对输入数据的需求。

  • map
  • shuffle
  • batch
  • repeat

1. map

map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,如我们可以对dataset中每个元素的值加1:

def add1(x):
return x+1 dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(add1)

2. shuffle

shuffle的功能为打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小:

dataset = dataset.shuffle(500)

3. batch

使用一次iterator返回一批数据的数量:

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))
dataset = dataset.batch(2)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
for i in range(10):
print(sess.run(one_element)) # 这样就一次获取两个数,可以取3次,第三次取到一个数

4. repeat

上面的代码取3次数就取完了,再取得话就会抛出异常,如果想重复取数,可以用dataset.repeat(count),count的值表示将全部的数在dataset中重复几次:

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))
dataset = dataset.batch(2).repeat(2)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.next()
with tf.Session() as sess:
for i in range(10):
print(sess.run(one_element))

这样就将5个数重复了两遍。这里需要注意的一点是它虽然重复了两次,但并不是可以取5次,一次取两个数,而是:[1,2], [3,4] , [5],  [1,2], [3,4] , [5] 。这样再取到数据集末尾的时候得到的数据数量不是我们设置的batch_size 条数据。要想重复取数并且每次得到的都是batch_size条数据,可以设置batch_size的大小能被总数据量整除。

repeat()中的参数如果是None,则可以无限取数。

读入图片和lable,创建自己的数据集

import tensorflow as tf
import os batch_size = 50
img_resize = [100,100]
epoch_num = None # dataset.repeat() 的参数,设置为None,可以不断取数

# 传入图片名,返回正则化后的图片的像素值
def read_img(img_name, lable):
image = tf.read_file(img_name)
image = tf.image.decode_jpeg(image)
image = tf.image.resize_images(image, img_resize)
image = tf.image.per_image_standardization(image)
return image,lable

# 传入图片所在的文件夹,图片名含有图片的lable,返回利用文件夹中图片创建的dataset
def create_dataset(path):
files = os.listdir(path) # 列出文件夹中所有的图片
img_names = []
lables = []
for f in files:
img_names.append(os.path.join(path,f)) # 图片的完整路径append到文件名list中
lable = f.split('.')[0]
lables.append([int(i) for i in lable]) # 根据规则得到图片的lable img_names = tf.convert_to_tensor(img_names, dtype=tf.string)
lables = tf.convert_to_tensor(lables, dtype=tf.float32) # 将图片名list和lable的list转换成Tensor类型
dataset = tf.data.Dataset.from_tensor_slices((img_names,lables)) # 创建dataset,传入的需要是tensor类型
dataset = dataset.map(read_img) # 传入read_img函数,将图片名转为像素
  
  # 将dataset打乱,设置一次获取batch_size条数据
dataset = dataset.shuffle(buffer_size=800).batch(batch_size).repeat(epoch_num)
return dataset
dataset = create_dataset('./img') # 图片所在的路径为./img
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next() # 创建dataset是batch_size 为多少这里一次就能获取多少个数据

在程序中,sess.run(one_element) 一次就能获取到batch_size条数据和对应的lable

参考链接

https://blog.csdn.net/ssmixi/article/details/80572813

https://www.jianshu.com/p/d80ea5d73446

tensorflow学习笔记--dataset使用,创建自己的数据集的更多相关文章

  1. TensorFlow学习笔记——LeNet-5(训练自己的数据集)

    在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...

  2. tensorflow学习笔记——自编码器及多层感知器

    1,自编码器简介 传统机器学习任务很大程度上依赖于好的特征工程,比如对数值型,日期时间型,种类型等特征的提取.特征工程往往是非常耗时耗力的,在图像,语音和视频中提取到有效的特征就更难了,工程师必须在这 ...

  3. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  4. Tensorflow学习笔记No.5

    tf.data卷积神经网络综合应用实例 使用tf.data建立自己的数据集,并使用CNN卷积神经网络实现对卫星图像的二分类问题. 数据下载链接:https://pan.baidu.com/s/141z ...

  5. Tensorflow学习笔记No.7

    tf.data与自定义训练综合实例 使用tf.data自定义猫狗数据集,并使用自定义训练实现猫狗数据集的分类. 1.使用tf.data创建自定义数据集 我们使用kaggle上的猫狗数据以及tf.dat ...

  6. Tensorflow学习笔记No.8

    使用VGG16网络进行迁移学习 使用在ImageNet数据上预训练的VGG16网络模型对猫狗数据集进行分类识别. 1.预训练网络 预训练网络是一个保存好的,已经在大型数据集上训练好的卷积神经网络. 如 ...

  7. Tensorflow学习笔记No.10

    多输出模型 使用函数式API构建多输出模型完成多标签分类任务. 数据集下载链接:https://pan.baidu.com/s/1JtKt7KCR2lEqAirjIXzvgg 提取码:2kbc 1.读 ...

  8. Tensorflow学习笔记No.11

    图像定位 图像定位是指在图像中将我们需要识别的部分使用定位框进行定位标记,本次主要讲述如何使用tensorflow2.0实现简单的图像定位任务. 我所使用的定位方法是训练神经网络使它输出定位框的四个顶 ...

  9. Tensorflow学习笔记2:About Session, Graph, Operation and Tensor

    简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...

随机推荐

  1. Flask - 上下文管理(核心)

    参考 http://flask.pocoo.org/docs/1.0/advanced_foreword/#thread-locals-in-flask https://zhuanlan.zhihu. ...

  2. ws2_32.dll的妙用与删除 (禁网)

    ws2_32.dll是Windows Sockets应用程序接口,用于支持Internet和网络应用程序.程序运行时会自动调用ws2_32.dll文件,ws2_32.dll是个动态链接库文件位于系统文 ...

  3. springMVC的执行请求过程

    springMVC的运行流程: 1.用户发送请求至前端控制器DispatcherServlet 2.DispatcherServlet收到请求调用HandlerMapping处理器映射器 3.处理器映 ...

  4. 今日份学习: Spring - 事实标准

    笔记 Spring IOC Inverse of Control:控制反转 DI:Dependancy Injections:依赖注入 没有IOC的时候,各种依赖需要逐个按顺序创建. 有了IOC的之后 ...

  5. P1095 解码PAT准考证

    1095 解码PAT准考证 (25分)   PAT 准考证号由 4 部分组成: 第 1 位是级别,即 T 代表顶级:A 代表甲级:B 代表乙级: 第 2~4 位是考场编号,范围从 101 到 999: ...

  6. Python学习第三课——运算符

    # 运算符 + - * / **(幂) %(取余) //(取整) num=9%2 print("余数为"+(str)(num)) #运算结果为 1 num1=9//2 print( ...

  7. 解题报告+板子:luogu P3387 【模板】缩点

    题目链接:P3387 [模板]缩点 缩点板子,所谓\(dp\)就是拓扑排序(毕竟可以重走边),像\(SPFA\)一样松弛就好,就是重边极其烦人,还加了排序(绝对自己想的,然鹅拓扑的思路不是). 下面上 ...

  8. ES6转换ES5

    各大浏览器的最新版本,对 ES6 的支持可以查看kangax.github.io/es5-compat-table/es6/.随着时间的推移,支持度已经越来越高了,超过 90%的 ES6 语法特性都实 ...

  9. Eclipse设置jvm参数的三种方式

    方式1. 修改Elipse运行JRE默认JVM参数打开Eclipse,选择Window--Preferences...在对话框左边的树上双击Java,再双击Installed JREs,在右边选择前面 ...

  10. maven intall 命令用法

    作用:将自定义maven项目  打成maven依赖存放到本地库,我们可以在另一个项目pom文件中加入相应依赖,刷新mavne即可将其加入项目中使用 使用说明:win+R 打开命令窗口,将目录切换至项目 ...