对数据集的shuffle处理需要设置相应的buffer_size参数,相当于需要将相应数目的样本读入内存,且这部分内存会在训练过程中一直保持占用。完全的shuffle需要将整个数据集读入内存,这在大规模数据集的情况下是不现实的,故需要结合设备内存以及Batch大小将TFRecord文件随机划分为多个子文件,再对数据集做local shuffle(即设置相对较小的buffer_size,不小于单个子文件的样本数)。

Shuffle和划分

下文以一个异常检测数据集(正负样本不平衡)为例,在生成第一批TFRecord时,我将正负样本分别写入单独的TFrecord文件以备后续在对正负样本有不同处理策略的情况下无需再解析example_proto。比如在以下代码中,我对正负样本有不同的验证集比例,并将他们写入不同的验证集文件。

import numpy as np
import tensorflow as tf
from tqdm.notebook import tqdm as tqdm # TFRecord划分
raw_normal_dataset = tf.data.TFRecordDataset("normal_16_256.tfrecords","GZIP")
raw_anomaly_dataset = tf.data.TFRecordDataset("anomaly_16_256.tfrecords","GZIP")
normal_val_writer = tf.io.TFRecordWriter(r'ex_1/'+'normal_val_16_256.tfrecords',"GZIP")
anomaly_val_writer = tf.io.TFRecordWriter(r'ex_1/'+'anomaly_val_16_256.tfrecords',"GZIP")
train_writer_list = [tf.io.TFRecordWriter(r'ex_1/'+'train_16_256_{}.tfrecords'.format(i),"GZIP") for i in range(SUBFILE_NUM+1)]
with tqdm(total=LEN_NORMAL_DATASET+LEN_ANOMALY_DATASET) as pbar:
for example_proto in raw_normal_dataset:
# 划分训练集和测试集
if np.random.random() > 0.99: # 正样本测试集的比例
normal_val_writer.write(example_proto.numpy())
else:
train_writer_list[np.random.randint(0,SUBFILE_NUM+1)].write(example_proto.numpy())
pbar.update(1) for example_proto in raw_anomaly_dataset:
# 划分训练集和测试集
if np.random.random() > 0.7: # 负样本测试集的比例
anomaly_val_writer.write(example_proto.numpy())
else:
train_writer_list[np.random.randint(0,SUBFILE_NUM+1)].write(example_proto.numpy())
pbar.update(1)
normal_val_writer.close()
anomaly_val_writer.close()
for train_writer in train_writer_list:
train_writer.close()

读取

raw_train_dataset = tf.data.TFRecordDataset([r'ex_1/'+'train_16_256_{}.tfrecords'.format(i) for i in range(SUBFILE_NUM+1)],"GZIP")
raw_train_dataset = raw_train_dataset.shuffle(buffer_size=100000).batch(BATCH_SIZE)
parsed_train_dataset = raw_train_dataset.map(map_func=map_func) raw_normal_val_dataset = tf.data.TFRecordDataset(r'ex_1/'+'normal_val_16_256.tfrecords',"GZIP")
raw_anomaly_val_dataset = tf.data.TFRecordDataset(r'ex_1/'+'anomaly_val_16_256.tfrecords',"GZIP")
parsed_nomarl_val_dataset = raw_normal_val_dataset.batch(BATCH_SIZE).map(map_func=map_func)
parsed_anomaly_val_dateset = raw_anomaly_val_dataset.batch(BATCH_SIZE).map(map_func=map_func)

TFRecord的Shuffle、划分和读取的更多相关文章

  1. Tensorflow 中(批量)读取数据的案列分析及TFRecord文件的打包与读取

    内容概要: 单一数据读取方式: 第一种:slice_input_producer() # 返回值可以直接通过 Session.run([images, labels])查看,且第一个参数必须放在列表中 ...

  2. 更加清晰的TFRecord格式数据生成及读取

    TFRecords 格式数据文件处理流程 TFRecords 文件包含了 tf.train.Example 协议缓冲区(protocol buffer),协议缓冲区包含了特征 Features.Ten ...

  3. Tensorflow中使用tfrecord方式读取数据-深度学习-周振洋

    本博客默认读者对神经网络与Tensorflow有一定了解,对其中的一些术语不再做具体解释.并且本博客主要以图片数据为例进行介绍,如有错误,敬请斧正. 使用Tensorflow训练神经网络时,我们可以用 ...

  4. Spark技术内幕:Stage划分及提交源码分析

    http://blog.csdn.net/anzhsoft/article/details/39859463 当触发一个RDD的action后,以count为例,调用关系如下: org.apache. ...

  5. Spark技术内幕:Stage划分及提交源代码分析

    当触发一个RDD的action后.以count为例,调用关系例如以下: org.apache.spark.rdd.RDD#count org.apache.spark.SparkContext#run ...

  6. 第十二节,TensorFlow读取数据的几种方法以及队列的使用

    TensorFlow程序读取数据一共有3种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow图的起 ...

  7. TensorFlow中数据读取之tfrecords

    关于Tensorflow读取数据,官网给出了三种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow ...

  8. tensorflow之数据读取探究(2)

    tensorflow之tfrecord数据读取 Tensorflow关于TFRecord格式文件的处理.模型的训练的架构为: 1.获取文件列表.创建文件队列:http://blog.csdn.net/ ...

  9. spark 笔记 15: ShuffleManager,shuffle map两端的stage/task的桥梁

    无论是Hadoop还是spark,shuffle操作都是决定其性能的重要因素.在不能减少shuffle的情况下,使用一个好的shuffle管理器也是优化性能的重要手段. ShuffleManager的 ...

随机推荐

  1. 开源框架 WebFirst 一键生成项目,在线建表

    1.WebFirst框架描述 WebFirst  是果糖大数据团队开发的新一代 高性能 代码生成器&数据库设计工具,由.net core 3.1 + sqlsugar 开发 导入1000个表只 ...

  2. 深入HTTP请求流程

    1.HTTP协议介绍 HTTP协议(HyperText Transfer Protocol,超文本传输协议)是因特网上应用最为广泛的一种网络传输协议,它是从WEB服务器传输超文本标记语言(HTML)到 ...

  3. Atlas2.2.0编译、安装及使用(集成ElasticSearch,导入Hive数据)

    1.编译阶段 组件信息: 组件名称 版本 Atals 2.2.0 HBase 2.2.6 Hive 3.1.2 Hadoop 3.1.1 Kafka 2.11_2.4.1 Zookeeper 3.6. ...

  4. 实验:Python图形图像处理

    1. 准备一张照片,编写Python程序将该照片进行图像处理,分别输出以下效果的图片:(a)灰度图:(b)轮廓图: (c)变换RGB通道图:(d)旋转45度图. 2. 假设当前文件夹中data.csv ...

  5. 3┃音视频直播系统之浏览器中通过 WebRTC 直播视频实时录制回放下载

    一.录制分类 在音视频会议.在线教育等系统中,录制是一个特别重要的功能 录制一般分为服务端录制和客户端录制 服务端录制:优点是不用担心客户因自身电脑问题造成录制失败(如磁盘空间不足),也不会因录制时抢 ...

  6. [2018-03-04] 利用 Settings Sync 插件同步 VS Code 设置

    VS Code 已原生支持设置同步,本文仅备份记录 [2018-03-04] 早就听说这个插件了,今天用了一下,确实挺方便的.通过把配置文件创建为 Gist 上来实现了 VS Code 设置的同步,下 ...

  7. R-CNN学习笔记

    R-CNN学习笔记 step1:总览 步骤: 输入图片 先挑选大约2000个感兴趣区域(ROI)使用select search方法:[在输入的图像中寻找blobby regions(可能相同纹理,颜色 ...

  8. 【freertos】009-任务控制

    目录 前言 9.1 相对延时 9.1.1 函数原型 9.1.2 函数说明 9.1.3 参考例子 9.2 绝对延时 9.2.1 函数原型 9.2.2 函数说明 9.2.3 参考例子 9.3 获取任务优先 ...

  9. 【Python - pip source】工欲善其事,必先利其器 - 不要让 pip install timeout 成为你的烦恼

    目录 前言 一.原因 二.解决方法 2.1 思路 2.2 国内镜像源列举 2.3 具体解决过程 2.3.1 方法一:命令行(推荐) 2.3.2 方法二:创建文件 总结 前言 解决pip install ...

  10. 写selenium常用到的js代码

    selenium可以运行JavaScript代码,可以用一些JavaScript来辅助编写Selelnium代码. 1.scrollIntoView - 向下拉滚动条,使得某元素可见 IWebElem ...