《TensorFlow2深度学习》学习笔记(四)对笔记二中的模型增加正确率展示
全部代码如下:(红色部分为与笔记二不同之处)
#1.Import the neccessary libraries needed
import numpy as np
import tensorflow as tf
import matplotlib
from matplotlib import pyplot as plt ######################################################################## #2.Set default parameters for plots
matplotlib.rcParams['font.size'] = 20
matplotlib.rcParams['figure.titlesize'] = 20
matplotlib.rcParams['figure.figsize'] = [9, 7]
matplotlib.rcParams['font.family'] = ['STKaiTi']
matplotlib.rcParams['axes.unicode_minus']=False ######################################################################## #3.Initialize Parameters #Initialize learning rate
lr = 1e-2 #----------------------changed
#Initialize batch size
batchsz = 512
#Initialize loss and accurate array
losses = []
accs = [] #----------------------changed
#Initialize the weights layers and the bias layers
w1=tf.Variable(tf.random.truncated_normal([784,256],stddev=0.1))
b1=tf.Variable(tf.zeros([256]))
w2=tf.Variable(tf.random.truncated_normal([256,128],stddev=0.1))
b2=tf.Variable(tf.zeros([128]))
w3=tf.Variable(tf.random.truncated_normal([128,10],stddev=0.1))
b3=tf.Variable(tf.zeros([10])) ########################################################################
#4.Define preprocess function #----------------------changed
def preprocess(x,y):
x=tf.cast(x,dtype=tf.float32)/255.
x=tf.reshape(x,[-1,28*28])
y=tf.cast(y,dtype=tf.int32)
#one_hot接受的输入为int32,输出为float32
y=tf.one_hot(y,depth=10)
return x,y ######################################################################## #5.Import the minist dataset offline
(x_train,y_train),(x_test,y_test)=tf.keras.datasets.mnist.load_data(path=r'F:\learning\machineLearning\TensorFlow2_deeplearning\forward_progression\mnist.npz')
train_db=tf.data.Dataset.from_tensor_slices((x_train,y_train))
train_db=train_db.shuffle(10000) #-----------------------changed
train_db=train_db.batch(batchsz)
train_db=train_db.map(preprocess)
#Control the epoch times
train_db=train_db.repeat(20) test_db=tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db=test_db.shuffle(1000).batch(batchsz).map(preprocess) ######################################################################## #The main function
def main():
for step,(x,y) in enumerate(train_db):#Or for x,y in train_db:
with tf.GradientTape() as tape: # tf.Variable
# layer1
h1 = x@w1 + b1
h1 = tf.nn.relu(h1)
# layer2
h2 = h1@w2 + b2
h2 = tf.nn.relu(h2)
# output
out = h2@w3 + b3
# compute loss
loss = tf.square(y-out)
# mean: scalar
loss = tf.reduce_mean(loss)
# compute gradients
grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
#Update the weights and the bias #-----------------------changed
for p, g in zip([w1, b1, w2, b2, w3, b3], grads):
p.assign_sub(lr * g) if step % 80 == 0:
print(step, 'loss:', float(loss))
losses.append(float(loss)) if step % 80 == 0: #-----------------------changed
total, total_correct = 0., 0
for x,y in test_db:
# layer1
h1 = x@w1 + b1
h1 = tf.nn.relu(h1)
# layer2
h2 = h1@w2 + b2
h2 = tf.nn.relu(h2)
# output
out = h2@w3 + b3
pred=tf.argmax(out,axis=1)
y=tf.argmax(y,axis=1)
correct=tf.equal(pred,y)
total_correct+=tf.reduce_sum(tf.cast(correct,dtype=tf.int32)).numpy()
total+=x.shape[0]
print(step,'Evaluate ACC:',total_correct/total)
accs.append(total_correct/total)
plt.figure()
x = [i*80 for i in range(len(losses))]
plt.plot(x, losses, color='C0', marker='s', label='训练')
plt.ylabel('MSE')
plt.xlabel('Step')
plt.legend() plt.figure()
plt.plot(x, accs, color='C1', marker='s', label='测试')
plt.ylabel('准确率')
plt.xlabel('Step')
plt.legend() plt.show()
if __name__ == '__main__':
main()
其中learning rate在此处改为了1e-2,经测试若为1e-3则accurate rate会增长较慢,在20epoch下最终会达到30~40%,而1e-2则会接近80%
并且通过.map(preprocess)方法预处理了train_db,包括将图片数据标准化到(0-1),reshape到[-1,28*28],将标签数据做one-hot处理,深度为10;通过train_db=train_db.repeat(20)代替了for epoch in range(20);用
for p, g in zip([w1, b1, w2, b2, w3, b3], grads):
p.assign_sub(lr * g)
w1.assign_sub(lr * grads[0])
b1.assign_sub(lr * grads[1])
w2.assign_sub(lr * grads[2])
b2.assign_sub(lr * grads[3])
w3.assign_sub(lr * grads[4])
b3.assign_sub(lr * grads[5])
《TensorFlow2深度学习》学习笔记(四)对笔记二中的模型增加正确率展示的更多相关文章
- ThinkPHP 学习笔记 ( 四 ) 数据库操作之关联模型 ( RelationMondel ) 和高级模型 ( AdvModel )
一.关联模型 ( RelationMondel ) 1.数据查询 ① HAS_ONE 查询 创建两张数据表评论表和文章表: tpk_comment , tpk_article .评论和文章的对应关系为 ...
- 深度学习课程笔记(十四)深度强化学习 --- Proximal Policy Optimization (PPO)
深度学习课程笔记(十四)深度强化学习 --- Proximal Policy Optimization (PPO) 2018-07-17 16:54:51 Reference: https://b ...
- 官网实例详解-目录和实例简介-keras学习笔记四
官网实例详解-目录和实例简介-keras学习笔记四 2018-06-11 10:36:18 wyx100 阅读数 4193更多 分类专栏: 人工智能 python 深度学习 keras 版权声明: ...
- C#可扩展编程之MEF学习笔记(四):见证奇迹的时刻
前面三篇讲了MEF的基础和基本到导入导出方法,下面就是见证MEF真正魅力所在的时刻.如果没有看过前面的文章,请到我的博客首页查看. 前面我们都是在一个项目中写了一个类来测试的,但实际开发中,我们往往要 ...
- iOS阶段学习第四天笔记(循环)
iOS学习(C语言)知识点整理笔记 一.分支结构 1.分支结构分为单分支 即:if( ){ } ;多分支 即:if( ){ }else{ } 两种 2.单分支 if表达式成立则执行{ }里的语句:双 ...
- IOS学习笔记(四)之UITextField和UITextView控件学习
IOS学习笔记(四)之UITextField和UITextView控件学习(博客地址:http://blog.csdn.net/developer_jiangqq) Author:hmjiangqq ...
- java之jvm学习笔记四(安全管理器)
java之jvm学习笔记四(安全管理器) 前面已经简述了java的安全模型的两个组成部分(类装载器,class文件校验器),接下来学习的是java安全模型的另外一个重要组成部分安全管理器. 安全管理器 ...
- Java学习笔记四---打包成双击可运行的jar文件
写笔记四前的脑回路是这样的: 前面的学习笔记二,提到3个环境变量,其中java_home好理解,就是jdk安装路径:classpath指向类文件的搜索路径:path指向可执行程序的搜索路径.这里的类文 ...
- Learning ROS for Robotics Programming Second Edition学习笔记(四) indigo devices
中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS for Robotics Pr ...
随机推荐
- EasyDSS高性能RTMP、HLS(m3u8)、HTTP-FLV、RTSP流媒体服务器软件实现的多码率视频点播功能说明
关于EasyDSS EasyDSS(http://www.easydss.com)流媒体解决方案采用业界优秀的流媒体框架模式设计,服务运行轻量.高效.稳定.可靠.易维护,支持RTMP直播.RTMP推送 ...
- LwIP应用开发笔记之七:LwIP无操作系统HTTP服务器
前面我们实现了TCP服务器和客户端的简单应用,接下来我们实现一个基于TCP协议的应用协议,那就是HTTP超文本传输协议 1. HTTP协议简介 超文本传输协议(Hyper Text Transf ...
- zabbix自动停用与开启agent
我们在升级环境时遇到了一个问题,那就是zabbix会自动发送邮件给领导,此时领导心里会嘎嘣一下,为了给领导营造一个良好的环境,减少不必要的告警邮件,减少嘎嘣次数,于是在升级之前,取消zabbix监控的 ...
- 晶体管放大电路与Multisim仿真学习笔记
前言 开始写点博客记录学习的点滴,第一篇就写基本的共射极放大电路吧. 很多教材都是偏重理论,而铃木雅臣著作的<晶体管电路设计>是一本很实用的书籍,个人十分推荐! 下面开始我的模电重温之旅吧 ...
- 【记录】【java】反射设值取值
1.设值 /** * 根据属性名设置属性值 * * @param fieldName * @param object * @return */ public boolean setFieldValue ...
- Mybaties的简单使用(全当做复习了)
在使用mybaties的时候,最容易忘掉的是他的动态SQL,不过网上有关这方面的文章很多. 在动态SQl中最常见的几种SQL的语法就是: if choose (when, otherwise) tri ...
- TCP/IP学习笔记1--概述,分组交换协议
1.TCP/IP 互联网是由许多独立发展的网络通信技术融合而成的,能够使它们不断融合并实现统一的正式TCP/IP技术,TCP/IP使通信协议的统称. TCP/IP协议模型(Transmission C ...
- Flume和 Sqoop
Sqoop简介 Sqoop是一种旨在有效地在Apache Hadoop和诸如关系数据库等结构化数据存储之间传输大量数据的工具 原理: 将导入或导出命令翻译成Mapreduce程序来实现. 在翻译出的M ...
- js中基本包装类型详情
基本包装类型 基本包装类型有Boolean,Number和string类型,每当读取一个基本类型值时,后台就会创建一个对应的基本包装类型对象. 从逻辑上,基本类型值不是对象,没有方法,但从技术上来看, ...
- 大数据之路【第十四篇】:数据挖掘--推荐算法(Mahout工具)
数据挖掘---推荐算法(Mahout工具) 一.简介 Apache顶级项目(2010.4) Hadoop上的开源机器学习库 可伸缩扩展的 Java库 推荐引擎(协同过滤).聚类和分类 二.机器学习介绍 ...