import tensorflow as tf
import os
from matplotlib import pyplot as plt
import tensorflow.keras.datasets
from tensorflow.keras import Model
import numpy as np
from tensorflow.keras.layers import Dense,Flatten,BatchNormalization,Dropout,Conv2D,Activation,MaxPool2D
cifar10=tf.keras.datasets.cifar10
(x_train,y_train),(x_test,y_test)=cifar10.load_data()
x_train=x_train/255.
x_test=x_test/255. class AlexNet(Model):
def __init__(self):
super(AlexNet, self).__init__()
self.c1=Conv2D(filters=96,kernel_size=(3,3),strides=1,padding='valid')
self.b1=BatchNormalization()
self.a1=Activation('relu')
self.p1=MaxPool2D(pool_size=(3,3),strides=2) self.c2 = Conv2D(filters=384, kernel_size=(3, 3), strides=1, padding='same')
#self.b2 = BatchNormalization()
self.a2 = Activation('relu')
#self.p2 = MaxPool2D(pool_size=(3, 3), strides=2) self.c3 = Conv2D(filters=256, kernel_size=(3, 3), strides=1, padding='same')
# self.b2 = BatchNormalization()
self.a3 = Activation('relu')
self.p3 = MaxPool2D(pool_size=(3, 3), strides=2) self.flatten=Flatten()
self.f1 = Dense(2048,activation='relu')
self.d1=Dropout(0.5)
self.f2 = Dense(2048, activation='relu')
self.d2 = Dropout(0.5)
self.f3 = Dense(10, activation='softmax') def call(self,x): x = self.c1(x)
x = self.b1(x)
x = self.a1(x)
x = self.p1(x) x = self.c2(x)
x = self.a2(x) x = self.c3(x)
x = self.a3(x)
x = self.p3(x) x = self.flatten(x) x=self.f1(x)
x=self.d1(x)
x=self.f2(x)
x=self.d2(x)
y=self.f3(x)
return y model=AlexNet() model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy']) check_save_path='./checkpoint/AlexNet.ckpt'
if os.path.exists(check_save_path+'.index'):
print('-------------lodel the model------------')
model.load_weights(check_save_path) cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,save_best_only=True,
save_weights_only=True) history=model.fit(x_train,y_train,batch_size=128,epochs=5,validation_data=(x_test,y_test),
validation_freq=1,callbacks=[cp_callback]) model.summary() file=open('./AlexNet_wights.txt','w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.np()) + '\n')
file.close() ############可视化图像###############
acc=history.history['sparse_categorical_accuracy']
val_acc=history.history['sparse_categorical_val_accuracy']
loss=history.history['loss']
val_loss=history.history['val_loss'] plt.subplot(1,2,1)
plt.plot(acc)
plt.plot(val_acc)
plt.legend() plt.subplot(1,2,2)
plt.plot(loss)
plt.plot(val_loss)
plt.legend() plt.show()

此代码运行较慢,单次遍历需要近15分钟,由此可见两层全连接层2048个神经元远远拖慢运行速度

AlexNet实现cifar10数据集分类的更多相关文章

  1. 第十三节,使用带有全局平均池化层的CNN对CIFAR10数据集分类

    这里使用的数据集仍然是CIFAR-10,由于之前写过一篇使用AlexNet对CIFAR数据集进行分类的文章,已经详细介绍了这个数据集,当时我们是直接把这些图片的数据文件下载下来,然后使用pickle进 ...

  2. 用pytorch进行CIFAR-10数据集分类

    CIFAR-10.(Canadian Institute for Advanced Research)是由 Alex Krizhevsky.Vinod Nair 与 Geoffrey Hinton 收 ...

  3. python实现HOG+SVM对CIFAR-10数据集分类(上)

    本博客只用于学习,如果有错误的地方,恳请指正,如需转载请注明出处. 看机器学习也是有一段时间了,这两天终于勇敢地踏出了第一步,实现了HOG+SVM对图片分类,具体代码可以在github上下载,http ...

  4. CIFAR-10数据集图像分类【PCA+基于最小错误率的贝叶斯决策】

    CIFAR-10和CIFAR-100均是带有标签的数据集,都出自于规模更大的一个数据集,他有八千万张小图片.而本次实验采用CIFAR-10数据集,该数据集共有60000张彩色图像,这些图像是32*32 ...

  5. Ubuntu+caffe训练cifar-10数据集

    1. 下载cifar-10数据库 ciffar-10数据集包含10种物体分类,50000张训练图片,10000张测试图片. 在终端执行指令下载cifar-10数据集(二进制文件): cd ~/caff ...

  6. caffe︱cifar-10数据集quick模型的官方案例

    准备拿几个caffe官方案例用来练习,就看到了caffe中的官方案例有cifar-10数据集.于是练习了一下,在CPU情况下构建quick模型.主要参考博客:liumaolincycle的博客 配置: ...

  7. 单向LSTM笔记, LSTM做minist数据集分类

    单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...

  8. 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化

    一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...

  9. Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes

    Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes 代码如下: # !/usr/bin/env python # encoding: utf-8 __author__ = ...

随机推荐

  1. 2019 HL SC day10

    10天都过去了 4天都在全程懵逼.. 怎么可以这么难啊 我服了 现在想起依稀只记得一些结论 什么 反演? 什么后缀自动机?什么组合数的应用?什么神仙东西 ,不过讲课人的确都是神仙.(实名羡慕. mzx ...

  2. sqlzoo刷题 SELECT from Nobel Tutorial

    SELECT from Nobel Tutorial 1.Change the query shown so that it displays Nobel prizes for 1950. SELEC ...

  3. 41-native关键字的理解

    使用 native 关键字说明这个方法是原生函数,也就是这个方法是用 C/C++等非Java 语言实现的,并且被编译成了 DLL,由 java 去调用. (1)为什么要用 native 方法 java ...

  4. vue-cli脚手架的搭建

    1.安装node.js 2.安装cnpm npm install -g cnpm --registry=https://registry.npm.taobao.org 3.安装vue-cli npm ...

  5. 朴素贝叶斯分类器基本代码 && n折交叉优化

    自己也是刚刚入门.. 没脸把自己的代码放上去,先用别人的. 加上自己的解析,挺全面的,希望有用. import re import pandas as pd import numpy as np fr ...

  6. sftp与ftp的区别

    SFTP和FTP非常相似,都支持批量传输(一次传输多个文件),文件夹/目录导航,文件移动,文件夹/目录创建,文件删除等.但还是存在着差异,下面我们来看看SFTP和FTP之间的区别. 1. 安全通道FT ...

  7. 触发链模式之使用jdk的Observable和Observerver实现触发链模式(附JDK源码)

    首先看看JDK的Observer接口 public interface Observer { void update(Observable o, Object arg); } 也就一个更新的方法,这里 ...

  8. 小程序的优化代码的分析Promise方法

    代码优化,这里通过了wx.request请求轮播图的API,通过result结果里面的data数据我们可以看到massage里面装着我们的数据 通过图片可以用看到swiperList返回的三个元素的数 ...

  9. 第一个Mybatis

    第一个Mybatis 思路:搭建环境-->导入Mybatis-->编写代码-->测试 1.搭建环境 新建maven工程,配置xml文件 <?xml version=" ...

  10. CTF bossplayers 靶机

    WAYs: robots.txt文件提供线索,命令执行漏洞获得反弹shell suid命令提升权限 1:netdiscover 发现主机地址192.168.1.109 2:使用namp进行端口扫描发现 ...