import tensorflow as tf
import os
from matplotlib import pyplot as plt
import tensorflow.keras.datasets
from tensorflow.keras import Model
import numpy as np
from tensorflow.keras.layers import Dense,Flatten,BatchNormalization,Dropout,Conv2D,Activation,MaxPool2D
cifar10=tf.keras.datasets.cifar10
(x_train,y_train),(x_test,y_test)=cifar10.load_data()
x_train=x_train/255.
x_test=x_test/255. class AlexNet(Model):
def __init__(self):
super(AlexNet, self).__init__()
self.c1=Conv2D(filters=96,kernel_size=(3,3),strides=1,padding='valid')
self.b1=BatchNormalization()
self.a1=Activation('relu')
self.p1=MaxPool2D(pool_size=(3,3),strides=2) self.c2 = Conv2D(filters=384, kernel_size=(3, 3), strides=1, padding='same')
#self.b2 = BatchNormalization()
self.a2 = Activation('relu')
#self.p2 = MaxPool2D(pool_size=(3, 3), strides=2) self.c3 = Conv2D(filters=256, kernel_size=(3, 3), strides=1, padding='same')
# self.b2 = BatchNormalization()
self.a3 = Activation('relu')
self.p3 = MaxPool2D(pool_size=(3, 3), strides=2) self.flatten=Flatten()
self.f1 = Dense(2048,activation='relu')
self.d1=Dropout(0.5)
self.f2 = Dense(2048, activation='relu')
self.d2 = Dropout(0.5)
self.f3 = Dense(10, activation='softmax') def call(self,x): x = self.c1(x)
x = self.b1(x)
x = self.a1(x)
x = self.p1(x) x = self.c2(x)
x = self.a2(x) x = self.c3(x)
x = self.a3(x)
x = self.p3(x) x = self.flatten(x) x=self.f1(x)
x=self.d1(x)
x=self.f2(x)
x=self.d2(x)
y=self.f3(x)
return y model=AlexNet() model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy']) check_save_path='./checkpoint/AlexNet.ckpt'
if os.path.exists(check_save_path+'.index'):
print('-------------lodel the model------------')
model.load_weights(check_save_path) cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,save_best_only=True,
save_weights_only=True) history=model.fit(x_train,y_train,batch_size=128,epochs=5,validation_data=(x_test,y_test),
validation_freq=1,callbacks=[cp_callback]) model.summary() file=open('./AlexNet_wights.txt','w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.np()) + '\n')
file.close() ############可视化图像###############
acc=history.history['sparse_categorical_accuracy']
val_acc=history.history['sparse_categorical_val_accuracy']
loss=history.history['loss']
val_loss=history.history['val_loss'] plt.subplot(1,2,1)
plt.plot(acc)
plt.plot(val_acc)
plt.legend() plt.subplot(1,2,2)
plt.plot(loss)
plt.plot(val_loss)
plt.legend() plt.show()

此代码运行较慢,单次遍历需要近15分钟,由此可见两层全连接层2048个神经元远远拖慢运行速度

AlexNet实现cifar10数据集分类的更多相关文章

  1. 第十三节,使用带有全局平均池化层的CNN对CIFAR10数据集分类

    这里使用的数据集仍然是CIFAR-10,由于之前写过一篇使用AlexNet对CIFAR数据集进行分类的文章,已经详细介绍了这个数据集,当时我们是直接把这些图片的数据文件下载下来,然后使用pickle进 ...

  2. 用pytorch进行CIFAR-10数据集分类

    CIFAR-10.(Canadian Institute for Advanced Research)是由 Alex Krizhevsky.Vinod Nair 与 Geoffrey Hinton 收 ...

  3. python实现HOG+SVM对CIFAR-10数据集分类(上)

    本博客只用于学习,如果有错误的地方,恳请指正,如需转载请注明出处. 看机器学习也是有一段时间了,这两天终于勇敢地踏出了第一步,实现了HOG+SVM对图片分类,具体代码可以在github上下载,http ...

  4. CIFAR-10数据集图像分类【PCA+基于最小错误率的贝叶斯决策】

    CIFAR-10和CIFAR-100均是带有标签的数据集,都出自于规模更大的一个数据集,他有八千万张小图片.而本次实验采用CIFAR-10数据集,该数据集共有60000张彩色图像,这些图像是32*32 ...

  5. Ubuntu+caffe训练cifar-10数据集

    1. 下载cifar-10数据库 ciffar-10数据集包含10种物体分类,50000张训练图片,10000张测试图片. 在终端执行指令下载cifar-10数据集(二进制文件): cd ~/caff ...

  6. caffe︱cifar-10数据集quick模型的官方案例

    准备拿几个caffe官方案例用来练习,就看到了caffe中的官方案例有cifar-10数据集.于是练习了一下,在CPU情况下构建quick模型.主要参考博客:liumaolincycle的博客 配置: ...

  7. 单向LSTM笔记, LSTM做minist数据集分类

    单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...

  8. 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化

    一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...

  9. Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes

    Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes 代码如下: # !/usr/bin/env python # encoding: utf-8 __author__ = ...

随机推荐

  1. PHP __construct() 函数

    实例 函数创建一个新的 SimpleXMLElement 对象,然后输出 body 节点的内容:高佣联盟 www.cgewang.com <?php $note=<<<XML ...

  2. Python性能分析与优化PDF高清完整版免费下载|百度云盘

    百度云盘|Python性能分析与优化PDF高清完整版免费下载 提取码:ubjt 内容简介 全面掌握Python代码性能分析和优化方法,消除性能瓶颈,迅速改善程序性能! 对于Python程序员来说,仅仅 ...

  3. C语言中的数据转换和定义常量

    一.数据转换 1.数据类型转换:C 语言中如果一个表达式中含有不同类型的常量和变量,在计算时,会将它们自动转换为同一种类型:在 C 语言中也可以对数据类型进行强制转换: 2.自动转换规则: a)浮点数 ...

  4. Guava基本工具--常见Object方法

    在Java中Object类是所有类的父类,其中有几个需要override的方法比如equals,hashCode和toString等方法.每次写这几个方法都要做很多重复性的判断, 很多类库提供了覆写这 ...

  5. C 语言学习 --2

    memset Declaration: void *memset(void *str, int c, size_t n); Copies the character c (an unsigned ch ...

  6. MyFirstJavaWeb

    源代码: Register.jsp <%@ page language="java" contentType="text/html; charset=utf-8&q ...

  7. Python datetime 转 JSON

    Python datetime 转 JSON Python 中将 datetime 转换为 JSON 类型,在使用 Django 时遇到的问题. 环境: Python2.7 代码: import js ...

  8. 网易云音乐ncm格式分析以及ncm与mp3格式转换

    目录 NCM格式分析 音频知识简介 两种可能 GitHub项目 格式分析 总体结构 密钥问题 代码分析 main函数 导入模块 dump函数 参考资料 代码完整版 转换工具 ncmdump ncmdu ...

  9. java_流程控制语句、权限修饰符

    判断语句 if语句第一种格式: if if(关系表达式){ 语句体; } if语句第二种格式: if…else if(关系表达式) { 语句体1; } else { 语句体2; } if语句第三种格式 ...

  10. java final关键字与static关键字

    一  final关键字 1.final修饰类不可以被继承,但是可以继承其他类. 例如: class Yy {} final class Fu extends Yy{} //可以继承Yy类 class ...