AlexNet实现cifar10数据集分类
import tensorflow as tf
import os
from matplotlib import pyplot as plt
import tensorflow.keras.datasets
from tensorflow.keras import Model
import numpy as np
from tensorflow.keras.layers import Dense,Flatten,BatchNormalization,Dropout,Conv2D,Activation,MaxPool2D
cifar10=tf.keras.datasets.cifar10
(x_train,y_train),(x_test,y_test)=cifar10.load_data()
x_train=x_train/255.
x_test=x_test/255. class AlexNet(Model):
def __init__(self):
super(AlexNet, self).__init__()
self.c1=Conv2D(filters=96,kernel_size=(3,3),strides=1,padding='valid')
self.b1=BatchNormalization()
self.a1=Activation('relu')
self.p1=MaxPool2D(pool_size=(3,3),strides=2) self.c2 = Conv2D(filters=384, kernel_size=(3, 3), strides=1, padding='same')
#self.b2 = BatchNormalization()
self.a2 = Activation('relu')
#self.p2 = MaxPool2D(pool_size=(3, 3), strides=2) self.c3 = Conv2D(filters=256, kernel_size=(3, 3), strides=1, padding='same')
# self.b2 = BatchNormalization()
self.a3 = Activation('relu')
self.p3 = MaxPool2D(pool_size=(3, 3), strides=2) self.flatten=Flatten()
self.f1 = Dense(2048,activation='relu')
self.d1=Dropout(0.5)
self.f2 = Dense(2048, activation='relu')
self.d2 = Dropout(0.5)
self.f3 = Dense(10, activation='softmax') def call(self,x): x = self.c1(x)
x = self.b1(x)
x = self.a1(x)
x = self.p1(x) x = self.c2(x)
x = self.a2(x) x = self.c3(x)
x = self.a3(x)
x = self.p3(x) x = self.flatten(x) x=self.f1(x)
x=self.d1(x)
x=self.f2(x)
x=self.d2(x)
y=self.f3(x)
return y model=AlexNet() model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy']) check_save_path='./checkpoint/AlexNet.ckpt'
if os.path.exists(check_save_path+'.index'):
print('-------------lodel the model------------')
model.load_weights(check_save_path) cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,save_best_only=True,
save_weights_only=True) history=model.fit(x_train,y_train,batch_size=128,epochs=5,validation_data=(x_test,y_test),
validation_freq=1,callbacks=[cp_callback]) model.summary() file=open('./AlexNet_wights.txt','w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.np()) + '\n')
file.close() ############可视化图像###############
acc=history.history['sparse_categorical_accuracy']
val_acc=history.history['sparse_categorical_val_accuracy']
loss=history.history['loss']
val_loss=history.history['val_loss'] plt.subplot(1,2,1)
plt.plot(acc)
plt.plot(val_acc)
plt.legend() plt.subplot(1,2,2)
plt.plot(loss)
plt.plot(val_loss)
plt.legend() plt.show()
此代码运行较慢,单次遍历需要近15分钟,由此可见两层全连接层2048个神经元远远拖慢运行速度
AlexNet实现cifar10数据集分类的更多相关文章
- 第十三节,使用带有全局平均池化层的CNN对CIFAR10数据集分类
这里使用的数据集仍然是CIFAR-10,由于之前写过一篇使用AlexNet对CIFAR数据集进行分类的文章,已经详细介绍了这个数据集,当时我们是直接把这些图片的数据文件下载下来,然后使用pickle进 ...
- 用pytorch进行CIFAR-10数据集分类
CIFAR-10.(Canadian Institute for Advanced Research)是由 Alex Krizhevsky.Vinod Nair 与 Geoffrey Hinton 收 ...
- python实现HOG+SVM对CIFAR-10数据集分类(上)
本博客只用于学习,如果有错误的地方,恳请指正,如需转载请注明出处. 看机器学习也是有一段时间了,这两天终于勇敢地踏出了第一步,实现了HOG+SVM对图片分类,具体代码可以在github上下载,http ...
- CIFAR-10数据集图像分类【PCA+基于最小错误率的贝叶斯决策】
CIFAR-10和CIFAR-100均是带有标签的数据集,都出自于规模更大的一个数据集,他有八千万张小图片.而本次实验采用CIFAR-10数据集,该数据集共有60000张彩色图像,这些图像是32*32 ...
- Ubuntu+caffe训练cifar-10数据集
1. 下载cifar-10数据库 ciffar-10数据集包含10种物体分类,50000张训练图片,10000张测试图片. 在终端执行指令下载cifar-10数据集(二进制文件): cd ~/caff ...
- caffe︱cifar-10数据集quick模型的官方案例
准备拿几个caffe官方案例用来练习,就看到了caffe中的官方案例有cifar-10数据集.于是练习了一下,在CPU情况下构建quick模型.主要参考博客:liumaolincycle的博客 配置: ...
- 单向LSTM笔记, LSTM做minist数据集分类
单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...
- 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化
一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...
- Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes
Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes 代码如下: # !/usr/bin/env python # encoding: utf-8 __author__ = ...
随机推荐
- PHP crc32() 函数
实例 输出 crc32() 的结果:高佣联盟 www.cgewang.com <?php $str = crc32("Hello World!"); printf(" ...
- 4.23 子集 分数规划 二分 贪心 set 单峰函数 三分
思维题. 显然考虑爆搜.然后考虑n^2能做不能. 容易想到枚举中间的数字mid 然后往mid两边加数字 使其整个集合权值最大. 这里有一个比较显然的贪心就不再赘述了. 可以发现这样做对于集合是奇数的时 ...
- Pintech品致全新多功能MDO 704E系列示波器全新推出
2020年 7月,Pintech品致全新推出推出首款具有多个模拟通道和多个数字通道的示波器.每个模拟通道带宽为200 MHz,每个模拟通道采样率同时达1 GSa/s,在一台仪器中,实现精确.可重复的. ...
- Sharding-JDBC实现读写分离
参考资料:猿天地 https://mp.weixin.qq.com/s/kp2lJHpTMz4bDWkJYjVbOQ 作者:尹吉欢 技术选型:SpringBoot + Sharding-JDBC ...
- 【NOIP2016】天天爱跑步 题解(LCA+桶+树上差分)
题目链接 题目大意:给定一颗含有$n$个结点的树,每个结点有一个权值$w$.给定$m$条路径,如果一个点与路径的起点的距离恰好为$w$,那么$ans[i]++$.求所有结点的ans. 题目分析 暴力的 ...
- PHP 之 Composer 新手入门指南
自2012年3月1日发布以来,Composer因提供了PHP迫切需要的东西:依赖项管理而广受欢迎.实际上,Composer是将所有第三方软件(例如CSS框架,jQuery插件等)引入你的项目的一种方法 ...
- 解惑4:java是值传递还是引用传递
一.概述 曾经纠结了很久java的参数传递方式是什么样的,后面粗略的了解了一鳞半爪以后有了大概的印象:"传参数就是值传递,传对象就是引用传递",后面进一步查找了相关资料和文章以后, ...
- 用 Python 下载抖音无水印视频
说起抖音,大家或多或少应该都接触过,如果大家在上面下载过视频,一定知道我们下载的视频是带有水印的,那么我们有什么方式下载不带水印的视频呢?其实用 Python 就可以做到,下面我们来看一下. 很多人学 ...
- spring data jap的使用 1
最近一直在研究Spring Boot,今天为大家介绍下Spring Data JPA在Spring Boot中的应用,如有错误,欢迎大家指正. 先解释下什么是JPA JPA就是一个基于O/R映射的标准 ...
- 【API进阶之路】无法想象!大龄码农的硬盘里有这么多宝藏
摘要:通过把所需建立的工具库做成云容器化应用,用CCE引擎,通过API网关调用云容器引擎中的容器应用.不仅顺应了云原生的发展趋势,还能随时弹性扩容,满足公司规模化发展的需求. 公司开完年中会后,大家的 ...