TensorFlow keras 迁移学习






数据的读取
import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator class TransferModel(object): def __init__(self):
#标准化和数据增强
self.train_generator = ImageDataGenerator(rescale=1.0/255.0)
self.test_generator = ImageDataGenerator(rescale=1.0/255.0)
#指定训练集数据和测试集数据目录
self.train_dir = "./data/train"
self.test_dir = "./data/test"
self.image_size = (224,224)
self.batch_size = 32 def get_loacl_data(self):
'''
读取本地的图片数据以及类别
:return:
'''
train_gen = self.train_generator.flow_from_directory(self.train_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
test_gen = self.test_generator.flow_from_directory(self.test_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True) return train_gen,test_gen if __name__ == '__main__':
tm = TransferModel()
train_gen,test_gen = tm.get_loacl_data()
print(train_gen)
迁移学习完整代码
import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.python.keras.applications.vgg16 import VGG16, preprocess_input
import numpy as np class TransferModel(object): def __init__(self): # 定义训练和测试图片的变化方法,标准化以及数据增强
self.train_generator = ImageDataGenerator(rescale=1.0 / 255.0)
self.test_generator = ImageDataGenerator(rescale=1.0 / 255.0) # 指定训练数据和测试数据的目录
self.train_dir = "./data/train"
self.test_dir = "./data/test" # 定义图片训练相关网络参数
self.image_size = (224, 224)
self.batch_size = 32 # 定义迁移学习的基类模型
# 不包含VGG当中3个全连接层的模型加载并且加载了参数
# vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
self.base_model = VGG16(weights='imagenet', include_top=False) self.label_dict = {
'0': '汽车',
'1': '恐龙',
'2': '大象',
'3': '花',
'4': '马'
} def get_local_data(self):
"""
读取本地的图片数据以及类别
:return: 训练数据和测试数据迭代器
"""
# 使用flow_from_derectory
train_gen = self.train_generator.flow_from_directory(self.train_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
test_gen = self.test_generator.flow_from_directory(self.test_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
return train_gen, test_gen def refine_base_model(self):
"""
微调VGG结构,5blocks后面+全局平均池化(减少迁移学习的参数数量)+两个全连接层
:return:
"""
# 1、获取原notop模型得出
# [?, ?, ?, 512]
x = self.base_model.outputs[0] # 2、在输出后面增加我们结构
# [?, ?, ?, 512]---->[?, 1 * 1 * 512]
x = keras.layers.GlobalAveragePooling2D()(x) # 3、定义新的迁移模型
x = keras.layers.Dense(1024, activation=tf.nn.relu)(x)
y_predict = keras.layers.Dense(5, activation=tf.nn.softmax)(x) # model定义新模型
# VGG 模型的输入, 输出:y_predict
transfer_model = keras.models.Model(inputs=self.base_model.inputs, outputs=y_predict) return transfer_model def freeze_model(self):
"""
冻结VGG模型(5blocks)
冻结VGG的多少,根据你的数据量
:return:
"""
# self.base_model.layers 获取所有层,返回层的列表
for layer in self.base_model.layers:
layer.trainable = False def compile(self, model):
"""
编译模型
:return:
"""
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
return None def fit_generator(self, model, train_gen, test_gen):
"""
训练模型,model.fit_generator()不是选择model.fit()
:return:
"""
# 每一次迭代准确率记录的h5文件
modelckpt = keras.callbacks.ModelCheckpoint('./ckpt/transfer_{epoch:02d}-{val_acc:.2f}.h5',
monitor='val_acc',
save_weights_only=True,
save_best_only=True,
mode='auto',
period=1) model.fit_generator(train_gen, epochs=3, validation_data=test_gen, callbacks=[modelckpt]) return None def predict(self, model):
"""
预测类别
:return:
""" # 加载模型,transfer_model
model.load_weights("./ckpt/transfer_02-0.93.h5") # 读取图片,处理
image = load_img("./1.jpg", target_size=(224, 224))
image.show()
image = img_to_array(image)
# print(image.shape)
# 四维(224, 224, 3)---》(1, 224, 224, 3)
img = image.reshape([1, image.shape[0], image.shape[1], image.shape[2]])
# print(img)
# model.predict() # 预测结果进行处理
image = preprocess_input(img)
predictions = model.predict(image)
print(predictions)
res = np.argmax(predictions, axis=1)
print("所预测的类别是:",self.label_dict[str(res[0])]) if __name__ == '__main__':
tm = TransferModel()
# 训练
# train_gen, test_gen = tm.get_local_data()
# # print(train_gen)
# # for data in train_gen:
# # print(data[0].shape, data[1].shape)
# # print(tm.base_model.summary())
# model = tm.refine_base_model()
# # print(model)
# tm.freeze_model()
# tm.compile(model)
#
# tm.fit_generator(model, train_gen, test_gen) # 测试
model = tm.refine_base_model() tm.predict(model)
TensorFlow keras 迁移学习的更多相关文章
- 『TensorFlow』迁移学习
完全版见github:TransforLearning 零.迁移学习 将一个领域的已经成熟的知识应用到其他的场景中称为迁移学习.用神经网络的角度来表述,就是一层层网络中每个节点的权重从一个训练好的网络 ...
- tensorflow实现迁移学习
此例程出自<TensorFlow实战Google深度学习框架>6.5.2小节 卷积神经网络迁移学习. 数据集来自http://download.tensorflow.org/example ...
- 吴裕雄--天生自然python Google深度学习框架:Tensorflow实现迁移学习
import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...
- ML.NET 示例:图像分类模型训练-首选API(基于原生TensorFlow迁移学习)
ML.NET 版本 API 类型 状态 应用程序类型 数据类型 场景 机器学习任务 算法 Microsoft.ML 1.5.0 动态API 最新 控制台应用程序和Web应用程序 图片文件 图像分类 基 ...
- TensorFlow从1到2(九)迁移学习
迁移学习基本概念 迁移学习是这两年比较火的一个话题,主要原因是在当前的机器学习中,样本数据的获取是成本最高的一块.而迁移学习可以有效的把原有的学习经验(对于模型就是模型本身及其训练好的权重值)带入到新 ...
- 深度学习应用系列(二) | 如何使用keras进行迁移学习,以训练和识别自己的图片集
本文的keras后台为tensorflow,介绍如何利用预编译的模型进行迁移学习,以训练和识别自己的图片集. 官网 https://keras.io/applications/ 已经介绍了各个基于Im ...
- 深度学习趣谈:什么是迁移学习?(附带Tensorflow代码实现)
一.迁移学习的概念 什么是迁移学习呢?迁移学习可以由下面的这张图来表示: 这张图最左边表示了迁移学习也就是把已经训练好的模型和权重直接纳入到新的数据集当中进行训练,但是我们只改变之前模型的分类器(全连 ...
- 常用深度学习框——Caffe/ TensorFlow / Keras/ PyTorch/MXNet
常用深度学习框--Caffe/ TensorFlow / Keras/ PyTorch/MXNet 一.概述 近几年来,深度学习的研究和应用的热潮持续高涨,各种开源深度学习框架层出不穷,包括Tenso ...
- 用tensorflow迁移学习猫狗分类
笔者这几天在跟着莫烦学习TensorFlow,正好到迁移学习(至于什么是迁移学习,看这篇),莫烦老师做的是预测猫和老虎尺寸大小的学习.作为一个有为的学生,笔者当然不能再预测猫啊狗啊的大小啦,正好之前正 ...
随机推荐
- redis吊锤面试官,这篇足够了!
原理篇 redis 时单线程的为什么还能那么快? 数据都在内存中,运算都是内存级别的运算. redis既然是单线程的为什么能处理那么多的并发数? 多路复用,操作系统时间轮训epoll 函数作为选择器, ...
- Java中for(;;)和while(true)的区别
while(true): public class Test { public static void main(String[] args) { while(true) { } } } 在?看看汇编 ...
- Hadoop调试记录(2)
自从上次调通hbase后很久没有碰hadoop了,今日想写一个mapreduce的小程序.于是先运行了下自带的wordcount示例程序,却报错了. 信息如下: kevin@ubuntu:~/usr/ ...
- RecyclerView 的简单使用
自从 Android 5.0 之后,google 推出了一个 RecyclerView 控件,他是 support-v7 包中的新组件,是一个强大的滑动组件,与经典的 ListView 相比,同样拥有 ...
- iOS开发 - 开发版+企业版无线发布一键打包
背景:项目进入快速迭代期,需要快速地交付出AdHoc版本和企业无线发布版本.每次打包都要来回切换bundle identifier和code signing,浪费很多时间. 示例项目名称名称为Test ...
- 【Python3爬虫】反反爬之破解同程旅游加密参数 antitoken
一.前言简介 在现在各个网站使用的反爬措施中,使用 JavaScript 加密算是很常用的了,通常会使用 JavaScript 加密某个参数,例如 token 或者 sign.在这次的例子中,就采取了 ...
- 1055 The World's Richest (25分)(水排序)
Forbes magazine publishes every year its list of billionaires based on the annual ranking of the wor ...
- ACL,NAT的使用
项目练习 练习一: 练习目的:通过配置路由器的dhcp功能使pc自动获取ip地址. Router>enable Router#configure terminal Router(config) ...
- 统计分析_集中趋势and离散程度
1.数组的集中趋势-如何定义数组的中心 1.1 常用几下几个指标来描述一个数组的集中趋势 均值-算术平均数 . 中位数-将数组升序或降序排列后,位于中间的数. 众数-数组中出现最多的数. 1.2 指标 ...
- python3(三十一)metaclass
""" """ __author__ = 'shaozhiqi' # 动态语言和静态语言最大的不同,就是函数和类的定义,不是编译时定义的,而 ...