MNIST数据集分类简单版本
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#载入数据集
mnist = input_data.read_data_sets("/data/stu05/mnist_data",one_hot=True)
Extracting /data/stu05/mnist_data/train-images-idx3-ubyte.gz
Extracting /data/stu05/mnist_data/train-labels-idx1-ubyte.gz
Extracting /data/stu05/mnist_data/t10k-images-idx3-ubyte.gz
Extracting /data/stu05/mnist_data/t10k-labels-idx1-ubyte.gz
#每个批次的大小
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size
#定义两个placeholder,None=100,28*28=784,即100行,784列
x = tf.placeholder(tf.float32,[None,784])
#0-9个输出标签
y = tf.placeholder(tf.float32,[None,10])
#创建一个简单的神经网络,只有输入层和输出层
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([1,10]))
#softmax函数转化为概率值
prediction = tf.nn.softmax(tf.matmul(x,W)+b)
#二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
#初始化变量
init = tf.global_variables_initializer()
#tf.equal()比较函数大小是否相同,相同为True,不同为false;tf.argmax():求y=1在哪个位置,求概率最大在哪个位置
#argmax返回一维张量中最大的值所在的位置,结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
#求准确率
#cast转化类型,将布尔型转化为32位浮点型,True=1.0,False=0.0;再求平均值
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
with tf.Session() as sess:
sess.run(init)
#将所有图片训练21次
for epoch in range(21):
#训练一次所有的图片
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
#feed_dict传入训练集的图片和标签
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
#传入测试集的图片和标签
acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Iter"+str(epoch)+",Testing Accuracy:"+str(acc))
Iter0,Testing Accuracy:0.8303
Iter1,Testing Accuracy:0.8708
Iter2,Testing Accuracy:0.8821
Iter3,Testing Accuracy:0.8885
Iter4,Testing Accuracy:0.8941
Iter5,Testing Accuracy:0.8973
Iter6,Testing Accuracy:0.9001
Iter7,Testing Accuracy:0.9013
Iter8,Testing Accuracy:0.9038
Iter9,Testing Accuracy:0.9048
Iter10,Testing Accuracy:0.9068
Iter11,Testing Accuracy:0.9068
Iter12,Testing Accuracy:0.9084
Iter13,Testing Accuracy:0.9094
Iter14,Testing Accuracy:0.9097
Iter15,Testing Accuracy:0.9107
Iter16,Testing Accuracy:0.9118
Iter17,Testing Accuracy:0.9116
Iter18,Testing Accuracy:0.9127
Iter19,Testing Accuracy:0.9136
Iter20,Testing Accuracy:0.9146
MNIST数据集分类简单版本的更多相关文章
- 6.MNIST数据集分类简单版本
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 载入数据集 mnist = i ...
- 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化
一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...
- 3.keras-简单实现Mnist数据集分类
keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...
- 6.keras-基于CNN网络的Mnist数据集分类
keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...
- 深度学习(一)之MNIST数据集分类
任务目标 对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率.(最终本文达到了\(99.36\%\)) 使用的库的版本: python:3.8.12 py ...
- Tensorflow学习教程------普通神经网络对mnist数据集分类
首先是不含隐层的神经网络, 输入层是784个神经元 输出层是10个神经元 代码如下 #coding:utf-8 import tensorflow as tf from tensorflow.exam ...
- 神经网络MNIST数据集分类tensorboard
今天分享同样数据集的CNN处理方式,同时加上tensorboard,可以看到清晰的结构图,迭代1000次acc收敛到0.992 先放代码,注释比较详细,变量名字看单词就能知道啥意思 import te ...
- 卷积神经网络应用于MNIST数据集分类
先贴代码 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...
- MNIST数据集
一.MNIST数据集分类简单版本 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data # ...
随机推荐
- SpringBoot23 分模块开发
1 开发环境说明 JDK:1.8 MAVEN:3.5 IDEA:2017.2.5 SpringBoot:2.0.3.RELEASE 2 创建SpringBoot项目 2.1 项目信息 2.2 添加项目 ...
- 使用RVM更新Ruby 版本
http://rvm.io/rvm/install Install RVM (development version): \curl -sSL https://get.rvm.io | bash Wi ...
- 数据库sql 开窗函数
--本文采用Oracle数据库测试,前4个查询为一组,后2个查询为一组,每组前面的查询是为了推出最后的查询 --创建表,为了简化处理,字段类型都采用varcharcreate table tb_sc( ...
- css 样式文字溢出显示省略号
在table中使用溢出样式,table样式要设置为”table-layout: fixed“,即<table style="table-layout: fixed;"> ...
- Solidity string to uint
oraclize result以string格式返回,solidity没有uint(string)这样的强制转换功能,如果要解析其中的数字,可以用oraclize提供的parseInt方法: prag ...
- Glade编程
一.简介 如果有一种软件能将图形界面的设计及时地展现于开发人员的面前,而且在设计完后能直接看到界面的外观效果,这样就使程序员的主要精力集中于应用程序核心功能的开发上,这就是所谓的可视化编程思想. ...
- 9.TOP 子句--mysql limit
TOP 子句 TOP 子句用于规定要返回的记录的数目. 对于拥有数千条记录的大型表来说,TOP 子句是非常有用的. 注释:并非所有的数据库系统都支持 TOP 子句. MySQL 语法 SELECT c ...
- Yii2邮箱发送与配置
1配置邮箱 在 common/config/web.php中写入以下代码配置 Mail代理 return [ 'components' => [ ...//your code, //以下是 ma ...
- LINUX 查看CUP温度
在Linux下可以通过lm_sensors来查看CPU的温度(当然你的硬件首先要支持),要使用这个功能要有内核相关模块(比如I2C)的支持,下面说一下操作方法: 先看一下你的机器上是否安装了lm_se ...
- Entity Framework 6.0 Tutorials(2):Async query and Save
Async query and Save: You can take advantage of asynchronous execution of .Net 4.5 with Entity Frame ...