官方提供的.flow_from_directory(directory)函数可以读取并训练大规模训练数据,基本可以满足大部分需求。但是在有些场合下,需要自己读取大规模数据以及对应标签,下面提供一种方法。

步骤0:导入相关

import random
import numpy as np
from keras.preprocessing.image import load_img,img_to_array
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model

步骤1:准备数据

#训练集样本路径
train_X = ["train/cat_1.jpg",
"train/cat_2.jpg",
"train/cat_3.jpg",
"train/dog_1.jpg",
"train/dog_2.jpg",
"train/dog_3.jpg"]
#验证集样本路径
val_X = ["val/cat_1.jpg",
"val/cat_2.jpg",
"val/cat_3.jpg",
"val/dog_1.jpg",
"val/dog_2.jpg",
"val/dog_3.jpg"] # 根据图片路径获取图片标签
def get_img_label(img_paths):
img_labels = [] for img_path in img_paths:
animal = img_path.split("/")[-1].split('_')[0]
if animal=='cat':
img_labels.append(0)
else:
img_labels.append(1) return img_labels # 读取图片
def load_batch_image(img_path, train_set = True, target_size=(224, 224)):
im = load_img(img_path, target_size=target_size)
if train_set:
return img_to_array(im) #converts image to numpy array
else:
return img_to_array(im)/255.0
# 建立一个数据迭代器
def GET_DATASET_SHUFFLE(X_samples, batch_size, train_set = True):
random.shuffle(X_samples) batch_num = int(len(X_samples) / batch_size)
max_len = batch_num * batch_size
X_samples = np.array(X_samples[:max_len])
y_samples = get_img_label(X_samples)
print(X_samples.shape) X_batches = np.split(X_samples, batch_num)
y_batches = np.split(y_samples, batch_num) for i in range(len(X_batches)):
if train_set:
x = np.array(list(map(load_batch_image, X_batches[i], [True for _ in range(batch_size)])))
else:
x = np.array(list(map(load_batch_image, X_batches[i], [False for _ in range(batch_size)])))
#print(x.shape)
y = np.array(y_batches[i])
yield x,y

步骤2:对训练数据进行数据增强处理

train_datagen = ImageDataGenerator(
rescale=1. / 255,
rotation_range=10,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)

步骤3:定义模型

model = Model(...)

步骤4:模型训练

n_epoch = 12
batch_size = 16
for e in range(n_epoch):
print("epoch", e)
batch_num = 0
loss_sum=np.array([0.0,0.0])
for X_train, y_train in GET_DATASET_SHUFFLE(train_X, batch_size, True): # chunks of 100 images
for X_batch, y_batch in train_datagen.flow(X_train, y_train, batch_size=batch_size): # chunks of 32 samples
loss = model.train_on_batch(X_batch, y_batch)
loss_sum += loss
batch_num += 1
break #手动break
if batch_num%200==0:
print("epoch %s, batch %s: train_loss = %.4f, train_acc = %.4f"%(e, batch_num, loss_sum[0]/200, loss_sum[1]/200))
loss_sum=np.array([0.0,0.0])
res = model.evaluate_generator(GET_DATASET_SHUFFLE(val_X, batch_size, False),int(len(val_X)/batch_size))
print("val_loss = %.4f, val_acc = %.4f: "%( res[0], res[1])) model.save("weight.h5")

另外,如果在训练的时候不需要做数据增强处理,那么训练就更加简单了,如下:

model.fit_generator(
GET_DATASET_SHUFFLE(train_X, batch_size, True),
epochs=10,
steps_per_epoch=int(len(train_X)/batch_size))

参考文献:

Training on Large Scale Image Datasets with Keras

使用Keras训练大规模数据集的更多相关文章

  1. Hinton胶囊网络后最新研究:用“在线蒸馏”训练大规模分布式神经网络

    Hinton胶囊网络后最新研究:用“在线蒸馏”训练大规模分布式神经网络 朱晓霞发表于目标检测和深度学习订阅 457 广告关闭 11.11 智慧上云 云服务器企业新用户优先购,享双11同等价格 立即抢购 ...

  2. Fast RCNN 训练自己数据集 (1编译配置)

    FastRCNN 训练自己数据集 (1编译配置) 转载请注明出处,楼燚(yì)航的blog,http://www.cnblogs.com/louyihang-loves-baiyan/ https:/ ...

  3. 使用caffe训练mnist数据集 - caffe教程实战(一)

    个人认为学习一个陌生的框架,最好从例子开始,所以我们也从一个例子开始. 学习本教程之前,你需要首先对卷积神经网络算法原理有些了解,而且安装好了caffe 卷积神经网络原理参考:http://cs231 ...

  4. 实践详细篇-Windows下使用VS2015编译的Caffe训练mnist数据集

    上一篇记录的是学习caffe前的环境准备以及如何创建好自己需要的caffe版本.这一篇记录的是如何使用编译好的caffe做训练mnist数据集,步骤编号延用上一篇 <实践详细篇-Windows下 ...

  5. keras训练cnn模型时loss为nan

    keras训练cnn模型时loss为nan 1.首先记下来如何解决这个问题的:由于我代码中 model.compile(loss='categorical_crossentropy', optimiz ...

  6. 使用py-faster-rcnn训练VOC2007数据集时遇到问题

    使用py-faster-rcnn训练VOC2007数据集时遇到如下问题: 1. KeyError: 'chair' File "/home/sai/py-faster-rcnn/tools/ ...

  7. Keras下载的数据集以及预训练模型保存在哪里

    Keras下载的数据集在以下目录中: root\\.keras\datasets Keras下载的预训练模型在以下目录中: root\\.keras\models 在win10系统来说,用户主目录是: ...

  8. YOLOV4在linux下训练自己数据集(亲测成功)

    最近推出了yolo-v4我也准备试着跑跑实验看看效果,看看大神的最新操作 这里不做打标签工作和配置cuda工作,需要的可以分别百度搜索   VOC格式数据集制作,cuda和cudnn配置 我们直接利用 ...

  9. Scaled-YOLOv4 快速开始,训练自定义数据集

    代码: https://github.com/ikuokuo/start-scaled-yolov4 Scaled-YOLOv4 代码: https://github.com/WongKinYiu/S ...

随机推荐

  1. Java程序第二次作业

    1.编写“人”类及其测试类.1.1 “人”类: 类名:Person 属性:姓名.性别.年龄.身份证号码 方法:在控制台输出各个信息1.2 测试类 类名:TestPerson 方法:main ...

  2. mysql 外键和子查询,视图

    1.mysql 外键约束 建表时生成外键   foreing key ('sid') references' student'('id'); 建表后添加外键  alter table' course ...

  3. DJango 前三天小结

    一 DJango 所有命令: 1下载: 控制台:pip install django== ​pip install django== -i 源解释器:找到解释器,点击加号搜索django 2创建项目; ...

  4. C语言-第4次作业得分

    作业链接:https://edu.cnblogs.com/campus/hljkj/CS201801/homework/2523 作业链接:https://edu.cnblogs.com/campus ...

  5. getParameter和getAttribute的区别

    1.getAttribute 是取得jsp中 用setAttribute 设定的attribute 2.parameter得到的是String:attribute得到的是object 3.reques ...

  6. bootstrap研究感想1

    我—>新人,特纯的新人,受到方大神的建议,开始写博客,写一些工作时敲代码时的感受,学习模仿大神时的感悟. -------------------------------------------- ...

  7. NPOI设置单元格背景色

    NPOI设置单元格背景色在网上有好多例子都是设置为NPOI内置的颜色值 但是想用rgb值来设置背景色,即:通过HSSFPalette类获取颜色值时会抛出异常:Could not Find free c ...

  8. 【分布式锁】redis实现

    转载:https://www.jianshu.com/p/c970cc710SETNX命令简介 SETNX key value 将key的值设为value,并且仅当key不存在. 若给定的key已经存 ...

  9. What You Can Learn from Actifio Logs

    The Actifio services generate many logs, some of which are useful for troubleshooting. This section ...

  10. angular2的ElementRef在组件中获取不到

    angular2的ElementRef在组件中获取不到 angular2不推荐操作dom,但是实际应用中不可避免的需要使用到dom操作,怎么操作,官方文档提供了一系列api(ElementRef,Vi ...