【机器学习炼丹术】的炼丹总群已经快满了,要加入的快联系炼丹兄WX:cyx645016617

参考目录:

之前讲过了如何用tensorflow构建数据集,然后这一节课讲解如何用Tensorflow2.0来创建模型。

TF2.0中创建模型的API基本上都放到了它的Keras中了,Keras可以理解为TF的高级API,里面封装了很多的常见网络层、常见损失函数等。 后续会详细介绍keras的全面功能,本篇文章讲解如何构建模型。

1 创建自定义网络层

import tensorflow as tf
import tensorflow.keras as keras class MyLayer(keras.layers.Layer):
def __init__(self, input_dim=32, output_dim=32):
super(MyLayer, self).__init__() w_init = tf.random_normal_initializer()
self.weight = tf.Variable(
initial_value=w_init(shape=(input_dim, output_dim), dtype=tf.float32),
trainable=True) # 如果是false则是不参与梯度下降的变量 b_init = tf.zeros_initializer()
self.bias = tf.Variable(initial_value=b_init(
shape=(output_dim), dtype=tf.float32), trainable=True) def call(self, inputs):
return tf.matmul(inputs, self.weight) + self.bias x = tf.ones((3,5))
my_layer = MyLayer(input_dim=5,
output_dim=10)
out = my_layer(x)
print(out.shape)
>>> (3, 10)

这个就是定义了一个TF的网络层,其实可以看出来和PyTorch定义的方式非常的类似:

  • 这个类要继承tf.keras.layers.Layer,这个pytorch中要继承torch.nn.Module类似;
  • 网络层的组件在__def__中定义,和pytorch的模型类相同;
  • call()和pytorch中的forward()的类似。

上面代码中实现的是一个全连接层的定义,其中可以看到使用tf.random_normal_initializer()来作为参数的初始化器,然后用tf.Variable来产生网络层中的权重变量,通过trainable=True这个参数说明这个权重变量是一个参与梯度下降的可以训练的变量。

我通过tf.ones((3,5))产生一个shape为[3,5]的一个全是1的张量,这里面第一维度的3表示有3个样本,第二维度的5就是表示要放入全连接层的数据(全连接层的输入是5个神经元);然后设置的全连接层的输出神经元数量是10,所以最后的输出是(3,10)。

2 创建一个完整的CNN

import tensorflow as tf
import tensorflow.keras as keras class CBR(keras.layers.Layer):
def __init__(self,output_dim):
super(CBR,self).__init__()
self.conv = keras.layers.Conv2D(filters=output_dim, kernel_size=4, padding='same', strides=1)
self.bn = keras.layers.BatchNormalization(axis=3)
self.ReLU = keras.layers.ReLU() def call(self, inputs):
inputs = self.conv(inputs)
inputs = self.ReLU(self.bn(inputs))
return inputs class MyNet(keras.Model):
def __init__ (self,input_dim=3):
super(MyNet,self).__init__()
self.cbr1 = CBR(16)
self.maxpool1 = keras.layers.MaxPool2D(pool_size=(2,2))
self.cbr2 = CBR(32)
self.maxpool2 = keras.layers.MaxPool2D(pool_size=(2,2)) def call(self, inputs):
inputs = self.maxpool1(self.cbr1(inputs))
inputs = self.maxpool2(self.cbr2(inputs))
return inputs model = MyNet(3)
data = tf.random.normal((16,224,224,3))
output = model(data)
print(output.shape)
>>> (16, 56, 56, 32)

这个是构建了一个非常简单的卷积网络,结构是常见的:卷积层+BN层+ReLU层。可以发现这里继承的一个tf.keras.Model这个类。

2.1 keras.Model vs keras.layers.Layer

Model比Layer的功能更多,反过来说,Layer的功能更精简专一。

  • Layer:仅仅用作张量的操作,输入一个张量,输出也要求是一个张量,对张量的操作都可以用Layer来封装;
  • Model:一个更加复杂的结构,由多个Layer组成。 Model的话,可以使用.fit(),.evaluate().predict()等方法来快速训练。保存和加载模型也是在Model这个级别进行的。

现在说一说上面的代码和pytorch中的区别,作为一个对比学习、也作为一个对pytorch的回顾:

  • 卷积层Conv2D中,Keras中不用输入输入的通道数,filters就是卷积后的输出特征图的通道数;而PyTorch的卷积层是需要输入两个通道数的参数,一个是输入特征图的通道数,一个是输出特征图的通道数;
  • keras.layers.BatchNormalization(axis=3)是BN层,这里的axis=3说明第三个维度(从0开始计数)是通道数,是需要作为批归一化的维度(这个了解BN算法的朋友应该可以理解吧,不了解的话去重新看我之前剖析BN层算法的那个文章吧,在文章末尾有相关链接)。pytorch的图像的四个维度是:
\[【样本数量,通道数,width,height】
\]

而tensorflow是:

\[【样本数量,width,height,通道数】
\]

总之,学了pytorch之后,再看keras的话,对照的keras的API,很多东西都直接就会了,两者的API越来越相似了。

上面最后输出是(16, 56, 56, 32),输入的是\(224\times 224\)的维度,然后经过两个最大池化层,就变成了\(56\times 56\)了。

到此为止,我们现在应该是可以用keras来构建模型了。

【小白学PyTorch】18 TF2构建自定义模型的更多相关文章

  1. 【小白学PyTorch】4 构建模型三要素与权重初始化

    文章目录: 目录 1 模型三要素 2 参数初始化 3 完整运行代码 4 尺寸计算与参数计算 1 模型三要素 三要素其实很简单 必须要继承nn.Module这个类,要让PyTorch知道这个类是一个Mo ...

  2. 【小白学PyTorch】20 TF2的eager模式与求导

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

  3. 【小白学PyTorch】19 TF2模型的存储与载入

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

  4. 【小白学PyTorch】15 TF2实现一个简单的服装分类任务

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

  5. 小白学PyTorch 动态图与静态图的浅显理解

    文章来自公众号[机器学习炼丹术],回复"炼丹"即可获得海量学习资料哦! 目录 1 动态图的初步推导 2 动态图的叶子节点 3. grad_fn 4 静态图 本章节缕一缕PyTorc ...

  6. 【小白学PyTorch】5 torchvision预训练模型与数据集全览

    文章来自:微信公众号[机器学习炼丹术].一个ai专业研究生的个人学习分享公众号 文章目录: 目录 torchvision 1 torchvision.datssets 2 torchvision.mo ...

  7. 【小白学PyTorch】6 模型的构建访问遍历存储(附代码)

    文章转载自微信公众号:机器学习炼丹术.欢迎大家关注,这是我的学习分享公众号,100+原创干货. 文章目录: 目录 1 模型构建函数 1.1 add_module 1.2 ModuleList 1.3 ...

  8. 【小白学PyTorch】16 TF2读取图片的方法

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.NLP等多个学术交流分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx645016617. 参考 ...

  9. 【小白学PyTorch】11 MobileNet详解及PyTorch实现

    文章来自微信公众号[机器学习炼丹术].我是炼丹兄,欢迎加我微信好友交流学习:cyx645016617. @ 目录 1 背景 2 深度可分离卷积 2.2 一般卷积计算量 2.2 深度可分离卷积计算量 2 ...

随机推荐

  1. soso官方:搜索引擎的对检索结果常用的评测方法

    http://www.wocaoseo.com/thread-188-1-1.html       很久很久以前,搜索引擎还不象今天的百花齐放,人们对它的要求较低,只要它能把互连网上相关的网站搜出来, ...

  2. Htmlcss学习笔记1——html基础

    Hyper text markup language 超文本标签语言.不是一种编程语言,而是一种标记语言标记语言是一套标记标签 开发工具 chrome subline vscode photoshop ...

  3. asterisk 传真服务器配置

    摘要: asterisk 可以作为电子传真服务器,进行收发电子传真.但是配置起来,比较麻烦,需要一番折腾.在这儿分享一下电子传真的配置,希望对朋友们有所帮助. 正题: asterisk 如果需要收发电 ...

  4. 据说是最好的记忆工具——Anki

    http://www.ankichina.net/ .u1s1,确实挺好用,自建题库,全程自助. 可以插入文字.图片.音频,会安排合理的复习频率,可以随时同步,电脑手机版本全.

  5. Codeforces1365

    AC代码 A. Matrix Game 对于给定矩阵,剩余可用的位置的数目是确定的,根据奇偶性判断就完事了. B. Trouble Sort 如果数组\(b\)有0有1,那么Yes.否则只有数组\(a ...

  6. [BUUOJ记录] [ACTF2020 新生赛]BackupFile、Exec

    两道题都比较简单,所以放到一块记下来吧,不是水博客,师傅们轻点打 BackupFile 题目提示“Try to find out source file!”,访问备份文件/index.php.bak获 ...

  7. JVM学习第三天(JVM的执行子系统)之类加载机制

    好几天没有学习了,前几天因为导出的事情,一直在忙,今天继续学习, 其实今天我也遇到了一个问题,如果有会的兄弟可以评论留给我谢谢; 问题:fastJSON中JSONObject.parseObject做 ...

  8. 【Azure DevOps系列】使ASP.NET Core应用程序托管到Azure Web App Service

    使用Azure DevOps Project设置ASP.NET项目 我们需要先在Azure面板中创建一个Azure WebApp服务,此处步骤我将省略,然后点击部署中心如下图所示: 此处我选择的是Az ...

  9. JS实现串行请求

    使用async和await var fn = async function(promiseArr) { for(let i = 0,len = arr.length; i<len; i++) { ...

  10. mac如何安装YaPi

    首先介绍一下YaPi是干什么的. 帮助开发者轻松创建.发布.维护 API,YApi 还为用户提供了优秀的交互体验,开发人员只需利用平台提供的接口数据写入工具以及简单的点击操作就可以实现接口的管理.免费 ...