用Tensorflow搭建神经网络的一般步骤
用Tensorflow搭建神经网络的一般步骤如下:
① 导入模块
② 创建模型变量和占位符
③ 建立模型
④ 定义loss函数
⑤ 定义优化器(optimizer), 使 loss 达到最小
⑥ 引入激活函数, 即添加非线性因素 (线性回归问题跳过此步骤)
⑦ 训练模型
⑧ 检验模型
⑨ 使用模型预测数据
⑩ 保存模型
⑪ 使用Tensorboard的可视化功能
下面以一个简单的线性回归问题为例:
首先是训练模型的代码: train_model.py
# ① 导入模块
import tensorflow as tf # ② 创建模型的变量和占位符
W = tf.Variable([.3], dtype=tf.float32)
b = tf.Variable([-.3], dtype=tf.float32)
x = tf.placeholder(tf.float32, name="input_x")
y = tf.placeholder(tf.float32, name="input_y") # ③建立模型
linear_model = W*x + b
# 如果是矩阵相乘,可以写成:
# linear_model = tf.matmul(x, W)+b # matmul表示矩阵相乘 # ④ 定义loss函数
loss = tf.reduce_sum(tf.square(linear_model - y)) # ⑤ 定义优化器(optimizer), 使 loss 达到最小
learning_rate=0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)
train = optimizer.minimize(loss) # ⑥ 引入激活函数, 即添加非线性因素。(线性回归问题跳过此步骤) # ⑦ 训练模型
# 假设模型是y=2x+1
x_train = [1, 2, 3, 4]
y_train = [3, 5, 7, 9] init = tf.global_variables_initializer() # 添加用于初始化变量的节点
sess = tf.Session()
sess.run(init) # 运行初始化操作
for step in range(1000):
sess.run(train, {x: x_train, y: y_train}) '''
第⑦步和第⑩步可以合并为:
for step in xrange(1000000):
sess.run(train, {x: x_train, y: y_train})
if step % 1000 == 0:
saver.save(sess, 'my-model', global_step=step)
''' # ⑧ 检验模型
curr_W, curr_b, curr_loss = sess.run([W, b, loss], {x: x_train, y: y_train})
print("W: %s b: %s loss: %s"%(curr_W, curr_b, curr_loss))
'''
W: [ 2.00000167] b: [ 0.99999553] loss: 1.29603e-11
''' # ⑨ 使用模型预测数据
x_predict = [-1, 0, 1, 2]
predicted_values=sess.run(linear_model, feed_dict={x:x_predict})
# 注意这么一种写法: predicted_values = [(W*x + b).eval(session=sess) for x in x_predict]
print("result:", predicted_values)
'''
result: [-1.0000062 0.99999553 2.99999714 4.99999905]
''' # ⑩ 保存模型
tf.add_to_collection("predict_network", linear_model)
saver = tf.train.Saver()
saver_path=saver.save(sess, "save/model.ckpt") # ⑪ 使用Tensorboard的可视化功能
# 定义保存日志的路径
path = "log" # 也可写成: path = "./log"
writer=tf.summary.FileWriter(path, sess.graph) sess.close()
然后是载入模型的代码: restore_model.py
import tensorflow as tf with tf.Session() as sess:
new_saver=tf.train.import_meta_graph("save/model.ckpt.meta")
new_saver.restore(sess,"save/model.ckpt")
# print(tf.get_collection("predict_network"))
restored_y=tf.get_collection("predict_network")[0] # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可 graph=tf.get_default_graph()
restored_x=graph.get_operation_by_name("input_x").outputs[0] predict_data = [-2, 3, 4]
predicted_result = sess.run(restored_y, feed_dict={restored_x:predict_data}) print("result:", predicted_result) # result: [-3.00000787 7.00000048 9.00000191]
用Tensorflow搭建神经网络的一般步骤的更多相关文章
- (转)一文学会用 Tensorflow 搭建神经网络
一文学会用 Tensorflow 搭建神经网络 本文转自:http://www.jianshu.com/p/e112012a4b2d 字数2259 阅读3168 评论8 喜欢11 cs224d-Day ...
- 一文学会用 Tensorflow 搭建神经网络
http://www.jianshu.com/p/e112012a4b2d 本文是学习这个视频课程系列的笔记,课程链接是 youtube 上的,讲的很好,浅显易懂,入门首选, 而且在github有代码 ...
- Tensorflow 搭建神经网络及tensorboard可视化
1. session对话控制 matrix1 = tf.constant([[3,3]]) matrix2 = tf.constant([[2],[2]]) product = tf.matmul(m ...
- kaggle赛题Digit Recognizer:利用TensorFlow搭建神经网络(附上K邻近算法模型预测)
一.前言 kaggle上有传统的手写数字识别mnist的赛题,通过分类算法,将图片数据进行识别.mnist数据集里面,包含了42000张手写数字0到9的图片,每张图片为28*28=784的像素,所以整 ...
- Tensorflow搭建神经网络及使用Tensorboard进行可视化
创建神经网络模型 1.构建神经网络结构,并进行模型训练 import tensorflow as tfimport numpy as npimport matplotlib.pyplot as plt ...
- tensorflow搭建神经网络
最简单的神经网络 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt date = np.linspa ...
- tensorflow搭建神经网络基本流程
定义添加神经层的函数 1.训练的数据2.定义节点准备接收数据3.定义神经层:隐藏层和预测层4.定义 loss 表达式5.选择 optimizer 使 loss 达到最小 然后对所有变量进行初始化,通过 ...
- 基于tensorflow搭建一个神经网络
一,tensorflow的简介 Tensorflow是一个采用数据流图,用于数值计算的 开源软件库.节点在图中表示数字操作,图中的线 则表示在节点间相互联系的多维数据数组,即张量 它灵活的架构让你可以 ...
- Tensorflow学习:(二)搭建神经网络
一.神经网络的实现过程 1.准备数据集,提取特征,作为输入喂给神经网络 2.搭建神经网络结构,从输入到输出 3.大量特征数据喂给 NN,迭代优化 NN 参数 4.使 ...
随机推荐
- Bigger-Mai 养成计划,前端基础学习之CSS
在标签上设置style属性: background-color: #2459a2; height: 48px; ... 编写css样式: 1. 标签的style属性 2. 写在head里面 style ...
- shell编程(五)之函数
function:函数 函数只有被调用才会执行如何调用:给定函数名 函数名出现的地方,会被自动替换为函数代码 函数的生命周期:被调用时创建,返回时终止return命令返回自定义状态结果 0:成功 1- ...
- Swagger和Postman的配置和使用
Swagger 1. 配置 pom文件添加swagger依赖,注意版本,2.8.0可以使用 <dependency> <groupId>io.springfox</gro ...
- cuda cudnn tensorflow-gpu安装
Ububtu18.04下载cuda9.0 下载好后得到: CUDA 9.0仅支持GCC 6.0及以下版本,而Ubuntu 18.04预装GCC版本为7.3,需要安装gcc-6与g++-6 查看当前版本 ...
- MySQL中使用union all获得并集的排序
项目中有时候因为某些不可逆转的原因使得表中存储的数据难以满足在页面中的展示要求.之前的项目上有文章内容的展示功能,文章分为三个状态待发布.已发布.已下线.他们在数据表中判断状态的字段(PROMOTE_ ...
- threejs绕轴转,粒子系统,控制器操作等(二)
前言:threejs系列的第二篇文章,也是一边学习一边总结: 1,一个物体绕着另一个物体转动 上一篇文中主要是物体自转,为了描述一个一个物体绕另一个物体转,这里我描述了一个月球绕地球公转,并且自转的场 ...
- java非阻塞NIO和阻塞IO
1 非阻塞NIO和阻塞IO 1.1 定义 阻塞IO:线程被阻塞,去处理一个读取和写入,中间如果有等待时间,则线程被占用,也不能处理其他任务: 非阻塞IO(new I ...
- 记一次webpack4.x项目配置
在自构建自己的个人页面的时候使用到webpack4,遇到了一些问题,查看了大佬们的文章以及官方文档,在这里总结一下. webpack比较基础的东西就不赘述了,代码里面的注释也会辅助说明,先看一下目录结 ...
- RobotFramework自动化测试框架-Selenium Web自动化(-)-Open Browser和Close Browser
Selenium出来已经有很多年了,从最初的Selenium1到后来的Selenium2,也变得越来越成熟,而且也已经被很多公司广泛使用.Selenium发展的过程中,分了很多模块,这里我们主要介绍W ...
- centos 7 安装二进制mysql 详细步骤
1 下载地址:https://cdn.mysql.com//Downloads/MySQL-5.7/mysql-5.7.24-linux-glibc2.12-x86_64.tar.gz 复制这个链接在 ...