深度学习原理与框架-Tensorflow卷积神经网络-神经网络mnist分类
使用tensorflow构造神经网络用来进行mnist数据集的分类
相比与上一节讲到的逻辑回归,神经网络比逻辑回归多了隐藏层,同时在每一个线性变化后添加了relu作为激活函数,
神经网络使用的损失值为softmax概率损失值,即为交叉熵损失值
代码:使用的是mnist数据集作为分类的测试数据,数据的维度为50000*784
第一步:载入mnist数据集
第二步:超参数的设置,输入图片的大小,分类的类别数,迭代的次数,每一个batch的大小
第三步:使用tf.placeholder() 进行输入数据的设置,进行数据的占位
第四步:使用tf.Variable() 设置里面设置tf.truncated_normal([inputSize, num_hidden], sttdv=0.1) 设置w的初始值,使用tf.Variable(tf.constant(0.1, [num_hidden])) 设置b,这一步主要是进行初始参数设置
第五步:使用tf.nn.relu(tf.matmul()+b) 构造第一层的网络,tf.nn.relu(tf.matmul() + b) 构造第二层的网络, tf.matmul() + b 构造输出层的得分
第六步:使用tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=score)) 来构造损失函数
第七步:使用tf.train.GradientDescentOptimizer().minimize(loss) 使用下降梯度降低损失值
第八步:使用tf.equal(tf.argmax(y, 1), tf.argmax(score, 1)) 即 tf.reduce_mean(tf.cast) 进行准确率的求解
第九步:进行循环,使用mnist.train.next_batch(batchSize) 读取部分数据
第十步:使用sess.run()执行梯度下降操作
第十一步:循环一千次,执行准确率的操作,并打印准确率
第十二步:使用验证集进行结果的验证
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data # 第一步数据读取
mnist = input_data.read_data_sets('/data', one_hot=True) # 第二步:超参数的设置
# 输入图片的大小
inputSize = 784
# 分类的类别数
num_classes = 10
# 隐藏层的个数
num_hidden = 50
# 迭代的次数
trainIteration = 10000
# 每一个batch值的大小
batch_size = 100
# 第二个隐藏层的个数
num_hidden_2 = 100 # 第三步:使用tf.placeholder()构造输入数据X 和 y
X = tf.placeholder(tf.float32, shape=[None, inputSize])
y = tf.placeholder(tf.float32, shape=[None, num_classes]) #第四步:初始化W和b参数
W1 = tf.Variable(tf.random_normal([inputSize, num_hidden], stddev=0.1), name='W1')
b1 = tf.Variable(tf.zeros([num_hidden]), name='b1')
# b1 = tf.Variable(tf.constant(0.1), [num_hidden])
W2 = tf.Variable(tf.random_normal([num_hidden, num_hidden_2], stddev=0.1), name='W2')
# b2 = tf.Variable(tf.constant(0.1), [num_classes])
b2 = tf.Variable(tf.zeros([num_hidden_2]), name='b2')
W3 = tf.Variable(tf.random_normal([num_hidden_2, num_classes], stddev=0.1), name='W3')
b3 = tf.Variable(tf.zeros([num_classes])) # 第五步:使用点乘获得第一层,第二层和最后的得分值
h1 = tf.nn.relu(tf.matmul(X, W1) + b1)
h2 = tf.nn.relu(tf.matmul(h1, W2) + b2)
y_pred = tf.matmul(h2, W3) + b3
# 第六步:使用tf.nn.softmax_cross_entropy_with_logits计算交叉熵损失值
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_pred))
# 第七步:使用tf.train.GradientDescentOptimizer降低损失值loss
opt = tf.train.GradientDescentOptimizer(learning_rate=0.05).minimize(loss)
# 第八步:使用tf.argmax(y_pred, 1)找出每一行的最大值索引,tf.equal判断索引是否相等
correct_pred = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))
# tf.cast将索引转换为float类型,使用tf.reduce_mean求均值
accr = tf.reduce_mean(tf.cast(correct_pred, 'float')) # 进行初始化
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init) for i in range(trainIteration):
# 第九步:获得一个batch的数据
batch = mnist.train.next_batch(batch_size)
# 第十步:使用sess.run()对损失值降低操作和损失值进行执行,从而进行参数更新
_, data_loss = sess.run([opt, loss], feed_dict={X: batch[0], y: batch[1]})
# 第十一步:每迭代1000次就进行打印准确率和损失值
if i % 1000 == 0:
accurracy = sess.run(accr, feed_dict={X: batch[0], y: batch[1]})
print('loss: %g accr: %g' % (data_loss, accurracy))
# 第十二步:使用测试数据进行训练结果的验证
batch = mnist.test.next_batch(batch_size)
accurracy = sess.run(accr, feed_dict={X: batch[0], y: batch[1]})
print('test accr %g'%(accurracy))
深度学习原理与框架-Tensorflow卷积神经网络-神经网络mnist分类的更多相关文章
- 深度学习原理与框架-Tensorflow卷积神经网络-cifar10图片分类(代码) 1.tf.nn.lrn(局部响应归一化操作) 2.random.sample(在列表中随机选值) 3.tf.one_hot(对标签进行one_hot编码)
1.tf.nn.lrn(pool_h1, 4, bias=1.0, alpha=0.001/9.0, beta=0.75) # 局部响应归一化,使用相同位置的前后的filter进行响应归一化操作 参数 ...
- 深度学习原理与框架-Tensorflow卷积神经网络-卷积神经网络mnist分类 1.tf.nn.conv2d(卷积操作) 2.tf.nn.max_pool(最大池化操作) 3.tf.nn.dropout(执行dropout操作) 4.tf.nn.softmax_cross_entropy_with_logits(交叉熵损失) 5.tf.truncated_normal(两个标准差内的正态分布)
1. tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME') # 对数据进行卷积操作 参数说明:x表示输入数据,w表示卷积核, stride ...
- 深度学习原理与框架-Tensorflow基本操作-mnist数据集的逻辑回归 1.tf.matmul(点乘操作) 2.tf.equal(对应位置是否相等) 3.tf.cast(将布尔类型转换为数值类型) 4.tf.argmax(返回最大值的索引) 5.tf.nn.softmax(计算softmax概率值) 6.tf.train.GradientDescentOptimizer(损失值梯度下降器)
1. tf.matmul(X, w) # 进行点乘操作 参数说明:X,w都表示输入的数据, 2.tf.equal(x, y) # 比较两个数据对应位置的数是否相等,返回值为True,或者False 参 ...
- 深度学习原理与框架-Tensorflow基本操作-实现线性拟合
代码:使用tensorflow进行数据点的线性拟合操作 第一步:使用np.random.normal生成正态分布的数据 第二步:将数据分为X_data 和 y_data 第三步:对参数W和b, 使用t ...
- 深度学习原理与框架-Tensorflow基本操作-变量常用操作 1.tf.random_normal(生成正态分布随机数) 2.tf.random_shuffle(进行洗牌操作) 3. tf.assign(赋值操作) 4.tf.convert_to_tensor(转换为tensor类型) 5.tf.add(相加操作) tf.divide(相乘操作) 6.tf.placeholder(输入数据占位
1. 使用tf.random_normal([2, 3], mean=-1, stddev=4) 创建一个正态分布的随机数 参数说明:[2, 3]表示随机数的维度,mean表示平均值,stddev表示 ...
- 深度学习原理与框架-Tensorflow基本操作-Tensorflow中的变量
1.tf.Variable([[1, 2]]) # 创建一个变量 参数说明:[[1, 2]] 表示输入的数据,为一行二列的数据 2.tf.global_variables_initializer() ...
- 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)
1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...
- 深度学习原理与框架-卷积神经网络-cifar10分类(图片分类代码) 1.数据读入 2.模型构建 3.模型参数训练
卷积神经网络:下面要说的这个网络,由下面三层所组成 卷积网络:卷积层 + 激活层relu+ 池化层max_pool组成 神经网络:线性变化 + 激活层relu 神经网络: 线性变化(获得得分值) 代码 ...
- 深度学习原理与框架-卷积神经网络基本原理 1.卷积层的前向传播 2.卷积参数共享 3. 卷积后的维度计算 4. max池化操作 5.卷积流程图 6.卷积层的反向传播 7.池化层的反向传播
卷积神经网络的应用:卷积神经网络使用卷积提取图像的特征来进行图像的分类和识别 分类 相似图像搜索 ...
随机推荐
- go学习day1
go语言特性 1.垃圾回收 a.内存自动回收,再也不需要开发人员管理内存 b.开发人员专注业务实现,降低了心智负担 c.只需要new分配内存,不需要释放 2.天然并发 a.从语言层面支持并发,非常简单 ...
- 【C#】语音识别 - System.Speech
一个有趣的东西,今后可能用得上. C#语音识别:在命名空间 System.Speech下SpeechSynthesizer可以将文字转换成语音 贴出代码: public partial class F ...
- P3811 乘法逆元
传送 乘法逆元:ax ≡ 1 (mod p),其中x为a的逆元,求模意义下的乘法逆元,通常有一下几种方法: 1.拓展欧几里得(也就是exgcd) ax ≡ 1 (mod p) ax-py=1 这就变成 ...
- centos6.5网络虚拟化技术
一.配置KVM虚拟机NAT网络 1.创建脚本执行权限 下面是NAT启动脚本 # vi /etc/qemu-ifup-NAT 赋予权限 # chmod +x /etc/qemu-ifup-NAT 下载镜 ...
- python之路——6
王二学习python的笔记以及记录,如有雷同,那也没事,欢迎交流,wx:wyb199594 复习 增dic['age'] = 21 dic.setfault()删popcleardel popitem ...
- python之路——2
王二学习python的笔记以及记录,如有雷同,那也没事,欢迎交流,wx:wyb199594 复习 1.编译型:一次性将全部的代码编译成二进制文件 c c++ 优点:运行效率高 缺点:开发速度慢,不能跨 ...
- 视角同步NewViewTarget
SetViewTargetwithBlen说明: http://api.unrealengine.com/INT/BlueprintAPI/Game/Player/SetViewTargetwithB ...
- 网站简介-为什么网站的ICO图标更新后,ie浏览器没有更新过来?
为什么网站的ICO图标更新后,ie浏览器没有更新过来? 如何更新本地ico图标? 收藏夹里的网址访问后网站ico小图标怎么不会更新,还是没图标的. 如果制作了一个新的favicon.ico图标,并且已 ...
- Android短信收发(二)
接收SMS类,代码如下 //for receive SMS private SmsReceiver mSmsReceiver; @Override protected void onResume() ...
- virtual box + win7 + usb + share folder
1.enable virtaulization on BIOS 2.new machine setup, memory, harddisk size 3. 4.install extension pa ...