TensorFlow+实战Google深度学习框架学习笔记(5)----神经网络训练步骤
一、TensorFlow实战Google深度学习框架学习
1、步骤:
1、定义神经网络的结构和前向传播的输出结果。
2、定义损失函数以及选择反向传播优化的算法。
3、生成会话(session)并且在训练数据上反复运行反向传播优化算法。
2、代码:
来源:https://blog.csdn.net/longji/article/details/69472310
import tensorflow as tf
from numpy.random import RandomState # 1. 定义神经网络的参数,输入和输出节点
batch_size = 8
w1= tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
w2= tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
y_= tf.placeholder(tf.float32, shape=(None, 1), name='y-input') # 2. 定义前向传播过程,损失函数及反向传播算法
a = tf.matmul(x, w1)
y = tf.matmul(a, w2)
cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy) # 3. 生成模拟数据集
rdm = RandomState(1)
X = rdm.rand(128,2)
Y = [[int(x1+x2 < 1)] for (x1, x2) in X] # 4. 创建一个会话来运行TensorFlow程序
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op) # 输出目前(未经训练)的参数取值。
print("w1:", sess.run(w1))
print("w2:", sess.run(w2))
print("\n") # 训练模型。
STEPS = 5000
for i in range(STEPS):
start = (i * batch_size) % 128
end = (i * batch_size) % 128 + batch_size
sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
if i % 1000 == 0:
total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y})
print("After %d training step(s), cross entropy on all data is %g" % (i, total_cross_entropy)) # 输出训练后的参数取值。
print("\n")
print("w1:", sess.run(w1))
print("w2:", sess.run(w2)) ''' 输出形式
w1: [[-0.81131822 1.48459876 0.06532937]
[-2.4427042 0.0992484 0.59122431]]
w2: [[-0.81131822]
[ 1.48459876]
[ 0.06532937]] After 0 training step(s), cross entropy on all data is 0.0674925
After 1000 training step(s), cross entropy on all data is 0.0163385
After 2000 training step(s), cross entropy on all data is 0.00907547
After 3000 training step(s), cross entropy on all data is 0.00714436
After 4000 training step(s), cross entropy on all data is 0.00578471 w1: [[-1.9618274 2.58235407 1.68203783]
[-3.46817183 1.06982327 2.11789012]]
w2: [[-1.82471502]
[ 2.68546653]
[ 1.41819513]]
'''
二、莫烦大大的神经网络训练步骤:
1、def add_layer()
添加神经网络层:
import tensorflow as tf
#输入、输入大小、输出大小、激活函数 def add_layer( inputs, in_size, out_size ,activation_function=None) : #weight初始化时生成一个随机变量矩阵比0矩阵效果要好 Weights = tf.Variable( tf.random_normal ( [in_size, out_size])) #biases初始值最好也不要都为0,则biases值全部等于0.1 biases = tf.Variable( tf.zeros([1,out_size]) + 0.1) #相当于Y_predict Wx_plus_b = tf.matmul ( inputs,Weights ) +biases
#如果为线性则outputs不用改变,如果不为线性则用激活函数
if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_functions ( Wx_plus_b)
reeturn outputs
2、建立神经网络
#定义数据 x_data = np.linspace ( -1,1,300) [:,np.newaxis]
noise = np.random.normal ( 0,0.05 , x_data.shape)
y_data = np.square( x_data) - 0.5 +noise #建立第一层layer
#一个输入层、一个隐藏层、一个输出层
#输入层:输入多少data就多少个神经元,这里的x只有一个特征属性,则输入层有1个神经元
#隐藏层:自己定义10个
#输出层:输出y只有1个输出 #None表示无论给多少个样本都可以
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None,1] ) #add_layer为上面自己建立的函数,这里建立隐藏层
l1 = add_layer( xs , 1 ,10 ,activation_function = tf.nn.relu)
#输出层
predition = add_layer( l1 ,10 , 1,activation_function = None) #算损失函数 , reduction_indices =[1] 按行求和
loss = tf.reduce_mean ( tf.square ( ys -prediction ),
reduction_indices =[1] ) #选择一个优化器,选择:梯度下降,需要给定一个学习率为0.1,通常要小于1
#优化器以0.1的学习效率要减少loss函数,使下一次结果更好
train_step = tf.train.GradientDecentOptimizer( 0.1).minimize (loss) #初始所有变量 init = tf.initialize_all_variables () sess = tf.Session() sess.run(init)
#重复学习1000次
for i in range(1000):
sess.run( train_step , feed_dict = {xs:x_data,ys:y_data})
#每50次打印loss
if i % 50 == 0:
print(sess.run(loss,feed_dict={x:x_data,ys:y_data})
TensorFlow+实战Google深度学习框架学习笔记(5)----神经网络训练步骤的更多相关文章
- [Tensorflow实战Google深度学习框架]笔记4
本系列为Tensorflow实战Google深度学习框架知识笔记,仅为博主看书过程中觉得较为重要的知识点,简单摘要下来,内容较为零散,请见谅. 2017-11-06 [第五章] MNIST数字识别问题 ...
- 1 如何使用pb文件保存和恢复模型进行迁移学习(学习Tensorflow 实战google深度学习框架)
学习过程是Tensorflow 实战google深度学习框架一书的第六章的迁移学习环节. 具体见我提出的问题:https://www.tensorflowers.cn/t/5314 参考https:/ ...
- 学习《TensorFlow实战Google深度学习框架 (第2版) 》中文PDF和代码
TensorFlow是谷歌2015年开源的主流深度学习框架,目前已得到广泛应用.<TensorFlow:实战Google深度学习框架(第2版)>为TensorFlow入门参考书,帮助快速. ...
- TensorFlow实战Google深度学习框架5-7章学习笔记
目录 第5章 MNIST数字识别问题 第6章 图像识别与卷积神经网络 第7章 图像数据处理 第5章 MNIST数字识别问题 MNIST是一个非常有名的手写体数字识别数据集,在很多资料中,这个数据集都会 ...
- TensorFlow实战Google深度学习框架-人工智能教程-自学人工智能的第二天-深度学习
自学人工智能的第一天 "TensorFlow 是谷歌 2015 年开源的主流深度学习框架,目前已得到广泛应用.本书为 TensorFlow 入门参考书,旨在帮助读者以快速.有效的方式上手 T ...
- 实现迁徙学习-《Tensorflow 实战Google深度学习框架》代码详解
为了实现迁徙学习,首先是数据集的下载 #利用curl下载数据集 curl -o flower_photos.tgz http://download.tensorflow.org/example_ima ...
- 2 (自我拓展)部署花的识别模型(学习tensorflow实战google深度学习框架)
kaggle竞赛的inception模型已经能够提取图像很好的特征,后续训练出一个针对当前图片数据的全连接层,进行花的识别和分类.这里见书即可,不再赘述. 书中使用google参加Kaggle竞赛的i ...
- TensorFlow实战第三课(可视化、加速神经网络训练)
matplotlib可视化 构件图形 用散点图描述真实数据之间的关系(plt.ion()用于连续显示) # plot the real data fig = plt.figure() ax = fig ...
- TensorFlow实战Google深度学习框架10-12章学习笔记
目录 第10章 TensorFlow高层封装 第11章 TensorBoard可视化 第12章 TensorFlow计算加速 第10章 TensorFlow高层封装 目前比较流行的TensorFlow ...
随机推荐
- C++继承与组合
转自https://blog.csdn.net/caoyan_12727/article/details/52337297 类的组合和继承一样,是软件重用的重要方式.组合和继承都是有效地利用已有类的资 ...
- SQLServer Oracle MySQL的区别
table tr:nth-child(odd){ background: #FFFFCC; font-size: 18px; } table tr:nth-child(even){ backgroun ...
- 实现el-dialog的拖拽,全屏,缩小功能
基于el-dialog, 封装了一下.,实在懒得写,所以直接把代码 粘出来了 大概粘了一下效果.自己体会把. 组件使用 <el-dialog v-dialogDrag ref="xhz ...
- [NoiPlus2016]天天爱跑步
巨坑 树剖学的好啊!---sfailsth 把一段路径拆成两段,向上和S->LCA,向下LCA->T 用维护重链什么的操作搞一下. sfailsth学长真不容易啊...考场上rush了4. ...
- Problem 9
Problem 9 # Problem_9.py """ A Pythagorean triplet is a set of three natural numbers, ...
- 00069_DateFormate
1.DateFormate类概述 (1)DateFormat 是日期/时间格式化子类的抽象类,它以与语言无关的方式格式化并解析日期或时间.日期/时间格式化子类(如 SimpleDateFormat类) ...
- 【hiho一下 第三周】KMP算法
[题目链接]:http://hihocoder.com/problemset/problem/1015 [题意] [题解] 把f数组,len1,len2数组一开始全都定义成char型 这酸爽. [Nu ...
- Android使用C代码
Android调用C代码 1.开发工具:Android studio 2.0 2.开发前准备: 2. 3. 4.下面我们就来开发我们的程序吧, [1]创建一个java类 package com.adm ...
- 如何利用eclipse实现批量修改文件的编码方式
在eclipse+Eclipse环境下,打开一个jsp文件,经常发现汉字无法显示,右键点击查看这个文件属性,发现文件的字符编码属性为ISO-8859-1. 目前的解决方法有:1. 手工把 ...
- CF802G Fake News (easy)
CF802G Fake News (easy) 题意翻译 给定一个字符串询问能否听过删除一些字母使其变为“heidi” 如果可以输出“YES”,不然为“NO” 题目描述 As it's the fir ...