TensorFlow+实战Google深度学习框架学习笔记(5)----神经网络训练步骤
一、TensorFlow实战Google深度学习框架学习
1、步骤:
1、定义神经网络的结构和前向传播的输出结果。
2、定义损失函数以及选择反向传播优化的算法。
3、生成会话(session)并且在训练数据上反复运行反向传播优化算法。
2、代码:
来源:https://blog.csdn.net/longji/article/details/69472310
import tensorflow as tf
from numpy.random import RandomState # 1. 定义神经网络的参数,输入和输出节点
batch_size = 8
w1= tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
w2= tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
y_= tf.placeholder(tf.float32, shape=(None, 1), name='y-input') # 2. 定义前向传播过程,损失函数及反向传播算法
a = tf.matmul(x, w1)
y = tf.matmul(a, w2)
cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy) # 3. 生成模拟数据集
rdm = RandomState(1)
X = rdm.rand(128,2)
Y = [[int(x1+x2 < 1)] for (x1, x2) in X] # 4. 创建一个会话来运行TensorFlow程序
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op) # 输出目前(未经训练)的参数取值。
print("w1:", sess.run(w1))
print("w2:", sess.run(w2))
print("\n") # 训练模型。
STEPS = 5000
for i in range(STEPS):
start = (i * batch_size) % 128
end = (i * batch_size) % 128 + batch_size
sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
if i % 1000 == 0:
total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y})
print("After %d training step(s), cross entropy on all data is %g" % (i, total_cross_entropy)) # 输出训练后的参数取值。
print("\n")
print("w1:", sess.run(w1))
print("w2:", sess.run(w2)) ''' 输出形式
w1: [[-0.81131822 1.48459876 0.06532937]
[-2.4427042 0.0992484 0.59122431]]
w2: [[-0.81131822]
[ 1.48459876]
[ 0.06532937]] After 0 training step(s), cross entropy on all data is 0.0674925
After 1000 training step(s), cross entropy on all data is 0.0163385
After 2000 training step(s), cross entropy on all data is 0.00907547
After 3000 training step(s), cross entropy on all data is 0.00714436
After 4000 training step(s), cross entropy on all data is 0.00578471 w1: [[-1.9618274 2.58235407 1.68203783]
[-3.46817183 1.06982327 2.11789012]]
w2: [[-1.82471502]
[ 2.68546653]
[ 1.41819513]]
'''
二、莫烦大大的神经网络训练步骤:
1、def add_layer()
添加神经网络层:
import tensorflow as tf
#输入、输入大小、输出大小、激活函数 def add_layer( inputs, in_size, out_size ,activation_function=None) : #weight初始化时生成一个随机变量矩阵比0矩阵效果要好 Weights = tf.Variable( tf.random_normal ( [in_size, out_size])) #biases初始值最好也不要都为0,则biases值全部等于0.1 biases = tf.Variable( tf.zeros([1,out_size]) + 0.1) #相当于Y_predict Wx_plus_b = tf.matmul ( inputs,Weights ) +biases
#如果为线性则outputs不用改变,如果不为线性则用激活函数
if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_functions ( Wx_plus_b)
reeturn outputs
2、建立神经网络
#定义数据 x_data = np.linspace ( -1,1,300) [:,np.newaxis]
noise = np.random.normal ( 0,0.05 , x_data.shape)
y_data = np.square( x_data) - 0.5 +noise #建立第一层layer
#一个输入层、一个隐藏层、一个输出层
#输入层:输入多少data就多少个神经元,这里的x只有一个特征属性,则输入层有1个神经元
#隐藏层:自己定义10个
#输出层:输出y只有1个输出 #None表示无论给多少个样本都可以
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None,1] ) #add_layer为上面自己建立的函数,这里建立隐藏层
l1 = add_layer( xs , 1 ,10 ,activation_function = tf.nn.relu)
#输出层
predition = add_layer( l1 ,10 , 1,activation_function = None) #算损失函数 , reduction_indices =[1] 按行求和
loss = tf.reduce_mean ( tf.square ( ys -prediction ),
reduction_indices =[1] ) #选择一个优化器,选择:梯度下降,需要给定一个学习率为0.1,通常要小于1
#优化器以0.1的学习效率要减少loss函数,使下一次结果更好
train_step = tf.train.GradientDecentOptimizer( 0.1).minimize (loss) #初始所有变量 init = tf.initialize_all_variables () sess = tf.Session() sess.run(init)
#重复学习1000次
for i in range(1000):
sess.run( train_step , feed_dict = {xs:x_data,ys:y_data})
#每50次打印loss
if i % 50 == 0:
print(sess.run(loss,feed_dict={x:x_data,ys:y_data})
TensorFlow+实战Google深度学习框架学习笔记(5)----神经网络训练步骤的更多相关文章
- [Tensorflow实战Google深度学习框架]笔记4
本系列为Tensorflow实战Google深度学习框架知识笔记,仅为博主看书过程中觉得较为重要的知识点,简单摘要下来,内容较为零散,请见谅. 2017-11-06 [第五章] MNIST数字识别问题 ...
- 1 如何使用pb文件保存和恢复模型进行迁移学习(学习Tensorflow 实战google深度学习框架)
学习过程是Tensorflow 实战google深度学习框架一书的第六章的迁移学习环节. 具体见我提出的问题:https://www.tensorflowers.cn/t/5314 参考https:/ ...
- 学习《TensorFlow实战Google深度学习框架 (第2版) 》中文PDF和代码
TensorFlow是谷歌2015年开源的主流深度学习框架,目前已得到广泛应用.<TensorFlow:实战Google深度学习框架(第2版)>为TensorFlow入门参考书,帮助快速. ...
- TensorFlow实战Google深度学习框架5-7章学习笔记
目录 第5章 MNIST数字识别问题 第6章 图像识别与卷积神经网络 第7章 图像数据处理 第5章 MNIST数字识别问题 MNIST是一个非常有名的手写体数字识别数据集,在很多资料中,这个数据集都会 ...
- TensorFlow实战Google深度学习框架-人工智能教程-自学人工智能的第二天-深度学习
自学人工智能的第一天 "TensorFlow 是谷歌 2015 年开源的主流深度学习框架,目前已得到广泛应用.本书为 TensorFlow 入门参考书,旨在帮助读者以快速.有效的方式上手 T ...
- 实现迁徙学习-《Tensorflow 实战Google深度学习框架》代码详解
为了实现迁徙学习,首先是数据集的下载 #利用curl下载数据集 curl -o flower_photos.tgz http://download.tensorflow.org/example_ima ...
- 2 (自我拓展)部署花的识别模型(学习tensorflow实战google深度学习框架)
kaggle竞赛的inception模型已经能够提取图像很好的特征,后续训练出一个针对当前图片数据的全连接层,进行花的识别和分类.这里见书即可,不再赘述. 书中使用google参加Kaggle竞赛的i ...
- TensorFlow实战第三课(可视化、加速神经网络训练)
matplotlib可视化 构件图形 用散点图描述真实数据之间的关系(plt.ion()用于连续显示) # plot the real data fig = plt.figure() ax = fig ...
- TensorFlow实战Google深度学习框架10-12章学习笔记
目录 第10章 TensorFlow高层封装 第11章 TensorBoard可视化 第12章 TensorFlow计算加速 第10章 TensorFlow高层封装 目前比较流行的TensorFlow ...
随机推荐
- php实现多图上传功能
总共三个文化 index.html conn.php upload.php index.html代码: <html> <head>上传文件</head> & ...
- 最快理解 - IO多路复用:select / poll / epoll 的区别.
目录 第一个解决方案(多线程) 第二个解决方案(select) 第三个解决方案(poll) 最终解决方案(epoll) 客栈遇到的问题 从开始学习编程后,我就想开一个 Hello World 餐厅,由 ...
- 2、链接数据库+mongodb基础命令行+小demo
链接数据库并且打印出数据的流程:1.在CMD里面输入 mongod 2.在CMD里面输入 mongo 3.在输入mongodb命令行里面进行操作,首先输入 show dbs 来查看是否能够链接得上库4 ...
- 【hdu 4135】Co-prime
[题目链接]:http://acm.hdu.edu.cn/showproblem.php?pid=4135 [题意] 让你求出[a..b]这个区间内和N互质的数的个数; [题解] 利用前缀和,求出[1 ...
- 最大团&稳定婚姻系列
[HDU] 1530 Maximum Clique 1435 Stable Match 3585 maximum shortest distance 二分+最大团 1522 Marriage is ...
- HDU 4345
细心点想,就明白了,题目是求和为N的各数的最小公倍数的种数.其实就是求N以内的各素数的不同的组合(包含他们的次方),当然,是不能超过N的.用Dp能解决.和背包差不多. #include <ios ...
- php+mysql 最简单的留言板
学完了记得动手操作. 測试地址(未过滤) <html> <body> <head><meta http-equiv="Content-Type&qu ...
- .NET泛型初探
总所周知,.NET出现在.net framework 2.0,为什么要在2.0引入泛型那,因为微软在开始开发.net框架时并没有想过多个类型参数传输时对方法的重构,这样一来,开发人员就要面对传输多种类 ...
- Spring整合TimerTask实现定时任务调度
一. 前言 近期在公司的项目中用到了定时任务, 本篇博文将会对TimerTask定时任务进行总结, 事实上TimerTask在实际项目中用的不多, 由于它不能在指定时间执行, 仅仅能让程序依照某一个频 ...
- 网页爬虫框架jsoup介绍
序言:在不知道jsoup框架前,因为项目需求.须要定时抓取其它站点上的内容.便想到用HttpClient方式获取指定站点的内容.这样的方法比較笨,就是通过url请求指定站点.依据指定站点返回文本解析. ...