tensorflow学习笔记--dataset使用,创建自己的数据集
数据读入需求
我们在训练模型参数时想要从训练数据集中一次取出一小批数据(比如50条、100条)做梯度下降,不断地分批取出数据直到损失函数基本不再减小并且在训练集上的正确率足够高,取出的n条数据还要是预处理过的,一次取出的要包含输入数据和对应的lable,并且希望在达到训练效果之前可以不断地取出数据而不会因数据集取空了提前结束训练,最好取出的数据还是乱序的。
基于上面的要求,我们可以利用TensorFlow的dataset模块创建我们所需的数据集。
Dataset简介
TensorFlow程序数据导入的方法有多种。一是通过 feed_dict 传入具体值。二是利用tf的Queues创建数据队列,一次取出batch个数据进行训练,队列可以用多线程读数据,速度比较快,但是队列模块的用法比较复杂,要修改程序的时候就感觉很乱。
Dataset与队列相比就简单多了,Dataset(数据集) API 在 TensorFlow 1.4版本中已经从tf.contrib.data迁移到了tf.data之中,增加了对于Python的生成器的支持,官方强烈建议使用Dataset API 为 TensorFlow模型创建输入管道。
dataset用法
import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))
创建了一个dataset,这个dataset中含有5个元素1….,5,为了将5个元素取出,方法是从Dataset中示例化一个iterator,然后对iterator进行迭代。
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
for i in range(5):
print(sess.run(one_element))
语句iterator = dataset.make_one_shot_iterator()从dataset中实例化了一个Iterator,这个Iterator是一个“one shot iterator”,即只能从头到尾读取一次。one_element = iterator.get_next()表示从iterator里取出一个元素。这里取5次后dataset里的元素就空了,再取的话就就会抛出tf.errors.OutOfRangeError异常。
除了one-hot iterator,tf还支持其他三种iterator
- initializable
- reinitializable
- feedable
这三个迭代器比one-hot复杂,这里就不介绍他们了。
dataset元素变换
dataset数据集API还有一些操作元素的函数来满足我们的对输入数据的需求。
- map
- shuffle
- batch
- repeat
1. map
map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,如我们可以对dataset中每个元素的值加1:
def add1(x):
return x+1 dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(add1)
2. shuffle
shuffle的功能为打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小:
dataset = dataset.shuffle(500)
3. batch
使用一次iterator返回一批数据的数量:
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))
dataset = dataset.batch(2)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
for i in range(10):
print(sess.run(one_element)) # 这样就一次获取两个数,可以取3次,第三次取到一个数
4. repeat
上面的代码取3次数就取完了,再取得话就会抛出异常,如果想重复取数,可以用dataset.repeat(count),count的值表示将全部的数在dataset中重复几次:
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))
dataset = dataset.batch(2).repeat(2)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.next()
with tf.Session() as sess:
for i in range(10):
print(sess.run(one_element))
这样就将5个数重复了两遍。这里需要注意的一点是它虽然重复了两次,但并不是可以取5次,一次取两个数,而是:[1,2], [3,4] , [5], [1,2], [3,4] , [5] 。这样再取到数据集末尾的时候得到的数据数量不是我们设置的batch_size 条数据。要想重复取数并且每次得到的都是batch_size条数据,可以设置batch_size的大小能被总数据量整除。
repeat()中的参数如果是None,则可以无限取数。
读入图片和lable,创建自己的数据集
import tensorflow as tf
import os batch_size = 50
img_resize = [100,100]
epoch_num = None # dataset.repeat() 的参数,设置为None,可以不断取数
# 传入图片名,返回正则化后的图片的像素值
def read_img(img_name, lable):
image = tf.read_file(img_name)
image = tf.image.decode_jpeg(image)
image = tf.image.resize_images(image, img_resize)
image = tf.image.per_image_standardization(image)
return image,lable
# 传入图片所在的文件夹,图片名含有图片的lable,返回利用文件夹中图片创建的dataset
def create_dataset(path):
files = os.listdir(path) # 列出文件夹中所有的图片
img_names = []
lables = []
for f in files:
img_names.append(os.path.join(path,f)) # 图片的完整路径append到文件名list中
lable = f.split('.')[0]
lables.append([int(i) for i in lable]) # 根据规则得到图片的lable img_names = tf.convert_to_tensor(img_names, dtype=tf.string)
lables = tf.convert_to_tensor(lables, dtype=tf.float32) # 将图片名list和lable的list转换成Tensor类型
dataset = tf.data.Dataset.from_tensor_slices((img_names,lables)) # 创建dataset,传入的需要是tensor类型
dataset = dataset.map(read_img) # 传入read_img函数,将图片名转为像素
# 将dataset打乱,设置一次获取batch_size条数据
dataset = dataset.shuffle(buffer_size=800).batch(batch_size).repeat(epoch_num)
return dataset
dataset = create_dataset('./img') # 图片所在的路径为./img
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next() # 创建dataset是batch_size 为多少这里一次就能获取多少个数据
在程序中,sess.run(one_element) 一次就能获取到batch_size条数据和对应的lable
参考链接
https://blog.csdn.net/ssmixi/article/details/80572813
https://www.jianshu.com/p/d80ea5d73446
tensorflow学习笔记--dataset使用,创建自己的数据集的更多相关文章
- TensorFlow学习笔记——LeNet-5(训练自己的数据集)
在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...
- tensorflow学习笔记——自编码器及多层感知器
1,自编码器简介 传统机器学习任务很大程度上依赖于好的特征工程,比如对数值型,日期时间型,种类型等特征的提取.特征工程往往是非常耗时耗力的,在图像,语音和视频中提取到有效的特征就更难了,工程师必须在这 ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)
续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...
- Tensorflow学习笔记No.5
tf.data卷积神经网络综合应用实例 使用tf.data建立自己的数据集,并使用CNN卷积神经网络实现对卫星图像的二分类问题. 数据下载链接:https://pan.baidu.com/s/141z ...
- Tensorflow学习笔记No.7
tf.data与自定义训练综合实例 使用tf.data自定义猫狗数据集,并使用自定义训练实现猫狗数据集的分类. 1.使用tf.data创建自定义数据集 我们使用kaggle上的猫狗数据以及tf.dat ...
- Tensorflow学习笔记No.8
使用VGG16网络进行迁移学习 使用在ImageNet数据上预训练的VGG16网络模型对猫狗数据集进行分类识别. 1.预训练网络 预训练网络是一个保存好的,已经在大型数据集上训练好的卷积神经网络. 如 ...
- Tensorflow学习笔记No.10
多输出模型 使用函数式API构建多输出模型完成多标签分类任务. 数据集下载链接:https://pan.baidu.com/s/1JtKt7KCR2lEqAirjIXzvgg 提取码:2kbc 1.读 ...
- Tensorflow学习笔记No.11
图像定位 图像定位是指在图像中将我们需要识别的部分使用定位框进行定位标记,本次主要讲述如何使用tensorflow2.0实现简单的图像定位任务. 我所使用的定位方法是训练神经网络使它输出定位框的四个顶 ...
- Tensorflow学习笔记2:About Session, Graph, Operation and Tensor
简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...
随机推荐
- 「POJ1734」Sightseeing trip
「POJ1734」Sightseeing trip 传送门 这题就是要我们求一个最小环并且按顺序输出一组解. 考虑 \(O(n^3)\) 地用 \(\text{Floyd}\) 求最小环: 考虑 \( ...
- FFmpeg笔记-基本使用
FFmpeg是目前最牛逼的开源跨平台音视频处理工具. 准备知识 我不是音视频编解码出身的,对于这一块非常的不了解,导致在学习FFmpeg的时候云里雾里的,所以学习之前最好看些资料对音视频编解码有点认识 ...
- 吴裕雄 Bootstrap 前端框架开发——Bootstrap 辅助类:"text-success" 类的文本样式
<!DOCTYPE html> <html> <head> <meta charset="utf-8"> <title> ...
- CentOS7 安装PHP7的swoole扩展:
一.绪 Swoole简介 PHP异步网络通信引擎 最终编译为so文件作为PHP的扩展 准备工作 Linux环境 PHP7 swoole2.1 redis 源码安装PHP7 源码安装swoole htt ...
- docker学习笔记-06:自定义DockerFile生成镜像
一.自定义centos的DockerFile 1.从阿里源里拉的centos镜像新建的容器实例中,没有vim编辑器和ifconfig命令,所以自定义centos的DockerFile,创建自己想要的镜 ...
- redis有序集合-zset
概念:它是在set的基础上增加了一个顺序属性,这一属性在添加修改元素的时候可以指定,每次指定后,zset会自动按新的值调整顺序.可以理解为有两列的mysql表,一列存储value,一列存储顺序,操作中 ...
- 前端学习笔记系列一:5 在项目中引入阿里图标icon
进入到阿里的图标库网站,里面有上百万种icon,https://www.iconfont.cn,需要注册一个帐号,然后进入到这个页面,在这里点击右下角的带加号的图标,创建一个新的项目,名称与你要使用图 ...
- 南邮CG-CTF Web记录
MYSQL(利用精度,传参为小数) robots.txt中的代码: <?php if($_GET[id]) { mysql_connect(SAE_MYSQL_HOST_M . ':' . SA ...
- vue 父组件向子组件传参(笔记)
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8&quo ...
- LR_问题_平均响应时间解释,summary与analysis不一致----Summary Report中的时间说明
Summary是按整个场景的时间来做平均的,最大最小值,也是从整个场景中取出来的. (1) 平均响应时间:事物全部响应时间做平均计算 (2) 90%响应时间:将事物全部响应时间 ...