AlexNet实现cifar10数据集分类
import tensorflow as tf
import os
from matplotlib import pyplot as plt
import tensorflow.keras.datasets
from tensorflow.keras import Model
import numpy as np
from tensorflow.keras.layers import Dense,Flatten,BatchNormalization,Dropout,Conv2D,Activation,MaxPool2D
cifar10=tf.keras.datasets.cifar10
(x_train,y_train),(x_test,y_test)=cifar10.load_data()
x_train=x_train/255.
x_test=x_test/255. class AlexNet(Model):
def __init__(self):
super(AlexNet, self).__init__()
self.c1=Conv2D(filters=96,kernel_size=(3,3),strides=1,padding='valid')
self.b1=BatchNormalization()
self.a1=Activation('relu')
self.p1=MaxPool2D(pool_size=(3,3),strides=2) self.c2 = Conv2D(filters=384, kernel_size=(3, 3), strides=1, padding='same')
#self.b2 = BatchNormalization()
self.a2 = Activation('relu')
#self.p2 = MaxPool2D(pool_size=(3, 3), strides=2) self.c3 = Conv2D(filters=256, kernel_size=(3, 3), strides=1, padding='same')
# self.b2 = BatchNormalization()
self.a3 = Activation('relu')
self.p3 = MaxPool2D(pool_size=(3, 3), strides=2) self.flatten=Flatten()
self.f1 = Dense(2048,activation='relu')
self.d1=Dropout(0.5)
self.f2 = Dense(2048, activation='relu')
self.d2 = Dropout(0.5)
self.f3 = Dense(10, activation='softmax') def call(self,x): x = self.c1(x)
x = self.b1(x)
x = self.a1(x)
x = self.p1(x) x = self.c2(x)
x = self.a2(x) x = self.c3(x)
x = self.a3(x)
x = self.p3(x) x = self.flatten(x) x=self.f1(x)
x=self.d1(x)
x=self.f2(x)
x=self.d2(x)
y=self.f3(x)
return y model=AlexNet() model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy']) check_save_path='./checkpoint/AlexNet.ckpt'
if os.path.exists(check_save_path+'.index'):
print('-------------lodel the model------------')
model.load_weights(check_save_path) cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,save_best_only=True,
save_weights_only=True) history=model.fit(x_train,y_train,batch_size=128,epochs=5,validation_data=(x_test,y_test),
validation_freq=1,callbacks=[cp_callback]) model.summary() file=open('./AlexNet_wights.txt','w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.np()) + '\n')
file.close() ############可视化图像###############
acc=history.history['sparse_categorical_accuracy']
val_acc=history.history['sparse_categorical_val_accuracy']
loss=history.history['loss']
val_loss=history.history['val_loss'] plt.subplot(1,2,1)
plt.plot(acc)
plt.plot(val_acc)
plt.legend() plt.subplot(1,2,2)
plt.plot(loss)
plt.plot(val_loss)
plt.legend() plt.show()
此代码运行较慢,单次遍历需要近15分钟,由此可见两层全连接层2048个神经元远远拖慢运行速度
AlexNet实现cifar10数据集分类的更多相关文章
- 第十三节,使用带有全局平均池化层的CNN对CIFAR10数据集分类
这里使用的数据集仍然是CIFAR-10,由于之前写过一篇使用AlexNet对CIFAR数据集进行分类的文章,已经详细介绍了这个数据集,当时我们是直接把这些图片的数据文件下载下来,然后使用pickle进 ...
- 用pytorch进行CIFAR-10数据集分类
CIFAR-10.(Canadian Institute for Advanced Research)是由 Alex Krizhevsky.Vinod Nair 与 Geoffrey Hinton 收 ...
- python实现HOG+SVM对CIFAR-10数据集分类(上)
本博客只用于学习,如果有错误的地方,恳请指正,如需转载请注明出处. 看机器学习也是有一段时间了,这两天终于勇敢地踏出了第一步,实现了HOG+SVM对图片分类,具体代码可以在github上下载,http ...
- CIFAR-10数据集图像分类【PCA+基于最小错误率的贝叶斯决策】
CIFAR-10和CIFAR-100均是带有标签的数据集,都出自于规模更大的一个数据集,他有八千万张小图片.而本次实验采用CIFAR-10数据集,该数据集共有60000张彩色图像,这些图像是32*32 ...
- Ubuntu+caffe训练cifar-10数据集
1. 下载cifar-10数据库 ciffar-10数据集包含10种物体分类,50000张训练图片,10000张测试图片. 在终端执行指令下载cifar-10数据集(二进制文件): cd ~/caff ...
- caffe︱cifar-10数据集quick模型的官方案例
准备拿几个caffe官方案例用来练习,就看到了caffe中的官方案例有cifar-10数据集.于是练习了一下,在CPU情况下构建quick模型.主要参考博客:liumaolincycle的博客 配置: ...
- 单向LSTM笔记, LSTM做minist数据集分类
单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...
- 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化
一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...
- Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes
Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes 代码如下: # !/usr/bin/env python # encoding: utf-8 __author__ = ...
随机推荐
- 安装ElasticSearch遇到的深坑
实验需要ES,安装过程中遇到一些奇葩的问题,记录下.下面介绍下安装步骤: 第一步:安装java ES是运行在java虚拟机上面的,所以首先需要安装java环境,安装过程不再赘述,唯一需要注意的是ES对 ...
- 属性集 Properties
5.1 概述 java.util.Properties 继承于 Hashtable ,来表示一个持久的属性集.它使用键值结构存储数据,每个键及其对应值都是一个字符串.该类也被许多Java类使用,比如获 ...
- python 爬虫刷访问量
import urllib.requestimport time # 使用build_opener()是为了让python程序模仿浏览器进行访问opener = urllib.request.buil ...
- 用Python做一个简单的翻译工具
编程本身是跟年龄无关的一件事,不论你现在是十四五岁,还是四五十岁,如果你热爱它,并且愿意持续投入其中,必定会有所收获. 很多人学习python,不知道从何学起.很多人学习python,掌握了基本语法过 ...
- 六种酷炫Python运行进度条
本文介绍了目前6种比较常用的进度条,让大家都能直观地看到脚本运行最新的进展情况 很多人学习python,不知道从何学起.很多人学习python,掌握了基本语法过后,不知道在哪里寻找案例上手.很多已经做 ...
- 华为云的研究成果又双叒叕被MICCAI收录了!
摘要:2020年国际医学图像计算和计算机辅助干预会议(MICCAI 2020),论文接收结果已经公布:华为云医疗AI团队和华中科技大学合作的2篇研究成果入选. 语义/实例分割问题是近年来医学图像计算领 ...
- SpringBoot介绍,快速入门小例子,目录结构,不同的启动方式,SpringBoot常用注解
SpringBoot介绍 引言 为了使用ssm框架去开发,准备ssm框架的模板配置 为了Spring整合第三方框架,单独的去编写xml文件 导致ssm项目后期xml文件特别多,维护xml文件的成本也是 ...
- MATLAB通过ODBC连接数据库方法
MATLAB通过ODBC连接数据库方法 1.首先创建数据库,我在这里用到的是MySQL 8.0 2.建立ODBC数据源,参考链接: https://www.cnblogs.com/benpao1314 ...
- centos之hadoop的安装
Evernote Export 第一步 环境部署 参考 http://dblab.xmu.edu.cn/blog/install-hadoop-in-centos/ 1.创建hadoop用户 $su ...
- 谁先执行?props还是data或是其他? vue组件初始化的执行顺序详解
初入vue的朋友可能会疑惑,组件初始化的时候,created,props,data到底谁先执行? 今天,我就带大家从源码的角度看看到底谁先执行? 我们知道,vue是个实例 那我们就从new Vue() ...