simplify mi350x gemm / viz asm tests (#13984)

* mi350x gemm cleanup

* asm tests work

* simpler asm tests
This commit is contained in:
qazal
2026-01-03 11:11:07 +09:00
committed by GitHub
parent 7cbafb2ef1
commit 2cc64d71b0
5 changed files with 27 additions and 39 deletions

View File

@@ -4,7 +4,6 @@
s_load_dwordx2 s[32:33], s[0:1], 0x10 // B s_load_dwordx2 s[32:33], s[0:1], 0x10 // B
// ** others kernel args // ** others kernel args
s_load_dword s24, s[0:1], 0x18 // N s_load_dword s24, s[0:1], 0x18 // N
s_load_dword s54, s[0:1], 0x1C // num work groups
s_waitcnt lgkmcnt(0) s_waitcnt lgkmcnt(0)
// "info" // "info"
s_mov_b32 s51, 1 // gemm_info = 1 s_mov_b32 s51, 1 // gemm_info = 1

View File

@@ -13,7 +13,7 @@ INSTRUCTIONS
# basic memory requirements # basic memory requirements
.amdhsa_group_segment_fixed_size 133120 .amdhsa_group_segment_fixed_size 133120
.amdhsa_private_segment_fixed_size 0 .amdhsa_private_segment_fixed_size 0
.amdhsa_kernarg_size 32 .amdhsa_kernarg_size 28
# register usage (RSRC1) # register usage (RSRC1)
.amdhsa_next_free_vgpr 504 .amdhsa_next_free_vgpr 504
.amdhsa_next_free_sgpr 96 .amdhsa_next_free_sgpr 96
@@ -61,15 +61,10 @@ amdhsa.kernels:
.size: 4 .size: 4
.value_kind: by_value .value_kind: by_value
.value_type: u32 .value_type: u32
- .name: num_wg
.offset: 28
.size: 4
.value_kind: by_value
.value_type: u32
.group_segment_fixed_size: 133120 .group_segment_fixed_size: 133120
.private_segment_fixed_size: 0 .private_segment_fixed_size: 0
.kernarg_segment_align: 8 .kernarg_segment_align: 8
.kernarg_segment_size: 32 .kernarg_segment_size: 28
.max_flat_workgroup_size: 256 .max_flat_workgroup_size: 256
.sgpr_count: 88 .sgpr_count: 88
.sgpr_spill_count: 0 .sgpr_spill_count: 0

View File

@@ -43,10 +43,9 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
src = (pathlib.Path(__file__).parent/"template.s").read_text().replace("INSTRUCTIONS", fp.read_text()) src = (pathlib.Path(__file__).parent/"template.s").read_text().replace("INSTRUCTIONS", fp.read_text())
sz = UOp.variable("SZ", 256, 8192) sz = UOp.variable("SZ", 256, 8192)
wg = UOp.variable("WG", 1, 1024)
sink = UOp.sink(C.base, A.base, B.base, sz, wg, lidx, gidx, arg=KernelInfo(name="gemm")) sink = UOp.sink(C.base, A.base, B.base, sz, lidx, gidx, arg=KernelInfo(name="gemm"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src))) return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)), arg=())
C_asm = Tensor.custom_kernel(C_asm, from_torch(A), from_torch(Bt), fxn=custom_asm_gemm)[0] C_asm = Tensor.custom_kernel(C_asm, from_torch(A), from_torch(Bt), fxn=custom_asm_gemm)[0]
@@ -57,7 +56,7 @@ eis = [si.lower() for si in sched]
with Context(DEBUG=2): with Context(DEBUG=2):
for ei in eis: for ei in eis:
et = ei.run({"SZ":N, "WG":NUM_WG}, wait=True) et = ei.run({"SZ":N}, wait=True)
print(f"{(N*N*N*2 / et)*1e-12:.2f} REAL TFLOPS") print(f"{(N*N*N*2 / et)*1e-12:.2f} REAL TFLOPS")
# ** correctness # ** correctness

View File

@@ -3,24 +3,18 @@
import numpy as np import numpy as np
import unittest import unittest
import subprocess, struct, math, textwrap, functools import subprocess, struct, math, functools
from tinygrad import Tensor, dtypes, Device, UOp from tinygrad import Tensor, dtypes, Device
from tinygrad.uop.ops import Ops, KernelInfo
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from extra.assembly.amd.autogen.rdna3.ins import * from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.asm import waitcnt from extra.assembly.amd.asm import waitcnt
from test.testextra.test_cfg_viz import template
def custom_src(out:UOp, src:str, device:str, n_threads:int=1, n_workgroups:int=1) -> UOp: from test.testextra.test_cfg_viz import asm_kernel
lidx = UOp.special(n_threads, "lidx0")
gidx = UOp.special(n_workgroups, "gidx0")
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name="test"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]): def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize() out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize()
src = "\n".join(inst.disasm() for inst in [ insts = [
s_load_b64(s[0:1], s[0:1], NULL), s_load_b64(s[0:1], s[0:1], NULL),
*asm, *asm,
v_lshlrev_b32_e32(v[0], 2, v[0]), v_lshlrev_b32_e32(v[0], 2, v[0]),
@@ -28,9 +22,8 @@ def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
#global_store_b32(v[0], v[1], s[0:1]), #global_store_b32(v[0], v[1], s[0:1]),
global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]), global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]),
s_endpgm() s_endpgm()
]) ]
src = template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src)) out = Tensor.custom_kernel(out, fxn=functools.partial(asm_kernel, name="test", insts=insts, device=out.device, n_threads=n_threads))[0]
out = Tensor.custom_kernel(out, fxn=functools.partial(custom_src, src=src, device=out.device, n_threads=n_threads))[0]
out.realize() out.realize()
return out.tolist() return out.tolist()

View File

@@ -2,17 +2,15 @@
# allow define from star imports # allow define from star imports
import unittest import unittest
import textwrap import textwrap, functools
from tinygrad import Device, Tensor from tinygrad import Device, Tensor
from tinygrad.uop.ops import UOp, Ops, track_rewrites from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.renderer import ProgramSpec from tinygrad.helpers import getenv
from tinygrad.helpers import TracingKey, getenv
from tinygrad.engine.realize import ExecItem, CompiledRunner
from extra.assembly.amd.autogen.rdna3.ins import * from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.dsl import Inst
# TODO: use the RDNA3 renderer when it's in master
template = """.text template = """.text
.globl fn_name .globl fn_name
.p2align 8 .p2align 8
@@ -57,14 +55,18 @@ amdhsa.kernels:
.end_amdgpu_metadata .end_amdgpu_metadata
""" """
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, ret=ret)) def asm_kernel(out:UOp, insts:list[str|Inst], name:str, device:str, n_threads:int=1, n_workgroups:int=1) -> UOp:
def run_asm(name:str, insts:list) -> ProgramSpec: lidx = UOp.special(n_threads, "lidx0")
src = "\n".join([inst if isinstance(inst, str) else inst.disasm() for inst in insts]) gidx = UOp.special(n_workgroups, "gidx0")
prg = ProgramSpec(name, src:=template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK), sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name=name))
lib=Device[Device.DEFAULT].compiler.compile(src), global_size=[1, 1, 1], local_size=[1, 1, 1], globals=[0]) asm = "\n".join([inst if isinstance(inst, str) else inst.disasm() for inst in insts])
ei = ExecItem(UOp(Ops.SINK), [Tensor.empty(1).uop.buffer.ensure_allocated()], prg=CompiledRunner(prg)) src = template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(asm))
ei.run() return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)), arg=())
return prg
def run_asm(name:str, insts:list) -> None:
fxn = functools.partial(asm_kernel, insts=insts, name=name, device=Device.DEFAULT)
out = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
out.realize()
@unittest.skipUnless(Device.DEFAULT == "AMD" and not getenv("AMD_LLVM"), "only on AMD with comgr") @unittest.skipUnless(Device.DEFAULT == "AMD" and not getenv("AMD_LLVM"), "only on AMD with comgr")
class TestCfg(unittest.TestCase): class TestCfg(unittest.TestCase):