simplify mi350x gemm / viz asm tests (#13984)

* mi350x gemm cleanup

* asm tests work

* simpler asm tests
This commit is contained in:
qazal
2026-01-03 11:11:07 +09:00
committed by GitHub
parent 7cbafb2ef1
commit 2cc64d71b0
5 changed files with 27 additions and 39 deletions

View File

@@ -3,24 +3,18 @@
import numpy as np
import unittest
import subprocess, struct, math, textwrap, functools
from tinygrad import Tensor, dtypes, Device, UOp
from tinygrad.uop.ops import Ops, KernelInfo
import subprocess, struct, math, functools
from tinygrad import Tensor, dtypes, Device
from tinygrad.helpers import getenv
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.asm import waitcnt
from test.testextra.test_cfg_viz import template
def custom_src(out:UOp, src:str, device:str, n_threads:int=1, n_workgroups:int=1) -> UOp:
lidx = UOp.special(n_threads, "lidx0")
gidx = UOp.special(n_workgroups, "gidx0")
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name="test"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
from test.testextra.test_cfg_viz import asm_kernel
def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize()
src = "\n".join(inst.disasm() for inst in [
insts = [
s_load_b64(s[0:1], s[0:1], NULL),
*asm,
v_lshlrev_b32_e32(v[0], 2, v[0]),
@@ -28,9 +22,8 @@ def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
#global_store_b32(v[0], v[1], s[0:1]),
global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]),
s_endpgm()
])
src = template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src))
out = Tensor.custom_kernel(out, fxn=functools.partial(custom_src, src=src, device=out.device, n_threads=n_threads))[0]
]
out = Tensor.custom_kernel(out, fxn=functools.partial(asm_kernel, name="test", insts=insts, device=out.device, n_threads=n_threads))[0]
out.realize()
return out.tolist()