mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-05 05:04:27 -05:00
simplify mi350x gemm / viz asm tests (#13984)
* mi350x gemm cleanup * asm tests work * simpler asm tests
This commit is contained in:
@@ -4,7 +4,6 @@
|
||||
s_load_dwordx2 s[32:33], s[0:1], 0x10 // B
|
||||
// ** others kernel args
|
||||
s_load_dword s24, s[0:1], 0x18 // N
|
||||
s_load_dword s54, s[0:1], 0x1C // num work groups
|
||||
s_waitcnt lgkmcnt(0)
|
||||
// "info"
|
||||
s_mov_b32 s51, 1 // gemm_info = 1
|
||||
|
||||
@@ -13,7 +13,7 @@ INSTRUCTIONS
|
||||
# basic memory requirements
|
||||
.amdhsa_group_segment_fixed_size 133120
|
||||
.amdhsa_private_segment_fixed_size 0
|
||||
.amdhsa_kernarg_size 32
|
||||
.amdhsa_kernarg_size 28
|
||||
# register usage (RSRC1)
|
||||
.amdhsa_next_free_vgpr 504
|
||||
.amdhsa_next_free_sgpr 96
|
||||
@@ -61,15 +61,10 @@ amdhsa.kernels:
|
||||
.size: 4
|
||||
.value_kind: by_value
|
||||
.value_type: u32
|
||||
- .name: num_wg
|
||||
.offset: 28
|
||||
.size: 4
|
||||
.value_kind: by_value
|
||||
.value_type: u32
|
||||
.group_segment_fixed_size: 133120
|
||||
.private_segment_fixed_size: 0
|
||||
.kernarg_segment_align: 8
|
||||
.kernarg_segment_size: 32
|
||||
.kernarg_segment_size: 28
|
||||
.max_flat_workgroup_size: 256
|
||||
.sgpr_count: 88
|
||||
.sgpr_spill_count: 0
|
||||
|
||||
@@ -43,10 +43,9 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
|
||||
src = (pathlib.Path(__file__).parent/"template.s").read_text().replace("INSTRUCTIONS", fp.read_text())
|
||||
|
||||
sz = UOp.variable("SZ", 256, 8192)
|
||||
wg = UOp.variable("WG", 1, 1024)
|
||||
|
||||
sink = UOp.sink(C.base, A.base, B.base, sz, wg, lidx, gidx, arg=KernelInfo(name="gemm"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
|
||||
sink = UOp.sink(C.base, A.base, B.base, sz, lidx, gidx, arg=KernelInfo(name="gemm"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)), arg=())
|
||||
|
||||
C_asm = Tensor.custom_kernel(C_asm, from_torch(A), from_torch(Bt), fxn=custom_asm_gemm)[0]
|
||||
|
||||
@@ -57,7 +56,7 @@ eis = [si.lower() for si in sched]
|
||||
|
||||
with Context(DEBUG=2):
|
||||
for ei in eis:
|
||||
et = ei.run({"SZ":N, "WG":NUM_WG}, wait=True)
|
||||
et = ei.run({"SZ":N}, wait=True)
|
||||
print(f"{(N*N*N*2 / et)*1e-12:.2f} REAL TFLOPS")
|
||||
|
||||
# ** correctness
|
||||
|
||||
@@ -3,24 +3,18 @@
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import subprocess, struct, math, textwrap, functools
|
||||
from tinygrad import Tensor, dtypes, Device, UOp
|
||||
from tinygrad.uop.ops import Ops, KernelInfo
|
||||
import subprocess, struct, math, functools
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
from extra.assembly.amd.asm import waitcnt
|
||||
from test.testextra.test_cfg_viz import template
|
||||
|
||||
def custom_src(out:UOp, src:str, device:str, n_threads:int=1, n_workgroups:int=1) -> UOp:
|
||||
lidx = UOp.special(n_threads, "lidx0")
|
||||
gidx = UOp.special(n_workgroups, "gidx0")
|
||||
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name="test"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
|
||||
from test.testextra.test_cfg_viz import asm_kernel
|
||||
|
||||
def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
|
||||
out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize()
|
||||
src = "\n".join(inst.disasm() for inst in [
|
||||
insts = [
|
||||
s_load_b64(s[0:1], s[0:1], NULL),
|
||||
*asm,
|
||||
v_lshlrev_b32_e32(v[0], 2, v[0]),
|
||||
@@ -28,9 +22,8 @@ def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
|
||||
#global_store_b32(v[0], v[1], s[0:1]),
|
||||
global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]),
|
||||
s_endpgm()
|
||||
])
|
||||
src = template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src))
|
||||
out = Tensor.custom_kernel(out, fxn=functools.partial(custom_src, src=src, device=out.device, n_threads=n_threads))[0]
|
||||
]
|
||||
out = Tensor.custom_kernel(out, fxn=functools.partial(asm_kernel, name="test", insts=insts, device=out.device, n_threads=n_threads))[0]
|
||||
out.realize()
|
||||
return out.tolist()
|
||||
|
||||
|
||||
@@ -2,17 +2,15 @@
|
||||
# allow define from star imports
|
||||
|
||||
import unittest
|
||||
import textwrap
|
||||
import textwrap, functools
|
||||
|
||||
from tinygrad import Device, Tensor
|
||||
from tinygrad.uop.ops import UOp, Ops, track_rewrites
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.helpers import TracingKey, getenv
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
from extra.assembly.amd.dsl import Inst
|
||||
|
||||
# TODO: use the RDNA3 renderer when it's in master
|
||||
template = """.text
|
||||
.globl fn_name
|
||||
.p2align 8
|
||||
@@ -57,14 +55,18 @@ amdhsa.kernels:
|
||||
.end_amdgpu_metadata
|
||||
"""
|
||||
|
||||
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, ret=ret))
|
||||
def run_asm(name:str, insts:list) -> ProgramSpec:
|
||||
src = "\n".join([inst if isinstance(inst, str) else inst.disasm() for inst in insts])
|
||||
prg = ProgramSpec(name, src:=template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK),
|
||||
lib=Device[Device.DEFAULT].compiler.compile(src), global_size=[1, 1, 1], local_size=[1, 1, 1], globals=[0])
|
||||
ei = ExecItem(UOp(Ops.SINK), [Tensor.empty(1).uop.buffer.ensure_allocated()], prg=CompiledRunner(prg))
|
||||
ei.run()
|
||||
return prg
|
||||
def asm_kernel(out:UOp, insts:list[str|Inst], name:str, device:str, n_threads:int=1, n_workgroups:int=1) -> UOp:
|
||||
lidx = UOp.special(n_threads, "lidx0")
|
||||
gidx = UOp.special(n_workgroups, "gidx0")
|
||||
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name=name))
|
||||
asm = "\n".join([inst if isinstance(inst, str) else inst.disasm() for inst in insts])
|
||||
src = template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(asm))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)), arg=())
|
||||
|
||||
def run_asm(name:str, insts:list) -> None:
|
||||
fxn = functools.partial(asm_kernel, insts=insts, name=name, device=Device.DEFAULT)
|
||||
out = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
|
||||
out.realize()
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "AMD" and not getenv("AMD_LLVM"), "only on AMD with comgr")
|
||||
class TestCfg(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user