data augmentation 几种方法总结

在深度学习中,有的时候训练集不够多,或者某一类数据较少,或者为了防止过拟合,让模型更加鲁棒性,data augmentation是一个不错的选择。

常见方法

Color Jittering:对颜色的数据增强:图像亮度、饱和度、对比度变化(此处对色彩抖动的理解不知是否得当);

PCA Jittering:首先按照RGB三个颜色通道计算均值和标准差,再在整个训练集上计算协方差矩阵,进行特征分解,得到特征向量和特征值,用来做PCA Jittering;

Random Scale:尺度变换;

Random Crop:采用随机图像差值方式,对图像进行裁剪、缩放;包括Scale Jittering方法(VGG及ResNet模型使用)或者尺度和长宽比增强变换;

Horizontal/Vertical Flip:水平/垂直翻转;

Shift:平移变换;

Rotation/Reflection:旋转/仿射变换;

Noise:高斯噪声、模糊处理;

Label shuffle:类别不平衡数据的增广,参见海康威视ILSVRC2016的report;另外,文中提出了一种Supervised Data Augmentation方法,有兴趣的朋友的可以动手实验下。

部分方法的具体实现

# -*- coding:utf-8 -*-
"""数据增强
1. 翻转变换 flip
2. 随机修剪 random crop
3. 色彩抖动 color jittering
4. 平移变换 shift
5. 尺度变换 scale
6. 对比度变换 contrast
7. 噪声扰动 noise
8. 旋转变换/反射变换 Rotation/reflection
""" from PIL import Image, ImageEnhance, ImageOps, ImageFile
import numpy as np
import random
import threading, os, time
import logging logger = logging.getLogger(__name__)
ImageFile.LOAD_TRUNCATED_IMAGES = True class DataAugmentation:
"""
包含数据增强的八种方式
""" def __init__(self):
pass @staticmethod
def openImage(image):
return Image.open(image, mode="r") @staticmethod
def randomRotation(image, mode=Image.BICUBIC):
"""
对图像进行随机任意角度(0~360度)旋转
:param mode 邻近插值,双线性插值,双三次B样条插值(default)
:param image PIL的图像image
:return: 旋转转之后的图像
"""
random_angle = np.random.randint(1, 360)
return image.rotate(random_angle, mode) @staticmethod
def randomCrop(image):
"""
对图像随意剪切,考虑到图像大小范围(68,68),使用一个一个大于(36*36)的窗口进行截图
:param image: PIL的图像image
:return: 剪切之后的图像 """
image_width = image.size[0]
image_height = image.size[1]
crop_win_size = np.random.randint(40, 68)
random_region = (
(image_width - crop_win_size) >> 1, (image_height - crop_win_size) >> 1, (image_width + crop_win_size) >> 1,
(image_height + crop_win_size) >> 1)
return image.crop(random_region) @staticmethod
def randomColor(image):
"""
对图像进行颜色抖动
:param image: PIL的图像image
:return: 有颜色色差的图像image
"""
random_factor = np.random.randint(0, 31) / 10. # 随机因子
color_image = ImageEnhance.Color(image).enhance(random_factor) # 调整图像的饱和度
random_factor = np.random.randint(10, 21) / 10. # 随机因子
brightness_image = ImageEnhance.Brightness(color_image).enhance(random_factor) # 调整图像的亮度
random_factor = np.random.randint(10, 21) / 10. # 随机因1子
contrast_image = ImageEnhance.Contrast(brightness_image).enhance(random_factor) # 调整图像对比度
random_factor = np.random.randint(0, 31) / 10. # 随机因子
return ImageEnhance.Sharpness(contrast_image).enhance(random_factor) # 调整图像锐度 @staticmethod
def randomGaussian(image, mean=0.2, sigma=0.3):
"""
对图像进行高斯噪声处理
:param image:
:return:
""" def gaussianNoisy(im, mean=0.2, sigma=0.3):
"""
对图像做高斯噪音处理
:param im: 单通道图像
:param mean: 偏移量
:param sigma: 标准差
:return:
"""
for _i in range(len(im)):
im[_i] += random.gauss(mean, sigma)
return im # 将图像转化成数组
img = np.asarray(image)
img.flags.writeable = True # 将数组改为读写模式
width, height = img.shape[:2]
img_r = gaussianNoisy(img[:, :, 0].flatten(), mean, sigma)
img_g = gaussianNoisy(img[:, :, 1].flatten(), mean, sigma)
img_b = gaussianNoisy(img[:, :, 2].flatten(), mean, sigma)
img[:, :, 0] = img_r.reshape([width, height])
img[:, :, 1] = img_g.reshape([width, height])
img[:, :, 2] = img_b.reshape([width, height])
return Image.fromarray(np.uint8(img)) @staticmethod
def saveImage(image, path):
image.save(path) def makeDir(path):
try:
if not os.path.exists(path):
if not os.path.isfile(path):
# os.mkdir(path)
os.makedirs(path)
return 0
else:
return 1
except Exception, e:
print str(e)
return -2 def imageOps(func_name, image, des_path, file_name, times=5):
funcMap = {"randomRotation": DataAugmentation.randomRotation,
"randomCrop": DataAugmentation.randomCrop,
"randomColor": DataAugmentation.randomColor,
"randomGaussian": DataAugmentation.randomGaussian
}
if funcMap.get(func_name) is None:
logger.error("%s is not exist", func_name)
return -1 for _i in range(0, times, 1):
new_image = funcMap[func_name](image)
DataAugmentation.saveImage(new_image, os.path.join(des_path, func_name + str(_i) + file_name)) opsList = {"randomRotation", "randomCrop", "randomColor", "randomGaussian"} def threadOPS(path, new_path):
"""
多线程处理事务
:param src_path: 资源文件
:param des_path: 目的地文件
:return:
"""
if os.path.isdir(path):
img_names = os.listdir(path)
else:
img_names = [path]
for img_name in img_names:
print img_name
tmp_img_name = os.path.join(path, img_name)
if os.path.isdir(tmp_img_name):
if makeDir(os.path.join(new_path, img_name)) != -1:
threadOPS(tmp_img_name, os.path.join(new_path, img_name))
else:
print 'create new dir failure'
return -1
# os.removedirs(tmp_img_name)
elif tmp_img_name.split('.')[1] != "DS_Store":
# 读取文件并进行操作
image = DataAugmentation.openImage(tmp_img_name)
threadImage = [0] * 5
_index = 0
for ops_name in opsList:
threadImage[_index] = threading.Thread(target=imageOps,
args=(ops_name, image, new_path, img_name,))
threadImage[_index].start()
_index += 1
time.sleep(0.2) if __name__ == '__main__':
threadOPS("/home/pic-image/train/12306train",
"/home/pic-image/train/12306train3")

参考文献

深度学习之图像的数据增强

知乎

data augmentation 总结的更多相关文章

  1. 深度学习中的Data Augmentation方法(转)基于keras

    在深度学习中,当数据量不够大时候,常常采用下面4中方法: 1. 人工增加训练集的大小. 通过平移, 翻转, 加噪声等方法从已有数据中创造出一批"新"的数据.也就是Data Augm ...

  2. 常见的数据扩充(data augmentation)方法

    G~L~M~R~S 一.data augmentation 常见的数据扩充(data augmentation)方法:文中图片均来自吴恩达教授的deeplearning.ai课程 1.Mirrorin ...

  3. (转)AutoML for Data Augmentation

    AutoML for Data Augmentation 2019-04-01 09:26:19 This blog is copied from: https://blog.insightdatas ...

  4. 图像数据增强 (Data Augmentation in Computer Vision)

    1.1 简介 深层神经网络一般都需要大量的训练数据才能获得比较理想的结果.在数据量有限的情况下,可以通过数据增强(Data Augmentation)来增加训练样本的多样性, 提高模型鲁棒性,避免过拟 ...

  5. Keras Data augmentation(数据扩充)

    在深度学习中,我们经常需要用到一些技巧(比如将图片进行旋转,翻转等)来进行data augmentation, 来减少过拟合. 在本文中,我们将主要介绍如何用深度学习框架keras来自动的进行data ...

  6. keras对图像数据进行增强 | keras data augmentation

    本文首发于个人博客https://kezunlin.me/post/8db507ff/,欢迎阅读最新内容! keras data augmentation Guide code # import th ...

  7. paper 147:Deep Learning -- Face Data Augmentation(一)

    1. 在深度学习中,当数据量不够大时候,常常采用下面4中方法:  (1)人工增加训练集的大小. 通过平移, 翻转, 加噪声等方法从已有数据中创造出一批"新"的数据.也就是Data ...

  8. 【48】数据扩充(Data augmentation)

    数据扩充(Data augmentation) 大部分的计算机视觉任务使用很多的数据,所以数据扩充是经常使用的一种技巧来提高计算机视觉系统的表现.我认为计算机视觉是一个相当复杂的工作,你需要输入图像的 ...

  9. Regularizing Deep Networks with Semantic Data Augmentation

    目录 概 主要内容 代码 Wang Y., Huang G., Song S., Pan X., Xia Y. and Wu C. Regularizing Deep Networks with Se ...

随机推荐

  1. cxGrid 循环选择条目

    Delphi DevExpress CxGrid 循环选择条目 整理出来的,直接复制粘贴即可使用 以下是从网络上复制粘帖到的,实践证明,利用以下代码进行获取选择行是错误的. 当我们利用 CxGrid进 ...

  2. hdu3729(二分图)

    比赛的时候没有想到二分图,一直在想dp和贪心. 原因是因为看到数据是100000所以直接就没有往二分图匹配上想. 现在想想. 因为二分图两边的太不对称了,60 和100000 , 如果用匈牙利算法考虑 ...

  3. SpringMVC如何接收json数据

    请求头:Content-Type=application/json数据如: {"mobile":"12345678912","smsContent&q ...

  4. cocos2d-X学习之主要类介绍:动作:CCAction

    引用自:http://www.cnblogs.com/lhming/archive/2012/07/01/2572238.html 类继承图: 主要函数: virtual CCObject *  co ...

  5. Java程序员面试题集(1-50

    下面的内容是对网上原有的Java面试题集及答案进行了全面修订之后给出的负责任的题目和答案,原来的题目中有很多重复题目和无价值的题目,还有不少的参考答案也是错误的,修改后的Java面试题集参照了JDK最 ...

  6. python多线程锁lock/Rlock/BoundedSemaphore/Condition/Event

    import time import threading lock = threading.RLock() n = 10 def task(arg): # 加锁,此区域的代码同一时刻只能有一个线程执行 ...

  7. 我的Android进阶之旅------>解决AES加密报错:java.security.InvalidKeyException: Unsupported key size: 18 bytes

    1.错误描述 今天使用AES进行加密时候,报错如下所示: 04-21 11:08:18.087 27501-27501/com.xtc.watch E/AESUtil.decryptAES:55: j ...

  8. MySQL数据库(5)- pymysql的使用、索引

    一.pymysql模块的使用 1.pymysql的下载和使用 之前我们都是通过MySQL自带的命令行客户端工具mysql来操作数据库,那如何在python程序中操作数据库呢?这就需要用到pymysql ...

  9. golang 实现并发计算文件数量

    package main import ( "fmt" "io/ioutil" "os" ) func listDir(path strin ...

  10. 2 TensorFlow入门笔记之建造神经网络并将结果可视化

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...