[TensorFlow 2.0] Keras 简介
Keras 是一个用于构建和训练深度学习模型的高阶 API。它可用于快速设计原型、高级研究和生产。
keras的3个优点: 方便用户使用、模块化和可组合、易于扩展
简单点说就是,简单、好用、快(构建)
引用方法:
import tensorflow as tf
from tensorflow.keras import layers
简单构建一个模型
先上代码
input_x = tf.keras.Input(shape=(72,))
hidden1 = layers.Dense(32, activation='relu')(input_x)
hidden2 = layers.Dense(16, activation='relu')(hidden1)
pred = layers.Dense(10, activation='softmax')(hidden2)
model = tf.keras.Model(inputs=input_x, outputs=pred)
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.categorical_crossentropy,
metrics=['accuracy'])
model.fit(train_x, train_y, batch_size=32, epochs=5)
tf.keras.Input
实例化一个 Keras tensor.
doc:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Input
方法定义如下:
tf.keras.Input(
shape=None,
batch_size=None,
name=None,
dtype=None,
sparse=False,
tensor=None,
**kwargs
)
tf.keras.layers.Dense
定义你的神经网络
这里官方在文档里加了一个 “densely-connected”来形容
doc:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/layers/Dense

tf.keras.Model
Model groups layers into an object with training and inference features.
额,不知道该怎么翻译了
我的理解就是把你的之前定义的网络给链接起来
跟上上面的代码可能不太好理解,等看后面把模型按照面向对象的思想构建的时候,就方便理解了。
当然,你可以自己去看下文档
doc:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model
tf.keras.Model.compile
配置模型训练时的相关数据
配置模型训练时,使用的相关参数。
比如,学习率、loss 等等
doc:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model#compile

model.fit
进行模型训练
doc:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model#fit

模型-OOP
通过OOP的思想,进行设计我们的模型
先看代码
class MyModel(tf.keras.Model):
def __init__(self, num_classes=10):
super(MyModel, self).__init__(name='my_model')
self.num_classes = num_classes
self.layer1 = layers.Dense(32, activation='relu')
self.layer2 = layers.Dense(num_classes, activation='softmax')
def call(self, inputs):
h1 = self.layer1(inputs)
out = self.layer2(h1)
return out
def compute_output_shape(self, input_shape):
shape = tf.TensorShape(input_shape).as_list()
shape[-1] = self.num_classes
return tf.TensorShape(shape)
model = MyModel(num_classes=10)
model.compile(optimizer=tf.keras.optimizers.RMSprop(0.001),
loss=tf.keras.losses.categorical_crossentropy,
metrics=['accuracy'])
model.fit(train_x, train_y, batch_size=16, epochs=5)
如果我们需要通过类来构造我们的模型,那么以下几点是必须的
1.继承tf.keras.Model
2.在__init__中调用以下父类,并构造我们的模型
super(MyModel, self).__init__(name='my_model')
self.num_classes = num_classes
self.layer1 = layers.Dense(32, activation='relu')
self.layer2 = layers.Dense(num_classes, activation='softmax')
3.在call中实现forward
4.compute_output_shape这个方法,在文档中,说明只有当模型修改了输入数据的形状时,才需要进行定义,否则没有必要。
但具体的效果,没有找到样例来参考。如果有懂得大神,希望举个例子。
模型保存
单独保存权重
model.save_weights('./weights/model')
model.load_weights('./weights/model')
model.save_weights('./model.h5')
model.load_weights('./model.h5')
单独保存模型结构
#json
model.to_json()
#yam
model.to_yaml()
保存整个模型
model.save('all_model.h5')
[TensorFlow 2.0] Keras 简介的更多相关文章
- 『TensorFlow2.0正式版』TF2.0+Keras速成教程·零:开篇简介与环境准备
此篇教程参考自TensorFlow 2.0 + Keras Crash Course,在原文的基础上进行了适当的总结与改编,以适应于国内开发者的理解与使用,水平有限,如果写的不对的地方欢迎大家评论指出 ...
- TensorFlow 2.0 新特性
安装 TensorFlow 2.0 Alpha 本文仅仅介绍 Windows 的安装方式: pip install tensorflow==2.0.0-alpha0 # cpu 版本 pip inst ...
- 三分钟快速上手TensorFlow 2.0 (后续)——扩展和附录
TensorFlow Hub 模型复用 TF Hub 网站 打开主页 https://tfhub.dev/ ,在左侧有 Text.Image.Video 和 Publishers 等选项,可以选取关注 ...
- Tensorflow 2.0 深度学习实战 —— 详细介绍损失函数、优化器、激活函数、多层感知机的实现原理
前言 AI 人工智能包含了机器学习与深度学习,在前几篇文章曾经介绍过机器学习的基础知识,包括了监督学习和无监督学习,有兴趣的朋友可以阅读< Python 机器学习实战 >.而深度学习开始只 ...
- TensorFlow 2.0 快速入门指南 | iBooker·ApacheCN
原文:TensorFlow 2.0 Quick Start Guide 协议:CC BY-NC-SA 4.0 自豪地采用谷歌翻译 不要担心自己的形象,只关心如何实现目标.--<原则>,生活 ...
- 【深度学习与TensorFlow 2.0】卷积神经网络(CNN)
注:在很长一段时间,MNIST数据集都是机器学习界很多分类算法的benchmark.初学深度学习,在这个数据集上训练一个有效的卷积神经网络就相当于学习编程的时候打印出一行“Hello World!”. ...
- Tensorflow 2.0 datasets数据加载
导入包 import tensorflow as tf from tensorflow import keras 加载数据 tensorflow可以调用keras自带的datasets,很方便,就是有 ...
- 独家 | TensorFlow 2.0将把Eager Execution变为默认执行模式,你该转向动态计算图了
机器之心报道 作者:邱陆陆 8 月中旬,谷歌大脑成员 Martin Wicke 在一封公开邮件中宣布,新版本开源框架——TensorFlow 2.0 预览版将在年底之前正式发布.今日,在上海谷歌开发者 ...
- 极简安装 TensorFlow 2.0 GPU
前言 之前写了几篇关于 TensorFlow 1.x GPU 版本安装的博客,但几乎没怎么学习过.之前基本在搞 Machine Learning 和 Data Mining 方面的东西,极少用到 NN ...
随机推荐
- html-前端内容初识
HTML解释: HTML是英文Hyper Text Mark-up Language(超文本标记语言)的缩写,他是一种制作万维网页面标准语言(标记).相当于定义统一的规则(W3C),大家都来遵守他,这 ...
- struts2学习2
拦截器 //拦截器:第一种创建方式 //拦截器生命周期:随项目的启动而创建,随项目关闭而销毁 public class MyInterceptor implements Interceptor { @ ...
- 网络协议 9 - TCP协议(下)
上次了解了 TCP 建立连接与断开连接的过程,我们发现,TCP 会通过各种“套路”来保证传输数据的安全.除此之外,我们还大概了解了 TCP 包头格式所对应解决的五个问题:顺序问题.丢包问题.连接维护. ...
- 洛谷p1902刺杀大使题解
题目传送门 方法:二分答案+dfs 二分一个mid,此次刺杀的最大伤害,作为判断条件来dfs,二分,更新. 我们二分一个答案mid来表示一个界限,如果当前这个格子的伤害代价比mid小则可以走否则就不走 ...
- 【BigData】Java基础_创建一个订单类
需求描述 定义一个类,描述订单信息订单id订单所属用户(用户对象)订单所包含的商品(不定数量个商品对象)订单总金额订单应付金额: 总金额500~1000,打折85折 总金额1000~150 ...
- NIO (一) NIO是什么
参考文档:java为什么需要NIO:https://liuchi.coding.me/2017/08/01/浅谈Java为什么需要NIO/美团技术团队 NIO浅析:https://tech.meitu ...
- Guava cacha 机制及源码分析
1.ehcahce 什么时候用比较好:2.问题:当有个消息的key不在guava里面的话,如果大量的消息过来,会同时请求数据库吗?还是只有一个请求数据库,其他的等待第一个把数据从DB加载到Guava中 ...
- 可能是全网最好的MySQL重要知识点 | 面试必备
可能是全网最好的MySQL重要知识点 | 面试必备 mp.weixin.qq.com 点击蓝色“程序猿DD”关注我 回复“资源”获取独家整理的学习资料! 标题有点标题党的意思,但希望你在看了文章之后 ...
- cad.net GeometricExtents出错了 调试看不到文字
飞诗: 难道块不能取GeometricExtents GeometryExtentsBestFit 用这个解决 GeometryExtentsBestFit 对动态块也不准 com方式 ...
- PG数据库CPU和内存满负荷运转优化案
1.问题描述 某客户系统采用三层架构:数据库—应用服务—前端应用.其中数据库使用PostgreSQL 10.0作为数据库软件.自周四起,服务器的CPU与内存使用率持续处于过饱合状态,并因此导致了数次宕 ...