https://blog.csdn.net/lujiandong1/article/details/53991373

方式一:不显示设置读取N个epoch的数据,而是使用循环,每次从训练的文件中随机读取一个batch_size的数据,直至最后读取的数据量达到N个epoch。说明,这个方式来实现epoch的输入是不合理。不是说每个样本都会被读取到的。

对于这个的解释,从数学上解释,比如说有放回的抽样,每次抽取一个样本,抽取N次,总样本数为N个。那么,这样抽取过一轮之后,该样本也是会有1/e的概率没有被抽取到。所以,如果使用这种方式去训练的话,理论上是没有用到全部的数据集去训练的,很可能会造成过拟合的现象。

我做了个小实验验证:

  1.  
    import tensorflow as tf
  2.  
    import numpy as np
  3.  
    import datetime,sys
  4.  
    from tensorflow.contrib import learn
  5.  
    from model import CCPM
  6.  
     
  7.  
    training_epochs = 5
  8.  
    train_num = 4
  9.  
    # 运行Graph
  10.  
    with tf.Session() as sess:
  11.  
     
  12.  
    #定义模型
  13.  
    BATCH_SIZE = 2
  14.  
    # 构建训练数据输入的队列
  15.  
    # 生成一个先入先出队列和一个QueueRunner,生成文件名队列
  16.  
    filenames = ['a.csv']
  17.  
    filename_queue = tf.train.string_input_producer(filenames, shuffle=True)
  18.  
    # 定义Reader
  19.  
    reader = tf.TextLineReader()
  20.  
    key, value = reader.read(filename_queue)
  21.  
    # 定义Decoder
  22.  
    # 编码后的数据字段有24,其中22维是特征字段,2维是lable字段,label是二分类经过one-hot编码后的字段
  23.  
    #更改了特征,使用不同的解析参数
  24.  
    record_defaults = [[1]]*5
  25.  
    col1,col2,col3,col4,col5 = tf.decode_csv(value,record_defaults=record_defaults)
  26.  
    features = tf.pack([col1,col2,col3,col4])
  27.  
    label = tf.pack([col5])
  28.  
     
  29.  
    example_batch, label_batch = tf.train.shuffle_batch([features,label], batch_size=BATCH_SIZE, capacity=20000, min_after_dequeue=4000, num_threads=2)
  30.  
     
  31.  
    sess.run(tf.initialize_all_variables())
  32.  
    coord = tf.train.Coordinator()#创建一个协调器,管理线程
  33.  
    threads = tf.train.start_queue_runners(coord=coord)#启动QueueRunner, 此时文件名队列已经进队。
  34.  
    #开始一个epoch的训练
  35.  
    for epoch in range(training_epochs):
  36.  
    total_batch = int(train_num/BATCH_SIZE)
  37.  
    #开始一个epoch的训练
  38.  
    for i in range(total_batch):
  39.  
    X,Y = sess.run([example_batch, label_batch])
  40.  
    print X,':',Y
  41.  
    coord.request_stop()
  42.  
    coord.join(threads)

toy data a.csv:

说明:输出如下,可以看出并不是每个样本都被遍历5次,其实这样的话,对于DL的训练会产生很大的影响,并不是每个样本都被使用同样的次数。

方式二:显示设置epoch的数目

  1.  
    #-*- coding:utf-8 -*-
  2.  
    import tensorflow as tf
  3.  
    import numpy as np
  4.  
    import datetime,sys
  5.  
    from tensorflow.contrib import learn
  6.  
    from model import CCPM
  7.  
     
  8.  
    training_epochs = 5
  9.  
    train_num = 4
  10.  
    # 运行Graph
  11.  
    with tf.Session() as sess:
  12.  
     
  13.  
    #定义模型
  14.  
    BATCH_SIZE = 2
  15.  
    # 构建训练数据输入的队列
  16.  
    # 生成一个先入先出队列和一个QueueRunner,生成文件名队列
  17.  
    filenames = ['a.csv']
  18.  
    filename_queue = tf.train.string_input_producer(filenames, shuffle=True,num_epochs=training_epochs)
  19.  
    # 定义Reader
  20.  
    reader = tf.TextLineReader()
  21.  
    key, value = reader.read(filename_queue)
  22.  
    # 定义Decoder
  23.  
    # 编码后的数据字段有24,其中22维是特征字段,2维是lable字段,label是二分类经过one-hot编码后的字段
  24.  
    #更改了特征,使用不同的解析参数
  25.  
    record_defaults = [[1]]*5
  26.  
    col1,col2,col3,col4,col5 = tf.decode_csv(value,record_defaults=record_defaults)
  27.  
    features = tf.pack([col1,col2,col3,col4])
  28.  
    label = tf.pack([col5])
  29.  
     
  30.  
    example_batch, label_batch = tf.train.shuffle_batch([features,label], batch_size=BATCH_SIZE, capacity=20000, min_after_dequeue=4000, num_threads=2)
  31.  
    sess.run(tf.initialize_local_variables())
  32.  
    sess.run(tf.initialize_all_variables())
  33.  
    coord = tf.train.Coordinator()#创建一个协调器,管理线程
  34.  
    threads = tf.train.start_queue_runners(coord=coord)#启动QueueRunner, 此时文件名队列已经进队。
  35.  
    try:
  36.  
    #开始一个epoch的训练
  37.  
    while not coord.should_stop():
  38.  
    total_batch = int(train_num/BATCH_SIZE)
  39.  
    #开始一个epoch的训练
  40.  
    for i in range(total_batch):
  41.  
    X,Y = sess.run([example_batch, label_batch])
  42.  
    print X,':',Y
  43.  
    except tf.errors.OutOfRangeError:
  44.  
    print('Done training')
  45.  
    finally:
  46.  
    coord.request_stop()
  47.  
    coord.join(threads)

说明:输出如下,可以看出每个样本都被访问5次,这才是合理的设置epoch数据的方式。


http://stats.stackexchange.com/questions/242004/why-do-neural-network-researchers-care-about-epochs

说明:这个博客也在探讨,为什么深度网络的训练中,要使用epoch,即要把训练样本全部过一遍.而不是随机有放回的从里面抽样batch_size个样本.在博客中,别人的实验结果是如果采用有放回抽样的这种方式来进行SGD的训练.其实网络见不到全部的数据集,推导过程如上所示.所以,网络的收敛速度比较慢.

tesnorflow实现N个epoch训练数据读取的办法的更多相关文章

  1. tensorflow读取训练数据方法

    1. 预加载数据 Preloaded data # coding: utf-8 import tensorflow as tf # 设计Graph x1 = tf.constant([2, 3, 4] ...

  2. TensorFlow Distribution(分布式中的数据读取和训练)

    本文目的 在介绍estimator分布式的时候,官方文档由于版本更新导致与接口不一致.具体是:在estimator分布式当中,使用dataset作为数据输入,在1.12版本中,数据训练只是datase ...

  3. TensorFlow实践笔记(一):数据读取

    本文整理了TensorFlow中的数据读取方法,在TensorFlow中主要有三种方法读取数据: Feeding:由Python提供数据. Preloaded data:预加载数据. Reading ...

  4. 『TensorFlow』数据读取类_data.Dataset

    一.资料 参考原文: TensorFlow全新的数据读取方式:Dataset API入门教程 API接口简介: TensorFlow的数据集 二.背景 注意,在TensorFlow 1.3中,Data ...

  5. tensorflow之数据读取探究(1)

    Tensorflow中之前主要用的数据读取方式主要有: 建立placeholder,然后使用feed_dict将数据feed进placeholder进行使用.使用这种方法十分灵活,可以一下子将所有数据 ...

  6. TensorFlow数据读取方式:Dataset API

    英文详细版参考:https://www.cnblogs.com/jins-note/p/10243716.html Dataset API是TensorFlow 1.3版本中引入的一个新的模块,主要服 ...

  7. tensoflow数据读取

    数据读取 TensorFlow程序读取数据一共有3种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFl ...

  8. TF Boys (TensorFlow Boys ) 养成记(二): TensorFlow 数据读取

    TensorFlow 的 How-Tos,讲解了这么几点: 1. 变量:创建,初始化,保存,加载,共享: 2. TensorFlow 的可视化学习,(r0.12版本后,加入了Embedding Vis ...

  9. 详解Tensorflow数据读取有三种方式(next_batch)

    转自:https://blog.csdn.net/lujiandong1/article/details/53376802 Tensorflow数据读取有三种方式: Preloaded data: 预 ...

随机推荐

  1. ARM汇编编程基础之一 —— 寄存器

    ARM的汇编编程,本质上就是针对CPU寄存器的编程,所以我们首先要弄清楚ARM有哪些寄存器?这些寄存器都是如何使用的? ARM寄存器分为2类,普通寄存器和状态寄存器 寄存器类别 寄存器在汇编中的名称 ...

  2. 【来龙去脉系列】深入理解DIP、IoC、DI以及IoC容器

    摘要 面向对象设计(OOD)有助于我们开发出高性能.易扩展以及易复用的程序.其中,OOD有一个重要的思想那就是依赖倒置原则(DIP),并由此引申出IoC.DI以及Ioc容器等概念.通过本文我们将一起学 ...

  3. Freescale OSBDM JM60仿真器 BGND Interface

    The BGND interface provides the standard 6 pin connection for the single wire BGND signal type devel ...

  4. USBDM Kinetis Debugger and Programmer

    Introduction The FRM-xxxx boards from Freescale includes a minimal SWD based debugging interface for ...

  5. AES CBC/CTR 加解密原理

    So, lets look at how CBC works first. The following picture shows the encryption when using CBC (in ...

  6. SQL Server DATEDIFF() 函数(SQL计算时间差)

    select  *   from   task_list  where 1=1 and    datediff(dd,carateTime,getdate()) =0      定义和用法 DATED ...

  7. linux 内核升级2 转

    linux内核升级 一.Linux内核概览 Linux是一个一体化内核(monolithic kernel)系统. 设备驱动程序可以完全访问硬件. Linux内的设备驱动程序可以方便地以模块化(mod ...

  8. Embarcadero RAD Studio XE5

    英巴卡迪诺 RAD Studio XE是终极应用程序开发套件,能以最快速方式为Windows.Mac OS X. .NET. PHP. Web和移动设备可视化开发数据丰富.界面美观的跨平台应用程序.R ...

  9. having只用来在group by之后,having不可单独用,必须和group by用。having只能对group by的结果进行操作

    having只能对group by的结果进行操作 having只能对group by的结果进行操作 having只能对group by的结果进行操作 having只用来在group by之后,havi ...

  10. 利用MPMoviePlayerViewController 播放视频 iOS

    方法一: @property (nonatomic, strong) MPMoviePlayerController *player; NSString *url = [[NSBundle mainB ...