From 230d08ec708a3e41edcc4119d7efbd1930236dc6 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Thu, 29 Jan 2026 17:11:24 +0300 Subject: [PATCH] test for am recovery and faults handling (#14421) * test for am recovery and faults handling * linter --- .../external_test_am_fault_recovery.py | 137 ++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 test/external/external_test_am_fault_recovery.py diff --git a/test/external/external_test_am_fault_recovery.py b/test/external/external_test_am_fault_recovery.py new file mode 100644 index 0000000000..b637582e5d --- /dev/null +++ b/test/external/external_test_am_fault_recovery.py @@ -0,0 +1,137 @@ +# ruff: noqa: F405 +import unittest, subprocess, os +from extra.assembly.amd.autogen.rdna3.ins import * # noqa: F403 +from extra.assembly.amd.dsl import s, v, Inst, NULL + +def assemble_kernel(insts:list[Inst], name:str="test") -> str: + kd = {"next_free_vgpr": 8, "next_free_sgpr": 8, "wavefront_size32": 1, "user_sgpr_kernarg_segment_ptr": 1, "kernarg_size": 8} + disasm = "\n".join(inst.disasm() for inst in insts) + hsasrc = f".text\n.globl {name}\n.p2align 8\n.type {name},@function\n{name}:\n{disasm}\n" + return hsasrc + f".rodata\n.p2align 6\n.amdhsa_kernel {name}\n" + "\n".join(f".amdhsa_{k} {v}" for k, v in kd.items()) + "\n.end_amdhsa_kernel" + +def _run(code:str, timeout:float=15.0) -> subprocess.CompletedProcess: + # TODO: AM_RESET is required for now, so subprocesses + return subprocess.run(["python", "-c", code], env={**os.environ, "AMD": "1"}, capture_output=True, text=True, timeout=timeout) + +def _run_asm(asm_src:str) -> subprocess.CompletedProcess: + return _run('from tinygrad.device import Device; from tinygrad.runtime.ops_amd import AMDProgram; ' + 'from tinygrad.runtime.support.compiler_amd import HIPCompiler; dev = Device["AMD"]; ' + f'AMDProgram(dev, "test", HIPCompiler(dev.arch).compile("""{asm_src}"""))(' + 'dev.allocator.alloc(64), global_size=(1,1,1), local_size=(1,1,1), wait=True)') + +def _verify_recovery() -> subprocess.CompletedProcess: + return _run('from tinygrad import Tensor; t = Tensor([1.0, 2.0], device="AMD").realize(); assert (t + 1).numpy().tolist() == [2.0, 3.0]') + +_ILLEGAL_INST_ASM = ".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n.byte 0xff,0xff,0xff,0xff\ns_endpgm\n" \ + ".rodata\n.p2align 6\n.amdhsa_kernel test\n.amdhsa_next_free_vgpr 8\n.amdhsa_next_free_sgpr 8\n" \ + ".amdhsa_wavefront_size32 1\n.amdhsa_user_sgpr_kernarg_segment_ptr 1\n.amdhsa_kernarg_size 8\n.end_amdhsa_kernel" + +@unittest.skipIf(os.environ.get("AMD") != "1" or os.environ.get("MOCKGPU") == "1", "AMD with AM driver required") +class TestAMFaultRecovery(unittest.TestCase): + def _run_kernel(self, insts: list[Inst]) -> subprocess.CompletedProcess: return _run_asm(assemble_kernel(insts)) + + def _assert_fault_and_recovery(self, result:subprocess.CompletedProcess): + if result.stdout.strip(): print(f"\nstdout: {result.stdout.strip()}") + if result.stderr.strip(): print(f"\nstderr: {result.stderr.strip()}") + self.assertNotEqual(result.returncode, 0, f"Expected fault but succeeded: {result.stdout}") + self.assertEqual(_verify_recovery().returncode, 0) + + +class TestGlobalMemoryFaults(TestAMFaultRecovery): + def test_global_load_unmapped(self): + insts = [v_mov_b32_e32(v[0], 0xBEEF0000), v_mov_b32_e32(v[1], 0xDEAD), + global_load_b32(v[2], addr=v[0:1], saddr=NULL, offset=0), s_waitcnt(vmcnt=0), s_endpgm()] + self._assert_fault_and_recovery(self._run_kernel(insts)) + + def test_global_store_unmapped(self): + insts = [v_mov_b32_e32(v[0], 0xBEEF0000), v_mov_b32_e32(v[1], 0xDEAD), v_mov_b32_e32(v[2], 0x12345678), + global_store_b32(addr=v[0:1], data=v[2], saddr=NULL, offset=0), s_waitcnt(vmcnt=0), s_endpgm()] + self._assert_fault_and_recovery(self._run_kernel(insts)) + + def test_global_null_ptr(self): + insts = [v_mov_b32_e32(v[0], 0), v_mov_b32_e32(v[1], 0), + global_load_b32(v[2], addr=v[0:1], saddr=NULL, offset=0), s_waitcnt(vmcnt=0), s_endpgm()] + self._assert_fault_and_recovery(self._run_kernel(insts)) + + def test_global_misaligned_b64(self): + insts = [v_mov_b32_e32(v[0], 0xBEEF0001), v_mov_b32_e32(v[1], 0xDEAD), + global_load_b64(v[2:3], addr=v[0:1], saddr=NULL, offset=0), s_waitcnt(vmcnt=0), s_endpgm()] + self._assert_fault_and_recovery(self._run_kernel(insts)) + + def test_global_misaligned_b128(self): + insts = [v_mov_b32_e32(v[0], 0xBEEF0004), v_mov_b32_e32(v[1], 0xDEAD), + global_load_b128(v[2:5], addr=v[0:1], saddr=NULL, offset=0), s_waitcnt(vmcnt=0), s_endpgm()] + self._assert_fault_and_recovery(self._run_kernel(insts)) + + +class TestSMEMFaults(TestAMFaultRecovery): + def test_smem_null_base(self): + insts = [s_mov_b32(s[2], 0), s_mov_b32(s[3], 0), + s_load_b32(s[4], s[2:3], 0, soffset=NULL), s_waitcnt(lgkmcnt=0), s_endpgm()] + self._assert_fault_and_recovery(self._run_kernel(insts)) + + def test_smem_unmapped_address(self): + insts = [s_mov_b32(s[2], 0xBEEF0000), s_mov_b32(s[3], 0xDEAD), + s_load_b32(s[4], s[2:3], 0, soffset=NULL), s_waitcnt(lgkmcnt=0), s_endpgm()] + self._assert_fault_and_recovery(self._run_kernel(insts)) + + def test_smem_misaligned_b64(self): + insts = [s_mov_b32(s[2], 0xBEEF0004), s_mov_b32(s[3], 0xDEAD), + s_load_b64(s[4:5], s[2:3], 0, soffset=NULL), s_waitcnt(lgkmcnt=0), s_endpgm()] + self._assert_fault_and_recovery(self._run_kernel(insts)) + + def test_smem_misaligned_b128(self): + insts = [s_mov_b32(s[2], 0xBEEF0004), s_mov_b32(s[3], 0xDEAD), + s_load_b128(s[4:7], s[2:3], 0, soffset=NULL), s_waitcnt(lgkmcnt=0), s_endpgm()] + self._assert_fault_and_recovery(self._run_kernel(insts)) + + +class TestIllegalInstruction(TestAMFaultRecovery): + def test_malformed_encoding(self): + self._assert_fault_and_recovery(_run_asm(_ILLEGAL_INST_ASM)) + + +class TestFlatFaults(TestAMFaultRecovery): + def test_flat_load_unmapped(self): + insts = [v_mov_b32_e32(v[0], 0xBEEF0000), v_mov_b32_e32(v[1], 0xDEAD), + flat_load_b32(v[2], addr=v[0:1], saddr=NULL, offset=0), s_waitcnt(vmcnt=0, lgkmcnt=0), s_endpgm()] + self._assert_fault_and_recovery(self._run_kernel(insts)) + + def test_flat_store_unmapped(self): + insts = [v_mov_b32_e32(v[0], 0xBEEF0000), v_mov_b32_e32(v[1], 0xDEAD), v_mov_b32_e32(v[2], 0x12345678), + flat_store_b32(addr=v[0:1], data=v[2], saddr=NULL, offset=0), s_waitcnt(vmcnt=0, lgkmcnt=0), s_endpgm()] + self._assert_fault_and_recovery(self._run_kernel(insts)) + + +class TestAtomicFaults(TestAMFaultRecovery): + def test_global_atomic_unmapped(self): + insts = [v_mov_b32_e32(v[0], 0xBEEF0000), v_mov_b32_e32(v[1], 0xDEAD), v_mov_b32_e32(v[2], 1), + global_atomic_add_u32(addr=v[0:1], data=v[2], saddr=NULL, offset=0), s_waitcnt(vmcnt=0), s_endpgm()] + self._assert_fault_and_recovery(self._run_kernel(insts)) + + def test_flat_atomic_unmapped(self): + insts = [v_mov_b32_e32(v[0], 0xBEEF0000), v_mov_b32_e32(v[1], 0xDEAD), v_mov_b32_e32(v[2], 1), + flat_atomic_add_u32(addr=v[0:1], data=v[2], saddr=NULL, offset=0), s_waitcnt(vmcnt=0, lgkmcnt=0), s_endpgm()] + self._assert_fault_and_recovery(self._run_kernel(insts)) + + +class TestRecovery(TestAMFaultRecovery): + def test_recovery_after_memviol(self): + insts = [v_mov_b32_e32(v[0], 0xBEEF0000), v_mov_b32_e32(v[1], 0xDEAD), + global_load_b32(v[2], addr=v[0:1], saddr=NULL, offset=0), s_waitcnt(vmcnt=0), s_endpgm()] + self.assertNotEqual(self._run_kernel(insts).returncode, 0) + self.assertEqual(_verify_recovery().returncode, 0) + + def test_recovery_after_illegal_inst(self): + self.assertNotEqual(_run_asm(_ILLEGAL_INST_ASM).returncode, 0) + self.assertEqual(_verify_recovery().returncode, 0) + + def test_multiple_faults_recovery(self): + insts = [v_mov_b32_e32(v[0], 0xBEEF0000), v_mov_b32_e32(v[1], 0xDEAD), + global_load_b32(v[2], addr=v[0:1], saddr=NULL, offset=0), s_waitcnt(vmcnt=0), s_endpgm()] + for _ in range(3): + self.assertNotEqual(self._run_kernel(insts).returncode, 0) + self.assertEqual(_verify_recovery().returncode, 0) + +if __name__ == "__main__": + unittest.main()