pytorch中调用C进行扩展,使得某些功能在CPU上运行更快;

第一步:编写头文件

/* src/my_lib.h */
int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, THFloatTensor *output);
int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input);

第二步:编写源文件

/* src/my_lib.c */
#include <TH/TH.h> int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
THFloatTensor *output)
{
if (!THFloatTensor_isSameSizeAs(input1, input2))
return ;
THFloatTensor_resizeAs(output, input1);
THFloatTensor_cadd(output, input1, 1.0, input2);
return ;
} int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input)
{
THFloatTensor_resizeAs(grad_input, grad_output);
THFloatTensor_fill(grad_input, );
return ;
}

注意:头文件TH就是pytorch底层代码的接口头文件,它是CPU模式,GPU下则为THC;

第三步:在同级目录下创建一个.py文件(比如叫“build.py”)

该文件用于对该C扩展模块进行编译(使用torch.util.ffi模块进行扩展编译);

# build.py
from torch.utils.ffi import create_extension
ffi = create_extension(
name='_ext.my_lib', # 输出文件地址及名称
headers='src/my_lib.h', # 编译.h文件地址及名称
sources=['src/my_lib.c'], # 编译.c文件地址及名称
with_cuda=False # 不使用cuda
)
ffi.build()

第四步:编写.py脚本调用编译好的C扩展模块

import torch
from torch.autograd import Function
from _ext import my_lib
import torch.nn as nn class MyAddFunction(Function):
def forward(self, input1, input2):
output = torch.FloatTensor()
my_lib.my_lib_add_forward(input1, input2, output)
return output def backward(self, grad_output):
grad_input = torch.FloatTensor()
my_lib.my_lib_add_backward(grad_input, grad_output)
return grad_input class MyAddModule(nn.Module):
def forward(self, input1, input2):
return MyAddFunction()(input1, input2) class MyNetWork(nn.Module):
def __init__(self):
super(MyNetWork, self).__init__()
self.add = MyAddModule() def forward(self, input1, input2):
return self.add(input1, input2) model = MyNetWork()
input1, input2 = torch.randn(5, 5), torch.randn(5, 5)
print(model(input1, input2))
print(input1 + input2)

至此,用这个简单的例子抛砖引玉~

pytorch中调用C进行扩展的更多相关文章

  1. tp中调用PHP系统扩展类

    例如使用Redis扩展类: use Reids; $redis = new Redis();

  2. PyTorch中的C++扩展

    今天要聊聊用 PyTorch 进行 C++ 扩展. 在正式开始前,我们需要了解 PyTorch 如何自定义module.这其中,最常见的就是在 python 中继承torch.nn.Module,用 ...

  3. iOS 中 h5 页面 iframe 调用高度自扩展问题及解决

    开发需求需要在 h5 中用 iframe 中调用一个其他公司开发的 html 页面. 简单的插入 <iframe /> 并设置宽高后,发现在 Android 手机浏览器上打开可以正常运行, ...

  4. C#中如果类的扩展方法和类本身的方法签名相同,那么会优先调用类本身的方法

    新建一个.NET Core项目,假如我们有如下代码: using System; namespace MethodOverload { static class DemoExtension { pub ...

  5. pytorch中使用cuda扩展

    以下面这个例子作为教程,实现功能是element-wise add: (pytorch中想调用cuda模块,还是用另外使用C编写接口脚本) 第一步:cuda编程的源文件和头文件 // mathutil ...

  6. Unity中调用Windows窗口句柄以及根据需求设置并且解决扩展屏窗体显示错乱/位置错误的Bug

    问题背景: 现在在搞PC端应用开发,我们开发中需要调用系统的窗口以及需要最大化最小化,缩放窗口拖拽窗口,以及设置窗口位置,去边框等功能 解决根据: 使用user32.dll解决 具体功能: Unity ...

  7. Pytorch中RoI pooling layer的几种实现

    Faster-RCNN论文中在RoI-Head网络中,将128个RoI区域对应的feature map进行截取,而后利用RoI pooling层输出7*7大小的feature map.在pytorch ...

  8. WebApi接口 - 如何在应用中调用webapi接口

    很高兴能再次和大家分享webapi接口的相关文章,本篇将要讲解的是如何在应用中调用webapi接口:对于大部分做内部管理系统及类似系统的朋友来说很少会去调用别人的接口,因此可能在这方面存在一些困惑,希 ...

  9. Mybatis中SqlMapper配置的扩展与应用(1)

    奋斗了好几个晚上调试程序,写了好几篇博客,终于建立起了Mybatis配置的扩展机制.虽然扩展机制是重要的,然而如果没有真正实用的扩展功能,那也至少是不那么鼓舞人心的,这篇博客就来举几个扩展的例子. 这 ...

随机推荐

  1. WebKit应用程序开发的起因

    公司的Web管理界面,用起来还可以,但是有一个问题一直无法解决. 在web页面上需要播放视频,由于比较这个功能比较老,不支持web模式播放,只支持CS模式,具体原因及不说了. 于是有了 winform ...

  2. children(),find()

    向下遍历 DOM 树 下面是两个用于向下遍历 DOM 树的 jQuery 方法: children() find() jQuery children() 方法 children() 方法返回被选元素的 ...

  3. Dynamics CRM 数据数量限制更改

    1.在CRM2016中如果想要导出超过10000记录数据,更新 MaxRecordsForExportToExcel  这个字段的值. SELECT MaxRecordsForExportToExce ...

  4. flask + websocket实现简单的单聊和群聊

    单聊 from flask import Flask,request,render_template from geventwebsocket.handler import WebSocketHand ...

  5. [51 Nod 1584] 加权约数和

    题意 求∑i=1N∑j=1Nmax(i,j)⋅σ1(ij)\large \sum_{i=1}^N\sum_{j=1}^Nmax(i,j)\cdot\sigma_1(ij)i=1∑N​j=1∑N​max ...

  6. Tips on Acoustic Signal Processing

    1.声音的三个主要的主观属性(即音量.音调.音色).音色(Timbre)是指不同的声音的频率表现在波形方面总是有与众不同的特性,音色的不同取决于不同的泛音.频率的高低决定声音的音调,振幅的大小决定声音 ...

  7. 洛谷 P4281 [AHOI2008] 紧急集合 题解

    挺好的一道题,本身不难,就把求两个点的LCA变为求三个点两两求LCA,不重合的点才是最优解.值得一提的是,最后对答案的处理运用差分的思想:假设两点 一点深度为d1,另一点 深度为d2,它们LCA深度为 ...

  8. WinDbg常用命令系列---|(进程状态)

    |(进程状态) 简介 (|) 命令显示指定进程的状态或当前正在调试你的所有进程. 使用形式 | Process 参数 Process 指定要显示的进程. 如果省略此参数,将显示所有正在调试的进程. 支 ...

  9. WinDbg常用命令系列---显示引用的内存(dda、ddp、ddu、dpa、dpp、dpu、dqa、dqp、dqu)

    命令dda, ddp, ddu, dpa, dpp, dpu, dqa, dqp, 和 dqu在指定位置显示指针,取消对该指针的引用,然后以各种格式显示结果位置的内存. ddp [Options] [ ...

  10. Theano入门笔记2:scan函数等

    1.Theano中的scan函数 目前先弱弱的认为:相当于symbolic的for循环吧,或者说计算图上的for循环,也可以用来替代repeat-until. 与scan相比,scan_checkpo ...