一个简单的TensorFlow可视化MNIST数据集识别程序
下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev))
# -*- coding: utf-8 -*- import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector old_v = tf.logging.get_verbosity()
tf.logging.set_verbosity(tf.logging.ERROR)
# 载入数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# 运行次数
max_steps = 3001
# 图片数量
image_num = 5000
# 文件路径
DIR = "D:/AIdata/tf_data/tf_test1/" sess = tf.Session() # 载入图片,
# tf.stack矩阵拼接函数,
embedding = tf.Variable(tf.stack(mnist.test.images[:image_num]),
trainable=False, name="embedding") def variable_summaries(var):
with tf.name_scope("summaries"):
mean = tf.reduce_mean(var)
with tf.name_scope("stddev"):
# 计算标准差
stddev = tf.sqrt(tf.reduce_mean(tf.square(var-mean)))
# 绘制标准差信息
tf.summary.scalar("stddev", stddev)
# 绘制最大值
tf.summary.scalar("max", tf.reduce_max(var))
tf.summary.scalar("min", tf.reduce_min(var))
# 绘制直方图信息
tf.summary.histogram("histogram", var) with tf.name_scope('Input'):
x = tf.placeholder(tf.float32, [None, 784], name="x_input")
y = tf.placeholder(tf.float32, [None, 10], name="y_input")
LR = tf.Variable(0.001, dtype=tf.float32) # 显示图片
with tf.name_scope("input_reshape"):
# 改变x的形状(28x28x1)
image_shape_input = tf.reshape(x, [-1, 28, 28, 1])
# 将图像写入summary,输出带图像的probuf
tf.summary.image("Input", image_shape_input, 10) with tf.name_scope('layer'):
with tf.name_scope('weights'):
W = tf.Variable(tf.zeros([784, 10]), name='W')
variable_summaries(W)
with tf.name_scope('biases'):
b = tf.Variable(tf.zeros([10]), name='b')
variable_summaries(b)
with tf.name_scope('wxb'):
# tf.matmul实现矩阵乘法功能
wxb = tf.matmul(x, W) + b
with tf.name_scope('softmax'):
prediction = tf.nn.softmax(wxb) with tf.name_scope("loss"):
# 交叉熵函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,
logits=prediction))
# 绘制loss值
tf.summary.scalar("loss", loss) with tf.name_scope("Train"):
# AdamOptimizer优化器
train_step = tf.train.AdamOptimizer(LR).minimize(loss) init_op = tf.global_variables_initializer()
sess.run(init_op) # 变量初始化 with tf.name_scope("Result"):
with tf.name_scope("correct_prediction"):
# 记录预测值和标签值对比结果
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
with tf.name_scope("Accuracy"):
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 绘制准确率
tf.summary.scalar("accuracy", accuracy) # 判断是否已存在metadata.tsv文件,若存在则删除
if tf.gfile.Exists(DIR+"projector/projector/metadata.tsv"):
tf.gfile.Remove(DIR+"projector/projector/metadata.tsv") # 创建并写入metadata.tsv文件
with open(DIR+"projector/projector/metadata.tsv", 'w') as f:
labels = sess.run(tf.argmax(mnist.test.labels[:], 1))
for i in range(image_num):
f.write(str(labels[i]) + '\n') # 合并默认图表管理summary
merged = tf.summary.merge_all() projector_writer = tf.summary.FileWriter(DIR+"/projector/projector", sess.graph)
# 定义saver对象,以保存和恢复模型变量
saver = tf.train.Saver()
# 定义配置
config = projector.ProjectorConfig()
embed = config.embeddings.add()
embed.tensor_name = embedding.name
# metadata_path文件路径
embed.metadata_path = DIR+"projector/projector/metadata.tsv"
# sprite image文件路径
embed.sprite.image_path = DIR+'projector/data/mnist_10k_sprite.png'
# sprite image中每一单个图像的大小
embed.sprite.single_image_dim.extend([28, 28])
# 写入可视化配置
projector.visualize_embeddings(projector_writer, config) for i in range(max_steps):
# 每个批次100个样本
batch_xs, batch_ys = mnist.train.next_batch(100)
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
summary, _ = sess.run([merged, train_step], feed_dict={x: batch_xs, y: batch_ys},
options=run_options, run_metadata=run_metadata)
projector_writer.add_run_metadata(run_metadata, 'step%03d' % i)
projector_writer.add_summary(summary, i) if i % 100 == 0:
sess.run(tf.assign(LR, 0.001))
acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
print("Iter " + str(i) + ", Testing Accuracy= " + str(acc))
# 保存模型
saver.save(sess, DIR+'projector/projector/mnist_model.ckpt', global_step=max_steps)
projector_writer.close()
sess.close()
在cmd中输入tensorboard --logdir=tensorboard --logdir=D:\AIdata\tf_data\tf_test1\projector\projector --host=127.0.0.1

在浏览器中输入http://127.0.0.1:6006打开,会显示如下内容
显示表(loss表, 权重W...)

显示图片信息

计算图

动态放映训练过程,可在此进行模型训练,动态的观看训练状态

一个简单的TensorFlow可视化MNIST数据集识别程序的更多相关文章
- TensorFlow 下 mnist 数据集的操作及可视化
from tensorflow.examples.tutorials.mnist import input_data 首先需要连网下载数据集: mnsit = input_data.read_data ...
- Tensorflow可视化MNIST手写数字训练
简述] 我们在学习编程语言时,往往第一个程序就是打印“Hello World”,那么对于人工智能学习系统平台来说,他的“Hello World”小程序就是MNIST手写数字训练了.MNIST是一个手写 ...
- 基于TensorFlow的MNIST数据集的实验
一.MNIST实验内容 MNIST的实验比较简单,可以直接通过下面的程序加上程序上的部分注释就能很好的理解了,后面在完善具体的相关的数学理论知识,先记录在这里: 代码如下所示: import tens ...
- TensorFlow 训练MNIST数据集(2)—— 多层神经网络
在我的上一篇随笔中,采用了单层神经网络来对MNIST进行训练,在测试集中只有约90%的正确率.这次换一种神经网络(多层神经网络)来进行训练和测试. 1.获取MNIST数据 MNIST数据集只要一行代码 ...
- 《Hands-On Machine Learning with Scikit-Learn&TensorFlow》mnist数据集错误及解决方案
最近在看这本书看到Chapter 3.Classification,是关于mnist数据集的分类,里面有个代码是 from sklearn.datasets import fetch_mldata m ...
- 深度学习原理与框架-Tensorflow基本操作-mnist数据集的逻辑回归 1.tf.matmul(点乘操作) 2.tf.equal(对应位置是否相等) 3.tf.cast(将布尔类型转换为数值类型) 4.tf.argmax(返回最大值的索引) 5.tf.nn.softmax(计算softmax概率值) 6.tf.train.GradientDescentOptimizer(损失值梯度下降器)
1. tf.matmul(X, w) # 进行点乘操作 参数说明:X,w都表示输入的数据, 2.tf.equal(x, y) # 比较两个数据对应位置的数是否相等,返回值为True,或者False 参 ...
- TensorFlow训练MNIST数据集(1) —— softmax 单层神经网络
1.MNIST数据集简介 首先通过下面两行代码获取到TensorFlow内置的MNIST数据集: from tensorflow.examples.tutorials.mnist import inp ...
- 基于Keras 的VGG16神经网络模型的Mnist数据集识别并使用GPU加速
这段话放在前面:之前一种用的Pytorch,用着还挺爽,感觉挺方便的,但是在最近文献的时候,很多实验都是基于Google 的Keras的,所以抽空学了下Keras,学了之后才发现Keras相比Pyto ...
- 基于 tensorflow 的 mnist 数据集预测
1. tensorflow 基本使用方法 2. mnist 数据集简介与预处理 3. 聚类算法模型 4. 使用卷积神经网络进行特征生成 5. 训练网络模型生成结果 how to install ten ...
随机推荐
- CamStar insitexmlclient重新封装为.net Core类库
工作原因经常使用camstar的 InsiteXMLClient类库做二次开发,但是只能在4.X环境下使用,对于日益繁荣的.net core生态,花费了些时间把原有的类库重新封装为.net core ...
- django项目部署
1.布署前需要关闭调试.允许任何机器访问,在setting文件中设置 DEBUG = False ALLOW_HOSTS=['*',] 2.安装uWSGI pip install uwsgi 3.配置 ...
- python day04笔记总结
2019.4.1 S21 day04笔记总结 昨日内容补充 1.解释器/编译器 1.解释型语言.编译型语言 2.解释型:写完代码后提交给解释器,解释器将代码一行行执行.(边接收边解释/实时解释) 常用 ...
- 使用其他分支替换master分支
在提交混乱的时候, 导致master分支和远程仓库完全一致的时候,这时候解决这种问题可以创建一个新的分支, 再合并到master分支, 像这样: git checkout seotweaks git ...
- php 计算坐标点方圆周围多少米的坐标算法
//地球半径 6371千米 const EARTH_ROUNT = 6371; /** * @param $distance 方圆多少千米 默认500米 */ private function _ge ...
- 大数据实操3 - hadoop集群添加新节点
hadoop集群支持动态扩展,不需要停止原有集群节点就可以实现新节点的加入. 我是使用docker搭建的进群环境,制作了镜像文件,这里以我的工作基础为例子介绍集群中添加集群的方法 一.制作一个新节点 ...
- QVariant类
QVariant类: #include "widget.h" #include <QApplication> #include <QDebug> int m ...
- redis缓存与数据库一致性问题
一般来说,如果允许缓存可以稍微的跟数据库偶尔有不一致的情况,也就是说如果你的系统不是严格要求 “缓存+数据库” 必须保持一致性的话,最好不要做这个方案,即:读请求和写请求串行化,串到一个内存队列里去. ...
- sqlserver2017 SSAS配置远程访问不成功的问题
sqlserver2017 SSAS通过IIS配置远程访问一直访问不成功的解决办法: 出现这个问题的原因从微软给出的更新包中说的就是: 从 SQL Server 2017 开始,Analysis Se ...
- 【noip模拟赛4】Matrix67的派对 暴力dfs
[noip模拟赛4]Matrix67的派对 描述 Matrix67发现身高接近的人似乎更合得来.Matrix67举办的派对共有N(1<=N<=10)个人参加,Matrix67需要把他们 ...