mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
use the asm dsl in remu hwtest.py (#13856)
* remu hw test with the asm dsl * simpler * nthreads and exec mask * cmp/cmpx * assembler error in s_mov_b32 * vopd in dsl?
This commit is contained in:
@@ -1,32 +1,37 @@
|
|||||||
|
# ruff: noqa: F405, F403
|
||||||
|
# allow define from star imports
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import unittest
|
import unittest
|
||||||
import subprocess, struct, math
|
import subprocess, struct, math, textwrap
|
||||||
from tinygrad import Tensor, dtypes, Device, UOp
|
from tinygrad import Tensor, dtypes, Device, UOp
|
||||||
|
from tinygrad.uop.ops import Ops
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import getenv
|
||||||
from tinygrad.runtime.support.compiler_amd import amdgpu_disassemble
|
from tinygrad.runtime.support.compiler_amd import amdgpu_disassemble
|
||||||
from tinygrad.renderer import ProgramSpec
|
from tinygrad.renderer import ProgramSpec
|
||||||
from tinygrad.engine.realize import CompiledRunner
|
from tinygrad.engine.realize import CompiledRunner
|
||||||
|
|
||||||
def get_output(asm:str, n_threads:int=1):
|
from extra.assembly.rdna3.autogen import *
|
||||||
input_asm = "\n".join([ln if ln.strip().startswith('asm volatile') else f'asm volatile("{ln.strip().lstrip()}" : "+v"(a), "+v"(b));'
|
from extra.assembly.rdna3.asm import waitcnt
|
||||||
for ln in asm.strip().splitlines() if ln.strip()])
|
from test.testextra.test_cfg_viz import template
|
||||||
src = f"""
|
|
||||||
typedef long unsigned int size_t;
|
def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
|
||||||
extern "C" __attribute__((device, const)) size_t __ockl_get_local_id(unsigned int);
|
out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize()
|
||||||
extern "C" __attribute__((global)) void __attribute__((amdgpu_flat_work_group_size(1, {n_threads}))) test(unsigned int* data0_1) {{
|
src = "\n".join(inst.disasm() for inst in [
|
||||||
int l = __ockl_get_local_id(0);
|
s_load_b64(s[0:1], s[0:1], NULL),
|
||||||
unsigned a = 0, b = 0, c = 0;
|
*asm,
|
||||||
{input_asm}
|
v_lshlrev_b32_e32(v[0], 2, v[0]),
|
||||||
unsigned res;
|
s_waitcnt(simm16=waitcnt(lgkmcnt=0)),
|
||||||
asm volatile("v_mov_b32 %0, %1" : "=v"(res) : "v"(a));
|
#global_store_b32(v[0], v[1], s[0:1]),
|
||||||
*(data0_1+l) = res;
|
global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]),
|
||||||
}}"""
|
s_endpgm()
|
||||||
t = Tensor.zeros(n_threads, dtype=dtypes.uint32).contiguous().realize()
|
])
|
||||||
prg = ProgramSpec("test", src, Device.DEFAULT, UOp.sink(t), global_size=[1, 1, 1], local_size=[n_threads, 1, 1])
|
prg = ProgramSpec("test", template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK),
|
||||||
|
global_size=[1, 1, 1], local_size=[n_threads, 1, 1], globals=[0])
|
||||||
car = CompiledRunner(prg)
|
car = CompiledRunner(prg)
|
||||||
if getenv("PRINT_ASM"): amdgpu_disassemble(car.lib)
|
if getenv("PRINT_ASM"): amdgpu_disassemble(car.lib)
|
||||||
car([t.uop.buffer], {}, wait=True)
|
car([out.uop.buffer], {}, wait=True)
|
||||||
return t.numpy()
|
return out.tolist()
|
||||||
|
|
||||||
def f16_to_bits(x:float) -> int: return struct.unpack('<H', struct.pack('<e', x))[0]
|
def f16_to_bits(x:float) -> int: return struct.unpack('<H', struct.pack('<e', x))[0]
|
||||||
def f32_from_bits(x:int) -> float: return struct.unpack('<f', struct.pack('<I', x))[0]
|
def f32_from_bits(x:int) -> float: return struct.unpack('<f', struct.pack('<I', x))[0]
|
||||||
@@ -37,83 +42,95 @@ class TestHW(unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
if getenv("MOCKGPU"): subprocess.run(["cargo", "build", "--release", "--manifest-path", "./extra/remu/Cargo.toml"], check=True)
|
if getenv("MOCKGPU"): subprocess.run(["cargo", "build", "--release", "--manifest-path", "./extra/remu/Cargo.toml"], check=True)
|
||||||
|
|
||||||
def test_simple(self):
|
def test_simple_v_mov(self):
|
||||||
out = get_output("""
|
out = get_output([
|
||||||
v_mov_b32_e32 %1 42
|
v_mov_b32_e32(v[1], 2),
|
||||||
v_mov_b32_e32 %2 %1
|
])
|
||||||
""")[0]
|
self.assertEqual(out, [2])
|
||||||
np.testing.assert_equal(out, 42)
|
|
||||||
|
# assembler err
|
||||||
|
@unittest.expectedFailure
|
||||||
|
def test_simple_s_mov(self):
|
||||||
|
out = get_output([
|
||||||
|
s_mov_b32(s[7], 0x7fffffff),
|
||||||
|
v_mov_b32_e32(v[1], s[7]),
|
||||||
|
])
|
||||||
|
self.assertEqual(out, [2])
|
||||||
|
|
||||||
def test_exec_mov(self):
|
def test_exec_mov(self):
|
||||||
out = get_output("""
|
out = get_output([
|
||||||
v_mov_b32_e32 %1 42
|
v_mov_b32_e32(v[1], 42),
|
||||||
s_mov_b32_e32 exec_lo 0b10
|
s_mov_b32(EXEC_LO, 0b10),
|
||||||
v_mov_b32_e32 %1 10
|
v_mov_b32_e32(v[1], 10),
|
||||||
s_mov_b32_e32 exec_lo 0b11
|
s_mov_b32(EXEC_LO, 0b11),
|
||||||
v_mov_b32_e32 %2 %1
|
], n_threads=2)
|
||||||
""", n_threads=2)
|
|
||||||
np.testing.assert_equal(out, [42, 10])
|
np.testing.assert_equal(out, [42, 10])
|
||||||
|
|
||||||
def test_exec_cmp_vopc(self):
|
def test_exec_cmp_vopc(self):
|
||||||
out = get_output("""
|
out = get_output([
|
||||||
s_mov_b32 vcc_lo 0 // reset vcc
|
s_mov_b32(VCC_LO, 0), # reset vcc
|
||||||
v_mov_b32_e32 %1 42
|
v_mov_b32_e32(v[1], 42),
|
||||||
v_mov_b32_e32 %2 10
|
v_mov_b32_e32(v[2], 10),
|
||||||
s_mov_b32_e32 exec_lo 0b01
|
s_mov_b32(EXEC_LO, 0b01),
|
||||||
v_cmp_ne_u32 %1 %2
|
v_cmp_ne_u32_e32(v[1], v[2]),
|
||||||
s_mov_b32_e32 exec_lo 0b11
|
s_mov_b32(EXEC_LO, 0b11),
|
||||||
v_mov_b32_e32 %2 vcc_lo
|
v_mov_b32_e32(v[1], VCC_LO),
|
||||||
""", n_threads=2)
|
], n_threads=2)[0]
|
||||||
np.testing.assert_equal(out, 0b01)
|
np.testing.assert_equal(out, 1)
|
||||||
|
|
||||||
def test_exec_cmpx_vop3(self):
|
def test_exec_cmpx_vop3(self):
|
||||||
out = get_output("""
|
out = get_output([
|
||||||
s_mov_b32_e32 exec_lo 0b11
|
s_mov_b32(EXEC_LO, 0b11),
|
||||||
v_mov_b32_e32 %1 42
|
v_mov_b32_e32(v[1], 42),
|
||||||
v_mov_b32_e32 %2 10
|
v_mov_b32_e32(v[2], 10),
|
||||||
s_mov_b32_e32 exec_lo 0b01
|
s_mov_b32(EXEC_LO, 0b01),
|
||||||
v_cmpx_ne_u32 %1 %2
|
v_cmpx_ne_u32_e32(v[1], v[2]),
|
||||||
s_mov_b32_e32 s10 exec_lo
|
s_mov_b32(s[10], EXEC_LO),
|
||||||
s_mov_b32_e32 exec_lo 0b11
|
s_mov_b32(EXEC_LO, 0b11),
|
||||||
v_mov_b32_e32 %2 s10
|
v_mov_b32_e32(v[1], s[10]),
|
||||||
""", n_threads=2)[0]
|
], n_threads=2)[0]
|
||||||
np.testing.assert_equal(out & 0b11, 0b01)
|
np.testing.assert_equal(out & 0b11, 0b01)
|
||||||
|
|
||||||
def test_fmac_vop3_modifier(self):
|
def test_fmac_vop3_modifier(self):
|
||||||
init_state = f"""
|
init_state = [
|
||||||
asm volatile("v_mov_b32_e32 %1, {f16_to_bits(4.0)}" : "+v"(a));
|
v_mov_b32_e32(a:=v[1], f16_to_bits(4.0)),
|
||||||
asm volatile("v_mov_b32_e32 %1, {f16_to_bits(3.0)}" : "+v"(b));
|
v_mov_b32_e32(b:=v[2], f16_to_bits(3.0)),
|
||||||
asm volatile("v_mov_b32_e32 %1, {f16_to_bits(2.0)}" : "+v"(c));
|
v_mov_b32_e32(c:=v[3], f16_to_bits(2.0)),
|
||||||
"""
|
]
|
||||||
mov = """asm volatile("v_mov_b32_e32 %1, %2" : "+v"(c), "+v"(a));"""
|
def run_fmac(a, b): return get_output(init_state+[v_fmac_f16_e64(c, a, b)], vdst=c)[0]
|
||||||
def fmac(a, b, c): return f"""asm volatile("v_fmac_f16_e64 {c}, {a}, {b}" : "+v"(c) : "v"(a), "v"(b));"""+"\n"+mov
|
self.assertEqual(run_fmac(a, b), f16_to_bits(14.0))
|
||||||
self.assertEqual(get_output(init_state+"\n"+fmac("%1", "%2", "%3")), f16_to_bits(14.))
|
self.assertEqual(run_fmac(a, -b), f16_to_bits(-10.0))
|
||||||
self.assertEqual(get_output(init_state+"\n"+fmac("%1", "-%2", "%3")), f16_to_bits(-10.))
|
self.assertEqual(run_fmac(-a, -b), f16_to_bits(14.0))
|
||||||
self.assertEqual(get_output(init_state+"\n"+fmac("-%1", "-%2", "%3")), f16_to_bits(14.))
|
|
||||||
|
|
||||||
|
# assembler err
|
||||||
|
@unittest.expectedFailure
|
||||||
def test_s_abs_i32(self):
|
def test_s_abs_i32(self):
|
||||||
def s_abs_i32(x, y, dst="s10", scc=0):
|
def check(x, y, dst=s[10], scc=0):
|
||||||
for reg,val in [(dst, y), ("scc", scc)]:
|
for reg,val in [(dst, y), (SCC, scc)]:
|
||||||
self.assertEqual(get_output(f"""
|
self.assertEqual(get_output([
|
||||||
s_mov_b32_e32 {dst} {x}
|
s_mov_b32(dst, x),
|
||||||
s_abs_i32 {dst} {dst}
|
s_abs_i32(dst, dst),
|
||||||
v_mov_b32_e32 %2 {reg}
|
v_mov_b32_e32(v[1], reg)
|
||||||
""")[0], val)
|
])[0], val)
|
||||||
s_abs_i32(0x00000001, 0x00000001, scc=1)
|
|
||||||
s_abs_i32(0x7fffffff, 0x7fffffff, scc=1)
|
|
||||||
s_abs_i32(0x80000000, 0x80000000, scc=1)
|
|
||||||
s_abs_i32(0x80000001, 0x7fffffff, scc=1)
|
|
||||||
s_abs_i32(0x80000002, 0x7ffffffe, scc=1)
|
|
||||||
s_abs_i32(0xffffffff, 0x00000001, scc=1)
|
|
||||||
s_abs_i32(0, 0, scc=0)
|
|
||||||
|
|
||||||
|
check(0x00000001, 0x00000001, scc=1)
|
||||||
|
check(0x7fffffff, 0x7fffffff, scc=1)
|
||||||
|
check(0x80000000, 0x80000000, scc=1)
|
||||||
|
check(0x80000001, 0x7fffffff, scc=1)
|
||||||
|
check(0x80000002, 0x7ffffffe, scc=1)
|
||||||
|
check(0xffffffff, 0x00000001, scc=1)
|
||||||
|
check(0, 0, scc=0)
|
||||||
|
|
||||||
|
# how do I negate a VGPR operand?
|
||||||
|
@unittest.expectedFailure
|
||||||
def test_v_rcp_f32_neg_vop3(self):
|
def test_v_rcp_f32_neg_vop3(self):
|
||||||
def v_neg_rcp_f32(x:float, y:float):
|
def v_neg_rcp_f32(x:float, y:float):
|
||||||
out = get_output(f"""
|
out = get_output([
|
||||||
v_mov_b32_e32 %2 {f32_to_bits(x)}
|
v_mov_b32_e32(v[2], f32_to_bits(x)),
|
||||||
v_rcp_f32_e64 %2, -%2
|
v_rcp_f32_e64(v[2], -v[2]),
|
||||||
""")[0]
|
], vdst=v[2])[0]
|
||||||
assert out == f32_to_bits(y), f"{f32_from_bits(out)} != {y} / {out} != {f32_to_bits(y)}"
|
assert out == f32_to_bits(y), f"{f32_from_bits(out)} != {y} / {out} != {f32_to_bits(y)}"
|
||||||
|
|
||||||
v_neg_rcp_f32(math.inf, -0.0)
|
v_neg_rcp_f32(math.inf, -0.0)
|
||||||
v_neg_rcp_f32(-math.inf, 0.0)
|
v_neg_rcp_f32(-math.inf, 0.0)
|
||||||
v_neg_rcp_f32(0.0, -math.inf)
|
v_neg_rcp_f32(0.0, -math.inf)
|
||||||
@@ -121,26 +138,31 @@ class TestHW(unittest.TestCase):
|
|||||||
v_neg_rcp_f32(-2.0, 0.5)
|
v_neg_rcp_f32(-2.0, 0.5)
|
||||||
v_neg_rcp_f32(2.0, -0.5)
|
v_neg_rcp_f32(2.0, -0.5)
|
||||||
|
|
||||||
|
# how do I negate a VGPR operand?
|
||||||
|
@unittest.expectedFailure
|
||||||
def test_v_cndmask_b32_neg(self):
|
def test_v_cndmask_b32_neg(self):
|
||||||
def v_neg(x:int|float, y:float):
|
def v_neg(x:float, y:float):
|
||||||
# always pick -v1
|
out = get_output([
|
||||||
out = get_output(f"""
|
v_mov_b32_e32(v[1], f32_to_bits(x)),
|
||||||
v_mov_b32_e32 %2 {f32_to_bits(x)}
|
s_mov_b32(s[10], 1),
|
||||||
s_mov_b32_e32 s10 1
|
v_cndmask_b32_e32(v[1], v[1], -v[1], s[10]),
|
||||||
v_cndmask_b32 %2, %2, -%2 s10
|
])[0]
|
||||||
""")[0]
|
|
||||||
assert out == f32_to_bits(y), f"{f32_from_bits(out)} != {y} / {out} != {f32_to_bits(y)}"
|
assert out == f32_to_bits(y), f"{f32_from_bits(out)} != {y} / {out} != {f32_to_bits(y)}"
|
||||||
|
|
||||||
v_neg(-0.0, 0.0)
|
v_neg(-0.0, 0.0)
|
||||||
v_neg(0.0, -0.0)
|
v_neg(0.0, -0.0)
|
||||||
v_neg(2.0, -2.0)
|
v_neg(2.0, -2.0)
|
||||||
v_neg(math.inf, -math.inf)
|
v_neg(math.inf, -math.inf)
|
||||||
v_neg(-math.inf, math.inf)
|
v_neg(-math.inf, math.inf)
|
||||||
|
|
||||||
|
@unittest.skip("how does VOPD work in the dsl")
|
||||||
def test_v_subrev_wrap(self):
|
def test_v_subrev_wrap(self):
|
||||||
out = get_output("""
|
out = get_output([
|
||||||
v_dual_mov_b32 %1, 0xffffffff :: v_dual_mov_b32 %2, 0x0
|
#v_dual_mov_b32(v[1], 0xffffffff, v[2], 0x0),
|
||||||
v_subrev_co_u32 %2, vcc_lo, %2, %1
|
#v_dual_mov_b32(vdstx=v[1], srcx=0xffffffff, vdsty=v[2], srcy=0x0),
|
||||||
""")[0]
|
#VOPD(opx=VOPDOp.V_DUAL_MOV_B32, opy=VOPDOp.V_DUAL_MOV_B32, vdstx=v[1], srcx=0xffffffff, vdsty=v[2], srcy=0x0),
|
||||||
|
v_subrev_co_u32(v[2], VCC_LO, v[2], v[1]),
|
||||||
|
], vdst=v[2])[0]
|
||||||
self.assertEqual(out, 0xffff_ffff)
|
self.assertEqual(out, 0xffff_ffff)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -534,8 +534,8 @@ def parse_sqtt_print_packets(data: bytes, filter=DEFAULT_FILTER, verbose=True) -
|
|||||||
|
|
||||||
def parse(fn:str):
|
def parse(fn:str):
|
||||||
with Timing(f"unpickle {fn}: "): dat = pickle.load(open(fn, "rb"))
|
with Timing(f"unpickle {fn}: "): dat = pickle.load(open(fn, "rb"))
|
||||||
if getenv("ROCM", 0):
|
#if getenv("ROCM", 0):
|
||||||
with Timing(f"decode {fn}: "): ctx = decode(dat)
|
# with Timing(f"decode {fn}: "): ctx = decode(dat)
|
||||||
dat_sqtt = [x for x in dat if isinstance(x, ProfileSQTTEvent)]
|
dat_sqtt = [x for x in dat if isinstance(x, ProfileSQTTEvent)]
|
||||||
print(f"got {len(dat_sqtt)} SQTT events in {fn}")
|
print(f"got {len(dat_sqtt)} SQTT events in {fn}")
|
||||||
return dat_sqtt
|
return dat_sqtt
|
||||||
|
|||||||
Reference in New Issue
Block a user