RDNA3 fp16 assembly gemm 85 TFLOPS (#13990)

This commit is contained in:
qazal
2026-01-03 18:34:23 +09:00
committed by GitHub
parent 6242a9d151
commit bd55507ee4
4 changed files with 3128 additions and 4 deletions

View File

@@ -140,11 +140,11 @@ def hand_spec_kernel3():
return sink.sink(arg=KernelInfo(opts_to_apply=())).simplify()
def test_matmul(sink:UOp, N=N):
def test_matmul(sink:UOp, dtype=dtypes.float32, N=N):
rng = np.random.default_rng()
a = Tensor(rng.random((N, N), dtype=np.float32)-0.5)
b = Tensor(rng.random((N, N), dtype=np.float32)-0.5)
hc = Tensor.empty(N, N)
a = Tensor(rng.random((N, N), dtype=np.float32)-0.5, dtype=dtype)
b = Tensor(rng.random((N, N), dtype=np.float32)-0.5, dtype=dtype)
hc = Tensor.empty(N, N, dtype=dtype)
Tensor.realize(a, b, hc)
ei = ExecItem(sink, [t.uop.buffer for t in [hc, a, b]], prg=get_runner(Device.DEFAULT, sink))

3018
extra/gemm/asm/rdna3/gemm.s Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,76 @@
.text
.section .text.
.global gemm
.p2align 8
.type gemm,@function
gemm:
INSTRUCTIONS
.section .rodata,"a",@progbits
.p2align 6, 0x0
.amdhsa_kernel gemm
# basic memory requirements
.amdhsa_group_segment_fixed_size 30336
.amdhsa_private_segment_fixed_size 0
.amdhsa_kernarg_size 32
# register usage (RSRC1)
.amdhsa_next_free_vgpr 256
.amdhsa_next_free_sgpr 100
# workgroup / workitem IDs (RSRC2)
.amdhsa_system_sgpr_workgroup_id_x 1
.amdhsa_system_sgpr_workgroup_id_y 1
.amdhsa_system_sgpr_workgroup_id_z 1
# user SGPRs: kernarg ptr in s[0:1]
.amdhsa_user_sgpr_kernarg_segment_ptr 1
.amdhsa_user_sgpr_count 2
# gfx10+ / gfx11 specifics (RSRC1[29..31])
.amdhsa_wavefront_size32 1
.amdhsa_workgroup_processor_mode 1
.amdhsa_memory_ordered 1
.amdhsa_forward_progress 1
# misc for gfx11
.amdhsa_dx10_clamp 1
.amdhsa_ieee_mode 1
.amdhsa_uses_dynamic_stack 0
.end_amdhsa_kernel
.amdgpu_metadata
---
amdhsa.kernels:
- .args:
- .address_space: generic
.name: C
.offset: 0
.size: 8
.value_kind: global_buffer
.value_type: f16
- .address_space: generic
.name: A
.offset: 8
.size: 8
.value_kind: global_buffer
.value_type: f16
- .address_space: generic
.name: B
.offset: 16
.size: 8
.value_kind: global_buffer
.value_type: f16
.group_segment_fixed_size: 30336
.kernarg_segment_align: 8
.kernarg_segment_size: 32
.max_flat_workgroup_size: 128
.name: gemm
.private_segment_fixed_size: 0
.sgpr_count: 70
.sgpr_spill_count: 0
.symbol: gemm.kd
.vgpr_count: 256
.vgpr_spill_count: 0
.wavefront_size: 32
amdhsa.version:
- 1
- 1
...
.end_amdgpu_metadata

View File

@@ -0,0 +1,30 @@
import math, pathlib
from tinygrad import Device, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from extra.gemm.amd_uop_matmul import test_matmul
N = 4096
TN = 96
THREADS_PER_WG = 128
NUM_WG = math.ceil(N / TN) * math.ceil(N / TN)
dname:str = Device.DEFAULT
template:str = (pathlib.Path(__file__).parent/"template.s").read_text()
def asm_kernel() -> UOp:
lidx = UOp.special(THREADS_PER_WG, "lidx0")
gidx = UOp.special(NUM_WG, "gidx0")
a = UOp.placeholder((N*N,), dtypes.half, slot=1)
b = UOp.placeholder((N*N,), dtypes.half, slot=2)
c = UOp.placeholder((N*N,), dtypes.half, slot=0)
src = template.replace("INSTRUCTIONS", (pathlib.Path(__file__).parent/"gemm.s").read_text())
sink = UOp.sink(a, b, c, lidx, gidx, arg=KernelInfo(name="gemm"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)), arg=())
if __name__ == "__main__":
test_matmul(asm_kernel(), dtype=dtypes.half, N=N)