TensorFlow keras 迁移学习






数据的读取
import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator class TransferModel(object): def __init__(self):
#标准化和数据增强
self.train_generator = ImageDataGenerator(rescale=1.0/255.0)
self.test_generator = ImageDataGenerator(rescale=1.0/255.0)
#指定训练集数据和测试集数据目录
self.train_dir = "./data/train"
self.test_dir = "./data/test"
self.image_size = (224,224)
self.batch_size = 32 def get_loacl_data(self):
'''
读取本地的图片数据以及类别
:return:
'''
train_gen = self.train_generator.flow_from_directory(self.train_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
test_gen = self.test_generator.flow_from_directory(self.test_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True) return train_gen,test_gen if __name__ == '__main__':
tm = TransferModel()
train_gen,test_gen = tm.get_loacl_data()
print(train_gen)
迁移学习完整代码
import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.python.keras.applications.vgg16 import VGG16, preprocess_input
import numpy as np class TransferModel(object): def __init__(self): # 定义训练和测试图片的变化方法,标准化以及数据增强
self.train_generator = ImageDataGenerator(rescale=1.0 / 255.0)
self.test_generator = ImageDataGenerator(rescale=1.0 / 255.0) # 指定训练数据和测试数据的目录
self.train_dir = "./data/train"
self.test_dir = "./data/test" # 定义图片训练相关网络参数
self.image_size = (224, 224)
self.batch_size = 32 # 定义迁移学习的基类模型
# 不包含VGG当中3个全连接层的模型加载并且加载了参数
# vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
self.base_model = VGG16(weights='imagenet', include_top=False) self.label_dict = {
'0': '汽车',
'1': '恐龙',
'2': '大象',
'3': '花',
'4': '马'
} def get_local_data(self):
"""
读取本地的图片数据以及类别
:return: 训练数据和测试数据迭代器
"""
# 使用flow_from_derectory
train_gen = self.train_generator.flow_from_directory(self.train_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
test_gen = self.test_generator.flow_from_directory(self.test_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
return train_gen, test_gen def refine_base_model(self):
"""
微调VGG结构,5blocks后面+全局平均池化(减少迁移学习的参数数量)+两个全连接层
:return:
"""
# 1、获取原notop模型得出
# [?, ?, ?, 512]
x = self.base_model.outputs[0] # 2、在输出后面增加我们结构
# [?, ?, ?, 512]---->[?, 1 * 1 * 512]
x = keras.layers.GlobalAveragePooling2D()(x) # 3、定义新的迁移模型
x = keras.layers.Dense(1024, activation=tf.nn.relu)(x)
y_predict = keras.layers.Dense(5, activation=tf.nn.softmax)(x) # model定义新模型
# VGG 模型的输入, 输出:y_predict
transfer_model = keras.models.Model(inputs=self.base_model.inputs, outputs=y_predict) return transfer_model def freeze_model(self):
"""
冻结VGG模型(5blocks)
冻结VGG的多少,根据你的数据量
:return:
"""
# self.base_model.layers 获取所有层,返回层的列表
for layer in self.base_model.layers:
layer.trainable = False def compile(self, model):
"""
编译模型
:return:
"""
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
return None def fit_generator(self, model, train_gen, test_gen):
"""
训练模型,model.fit_generator()不是选择model.fit()
:return:
"""
# 每一次迭代准确率记录的h5文件
modelckpt = keras.callbacks.ModelCheckpoint('./ckpt/transfer_{epoch:02d}-{val_acc:.2f}.h5',
monitor='val_acc',
save_weights_only=True,
save_best_only=True,
mode='auto',
period=1) model.fit_generator(train_gen, epochs=3, validation_data=test_gen, callbacks=[modelckpt]) return None def predict(self, model):
"""
预测类别
:return:
""" # 加载模型,transfer_model
model.load_weights("./ckpt/transfer_02-0.93.h5") # 读取图片,处理
image = load_img("./1.jpg", target_size=(224, 224))
image.show()
image = img_to_array(image)
# print(image.shape)
# 四维(224, 224, 3)---》(1, 224, 224, 3)
img = image.reshape([1, image.shape[0], image.shape[1], image.shape[2]])
# print(img)
# model.predict() # 预测结果进行处理
image = preprocess_input(img)
predictions = model.predict(image)
print(predictions)
res = np.argmax(predictions, axis=1)
print("所预测的类别是:",self.label_dict[str(res[0])]) if __name__ == '__main__':
tm = TransferModel()
# 训练
# train_gen, test_gen = tm.get_local_data()
# # print(train_gen)
# # for data in train_gen:
# # print(data[0].shape, data[1].shape)
# # print(tm.base_model.summary())
# model = tm.refine_base_model()
# # print(model)
# tm.freeze_model()
# tm.compile(model)
#
# tm.fit_generator(model, train_gen, test_gen) # 测试
model = tm.refine_base_model() tm.predict(model)
TensorFlow keras 迁移学习的更多相关文章
- 『TensorFlow』迁移学习
完全版见github:TransforLearning 零.迁移学习 将一个领域的已经成熟的知识应用到其他的场景中称为迁移学习.用神经网络的角度来表述,就是一层层网络中每个节点的权重从一个训练好的网络 ...
- tensorflow实现迁移学习
此例程出自<TensorFlow实战Google深度学习框架>6.5.2小节 卷积神经网络迁移学习. 数据集来自http://download.tensorflow.org/example ...
- 吴裕雄--天生自然python Google深度学习框架:Tensorflow实现迁移学习
import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...
- ML.NET 示例:图像分类模型训练-首选API(基于原生TensorFlow迁移学习)
ML.NET 版本 API 类型 状态 应用程序类型 数据类型 场景 机器学习任务 算法 Microsoft.ML 1.5.0 动态API 最新 控制台应用程序和Web应用程序 图片文件 图像分类 基 ...
- TensorFlow从1到2(九)迁移学习
迁移学习基本概念 迁移学习是这两年比较火的一个话题,主要原因是在当前的机器学习中,样本数据的获取是成本最高的一块.而迁移学习可以有效的把原有的学习经验(对于模型就是模型本身及其训练好的权重值)带入到新 ...
- 深度学习应用系列(二) | 如何使用keras进行迁移学习,以训练和识别自己的图片集
本文的keras后台为tensorflow,介绍如何利用预编译的模型进行迁移学习,以训练和识别自己的图片集. 官网 https://keras.io/applications/ 已经介绍了各个基于Im ...
- 深度学习趣谈:什么是迁移学习?(附带Tensorflow代码实现)
一.迁移学习的概念 什么是迁移学习呢?迁移学习可以由下面的这张图来表示: 这张图最左边表示了迁移学习也就是把已经训练好的模型和权重直接纳入到新的数据集当中进行训练,但是我们只改变之前模型的分类器(全连 ...
- 常用深度学习框——Caffe/ TensorFlow / Keras/ PyTorch/MXNet
常用深度学习框--Caffe/ TensorFlow / Keras/ PyTorch/MXNet 一.概述 近几年来,深度学习的研究和应用的热潮持续高涨,各种开源深度学习框架层出不穷,包括Tenso ...
- 用tensorflow迁移学习猫狗分类
笔者这几天在跟着莫烦学习TensorFlow,正好到迁移学习(至于什么是迁移学习,看这篇),莫烦老师做的是预测猫和老虎尺寸大小的学习.作为一个有为的学生,笔者当然不能再预测猫啊狗啊的大小啦,正好之前正 ...
随机推荐
- PHP7内核(五):系统分析生命周期
上篇文章讲述了模块初始化阶段之前的准备工作,本篇我来详细介绍PHP生命周期的五个阶段. 一.模块初始化阶段 我们先来看一下该阶段的每个函数的作用. 1.1.sapi_initialize_reques ...
- Servlet读取前端的request payload
这几天遇见了一个很头疼的事,当我想用表单上传文件时,后端servlet读取不到文件的信息 网上搜索,说是要将form添加这个属性enctype="multipart/form-data&qu ...
- LM拟合算法
一. Levenberg-Marquardt算法 (1)y=a*e.^(-b*x)形式拟合 clear all % 计算函数f的雅克比矩阵,是解析式 syms a b y x real; f=a*e ...
- [一起读源码]走进C#并发队列ConcurrentQueue的内部世界
决定从这篇文章开始,开一个读源码系列,不限制平台语言或工具,任何自己感兴趣的都会写.前几天碰到一个小问题又读了一遍ConcurrentQueue的源码,那就拿C#中比较常用的并发队列Concurren ...
- IBN-Net: 提升模型的域自适应性
本文解读内容是IBN-Net, 笔者最初是在很多行人重识别的库中频繁遇到比如ResNet-ibn这样的模型,所以产生了阅读并研究这篇文章的兴趣,文章全称是: <Two at Once: Enha ...
- RocketMQ的高可用集群部署
RocketMQ的高可用集群部署 标签(空格分隔): 消息队列 部署 1. RocketMQ 集群物理部署结构 Rocket 物理部署结构 Name Server: 单点,供Producer和Cons ...
- SHTC3温湿度传感器的使用
1.SHTC3简单说明 SHTC3是一个检测温度.湿度的传感器,可以检测-40℃~125℃的温度范围和0%~100%的湿度范围. SHTC3的工作电压范围为:1.62V~3.6V. SHTC3使用的通 ...
- Bitmap之内存缓存和磁盘缓存详解
原文首发于微信公众号:躬行之(jzman-blog) Android 中缓存的使用比较普遍,使用相应的缓存策略可以减少流量的消耗,也可以在一定程度上提高应用的性能,如加载网络图片的情况,不应该每次都从 ...
- H5 布局 -- 让容器充满屏幕高度或自适应剩余高度
在前端页面布局中,经常会碰到要让容器充满整个屏幕高度或者剩余屏幕高度的需求.一般这时候都会想当然的使用 height:100% 这样的 CSS 来写. 这样写的话,当容器内内容很多的时候是没有问题的, ...
- 简单分析ucenter 会员同步登录通信原理
1.用户登录discuz,通过logging.php文件中的函数uc_user_login对post过来的数据进行验证,也就是对username和password进行验证. 2.如果验证成功,将调用位 ...