use the asm dsl in remu hwtest.py (#13856)

* remu hw test with the asm dsl

* simpler

* nthreads and exec mask

* cmp/cmpx

* assembler error in s_mov_b32

* vopd in dsl?
This commit is contained in:
qazal
2025-12-28 11:32:41 +09:00
committed by GitHub
parent 784b919f7f
commit 2180eee5e4
2 changed files with 115 additions and 93 deletions

View File

@@ -1,32 +1,37 @@
# ruff: noqa: F405, F403
# allow define from star imports
import numpy as np
import unittest
import subprocess, struct, math
import subprocess, struct, math, textwrap
from tinygrad import Tensor, dtypes, Device, UOp
from tinygrad.uop.ops import Ops
from tinygrad.helpers import getenv
from tinygrad.runtime.support.compiler_amd import amdgpu_disassemble
from tinygrad.renderer import ProgramSpec
from tinygrad.engine.realize import CompiledRunner
def get_output(asm:str, n_threads:int=1):
input_asm = "\n".join([ln if ln.strip().startswith('asm volatile') else f'asm volatile("{ln.strip().lstrip()}" : "+v"(a), "+v"(b));'
for ln in asm.strip().splitlines() if ln.strip()])
src = f"""
typedef long unsigned int size_t;
extern "C" __attribute__((device, const)) size_t __ockl_get_local_id(unsigned int);
extern "C" __attribute__((global)) void __attribute__((amdgpu_flat_work_group_size(1, {n_threads}))) test(unsigned int* data0_1) {{
int l = __ockl_get_local_id(0);
unsigned a = 0, b = 0, c = 0;
{input_asm}
unsigned res;
asm volatile("v_mov_b32 %0, %1" : "=v"(res) : "v"(a));
*(data0_1+l) = res;
}}"""
t = Tensor.zeros(n_threads, dtype=dtypes.uint32).contiguous().realize()
prg = ProgramSpec("test", src, Device.DEFAULT, UOp.sink(t), global_size=[1, 1, 1], local_size=[n_threads, 1, 1])
from extra.assembly.rdna3.autogen import *
from extra.assembly.rdna3.asm import waitcnt
from test.testextra.test_cfg_viz import template
def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize()
src = "\n".join(inst.disasm() for inst in [
s_load_b64(s[0:1], s[0:1], NULL),
*asm,
v_lshlrev_b32_e32(v[0], 2, v[0]),
s_waitcnt(simm16=waitcnt(lgkmcnt=0)),
#global_store_b32(v[0], v[1], s[0:1]),
global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]),
s_endpgm()
])
prg = ProgramSpec("test", template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK),
global_size=[1, 1, 1], local_size=[n_threads, 1, 1], globals=[0])
car = CompiledRunner(prg)
if getenv("PRINT_ASM"): amdgpu_disassemble(car.lib)
car([t.uop.buffer], {}, wait=True)
return t.numpy()
car([out.uop.buffer], {}, wait=True)
return out.tolist()
def f16_to_bits(x:float) -> int: return struct.unpack('<H', struct.pack('<e', x))[0]
def f32_from_bits(x:int) -> float: return struct.unpack('<f', struct.pack('<I', x))[0]
@@ -37,83 +42,95 @@ class TestHW(unittest.TestCase):
def setUp(self):
if getenv("MOCKGPU"): subprocess.run(["cargo", "build", "--release", "--manifest-path", "./extra/remu/Cargo.toml"], check=True)
def test_simple(self):
out = get_output("""
v_mov_b32_e32 %1 42
v_mov_b32_e32 %2 %1
""")[0]
np.testing.assert_equal(out, 42)
def test_simple_v_mov(self):
out = get_output([
v_mov_b32_e32(v[1], 2),
])
self.assertEqual(out, [2])
# assembler err
@unittest.expectedFailure
def test_simple_s_mov(self):
out = get_output([
s_mov_b32(s[7], 0x7fffffff),
v_mov_b32_e32(v[1], s[7]),
])
self.assertEqual(out, [2])
def test_exec_mov(self):
out = get_output("""
v_mov_b32_e32 %1 42
s_mov_b32_e32 exec_lo 0b10
v_mov_b32_e32 %1 10
s_mov_b32_e32 exec_lo 0b11
v_mov_b32_e32 %2 %1
""", n_threads=2)
out = get_output([
v_mov_b32_e32(v[1], 42),
s_mov_b32(EXEC_LO, 0b10),
v_mov_b32_e32(v[1], 10),
s_mov_b32(EXEC_LO, 0b11),
], n_threads=2)
np.testing.assert_equal(out, [42, 10])
def test_exec_cmp_vopc(self):
out = get_output("""
s_mov_b32 vcc_lo 0 // reset vcc
v_mov_b32_e32 %1 42
v_mov_b32_e32 %2 10
s_mov_b32_e32 exec_lo 0b01
v_cmp_ne_u32 %1 %2
s_mov_b32_e32 exec_lo 0b11
v_mov_b32_e32 %2 vcc_lo
""", n_threads=2)
np.testing.assert_equal(out, 0b01)
out = get_output([
s_mov_b32(VCC_LO, 0), # reset vcc
v_mov_b32_e32(v[1], 42),
v_mov_b32_e32(v[2], 10),
s_mov_b32(EXEC_LO, 0b01),
v_cmp_ne_u32_e32(v[1], v[2]),
s_mov_b32(EXEC_LO, 0b11),
v_mov_b32_e32(v[1], VCC_LO),
], n_threads=2)[0]
np.testing.assert_equal(out, 1)
def test_exec_cmpx_vop3(self):
out = get_output("""
s_mov_b32_e32 exec_lo 0b11
v_mov_b32_e32 %1 42
v_mov_b32_e32 %2 10
s_mov_b32_e32 exec_lo 0b01
v_cmpx_ne_u32 %1 %2
s_mov_b32_e32 s10 exec_lo
s_mov_b32_e32 exec_lo 0b11
v_mov_b32_e32 %2 s10
""", n_threads=2)[0]
out = get_output([
s_mov_b32(EXEC_LO, 0b11),
v_mov_b32_e32(v[1], 42),
v_mov_b32_e32(v[2], 10),
s_mov_b32(EXEC_LO, 0b01),
v_cmpx_ne_u32_e32(v[1], v[2]),
s_mov_b32(s[10], EXEC_LO),
s_mov_b32(EXEC_LO, 0b11),
v_mov_b32_e32(v[1], s[10]),
], n_threads=2)[0]
np.testing.assert_equal(out & 0b11, 0b01)
def test_fmac_vop3_modifier(self):
init_state = f"""
asm volatile("v_mov_b32_e32 %1, {f16_to_bits(4.0)}" : "+v"(a));
asm volatile("v_mov_b32_e32 %1, {f16_to_bits(3.0)}" : "+v"(b));
asm volatile("v_mov_b32_e32 %1, {f16_to_bits(2.0)}" : "+v"(c));
"""
mov = """asm volatile("v_mov_b32_e32 %1, %2" : "+v"(c), "+v"(a));"""
def fmac(a, b, c): return f"""asm volatile("v_fmac_f16_e64 {c}, {a}, {b}" : "+v"(c) : "v"(a), "v"(b));"""+"\n"+mov
self.assertEqual(get_output(init_state+"\n"+fmac("%1", "%2", "%3")), f16_to_bits(14.))
self.assertEqual(get_output(init_state+"\n"+fmac("%1", "-%2", "%3")), f16_to_bits(-10.))
self.assertEqual(get_output(init_state+"\n"+fmac("-%1", "-%2", "%3")), f16_to_bits(14.))
init_state = [
v_mov_b32_e32(a:=v[1], f16_to_bits(4.0)),
v_mov_b32_e32(b:=v[2], f16_to_bits(3.0)),
v_mov_b32_e32(c:=v[3], f16_to_bits(2.0)),
]
def run_fmac(a, b): return get_output(init_state+[v_fmac_f16_e64(c, a, b)], vdst=c)[0]
self.assertEqual(run_fmac(a, b), f16_to_bits(14.0))
self.assertEqual(run_fmac(a, -b), f16_to_bits(-10.0))
self.assertEqual(run_fmac(-a, -b), f16_to_bits(14.0))
# assembler err
@unittest.expectedFailure
def test_s_abs_i32(self):
def s_abs_i32(x, y, dst="s10", scc=0):
for reg,val in [(dst, y), ("scc", scc)]:
self.assertEqual(get_output(f"""
s_mov_b32_e32 {dst} {x}
s_abs_i32 {dst} {dst}
v_mov_b32_e32 %2 {reg}
""")[0], val)
s_abs_i32(0x00000001, 0x00000001, scc=1)
s_abs_i32(0x7fffffff, 0x7fffffff, scc=1)
s_abs_i32(0x80000000, 0x80000000, scc=1)
s_abs_i32(0x80000001, 0x7fffffff, scc=1)
s_abs_i32(0x80000002, 0x7ffffffe, scc=1)
s_abs_i32(0xffffffff, 0x00000001, scc=1)
s_abs_i32(0, 0, scc=0)
def check(x, y, dst=s[10], scc=0):
for reg,val in [(dst, y), (SCC, scc)]:
self.assertEqual(get_output([
s_mov_b32(dst, x),
s_abs_i32(dst, dst),
v_mov_b32_e32(v[1], reg)
])[0], val)
check(0x00000001, 0x00000001, scc=1)
check(0x7fffffff, 0x7fffffff, scc=1)
check(0x80000000, 0x80000000, scc=1)
check(0x80000001, 0x7fffffff, scc=1)
check(0x80000002, 0x7ffffffe, scc=1)
check(0xffffffff, 0x00000001, scc=1)
check(0, 0, scc=0)
# how do I negate a VGPR operand?
@unittest.expectedFailure
def test_v_rcp_f32_neg_vop3(self):
def v_neg_rcp_f32(x:float, y:float):
out = get_output(f"""
v_mov_b32_e32 %2 {f32_to_bits(x)}
v_rcp_f32_e64 %2, -%2
""")[0]
out = get_output([
v_mov_b32_e32(v[2], f32_to_bits(x)),
v_rcp_f32_e64(v[2], -v[2]),
], vdst=v[2])[0]
assert out == f32_to_bits(y), f"{f32_from_bits(out)} != {y} / {out} != {f32_to_bits(y)}"
v_neg_rcp_f32(math.inf, -0.0)
v_neg_rcp_f32(-math.inf, 0.0)
v_neg_rcp_f32(0.0, -math.inf)
@@ -121,26 +138,31 @@ class TestHW(unittest.TestCase):
v_neg_rcp_f32(-2.0, 0.5)
v_neg_rcp_f32(2.0, -0.5)
# how do I negate a VGPR operand?
@unittest.expectedFailure
def test_v_cndmask_b32_neg(self):
def v_neg(x:int|float, y:float):
# always pick -v1
out = get_output(f"""
v_mov_b32_e32 %2 {f32_to_bits(x)}
s_mov_b32_e32 s10 1
v_cndmask_b32 %2, %2, -%2 s10
""")[0]
def v_neg(x:float, y:float):
out = get_output([
v_mov_b32_e32(v[1], f32_to_bits(x)),
s_mov_b32(s[10], 1),
v_cndmask_b32_e32(v[1], v[1], -v[1], s[10]),
])[0]
assert out == f32_to_bits(y), f"{f32_from_bits(out)} != {y} / {out} != {f32_to_bits(y)}"
v_neg(-0.0, 0.0)
v_neg(0.0, -0.0)
v_neg(2.0, -2.0)
v_neg(math.inf, -math.inf)
v_neg(-math.inf, math.inf)
@unittest.skip("how does VOPD work in the dsl")
def test_v_subrev_wrap(self):
out = get_output("""
v_dual_mov_b32 %1, 0xffffffff :: v_dual_mov_b32 %2, 0x0
v_subrev_co_u32 %2, vcc_lo, %2, %1
""")[0]
out = get_output([
#v_dual_mov_b32(v[1], 0xffffffff, v[2], 0x0),
#v_dual_mov_b32(vdstx=v[1], srcx=0xffffffff, vdsty=v[2], srcy=0x0),
#VOPD(opx=VOPDOp.V_DUAL_MOV_B32, opy=VOPDOp.V_DUAL_MOV_B32, vdstx=v[1], srcx=0xffffffff, vdsty=v[2], srcy=0x0),
v_subrev_co_u32(v[2], VCC_LO, v[2], v[1]),
], vdst=v[2])[0]
self.assertEqual(out, 0xffff_ffff)
if __name__ == "__main__":

View File

@@ -534,8 +534,8 @@ def parse_sqtt_print_packets(data: bytes, filter=DEFAULT_FILTER, verbose=True) -
def parse(fn:str):
with Timing(f"unpickle {fn}: "): dat = pickle.load(open(fn, "rb"))
if getenv("ROCM", 0):
with Timing(f"decode {fn}: "): ctx = decode(dat)
#if getenv("ROCM", 0):
# with Timing(f"decode {fn}: "): ctx = decode(dat)
dat_sqtt = [x for x in dat if isinstance(x, ProfileSQTTEvent)]
print(f"got {len(dat_sqtt)} SQTT events in {fn}")
return dat_sqtt