from keras.applications.vgg16 import VGG16
from keras.models import Sequential
from keras.layers import Conv2D,MaxPool2D,Activation,Dropout,Flatten,Dense
from keras.optimizers import SGD
from keras.preprocessing.image import ImageDataGenerator,img_to_array,load_img
import numpy as np
 vgg16_model = VGG16(weights='imagenet',include_top=False, input_shape=(150,150,3))
 # 搭建全连接层
top_model = Sequential()
top_model.add(Flatten(input_shape=vgg16_model.output_shape[1:]))
top_model.add(Dense(256,activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(2,activation='softmax')) model = Sequential()
model.add(vgg16_model)
model.add(top_model)
train_datagen = ImageDataGenerator(
rotation_range = 40, # 随机旋转度数
width_shift_range = 0.2, # 随机水平平移
height_shift_range = 0.2,# 随机竖直平移
rescale = 1/255, # 数据归一化
shear_range = 20, # 随机错切变换
zoom_range = 0.2, # 随机放大
horizontal_flip = True, # 水平翻转
fill_mode = 'nearest', # 填充方式
)
test_datagen = ImageDataGenerator(
rescale = 1/255, # 数据归一化
)
batch_size = 32

# 生成训练数据
train_generator = train_datagen.flow_from_directory(
'image/train',
target_size=(150,150),
batch_size=batch_size,
) # 测试数据
test_generator = test_datagen.flow_from_directory(
'image/test',
target_size=(150,150),
batch_size=batch_size,
)
train_generator.class_indices
{'cat': 0, 'dog': 1}
 # 定义优化器,代价函数,训练过程中计算准确率
model.compile(optimizer=SGD(lr=1e-4,momentum=0.9),loss='categorical_crossentropy',metrics=['accuracy']) model.fit_generator(train_generator,steps_per_epoch=len(train_generator),epochs=20,validation_data=test_generator,validation_steps=len(test_generator))

# pip install h5py
model.save('model_vgg16.h5')

测试

from keras.models import load_model
import numpy as np label = np.array(['cat','dog'])
# 载入模型
model = load_model('model_vgg16.h5') # 导入图片
image = load_img('image/test/cat/cat.1003.jpg')
image

image = image.resize((150,150))
image = img_to_array(image)
image = image/255
image = np.expand_dims(image,0)
image.shape
(1, 150, 150, 3)
print(label[model.predict_classes(image)]
['cat']
 

使用VGG16完成猫狗分类的更多相关文章

  1. 使用ModelArts自动学习完成猫狗声音分类

    准备数据 点击下载猫狗声音数据集至本地: 解压,文件包结构大概如下图所示 data ├── test │ ├── cats │ │ ├── cat_20.wav │ │ ├── ...... │ │ ...

  2. paddlepaddle实现猫狗分类

    目录 1.预备工作 1.1 数据集准备 1.2 数据预处理 2.训练 2.1 模型 2.2 定义训练 2.3 训练 3.预测 4.参考文献 声明:这是我的个人学习笔记,大佬可以点评,指导,不喜勿喷.实 ...

  3. 人工智能——CNN卷积神经网络项目之猫狗分类

    首先先导入所需要的库 import sys from matplotlib import pyplot from tensorflow.keras.utils import to_categorica ...

  4. 用tensorflow迁移学习猫狗分类

    笔者这几天在跟着莫烦学习TensorFlow,正好到迁移学习(至于什么是迁移学习,看这篇),莫烦老师做的是预测猫和老虎尺寸大小的学习.作为一个有为的学生,笔者当然不能再预测猫啊狗啊的大小啦,正好之前正 ...

  5. Gluon炼丹(Kaggle 120种狗分类,迁移学习加双模型融合)

    这是在kaggle上的一个练习比赛,使用的是ImageNet数据集的子集. 注意,mxnet版本要高于0.12.1b2017112. 下载数据集. train.zip test.zip labels ...

  6. 猫狗分类--Tensorflow实现

    贴一张自己画的思维导图  数据集准备 kaggle猫狗大战数据集(训练),微软的不需要FQ 12500张cat 12500张dog 生成图片路径和标签的List step1:获取D:/Study/Py ...

  7. 1.keras实现-->自己训练卷积模型实现猫狗二分类(CNN)

    原数据集:包含 25000张猫狗图像,两个类别各有12500 新数据集:猫.狗 (照片大小不一样) 训练集:各1000个样本 验证集:各500个样本 测试集:各500个样本 1= 狗,0= 猫 # 将 ...

  8. wdcp lanmp 安装+搭建网站+安全狗安装 详细实用

    先说一下WDCP,其实就是一个集成环境,优点是有后台可视化面板操作,不像一般的linux似的 都要用代码命令! Linux 的PHP 环境一般就是两个搭配 [mysql+Apache+PHP]和[My ...

  9. 使用pytorch完成kaggle猫狗图像识别

    kaggle是一个为开发商和数据科学家提供举办机器学习竞赛.托管数据库.编写和分享代码的平台,在这上面有非常多的好项目.好资源可供机器学习.深度学习爱好者学习之用.碰巧最近入门了一门非常的深度学习框架 ...

随机推荐

  1. maven——将jar安装到本地仓库

    环境变量MAVEN_HOME配置正确后,cmd窗口执行此命令: mvn install:install-file -Dfile=C:\hehe.jar  -DgroupId=com.rockontro ...

  2. 手写Indexof

    String.prototype.indexO = function(st){ // console.log(this.length); let str = this; var j = 0; let ...

  3. [Comet OJ - Contest #7 D][52D 2417]机器学习题_斜率优化dp

    机器学习题 题目大意: 数据范围: 题解: 学长说是决策单调性? 直接斜率优化就好了嘛 首先发现的是,$A$和$B$的值必定是某两个$x$值. 那么我们就把,$y$的正负分成两个序列,$val1_i$ ...

  4. ffmpeg AVFrame结构体及其相关函数

    0. 简介 AVFrame中存储的是原始数据(例如视频的YUV, RGB, 音频的PCM), 此外还包含了一些相关的信息, 例如: 解码的时候存储了宏块类型表, QP表, 运动矢量等数据. 编码的时候 ...

  5. [读书笔记]Hadoop权威指南 第3版

    下面归纳概述了用于设置MapReduce作业输出的压缩格式的配置属性.如果MapReduce驱动使用了Tool接口,则可以通过命令行将这些属性传递给程序,这比通过程序代码来修改压缩属性更加简便. Ma ...

  6. (0)c++入门——认识指针与数组——指针即是内存中地址。

    初识指针 首先需要了解一个概念,计算机的内存(或者说是寄存器)都是有地址的. <c++ primer plus>一书P37中提到这样一个概念:为把信息存储在计算机中,程序必须记录3个基本属 ...

  7. 并不对劲的bzoj4538:loj2049:p3250:[HNOI2016]网络

    题意 有一棵\(n\)(\(n\leq 10^5\))个点的树,\(m\)(\(m\leq 2\times 10^5\))个操作.操作有三种:1.给出\(u,v,k\),表示加入一条从\(u\)到\( ...

  8. webSocket协议和Socket.IO

    一.Http无法轻松实现实时应用: ● HTTP协议是无状态的,服务器只会响应来自客户端的请求,但是它与客户端之间不具备持续连接. ● 我们可以非常轻松的捕获浏览器上发生的事件(比如用户点击了盒子), ...

  9. Digester库

    在之前所学习关于启动简单的Tomcat部分实现的代码中,我们使用一个启动类Bootstrap类 来实例化连接器.servlet容器.wrapper实例.和其他组件,然后调用各个对象的set方法将他们关 ...

  10. C语言写郑州大学校友通讯录

    #include <stdio.h> #include <string.h> #include <stdlib.h> #define LEN sizeof(stru ...