Keras != tf.keras

  • Keras是一个框架

  • datasets

  • layers

  • losses

  • metrics

  • optimizers

Outline1

  • Metrics

  • update_state

  • result().numpy()

  • reset_states

Metrics

Step1.Build a meter

acc_meter = metrics.Accuarcy()
loss_meter = metrics.Mean

Step2.Update data

loss_meter.update_state(loss)
acc_meter.update_state(y,pred)

Step3.Get Average data

print(step, 'loss:', loss_meter.result().numpy())
# ...
print(step,'Evaluate Acc:', total_correct/total, acc_meter.result().numpy()

Clear buffer

if step % 100 == 0:
print(step, 'loss:', loss_meter.result().numpy())
loss_meter.reset_states() # ... if step % 500 == 0:
total, total_correct = 0., 0
acc_meter.reset_states()

Outline2

  • Compile

  • Fit

  • Evaluate

  • Predict

Compile + Fit

Individual loss and optimize1

with tf.GradientTape() as tape:
x = tf.reshape(x, (-1, 28*28))
out = network(x)
y_onehot = tf.one_hot(y, depth=10)
loss = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, out, from_logits=True)) grads = tape.gradient(loss, network.trainable_variables)
optimizer.apply_gradients(zip(grads, network.trainable_variables))

Now1

network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(fromlogits=True),
metircs=['accuracy'])

Individual epoch and step2

for epoch in range(epochs):
for step, (x, y) in enumerate(db):
# ...

Now2

network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(fromlogits=True),
metircs=['accuracy']) network.fit(db, epochs=10)

Standard Progressbar

Individual evaluation3

if step % 500 == 0:
total, total_correct = 0., 0 for step, (x, y) in enumerate(ds_val):
x = tf.reshape(x, (-1, 28*28))
out = network(x)
pred = tf.argmax(out, axis=1)
pred = tf.cast(pred, dtype=tf.int32)
correct = tf.equal(pred, y)
total_correct += tf.reduce_sum(tf.cast(correct, dtype=tf.int32)).numpy()
total += x.shape[0] print(step, 'Evaluate Acc:', total_correct/total)

Now3

network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(fromlogits=True),
metircs=['accuracy']) # validation_freq=2表示2个epochs做一次验证
network.fit(db, epochs=10, validation_data=ds_val, validation_freq=2)

Evaluation

Test

network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(fromlogits=True),
metircs=['accuracy']) # validation_freq=2表示2个epochs做一次验证
network.fit(db, epochs=10, validation_data=ds_val, validation_freq=2) network.evaluate(ds_val)

Predict

sample = next(iter(ds_val))
x = sample[0]
y = sample[1]
pred = network.predict(x)
y = tf.argmax(y, axis=1)
pred = tf.argmax(pre, axis=1) print(pred)
print(y)

Kera高层API的更多相关文章

  1. Flask 框架下 Jinja2 模板引擎高层 API 类——Environment

    Environment 类版本: 本文所描述的 Environment 类对应于 Jinja2-2.7 版本.   Environment 类功能: Environment 是 Jinja2 中的一个 ...

  2. 手写数字识别——利用keras高层API快速搭建并优化网络模型

    在<手写数字识别——手动搭建全连接层>一文中,我们通过机器学习的基本公式构建出了一个网络模型,其实现过程毫无疑问是过于复杂了——不得不考虑诸如数据类型匹配.梯度计算.准确度的统计等问题,但 ...

  3. Tcl脚本调用高层API实现仪表使用和主机创建配置的自己主动化測试用例

    #设置Chassis的基本參数,包含IP地址.port的数量等等 set chassisAddr 10.132.238.190 set islot 1 set portList {11 12} ;#端 ...

  4. Keras高层API之Metrics

    在tf.keras中,metrics其实就是起到了一个测量表的作用,即测量损失或者模型精度的变化.metrics的使用分为以下四步: step1:Build a meter acc_meter = m ...

  5. 理解 OpenStack + Ceph (3):Ceph RBD 接口和工具 [Ceph RBD API and Tools]

    本系列文章会深入研究 Ceph 以及 Ceph 和 OpenStack 的集成: (1)安装和部署 (2)Ceph RBD 接口和工具 (3)Ceph 物理和逻辑结构 (4)Ceph 的基础数据结构 ...

  6. 分布式消息队列kafka系列介绍 — 核心API介绍及实例

    原文地址:http://www.inter12.org/archives/834 一 PRODUCER的API 1.Producer的创建,依赖于ProducerConfig public Produ ...

  7. API设计原则(觉得太合适,转发做记录)

    API设计原则 对于云计算系统,系统API实际上处于系统设计的统领地位,正如本文前面所说,K8s集群系统每支持一项新功能,引入一项新技术,一定会新引入对应的API对象,支持对该功能的管理操作,理解掌握 ...

  8. TensorFlow高级API(tf.contrib.learn)及可视化工具TensorBoard的使用

    一.TensorFlow高层次机器学习API (tf.contrib.learn) 1.tf.contrib.learn.datasets.base.load_csv_with_header 加载cs ...

  9. 蓝牙中文API文档

    蓝牙是一种低成本.短距离的无线通信技术.对于那些希望创建个人局域网(PANs)的人们来说,蓝牙技术已经越来越流行了.每个个人局域网都在独立设备的周围被动态地创建,并且为蜂窝式电话和PDA等设备提供了自 ...

随机推荐

  1. Local Response Normalization 60 million parameters and 500,000 neurons

    CNN是工具,在图像识别中是发现图像中待识别对象的特征的工具,是剔除对识别结果无用信息的工具. ImageNet Classification with Deep Convolutional Neur ...

  2. 7-5 打印选课学生名单(25 point(s)) 【排序】

    7-5 打印选课学生名单(25 point(s)) 假设全校有最多40000名学生和最多2500门课程.现给出每个学生的选课清单,要求输出每门课的选课学生名单. 输入格式: 输入的第一行是两个正整数: ...

  3. Linux随笔-鸟哥Linux基础篇学习总结(全)

    Linux随笔-鸟哥Linux基础篇学习总结(全) 修改Linux系统语系:LANG-en_US,如果我们想让系统默认的语系变成英文的话我们可以修改系统配置文件:/etc/sysconfig/i18n ...

  4. C++中map容器的说明和使用技巧

    C++中map容器提供一个键值对容器,map与multimap差别仅仅在于multiple允许一个键对应多个值. 一.map的说明 1 头文件 #include <map> 2 定义 ma ...

  5. Java(一)——认识Java语言

    1.Java语言简介 Java是一种可以撰写跨平台应用程序的面向对象的程序设计语言,具有卓越的通用性.高效性.平台移植性和安全性.Sun 公司对 Java 编程语言的解释是:Java 编程语言是个简单 ...

  6. Relocation POJ-2923

    题目链接 题目意思: 有 n 个货物,并且知道了每个货物的重量,每次用载重量分别为c1,c2的火车装载,问最少需要运送多少次可以将货物运完. 分析:本题可以用二进制枚举所有不冲突的方案,再来dp 一下 ...

  7. keras中的Flatten和Reshape

    最近在看SSD源码的时候,就一直不理解,在模型构建的时候如果使用Flatten或者是Merge层,那么整个数据的shape就发生了变化,那么还可以对应起来么(可能你不知道我在说什么)?后来不知怎么的, ...

  8. 对Webview跨源攻击的理解

    首先是addJavaScriptInterface漏洞: 有时候访问手机百度贴吧网页版本,网页上会有个按钮提示用手机应用打开.这种交互通常都是使用JS来实现,而WebView已经提供了这样的方法,具体 ...

  9. kitti数据集标定文件解析

    1.kitti数据采集平台 KITTI数据集的数据采集平台装配有2个灰度摄像机,2个彩色摄像机,一个Velodyne64线3D激光雷达,4个光学镜头,以及1个GPS导航系统.图示为传感器的配置平面图, ...

  10. linux下监控用户的操作记录---录像播放性质

    想知道用户登陆系统后都操作了什么,怎么办? 别急,linux下有一个script工具,专门记录终端会话中所有输入输出结果,并存放到指定文件中. 先看看怎么录制吧! 1.创建日志存放目录 # mkdir ...