# 训练设置
# 使用GPU
caffe.set_device(gpu_id) # 若不设置,默认为0
caffe.set_mode_gpu()
# 使用CPU
caffe.set_mode_cpu() # 加载Solver,有两种常用方法
# 1. 无论模型中Slover类型是什么统一设置为SGD
solver = caffe.SGDSolver('/home/xxx/data/solver.prototxt')
# 2. 根据solver的prototxt中solver_type读取,默认为SGD
solver = caffe.get_solver('/home/xxx/data/solver.prototxt') # 训练模型
# 1.1 前向传播
solver.net.forward() # train net
solver.test_nets[0].forward() # test net (there can be more than one)
# 1.2 反向传播,计算梯度
solver.net.backward()
# 2. 进行一次前向传播一次反向传播并根据梯度更新参数
solver.step(1)
# 3. 根据solver文件中设置进行完整model训练
solver.solve()

如果想在训练过程中保存模型参数,调用

solver.net.save('mymodel.caffemodel')

caffe Python API 之Model训练的更多相关文章

  1. caffe Python API 之图片预处理

    # 设定图片的shape格式为网络data层格式 transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) ...

  2. caffe Python API 之BatchNormal

    net.bn = caffe.layers.BatchNorm( net.conv1, batch_norm_param=dict( moving_average_fraction=0.90, #滑动 ...

  3. caffe Python API 之可视化

    一.显示各层 # params显示:layer名,w,b for layer_name, param in net.params.items(): print layer_name + '\t' + ...

  4. caffe Python API 之中值转换

    # 编写一个函数,将二进制的均值转换为python的均值 def convert_mean(binMean,npyMean): blob = caffe.proto.caffe_pb2.BlobPro ...

  5. caffe Python API 之Solver定义

    from caffe.proto import caffe_pb2 s = caffe_pb2.SolverParameter() path='/home/xxx/data/' solver_file ...

  6. caffe Python API 之激活函数ReLU

    import sys import os sys.path.append("/projects/caffe-ssd/python") import caffe net = caff ...

  7. caffe Python API 之 数据输入层(Data,ImageData,HDF5Data)

    import sys sys.path.append('/projects/caffe-ssd/python') import caffe4 net = caffe.NetSpec() 一.Image ...

  8. caffe Python API 之上卷积层(Deconvolution)

    对于convolution: output = (input + 2 * p  - k)  / s + 1; 对于deconvolution: output = (input - 1) * s + k ...

  9. caffe Python API 之Inference

    #以SSD的检测测试为例 def detetion(image_dir,weight,deploy,resolution=300): caffe.set_mode_gpu() net = caffe. ...

随机推荐

  1. 最大流Dinic算法模板(pascal)

    program rrr(input,output); const inf=; type pointer=^nodetype; nodetype=record t,c:longint; next,rev ...

  2. 【bzoj3992】[SDOI2015]序列统计 原根+NTT

    题目描述 求长度为 $n$ 的序列,每个数都是 $|S|$ 中的某一个,所有数的乘积模 $m$ 等于 $x$ 的序列数目模1004535809的值. 输入 一行,四个整数,N.M.x.|S|,其中|S ...

  3. jsp - redirect重定向 / forward转发

    redirect:请求重定向: 客户端行为,本质上为2次请求,地址栏改变,前一次请求对象不保存, 所以请求携带的数据会丢失. 举例:你去银行办事(forward.jsp),结果告诉你少带了东西,你得先 ...

  4. UVA - 11997(巧妙的优先队列)

    题意: 有k个整数数组,各包含k个元素,在每个数组中取一个元素加起来,可以得到kk个和,求这些和中最小的k个值 解析: 从简单的情况开始分析:经典方法,对原题没有思路,那么分析问题的简化版 这是对于两 ...

  5. Codeforces Round #406 (Div. 2)滚粗记

    A 一看到题,不是一道解不定方程的裸题吗,调了好久exgcd. 其实一个for就好了啊 B 一直WA ON TEST 7真是烦,一想会不会是编号太大了,又写了一个map版本,无用. 调了好久好久才发现 ...

  6. 【生成树,堆】【CF1095F】 Make It Connected

    Description 给定 \(n\) 个点,每个点有点权,连结两个点花费的代价为两点的点权和.另外有 \(m\) 条特殊边,参数为 \(x,y,z\).意为如果你选择这条边,就可以花费 \(z\) ...

  7. activiti学习-用户与用户组

    activiti学习笔记3-用户与用户组 2015年05月07日 14:43:06 cq1982 阅读数:4142更多 个人分类: activiti工作流引擎   (本博客都是纯文本手工代码,错误难免 ...

  8. Codeforces 894.D Ralph And His Tour in Binary Country

    D. Ralph And His Tour in Binary Country time limit per test 2.5 seconds memory limit per test 512 me ...

  9. Netfilter之连接跟踪实现机制初步分析

    Netfilter之连接跟踪实现机制初步分析 原文: http://blog.chinaunix.net/uid-22227409-id-2656910.html 什么是连接跟踪 连接跟踪(CONNT ...

  10. Libevent学习笔记(五) 根据例子学习bufferevent

    libevent中提供了一个Hello-world.c 的例子,从这个例子可以学习libevent是如何使用bufferevent的. 这个例子在Sample中 这个例子之前讲解过,这次主要看下buf ...