# -*- coding: UTF-8 -*-
import keras
from keras import Model
from keras.applications import VGG16
from keras.callbacks import TensorBoard, ModelCheckpoint
from keras.layers import Flatten, Dense, Dropout, GlobalAveragePooling2D
from keras.models import load_model
from keras.preprocessing import image
from PIL import ImageFile
import numpy as np
import tensorflow as tf
from keras.preprocessing.image import ImageDataGenerator
from datetime import datetime
TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now())
ImageFile.LOAD_TRUNCATED_IMAGES = True EPOCHS = 30
BATCH_SIZE = 16
DATA_TRAIN_PATH = 'D:/data/train' def Train():
#-------------准备数据--------------------------
#数据集目录应该是 train/LabelA/1.jpg train/LabelB/1.jpg这样
gen = ImageDataGenerator(rescale=1. / 255)
train_generator = gen.flow_from_directory(DATA_TRAIN_PATH, (224,224)), shuffle=False,
batch_size=BATCH_SIZE, class_mode='categorical') #-------------加载VGG模型并且添加自己的层----------------------
#这里自己添加的层需要不断调整超参数来提升结果,输出类别更改softmax层即可 #参数说明:inlucde_top:是否包含最上方的Dense层,input_shape:输入的图像大小(width,height,channel)
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
x = base_model.output
x=Flatten()(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(1, activation='sigmoid')(x)
predictions = Dense(2, activation='softmax')(x)
model = Model(input=base_model.input, output=predictions) #-----------控制需要FineTune的层数,不FineTune的就直接冻结
for layer in base_model.layers:
layer.trainable = False #----------编译,设置优化器,损失函数,性能指标
model.compile(optimizer='rmsprop',
loss='binary_crossentropy', metrics=['accuracy']) #----------设置tensorboard,用来观察acc和loss的曲线---------------
tbCallBack = TensorBoard(log_dir='./logs/' + TIMESTAMP, # log 目录
histogram_freq=0, # 按照何等频率(epoch)来计算直方图,0为不计算
batch_size=16, # 用多大量的数据计算直方图
write_graph=True, # 是否存储网络结构图
write_grads=True, # 是否可视化梯度直方图
write_images=True, # 是否可视化参数
embeddings_freq=0,
embeddings_layer_names=None,
embeddings_metadata=None) #---------设置自动保存点,acc最好的时候就会自动保存一次,会覆盖之前的存档---------------
checkpoint = ModelCheckpoint(filepath='HatNewModel.h5', monitor='acc', mode='auto', save_best_only='True') #----------开始训练---------------------------------------------
model.fit_generator(generator=train_generator,
epochs=EPOCHS,
callbacks=[tbCallBack,checkpoint],
verbose=2
) #-------------预测单个图像--------------------------------------
def Predict(imgPath):
model = load_model(SAVE_MODEL_NAME)
img = image.load_img(imgPath, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
res = model.predict(x)
print(np.argmax(res, axis=1)[0])

以上运行环境:

Keras2.1.4

Tensorflow-gpu 1.5

CUDA9.0

cudnn7.0

python3.5

[深度学习]Keras利用VGG进行迁移学习模板的更多相关文章

  1. 深度学习原理与框架-Alexnet(迁移学习代码) 1.sys.argv[1:](控制台输入的参数获取第二个参数开始) 2.tf.split(对数据进行切分操作) 3.tf.concat(对数据进行合并操作) 4.tf.variable_scope(指定w的使用范围) 5.tf.get_variable(构造和获得参数) 6.np.load(加载.npy文件)

    1. sys.argv[1:]  # 在控制台进行参数的输入时,只使用第二个参数以后的数据 参数说明:控制台的输入:python test.py what, 使用sys.argv[1:],那么将获得w ...

  2. 吴裕雄--天生自然python Google深度学习框架:Tensorflow实现迁移学习

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  3. 基于深度学习和迁移学习的识花实践——利用 VGG16 的深度网络结构中的五轮卷积网络层和池化层,对每张图片得到一个 4096 维的特征向量,然后我们直接用这个特征向量替代原来的图片,再加若干层全连接的神经网络,对花朵数据集进行训练(属于模型迁移)

    基于深度学习和迁移学习的识花实践(转)   深度学习是人工智能领域近年来最火热的话题之一,但是对于个人来说,以往想要玩转深度学习除了要具备高超的编程技巧,还需要有海量的数据和强劲的硬件.不过 Tens ...

  4. 【深度学习系列】迁移学习Transfer Learning

    在前面的文章中,我们通常是拿到一个任务,譬如图像分类.识别等,搜集好数据后就开始直接用模型进行训练,但是现实情况中,由于设备的局限性.时间的紧迫性等导致我们无法从头开始训练,迭代一两百万次来收敛模型, ...

  5. 深度学习趣谈:什么是迁移学习?(附带Tensorflow代码实现)

    一.迁移学习的概念 什么是迁移学习呢?迁移学习可以由下面的这张图来表示: 这张图最左边表示了迁移学习也就是把已经训练好的模型和权重直接纳入到新的数据集当中进行训练,但是我们只改变之前模型的分类器(全连 ...

  6. TensorFlow迁移学习的识别花试验

    最近学习了TensorFlow,发现一个模型叫vgg16,然后搭建环境跑了一下,觉得十分神奇,而且准确率十分的高.又上了一节选修课,关于人工智能,老师让做一个关于人工智能的试验,于是觉得vgg16很不 ...

  7. 迁移学习(Transformer),面试看这些就够了!(附代码)

    1. 什么是迁移学习 迁移学习(Transformer Learning)是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中.迁移学习是通过从已学习的相 ...

  8. PyTorch专栏(五):迁移学习

    专栏目录: 第一章:PyTorch之简介与下载 PyTorch简介 PyTorch环境搭建 第二章:PyTorch之60分钟入门 PyTorch入门 PyTorch自动微分 PyTorch神经网络 P ...

  9. 迁移学习( Transfer Learning )

    在传统的机器学习的框架下,学习的任务就是在给定充分训练数据的基础上来学习一个分类模型:然后利用这个学习到的模型来对测试文档进行分类与预测.然而,我们看到机器学习算法在当前的Web挖掘研究中存在着一个关 ...

  10. 【迁移学习】2010-A Survey on Transfer Learning

    资源:http://www.cse.ust.hk/TL/ 简介: 一个例子: 关于照片的情感分析. 源:比如你之前已经搜集了大量N种类型物品的图片进行了大量的人工标记(label),耗费了巨大的人力物 ...

随机推荐

  1. 一键体验 Istio

    背景介绍 Istio 是一种服务网格,是一种现代化的服务网络层,它提供了一种透明.独立于语言的方法,以灵活且轻松地实现应用网络功能自动化.它是一种管理构成云原生应用的不同微服务的常用解决方案.Isti ...

  2. laravel 报错 AUTH` failed: ERR Client sent AUTH, but no password is set

    明明没有设置redis密码.访问时候却报错 在代码里面的databases.php 改成这样就可以了.predis新版也会有取不到passwor的时候.改成我截图那样也可以.他默认取的是default ...

  3. winscp报错Server sent passive reply with unroutable address. Using server address instead

    找了一堆没用. 最后终于 1.使用winSCP连接ftp时,编辑会话,单击高级. 2.进入高级设置之后,单击连接,查看连接模式,把被动模式的勾,勾掉. 3.单击确定,然后保存配置,重新连接FTP,OK

  4. k8s集权IP更换

    -.背景描述 背景:在场内进行部署完成后标准版产品,打包服务器到客户现场后服务不能正常使用,因为客户现场的IP地址不能再使用场内的IP,导致部署完的产品环境在客户现场无法使用:此方案就是针对这一问题撰 ...

  5. 华为云 MRS 基于 Apache Hudi 极致查询优化的探索实践

    背景 湖仓一体(LakeHouse)是一种新的开放式架构,它结合了数据湖和数据仓库的最佳元素,是当下大数据领域的重要发展方向. 华为云早在2020年就开始着手相关技术的预研,并落地在华为云 Fusio ...

  6. 扫雷(哈希+bfs)

    扫雷 题目描述: 小明最近迷上了一款名为<扫雷>的游戏. 其中有一个关卡的任务如下: 在一个二维平面上放置着 n 个炸雷,第 i 个炸雷 (x\(_i\),y\(_i\),r\(_i\)) ...

  7. 16.python中的回收机制

    python中的垃圾回收机制是以引用计数器为主,标记清除和分代回收为辅的 + 缓存机制 1.引用计数器 在python内部维护了一个名为refchain的环状双向链表,在python中创建的任何对象都 ...

  8. java学习之JSP

    0x00前言 JSP:全拼写:java Server pages:java 服务器端页面 可以理解为一个特殊的页面:可以定义html代码也可以定义java的代码 定义:JSP是简化Servlet编写的 ...

  9. 【翻译】Thymeleaf – Spring Security集成模块

    原文链接:Thymeleaf - Spring Security integration modules 来源:thymeleaf/thymeleaf-extras-springsecurity自述文 ...

  10. TCN代码详解-Torch (误导纠正)

    TCN代码详解-Torch (误导纠正) 1. 绪论 TCN网络由Shaojie Bai, J. Zico Kolter, Vladlen Koltun 三人于2018提出.对于序列预测而言,通常考虑 ...