最近打PKU的HPCGAME用的代码,这里只用上了20个zmm寄存器,改变block的大小应该还能优化一下速度。

代码只考虑了方阵,其他非2^n次方阵要自己改代码。具体原理很简单,看看代码就差不多知道。

const int BLOCK_SIZE = 1024;
const int BLOCK_SIZE2 = 256; inline static void block_avx512_32x4( // AVX256效果不好,硬着头皮上吧(汇编上看好像还有12个寄存器没用上,还有优化空间)
int n, int K, //方阵大小
double* A, double* B, double* C)
{
__m512d c0000_0700,c0800_1500, c1600_2300, c2400_3100,
c0001_0701, c0801_1501, c1601_2301, c2401_3101,
c0002_0702, c0802_1502, c1602_2302, c2402_3102,
c0003_0703, c0803_1503, c1603_2303, c2403_3103; __m512d a0x_7x, a8x_15x, a16x_23x, a24x_31x,
bx0, bx1, bx2, bx3; double* c0001_0701_ptr = C + n;
double* c0002_0702_ptr = C + n * 2;
double* c0003_0703_ptr = C + n * 3; c0000_0700 = _mm512_load_pd(C);
c0800_1500 = _mm512_load_pd(C + 8);
c1600_2300 = _mm512_load_pd(C + 16);
c2400_3100 = _mm512_load_pd(C + 24); c0001_0701 = _mm512_load_pd(c0001_0701_ptr);
c0801_1501 = _mm512_load_pd(c0001_0701_ptr + 8);
c1601_2301 = _mm512_load_pd(c0001_0701_ptr + 16);
c2401_3101 = _mm512_load_pd(c0001_0701_ptr + 24); c0002_0702 = _mm512_load_pd(c0002_0702_ptr);
c0802_1502 = _mm512_load_pd(c0002_0702_ptr + 8);
c1602_2302 = _mm512_load_pd(c0002_0702_ptr + 16);
c2402_3102 = _mm512_load_pd(c0002_0702_ptr + 24); c0003_0703 = _mm512_load_pd(c0003_0703_ptr);
c0803_1503 = _mm512_load_pd(c0003_0703_ptr + 8);
c1603_2303 = _mm512_load_pd(c0003_0703_ptr + 16);
c2403_3103 = _mm512_load_pd(c0003_0703_ptr + 24); for (int x = 0; x < K; ++x)
{
a0x_7x = _mm512_load_pd(A);
a8x_15x = _mm512_load_pd(A + 8);
a16x_23x = _mm512_load_pd(A + 16);
a24x_31x = _mm512_load_pd(A + 24);
A += 32; bx0 = _mm512_broadcastsd_pd(_mm_load_sd(B++));
bx1 = _mm512_broadcastsd_pd(_mm_load_sd(B++));
bx2 = _mm512_broadcastsd_pd(_mm_load_sd(B++));
bx3 = _mm512_broadcastsd_pd(_mm_load_sd(B++)); c0000_0700 = _mm512_add_pd(_mm512_mul_pd(a0x_7x, bx0), c0000_0700);
c0800_1500 = _mm512_add_pd(_mm512_mul_pd(a8x_15x, bx0), c0800_1500);
c1600_2300 = _mm512_add_pd(_mm512_mul_pd(a16x_23x, bx0), c1600_2300);
c2400_3100 = _mm512_add_pd(_mm512_mul_pd(a24x_31x, bx0), c2400_3100); c0001_0701 = _mm512_add_pd(_mm512_mul_pd(a0x_7x, bx1), c0001_0701);
c0801_1501 = _mm512_add_pd(_mm512_mul_pd(a8x_15x, bx1), c0801_1501);
c1601_2301 = _mm512_add_pd(_mm512_mul_pd(a16x_23x, bx1), c1601_2301);
c2401_3101 = _mm512_add_pd(_mm512_mul_pd(a24x_31x, bx1), c2401_3101); c0002_0702 = _mm512_add_pd(_mm512_mul_pd(a0x_7x, bx2), c0002_0702);
c0802_1502 = _mm512_add_pd(_mm512_mul_pd(a8x_15x, bx2), c0802_1502);
c1602_2302 = _mm512_add_pd(_mm512_mul_pd(a16x_23x, bx2), c1602_2302);
c2402_3102 = _mm512_add_pd(_mm512_mul_pd(a24x_31x, bx2), c2402_3102); c0003_0703 = _mm512_add_pd(_mm512_mul_pd(a0x_7x, bx3), c0003_0703);
c0803_1503 = _mm512_add_pd(_mm512_mul_pd(a8x_15x, bx3), c0803_1503);
c1603_2303 = _mm512_add_pd(_mm512_mul_pd(a16x_23x, bx3), c1603_2303);
c2403_3103 = _mm512_add_pd(_mm512_mul_pd(a24x_31x, bx3), c2403_3103);
}
_mm512_storeu_pd(C, c0000_0700);
_mm512_storeu_pd(C + 8, c0800_1500);
_mm512_storeu_pd(C + 16, c1600_2300);
_mm512_storeu_pd(C + 24, c2400_3100); _mm512_storeu_pd(c0001_0701_ptr, c0001_0701);
_mm512_storeu_pd(c0001_0701_ptr + 8, c0801_1501);
_mm512_storeu_pd(c0001_0701_ptr + 16, c1601_2301);
_mm512_storeu_pd(c0001_0701_ptr + 24, c2401_3101); _mm512_storeu_pd(c0002_0702_ptr, c0002_0702);
_mm512_storeu_pd(c0002_0702_ptr + 8, c0802_1502);
_mm512_storeu_pd(c0002_0702_ptr + 16, c1602_2302);
_mm512_storeu_pd(c0002_0702_ptr + 24, c2402_3102); _mm512_storeu_pd(c0003_0703_ptr, c0003_0703);
_mm512_storeu_pd(c0003_0703_ptr + 8, c0803_1503);
_mm512_storeu_pd(c0003_0703_ptr + 16, c1603_2303);
_mm512_storeu_pd(c0003_0703_ptr + 24, c2403_3103);
} static inline void copy_avx512_b(int lda, const int K, double* b_src, double* b_dest) {
double* b_ptr0, * b_ptr1, * b_ptr2, * b_ptr3;
b_ptr0 = b_src;
b_ptr1 = b_ptr0 + lda;
b_ptr2 = b_ptr1 + lda;
b_ptr3 = b_ptr2 + lda; for (int i = 0; i < K; ++i)
{
*b_dest++ = *b_ptr0++;
*b_dest++ = *b_ptr1++;
*b_dest++ = *b_ptr2++;
*b_dest++ = *b_ptr3++;
}
} static inline void copy_avx512_a(int lda, const int K, double* a_src, double* a_dest) {
for (int i = 0; i < K; ++i)
{
memcpy(a_dest, a_src, 32 * 8);
a_dest += 32;
a_src += lda;
}
} static inline void do_block_avx512(int lda, int M, int N, int K, double* A, double* B, double* C)
{
double* A_block, * B_block;
A_block = (double*)_mm_malloc(M * K * sizeof(double), 64);
B_block = (double*)_mm_malloc(K * N * sizeof(double), 64); double* a_ptr, * b_ptr, * c; const int Nmax = N - 3;
int Mmax = M - 32; int i = 0, j = 0, p = 0; for (j = 0; j < Nmax; j += 4)
{
b_ptr = &B_block[j * K];
copy_avx512_b(lda, K, B + j * lda, b_ptr); // 将 B 展开
for (i = 0; i < Mmax; i += 32) {
a_ptr = &A_block[i * K];
if (j == 0) copy_avx512_a(lda, K, A + i, a_ptr); // 将 A 展开
c = C + i + j * lda;
block_avx512_32x4(lda, K, a_ptr, b_ptr, c);
}
}
_mm_free(A_block);
_mm_free(B_block);
} void gemm_avx512(int lda, double* A, double* B, double* C)
{
#pragma omp parallel for
for (int j = 0; j < lda; j += BLOCK_SIZE) { // j i k 序 内存读写更快
for (int i = 0; i < lda; i += BLOCK_SIZE) {
for (int k = 0; k < lda; k += BLOCK_SIZE) {
// 大分块里小分块
for (int jj = j; jj < j + BLOCK_SIZE; jj += BLOCK_SIZE2)
for (int ii = i; ii < i + BLOCK_SIZE; ii += BLOCK_SIZE2)
for (int kk = k; kk < k + BLOCK_SIZE; kk += BLOCK_SIZE2)
do_block_avx512(lda, BLOCK_SIZE2, BLOCK_SIZE2, BLOCK_SIZE2, A + ii + kk * lda, B + kk + jj * lda, C + ii + jj * lda);
}
}
}
}

AVX512加速矩阵乘法的更多相关文章

  1. 【POJ3613】Cow Relays 离散化+倍增+矩阵乘法

    题目大意:给定一个 N 个顶点,M 条边的无向图,求从起点到终点恰好经过 K 个点的最短路. 题解:设 \(d[1][i][j]\) 表示恰好经过一条边 i,j 两点的最短路,那么有 \(d[r+m] ...

  2. 如何使用矩阵乘法加速动态规划——以[SDOI2009]HH去散步为例

    对这个题目的最初理解 开始看到这个题,觉得很水,直接写了一个最简单地动态规划,就是定义 f[i][j]为到了i节点路径长度为j的路径总数, 转移的话使用Floyd算法的思想去转移,借助这个题目也理解了 ...

  3. [模板][题解][Luogu1939]矩阵乘法加速递推(详解)

    题目传送门 题目大意:计算数列a的第n项,其中: \[a[1] = a[2] = a[3] = 1\] \[a[i] = a[i-3] + a[i - 1]\] \[(n ≤ 2 \times 10^ ...

  4. BZOJ 1009 GT考试 (AC自动机 + 矩阵乘法加速dp)

    题目链接: https://www.lydsy.com/JudgeOnline/problem.php?id=1009 题意: 准考证号为\(n\)位数\(X_1X_2....X_n(0<=X_ ...

  5. 『公交线路 状压dp 矩阵乘法加速』

    公交线路 Description 小Z所在的城市有N个公交车站,排列在一条长(N-1)km的直线上,从左到右依次编号为1到N,相邻公交车站间的距离均为1km. 作为公交车线路的规划者,小Z调查了市民的 ...

  6. c++的矩阵乘法加速trick

    最近读RNNLM的源代码,发现其实现矩阵乘法时使用了一个trick,这里描述一下这个trick. 首先是正常版的矩阵乘法(其实是矩阵乘向量) void matrixXvector(float* des ...

  7. HDU 5607 graph(DP+矩阵乘法)

    [题目链接] http://bestcoder.hdu.edu.cn/contests/contest_showproblem.php?cid=663&pid=1002 [题意] 给定一个有向 ...

  8. BZOJ_1009_[HNOI2008]GT考试_KMP+矩阵乘法

    BZOJ_1009_[HNOI2008]GT考试_KMP+矩阵乘法 Description 阿申准备报名参加GT考试,准考证号为N位数X1X2....Xn(0<=Xi<=9),他不希望准考 ...

  9. Codeforces 1106F Lunar New Year and a Recursive Sequence | BSGS/exgcd/矩阵乘法

    我诈尸啦! 高三退役选手好不容易抛弃天利和金考卷打场CF,结果打得和shi一样--还因为queue太长而unrated了!一个学期不敲代码实在是忘干净了-- 没分该没分,考题还是要订正的 =v= 欢迎 ...

  10. [转]OpenBLAS项目与矩阵乘法优化

    课程内容 OpenBLAS项目介绍 矩阵乘法优化算法 一步步调优实现 以下为公开课完整视频,共64分钟: 以下为公开课内容的文字及 PPT 整理. 雷锋网的朋友们大家好,我是张先轶,今天主要介绍一下我 ...

随机推荐

  1. CoaXPress 协议的CRC及其具体实现

    CoaXPress CRC 在CXP协议中,CRC用在stream packet和control packet中,用于指示数据是否错误,如果是control packet, device发现CRC错误 ...

  2. java中sha1.md5,base64到底怎么回事

    MD5 Message Digest Algorithm MD5(中文名为消息摘要算法第五版)为计算机安全领域广泛使用的一种散列函数,用以提供消息的完整性保护.MD5用的是哈希函数,在计算机网络中应用 ...

  3. UG474

    为了对工程的资源利用率进行优化,我们首先需要知道当前工程对资源的利用率情况.在Vivado下,我们可以查看工程的资源利用率情况,在下面这张图中,其罗列出了整个工程所使用的资源情况.首先,下面我们需要一 ...

  4. GitHub访问地址映射更新的时候刷新DNS

    1.windows系统 上设置地址映射 Window系统本地可以安装 Git Bash 方便本地管理仓,或下载Git 上的代码,在访问Git的时候经常出现Git访问主页加载不了等问题.需要设置在本地设 ...

  5. 英语单词组件- 单词在句子中,上面显示中文下面显示音标 css样式

    原先效果: 改进demo效果 优化点 音标长度超出,或者中文超出,总宽度会按照最长的走 居中显示 再次优化 line-height: 22px; 加入这个 对齐中间行(字号大小会让绝对上下高度,对不齐 ...

  6. opus编解码的特色和优点

    概念原理   Opus是一个有损音频压缩的数字音频编码格式,由Xiph.Org基金会开发,之后由互联网工程任务组(IETF)进行标准化,目标是希望用单一格式包含声音和语音,取代Speex和Vorbis ...

  7. ETL工具-KETTLE教程实例实战3----转换(输入、输出)

    ETL工具-KETTLE教程实例实战3----转换(输入.输出) 欢迎关注笔者的公众号: java大师, 每日推送java.kettle运维等领域干货文章,关注即免费无套路附送 100G 海量学习.面 ...

  8. Elasticsearch - Docker安装Elasticsearch8.12.2

    前言 最近在学习 ES,所以需要在服务器上装一个单节点的 ES 服务器环境:centos 7.9 安装 下载镜像 目前最新版本是 8.12.2 docker pull docker.elastic.c ...

  9. AOSP源码编译—交换空间扩容

    编译AOSP源码的时候会出现提示如下: 意思是需要16G左右的内存(实际上编译会超过16G),而我们之前安装Ubuntu的时候只分配了8G,编 译一定会失败!此时需要添加虚拟内存(swap交换空间) ...

  10. C 可变参数函数分析(va_start,va_end,va_list...)

    PS:要转载请注明出处,本人版权所有. PS: 这个只是基于<我自己>的理解, 如果和你的原则及想法相冲突,请谅解,勿喷. 前置说明   本文作为本人csdn blog的主站的备份.(Bl ...