mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
RDNA3 fp16 assembly gemm 85 TFLOPS (#13990)
This commit is contained in:
@@ -140,11 +140,11 @@ def hand_spec_kernel3():
|
||||
|
||||
return sink.sink(arg=KernelInfo(opts_to_apply=())).simplify()
|
||||
|
||||
def test_matmul(sink:UOp, N=N):
|
||||
def test_matmul(sink:UOp, dtype=dtypes.float32, N=N):
|
||||
rng = np.random.default_rng()
|
||||
a = Tensor(rng.random((N, N), dtype=np.float32)-0.5)
|
||||
b = Tensor(rng.random((N, N), dtype=np.float32)-0.5)
|
||||
hc = Tensor.empty(N, N)
|
||||
a = Tensor(rng.random((N, N), dtype=np.float32)-0.5, dtype=dtype)
|
||||
b = Tensor(rng.random((N, N), dtype=np.float32)-0.5, dtype=dtype)
|
||||
hc = Tensor.empty(N, N, dtype=dtype)
|
||||
Tensor.realize(a, b, hc)
|
||||
|
||||
ei = ExecItem(sink, [t.uop.buffer for t in [hc, a, b]], prg=get_runner(Device.DEFAULT, sink))
|
||||
|
||||
3018
extra/gemm/asm/rdna3/gemm.s
Normal file
3018
extra/gemm/asm/rdna3/gemm.s
Normal file
File diff suppressed because it is too large
Load Diff
76
extra/gemm/asm/rdna3/template.s
Normal file
76
extra/gemm/asm/rdna3/template.s
Normal file
@@ -0,0 +1,76 @@
|
||||
.text
|
||||
.section .text.
|
||||
.global gemm
|
||||
.p2align 8
|
||||
.type gemm,@function
|
||||
|
||||
gemm:
|
||||
INSTRUCTIONS
|
||||
|
||||
.section .rodata,"a",@progbits
|
||||
.p2align 6, 0x0
|
||||
.amdhsa_kernel gemm
|
||||
# basic memory requirements
|
||||
.amdhsa_group_segment_fixed_size 30336
|
||||
.amdhsa_private_segment_fixed_size 0
|
||||
.amdhsa_kernarg_size 32
|
||||
# register usage (RSRC1)
|
||||
.amdhsa_next_free_vgpr 256
|
||||
.amdhsa_next_free_sgpr 100
|
||||
# workgroup / workitem IDs (RSRC2)
|
||||
.amdhsa_system_sgpr_workgroup_id_x 1
|
||||
.amdhsa_system_sgpr_workgroup_id_y 1
|
||||
.amdhsa_system_sgpr_workgroup_id_z 1
|
||||
# user SGPRs: kernarg ptr in s[0:1]
|
||||
.amdhsa_user_sgpr_kernarg_segment_ptr 1
|
||||
.amdhsa_user_sgpr_count 2
|
||||
# gfx10+ / gfx11 specifics (RSRC1[29..31])
|
||||
.amdhsa_wavefront_size32 1
|
||||
.amdhsa_workgroup_processor_mode 1
|
||||
.amdhsa_memory_ordered 1
|
||||
.amdhsa_forward_progress 1
|
||||
# misc for gfx11
|
||||
.amdhsa_dx10_clamp 1
|
||||
.amdhsa_ieee_mode 1
|
||||
.amdhsa_uses_dynamic_stack 0
|
||||
.end_amdhsa_kernel
|
||||
|
||||
.amdgpu_metadata
|
||||
---
|
||||
amdhsa.kernels:
|
||||
- .args:
|
||||
- .address_space: generic
|
||||
.name: C
|
||||
.offset: 0
|
||||
.size: 8
|
||||
.value_kind: global_buffer
|
||||
.value_type: f16
|
||||
- .address_space: generic
|
||||
.name: A
|
||||
.offset: 8
|
||||
.size: 8
|
||||
.value_kind: global_buffer
|
||||
.value_type: f16
|
||||
- .address_space: generic
|
||||
.name: B
|
||||
.offset: 16
|
||||
.size: 8
|
||||
.value_kind: global_buffer
|
||||
.value_type: f16
|
||||
.group_segment_fixed_size: 30336
|
||||
.kernarg_segment_align: 8
|
||||
.kernarg_segment_size: 32
|
||||
.max_flat_workgroup_size: 128
|
||||
.name: gemm
|
||||
.private_segment_fixed_size: 0
|
||||
.sgpr_count: 70
|
||||
.sgpr_spill_count: 0
|
||||
.symbol: gemm.kd
|
||||
.vgpr_count: 256
|
||||
.vgpr_spill_count: 0
|
||||
.wavefront_size: 32
|
||||
amdhsa.version:
|
||||
- 1
|
||||
- 1
|
||||
...
|
||||
.end_amdgpu_metadata
|
||||
30
extra/gemm/asm/rdna3/test.py
Normal file
30
extra/gemm/asm/rdna3/test.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import math, pathlib
|
||||
|
||||
from tinygrad import Device, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
|
||||
from extra.gemm.amd_uop_matmul import test_matmul
|
||||
|
||||
N = 4096
|
||||
TN = 96
|
||||
THREADS_PER_WG = 128
|
||||
NUM_WG = math.ceil(N / TN) * math.ceil(N / TN)
|
||||
|
||||
dname:str = Device.DEFAULT
|
||||
template:str = (pathlib.Path(__file__).parent/"template.s").read_text()
|
||||
|
||||
def asm_kernel() -> UOp:
|
||||
lidx = UOp.special(THREADS_PER_WG, "lidx0")
|
||||
gidx = UOp.special(NUM_WG, "gidx0")
|
||||
|
||||
a = UOp.placeholder((N*N,), dtypes.half, slot=1)
|
||||
b = UOp.placeholder((N*N,), dtypes.half, slot=2)
|
||||
c = UOp.placeholder((N*N,), dtypes.half, slot=0)
|
||||
|
||||
src = template.replace("INSTRUCTIONS", (pathlib.Path(__file__).parent/"gemm.s").read_text())
|
||||
|
||||
sink = UOp.sink(a, b, c, lidx, gidx, arg=KernelInfo(name="gemm"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)), arg=())
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_matmul(asm_kernel(), dtype=dtypes.half, N=N)
|
||||
Reference in New Issue
Block a user