From f6de9095a0d4fc4721c5f492fe58d2bfdd3389f9 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sat, 27 Dec 2025 02:15:16 +0900 Subject: [PATCH] switch asm tests to dsl (#13840) * switch asm tests to dsl * labeled basic blocks also work * indenting for basic blocks * allow define from star import --- test/testextra/test_cfg_viz.py | 186 +++++++++++++++++---------------- 1 file changed, 96 insertions(+), 90 deletions(-) diff --git a/test/testextra/test_cfg_viz.py b/test/testextra/test_cfg_viz.py index 9767d2ceee..2aa4c03f4f 100644 --- a/test/testextra/test_cfg_viz.py +++ b/test/testextra/test_cfg_viz.py @@ -1,3 +1,6 @@ +# ruff: noqa: F405, F403 +# allow define from star imports + import unittest import textwrap @@ -7,6 +10,8 @@ from tinygrad.renderer import ProgramSpec from tinygrad.helpers import TracingKey, getenv from tinygrad.engine.realize import ExecItem, CompiledRunner +from extra.assembly.rdna3.autogen import * + # TODO: use the RDNA3 renderer when it's in master template = """.text .globl fn_name @@ -53,7 +58,8 @@ amdhsa.kernels: """ @track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, ret=ret)) -def run_asm(name:str, src:str) -> ProgramSpec: +def run_asm(name:str, insts:list) -> ProgramSpec: + src = "\n".join([inst if isinstance(inst, str) else inst.disasm() for inst in insts]) prg = ProgramSpec(name, template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK), global_size=[1, 1, 1], local_size=[1, 1, 1], globals=[0]) ei = ExecItem(UOp(Ops.SINK), [Tensor.empty(1).uop.buffer.ensure_allocated()], prg=CompiledRunner(prg)) @@ -68,107 +74,107 @@ class TestCfg(unittest.TestCase): self.skipTest(f"tests written for RDNA, got arch {arch}") def test_simple(self): - run_asm("simple", """ - entry: - s_branch bb1 - bb1: - s_endpgm - """) + run_asm("simple", [ + "entry:", + s_branch("bb1"), + "bb1:", + s_endpgm(), + ]) def test_diamond(self): - run_asm("diamond", """ - entry: - s_cmp_eq_i32 s0, 0 - s_cbranch_scc1 if - s_branch else - if: - s_nop 1 - s_branch end - else: - s_nop 0 - end: - s_endpgm - """) + run_asm("diamond", [ + "entry:", + s_cmp_eq_i32(s[0], 0), + s_cbranch_scc1("if"), + s_branch("else"), + "if:", + s_nop(1), + s_branch("end"), + "else:", + s_nop(0), + "end:", + s_endpgm(), + ]) def test_loop(self): - run_asm("simple_loop", """ - entry: - s_mov_b32 s1, 4 - loop: - s_add_u32 s1, s1, -1 - s_cmp_eq_i32 s1, 0 - s_cbranch_scc0 loop - s_endpgm - """) + run_asm("simple_loop", [ + "entry:", + s_mov_b32(s[1], 4), + "loop:", + s_add_u32(s[1], s[1], -1), + s_cmp_eq_i32(s[1], 0), + s_cbranch_scc0("loop"), + s_endpgm(), + ]) def test_loop_branch(self): - run_asm("loop_if", """ - entry: - s_mov_b32 s1, 4 - loop: - s_add_u32 s1, s1, -1 - s_cmp_eq_i32 s1, 2 - s_cbranch_scc1 cond - s_branch cont - cond: - s_add_u32 s1, s1, -2 - cont: - s_cmp_eq_i32 s1, 0 - s_cbranch_scc0 loop - s_endpgm - """) + run_asm("loop_if", [ + "entry:", + s_mov_b32(s[1], 4), + "loop:", + s_add_u32(s[1], s[1], -1), + s_cmp_eq_i32(s[1], 2), + s_cbranch_scc1("cond"), + s_branch("cont"), + "cond:", + s_add_u32(s[1], s[1], -2), + "cont:", + s_cmp_eq_i32(s[1], 0), + s_cbranch_scc0("loop"), + s_endpgm(), + ]) def test_loop_break(self): - run_asm("loop_break", """ - entry: - s_mov_b32 s1, 8 - loop: - s_add_u32 s1, s1, -1 - s_cmp_eq_i32 s1, 5 - s_cbranch_scc1 break - s_cmp_eq_i32 s1, 0 - s_cbranch_scc0 loop - break: - s_endpgm - """) + run_asm("loop_break", [ + "entry:", + s_mov_b32(s[1], 8), + "loop:", + s_add_u32(s[1], s[1], -1), + s_cmp_eq_i32(s[1], 5), + s_cbranch_scc1("break"), + s_cmp_eq_i32(s[1], 0), + s_cbranch_scc0("loop"), + "break:", + s_endpgm(), + ]) def test_switch(self): - run_asm("switch_case", """ - entry: - s_cmp_eq_i32 s0, 0 - s_cbranch_scc1 case0 - s_cmp_eq_i32 s0, 1 - s_cbranch_scc1 case1 - s_branch case2 - case0: - s_nop 0 - s_branch join - case1: - s_nop 1 - s_branch join - case2: - s_nop 2 - s_branch join - join: - s_endpgm - """) + run_asm("switch_case", [ + "entry:", + s_cmp_eq_i32(s[0], 0), + s_cbranch_scc1("case0"), + s_cmp_eq_i32(s[0], 1), + s_cbranch_scc1("case1"), + s_branch("case2"), + "case0:", + s_nop(0), + s_branch("join"), + "case1:", + s_nop(1), + s_branch("join"), + "case2:", + s_nop(2), + s_branch("join"), + "join:", + s_endpgm(), + ]) def test_ping_pong(self): - run_asm("ping_pong", """ - entry: - s_cmp_eq_i32 s0, 0 - s_cbranch_scc1 ping - s_branch pong - ping: - s_cmp_eq_i32 s1, 0 - s_cbranch_scc1 pong - s_branch end - pong: - s_cmp_eq_i32 s2, 0 - s_cbranch_scc1 ping - end: - s_endpgm - """) + run_asm("ping_pong", [ + "entry:", + s_cmp_eq_i32(s[0], 0), + s_cbranch_scc1("ping"), + s_branch("pong"), + "ping:", + s_cmp_eq_i32(s[1], 0), + s_cbranch_scc1("pong"), + s_branch("end"), + "pong:", + s_cmp_eq_i32(s[2], 0), + s_cbranch_scc1("ping"), + "end:", + s_endpgm(), + ]) if __name__ == "__main__": unittest.main()