用Tensorflow搭建神经网络的一般步骤如下:

① 导入模块

② 创建模型变量和占位符

③ 建立模型

④ 定义loss函数

⑤ 定义优化器(optimizer), 使 loss 达到最小

⑥ 引入激活函数, 即添加非线性因素 (线性回归问题跳过此步骤)

⑦ 训练模型

⑧ 检验模型

⑨ 使用模型预测数据

⑩ 保存模型

⑪ 使用Tensorboard的可视化功能

下面以一个简单的线性回归问题为例:

首先是训练模型的代码: train_model.py

 # ① 导入模块
import tensorflow as tf # ② 创建模型的变量和占位符
W = tf.Variable([.3], dtype=tf.float32)
b = tf.Variable([-.3], dtype=tf.float32)
x = tf.placeholder(tf.float32, name="input_x")
y = tf.placeholder(tf.float32, name="input_y") # ③建立模型
linear_model = W*x + b
# 如果是矩阵相乘,可以写成:
# linear_model = tf.matmul(x, W)+b # matmul表示矩阵相乘 # ④ 定义loss函数
loss = tf.reduce_sum(tf.square(linear_model - y)) # ⑤ 定义优化器(optimizer), 使 loss 达到最小
learning_rate=0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)
train = optimizer.minimize(loss) # ⑥ 引入激活函数, 即添加非线性因素。(线性回归问题跳过此步骤) # ⑦ 训练模型
# 假设模型是y=2x+1
x_train = [1, 2, 3, 4]
y_train = [3, 5, 7, 9] init = tf.global_variables_initializer() # 添加用于初始化变量的节点
sess = tf.Session()
sess.run(init) # 运行初始化操作
for step in range(1000):
sess.run(train, {x: x_train, y: y_train}) '''
第⑦步和第⑩步可以合并为:
for step in xrange(1000000):
sess.run(train, {x: x_train, y: y_train})
if step % 1000 == 0:
saver.save(sess, 'my-model', global_step=step)
''' # ⑧ 检验模型
curr_W, curr_b, curr_loss = sess.run([W, b, loss], {x: x_train, y: y_train})
print("W: %s b: %s loss: %s"%(curr_W, curr_b, curr_loss))
'''
W: [ 2.00000167] b: [ 0.99999553] loss: 1.29603e-11
''' # ⑨ 使用模型预测数据
x_predict = [-1, 0, 1, 2]
predicted_values=sess.run(linear_model, feed_dict={x:x_predict})
# 注意这么一种写法: predicted_values = [(W*x + b).eval(session=sess) for x in x_predict]
print("result:", predicted_values)
'''
result: [-1.0000062 0.99999553 2.99999714 4.99999905]
''' # ⑩ 保存模型
tf.add_to_collection("predict_network", linear_model)
saver = tf.train.Saver()
saver_path=saver.save(sess, "save/model.ckpt") # ⑪ 使用Tensorboard的可视化功能
# 定义保存日志的路径
path = "log" # 也可写成: path = "./log"
writer=tf.summary.FileWriter(path, sess.graph) sess.close()

然后是载入模型的代码: restore_model.py

 import tensorflow as tf

 with tf.Session() as sess:
new_saver=tf.train.import_meta_graph("save/model.ckpt.meta")
new_saver.restore(sess,"save/model.ckpt")
# print(tf.get_collection("predict_network"))
restored_y=tf.get_collection("predict_network")[0] # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可 graph=tf.get_default_graph()
restored_x=graph.get_operation_by_name("input_x").outputs[0] predict_data = [-2, 3, 4]
predicted_result = sess.run(restored_y, feed_dict={restored_x:predict_data}) print("result:", predicted_result) # result: [-3.00000787 7.00000048 9.00000191]

用Tensorflow搭建神经网络的一般步骤的更多相关文章

  1. (转)一文学会用 Tensorflow 搭建神经网络

    一文学会用 Tensorflow 搭建神经网络 本文转自:http://www.jianshu.com/p/e112012a4b2d 字数2259 阅读3168 评论8 喜欢11 cs224d-Day ...

  2. 一文学会用 Tensorflow 搭建神经网络

    http://www.jianshu.com/p/e112012a4b2d 本文是学习这个视频课程系列的笔记,课程链接是 youtube 上的,讲的很好,浅显易懂,入门首选, 而且在github有代码 ...

  3. Tensorflow 搭建神经网络及tensorboard可视化

    1. session对话控制 matrix1 = tf.constant([[3,3]]) matrix2 = tf.constant([[2],[2]]) product = tf.matmul(m ...

  4. kaggle赛题Digit Recognizer:利用TensorFlow搭建神经网络(附上K邻近算法模型预测)

    一.前言 kaggle上有传统的手写数字识别mnist的赛题,通过分类算法,将图片数据进行识别.mnist数据集里面,包含了42000张手写数字0到9的图片,每张图片为28*28=784的像素,所以整 ...

  5. Tensorflow搭建神经网络及使用Tensorboard进行可视化

    创建神经网络模型 1.构建神经网络结构,并进行模型训练 import tensorflow as tfimport numpy as npimport matplotlib.pyplot as plt ...

  6. tensorflow搭建神经网络

    最简单的神经网络 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt date = np.linspa ...

  7. tensorflow搭建神经网络基本流程

    定义添加神经层的函数 1.训练的数据2.定义节点准备接收数据3.定义神经层:隐藏层和预测层4.定义 loss 表达式5.选择 optimizer 使 loss 达到最小 然后对所有变量进行初始化,通过 ...

  8. 基于tensorflow搭建一个神经网络

    一,tensorflow的简介 Tensorflow是一个采用数据流图,用于数值计算的 开源软件库.节点在图中表示数字操作,图中的线 则表示在节点间相互联系的多维数据数组,即张量 它灵活的架构让你可以 ...

  9. Tensorflow学习:(二)搭建神经网络

    一.神经网络的实现过程 1.准备数据集,提取特征,作为输入喂给神经网络       2.搭建神经网络结构,从输入到输出       3.大量特征数据喂给 NN,迭代优化 NN 参数       4.使 ...

随机推荐

  1. 编程类-----matlab基础语法复习(2)

    2019年美赛准备:matlab基本题目运算 clear,clc %% 计算1/3 + 2/5 + ...3/7 +10/21 % i = 1; j = 3; ans = 0; % while i & ...

  2. java和js中int和String相互转换常用方法整理

    java中int和String的相互转换常用的几种方法: int  > String int i=10;String s="";第一种方法:s=i+""; ...

  3. caffe特征层可视化

    #参考1:https://blog.csdn.net/sushiqian/article/details/78614133#参考2:https://blog.csdn.net/thy_2014/art ...

  4. python3中argparse模块

    1.定义:argparse是python标准库里面用来处理命令行参数的库 2.命令行参数分为位置参数和选项参数:         位置参数就是程序根据该参数出现的位置来确定的              ...

  5. LGOJ P3834 【模板】可持久化线段树 1(主席树)

    代码 #include <cstdio> #include <iostream> #include <algorithm> using namespace std; ...

  6. Oh-My-Zsh及主题、插件安装与配置

    切换zsh Manjaro linux默认安装了zsh,其他可能需要先安装 cat /etc/shells #查看本地有哪几种shell chsh -s /bin/zsh #切换到zsh 默认终端启动 ...

  7. 在Linux(Debian)环境下搭建并运行GPU

    首先通过以下命令查看是否GPU驱动成功: 注意:需要在bash终端输入 import tensorflow as tf hello = tf.constant('Hello, TensorFlow!' ...

  8. loadrunner常用函数集锦

    一.三个复制函数的区别: strcpy 原型:extern char *strcpy(char *dest,char *src);用法:#i nclude功能:把src所指由NULL结束的字符串复制到 ...

  9. Tensorflow训练和预测中的BN层的坑

    以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了.在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在<实战Google ...

  10. leecode第二百一十七题(存在重复元素)

    class Solution { public: bool containsDuplicate(vector<int>& nums) { set<int> s; for ...