用Tensorflow搭建神经网络的一般步骤
用Tensorflow搭建神经网络的一般步骤如下:
① 导入模块
② 创建模型变量和占位符
③ 建立模型
④ 定义loss函数
⑤ 定义优化器(optimizer), 使 loss 达到最小
⑥ 引入激活函数, 即添加非线性因素 (线性回归问题跳过此步骤)
⑦ 训练模型
⑧ 检验模型
⑨ 使用模型预测数据
⑩ 保存模型
⑪ 使用Tensorboard的可视化功能
下面以一个简单的线性回归问题为例:
首先是训练模型的代码: train_model.py
# ① 导入模块
import tensorflow as tf # ② 创建模型的变量和占位符
W = tf.Variable([.3], dtype=tf.float32)
b = tf.Variable([-.3], dtype=tf.float32)
x = tf.placeholder(tf.float32, name="input_x")
y = tf.placeholder(tf.float32, name="input_y") # ③建立模型
linear_model = W*x + b
# 如果是矩阵相乘,可以写成:
# linear_model = tf.matmul(x, W)+b # matmul表示矩阵相乘 # ④ 定义loss函数
loss = tf.reduce_sum(tf.square(linear_model - y)) # ⑤ 定义优化器(optimizer), 使 loss 达到最小
learning_rate=0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)
train = optimizer.minimize(loss) # ⑥ 引入激活函数, 即添加非线性因素。(线性回归问题跳过此步骤) # ⑦ 训练模型
# 假设模型是y=2x+1
x_train = [1, 2, 3, 4]
y_train = [3, 5, 7, 9] init = tf.global_variables_initializer() # 添加用于初始化变量的节点
sess = tf.Session()
sess.run(init) # 运行初始化操作
for step in range(1000):
sess.run(train, {x: x_train, y: y_train}) '''
第⑦步和第⑩步可以合并为:
for step in xrange(1000000):
sess.run(train, {x: x_train, y: y_train})
if step % 1000 == 0:
saver.save(sess, 'my-model', global_step=step)
''' # ⑧ 检验模型
curr_W, curr_b, curr_loss = sess.run([W, b, loss], {x: x_train, y: y_train})
print("W: %s b: %s loss: %s"%(curr_W, curr_b, curr_loss))
'''
W: [ 2.00000167] b: [ 0.99999553] loss: 1.29603e-11
''' # ⑨ 使用模型预测数据
x_predict = [-1, 0, 1, 2]
predicted_values=sess.run(linear_model, feed_dict={x:x_predict})
# 注意这么一种写法: predicted_values = [(W*x + b).eval(session=sess) for x in x_predict]
print("result:", predicted_values)
'''
result: [-1.0000062 0.99999553 2.99999714 4.99999905]
''' # ⑩ 保存模型
tf.add_to_collection("predict_network", linear_model)
saver = tf.train.Saver()
saver_path=saver.save(sess, "save/model.ckpt") # ⑪ 使用Tensorboard的可视化功能
# 定义保存日志的路径
path = "log" # 也可写成: path = "./log"
writer=tf.summary.FileWriter(path, sess.graph) sess.close()
然后是载入模型的代码: restore_model.py
import tensorflow as tf with tf.Session() as sess:
new_saver=tf.train.import_meta_graph("save/model.ckpt.meta")
new_saver.restore(sess,"save/model.ckpt")
# print(tf.get_collection("predict_network"))
restored_y=tf.get_collection("predict_network")[0] # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可 graph=tf.get_default_graph()
restored_x=graph.get_operation_by_name("input_x").outputs[0] predict_data = [-2, 3, 4]
predicted_result = sess.run(restored_y, feed_dict={restored_x:predict_data}) print("result:", predicted_result) # result: [-3.00000787 7.00000048 9.00000191]
用Tensorflow搭建神经网络的一般步骤的更多相关文章
- (转)一文学会用 Tensorflow 搭建神经网络
一文学会用 Tensorflow 搭建神经网络 本文转自:http://www.jianshu.com/p/e112012a4b2d 字数2259 阅读3168 评论8 喜欢11 cs224d-Day ...
- 一文学会用 Tensorflow 搭建神经网络
http://www.jianshu.com/p/e112012a4b2d 本文是学习这个视频课程系列的笔记,课程链接是 youtube 上的,讲的很好,浅显易懂,入门首选, 而且在github有代码 ...
- Tensorflow 搭建神经网络及tensorboard可视化
1. session对话控制 matrix1 = tf.constant([[3,3]]) matrix2 = tf.constant([[2],[2]]) product = tf.matmul(m ...
- kaggle赛题Digit Recognizer:利用TensorFlow搭建神经网络(附上K邻近算法模型预测)
一.前言 kaggle上有传统的手写数字识别mnist的赛题,通过分类算法,将图片数据进行识别.mnist数据集里面,包含了42000张手写数字0到9的图片,每张图片为28*28=784的像素,所以整 ...
- Tensorflow搭建神经网络及使用Tensorboard进行可视化
创建神经网络模型 1.构建神经网络结构,并进行模型训练 import tensorflow as tfimport numpy as npimport matplotlib.pyplot as plt ...
- tensorflow搭建神经网络
最简单的神经网络 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt date = np.linspa ...
- tensorflow搭建神经网络基本流程
定义添加神经层的函数 1.训练的数据2.定义节点准备接收数据3.定义神经层:隐藏层和预测层4.定义 loss 表达式5.选择 optimizer 使 loss 达到最小 然后对所有变量进行初始化,通过 ...
- 基于tensorflow搭建一个神经网络
一,tensorflow的简介 Tensorflow是一个采用数据流图,用于数值计算的 开源软件库.节点在图中表示数字操作,图中的线 则表示在节点间相互联系的多维数据数组,即张量 它灵活的架构让你可以 ...
- Tensorflow学习:(二)搭建神经网络
一.神经网络的实现过程 1.准备数据集,提取特征,作为输入喂给神经网络 2.搭建神经网络结构,从输入到输出 3.大量特征数据喂给 NN,迭代优化 NN 参数 4.使 ...
随机推荐
- Bigger-Mai 养成计划,Python基础巩固二
模块初识1.标准库2.第三方库import sys sys.path #自己的本文件名不可为sys.py#输出模块存储的环境变量sys.argv #打印脚本的相对路径sys.argv[2] #取第二个 ...
- 【FJOI 20170305】省选模拟赛
题面被改成了个猪... T1猪猪划船(boat) [题目描述] 6只可爱的猪猪们一起旅游,其中有3只大猪A,B,C,他们的孩子为3只小猪a,b,c.由于猪猪们十分凶残,如果小猪在没有父母监护的情况下, ...
- 【新特性】JDK11
随着JDK11正式发布,带来了许多新的特性.本文主要介绍JDK11的部分新特性和新的API. 一.Local Var 在Lambda表达式中,可以使用var关键字来标识变量,变量类型由编译器自行推断. ...
- ModelBiner不验证某个属性
问题 使用MVC的同学十有八九都会遇到这个错误:从客户端(Content="<script>...")中检测到有潜在危险的Request.Form 值. 这个错误是在请 ...
- C语言--第4次作业
1.本章学习总结 1.1思维导图 1.2本章学习体会及代码量学习体会 1.2.1学习体会 初识数组:这几周第一次接触数组,感觉有点懵,是一个很陌生的知识点,但是运用范围及其广泛,大大简化了程序,增大了 ...
- 剑指offer 08:跳台阶
题目描述 一只青蛙一次可以跳上1级台阶,也可以跳上2级.求该青蛙跳上一个n级的台阶总共有多少种跳法(先后次序不同算不同的结果). public class Solution { public int ...
- 前端开发需要掌握的SEO的知识点
SEO 工作的目的 seo 的工作目的是为了让网站更利于让各大搜索引擎抓取和收录,增加产品的曝光率. SEO 注意事项 1. 网站 TDK 标签的设置.title,description,keywor ...
- postgres 11 单实例最大支持多少个database?
有人在pg8时代(10年前)问过,当时说10000个没问题,而且每个db会在/base下建立1个文件夹, 文件ext3只支持32000个子文件夹,所以这是上限了. 而现在早就ext4了,根本没有限制了 ...
- centos7安装python,mariaDB,django,nginx
0,安装centos7 centos默认不开启网卡,需要在安装时将ens33设置为on,或者后续通过vi ifcfg-ens33,找到onboot,设置为yes ssg登陆centos7时,如果提示W ...
- android手机短信获取
关于Android中对短信的一些相关操.我看到一个文章下面我就从标题中的三个方面来对Android系统中的短信操作进行一个简单地学习. 短信发送: 由于Android中对短信发送方法的优良封装,之后对 ...