本篇涉及的内容主要有小型常用的经典数据集的加载步骤,tensorflow提供了如下接口:keras.datasets、tf.data.Dataset.from_tensor_slices(shuffle、map、batch、repeat),涉及的数据集如下:boston housing、mnist/fashion mnist、cifar10/100、imdb

1.keras.datasets

通过该接口可以直接下载指定数据集。boston housing提供了和房价有关的一些因子(面积、居民来源等),mnist提供了手写数字的图片和对应label,fashion mnist提供了10种衣服的灰度图和对应label,cifar10/100是用来进行简单图像识别的数据集,分别包含10类物品和100类物品,imdb是一个类似于淘宝好评的数据集,即通过评语及其标注(好评或差评),来实现一个好评或差评的分类器。

注:通过该接口得到的数据集格式为numpy格式。

2.tf.data.Dataset.from_tensor_slices()

该方法可以用来进行数据的迭代,过程中可以直接将numpy格式转化为tensor格式,然后通过调用next(iter())方法实现迭代,使用示例如下:

# 加载数据集
(x,y),(x_test,y_test) = keras.datasets.mnist.load_data()
# 转化为tensor并实现迭代
db = tf.data.Dataset.from_tensor_slices(x_test)
# 打印迭代数据的shape
print(next(iter(db)).shape)
# 将img和label封装为同一次迭代
db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
print(next(iter(db))[0].shape)
print(next(iter(db))[1].shape)

3.shuffle

通过shuffle函数可以将数据集打散,从而提高模型的泛化能力,使用方法:db.shuffle(10000),参数设置范围,通常值设置比较大

4.map

# deep learning一般使用float32,而numpy格式多为float64,所以需要转化
def preprocess(x,y):
x = tf.cast(x,dtype=tf.float32)/255
y = tf.cast(y,dtype=tf.int32)
y = tf.one_hot(y,depth=10)
return x,y db2 = db.map(preprocess)
res = next(iter(db2))
print(res[0].shape,res[1].shape)

5.batch

db3 = db2.batch(32)
res = next(iter(db3))
print(res[0].shape,res[1].shape)

6.StopIteration

因为迭代多次后会到达数据集的末尾,如果不进行异常处理则会报StopIteration异常,如下处理方式就是错误的:

db_iter = iter(db3)
while True:
next(db_iter)

只要加上异常处理语句对db_iter重新赋值即可

tensorflow数据集加载的更多相关文章

  1. OFRecord 数据集加载

    OFRecord 数据集加载 在数据输入一文中知道了使用 DataLoader 及相关算子加载数据,往往效率更高,并且学习了如何使用 DataLoader 及相关算子. 在 OFrecord 数据格式 ...

  2. 什么是pytorch(4.数据集加载和处理)(翻译)

    数据集加载和处理 这里主要涉及两个包:torchvision.datasets 和torch.utils.data.Dataset 和DataLoader torchvision.datasets是一 ...

  3. tensorflow数据加载、模型训练及预测

    数据集 DNN 依赖于大量的数据.可以收集或生成数据,也可以使用可用的标准数据集.TensorFlow 支持三种主要的读取数据的方法,可以在不同的数据集中使用:本教程中用来训练建立模型的一些数据集介绍 ...

  4. Windows下pycharm远程连接服务器调试-tensorflow无法加载问题

    最近打算在win系统下使用pycharm开发程序,并远程连接服务器调试程序,其中在import tensorflow时报错如图所示(在远程服务器中执行程序正常): 直观错误为: ImportError ...

  5. Tensorflow模型加载与保存、Tensorboard简单使用

    先上代码: from __future__ import absolute_import from __future__ import division from __future__ import ...

  6. tensorflow学习笔记2:c++程序静态链接tensorflow库加载模型文件

    首先需要搞定tensorflow c++库,搜了一遍没有找到现成的包,于是下载tensorflow的源码开始编译: tensorflow的contrib中有一个makefile项目,极大的简化的接下来 ...

  7. TensorFlow模型加载与保存

    我们经常遇到训练时间很长,使用起来就是Weight和Bias.那么如何将训练和测试分开操作呢? TF给出了模型的加载与保存操作,看了网上都是很简单的使用了一下,这里给出一个神经网络的小程序去测试. 本 ...

  8. Tensorflow同时加载使用多个模型

    在Tensorflow中,所有操作对象都包装到相应的Session中的,所以想要使用不同的模型就需要将这些模型加载到不同的Session中并在使用的时候申明是哪个Session,从而避免由于Sessi ...

  9. PIE SDK 多数据源的复合数据集加载

    1. 功能简介 GIS遥感图像数据复合是将多种遥感图像数据融合成一种新的图像数据的技术,是目前遥感应用分析的前沿,PIESDK通过复合数据技术可以将多幅幅影像数据集(多光谱和全色数据)组合成一幅多波段 ...

随机推荐

  1. Zabbix:主动模式

    简介 Zabbix 是由 Alexei Vladishev 开发的一种网络监视.管理系统,基于 Server-Client 架构.可用于监视各种网络服务.服务器和网络机器等状态,官方站点:https: ...

  2. 1. 学习Linux操作系统

    1.熟练使用Linux命令行(鸟哥的Linux私房菜.Linux系统管理技术手册) 2.学会Linux程序设计(UNIX环境高级编程) 3.了解Linux内核机制(深入理解LINUX内核) 4.阅读L ...

  3. Go语言实现:【剑指offer】数据流中的中位数

    该题目来源于牛客网<剑指offer>专题. 如何得到一个数据流中的中位数?如果从数据流中读出奇数个数值,那么中位数就是所有数值排序之后位于中间的数值.如果从数据流中读出偶数个数值,那么中位 ...

  4. php 上传文件 示例

    <?php header("content-type:text/html;charset=utf-8"); error_reporting(E_ALL); //session ...

  5. PPT导出图片质量太差?简单操作直接导出印刷质地图片

    PPT导出图片质量太差?简单操作直接导出印刷质地图片    ​ PPT不仅可以用于展示文档,还可以用于简单图片合成处理,同时,PPT文档还可以全部导出为图片. 默认情况下,PPT导出的图片为96DPI ...

  6. centos 7 设置 本地更新源

    #yum-config-manager --disable \*--屏弊所有更新源#mkdir /r7iso# cd /run/media/{用户名}/CentOS\ 7\ x86_64/ #cp - ...

  7. vi/vim编辑器操作梳理

    1. vi/vim编辑器详细使用讲解 1.1 vi/vim编辑器的3种模式 1.2 vi/vim编辑器操作说明 参数/命令/模式 说明 ###普通模式   :set nu  显示行号 :set non ...

  8. 【转】Android之四大组件、六大布局、五大存储

    文章来自:http://blog.csdn.net/shenggaofei/article/details/52450668 一.四大组件: Android四大组件分别为activity.servic ...

  9. 【STM32H7教程】第62章 STM32H7的MDMA,DMA2D和通用DMA性能比较

    完整教程下载地址:http://www.armbbs.cn/forum.php?mod=viewthread&tid=86980 第62章       STM32H7的MDMA,DMA2D和通 ...

  10. for遍历用例数据时,报错:TypeError: list indices must be integers, not dict,'int' object is not iterable解决方法

    一:报错:TypeError: list indices must be integers, not dict for i in range(0,len(test_data)): suite.addT ...