使用tensorflow构造神经网络用来进行mnist数据集的分类

相比与上一节讲到的逻辑回归,神经网络比逻辑回归多了隐藏层,同时在每一个线性变化后添加了relu作为激活函数,

神经网络使用的损失值为softmax概率损失值,即为交叉熵损失值

代码:使用的是mnist数据集作为分类的测试数据,数据的维度为50000*784

第一步:载入mnist数据集

第二步:超参数的设置,输入图片的大小,分类的类别数,迭代的次数,每一个batch的大小

第三步:使用tf.placeholder() 进行输入数据的设置,进行数据的占位

第四步:使用tf.Variable() 设置里面设置tf.truncated_normal([inputSize, num_hidden], sttdv=0.1) 设置w的初始值,使用tf.Variable(tf.constant(0.1, [num_hidden])) 设置b,这一步主要是进行初始参数设置

第五步:使用tf.nn.relu(tf.matmul()+b) 构造第一层的网络,tf.nn.relu(tf.matmul() + b) 构造第二层的网络, tf.matmul() + b 构造输出层的得分

第六步:使用tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=score)) 来构造损失函数

第七步:使用tf.train.GradientDescentOptimizer().minimize(loss) 使用下降梯度降低损失值

第八步:使用tf.equal(tf.argmax(y, 1), tf.argmax(score, 1)) 即 tf.reduce_mean(tf.cast) 进行准确率的求解

第九步:进行循环,使用mnist.train.next_batch(batchSize) 读取部分数据

第十步:使用sess.run()执行梯度下降操作

第十一步:循环一千次,执行准确率的操作,并打印准确率

第十二步:使用验证集进行结果的验证

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data # 第一步数据读取
mnist = input_data.read_data_sets('/data', one_hot=True) # 第二步:超参数的设置
# 输入图片的大小
inputSize = 784
# 分类的类别数
num_classes = 10
# 隐藏层的个数
num_hidden = 50
# 迭代的次数
trainIteration = 10000
# 每一个batch值的大小
batch_size = 100
# 第二个隐藏层的个数
num_hidden_2 = 100 # 第三步:使用tf.placeholder()构造输入数据X 和 y
X = tf.placeholder(tf.float32, shape=[None, inputSize])
y = tf.placeholder(tf.float32, shape=[None, num_classes]) #第四步:初始化W和b参数
W1 = tf.Variable(tf.random_normal([inputSize, num_hidden], stddev=0.1), name='W1')
b1 = tf.Variable(tf.zeros([num_hidden]), name='b1')
# b1 = tf.Variable(tf.constant(0.1), [num_hidden])
W2 = tf.Variable(tf.random_normal([num_hidden, num_hidden_2], stddev=0.1), name='W2')
# b2 = tf.Variable(tf.constant(0.1), [num_classes])
b2 = tf.Variable(tf.zeros([num_hidden_2]), name='b2')
W3 = tf.Variable(tf.random_normal([num_hidden_2, num_classes], stddev=0.1), name='W3')
b3 = tf.Variable(tf.zeros([num_classes])) # 第五步:使用点乘获得第一层,第二层和最后的得分值
h1 = tf.nn.relu(tf.matmul(X, W1) + b1)
h2 = tf.nn.relu(tf.matmul(h1, W2) + b2)
y_pred = tf.matmul(h2, W3) + b3
# 第六步:使用tf.nn.softmax_cross_entropy_with_logits计算交叉熵损失值
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_pred))
# 第七步:使用tf.train.GradientDescentOptimizer降低损失值loss
opt = tf.train.GradientDescentOptimizer(learning_rate=0.05).minimize(loss)
# 第八步:使用tf.argmax(y_pred, 1)找出每一行的最大值索引,tf.equal判断索引是否相等
correct_pred = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))
# tf.cast将索引转换为float类型,使用tf.reduce_mean求均值
accr = tf.reduce_mean(tf.cast(correct_pred, 'float')) # 进行初始化
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init) for i in range(trainIteration):
# 第九步:获得一个batch的数据
batch = mnist.train.next_batch(batch_size)
# 第十步:使用sess.run()对损失值降低操作和损失值进行执行,从而进行参数更新
_, data_loss = sess.run([opt, loss], feed_dict={X: batch[0], y: batch[1]})
# 第十一步:每迭代1000次就进行打印准确率和损失值
if i % 1000 == 0:
accurracy = sess.run(accr, feed_dict={X: batch[0], y: batch[1]})
print('loss: %g accr: %g' % (data_loss, accurracy))
# 第十二步:使用测试数据进行训练结果的验证
batch = mnist.test.next_batch(batch_size)
accurracy = sess.run(accr, feed_dict={X: batch[0], y: batch[1]})
print('test accr %g'%(accurracy))

深度学习原理与框架-Tensorflow卷积神经网络-神经网络mnist分类的更多相关文章

  1. 深度学习原理与框架-Tensorflow卷积神经网络-cifar10图片分类(代码) 1.tf.nn.lrn(局部响应归一化操作) 2.random.sample(在列表中随机选值) 3.tf.one_hot(对标签进行one_hot编码)

    1.tf.nn.lrn(pool_h1, 4, bias=1.0, alpha=0.001/9.0, beta=0.75) # 局部响应归一化,使用相同位置的前后的filter进行响应归一化操作 参数 ...

  2. 深度学习原理与框架-Tensorflow卷积神经网络-卷积神经网络mnist分类 1.tf.nn.conv2d(卷积操作) 2.tf.nn.max_pool(最大池化操作) 3.tf.nn.dropout(执行dropout操作) 4.tf.nn.softmax_cross_entropy_with_logits(交叉熵损失) 5.tf.truncated_normal(两个标准差内的正态分布)

    1. tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME')  # 对数据进行卷积操作 参数说明:x表示输入数据,w表示卷积核, stride ...

  3. 深度学习原理与框架-Tensorflow基本操作-mnist数据集的逻辑回归 1.tf.matmul(点乘操作) 2.tf.equal(对应位置是否相等) 3.tf.cast(将布尔类型转换为数值类型) 4.tf.argmax(返回最大值的索引) 5.tf.nn.softmax(计算softmax概率值) 6.tf.train.GradientDescentOptimizer(损失值梯度下降器)

    1. tf.matmul(X, w) # 进行点乘操作 参数说明:X,w都表示输入的数据, 2.tf.equal(x, y) # 比较两个数据对应位置的数是否相等,返回值为True,或者False 参 ...

  4. 深度学习原理与框架-Tensorflow基本操作-实现线性拟合

    代码:使用tensorflow进行数据点的线性拟合操作 第一步:使用np.random.normal生成正态分布的数据 第二步:将数据分为X_data 和 y_data 第三步:对参数W和b, 使用t ...

  5. 深度学习原理与框架-Tensorflow基本操作-变量常用操作 1.tf.random_normal(生成正态分布随机数) 2.tf.random_shuffle(进行洗牌操作) 3. tf.assign(赋值操作) 4.tf.convert_to_tensor(转换为tensor类型) 5.tf.add(相加操作) tf.divide(相乘操作) 6.tf.placeholder(输入数据占位

    1. 使用tf.random_normal([2, 3], mean=-1, stddev=4) 创建一个正态分布的随机数 参数说明:[2, 3]表示随机数的维度,mean表示平均值,stddev表示 ...

  6. 深度学习原理与框架-Tensorflow基本操作-Tensorflow中的变量

    1.tf.Variable([[1, 2]])  # 创建一个变量 参数说明:[[1, 2]] 表示输入的数据,为一行二列的数据 2.tf.global_variables_initializer() ...

  7. 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)

    1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...

  8. 深度学习原理与框架-卷积神经网络-cifar10分类(图片分类代码) 1.数据读入 2.模型构建 3.模型参数训练

    卷积神经网络:下面要说的这个网络,由下面三层所组成 卷积网络:卷积层 + 激活层relu+ 池化层max_pool组成 神经网络:线性变化 + 激活层relu 神经网络: 线性变化(获得得分值) 代码 ...

  9. 深度学习原理与框架-卷积神经网络基本原理 1.卷积层的前向传播 2.卷积参数共享 3. 卷积后的维度计算 4. max池化操作 5.卷积流程图 6.卷积层的反向传播 7.池化层的反向传播

    卷积神经网络的应用:卷积神经网络使用卷积提取图像的特征来进行图像的分类和识别       分类                        相似图像搜索                        ...

随机推荐

  1. Javascript中的对象(八)

    一.如何编写可以计算的对象的属性名 我们都知道对象的属性访问分两种,键访问(["属性名"])和属性访问(.属性名,遵循标识符的命名规范) 对于动态属性名可以这样 var prefi ...

  2. C++进阶小结

    1.C++中类的不同存储区的对象的初始值 class test; class test { private: int i; int j; public: int geti() { return i; ...

  3. spring AOP 之一:spring AOP功能介绍

    一.AOP简介 AOP:是一种面向切面的编程范式,是一种编程思想,旨在通过分离横切关注点,提高模块化,可以跨越对象关注点.Aop的典型应用即spring的事务机制,日志记录.利用AOP可以对业务逻辑的 ...

  4. Html的本质及在web中的作用

    概要 本文以一个Socket程序为例由浅及深地揭示了Html的本质问题,同时介绍了作为web开发者我们在开发网站时需要做的事情 Html的本质以及开发需要的工作 1.服务器-客户端模型 其实,对于所有 ...

  5. 使用dtc把dtb的反编译为dts

    sudo apt-get install device-tree-compiler dtc -I dtb -O dts msm8976-v1.1-qrd.dtb > msm8976-v1.1-q ...

  6. 使用Dotfuscator混淆你的.net程序

    简介 众所周知C#等net框架的程序是无法防止反编译的,但可以通过混淆,让反编译出来的代码非常难看. Dotfuscator是微软推荐使用的第三方混淆器,用来保护你的net程序.可以在安装VS的时候顺 ...

  7. ERROR hive.HiveConfig: Could not load org.apache.hadoop.hive.conf.HiveConf. Make sure HIVE_CONF_DIR is set correctly.

    Sqoop导入mysql表中的数据到hive,出现如下错误:  ERROR hive.HiveConfig: Could not load org.apache.hadoop.hive.conf.Hi ...

  8. MYSQL存储过程中 使用变量 做表名

    ); DECLARE temp2 int; set temp1=m_tableName; set temp2=m_maxCount; set @sqlStr=CONCAT('select * from ...

  9. 图片Alpha预乘的作用[转]

    Premultiplied Alpha 这个概念做游戏开发的人都不会不知道.Xcode 的工程选项里有一项 Compress PNG Files,会对 PNG 进行 Premultiplied Alp ...

  10. js中基本数据类型和引用数据类型的区别

    1.基本数据类型和引用数据类型 ECMAScript包括两个不同类型的值:基本数据类型和引用数据类型. 基本数据类型指的是简单的数据段,引用数据类型指的是有多个值构成的对象. 当我们把变量赋值给一个变 ...