Nvidia Tensor Core-WMMA API编程入门
1 WMMA (Warp-level Matrix Multiply Accumulate) API
template<typename Use, int m, int n, int k, typename T, typename Layout=void> class fragment; void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm);
void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm, layout_t layout);
void store_matrix_sync(T* mptr, const fragment<...> &a, unsigned ldm, layout_t layout);
void fill_fragment(fragment<...> &a, const T& v);
void mma_sync(fragment<...> &d, const fragment<...> &a, const fragment<...> &b, const fragment<...> &c, bool satf=false);
- fragment:Tensor Core数据存储类,支持matrix_a、matrix_b和accumulator
- load_matrix_sync:Tensor Core数据加载API,支持将矩阵数据从global memory或shared memory加载到fragment
- store_matrix_sync:Tensor Core结果存储API,支持将计算结果从fragment存储到global memory或shared memory
- fill_fragment:fragment填充API,支持常数值填充
- mma_sync:Tensor Core矩阵乘计算API,支持D = AB + C或者C = AB + C
2 示例
2.1 CUDA Core
#define DIV_CEIL(x, y) (((x) + (y) - 1) / (y)) __global__ void naiveKernel(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C, size_t M,
size_t N, size_t K) {
size_t row = threadIdx.x + blockDim.x * blockIdx.x;
size_t col = threadIdx.y + blockDim.y * blockIdx.y;
if (row < M && col < N) {
half tmp = 0.0;
for (size_t i = 0; i < K; ++i) {
tmp += A[row * K + i] * B[i + col * K];
}
C[row * N + col] = tmp;
}
} void hgemmNaive(half *A, half *B, half *C, size_t M, size_t N, size_t K) {
dim3 block(16, 16);
dim3 grid(DIV_CEIL(M, block.x), DIV_CEIL(N, block.y)); naiveKernel<<<grid, block>>>(A, B, C, M, N, K);
}
2.2 Tensor Core
#include <mma.h> #define WARP_SIZE 32 #define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16 using namespace nvcuda; __global__ void wmmaNaiveKernel(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C, size_t M,
size_t N, size_t K) {
size_t warpM = (blockIdx.x * blockDim.x + threadIdx.x) / WARP_SIZE;
size_t warpN = (blockIdx.y * blockDim.y + threadIdx.y); wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> c_frag; wmma::fill_fragment(c_frag, 0.0f); for (size_t i = 0; i < K; i += WMMA_K) {
size_t aCol = i;
size_t aRow = warpM * WMMA_M;
size_t bCol = warpN * WMMA_N;
size_t bRow = i; if (aRow < M && aCol < K && bRow < K && bCol < N) {
wmma::load_matrix_sync(a_frag, A + aCol + aRow * K, K);
wmma::load_matrix_sync(b_frag, B + bRow + bCol * K, K); wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
} size_t cCol = warpN * WMMA_N;
size_t cRow = warpM * WMMA_M; if (cRow < M && cCol < N) {
wmma::store_matrix_sync(C + cCol + cRow * N, c_frag, N, wmma::mem_row_major);
}
} void hgemmWmmaNaive(half *A, half *B, half *C, size_t M, size_t N, size_t K) {
dim3 block(128, 4);
dim3 grid((M - 1) / (WMMA_M * block.x / WARP_SIZE) + 1, (N - 1) / (WMMA_N * block.y) + 1); wmmaNaiveKernel<<<grid, block>>>(A, B, C, M, N, K);
}
2.3 区别
- 计算层级:CUDA Core是线程级别,Tensor Core是warp级别
- 计算维度:CUDA Core是一维逐点计算,Tensor Core是二维逐tile计算
- 计算依赖:WMMA调用Tensor Core需要借助数据存储类fragment,CUDA Core不需要借助其他
3 底层代码
3.1 PTX
.visible .entry _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm(
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_0,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_1,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_2,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_3,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_4,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_5
)
{
.reg .pred %p<8>;
.reg .b16 %rs<2>;
.reg .f32 %f<2>;
.reg .b32 %r<58>;
.reg .b64 %rd<28>; ld.param.u64 %rd9, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_0];
ld.param.u64 %rd10, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_1];
ld.param.u64 %rd11, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_2];
ld.param.u64 %rd14, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_3];
ld.param.u64 %rd12, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_4];
ld.param.u64 %rd13, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_5];
mov.u32 %r19, %ntid.x;
mov.u32 %r20, %ctaid.x;
mov.u32 %r21, %tid.x;
mad.lo.s32 %r22, %r20, %r19, %r21;
mov.u32 %r23, %ntid.y;
mov.u32 %r24, %ctaid.y;
mov.u32 %r25, %tid.y;
mad.lo.s32 %r26, %r24, %r23, %r25;
mov.f32 %f1, 0f00000000; { cvt.rn.f16.f32 %rs1, %f1;} mov.b32 %r50, {%rs1, %rs1};
mul.wide.u32 %rd1, %r26, 16;
shr.u32 %r27, %r22, 1;
and.b32 %r28, %r27, 2147483632;
cvt.u64.u32 %rd2, %r28;
setp.lt.u64 %p2, %rd2, %rd14;
setp.lt.u64 %p3, %rd1, %rd12;
and.pred %p1, %p2, %p3;
setp.eq.s64 %p4, %rd13, 0;
mov.u32 %r51, %r50;
mov.u32 %r52, %r50;
mov.u32 %r53, %r50;
@%p4 bra $L__BB0_5; mul.lo.s64 %rd3, %rd2, %rd13;
cvt.u32.u64 %r2, %rd13;
mul.lo.s64 %rd4, %rd1, %rd13;
cvta.to.global.u64 %rd5, %rd10;
cvta.to.global.u64 %rd6, %rd9;
mov.u64 %rd27, 0;
not.pred %p5, %p1;
mov.u32 %r51, %r50;
mov.u32 %r52, %r50;
mov.u32 %r53, %r50; $L__BB0_2:
@%p5 bra $L__BB0_4; add.s64 %rd16, %rd27, %rd3;
shl.b64 %rd17, %rd16, 1;
add.s64 %rd18, %rd6, %rd17;
wmma.load.a.sync.aligned.row.m16n16k16.global.f16 {%r29, %r30, %r31, %r32, %r33, %r34, %r35, %r36}, [%rd18], %r2;
add.s64 %rd19, %rd27, %rd4;
shl.b64 %rd20, %rd19, 1;
add.s64 %rd21, %rd5, %rd20;
wmma.load.b.sync.aligned.col.m16n16k16.global.f16 {%r37, %r38, %r39, %r40, %r41, %r42, %r43, %r44}, [%rd21], %r2;
wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16 {%r53, %r52, %r51, %r50}, {%r29, %r30, %r31, %r32, %r33, %r34, %r35, %r36}, {%r37, %r38, %r39, %r40, %r41, %r42, %r43, %r44}, {%r53, %r52, %r51, %r50}; $L__BB0_4:
add.s64 %rd27, %rd27, 16;
setp.lt.u64 %p6, %rd27, %rd13;
@%p6 bra $L__BB0_2; $L__BB0_5:
not.pred %p7, %p1;
@%p7 bra $L__BB0_7; mul.lo.s64 %rd22, %rd2, %rd12;
add.s64 %rd23, %rd22, %rd1;
cvta.to.global.u64 %rd24, %rd11;
shl.b64 %rd25, %rd23, 1;
add.s64 %rd26, %rd24, %rd25;
cvt.u32.u64 %r45, %rd12;
wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%rd26], {%r53, %r52, %r51, %r50}, %r45; $L__BB0_7:
ret; }
不过我们主要关注WMMA相关的PTX指令,如下所示。可以看到这里正是Nvidia提供的WMMA PTX指令来调用Tensor Core,所以无论是使用WMMA API编程,还是使用WMMA PTX指令编程,底层差别不会太大。
wmma.load.a.sync.aligned.row.m16n16k16.global.f16
wmma.load.b.sync.aligned.col.m16n16k16.global.f16
wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16
wmma.store.d.sync.aligned.row.m16n16k16.global.f16
3.2 SASS
IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28]
S2R R0, SR_CTAID.X
ISETP.NE.U32.AND P2, PT, RZ, c[0x0][0x188], PT
ULDC.64 UR4, c[0x0][0x118]
CS2R R8, SRZ
S2R R10, SR_CTAID.Y
ISETP.NE.AND.EX P2, PT, RZ, c[0x0][0x18c], PT, P2
S2R R5, SR_TID.Y
S2R R3, SR_TID.X
IMAD R10, R10, c[0x0][0x4], R5
IMAD R0, R0, c[0x0][0x0], R3
IMAD.WIDE.U32 R10, R10, 0x10, RZ
CS2R R2, SRZ
SHF.R.U32.HI R0, RZ, 0x1, R0
ISETP.GE.U32.AND P0, PT, R10, c[0x0][0x180], PT
LOP3.LUT R13, R0, 0x7ffffff0, RZ, 0xc0, !PT
ISETP.GE.U32.AND.EX P0, PT, R11, c[0x0][0x184], PT, P0
ISETP.LT.U32.AND P1, PT, R13, c[0x0][0x178], PT
ISETP.LT.U32.AND.EX P0, PT, RZ, c[0x0][0x17c], !P0, P1
@!P2 BRA 0x7f1eaefc0160
BSSY B0, 0x7f1eaefc0160
IMAD.MOV.U32 R0, RZ, RZ, RZ
CS2R R8, SRZ
IMAD.MOV.U32 R15, RZ, RZ, RZ
IMAD.MOV.U32 R2, RZ, RZ, RZ
BSSY B1, 0x7f1eaefc0100
@!P0 BRA 0x7f1eaefc00f0
S2R R16, SR_LANEID
IMAD R17, R11, c[0x0][0x188], RZ
IMAD.MOV.U32 R14, RZ, RZ, R0
IMAD.MOV.U32 R23, RZ, RZ, c[0x0][0x188]
IMAD.WIDE.U32 R6, R10, c[0x0][0x188], R14
SHF.R.U32.HI R12, RZ, 0x1, R23
IMAD R17, R10, c[0x0][0x18c], R17
LEA R21, P2, R6, c[0x0][0x168], 0x1
IMAD.WIDE.U32 R4, R13, c[0x0][0x188], R14
IMAD.IADD R7, R7, 0x1, R17
IMAD.MOV.U32 R17, RZ, RZ, RZ
IMAD R5, R13, c[0x0][0x18c], R5
LEA.HI.X R7, R6, c[0x0][0x16c], R7, 0x1, P2
SHF.R.U32.HI R19, RZ, 0x2, R16
LOP3.LUT R16, R16, 0x3, RZ, 0xc0, !PT
IMAD.WIDE.U32 R16, R19, R12, R16
LEA R19, P1, R4, c[0x0][0x160], 0x1
LEA.HI.X R5, R4, c[0x0][0x164], R5, 0x1, P1
LEA R18, P1, R16, R19, 0x2
LEA R20, P2, R16, R21, 0x2
LEA.HI.X R19, R16, R5, R17, 0x2, P1
LEA.HI.X R21, R16, R7, R17, 0x2, P2
IMAD.WIDE.U32 R16, R23, 0x10, R18
LDG.E R4, [R18.64]
IMAD.WIDE.U32 R22, R23, 0x10, R20
LDG.E R24, [R20.64]
LDG.E R25, [R20.64+0x10]
LDG.E R6, [R18.64+0x10]
LDG.E R5, [R16.64]
LDG.E R7, [R16.64+0x10]
LDG.E R26, [R22.64]
LDG.E R27, [R22.64+0x10]
WARPSYNC 0xffffffff
HMMA.16816.F16 R8, R4, R24, R8
HMMA.16816.F16 R2, R4, R26, R2
NOP
BSYNC B1
IADD3 R0, P1, R0, 0x10, RZ
IMAD.X R15, RZ, RZ, R15, P1
ISETP.GE.U32.AND P1, PT, R0, c[0x0][0x188], PT
ISETP.GE.U32.AND.EX P1, PT, R15, c[0x0][0x18c], PT, P1
@!P1 BRA 0x7f1eaefbfe90
BSYNC B0
@!P0 EXIT
S2R R4, SR_LANEID
IMAD.MOV.U32 R15, RZ, RZ, c[0x0][0x180]
WARPSYNC 0xffffffff
IMAD.WIDE.U32 R10, R13, c[0x0][0x180], R10
SHF.R.U32.HI R15, RZ, 0x1, R15
IMAD.MOV.U32 R5, RZ, RZ, RZ
LEA R7, P0, R10, c[0x0][0x170], 0x1
IMAD R11, R13, c[0x0][0x184], R11
LEA.HI.X R11, R10, c[0x0][0x174], R11, 0x1, P0
SHF.R.U32.HI R0, RZ, 0x2, R4
LOP3.LUT R4, R4, 0x3, RZ, 0xc0, !PT
IMAD.WIDE.U32 R4, R0, R15, R4
LEA R6, P0, R4, R7, 0x2
LEA.HI.X R7, R4, R11, R5, 0x2, P0
IMAD.WIDE.U32 R4, R15, 0x20, R6
STG.E [R6.64], R8
STG.E [R4.64], R9
STG.E [R6.64+0x10], R2
STG.E [R4.64+0x10], R3
EXIT
BRA 0x7f1eaefc02b0
NOP
NOP
NOP
NOP
NOP
NOP
NOP
NOP
NOP
NOP
NOP
NOP
我们依然主要关注WMMA相关的SASS指令,如下所示。可以发现WMMA161616在底层是通过两个HMMA16816指令实现,同样地,SASS指令也是Nvidia提供的另一种调用Tensor Core的编程方法。
HMMA.16816.F16
4 其他
4.1 HGEMM优化
Nvidia Tensor Core-WMMA API编程入门的更多相关文章
- Mysql C语言API编程入门讲解
原文:Mysql C语言API编程入门讲解 软件开发中我们经常要访问数据库,存取数据,之前已经有网友提出让鸡啄米讲讲数据库编程的知识,本文就详细讲解如何使用Mysql的C语言API进行数据库编程. ...
- Windows API 编程入门
Windows 工作原理的中心思想就是“动态链接”概念.Windows 自身带有一大套函数,应用程序就是通过调用这些函数 来实现它的用户界面和在屏幕上显示文本和图形的.这些函数都是在动态链接库里实现的 ...
- NVIDIA Tensor Cores解析
NVIDIA Tensor Cores解析 高性能计算机和人工智能前所未有的加速 Tensor Cores支持混合精度计算,动态调整计算以加快吞吐量,同时保持精度.最新一代将这些加速功能扩展到各种工作 ...
- NVIDIA深度学习Tensor Core性能解析(下)
NVIDIA深度学习Tensor Core性能解析(下) DeepBench推理测试之RNN和Sparse GEMM DeepBench的最后一项推理测试是RNN和Sparse GEMM,虽然测试中可 ...
- NVIDIA深度学习Tensor Core性能解析(上)
NVIDIA深度学习Tensor Core性能解析(上) 本篇将通过多项测试来考验Volta架构,利用各种深度学习框架来了解Tensor Core的性能. 很多时候,深度学习这样的新领域会让人难以理解 ...
- 转载自~浮云比翼:Step by Step:Linux C多线程编程入门(基本API及多线程的同步与互斥)
Step by Step:Linux C多线程编程入门(基本API及多线程的同步与互斥) 介绍:什么是线程,线程的优点是什么 线程在Unix系统下,通常被称为轻量级的进程,线程虽然不是进程,但却可 ...
- 《ASP.NET Core跨平台开发从入门到实战》Web API自定义格式化protobuf
<ASP.NET Core跨平台开发从入门到实战>样章节 Web API自定义格式化protobuf. 样章 Protocol Buffers 是一种轻便高效的结构化数据存储格式,可以用于 ...
- 初识Django —Python API接口编程入门
初识Django —Python API接口编程入门 一.WEB架构的简单介绍 Django是什么? Django是一个开放源代码的Web应用框架,由Python写成.我们的目标是用Python语言, ...
- Storm编程入门API系列之Storm的Topology多个Workers数目控制实现
前期博客 Storm编程入门API系列之Storm的Topology默认Workers.默认executors和默认tasks数目 继续编写 StormTopologyMoreWorker.java ...
- Storm编程入门API系列之Storm的Topology多个Executors数目控制实现
前期博客 Storm编程入门API系列之Storm的Topology默认Workers.默认executors和默认tasks数目 Storm编程入门API系列之Storm的Topology多个Wor ...
随机推荐
- 2003031118—李伟—Python数据分析第七周作业—MySQL的安装以及使用
项目 MySQL的安装以及使用 课程班级博客链接 20级数据班(本) 这个作业要求链接 作业要求 博客名称 2003031118-李伟-Python数据分析第七周作业-MySQL的安装以及使用 ...
- 获取n位数m进制的随机数 js
js 获取n位数m进制的随机数 n 的取值范围为 0 < n > 1.7976931348623157e+308 (Number.MAX_VALUE) m的取值范围为 2 <= m ...
- python for houdini——python在houdini中的基础应用02
内容来源于网上视频 一.houdini python编译器 1.python shell 2.python source editor----代码可以随场景保存 构造的函数可以在外部通过hou.ses ...
- Vue3+Vite项目中 使用WindiCSS.
之前工作有了解过根据类名来写元素的样式,一听就发出疑问:这样写项目可读性恐怕不是很好吧... 之后来到杭州工作后,开始使用WindiCSS后发现 真香!!! 由于近期所写的项目都是自己一个人开发的 ...
- 什么是 SpringMvc
SpringMvc 是 spring 的一个模块,基于 MVC 的一个框架,无需中间整合层来整合
- OpenCV图像拼接函数
图像拼接函数 第一种方法:通过遍历图像,将待拼接的图像每个像素赋值给输出图像 //图像拼接函数 //imageVector 输入图像数组 //outputImage 输出图像 //colCount_ ...
- php运行找不到命令
这个跟环境path设置有关: 1. 找php.ini位置./www/wdlinux/apache_php-5.5.38/bin 2. 写入默认path: export PATH=$PATH:/www/ ...
- vue 作者在2022-2-7起宣布 vue3 正式作为默认版本
vue 作者在2022-2-7起宣布 vue3 正式作为默认版本 vue 作者尤雨溪在知乎上发布一篇文章,宣布 Vue3 将在 2022 年 2 月 7 日 成为新的默认版本! 并且还在文章中做出了一 ...
- kali更新源数字签名错误解决办法
apt-get update更新时出现错误,提示Release文件已经过期,无论是使用kali官方源还是阿里源.中科大源都报该错误. 网上查找相关资料,签名出错需要下载数字签名,方案如下: wget ...
- ngix安装与使用
主要是nginx的安装使用, 至于原理 1. 安装nginx(以及两个tomcat) 2. 使用nginx(测试负载均衡) 想要搭建的测试环境, 1.两个tomcat, 端口分别是80和8090(因为 ...