AlexNet实现cifar10数据集分类
import tensorflow as tf
import os
from matplotlib import pyplot as plt
import tensorflow.keras.datasets
from tensorflow.keras import Model
import numpy as np
from tensorflow.keras.layers import Dense,Flatten,BatchNormalization,Dropout,Conv2D,Activation,MaxPool2D
cifar10=tf.keras.datasets.cifar10
(x_train,y_train),(x_test,y_test)=cifar10.load_data()
x_train=x_train/255.
x_test=x_test/255. class AlexNet(Model):
def __init__(self):
super(AlexNet, self).__init__()
self.c1=Conv2D(filters=96,kernel_size=(3,3),strides=1,padding='valid')
self.b1=BatchNormalization()
self.a1=Activation('relu')
self.p1=MaxPool2D(pool_size=(3,3),strides=2) self.c2 = Conv2D(filters=384, kernel_size=(3, 3), strides=1, padding='same')
#self.b2 = BatchNormalization()
self.a2 = Activation('relu')
#self.p2 = MaxPool2D(pool_size=(3, 3), strides=2) self.c3 = Conv2D(filters=256, kernel_size=(3, 3), strides=1, padding='same')
# self.b2 = BatchNormalization()
self.a3 = Activation('relu')
self.p3 = MaxPool2D(pool_size=(3, 3), strides=2) self.flatten=Flatten()
self.f1 = Dense(2048,activation='relu')
self.d1=Dropout(0.5)
self.f2 = Dense(2048, activation='relu')
self.d2 = Dropout(0.5)
self.f3 = Dense(10, activation='softmax') def call(self,x): x = self.c1(x)
x = self.b1(x)
x = self.a1(x)
x = self.p1(x) x = self.c2(x)
x = self.a2(x) x = self.c3(x)
x = self.a3(x)
x = self.p3(x) x = self.flatten(x) x=self.f1(x)
x=self.d1(x)
x=self.f2(x)
x=self.d2(x)
y=self.f3(x)
return y model=AlexNet() model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy']) check_save_path='./checkpoint/AlexNet.ckpt'
if os.path.exists(check_save_path+'.index'):
print('-------------lodel the model------------')
model.load_weights(check_save_path) cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,save_best_only=True,
save_weights_only=True) history=model.fit(x_train,y_train,batch_size=128,epochs=5,validation_data=(x_test,y_test),
validation_freq=1,callbacks=[cp_callback]) model.summary() file=open('./AlexNet_wights.txt','w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.np()) + '\n')
file.close() ############可视化图像###############
acc=history.history['sparse_categorical_accuracy']
val_acc=history.history['sparse_categorical_val_accuracy']
loss=history.history['loss']
val_loss=history.history['val_loss'] plt.subplot(1,2,1)
plt.plot(acc)
plt.plot(val_acc)
plt.legend() plt.subplot(1,2,2)
plt.plot(loss)
plt.plot(val_loss)
plt.legend() plt.show()
此代码运行较慢,单次遍历需要近15分钟,由此可见两层全连接层2048个神经元远远拖慢运行速度
AlexNet实现cifar10数据集分类的更多相关文章
- 第十三节,使用带有全局平均池化层的CNN对CIFAR10数据集分类
这里使用的数据集仍然是CIFAR-10,由于之前写过一篇使用AlexNet对CIFAR数据集进行分类的文章,已经详细介绍了这个数据集,当时我们是直接把这些图片的数据文件下载下来,然后使用pickle进 ...
- 用pytorch进行CIFAR-10数据集分类
CIFAR-10.(Canadian Institute for Advanced Research)是由 Alex Krizhevsky.Vinod Nair 与 Geoffrey Hinton 收 ...
- python实现HOG+SVM对CIFAR-10数据集分类(上)
本博客只用于学习,如果有错误的地方,恳请指正,如需转载请注明出处. 看机器学习也是有一段时间了,这两天终于勇敢地踏出了第一步,实现了HOG+SVM对图片分类,具体代码可以在github上下载,http ...
- CIFAR-10数据集图像分类【PCA+基于最小错误率的贝叶斯决策】
CIFAR-10和CIFAR-100均是带有标签的数据集,都出自于规模更大的一个数据集,他有八千万张小图片.而本次实验采用CIFAR-10数据集,该数据集共有60000张彩色图像,这些图像是32*32 ...
- Ubuntu+caffe训练cifar-10数据集
1. 下载cifar-10数据库 ciffar-10数据集包含10种物体分类,50000张训练图片,10000张测试图片. 在终端执行指令下载cifar-10数据集(二进制文件): cd ~/caff ...
- caffe︱cifar-10数据集quick模型的官方案例
准备拿几个caffe官方案例用来练习,就看到了caffe中的官方案例有cifar-10数据集.于是练习了一下,在CPU情况下构建quick模型.主要参考博客:liumaolincycle的博客 配置: ...
- 单向LSTM笔记, LSTM做minist数据集分类
单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...
- 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化
一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...
- Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes
Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes 代码如下: # !/usr/bin/env python # encoding: utf-8 __author__ = ...
随机推荐
- luogu P2607 [ZJOI2008]骑士 tarjan dp
LINK:骑士 本来是不打算写的 发现这道题在tarjan的时候有一个坑点 所以写出来记录一下. 可以发现图可能是不连通的 且一个连通块中是一个奇环树. 做法:类似tarjan找割点 然后把环给拉出来 ...
- win10 64位 汇编环境
masm6或者masm5 下载. dosbox 下载安装 为何要用这个呢,因为 机子是64位的,dosbox 模拟32位的用来执行生成的exe文件 masm 安装好后,有个bin文件:个人建议将其设置 ...
- 安卓APP开发的初步了解
今天成功安装了Android Studio 并且对APP的开发框架结构进行了初步了解 如上图:app基本结构情况 下面来仔细解释一下各个方面目录的作用 首先 manifests目录:包含Android ...
- Unity 笔记
摄像机 Main Camera 跟随主角移动,不看 UI 剧情摄像机 当进入剧情时,可以关闭 main camera,启用剧情摄像机,不看 UI UI 摄像机 看 UI Unity编辑器常用的sett ...
- PhpStorm配置Apache与php的运行环境详细教程
本文主要说明如何在phpstorm中配置已经安装好的PHP与apache.首先需要在本地安装php,这里我安装的是phpstudy 进入PHPstorm的界面点击file 下的settings 在La ...
- Java 设置、删除、获取Word文档背景(基于Spire.Cloud.SDK for Java)
本文介绍使用Spire.Cloud.SDK for Java 提供的BackgroundApi接口来操作Word文档背景的方法,可设置背景,包括设置颜色背景setBackgroundColor().图 ...
- 如何实现数据库CDP,即数据库连续数据保护
备份可以分为定期备份和实时备份.定期备份与实时备份相比存在两大劣势:一是备份需要时间窗口,对于很多24小时业务运行的机构,线上业务不允许有过多的业务系统停机去进行数据备份:二是定期备份无法保证数据丢失 ...
- 2020-05-08:mycat部署数据库集群的时候 遇到了哪些坑
福哥答案2020-05-08:答案仅供参考,来自群员 使用activity时,连接mycat设置进去的序列化的流程变量,反序列化会报错这个类型字段类型是blob类型,mycat对这种类型处理时有点问题
- C#LeetCode刷题-数组
数组篇 # 题名 刷题 通过率 难度 1 两数之和 C#LeetCode刷题之#1-两数之和(Two Sum) 43.1% 简单 4 两个排序数组的中位数 C#LeetCode刷题之#4-两个排序数组 ...
- JVM垃圾回收(GC)
JVM垃圾回收(GC) 1. 判断对象是否可以被回收 引用计数法:每个对象有一个引用计数属性,新增一个引用时计数加1,引用释放时计数减1,计数为0时可以回收.此方法简单,但无法解决对象相互循环引用的问 ...