本篇是对SM120架构下Tensor Core GEMM实现的源码解读笔记。代码来源于仓库:

https://github.com/gau-nernst/learn-cuda

在上一篇中,我们已经看过了SM80的Tensor Core GEMM,通过基本的优化方法和对cp.async,在5060上实现了大约33TFLOPs的性能,大约是峰值的90%,距离cublas以及峰值还有一些差距,这些差距来自于没有利用TMA来进行数据搬运。

TMA是Hopper架构新增加的拷贝DMA引擎,目的是解决大块显存拷贝中的多次地址计算等操作。TMA支持大块bulk的异步显存拷贝,可以减少拷贝指令数量,并支持多维度的显存块拷贝,sharedMemoryGlobalMemory的异步数据拷贝。

对于SM120,没有SM100的TMEM以及tcgen05,2-SM-MMA等,只要充分利用SM90-TMA就可以达到峰值性能,因此在利用了TMA后,gemm实现可以实现峰值性能。本文是对代码仓库02c-sm120的解读。

仍然是从指令开始,逐步介绍gemm实现,一些之前sm80代码实现已经包含的优化不再单独介绍。

1 指令

ldmatrix等指令略过。

1.1 elect.sync

elect.sync是用于从一个WARP中选出一个线程作为领袖线程的指令。因为TMA指令的发射是单线程完成的。单条指令设置的是谓词寄存器,指令原型如下:

1
2
//source_mask表示参加选举的线程,通常是0XFFFFFFFF
elect.sync dest_mask | dest_pred, source_mask;

配合一个条件赋值,来让一个线程拿到pred=1。

1
2
3
4
5
6
7
8
9
10
11
asm volatile(
"{\\n" // 1. 开始一个局部作用域(类似 C 语言的 {})
".reg .pred P;\\n" // 2. 声明一个临时的“谓词寄存器” P(布尔类型,只能是 0 或 1)
"elect.sync _|P, %1;\\n" // 3. 核心指令:选举
"@P mov.s32 %0, 1;\\n" // 4. 条件执行:如果我是选中的那个人,就把输出设为 1
"}" // 5. 结束作用域
: "+r"(pred) // 输出操作数 %0
: "r"(0xFFFF'FFFF) // 输入操作数 %1 (参与选举的线程掩码,全 1 表示全员参与)
);
//@P表示条件执行
//+r表示可读写

1.2 mbarrier

mbarrier是在SM80就引入的指令,但在SM90中开始扩展加强并配合TMA广泛使用。接下来是相关的指令:

mbarrier.init是TMA搬运的初始化指令,这个指令会设置初始计数和barrier的初始状态。这个初始化只由领袖线程调用。

1
2
3
4
5
6
7
8
9
mbarrier.init{.shared::cta}.b64  [addr], count;
///.b64 mbarrier在数据表示上表达为一个64bit的整数类型数据
///count表示参与同步的线程数量

//指令wrapper:
__device__ inline
void mbarrier_init(int addr, int count) {
asm volatile("mbarrier.init.shared::cta.b64 [%0], %1;" :: "r"(addr), "r"(count));
}

mbarrier.init初始化后,mbarrier.arrive表示达到的动作,相应的指令原型为:

1
2
3
4
5
6
7
8
9
10
mbarrier.arrive.release.cta.shared::cta.b64  _, [addr];
//.release arrive之前的写入可见
//_ 当前屏障的句柄,简单情况下不需要,所以丢弃
//: "memory"表示该条指令修改内存,确保编译器保证读写顺序

//指令wrapper:
__device__ inline
void mbarrier_arrive(int addr) {
asm volatile("mbarrier.arrive.release.cta.shared::cta.b64 _, [%0];" :: "r"(addr) : "memory");
}

这条指令会更新屏障的计数器,如果计数器归0,则屏障已完成,停止mbarrier等待的线程。

SM90引入的加强指令就是基于arrive添加一个增量的拷贝数量,原有的arrive只统计线程数量,而mbarrier.arrive.expect_tx会增加一个字节计数器值,当TMA搬运了指定字节数量的数据,barrier才会到达。

1
2
3
4
5
6
7
mbarrier.arrive.expect_tx.release.cta.shared::cta.b64  _, [addr], tx_count;

//指令wrapper:
__device__ inline
void mbarrier_arrive_expect_tx(int addr, int size) {
asm volatile("mbarrier.arrive.expect_tx.release.cta.shared::cta.b64 _, [%0], %1;" :: "r"(addr), "r"(size) : "memory");
}

上述指令配合的一个典型工作流是:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// 1. 选出领袖线程
if (elect_sync()) {
// 初始化屏障,只等 1 个领袖线程打卡
mbarrier_init(bar_ptr, 1);

// 还要等 2048 字节的 TMA 数据
mbarrier_arrive_expect_tx(bar_ptr, 2048);

// 发起 TMA 异步拷贝,关联该屏障
tma_copy_2d(dst, src, ..., bar_ptr);

// 领袖线程到达
mbarrier_arrive(bar_ptr);
}

// 2. 所有线程(包括非领袖)在这里等待
mbarrier_wait(bar_ptr, phase);

而mbarrier的等待如下所示,是一个轮询的过程,并可以通过phase控制相位。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
__device__ inline
void mbarrier_wait(int mbar_addr, int phase) {
int ticks = 0x989680; // this is optional
asm volatile(
"{\\n\\t"
".reg .pred P1;\\n\\t"
"LAB_WAIT:\\n\\t"
"mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 P1, [%0], %1, %2;\\n\\t"
"@!P1 bra.uni LAB_WAIT;\\n\\t"
"}"
:: "r"(mbar_addr), "r"(phase), "r"(ticks)
);
}
//bra.uni (Branch Uniform): 这是一个优化过的跳转指令。

1.3 TMA

TMA指令的原型是cp.async.bulk.tensor,支持多维的数据拷贝,2d数据块的拷贝原型:

1
cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];

.mbarrier::complete_tx::bytes表示搬运完数据后,去修改[%4]的mbarrier状态,只有当所有字节(Transactions)都确认写入共享内存后,才算完成。其他的操作数映射:

  • [%0] dst: 共享内存的目标起始地址。
  • %1,指向 Tensor Map 的 64 位指针。
  • %2,%3: 张量的 2D 坐标 (x, y)。注意,这里是坐标而非偏移量。
  • [%4] (mbar): mbarrier 对象的地址。

该条指令会根据Tensor Map描述的逻辑去进行地址映射,硬件自动计算真正的物理地址,做数据搬运,这也是相较于SM80的cp.async的重要的优化。

wrapper:

1
2
3
4
5
6
7
__device__ inline
void tma_2d_g2s(int dst, const void *tmap_ptr, int x, int y, int mbar_addr) {
asm volatile("cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes "
"[%0], [%1, {%2, %3}], [%4];"
:: "r"(dst), "l"(tmap_ptr), "r"(x), "r"(y), "r"(mbar_addr)
: "memory");
}

2 matmul with TMA

1
2
3
4
5
M=4096, N=4096, K=4096
CuBLAS: 3.9797 ms 34.53 TFLOPS 95.58% SOL
Inductor Triton: 3.8247 ms 35.93 TFLOPS 99.46% SOL
v0: 3.8688 ms 35.53 TFLOPS 98.33% SOL cp.async
v1: 3.8039 ms 36.13 TFLOPS 100.00% SOL TMA

2.1 init tensor map

CutensorMap是一个描述符,描述tensor的size,步长等信息,以及搬运到sharedMemory的块和步长信息。在计算开始前,需要初始化CutensorMap。

  • rank:n Dimension张量,对于2D tensor,值为2
  • size:张量大小,最快变化维度到最慢变化维度,因此对于行主序数据,是width height
  • stride:步长,第一个维度默认是连续的,所以stride数组有rank-1个值
  • box_size:搬运到shared Memory的块大小
  • elem_stride:跳跃元素数,需要读取连续的数据时,都是1
  • CUtensorMapSwizzle:swizzle模式,和smem_width的字节数是对应的
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
static void init_tensor_map(
CUtensorMap *tmap_ptr,
const nv_bfloat16 *gmem_ptr,
uint64_t gmem_height, uint64_t gmem_width,
uint32_t smem_height, uint32_t smem_width
) {
constexpr uint32_t rank = 2;
uint64_t size[rank] = {gmem_width, gmem_height};
uint64_t stride[rank - 1] = {gmem_width * sizeof(nv_bfloat16)}; // in bytes
uint32_t box_size[rank] = {smem_width, smem_height};
uint32_t elem_stride[rank] = {1, 1};

CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
if (smem_width == 16)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B;
else if (smem_width == 32)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B;
else if (smem_width == 64)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B;

auto res = cuTensorMapEncodeTiled(
tmap_ptr,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
rank,
(void *)gmem_ptr,
size,
stride,
box_size,
elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,// 不使用交错存储
swizzle,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,// 是否通过 L2 预取提升优先级
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE //越界填充,NONE 表示由硬件自动补零
);
if (res != CUDA_SUCCESS) {
const char *error_msg_ptr;
if (cuGetErrorString(res, &error_msg_ptr) != CUDA_SUCCESS)
error_msg_ptr = "unable to get error string";
std::cerr << "cuTensorMapEncodeTiled error: " << error_msg_ptr << std::endl;
}
};

2.2 entrance

核函数的入口处进行init_tensor_map,传入初始化的A_tmap和B_tmap,并且需要为mbarrier的使用预留shared memory的空间,每个mbarrier占用8字节空间,需要nstage个。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
void matmul_v1(const nv_bfloat16 *A, const nv_bfloat16 *B, nv_bfloat16 *C, int M, int N, int K) {
assert(is_power_of_two(M) && "M must be a power of 2");
assert(is_power_of_two(N) && "N must be a power of 2");
assert(is_power_of_two(K) && "K must be a power of 2");

// tuned for 5090
const int BLOCK_M = 128, BLOCK_N = 64, BLOCK_K = 64;
const int NUM_WARP_M = 2, NUM_WARP_N = 2;
const int NUM_STAGES = 2;

// tuned for PRO 6000
// const int BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 32;
// const int NUM_WARP_M = 2, NUM_WARP_N = 2;
// const int NUM_STAGES = 3;

auto kernel = matmul_v1_kernel<BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARP_M, NUM_WARP_N, NUM_STAGES>;

const int TB_SIZE = NUM_WARP_M * NUM_WARP_N * WARP_SIZE;
const int grid_size = cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N);
const int smem_size = (BLOCK_M + BLOCK_N) * BLOCK_K * sizeof(nv_bfloat16) * NUM_STAGES
+ NUM_STAGES * 8; // mbar

CUtensorMap A_tmap, B_tmap;
init_tensor_map(&A_tmap, A, M, K, BLOCK_M, BLOCK_K);
init_tensor_map(&B_tmap, B, N, K, BLOCK_N, BLOCK_K);

launch_kernel(kernel, grid_size, TB_SIZE, smem_size, A_tmap, B_tmap, C, M, N, K);
}

2.3 kernel

核函数的开头仍然是对形状和大小的检查,确保数据尺寸是可处理的,和SM80是完全一样的,block swizzle也在这里:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
template <int BLOCK_M, int BLOCK_N, int BLOCK_K, int NUM_WARP_M, int NUM_WARP_N, int NUM_STAGES>
__launch_bounds__(NUM_WARP_M * NUM_WARP_N * WARP_SIZE) // maxThreadsPerBlock
__global__
void matmul_v1_kernel(
const __grid_constant__ CUtensorMap A_tmap,
const __grid_constant__ CUtensorMap B_tmap,
nv_bfloat16 *C,
int M, int N, int K
) {
constexpr int WARP_M = BLOCK_M / NUM_WARP_M;
constexpr int WARP_N = BLOCK_N / NUM_WARP_N;

static_assert(BLOCK_M % NUM_WARP_M == 0);
static_assert(BLOCK_N % NUM_WARP_N == 0);
static_assert(WARP_M % MMA_M == 0);
static_assert(WARP_N % MMA_N == 0);

const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int warp_id = warp_uniform(tid / WARP_SIZE);
const int lane_id = tid % WARP_SIZE;

const int warp_id_m = warp_id / NUM_WARP_N;
const int warp_id_n = warp_id % NUM_WARP_N;

// threadblock swizzling to improve L2 cache hit rate
// <https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html>
constexpr int GROUP_M = 8;
const int grid_m = cdiv(M, BLOCK_M);
const int grid_n = cdiv(N, BLOCK_N);

// each group is [GROUP_M, grid_n], tile from top (small M) to bottom (large M).
// the last group might be shorter than GROUP_M if grid_m % GROUP_M != 0.
const int group_size = GROUP_M * grid_n;
const int group_id = bid / group_size;
const int group_off_m = group_id * GROUP_M;
const int group_m = std::min(grid_m - group_off_m, GROUP_M); // actual group height

const int bid_m = warp_uniform(group_off_m + ((bid % group_size) % group_m));
const int bid_n = warp_uniform((bid % group_size) / group_m);

const int off_m = bid_m * BLOCK_M;
const int off_n = bid_n * BLOCK_N;

constexpr int A_size = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16);
constexpr int B_size = BLOCK_N * BLOCK_K * sizeof(nv_bfloat16);
constexpr int AB_size = A_size + B_size;

然后是sharememory的地址偏移处理,之前提到的PTX指令都接收的是u32的指针,所以这里要通过__cvta_generic_to_shared对地址进行转换,转换后区分global/shared等位将被移除,只剩偏移量。

1
2
3
4
5
6
// set up smem
extern __shared__ __align__(1024) char smem[];
const int smem_addr = static_cast<int>(__cvta_generic_to_shared(smem));
const int A_smem = smem_addr;
const int B_smem = A_smem + A_size;
const int mbar_addr = smem_addr + NUM_STAGES * AB_size;

mbarrier的初始化后要插入一个fence,因为SM内部会分为执行代理和异步代理,执行代理执行的mbarrier_init必须对于异步代理是可见的:

1
2
3
4
5
6
7
if (warp_id == 0 && elect_sync()) {
for (int i = 0; i < NUM_STAGES; i++)
mbarrier_init(mbar_addr + i * 8, 1);
asm volatile("fence.mbarrier_init.release.cluster;"); // visible to async proxy
}
// make mbarrier visible
__syncthreads();

寄存器声明:

1
2
3
4
// set up rmem
int A_rmem[WARP_M / MMA_M][BLOCK_K / MMA_K][4];
int B_rmem[WARP_N / MMA_N][BLOCK_K / MMA_K][2];
float acc[WARP_M / MMA_M][WARP_N / MMA_N][4] = {};

为后续计算读取准备的Asmem和Bsmem swizzle基址,这里可以回看SM80时绘制的图,对于WARP10来说,要处理的是A1的在sharedMem中下半的部分,B0在sharedMem上半的部分,而基址就是这两块的起点,后续线程只要更新K维度上的偏移。

1
2
3
// pre-compute address and swizzling used for ldmatrix
const int A_smem_thread = A_smem + swizzle<BLOCK_K * sizeof(nv_bfloat16)>(warp_id_m * WARP_M + (lane_id % 16), lane_id / 16);
const int B_smem_thread = B_smem + swizzle<BLOCK_K * sizeof(nv_bfloat16)>(warp_id_n * WARP_N + (lane_id % 8), lane_id / 8);

image-20260302173446925

读取函数这里就是mbarrier重点优化的地方,在SM80实现中,我们需要计算每个线程要搬运的数据的起始地址,然后通过cp.async共同启动搬运,并且需要多轮次搬运,而在使用TMA后,这里只需要warp0的一个线程启动TMA数据搬运。

1
2
3
4
5
6
7
8
9
auto load_AB = [&](int iter_k, int stage_id) {
if (warp_id == 0 && elect_sync()) {
const int this_mbar_addr = mbar_addr + stage_id * 8;
const int off_k = iter_k * BLOCK_K;
tma_2d_g2s(A_smem + stage_id * AB_size, &A_tmap, off_k, off_m, this_mbar_addr);
tma_2d_g2s(B_smem + stage_id * AB_size, &B_tmap, off_k, off_n, this_mbar_addr);
mbarrier_arrive_expect_tx(this_mbar_addr, AB_size);
}
};

计算部分则没有什么变化,仍然是从sharedMem中读取数据到寄存器,然后m16n8k16进行计算。但是这里我们不用ldmatrix_x2来读取数据,而是使用ldmatrix_x4来一次读取两个B tile,减少指令数量。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
auto compute = [&](int stage_id) {
// A smem->rmem
for (int m = 0; m < WARP_M / MMA_M; m++)
for (int k = 0; k < BLOCK_K / MMA_K; k++) {
int addr = A_smem_thread + stage_id * AB_size;
addr += m * MMA_M * BLOCK_K * sizeof(nv_bfloat16);
ldmatrix_x4(A_rmem[m][k], addr ^ (k * 32));
}

// B smem->rmem
for (int n = 0; n < WARP_N / MMA_N; n++)
for (int k = 0; k < BLOCK_K / MMA_K; k += 2) {
int addr = B_smem_thread + stage_id * AB_size;
addr += n * MMA_N * BLOCK_K * sizeof(nv_bfloat16);
ldmatrix_x4(B_rmem[n][k], addr ^ (k * 32));
}

// MMA
for (int m = 0; m < WARP_M / MMA_M; m++)
for (int n = 0; n < WARP_N / MMA_N; n++)
for (int k = 0; k < BLOCK_K / MMA_K; k++)
mma_m16n8k16(A_rmem[m][k], B_rmem[n][k], acc[m][n]);
};

multi-stage流水线计算,对于数据搬运的同步,只需要warp0进行mbarrier,其他warp通过__syncthreads()进行等待,避免大量线程轮询,每次nstage个搬运结束后,都要翻转相位。在开始的时候需要先进行num_stage-1次预取,确保内存总线总是处于满载状态。虽然大多数时候num_stage是2,只有一次预取,但从pipeline设计的角度来说是要进行num_stage-1次预取,而不是1次预取就进入流水线。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
 const int num_k_iters = cdiv(K, BLOCK_K);

// initiate NUM_STAGES-1 TMA stages
for (int stage = 0; stage < NUM_STAGES - 1; stage++)
load_AB(stage, stage);

int stage = 0;
int phase = 0;

for (int iter_k = 0; iter_k < num_k_iters - (NUM_STAGES - 1); iter_k++) {
// wait for previous MMA
__syncthreads();

// issue TMA
const int prefetch_iter_k = iter_k + NUM_STAGES - 1;
load_AB(prefetch_iter_k, prefetch_iter_k % NUM_STAGES);

// wait for TMA
if (warp_id == 0)
mbarrier_wait(mbar_addr + stage * 8, phase);
__syncthreads();

// issue MMA
compute(stage);

// update stage and phase
stage = (stage + 1) % NUM_STAGES;
if (stage == 0)
phase ^= 1;
}

//last tile
for (int iter_k = num_k_iters - (NUM_STAGES - 1); iter_k < num_k_iters; iter_k++) {
if (warp_id == 0)
mbarrier_wait(mbar_addr + stage * 8, phase);
__syncthreads();

compute(stage);
stage = (stage + 1) % NUM_STAGES;
if (stage == 0)
phase ^= 1;
}

最后是数据的写回,与SM80的实现一致。

1
2
3
4
5
6
7
8
9
10
C += (off_m + warp_id_m * WARP_M) * N + (off_n + warp_id_n * WARP_N);
for (int m = 0; m < WARP_M / MMA_M; m++)
for (int n = 0; n < WARP_N / MMA_N; n++) {
const int row = m * MMA_M + (lane_id / 4); //4个线程持有一行的数据
const int col = n * MMA_N + (lane_id % 4) * 2; //每个线程在一行上有2个16B数据

float *regs = acc[m][n];
reinterpret_cast<nv_bfloat162 *>(C + (row + 0) * N + col)[0] = __float22bfloat162_rn({regs[0], regs[1]});
reinterpret_cast<nv_bfloat162 *>(C + (row + 8) * N + col)[0] = __float22bfloat162_rn({regs[2], regs[3]});
}