实战 迁移学习 VGG19、ResNet50、InceptionV3 实践 猫狗大战 问题

 
参考博客:::https://blog.csdn.net/pengdali/article/details/79050662
 
 
2018年01月13日 12:52:14 pengdali 阅读数 10417
 

一、实践流程

1、数据预处理

主要是对训练数据进行随机偏移、转动等变换图像处理,这样可以尽可能让训练数据多样化

另外处理数据方式采用分批无序读取的形式,避免了数据按目录排序训练

  1.  
    #数据准备
  2.  
    def DataGen(self, dir_path, img_row, img_col, batch_size, is_train):
  3.  
    if is_train:
  4.  
    datagen = ImageDataGenerator(rescale=1./255,
  5.  
    zoom_range=0.25, rotation_range=15.,
  6.  
    channel_shift_range=25., width_shift_range=0.02, height_shift_range=0.02,
  7.  
    horizontal_flip=True, fill_mode='constant')
  8.  
    else:
  9.  
    datagen = ImageDataGenerator(rescale=1./255)
  10.  
     
  11.  
    generator = datagen.flow_from_directory(
  12.  
    dir_path, target_size=(img_row, img_col),
  13.  
    batch_size=batch_size,
  14.  
    shuffle=is_train)
  15.  
     
  16.  
    return generator
2、载入现有模型

这个部分是核心工作,目的是使用ImageNet训练出的权重来做我们的特征提取器,注意这里后面的分类层去掉

  1.  
    base_model = InceptionV3(weights='imagenet', include_top=False, pooling=None,
  2.  
    input_shape=(img_rows, img_cols, color),
  3.  
    classes=nb_classes)

然后是冻结这些层,因为是训练好的

  1.  
    for layer in base_model.layers:
  2.  
    layer.trainable = False

而分类部分,需要我们根据现有需求来新定义的,这里可以根据实际情况自己进行调整,比如这样

  1.  
    x = base_model.output
  2.  
    # 添加自己的全链接分类层
  3.  
    x = GlobalAveragePooling2D()(x)
  4.  
    x = Dense(1024, activation='relu')(x)
  5.  
    predictions = Dense(nb_classes, activation='softmax')(x)

或者

  1.  
    x = base_model.output
  2.  
    #添加自己的全链接分类层
  3.  
    x = Flatten()(x)
  4.  
    predictions = Dense(nb_classes, activation='softmax')(x)
3、训练模型

这里我们用fit_generator函数,它可以避免了一次性加载大量的数据,并且生成器与模型将并行执行以提高效率。比如可以在CPU上进行实时的数据提升,同时在GPU上进行模型训练

  1.  
    history_ft = model.fit_generator(
  2.  
    train_generator,
  3.  
    steps_per_epoch=steps_per_epoch,
  4.  
    epochs=epochs,
  5.  
    validation_data=validation_generator,
  6.  
    validation_steps=validation_steps)

二、猫狗大战数据集

训练数据540M,测试数据270M,大家可以去官网下载

https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data

下载后把数据分成dog和cat两个目录来存放

三、训练

训练的时候会自动去下权值,比如vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5,但是如果我们已经下载好了的话,可以改源代码,让他直接读取我们的下载好的权值,比如在resnet50.py中

1、VGG19

vgg19的深度有26层,参数达到了549M,原模型最后有3个全连接层做分类器所以我还是加了一个1024的全连接层,训练10轮的情况达到了89%

2、ResNet50

ResNet50的深度达到了168层,但是参数只有99M,分类模型我就简单点,一层直接分类,训练10轮的达到了96%的准确率

3、inception_v3

InceptionV3的深度159层,参数92M,训练10轮的结果

这是一层直接分类的结果

这是加了一个512全连接的,大家可以随意调整测试

四、完整的代码

  1.  
    # -*- coding: utf-8 -*-
  2.  
    import os
  3.  
    from keras.utils import plot_model
  4.  
    from keras.applications.resnet50 import ResNet50
  5.  
    from keras.applications.vgg19 import VGG19
  6.  
    from keras.applications.inception_v3 import InceptionV3
  7.  
    from keras.layers import Dense,Flatten,GlobalAveragePooling2D
  8.  
    from keras.models import Model,load_model
  9.  
    from keras.optimizers import SGD
  10.  
    from keras.preprocessing.image import ImageDataGenerator
  11.  
    import matplotlib.pyplot as plt
  12.  
     
  13.  
    class PowerTransferMode:
  14.  
    #数据准备
  15.  
    def DataGen(self, dir_path, img_row, img_col, batch_size, is_train):
  16.  
    if is_train:
  17.  
    datagen = ImageDataGenerator(rescale=1./255,
  18.  
    zoom_range=0.25, rotation_range=15.,
  19.  
    channel_shift_range=25., width_shift_range=0.02, height_shift_range=0.02,
  20.  
    horizontal_flip=True, fill_mode='constant')
  21.  
    else:
  22.  
    datagen = ImageDataGenerator(rescale=1./255)
  23.  
     
  24.  
    generator = datagen.flow_from_directory(
  25.  
    dir_path, target_size=(img_row, img_col),
  26.  
    batch_size=batch_size,
  27.  
    #class_mode='binary',
  28.  
    shuffle=is_train)
  29.  
     
  30.  
    return generator
  31.  
     
  32.  
    #ResNet模型
  33.  
    def ResNet50_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):
  34.  
    color = 3 if RGB else 1
  35.  
    base_model = ResNet50(weights='imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),
  36.  
    classes=nb_classes)
  37.  
     
  38.  
    #冻结base_model所有层,这样就可以正确获得bottleneck特征
  39.  
    for layer in base_model.layers:
  40.  
    layer.trainable = False
  41.  
     
  42.  
    x = base_model.output
  43.  
    #添加自己的全链接分类层
  44.  
    x = Flatten()(x)
  45.  
    #x = GlobalAveragePooling2D()(x)
  46.  
    #x = Dense(1024, activation='relu')(x)
  47.  
    predictions = Dense(nb_classes, activation='softmax')(x)
  48.  
     
  49.  
    #训练模型
  50.  
    model = Model(inputs=base_model.input, outputs=predictions)
  51.  
    sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
  52.  
    model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
  53.  
     
  54.  
    #绘制模型
  55.  
    if is_plot_model:
  56.  
    plot_model(model, to_file='resnet50_model.png',show_shapes=True)
  57.  
     
  58.  
    return model
  59.  
     
  60.  
     
  61.  
    #VGG模型
  62.  
    def VGG19_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):
  63.  
    color = 3 if RGB else 1
  64.  
    base_model = VGG19(weights='imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),
  65.  
    classes=nb_classes)
  66.  
     
  67.  
    #冻结base_model所有层,这样就可以正确获得bottleneck特征
  68.  
    for layer in base_model.layers:
  69.  
    layer.trainable = False
  70.  
     
  71.  
    x = base_model.output
  72.  
    #添加自己的全链接分类层
  73.  
    x = GlobalAveragePooling2D()(x)
  74.  
    x = Dense(1024, activation='relu')(x)
  75.  
    predictions = Dense(nb_classes, activation='softmax')(x)
  76.  
     
  77.  
    #训练模型
  78.  
    model = Model(inputs=base_model.input, outputs=predictions)
  79.  
    sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
  80.  
    model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
  81.  
     
  82.  
    # 绘图
  83.  
    if is_plot_model:
  84.  
    plot_model(model, to_file='vgg19_model.png',show_shapes=True)
  85.  
     
  86.  
    return model
  87.  
     
  88.  
    # InceptionV3模型
  89.  
    def InceptionV3_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True,
  90.  
    is_plot_model=False):
  91.  
    color = 3 if RGB else 1
  92.  
    base_model = InceptionV3(weights='imagenet', include_top=False, pooling=None,
  93.  
    input_shape=(img_rows, img_cols, color),
  94.  
    classes=nb_classes)
  95.  
     
  96.  
    # 冻结base_model所有层,这样就可以正确获得bottleneck特征
  97.  
    for layer in base_model.layers:
  98.  
    layer.trainable = False
  99.  
     
  100.  
    x = base_model.output
  101.  
    # 添加自己的全链接分类层
  102.  
    x = GlobalAveragePooling2D()(x)
  103.  
    x = Dense(1024, activation='relu')(x)
  104.  
    predictions = Dense(nb_classes, activation='softmax')(x)
  105.  
     
  106.  
    # 训练模型
  107.  
    model = Model(inputs=base_model.input, outputs=predictions)
  108.  
    sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
  109.  
    model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
  110.  
     
  111.  
    # 绘图
  112.  
    if is_plot_model:
  113.  
    plot_model(model, to_file='inception_v3_model.png', show_shapes=True)
  114.  
     
  115.  
    return model
  116.  
     
  117.  
    #训练模型
  118.  
    def train_model(self, model, epochs, train_generator, steps_per_epoch, validation_generator, validation_steps, model_url, is_load_model=False):
  119.  
    # 载入模型
  120.  
    if is_load_model and os.path.exists(model_url):
  121.  
    model = load_model(model_url)
  122.  
     
  123.  
    history_ft = model.fit_generator(
  124.  
    train_generator,
  125.  
    steps_per_epoch=steps_per_epoch,
  126.  
    epochs=epochs,
  127.  
    validation_data=validation_generator,
  128.  
    validation_steps=validation_steps)
  129.  
    # 模型保存
  130.  
    model.save(model_url,overwrite=True)
  131.  
    return history_ft
  132.  
     
  133.  
    # 画图
  134.  
    def plot_training(self, history):
  135.  
    acc = history.history['acc']
  136.  
    val_acc = history.history['val_acc']
  137.  
    loss = history.history['loss']
  138.  
    val_loss = history.history['val_loss']
  139.  
    epochs = range(len(acc))
  140.  
    plt.plot(epochs, acc, 'b-')
  141.  
    plt.plot(epochs, val_acc, 'r')
  142.  
    plt.title('Training and validation accuracy')
  143.  
    plt.figure()
  144.  
    plt.plot(epochs, loss, 'b-')
  145.  
    plt.plot(epochs, val_loss, 'r-')
  146.  
    plt.title('Training and validation loss')
  147.  
    plt.show()
  148.  
     
  149.  
     
  150.  
    if __name__ == '__main__':
  151.  
    image_size = 197
  152.  
    batch_size = 32
  153.  
     
  154.  
    transfer = PowerTransferMode()
  155.  
     
  156.  
    #得到数据
  157.  
    train_generator = transfer.DataGen('data/cat_dog_Dataset/train', image_size, image_size, batch_size, True)
  158.  
    validation_generator = transfer.DataGen('data/cat_dog_Dataset/test', image_size, image_size, batch_size, False)
  159.  
     
  160.  
    #VGG19
  161.  
    #model = transfer.VGG19_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=False)
  162.  
    #history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'vgg19_model_weights.h5', is_load_model=False)
  163.  
     
  164.  
    #ResNet50
  165.  
    model = transfer.ResNet50_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=False)
  166.  
    history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'resnet50_model_weights.h5', is_load_model=False)
  167.  
     
  168.  
    #InceptionV3
  169.  
    #model = transfer.InceptionV3_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=True)
  170.  
    #history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'inception_v3_model_weights.h5', is_load_model=False)
  171.  
     
  172.  
    # 训练的acc_loss图
  173.  
    transfer.plot_training(history_ft)

实战 迁移学习 VGG19、ResNet50、InceptionV3 实践 猫狗大战 问题的更多相关文章

  1. keras系列︱迁移学习:利用InceptionV3进行fine-tuning及预测、完美案例(五)

    引自:http://blog.csdn.net/sinat_26917383/article/details/72982230 之前在博客<keras系列︱图像多分类训练与利用bottlenec ...

  2. 基于深度学习和迁移学习的识花实践——利用 VGG16 的深度网络结构中的五轮卷积网络层和池化层,对每张图片得到一个 4096 维的特征向量,然后我们直接用这个特征向量替代原来的图片,再加若干层全连接的神经网络,对花朵数据集进行训练(属于模型迁移)

    基于深度学习和迁移学习的识花实践(转)   深度学习是人工智能领域近年来最火热的话题之一,但是对于个人来说,以往想要玩转深度学习除了要具备高超的编程技巧,还需要有海量的数据和强劲的硬件.不过 Tens ...

  3. 1、VGG16 2、VGG19 3、ResNet50 4、Inception V3 5、Xception介绍——迁移学习

    ResNet, AlexNet, VGG, Inception: 理解各种各样的CNN架构 本文翻译自ResNet, AlexNet, VGG, Inception: Understanding va ...

  4. 1 如何使用pb文件保存和恢复模型进行迁移学习(学习Tensorflow 实战google深度学习框架)

    学习过程是Tensorflow 实战google深度学习框架一书的第六章的迁移学习环节. 具体见我提出的问题:https://www.tensorflowers.cn/t/5314 参考https:/ ...

  5. Google Tensorflow 迁移学习 Inception-v3

    附上代码加数据地址 https://github.com/Liuyubao/transfer-learning ,欢迎参考. 一.Inception-V3模型 1.1 详细了解模型可参考以下论文: [ ...

  6. 【深度学习系列】迁移学习Transfer Learning

    在前面的文章中,我们通常是拿到一个任务,譬如图像分类.识别等,搜集好数据后就开始直接用模型进行训练,但是现实情况中,由于设备的局限性.时间的紧迫性等导致我们无法从头开始训练,迭代一两百万次来收敛模型, ...

  7. 迁移学习︱艺术风格转化:Artistic style-transfer+ubuntu14.0+caffe(only CPU)

    说起来这门技术大多是秀的成分高于实际,但是呢,其也可以作为图像增强的工具,看到一些比赛拿他作训练集扩充,还是一个比较好的思路.如何在caffe上面实现简单的风格转化呢? 好像网上的博文都没有说清楚,而 ...

  8. TensorFlow从1到2(九)迁移学习

    迁移学习基本概念 迁移学习是这两年比较火的一个话题,主要原因是在当前的机器学习中,样本数据的获取是成本最高的一块.而迁移学习可以有效的把原有的学习经验(对于模型就是模型本身及其训练好的权重值)带入到新 ...

  9. 『TensorFlow』迁移学习

    完全版见github:TransforLearning 零.迁移学习 将一个领域的已经成熟的知识应用到其他的场景中称为迁移学习.用神经网络的角度来表述,就是一层层网络中每个节点的权重从一个训练好的网络 ...

随机推荐

  1. IAR STM32F10x_StdPeriph_Driver 3.4转3.6.1库

    1.Fatal Error[Pe1696]: cannot open source file core_cmInstr.h STM32F10x_StdPeriph_Driver 3.4库移植换成3.6 ...

  2. 图解jvm--(一)jvm内存结构

    jvm内存结构 1.程序计数器 1.1 定义 Program Counter Register 程序计数器(寄存器) 作用,记住下一条jvm指令的执行地址 特点 是线程私有的 (唯一)不会存在内存溢出 ...

  3. css限制文字显示字数长度,超出部分自动用省略号显示,防止溢出到第二行

    为了保证页面的整洁美观,在很多的时候,我们常需要隐藏超出长度的文字.这在列表条目,题目,名称等地方常用到. 效果如下: 未限制显示长度,如果超出了会溢出到第二行里.严重影响用户体验和显示效果. 我们在 ...

  4. 通用dao的demo

          代码片段 1. [代码]整型映射工具 ? 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 package org.dave.common.databas ...

  5. StringGrid换行功能

    关闭stringgrid的defaultdrawing功能 StringGrid1.Cells[cCol,cRow] := '测试1'+#13#10+'测试2'; procedure TForm1.S ...

  6. docker-compose 修改zabbix images 添加微信报警插件 时间同步 中文乱码 添加grafana美化zabbix

    我们先来看一下我们要修改得  zabbix.yaml           github   https://github.com/bboysoulcn/awesome-dockercompose ve ...

  7. P1061 判断题

    P1061 判断题 转跳点:

  8. C语言整理复习——指针

    指针是C的精华,不会指针就等于没学C.但指针又是C里最难理解的部分,所以特意写下这篇博客整理思路. 一.指针类型的声明 C的数据类型由整型.浮点型.字符型.布尔型.指针这几部分构成.前四种类型比较好理 ...

  9. 【capstone/ropgadget】环境配置

    具体环境配置可参考 https://github.com/JonathanSalwan/ROPgadget/tree/master 作者给出的安装方式 但具体配置中出现了问题,如引用时出现如下错误: ...

  10. js的执行和调试

    JavaScript 是指在浏览器运行的脚本 脚本就是剧本,在指定场景,特定时间,规定角色的对白,动作,情绪的变化 并且js是同步的,单线程的执行脚本 同步异步 js的运行是同步的, 运行完第一行才会 ...