# -*- coding: UTF-8 -*-
import keras
from keras import Model
from keras.applications import VGG16
from keras.callbacks import TensorBoard, ModelCheckpoint
from keras.layers import Flatten, Dense, Dropout, GlobalAveragePooling2D
from keras.models import load_model
from keras.preprocessing import image
from PIL import ImageFile
import numpy as np
import tensorflow as tf
from keras.preprocessing.image import ImageDataGenerator
from datetime import datetime
TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now())
ImageFile.LOAD_TRUNCATED_IMAGES = True EPOCHS = 30
BATCH_SIZE = 16
DATA_TRAIN_PATH = 'D:/data/train' def Train():
#-------------准备数据--------------------------
#数据集目录应该是 train/LabelA/1.jpg train/LabelB/1.jpg这样
gen = ImageDataGenerator(rescale=1. / 255)
train_generator = gen.flow_from_directory(DATA_TRAIN_PATH, (224,224)), shuffle=False,
batch_size=BATCH_SIZE, class_mode='categorical') #-------------加载VGG模型并且添加自己的层----------------------
#这里自己添加的层需要不断调整超参数来提升结果,输出类别更改softmax层即可 #参数说明:inlucde_top:是否包含最上方的Dense层,input_shape:输入的图像大小(width,height,channel)
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
x = base_model.output
x=Flatten()(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(1, activation='sigmoid')(x)
predictions = Dense(2, activation='softmax')(x)
model = Model(input=base_model.input, output=predictions) #-----------控制需要FineTune的层数,不FineTune的就直接冻结
for layer in base_model.layers:
layer.trainable = False #----------编译,设置优化器,损失函数,性能指标
model.compile(optimizer='rmsprop',
loss='binary_crossentropy', metrics=['accuracy']) #----------设置tensorboard,用来观察acc和loss的曲线---------------
tbCallBack = TensorBoard(log_dir='./logs/' + TIMESTAMP, # log 目录
histogram_freq=0, # 按照何等频率(epoch)来计算直方图,0为不计算
batch_size=16, # 用多大量的数据计算直方图
write_graph=True, # 是否存储网络结构图
write_grads=True, # 是否可视化梯度直方图
write_images=True, # 是否可视化参数
embeddings_freq=0,
embeddings_layer_names=None,
embeddings_metadata=None) #---------设置自动保存点,acc最好的时候就会自动保存一次,会覆盖之前的存档---------------
checkpoint = ModelCheckpoint(filepath='HatNewModel.h5', monitor='acc', mode='auto', save_best_only='True') #----------开始训练---------------------------------------------
model.fit_generator(generator=train_generator,
epochs=EPOCHS,
callbacks=[tbCallBack,checkpoint],
verbose=2
) #-------------预测单个图像--------------------------------------
def Predict(imgPath):
model = load_model(SAVE_MODEL_NAME)
img = image.load_img(imgPath, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
res = model.predict(x)
print(np.argmax(res, axis=1)[0])

以上运行环境:

Keras2.1.4

Tensorflow-gpu 1.5

CUDA9.0

cudnn7.0

python3.5

[深度学习]Keras利用VGG进行迁移学习模板的更多相关文章

  1. 深度学习原理与框架-Alexnet(迁移学习代码) 1.sys.argv[1:](控制台输入的参数获取第二个参数开始) 2.tf.split(对数据进行切分操作) 3.tf.concat(对数据进行合并操作) 4.tf.variable_scope(指定w的使用范围) 5.tf.get_variable(构造和获得参数) 6.np.load(加载.npy文件)

    1. sys.argv[1:]  # 在控制台进行参数的输入时,只使用第二个参数以后的数据 参数说明:控制台的输入:python test.py what, 使用sys.argv[1:],那么将获得w ...

  2. 吴裕雄--天生自然python Google深度学习框架:Tensorflow实现迁移学习

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  3. 基于深度学习和迁移学习的识花实践——利用 VGG16 的深度网络结构中的五轮卷积网络层和池化层,对每张图片得到一个 4096 维的特征向量,然后我们直接用这个特征向量替代原来的图片,再加若干层全连接的神经网络,对花朵数据集进行训练(属于模型迁移)

    基于深度学习和迁移学习的识花实践(转)   深度学习是人工智能领域近年来最火热的话题之一,但是对于个人来说,以往想要玩转深度学习除了要具备高超的编程技巧,还需要有海量的数据和强劲的硬件.不过 Tens ...

  4. 【深度学习系列】迁移学习Transfer Learning

    在前面的文章中,我们通常是拿到一个任务,譬如图像分类.识别等,搜集好数据后就开始直接用模型进行训练,但是现实情况中,由于设备的局限性.时间的紧迫性等导致我们无法从头开始训练,迭代一两百万次来收敛模型, ...

  5. 深度学习趣谈:什么是迁移学习?(附带Tensorflow代码实现)

    一.迁移学习的概念 什么是迁移学习呢?迁移学习可以由下面的这张图来表示: 这张图最左边表示了迁移学习也就是把已经训练好的模型和权重直接纳入到新的数据集当中进行训练,但是我们只改变之前模型的分类器(全连 ...

  6. TensorFlow迁移学习的识别花试验

    最近学习了TensorFlow,发现一个模型叫vgg16,然后搭建环境跑了一下,觉得十分神奇,而且准确率十分的高.又上了一节选修课,关于人工智能,老师让做一个关于人工智能的试验,于是觉得vgg16很不 ...

  7. 迁移学习(Transformer),面试看这些就够了!(附代码)

    1. 什么是迁移学习 迁移学习(Transformer Learning)是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中.迁移学习是通过从已学习的相 ...

  8. PyTorch专栏(五):迁移学习

    专栏目录: 第一章:PyTorch之简介与下载 PyTorch简介 PyTorch环境搭建 第二章:PyTorch之60分钟入门 PyTorch入门 PyTorch自动微分 PyTorch神经网络 P ...

  9. 迁移学习( Transfer Learning )

    在传统的机器学习的框架下,学习的任务就是在给定充分训练数据的基础上来学习一个分类模型:然后利用这个学习到的模型来对测试文档进行分类与预测.然而,我们看到机器学习算法在当前的Web挖掘研究中存在着一个关 ...

  10. 【迁移学习】2010-A Survey on Transfer Learning

    资源:http://www.cse.ust.hk/TL/ 简介: 一个例子: 关于照片的情感分析. 源:比如你之前已经搜集了大量N种类型物品的图片进行了大量的人工标记(label),耗费了巨大的人力物 ...

随机推荐

  1. vlunhub靶场之EMPIRE: LUPINONE

    准备: 攻击机:虚拟机kali.本机win10. 靶机:EMPIRE: LUPINONE,网段地址我这里设置的桥接,所以与本机电脑在同一网段,下载地址:https://download.vulnhub ...

  2. 不妨试试更快更小更灵活Java开发框架Solon

    @ 目录 概述 定义 性能 架构 实战 Solon Web示例 Solon Mybatis-Plus示例 Solon WebSocket示例 Solon Remoting RPC示例 Solon Cl ...

  3. go-zero docker-compose搭建课件服务(四):生成Dockerfile

    0.转载 go-zero docker-compose 搭建课件服务(四):生成Dockerfile并在docker-compose中启动 0.1源码地址 https://github.com/liu ...

  4. PX01关于手机屏SPI触摸调试学习笔记

    上位机工具:http://www.xk-image.com/download/blog/0002_TP调试/LcdTools20210605.rar 调试案例:http://www.xk-image. ...

  5. 一篇文章带你了解服务器操作系统——Linux简单入门

    一篇文章带你了解服务器操作系统--Linux简单入门 Linux作为服务器的常用操作系统,身为工作人员自然是要有所了解的 在本篇中我们会简单介绍Linux的特点,安装,相关指令使用以及内部程序的安装等 ...

  6. html+css 面试题总结附答案

    行内元素有哪些? 块级元素有哪些? 块级元素:div p h1 ul li form table行内元素: a b br i span input select laber strong em img ...

  7. JWT中token的理解

    今天我们来聊一聊关于JWT授权的事情. JWT:Json Web Token.顾名思义,它是一种在Web中,使用Json来进行Token授权的方案. 既然没有找好密码,token是如何解决信任问题的呢 ...

  8. 部署redis-cluster

    1.环境准备 ☆ 每个Redis 节点采用相同的相同的Redis版本.相同的密码.硬件配置 ☆ 所有Redis服务器必须没有任何数据 #所有主从节点执行: [root@ubuntu2004 ~]#ba ...

  9. 基于LZO的高性能无损数据解压缩IP

    LZOAccel-D LZO Data Decompression Core/无损数据解压缩IP Core LZOAccel-D是一个无损数据解压缩引擎的FPGA硬件实现,兼容LZO 2.10标准. ...

  10. Docker | 专栏文章整理🎉🎉

    Docker Docker系列文章基本已经更新完毕,这是我从去年的学习笔记中整理出来的. 笔记稍微有点杂乱.随意,把它们整理成文章花费了不少力气.整理的过程也是我的一个再次学习的过程,同时也是为了方便 ...