MNIST数据集分类简单版本
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#载入数据集
mnist = input_data.read_data_sets("/data/stu05/mnist_data",one_hot=True)
Extracting /data/stu05/mnist_data/train-images-idx3-ubyte.gz
Extracting /data/stu05/mnist_data/train-labels-idx1-ubyte.gz
Extracting /data/stu05/mnist_data/t10k-images-idx3-ubyte.gz
Extracting /data/stu05/mnist_data/t10k-labels-idx1-ubyte.gz
#每个批次的大小
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size
#定义两个placeholder,None=100,28*28=784,即100行,784列
x = tf.placeholder(tf.float32,[None,784])
#0-9个输出标签
y = tf.placeholder(tf.float32,[None,10])
#创建一个简单的神经网络,只有输入层和输出层
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([1,10]))
#softmax函数转化为概率值
prediction = tf.nn.softmax(tf.matmul(x,W)+b)
#二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
#初始化变量
init = tf.global_variables_initializer()
#tf.equal()比较函数大小是否相同,相同为True,不同为false;tf.argmax():求y=1在哪个位置,求概率最大在哪个位置
#argmax返回一维张量中最大的值所在的位置,结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
#求准确率
#cast转化类型,将布尔型转化为32位浮点型,True=1.0,False=0.0;再求平均值
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
with tf.Session() as sess:
sess.run(init)
#将所有图片训练21次
for epoch in range(21):
#训练一次所有的图片
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
#feed_dict传入训练集的图片和标签
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
#传入测试集的图片和标签
acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Iter"+str(epoch)+",Testing Accuracy:"+str(acc))
Iter0,Testing Accuracy:0.8303
Iter1,Testing Accuracy:0.8708
Iter2,Testing Accuracy:0.8821
Iter3,Testing Accuracy:0.8885
Iter4,Testing Accuracy:0.8941
Iter5,Testing Accuracy:0.8973
Iter6,Testing Accuracy:0.9001
Iter7,Testing Accuracy:0.9013
Iter8,Testing Accuracy:0.9038
Iter9,Testing Accuracy:0.9048
Iter10,Testing Accuracy:0.9068
Iter11,Testing Accuracy:0.9068
Iter12,Testing Accuracy:0.9084
Iter13,Testing Accuracy:0.9094
Iter14,Testing Accuracy:0.9097
Iter15,Testing Accuracy:0.9107
Iter16,Testing Accuracy:0.9118
Iter17,Testing Accuracy:0.9116
Iter18,Testing Accuracy:0.9127
Iter19,Testing Accuracy:0.9136
Iter20,Testing Accuracy:0.9146
MNIST数据集分类简单版本的更多相关文章
- 6.MNIST数据集分类简单版本
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 载入数据集 mnist = i ...
- 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化
一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...
- 3.keras-简单实现Mnist数据集分类
keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...
- 6.keras-基于CNN网络的Mnist数据集分类
keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...
- 深度学习(一)之MNIST数据集分类
任务目标 对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率.(最终本文达到了\(99.36\%\)) 使用的库的版本: python:3.8.12 py ...
- Tensorflow学习教程------普通神经网络对mnist数据集分类
首先是不含隐层的神经网络, 输入层是784个神经元 输出层是10个神经元 代码如下 #coding:utf-8 import tensorflow as tf from tensorflow.exam ...
- 神经网络MNIST数据集分类tensorboard
今天分享同样数据集的CNN处理方式,同时加上tensorboard,可以看到清晰的结构图,迭代1000次acc收敛到0.992 先放代码,注释比较详细,变量名字看单词就能知道啥意思 import te ...
- 卷积神经网络应用于MNIST数据集分类
先贴代码 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...
- MNIST数据集
一.MNIST数据集分类简单版本 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data # ...
随机推荐
- 网页开发中调用iframe中的函数或者是dom元素
iframe中的代码 <!DOCTYPE html> <html xmlns="http://www.w3.org/1999/xhtml"> <hea ...
- Docker学习笔记_Dockerfile基本知识
Dockerfile由一行行命令语句组成,并支持以#开头的注释行. 1.编写一个Dockerfile文件 创建一个空的Docker工作目录,进入该目录,使用sudo vim Dockerfile指令新 ...
- c# 获取客户端ip、mac、机器名、操作系统、浏览器信息
d using System; using System.Collections.Generic; using System.Linq; using System.Web; using System. ...
- Free GIS Software
Refer to There are lots of free gis software listed in the website: http://www.freegis.org/ http://w ...
- Date3.19
1.正则表达式的定义及使用2.Date类的用法3.Calendar类的用法========================================================1正则表达式的 ...
- 简单工厂(Simple Factory)模式
工厂模式专门负责将大量有共同接口的类实例化.工厂模式可以动态决定将哪一个类实例化,不必事先知道每次要实例化哪一个类.工厂模式有以下几种形态: 简单工厂(Simple Factory)模式 工厂方法(F ...
- AD对象DirectoryEntry本地开发
DirectoryEntry类如果需要在本地计算机开发需要满足以下条件: 1.本地计算机dns解析必须和AD域控制器的dns保持一致,如图: 2.必须模拟身份验证,才能操作查询AD用户 /// < ...
- Xcode更新至IOS 9 后错误处理
1.obtain an updated library from the vendor, or disable bitcode for this target. for architecture ar ...
- Django工程中使用echarts怎么循环遍历显示数据
前言: 后面要开发测试管理平台,需要用到数据可视化,所以研究了一下 先看下最后的图吧,单击最上方的按钮可以控制柱状图显隐 views.py # -*- coding: utf-8 -*- from _ ...
- 快速了解“云原生”(Cloud Native)和前端开发的技术结合点
欢迎访问网易云社区,了解更多网易技术产品运营经验. 后端视角,结合点就是通过前端流控缓解后端的压力,提升系统响应能力. 从一般意义理解,Cloud Native 是后端应用的事情,要搞的是系统解耦.横 ...