MXNet学习:试用卷积-训练CIFAR-10数据集
第一次用卷积,看的别人的模型跑的CIFAR-10,不过吐槽一下。。。我觉着我的965m加速之后比我的cpu算起来没快多少。。正确率64%的样子,没达到模型里说的75%,不知道问题出在哪里
import numpy as np
import os
import mxnet as mx
import logging
import cPickle def unpickle(file):
with open(file,'rb') as fo:
dict = cPickle.load(fo)
return np.array(dict['data']).reshape(10000,3072),np.array(dict['labels']).reshape(10000) def to4d(img):
return img.reshape(img.shape[0],3,32,32).astype(np.float32)/255 def fit(batch_num,model,val_iter,batch_size):
(train_img, train_lbl) = unpickle('cifar-10/data_batch_'+str(batch_num))
train_iter = mx.io.NDArrayIter(to4d(train_img), train_lbl, batch_size, shuffle=True)
model.fit(
X=train_iter,
eval_data=val_iter,
batch_end_callback=mx.callback.Speedometer(batch_size,200)
) (val_img, val_lbl) = unpickle('cifar-10/test_batch') batch_size = 100
val_iter = mx.io.NDArrayIter(to4d(val_img),val_lbl,batch_size) data = mx.sym.Variable('data')
cv1 = mx.sym.Convolution(data=data,name='cv1',num_filter=32,kernel=(3,3))
act1 = mx.sym.Activation(data=cv1,name='relu1',act_type='relu')
poing1 = mx.sym.Pooling(data=act1,name='poing1',kernel=(2,2),pool_type='max')
do1 = mx.sym.Dropout(data=poing1,name='do1',p=0.25)
cv2 = mx.sym.Convolution(data=do1,name='cv2',num_filter=32,kernel=(3,3))
act2 = mx.sym.Activation(data=cv2,name='relu2',act_type='relu')
poing2 = mx.sym.Pooling(data=act2,name='poing2',kernel=(2,2),pool_type='avg')
do2 = mx.sym.Dropout(data=poing2,name='do2',p=0.25)
cv3 = mx.sym.Convolution(data=do2,name='cv3',num_filter=64,kernel=(3,3))
act3 = mx.sym.Activation(data=cv3,name='relu3',act_type='relu')
poing3 = mx.sym.Pooling(data=act3,name='poing3',kernel=(2,2),pool_type='avg')
do3 = mx.sym.Dropout(data=poing3,name='do3',p=0.25)
data = mx.sym.Flatten(data=do3)
fc1 = mx.sym.FullyConnected(data=data,name='fc1',num_hidden=64)
act4 = mx.sym.Activation(data=fc1,name='relu4',act_type='relu')
do4 = mx.sym.Dropout(data=act4,name='do4',p=0.25)
fc2 = mx.sym.FullyConnected(data=do4,name='fc2',num_hidden=10)
mlp = mx.sym.SoftmaxOutput(data=fc2,name='softmax') logging.getLogger().setLevel(logging.DEBUG) model = mx.model.FeedForward(
ctx=mx.gpu(0),
symbol=mlp,
num_epoch=10,
learning_rate=0.1
)
for batch_num in range(1,6):
fit(batch_num, model, val_iter, batch_size)
MXNet学习:试用卷积-训练CIFAR-10数据集的更多相关文章
- Keras学习:试用卷积-训练CIFAR-10数据集
import numpy as np import cPickle import keras as ks from keras.layers import Dense, Activation, Fla ...
- 【翻译】TensorFlow卷积神经网络识别CIFAR 10Convolutional Neural Network (CNN)| CIFAR 10 TensorFlow
原网址:https://data-flair.training/blogs/cnn-tensorflow-cifar-10/ by DataFlair Team · Published May 21, ...
- 可变卷积Deforable ConvNet 迁移训练自己的数据集 MXNet框架 GPU版
[引言] 最近在用可变卷积的rfcn 模型迁移训练自己的数据集, MSRA官方使用的MXNet框架 环境搭建及配置:http://www.cnblogs.com/andre-ma/p/8867031. ...
- 【神经网络与深度学习】基于Windows+Caffe的Minst和CIFAR—10训练过程说明
Minst训练 我的路径:G:\Caffe\Caffe For Windows\examples\mnist 对于新手来说,初步完成环境的配置后,一脸茫然.不知如何跑Demo,有么有!那么接下来的教 ...
- TensorFlow学习笔记——LeNet-5(训练自己的数据集)
在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...
- 深度学习之卷积神经网络(CNN)详解与代码实现(二)
用Tensorflow实现卷积神经网络(CNN) 本文系作者原创,转载请注明出处:https://www.cnblogs.com/further-further-further/p/10737065. ...
- YOLO训练自己的数据集的一些心得
YOLO训练自己的数据集 YOLO-darknet训练自己的数据 [Darknet][yolo v2]训练自己数据集的一些心得----VOC格式 YOLO模型训练可视化训练过程中的中间参数 项目开源代 ...
- 【Tensorflow系列】使用Inception_resnet_v2训练自己的数据集并用Tensorboard监控
[写在前面] 用Tensorflow(TF)已实现好的卷积神经网络(CNN)模型来训练自己的数据集,验证目前较成熟模型在不同数据集上的准确度,如Inception_V3, VGG16,Inceptio ...
- 深度学习之卷积神经网络(CNN)的应用-验证码的生成与识别
验证码的生成与识别 本文系作者原创,转载请注明出处:https://www.cnblogs.com/further-further-further/p/10755361.html 目录 1.验证码的制 ...
随机推荐
- js监听网页页面滑动滚动事件,实现导航栏自动显示或隐藏
/** * 页面滑动滚动事件 * @param e *///0为隐藏,1为显示var s = 1;function scrollFunc(e) { // e存在就用e不存在就用windon.event ...
- Redis 事务支持 ACID 么?
腾讯面试官:「数据库事务机制了解么?」 「内心独白:小意思,不就 ACID 嘛,转眼一想,我面试的可是技术专家,不会这么简单的问题吧」 程许远:「balabala-- 极其自信且从容淡定的说了一通.」 ...
- 【LeetCode】628. 三个数的最大乘积
解题思路 如果数组中全是正数或者全是负数,最大乘积就是最大的三个数的乘积.如果数组中既有正数又有负数,最大乘积可能是三个最大正数乘积,也可能是两个最小负数和最大正数的乘积.遍历数组找到最大的三个数和最 ...
- .NET 云原生架构师训练营(KestrelServer源码分析)--学习笔记
目录 目标 源码 目标 理解 KestrelServer 如何接收网络请求,网络请求如何转换成 http request context(C# 可识别) 源码 https://github.com/d ...
- [开发笔记usbTOcan]用树莓派搭建私有Git服务器
0 | 思路 在开始编程前,先创建一个版本管理库,以前一直用SVN,但目前用Git的还是比较,正好利用这个机会学习GIt. 想过使用Github提供的免费服务器,但项目目前还没有做开源的准备,于是就有 ...
- 【记录一个问题】一个golang中的BUG,为啥编译的时候无法发现,而单独跑测试用例就发现了
代码大致如下: func DoSomething(){ log.Printf("a=%s, b=%s, c=%s", a, b) //忘记少写一个参数.但是编译正常通过 } fun ...
- thinkpad s5 电源功率不足提示
相关答案 作者:路灯瓜 链接:https://www.zhihu.com/question/47551448/answer/122578101 来源:知乎 著作权归作者所有.商业转载请联系作者获得授权 ...
- 在DigitalOcean vps中安装vnstat监控流量,浏览器打开php代码。。。
由于DigitalOcean中没有发现可以观察已用流量的功能,有想知道自己的流量使用情况,所以安装了vnstat. 安装过程十分简单,见百度经验,官方主页等. 1.安装完vnstat后,直接命令vns ...
- MySQL基本数据类型之枚举与集合类型
目录 一:枚举 1.枚举 2.创建表(使用枚举) 3.表内添加数据 二:集合 1.集合 2.创建表(使用集合) 3.表内添加数据 一:枚举 1.枚举 枚举作用: 提前定义好数据之后 后续录入只能录定义 ...
- maven常用打包命令
常用maven命令 执行与构建过程(编译,测试,打包)相关的命令必须进入pom.xml所在位置执行 mvn clean:清理(打包好的程序放在生成的名为target的文件中,清理即删除文件中打包好的程 ...