Sub-group Swizzle分块矩阵乘优化原理分析

Sub-group Swizzle分块矩阵乘优化原理分析

因为要写成实验报告用来应付课程作业,因此写得罗里吧嗦的,见谅 :(

摘要

矩阵乘法是深度学习、科学计算等应用领域的基础算子,其优化方法是广泛研究的课题。对于大规模矩阵乘法,可以利用线性代数中矩阵分块乘法理论,将大规模矩阵乘法转化为分块矩阵乘。Sub-group Swizzle Block Matrix Multiplication是一种将矩阵乘法划分为子组,并重新映射线程块threadBlockIdx,提高GPU L2 Cache Hit Rate,以加速分块矩阵乘的优化方法。这种优化方法在Cutlass[1], triton[2]中被广泛实现,但是目前相关资料对这种优化方法原理的分析较少,甚至存在原理上的错误,相关论坛中的讨论[3]没有定论。本文分析了线程块与SM分配原则、线程块之间的执行顺序,解释了Sub-group Swizzle Matrix Multiplication优化L2 cache hit rate的原理。

引言

分块矩阵乘法


上图表示的是$A \times B = C$的矩阵乘法,假设三个矩阵$A,B,C$为大小相等的方形矩阵,每一个矩阵可以分为$5\times 5$个分块。以下是为每一个分块分配一个threadblock并行计算C矩阵的triton核函数。

1
2
3
4
5
6
7
8
9
@triton.JIT
def matmul_row_wise(...):
pid = tl.program_id(axis=0)
grid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // grid_n
pid_n = pid % grid_n

#compute every block of matrix C
.......

以上triton核函数分配了$5\times5$个线程块,每一个线程块取出A矩阵相应的分块行(a row of blocks in A),与B矩阵相应的分块列(a column of blocks in B),计算得到C矩阵对应的一个分块。threadblockIdx(Triton中称为Program Idx)按照列主序(row-wise)分布。但是,这种算法的L2 cache hit rate不高,当$B$矩阵列数$N$较大时,L2 cache命中率较低。

Sub-group Swizzle分块矩阵乘

Sub-group Swizzle Block Matmul分为两步:①将多个分块行分为一个sub-group;②更改threadblockIdx在结果矩阵C上的排列顺序。

如图二所示,在一个分组中,threadblockIdx按照如图所示的折线分布。例如在一个大小为2的分组中,threadblockIdx按照(0,0),(1,0),(0,1),(1,1)的顺序递增,在下一个分组中,threadblockIdx按照(2,0),(3,0),(2,1),(3,1)的顺序递增。以下为这种算法的triton实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
@triton.jit
def matmul_kernel_group_col_wise(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,M, N, K,stride_am, stride_ak,stride_bk, stride_bn,
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
ACTIVATION: tl.constexpr #
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

#computer every block.......

基于分组的swizzle threadblockIdx映射方式能够提高L2 cache命中率。然而,没有相关材料论述其中的原理。材料[2]认为线程块之间是严格串行的,因此能够复用部分数据。但在CUDA编程模型中,所有线程块之间并不是严格串行的,线程块之间的执行顺序由GPU SM Scheduler和硬件资源所影响,存在较为复杂的关系。接下来将展开优化原理分析。

原理分析

ThreadBlock与SM映射关系

GPU按照ThreadBlockIdx顺序映射到SM(Stream Mutliprocessor),占满所有SM后,剩余的threadblock等待SM资源的释放,一旦SM空闲,则分配到该SM上。使用如下设备端代码可以获取每一个threadblock对应的SM的索引,从而逆向得到特定GPU上threadblockSM的映射关系。

1
2
3
4
5
6
__device__ unsigned int smid(void)
{
unsigned int sm;
asm("mov.u32 %0, %%smid;" : "=r"(sm));
return sm;
}

例如,GeForce 4090 GPU有9个完整的GPC(Graph Process cluster(每个GPC含有12个SM),2个不完整的GPC(每个GPC含10个SM),共$9\times12+2\times10=128$个SM。在实验程序中,由于资源限制,单个SM最多同时并行2个threadblock(residency=2)下图为GeForce RTX 4090SMthreadblock之间的映射关系。
[The mapping relationship between SMs and thread blocks on NVIDIA GeForce RTX 4090.png]
如上图所示,0~127threadblock被依次映射到0,0,2,2,4,4,6,6......126,126SM上,128~143threadblock被依次映射到1,1,3,3,5,5......15,15SM上,144~199threadblock被依次映射到17,19......127SM上,200~255threadblock被依次映射到17,19......127SMthreadblockIdx大于255的threadblock等待SM资源的释放,一旦SM空闲,则将threadblock映射到空闲的SM上。

我们称第一次调度中,将尽可能多的threadblock映射到SMs上,占满SM,为the first wave(第一波调度)[4]。所有SM被占满后,采用贪心调度策略(Greedy Schedule)将threadblock映射到SM,即一旦有SM空闲,立刻将threadblock映射上。使用如下设备端函数测出各线程块发射时间:

1
2
3
4
5
6
__device__ unsigned long long globaltime(void)
{
unsigned long long time;
asm("mov.u64 %0, %%globaltimer;" : "=l"(time));
return time;
}

[the first wave and Greedy Schedule.jpg]

由上图可知,由于GPU SM资源限制,GPU不能同时发射所有的threadblocks,因此不同批次发射到SM上的threadblocks之间存在串行关系。接下来分析不同批次的threadblock之间如何通过分组swizzle优化L2 cache命中率。

不同批次之间线程块访存优化

本节将对L2 cache访问进行量化分析,分析案例为:

  • 矩阵分块规模M,K,N = 512, 512, 512,数据精度为Fp16BLOCK SIZE M,K,N = 64,64,32
  • NVIDIA GeForce RTX 4090 L2 cache72MB,为便于分析,假设L2 cache容量正好能够缓存住256个threadblock计算所需的数据量。
  • 假设每一批次的所有threadblock同时结束。虽然一般来说,矩阵乘法中每个threadblock计算的数据量相同,但是各threadblock的耗时可能不是严格相等。为便于分析,此处做近似分析,下一节做补充说明。

[row wise 256blocks.jpg]
上图展示了按照行主序设置threadblockIdx的计算流程:①C0表示第一批同时启动的256个ThreadBlocks。256个threadblocks同时读取A0(一行分块),GPU会将其合并为一次读操作;256个threadBlocks一共需要读取B0(256列分块)。那么C0中256个threadblocks完成计算总共需要从Global Memory中读257行分块。此时,A0B0被缓存在L2 cache中。②C1中256个threadblocks需要读A0B1A0被缓存在L2 Cache中,不需要从Global Memory中读取。B1需要从Global Memory读取。C1完成计算需要从Global Memory中读取256行分块(忽略从L2 Cache中读取A0的开销,因为L2 Cache速度远快于Global Memory)。此时A0和B1被缓存在L2 Cache中(B0被替换)。同理,③C2完成计算需要从Global Memory中读取257行分块;④C3完成计算需要从Global Memory中读取257行分块。

综上,4批threadblocks完成计算,依次需要从Global Memory读取的数据量:257行分块$\rightarrow$ 256 $\rightarrow$ 257 $\rightarrow$ 256。

sub-group swizzle 256 blocks.png
上图展示了分组大小为2,按照折线顺序编排threadblockIdx的计算流程:①C0表示第一批同时启动的256个ThreadBlocks。256个threadblocks同时读取A0(一行分块),GPU会将其合并为一次读操作;256个threadBlocks一共需要读取B0(256列分块)。那么C0中256个threadblocks完成计算总共需要从Global Memory中读257行分块。此时,A0B0被缓存在L2 cache中。②C1中256个threadblocks需要读A1B0B0被缓存在L2 Cache中,不需要从Global Memory中读取。A1需要从Global Memory读取。C1完成计算需要从Global 1行分块(忽略从L2 Cache中读取B0的开销,因为L2 Cache速度远快于Global Memory)。此时A1和B0被缓存在L2 Cache中(A0被替换)。同理,③C2完成计算需要从Global Memory中读取257行分块;④C3完成计算需要从Global Memory中读取1行分块。

综上,4批threadblocks完成计算,依次需要从Global Memory读取的数据量:257行分块$\rightarrow$ 1 $\rightarrow$ 257 $\rightarrow$ 1。

由此可见,分组折线编排threadblockIdx能够较大地提升L2 Cache命中率。

同一批线程块的发射顺序

同一批线程块的发射顺序并不是严格并行的,而是大致上按照threadblockIdx的顺序发射(但不是严格顺序),如下图所示[5]:
Time-View of CTA Scheduling Across SMs.png
可以得出经验性的结论:让threadblockIdx相邻的线程块访问相邻地址的数据,更有可能提高L2 Cache命中率。

实验:性能比较

实验环境:

  • NVIDIA GeForce RTX 4090
  • cuda 12.1
  • CentOS 7.9

下图是两种不同的threadblockIdx编排顺序的算法性能对比图。当矩阵规模较小时,两种算法性能接近,在大规模矩阵乘时,sub-group swizzlerow-wise性能大约高2倍。

[row-wise vs sub-group swizzle.png]
L2 Cache Hit rate对比:
L2 cache hit rate compare.png

小规模矩阵乘法中,因为4090 GPU L2 Cache足够大(72MB),能够缓存住所有的数据,两种算法的L2 cache命中率都很高。大规模矩阵乘法中,两种算法有明显差异,M,K,N=16384下,两种算法的L2 cache命中率分别为83.22%, 95.70%M,K,N=32768规模下分别为83.28%, 95.64%

经过triton tuning后,sub-group swizzle矩阵乘与cuBLAS性能对比图。
[tuning triton vs cublas.png]

参考文献

[1] https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/threadblock/threadblock_swizzle.h
[2] https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
[3] https://github.com/NVIDIA/cutlass/issues/1017
[4] https://cs.rochester.edu/%7Esree/fermi-tbs/fermi-tbs.html
[5] Yang J, Wen M, Chen D, et al. HyFiSS: A Hybrid Fidelity Stall-Aware Simulator for GPGPUs[C]//2024 57th IEEE/ACM International Symposium on Microarchitecture (MICRO). IEEE, 2024: 168-185.