6.MNIST数据集分类简单版本
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 批次大小
batch_size = 64
# 计算一个周期一共有多少个批次
n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10]) # 创建一个简单的神经网络:784-10
W = tf.Variable(tf.truncated_normal([784,10], stddev=0.1))
b = tf.Variable(tf.zeros([10]) + 0.1)
prediction = tf.nn.softmax(tf.matmul(x,W)+b) # 二次代价函数
loss = tf.losses.mean_squared_error(y, prediction)
# 使用梯度下降法
train = tf.train.GradientDescentOptimizer(0.3).minimize(loss) # 结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) with tf.Session() as sess:
# 变量初始化
sess.run(tf.global_variables_initializer())
# 周期epoch:所有数据训练一次,就是一个周期
for epoch in range(21):
for batch in range(n_batch):
# 获取一个批次的数据和标签
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train,feed_dict={x:batch_xs,y:batch_ys})
# 每训练一个周期做一次测试
acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
6.MNIST数据集分类简单版本的更多相关文章
- MNIST数据集分类简单版本
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist = ...
- 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化
一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...
- 3.keras-简单实现Mnist数据集分类
keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...
- 6.keras-基于CNN网络的Mnist数据集分类
keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...
- 深度学习(一)之MNIST数据集分类
任务目标 对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率.(最终本文达到了\(99.36\%\)) 使用的库的版本: python:3.8.12 py ...
- Tensorflow学习教程------普通神经网络对mnist数据集分类
首先是不含隐层的神经网络, 输入层是784个神经元 输出层是10个神经元 代码如下 #coding:utf-8 import tensorflow as tf from tensorflow.exam ...
- 神经网络MNIST数据集分类tensorboard
今天分享同样数据集的CNN处理方式,同时加上tensorboard,可以看到清晰的结构图,迭代1000次acc收敛到0.992 先放代码,注释比较详细,变量名字看单词就能知道啥意思 import te ...
- 卷积神经网络应用于MNIST数据集分类
先贴代码 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...
- MNIST数据集
一.MNIST数据集分类简单版本 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data # ...
随机推荐
- python3 速查参考- python基础 8 -> 面向对象基础:类的创建与基础使用,类属性,property、类方法、静态方法、常用知识点概念(封装、继承等等见下一章)
基础概念 1.速查笔记: #-- 最普通的类 class C1(C2, C3): spam = 42 # 数据属性 def __init__(self, name): # 函数属性:构造函数 self ...
- DRF视图-5个扩展类以及GenericAPIView基类
视图 5个视图扩展类 视图拓展类的作用: 提供了几种后端视图(对数据资源进行曾删改查)处理流程的实现,如果需要编写的视图属于这五种,则视图可以通过继承相应的扩展类来复用代码,减少自己编写的代码量. 这 ...
- cisco上配置 pppoe拨号
r1(config)#vpdnenable 开启虚拟拨号VPDN r1(config)#vpdn-groupoffice 定义组 ...
- Python Network Security Programming-1
UNIX口令破解1.程序运行需求: 其中dictionary.txt文件为破解口令的字典文件,passwords.txt文件为临时存放UNIX系统密码的文件 2.程序源码: import crypt ...
- Netcat—瑞士军刀
netcat是网络工具中的瑞士军刀,它能通过TCP和UDP在网络中读写数据.通过与其他工具结合和重定向,你可以在脚本中以多种方式使用它.使用netcat命令所能完成的事情令人惊讶. netcat所做的 ...
- AD域环境取消密码复杂度和密码使用期限
本地组策略功能中设置密码永不过期的时候发现功能置灰了,不能设置: 这是因为创建域后默认本地组策略功能会被转移到域组策略管理里面,所以我们可以去组策略管理器里去更改组策略,因为一般本地策略的优先级别最低 ...
- 修改Jupyter Notebook默认目录
Jupyter Notebook每次打开都需要先进到相应的文件夹再打开 很不方便 首先进入到Jupyter的安装目录,我的是 D:\Anaconda3\Scripts 然后,输入命令: jupyter ...
- [转帖]查看Linux用的桌面是GNOME、KDE或者其他
http://superuser.com/questions/96151/how-do-i-check-whether-i-am-using-kde-or-gnome KDE 基于QT做的 已经越来越 ...
- centOS重启网络服务报错
1:启动网卡报错(Failed to start LSB: Bring up/down networking )解决办法总结 将 NetworkManager关闭, systemctl stop Ne ...
- ARST 第五周打卡
Algorithm : 做一个 leetcode 的算法题 /////////////////////////////////////////////////////////////////// // ...