Sample Classification Code of CIFAR-10 in Torch
Sample Classification Code of CIFAR-10 in Torch
from: http://torch.ch/blog/2015/07/30/cifar.html
require 'xlua'
require 'optim'
require 'nn'
require 'image'
local c = require 'trepl.colorize' opt = lapp[[
-s,--save (default "logs") subdirectory to save logs
-b,--batchSize (default 128) batch size
-r,--learningRate (default 1) learning rate
--learningRateDecay (default 1e-7) learning rate decay
--weightDecay (default 0.0005) weightDecay
-m,--momentum (default 0.9) momentum
--epoch_step (default 25) epoch step
--model (default vgg_bn_drop) model name
--max_epoch (default 300) maximum number of iterations
--backend (default nn) backend
--type (default cuda) cuda/float/cl
]] print(opt) do -- data augmentation module
local BatchFlip,parent = torch.class('nn.BatchFlip', 'nn.Module') function BatchFlip:__init()
parent.__init(self)
self.train = true
end function BatchFlip:updateOutput(input)
if self.train then
local bs = input:size()
local flip_mask = torch.randperm(bs):le(bs/)
for i=,input:size() do
if flip_mask[i] == then image.hflip(input[i], input[i]) end
end
end
self.output:set(input)
return self.output
end
end local function cast(t)
if opt.type == 'cuda' then
require 'cunn'
return t:cuda()
elseif opt.type == 'float' then
return t:float()
elseif opt.type == 'cl' then
require 'clnn'
return t:cl()
else
error('Unknown type '..opt.type)
end
end print(c.blue '==>' ..' configuring model')
local model = nn.Sequential()
model:add(nn.BatchFlip():float())
model:add(cast(nn.Copy('torch.FloatTensor', torch.type(cast(torch.Tensor())))))
model:add(cast(dofile('models/'..opt.model..'.lua')))
model:get().updateGradInput = function(input) return end if opt.backend == 'cudnn' then
require 'cudnn'
cudnn.benchmark=true
cudnn.convert(model:get(), cudnn)
end print(model) print(c.blue '==>' ..' loading data') -------------------------------------------------------------------------------------------
---------------------------- Load the Train and Test data -------------------------------
------------------------------------------------------------------------------------------- local trsize =
local tesize =
-- load dataset
trainData = {
data = torch.Tensor(, ),
labels = torch.Tensor(),
size = function() return trsize end
}
local trainData = trainData
for i = , do
local subset = torch.load('cifar-10-batches-t7/data_batch_' .. (i+) .. '.t7', 'ascii')
trainData.data[{ {i*+, (i+)*} }] = subset.data:t()
trainData.labels[{ {i*+, (i+)*} }] = subset.labels
end
trainData.labels = trainData.labels + local subset = torch.load('cifar-10-batches-t7/test_batch.t7', 'ascii')
testData = {
data = subset.data:t():double(),
labels = subset.labels[]:double(),
size = function() return tesize end
}
local testData = testData
testData.labels = testData.labels + -- resize dataset (if using small version)
trainData.data = trainData.data[{ {,trsize} }]
trainData.labels = trainData.labels[{ {,trsize} }] testData.data = testData.data[{ {,tesize} }]
testData.labels = testData.labels[{ {,tesize} }] -- reshape data
trainData.data = trainData.data:reshape(trsize,,,)
testData.data = testData.data:reshape(tesize,,,) ----------------------------------------------------------------------------------
----------------------------------------------------------------------------------
-- preprocessing data (color space + normalization)
----------------------------------------------------------------------------------
----------------------------------------------------------------------------------
print '<trainer> preprocessing data (color space + normalization)'
collectgarbage() -- preprocess trainSet
local normalization = nn.SpatialContrastiveNormalization(, image.gaussian1D())
for i = ,trainData:size() do
xlua.progress(i, trainData:size())
-- rgb -> yuv
local rgb = trainData.data[i]
local yuv = image.rgb2yuv(rgb)
-- normalize y locally:
yuv[] = normalization(yuv[{{}}])
trainData.data[i] = yuv
end
-- normalize u globally:
local mean_u = trainData.data:select(,):mean()
local std_u = trainData.data:select(,):std()
trainData.data:select(,):add(-mean_u)
trainData.data:select(,):div(std_u)
-- normalize v globally:
local mean_v = trainData.data:select(,):mean()
local std_v = trainData.data:select(,):std()
trainData.data:select(,):add(-mean_v)
trainData.data:select(,):div(std_v) trainData.mean_u = mean_u
trainData.std_u = std_u
trainData.mean_v = mean_v
trainData.std_v = std_v -- preprocess testSet
for i = ,testData:size() do
xlua.progress(i, testData:size())
-- rgb -> yuv
local rgb = testData.data[i]
local yuv = image.rgb2yuv(rgb)
-- normalize y locally:
yuv[{}] = normalization(yuv[{{}}])
testData.data[i] = yuv
end
-- normalize u globally:
testData.data:select(,):add(-mean_u)
testData.data:select(,):div(std_u)
-- normalize v globally:
testData.data:select(,):add(-mean_v)
testData.data:select(,):div(std_v) ----------------------------------------------------------------------------------
----------------------------- END --------------------------------------------- trainData.data = trainData.data:float()
testData.data = testData.data:float() confusion = optim.ConfusionMatrix() print('Will save at '..opt.save)
paths.mkdir(opt.save)
testLogger = optim.Logger(paths.concat(opt.save, 'test.log'))
testLogger:setNames{'% mean class accuracy (train set)', '% mean class accuracy (test set)'}
testLogger.showPlot = false parameters,gradParameters = model:getParameters() print(c.blue'==>' ..' setting criterion')
criterion = cast(nn.CrossEntropyCriterion()) print(c.blue'==>' ..' configuring optimizer')
optimState = {
learningRate = opt.learningRate,
weightDecay = opt.weightDecay,
momentum = opt.momentum,
learningRateDecay = opt.learningRateDecay,
} function train()
model:training()
epoch = epoch or -- drop learning rate every "epoch_step" epochs
if epoch % opt.epoch_step == then optimState.learningRate = optimState.learningRate/ end print(c.blue '==>'.." online epoch # " .. epoch .. ' [batchSize = ' .. opt.batchSize .. ']') local targets = cast(torch.FloatTensor(opt.batchSize))
local indices = torch.randperm(trainData.data:size()):long():split(opt.batchSize)
-- remove last element so that all the batches have equal size
indices[#indices] = nil local tic = torch.tic()
for t,v in ipairs(indices) do
xlua.progress(t, #indices) local inputs = trainData.data:index(,v)
targets:copy(trainData.labels:index(,v)) local feval = function(x)
if x ~= parameters then parameters:copy(x) end
gradParameters:zero() local outputs = model:forward(inputs)
local f = criterion:forward(outputs, targets)
local df_do = criterion:backward(outputs, targets)
model:backward(inputs, df_do) confusion:batchAdd(outputs, targets) return f,gradParameters
end
optim.sgd(feval, parameters, optimState)
end confusion:updateValids()
print(('Train accuracy: '..c.cyan'%.2f'..' %%\t time: %.2f s'):format(
confusion.totalValid * , torch.toc(tic))) train_acc = confusion.totalValid * confusion:zero()
epoch = epoch +
end function test()
-- disable flips, dropouts and batch normalization
model:evaluate()
print(c.blue '==>'.." testing")
local bs =
for i=,testData.data:size(),bs do
local outputs = model:forward(testData.data:narrow(,i,bs))
confusion:batchAdd(outputs, testData.labels:narrow(,i,bs))
end confusion:updateValids()
print('Test accuracy:', confusion.totalValid * ) if testLogger then
paths.mkdir(opt.save)
testLogger:add{train_acc, confusion.totalValid * }
testLogger:style{'-','-'}
testLogger:plot() if paths.filep(opt.save..'/test.log.eps') then
local base64im
do
os.execute(('convert -density 200 %s/test.log.eps %s/test.png'):format(opt.save,opt.save))
os.execute(('openssl base64 -in %s/test.png -out %s/test.base64'):format(opt.save,opt.save))
local f = io.open(opt.save..'/test.base64')
if f then base64im = f:read'*all' end
end local file = io.open(opt.save..'/report.html','w')
file:write(([[
<!DOCTYPE html>
<html>
<body>
<title>%s - %s</title>
<img src="data:image/png;base64,%s">
<h4>optimState:</h4>
<table>
]]):format(opt.save,epoch,base64im))
for k,v in pairs(optimState) do
if torch.type(v) == 'number' then
file:write('<tr><td>'..k..'</td><td>'..v..'</td></tr>\n')
end
end
file:write'</table><pre>\n'
file:write(tostring(confusion)..'\n')
file:write(tostring(model)..'\n')
file:write'</pre></body></html>'
file:close()
end
end -- save model every 50 epochs
if epoch % == then
local filename = paths.concat(opt.save, 'model.net')
print('==> saving model to '..filename)
torch.save(filename, model:get():clearState())
end confusion:zero()
end for i=,opt.max_epoch do
train()
test()
end
the original version code:
why they written like this ?
It can not run ...
Sample Classification Code of CIFAR-10 in Torch的更多相关文章
- 【翻译】TensorFlow卷积神经网络识别CIFAR 10Convolutional Neural Network (CNN)| CIFAR 10 TensorFlow
原网址:https://data-flair.training/blogs/cnn-tensorflow-cifar-10/ by DataFlair Team · Published May 21, ...
- code::blocks(版本10.05) 配置opencv2.4.3
(1)首先下载opencv2.4.3, 解压缩到D:下: (2)配置code::blocks, 具体操作如下: 第一步, 配置compiler, 操作步骤为Settings -> Compil ...
- code::blocks(版本号10.05) 配置opencv2.4.3
(1)首先下载opencv2.4.3, 解压缩到D:下: (2)配置code::blocks, 详细操作例如以下: 第一步, 配置compiler, 操作步骤为Settings -> Comp ...
- DL Practice:Cifar 10分类
Step 1:数据加载和处理 一般使用深度学习框架会经过下面几个流程: 模型定义(包括损失函数的选择)——>数据处理和加载——>训练(可能包括训练过程可视化)——>测试 所以自己写代 ...
- 【神经网络与深度学习】基于Windows+Caffe的Minst和CIFAR—10训练过程说明
Minst训练 我的路径:G:\Caffe\Caffe For Windows\examples\mnist 对于新手来说,初步完成环境的配置后,一脸茫然.不知如何跑Demo,有么有!那么接下来的教 ...
- Oracle Applications Multiple Organizations Access Control for Custom Code
档 ID 420787.1 White Paper Oracle Applications Multiple Organizations Access Control for Custom Code ...
- UWP深入学习六:Build better apps: Windows 10 by 10 development series
Promotion in the Windows Store In this article, I walk through how to Give your Store listing a mak ...
- Removing Columns 分类: 贪心 CF 2015-08-08 16:10 10人阅读 评论(0) 收藏
Removing Columns time limit per test 2 seconds memory limit per test 256 megabytes input standard in ...
- CV code references
转:http://www.sigvc.org/bbs/thread-72-1-1.html 一.特征提取Feature Extraction: SIFT [1] [Demo program][SI ...
随机推荐
- 【转】HTTP429
转载:http://codewa.com/question/45600.html Q:How to avoid HTTP error 429 (Too Many Requests) python Q: ...
- Java基础语法(三)
七.方法 定义: 方法就是完成特定功能的代码块 在很多语言里面都有函数的定义 函数在Java中被称为方法 格式: 修饰符 返回值类型 方法名(参数类型 参数名1,参数类型 参数名2…) { 函数体; ...
- UITableView 的坑
1.cell的view和contentView的区别 1.1 addSubView UITableViewCell实例上添加子视图,有两种方式:[cell addSubview:view]或[cell ...
- QTCreator 调试:unknown debugger type "No engine"
[1]QTCreator调试,应用程序输出:unknown debugger type "No engine" 如图:下断点->调试程序->应用程序输出 说明:调试器无 ...
- 字符编码几个缩写 ACR CCS CEF CES TES
摘自https://zhuanlan.zhihu.com/p/27012967 5. 在Unicode Technical Report (UTR统一码技术报告) #17<UNICODE CHA ...
- Django之真正创建一个django项目
真正创建一个django项目 1 创建Django项目 :new-project 2 创建APP : python manager.py startapp app01 3 setting 配 ...
- PHP框架CI CodeIgniter 的log_message开启日志记录方法
PHP框架CI CodeIgniter 的log_message开启日志记录方法 第一步:index.php文件,修改环境为开发环境define(‘ENVIRONMENT’, ‘development ...
- java 数组和集合
1.概念说明 区别:数组固定长度的,集合,数组的长度是可以变化的. List,继承Collection,可重复.有序的对象 Set,继承Collection,不可重复.无序的对象 Map,键值对,提供 ...
- 51Nod 1212 无向图最小生成树 (路径压缩)
N个点M条边的无向连通图,每条边有一个权值,求该图的最小生成树. Input 第1行:2个数N,M中间用空格分隔,N为点的数量,M为边的数量.(2 <= N <= 1000, 1 &l ...
- 像黑客一样使用Linux命令行(转载)
阅读目录 前言 使用 tmux 复用控制台窗口 在命令行中快速移动光标 在命令行中快速删除文本 快速查看和搜索历史命令 快速引用和修饰历史命令 录制屏幕并转换为 gif 动画图片 总结 回到顶部 前言 ...