# 训练设置
# 使用GPU
caffe.set_device(gpu_id) # 若不设置,默认为0
caffe.set_mode_gpu()
# 使用CPU
caffe.set_mode_cpu() # 加载Solver,有两种常用方法
# 1. 无论模型中Slover类型是什么统一设置为SGD
solver = caffe.SGDSolver('/home/xxx/data/solver.prototxt')
# 2. 根据solver的prototxt中solver_type读取,默认为SGD
solver = caffe.get_solver('/home/xxx/data/solver.prototxt') # 训练模型
# 1.1 前向传播
solver.net.forward() # train net
solver.test_nets[0].forward() # test net (there can be more than one)
# 1.2 反向传播,计算梯度
solver.net.backward()
# 2. 进行一次前向传播一次反向传播并根据梯度更新参数
solver.step(1)
# 3. 根据solver文件中设置进行完整model训练
solver.solve()

如果想在训练过程中保存模型参数,调用

solver.net.save('mymodel.caffemodel')

caffe Python API 之Model训练的更多相关文章

  1. caffe Python API 之图片预处理

    # 设定图片的shape格式为网络data层格式 transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) ...

  2. caffe Python API 之BatchNormal

    net.bn = caffe.layers.BatchNorm( net.conv1, batch_norm_param=dict( moving_average_fraction=0.90, #滑动 ...

  3. caffe Python API 之可视化

    一.显示各层 # params显示:layer名,w,b for layer_name, param in net.params.items(): print layer_name + '\t' + ...

  4. caffe Python API 之中值转换

    # 编写一个函数,将二进制的均值转换为python的均值 def convert_mean(binMean,npyMean): blob = caffe.proto.caffe_pb2.BlobPro ...

  5. caffe Python API 之Solver定义

    from caffe.proto import caffe_pb2 s = caffe_pb2.SolverParameter() path='/home/xxx/data/' solver_file ...

  6. caffe Python API 之激活函数ReLU

    import sys import os sys.path.append("/projects/caffe-ssd/python") import caffe net = caff ...

  7. caffe Python API 之 数据输入层(Data,ImageData,HDF5Data)

    import sys sys.path.append('/projects/caffe-ssd/python') import caffe4 net = caffe.NetSpec() 一.Image ...

  8. caffe Python API 之上卷积层(Deconvolution)

    对于convolution: output = (input + 2 * p  - k)  / s + 1; 对于deconvolution: output = (input - 1) * s + k ...

  9. caffe Python API 之Inference

    #以SSD的检测测试为例 def detetion(image_dir,weight,deploy,resolution=300): caffe.set_mode_gpu() net = caffe. ...

随机推荐

  1. BZOJ 1799 同类分布(数位DP)

    给出a,b,求出[a,b]中各位数字之和能整除原数的数的个数.1<=a<=b<=1e18. 注意到各位数字之和最大是153.考虑枚举这个东西.那么需要统计的是[0,a-1]和[0,b ...

  2. hdu6021[BestCoder #93] MG loves string

    这场BC实在是有趣啊,T2是个没有什么算法但是细节坑的贪心+分类讨论乱搞,T3反而码起来很顺. 然后出现了T2过的人没有T3多的现象(T2:20人,T3:30人),而且T2的AC率是惨烈的不到3% ( ...

  3. 转---秒杀多线程第八篇 经典线程同步 信号量Semaphore

    阅读本篇之前推荐阅读以下姊妹篇: <秒杀多线程第四篇一个经典的多线程同步问题> <秒杀多线程第五篇经典线程同步关键段CS> <秒杀多线程第六篇经典线程同步事件Event& ...

  4. [BZOJ4653][NOI2016]区间 贪心+线段树

    4653: [Noi2016]区间 Time Limit: 60 Sec  Memory Limit: 256 MB Description 在数轴上有 n个闭区间 [l1,r1],[l2,r2],. ...

  5. 【BZOJ5306】[HAOI2018]染色(NTT)

    [BZOJ5306]染色(NTT) 题面 BZOJ 洛谷 题解 我们只需要考虑每一个\(W[i]\)的贡献就好了 令\(lim=min(M,\frac{N}{S})\) 那么,开始考虑每一个\(W[i ...

  6. 洛谷 3201 [HNOI2009]梦幻布丁 解题报告

    3201 [HNOI2009]梦幻布丁 题目描述 \(N\)个布丁摆成一行,进行\(M\)次操作.每次将某个颜色的布丁全部变成另一种颜色的,然后再询问当前一共有多少段颜色.例如颜色分别为\(1,2,2 ...

  7. codeforces765F Souvenirs

    本文版权归ljh2000和博客园共有,欢迎转载,但须保留此声明,并给出原文链接,谢谢合作. 本文作者:ljh2000 作者博客:http://www.cnblogs.com/ljh2000-jump/ ...

  8. Java之面向对象编程20170619

    /*************************************************************************************************** ...

  9. php前后端分离项目跨域问题解决办法

    由于之前一直没有做过前后端分离项目,导致走了不少弯路,而且还采用了一种及其不优雅的方法 (在第一次请求的时候把服务器返回的session id保存起来,后续请求的时候把该session id作为参数传 ...

  10. vmvare安装ubuntu后

    配置源: http://wiki.ubuntu.org.cn/%E6%BA%90%E5%88%97%E8%A1%A8#Trusty.2814.04.29.E7.89.88.E6.9C.AC 清理工作: ...