基础_模型迁移_CBIR_augmentation
在之前我们做过这样的研究:5图分类CBIR问题
import numpy as np
from keras.datasets import mnist
import gc
from keras.models import Sequential, Model
from keras.layers import Input, Dense, Dropout, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.applications.vgg16 import VGG16
from keras.optimizers import SGD
from keras.utils.data_utils import get_file
import cv2
import h5py as h5py
import numpy as np
import os
import math
from matplotlib import pyplot as plt
#全局变量
RATIO = 0.2
train_dir = 'D:/dl4cv/datesets/littleCBIR/'
#根据分类总数确定one-hot总类
NUM_DENSE = 5
#训练总数
epochs = 10
def tran_y(y):
y_ohe = np.zeros(NUM_DENSE)
y_ohe[y] = 1
return y_ohe
#根据Ratio获得训练和测试数据集的图片地址和标签
##生成数据集,本例先验3**汽车、4**恐龙、5**大象、6**花、7**马
def get_files(file_dir, ratio):
'''
Args:
file_dir: file directory
Returns:
list of images and labels
'''
image_list = []
label_list = []
for file in os.listdir(file_dir):
if file[0:1]=='3':
image_list.append(file_dir + file)
label_list.append(0)
elif file[0:1]=='4':
image_list.append(file_dir + file)
label_list.append(1)
elif file[0:1]=='5':
image_list.append(file_dir + file)
label_list.append(2)
elif file[0:1]=='6':
image_list.append(file_dir + file)
label_list.append(3)
else:
image_list.append(file_dir + file)
label_list.append(4)
print('数据集导入完毕')
#图片list和标签list
#hstack 水平(按列顺序)把数组给堆叠起来
image_list = np.hstack(image_list)
label_list = np.hstack(label_list)
temp = np.array([image_list, label_list])
temp = temp.transpose()
np.random.shuffle(temp)
all_image_list = temp[:, 0]
all_label_list = temp[:, 1]
n_sample = len(all_label_list)
#根据比率,确定训练和测试数量
n_val = math.ceil(n_sample*ratio) # number of validation samples
n_train = n_sample - n_val # number of trainning samples
tra_images = []
val_images = []
#按照0-n_train为tra_images,后面位val_images的方式来排序
for index in range(n_train):
image = cv2.imread(all_image_list[index])
#灰度,然后缩放
image = cv2.cvtColor(image,cv2.COLOR_RGB2GRAY)
image = cv2.resize(image,(48,48))#到底在这个地方修改,还是在后面修改,需要做具体实验
tra_images.append(image)
tra_labels = all_label_list[:n_train]
tra_labels = [int(float(i)) for i in tra_labels]
for index in range(n_val):
image = cv2.imread(all_image_list[n_train+index])
#灰度,然后缩放
image = cv2.cvtColor(image,cv2.COLOR_RGB2GRAY)
image = cv2.resize(image,(32,32))
val_images.append(image)
val_labels = all_label_list[n_train:]
val_labels = [int(float(i)) for i in val_labels]
return np.array(tra_images),np.array(tra_labels),np.array(val_images),np.array(val_labels)
# colab+VGG要求至少48像素在现有数据集上,已经能够完成不错情况
ishape=48
#(X_train, y_train), (X_test, y_test) = mnist.load_data()
#获得数据集
#X_train, y_train, X_test, y_test = get_files(train_dir, RATIO)
#保持数据
##np.savez("D:\\dl4cv\\datesets\\littleCBIR.npz",X_train=X_train,y_train=y_train,X_test=X_test,y_test=y_test)
#读取数据
path='littleCBIR.npz'
#https://github.com/jsxyhelu/GOCW/raw/master/littleCBIR.npz
path = get_file(path,origin='https://github.com/jsxyhelu/GOCW/raw/master/littleCBIR.npz')
f = np.load(path)
X_train, y_train = f['X_train'], f['y_train']
X_test, y_test = f['X_test'], f['y_test']
X_train = [cv2.cvtColor(cv2.resize(i, (ishape, ishape)), cv2.COLOR_GRAY2BGR) for i in X_train]
X_train = np.concatenate([arr[np.newaxis] for arr in X_train]).astype('float32')
X_train /= 255.0
X_test = [cv2.cvtColor(cv2.resize(i, (ishape, ishape)), cv2.COLOR_GRAY2BGR) for i in X_test]
X_test = np.concatenate([arr[np.newaxis] for arr in X_test]).astype('float32')
X_test /= 255.0
y_train_ohe = np.array([tran_y(y_train[i]) for i in range(len(y_train))])
y_test_ohe = np.array([tran_y(y_test[i]) for i in range(len(y_test))])
y_train_ohe = y_train_ohe.astype('float32')
y_test_ohe = y_test_ohe.astype('float32')
model_vgg = VGG16(include_top = False, weights = 'imagenet', input_shape = (ishape, ishape, 3))
#for i, layer in enumerate(model_vgg.layers):
# if i<20:
for layer in model_vgg.layers:
layer.trainable = False
model = Flatten()(model_vgg.output)
model = Dense(4096, activation='relu', name='fc1')(model)
model = Dense(4096, activation='relu', name='fc2')(model)
model = Dropout(0.5)(model)
model = Dense(NUM_DENSE, activation = 'softmax', name='prediction')(model)
model_vgg_pretrain = Model(model_vgg.input, model, name = 'vgg16_pretrain')
#model_vgg_pretrain.summary()
print("vgg准备完毕\n")
sgd = SGD(lr = 0.05, decay = 1e-5)
model_vgg_pretrain.compile(loss = 'categorical_crossentropy', optimizer = sgd, metrics = ['accuracy'])
print("vgg开始训练\n")
log = model_vgg_pretrain.fit(X_train, y_train_ohe, validation_data = (X_test, y_test_ohe), epochs = epochs, batch_size = 64)
score = model_vgg_pretrain.evaluate(X_test, y_test_ohe, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
plt.figure('acc')
plt.subplot(2, 1, 1)
plt.plot(log.history['acc'],'r--',label='Training Accuracy')
plt.plot(log.history['val_acc'],'r-',label='Validation Accuracy')
plt.legend(loc='best')
plt.xlabel('Epochs')
plt.axis([0, epochs, 0.5, 1])
plt.figure('loss')
plt.subplot(2, 1, 2)
plt.plot(log.history['loss'],'b--',label='Training Loss')
plt.plot(log.history['val_loss'],'b-',label='Validation Loss')
plt.legend(loc='best')
plt.xlabel('Epochs')
plt.axis([0, epochs, 0, 1])
plt.show()
os.system("pause")
log = model_vgg_pretrain.fit_generator(img_generator.flow(X_train,y_train_ohe, batch_size= 128), steps_per_epoch = 400, epochs=10,validation_data=(X_test, y_test_ohe),workers=4)
# Install the PyDrive wrapper & import libraries.
# This only needs to be done once in a notebook.
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
# Authenticate and create the PyDrive client.
# This only needs to be done once in a notebook.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)
# Create & upload a text file.
uploaded = drive.CreateFile()
uploaded.SetContentFile('5type4cbirMODEL.h5')
uploaded.Upload()
print('Uploaded file with ID {}'.format(uploaded.get('id')))

# Install the PyDrive wrapper & import libraries.
# This only needs to be done once per notebook.
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
# Authenticate and create the PyDrive client.
# This only needs to be done once per notebook.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)
#根据文件名进行下载
file_id = '1qjxAm_QiXdSqBmyIoPl3bfnyLNJxwKo9'
downloaded = drive.CreateFile({'id': file_id})
print('Downloaded content "{}"'.format(downloaded.GetContentString()))
附件列表
基础_模型迁移_CBIR_augmentation的更多相关文章
- 使用 Azure PowerShell 将 IaaS 资源从经典部署模型迁移到 Azure Resource Manager
以下步骤演示了如何使用 Azure PowerShell 命令将基础结构即服务 (IaaS) 资源从经典部署模型迁移到 Azure Resource Manager 部署模型. 也可根据需要通过 Az ...
- 老李分享: 并行计算基础&编程模型与工具 1
老李分享: 并行计算基础&编程模型与工具 在当前计算机应用中,对高速并行计算的需求是广泛的,归纳起来,主要有三种类型的应用需求: 计算密集(Computer-Intensive)型应用,如 ...
- 算法基础_递归_求杨辉三角第m行第n个数字
问题描述: 算法基础_递归_求杨辉三角第m行第n个数字(m,n都从0开始) 解题源代码(这里打印出的是杨辉三角某一层的所有数字,没用大数,所以有上限,这里只写基本逻辑,要符合题意的话,把循环去掉就好) ...
- 规划将 IaaS 资源从经典部署模型迁移到 Azure Resource Manager
尽管 Azure 资源管理器提供了许多精彩功能,但请务必计划迁移,以确保一切顺利进行. 花时间进行规划可确保执行迁移活动时不会遇到问题. Note 以下指导的主要参与者为 Azure 客户顾问团队,以 ...
- 有关从经典部署模型迁移到 Azure Resource Manager 部署模型的常见问题
此迁移计划是否影响 Azure 虚拟机上运行的任何现有服务或应用程序? 不可以. VM(经典)是公开上市的完全受支持的服务. 你可以继续使用这些资源来拓展你在 Azure 上的足迹. 如果我近期不打算 ...
- 桥接模式_NAT模式_仅主机模式_模型图.ziw
2017年1月12日, 星期四 桥接模式_NAT模式_仅主机模式_模型图 null
- 使用 Azure CLI 将 IaaS 资源从经典部署模型迁移到 Azure Resource Manager 部署模型
以下步骤演示如何使用 Azure 命令行接口 (CLI) 命令将基础结构即服务 (IaaS) 资源从经典部署模型迁移到 Azure Resource Manager 部署模型. 本文中的操作需要 Az ...
- Flutter实战视频-移动电商-05.Dio基础_引入和简单的Get请求
05.Dio基础_引入和简单的Get请求 博客地址: https://jspang.com/post/FlutterShop.html#toc-4c7 第三方的http请求库叫做Dio https:/ ...
- Flutter实战视频-移动电商-08.Dio基础_伪造请求头获取数据
08.Dio基础_伪造请求头获取数据 上节课代码清楚 重新编写HomePage这个动态组件 开始写请求的方法 请求数据 .但是由于我们没加请求的头 所以没有返回数据 451就是表示请求错错误 创建请求 ...
随机推荐
- Go linux 实践2
今天,看看GO的高级语言特性-方法和接口 废话不多说,直接上代码 ************************************************* 1 package main 2 ...
- iOS UIViewController生命周期控制
具体流程,看下图: init方法在init方法中实例化必要的对象(遵从LazyLoad思想)init方法中初始化ViewController本身 loadView方法当view需要被展示而它却是nil ...
- java基础(三) -基本数据类型
变量就是申请内存来存储值.也就是说,当创建变量的时候,需要在内存中申请空间. 内存管理系统根据变量的类型为变量分配存储空间,分配的空间只能用来储存该类型数据. 因此,通过定义不同类型的变量,可以在内存 ...
- Selenium基础知识(九)验证码
关于Selenium处理验证码总结下: 1.去掉验证码(这个为了测试,去掉可能性不大) 2.万能验证码(让开发给做一个万能验证码,可能性也不大) 3.OCR光学识别,python包Python-tes ...
- div上下左右居中方法
方法一:在浏览器中只有一个div的情况 { position:fixed; position:fixed; ; ; ; ; margin:auto; } 方法一 方法二:一个父元素div和一个已知宽度 ...
- 域名重新绑定ip时访问地址NotFount404
情形描述:部署在A服务器IIS上的asp.net程序,搬迁到B服务器上,重新绑定域名和ip后.再访问网址时有些电脑能正常访问,而有些电脑报404 not found错误. 经分析发现是个人电脑网络设置 ...
- keras如何求分类问题中的准确率和召回率
https://www.zhihu.com/question/53294625 由于要用keras做一个多分类的问题,评价标准采用precision,recall,和f1_score:但是keras中 ...
- SQL性能优化前期准备-清除缓存、开启IO统计
文章来至:https://www.cnblogs.com/Ren_Lei/p/5669662.html 如果需要进行SQl Server下的SQL性能优化,需要准备以下内容: 一.SQL查询分析器设置 ...
- Azure Messaging-ServiceBus Messaging消息队列技术系列1-基本概念和架构
前段时间研究了Window Azure ServiceBus Messaging消息队列技术,搞了很多技术研究和代码验证,最近准备总结一下,分享给大家. 首先,Windows Azure提供了两种类型 ...
- Linux查看磁盘空间大小命令
Linux查看磁盘空间大小命令df -hldf -hl 查看磁盘剩余空间df -h 查看每个根路径的分区大小du -sh [目录名] 返回该目录的大小du -sm [文件夹] 返回该文件夹总M数du ...