使用之前那个格式写法到后面层数多的话会很乱,所以编写了一个函数创建层,这样看起来可读性高点也更方便整理后期修改维护

#全连接层函数

def fcn_layer(
inputs, #输入数据
input_dim, #输入层神经元数量
output_dim,#输出层神经元数量
activation =None): #激活函数 W = tf.Variable(tf.truncated_normal([input_dim,output_dim],stddev = 0.1))
#以截断正态分布的随机初始化W
b = tf.Variable(tf.zeros([output_dim]))
#以0初始化b
XWb = tf.matmul(inputs,W)+b # Y=WX+B if(activation==None): #默认不使用激活函数
outputs =XWb
else:
outputs = activation(XWb) #代入参数选择的激活函数
return outputs #返回
#各层神经元数量设置
H1_NN = 256
H2_NN = 64
H3_NN = 32 #构建输入层
x = tf.placeholder(tf.float32,[None,784],name='X')
y = tf.placeholder(tf.float32,[None,10],name='Y')
#构建隐藏层
h1 = fcn_layer(x,784,H1_NN,tf.nn.relu)
h2 = fcn_layer(h1,H1_NN,H2_NN,tf.nn.relu)
h3 = fcn_layer(h2,H2_NN,H3_NN,tf.nn.relu)
#构建输出层
forward = fcn_layer(h3,H3_NN,10,None)
pred = tf.nn.softmax(forward)#输出层分类应用使用softmax当作激活函数

这样写方便后期维护 不必对着一群 W1 W2..... Wn

接下来记录一下保存模型的方法

#保存模型
save_step = 5 #储存模型力度
import os
ckpt_dir = '.ckpt_dir/'
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)

  5轮训练保存一次,以后大模型可以调高点,接下来需要在模型整合处修改一下

saver = tf.train.Saver() #声明完所有变量以后,调用tf.train.Saver开始记录

if(epochs+1) % save_step == 0:
  saver.save(sess, os.path.join(ckpt_dir,"mnist_h256_model_{:06d}.ckpt".format(epochs+1)))#储存模型
  print("mnist_h256_model_{:06d}.ckpt saved".format(epochs+1))#输出情况

至此储存模型结束

接下来是还原模型,要注意还原的模型层数和神经元数量大小需要和之前储存模型的大小一致。

第一步设置保存模型文件的路径

#必须指定存储位置
ckpt_dir = "/ckpt_dir/"

存盘只会保存最近的5次,恢复会恢复最新那一份

#恢复模型,创建会话

saver = tf.train.Saver()

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init) ckpt = tf.train.get_checkpoint_state(ckpt_dir)#选择模型保存路径
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess ,ckpt.model_checkpoint_path)#从已保存模型中读取参数
print("Restore model from"+ckpt.model_checkpoint_path)

 至此模型恢复完成 下面可以选择继续训练或者评估使用

最后附上完整代码

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import numpy as np
import matplotlib.pyplot as plt
from time import time
mnist = input_data.read_data_sets("data/",one_hot = True)
#导入Tensorflwo和mnist数据集等 常用库
#全连接层函数 def fcn_layer(
inputs, #输入数据
input_dim, #输入层神经元数量
output_dim,#输出层神经元数量
activation =None): #激活函数 W = tf.Variable(tf.truncated_normal([input_dim,output_dim],stddev = 0.1))
#以截断正态分布的随机初始化W
b = tf.Variable(tf.zeros([output_dim]))
#以0初始化b
XWb = tf.matmul(inputs,W)+b # Y=WX+B if(activation==None): #默认不使用激活函数
outputs =XWb
else:
outputs = activation(XWb) #代入参数选择的激活函数
return outputs #返回
#各层神经元数量设置
H1_NN = 256
H2_NN = 64
H3_NN = 32 #构建输入层
x = tf.placeholder(tf.float32,[None,784],name='X')
y = tf.placeholder(tf.float32,[None,10],name='Y')
#构建隐藏层
h1 = fcn_layer(x,784,H1_NN,tf.nn.relu)
h2 = fcn_layer(h1,H1_NN,H2_NN,tf.nn.relu)
h3 = fcn_layer(h2,H2_NN,H3_NN,tf.nn.relu)
#构建输出层
forward = fcn_layer(h3,H3_NN,10,None)
pred = tf.nn.softmax(forward)#输出层分类应用使用softmax当作激活函数
#损失函数使用交叉熵
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = forward,labels = y))
#设置训练参数
train_epochs = 50
batch_size = 50
total_batch = int(mnist.train.num_examples/batch_size) #随机抽取样本
learning_rate = 0.01
display_step = 1
#优化器
opimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss_function)
#定义准确率
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(pred,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
#保存模型
save_step = 5 #储存模型力度
import os
ckpt_dir = '.ckpt_dir/'
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
#开始训练
sess = tf.Session()
init = tf.global_variables_initializer()
saver = tf.train.Saver() #声明完所有变量以后,调用tf.train.Saver开始记录
startTime = time()
sess.run(init)
for epochs in range(train_epochs):
for batch in range(total_batch):
xs,ys = mnist.train.next_batch(batch_size)#读取批次数据
sess.run(opimizer,feed_dict={x:xs,y:ys})#执行批次数据训练 #total_batch个批次训练完成后,使用验证数据计算误差与准确率
loss,acc = sess.run([loss_function,accuracy],
feed_dict={
x:mnist.validation.images,
y:mnist.validation.labels})
#输出训练情况
if(epochs+1) % display_step == 0:
epochs += 1
print("Train Epoch:",epochs,
"Loss=",loss,"Accuracy=",acc)
if(epochs+1) % save_step == 0:
saver.save(sess, os.path.join(ckpt_dir,"mnist_h256_model_{:06d}.ckpt".format(epochs+1)))
print("mnist_h256_model_{:06d}.ckpt saved".format(epochs+1))
duration = time()-startTime
print("Trian Finshed takes:","{:.2f}".format(duration))#显示预测耗时
#评估模型
accu_test = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("model accuracy:",accu_test)
#恢复模型,创建会话 saver = tf.train.Saver() sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init) ckpt = tf.train.get_checkpoint_state(ckpt_dir)#选择模型保存路径
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess ,ckpt.model_checkpoint_path)#从已保存模型中读取参数
print("Restore model from"+ckpt.model_checkpoint_path)

完整代码

  

基于tensorflow使用全连接层函数实现多层神经网络并保存和读取模型的更多相关文章

  1. 深度学习原理与框架-卷积网络细节-图像分类与图像位置回归任务 1.模型加载 2.串接新的全连接层 3.使用SGD梯度对参数更新 4.模型结果测试 5.各个模型效果对比

    对于图像的目标检测任务:通常分为目标的类别检测和目标的位置检测 目标的类别检测使用的指标:准确率, 预测的结果是类别值,即cat 目标的位置检测使用的指标:欧式距离,预测的结果是(x, y, w, h ...

  2. 基于tensorflow实现mnist手写识别 (多层神经网络)

    标题党其实也不多,一个输入层,三个隐藏层,一个输出层 老样子先上代码 导入mnist的路径很长,现在还记不住 import tensorflow as tf import tensorflow.exa ...

  3. tensorflow 1.0 学习:池化层(pooling)和全连接层(dense)

    池化层定义在 tensorflow/python/layers/pooling.py. 有最大值池化和均值池化. 1.tf.layers.max_pooling2d max_pooling2d( in ...

  4. tensorflow 添加一个全连接层

    对于一个全连接层,tensorflow都为我们封装好了. 使用:tf.layers.dense() tf.layers.dense( inputs, units, activation=None, u ...

  5. keras channels_last、preprocess_input、全连接层Dense、SGD优化器、模型及编译

    channels_last 和 channels_first keras中 channels_last 和 channels_first 用来设定数据的维度顺序(image_data_format). ...

  6. resnet18全连接层改成卷积层

    想要尝试一下将resnet18最后一层的全连接层改成卷积层看会不会对网络效果和网络大小有什么影响 1.首先先对train.py中的更改是: train.py代码可见:pytorch实现性别检测 # m ...

  7. Caffe源码阅读(1) 全连接层

    Caffe源码阅读(1) 全连接层 发表于 2014-09-15   |   今天看全连接层的实现.主要看的是https://github.com/BVLC/caffe/blob/master/src ...

  8. 深度学习基础系列(十)| Global Average Pooling是否可以替代全连接层?

    Global Average Pooling(简称GAP,全局池化层)技术最早提出是在这篇论文(第3.2节)中,被认为是可以替代全连接层的一种新技术.在keras发布的经典模型中,可以看到不少模型甚至 ...

  9. TensorFlow------单层(全连接层)实现手写数字识别训练及测试实例

    TensorFlow之单层(全连接层)实现手写数字识别训练及测试实例: import tensorflow as tf from tensorflow.examples.tutorials.mnist ...

随机推荐

  1. spark-机器学习实践-K近邻应用实践一

    K近邻应用-异常检测应用 原理: 根据数据样本进行KMeans机器学习模型的建立,获取簇心点,以簇为单位,离簇心最远的第五个点的距离为阈值,大于这个值的为异常点,即获得数据异常. 如图:

  2. C++第七次作业

    关于计算器项目的总结: 一.就目前完成的计算器,包括界面的实现这部分,总体实现了简单计算的功能,但仍有很多不足之处: 需改进完善之处:1.关于界面可再优化: 2.界面放大时,无法自动聚焦(按钮等控件无 ...

  3. react-navigation 使用笔记 持续更新中

    目录 基本使用(此处基本使用仅针对导航头部而言,不包含tabbar等) header怎么和app中通信呢? React-Navigation是目前React-Native官方推荐的导航组件,代替了原用 ...

  4. 【洛谷】【前缀和+st表】P2629 好消息,坏消息

    [题目描述:] uim在公司里面当秘书,现在有n条消息要告知老板.每条消息有一个好坏度,这会影响老板的心情.告知完一条消息后,老板的心情等于之前老板的心情加上这条消息的好坏度.最开始老板的心情是0,一 ...

  5. SQL必知必会摘要

    数据检索 2.2 检索单个列 SELECT prod_name FROM Products; SQL语句不区分大小写   2.3 检索多个列 SELECT prod_name,prod_id,prod ...

  6. PAT B1018 锤子剪刀布 (20 分)

    大家应该都会玩“锤子剪刀布”的游戏:两人同时给出手势,胜负规则如图所示: 现给出两人的交锋记录,请统计双方的胜.平.负次数,并且给出双方分别出什么手势的胜算最大. 输入格式: 输入第 1 行给出正整数 ...

  7. Docker学习5-Services – 服务(未完待续)

    扩展应用程序并启用负载平衡, 为此,必须在分布式应用程序的层次结构中提升一级:服务.在分布式应用程序中,应用程序的不同部分称为“服务”.例如,一个视频共享站点,它可能包含用于将应用程序数据存储在数据库 ...

  8. JavaScript基础注意点

    1.每个语句结尾一定加上分号 2.JavaScript本身对嵌套{ }的层级没有限制,但是过多的嵌套无疑会大大增加看懂代码的难度.遇到这种情况,需要把部分代码抽出来,作为函数来调用,这样可以减少代码的 ...

  9. Fiddler抓包调试前端脚本代码

    0.写在前面的话 之前看了阮一峰老师关于互联网协议入门的博客,受益匪浅,接着再去体会了下HTTP协议,就想着看实际网络访问中的那些HTTP请求头和响应是什么样的.Chrome的调试工具的Network ...

  10. 使用clipboard插件实现div、textarea、input里面的内容复制到粘贴板

    一.引用clipboard的js文件 二.编写代码.data-clipboard-action=“copy”,代表要执行的动作是复制.data-clipboard-target里面要是要选择复制的元素 ...