数据的读取

import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator class TransferModel(object): def __init__(self):
#标准化和数据增强
self.train_generator = ImageDataGenerator(rescale=1.0/255.0)
self.test_generator = ImageDataGenerator(rescale=1.0/255.0)
#指定训练集数据和测试集数据目录
self.train_dir = "./data/train"
self.test_dir = "./data/test"
self.image_size = (224,224)
self.batch_size = 32 def get_loacl_data(self):
'''
读取本地的图片数据以及类别
:return:
'''
train_gen = self.train_generator.flow_from_directory(self.train_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
test_gen = self.test_generator.flow_from_directory(self.test_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True) return train_gen,test_gen if __name__ == '__main__':
tm = TransferModel()
train_gen,test_gen = tm.get_loacl_data()
print(train_gen)

  迁移学习完整代码

import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.python.keras.applications.vgg16 import VGG16, preprocess_input
import numpy as np class TransferModel(object): def __init__(self): # 定义训练和测试图片的变化方法,标准化以及数据增强
self.train_generator = ImageDataGenerator(rescale=1.0 / 255.0)
self.test_generator = ImageDataGenerator(rescale=1.0 / 255.0) # 指定训练数据和测试数据的目录
self.train_dir = "./data/train"
self.test_dir = "./data/test" # 定义图片训练相关网络参数
self.image_size = (224, 224)
self.batch_size = 32 # 定义迁移学习的基类模型
# 不包含VGG当中3个全连接层的模型加载并且加载了参数
# vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
self.base_model = VGG16(weights='imagenet', include_top=False) self.label_dict = {
'0': '汽车',
'1': '恐龙',
'2': '大象',
'3': '花',
'4': '马'
} def get_local_data(self):
"""
读取本地的图片数据以及类别
:return: 训练数据和测试数据迭代器
"""
# 使用flow_from_derectory
train_gen = self.train_generator.flow_from_directory(self.train_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
test_gen = self.test_generator.flow_from_directory(self.test_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
return train_gen, test_gen def refine_base_model(self):
"""
微调VGG结构,5blocks后面+全局平均池化(减少迁移学习的参数数量)+两个全连接层
:return:
"""
# 1、获取原notop模型得出
# [?, ?, ?, 512]
x = self.base_model.outputs[0] # 2、在输出后面增加我们结构
# [?, ?, ?, 512]---->[?, 1 * 1 * 512]
x = keras.layers.GlobalAveragePooling2D()(x) # 3、定义新的迁移模型
x = keras.layers.Dense(1024, activation=tf.nn.relu)(x)
y_predict = keras.layers.Dense(5, activation=tf.nn.softmax)(x) # model定义新模型
# VGG 模型的输入, 输出:y_predict
transfer_model = keras.models.Model(inputs=self.base_model.inputs, outputs=y_predict) return transfer_model def freeze_model(self):
"""
冻结VGG模型(5blocks)
冻结VGG的多少,根据你的数据量
:return:
"""
# self.base_model.layers 获取所有层,返回层的列表
for layer in self.base_model.layers:
layer.trainable = False def compile(self, model):
"""
编译模型
:return:
"""
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
return None def fit_generator(self, model, train_gen, test_gen):
"""
训练模型,model.fit_generator()不是选择model.fit()
:return:
"""
# 每一次迭代准确率记录的h5文件
modelckpt = keras.callbacks.ModelCheckpoint('./ckpt/transfer_{epoch:02d}-{val_acc:.2f}.h5',
monitor='val_acc',
save_weights_only=True,
save_best_only=True,
mode='auto',
period=1) model.fit_generator(train_gen, epochs=3, validation_data=test_gen, callbacks=[modelckpt]) return None def predict(self, model):
"""
预测类别
:return:
""" # 加载模型,transfer_model
model.load_weights("./ckpt/transfer_02-0.93.h5") # 读取图片,处理
image = load_img("./1.jpg", target_size=(224, 224))
image.show()
image = img_to_array(image)
# print(image.shape)
# 四维(224, 224, 3)---》(1, 224, 224, 3)
img = image.reshape([1, image.shape[0], image.shape[1], image.shape[2]])
# print(img)
# model.predict() # 预测结果进行处理
image = preprocess_input(img)
predictions = model.predict(image)
print(predictions)
res = np.argmax(predictions, axis=1)
print("所预测的类别是:",self.label_dict[str(res[0])]) if __name__ == '__main__':
tm = TransferModel()
# 训练
# train_gen, test_gen = tm.get_local_data()
# # print(train_gen)
# # for data in train_gen:
# # print(data[0].shape, data[1].shape)
# # print(tm.base_model.summary())
# model = tm.refine_base_model()
# # print(model)
# tm.freeze_model()
# tm.compile(model)
#
# tm.fit_generator(model, train_gen, test_gen) # 测试
model = tm.refine_base_model() tm.predict(model)

  

TensorFlow keras 迁移学习的更多相关文章

  1. 『TensorFlow』迁移学习

    完全版见github:TransforLearning 零.迁移学习 将一个领域的已经成熟的知识应用到其他的场景中称为迁移学习.用神经网络的角度来表述,就是一层层网络中每个节点的权重从一个训练好的网络 ...

  2. tensorflow实现迁移学习

    此例程出自<TensorFlow实战Google深度学习框架>6.5.2小节 卷积神经网络迁移学习. 数据集来自http://download.tensorflow.org/example ...

  3. 吴裕雄--天生自然python Google深度学习框架:Tensorflow实现迁移学习

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  4. ML.NET 示例:图像分类模型训练-首选API(基于原生TensorFlow迁移学习)

    ML.NET 版本 API 类型 状态 应用程序类型 数据类型 场景 机器学习任务 算法 Microsoft.ML 1.5.0 动态API 最新 控制台应用程序和Web应用程序 图片文件 图像分类 基 ...

  5. TensorFlow从1到2(九)迁移学习

    迁移学习基本概念 迁移学习是这两年比较火的一个话题,主要原因是在当前的机器学习中,样本数据的获取是成本最高的一块.而迁移学习可以有效的把原有的学习经验(对于模型就是模型本身及其训练好的权重值)带入到新 ...

  6. 深度学习应用系列(二) | 如何使用keras进行迁移学习,以训练和识别自己的图片集

    本文的keras后台为tensorflow,介绍如何利用预编译的模型进行迁移学习,以训练和识别自己的图片集. 官网 https://keras.io/applications/ 已经介绍了各个基于Im ...

  7. 深度学习趣谈:什么是迁移学习?(附带Tensorflow代码实现)

    一.迁移学习的概念 什么是迁移学习呢?迁移学习可以由下面的这张图来表示: 这张图最左边表示了迁移学习也就是把已经训练好的模型和权重直接纳入到新的数据集当中进行训练,但是我们只改变之前模型的分类器(全连 ...

  8. 常用深度学习框——Caffe/ TensorFlow / Keras/ PyTorch/MXNet

    常用深度学习框--Caffe/ TensorFlow / Keras/ PyTorch/MXNet 一.概述 近几年来,深度学习的研究和应用的热潮持续高涨,各种开源深度学习框架层出不穷,包括Tenso ...

  9. 用tensorflow迁移学习猫狗分类

    笔者这几天在跟着莫烦学习TensorFlow,正好到迁移学习(至于什么是迁移学习,看这篇),莫烦老师做的是预测猫和老虎尺寸大小的学习.作为一个有为的学生,笔者当然不能再预测猫啊狗啊的大小啦,正好之前正 ...

随机推荐

  1. [codevs]1250斐波那契数列<矩阵乘法&快速幂>

    题目描述 Description 定义:f0=f1=1, fn=fn-1+fn-2(n>=2).{fi}称为Fibonacci数列. 输入n,求fn mod q.其中1<=q<=30 ...

  2. 《 OO第一作业周期(前四周)总结 》

    作为一名软件工程的大学生,很高兴能够以这样一种方式,实现对博客编写零的突破.专业课老师也介绍了编写博客给我们带来的帮助,听了以后,我感觉到了培养出写博客的习惯,是一件多么有意义的事! 话不多说,让我们 ...

  3. SpringBoot登录判断

    <!-- html登录代码 --> <div class="box"> <div class="title">登录</ ...

  4. Let‘s play computer game(最短路 + dfs找出所有确定长度的最短路)

    Let's play computer game Description xxxxxxxxx在疫情期间迷上了一款游戏,这个游戏一共有nnn个地点(编号为1--n1--n1--n),他每次从一个地点移动 ...

  5. codeforces 1038a(找最长的前k个字母出现相同次数的字符串)

    codeforces 1038a You are given a string s of length n, which consists only of the first k letters of ...

  6. Step by Step!教你如何在k3s集群上使用Traefik 2.x

    本文来自边缘计算k3s社区 作者简介 Cello Spring,瑞士人.从电子起步,拥有电子工程学位.尔后开始关注计算机领域,在软件开发领域拥有多年的工作经验. Traefik是一个十分可靠的云原生动 ...

  7. JavaScript五子棋第二版

      这是博主做的一个移动端五子棋小游戏,请使用手机体验.由于希望能有迭代开发的感觉,所以暂时只支持双人对战且无其他提示及对战界面,只有胜利提示,悔棋.对战双方显示.人机对战.集成TS(用于学习).和局 ...

  8. css布局之盒模型

    盒模型 导读 随着网络技术的不断发展,人们已经不再只关注网页的功能,还追求网页的性能和美观,于是css应运而生,一个完美的网页必然有一个完美的布局,而css盒模型是网页布局的基石,所以了解它对网页制作 ...

  9. Python GUI——tkinter菜鸟编程(中)

    8. Radiobutton 选项按钮:可以用鼠标单击方式选取,一次只能有一个选项被选取. Radiobutton(父对象,options,-) 常用options参数: anchor,bg,bitm ...

  10. .NET Core项目部署到Linux(Centos7)(五)Centos 7安装.NET Core环境

    目录 1.前言 2.环境和软件的准备 3.创建.NET Core API项目 4.VMware Workstation虚拟机及Centos 7安装 5.Centos 7安装.NET Core环境 6. ...