import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 批次大小
batch_size = 64
# 计算一个周期一共有多少个批次
n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10]) # 创建一个简单的神经网络:784-10
W = tf.Variable(tf.truncated_normal([784,10], stddev=0.1))
b = tf.Variable(tf.zeros([10]) + 0.1)
prediction = tf.nn.softmax(tf.matmul(x,W)+b) # 二次代价函数
loss = tf.losses.mean_squared_error(y, prediction)
# 使用梯度下降法
train = tf.train.GradientDescentOptimizer(0.3).minimize(loss) # 结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) with tf.Session() as sess:
# 变量初始化
sess.run(tf.global_variables_initializer())
# 周期epoch:所有数据训练一次,就是一个周期
for epoch in range(21):
for batch in range(n_batch):
# 获取一个批次的数据和标签
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train,feed_dict={x:batch_xs,y:batch_ys})
# 每训练一个周期做一次测试
acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))

6.MNIST数据集分类简单版本的更多相关文章

  1. MNIST数据集分类简单版本

      import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist = ...

  2. 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化

    一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...

  3. 3.keras-简单实现Mnist数据集分类

    keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...

  4. 6.keras-基于CNN网络的Mnist数据集分类

    keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...

  5. 深度学习(一)之MNIST数据集分类

    任务目标 对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率.(最终本文达到了\(99.36\%\)) 使用的库的版本: python:3.8.12 py ...

  6. Tensorflow学习教程------普通神经网络对mnist数据集分类

    首先是不含隐层的神经网络, 输入层是784个神经元 输出层是10个神经元 代码如下 #coding:utf-8 import tensorflow as tf from tensorflow.exam ...

  7. 神经网络MNIST数据集分类tensorboard

    今天分享同样数据集的CNN处理方式,同时加上tensorboard,可以看到清晰的结构图,迭代1000次acc收敛到0.992 先放代码,注释比较详细,变量名字看单词就能知道啥意思 import te ...

  8. 卷积神经网络应用于MNIST数据集分类

    先贴代码 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...

  9. MNIST数据集

    一.MNIST数据集分类简单版本 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data # ...

随机推荐

  1. A smooth collaborative recommender system 推荐系统-浅显了解

    characteristic: 1.Tracking user 2.personliza 3.面对的问题类似于分形学+混沌学(以有观无+窥一管而知全貌) 4.Data:high-volume.spar ...

  2. MATLAB实现图像的代数运算

    目录 1.使用求补运算对各类图像进行处理 2.利用imlincomb函数将图像的灰度值放大1.5倍 3.利用imlincomb函数计算两幅图像的平均值. 4.图像的加法运算 5.利用imnoise函数 ...

  3. hive配置元数据库mysql文件配置

    hive中conf/hive-site.xml文件配置(没有该文件则新建) <?xml version="1.0"?> <?xml-stylesheet type ...

  4. kernel32.dll 这个系统模块

    详细解读:远程线程注入DLL到PC版微信 一.远程线程注入的原理 1.其基础是在 Windows 系统中,每个 .exe 文件在双击打开时都会加载 kernel32.dll 这个系统模块,该模块中有一 ...

  5. hbase增量和全量备份

    1.星期五全量备份星期四23:59:59的数据,星期一全量备份到星期日23:59:59的数据,其他的增量备份,备份前一天00:00:00  -  23:59:59的数据 * * /opt/prodfu ...

  6. 面试总结 | 百度 NLP 实习生

    1. 项目简历:主要体现和招聘要求相关的工作,简历要精简,不要给过多冗余信息.对于每个项目,自己做过的工作,里面用到的方法,要很清楚,工作的motivation.意义等也要清楚. 这次面试中我的问题: ...

  7. 2019牛客暑期多校训练营(第八场)-C CDMA(递归构造)

    题目链接:https://ac.nowcoder.com/acm/contest/888/C 题意:输入m(为2的n次幂,n<=10),构造一个m*m的矩阵满足任意不同的两行的元素乘积和为0. ...

  8. 使用 WijmoJS 轻松实现撤消重做(Undo /Redo)

    使用 WijmoJS 轻松实现撤消重做(Undo /Redo) 在V2019.0 Update2 的全新版本中,WijmoJS能够轻松实现撤消和重做操作,使Web应用程序的使用更加友好.更加高效. 不 ...

  9. 【spring Boot】spring boot1.5以上版本@ConfigurationProperties取消location注解后的替代方案

    前言 =========================================== 初步接触Spring Boot ===================================== ...

  10. STL-set 容器以及迭代器的简单理解

    先说下set的基本操作和时间复杂度 begin()     ,返回set容器的第一个元素 end() ,返回set容器的最后一个元素 clear()        ,删除set容器中的所有的元素 em ...