深度学习趣谈:什么是迁移学习?(附带Tensorflow代码实现)
一.迁移学习的概念
什么是迁移学习呢?迁移学习可以由下面的这张图来表示:

这张图最左边表示了迁移学习也就是把已经训练好的模型和权重直接纳入到新的数据集当中进行训练,但是我们只改变之前模型的分类器(全连接层和softmax/sigmoid),这样就可以节省训练的时间的到一个新训练的模型了!
但是为什么可以这么做呢?
二.为什么可以使用迁移学习?
一般在图像分类的问题当中,卷积神经网络最前面的层用于识别图像最基本的特征,比如物体的轮廓,颜色,纹理等等,而后面的层才是提取图像抽象特征的关键,因此最好的办法是我们只需要保留卷积神经网络当中底层的权重,对顶层和新的分类器进行训练即可。那么在图像分类问题当中,我们如何使用迁移学习呢?一般使用迁移学习,也就是预训练神经网络的步骤如下;
1.冻结预训练网络的卷积层权重
2.置换旧的全连接层,换上新的全连接层和分类器
3.解冻部分顶部的卷积层,保留底部卷积神经网络的权重
4.同时对卷积层和全连接层的顶层进行联合训练,得到新的网络权重
既然我们知道了迁移学习的基本特点,何不试试看呢?
三.迁移学习的代码实现
我们使用迁移学习的方法来进行猫狗图像的分类识别,猫猫的图像在我的文件夹里如下图所示:

然后导包:
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np
import glob
import os
获取图片的路径,标签,制作batch数据,图片的路径我存放在了F盘下的train文件夹下,路径为:F://UNIVERSITY STUDY/AI/dataset/catdog/train/。
代码如下:
keras=tf.keras
layers=tf.keras.layers
#得到图片的所有label
train_image_label=[int(p.split("\\")[1]=='cat') for p in train_image_path ] #现在我们的jpg文件进行解码,变成三维矩阵
def load_preprosess_image(path,label):
#读取路径
image=tf.io.read_file(path)
#解码
image=tf.image.decode_jpeg(image,channels=3)#彩色图像为3个channel
#将图像改变为同样的大小,利用裁剪或者扭曲,这里应用了扭曲
image=tf.image.resize(image,[360,360])
#随机裁剪图像
image=tf.image.random_crop(image,[256,256,3])
#随机上下翻转图像
image=tf.image.random_flip_left_right(image)
#随机上下翻转
image=tf.image.random_flip_up_down(image)
#随机改变图像的亮度
image=tf.image.random_brightness(image,0.5)
#随机改变对比度
image=tf.image.random_contrast(image,0,1)
#改变数据类型
image=tf.cast(image,tf.float32)
#将图像进行归一化
image=image/255
#现在还需要对label进行处理,我们现在是列表[1,2,3],
#需要变成[[1].[2].[3]]
label=tf.reshape(label,[1])
return image,label train_image_ds=tf.data.Dataset.from_tensor_slices((train_image_path,train_image_label))
AUTOTUNE=tf.data.experimental.AUTOTUNE#根据计算机性能进行运算速度的调整
train_image_ds=train_image_ds.map(load_preprosess_image,num_parallel_calls=AUTOTUNE)
#现在train_image_ds就读取进来了,现在进行乱序和batchsize的规定
BATCH_SIZE=32
train_count=len(train_image_path)
#现在设置batch和乱序
train_image_ds=train_image_ds.shuffle(train_count).batch(BATCH_SIZE)
train_image_ds=train_image_ds.prefetch(AUTOTUNE)#预处理一部分处理,准备读取 imags,labels=iter(train_image_ds).next()#放到生成器里,单独取出数据
plt.imshow(imags[30])
显示出制作batch数据当中的猫猫图片:

搭建网络架构,引入经典图像分类模型VGG16,同时调用VGG16预训练网络的权重。最后调整卷积层的最后三层为可训练的,也就是说顶层的卷积神经网路可以和全连接层分类器一起进行联合训练:
conv_base=keras.applications.VGG16(weights='imagenet',include_top=False)
#weights设置为imagenet表示使用imagebnet训练出来的权重,如果填写False表示不使用权重
#仅适用网络架构,include_top表示是否使用用于分类的全连接层
#我们在这个卷积层上添加全连接层和输出层即可
model=keras.Sequential()
model.add(conv_base)
model.add(layers.GlobalAveragePooling2D())
model.add(layers.Dense(512,activation='relu'))
model.add(layers.Dense(1,activation='sigmoid')) conv_base.trainable=True#一共有19层
for layer in conv_base.layers[:-3]:
layer.trainable=False
#从第一层到倒数第三层重新设置为是不可训练的,现在卷积的顶层已经解冻,开始联合训练 #编译这个网络
model.compile(optimizer=keras.optimizers.Adam(lr=0.001),
loss='binary_crossentropy',
metrics=['acc']) history=model.fit(
train_image_ds,
steps_per_epoch=train_count//BATCH_SIZE,
epochs=1
)
仅仅训练一个epoch的结果如下所示;
Train for 62 steps
62/62 [==============================] - 469s 8s/step - loss: 0.6323 - acc: 0.6159
一次迭代准确率已经达到了百分之六十。怎么样呢?你现在对迁移学习有一定的感觉了吗?
深度学习趣谈:什么是迁移学习?(附带Tensorflow代码实现)的更多相关文章
- 《趣谈 Linux 操作系统》学习笔记(一):为什么要学 Linux 及学习路径
前言:学习的课程来自极客时间的专栏<趣谈 Linux 操作系统>,作者用形象化的比喻和丰富的图片让课程变得比较易懂,为了避免知识看过就忘,打算通过写学习笔记的形式记录自己的学习过程. Li ...
- Linux内核学习趣谈
本文原创是freas_1990,转载请标明出处:http://blog.csdn.net/freas_1990/article/details/9304991 从大二开始学习Linux内核,到现在已经 ...
- 深度学习原理与框架-Alexnet(迁移学习代码) 1.sys.argv[1:](控制台输入的参数获取第二个参数开始) 2.tf.split(对数据进行切分操作) 3.tf.concat(对数据进行合并操作) 4.tf.variable_scope(指定w的使用范围) 5.tf.get_variable(构造和获得参数) 6.np.load(加载.npy文件)
1. sys.argv[1:] # 在控制台进行参数的输入时,只使用第二个参数以后的数据 参数说明:控制台的输入:python test.py what, 使用sys.argv[1:],那么将获得w ...
- 《趣谈 Linux 操作系统》学习笔记(二):对 Linux 操作系统的理解
首先,我们知道操作系统是管理和控制计算机硬件与软件资源的计算机程序.这里把操作系统想象为一个软件外包公司,其内核就相当于这家外包公司的老板,那么我们可以把自己的角色切换成这家外包公司的老板,设身处地的 ...
- 【深度学习系列】迁移学习Transfer Learning
在前面的文章中,我们通常是拿到一个任务,譬如图像分类.识别等,搜集好数据后就开始直接用模型进行训练,但是现实情况中,由于设备的局限性.时间的紧迫性等导致我们无法从头开始训练,迭代一两百万次来收敛模型, ...
- 深挖计算机基础:趣谈Linux操作系统学习笔记
参考极客时间专栏<趣谈Linux操作系统>学习笔记 核心原理篇:内存管理 趣谈Linux操作系统学习笔记:第二十讲 趣谈Linux操作系统学习笔记:第二十一讲 趣谈Linux操作系统学习笔 ...
- 用迁移学习创造的通用语言模型ULMFiT,达到了文本分类的最佳水平
https://www.jqr.com/article/000225 这篇文章的目的是帮助新手和外行人更好地了解我们新论文,我们的论文展示了如何用更少的数据自动将文本分类,同时精确度还比原来的方法高. ...
- [DeeplearningAI笔记]卷积神经网络2.9-2.10迁移学习与数据增强
4.2深度卷积网络 觉得有用的话,欢迎一起讨论相互学习~Follow Me 2.9迁移学习 迁移学习的基础知识已经介绍过,本篇博文将介绍提高的部分. 提高迁移学习的速度 可以将迁移学习模型冻结的部分看 ...
- 迁移学习(Transformer),面试看这些就够了!(附代码)
1. 什么是迁移学习 迁移学习(Transformer Learning)是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中.迁移学习是通过从已学习的相 ...
随机推荐
- JMeter+Grafana+Influxdb搭建可视化性能测试监控平台(使用了docker)
[运行自定义镜像搭建监控平台] 继上一篇的帖子 ,上一篇已经展示了如何自定义docker镜像,大家操作就行 或者 用我已经自定义好了的镜像,直接pull就行 下面我简单介绍pull下来后如何使用 拉取 ...
- 13.DRF-版本
Django rest framework源码分析(4)----版本 版本 新建一个工程Myproject和一个app名为api (1)api/models.py from django.db imp ...
- 用VMware克隆CentOS 6.5如何进行网络设置
我们使用虚拟机的克隆工具克隆出了一个电脑,电脑连接采用nat方式 111电脑对于的ip地址设置如下 [root@localhost ~]# cd /etc/sysconfig/network-scri ...
- redis高级命令3哨兵模式
redis的哨兵模式 现在我们在从服务器1.222上让该从服务器作为哨兵 首先将redis安装包文件下的sentinel.conf文件复制到/usr/local/redis/etc目录下 然后修改se ...
- Python 简明教程 --- 14,Python 数据结构进阶
微信公众号:码农充电站pro 个人主页:https://codeshellme.github.io 如果你发现特殊情况太多,那很可能是用错算法了. -- Carig Zerouni 目录 前几节我们介 ...
- 入门大数据---HDFS,Zookeeper,ZookeeperFailOverController(简称:ZKFC),JournalNode是什么?
HDFS介绍: 简述: Hadoop Distributed File System(HDFS)是一种分布式文件系统,设计用于在商用硬件上运行.它与现有的分布式文件系统有许多相似之处.但是,与其他分布 ...
- Spring9——通过用Aware接口使用Spring底层组件、环境切换
通过用Aware接口使用Spring底层组件 能够供我们使用的组件,都是Aware的子接口. ApplicationContextAware:实现步骤: (1)实现Applic ...
- Python实用笔记 (21)面向对象编程——获取对象信息
当我们拿到一个对象的引用时,如何知道这个对象是什么类型.有哪些方法呢? 使用type() 首先,我们来判断对象类型,使用type()函数: 基本类型都可以用type()判断: >>> ...
- 传统声学模型之HMM和GMM
声学模型是指给定声学符号(音素)的情况下对音频特征建立的模型. 数学表达 用 \(X\) 表示音频特征向量 (观察向量),用 \(S\) 表示音素 (隐藏/内部状态),声学模型表示为 \(P(X|S) ...
- 打造属于你的聊天室(WebSocket)
SpringBoot 是为了简化 Spring 应用的创建.运行.调试.部署等一系列问题而诞生的产物,自动装配的特性让我们可以更好的关注业务本身而不是外部的XML配置,我们只需遵循规范,引入相关的依赖 ...