tesnorflow实现N个epoch训练数据读取的办法
https://blog.csdn.net/lujiandong1/article/details/53991373
方式一:不显示设置读取N个epoch的数据,而是使用循环,每次从训练的文件中随机读取一个batch_size的数据,直至最后读取的数据量达到N个epoch。说明,这个方式来实现epoch的输入是不合理。不是说每个样本都会被读取到的。
对于这个的解释,从数学上解释,比如说有放回的抽样,每次抽取一个样本,抽取N次,总样本数为N个。那么,这样抽取过一轮之后,该样本也是会有1/e的概率没有被抽取到。所以,如果使用这种方式去训练的话,理论上是没有用到全部的数据集去训练的,很可能会造成过拟合的现象。
我做了个小实验验证:
- import tensorflow as tf
- import numpy as np
- import datetime,sys
- from tensorflow.contrib import learn
- from model import CCPM
- training_epochs = 5
- train_num = 4
- # 运行Graph
- with tf.Session() as sess:
- #定义模型
- BATCH_SIZE = 2
- # 构建训练数据输入的队列
- # 生成一个先入先出队列和一个QueueRunner,生成文件名队列
- filenames = ['a.csv']
- filename_queue = tf.train.string_input_producer(filenames, shuffle=True)
- # 定义Reader
- reader = tf.TextLineReader()
- key, value = reader.read(filename_queue)
- # 定义Decoder
- # 编码后的数据字段有24,其中22维是特征字段,2维是lable字段,label是二分类经过one-hot编码后的字段
- #更改了特征,使用不同的解析参数
- record_defaults = [[1]]*5
- col1,col2,col3,col4,col5 = tf.decode_csv(value,record_defaults=record_defaults)
- features = tf.pack([col1,col2,col3,col4])
- label = tf.pack([col5])
- example_batch, label_batch = tf.train.shuffle_batch([features,label], batch_size=BATCH_SIZE, capacity=20000, min_after_dequeue=4000, num_threads=2)
- sess.run(tf.initialize_all_variables())
- coord = tf.train.Coordinator()#创建一个协调器,管理线程
- threads = tf.train.start_queue_runners(coord=coord)#启动QueueRunner, 此时文件名队列已经进队。
- #开始一个epoch的训练
- for epoch in range(training_epochs):
- total_batch = int(train_num/BATCH_SIZE)
- #开始一个epoch的训练
- for i in range(total_batch):
- X,Y = sess.run([example_batch, label_batch])
- print X,':',Y
- coord.request_stop()
- coord.join(threads)
toy data a.csv:
说明:输出如下,可以看出并不是每个样本都被遍历5次,其实这样的话,对于DL的训练会产生很大的影响,并不是每个样本都被使用同样的次数。
方式二:显示设置epoch的数目
- #-*- coding:utf-8 -*-
- import tensorflow as tf
- import numpy as np
- import datetime,sys
- from tensorflow.contrib import learn
- from model import CCPM
- training_epochs = 5
- train_num = 4
- # 运行Graph
- with tf.Session() as sess:
- #定义模型
- BATCH_SIZE = 2
- # 构建训练数据输入的队列
- # 生成一个先入先出队列和一个QueueRunner,生成文件名队列
- filenames = ['a.csv']
- filename_queue = tf.train.string_input_producer(filenames, shuffle=True,num_epochs=training_epochs)
- # 定义Reader
- reader = tf.TextLineReader()
- key, value = reader.read(filename_queue)
- # 定义Decoder
- # 编码后的数据字段有24,其中22维是特征字段,2维是lable字段,label是二分类经过one-hot编码后的字段
- #更改了特征,使用不同的解析参数
- record_defaults = [[1]]*5
- col1,col2,col3,col4,col5 = tf.decode_csv(value,record_defaults=record_defaults)
- features = tf.pack([col1,col2,col3,col4])
- label = tf.pack([col5])
- example_batch, label_batch = tf.train.shuffle_batch([features,label], batch_size=BATCH_SIZE, capacity=20000, min_after_dequeue=4000, num_threads=2)
- sess.run(tf.initialize_local_variables())
- sess.run(tf.initialize_all_variables())
- coord = tf.train.Coordinator()#创建一个协调器,管理线程
- threads = tf.train.start_queue_runners(coord=coord)#启动QueueRunner, 此时文件名队列已经进队。
- try:
- #开始一个epoch的训练
- while not coord.should_stop():
- total_batch = int(train_num/BATCH_SIZE)
- #开始一个epoch的训练
- for i in range(total_batch):
- X,Y = sess.run([example_batch, label_batch])
- print X,':',Y
- except tf.errors.OutOfRangeError:
- print('Done training')
- finally:
- coord.request_stop()
- coord.join(threads)
说明:输出如下,可以看出每个样本都被访问5次,这才是合理的设置epoch数据的方式。
http://stats.stackexchange.com/questions/242004/why-do-neural-network-researchers-care-about-epochs
说明:这个博客也在探讨,为什么深度网络的训练中,要使用epoch,即要把训练样本全部过一遍.而不是随机有放回的从里面抽样batch_size个样本.在博客中,别人的实验结果是如果采用有放回抽样的这种方式来进行SGD的训练.其实网络见不到全部的数据集,推导过程如上所示.所以,网络的收敛速度比较慢.
tesnorflow实现N个epoch训练数据读取的办法的更多相关文章
- tensorflow读取训练数据方法
1. 预加载数据 Preloaded data # coding: utf-8 import tensorflow as tf # 设计Graph x1 = tf.constant([2, 3, 4] ...
- TensorFlow Distribution(分布式中的数据读取和训练)
本文目的 在介绍estimator分布式的时候,官方文档由于版本更新导致与接口不一致.具体是:在estimator分布式当中,使用dataset作为数据输入,在1.12版本中,数据训练只是datase ...
- TensorFlow实践笔记(一):数据读取
本文整理了TensorFlow中的数据读取方法,在TensorFlow中主要有三种方法读取数据: Feeding:由Python提供数据. Preloaded data:预加载数据. Reading ...
- 『TensorFlow』数据读取类_data.Dataset
一.资料 参考原文: TensorFlow全新的数据读取方式:Dataset API入门教程 API接口简介: TensorFlow的数据集 二.背景 注意,在TensorFlow 1.3中,Data ...
- tensorflow之数据读取探究(1)
Tensorflow中之前主要用的数据读取方式主要有: 建立placeholder,然后使用feed_dict将数据feed进placeholder进行使用.使用这种方法十分灵活,可以一下子将所有数据 ...
- TensorFlow数据读取方式:Dataset API
英文详细版参考:https://www.cnblogs.com/jins-note/p/10243716.html Dataset API是TensorFlow 1.3版本中引入的一个新的模块,主要服 ...
- tensoflow数据读取
数据读取 TensorFlow程序读取数据一共有3种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFl ...
- TF Boys (TensorFlow Boys ) 养成记(二): TensorFlow 数据读取
TensorFlow 的 How-Tos,讲解了这么几点: 1. 变量:创建,初始化,保存,加载,共享: 2. TensorFlow 的可视化学习,(r0.12版本后,加入了Embedding Vis ...
- 详解Tensorflow数据读取有三种方式(next_batch)
转自:https://blog.csdn.net/lujiandong1/article/details/53376802 Tensorflow数据读取有三种方式: Preloaded data: 预 ...
随机推荐
- SourceTree 的初次使用的两个小问题
菜鸟才开始使用SourceTree,出现了两个小问题,特此整理一下,希望对各位新手有帮助.刚开始以为装了SourceTree就不用装git了,其实不然,不装git就会出现下面第一个问题: 1.新手使用 ...
- C#内存映射文件消息队列实战演练(MMF—MQ)
一.课程介绍 本次分享课程属于<C#高级编程实战技能开发宝典课程系列>中的一部分,阿笨后续会计划将实际项目中的一些比较实用的关于C#高级编程的技巧分享出来给大家进行学习,不断的收集.整理和 ...
- 在qemu模拟的aarch32上使用kgtp
KGTP 介绍 KGTP 是一个能在产品系统上实时分析 Linux 内核和应用程序(包括 Android)问题的全面动态跟踪器. 使用 KGTP 不需要 在 Linux 内核上打 PATCH 或者重新 ...
- ASP.NET Web API实践系列05,消息处理管道
ASP.NET Web API的消息处理管道可以理解为请求到达Controller之前.Controller返回响应之后的处理机制.之所以需要了解消息处理管道,是因为我们可以借助它来实现对请求和响应的 ...
- Delphi取UTC时间秒
自格林威治标准时间1970年1月1日00:00:00 至现在经过多少秒数时间模块Uses DateUtils;当前时间:中国是 +8时区,换成UTC 就要减掉8小时showMessage(intt ...
- python测试开发django-28.发送邮件send_mail
前言 django发邮件的功能很简单,只需简单的配置即可,发邮件的代码里面已经封装好了,调用send_mail()函数就可以了 实现多个邮件发送可以用send_mass_mail()函数 send_m ...
- python笔记33-python3连mysql增删改查
前言 做自动化测试的时候,注册了一个新用户,产生了多余的数据,下次用同一账号就无法注册了,这种情况该怎么办呢? 自动化测试都有个数据准备和数据清理的操作,如果因为此用例产生了多余的数据,就需要数据清理 ...
- FEC详解三
转自:http://blog.csdn.net/Stone_OverLooking/article/details/77752076 继续上文讲解: 3) 标准的RTP头结构如下所示: 其中第一个字节 ...
- HDR和bloom效果的区别和关系
什么是HDR? 谈论游戏画面时常说的HDR到底是什么呢?HDR,本身是High-Dynamic Range(高动态范围)的缩写,这本来是一个CG概念.HDR的含义,简单说,就是超越普通的 ...
- left join 注意事项
相信对于熟悉SQL的人来说,LEFT JOIN非常简单,采用的时候也很多,但是有个问题还是需要注意一下.假如一个主表M有多个从表的话A B C …..的话,并且每个表都有筛选条件,那么把筛选条件放到哪 ...