在之前我们做过这样的研究:5图分类CBIR问题

各不相同的 5类的图形,每类100张
import numpy as np
from keras.datasets import mnist
import gc

from keras.models import Sequential, Model
from keras.layers import Input, Dense, Dropout, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.applications.vgg16 import VGG16
from keras.optimizers import SGD
from keras.utils.data_utils import get_file
import cv2
import h5py as h5py 
import numpy as np

import os
import math
from matplotlib import pyplot as plt

#全局变量
RATIO = 0.2
train_dir = 'D:/dl4cv/datesets/littleCBIR/'

#根据分类总数确定one-hot总类
NUM_DENSE = 5
#训练总数
epochs = 10

def tran_y(y): 
    y_ohe = np.zeros(NUM_DENSE) 
    y_ohe[y] = 1 
    return y_ohe

#根据Ratio获得训练和测试数据集的图片地址和标签
##生成数据集,本例先验3**汽车、4**恐龙、5**大象、6**花、7**马
def get_files(file_dir, ratio):
    '''
    Args:
        file_dir: file directory
    Returns:
        list of images and labels
    '''
    image_list = []
    label_list = []
    for file in os.listdir(file_dir):
        if file[0:1]=='3':
            image_list.append(file_dir + file)
            label_list.append(0)
        elif file[0:1]=='4':
            image_list.append(file_dir + file)
            label_list.append(1)
        elif file[0:1]=='5':
            image_list.append(file_dir + file)
            label_list.append(2)
        elif file[0:1]=='6':
            image_list.append(file_dir + file)
            label_list.append(3)
        else:
            image_list.append(file_dir + file)
            label_list.append(4)
    print('数据集导入完毕')
    #图片list和标签list
    #hstack 水平(按列顺序)把数组给堆叠起来
    image_list = np.hstack(image_list)
    label_list = np.hstack(label_list)
    
    temp = np.array([image_list, label_list])
    temp = temp.transpose()
    np.random.shuffle(temp)   
    
    all_image_list = temp[:, 0]
    all_label_list = temp[:, 1]
    
    n_sample = len(all_label_list)
    #根据比率,确定训练和测试数量
    n_val = math.ceil(n_sample*ratio) # number of validation samples
    n_train = n_sample - n_val # number of trainning samples
    tra_images = []
    val_images = []
    #按照0-n_train为tra_images,后面位val_images的方式来排序
    
    for index in range(n_train):
        image = cv2.imread(all_image_list[index])
        #灰度,然后缩放
        image = cv2.cvtColor(image,cv2.COLOR_RGB2GRAY)
        image = cv2.resize(image,(48,48))#到底在这个地方修改,还是在后面修改,需要做具体实验
        tra_images.append(image)

    tra_labels = all_label_list[:n_train]
    tra_labels = [int(float(i)) for i in tra_labels]

    for index in range(n_val):
        image = cv2.imread(all_image_list[n_train+index])
        #灰度,然后缩放
        image = cv2.cvtColor(image,cv2.COLOR_RGB2GRAY)
        image = cv2.resize(image,(32,32))
        val_images.append(image)

    val_labels = all_label_list[n_train:]
    val_labels = [int(float(i)) for i in val_labels]
    return np.array(tra_images),np.array(tra_labels),np.array(val_images),np.array(val_labels)

# colab+VGG要求至少48像素在现有数据集上,已经能够完成不错情况
ishape=48
#(X_train, y_train), (X_test, y_test) = mnist.load_data() 
#获得数据集
#X_train, y_train, X_test, y_test = get_files(train_dir, RATIO)
#保持数据
##np.savez("D:\\dl4cv\\datesets\\littleCBIR.npz",X_train=X_train,y_train=y_train,X_test=X_test,y_test=y_test)
#读取数据
 
path='littleCBIR.npz'
#https://github.com/jsxyhelu/GOCW/raw/master/littleCBIR.npz
path = get_file(path,origin='https://github.com/jsxyhelu/GOCW/raw/master/littleCBIR.npz')
f = np.load(path)
X_train, y_train = f['X_train'], f['y_train']
X_test, y_test = f['X_test'], f['y_test']

X_train = [cv2.cvtColor(cv2.resize(i, (ishape, ishape)), cv2.COLOR_GRAY2BGR) for i in X_train] 
X_train = np.concatenate([arr[np.newaxis] for arr in X_train]).astype('float32') 
X_train /= 255.0

X_test = [cv2.cvtColor(cv2.resize(i, (ishape, ishape)), cv2.COLOR_GRAY2BGR) for i in X_test] 
X_test = np.concatenate([arr[np.newaxis] for arr in X_test]).astype('float32')
X_test /= 255.0

y_train_ohe = np.array([tran_y(y_train[i]) for i in range(len(y_train))]) 
y_test_ohe = np.array([tran_y(y_test[i]) for i in range(len(y_test))])
y_train_ohe = y_train_ohe.astype('float32')
y_test_ohe = y_test_ohe.astype('float32')

model_vgg = VGG16(include_top = False, weights = 'imagenet', input_shape = (ishape, ishape, 3)) 
#for i, layer in enumerate(model_vgg.layers): 
#    if i<20:
for layer in model_vgg.layers:
        layer.trainable = False
model = Flatten()(model_vgg.output) 
model = Dense(4096, activation='relu', name='fc1')(model)
model = Dense(4096, activation='relu', name='fc2')(model)
model = Dropout(0.5)(model)
model = Dense(NUM_DENSE, activation = 'softmax', name='prediction')(model) 
model_vgg_pretrain = Model(model_vgg.input, model, name = 'vgg16_pretrain')
#model_vgg_pretrain.summary()
print("vgg准备完毕\n")
sgd = SGD(lr = 0.05, decay = 1e-5) 
model_vgg_pretrain.compile(loss = 'categorical_crossentropy', optimizer = sgd, metrics = ['accuracy'])
print("vgg开始训练\n")
log = model_vgg_pretrain.fit(X_train, y_train_ohe, validation_data = (X_test, y_test_ohe), epochs = epochs, batch_size = 64)

score = model_vgg_pretrain.evaluate(X_test, y_test_ohe, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

plt.figure('acc')  
plt.subplot(2, 1, 1)  
plt.plot(log.history['acc'],'r--',label='Training Accuracy')  
plt.plot(log.history['val_acc'],'r-',label='Validation Accuracy')  
plt.legend(loc='best')  
plt.xlabel('Epochs')  
plt.axis([0, epochs, 0.5, 1])  
plt.figure('loss')  
plt.subplot(2, 1, 2)  
plt.plot(log.history['loss'],'b--',label='Training Loss')  
plt.plot(log.history['val_loss'],'b-',label='Validation Loss')  
plt.legend(loc='best')  
plt.xlabel('Epochs')  
plt.axis([0, epochs, 0, 1])  
  
plt.show() 
os.system("pause")


epoch = 10

Test loss: 0.5930496269464492
Test accuracy: 0.83

epoch = 30

Test loss: 1.254318968951702
Test accuracy: 0.68



原图修正
Test loss: 0.2156761786714196
Test accuracy: 0.93

应该说,主要代码问题不大,也摸索出来了系列方法,但是需要进一步丰富。今天,我在之前的基础上继续研究,解决以下问题:
1、在5Type4CBIR(https://www.kaggle.com/jsxyhelu/5type4cbir)做训练;
2、使用aumentation方法,提高前期效果;
3、不仅仅是曲线,),
)
简单几行代码,就可以解决这里的问题。如果确实有效,那么将被复用;
在不generate下,达到95
Epoch 30/30
400/400 [==============================] - 15s 36ms/step - loss: 0.1731 - acc: 0.9375 - val_loss: 0.1910 - val_acc: 0.9500
Test loss: 0.19096931144595147
Test accuracy: 0.95

​在对generate进一步理解后,我认为现在的结果应该能够更高,经过夜间训练,得到这样结果
 
Epoch 10/10
18/400 [>.............................] - ETA: 20:44 - loss: 0.0221 - acc: 0.9935
400/400 [==============================] - 1275s 3s/step - loss: 0.0449 - acc: 0.9875 - val_loss: 0.0667 - val_acc: 0.9700
Test loss: 0.06673358775209635
Test accuracy: 0.97


其中,核心代码及参数为:
log = model_vgg_pretrain.fit_generator(img_generator.flow(X_train,y_train_ohe, batch_size= 128), steps_per_epoch = 400, epochs=10,validation_data=(X_test, y_test_ohe),workers=4)

反观数据,可以发现在epoch = 7 的时候,就已经达到最好;这个曲线较前面的训练更为平滑,总体质量也更高。98的准确率对于消费级的应用是一个很好的结果。
augumentation的方法将会被经常复用。
3、不仅仅是曲线,结果要以可视化方法绘制出来;
这里的意思也很直接,就是我要真实地去验证数据,要把模型结果显示出来,而不仅仅是在模型里面fit。这个显示首先是在juypter里面直观的观看。这个是给我来看的。一个是要在后面的结果应用中能够下载,并且被我现在的软件来使用。
但是目前这一步无法完成,因为前期我使用OpenCV读入图像、预处理、压制成npz格式的,目前我认为这个效果很好。OpenCV的Mat格式目前无法在Juypter中进行显示。
4、最终的结果要能够下载下来;
colab中有一段可以下载的代码的,翻阅了一些资料,解决了这个问题:

# Install the PyDrive wrapper & import libraries.
# This only needs to be done once in a notebook.
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# Authenticate and create the PyDrive client.
# This only needs to be done once in a notebook.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

# Create & upload a text file.
uploaded = drive.CreateFile()
uploaded.SetContentFile('5type4cbirMODEL.h5')
uploaded.Upload()
print('Uploaded file with ID {}'.format(uploaded.get('id')))



但是这样一个200m的东西下载下来也比较费时间,如果能够结合google driver直接使用是最好的。google driver就相当于colab的外存储。调用的代码应该类似这样:

# Install the PyDrive wrapper & import libraries.
# This only needs to be done once per notebook.
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# Authenticate and create the PyDrive client.
# This only needs to be done once per notebook.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)
#根据文件名进行下载
file_id = '1qjxAm_QiXdSqBmyIoPl3bfnyLNJxwKo9'
downloaded = drive.CreateFile({'id': file_id})
print('Downloaded content "{}"'.format(downloaded.GetContentString()))



​还是要通过id来进行调用。毕竟这个调用只会发生在内部情况下,可以理解。 


到了小结时间了:
昨天和今天在做CBIR的实验,得到目前这个结果,应该说能够解决一些问题,特别是模型训练这个方面的问题了,还是有很多收获的,当然主要的收获还是在代码里面:
1、agumenntation可以提高质量,fit_generate也是有技巧的(step_per_epoch),更不要说在数据集准备这块前期积累的技巧;
2、transform是可以用来解决垂直领域问题的,下一步就是需要验证;
3、研究问题的时候,首先是不能失去信心,然后就是要总结方法。


 

附件列表

基础_模型迁移_CBIR_augmentation的更多相关文章

  1. 使用 Azure PowerShell 将 IaaS 资源从经典部署模型迁移到 Azure Resource Manager

    以下步骤演示了如何使用 Azure PowerShell 命令将基础结构即服务 (IaaS) 资源从经典部署模型迁移到 Azure Resource Manager 部署模型. 也可根据需要通过 Az ...

  2. 老李分享: 并行计算基础&编程模型与工具 1

    老李分享: 并行计算基础&编程模型与工具   在当前计算机应用中,对高速并行计算的需求是广泛的,归纳起来,主要有三种类型的应用需求: 计算密集(Computer-Intensive)型应用,如 ...

  3. 算法基础_递归_求杨辉三角第m行第n个数字

    问题描述: 算法基础_递归_求杨辉三角第m行第n个数字(m,n都从0开始) 解题源代码(这里打印出的是杨辉三角某一层的所有数字,没用大数,所以有上限,这里只写基本逻辑,要符合题意的话,把循环去掉就好) ...

  4. 规划将 IaaS 资源从经典部署模型迁移到 Azure Resource Manager

    尽管 Azure 资源管理器提供了许多精彩功能,但请务必计划迁移,以确保一切顺利进行. 花时间进行规划可确保执行迁移活动时不会遇到问题. Note 以下指导的主要参与者为 Azure 客户顾问团队,以 ...

  5. 有关从经典部署模型迁移到 Azure Resource Manager 部署模型的常见问题

    此迁移计划是否影响 Azure 虚拟机上运行的任何现有服务或应用程序? 不可以. VM(经典)是公开上市的完全受支持的服务. 你可以继续使用这些资源来拓展你在 Azure 上的足迹. 如果我近期不打算 ...

  6. 桥接模式_NAT模式_仅主机模式_模型图.ziw

      2017年1月12日, 星期四 桥接模式_NAT模式_仅主机模式_模型图   null

  7. 使用 Azure CLI 将 IaaS 资源从经典部署模型迁移到 Azure Resource Manager 部署模型

    以下步骤演示如何使用 Azure 命令行接口 (CLI) 命令将基础结构即服务 (IaaS) 资源从经典部署模型迁移到 Azure Resource Manager 部署模型. 本文中的操作需要 Az ...

  8. Flutter实战视频-移动电商-05.Dio基础_引入和简单的Get请求

    05.Dio基础_引入和简单的Get请求 博客地址: https://jspang.com/post/FlutterShop.html#toc-4c7 第三方的http请求库叫做Dio https:/ ...

  9. Flutter实战视频-移动电商-08.Dio基础_伪造请求头获取数据

    08.Dio基础_伪造请求头获取数据 上节课代码清楚 重新编写HomePage这个动态组件 开始写请求的方法 请求数据 .但是由于我们没加请求的头 所以没有返回数据 451就是表示请求错错误 创建请求 ...

随机推荐

  1. 网络编程之Socket的TCP协议实现客户端与客户端之间的通信

    我认为当你学完某个知识点后,最好是做一个实实在在的小案例.这样才能更好对知识的运用与掌握 如果你看了我前两篇关于socket通信原理的入门文章.我相信对于做出我这个小案列是完全没有问题的!! 既然是小 ...

  2. 数据加密之RijndaelManaged加密

    #region RijndaelManaged加密 /// <summary> /// 加密数据 /// </summary> /// <param name=" ...

  3. sqli-labs(六)

    第十一关: 这关是一个登陆口,也是一个sql注入的漏洞,也就是常说的万能密码. 在输入框账号密码种分别输入 1'  和1'  页面会报错. 后台使用的单引符号进行的拼接.账号输入1' or '1'=' ...

  4. 使用Spark下的corr计算皮尔森相似度Pearson时,报错Can only zip RDDs with same number of elements in each partition....

    package com.huawei.bigdata.spark.examples import org.apache.spark.mllib.stat.Statistics import org.a ...

  5. jdk8新特性-亮瞎眼的lambda表达式

    jdk8之前,尤其是在写GUI程序的事件监听的时候,各种的匿名内部类,大把大把拖沓的代码,程序毫无美感可言!既然Java中一切皆为对象,那么,就类似于某些动态语言一样,函数也可以当成是对象啊!代码块也 ...

  6. jdbc连接oracle时使用的字符串格式

    两种方式: SID的方式: jdbc:oracle:thin:@[HOST][:PORT]:SID SERVICE_NAME的方式: jdbc:oracle:thin:@//[HOST][:PORT] ...

  7. 关于kingoroot这款软件

    弃了饱受诟病的kingroot系列的软件,又出现了一款名为kingoroot的软件. 大约一年之前用过kingoroot的apk版,成功为我的手机root了,而且其行为也并不是那么流氓,所以当时对其很 ...

  8. 软件常用设置(VC, eclipse ,nodejs)---自己备用

    留存复制使用 1.VC ----1.1VC项目设置 输出目录: $(SolutionDir)../bin/$(platform)/$(Configuration) $(ProjectDir)../bi ...

  9. 终极解决liunx GUI 无法显示中文的问题。

    为linux安装字体 Linux字体文件放在/usr/share/font/,只要将字体文件拷贝到这里就可以了.这里示例安装Windows的所有字体. 2,复制Windows下 的所有字体.cd命令切 ...

  10. django models数据库操作

    一.数据库操作 1.创建model表         基本结构 1 2 3 4 5 6 from django.db import models     class userinfo(models.M ...