MNIST数据集分类简单版本
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#载入数据集
mnist = input_data.read_data_sets("/data/stu05/mnist_data",one_hot=True)
Extracting /data/stu05/mnist_data/train-images-idx3-ubyte.gz
Extracting /data/stu05/mnist_data/train-labels-idx1-ubyte.gz
Extracting /data/stu05/mnist_data/t10k-images-idx3-ubyte.gz
Extracting /data/stu05/mnist_data/t10k-labels-idx1-ubyte.gz
#每个批次的大小
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size
#定义两个placeholder,None=100,28*28=784,即100行,784列
x = tf.placeholder(tf.float32,[None,784])
#0-9个输出标签
y = tf.placeholder(tf.float32,[None,10])
#创建一个简单的神经网络,只有输入层和输出层
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([1,10]))
#softmax函数转化为概率值
prediction = tf.nn.softmax(tf.matmul(x,W)+b)
#二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
#初始化变量
init = tf.global_variables_initializer()
#tf.equal()比较函数大小是否相同,相同为True,不同为false;tf.argmax():求y=1在哪个位置,求概率最大在哪个位置
#argmax返回一维张量中最大的值所在的位置,结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
#求准确率
#cast转化类型,将布尔型转化为32位浮点型,True=1.0,False=0.0;再求平均值
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
with tf.Session() as sess:
sess.run(init)
#将所有图片训练21次
for epoch in range(21):
#训练一次所有的图片
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
#feed_dict传入训练集的图片和标签
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
#传入测试集的图片和标签
acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Iter"+str(epoch)+",Testing Accuracy:"+str(acc))
Iter0,Testing Accuracy:0.8303
Iter1,Testing Accuracy:0.8708
Iter2,Testing Accuracy:0.8821
Iter3,Testing Accuracy:0.8885
Iter4,Testing Accuracy:0.8941
Iter5,Testing Accuracy:0.8973
Iter6,Testing Accuracy:0.9001
Iter7,Testing Accuracy:0.9013
Iter8,Testing Accuracy:0.9038
Iter9,Testing Accuracy:0.9048
Iter10,Testing Accuracy:0.9068
Iter11,Testing Accuracy:0.9068
Iter12,Testing Accuracy:0.9084
Iter13,Testing Accuracy:0.9094
Iter14,Testing Accuracy:0.9097
Iter15,Testing Accuracy:0.9107
Iter16,Testing Accuracy:0.9118
Iter17,Testing Accuracy:0.9116
Iter18,Testing Accuracy:0.9127
Iter19,Testing Accuracy:0.9136
Iter20,Testing Accuracy:0.9146
MNIST数据集分类简单版本的更多相关文章
- 6.MNIST数据集分类简单版本
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 载入数据集 mnist = i ...
- 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化
一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...
- 3.keras-简单实现Mnist数据集分类
keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...
- 6.keras-基于CNN网络的Mnist数据集分类
keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...
- 深度学习(一)之MNIST数据集分类
任务目标 对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率.(最终本文达到了\(99.36\%\)) 使用的库的版本: python:3.8.12 py ...
- Tensorflow学习教程------普通神经网络对mnist数据集分类
首先是不含隐层的神经网络, 输入层是784个神经元 输出层是10个神经元 代码如下 #coding:utf-8 import tensorflow as tf from tensorflow.exam ...
- 神经网络MNIST数据集分类tensorboard
今天分享同样数据集的CNN处理方式,同时加上tensorboard,可以看到清晰的结构图,迭代1000次acc收敛到0.992 先放代码,注释比较详细,变量名字看单词就能知道啥意思 import te ...
- 卷积神经网络应用于MNIST数据集分类
先贴代码 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...
- MNIST数据集
一.MNIST数据集分类简单版本 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data # ...
随机推荐
- c++策略模式(Strategy Method)
别人的博客再讲策略模式时都会讲三国,策略类就是赵云的锦囊,锦囊里装着若干妙计.在打仗时想要用什么妙计,直接从锦囊里去取. 锦囊类: class context { public: context(IS ...
- Angular23 loading组件、路由配置、子路由配置、路由懒加载配置
1 需求 由于Angular是单页面的应用,所以在进行数据刷新是进行的局部刷新:在进行数据刷新时从浏览器发出请求到后台响应数据是有时间延迟的,所以在这段时间就需要进行遮罩处理来提示用户系统正在请求数据 ...
- Servlet和JSP简述
什么是Servlet和JSP 用Java开发Web应用程序时用到的技术主要有两种,即Servlet和JSP. Servlet是在服务器端执行的Java程序,一个被称为Servlet容器的程序(其实就是 ...
- hook NtQueryDirectoryFile实现文件隐藏
一.NtQueryDirectoryFile函数功能(NT系列函数) NtQueryDirectoryFile函数:在一个给定的文件句柄,该函数返回该文件句柄指定目录下的不同文件的各种信息. 根据传入 ...
- Jackson-将对象转为Json字符串
SpringMVC-处理JSON 1.引入jackson依赖 <properties> <jackson.version>1.9.13</jackson.version& ...
- 7. Smali基础语法总结
最近在学习Android 移动安全逆向方面,逆向首先要看懂代码,Android4.4之前一直使用的是 Dalivk虚拟机,而Smali是用于Dalivk的反汇编程序的实现. Smali 支持注解,调试 ...
- 一步一步带你构建第一个 Laravel 项目
参考链接:https://laravel-news.com/your-first-laravel-application 简介 按照以下的步骤,你会创建一个简易的链接分享网站. 安装 Laravel ...
- css属性position的运用
随着web标准的规范化,网页的布局也随之千变万化.各种复杂漂亮有创意的页面布局冲 击这人们的视野,相比以前的table布局那就不是一等级的事儿.这个很大一部分功劳是css 样式的引入.而这个多样性布局 ...
- C#转java
懂C#的话,转Java也不是那么难,毕竟,语言语法还是相似的.尝试了下Java,说说自己的体会吧. 一,Java和C#都是完全面向对象的语言.在面向对象编程的三大原则方面,这两种语言接近得不能再接近. ...
- 一个基于 .NET Core 2.0 开发的简单易用的快速开发框架 - LinFx
LinFx 一个基于 .NET Core 2.0 开发的简单易用的快速开发框架,遵循领域驱动设计(DDD)规范约束,提供实现事件驱动.事件回溯.响应式等特性的基础设施.让开发者享受到正真意义的面向对象 ...