用Tensorflow搭建神经网络的一般步骤如下:

① 导入模块

② 创建模型变量和占位符

③ 建立模型

④ 定义loss函数

⑤ 定义优化器(optimizer), 使 loss 达到最小

⑥ 引入激活函数, 即添加非线性因素 (线性回归问题跳过此步骤)

⑦ 训练模型

⑧ 检验模型

⑨ 使用模型预测数据

⑩ 保存模型

⑪ 使用Tensorboard的可视化功能

下面以一个简单的线性回归问题为例:

首先是训练模型的代码: train_model.py

 # ① 导入模块
import tensorflow as tf # ② 创建模型的变量和占位符
W = tf.Variable([.3], dtype=tf.float32)
b = tf.Variable([-.3], dtype=tf.float32)
x = tf.placeholder(tf.float32, name="input_x")
y = tf.placeholder(tf.float32, name="input_y") # ③建立模型
linear_model = W*x + b
# 如果是矩阵相乘,可以写成:
# linear_model = tf.matmul(x, W)+b # matmul表示矩阵相乘 # ④ 定义loss函数
loss = tf.reduce_sum(tf.square(linear_model - y)) # ⑤ 定义优化器(optimizer), 使 loss 达到最小
learning_rate=0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)
train = optimizer.minimize(loss) # ⑥ 引入激活函数, 即添加非线性因素。(线性回归问题跳过此步骤) # ⑦ 训练模型
# 假设模型是y=2x+1
x_train = [1, 2, 3, 4]
y_train = [3, 5, 7, 9] init = tf.global_variables_initializer() # 添加用于初始化变量的节点
sess = tf.Session()
sess.run(init) # 运行初始化操作
for step in range(1000):
sess.run(train, {x: x_train, y: y_train}) '''
第⑦步和第⑩步可以合并为:
for step in xrange(1000000):
sess.run(train, {x: x_train, y: y_train})
if step % 1000 == 0:
saver.save(sess, 'my-model', global_step=step)
''' # ⑧ 检验模型
curr_W, curr_b, curr_loss = sess.run([W, b, loss], {x: x_train, y: y_train})
print("W: %s b: %s loss: %s"%(curr_W, curr_b, curr_loss))
'''
W: [ 2.00000167] b: [ 0.99999553] loss: 1.29603e-11
''' # ⑨ 使用模型预测数据
x_predict = [-1, 0, 1, 2]
predicted_values=sess.run(linear_model, feed_dict={x:x_predict})
# 注意这么一种写法: predicted_values = [(W*x + b).eval(session=sess) for x in x_predict]
print("result:", predicted_values)
'''
result: [-1.0000062 0.99999553 2.99999714 4.99999905]
''' # ⑩ 保存模型
tf.add_to_collection("predict_network", linear_model)
saver = tf.train.Saver()
saver_path=saver.save(sess, "save/model.ckpt") # ⑪ 使用Tensorboard的可视化功能
# 定义保存日志的路径
path = "log" # 也可写成: path = "./log"
writer=tf.summary.FileWriter(path, sess.graph) sess.close()

然后是载入模型的代码: restore_model.py

 import tensorflow as tf

 with tf.Session() as sess:
new_saver=tf.train.import_meta_graph("save/model.ckpt.meta")
new_saver.restore(sess,"save/model.ckpt")
# print(tf.get_collection("predict_network"))
restored_y=tf.get_collection("predict_network")[0] # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可 graph=tf.get_default_graph()
restored_x=graph.get_operation_by_name("input_x").outputs[0] predict_data = [-2, 3, 4]
predicted_result = sess.run(restored_y, feed_dict={restored_x:predict_data}) print("result:", predicted_result) # result: [-3.00000787 7.00000048 9.00000191]

用Tensorflow搭建神经网络的一般步骤的更多相关文章

  1. (转)一文学会用 Tensorflow 搭建神经网络

    一文学会用 Tensorflow 搭建神经网络 本文转自:http://www.jianshu.com/p/e112012a4b2d 字数2259 阅读3168 评论8 喜欢11 cs224d-Day ...

  2. 一文学会用 Tensorflow 搭建神经网络

    http://www.jianshu.com/p/e112012a4b2d 本文是学习这个视频课程系列的笔记,课程链接是 youtube 上的,讲的很好,浅显易懂,入门首选, 而且在github有代码 ...

  3. Tensorflow 搭建神经网络及tensorboard可视化

    1. session对话控制 matrix1 = tf.constant([[3,3]]) matrix2 = tf.constant([[2],[2]]) product = tf.matmul(m ...

  4. kaggle赛题Digit Recognizer:利用TensorFlow搭建神经网络(附上K邻近算法模型预测)

    一.前言 kaggle上有传统的手写数字识别mnist的赛题,通过分类算法,将图片数据进行识别.mnist数据集里面,包含了42000张手写数字0到9的图片,每张图片为28*28=784的像素,所以整 ...

  5. Tensorflow搭建神经网络及使用Tensorboard进行可视化

    创建神经网络模型 1.构建神经网络结构,并进行模型训练 import tensorflow as tfimport numpy as npimport matplotlib.pyplot as plt ...

  6. tensorflow搭建神经网络

    最简单的神经网络 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt date = np.linspa ...

  7. tensorflow搭建神经网络基本流程

    定义添加神经层的函数 1.训练的数据2.定义节点准备接收数据3.定义神经层:隐藏层和预测层4.定义 loss 表达式5.选择 optimizer 使 loss 达到最小 然后对所有变量进行初始化,通过 ...

  8. 基于tensorflow搭建一个神经网络

    一,tensorflow的简介 Tensorflow是一个采用数据流图,用于数值计算的 开源软件库.节点在图中表示数字操作,图中的线 则表示在节点间相互联系的多维数据数组,即张量 它灵活的架构让你可以 ...

  9. Tensorflow学习:(二)搭建神经网络

    一.神经网络的实现过程 1.准备数据集,提取特征,作为输入喂给神经网络       2.搭建神经网络结构,从输入到输出       3.大量特征数据喂给 NN,迭代优化 NN 参数       4.使 ...

随机推荐

  1. web性能测试

    在公司Confluence上看到一篇好文,原链接已不能访问.先收藏 转帖自:http://blog.csdn.net/wxq8102/article/details/1735726 1.1基本概念并发 ...

  2. linux服务基础(一)之CentOS6编译安装httpd2.4

    安装http-2.4 Http依赖于apr-1.4+,apr-util-1.4+ CentOS6上默认是apr-1.3,apr-util1.3 先编译安装apr-1.5,apr-util-1.5 开始 ...

  3. UVA1030 Image Is Everything

    思路 如果两个面看到颜色不同,则这个正方体一定要被删掉 然后依次考虑每个面即可 注意坐标的映射 代码 #include <cstdio> #include <algorithm> ...

  4. (Code) Python implementation of phrase extraction from sentence

    import os import numpy as np import pandas as pd from tqdm import tqdm import numpy as np import str ...

  5. 拿取页面值 跟拿取value里面的值

    拿取页面输入框的数值 使用  val() val()设置或返回表单字段的值 拿取value里面的数值 value(); attr() 获取属性值

  6. EF中打印出执行sql语句

    不用非得去 SQL Server Profiler 中查看了 方法如下: dbContext.Database.Log+=c=>Console.WriteLine(c)

  7. Date中before和after方法的使用

    Date1.after(Date2),当Date1大于Date2时,返回TRUE,当小于等于时,返回false: Date1.before(Date2),当Date1小于Date2时,返回TRUE,当 ...

  8. selenium+Headless Chrome实现不弹出浏览器自动化登录

    目前由于phantomjs已经不维护了,而新版的Chrome(59+)推出了Headless模式,对爬虫来说尤其是定时任务的爬虫截屏之类的是一大好事. 不过按照网络上的一些方法来写的话,会报下面的错误 ...

  9. SparkStreaming流处理

    一.Spark Streaming的介绍 1.       流处理 流式处理(Stream Processing).流式处理就是指源源不断的数据流过系统时,系统能够不停地连续计算.所以流式处理没有什么 ...

  10. lua 5.3.5 安装/初体验

    安装 官网http://www.lua.org/start.html 参考  https://blog.csdn.net/qq_23954569/article/details/70879672 cd ...