mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
simplify mi350x gemm / viz asm tests (#13984)
* mi350x gemm cleanup * asm tests work * simpler asm tests
This commit is contained in:
@@ -4,7 +4,6 @@
|
|||||||
s_load_dwordx2 s[32:33], s[0:1], 0x10 // B
|
s_load_dwordx2 s[32:33], s[0:1], 0x10 // B
|
||||||
// ** others kernel args
|
// ** others kernel args
|
||||||
s_load_dword s24, s[0:1], 0x18 // N
|
s_load_dword s24, s[0:1], 0x18 // N
|
||||||
s_load_dword s54, s[0:1], 0x1C // num work groups
|
|
||||||
s_waitcnt lgkmcnt(0)
|
s_waitcnt lgkmcnt(0)
|
||||||
// "info"
|
// "info"
|
||||||
s_mov_b32 s51, 1 // gemm_info = 1
|
s_mov_b32 s51, 1 // gemm_info = 1
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ INSTRUCTIONS
|
|||||||
# basic memory requirements
|
# basic memory requirements
|
||||||
.amdhsa_group_segment_fixed_size 133120
|
.amdhsa_group_segment_fixed_size 133120
|
||||||
.amdhsa_private_segment_fixed_size 0
|
.amdhsa_private_segment_fixed_size 0
|
||||||
.amdhsa_kernarg_size 32
|
.amdhsa_kernarg_size 28
|
||||||
# register usage (RSRC1)
|
# register usage (RSRC1)
|
||||||
.amdhsa_next_free_vgpr 504
|
.amdhsa_next_free_vgpr 504
|
||||||
.amdhsa_next_free_sgpr 96
|
.amdhsa_next_free_sgpr 96
|
||||||
@@ -61,15 +61,10 @@ amdhsa.kernels:
|
|||||||
.size: 4
|
.size: 4
|
||||||
.value_kind: by_value
|
.value_kind: by_value
|
||||||
.value_type: u32
|
.value_type: u32
|
||||||
- .name: num_wg
|
|
||||||
.offset: 28
|
|
||||||
.size: 4
|
|
||||||
.value_kind: by_value
|
|
||||||
.value_type: u32
|
|
||||||
.group_segment_fixed_size: 133120
|
.group_segment_fixed_size: 133120
|
||||||
.private_segment_fixed_size: 0
|
.private_segment_fixed_size: 0
|
||||||
.kernarg_segment_align: 8
|
.kernarg_segment_align: 8
|
||||||
.kernarg_segment_size: 32
|
.kernarg_segment_size: 28
|
||||||
.max_flat_workgroup_size: 256
|
.max_flat_workgroup_size: 256
|
||||||
.sgpr_count: 88
|
.sgpr_count: 88
|
||||||
.sgpr_spill_count: 0
|
.sgpr_spill_count: 0
|
||||||
|
|||||||
@@ -43,10 +43,9 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
|
|||||||
src = (pathlib.Path(__file__).parent/"template.s").read_text().replace("INSTRUCTIONS", fp.read_text())
|
src = (pathlib.Path(__file__).parent/"template.s").read_text().replace("INSTRUCTIONS", fp.read_text())
|
||||||
|
|
||||||
sz = UOp.variable("SZ", 256, 8192)
|
sz = UOp.variable("SZ", 256, 8192)
|
||||||
wg = UOp.variable("WG", 1, 1024)
|
|
||||||
|
|
||||||
sink = UOp.sink(C.base, A.base, B.base, sz, wg, lidx, gidx, arg=KernelInfo(name="gemm"))
|
sink = UOp.sink(C.base, A.base, B.base, sz, lidx, gidx, arg=KernelInfo(name="gemm"))
|
||||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
|
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)), arg=())
|
||||||
|
|
||||||
C_asm = Tensor.custom_kernel(C_asm, from_torch(A), from_torch(Bt), fxn=custom_asm_gemm)[0]
|
C_asm = Tensor.custom_kernel(C_asm, from_torch(A), from_torch(Bt), fxn=custom_asm_gemm)[0]
|
||||||
|
|
||||||
@@ -57,7 +56,7 @@ eis = [si.lower() for si in sched]
|
|||||||
|
|
||||||
with Context(DEBUG=2):
|
with Context(DEBUG=2):
|
||||||
for ei in eis:
|
for ei in eis:
|
||||||
et = ei.run({"SZ":N, "WG":NUM_WG}, wait=True)
|
et = ei.run({"SZ":N}, wait=True)
|
||||||
print(f"{(N*N*N*2 / et)*1e-12:.2f} REAL TFLOPS")
|
print(f"{(N*N*N*2 / et)*1e-12:.2f} REAL TFLOPS")
|
||||||
|
|
||||||
# ** correctness
|
# ** correctness
|
||||||
|
|||||||
@@ -3,24 +3,18 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import unittest
|
import unittest
|
||||||
import subprocess, struct, math, textwrap, functools
|
import subprocess, struct, math, functools
|
||||||
from tinygrad import Tensor, dtypes, Device, UOp
|
from tinygrad import Tensor, dtypes, Device
|
||||||
from tinygrad.uop.ops import Ops, KernelInfo
|
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import getenv
|
||||||
|
|
||||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||||
from extra.assembly.amd.asm import waitcnt
|
from extra.assembly.amd.asm import waitcnt
|
||||||
from test.testextra.test_cfg_viz import template
|
|
||||||
|
|
||||||
def custom_src(out:UOp, src:str, device:str, n_threads:int=1, n_workgroups:int=1) -> UOp:
|
from test.testextra.test_cfg_viz import asm_kernel
|
||||||
lidx = UOp.special(n_threads, "lidx0")
|
|
||||||
gidx = UOp.special(n_workgroups, "gidx0")
|
|
||||||
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name="test"))
|
|
||||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
|
|
||||||
|
|
||||||
def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
|
def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
|
||||||
out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize()
|
out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize()
|
||||||
src = "\n".join(inst.disasm() for inst in [
|
insts = [
|
||||||
s_load_b64(s[0:1], s[0:1], NULL),
|
s_load_b64(s[0:1], s[0:1], NULL),
|
||||||
*asm,
|
*asm,
|
||||||
v_lshlrev_b32_e32(v[0], 2, v[0]),
|
v_lshlrev_b32_e32(v[0], 2, v[0]),
|
||||||
@@ -28,9 +22,8 @@ def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
|
|||||||
#global_store_b32(v[0], v[1], s[0:1]),
|
#global_store_b32(v[0], v[1], s[0:1]),
|
||||||
global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]),
|
global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]),
|
||||||
s_endpgm()
|
s_endpgm()
|
||||||
])
|
]
|
||||||
src = template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src))
|
out = Tensor.custom_kernel(out, fxn=functools.partial(asm_kernel, name="test", insts=insts, device=out.device, n_threads=n_threads))[0]
|
||||||
out = Tensor.custom_kernel(out, fxn=functools.partial(custom_src, src=src, device=out.device, n_threads=n_threads))[0]
|
|
||||||
out.realize()
|
out.realize()
|
||||||
return out.tolist()
|
return out.tolist()
|
||||||
|
|
||||||
|
|||||||
@@ -2,17 +2,15 @@
|
|||||||
# allow define from star imports
|
# allow define from star imports
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
import textwrap
|
import textwrap, functools
|
||||||
|
|
||||||
from tinygrad import Device, Tensor
|
from tinygrad import Device, Tensor
|
||||||
from tinygrad.uop.ops import UOp, Ops, track_rewrites
|
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||||
from tinygrad.renderer import ProgramSpec
|
from tinygrad.helpers import getenv
|
||||||
from tinygrad.helpers import TracingKey, getenv
|
|
||||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
|
||||||
|
|
||||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||||
|
from extra.assembly.amd.dsl import Inst
|
||||||
|
|
||||||
# TODO: use the RDNA3 renderer when it's in master
|
|
||||||
template = """.text
|
template = """.text
|
||||||
.globl fn_name
|
.globl fn_name
|
||||||
.p2align 8
|
.p2align 8
|
||||||
@@ -57,14 +55,18 @@ amdhsa.kernels:
|
|||||||
.end_amdgpu_metadata
|
.end_amdgpu_metadata
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, ret=ret))
|
def asm_kernel(out:UOp, insts:list[str|Inst], name:str, device:str, n_threads:int=1, n_workgroups:int=1) -> UOp:
|
||||||
def run_asm(name:str, insts:list) -> ProgramSpec:
|
lidx = UOp.special(n_threads, "lidx0")
|
||||||
src = "\n".join([inst if isinstance(inst, str) else inst.disasm() for inst in insts])
|
gidx = UOp.special(n_workgroups, "gidx0")
|
||||||
prg = ProgramSpec(name, src:=template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK),
|
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name=name))
|
||||||
lib=Device[Device.DEFAULT].compiler.compile(src), global_size=[1, 1, 1], local_size=[1, 1, 1], globals=[0])
|
asm = "\n".join([inst if isinstance(inst, str) else inst.disasm() for inst in insts])
|
||||||
ei = ExecItem(UOp(Ops.SINK), [Tensor.empty(1).uop.buffer.ensure_allocated()], prg=CompiledRunner(prg))
|
src = template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(asm))
|
||||||
ei.run()
|
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)), arg=())
|
||||||
return prg
|
|
||||||
|
def run_asm(name:str, insts:list) -> None:
|
||||||
|
fxn = functools.partial(asm_kernel, insts=insts, name=name, device=Device.DEFAULT)
|
||||||
|
out = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
|
||||||
|
out.realize()
|
||||||
|
|
||||||
@unittest.skipUnless(Device.DEFAULT == "AMD" and not getenv("AMD_LLVM"), "only on AMD with comgr")
|
@unittest.skipUnless(Device.DEFAULT == "AMD" and not getenv("AMD_LLVM"), "only on AMD with comgr")
|
||||||
class TestCfg(unittest.TestCase):
|
class TestCfg(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user