Keras != tf.keras

  • Keras是一个框架

  • datasets

  • layers

  • losses

  • metrics

  • optimizers

Outline1

  • Metrics

  • update_state

  • result().numpy()

  • reset_states

Metrics

Step1.Build a meter

acc_meter = metrics.Accuarcy()
loss_meter = metrics.Mean

Step2.Update data

loss_meter.update_state(loss)
acc_meter.update_state(y,pred)

Step3.Get Average data

print(step, 'loss:', loss_meter.result().numpy())
# ...
print(step,'Evaluate Acc:', total_correct/total, acc_meter.result().numpy()

Clear buffer

if step % 100 == 0:
print(step, 'loss:', loss_meter.result().numpy())
loss_meter.reset_states() # ... if step % 500 == 0:
total, total_correct = 0., 0
acc_meter.reset_states()

Outline2

  • Compile

  • Fit

  • Evaluate

  • Predict

Compile + Fit

Individual loss and optimize1

with tf.GradientTape() as tape:
x = tf.reshape(x, (-1, 28*28))
out = network(x)
y_onehot = tf.one_hot(y, depth=10)
loss = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, out, from_logits=True)) grads = tape.gradient(loss, network.trainable_variables)
optimizer.apply_gradients(zip(grads, network.trainable_variables))

Now1

network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(fromlogits=True),
metircs=['accuracy'])

Individual epoch and step2

for epoch in range(epochs):
for step, (x, y) in enumerate(db):
# ...

Now2

network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(fromlogits=True),
metircs=['accuracy']) network.fit(db, epochs=10)

Standard Progressbar

Individual evaluation3

if step % 500 == 0:
total, total_correct = 0., 0 for step, (x, y) in enumerate(ds_val):
x = tf.reshape(x, (-1, 28*28))
out = network(x)
pred = tf.argmax(out, axis=1)
pred = tf.cast(pred, dtype=tf.int32)
correct = tf.equal(pred, y)
total_correct += tf.reduce_sum(tf.cast(correct, dtype=tf.int32)).numpy()
total += x.shape[0] print(step, 'Evaluate Acc:', total_correct/total)

Now3

network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(fromlogits=True),
metircs=['accuracy']) # validation_freq=2表示2个epochs做一次验证
network.fit(db, epochs=10, validation_data=ds_val, validation_freq=2)

Evaluation

Test

network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(fromlogits=True),
metircs=['accuracy']) # validation_freq=2表示2个epochs做一次验证
network.fit(db, epochs=10, validation_data=ds_val, validation_freq=2) network.evaluate(ds_val)

Predict

sample = next(iter(ds_val))
x = sample[0]
y = sample[1]
pred = network.predict(x)
y = tf.argmax(y, axis=1)
pred = tf.argmax(pre, axis=1) print(pred)
print(y)

Kera高层API的更多相关文章

  1. Flask 框架下 Jinja2 模板引擎高层 API 类——Environment

    Environment 类版本: 本文所描述的 Environment 类对应于 Jinja2-2.7 版本.   Environment 类功能: Environment 是 Jinja2 中的一个 ...

  2. 手写数字识别——利用keras高层API快速搭建并优化网络模型

    在<手写数字识别——手动搭建全连接层>一文中,我们通过机器学习的基本公式构建出了一个网络模型,其实现过程毫无疑问是过于复杂了——不得不考虑诸如数据类型匹配.梯度计算.准确度的统计等问题,但 ...

  3. Tcl脚本调用高层API实现仪表使用和主机创建配置的自己主动化測试用例

    #设置Chassis的基本參数,包含IP地址.port的数量等等 set chassisAddr 10.132.238.190 set islot 1 set portList {11 12} ;#端 ...

  4. Keras高层API之Metrics

    在tf.keras中,metrics其实就是起到了一个测量表的作用,即测量损失或者模型精度的变化.metrics的使用分为以下四步: step1:Build a meter acc_meter = m ...

  5. 理解 OpenStack + Ceph (3):Ceph RBD 接口和工具 [Ceph RBD API and Tools]

    本系列文章会深入研究 Ceph 以及 Ceph 和 OpenStack 的集成: (1)安装和部署 (2)Ceph RBD 接口和工具 (3)Ceph 物理和逻辑结构 (4)Ceph 的基础数据结构 ...

  6. 分布式消息队列kafka系列介绍 — 核心API介绍及实例

    原文地址:http://www.inter12.org/archives/834 一 PRODUCER的API 1.Producer的创建,依赖于ProducerConfig public Produ ...

  7. API设计原则(觉得太合适,转发做记录)

    API设计原则 对于云计算系统,系统API实际上处于系统设计的统领地位,正如本文前面所说,K8s集群系统每支持一项新功能,引入一项新技术,一定会新引入对应的API对象,支持对该功能的管理操作,理解掌握 ...

  8. TensorFlow高级API(tf.contrib.learn)及可视化工具TensorBoard的使用

    一.TensorFlow高层次机器学习API (tf.contrib.learn) 1.tf.contrib.learn.datasets.base.load_csv_with_header 加载cs ...

  9. 蓝牙中文API文档

    蓝牙是一种低成本.短距离的无线通信技术.对于那些希望创建个人局域网(PANs)的人们来说,蓝牙技术已经越来越流行了.每个个人局域网都在独立设备的周围被动态地创建,并且为蜂窝式电话和PDA等设备提供了自 ...

随机推荐

  1. static 静态域 类域 静态方法 工厂方法 he use of the static keyword to create fields and methods that belong to the class, rather than to an instance of the class 非访问修饰符

    总结: 1.无论一个类实例化多少对象,它的静态变量只有一份拷贝: 静态域属于类,而非由类构造的实例化的对象,所有类的实例对象共享静态域. class Employee { private static ...

  2. 对ShortCut和TWMKey的研究

    TWMKey = packed record Msg: Cardinal; CharCode: Word; Unused: Word; KeyData: Longint; Result: Longin ...

  3. 设置开启telnet功能

    今天访问服务器的时候发现ip可以ping通,但是不能访问,就telnet一下端口吧,谁知系统逗我:

  4. 阿里妈妈-RAP项目的实践(2)

    接口详情 (id: 32872) Mock数据 接口名称 datalist1 请求类型 get 请求Url /datas/list1 接口描述 数据列表 请求参数列表 变量名 含义 类型 备注 响应参 ...

  5. hihocoder #1122 二分图二•二分图最大匹配之匈牙利算法(*【模板】应用 )

    梳理整个算法: 1. 依次枚举每一个点i: 2. 若点i尚未匹配,则以此点为起点查询一次交错路径. 最后即可得到最大匹配数. 在这个基础上仍然有两个可以优化的地方: 1.对于点的枚举:当我们枚举了所有 ...

  6. C++中如何计算程序运行的时间 (转载)

    转载地址:http://blog.csdn.net/wuxuguang123/article/details/8130081 一 个程序的功能通常有很多种方法来实现,怎么样的程序才算得上最优呢?举个例 ...

  7. Gym - 101147E E. Jumping —— bfs

    题目链接:http://codeforces.com/gym/101147/problem/E 题意:当人在第i个商店时,他可以向左或向右跳di段距离到达另一个商店(在范围之内),一个商店为一段距离. ...

  8. 恶心的struts标签,等我毕业设计弄完了,瞧我怎么收拾你。

    1.从java action中到页面中获取变量值的struts标签 获取从bean中定义的对象中属性的值: <s:property value="#request.cardTo.acc ...

  9. UniCode转码

    <script type="text/javascript"> var GB2312UnicodeConverter = { ToUnicode: function ( ...

  10. 「USACO06FEB」「LuoguP2858」奶牛零食Treats for the Cows(区间dp

    题目描述 FJ has purchased N (1 <= N <= 2000) yummy treats for the cows who get money for giving va ...