use the asm dsl in remu hwtest.py (#13856)

* remu hw test with the asm dsl

* simpler

* nthreads and exec mask

* cmp/cmpx

* assembler error in s_mov_b32

* vopd in dsl?
This commit is contained in:
qazal
2025-12-28 11:32:41 +09:00
committed by GitHub
parent 784b919f7f
commit 2180eee5e4
2 changed files with 115 additions and 93 deletions

View File

@@ -1,32 +1,37 @@
# ruff: noqa: F405, F403
# allow define from star imports
import numpy as np import numpy as np
import unittest import unittest
import subprocess, struct, math import subprocess, struct, math, textwrap
from tinygrad import Tensor, dtypes, Device, UOp from tinygrad import Tensor, dtypes, Device, UOp
from tinygrad.uop.ops import Ops
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from tinygrad.runtime.support.compiler_amd import amdgpu_disassemble from tinygrad.runtime.support.compiler_amd import amdgpu_disassemble
from tinygrad.renderer import ProgramSpec from tinygrad.renderer import ProgramSpec
from tinygrad.engine.realize import CompiledRunner from tinygrad.engine.realize import CompiledRunner
def get_output(asm:str, n_threads:int=1): from extra.assembly.rdna3.autogen import *
input_asm = "\n".join([ln if ln.strip().startswith('asm volatile') else f'asm volatile("{ln.strip().lstrip()}" : "+v"(a), "+v"(b));' from extra.assembly.rdna3.asm import waitcnt
for ln in asm.strip().splitlines() if ln.strip()]) from test.testextra.test_cfg_viz import template
src = f"""
typedef long unsigned int size_t; def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
extern "C" __attribute__((device, const)) size_t __ockl_get_local_id(unsigned int); out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize()
extern "C" __attribute__((global)) void __attribute__((amdgpu_flat_work_group_size(1, {n_threads}))) test(unsigned int* data0_1) {{ src = "\n".join(inst.disasm() for inst in [
int l = __ockl_get_local_id(0); s_load_b64(s[0:1], s[0:1], NULL),
unsigned a = 0, b = 0, c = 0; *asm,
{input_asm} v_lshlrev_b32_e32(v[0], 2, v[0]),
unsigned res; s_waitcnt(simm16=waitcnt(lgkmcnt=0)),
asm volatile("v_mov_b32 %0, %1" : "=v"(res) : "v"(a)); #global_store_b32(v[0], v[1], s[0:1]),
*(data0_1+l) = res; global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]),
}}""" s_endpgm()
t = Tensor.zeros(n_threads, dtype=dtypes.uint32).contiguous().realize() ])
prg = ProgramSpec("test", src, Device.DEFAULT, UOp.sink(t), global_size=[1, 1, 1], local_size=[n_threads, 1, 1]) prg = ProgramSpec("test", template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK),
global_size=[1, 1, 1], local_size=[n_threads, 1, 1], globals=[0])
car = CompiledRunner(prg) car = CompiledRunner(prg)
if getenv("PRINT_ASM"): amdgpu_disassemble(car.lib) if getenv("PRINT_ASM"): amdgpu_disassemble(car.lib)
car([t.uop.buffer], {}, wait=True) car([out.uop.buffer], {}, wait=True)
return t.numpy() return out.tolist()
def f16_to_bits(x:float) -> int: return struct.unpack('<H', struct.pack('<e', x))[0] def f16_to_bits(x:float) -> int: return struct.unpack('<H', struct.pack('<e', x))[0]
def f32_from_bits(x:int) -> float: return struct.unpack('<f', struct.pack('<I', x))[0] def f32_from_bits(x:int) -> float: return struct.unpack('<f', struct.pack('<I', x))[0]
@@ -37,83 +42,95 @@ class TestHW(unittest.TestCase):
def setUp(self): def setUp(self):
if getenv("MOCKGPU"): subprocess.run(["cargo", "build", "--release", "--manifest-path", "./extra/remu/Cargo.toml"], check=True) if getenv("MOCKGPU"): subprocess.run(["cargo", "build", "--release", "--manifest-path", "./extra/remu/Cargo.toml"], check=True)
def test_simple(self): def test_simple_v_mov(self):
out = get_output(""" out = get_output([
v_mov_b32_e32 %1 42 v_mov_b32_e32(v[1], 2),
v_mov_b32_e32 %2 %1 ])
""")[0] self.assertEqual(out, [2])
np.testing.assert_equal(out, 42)
# assembler err
@unittest.expectedFailure
def test_simple_s_mov(self):
out = get_output([
s_mov_b32(s[7], 0x7fffffff),
v_mov_b32_e32(v[1], s[7]),
])
self.assertEqual(out, [2])
def test_exec_mov(self): def test_exec_mov(self):
out = get_output(""" out = get_output([
v_mov_b32_e32 %1 42 v_mov_b32_e32(v[1], 42),
s_mov_b32_e32 exec_lo 0b10 s_mov_b32(EXEC_LO, 0b10),
v_mov_b32_e32 %1 10 v_mov_b32_e32(v[1], 10),
s_mov_b32_e32 exec_lo 0b11 s_mov_b32(EXEC_LO, 0b11),
v_mov_b32_e32 %2 %1 ], n_threads=2)
""", n_threads=2)
np.testing.assert_equal(out, [42, 10]) np.testing.assert_equal(out, [42, 10])
def test_exec_cmp_vopc(self): def test_exec_cmp_vopc(self):
out = get_output(""" out = get_output([
s_mov_b32 vcc_lo 0 // reset vcc s_mov_b32(VCC_LO, 0), # reset vcc
v_mov_b32_e32 %1 42 v_mov_b32_e32(v[1], 42),
v_mov_b32_e32 %2 10 v_mov_b32_e32(v[2], 10),
s_mov_b32_e32 exec_lo 0b01 s_mov_b32(EXEC_LO, 0b01),
v_cmp_ne_u32 %1 %2 v_cmp_ne_u32_e32(v[1], v[2]),
s_mov_b32_e32 exec_lo 0b11 s_mov_b32(EXEC_LO, 0b11),
v_mov_b32_e32 %2 vcc_lo v_mov_b32_e32(v[1], VCC_LO),
""", n_threads=2) ], n_threads=2)[0]
np.testing.assert_equal(out, 0b01) np.testing.assert_equal(out, 1)
def test_exec_cmpx_vop3(self): def test_exec_cmpx_vop3(self):
out = get_output(""" out = get_output([
s_mov_b32_e32 exec_lo 0b11 s_mov_b32(EXEC_LO, 0b11),
v_mov_b32_e32 %1 42 v_mov_b32_e32(v[1], 42),
v_mov_b32_e32 %2 10 v_mov_b32_e32(v[2], 10),
s_mov_b32_e32 exec_lo 0b01 s_mov_b32(EXEC_LO, 0b01),
v_cmpx_ne_u32 %1 %2 v_cmpx_ne_u32_e32(v[1], v[2]),
s_mov_b32_e32 s10 exec_lo s_mov_b32(s[10], EXEC_LO),
s_mov_b32_e32 exec_lo 0b11 s_mov_b32(EXEC_LO, 0b11),
v_mov_b32_e32 %2 s10 v_mov_b32_e32(v[1], s[10]),
""", n_threads=2)[0] ], n_threads=2)[0]
np.testing.assert_equal(out & 0b11, 0b01) np.testing.assert_equal(out & 0b11, 0b01)
def test_fmac_vop3_modifier(self): def test_fmac_vop3_modifier(self):
init_state = f""" init_state = [
asm volatile("v_mov_b32_e32 %1, {f16_to_bits(4.0)}" : "+v"(a)); v_mov_b32_e32(a:=v[1], f16_to_bits(4.0)),
asm volatile("v_mov_b32_e32 %1, {f16_to_bits(3.0)}" : "+v"(b)); v_mov_b32_e32(b:=v[2], f16_to_bits(3.0)),
asm volatile("v_mov_b32_e32 %1, {f16_to_bits(2.0)}" : "+v"(c)); v_mov_b32_e32(c:=v[3], f16_to_bits(2.0)),
""" ]
mov = """asm volatile("v_mov_b32_e32 %1, %2" : "+v"(c), "+v"(a));""" def run_fmac(a, b): return get_output(init_state+[v_fmac_f16_e64(c, a, b)], vdst=c)[0]
def fmac(a, b, c): return f"""asm volatile("v_fmac_f16_e64 {c}, {a}, {b}" : "+v"(c) : "v"(a), "v"(b));"""+"\n"+mov self.assertEqual(run_fmac(a, b), f16_to_bits(14.0))
self.assertEqual(get_output(init_state+"\n"+fmac("%1", "%2", "%3")), f16_to_bits(14.)) self.assertEqual(run_fmac(a, -b), f16_to_bits(-10.0))
self.assertEqual(get_output(init_state+"\n"+fmac("%1", "-%2", "%3")), f16_to_bits(-10.)) self.assertEqual(run_fmac(-a, -b), f16_to_bits(14.0))
self.assertEqual(get_output(init_state+"\n"+fmac("-%1", "-%2", "%3")), f16_to_bits(14.))
# assembler err
@unittest.expectedFailure
def test_s_abs_i32(self): def test_s_abs_i32(self):
def s_abs_i32(x, y, dst="s10", scc=0): def check(x, y, dst=s[10], scc=0):
for reg,val in [(dst, y), ("scc", scc)]: for reg,val in [(dst, y), (SCC, scc)]:
self.assertEqual(get_output(f""" self.assertEqual(get_output([
s_mov_b32_e32 {dst} {x} s_mov_b32(dst, x),
s_abs_i32 {dst} {dst} s_abs_i32(dst, dst),
v_mov_b32_e32 %2 {reg} v_mov_b32_e32(v[1], reg)
""")[0], val) ])[0], val)
s_abs_i32(0x00000001, 0x00000001, scc=1)
s_abs_i32(0x7fffffff, 0x7fffffff, scc=1)
s_abs_i32(0x80000000, 0x80000000, scc=1)
s_abs_i32(0x80000001, 0x7fffffff, scc=1)
s_abs_i32(0x80000002, 0x7ffffffe, scc=1)
s_abs_i32(0xffffffff, 0x00000001, scc=1)
s_abs_i32(0, 0, scc=0)
check(0x00000001, 0x00000001, scc=1)
check(0x7fffffff, 0x7fffffff, scc=1)
check(0x80000000, 0x80000000, scc=1)
check(0x80000001, 0x7fffffff, scc=1)
check(0x80000002, 0x7ffffffe, scc=1)
check(0xffffffff, 0x00000001, scc=1)
check(0, 0, scc=0)
# how do I negate a VGPR operand?
@unittest.expectedFailure
def test_v_rcp_f32_neg_vop3(self): def test_v_rcp_f32_neg_vop3(self):
def v_neg_rcp_f32(x:float, y:float): def v_neg_rcp_f32(x:float, y:float):
out = get_output(f""" out = get_output([
v_mov_b32_e32 %2 {f32_to_bits(x)} v_mov_b32_e32(v[2], f32_to_bits(x)),
v_rcp_f32_e64 %2, -%2 v_rcp_f32_e64(v[2], -v[2]),
""")[0] ], vdst=v[2])[0]
assert out == f32_to_bits(y), f"{f32_from_bits(out)} != {y} / {out} != {f32_to_bits(y)}" assert out == f32_to_bits(y), f"{f32_from_bits(out)} != {y} / {out} != {f32_to_bits(y)}"
v_neg_rcp_f32(math.inf, -0.0) v_neg_rcp_f32(math.inf, -0.0)
v_neg_rcp_f32(-math.inf, 0.0) v_neg_rcp_f32(-math.inf, 0.0)
v_neg_rcp_f32(0.0, -math.inf) v_neg_rcp_f32(0.0, -math.inf)
@@ -121,26 +138,31 @@ class TestHW(unittest.TestCase):
v_neg_rcp_f32(-2.0, 0.5) v_neg_rcp_f32(-2.0, 0.5)
v_neg_rcp_f32(2.0, -0.5) v_neg_rcp_f32(2.0, -0.5)
# how do I negate a VGPR operand?
@unittest.expectedFailure
def test_v_cndmask_b32_neg(self): def test_v_cndmask_b32_neg(self):
def v_neg(x:int|float, y:float): def v_neg(x:float, y:float):
# always pick -v1 out = get_output([
out = get_output(f""" v_mov_b32_e32(v[1], f32_to_bits(x)),
v_mov_b32_e32 %2 {f32_to_bits(x)} s_mov_b32(s[10], 1),
s_mov_b32_e32 s10 1 v_cndmask_b32_e32(v[1], v[1], -v[1], s[10]),
v_cndmask_b32 %2, %2, -%2 s10 ])[0]
""")[0]
assert out == f32_to_bits(y), f"{f32_from_bits(out)} != {y} / {out} != {f32_to_bits(y)}" assert out == f32_to_bits(y), f"{f32_from_bits(out)} != {y} / {out} != {f32_to_bits(y)}"
v_neg(-0.0, 0.0) v_neg(-0.0, 0.0)
v_neg(0.0, -0.0) v_neg(0.0, -0.0)
v_neg(2.0, -2.0) v_neg(2.0, -2.0)
v_neg(math.inf, -math.inf) v_neg(math.inf, -math.inf)
v_neg(-math.inf, math.inf) v_neg(-math.inf, math.inf)
@unittest.skip("how does VOPD work in the dsl")
def test_v_subrev_wrap(self): def test_v_subrev_wrap(self):
out = get_output(""" out = get_output([
v_dual_mov_b32 %1, 0xffffffff :: v_dual_mov_b32 %2, 0x0 #v_dual_mov_b32(v[1], 0xffffffff, v[2], 0x0),
v_subrev_co_u32 %2, vcc_lo, %2, %1 #v_dual_mov_b32(vdstx=v[1], srcx=0xffffffff, vdsty=v[2], srcy=0x0),
""")[0] #VOPD(opx=VOPDOp.V_DUAL_MOV_B32, opy=VOPDOp.V_DUAL_MOV_B32, vdstx=v[1], srcx=0xffffffff, vdsty=v[2], srcy=0x0),
v_subrev_co_u32(v[2], VCC_LO, v[2], v[1]),
], vdst=v[2])[0]
self.assertEqual(out, 0xffff_ffff) self.assertEqual(out, 0xffff_ffff)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -534,8 +534,8 @@ def parse_sqtt_print_packets(data: bytes, filter=DEFAULT_FILTER, verbose=True) -
def parse(fn:str): def parse(fn:str):
with Timing(f"unpickle {fn}: "): dat = pickle.load(open(fn, "rb")) with Timing(f"unpickle {fn}: "): dat = pickle.load(open(fn, "rb"))
if getenv("ROCM", 0): #if getenv("ROCM", 0):
with Timing(f"decode {fn}: "): ctx = decode(dat) # with Timing(f"decode {fn}: "): ctx = decode(dat)
dat_sqtt = [x for x in dat if isinstance(x, ProfileSQTTEvent)] dat_sqtt = [x for x in dat if isinstance(x, ProfileSQTTEvent)]
print(f"got {len(dat_sqtt)} SQTT events in {fn}") print(f"got {len(dat_sqtt)} SQTT events in {fn}")
return dat_sqtt return dat_sqtt