1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
| @triton.jit def matmul_kernel( a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): pid = tl.program_id(axis=0)
grid_n = tl.cdiv(N, BLOCK_N) pid_m = pid // grid_n pid_n = pid % grid_n
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_start in range(0, K, BLOCK_K): a_ptrs = a_ptr + (offs_m[:, None] * stride_am + (k_start + offs_k)[None, :] * stride_ak)
b_ptrs = b_ptr + ((k_start + offs_k)[:, None] * stride_bk + offs_n[None, :] * stride_bn)
a = tl.load(a_ptrs) b = tl.load(b_ptrs)
acc += tl.dot(a, b)
c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
tl.store(c_ptrs, acc)
|