基础_模型迁移_CBIR_augmentation
在之前我们做过这样的研究:5图分类CBIR问题
import numpy as np
from keras.datasets import mnist
import gc
from keras.models import Sequential, Model
from keras.layers import Input, Dense, Dropout, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.applications.vgg16 import VGG16
from keras.optimizers import SGD
from keras.utils.data_utils import get_file
import cv2
import h5py as h5py
import numpy as np
import os
import math
from matplotlib import pyplot as plt
#全局变量
RATIO = 0.2
train_dir = 'D:/dl4cv/datesets/littleCBIR/'
#根据分类总数确定one-hot总类
NUM_DENSE = 5
#训练总数
epochs = 10
def tran_y(y):
y_ohe = np.zeros(NUM_DENSE)
y_ohe[y] = 1
return y_ohe
#根据Ratio获得训练和测试数据集的图片地址和标签
##生成数据集,本例先验3**汽车、4**恐龙、5**大象、6**花、7**马
def get_files(file_dir, ratio):
'''
Args:
file_dir: file directory
Returns:
list of images and labels
'''
image_list = []
label_list = []
for file in os.listdir(file_dir):
if file[0:1]=='3':
image_list.append(file_dir + file)
label_list.append(0)
elif file[0:1]=='4':
image_list.append(file_dir + file)
label_list.append(1)
elif file[0:1]=='5':
image_list.append(file_dir + file)
label_list.append(2)
elif file[0:1]=='6':
image_list.append(file_dir + file)
label_list.append(3)
else:
image_list.append(file_dir + file)
label_list.append(4)
print('数据集导入完毕')
#图片list和标签list
#hstack 水平(按列顺序)把数组给堆叠起来
image_list = np.hstack(image_list)
label_list = np.hstack(label_list)
temp = np.array([image_list, label_list])
temp = temp.transpose()
np.random.shuffle(temp)
all_image_list = temp[:, 0]
all_label_list = temp[:, 1]
n_sample = len(all_label_list)
#根据比率,确定训练和测试数量
n_val = math.ceil(n_sample*ratio) # number of validation samples
n_train = n_sample - n_val # number of trainning samples
tra_images = []
val_images = []
#按照0-n_train为tra_images,后面位val_images的方式来排序
for index in range(n_train):
image = cv2.imread(all_image_list[index])
#灰度,然后缩放
image = cv2.cvtColor(image,cv2.COLOR_RGB2GRAY)
image = cv2.resize(image,(48,48))#到底在这个地方修改,还是在后面修改,需要做具体实验
tra_images.append(image)
tra_labels = all_label_list[:n_train]
tra_labels = [int(float(i)) for i in tra_labels]
for index in range(n_val):
image = cv2.imread(all_image_list[n_train+index])
#灰度,然后缩放
image = cv2.cvtColor(image,cv2.COLOR_RGB2GRAY)
image = cv2.resize(image,(32,32))
val_images.append(image)
val_labels = all_label_list[n_train:]
val_labels = [int(float(i)) for i in val_labels]
return np.array(tra_images),np.array(tra_labels),np.array(val_images),np.array(val_labels)
# colab+VGG要求至少48像素在现有数据集上,已经能够完成不错情况
ishape=48
#(X_train, y_train), (X_test, y_test) = mnist.load_data()
#获得数据集
#X_train, y_train, X_test, y_test = get_files(train_dir, RATIO)
#保持数据
##np.savez("D:\\dl4cv\\datesets\\littleCBIR.npz",X_train=X_train,y_train=y_train,X_test=X_test,y_test=y_test)
#读取数据
path='littleCBIR.npz'
#https://github.com/jsxyhelu/GOCW/raw/master/littleCBIR.npz
path = get_file(path,origin='https://github.com/jsxyhelu/GOCW/raw/master/littleCBIR.npz')
f = np.load(path)
X_train, y_train = f['X_train'], f['y_train']
X_test, y_test = f['X_test'], f['y_test']
X_train = [cv2.cvtColor(cv2.resize(i, (ishape, ishape)), cv2.COLOR_GRAY2BGR) for i in X_train]
X_train = np.concatenate([arr[np.newaxis] for arr in X_train]).astype('float32')
X_train /= 255.0
X_test = [cv2.cvtColor(cv2.resize(i, (ishape, ishape)), cv2.COLOR_GRAY2BGR) for i in X_test]
X_test = np.concatenate([arr[np.newaxis] for arr in X_test]).astype('float32')
X_test /= 255.0
y_train_ohe = np.array([tran_y(y_train[i]) for i in range(len(y_train))])
y_test_ohe = np.array([tran_y(y_test[i]) for i in range(len(y_test))])
y_train_ohe = y_train_ohe.astype('float32')
y_test_ohe = y_test_ohe.astype('float32')
model_vgg = VGG16(include_top = False, weights = 'imagenet', input_shape = (ishape, ishape, 3))
#for i, layer in enumerate(model_vgg.layers):
# if i<20:
for layer in model_vgg.layers:
layer.trainable = False
model = Flatten()(model_vgg.output)
model = Dense(4096, activation='relu', name='fc1')(model)
model = Dense(4096, activation='relu', name='fc2')(model)
model = Dropout(0.5)(model)
model = Dense(NUM_DENSE, activation = 'softmax', name='prediction')(model)
model_vgg_pretrain = Model(model_vgg.input, model, name = 'vgg16_pretrain')
#model_vgg_pretrain.summary()
print("vgg准备完毕\n")
sgd = SGD(lr = 0.05, decay = 1e-5)
model_vgg_pretrain.compile(loss = 'categorical_crossentropy', optimizer = sgd, metrics = ['accuracy'])
print("vgg开始训练\n")
log = model_vgg_pretrain.fit(X_train, y_train_ohe, validation_data = (X_test, y_test_ohe), epochs = epochs, batch_size = 64)
score = model_vgg_pretrain.evaluate(X_test, y_test_ohe, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
plt.figure('acc')
plt.subplot(2, 1, 1)
plt.plot(log.history['acc'],'r--',label='Training Accuracy')
plt.plot(log.history['val_acc'],'r-',label='Validation Accuracy')
plt.legend(loc='best')
plt.xlabel('Epochs')
plt.axis([0, epochs, 0.5, 1])
plt.figure('loss')
plt.subplot(2, 1, 2)
plt.plot(log.history['loss'],'b--',label='Training Loss')
plt.plot(log.history['val_loss'],'b-',label='Validation Loss')
plt.legend(loc='best')
plt.xlabel('Epochs')
plt.axis([0, epochs, 0, 1])
plt.show()
os.system("pause")
log = model_vgg_pretrain.fit_generator(img_generator.flow(X_train,y_train_ohe, batch_size= 128), steps_per_epoch = 400, epochs=10,validation_data=(X_test, y_test_ohe),workers=4)
# Install the PyDrive wrapper & import libraries.
# This only needs to be done once in a notebook.
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
# Authenticate and create the PyDrive client.
# This only needs to be done once in a notebook.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)
# Create & upload a text file.
uploaded = drive.CreateFile()
uploaded.SetContentFile('5type4cbirMODEL.h5')
uploaded.Upload()
print('Uploaded file with ID {}'.format(uploaded.get('id')))

# Install the PyDrive wrapper & import libraries.
# This only needs to be done once per notebook.
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
# Authenticate and create the PyDrive client.
# This only needs to be done once per notebook.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)
#根据文件名进行下载
file_id = '1qjxAm_QiXdSqBmyIoPl3bfnyLNJxwKo9'
downloaded = drive.CreateFile({'id': file_id})
print('Downloaded content "{}"'.format(downloaded.GetContentString()))
附件列表
基础_模型迁移_CBIR_augmentation的更多相关文章
- 使用 Azure PowerShell 将 IaaS 资源从经典部署模型迁移到 Azure Resource Manager
以下步骤演示了如何使用 Azure PowerShell 命令将基础结构即服务 (IaaS) 资源从经典部署模型迁移到 Azure Resource Manager 部署模型. 也可根据需要通过 Az ...
- 老李分享: 并行计算基础&编程模型与工具 1
老李分享: 并行计算基础&编程模型与工具 在当前计算机应用中,对高速并行计算的需求是广泛的,归纳起来,主要有三种类型的应用需求: 计算密集(Computer-Intensive)型应用,如 ...
- 算法基础_递归_求杨辉三角第m行第n个数字
问题描述: 算法基础_递归_求杨辉三角第m行第n个数字(m,n都从0开始) 解题源代码(这里打印出的是杨辉三角某一层的所有数字,没用大数,所以有上限,这里只写基本逻辑,要符合题意的话,把循环去掉就好) ...
- 规划将 IaaS 资源从经典部署模型迁移到 Azure Resource Manager
尽管 Azure 资源管理器提供了许多精彩功能,但请务必计划迁移,以确保一切顺利进行. 花时间进行规划可确保执行迁移活动时不会遇到问题. Note 以下指导的主要参与者为 Azure 客户顾问团队,以 ...
- 有关从经典部署模型迁移到 Azure Resource Manager 部署模型的常见问题
此迁移计划是否影响 Azure 虚拟机上运行的任何现有服务或应用程序? 不可以. VM(经典)是公开上市的完全受支持的服务. 你可以继续使用这些资源来拓展你在 Azure 上的足迹. 如果我近期不打算 ...
- 桥接模式_NAT模式_仅主机模式_模型图.ziw
2017年1月12日, 星期四 桥接模式_NAT模式_仅主机模式_模型图 null
- 使用 Azure CLI 将 IaaS 资源从经典部署模型迁移到 Azure Resource Manager 部署模型
以下步骤演示如何使用 Azure 命令行接口 (CLI) 命令将基础结构即服务 (IaaS) 资源从经典部署模型迁移到 Azure Resource Manager 部署模型. 本文中的操作需要 Az ...
- Flutter实战视频-移动电商-05.Dio基础_引入和简单的Get请求
05.Dio基础_引入和简单的Get请求 博客地址: https://jspang.com/post/FlutterShop.html#toc-4c7 第三方的http请求库叫做Dio https:/ ...
- Flutter实战视频-移动电商-08.Dio基础_伪造请求头获取数据
08.Dio基础_伪造请求头获取数据 上节课代码清楚 重新编写HomePage这个动态组件 开始写请求的方法 请求数据 .但是由于我们没加请求的头 所以没有返回数据 451就是表示请求错错误 创建请求 ...
随机推荐
- python class 2
//test.py 1 class Employee: 2 'all employee' 3 empCount = 0 4 def __init__(s ...
- android 通过页面上关键字快速定位代码
这里定位微信关于页面, 当然可以直接获取当前最顶层activity 反编译apk后 搜索 strings.xml,找到对应id 搜索文件到用到id对应的成员变量,通常 是在R*.smali文件中 字符 ...
- Mysql修改字段类型,修改字段名
mysql修改字段名: ALTER TABLE 表名 CHANGE 旧字段名 新字段名 新数据类型; 参考:https://blog.csdn.net/u010002184/article/detai ...
- servlet 的servletconfig
- 003-SqlHelper.cs/Web.config
<?xml version="1.0" encoding="utf-8"?> <!-- 有关如何配置 ASP.NET 应用程序的详细信息,请访 ...
- caffe运行训练脚本时报错:Unknown bottom blob 'data' (layer 'conv1',bottom index 0)
报错的两种报错原因: 1.输入数的路径错误,需要将路径进行修改排查目录是否出错 2.训练原数据格式不对 3.train.prototxt文件中并未设置test层,而在solver层则设置了test的迭 ...
- c#之如何转换文本文件编码格式为utf-8
如代码: string content = File.ReadAllText(path, Encoding.Default); File.WriteAllText(path, content, Enc ...
- Python全栈-day12-day13-函数4
1.迭代器 1)定义:迭代取值的工具 2)优缺点 优点 a.提供一种不依赖索引取值的方法 b.同时一时刻内存在存在的值只有一个,更加省内存 缺点 a.取值麻烦,只能通过next方法一个一个地往后取 b ...
- CentOS下nginx+php的配置及nginx开机启动配置
关闭防火墙 (不然外链接是访问不了 apache) service iptables stop 关闭安全系统 SELinux( 不然报403 访问页面错误 ) 1.Nginx安装主要在于配置文件的修改 ...
- vs实现数据库数据迁移
public ActionResult About() { List<ChangeData.Models.old.adsinfo> adsinfo_new = new List<Mo ...