# 训练设置
# 使用GPU
caffe.set_device(gpu_id) # 若不设置,默认为0
caffe.set_mode_gpu()
# 使用CPU
caffe.set_mode_cpu() # 加载Solver,有两种常用方法
# 1. 无论模型中Slover类型是什么统一设置为SGD
solver = caffe.SGDSolver('/home/xxx/data/solver.prototxt')
# 2. 根据solver的prototxt中solver_type读取,默认为SGD
solver = caffe.get_solver('/home/xxx/data/solver.prototxt') # 训练模型
# 1.1 前向传播
solver.net.forward() # train net
solver.test_nets[0].forward() # test net (there can be more than one)
# 1.2 反向传播,计算梯度
solver.net.backward()
# 2. 进行一次前向传播一次反向传播并根据梯度更新参数
solver.step(1)
# 3. 根据solver文件中设置进行完整model训练
solver.solve()

如果想在训练过程中保存模型参数,调用

solver.net.save('mymodel.caffemodel')

caffe Python API 之Model训练的更多相关文章

  1. caffe Python API 之图片预处理

    # 设定图片的shape格式为网络data层格式 transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) ...

  2. caffe Python API 之BatchNormal

    net.bn = caffe.layers.BatchNorm( net.conv1, batch_norm_param=dict( moving_average_fraction=0.90, #滑动 ...

  3. caffe Python API 之可视化

    一.显示各层 # params显示:layer名,w,b for layer_name, param in net.params.items(): print layer_name + '\t' + ...

  4. caffe Python API 之中值转换

    # 编写一个函数,将二进制的均值转换为python的均值 def convert_mean(binMean,npyMean): blob = caffe.proto.caffe_pb2.BlobPro ...

  5. caffe Python API 之Solver定义

    from caffe.proto import caffe_pb2 s = caffe_pb2.SolverParameter() path='/home/xxx/data/' solver_file ...

  6. caffe Python API 之激活函数ReLU

    import sys import os sys.path.append("/projects/caffe-ssd/python") import caffe net = caff ...

  7. caffe Python API 之 数据输入层(Data,ImageData,HDF5Data)

    import sys sys.path.append('/projects/caffe-ssd/python') import caffe4 net = caffe.NetSpec() 一.Image ...

  8. caffe Python API 之上卷积层(Deconvolution)

    对于convolution: output = (input + 2 * p  - k)  / s + 1; 对于deconvolution: output = (input - 1) * s + k ...

  9. caffe Python API 之Inference

    #以SSD的检测测试为例 def detetion(image_dir,weight,deploy,resolution=300): caffe.set_mode_gpu() net = caffe. ...

随机推荐

  1. 题解 P1888 【三角函数】

    堆排序万岁! 小金羊又来水题了 #include <iostream> #include <queue> using namespace std; priority_queue ...

  2. Qt 多线程同步与通信

    Qt 多线程同步与通信 1 多线程同步 Qt提供了以下几个类来完成这一点:QMutex.QMutexLocker.QSemphore.QWaitCondition. 当然可能还包含QReadWrite ...

  3. hadoop(二)hadoop集群的搭建

    一.集群环境准备工作 1.修改主机名 在root 账户下 vi /etc/sysconfig/network   或者 sudo vi /etc/sysconfig/network 2.设置系统默认启 ...

  4. 简单的函数——Min_25筛

    %%yyb %%zsy 就是实现一下Min-25筛 筛积性函数的操作 首先要得到 $G(M,j)=\sum_{t=j}^{cnt} \sum_{e=1}^{p_t^{e+1}<=M} [\phi ...

  5. 【bzoj3230】相似子串

    Portal -->bzoj3230 Description 给你一个长度为\(n\)的字符串,把它的所有本质不同的子串按字典序大小排序,有\(m\)个询问,对于每一个询问\(x,y\)你需要回 ...

  6. IT(然而其实是。。hdu5244?)

    Time Limit: 3000 ms Memory Limit: 256 MB Description IT = Inverse Transform 两个长度为 \(2^n\) 的序列 \(a,b\ ...

  7. 音视频处理之H264编码标准20170906

    一. H264基础概念 1.名词解释 场和帧 :    视频的一场或一帧可用来产生一个编码图像.在电视中,为减少大面积闪烁现象,把一帧分成两个隔行的场. 片:             每个图象中,若干 ...

  8. bzoj2962 序列操作

    2962: 序列操作 Time Limit: 50 Sec  Memory Limit: 256 MBSubmit: 1145  Solved: 378[Submit][Status][Discuss ...

  9. bzoj 1513 POI2006 Tet-Tetris 3D 二维线段树+标记永久化

    1511: [POI2006]OKR-Periods of Words Time Limit: 5 Sec  Memory Limit: 64 MBSubmit: 351  Solved: 220[S ...

  10. C#中调用Dll动态链接库

    C#中调用Dll动态链接库 起始 受限于语言的不同,我们有的时候可能会用别人提供的函数及方法 或者其他的什么原因.反正就是要调!!! 恰巧别人所使用的的语言跟自己又不是一样的 这个时候想要调用别人的函 ...