本篇是对SM80架构下Tensor Core GEMM实现的源码解读笔记。代码来源于仓库:

https://github.com/gau-nernst/learn-cuda

该仓库有SM80/SM120/SM100/CDNA3多个架构的源码实现,并且性能都能达到cublas水平,是很好的学习材料,但仓库中的README只有一些简单的介绍,所以还是需要仔细的过一遍代码才能真正理解。

1 指令

1.1 ldmatrix

ldmatrix是专为16位数据设计的数据搬运指令,用于指定一个warp的线程将共享内存的数据load到每个线程的寄存器上,后续使用mma指令进行计算。ldmatrix在后续的架构是支持其他位数和规格的,但在SM80上,其只支持16位数据的m8n8搬运,即搬运8x8的小块,可以指定搬运1/2/4个。

1
2
3
4
5
6
7
8
9
10
11
12
__device__ inline
void ldmatrix_x2(uint32_t reg[2], uint32_t addr) {
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];"
: "=r"(reg[0]), "=r"(reg[1])
: "r"(addr));
}
__device__ inline
void ldmatrix_x4(uint32_t reg[4], uint32_t addr) {
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg[0]), "=r"(reg[1]), "=r"(reg[2]), "=r"(reg[3])
: "r"(addr));
}

读取时,一个线程会读取一行的元素:

.num Threads 0–7 Threads 8–15 Threads 16–23 Threads 24–31
.x1 addr0–addr7
.x2 addr0–addr7 addr8–addr15
.x4 addr0–addr7 addr8–addr15 addr16–addr23 addr24–addr31

但元素读取后会分发给整个warp的线程,可以视为是每个线程连续拿到两个相邻的元素:

image-20260302172558660

如果是x2的话,也是一样的,线程的reg[1]会存储第二个矩阵的连续的两个数据。如果读取指令加了.trans可选后缀,则读到的是:

image-20260302172607405

1.2 mma

根据数据类型的不同,mma支持不同尺寸的矩阵计算,比如对于fp16,mma支持m8n8k4,m16n8k8和m16n8k16大小的矩阵运算。而对于u8/s8类型,mma则支持m8n8k16,m16n8k16和m16n8k32大小的矩阵运算。选用不同的mma尺寸,数据的加载也不同。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
__device__ inline
void mma_m16n8k16(const uint32_t A[4], const uint32_t B[2], float D[4]) {
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3}, " // D
"{%4, %5, %6, %7}, " // A
"{%8, %9}, " // B
"{%10, %11, %12, %13};" // C
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
"r"(B[0]), "r"(B[1]),
"f"(D[0]), "f"(D[1]), "f"(D[2]), "f"(D[3]));
}
//"f"表述处理浮点的寄存器(f32)
//"="表示只写

对于m16n8k16,A是16x16的矩阵,是四个8x8拼起来的矩阵,在进行mma计算时,每个线程拿的值就是ldmatrix.x4读取的值:

image-20260302172616776

而B矩阵是16x8的矩阵,是两个8x8拼起来的矩阵,每个线程应该持有的数据如下:

image-20260302172627155

矩阵C是16x8,也是两个8x8矩阵:

image-20260302172633650

对于上述的数据,可以发现如果矩阵B存储的数据在sharedMem是行主序的,那么用.trans后缀的ldmatrix.x2读取到的就是mma需要的数据。如果数据是列主序存储在sharedMem中,不需要后缀,读到的就是mma需要的数据。

1.3 cp.async

异步从global memory拷贝数据到shared memory,会bypass L1和寄存器,通过和commit_group,wait_group配合来实现同步。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
__device__ inline
void cp_async(uint32_t dst, const void *src) {
// .ca means cache to L1 and L2. .cg means cache to L2 only.
// .cg only accepts cp-size=16
// .ca results in significantly slower kernel, probably because it uses up L1 resources
// + additional copy, which is unnecessary, since we already manually cache it in shared memory.
asm volatile("cp.async.cg.shared.global [%0], [%1], 16;" ::"r"(dst), "l"(src));
};

__device__ inline
void cp_async_commit_group() { asm volatile("cp.async.commit_group;"); };

template <int N>
__device__ inline
void cp_async_wait_group() { asm volatile("cp.async.wait_group %0;" ::"n"(N)); };

__device__ inline
void cp_async_wait_all() { asm volatile("cp.async.wait_all;"); };

//{}表示寄存器列表
//[]表示寄存器寻址,而不加[]的%0表示寄存器值
//"n"表示立即数/常数 "l"表示longlong
//wait_group 0表示等待全部完成,N表示允许N个还未完成

2 matmul_sm80

在5060 laptop上,4096x4096矩阵的性能测试的结果如下,峰值GFLOPs是36.13(https://www.waredb.com/zh/ranking/memory/bandwidth/pc)

1
2
3
4
5
6
7
8
9
10
11
12
| Kernel name                                            | TFLOPS | % SOL  |
|:-------------------------------------------------------|:-------|:-------|
| CuBLAS 13.0 (via PyTorch 2.10) | 35.01 | 96.29% |
| Inductor Triton (PyTorch 2.10) | nan | nan% |
| v1 (block+warp tiling, vectorized load) | 26.35 | 72.92% |
| v2 (`cp.async`) | 30.36 | 84.02% |
| v3 (pad shared memory) | 28.38 | 78.54% |
| v4 (swizzle shared memory) | 30.80 | 85.25% |
| v5 (`ldmatrix.x4` for B, optimize address computation) | 30.25 | 83.73% |
| v6 (2-stage pipelining) | 30.53 | 84.51% |
| v7 (better swizzling logic, unroll prefetch stages) | 31.02 | 85.86% |
| v8 (threadblock swizzling) | 33.12 | 91.05% |

2.1 block+warp tiling 7 cp.async

v2版本是最基本的使用mma + cp.async的实现。其入口如下,和使用CUDA CORE的sgemm的配置基本是一样的,每个块计算一个128*128的C块。TB_SIZE表示threadBlock size。整个计算过程中数据块的偏移等大致可以用下图图示来表示:

image-20260302172728766

对于block C10,计算的是A1 x B0,由于计算是warp级别的,每个warp分到的是block计算的C10的一部分,对于下面的代码,一个block包含4个线程,因此上图的WARP(10)计算的是C10左下角的部分,在从global memory读取A1和B0时,每次读取的是BMxBK,BNxBK的数据,这个读取是整个block进行的,因此是从block线程数量来做逐行的数据搬运,对于行主序或列主序(转置的B),只是步长不同。读取完毕后,warp需要从shared memory上找到需要的tile数据所在位置,进行计算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
void matmul_v2(const nv_bfloat16 *A, const nv_bfloat16 *B, nv_bfloat16 *C, int M, int N, int K) {
assert(is_power_of_two(M) && "M must be a power of 2");
assert(is_power_of_two(N) && "N must be a power of 2");
assert(is_power_of_two(K) && "K must be a power of 2");

// 4 warps
const int BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64;
const int NUM_WARP_M = 2, NUM_WARP_N = 2;
const int SHM_STRIDE = BLOCK_K; // no padding
const int use_cp_async = true;
const int use_swizzle = false;

auto kernel =
matmul_v1_kernel<BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARP_M, NUM_WARP_N, SHM_STRIDE, use_cp_async, use_swizzle>;

const int TB_SIZE = NUM_WARP_M * NUM_WARP_N * WARP_SIZE;
const int grid_size = cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N);
const int shm_size = (BLOCK_M + BLOCK_N) * SHM_STRIDE * sizeof(nv_bfloat16);

launch_kernel(kernel, grid_size, TB_SIZE, shm_size, A, B, C, M, N, K);
}

设置每个block中有NUM_WARP_M x NUM_WARP_N个WARP,那么BLOCK_M必须能整除NUM_WARP_M,将块分给WARP,每个WARP计算的WARP_M必须能整除MMA_M,才能使用mma,基本的约束检查:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
__global__
void matmul_v1_kernel(const nv_bfloat16 *A, const nv_bfloat16 *B, nv_bfloat16 *C, int M, int N, int K) {
constexpr int MMA_M = 16;
constexpr int MMA_N = 8;
constexpr int MMA_K = 16;
static_assert(BLOCK_M % NUM_WARP_M == 0);
static_assert(BLOCK_N % NUM_WARP_N == 0);
static_assert(BLOCK_K % MMA_K == 0);
constexpr int WARP_M = BLOCK_M / NUM_WARP_M;
constexpr int WARP_N = BLOCK_N / NUM_WARP_N;
static_assert(WARP_M % MMA_M == 0);
static_assert(WARP_N % MMA_N == 0);
static_assert(use_cp_async || !use_swizzle); // use_swizzle=true requires use_cp_async=true
constexpr int TB_SIZE = NUM_WARP_M * NUM_WARP_N * WARP_SIZE;

// each warp will do (NUM_MMA_M * NUM_MMA_N) MMAs
constexpr int NUM_MMA_M = WARP_M / MMA_M;
constexpr int NUM_MMA_N = WARP_N / MMA_N;

不考虑L2 hit rate的优化,只是简单的让block按行主序计算C的结果,计算基本的偏移:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int warp_id = tid / WARP_SIZE;
const int lane_id = tid % WARP_SIZE;

// TODO: threadblock swizzling to improve L2 cache hit rate
const int num_blocks_n = cdiv(N, BLOCK_N);
const int bid_m = bid / num_blocks_n;
const int bid_n = bid % num_blocks_n;
const int offset_m = bid_m * BLOCK_M;
const int offset_n = bid_n * BLOCK_N;

const int warp_id_m = warp_id / NUM_WARP_N;
const int warp_id_n = warp_id % NUM_WARP_N;

// A is row-major, B is column-major, C is row-major
A += offset_m * K;
B += offset_n * K;
C += (offset_m + warp_id_m * WARP_M) * N + (offset_n + warp_id_n * WARP_N);

这里A和B只需要加上Block偏移,因为读取A和B到shared Memory是一个block内的线程共同完成的,不分warp,而对C的计算和写回是各个warp自己完成的。

shared memory的空间和sgemm是相同的。

1
2
3
extern __shared__ nv_bfloat16 shm[];
nv_bfloat16 *A_shared = shm; // BLOCK_M * BLOCK_K
nv_bfloat16 *B_shared = A_shared + (BLOCK_M * SHM_STRIDE); // BLOCK_N * BLOCK_K

用于累加的寄存器是FP32的,所以每个线程需要的数量就是MMA_M * MMA_N / WARP_SIZE,而读取的A和B是BF/FP16数据存放在32位寄存器,所以要乘sizeof(nv_bfloat16) / 4。而用于累加的寄存器,由于迭代计算过程中,K是外层循环,M和N是内层,所以需要留出每个WARP计算的NUM_MMA_M x NUM_MMA_N个块。

1
2
3
4
5
6
7
8
9
10
11
12
// all registers are 32-bit (4-byte)
// - we accumulate to FP32, which is exactly 32-bit
// - our inputs are FP16/BF16, hence each register holds 2 elements
// - inputs and accumulate are distributed across 32 threads in a warp
// for m16n8k8, each thread holds
// - 4 output float
// - 4 input A FP16/BF16
// - 2 input B FP16/BF16
constexpr int num_acc_regs = MMA_M * MMA_N / WARP_SIZE;
constexpr int num_A_regs = MMA_M * MMA_K * sizeof(nv_bfloat16) / 4 / WARP_SIZE;
constexpr int num_B_regs = MMA_N * MMA_K * sizeof(nv_bfloat16) / 4 / WARP_SIZE;
float acc[NUM_MMA_M][NUM_MMA_N][num_acc_regs] = {};

从global中读取数据到sharedMemory,和sgemm类似,但使用cp.async进行读取:

1
2
3
4
5
6
7
8
9
10
for (int block_k = 0; block_k < K; block_k += BLOCK_K) {
if constexpr (use_cp_async) {
global_to_shared_async<TB_SIZE, BLOCK_M, BLOCK_K, SHM_STRIDE, use_swizzle>(A, K, A_shared, tid);
global_to_shared_async<TB_SIZE, BLOCK_N, BLOCK_K, SHM_STRIDE, use_swizzle>(B, K, B_shared, tid);
cp_async_wait_all();
} else {
global_to_shared<TB_SIZE, BLOCK_M, BLOCK_K, SHM_STRIDE>(A, K, A_shared, tid);
global_to_shared<TB_SIZE, BLOCK_N, BLOCK_K, SHM_STRIDE>(B, K, B_shared, tid);
}
__syncthreads();

搬运的实现如下,cp.async每次可以搬运16个字节的数据,搬运轮次就按照这个实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
template <int TB_SIZE, int HEIGHT, int WIDTH, int OUT_STRIDE, bool use_swizzle>
__device__
void global_to_shared_async(const nv_bfloat16 *in, int in_stride, nv_bfloat16 *out, int tid) {
constexpr int num_elems = 16 / sizeof(nv_bfloat16); // cp.async cp-size = 16

// convert to shared state space outside of the loop
// TODO: move this to kernel body
uint64_t out_addr = cvta_shared(out);

constexpr int num_iters = (HEIGHT * WIDTH) / (TB_SIZE * num_elems);
for (int iter = 0; iter < num_iters; iter++) {
const int idx = (iter * TB_SIZE + tid) * num_elems;
const int row = idx / WIDTH;
const int col = idx % WIDTH;

uint64_t dst_addr = out_addr + (row * OUT_STRIDE + col) * sizeof(nv_bfloat16);
if constexpr (use_swizzle)
dst_addr = swizzle<OUT_STRIDE * sizeof(nv_bfloat16)>(dst_addr);
cp_async(dst_addr, in + row * in_stride + col);
}
}

计算的部分是先读入B,然后轮流读入A,使用mma指令完成每个WARP负责的NUM_MMA_M x NUM_MMA_N 个块的计算。为什么先读入和存着B呢,这是一个寄存器使用数量上的考虑,因为mma计算的是m16n8k16,B用的寄存器对于单个mma来说少一些,所以暂存B的寄存器压力比较小,但是每个WARP要计算NUM_MMA_M x NUM_MMA_N个tile,对于当前配置:

  • BLOCK_M = BLOCK_N = 128
  • NUM_WARP_M = NUM_WARP_N = 2
  • 每个WARP要算64 x 64的块,NUM_MMA_M = 64/16= 4, NUM_MMA_N = 64/8 = 8

所以这里A和B用的寄存器数量是一样的,A和B谁在外层没有差别。

读入时,首先找到当前WARP需要处理的数据:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
for (int mma_k = 0; mma_k < BLOCK_K; mma_k += MMA_K) {
// for m16n8k8
// <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-1688>
// A\\B [8x8-0]
// [8x8-0]
// [8x8-1]
// where each [8x8] matrix can be loaded from shared memory with ldmatrix
// <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix>

// for m16n8k16
// <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float>
// [8x8-0]
// A\\B [8x8-1]
// [8x8-0][8x8-2]
// [8x8-1][8x8-3]

// select the tile this warp is responsible for
const nv_bfloat16 *A_shm_warp = A_shared + (warp_id_m * WARP_M) * SHM_STRIDE + mma_k;
const nv_bfloat16 *B_shm_warp = B_shared + (warp_id_n * WARP_N) * SHM_STRIDE + mma_k;

这里会注意到B在shared Memory当中是N x K存储的,这是因为mma m16n8k16要求row.col的输入,所以我们得把数据转置过来,这样ldmatrix不需要加.trans,直接读取到的就是需要的数据。这里并不需要在意本身的B是转置(列主序)还是行主序,只要搬运进来是NxK就对了,计算AxB^T,和计算Ax列主序的B是完全相同的,而行列主序仅仅只是步长有变化,在读入时global_to_shared_async传入正常的步长,就可以读到NxK的数据。

在考虑这个问题的时候,要先确定我们需要的数据是什么形状,才据此给步长,并考虑原数据是什么主序。这里我们读入的B需要是NxK的,那么读入的时候height width就是N和K,因为原来的数据是列主序的KxN(未物理转置的NxK),所以N的步长是K,K的步长是1,按照这个去读就是正确的。

1
2
3
4
5
6
7
8
9
10
uint64_t B_reg[NUM_MMA_N][num_B_regs];
for (int mma_id_n = 0; mma_id_n < NUM_MMA_N; mma_id_n++) {
// NOTE: we can reduce unnecessary address calculation if we know MMA_K=8 or 16
// convert generic address to .shared state space address expected by inline PTX
const nv_bfloat16 *B_ptr = B_shm_warp + (mma_id_n * MMA_N + (lane_id % 8)) * SHM_STRIDE + (lane_id / 8) * 8;
uint64_t B_addr = cvta_shared(B_ptr);
if constexpr (use_swizzle)
B_addr = swizzle<SHM_STRIDE * sizeof(nv_bfloat16)>(B_addr);
ldmatrix_x2(B_reg[mma_id_n], B_addr);
}

读A是类似的,读到之后就可以计算了,计算结束后更新A和B的指针:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
    for (int mma_id_m = 0; mma_id_m < NUM_MMA_M; mma_id_m++) {
// load A to registers
uint64_t A_reg[num_A_regs];
const nv_bfloat16 *A_ptr = A_shm_warp + (mma_id_m * MMA_M + (lane_id % 16)) * SHM_STRIDE + (lane_id / 16) * 8;
uint64_t A_addr = cvta_shared(A_ptr);
if constexpr (use_swizzle)
A_addr = swizzle<SHM_STRIDE * sizeof(nv_bfloat16)>(A_addr);
ldmatrix_x4(A_reg, A_addr);

// call mma
for (int mma_id_n = 0; mma_id_n < NUM_MMA_N; mma_id_n++)
mma_m16n8k16(A_reg, B_reg[mma_id_n], acc[mma_id_m][mma_id_n]);
}
}
__syncthreads();

A += BLOCK_K;
B += BLOCK_K;
}

最后是写回部分,计算结束后的C 16x8,每个线程拿到的是和ldmatrix读取的类似,线程0拿到的是第一个矩阵的(0,0)(0,1)和第二个矩阵的(0,0)(0,1),线程1拿到的是第一个矩阵的(0,2)(0,3)和第二个矩阵的(0,2)(0,3),找到对应位置写回,并需要转换回bfloat16格式。

1
2
3
4
5
6
7
8
9
10
for (int mma_id_m = 0; mma_id_m < NUM_MMA_M; mma_id_m++)
for (int mma_id_n = 0; mma_id_n < NUM_MMA_N; mma_id_n++) {
const int row = mma_id_m * MMA_M + (lane_id / 4);
const int col = mma_id_n * MMA_N + (lane_id % 4) * 2;
nv_bfloat16 *C_local = C + row * N + col;

float *regs = acc[mma_id_m][mma_id_n];
reinterpret_cast<nv_bfloat162 *>(C_local)[0] = __float22bfloat162_rn({regs[0], regs[1]}); // c0 and c1
reinterpret_cast<nv_bfloat162 *>(C_local + 8 * N)[0] = __float22bfloat162_rn({regs[2], regs[3]}); // c2 and c3
}

C拿到的数据类似下图,和ldmatrix分发得到的是一样的。根据这个排布,每个线程写回自己的4个数据。

image-20260302172750756

2.2 swizzle share memory

在之前的读取中,可以明显注意到从shared memory中读取数据是会发生bank冲突的,因为多个thread会读取多行元素,而行stride又是32的倍数。在ldmatrix中,每个线程会读取一行(2x8=16B)的数据,其实就是FLOAT4向量化读取,那么考虑其冲突是以一个QUART WARP考虑的,也就是说不管是ldmatrix.x2还是x4,我们都只考虑单次对一个8x8矩阵的读取时的bank冲突。在不做swizzle的情况下,读取会发生8-way冲突,需要swizzle列块数据的访问到右边的模式。

image-20260302172754931

swizzle的逻辑是更改写入的列的起始位置,从而避免后续读取的bank冲突,在写入时和读取时采用一样的swizzle逻辑,就可以保证正确,因为使用XOR swizzle是完全可逆的。swizzle的bank补充还和步长有关,如果步长是64B,那么bank的一行32个存储体能存储两行数据,所以偏移就变成了两行+1,所以swizzle的逻辑如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
template <int STRIDE>    //单位是字节
__device__
uint32_t swizzle(uint32_t index) {
// no need swizzling
if constexpr (STRIDE == 16)
return index;
//只关注8行内的数据访问,所以row_idx要%8
uint32_t row_idx = (index / STRIDE) % 8;
//这是需要右移的bankx4的数量,取决于stride
uint32_t bits_to_xor = row_idx / std::max(128 / STRIDE, 1);
//右移4bank
return index ^ (bits_to_xor << 4);
}
STRIDE 自然 bank 序列 需要的额外 bank 补偿 bits_to_xor 序列
128 {0,0,0,0,0,0,0,0} +{0,4,8,12,16,20,24,28} {0,1,2,3,4,5,6,7}
64 {0,16,0,16,0,16,0,16} +{0,0,4,4,8,8,12,12} {0,0,1,1,2,2,3,3}
32 {0,8,16,24,0,8,16,24} +{0,0,0,0,4,4,4,4} {0,0,0,0,1,1,1,1}

2.3 thread block swizzle

这里就是在triton gemm的见过的L2优化,把原来连续的block计算一行的块改为同时计算多行的块,这样能提高对L2中的B矩阵的利用率。只对bid_m和bid_n做swizzle就可以实现,不需要修改任何其他代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
if constexpr (GROUP_M == 0) {
// no swizzling
bid_m = bid / grid_n;
bid_n = bid % grid_n;
}
else {
// threadblock swizzling to improve L2 cache hit rate
// <https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html>
// each group is [GROUP_M, grid_n], tile from top (small M) to bottom (large M).
// the last group might be shorter than GROUP_M if grid_m % GROUP_M != 0.
const int group_size = GROUP_M * grid_n;
const int group_id = bid / group_size;
const int group_off_m = group_id * GROUP_M;
const int group_m = std::min(grid_m - group_off_m, GROUP_M); // actual group height

bid_m = group_off_m + ((bid % group_size) % group_m);
bid_n = (bid % group_size) / group_m;
}

2.4 pipeline

在之前的版本中,从globalmemory拷贝数据到shared memory的时间只能等待,无法真正异步,因为后面的计算需要搬运的数据,overlap这部分延迟,需要开双倍的shared memory,每次搬运下一次计算需要的数据。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
auto load_AB = [&](int k_iter) {
if (k_iter < num_k_iters) {
// select the correct shared memory buffer
const int stage_id = k_iter % NUM_STAGES;
global_to_shared_async<TB_SIZE, BLOCK_M, BLOCK_K>(A, K, A_shm + stage_id * AB_size, tid);
global_to_shared_async<TB_SIZE, BLOCK_N, BLOCK_K>(B, K, B_shm + stage_id * AB_size, tid);

// A/B pointer tracks position for global->shared load
A += BLOCK_K;
B += BLOCK_K;
}
cp_async_commit_group();
};

// initiate NUM_STAGES-1 stages
for (int stage = 0; stage < NUM_STAGES - 1; stage++)
load_AB(stage);

// loop invariance: there is always NUM_STAGES - 1 prefetch stages in-flight
// thanks to pipelining, this loop now only has 1 __syncthreads()
for (int k_iter = 0; k_iter < num_k_iters; k_iter++) {
// wait for previous MMA to finish using the shared buffer
__syncthreads();

// prefetch the next stage. add 1 more stage to the pipeline
load_AB(k_iter + NUM_STAGES - 1);

// wait for the 1st stage to finish. remove 1 stage from the pipeline
// -> restore loop invariance
cp_async_wait_group<NUM_STAGES - 1>();
__syncthreads();

// A shared->regs...

// B shared->regs...

// do MMA. NUM_STAGES-1 prefetch stages are still on-going
}