数据读入需求

我们在训练模型参数时想要从训练数据集中一次取出一小批数据(比如50条、100条)做梯度下降,不断地分批取出数据直到损失函数基本不再减小并且在训练集上的正确率足够高,取出的n条数据还要是预处理过的,一次取出的要包含输入数据和对应的lable,并且希望在达到训练效果之前可以不断地取出数据而不会因数据集取空了提前结束训练,最好取出的数据还是乱序的。

基于上面的要求,我们可以利用TensorFlow的dataset模块创建我们所需的数据集。

Dataset简介

TensorFlow程序数据导入的方法有多种。一是通过 feed_dict 传入具体值。二是利用tf的Queues创建数据队列,一次取出batch个数据进行训练,队列可以用多线程读数据,速度比较快,但是队列模块的用法比较复杂,要修改程序的时候就感觉很乱。

Dataset与队列相比就简单多了,Dataset(数据集) API 在 TensorFlow 1.4版本中已经从tf.contrib.data迁移到了tf.data之中,增加了对于Python的生成器的支持,官方强烈建议使用Dataset API 为 TensorFlow模型创建输入管道。

dataset用法

import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))

创建了一个dataset,这个dataset中含有5个元素1….,5,为了将5个元素取出,方法是从Dataset中示例化一个iterator,然后对iterator进行迭代。

iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
for i in range(5):
print(sess.run(one_element))

语句iterator = dataset.make_one_shot_iterator()从dataset中实例化了一个Iterator,这个Iterator是一个“one shot iterator”,即只能从头到尾读取一次。one_element = iterator.get_next()表示从iterator里取出一个元素。这里取5次后dataset里的元素就空了,再取的话就就会抛出tf.errors.OutOfRangeError异常。

除了one-hot iterator,tf还支持其他三种iterator

  • initializable
  • reinitializable
  • feedable

这三个迭代器比one-hot复杂,这里就不介绍他们了。

dataset元素变换

dataset数据集API还有一些操作元素的函数来满足我们的对输入数据的需求。

  • map
  • shuffle
  • batch
  • repeat

1. map

map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,如我们可以对dataset中每个元素的值加1:

def add1(x):
return x+1 dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(add1)

2. shuffle

shuffle的功能为打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小:

dataset = dataset.shuffle(500)

3. batch

使用一次iterator返回一批数据的数量:

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))
dataset = dataset.batch(2)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
for i in range(10):
print(sess.run(one_element)) # 这样就一次获取两个数,可以取3次,第三次取到一个数

4. repeat

上面的代码取3次数就取完了,再取得话就会抛出异常,如果想重复取数,可以用dataset.repeat(count),count的值表示将全部的数在dataset中重复几次:

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))
dataset = dataset.batch(2).repeat(2)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.next()
with tf.Session() as sess:
for i in range(10):
print(sess.run(one_element))

这样就将5个数重复了两遍。这里需要注意的一点是它虽然重复了两次,但并不是可以取5次,一次取两个数,而是:[1,2], [3,4] , [5],  [1,2], [3,4] , [5] 。这样再取到数据集末尾的时候得到的数据数量不是我们设置的batch_size 条数据。要想重复取数并且每次得到的都是batch_size条数据,可以设置batch_size的大小能被总数据量整除。

repeat()中的参数如果是None,则可以无限取数。

读入图片和lable,创建自己的数据集

import tensorflow as tf
import os batch_size = 50
img_resize = [100,100]
epoch_num = None # dataset.repeat() 的参数,设置为None,可以不断取数

# 传入图片名,返回正则化后的图片的像素值
def read_img(img_name, lable):
image = tf.read_file(img_name)
image = tf.image.decode_jpeg(image)
image = tf.image.resize_images(image, img_resize)
image = tf.image.per_image_standardization(image)
return image,lable

# 传入图片所在的文件夹,图片名含有图片的lable,返回利用文件夹中图片创建的dataset
def create_dataset(path):
files = os.listdir(path) # 列出文件夹中所有的图片
img_names = []
lables = []
for f in files:
img_names.append(os.path.join(path,f)) # 图片的完整路径append到文件名list中
lable = f.split('.')[0]
lables.append([int(i) for i in lable]) # 根据规则得到图片的lable img_names = tf.convert_to_tensor(img_names, dtype=tf.string)
lables = tf.convert_to_tensor(lables, dtype=tf.float32) # 将图片名list和lable的list转换成Tensor类型
dataset = tf.data.Dataset.from_tensor_slices((img_names,lables)) # 创建dataset,传入的需要是tensor类型
dataset = dataset.map(read_img) # 传入read_img函数,将图片名转为像素
  
  # 将dataset打乱,设置一次获取batch_size条数据
dataset = dataset.shuffle(buffer_size=800).batch(batch_size).repeat(epoch_num)
return dataset
dataset = create_dataset('./img') # 图片所在的路径为./img
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next() # 创建dataset是batch_size 为多少这里一次就能获取多少个数据

在程序中,sess.run(one_element) 一次就能获取到batch_size条数据和对应的lable

参考链接

https://blog.csdn.net/ssmixi/article/details/80572813

https://www.jianshu.com/p/d80ea5d73446

tensorflow学习笔记--dataset使用,创建自己的数据集的更多相关文章

  1. TensorFlow学习笔记——LeNet-5(训练自己的数据集)

    在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...

  2. tensorflow学习笔记——自编码器及多层感知器

    1,自编码器简介 传统机器学习任务很大程度上依赖于好的特征工程,比如对数值型,日期时间型,种类型等特征的提取.特征工程往往是非常耗时耗力的,在图像,语音和视频中提取到有效的特征就更难了,工程师必须在这 ...

  3. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  4. Tensorflow学习笔记No.5

    tf.data卷积神经网络综合应用实例 使用tf.data建立自己的数据集,并使用CNN卷积神经网络实现对卫星图像的二分类问题. 数据下载链接:https://pan.baidu.com/s/141z ...

  5. Tensorflow学习笔记No.7

    tf.data与自定义训练综合实例 使用tf.data自定义猫狗数据集,并使用自定义训练实现猫狗数据集的分类. 1.使用tf.data创建自定义数据集 我们使用kaggle上的猫狗数据以及tf.dat ...

  6. Tensorflow学习笔记No.8

    使用VGG16网络进行迁移学习 使用在ImageNet数据上预训练的VGG16网络模型对猫狗数据集进行分类识别. 1.预训练网络 预训练网络是一个保存好的,已经在大型数据集上训练好的卷积神经网络. 如 ...

  7. Tensorflow学习笔记No.10

    多输出模型 使用函数式API构建多输出模型完成多标签分类任务. 数据集下载链接:https://pan.baidu.com/s/1JtKt7KCR2lEqAirjIXzvgg 提取码:2kbc 1.读 ...

  8. Tensorflow学习笔记No.11

    图像定位 图像定位是指在图像中将我们需要识别的部分使用定位框进行定位标记,本次主要讲述如何使用tensorflow2.0实现简单的图像定位任务. 我所使用的定位方法是训练神经网络使它输出定位框的四个顶 ...

  9. Tensorflow学习笔记2:About Session, Graph, Operation and Tensor

    简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...

随机推荐

  1. UniGUI设置背景图片(09)

    主要是Background和LoginBackground属性, 类似地Login窗口背景图也可这样修改 UniServerModule.MainFormDisplayMode:=  mfPage;/ ...

  2. java 获取(格式化)日期格式

    // 参考: https://www.cnblogs.com/blog5277/p/6407463.htmlpublic class DateTest { // 支持时分秒 private stati ...

  3. 虚拟机下安装 VMwareTools 实现宿主机和虚拟机的文件共享

    $ mount /dev/sr0 /media/ #点击 虚拟机 安装 VMwareTools 挂载 $ cd /media/ $ cp VMwareTools-10.1.6-5214329.tar. ...

  4. LeetCode中等题(一)

    题目一: 给出两个 非空 的链表用来表示两个非负的整数.其中,它们各自的位数是按照 逆序 的方式存储的,并且它们的每个节点只能存储 一位 数字. 如果,我们将这两个数相加起来,则会返回一个新的链表来表 ...

  5. 伪类:after,:before的用法

    :after和:before是css3中的伪类元素.用法是像元素的前或者后插入元素.以after为例: li:after{ content: ''; color: #ff0000; } 意思是向li元 ...

  6. VUE 父子组件之间通信传值 props和 $emit

    1.父组件传值给子组件 $props,子组件传值给父组件 $emit 父组件          <div id="app" >               <tr ...

  7. Hibernate一对多(多对一)外键设置汇总

    我打算在角色表(role)中添加一个帐号表(account)的外键(accountId),步骤如下: 1.首先在角色表(role)中添加列. 添加语句:alter table role add(acc ...

  8. Xeon 第一次训练赛 苏州大学ICPC集训队新生赛第二场(同步赛) [Cloned]

    A.给出一个字符串,求出连续的权值递增和,断开以后权值重新计数,水题 #include<iostream> #include<string> #include<cmath ...

  9. SpringCloud实战——(1)创建SpringCloud项目

    首先创建一个SpirngCloud工程,并添加公用依赖. <?xml version="1.0" encoding="UTF-8"?> <pr ...

  10. Ubuntu安装Orcale

    Linux_Ubuntu安装oracle总结 ---------转自 https://www.2cto.com/database/201305/215338.html 话说我花了一晚上才在ubuntu ...