写入Tfrecord

        print("convert data into tfrecord:train\n")
out_file_train = "/home/huadong.wang/bo.yan/fudan_mtl/data/ace2005/bn_nw.train.tfrecord"
writer = tf.python_io.TFRecordWriter(out_file_train) for i in tqdm(range(len(data_train))):
record = tf.train.Example(features=tf.train.Features(feature={
'word_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_x[i].tostring()])),
'et_ids1': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_et1[i].tostring()])),
'et_ids2': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_et2[i].tostring()])),
'position_ids1': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_p1[i].tostring()])),
'position_ids2': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_p1[i].tostring()])),
'chunks': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_chunks[i].tostring()])),
'spath_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_spath[i].tostring()])),
'seq_len': tf.train.Feature(int64_list=tf.train.Int64List(value=[train_x_len[i]])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[np.argmax(train_relation[i])])),
'task': tf.train.Feature(int64_list=tf.train.Int64List(value=[np.int64(0)]))
}))
writer.write(record.SerializeToString())
writer.close()

  

解析tfrecord

def _parse_tfexample(serialized_example):
'''parse serialized tf.train.SequenceExample to tensors
context features : label, task
sequence features: sentence
'''
context_features={'label' : tf.FixedLenFeature([], tf.int64),
'task' : tf.FixedLenFeature([], tf.int64),
'seq_len': tf.FixedLenFeature([], tf.int64)}
sequence_features={'word_ids': tf.FixedLenSequenceFeature([], tf.int64),
'et_ids1': tf.FixedLenSequenceFeature([], tf.int64),
'et_ids2': tf.FixedLenSequenceFeature([], tf.int64),
'position_ids1': tf.FixedLenSequenceFeature([], tf.int64),
'position_ids2': tf.FixedLenSequenceFeature([], tf.int64),
'chunks': tf.FixedLenSequenceFeature([], tf.int64),
'spath_ids': tf.FixedLenSequenceFeature([], tf.int64),
}
context_dict, sequence_dict = tf.parse_single_sequence_example(
serialized_example,
context_features = context_features,
sequence_features = sequence_features) sentence = (sequence_dict['word_ids'],sequence_dict['et_ids1'],sequence_dict['et_ids2'],sequence_dict['position_ids1'],
sequence_dict['position_ids2'],sequence_dict['chunks'],sequence_dict['spath_ids'], context_dict['seq_len']) label = context_dict['label']
task = context_dict['task'] return task, label, sentence def read_tfrecord(epoch, batch_size):
for dataset in DATASETS:
train_record_file = os.path.join(OUT_DIR, dataset+'.train.tfrecord')
test_record_file = os.path.join(OUT_DIR, dataset+'.test.tfrecord') train_data = util.read_tfrecord(train_record_file,
epoch,
batch_size,
_parse_tfexample,
shuffle=True) test_data = util.read_tfrecord(test_record_file,
epoch,
batch_size,
_parse_tfexample,
shuffle=False)
yield train_data, test_data

模型中使用:

  def build_task_graph(self, data):
task_label, labels, sentence = data
# sentence = tf.nn.embedding_lookup(self.word_embed, sentence)
##########################
word_ids, et_ids1,et_ids2,position_ids1,position_ids2,chunks,spath_ids,seq_len = sentence
# sentence = word_ids
######################### self.word_ids = word_ids
self.position_ids1 = position_ids1
self.position_ids2 = position_ids2
self.et_ids1 = et_ids1
self.et_ids2 = et_ids2
self.chunks_ids = chunks
self.spath_ids = spath_ids
self.seq_len = seq_len sentence = self.add_embedding_layers()

  

 

关于Tfrecord的更多相关文章

  1. Tensorflow 处理libsvm格式数据生成TFRecord (parse libsvm data to TFRecord)

    #写libsvm格式 数据 write libsvm     #!/usr/bin/env python #coding=gbk # ================================= ...

  2. 学习笔记TF016:CNN实现、数据集、TFRecord、加载图像、模型、训练、调试

    AlexNet(Alex Krizhevsky,ILSVRC2012冠军)适合做图像分类.层自左向右.自上向下读取,关联层分为一组,高度.宽度减小,深度增加.深度增加减少网络计算量. 训练模型数据集 ...

  3. [TFRecord格式数据]利用TFRecords存储与读取带标签的图片

    利用TFRecords存储与读取带标签的图片 原创文章,转载请注明出处~ 觉得有用的话,欢迎一起讨论相互学习~Follow Me TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是 ...

  4. 深度学习原理与框架-Tfrecord数据集的读取与训练(代码) 1.tf.train.batch(获取batch图片) 2.tf.image.resize_image_with_crop_or_pad(图片压缩) 3.tf.train.per_image_stand..(图片标准化) 4.tf.train.string_input_producer(字符串入队列) 5.tf.TFRecord(读

    1.tf.train.batch(image, batch_size=batch_size, num_threads=1) # 获取一个batch的数据 参数说明:image表示输入图片,batch_ ...

  5. 深度学习原理与框架-Tfrecord数据集的制作 1.tf.train.Examples(数据转换为二进制) 3.tf.image.encode_jpeg(解码图片加码成jpeg) 4.tf.train.Coordinator(构建多线程通道) 5.threading.Thread(建立单线程) 6.tf.python_io.TFR(TFR读入器)

    1. 配套使用: tf.train.Examples将数据转换为二进制,提升IO效率和方便管理 对于int类型 : tf.train.Examples(features=tf.train.Featur ...

  6. 3. Tensorflow生成TFRecord

    1. Tensorflow高效流水线Pipeline 2. Tensorflow的数据处理中的Dataset和Iterator 3. Tensorflow生成TFRecord 4. Tensorflo ...

  7. TFRecord文件的读写

    前言在跑通了官网的mnist和cifar10数据之后,笔者尝试着制作自己的数据集,并保存,读入,显示. TensorFlow可以支持cifar10的数据格式, 也提供了标准的TFRecord 格式,而 ...

  8. 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练

    将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练. import xml.etree.ElementTree as ET import numpy as ...

  9. tfrecord

    制作自己的TFRecord数据集,读取,显示及代码详解 http://blog.csdn.net/miaomiaoyuan/article/details/56865361

  10. 3 TFRecord样例程序实战

    将图片数据写入Record文件 # 定义函数转化变量类型. def _int64_feature(value): return tf.train.Feature(int64_list=tf.train ...

随机推荐

  1. JVM系列之二:编译过程

    1. Java的编译和执行 编译包括两种情况: 1,源码编译成字节码2,字节码编译成本地机器码(符合本地系统专属的指令) 解释执行也包括两种情况: 1,源码解释执行2,字节码解释执行 解释和编译执行的 ...

  2. Linux tty驱动架构

    Linux tty子系统包含:tty核心,tty线路规程和tty驱动.tty核心是对整个tty设备的抽象,对用户提供统一的接口,tty线路规程是对传输数据的格式化,tty驱动则是面向tty设备的硬件驱 ...

  3. Redis哨兵、复制、集群的设计原理与区别

    一 前言 谈到Redis服务器的高可用,如何保证备份的机器是原始服务器的完整备份呢?这时候就需要哨兵和复制. 哨兵(Sentinel):可以管理多个Redis服务器,它提供了监控,提醒以及自动的故障转 ...

  4. | C语言I作业02

    C语言I博客作业02 标签: 18软件2班 李煦亮 问题 答案 这个作业属于那个课程 C语言程序设计I 这个作业要求在哪里 https://edu.cnblogs.com/campus/zswxy/C ...

  5. python 基础 ---- 面向对象

    ------   面向对象的思想 三个基本特征: 封装(封装属性方法可以减少耦合)继承(可以抬高开发效率) 多态 主要包括 : 类 : 描述具有相同的属性和方法的对象的集合  变量:   类变量/ 成 ...

  6. C语言 宏定义的1<<0 与 直接定义1 有什么区别

    [1]示例程序 如下示例代码: #include <stdio.h> #define TEST1 1 << 0 #define TEST2 (1 << 0) #de ...

  7. idea 设置默认的maven

    idea版本2019.2 设置maven 按照上图中的1-4顺序进行配置,就可以让以后每一个工程使用我们指定的配置了. 1:打开maven配置界面. 2:点击后面的三角符号,使maven列表显示,并在 ...

  8. es+logstash+kibana搭建

    1.简介 ELK(elasticsearch+logstash+kibana)是目前比较常用的日志分析系统,包括日志收集(logstash),日志存储搜索(elasticsearch),展示查询(ki ...

  9. jboss/wildfly安全域的密码加密和解密

    加密: java_path=$(source /opt/wildfly/bin/.Beta1.jar:/opt/wildfly/modules/system/layers/base/org/jboss ...

  10. 【华为云实战开发】10.经典的C++项目怎么在云端开发?【华为云技术分享】

    1 概述 1.1 文章目的 本文主要想为研发C++项目的企业或个人提供上云指导,通过本文中的示例项目 “音频解析器”,为开发者提供包括项目管理,代码托管,代码检查,编译构建,测试管理的操作指导,覆盖软 ...