mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
switch asm tests to dsl (#13840)
* switch asm tests to dsl * labeled basic blocks also work * indenting for basic blocks * allow define from star import
This commit is contained in:
@@ -1,3 +1,6 @@
|
||||
# ruff: noqa: F405, F403
|
||||
# allow define from star imports
|
||||
|
||||
import unittest
|
||||
import textwrap
|
||||
|
||||
@@ -7,6 +10,8 @@ from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.helpers import TracingKey, getenv
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
|
||||
from extra.assembly.rdna3.autogen import *
|
||||
|
||||
# TODO: use the RDNA3 renderer when it's in master
|
||||
template = """.text
|
||||
.globl fn_name
|
||||
@@ -53,7 +58,8 @@ amdhsa.kernels:
|
||||
"""
|
||||
|
||||
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, ret=ret))
|
||||
def run_asm(name:str, src:str) -> ProgramSpec:
|
||||
def run_asm(name:str, insts:list) -> ProgramSpec:
|
||||
src = "\n".join([inst if isinstance(inst, str) else inst.disasm() for inst in insts])
|
||||
prg = ProgramSpec(name, template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK),
|
||||
global_size=[1, 1, 1], local_size=[1, 1, 1], globals=[0])
|
||||
ei = ExecItem(UOp(Ops.SINK), [Tensor.empty(1).uop.buffer.ensure_allocated()], prg=CompiledRunner(prg))
|
||||
@@ -68,107 +74,107 @@ class TestCfg(unittest.TestCase):
|
||||
self.skipTest(f"tests written for RDNA, got arch {arch}")
|
||||
|
||||
def test_simple(self):
|
||||
run_asm("simple", """
|
||||
entry:
|
||||
s_branch bb1
|
||||
bb1:
|
||||
s_endpgm
|
||||
""")
|
||||
run_asm("simple", [
|
||||
"entry:",
|
||||
s_branch("bb1"),
|
||||
"bb1:",
|
||||
s_endpgm(),
|
||||
])
|
||||
|
||||
def test_diamond(self):
|
||||
run_asm("diamond", """
|
||||
entry:
|
||||
s_cmp_eq_i32 s0, 0
|
||||
s_cbranch_scc1 if
|
||||
s_branch else
|
||||
if:
|
||||
s_nop 1
|
||||
s_branch end
|
||||
else:
|
||||
s_nop 0
|
||||
end:
|
||||
s_endpgm
|
||||
""")
|
||||
run_asm("diamond", [
|
||||
"entry:",
|
||||
s_cmp_eq_i32(s[0], 0),
|
||||
s_cbranch_scc1("if"),
|
||||
s_branch("else"),
|
||||
"if:",
|
||||
s_nop(1),
|
||||
s_branch("end"),
|
||||
"else:",
|
||||
s_nop(0),
|
||||
"end:",
|
||||
s_endpgm(),
|
||||
])
|
||||
|
||||
def test_loop(self):
|
||||
run_asm("simple_loop", """
|
||||
entry:
|
||||
s_mov_b32 s1, 4
|
||||
loop:
|
||||
s_add_u32 s1, s1, -1
|
||||
s_cmp_eq_i32 s1, 0
|
||||
s_cbranch_scc0 loop
|
||||
s_endpgm
|
||||
""")
|
||||
run_asm("simple_loop", [
|
||||
"entry:",
|
||||
s_mov_b32(s[1], 4),
|
||||
"loop:",
|
||||
s_add_u32(s[1], s[1], -1),
|
||||
s_cmp_eq_i32(s[1], 0),
|
||||
s_cbranch_scc0("loop"),
|
||||
s_endpgm(),
|
||||
])
|
||||
|
||||
def test_loop_branch(self):
|
||||
run_asm("loop_if", """
|
||||
entry:
|
||||
s_mov_b32 s1, 4
|
||||
loop:
|
||||
s_add_u32 s1, s1, -1
|
||||
s_cmp_eq_i32 s1, 2
|
||||
s_cbranch_scc1 cond
|
||||
s_branch cont
|
||||
cond:
|
||||
s_add_u32 s1, s1, -2
|
||||
cont:
|
||||
s_cmp_eq_i32 s1, 0
|
||||
s_cbranch_scc0 loop
|
||||
s_endpgm
|
||||
""")
|
||||
run_asm("loop_if", [
|
||||
"entry:",
|
||||
s_mov_b32(s[1], 4),
|
||||
"loop:",
|
||||
s_add_u32(s[1], s[1], -1),
|
||||
s_cmp_eq_i32(s[1], 2),
|
||||
s_cbranch_scc1("cond"),
|
||||
s_branch("cont"),
|
||||
"cond:",
|
||||
s_add_u32(s[1], s[1], -2),
|
||||
"cont:",
|
||||
s_cmp_eq_i32(s[1], 0),
|
||||
s_cbranch_scc0("loop"),
|
||||
s_endpgm(),
|
||||
])
|
||||
|
||||
def test_loop_break(self):
|
||||
run_asm("loop_break", """
|
||||
entry:
|
||||
s_mov_b32 s1, 8
|
||||
loop:
|
||||
s_add_u32 s1, s1, -1
|
||||
s_cmp_eq_i32 s1, 5
|
||||
s_cbranch_scc1 break
|
||||
s_cmp_eq_i32 s1, 0
|
||||
s_cbranch_scc0 loop
|
||||
break:
|
||||
s_endpgm
|
||||
""")
|
||||
run_asm("loop_break", [
|
||||
"entry:",
|
||||
s_mov_b32(s[1], 8),
|
||||
"loop:",
|
||||
s_add_u32(s[1], s[1], -1),
|
||||
s_cmp_eq_i32(s[1], 5),
|
||||
s_cbranch_scc1("break"),
|
||||
s_cmp_eq_i32(s[1], 0),
|
||||
s_cbranch_scc0("loop"),
|
||||
"break:",
|
||||
s_endpgm(),
|
||||
])
|
||||
|
||||
def test_switch(self):
|
||||
run_asm("switch_case", """
|
||||
entry:
|
||||
s_cmp_eq_i32 s0, 0
|
||||
s_cbranch_scc1 case0
|
||||
s_cmp_eq_i32 s0, 1
|
||||
s_cbranch_scc1 case1
|
||||
s_branch case2
|
||||
case0:
|
||||
s_nop 0
|
||||
s_branch join
|
||||
case1:
|
||||
s_nop 1
|
||||
s_branch join
|
||||
case2:
|
||||
s_nop 2
|
||||
s_branch join
|
||||
join:
|
||||
s_endpgm
|
||||
""")
|
||||
run_asm("switch_case", [
|
||||
"entry:",
|
||||
s_cmp_eq_i32(s[0], 0),
|
||||
s_cbranch_scc1("case0"),
|
||||
s_cmp_eq_i32(s[0], 1),
|
||||
s_cbranch_scc1("case1"),
|
||||
s_branch("case2"),
|
||||
"case0:",
|
||||
s_nop(0),
|
||||
s_branch("join"),
|
||||
"case1:",
|
||||
s_nop(1),
|
||||
s_branch("join"),
|
||||
"case2:",
|
||||
s_nop(2),
|
||||
s_branch("join"),
|
||||
"join:",
|
||||
s_endpgm(),
|
||||
])
|
||||
|
||||
def test_ping_pong(self):
|
||||
run_asm("ping_pong", """
|
||||
entry:
|
||||
s_cmp_eq_i32 s0, 0
|
||||
s_cbranch_scc1 ping
|
||||
s_branch pong
|
||||
ping:
|
||||
s_cmp_eq_i32 s1, 0
|
||||
s_cbranch_scc1 pong
|
||||
s_branch end
|
||||
pong:
|
||||
s_cmp_eq_i32 s2, 0
|
||||
s_cbranch_scc1 ping
|
||||
end:
|
||||
s_endpgm
|
||||
""")
|
||||
run_asm("ping_pong", [
|
||||
"entry:",
|
||||
s_cmp_eq_i32(s[0], 0),
|
||||
s_cbranch_scc1("ping"),
|
||||
s_branch("pong"),
|
||||
"ping:",
|
||||
s_cmp_eq_i32(s[1], 0),
|
||||
s_cbranch_scc1("pong"),
|
||||
s_branch("end"),
|
||||
"pong:",
|
||||
s_cmp_eq_i32(s[2], 0),
|
||||
s_cbranch_scc1("ping"),
|
||||
"end:",
|
||||
s_endpgm(),
|
||||
])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user