# ruff: noqa: F405, F403 # allow define from star imports import unittest import textwrap, functools from tinygrad import Device, Tensor from tinygrad.uop.ops import UOp, Ops, KernelInfo from tinygrad.helpers import getenv from extra.assembly.amd.autogen.rdna3.ins import * from extra.assembly.amd.dsl import Inst template = """.text .globl fn_name .p2align 8 .type fn_name,@function fn_name: INSTRUCTION .rodata .p2align 6 .amdhsa_kernel fn_name .amdhsa_kernarg_size 8 .amdhsa_user_sgpr_kernarg_segment_ptr 1 .amdhsa_next_free_vgpr .amdgcn.next_free_vgpr .amdhsa_next_free_sgpr .amdgcn.next_free_sgpr .amdhsa_wavefront_size32 1 .end_amdhsa_kernel .amdgpu_metadata --- amdhsa.version: - 1 - 0 amdhsa.kernels: - .name: fn_name .symbol: fn_name.kd .group_segment_fixed_size: 0 .private_segment_fixed_size: 0 .wavefront_size: 32 .sgpr_count: 8 .vgpr_count: 8 .max_flat_workgroup_size: 1024 .kernarg_segment_align: 8 .kernarg_segment_size: 8 .args: - .address_space: global .name: a .offset: 0 .size: 8 .type_name: 'float*' .value_kind: global_buffer ... .end_amdgpu_metadata """ def asm_kernel(out:UOp, insts:list[str|Inst], name:str, device:str, n_threads:int=1, n_workgroups:int=1) -> UOp: lidx = UOp.special(n_threads, "lidx0") gidx = UOp.special(n_workgroups, "gidx0") sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name=name)) asm = "\n".join([inst if isinstance(inst, str) else inst.disasm() for inst in insts]) src = template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(asm)) return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)), arg=()) def run_asm(name:str, insts:list) -> None: fxn = functools.partial(asm_kernel, insts=insts, name=name, device=Device.DEFAULT) out = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0] out.realize() @unittest.skipUnless(Device.DEFAULT == "AMD" and not getenv("AMD_LLVM"), "only on AMD with comgr") class TestCfg(unittest.TestCase): def setUp(self): arch = Device["AMD"].arch if not any(arch.startswith(a) for a in {"gfx11", "gfx12"}): self.skipTest(f"tests written for RDNA, got arch {arch}") def test_simple(self): run_asm("simple", [ "entry:", s_branch("bb1"), "bb1:", s_endpgm(), ]) def test_diamond(self): run_asm("diamond", [ "entry:", s_cmp_eq_i32(s[0], 0), s_cbranch_scc1("if"), s_branch("else"), "if:", s_nop(1), s_branch("end"), "else:", s_nop(0), "end:", s_endpgm(), ]) def test_loop(self): run_asm("simple_loop", [ "entry:", s_mov_b32(s[1], 4), "loop:", s_add_u32(s[1], s[1], -1), s_cmp_eq_i32(s[1], 0), s_cbranch_scc0("loop"), s_endpgm(), ]) def test_loop_branch(self): run_asm("loop_if", [ "entry:", s_mov_b32(s[1], 4), "loop:", s_add_u32(s[1], s[1], -1), s_cmp_eq_i32(s[1], 2), s_cbranch_scc1("cond"), s_branch("cont"), "cond:", s_add_u32(s[1], s[1], -2), "cont:", s_cmp_eq_i32(s[1], 0), s_cbranch_scc0("loop"), s_endpgm(), ]) def test_loop_break(self): run_asm("loop_break", [ "entry:", s_mov_b32(s[1], 8), "loop:", s_add_u32(s[1], s[1], -1), s_cmp_eq_i32(s[1], 5), s_cbranch_scc1("break"), s_cmp_eq_i32(s[1], 0), s_cbranch_scc0("loop"), "break:", s_endpgm(), ]) def test_switch(self): run_asm("switch_case", [ "entry:", s_cmp_eq_i32(s[0], 0), s_cbranch_scc1("case0"), s_cmp_eq_i32(s[0], 1), s_cbranch_scc1("case1"), s_branch("case2"), "case0:", s_nop(0), s_branch("join"), "case1:", s_nop(1), s_branch("join"), "case2:", s_nop(2), s_branch("join"), "join:", s_endpgm(), ]) def test_ping_pong(self): run_asm("ping_pong", [ "entry:", s_cmp_eq_i32(s[0], 0), s_cbranch_scc1("ping"), s_branch("pong"), "ping:", s_cmp_eq_i32(s[1], 0), s_cbranch_scc1("pong"), s_branch("end"), "pong:", s_cmp_eq_i32(s[2], 0), s_cbranch_scc1("ping"), "end:", s_endpgm(), ]) def test_colored_blocks(self): N = 10 asm = ["entry:", s_branch("init0"),] for i in range(N): asm += [f"init{i}:", s_mov_b32(s[1], i + 1), s_branch(loop:=f"loop{i}")] asm += [ f"{loop}:", s_nop(i & 7), s_add_u32(s[1], s[1], -1), s_cmp_eq_i32(s[1], 0), s_cbranch_scc0(loop), s_branch(f"init{i+1}" if i + 1 < N else "end"), ] asm += ["end:", s_endpgm()] run_asm("test_colored_blocks", asm) if __name__ == "__main__": unittest.main()