使用tensorflow构造神经网络用来进行mnist数据集的分类

相比与上一节讲到的逻辑回归,神经网络比逻辑回归多了隐藏层,同时在每一个线性变化后添加了relu作为激活函数,

神经网络使用的损失值为softmax概率损失值,即为交叉熵损失值

代码:使用的是mnist数据集作为分类的测试数据,数据的维度为50000*784

第一步:载入mnist数据集

第二步:超参数的设置,输入图片的大小,分类的类别数,迭代的次数,每一个batch的大小

第三步:使用tf.placeholder() 进行输入数据的设置,进行数据的占位

第四步:使用tf.Variable() 设置里面设置tf.truncated_normal([inputSize, num_hidden], sttdv=0.1) 设置w的初始值,使用tf.Variable(tf.constant(0.1, [num_hidden])) 设置b,这一步主要是进行初始参数设置

第五步:使用tf.nn.relu(tf.matmul()+b) 构造第一层的网络,tf.nn.relu(tf.matmul() + b) 构造第二层的网络, tf.matmul() + b 构造输出层的得分

第六步:使用tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=score)) 来构造损失函数

第七步:使用tf.train.GradientDescentOptimizer().minimize(loss) 使用下降梯度降低损失值

第八步:使用tf.equal(tf.argmax(y, 1), tf.argmax(score, 1)) 即 tf.reduce_mean(tf.cast) 进行准确率的求解

第九步:进行循环,使用mnist.train.next_batch(batchSize) 读取部分数据

第十步:使用sess.run()执行梯度下降操作

第十一步:循环一千次,执行准确率的操作,并打印准确率

第十二步:使用验证集进行结果的验证

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data # 第一步数据读取
mnist = input_data.read_data_sets('/data', one_hot=True) # 第二步:超参数的设置
# 输入图片的大小
inputSize = 784
# 分类的类别数
num_classes = 10
# 隐藏层的个数
num_hidden = 50
# 迭代的次数
trainIteration = 10000
# 每一个batch值的大小
batch_size = 100
# 第二个隐藏层的个数
num_hidden_2 = 100 # 第三步:使用tf.placeholder()构造输入数据X 和 y
X = tf.placeholder(tf.float32, shape=[None, inputSize])
y = tf.placeholder(tf.float32, shape=[None, num_classes]) #第四步:初始化W和b参数
W1 = tf.Variable(tf.random_normal([inputSize, num_hidden], stddev=0.1), name='W1')
b1 = tf.Variable(tf.zeros([num_hidden]), name='b1')
# b1 = tf.Variable(tf.constant(0.1), [num_hidden])
W2 = tf.Variable(tf.random_normal([num_hidden, num_hidden_2], stddev=0.1), name='W2')
# b2 = tf.Variable(tf.constant(0.1), [num_classes])
b2 = tf.Variable(tf.zeros([num_hidden_2]), name='b2')
W3 = tf.Variable(tf.random_normal([num_hidden_2, num_classes], stddev=0.1), name='W3')
b3 = tf.Variable(tf.zeros([num_classes])) # 第五步:使用点乘获得第一层,第二层和最后的得分值
h1 = tf.nn.relu(tf.matmul(X, W1) + b1)
h2 = tf.nn.relu(tf.matmul(h1, W2) + b2)
y_pred = tf.matmul(h2, W3) + b3
# 第六步:使用tf.nn.softmax_cross_entropy_with_logits计算交叉熵损失值
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_pred))
# 第七步:使用tf.train.GradientDescentOptimizer降低损失值loss
opt = tf.train.GradientDescentOptimizer(learning_rate=0.05).minimize(loss)
# 第八步:使用tf.argmax(y_pred, 1)找出每一行的最大值索引,tf.equal判断索引是否相等
correct_pred = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))
# tf.cast将索引转换为float类型,使用tf.reduce_mean求均值
accr = tf.reduce_mean(tf.cast(correct_pred, 'float')) # 进行初始化
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init) for i in range(trainIteration):
# 第九步:获得一个batch的数据
batch = mnist.train.next_batch(batch_size)
# 第十步:使用sess.run()对损失值降低操作和损失值进行执行,从而进行参数更新
_, data_loss = sess.run([opt, loss], feed_dict={X: batch[0], y: batch[1]})
# 第十一步:每迭代1000次就进行打印准确率和损失值
if i % 1000 == 0:
accurracy = sess.run(accr, feed_dict={X: batch[0], y: batch[1]})
print('loss: %g accr: %g' % (data_loss, accurracy))
# 第十二步:使用测试数据进行训练结果的验证
batch = mnist.test.next_batch(batch_size)
accurracy = sess.run(accr, feed_dict={X: batch[0], y: batch[1]})
print('test accr %g'%(accurracy))

深度学习原理与框架-Tensorflow卷积神经网络-神经网络mnist分类的更多相关文章

  1. 深度学习原理与框架-Tensorflow卷积神经网络-cifar10图片分类(代码) 1.tf.nn.lrn(局部响应归一化操作) 2.random.sample(在列表中随机选值) 3.tf.one_hot(对标签进行one_hot编码)

    1.tf.nn.lrn(pool_h1, 4, bias=1.0, alpha=0.001/9.0, beta=0.75) # 局部响应归一化,使用相同位置的前后的filter进行响应归一化操作 参数 ...

  2. 深度学习原理与框架-Tensorflow卷积神经网络-卷积神经网络mnist分类 1.tf.nn.conv2d(卷积操作) 2.tf.nn.max_pool(最大池化操作) 3.tf.nn.dropout(执行dropout操作) 4.tf.nn.softmax_cross_entropy_with_logits(交叉熵损失) 5.tf.truncated_normal(两个标准差内的正态分布)

    1. tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME')  # 对数据进行卷积操作 参数说明:x表示输入数据,w表示卷积核, stride ...

  3. 深度学习原理与框架-Tensorflow基本操作-mnist数据集的逻辑回归 1.tf.matmul(点乘操作) 2.tf.equal(对应位置是否相等) 3.tf.cast(将布尔类型转换为数值类型) 4.tf.argmax(返回最大值的索引) 5.tf.nn.softmax(计算softmax概率值) 6.tf.train.GradientDescentOptimizer(损失值梯度下降器)

    1. tf.matmul(X, w) # 进行点乘操作 参数说明:X,w都表示输入的数据, 2.tf.equal(x, y) # 比较两个数据对应位置的数是否相等,返回值为True,或者False 参 ...

  4. 深度学习原理与框架-Tensorflow基本操作-实现线性拟合

    代码:使用tensorflow进行数据点的线性拟合操作 第一步:使用np.random.normal生成正态分布的数据 第二步:将数据分为X_data 和 y_data 第三步:对参数W和b, 使用t ...

  5. 深度学习原理与框架-Tensorflow基本操作-变量常用操作 1.tf.random_normal(生成正态分布随机数) 2.tf.random_shuffle(进行洗牌操作) 3. tf.assign(赋值操作) 4.tf.convert_to_tensor(转换为tensor类型) 5.tf.add(相加操作) tf.divide(相乘操作) 6.tf.placeholder(输入数据占位

    1. 使用tf.random_normal([2, 3], mean=-1, stddev=4) 创建一个正态分布的随机数 参数说明:[2, 3]表示随机数的维度,mean表示平均值,stddev表示 ...

  6. 深度学习原理与框架-Tensorflow基本操作-Tensorflow中的变量

    1.tf.Variable([[1, 2]])  # 创建一个变量 参数说明:[[1, 2]] 表示输入的数据,为一行二列的数据 2.tf.global_variables_initializer() ...

  7. 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)

    1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...

  8. 深度学习原理与框架-卷积神经网络-cifar10分类(图片分类代码) 1.数据读入 2.模型构建 3.模型参数训练

    卷积神经网络:下面要说的这个网络,由下面三层所组成 卷积网络:卷积层 + 激活层relu+ 池化层max_pool组成 神经网络:线性变化 + 激活层relu 神经网络: 线性变化(获得得分值) 代码 ...

  9. 深度学习原理与框架-卷积神经网络基本原理 1.卷积层的前向传播 2.卷积参数共享 3. 卷积后的维度计算 4. max池化操作 5.卷积流程图 6.卷积层的反向传播 7.池化层的反向传播

    卷积神经网络的应用:卷积神经网络使用卷积提取图像的特征来进行图像的分类和识别       分类                        相似图像搜索                        ...

随机推荐

  1. visual studio中新建和使用dll

    本文的目的是 创建一个最小化的dll并使用它 环境:win7 + vs2012 一个VS的解决方案(sln)下面可以有多个项目(project),所以这里新建一个解决方案,然后下面创建两个项目. 新建 ...

  2. SEO之HTML代码优化

    原文地址:http://www.admin5.com/article/20081128/117821.shtml   一.文档类型(DOCTYPE) XHTML1.0有三种DOCTYPE: 1.过渡型 ...

  3. 通过编写PHP代码并运用“正则表达式”来实现对试题文档进行去重复、排序

    通过编写PHP代码并运用“正则表达式”来实现对试题文档进行去重复.排序 <?php $subject = file_get_contents('test.txt'); $pattern = '/ ...

  4. 1136 A Delayed Palindrome (20 分)

    Consider a positive integer N written in standard notation with k+1 digits a​i​​ as a​k​​⋯a​1​​a​0​​ ...

  5. JVM内存调优

    JVM性能调优有很多设置,这个参考JVM参数即可. 主要调优的目的: 控制GC的行为.GC是一个后台处理,但是它也是会消耗系统性能的,因此经常会根据系统运行的程序的特性来更改GC行为 控制JVM堆栈大 ...

  6. 00001 - Linux 上的 Shebang 符号(#!)

    使用Linux或者unix系统的同学可能都对#!这个符号并不陌生,但是你真的了解它吗? 本文了将给你简单介绍一下Shebang(”#!”)这个符号. 首先,这个符号(#!)的名称,叫做”Shebang ...

  7. python类的全面介绍

    转载:全面介绍python面向对象的编程——类的基础 转载:类的实例方法.静态方法.类方法的区别

  8. linux驱动开发—基于Device tree机制的驱动编写

    前言Device Tree是一种用来描述硬件的数据结构,类似板级描述语言,起源于OpenFirmware(OF).在目前广泛使用的Linux kernel 2.6.x版本中,对于不同平台.不同硬件,往 ...

  9. Shiro 权限注解

      Shiro 权限注解:   Shiro 提供了相应的注解用于权限控制,如果使用这些注解就需要使用AOP 的功能来进行 判断,如Spring AOP:Shiro 提供了Spring AOP 集成用于 ...

  10. 电商系统架构总结2(Redis)

    二  Redis缓存 考虑到将来服务器的升级扩展,使用redis代替.net内置缓存是比较理想的选择.redis是非常成熟好用的缓存系统,安装配置非常简单,直接上官网下载安装包 安装启动就行了. 1 ...