这是一个图像分类的比赛CIFAR( CIFAR-10 - Object Recognition in Images )

首先我们需要下载数据文件,地址:

http://www.cs.toronto.edu/~kriz/cifar.html

CIFAR-10数据集包含10个类别的60000个32x32彩色图像,每个类别6000个图像。有50000张训练图像和10000张测试图像。

数据集分为五个训练批次和一个测试批次,每个批次具有10000张图像。测试批次包含每个类别中恰好1000张随机选择的图像。训练批次按随机顺序包含其余图像,但是某些训练批次可能包含比另一类更多的图像。在它们之间,培训批次精确地包含每个班级的5000张图像。

这些类是完全互斥的。汽车和卡车之间没有重叠。“汽车”包括轿车,SUV和类似的东西。“卡车”仅包括大型卡车。都不包括皮卡车。

详细代码:

1.导包

 import numpy as np

 # 序列化和反序列化
import pickle from sklearn.preprocessing import OneHotEncoder import warnings
warnings.filterwarnings('ignore') import tensorflow as tf

2.数据加载

 def unpickle(file):

with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='ISO-8859-1')
return dict # def unpickle(file):
# import pickle
# with open(file, 'rb') as fo:
# dict = pickle.load(fo, encoding='bytes')
# return dict labels = []
X_train = []
for i in range(1,6):
data = unpickle('./cifar-10-batches-py/data_batch_%d'%(i))
labels.append(data['labels'])
X_train.append(data['data']) # 将list类型转换为ndarray
y_train = np.array(labels).reshape(-1)
X_train = np.array(X_train) # reshape
X_train = X_train.reshape(-1,3072) # 目标值概率
one_hot = OneHotEncoder()
y_train =one_hot.fit_transform(y_train.reshape(-1,1)).toarray()
display(X_train.shape,y_train.shape)

3.构建神经网络

 X = tf.placeholder(dtype=tf.float32,shape = [None,3072])
y = tf.placeholder(dtype=tf.float32,shape = [None,10])
kp = tf.placeholder(dtype=tf.float32) def gen_v(shape):
return tf.Variable(tf.truncated_normal(shape = shape)) def conv(input_,filter_,b):
conv = tf.nn.relu(tf.nn.conv2d(input_,filter_,strides=[1,1,1,1],padding='SAME') + b)
return tf.nn.max_pool(conv,[1,3,3,1],[1,2,2,1],'SAME') def net_work(input_,kp): # 形状改变,4维
input_ = tf.reshape(input_,shape = [-1,32,32,3])
# 第一层
filter1 = gen_v(shape = [3,3,3,64])
b1 = gen_v(shape = [64])
conv1 = conv(input_,filter1,b1)
# 归一化
conv1 = tf.layers.batch_normalization(conv1,training=True) # 第二层
filter2 = gen_v([3,3,64,128])
b2 = gen_v(shape = [128])
conv2 = conv(conv1,filter2,b2)
conv2 = tf.layers.batch_normalization(conv2,training=True) # 第三层
filter3 = gen_v([3,3,128,256])
b3 = gen_v([256])
conv3 = conv(conv2,filter3,b3)
conv3 = tf.layers.batch_normalization(conv3,training=True) # 第一层全连接层
dense = tf.reshape(conv3,shape = [-1,4*4*256])
fc1_w = gen_v(shape = [4*4*256,1024])
fc1_b = gen_v([1024])
fc1 = tf.matmul(dense,fc1_w) + fc1_b
fc1 = tf.layers.batch_normalization(fc1,training=True)
fc1 = tf.nn.relu(fc1)
# fc1.shape = [-1,1024] # dropout
dp = tf.nn.dropout(fc1,keep_prob=kp) # 第二层全连接层
fc2_w = gen_v(shape = [1024,1024])
fc2_b = gen_v(shape = [1024])
fc2 = tf.nn.relu(tf.layers.batch_normalization(tf.matmul(dp,fc2_w) + fc2_b,training=True)) # 输出层
out_w = gen_v(shape = [1024,10])
out_b = gen_v(shape = [10])
out = tf.matmul(fc2,out_w) + out_b
return out

4.损失函数准确率

 out = net_work(X,kp)

 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=out))

 # 准确率
y_ = tf.nn.softmax(out) # equal 相当于 ==
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y,axis = -1),tf.argmax(y_,axis = 1)),tf.float16))
accuracy

5.最优化

 opt = tf.train.AdamOptimizer().minimize(loss)
opt

6.开启训练

 epoches = 50000
saver = tf.train.Saver() index = 0
def next_batch(X,y):
global index
batch_X = X[index*128:(index+1)*128]
batch_y = y[index*128:(index+1)*128]
index+=1
if index == 390:
index = 0
return batch_X,batch_y test = unpickle('./cifar-10-batches-py/test_batch')
y_test = test['labels']
y_test = np.array(y_test)
X_test = test['data']
y_test = one_hot.transform(y_test.reshape(-1,1)).toarray()
y_test[:10] with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(epoches):
batch_X,batch_y = next_batch(X_train,y_train)
opt_,loss_ = sess.run([opt,loss],feed_dict = {X:batch_X,y:batch_y,kp:0.5})
print('----------------------------',loss_)
if i % 100 == 0:
score_test = sess.run(accuracy,feed_dict = {X:X_test,y:y_test,kp:1.0})
score_train = sess.run(accuracy,feed_dict = {X:batch_X,y:batch_y,kp:1.0})
print('iter count:%d。mini_batch loss:%0.4f。训练数据上的准确率:%0.4f。测试数据上准确率:%0.4f'%
(i+1,loss_,score_train,score_test))


这个准确率只达到了百分之80

如果想提高准确率,还需要进一步优化,调参

利用卷积神经网络处理cifar图像分类的更多相关文章

  1. TensorFlow系列专题(十四): 手把手带你搭建卷积神经网络实现冰山图像分类

    目录: 冰山图片识别背景 数据介绍 数据预处理 模型搭建 结果分析 总结 一.冰山图片识别背景 这里我们要解决的任务是来自于Kaggle上的一道赛题(https://www.kaggle.com/c/ ...

  2. Neuromation新研究:利用卷积神经网络进行儿童骨龄评估

    近日,Neuromation 团队在 Medium 上撰文介绍其最新研究成果:利用卷积神经网络(CNN)评估儿童骨龄,这一自动骨龄评估系统可以得到与放射科专家相似或更好的结果.该团队评估了手骨不同区域 ...

  3. 【翻译】TensorFlow卷积神经网络识别CIFAR 10Convolutional Neural Network (CNN)| CIFAR 10 TensorFlow

    原网址:https://data-flair.training/blogs/cnn-tensorflow-cifar-10/ by DataFlair Team · Published May 21, ...

  4. 利用卷积神经网络(CNN)构造社区问答系统

    /* 版权声明:能够随意转载,转载时请标明文章原始出处和作者信息 .*/                                                     author: 张俊林 ...

  5. 利用卷积神经网络(VGG19)实现火灾分类(附tensorflow代码及训练集)

    源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...

  6. Tensorflow学习教程------利用卷积神经网络对mnist数据集进行分类_利用训练好的模型进行分类

    #coding:utf-8 import tensorflow as tf from PIL import Image,ImageFilter from tensorflow.examples.tut ...

  7. 利用卷积神经网络实现MNIST手写数据识别

    代码: import torch import torch.nn as nn import torch.utils.data as Data import torchvision # 数据库模块 im ...

  8. SIGAI深度学习第八集 卷积神经网络2

    讲授Lenet.Alexnet.VGGNet.GoogLeNet等经典的卷积神经网络.Inception模块.小尺度卷积核.1x1卷积核.使用反卷积实现卷积层可视化等. 大纲: LeNet网络 Ale ...

  9. 直白介绍卷积神经网络(CNN)【转】

    英文地址:https://ujjwalkarn.me/2016/08/11/intuitive-explanation-convnets/ 中文译文:http://mp.weixin.qq.com/s ...

随机推荐

  1. 洛谷 P1823 [COI2007] Patrik 音乐会的等待 题解

    P1823 [COI2007] Patrik 音乐会的等待 题目描述 N个人正在排队进入一个音乐会.人们等得很无聊,于是他们开始转来转去,想在队伍里寻找自己的熟人.队列中任意两个人A和B,如果他们是相 ...

  2. DIJ的优化,和spfa的优化

    SPFA和DIJ求最短路的算法的坑点一直是很多的.经常会让人搞不懂. 易错案例: 用重载运算符来排序,如: struct cmp { bool operator ()(int x, int y) co ...

  3. python中的zip函数的使用

    >>> x = [, , ] >>> y = [, , ] >>> z = [, , ] >>> xyz = list(zip( ...

  4. 「雅礼集训 2018 Day2」农民

    传送门 Description  「搞 OI 不如种田.」 小 D 在家种了一棵二叉树,第 ii 个结点的权值为 \(a_i\). 小 D 为自己种的树买了肥料,每天给树施肥. 可是几天后,小 D 却 ...

  5. NTT小结及原根求法

    注意 由于蒟蒻实在太弱了~^_^~暂时无法完成证明,仅能写出简单版总结 与FFT的区别 \(NTT\)与\(FFT\)的代码区别就是把单位根换成了原根,从而实现无精度误差与浮点数的巨大常数 原根具有单 ...

  6. css自定义属性和简单效果

    简介 CSS Variables,一个并不是那么新的东西,但对css来说绝对是一场革命. 之前使用变量的时候,需要借助sass.less等预处理工具来实现,现在我们可以直接使用css来声明变量. 兼容 ...

  7. 如何把ANSYS模型输出为CDB文件并导入FLUENT  【转载】

    转载自: http://linziok99.blog.163.com/blog/static/100157302009320134826/ 在main menu中选择Archive Model ,再点 ...

  8. 小福bbs-冲刺日志(第一天)

    [小福bbs-冲刺日志(第一天)] 这个作业属于哪个课程 班级链接 这个作业要求在哪里 作业要求的链接 团队名称 小福bbs 这个作业的目标 开会,安排具体工作 作业的正文 小福bbs-冲刺日志(第一 ...

  9. spring boot修改代码后无需重启设置,在开发时实现热部署

    Spring Boot在开发时实现热部署(开发时修改文件保存后自动重启应用)(spring-boot-devtools) 热部署是什么 大家都知道在项目开发过程中,常常会改动页面数据或者修改数据结构, ...

  10. 从Windows命令行启动MySQL

    SERVER: 从Windows命令行启动MySQL 可以从命令行手动启动MySQL服务器.可以在任何版本的Windows中实现. 要想从命令行启动mysqld服务器,你应当启动控制台窗口(或“DOS ...