数据的读取

import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator class TransferModel(object): def __init__(self):
#标准化和数据增强
self.train_generator = ImageDataGenerator(rescale=1.0/255.0)
self.test_generator = ImageDataGenerator(rescale=1.0/255.0)
#指定训练集数据和测试集数据目录
self.train_dir = "./data/train"
self.test_dir = "./data/test"
self.image_size = (224,224)
self.batch_size = 32 def get_loacl_data(self):
'''
读取本地的图片数据以及类别
:return:
'''
train_gen = self.train_generator.flow_from_directory(self.train_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
test_gen = self.test_generator.flow_from_directory(self.test_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True) return train_gen,test_gen if __name__ == '__main__':
tm = TransferModel()
train_gen,test_gen = tm.get_loacl_data()
print(train_gen)

  迁移学习完整代码

import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.python.keras.applications.vgg16 import VGG16, preprocess_input
import numpy as np class TransferModel(object): def __init__(self): # 定义训练和测试图片的变化方法,标准化以及数据增强
self.train_generator = ImageDataGenerator(rescale=1.0 / 255.0)
self.test_generator = ImageDataGenerator(rescale=1.0 / 255.0) # 指定训练数据和测试数据的目录
self.train_dir = "./data/train"
self.test_dir = "./data/test" # 定义图片训练相关网络参数
self.image_size = (224, 224)
self.batch_size = 32 # 定义迁移学习的基类模型
# 不包含VGG当中3个全连接层的模型加载并且加载了参数
# vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
self.base_model = VGG16(weights='imagenet', include_top=False) self.label_dict = {
'0': '汽车',
'1': '恐龙',
'2': '大象',
'3': '花',
'4': '马'
} def get_local_data(self):
"""
读取本地的图片数据以及类别
:return: 训练数据和测试数据迭代器
"""
# 使用flow_from_derectory
train_gen = self.train_generator.flow_from_directory(self.train_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
test_gen = self.test_generator.flow_from_directory(self.test_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
return train_gen, test_gen def refine_base_model(self):
"""
微调VGG结构,5blocks后面+全局平均池化(减少迁移学习的参数数量)+两个全连接层
:return:
"""
# 1、获取原notop模型得出
# [?, ?, ?, 512]
x = self.base_model.outputs[0] # 2、在输出后面增加我们结构
# [?, ?, ?, 512]---->[?, 1 * 1 * 512]
x = keras.layers.GlobalAveragePooling2D()(x) # 3、定义新的迁移模型
x = keras.layers.Dense(1024, activation=tf.nn.relu)(x)
y_predict = keras.layers.Dense(5, activation=tf.nn.softmax)(x) # model定义新模型
# VGG 模型的输入, 输出:y_predict
transfer_model = keras.models.Model(inputs=self.base_model.inputs, outputs=y_predict) return transfer_model def freeze_model(self):
"""
冻结VGG模型(5blocks)
冻结VGG的多少,根据你的数据量
:return:
"""
# self.base_model.layers 获取所有层,返回层的列表
for layer in self.base_model.layers:
layer.trainable = False def compile(self, model):
"""
编译模型
:return:
"""
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
return None def fit_generator(self, model, train_gen, test_gen):
"""
训练模型,model.fit_generator()不是选择model.fit()
:return:
"""
# 每一次迭代准确率记录的h5文件
modelckpt = keras.callbacks.ModelCheckpoint('./ckpt/transfer_{epoch:02d}-{val_acc:.2f}.h5',
monitor='val_acc',
save_weights_only=True,
save_best_only=True,
mode='auto',
period=1) model.fit_generator(train_gen, epochs=3, validation_data=test_gen, callbacks=[modelckpt]) return None def predict(self, model):
"""
预测类别
:return:
""" # 加载模型,transfer_model
model.load_weights("./ckpt/transfer_02-0.93.h5") # 读取图片,处理
image = load_img("./1.jpg", target_size=(224, 224))
image.show()
image = img_to_array(image)
# print(image.shape)
# 四维(224, 224, 3)---》(1, 224, 224, 3)
img = image.reshape([1, image.shape[0], image.shape[1], image.shape[2]])
# print(img)
# model.predict() # 预测结果进行处理
image = preprocess_input(img)
predictions = model.predict(image)
print(predictions)
res = np.argmax(predictions, axis=1)
print("所预测的类别是:",self.label_dict[str(res[0])]) if __name__ == '__main__':
tm = TransferModel()
# 训练
# train_gen, test_gen = tm.get_local_data()
# # print(train_gen)
# # for data in train_gen:
# # print(data[0].shape, data[1].shape)
# # print(tm.base_model.summary())
# model = tm.refine_base_model()
# # print(model)
# tm.freeze_model()
# tm.compile(model)
#
# tm.fit_generator(model, train_gen, test_gen) # 测试
model = tm.refine_base_model() tm.predict(model)

  

TensorFlow keras 迁移学习的更多相关文章

  1. 『TensorFlow』迁移学习

    完全版见github:TransforLearning 零.迁移学习 将一个领域的已经成熟的知识应用到其他的场景中称为迁移学习.用神经网络的角度来表述,就是一层层网络中每个节点的权重从一个训练好的网络 ...

  2. tensorflow实现迁移学习

    此例程出自<TensorFlow实战Google深度学习框架>6.5.2小节 卷积神经网络迁移学习. 数据集来自http://download.tensorflow.org/example ...

  3. 吴裕雄--天生自然python Google深度学习框架:Tensorflow实现迁移学习

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  4. ML.NET 示例:图像分类模型训练-首选API(基于原生TensorFlow迁移学习)

    ML.NET 版本 API 类型 状态 应用程序类型 数据类型 场景 机器学习任务 算法 Microsoft.ML 1.5.0 动态API 最新 控制台应用程序和Web应用程序 图片文件 图像分类 基 ...

  5. TensorFlow从1到2(九)迁移学习

    迁移学习基本概念 迁移学习是这两年比较火的一个话题,主要原因是在当前的机器学习中,样本数据的获取是成本最高的一块.而迁移学习可以有效的把原有的学习经验(对于模型就是模型本身及其训练好的权重值)带入到新 ...

  6. 深度学习应用系列(二) | 如何使用keras进行迁移学习,以训练和识别自己的图片集

    本文的keras后台为tensorflow,介绍如何利用预编译的模型进行迁移学习,以训练和识别自己的图片集. 官网 https://keras.io/applications/ 已经介绍了各个基于Im ...

  7. 深度学习趣谈:什么是迁移学习?(附带Tensorflow代码实现)

    一.迁移学习的概念 什么是迁移学习呢?迁移学习可以由下面的这张图来表示: 这张图最左边表示了迁移学习也就是把已经训练好的模型和权重直接纳入到新的数据集当中进行训练,但是我们只改变之前模型的分类器(全连 ...

  8. 常用深度学习框——Caffe/ TensorFlow / Keras/ PyTorch/MXNet

    常用深度学习框--Caffe/ TensorFlow / Keras/ PyTorch/MXNet 一.概述 近几年来,深度学习的研究和应用的热潮持续高涨,各种开源深度学习框架层出不穷,包括Tenso ...

  9. 用tensorflow迁移学习猫狗分类

    笔者这几天在跟着莫烦学习TensorFlow,正好到迁移学习(至于什么是迁移学习,看这篇),莫烦老师做的是预测猫和老虎尺寸大小的学习.作为一个有为的学生,笔者当然不能再预测猫啊狗啊的大小啦,正好之前正 ...

随机推荐

  1. nginx使用手册+基本原理+优缺点

    一.nginx优点 1.反向代理 1.正向代理: 客户端和原始服务器(origin server)之间的服务器,为了从原始服务器取得内容,客户端向代理发送一个请求并指定目标(原始服务器),然后代理向原 ...

  2. SublimeのJedi (自动补全)

    关于 Sublime 3 - Jedi Package 的设置和使用方法 我是一枚小白,安装后 Sublime 后,想在码字时,达到如下效果: 打字时,自动提示相关内容 按Tab键,相关内容自动填充 ...

  3. Array.forEach原理,仿造一个类似功能

    Array.forEach原理,仿造一个类似功能 array.forEach // 设一个arr数组 let arr = [12,45,78,165,68,124]; let sum = 0; // ...

  4. 模块 configparser 配置文件生成修改

    此模块用于生成和修改常见配置文档 1.来看一个好多软件的常见配置文件格式如下***.ini [DEFAULT] ServerAliveInterval = 45 Compression = yes C ...

  5. React Hook挖坑

    React Hook挖坑 如果已经使用过 Hook,相信你一定回不去了,这种用函数的方式去编写有状态组件简直太爽啦. 如果还没使用过 Hook,那你要赶紧升级你的 React(v16.8+),投入 H ...

  6. Day13 流程控制

    Linux中的流程控制语句 一.if语句 1.单分支if条件语句 格式:if [ 条件判断式 ] then 程序     fi 注意:1.在Linux中是以if开头,fi结尾.其他地方一般是{开头,} ...

  7. 性能优化之三:将Dottrace过程加入持续集成

    之前分享过一篇如何做接口性能分析的文章,但是整个分析过程有点繁琐,需要写一个控制台程序调用被测接口,再预热.启动dottrace追踪,最后才能得到我们想要的性能分析报告.如果有办法一键生成性能分析报告 ...

  8. Java String与char

    1. char类型 + char 类型 = 字符对应的ASCII码值相加(数字): char类型 + String 类型 = 字符对应的ASCII码值相加(数字) + String 类型: Strin ...

  9. jq日历一周为单位轮播

    简单效果图: 代码如下: <!doctype html> <html lang="en"> <head> <meta charset=&q ...

  10. pm2 开机启动egg项目

    1.在服务器上安装PM2 npm install pm2 -g 2.对PM2进行更新 pm2 update 3.进入服务器中egg项目更目录,并新建server.js文件,并在其中写入以下代码 con ...