使用tensorflow构造神经网络用来进行mnist数据集的分类

相比与上一节讲到的逻辑回归,神经网络比逻辑回归多了隐藏层,同时在每一个线性变化后添加了relu作为激活函数,

神经网络使用的损失值为softmax概率损失值,即为交叉熵损失值

代码:使用的是mnist数据集作为分类的测试数据,数据的维度为50000*784

第一步:载入mnist数据集

第二步:超参数的设置,输入图片的大小,分类的类别数,迭代的次数,每一个batch的大小

第三步:使用tf.placeholder() 进行输入数据的设置,进行数据的占位

第四步:使用tf.Variable() 设置里面设置tf.truncated_normal([inputSize, num_hidden], sttdv=0.1) 设置w的初始值,使用tf.Variable(tf.constant(0.1, [num_hidden])) 设置b,这一步主要是进行初始参数设置

第五步:使用tf.nn.relu(tf.matmul()+b) 构造第一层的网络,tf.nn.relu(tf.matmul() + b) 构造第二层的网络, tf.matmul() + b 构造输出层的得分

第六步:使用tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=score)) 来构造损失函数

第七步:使用tf.train.GradientDescentOptimizer().minimize(loss) 使用下降梯度降低损失值

第八步:使用tf.equal(tf.argmax(y, 1), tf.argmax(score, 1)) 即 tf.reduce_mean(tf.cast) 进行准确率的求解

第九步:进行循环,使用mnist.train.next_batch(batchSize) 读取部分数据

第十步:使用sess.run()执行梯度下降操作

第十一步:循环一千次,执行准确率的操作,并打印准确率

第十二步:使用验证集进行结果的验证

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data # 第一步数据读取
mnist = input_data.read_data_sets('/data', one_hot=True) # 第二步:超参数的设置
# 输入图片的大小
inputSize = 784
# 分类的类别数
num_classes = 10
# 隐藏层的个数
num_hidden = 50
# 迭代的次数
trainIteration = 10000
# 每一个batch值的大小
batch_size = 100
# 第二个隐藏层的个数
num_hidden_2 = 100 # 第三步:使用tf.placeholder()构造输入数据X 和 y
X = tf.placeholder(tf.float32, shape=[None, inputSize])
y = tf.placeholder(tf.float32, shape=[None, num_classes]) #第四步:初始化W和b参数
W1 = tf.Variable(tf.random_normal([inputSize, num_hidden], stddev=0.1), name='W1')
b1 = tf.Variable(tf.zeros([num_hidden]), name='b1')
# b1 = tf.Variable(tf.constant(0.1), [num_hidden])
W2 = tf.Variable(tf.random_normal([num_hidden, num_hidden_2], stddev=0.1), name='W2')
# b2 = tf.Variable(tf.constant(0.1), [num_classes])
b2 = tf.Variable(tf.zeros([num_hidden_2]), name='b2')
W3 = tf.Variable(tf.random_normal([num_hidden_2, num_classes], stddev=0.1), name='W3')
b3 = tf.Variable(tf.zeros([num_classes])) # 第五步:使用点乘获得第一层,第二层和最后的得分值
h1 = tf.nn.relu(tf.matmul(X, W1) + b1)
h2 = tf.nn.relu(tf.matmul(h1, W2) + b2)
y_pred = tf.matmul(h2, W3) + b3
# 第六步:使用tf.nn.softmax_cross_entropy_with_logits计算交叉熵损失值
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_pred))
# 第七步:使用tf.train.GradientDescentOptimizer降低损失值loss
opt = tf.train.GradientDescentOptimizer(learning_rate=0.05).minimize(loss)
# 第八步:使用tf.argmax(y_pred, 1)找出每一行的最大值索引,tf.equal判断索引是否相等
correct_pred = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))
# tf.cast将索引转换为float类型,使用tf.reduce_mean求均值
accr = tf.reduce_mean(tf.cast(correct_pred, 'float')) # 进行初始化
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init) for i in range(trainIteration):
# 第九步:获得一个batch的数据
batch = mnist.train.next_batch(batch_size)
# 第十步:使用sess.run()对损失值降低操作和损失值进行执行,从而进行参数更新
_, data_loss = sess.run([opt, loss], feed_dict={X: batch[0], y: batch[1]})
# 第十一步:每迭代1000次就进行打印准确率和损失值
if i % 1000 == 0:
accurracy = sess.run(accr, feed_dict={X: batch[0], y: batch[1]})
print('loss: %g accr: %g' % (data_loss, accurracy))
# 第十二步:使用测试数据进行训练结果的验证
batch = mnist.test.next_batch(batch_size)
accurracy = sess.run(accr, feed_dict={X: batch[0], y: batch[1]})
print('test accr %g'%(accurracy))

深度学习原理与框架-Tensorflow卷积神经网络-神经网络mnist分类的更多相关文章

  1. 深度学习原理与框架-Tensorflow卷积神经网络-cifar10图片分类(代码) 1.tf.nn.lrn(局部响应归一化操作) 2.random.sample(在列表中随机选值) 3.tf.one_hot(对标签进行one_hot编码)

    1.tf.nn.lrn(pool_h1, 4, bias=1.0, alpha=0.001/9.0, beta=0.75) # 局部响应归一化,使用相同位置的前后的filter进行响应归一化操作 参数 ...

  2. 深度学习原理与框架-Tensorflow卷积神经网络-卷积神经网络mnist分类 1.tf.nn.conv2d(卷积操作) 2.tf.nn.max_pool(最大池化操作) 3.tf.nn.dropout(执行dropout操作) 4.tf.nn.softmax_cross_entropy_with_logits(交叉熵损失) 5.tf.truncated_normal(两个标准差内的正态分布)

    1. tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME')  # 对数据进行卷积操作 参数说明:x表示输入数据,w表示卷积核, stride ...

  3. 深度学习原理与框架-Tensorflow基本操作-mnist数据集的逻辑回归 1.tf.matmul(点乘操作) 2.tf.equal(对应位置是否相等) 3.tf.cast(将布尔类型转换为数值类型) 4.tf.argmax(返回最大值的索引) 5.tf.nn.softmax(计算softmax概率值) 6.tf.train.GradientDescentOptimizer(损失值梯度下降器)

    1. tf.matmul(X, w) # 进行点乘操作 参数说明:X,w都表示输入的数据, 2.tf.equal(x, y) # 比较两个数据对应位置的数是否相等,返回值为True,或者False 参 ...

  4. 深度学习原理与框架-Tensorflow基本操作-实现线性拟合

    代码:使用tensorflow进行数据点的线性拟合操作 第一步:使用np.random.normal生成正态分布的数据 第二步:将数据分为X_data 和 y_data 第三步:对参数W和b, 使用t ...

  5. 深度学习原理与框架-Tensorflow基本操作-变量常用操作 1.tf.random_normal(生成正态分布随机数) 2.tf.random_shuffle(进行洗牌操作) 3. tf.assign(赋值操作) 4.tf.convert_to_tensor(转换为tensor类型) 5.tf.add(相加操作) tf.divide(相乘操作) 6.tf.placeholder(输入数据占位

    1. 使用tf.random_normal([2, 3], mean=-1, stddev=4) 创建一个正态分布的随机数 参数说明:[2, 3]表示随机数的维度,mean表示平均值,stddev表示 ...

  6. 深度学习原理与框架-Tensorflow基本操作-Tensorflow中的变量

    1.tf.Variable([[1, 2]])  # 创建一个变量 参数说明:[[1, 2]] 表示输入的数据,为一行二列的数据 2.tf.global_variables_initializer() ...

  7. 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)

    1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...

  8. 深度学习原理与框架-卷积神经网络-cifar10分类(图片分类代码) 1.数据读入 2.模型构建 3.模型参数训练

    卷积神经网络:下面要说的这个网络,由下面三层所组成 卷积网络:卷积层 + 激活层relu+ 池化层max_pool组成 神经网络:线性变化 + 激活层relu 神经网络: 线性变化(获得得分值) 代码 ...

  9. 深度学习原理与框架-卷积神经网络基本原理 1.卷积层的前向传播 2.卷积参数共享 3. 卷积后的维度计算 4. max池化操作 5.卷积流程图 6.卷积层的反向传播 7.池化层的反向传播

    卷积神经网络的应用:卷积神经网络使用卷积提取图像的特征来进行图像的分类和识别       分类                        相似图像搜索                        ...

随机推荐

  1. bzoj5011: [Jx2017]颜色

    Description 可怜有一个长度为n的正整数序列Ai,其中相同的正整数代表着相同的颜色. 现在可怜觉得这个序列太长了,于是她决定选择一些颜色把这些颜色的所有位置都删去. 删除颜色i可以定义为把所 ...

  2. 2.Linux技能要求

    Linux嵌入式工程师技能要求: 1.C语言                    具备C语言基础.理解C语言基础编程及高级编程,包括:数据类型.数组.指针.结构体.链表.文件操作.队列.栈.     ...

  3. python之路——2

    王二学习python的笔记以及记录,如有雷同,那也没事,欢迎交流,wx:wyb199594 复习 1.编译型:一次性将全部的代码编译成二进制文件 c c++ 优点:运行效率高 缺点:开发速度慢,不能跨 ...

  4. Android应用启动会白屏一下的解决办法

    设置透明样式,如下:<activity android:name="com.hongfans.cvi.ui.MainActivity" android:configChang ...

  5. MySql 索引优化实例

    查询语句 SELECT customer_id,title,content FROM `product_comment` WHERE audit_status=1 AND product_id=199 ...

  6. 清除win7桌面背景的图片位置下拉菜单的历史记录

    到注册表 清除win7桌面背景的图片位置下拉菜单的历史记录: 开始--->运行--->输入regedit,在弹出的注册表编辑器中,定位到如下位置 HKEY_CURRENT_USER\Sof ...

  7. Windows安装启动MySQL

    Win安装MySQL数据库将下载下来的mysql解压到指定目录下,cmd切换到bin目录下>mysqld -install 安装服务>mysqld -remove 卸载服务>net ...

  8. noteforjs

    轻量高效的开源JavaScript插件和库---<!-- TOC --> - [图片](#图片)- [布局](#布局)- [轮播图](#轮播图)- [弹出层](#弹出层)- [音频视频]( ...

  9. 移动端动态font-size

    /** * Created by shimin on 2017/8/18. *///计算dpr!function(win, lib) { var timer, doc = win.document, ...

  10. 10进制与16进制之间的转换 delphi

    delphi中有直接把10进制转换成16进制的函数: function   IntToHex(Value:   Integer;   Digits:   Integer):   string;   o ...