tensorflow中数据批次划分示例教程
1.简介
将数据划分成若干批次的数据,可以使用tf.train或者tf.data.Dataset中的方法。
1.1 tf.train
tf.train.slice_input_producer(tensor_list,shuffle=True,seed=None,capacity=32)
tf.train.batch(tensors,batch_size,num_threads=1,capacity=32,allow_smaller_final_batch=False)
参数说明:
shuffle:为True时进行数据清洗
allow_smaller_final_batch:为True时将小于batch_size的批次值输出
-------------------------------------------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------------------------------------------
1.2 tf.data.Dataset
tf.data.Dataset是一个类,可以使用以下方法:
from_tensor_slices(tensors)
batch(batch_size,drop_remainder=False)
shuffle(buffer_size,seed=None,reshuffle_each_iteration=None)
repeat(count=None)
make_one_shot_iterator() / get_next()
注:make_one_shot_iterator() / get_next()用于Dataset数据的迭代器
参数说明:
tensors:可以是列表、字典、元组等类型
drop_remainder:为False时表示不保留小于batch_size的批次,否则删除
buffer_size:数据清洗时使用的buffer大小
count:对应为epoch个数,为None时表示数据序列无限延续
2.示例
2.1 使用tf.train.slice_input_producer和tf.train.batch
import tensorflow as tf
import numpy as np
import math # 生成样例数据集
def generate_data():
num = 15
labels = np.asarray(range(num))
images = np.random.random([num, 5, 5, 3])
return images, labels # 打印样例信息
images, labels = generate_data()
print('images.shape={0}, labels.shape={1}'.format(images.shape, labels.shape)) # 定义周期、批次、数据总量和遍历一次所有数据所需的迭代次数
n_epochs = 3
batch_size = 6
train_nums = 15
iterations = math.ceil(train_nums/batch_size) # 使用tf.train.slice_input_producer将所有数据放入队列,使用tf.train.batch划分队列中的数据
input_queue = tf.train.slice_input_producer([images, labels], shuffle=False)
image_batch, label_batch = tf.train.batch(input_queue, batch_size=batch_size, num_threads=1, capacity=32)
print('image_batch.shape={0}, label_batch.shape={1}'.format(image_batch.shape, label_batch.shape)) with tf.Session() as sess:
tf.global_variables_initializer().run()
# 启动队列线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord)
# 打印信息
for epoch in range(n_epochs):
for iteration in range(iterations):
cu_image_batch, cu_label_batch = sess.run([image_batch, label_batch])
print('The {0} epoch, the {1} iteration, current batch is {2}'.format(epoch+1,iteration+1,cu_label_batch))
# 接收线程
coord.request_stop()
coord.join(threads) # 打印结果如下
images.shape=(15, 5, 5, 3), labels.shape=(15,)
image_batch.shape=(6, 5, 5, 3), label_batch.shape=(6,)
The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
The 1 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11]
The 1 epoch, the 3 iteration, current batch is [12 13 14 0 1 2]
The 2 epoch, the 1 iteration, current batch is [3 4 5 6 7 8]
The 2 epoch, the 2 iteration, current batch is [ 9 10 11 12 13 14]
The 2 epoch, the 3 iteration, current batch is [0 1 2 3 4 5]
The 3 epoch, the 1 iteration, current batch is [ 6 7 8 9 10 11]
The 3 epoch, the 2 iteration, current batch is [12 13 14 0 1 2]
The 3 epoch, the 3 iteration, current batch is [3 4 5 6 7 8]
如果tf.train.slice_input_producer(shuffle=True),输出为乱序,结果如下:
images.shape=(15, 5, 5, 3), labels.shape=(15,)
image_batch.shape=(6, 5, 5, 3), label_batch.shape=(6,)
The 1 epoch, the 1 iteration, current batch is [ 2 5 8 11 3 10]
The 1 epoch, the 2 iteration, current batch is [ 9 12 7 1 14 13]
The 1 epoch, the 3 iteration, current batch is [0 6 4 2 3 6]
The 2 epoch, the 1 iteration, current batch is [11 10 12 14 13 5]
The 2 epoch, the 2 iteration, current batch is [8 1 0 9 4 7]
The 2 epoch, the 3 iteration, current batch is [10 13 1 4 12 3]
The 3 epoch, the 1 iteration, current batch is [ 2 8 5 9 14 7]
The 3 epoch, the 2 iteration, current batch is [ 0 11 6 1 14 9]
The 3 epoch, the 3 iteration, current batch is [11 6 12 7 0 13]
如果tf.train.batch(allow_smaller_final_batch=True),则会返回不足批次数目的数据,结果如下:
images.shape=(15, 5, 5, 3), labels.shape=(15,)
The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
The 1 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11]
The 1 epoch, the 3 iteration, current batch is [12 13 14]
The 2 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
The 2 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11]
The 2 epoch, the 3 iteration, current batch is [12 13 14]
The 3 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
The 3 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11]
The 3 epoch, the 3 iteration, current batch is [12 13 14]
2.2 使用tf.data.Dataset类
import tensorflow as tf
import numpy as np
import math # 生成样例数据集
def generate_data():
num = 15
labels = np.asarray(range(num))
images = np.random.random([num, 5, 5, 3])
return images, labels
# 打印样例信息
images, labels = generate_data()
print('images.shape={0}, labels.shape={1}'.format(images.shape, labels.shape)) # 定义周期、批次、数据总数、遍历一次所有数据需的迭代次数
n_epochs = 3
batch_size = 6
train_nums = 15
iterations = math.ceil(train_nums/batch_size) # 使用from_tensor_slices将数据放入队列,使用batch和repeat划分数据批次,且让数据序列无限延续
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset = dataset.batch(batch_size).repeat()
# 使用生成器make_one_shot_iterator和get_next取数据
iterator = dataset.make_one_shot_iterator()
next_iterator = iterator.get_next()
with tf.Session() as sess:
for epoch in range(n_epochs):
for iteration in range(iterations):
cu_image_batch, cu_label_batch = sess.run(next_iterator)
print('The {0} epoch, the {1} iteration, current batch is {2}'.format(epoch+1,iteration+1,cu_label_batch)) # 结果如下:
images.shape=(15, 5, 5, 3), labels.shape=(15,)
The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
The 1 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11]
The 1 epoch, the 3 iteration, current batch is [12 13 14]
The 2 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
The 2 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11]
The 2 epoch, the 3 iteration, current batch is [12 13 14]
The 3 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
The 3 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11]
The 3 epoch, the 3 iteration, current batch is [12 13 14]
使用shuffle(),第23行修改为dataset = dataset.shuffle(100).batch(batch_size).repeat(),结果如下:
images.shape=(15, 5, 5, 3), labels.shape=(15,)
The 1 epoch, the 1 iteration, current batch is [ 7 4 10 8 3 11]
The 1 epoch, the 2 iteration, current batch is [ 0 2 12 13 14 5]
The 1 epoch, the 3 iteration, current batch is [6 9 1]
The 2 epoch, the 1 iteration, current batch is [ 6 14 7 9 3 8]
The 2 epoch, the 2 iteration, current batch is [13 5 12 1 11 2]
The 2 epoch, the 3 iteration, current batch is [ 0 4 10]
The 3 epoch, the 1 iteration, current batch is [10 8 13 12 3 14]
The 3 epoch, the 2 iteration, current batch is [ 6 9 2 5 1 11]
The 3 epoch, the 3 iteration, current batch is [0 4 7]
!!!
tensorflow中数据批次划分示例教程的更多相关文章
- TensorFlow中数据读取之tfrecords
关于Tensorflow读取数据,官网给出了三种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow ...
- 大数据下基于Tensorflow框架的深度学习示例教程
近几年,信息时代的快速发展产生了海量数据,诞生了无数前沿的大数据技术与应用.在当今大数据时代的产业界,商业决策日益基于数据的分析作出.当数据膨胀到一定规模时,基于机器学习对海量复杂数据的分析更能产生较 ...
- TensorFlow中数据读取—如何载入样本
考虑到要是自己去做一个项目,那么第一步是如何把数据导入到代码中,何种形式呢?是否需要做预处理?官网中给的实例mnist,数据导入都是写好的模块,那么自己的数据呢? 一.从文件中读取数据(CSV文件.二 ...
- .NET 5/.NET Core使用EF Core 5连接MySQL数据库写入/读取数据示例教程
本文首发于<.NET 5/.NET Core使用EF Core 5(Entity Framework Core)连接MySQL数据库写入/读取数据示例教程> 前言 在.NET Core/. ...
- [开发技巧]·TensorFlow中numpy与tensor数据相互转化
[开发技巧]·TensorFlow中numpy与tensor数据相互转化 个人主页–> https://xiaosongshine.github.io/ - 问题描述 在我们使用TensorFl ...
- python操作txt文件中数据教程[4]-python去掉txt文件行尾换行
python操作txt文件中数据教程[4]-python去掉txt文件行尾换行 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文章 python操作txt文件中数据教程[1]-使用pyt ...
- python操作txt文件中数据教程[3]-python读取文件夹中所有txt文件并将数据转为csv文件
python操作txt文件中数据教程[3]-python读取文件夹中所有txt文件并将数据转为csv文件 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 python操作txt文件中 ...
- python操作txt文件中数据教程[2]-python提取txt文件
python操作txt文件中数据教程[2]-python提取txt文件中的行列元素 觉得有用的话,欢迎一起讨论相互学习~Follow Me 原始txt文件 程序实现后结果-将txt中元素提取并保存在c ...
- python操作txt文件中数据教程[1]-使用python读写txt文件
python操作txt文件中数据教程[1]-使用python读写txt文件 觉得有用的话,欢迎一起讨论相互学习~Follow Me 原始txt文件 程序实现后结果 程序实现 filename = '. ...
随机推荐
- https证书随记
下载证书之后: 1:域名跳转操作 <system.webServer> <rewrite> <rules> ...
- cherry-pick多个commitid时的顺序说明
有的时候,我们在一个分支上提交了几个commit,然后我们会需要提交到其他分支上,一般情况下,我们会采用的merge的方式来合并分支,另外一种方式是只需要其中几个提交时,我们会cherry-pick到 ...
- 动物管理员--zooKeeper-01
ZooKeeper集群角色介绍: 最典型集群模式:Master/Slave 模式(主备模式).在这种模式中,通常 Master 服务器作为主服务器提供写服务,其他的 Slave 服务器从服务器通过异步 ...
- 微信内置安卓x5浏览器请求超时自动重发问题处理小记
X5内核 请求超时后会自动阻止请求返回并由代理服务器将原参数重新发送请求到服务层代码.但由于第一次请求已经请求到服务器,会导致出现重复下单.支付等重大问题. 该问题由于腾讯x5浏览器会自动阻止第一次 ...
- HTML和CSS总结
语义化.我们学习网页制作时,常常会听到一个词,语义化.那么什么叫做语义化呢,说的通俗点就是:明白每个标签的用途(在什么情况下使用此标签合理)比如,网页上的文章的标题就可以用标题标签,网页上的各个栏目的 ...
- java之导入excel
接口: /** * * Description: 导入excel表 * @param map * @param request * @param session * @return * @author ...
- Java中static的用法解析
知识点1.static关键字a.可以修饰变量,方法,代码块b.修饰的变量和方法可以使用类名.变量名/类名.方法名调用c.static修饰的资源为静态资源,在类加载的时候执行d.在静态方法中只能调用静态 ...
- [c/c++] programming之路(19)、数组指针
一.指针运算 #include<stdio.h> #include<stdlib.h> void main0(){ ; int *p=&a; printf());//变 ...
- day01编程语言,计算机组成: 五大组成部分,计算机三大核心,进制,内存分布图,操作系统
本周内容 第一天: 计算机原理 操作系统 第二天: 编程语言 python入门:环境 - 编辑器 变量 基本数据类型 学习方法 鸡汤 - 干货wwwh:what | why | where | h ...
- 使用JS调用手机本地摄像头或者相册图片识别二维码/条形码
接着昨天的需求,不过这次不依赖微信,使用纯js唤醒手机本地摄像头或者选择手机相册图片,识别其中的二维码或者是条形码.昨天,我使用微信扫一扫识别,效果超棒的.不过如果依赖微信的话,又怎么实现呢,这里介绍 ...