Tensorflow学习笔记No.3
使用tf.data加载数据
tf.data是tensorflow2.0中加入的数据加载模块,是一个非常便捷的处理数据的模块。
这里简单介绍一些tf.data的使用方法。
1.加载tensorflow中自带的mnist数据并对数据进行一些简单的处理
1 (train_image, train_label), (test_image, test_label) = tf.keras.datasets.mnist.load_data()
2 train_image = train_image / 255
3 test_image = test_image / 255
2.使用tf.data.Dataset.from_tensor_slices()方法对数据进行切片处理
该函数是dataset核心函数之一,它的作用是把给定的元组、列表和张量等数据进行特征切片。切片的范围是从最外层维度开始的。如果有多个特征进行组合,那么一次切片是把每个组合的最外维度的数据切开,分成一组一组的。
1 ds_train_label = tf.data.Dataset.from_tensor_slices(train_label)
2 ds_train_label = tf.data.Dataset.from_tensor_slices(train_label)
3.使用tf.data.Dataset.zip()方法将image和label数据合并
tf.data.Dataset.zip()方法可将迭代对象中相对应(例如image对应label)的数据打包成一个元组,返回由这些元组组成的对象。
1 ds_train = tf.data.Dataset.zip((ds_train_image, ds_train_label))
这里ds_train中的数据就是由许多个(image, label)元组组成的。
事实上我们也可以直接把train_image与train_label进行合并,以元组的形式对train_image和train_label进行切片即可。
1 ds_trian = tf.data.Dataset.from_tensor_slices((train_image, train_label))
4.使用.shuffle().repeat().batch()方法对数据进行处理
1 ds_train = ds_train.shuffle(10000).repeat(count = 3).batch(64)
.shuffle()作用是将数据进行打乱操作,传入参数为buffer_size,改参数为设置“打乱缓存区大小”,也就是说程序会维持一个buffer_size大小的缓存,每次都会随机在这个缓存区抽取一定数量的数据。
.repeat()作用就是将数据重复使用多少次,参数是重复的次数,若无参数则无限重复。
.batch()作用是将数据打包成batch_size, 每batch_size个数据打包在一起作为一个epoch。
5.注意事项
在使用tf.data时,如果不设置数据的.repeat()的重复次数,数据会无限制重复,如果把这样的数据直接输入到神经网络中会导致内存不足程序无法终止等错误。此时,要在.fit()方法中加以限制。
1 history = model.fit(ds_train, epochs = 5, steps_per_epoch = step_per_epochs,
2 validation_data = ds_test, validation_steps = 10000 // 64
3 )
使用steps_per_epoch参数限制每个epochs的数据量。
使用validation_steps限制验证集中的数据量。
到这里tf.data的简单介绍就结束了,后续会更新tf.data中的更多内容。
Tensorflow学习笔记No.3的更多相关文章
- Tensorflow学习笔记2:About Session, Graph, Operation and Tensor
简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...
- Tensorflow学习笔记2019.01.22
tensorflow学习笔记2 edit by Strangewx 2019.01.04 4.1 机器学习基础 4.1.1 一般结构: 初始化模型参数:通常随机赋值,简单模型赋值0 训练数据:一般打乱 ...
- Tensorflow学习笔记2019.01.03
tensorflow学习笔记: 3.2 Tensorflow中定义数据流图 张量知识矩阵的一个超集. 超集:如果一个集合S2中的每一个元素都在集合S1中,且集合S1中可能包含S2中没有的元素,则集合S ...
- TensorFlow学习笔记之--[compute_gradients和apply_gradients原理浅析]
I optimizer.minimize(loss, var_list) 我们都知道,TensorFlow为我们提供了丰富的优化函数,例如GradientDescentOptimizer.这个方法会自 ...
- 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识
深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...
- 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别
深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...
- tensorflow学习笔记(4)-学习率
tensorflow学习笔记(4)-学习率 首先学习率如下图 所以在实际运用中我们会使用指数衰减的学习率 在tf中有这样一个函数 tf.train.exponential_decay(learning ...
- tensorflow学习笔记(3)前置数学知识
tensorflow学习笔记(3)前置数学知识 首先是神经元的模型 接下来是激励函数 神经网络的复杂度计算 层数:隐藏层+输出层 总参数=总的w+b 下图为2层 如下图 w为3*4+4个 b为4* ...
- tensorflow学习笔记(2)-反向传播
tensorflow学习笔记(2)-反向传播 反向传播是为了训练模型参数,在所有参数上使用梯度下降,让NN模型在的损失函数最小 损失函数:学过机器学习logistic回归都知道损失函数-就是预测值和真 ...
- tensorflow学习笔记(1)-基本语法和前向传播
tensorflow学习笔记(1) (1)tf中的图 图中就是一个计算图,一个计算过程. 图中的constant是个常量 计 ...
随机推荐
- SQL Node 1.05版
输出: select a.f1, b.f2 from table01 a, ( select a from tb ) b where a.f1=1 and b.f2=2 or b.f3=3 order ...
- Ajax跨域解决方案大全
题纲 关于跨域,有N种类型,本文只专注于ajax请求跨域(,ajax跨域只是属于浏览器"同源策略"中的一部分,其它的还有Cookie跨域iframe跨域,LocalStorage跨 ...
- 借助FreeHttp任意篡改Websocket报文(Websocket改包)
前言 作为Web应用中最常见的数据传输协议之一的Websocket,在我们日常工作中也势必会经常使用到,而在调试或测试中我们常常也有直接改变Websocket数据报文以确认其对应用的影响的需求,本文将 ...
- 单元测试unittest(基于数据驱动的框架:unittest+HTMLTestRunner/BeautifulReport+yaml+ddt)
一.定义 unittest单元测试框架不仅可以适用于单元测试,还可以适用WEB自动化测试用例的开发与执行,该测试框架可组织执行测试用例,并且提供了丰富的断言方法,判断测试用例是否通过,最终生成测试结果 ...
- python调用接口——requests模块
前提:安装pip install requests 导入import requests 1.get请求 result=requests.get(url,d).json() 或 .text 2. ...
- C#类库推荐 拼多多.Net SDK,开源免费!
背景介绍 近两年拼多多的发展非常迅速,即便口碑一般,也没有网页端,奈何我们已经全面小康,6亿月收入1000以下,9亿月收入2000以下,所以因为价格原因使用拼多多的用户也越来越多了.同样的,拼多多也开 ...
- canvas学习作业,模仿做一个祖玛的小游戏
这个游戏的原理我分为11个步骤,依次如下: 1.布局, 2.画曲线(曲线由两个半径不同的圆构成) 3.画曲线起点起始圆和曲线终点终止圆 4.起始的圆动起来, 5.起始的圆沿曲线走起来 6.起始的圆沿曲 ...
- 面试官:哪些场景会产生OOM?怎么解决?
这个面试题是一个朋友在面试的时候碰到的,什么时候会抛出OutOfMemery异常呢?初看好像挺简单的,其实深究起来考察的是对整个JVM的了解,而且这个问题从网上可以翻到一些乱七八糟的答案,其实在总结下 ...
- 云计算openstack共享组件——时间同步服务ntp(2)
一.标准时间讲解 地球分为东西十二个区域,共计 24 个时区 格林威治作为全球标准时间即 (GMT 时间 ),东时区以格林威治时区进行加,而西时区则为减. 地球的轨道并非正圆,在加上自转速度逐年递减, ...
- RDS、DDS 和 GaussDB 理不清?看这一篇足够了!
当前,华为云提供的数据库服务主要包括三大类:关系型数据库服务,非关系型数据库服务以及数据库工具服务.如下图所示: 关系型数据库和非关系型数据库均可分为开源和自研两大类.其中,自研数据库统一为Gauss ...