mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
173 lines
3.8 KiB
Python
173 lines
3.8 KiB
Python
import unittest
|
|
import textwrap
|
|
|
|
from tinygrad import Device, Tensor
|
|
from tinygrad.uop.ops import UOp, Ops, track_rewrites
|
|
from tinygrad.renderer import ProgramSpec
|
|
from tinygrad.helpers import TracingKey
|
|
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
|
|
|
# TODO: use the RDNA3 renderer when it's in master
|
|
template = """.text
|
|
.globl fn_name
|
|
.p2align 8
|
|
.type fn_name,@function
|
|
fn_name:
|
|
INSTRUCTION
|
|
|
|
.rodata
|
|
.p2align 6
|
|
.amdhsa_kernel fn_name
|
|
.amdhsa_user_sgpr_kernarg_segment_ptr 1
|
|
.amdhsa_next_free_vgpr .amdgcn.next_free_vgpr
|
|
.amdhsa_next_free_sgpr .amdgcn.next_free_sgpr
|
|
.amdhsa_wavefront_size32 1
|
|
.end_amdhsa_kernel
|
|
|
|
.amdgpu_metadata
|
|
---
|
|
amdhsa.version:
|
|
- 1
|
|
- 0
|
|
amdhsa.kernels:
|
|
- .name: fn_name
|
|
.symbol: fn_name.kd
|
|
.group_segment_fixed_size: 0
|
|
.private_segment_fixed_size: 0
|
|
.wavefront_size: 32
|
|
.sgpr_count: 8
|
|
.vgpr_count: 8
|
|
.max_flat_workgroup_size: 1024
|
|
.kernarg_segment_align: 8
|
|
.kernarg_segment_size: 8
|
|
.args:
|
|
- .address_space: global
|
|
.name: a
|
|
.offset: 0
|
|
.size: 8
|
|
.type_name: 'float*'
|
|
.value_kind: global_buffer
|
|
...
|
|
.end_amdgpu_metadata
|
|
"""
|
|
|
|
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, ret=ret))
|
|
def run_asm(name:str, src:str) -> ProgramSpec:
|
|
prg = ProgramSpec(name, template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK))
|
|
ei = ExecItem(UOp(Ops.SINK), [Tensor.empty(1).uop.buffer.ensure_allocated()], prg=CompiledRunner(prg))
|
|
ei.run()
|
|
return prg
|
|
|
|
@unittest.skipUnless(Device.DEFAULT == "AMD", "only on AMD")
|
|
class TestCfg(unittest.TestCase):
|
|
def setUp(self):
|
|
arch = Device["AMD"].arch
|
|
if not any(arch.startswith(a) for a in {"gfx11", "gfx12"}):
|
|
self.skipTest(f"tests written for RDNA, got arch {arch}")
|
|
|
|
def test_simple(self):
|
|
run_asm("simple", """
|
|
entry:
|
|
s_branch bb1
|
|
bb1:
|
|
s_endpgm
|
|
""")
|
|
|
|
def test_diamond(self):
|
|
run_asm("diamond", """
|
|
entry:
|
|
s_cmp_eq_i32 s0, 0
|
|
s_cbranch_scc1 if
|
|
s_branch else
|
|
if:
|
|
s_nop 1
|
|
s_branch end
|
|
else:
|
|
s_nop 0
|
|
end:
|
|
s_endpgm
|
|
""")
|
|
|
|
def test_loop(self):
|
|
run_asm("simple_loop", """
|
|
entry:
|
|
s_mov_b32 s1, 4
|
|
loop:
|
|
s_add_u32 s1, s1, -1
|
|
s_cmp_eq_i32 s1, 0
|
|
s_cbranch_scc0 loop
|
|
s_endpgm
|
|
""")
|
|
|
|
def test_loop_branch(self):
|
|
run_asm("loop_if", """
|
|
entry:
|
|
s_mov_b32 s1, 4
|
|
loop:
|
|
s_add_u32 s1, s1, -1
|
|
s_cmp_eq_i32 s1, 2
|
|
s_cbranch_scc1 cond
|
|
s_branch cont
|
|
cond:
|
|
s_add_u32 s1, s1, -2
|
|
cont:
|
|
s_cmp_eq_i32 s1, 0
|
|
s_cbranch_scc0 loop
|
|
s_endpgm
|
|
""")
|
|
|
|
def test_loop_break(self):
|
|
run_asm("loop_break", """
|
|
entry:
|
|
s_mov_b32 s1, 8
|
|
loop:
|
|
s_add_u32 s1, s1, -1
|
|
s_cmp_eq_i32 s1, 5
|
|
s_cbranch_scc1 break
|
|
s_cmp_eq_i32 s1, 0
|
|
s_cbranch_scc0 loop
|
|
break:
|
|
s_endpgm
|
|
""")
|
|
|
|
def test_switch(self):
|
|
run_asm("switch_case", """
|
|
entry:
|
|
s_cmp_eq_i32 s0, 0
|
|
s_cbranch_scc1 case0
|
|
s_cmp_eq_i32 s0, 1
|
|
s_cbranch_scc1 case1
|
|
s_branch case2
|
|
case0:
|
|
s_nop 0
|
|
s_branch join
|
|
case1:
|
|
s_nop 1
|
|
s_branch join
|
|
case2:
|
|
s_nop 2
|
|
s_branch join
|
|
join:
|
|
s_endpgm
|
|
""")
|
|
|
|
def test_ping_pong(self):
|
|
run_asm("ping_pong", """
|
|
entry:
|
|
s_cmp_eq_i32 s0, 0
|
|
s_cbranch_scc1 ping
|
|
s_branch pong
|
|
ping:
|
|
s_cmp_eq_i32 s1, 0
|
|
s_cbranch_scc1 pong
|
|
s_branch end
|
|
pong:
|
|
s_cmp_eq_i32 s2, 0
|
|
s_cbranch_scc1 ping
|
|
end:
|
|
s_endpgm
|
|
""")
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|