Sub-group Swizzle分块矩阵乘优化原理分析
Sub-group Swizzle分块矩阵乘优化原理分析
因为要写成实验报告用来应付课程作业,因此写得罗里吧嗦的,见谅 :(
摘要
矩阵乘法是深度学习、科学计算等应用领域的基础算子,其优化方法是广泛研究的课题。对于大规模矩阵乘法,可以利用线性代数中矩阵分块乘法理论,将大规模矩阵乘法转化为分块矩阵乘。Sub-group Swizzle Block Matrix Multiplication是一种将矩阵乘法划分为子组,并重新映射线程块threadBlockIdx,提高GPU L2 Cache Hit Rate,以加速分块矩阵乘的优化方法。这种优化方法在Cutlass[1], triton[2]中被广泛实现,但是目前相关资料对这种优化方法原理的分析较少,甚至存在原理上的错误,相关论坛中的讨论[3]没有定论。本文分析了线程块与SM分配原则、线程块之间的执行顺序,解释了Sub-group Swizzle Matrix Multiplication优化L2 cache hit rate的原理。
引言
分块矩阵乘法

上图表示的是$A \times B = C$的矩阵乘法,假设三个矩阵$A,B,C$为大小相等的方形矩阵,每一个矩阵可以分为$5\times 5$个分块。以下是为每一个分块分配一个threadblock并行计算C矩阵的triton核函数。
1 |
|
以上triton核函数分配了$5\times5$个线程块,每一个线程块取出A矩阵相应的分块行(a row of blocks in A),与B矩阵相应的分块列(a column of blocks in B),计算得到C矩阵对应的一个分块。threadblockIdx(Triton中称为Program Idx)按照列主序(row-wise)分布。但是,这种算法的L2 cache hit rate不高,当$B$矩阵列数$N$较大时,L2 cache命中率较低。
Sub-group Swizzle分块矩阵乘
Sub-group Swizzle Block Matmul分为两步:①将多个分块行分为一个sub-group;②更改threadblockIdx在结果矩阵C上的排列顺序。

如图二所示,在一个分组中,threadblockIdx按照如图所示的折线分布。例如在一个大小为2的分组中,threadblockIdx按照(0,0),(1,0),(0,1),(1,1)的顺序递增,在下一个分组中,threadblockIdx按照(2,0),(3,0),(2,1),(3,1)的顺序递增。以下为这种算法的triton实现:
1 |
|
基于分组的swizzle threadblockIdx映射方式能够提高L2 cache命中率。然而,没有相关材料论述其中的原理。材料[2]认为线程块之间是严格串行的,因此能够复用部分数据。但在CUDA编程模型中,所有线程块之间并不是严格串行的,线程块之间的执行顺序由GPU SM Scheduler和硬件资源所影响,存在较为复杂的关系。接下来将展开优化原理分析。
原理分析
ThreadBlock与SM映射关系
GPU按照ThreadBlockIdx顺序映射到SM(Stream Mutliprocessor),占满所有SM后,剩余的threadblock等待SM资源的释放,一旦SM空闲,则分配到该SM上。使用如下设备端代码可以获取每一个threadblock对应的SM的索引,从而逆向得到特定GPU上threadblock与SM的映射关系。
1 | __device__ unsigned int smid(void) |
例如,GeForce 4090 GPU有9个完整的GPC(Graph Process cluster(每个GPC含有12个SM),2个不完整的GPC(每个GPC含10个SM),共$9\times12+2\times10=128$个SM。在实验程序中,由于资源限制,单个SM最多同时并行2个threadblock(residency=2)下图为GeForce RTX 4090中SM与threadblock之间的映射关系。![[The mapping relationship between SMs and thread blocks on NVIDIA GeForce RTX 4090.png]](https://pic1.imgdb.cn/item/6836fe1458cb8da5c8164ddf.png)
如上图所示,0~127号threadblock被依次映射到0,0,2,2,4,4,6,6......126,126号SM上,128~143号threadblock被依次映射到1,1,3,3,5,5......15,15号SM上,144~199号threadblock被依次映射到17,19......127号SM上,200~255号threadblock被依次映射到17,19......127号SM。threadblockIdx大于255的threadblock等待SM资源的释放,一旦SM空闲,则将threadblock映射到空闲的SM上。
我们称第一次调度中,将尽可能多的threadblock映射到SMs上,占满SM,为the first wave(第一波调度)[4]。所有SM被占满后,采用贪心调度策略(Greedy Schedule)将threadblock映射到SM,即一旦有SM空闲,立刻将threadblock映射上。使用如下设备端函数测出各线程块发射时间:
1 | __device__ unsigned long long globaltime(void) |
![[the first wave and Greedy Schedule.jpg]](https://pic1.imgdb.cn/item/6836fe2e58cb8da5c8164de4.png)
由上图可知,由于GPU SM资源限制,GPU不能同时发射所有的threadblocks,因此不同批次发射到SM上的threadblocks之间存在串行关系。接下来分析不同批次的threadblock之间如何通过分组swizzle优化L2 cache命中率。
不同批次之间线程块访存优化
本节将对L2 cache访问进行量化分析,分析案例为:
- 矩阵分块规模
M,K,N = 512, 512, 512,数据精度为Fp16,BLOCK SIZE M,K,N = 64,64,32 NVIDIA GeForce RTX 4090L2 cache为72MB,为便于分析,假设L2 cache容量正好能够缓存住256个threadblock计算所需的数据量。- 假设每一批次的所有
threadblock同时结束。虽然一般来说,矩阵乘法中每个threadblock计算的数据量相同,但是各threadblock的耗时可能不是严格相等。为便于分析,此处做近似分析,下一节做补充说明。
![[row wise 256blocks.jpg]](https://pic1.imgdb.cn/item/6836fe4058cb8da5c8164de5.png)
上图展示了按照行主序设置threadblockIdx的计算流程:①C0表示第一批同时启动的256个ThreadBlocks。256个threadblocks同时读取A0(一行分块),GPU会将其合并为一次读操作;256个threadBlocks一共需要读取B0(256列分块)。那么C0中256个threadblocks完成计算总共需要从Global Memory中读257行分块。此时,A0和B0被缓存在L2 cache中。②C1中256个threadblocks需要读A0和B1。A0被缓存在L2 Cache中,不需要从Global Memory中读取。B1需要从Global Memory读取。C1完成计算需要从Global Memory中读取256行分块(忽略从L2 Cache中读取A0的开销,因为L2 Cache速度远快于Global Memory)。此时A0和B1被缓存在L2 Cache中(B0被替换)。同理,③C2完成计算需要从Global Memory中读取257行分块;④C3完成计算需要从Global Memory中读取257行分块。
综上,4批threadblocks完成计算,依次需要从Global Memory读取的数据量:257行分块$\rightarrow$ 256 $\rightarrow$ 257 $\rightarrow$ 256。

上图展示了分组大小为2,按照折线顺序编排threadblockIdx的计算流程:①C0表示第一批同时启动的256个ThreadBlocks。256个threadblocks同时读取A0(一行分块),GPU会将其合并为一次读操作;256个threadBlocks一共需要读取B0(256列分块)。那么C0中256个threadblocks完成计算总共需要从Global Memory中读257行分块。此时,A0和B0被缓存在L2 cache中。②C1中256个threadblocks需要读A1和B0。B0被缓存在L2 Cache中,不需要从Global Memory中读取。A1需要从Global Memory读取。C1完成计算需要从Global 1行分块(忽略从L2 Cache中读取B0的开销,因为L2 Cache速度远快于Global Memory)。此时A1和B0被缓存在L2 Cache中(A0被替换)。同理,③C2完成计算需要从Global Memory中读取257行分块;④C3完成计算需要从Global Memory中读取1行分块。
综上,4批threadblocks完成计算,依次需要从Global Memory读取的数据量:257行分块$\rightarrow$ 1 $\rightarrow$ 257 $\rightarrow$ 1。
由此可见,分组折线编排threadblockIdx能够较大地提升L2 Cache命中率。
同一批线程块的发射顺序
同一批线程块的发射顺序并不是严格并行的,而是大致上按照threadblockIdx的顺序发射(但不是严格顺序),如下图所示[5]:
可以得出经验性的结论:让threadblockIdx相邻的线程块访问相邻地址的数据,更有可能提高L2 Cache命中率。
实验:性能比较
实验环境:
NVIDIA GeForce RTX 4090cuda 12.1CentOS 7.9
下图是两种不同的threadblockIdx编排顺序的算法性能对比图。当矩阵规模较小时,两种算法性能接近,在大规模矩阵乘时,sub-group swizzle比row-wise性能大约高2倍。
![[row-wise vs sub-group swizzle.png]](https://pic1.imgdb.cn/item/6836feef58cb8da5c8164e11.png)
L2 Cache Hit rate对比:
小规模矩阵乘法中,因为4090 GPU L2 Cache足够大(72MB),能够缓存住所有的数据,两种算法的L2 cache命中率都很高。大规模矩阵乘法中,两种算法有明显差异,M,K,N=16384下,两种算法的L2 cache命中率分别为83.22%, 95.70%,M,K,N=32768规模下分别为83.28%, 95.64%。
经过triton tuning后,sub-group swizzle矩阵乘与cuBLAS性能对比图。![[tuning triton vs cublas.png]](https://pic1.imgdb.cn/item/6836ff3e58cb8da5c8164e2d.png)
参考文献
[1] https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/threadblock/threadblock_swizzle.h
[2] https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
[3] https://github.com/NVIDIA/cutlass/issues/1017
[4] https://cs.rochester.edu/%7Esree/fermi-tbs/fermi-tbs.html
[5] Yang J, Wen M, Chen D, et al. HyFiSS: A Hybrid Fidelity Stall-Aware Simulator for GPGPUs[C]//2024 57th IEEE/ACM International Symposium on Microarchitecture (MICRO). IEEE, 2024: 168-185.