基于MNIST数据集使用TensorFlow训练一个包含一个隐含层的全连接神经网络
包含一个隐含层的全连接神经网络结构如下:

包含一个隐含层的神经网络结构图
以MNIST数据集为例,以上结构的神经网络训练如下:
#coding=utf-8
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf # 加载数据
mnist = input_data.read_data_sets('/home/workspace/python/tf/data/mnist', one_hot=True)
"""
# 创建模型
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b
"""
x = tf.placeholder(tf.float32, [None, 784])
W1 = tf.Variable(tf.truncated_normal([784, 500], stddev=0.1))
b1 = tf.Variable(tf.zeros([500]))
W2 = tf.Variable(tf.truncated_normal([500, 10], stddev=0.1))
b2 = tf.Variable(tf.zeros([10]))
layer1 = tf.nn.relu(tf.matmul(x, W1) + b1)
y = tf.matmul(layer1, W2) + b2 # 正确的样本标签
y_ = tf.placeholder(tf.float32, [None, 10]) # 损失函数选择softmax后的交叉熵,结果作为y的输出
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) sess = tf.InteractiveSession()
tf.global_variables_initializer().run() # 训练过程
for _ in range(5000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
if _%1000 == 0:
# 使用测试集评估准确率
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print (sess.run(accuracy, feed_dict = {x: mnist.test.images,
y_: mnist.test.labels}))
注意:权重向量初始化时使用tf.truncated_normal,而不要使用tf.zeros
以上代码大概能得到97.98%的准确率。
软件版本
TensorFlow 1.0.1 + Python 2.7.12
基于MNIST数据集使用TensorFlow训练一个包含一个隐含层的全连接神经网络的更多相关文章
- 基于MNIST数据集使用TensorFlow训练一个没有隐含层的浅层神经网络
基础 在参考①中我们详细介绍了没有隐含层的神经网络结构,该神经网络只有输入层和输出层,并且输入层和输出层是通过全连接方式进行连接的.具体结构如下: 我们用此网络结构基于MNIST数据集(参考②)进行训 ...
- TensorFlow之DNN(二):全连接神经网络的加速技巧(Xavier初始化、Adam、Batch Norm、学习率衰减与梯度截断)
在上一篇博客<TensorFlow之DNN(一):构建“裸机版”全连接神经网络>中,我整理了一个用TensorFlow实现的简单全连接神经网络模型,没有运用加速技巧(小批量梯度下降不算哦) ...
- tensorflow中使用mnist数据集训练全连接神经网络-学习笔记
tensorflow中使用mnist数据集训练全连接神经网络 ——学习曹健老师“人工智能实践:tensorflow笔记”的学习笔记, 感谢曹老师 前期准备:mnist数据集下载,并存入data目录: ...
- 【TensorFlow/简单网络】MNIST数据集-softmax、全连接神经网络,卷积神经网络模型
初学tensorflow,参考了以下几篇博客: soft模型 tensorflow构建全连接神经网络 tensorflow构建卷积神经网络 tensorflow构建卷积神经网络 tensorflow构 ...
- 深度学习tensorflow实战笔记(1)全连接神经网络(FCN)训练自己的数据(从txt文件中读取)
1.准备数据 把数据放进txt文件中(数据量大的话,就写一段程序自己把数据自动的写入txt文件中,任何语言都能实现),数据之间用逗号隔开,最后一列标注数据的标签(用于分类),比如0,1.每一行表示一个 ...
- TensorFlow之DNN(一):构建“裸机版”全连接神经网络
博客断更了一周,干啥去了?想做个聊天机器人出来,去看教程了,然后大受打击,哭着回来补TensorFlow和自然语言处理的基础了.本来如意算盘打得挺响,作为一个初学者,直接看项目(不是指MINIST手写 ...
- Tensorflow 多层全连接神经网络
本节涉及: 身份证问题 单层网络的模型 多层全连接神经网络 激活函数 tanh 身份证问题新模型的代码实现 模型的优化 一.身份证问题 身份证号码是18位的数字[此处暂不考虑字母的情况],身份证倒数第 ...
- caffe中全卷积层和全连接层训练参数如何确定
今天来仔细讲一下卷基层和全连接层训练参数个数如何确定的问题.我们以Mnist为例,首先贴出网络配置文件: name: "LeNet" layer { name: "mni ...
- MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)
版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...
随机推荐
- js禁止页面滚动
开发移动端页面的时候有一个很比较常见的需求,在出现弹窗时,禁止滑动弹窗后面的主体页面.如何实现呢,往下看 js实现整个页面禁止滚动: document.body.addEventListener('t ...
- Git:一个简单示例
初始状态:两个分支master/dev都只有一个文件readme.txt 待解决问题:在master分支新增文件,并且修改readme.txt文件,将上述操作同步至远程master分支,最后同步到de ...
- makefile中 = := += 的区别
= 是最基本的赋值 := 是覆盖之前的值 ?= 是如果没有被赋值过就赋予等号后面的值 += 是添加等号后面的值 1.“=” make会将整个makefile展开后,再决定变量的值.也就是说,变量的值将 ...
- javascript 一些特殊的写法
数组+数组: ["f", "o", "o"]+[] 执行结果:"f,o,o" ["f", " ...
- LL(1),LR(0),SLR(1),LR(1),LALR(1)的 联系与区别
一:LR(0),SLR(1),规范LR(1),LALR(1)的关系 首先LL(1)分析法是自上而下的分析法.LR(0),LR(1),SLR(1),LALR(1)是自下而上的分析法. ...
- mongodb安装建议
1)软件包的选择 确保使用最新的稳定版本.目前我们线上使用的版本是2.4.6.MongoDB软件包下载页面http://www.mongodb.org/downloads. 确保线上环境总是使用64位 ...
- go web framework gin middleware 设计原理
场景:一个middleware可以具体为一个函数,而由前面的gin 路由分析可得,每一个路径都对有一个HandlersChain 与其对应. 那么实际上增加一个middleware的过程,就是将每一个 ...
- 【IDEA&&Eclipse】1、为何 IntelliJ IDEA 比 Eclipse 更适合于专业java开发者
圣战 有一些没有唯一正确答案的“永恒”的问题,例如哪个更好:是Windows还是Linux,Java还是C#:谁更强壮:Chuck Norris还是Van Damme. 其中的一个圣战便是Java I ...
- wav文件格式分析与详解
WAV文件是在PC机平台上很常见的.最经典的多媒体音频文件,最早于1991年8月出现在Windows 3.1操作系统上,文件扩展名为WAV,是WaveFom的简写,也称为波形文件,可直接存储声音波形, ...
- 关于attibutedText输出中文字符后的英文和数字进行分开解析的问题
上面的图应该很清楚 具体这个attibutedText 是做什么的就不说了 ,最初我查了资料发现有人和我一样的输出,把一个字符串的中英文分开打印出来是iOS关于UItextVIew和UIlabel的差 ...