mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
simplify mi350x gemm / viz asm tests (#13984)
* mi350x gemm cleanup * asm tests work * simpler asm tests
This commit is contained in:
@@ -3,24 +3,18 @@
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import subprocess, struct, math, textwrap, functools
|
||||
from tinygrad import Tensor, dtypes, Device, UOp
|
||||
from tinygrad.uop.ops import Ops, KernelInfo
|
||||
import subprocess, struct, math, functools
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
from extra.assembly.amd.asm import waitcnt
|
||||
from test.testextra.test_cfg_viz import template
|
||||
|
||||
def custom_src(out:UOp, src:str, device:str, n_threads:int=1, n_workgroups:int=1) -> UOp:
|
||||
lidx = UOp.special(n_threads, "lidx0")
|
||||
gidx = UOp.special(n_workgroups, "gidx0")
|
||||
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name="test"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
|
||||
from test.testextra.test_cfg_viz import asm_kernel
|
||||
|
||||
def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
|
||||
out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize()
|
||||
src = "\n".join(inst.disasm() for inst in [
|
||||
insts = [
|
||||
s_load_b64(s[0:1], s[0:1], NULL),
|
||||
*asm,
|
||||
v_lshlrev_b32_e32(v[0], 2, v[0]),
|
||||
@@ -28,9 +22,8 @@ def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
|
||||
#global_store_b32(v[0], v[1], s[0:1]),
|
||||
global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]),
|
||||
s_endpgm()
|
||||
])
|
||||
src = template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src))
|
||||
out = Tensor.custom_kernel(out, fxn=functools.partial(custom_src, src=src, device=out.device, n_threads=n_threads))[0]
|
||||
]
|
||||
out = Tensor.custom_kernel(out, fxn=functools.partial(asm_kernel, name="test", insts=insts, device=out.device, n_threads=n_threads))[0]
|
||||
out.realize()
|
||||
return out.tolist()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user