上一节中,我们采用了一个自定义的网络结构,从头开始训练猫狗大战分类器,最终在使用图像增强的方式下得到了82%的验证准确率。但是,想要将深度学习应用于小型图像数据集,通常不会贸然采用复杂网络并且从头开始训练(training from scratch),因为训练代价高,且很难避免过拟合问题。相对的,通常会采用一种更高效的方法——使用预训练网络。

预训练网络的使用通常有两种方式,一种是利用预训练网络简单提取图像的特征,之后可能会利用这些特征进行其他操作(比如和文本信息结合以用于image caption,或者简单的进行分类);另一种是对预训练的网络进行裁剪和微调,以适应自己的任务。

第一种方式训练代价极低,因为它就是简单提取个特征,不涉及训练;缺点是保存提取出来的特征需要占用一定空间,且无法使用图像增强(而图像增强对于防止小型数据集的过拟合非常重要)。第二种方式可以使用图像增强,但训练代价也会大幅增加。(当然相对于从头训练来说,使用预训练网络的训练代价肯定要低得多。)

这一节中我们以VGG16提取图像特征为例,展示第一种使用方式。该案例接着上一个例子,使用同样的数据集,利用keras中自带的VGG16模型提取图像特征,然后以这些图像特征为输入,训练一个小型分类器。

import numpy as np
from keras.applications.vgg16 import VGG16 #实例化一个VGG16卷积基
#输入维度根据需要自行指定,这里仍然采用上一个例子的维度,卷积基的输出是(None,4,4,512)
conv_base = VGG16(include_top=False, input_shape=(150,150,3))
#conv_base.summary() ###############单纯用VGG16卷积基直接提取特征,不使用图像增强####################
import os
from keras.preprocessing.image import ImageDataGenerator #定义提取图像特征的函数
datagen = ImageDataGenerator(rescale=1./255)
batch_size = 20
def extract_features(directory, sample_count):
#输入:文件路径,样本个数
#返回:指定个数的样本特征,以及对应的标签
features = np.zeros(shape=(sample_count, 4, 4, 512))
labels = np.zeros(shape=(sample_count))
generator = datagen.flow_from_directory(
directory,
target_size=(150,150),
batch_size=batch_size,
class_mode='binary')
i = 0
for inputs_batch, labels_batch in generator: #分别为(20,150,150,3) (20,)
features_batch = conv_base.predict(inputs_batch) #(20,4,4,512)
features[i * batch_size : (i + 1) * batch_size] = features_batch
labels[i * batch_size : (i + 1) * batch_size] = labels_batch
i += 1
if i * batch_size >= sample_count: #读取了指定样本个数后即退出
break
return features, labels #分别提取训练集、验证集、测试集的图像特征
train_dir = r'D:\KaggleDatasets\MyDatasets\dogs-vs-cats-small\train'
validation_dir = r'D:\KaggleDatasets\MyDatasets\dogs-vs-cats-small\validation'
test_dir = r'D:\KaggleDatasets\MyDatasets\dogs-vs-cats-small\test'
train_features, train_labels = extract_features(train_dir, 2000)
validation_features, validation_labels = extract_features(validation_dir, 1000)
test_features, test_labels = extract_features(test_dir, 1000) #将各自的图像特征展平,作为后续Dense层的输入
assert train_features.shape == (2000, 4, 4, 512)
assert validation_features.shape == (1000, 4, 4, 512)
assert test_features.shape == (1000, 4, 4, 512)
train_features = train_features.reshape(2000, 4*4*512)
validation_features = validation_features.reshape(1000, 4*4*512)
test_features = test_features.reshape(1000, 4*4*512) ###################定义并训练一个小型分类器#########################
from keras.models import Model
from keras.layers import Input, Dense, Dropout input = Input(shape=(4*4*512,))
X = Dense(256, activation='relu')(input)
X = Dropout(0.5)(X)
X = Dense(1, activation='sigmoid')(X) model = Model(inputs=input, outputs=X)
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) H = model.fit(train_features, train_labels,
validation_data=(validation_features, validation_labels),
epochs=30, batch_size=64, verbose=1) #######################训练结果可视化############################
import matplotlib.pyplot as plt acc = H.history['acc']
val_acc = H.history['val_acc']
loss = H.history['loss']
val_loss = H.history['val_loss']
epoch = range(1, len(loss) + 1) fig, ax = plt.subplots(1, 2, figsize=(10,4))
fig.subplots_adjust(wspace=0.2)
ax[0].plot(epoch, loss, label='Train loss') #注意不要写成labels
ax[0].plot(epoch, val_loss, label='Validation loss')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Loss')
ax[0].legend()
ax[1].plot(epoch, acc, label='Train acc')
ax[1].plot(epoch, val_acc, label='Validation acc')
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Accuracy')
ax[1].legend()
plt.show()

训练结果如下所示。可以看出,相对于上一个从头开始训练的猫狗分类任务,很轻松的就把验证集准确率由82%提高到90%左右,更重要的是,现在还没有使用重量级武器——图像增强。下一节,我们会使用第二种更常用更高效的方式——模型微调。

CNN基础二:使用预训练网络提取图像特征的更多相关文章

  1. AI:拿来主义——预训练网络(二)

    上一篇文章我们聊的是使用预训练网络中的一种方法,特征提取,今天我们讨论另外一种方法,微调模型,这也是迁移学习的一种方法. 微调模型 为什么需要微调模型?我们猜测和之前的实验,我们有这样的共识,数据量越 ...

  2. 原来CNN是这样提取图像特征的。。。

    对于即将到来的人工智能时代,作为一个有理想有追求的程序员,不懂深度学习(Deep Learning)这个超热的领域,会不会感觉马上就out了?作为机器学习的一个分支,深度学习同样需要计算机获得强大的学 ...

  3. VGG16提取图像特征 (torch7)

    VGG16提取图像特征 (torch7) VGG16 loadcaffe torch7 下载pretrained model,保存到当前目录下 th> caffemodel_url = 'htt ...

  4. CNN基础三:预训练模型的微调

    上一节中,我们利用了预训练的VGG网络卷积基,来简单的提取了图像的特征,并用这些特征作为输入,训练了一个小分类器. 这种方法好处在于简单粗暴,特征提取部分的卷积基不需要训练.但缺点在于,一是别人的模型 ...

  5. 深度学习tensorflow实战笔记 用预训练好的VGG-16模型提取图像特征

    1.首先就要下载模型结构 首先要做的就是下载训练好的模型结构和预训练好的模型,结构地址是:点击打开链接 模型结构如下: 文件test_vgg16.py可以用于提取特征.其中vgg16.npy是需要单独 ...

  6. AI:拿来主义——预训练网络(一)

    我们已经训练过几个神经网络了,识别手写数字,房价预测或者是区分猫和狗,那随之而来就有一个问题,这些训练出的网络怎么用,每个问题我都需要重新去训练网络吗?因为程序员都不太喜欢做重复的事情,因此答案肯定是 ...

  7. Pytorch如何用预训练模型提取图像特征

    方法很简单,你只需要将模型最后的全连接层改成Dropout即可. import torch from torchvision import models # load data x, y = get_ ...

  8. 学习TensorFlow,调用预训练好的网络(Alex, VGG, ResNet etc)

    视觉问题引入深度神经网络后,针对端对端的训练和预测网络,可以看是特征的表达和任务的决策问题(分类,回归等).当我们自己的训练数据量过小时,往往借助牛人已经预训练好的网络进行特征的提取,然后在后面加上自 ...

  9. 从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史(转载)

    转载 https://zhuanlan.zhihu.com/p/49271699 首发于深度学习前沿笔记 写文章   从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史 张 ...

随机推荐

  1. 为什么集合类没有实现Cloneable和Serializable接口?

    为什么集合类没有实现Cloneable和Serializable接口? 克隆(cloning)或者是序列化(serialization)的语义和含义是跟具体的实现相关的.因此,应该由集合类的具体实现来 ...

  2. FMXUI TEXTVIEW代码设置IMAGEINDEX

    FMXUI作为一个开源的控件,真是DELPHIER的福音,向作者致敬.​TEXTVIEW非常好用,在属性面板中有ImageIndex属性,可以方便设置图标,在实际应用中图标状态需要改变,但在代码设置时 ...

  3. 详解如何定义SQL Server外关键字约束

    SQL Server外关键字约束定义了表之间的关系.当一个表中的一个列或多个列的组合和其它表中的主关键字定义相同时,就可以将这些列或列的组合定义为外关键字,并设定它适合哪个表中哪些列相关联.这样,当在 ...

  4. 圆周率Pi是如何计算出来的

    object SparkPi { def main(args: Array[String]) { val spark = SparkSession .builder .appName("Sp ...

  5. 【Elasticsearch】清空指定index/type下的数据

    1.postman请求接口 http://ip:端口/index/type/_delete_by_query?conflicts=proceed body为: { "query": ...

  6. Borůvka (Sollin) 算法求 MST 最小生成树

    基本思路: 用定点数组记录每个子树的最近邻居. 对于每一条边进行处理: 如果这条边连成的两个顶点同属于一个集合,则不处理,否则检测这条边连接的两个子树,如果是连接这两个子树的最小边,则更新 (合并). ...

  7. Flutter样式和布局控件简析(二)

    开始 继续接着分析Flutter相关的样式和布局控件,但是这次内容难度感觉比较高,怕有分析不到位的地方,所以这次仅仅当做一个参考,大家最好可以自己阅读一下代码,应该会有更深的体会. Sliver布局 ...

  8. Http协议面试题(总结)

    Http协议面试题(总结) 一.总结 一句话总结: 主要考常见的状态码,以及https,其它的多抓抓包就熟悉了 1.说一下什么是Http协议? 数据传输的格式规范:对器客户端和 服务器端之间数据传输的 ...

  9. Linux随笔 - 修改主机名

    1.临时修改主机名: hostname 主机名 修改只能临时有效,机器重启后会自动还原. 2.永久修改主机名: 修改hostname文件(路径:/etc/sysconfig/network),把hos ...

  10. 通过生成HFile导入HBase

    要实现DataFrame通过HFile导入HBase有两个关键步骤 第一个是要生成Hfile第二个是HFile导入HBase 测试DataFrame数据来自mysql,如果对读取mysql作为Data ...