# 训练设置
# 使用GPU
caffe.set_device(gpu_id) # 若不设置,默认为0
caffe.set_mode_gpu()
# 使用CPU
caffe.set_mode_cpu() # 加载Solver,有两种常用方法
# 1. 无论模型中Slover类型是什么统一设置为SGD
solver = caffe.SGDSolver('/home/xxx/data/solver.prototxt')
# 2. 根据solver的prototxt中solver_type读取,默认为SGD
solver = caffe.get_solver('/home/xxx/data/solver.prototxt') # 训练模型
# 1.1 前向传播
solver.net.forward() # train net
solver.test_nets[0].forward() # test net (there can be more than one)
# 1.2 反向传播,计算梯度
solver.net.backward()
# 2. 进行一次前向传播一次反向传播并根据梯度更新参数
solver.step(1)
# 3. 根据solver文件中设置进行完整model训练
solver.solve()

如果想在训练过程中保存模型参数,调用

solver.net.save('mymodel.caffemodel')

caffe Python API 之Model训练的更多相关文章

  1. caffe Python API 之图片预处理

    # 设定图片的shape格式为网络data层格式 transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) ...

  2. caffe Python API 之BatchNormal

    net.bn = caffe.layers.BatchNorm( net.conv1, batch_norm_param=dict( moving_average_fraction=0.90, #滑动 ...

  3. caffe Python API 之可视化

    一.显示各层 # params显示:layer名,w,b for layer_name, param in net.params.items(): print layer_name + '\t' + ...

  4. caffe Python API 之中值转换

    # 编写一个函数,将二进制的均值转换为python的均值 def convert_mean(binMean,npyMean): blob = caffe.proto.caffe_pb2.BlobPro ...

  5. caffe Python API 之Solver定义

    from caffe.proto import caffe_pb2 s = caffe_pb2.SolverParameter() path='/home/xxx/data/' solver_file ...

  6. caffe Python API 之激活函数ReLU

    import sys import os sys.path.append("/projects/caffe-ssd/python") import caffe net = caff ...

  7. caffe Python API 之 数据输入层(Data,ImageData,HDF5Data)

    import sys sys.path.append('/projects/caffe-ssd/python') import caffe4 net = caffe.NetSpec() 一.Image ...

  8. caffe Python API 之上卷积层(Deconvolution)

    对于convolution: output = (input + 2 * p  - k)  / s + 1; 对于deconvolution: output = (input - 1) * s + k ...

  9. caffe Python API 之Inference

    #以SSD的检测测试为例 def detetion(image_dir,weight,deploy,resolution=300): caffe.set_mode_gpu() net = caffe. ...

随机推荐

  1. 【bzoj2121】字符串游戏 区间dp

    题目描述 给你一个字符串L和一个字符串集合S,如果S的某个子串在S集合中,那么可以将其删去,剩余的部分拼到一起成为新的L串.问:最后剩下的串长度的最小值. 输入 输入的第一行包含一个字符串,表示L. ...

  2. BZOJ5074 小B的数字

    对bi取log,则相当于Σbi<=min{bi*ai}.注意到值域很小,那么如果有解,使其成立的最小的Σbi不会很大,大胆猜想不超过Σai.然而一点也不会(xiang)证.暴力枚举就好了. #i ...

  3. React Native工程中TSLint静态检查工具的探索之路

    建立的代码规范没人遵守,项目中遍地风格迥异的代码,你会不会抓狂? 通过测试用例的程序还会出现Bug,而原因仅仅是自己犯下的低级错误,你会不会抓狂? 某种代码写法存在问题导致崩溃时,只能全工程检查代码, ...

  4. 输入三个数a,b,n,输出a和b不大于n的公倍数的个数

    题:输入三个数a,b,n,输出a和b不大于n的公倍数的所有个数. 这题的思想是先求得a和b的最大公约数,然后用a和b的积除以最大公约数,得到最小公倍数,再持续加上最小公倍数,直到超过n,记下n的个数. ...

  5. php安装gd库

    安装gd需要以下库: gd-2.0.33.tar.gz http://www.boutell.com/gd/ jpegsrc.v6b.tar.gz http://www.ijg.org/ libpng ...

  6. bzoj2314: 士兵的放置(树形DP)

    0表示被父亲控制,1表示被儿子控制,2表示被自己控制.f表示最少士兵数,g表示方案数. 转移贼难写,写了好久之后写不下去了,看了一眼题解,学习了...原来还可以这么搞 比如求f[i][1]的时候,要在 ...

  7. python学习(十七) 爬取MM图片

    这一篇巩固前几篇文章的学到的技术,利用urllib库爬取美女图片,其中采用了多线程,文件读写,目录匹配,正则表达式解析,字符串拼接等知识,这些都是前文提到的,综合运用一下,写个爬虫示例爬取美女图片.先 ...

  8. Splay 区间操作(二)

    首先基本操作如下: 删除第rank个点 void Remove(int id){//删除第rank个点 rank++; int x = find(root, rank - 1); splay(x, 0 ...

  9. JS获取URL中参数值(QueryString)的4种方法分享

    方法一:正则法 function getQueryString(name) { var reg = new RegExp('(^|&)' + name + '=([^&]*)(& ...

  10. Tensorflow BatchNormalization详解:4_使用tf.nn.batch_normalization函数实现Batch Normalization操作

    使用tf.nn.batch_normalization函数实现Batch Normalization操作 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 吴恩达deeplearnin ...