深度学习基础系列(十一)| Keras中图像增强技术详解
在深度学习中,数据短缺是我们经常面临的一个问题,虽然现在有不少公开数据集,但跟大公司掌握的海量数据集相比,数量上仍然偏少,而某些特定领域的数据采集更是非常困难。根据之前的学习可知,数据量少带来的最直接影响就是过拟合。那有没有办法在现有少量数据基础上,降低或解决过拟合问题呢?
答案是有的,就是数据增强技术。我们可以对现有的数据,如图片数据进行平移、翻转、旋转、缩放、亮度增强等操作,以生成新的图片来参与训练或测试。这种操作可以将图片数量提升数倍,由此大大降低了过拟合的可能。本文将详解图像增强技术在Keras中的原理和应用。
一、Keras中的ImageDataGenerator类
图像增强的官网地址是:https://keras.io/preprocessing/image/ ,API使用相对简单,功能也很强大。
先介绍的是ImageDataGenerator类,这个类定义了图片该如何进行增强操作,其API及参数定义如下:
keras.preprocessing.image.ImageDataGenerator(
featurewise_center=False, #输入值按照均值为0进行处理
samplewise_center=False, #每个样本的均值按0处理
featurewise_std_normalization=False, #输入值按照标准正态化处理
samplewise_std_normalization=False, #每个样本按照标准正态化处理
zca_whitening=False, # 是否开启增白
zca_epsilon=1e-06,
rotation_range=0, #图像随机旋转一定角度,最大旋转角度为设定值
width_shift_range=0.0, #图像随机水平平移,最大平移值为设定值。若值为小于1的float值,则可认为是按比例平移,若大于1,则平移的是像素;若值为整型,平移的也是像素;假设像素为2.0,则移动范围为[-1,1]之间
height_shift_range=0.0, #图像随机垂直平移,同上
brightness_range=None, # 图像随机亮度增强,给定一个含两个float值的list,亮度值取自上下限值间
shear_range=0.0, # 图像随机修剪
zoom_range=0.0, # 图像随机变焦
channel_shift_range=0.0,
fill_mode='nearest', #填充模式,默认为最近原则,比如一张图片向右平移,那么最左侧部分会被临近的图案覆盖
cval=0.0,
horizontal_flip=False, #图像随机水平翻转
vertical_flip=False, #图像随机垂直翻转
rescale=None, #缩放尺寸
preprocessing_function=None,
data_format=None,
validation_split=0.0,
dtype=None)
下文将以mnist和花类的数据集进行图片操作,其中花类(17种花,共1360张图片)数据集可见我的百度网盘: https://pan.baidu.com/s/1YDA_VOBlJSQEijcCoGC60w 。让我们以直观地方式看看各参数能带来什么样的图片变化。
随机旋转
我们可用mnist数据集对图片进行随机旋转,旋转的最大角度由参数定义。
from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras import backend as K K.set_image_dim_ordering('th') (train_data, train_label), (test_data, test_label) = mnist.load_data()
train_data = train_data.reshape(train_data.shape[0], 1, 28, 28)
train_data = train_data.astype('float32') # 创建图像生成器,指定对图像操作的内容
datagen = ImageDataGenerator(rotation_range=90)
# 图像生成器要训练的数据
datagen.fit(train_data) # 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for batch_data, batch_label in datagen.flow(train_data, train_label, batch_size=9):
for i in range(0, 9):
# 创建一个 3*3的九宫格,以显示图片
pyplot.subplot(330 + 1 + i)
pyplot.imshow(batch_data[i].reshape(28, 28), cmap=pyplot.get_cmap('gray'))
pyplot.show()
break
生成结果为:
随机平移
我们可用花类数据集对图片进行随机平移,可以在垂直和水平方向上平移,平移最大值由参数定义。
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras.preprocessing.image import array_to_img IMAGE_SIZE = 224
NUM_CLASSES = 17
TRAIN_PATH = '/home/yourname/Documents/tensorflow/images/17flowerclasses/train'
TEST_PATH = '/home/yourname/Documents/tensorflow/images/17flowerclasses/test'
FLOWER_CLASSES = ['Bluebell', 'ButterCup', 'ColtsFoot', 'Cowslip', 'Crocus', 'Daffodil', 'Daisy',
'Dandelion', 'Fritillary', 'Iris', 'LilyValley', 'Pansy', 'Snowdrop', 'Sunflower',
'Tigerlily', 'tulip', 'WindFlower'] # 创建图像生成器,指定对图像操作的内容,平移的最大比例为50%
train_datagen = ImageDataGenerator(width_shift_range=0.5, height_shift_range=0.5) # 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for X_batch, y_batch in train_datagen.flow_from_directory(directory=TRAIN_PATH, target_size=(IMAGE_SIZE, IMAGE_SIZE),batch_size=9, classes=FLOWER_CLASSES):
for i in range(0, 9):
pyplot.subplot(330 + 1 + i)
pyplot.imshow(array_to_img(X_batch[i]))
pyplot.show()
break
生成结果为:
可以观察到,图片除了实现平移外,其原来的位置都被最近的图案给填充,因为默认给的填充方式是nearest。
随机亮度调整
我们可用花类数据集对图片进行随机亮度调整,亮度范围由参数定义。
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras.preprocessing.image import array_to_img IMAGE_SIZE = 224
NUM_CLASSES = 17
TRAIN_PATH = '/home/yourname/Documents/tensorflow/images/17flowerclasses/train'
TEST_PATH = '/home/yourname/Documents/tensorflow/images/17flowerclasses/test'
FLOWER_CLASSES = ['Bluebell', 'ButterCup', 'ColtsFoot', 'Cowslip', 'Crocus', 'Daffodil', 'Daisy',
'Dandelion', 'Fritillary', 'Iris', 'LilyValley', 'Pansy', 'Snowdrop', 'Sunflower',
'Tigerlily', 'tulip', 'WindFlower'] # 创建图像生成器,指定对图像操作的内容,亮度范围在0.1~10之间随机选择
train_datagen = ImageDataGenerator(brightness_range=[0.1, 10]) # 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for X_batch, y_batch in train_datagen.flow_from_directory(directory=TRAIN_PATH, target_size=(IMAGE_SIZE, IMAGE_SIZE),batch_size=9, classes=FLOWER_CLASSES):
for i in range(0, 9):
pyplot.subplot(330 + 1 + i)
pyplot.imshow(array_to_img(X_batch[i]))
pyplot.show()
break
生成结果为:
随机焦距调整
我们可用mnist数据集对图片进行随机焦距调整,焦距调整值由参数定义。
from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras import backend as K K.set_image_dim_ordering('th') (train_data, train_label), (test_data, test_label) = mnist.load_data()
train_data = train_data.reshape(train_data.shape[0], 1, 28, 28)
train_data = train_data.astype('float32') # 创建图像生成器,指定对图像操作的内容,焦距值在0.1~1之间
datagen = ImageDataGenerator(zoom_range=[0.1, 1])
# 图像生成器要训练的数据
datagen.fit(train_data) # 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for batch_data, batch_label in datagen.flow(train_data, train_label, batch_size=9):
for i in range(0, 9):
# 创建一个 3*3的九宫格,以显示图片
pyplot.subplot(330 + 1 + i)
pyplot.imshow(batch_data[i].reshape(28, 28), cmap=pyplot.get_cmap('gray'))
pyplot.show()
break
生成结果为:
可以看出这跟相机调焦一样,可以放大或缩小焦距。
随机翻转
我们可用花类数据集对图片进行随机翻转。
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras.preprocessing.image import array_to_img IMAGE_SIZE = 224
NUM_CLASSES = 17
TRAIN_PATH = '/home/hutao/Documents/tensorflow/images/17flowerclasses/train'
TEST_PATH = '/home/hutao/Documents/tensorflow/images/17flowerclasses/test'
FLOWER_CLASSES = ['Bluebell', 'ButterCup', 'ColtsFoot', 'Cowslip', 'Crocus', 'Daffodil', 'Daisy',
'Dandelion', 'Fritillary', 'Iris', 'LilyValley', 'Pansy', 'Snowdrop', 'Sunflower',
'Tigerlily', 'tulip', 'WindFlower'] # 创建图像生成器,指定对图像操作的内容,图片随机翻转
train_datagen = ImageDataGenerator(horizontal_flip=True, vertical_flip=True) # 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for X_batch, y_batch in train_datagen.flow_from_directory(directory=TRAIN_PATH, target_size=(IMAGE_SIZE, IMAGE_SIZE),batch_size=9, classes=FLOWER_CLASSES):
for i in range(0, 9):
pyplot.subplot(330 + 1 + i)
pyplot.imshow(array_to_img(X_batch[i]))
pyplot.show()
break
生成结果为:
从上图可看出,有些图片水平翻转了,有些是垂直翻转了。
ZCA图像增白
说实在我不太清楚该技术有何用,用花类图片实验结果显示zca不支持,可以用mnist数据集来看看效果。
from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras import backend as K K.set_image_dim_ordering('th') (train_data, train_label), (test_data, test_label) = mnist.load_data()
train_data = train_data.reshape(train_data.shape[0], 1, 28, 28)
train_data = train_data.astype('float32') # 创建图像生成器,指定对图像操作的内容,增白图片
datagen = ImageDataGenerator(zca_whitening=True)
# 图像生成器要训练的数据
datagen.fit(train_data) # 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for batch_data, batch_label in datagen.flow(train_data, train_label, batch_size=9):
for i in range(0, 9):
# 创建一个 3*3的九宫格,以显示图片
pyplot.subplot(330 + 1 + i)
pyplot.imshow(batch_data[i].reshape(28, 28), cmap=pyplot.get_cmap('gray'))
pyplot.show()
break
生成结果为:
特征标准化
特征标准化的含义是使图片的像素均值为0,标准差为1,不过我试了多次,直观效果不明显。
from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras import backend as K K.set_image_dim_ordering('th') (train_data, train_label), (test_data, test_label) = mnist.load_data()
train_data = train_data.reshape(train_data.shape[0], 1, 28, 28)
train_data = train_data.astype('float32') # 创建图像生成器,指定对图像操作的内容,允许图片标准化处理
datagen = ImageDataGenerator(featurewise_center=True, featurewise_std_normalization=True)
# 图像生成器要训练的数据
datagen.fit(train_data) # 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for batch_data, batch_label in datagen.flow(train_data, train_label, batch_size=9):
for i in range(0, 9):
# 创建一个 3*3的九宫格,以显示图片
pyplot.subplot(330 + 1 + i)
pyplot.imshow(batch_data[i].reshape(28, 28), cmap=pyplot.get_cmap('gray'))
pyplot.show()
break
生成结果为:
就个人而言,我倾向于在图像增强中使用旋转、亮度调整、翻转和平移操作。
二、Keras如何进行图像增强数据训练
在之前的文章中我已经展现过数据增强的使用。在Keras中,增强图片有三种来源:
- 图片来源于已知数据集,如mnist、cifar,数据格式为numpy格式;
- 图片来源于我们自己搜集的图片,如本文引入的花类数据集,其图片为jpg、png等格式;
- 图片来源于panda数据集;
其中数据来源已知数据集,其操作方法如下:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train = np_utils.to_categorical(y_train, num_classes)
y_test = np_utils.to_categorical(y_test, num_classes) datagen = ImageDataGenerator(
featurewise_center=True,
featurewise_std_normalization=True,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True) #生成器绑定训练集
datagen.fit(x_train) # 模型绑定生成器,并不停地迭代产生数据,可指定迭代次数,假设图片总数为1000张,batch默认为32,则每次迭代需要产生1000/32=32个步骤
history = model.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
steps_per_epoch=len(x_train) / 32, epochs=epochs)
数据来源图片集,其操作方法如下:
batch_size = 32
# 迭代50次
epochs = 50
# 依照模型规定,图片大小被设定为224
IMAGE_SIZE = 224
TRAIN_PATH = '/home/yourname/Documents/tensorflow/images/17flowerclasses/train'
TEST_PATH = '/home/yourname/Documents/tensorflow/images/17flowerclasses/test'
FLOWER_CLASSES = ['Bluebell', 'ButterCup', 'ColtsFoot', 'Cowslip', 'Crocus', 'Daffodil', 'Daisy','Dandelion', 'Fritillary', 'Iris', 'LilyValley', 'Pansy', 'Snowdrop', 'Sunflower','Tigerlily', 'tulip', 'WindFlower'] # 使用数据增强
train_datagen = ImageDataGenerator(rotation_range=90)
# 可指定输出图片大小,因为深度学习要求训练图片大小保持一致
train_generator = train_datagen.flow_from_directory(directory=TRAIN_PATH,
target_size=(IMAGE_SIZE, IMAGE_SIZE),
classes=FLOWER_CLASSES)
test_datagen = ImageDataGenerator()
test_generator = test_datagen.flow_from_directory(directory=TEST_PATH,
target_size=(IMAGE_SIZE, IMAGE_SIZE),
classes=FLOWER_CLASSES)
# 运行模型
history = model.fit_generator(train_generator, epochs=epochs, validation_data=test_generator)
需要说明的是,这些增强图片都是在内存中实时批量迭代生成的,不是一次性被读入内存,这样可以极大地节约内存空间,加快处理速度。若想保留中间过程生成的增强图片,可以在上述方法中添加保存路径等参数,此处不再赘述。
三、结论
本文介绍了如何在Keras中使用图像增强技术,对图片可以进行各种操作,以生成数倍于原图片的增强图片集。这些数据集可帮助我们有效地对抗过拟合问题,更好地生成理想的模型。
深度学习基础系列(十一)| Keras中图像增强技术详解的更多相关文章
- 深度学习基础系列(五)| 深入理解交叉熵函数及其在tensorflow和keras中的实现
在统计学中,损失函数是一种衡量损失和错误(这种损失与“错误地”估计有关,如费用或者设备的损失)程度的函数.假设某样本的实际输出为a,而预计的输出为y,则y与a之间存在偏差,深度学习的目的即是通过不断地 ...
- 深度学习基础系列(九)| Dropout VS Batch Normalization? 是时候放弃Dropout了
Dropout是过去几年非常流行的正则化技术,可有效防止过拟合的发生.但从深度学习的发展趋势看,Batch Normalizaton(简称BN)正在逐步取代Dropout技术,特别是在卷积层.本文将首 ...
- 转:LoadRunner中参数化技术详解
LoadRunner中参数化技术详解 LoadRunner在录制脚本的时候,只是忠实的记录了所有从客户端发送到服务器的数据,而在进行性能测试的时候,为了更接近真实的模拟现实应用,对于某些信息需要每次提 ...
- Tensorflow2(一)深度学习基础和tf.keras
代码和其他资料在 github 一.tf.keras概述 首先利用tf.keras实现一个简单的线性回归,如 \(f(x) = ax + b\),其中 \(x\) 代表学历,\(f(x)\) 代表收入 ...
- 深度学习基础系列(七)| Batch Normalization
Batch Normalization(批量标准化,简称BN)是近些年来深度学习优化中一个重要的手段.BN能带来如下优点: 加速训练过程: 可以使用较大的学习率: 允许在深层网络中使用sigmoid这 ...
- 深度学习基础系列(四)| 理解softmax函数
深度学习最终目的表现为解决分类或回归问题.在现实应用中,输出层我们大多采用softmax或sigmoid函数来输出分类概率值,其中二元分类可以应用sigmoid函数. 而在多元分类的问题中,我们默认采 ...
- 深度学习基础系列(十)| Global Average Pooling是否可以替代全连接层?
Global Average Pooling(简称GAP,全局池化层)技术最早提出是在这篇论文(第3.2节)中,被认为是可以替代全连接层的一种新技术.在keras发布的经典模型中,可以看到不少模型甚至 ...
- 深度学习基础系列(一)| 一文看懂用kersa构建模型的各层含义(掌握输出尺寸和可训练参数数量的计算方法)
我们在学习成熟网络模型时,如VGG.Inception.Resnet等,往往面临的第一个问题便是这些模型的各层参数是如何设置的呢?另外,我们如果要设计自己的网路模型时,又该如何设置各层参数呢?如果模型 ...
- 深度学习(PYTORCH)-3.sphereface-pytorch.lfw_eval.py详解
pytorch版本sphereface的原作者地址:https://github.com/clcarwin/sphereface_pytorch 由于接触深度学习不久,所以花了较长时间来阅读源码,以下 ...
随机推荐
- Redis(Remote Dictionary Server)入门
说说特性 存储结构:键值对支持多种数据类型,包括字符串类型,散列类型,列表类型,集合类型,有序集合类型. 内存存储与持久化:支持将内存中的数据异步写入磁盘中. 丰富的功能:支持为每个键值对设置生存时间 ...
- css3 加载动画效果
Loading 动画效果一 HTML 代码: <div class="spinner"> <div class="rect1&quo ...
- Django进阶(路由系统、中间件、缓存、Cookie和Session、Ajax发送数据
路由系统 1.每个路由规则对应一个view中的函数 url(r'^index/(\d*)', views.index), url(r'^manage/(?P<name>\w*)/(?P&l ...
- summernote 文本编辑器使用时,选择上传图片、链接、录像时,弹出的对话框被遮挡住
更多内容推荐微信公众号,欢迎关注: 即问题如下链接内的情况: http://bbs.csdn.net/topics/392004332 这个一般属于CSS中样式出现了问题,可以在点开的时候,F12查看 ...
- Linux configure关于交叉编译的参数设置【转】
转自:http://blog.csdn.net/darennet/article/details/9003005 configure的参数众多,一般包括如下 --srcdir=DIR 这个选项对安装没 ...
- Oracle基础结构认知—初识oracle【转】
Oracle服务器(oracle server)由实例和数据库组成.其中,实例就是所谓的关系型数据库管理系统(Relational Database Management System,RDBMS), ...
- MySQL 5.7.17 Group Relication(组复制)搭建手册【转】
本博文介绍了Group Replication的两种工作模式的架构.并详细介绍了Single-Master Mode的部署过程,以及如何切换到Multi-Master Mode.当然,文末给出了Gro ...
- oracle11g 创建id自增长监听器的步骤与问题
首先,我们通过sql/plus先建个TEST表 sql语句: CTEATE TABLE TEST( ID NUMBER, NAME VARCHAR2(20), PRIMARY KEY(ID) ); 通 ...
- ASP防XSS代码
原作是在GitHub上,基于Node.js所写.但是..ASP的JS引擎跟V8又有些不同..于是,嗯.. <% Function AntiXSS_VbsTrim(s) AntiXSS_VbsTr ...
- maven学习--生命周期
clean --清理项目 default --构建项目(最核心) ===========compile , test , package , install site --生成项目站点