官方提供的.flow_from_directory(directory)函数可以读取并训练大规模训练数据,基本可以满足大部分需求。但是在有些场合下,需要自己读取大规模数据以及对应标签,下面提供一种方法。

步骤0:导入相关

import random
import numpy as np
from keras.preprocessing.image import load_img,img_to_array
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model

步骤1:准备数据

#训练集样本路径
train_X = ["train/cat_1.jpg",
"train/cat_2.jpg",
"train/cat_3.jpg",
"train/dog_1.jpg",
"train/dog_2.jpg",
"train/dog_3.jpg"]
#验证集样本路径
val_X = ["val/cat_1.jpg",
"val/cat_2.jpg",
"val/cat_3.jpg",
"val/dog_1.jpg",
"val/dog_2.jpg",
"val/dog_3.jpg"] # 根据图片路径获取图片标签
def get_img_label(img_paths):
img_labels = [] for img_path in img_paths:
animal = img_path.split("/")[-1].split('_')[0]
if animal=='cat':
img_labels.append(0)
else:
img_labels.append(1) return img_labels # 读取图片
def load_batch_image(img_path, train_set = True, target_size=(224, 224)):
im = load_img(img_path, target_size=target_size)
if train_set:
return img_to_array(im) #converts image to numpy array
else:
return img_to_array(im)/255.0
# 建立一个数据迭代器
def GET_DATASET_SHUFFLE(X_samples, batch_size, train_set = True):
random.shuffle(X_samples) batch_num = int(len(X_samples) / batch_size)
max_len = batch_num * batch_size
X_samples = np.array(X_samples[:max_len])
y_samples = get_img_label(X_samples)
print(X_samples.shape) X_batches = np.split(X_samples, batch_num)
y_batches = np.split(y_samples, batch_num) for i in range(len(X_batches)):
if train_set:
x = np.array(list(map(load_batch_image, X_batches[i], [True for _ in range(batch_size)])))
else:
x = np.array(list(map(load_batch_image, X_batches[i], [False for _ in range(batch_size)])))
#print(x.shape)
y = np.array(y_batches[i])
yield x,y

步骤2:对训练数据进行数据增强处理

train_datagen = ImageDataGenerator(
rescale=1. / 255,
rotation_range=10,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)

步骤3:定义模型

model = Model(...)

步骤4:模型训练

n_epoch = 12
batch_size = 16
for e in range(n_epoch):
print("epoch", e)
batch_num = 0
loss_sum=np.array([0.0,0.0])
for X_train, y_train in GET_DATASET_SHUFFLE(train_X, batch_size, True): # chunks of 100 images
for X_batch, y_batch in train_datagen.flow(X_train, y_train, batch_size=batch_size): # chunks of 32 samples
loss = model.train_on_batch(X_batch, y_batch)
loss_sum += loss
batch_num += 1
break #手动break
if batch_num%200==0:
print("epoch %s, batch %s: train_loss = %.4f, train_acc = %.4f"%(e, batch_num, loss_sum[0]/200, loss_sum[1]/200))
loss_sum=np.array([0.0,0.0])
res = model.evaluate_generator(GET_DATASET_SHUFFLE(val_X, batch_size, False),int(len(val_X)/batch_size))
print("val_loss = %.4f, val_acc = %.4f: "%( res[0], res[1])) model.save("weight.h5")

另外,如果在训练的时候不需要做数据增强处理,那么训练就更加简单了,如下:

model.fit_generator(
GET_DATASET_SHUFFLE(train_X, batch_size, True),
epochs=10,
steps_per_epoch=int(len(train_X)/batch_size))

参考文献:

Training on Large Scale Image Datasets with Keras

使用Keras训练大规模数据集的更多相关文章

  1. Hinton胶囊网络后最新研究:用“在线蒸馏”训练大规模分布式神经网络

    Hinton胶囊网络后最新研究:用“在线蒸馏”训练大规模分布式神经网络 朱晓霞发表于目标检测和深度学习订阅 457 广告关闭 11.11 智慧上云 云服务器企业新用户优先购,享双11同等价格 立即抢购 ...

  2. Fast RCNN 训练自己数据集 (1编译配置)

    FastRCNN 训练自己数据集 (1编译配置) 转载请注明出处,楼燚(yì)航的blog,http://www.cnblogs.com/louyihang-loves-baiyan/ https:/ ...

  3. 使用caffe训练mnist数据集 - caffe教程实战(一)

    个人认为学习一个陌生的框架,最好从例子开始,所以我们也从一个例子开始. 学习本教程之前,你需要首先对卷积神经网络算法原理有些了解,而且安装好了caffe 卷积神经网络原理参考:http://cs231 ...

  4. 实践详细篇-Windows下使用VS2015编译的Caffe训练mnist数据集

    上一篇记录的是学习caffe前的环境准备以及如何创建好自己需要的caffe版本.这一篇记录的是如何使用编译好的caffe做训练mnist数据集,步骤编号延用上一篇 <实践详细篇-Windows下 ...

  5. keras训练cnn模型时loss为nan

    keras训练cnn模型时loss为nan 1.首先记下来如何解决这个问题的:由于我代码中 model.compile(loss='categorical_crossentropy', optimiz ...

  6. 使用py-faster-rcnn训练VOC2007数据集时遇到问题

    使用py-faster-rcnn训练VOC2007数据集时遇到如下问题: 1. KeyError: 'chair' File "/home/sai/py-faster-rcnn/tools/ ...

  7. Keras下载的数据集以及预训练模型保存在哪里

    Keras下载的数据集在以下目录中: root\\.keras\datasets Keras下载的预训练模型在以下目录中: root\\.keras\models 在win10系统来说,用户主目录是: ...

  8. YOLOV4在linux下训练自己数据集(亲测成功)

    最近推出了yolo-v4我也准备试着跑跑实验看看效果,看看大神的最新操作 这里不做打标签工作和配置cuda工作,需要的可以分别百度搜索   VOC格式数据集制作,cuda和cudnn配置 我们直接利用 ...

  9. Scaled-YOLOv4 快速开始,训练自定义数据集

    代码: https://github.com/ikuokuo/start-scaled-yolov4 Scaled-YOLOv4 代码: https://github.com/WongKinYiu/S ...

随机推荐

  1. spring注解注入:<context:component-scan>详解

    spring从2.5版本开始支持注解注入,注解注入可以省去很多的xml配置工作.由于注解是写入java代码中的,所以注解注入会失去一定的灵活性,我们要根据需要来选择是否启用注解注入. 我们首先看一个注 ...

  2. 1040 mysql Too many connections

    笔者在项目中遇到mysql 出现:1040 too many connections 异常,意思是超过数据库最大连接数,打不开表结构信息.笔者排除问题建议:1.查看程序代码是否存在BUG:2.检查代码 ...

  3. git 提交本地工程

    1> 进入github 建立repository 2> 得到git master 地址 3> 得到进入本地工程目录 右键->git bash here 4>执行 以下命令 ...

  4. vue---- v-bind指令

    v-bind指令用于给html标签设置属性. 基本用法 <div id="app"> <div v-bind:id="id1">文字&l ...

  5. Python基础:八、python基本数据类型

    一.什么是数据类型? 我们人类可以很容易的分清数字与字符的区别,但是计算机并不能,计算机虽然很强大,但从某种角度上来看又很傻,除非你明确告诉它,"1"是数字,"壹&quo ...

  6. ZZFAFA_BilibiliMusic_DownUrl

    OneDrive_DownFileUrl: FHProductionHK-BGM:https://1drv.ms/f/s!Ajs97XY1QSQ8cPXo36h4AK9XG7k CABAL&A ...

  7. Pac-Man 吃豆人

    发售年份 1980 平台 街机 开发商 南梦宫(Namco) 类型 迷宫 https://www.youtube.com/watch?v=dScq4P5gn4A

  8. 2018-2019-2 网络对抗技术 20165308 Exp2 后门原理与实践

    2018-2019-2 网络对抗技术 20165308 Exp2 后门原理与实践 1.实验内容 (3.5分) (1)使用netcat获取主机操作Shell,cron启动 (0.5分) (2)使用soc ...

  9. 利用chrome浏览器爬取数据

    相关的库自己下载吧,直接上代码 from selenium import webdriver from bs4 import BeautifulSoup import time #手动添加路径 pat ...

  10. redis 缓存击穿 看一篇成高手系列3

    什么是缓存击穿 在谈论缓存击穿之前,我们先来回忆下从缓存中加载数据的逻辑,如下图所示 因此,如果黑客每次故意查询一个在缓存内必然不存在的数据,导致每次请求都要去存储层去查询,这样缓存就失去了意义.如果 ...