下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev))

# -*- coding: utf-8 -*-

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector old_v = tf.logging.get_verbosity()
tf.logging.set_verbosity(tf.logging.ERROR)
# 载入数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# 运行次数
max_steps = 3001
# 图片数量
image_num = 5000
# 文件路径
DIR = "D:/AIdata/tf_data/tf_test1/" sess = tf.Session() # 载入图片,
# tf.stack矩阵拼接函数,
embedding = tf.Variable(tf.stack(mnist.test.images[:image_num]),
trainable=False, name="embedding") def variable_summaries(var):
with tf.name_scope("summaries"):
mean = tf.reduce_mean(var)
with tf.name_scope("stddev"):
# 计算标准差
stddev = tf.sqrt(tf.reduce_mean(tf.square(var-mean)))
# 绘制标准差信息
tf.summary.scalar("stddev", stddev)
# 绘制最大值
tf.summary.scalar("max", tf.reduce_max(var))
tf.summary.scalar("min", tf.reduce_min(var))
# 绘制直方图信息
tf.summary.histogram("histogram", var) with tf.name_scope('Input'):
x = tf.placeholder(tf.float32, [None, 784], name="x_input")
y = tf.placeholder(tf.float32, [None, 10], name="y_input")
LR = tf.Variable(0.001, dtype=tf.float32) # 显示图片
with tf.name_scope("input_reshape"):
# 改变x的形状(28x28x1)
image_shape_input = tf.reshape(x, [-1, 28, 28, 1])
# 将图像写入summary,输出带图像的probuf
tf.summary.image("Input", image_shape_input, 10) with tf.name_scope('layer'):
with tf.name_scope('weights'):
W = tf.Variable(tf.zeros([784, 10]), name='W')
variable_summaries(W)
with tf.name_scope('biases'):
b = tf.Variable(tf.zeros([10]), name='b')
variable_summaries(b)
with tf.name_scope('wxb'):
# tf.matmul实现矩阵乘法功能
wxb = tf.matmul(x, W) + b
with tf.name_scope('softmax'):
prediction = tf.nn.softmax(wxb) with tf.name_scope("loss"):
# 交叉熵函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,
logits=prediction))
# 绘制loss值
tf.summary.scalar("loss", loss) with tf.name_scope("Train"):
# AdamOptimizer优化器
train_step = tf.train.AdamOptimizer(LR).minimize(loss) init_op = tf.global_variables_initializer()
sess.run(init_op) # 变量初始化 with tf.name_scope("Result"):
with tf.name_scope("correct_prediction"):
# 记录预测值和标签值对比结果
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
with tf.name_scope("Accuracy"):
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 绘制准确率
tf.summary.scalar("accuracy", accuracy) # 判断是否已存在metadata.tsv文件,若存在则删除
if tf.gfile.Exists(DIR+"projector/projector/metadata.tsv"):
tf.gfile.Remove(DIR+"projector/projector/metadata.tsv") # 创建并写入metadata.tsv文件
with open(DIR+"projector/projector/metadata.tsv", 'w') as f:
labels = sess.run(tf.argmax(mnist.test.labels[:], 1))
for i in range(image_num):
f.write(str(labels[i]) + '\n') # 合并默认图表管理summary
merged = tf.summary.merge_all() projector_writer = tf.summary.FileWriter(DIR+"/projector/projector", sess.graph)
# 定义saver对象,以保存和恢复模型变量
saver = tf.train.Saver()
# 定义配置
config = projector.ProjectorConfig()
embed = config.embeddings.add()
embed.tensor_name = embedding.name
# metadata_path文件路径
embed.metadata_path = DIR+"projector/projector/metadata.tsv"
# sprite image文件路径
embed.sprite.image_path = DIR+'projector/data/mnist_10k_sprite.png'
# sprite image中每一单个图像的大小
embed.sprite.single_image_dim.extend([28, 28])
# 写入可视化配置
projector.visualize_embeddings(projector_writer, config) for i in range(max_steps):
# 每个批次100个样本
batch_xs, batch_ys = mnist.train.next_batch(100)
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
summary, _ = sess.run([merged, train_step], feed_dict={x: batch_xs, y: batch_ys},
options=run_options, run_metadata=run_metadata)
projector_writer.add_run_metadata(run_metadata, 'step%03d' % i)
projector_writer.add_summary(summary, i) if i % 100 == 0:
sess.run(tf.assign(LR, 0.001))
acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
print("Iter " + str(i) + ", Testing Accuracy= " + str(acc))
# 保存模型
saver.save(sess, DIR+'projector/projector/mnist_model.ckpt', global_step=max_steps)
projector_writer.close()
sess.close()

在cmd中输入tensorboard --logdir=tensorboard --logdir=D:\AIdata\tf_data\tf_test1\projector\projector  --host=127.0.0.1

在浏览器中输入http://127.0.0.1:6006打开,会显示如下内容

显示表(loss表, 权重W...)

显示图片信息

计算图

动态放映训练过程,可在此进行模型训练,动态的观看训练状态

一个简单的TensorFlow可视化MNIST数据集识别程序的更多相关文章

  1. TensorFlow 下 mnist 数据集的操作及可视化

    from tensorflow.examples.tutorials.mnist import input_data 首先需要连网下载数据集: mnsit = input_data.read_data ...

  2. Tensorflow可视化MNIST手写数字训练

    简述] 我们在学习编程语言时,往往第一个程序就是打印“Hello World”,那么对于人工智能学习系统平台来说,他的“Hello World”小程序就是MNIST手写数字训练了.MNIST是一个手写 ...

  3. 基于TensorFlow的MNIST数据集的实验

    一.MNIST实验内容 MNIST的实验比较简单,可以直接通过下面的程序加上程序上的部分注释就能很好的理解了,后面在完善具体的相关的数学理论知识,先记录在这里: 代码如下所示: import tens ...

  4. TensorFlow 训练MNIST数据集(2)—— 多层神经网络

    在我的上一篇随笔中,采用了单层神经网络来对MNIST进行训练,在测试集中只有约90%的正确率.这次换一种神经网络(多层神经网络)来进行训练和测试. 1.获取MNIST数据 MNIST数据集只要一行代码 ...

  5. 《Hands-On Machine Learning with Scikit-Learn&TensorFlow》mnist数据集错误及解决方案

    最近在看这本书看到Chapter 3.Classification,是关于mnist数据集的分类,里面有个代码是 from sklearn.datasets import fetch_mldata m ...

  6. 深度学习原理与框架-Tensorflow基本操作-mnist数据集的逻辑回归 1.tf.matmul(点乘操作) 2.tf.equal(对应位置是否相等) 3.tf.cast(将布尔类型转换为数值类型) 4.tf.argmax(返回最大值的索引) 5.tf.nn.softmax(计算softmax概率值) 6.tf.train.GradientDescentOptimizer(损失值梯度下降器)

    1. tf.matmul(X, w) # 进行点乘操作 参数说明:X,w都表示输入的数据, 2.tf.equal(x, y) # 比较两个数据对应位置的数是否相等,返回值为True,或者False 参 ...

  7. TensorFlow训练MNIST数据集(1) —— softmax 单层神经网络

    1.MNIST数据集简介 首先通过下面两行代码获取到TensorFlow内置的MNIST数据集: from tensorflow.examples.tutorials.mnist import inp ...

  8. 基于Keras 的VGG16神经网络模型的Mnist数据集识别并使用GPU加速

    这段话放在前面:之前一种用的Pytorch,用着还挺爽,感觉挺方便的,但是在最近文献的时候,很多实验都是基于Google 的Keras的,所以抽空学了下Keras,学了之后才发现Keras相比Pyto ...

  9. 基于 tensorflow 的 mnist 数据集预测

    1. tensorflow 基本使用方法 2. mnist 数据集简介与预处理 3. 聚类算法模型 4. 使用卷积神经网络进行特征生成 5. 训练网络模型生成结果 how to install ten ...

随机推荐

  1. 不同应用共享redis应用,但分数据库存储数据

    日常开发工作中,常常遇到这种情况 项目A ,需要使用redis 项目B ,也需使用redis …… 原来傻乎乎的在服务器上装几个redis,通过不同的端口号来进行使用 其实redis可用有16个数据库 ...

  2. html和css问题?

    1.说说你对语义化的理解?答,去掉或者丢失样式的时候能够让页面呈现出清晰的结构方便其他设备解析(如屏幕阅读器.盲人阅读器.移动设备)以意义的方式来渲染网页:便于团队开发和维护,语义化更具可读性,是下一 ...

  3. 在anaconda下安装已经下载好Opencv4的痛苦回忆

    来来回回装了很多回,今天终于一鼓作气把它安装好,记录一下过程. 准备: Opencv4的安装包,可以在官网上下载 anaconda——主要目的是在anaconda下的某个environment中安装最 ...

  4. python入门(十三):面向对象(继承、重写、公有、私有)

    1. 三种类定义的写法  class P1:#定义类   加不加()都可以    pass   class P2():                    #带(),且括号中为空,类定义 pass ...

  5. React中this.props的主要属性

    this.props主要包含:history属性.location属性.match属性 ①history属性又包含 ②location属性又包含 ③match属性又包含

  6. Exp2 后门原理与实践 20165110

    Exp2 后门原理与实践 一.实验要求 1.使用netcat获取主机操作Shell,cron启动 2.使用socat获取主机操作Shell, 任务计划启动 3.使用MSF meterpreter(或其 ...

  7. Java I/O - 对象的输入输出与序列化

    先说概念: 一.相关概念 序列化是Java提供的一种将对象写入到输出流.并在之后将其读回的机制. 序列化:把内存中的java对象转换成与平台无关的二进制字节序列,以便永久保存在磁盘上或通过网络进行传输 ...

  8. JMeter 通过JSON Extractor 插件来提取响应结果

    接口响应结果,通常为HTML.JSON格式的数据,对于HTML的响应结果的提取,可以通过正则表达式,也可以通过XPath 来提取. 对于JSON格式的数据,可以通过正则表达式.JSON Extract ...

  9. PHP+Mysql 实现数据库增删改查(原生)

    Mysql数据库创建 创建一个新闻列表的数据库: 1. 查询数据库 1.1. 创建文件dbconfig.php,保存常量 <?php define("HOST"," ...

  10. mysql启动服务

    mysql.server start 启动mysql服务mysql.server stop 停止mysql服务 mysql密码:123456Az_