用Tensorflow搭建神经网络的一般步骤如下:

① 导入模块

② 创建模型变量和占位符

③ 建立模型

④ 定义loss函数

⑤ 定义优化器(optimizer), 使 loss 达到最小

⑥ 引入激活函数, 即添加非线性因素 (线性回归问题跳过此步骤)

⑦ 训练模型

⑧ 检验模型

⑨ 使用模型预测数据

⑩ 保存模型

⑪ 使用Tensorboard的可视化功能

下面以一个简单的线性回归问题为例:

首先是训练模型的代码: train_model.py

 # ① 导入模块
import tensorflow as tf # ② 创建模型的变量和占位符
W = tf.Variable([.3], dtype=tf.float32)
b = tf.Variable([-.3], dtype=tf.float32)
x = tf.placeholder(tf.float32, name="input_x")
y = tf.placeholder(tf.float32, name="input_y") # ③建立模型
linear_model = W*x + b
# 如果是矩阵相乘,可以写成:
# linear_model = tf.matmul(x, W)+b # matmul表示矩阵相乘 # ④ 定义loss函数
loss = tf.reduce_sum(tf.square(linear_model - y)) # ⑤ 定义优化器(optimizer), 使 loss 达到最小
learning_rate=0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)
train = optimizer.minimize(loss) # ⑥ 引入激活函数, 即添加非线性因素。(线性回归问题跳过此步骤) # ⑦ 训练模型
# 假设模型是y=2x+1
x_train = [1, 2, 3, 4]
y_train = [3, 5, 7, 9] init = tf.global_variables_initializer() # 添加用于初始化变量的节点
sess = tf.Session()
sess.run(init) # 运行初始化操作
for step in range(1000):
sess.run(train, {x: x_train, y: y_train}) '''
第⑦步和第⑩步可以合并为:
for step in xrange(1000000):
sess.run(train, {x: x_train, y: y_train})
if step % 1000 == 0:
saver.save(sess, 'my-model', global_step=step)
''' # ⑧ 检验模型
curr_W, curr_b, curr_loss = sess.run([W, b, loss], {x: x_train, y: y_train})
print("W: %s b: %s loss: %s"%(curr_W, curr_b, curr_loss))
'''
W: [ 2.00000167] b: [ 0.99999553] loss: 1.29603e-11
''' # ⑨ 使用模型预测数据
x_predict = [-1, 0, 1, 2]
predicted_values=sess.run(linear_model, feed_dict={x:x_predict})
# 注意这么一种写法: predicted_values = [(W*x + b).eval(session=sess) for x in x_predict]
print("result:", predicted_values)
'''
result: [-1.0000062 0.99999553 2.99999714 4.99999905]
''' # ⑩ 保存模型
tf.add_to_collection("predict_network", linear_model)
saver = tf.train.Saver()
saver_path=saver.save(sess, "save/model.ckpt") # ⑪ 使用Tensorboard的可视化功能
# 定义保存日志的路径
path = "log" # 也可写成: path = "./log"
writer=tf.summary.FileWriter(path, sess.graph) sess.close()

然后是载入模型的代码: restore_model.py

 import tensorflow as tf

 with tf.Session() as sess:
new_saver=tf.train.import_meta_graph("save/model.ckpt.meta")
new_saver.restore(sess,"save/model.ckpt")
# print(tf.get_collection("predict_network"))
restored_y=tf.get_collection("predict_network")[0] # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可 graph=tf.get_default_graph()
restored_x=graph.get_operation_by_name("input_x").outputs[0] predict_data = [-2, 3, 4]
predicted_result = sess.run(restored_y, feed_dict={restored_x:predict_data}) print("result:", predicted_result) # result: [-3.00000787 7.00000048 9.00000191]

用Tensorflow搭建神经网络的一般步骤的更多相关文章

  1. (转)一文学会用 Tensorflow 搭建神经网络

    一文学会用 Tensorflow 搭建神经网络 本文转自:http://www.jianshu.com/p/e112012a4b2d 字数2259 阅读3168 评论8 喜欢11 cs224d-Day ...

  2. 一文学会用 Tensorflow 搭建神经网络

    http://www.jianshu.com/p/e112012a4b2d 本文是学习这个视频课程系列的笔记,课程链接是 youtube 上的,讲的很好,浅显易懂,入门首选, 而且在github有代码 ...

  3. Tensorflow 搭建神经网络及tensorboard可视化

    1. session对话控制 matrix1 = tf.constant([[3,3]]) matrix2 = tf.constant([[2],[2]]) product = tf.matmul(m ...

  4. kaggle赛题Digit Recognizer:利用TensorFlow搭建神经网络(附上K邻近算法模型预测)

    一.前言 kaggle上有传统的手写数字识别mnist的赛题,通过分类算法,将图片数据进行识别.mnist数据集里面,包含了42000张手写数字0到9的图片,每张图片为28*28=784的像素,所以整 ...

  5. Tensorflow搭建神经网络及使用Tensorboard进行可视化

    创建神经网络模型 1.构建神经网络结构,并进行模型训练 import tensorflow as tfimport numpy as npimport matplotlib.pyplot as plt ...

  6. tensorflow搭建神经网络

    最简单的神经网络 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt date = np.linspa ...

  7. tensorflow搭建神经网络基本流程

    定义添加神经层的函数 1.训练的数据2.定义节点准备接收数据3.定义神经层:隐藏层和预测层4.定义 loss 表达式5.选择 optimizer 使 loss 达到最小 然后对所有变量进行初始化,通过 ...

  8. 基于tensorflow搭建一个神经网络

    一,tensorflow的简介 Tensorflow是一个采用数据流图,用于数值计算的 开源软件库.节点在图中表示数字操作,图中的线 则表示在节点间相互联系的多维数据数组,即张量 它灵活的架构让你可以 ...

  9. Tensorflow学习:(二)搭建神经网络

    一.神经网络的实现过程 1.准备数据集,提取特征,作为输入喂给神经网络       2.搭建神经网络结构,从输入到输出       3.大量特征数据喂给 NN,迭代优化 NN 参数       4.使 ...

随机推荐

  1. 0x11栈之Editor

    参考链接:https://blog.csdn.net/SSLGZ_yyc/article/details/81700623 对顶栈的思想: 建立两个栈,栈A存储从序列开头到当前光标的位置的一段序列,栈 ...

  2. Linux 设置系统时间和时区1.Centos

  3. Python3 tkinter基础 Scrollbar pack 创建靠右、充满Y轴的垂直滚动条

             Python : 3.7.0          OS : Ubuntu 18.04.1 LTS         IDE : PyCharm 2018.2.4       Conda ...

  4. NPOI 的使用姿势

    NPOI 正确的使用姿势 主要是需要注意公式和日期类型的单元格的读取. /// <summary> /// 打开指定 Excel 文件 /// </summary> /// & ...

  5. 记一次VM虚拟机Ubuntu无法联网问题

    突然ubuntu获取不到ipv4地址,手动设置静态ip也ping不通本机, 在网上试了一堆的方法也不行,就怀疑是vm设置问题了.因为 作业环境我的VM需要经常性的改变桥接的网卡,所以检查了一 下这里, ...

  6. Spring定时任务@Scheduled注解使用方式

    1.开篇 spring的@Scheduled定时任务相信大家都是十分熟悉.最近在使用过程中发现了一些问题,写篇文章,和大家分享一下.结论在最后,不想看冗长过程的小伙伴可以直接拉到最后看结论. 2.简单 ...

  7. webpack-工程化工具

    一.简介 1.webpack 是 facebook 公司发布的一款工程化工具,早期有 react 使用. 2.核心理念: 一切都是资源,是资源我们就能模块化打包加载. 3.webpack 默认支持 c ...

  8. 【架构设计】Android:配置式金字塔架构

    最近做一个项目,在项目搭建之前,花了些许时间去思考一下如何搭建一个合适的架构.一开始的构思是希望能合理的把应用的各部分进行分离,使其像金字塔一样从上往下,下层为上层提供功能. 在平常项目中,总是有很多 ...

  9. HTML和CSS查缺补漏

    margin的问题: 1.margin-top向上传递 解决:1.父元素border边框,2.父元素使用overflow:hidden 3.为父元素或者子元素声明绝对定位,4.为父元素或者子元素声明浮 ...

  10. C#通过DocX创建word

    网上有一些基础的东西,但是比如插入图片,就没有找到方案,最终自己摸索出来的. 1.首先通过Nuget获取引用,关键字:“DocX” 2.示例代码 class Program { static void ...