simplify mi350x gemm / viz asm tests (#13984)

* mi350x gemm cleanup

* asm tests work

* simpler asm tests
This commit is contained in:
qazal
2026-01-03 11:11:07 +09:00
committed by GitHub
parent 7cbafb2ef1
commit 2cc64d71b0
5 changed files with 27 additions and 39 deletions

View File

@@ -4,7 +4,6 @@
s_load_dwordx2 s[32:33], s[0:1], 0x10 // B
// ** others kernel args
s_load_dword s24, s[0:1], 0x18 // N
s_load_dword s54, s[0:1], 0x1C // num work groups
s_waitcnt lgkmcnt(0)
// "info"
s_mov_b32 s51, 1 // gemm_info = 1

View File

@@ -13,7 +13,7 @@ INSTRUCTIONS
# basic memory requirements
.amdhsa_group_segment_fixed_size 133120
.amdhsa_private_segment_fixed_size 0
.amdhsa_kernarg_size 32
.amdhsa_kernarg_size 28
# register usage (RSRC1)
.amdhsa_next_free_vgpr 504
.amdhsa_next_free_sgpr 96
@@ -61,15 +61,10 @@ amdhsa.kernels:
.size: 4
.value_kind: by_value
.value_type: u32
- .name: num_wg
.offset: 28
.size: 4
.value_kind: by_value
.value_type: u32
.group_segment_fixed_size: 133120
.private_segment_fixed_size: 0
.kernarg_segment_align: 8
.kernarg_segment_size: 32
.kernarg_segment_size: 28
.max_flat_workgroup_size: 256
.sgpr_count: 88
.sgpr_spill_count: 0

View File

@@ -43,10 +43,9 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
src = (pathlib.Path(__file__).parent/"template.s").read_text().replace("INSTRUCTIONS", fp.read_text())
sz = UOp.variable("SZ", 256, 8192)
wg = UOp.variable("WG", 1, 1024)
sink = UOp.sink(C.base, A.base, B.base, sz, wg, lidx, gidx, arg=KernelInfo(name="gemm"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
sink = UOp.sink(C.base, A.base, B.base, sz, lidx, gidx, arg=KernelInfo(name="gemm"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)), arg=())
C_asm = Tensor.custom_kernel(C_asm, from_torch(A), from_torch(Bt), fxn=custom_asm_gemm)[0]
@@ -57,7 +56,7 @@ eis = [si.lower() for si in sched]
with Context(DEBUG=2):
for ei in eis:
et = ei.run({"SZ":N, "WG":NUM_WG}, wait=True)
et = ei.run({"SZ":N}, wait=True)
print(f"{(N*N*N*2 / et)*1e-12:.2f} REAL TFLOPS")
# ** correctness

View File

@@ -3,24 +3,18 @@
import numpy as np
import unittest
import subprocess, struct, math, textwrap, functools
from tinygrad import Tensor, dtypes, Device, UOp
from tinygrad.uop.ops import Ops, KernelInfo
import subprocess, struct, math, functools
from tinygrad import Tensor, dtypes, Device
from tinygrad.helpers import getenv
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.asm import waitcnt
from test.testextra.test_cfg_viz import template
def custom_src(out:UOp, src:str, device:str, n_threads:int=1, n_workgroups:int=1) -> UOp:
lidx = UOp.special(n_threads, "lidx0")
gidx = UOp.special(n_workgroups, "gidx0")
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name="test"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
from test.testextra.test_cfg_viz import asm_kernel
def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize()
src = "\n".join(inst.disasm() for inst in [
insts = [
s_load_b64(s[0:1], s[0:1], NULL),
*asm,
v_lshlrev_b32_e32(v[0], 2, v[0]),
@@ -28,9 +22,8 @@ def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
#global_store_b32(v[0], v[1], s[0:1]),
global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]),
s_endpgm()
])
src = template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src))
out = Tensor.custom_kernel(out, fxn=functools.partial(custom_src, src=src, device=out.device, n_threads=n_threads))[0]
]
out = Tensor.custom_kernel(out, fxn=functools.partial(asm_kernel, name="test", insts=insts, device=out.device, n_threads=n_threads))[0]
out.realize()
return out.tolist()

View File

@@ -2,17 +2,15 @@
# allow define from star imports
import unittest
import textwrap
import textwrap, functools
from tinygrad import Device, Tensor
from tinygrad.uop.ops import UOp, Ops, track_rewrites
from tinygrad.renderer import ProgramSpec
from tinygrad.helpers import TracingKey, getenv
from tinygrad.engine.realize import ExecItem, CompiledRunner
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.helpers import getenv
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.dsl import Inst
# TODO: use the RDNA3 renderer when it's in master
template = """.text
.globl fn_name
.p2align 8
@@ -57,14 +55,18 @@ amdhsa.kernels:
.end_amdgpu_metadata
"""
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, ret=ret))
def run_asm(name:str, insts:list) -> ProgramSpec:
src = "\n".join([inst if isinstance(inst, str) else inst.disasm() for inst in insts])
prg = ProgramSpec(name, src:=template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK),
lib=Device[Device.DEFAULT].compiler.compile(src), global_size=[1, 1, 1], local_size=[1, 1, 1], globals=[0])
ei = ExecItem(UOp(Ops.SINK), [Tensor.empty(1).uop.buffer.ensure_allocated()], prg=CompiledRunner(prg))
ei.run()
return prg
def asm_kernel(out:UOp, insts:list[str|Inst], name:str, device:str, n_threads:int=1, n_workgroups:int=1) -> UOp:
lidx = UOp.special(n_threads, "lidx0")
gidx = UOp.special(n_workgroups, "gidx0")
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name=name))
asm = "\n".join([inst if isinstance(inst, str) else inst.disasm() for inst in insts])
src = template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(asm))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)), arg=())
def run_asm(name:str, insts:list) -> None:
fxn = functools.partial(asm_kernel, insts=insts, name=name, device=Device.DEFAULT)
out = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
out.realize()
@unittest.skipUnless(Device.DEFAULT == "AMD" and not getenv("AMD_LLVM"), "only on AMD with comgr")
class TestCfg(unittest.TestCase):