Files
tinygrad/test/amd/test_emu2_pcode.py
George Hotz e8bd432bf6 move amd emulator out of tree (#14740)
* move amd emulator out of tree

* move the readme too
2026-02-14 10:32:00 +08:00

341 lines
14 KiB
Python

"""Tests for the pcode parser."""
import unittest
from collections import defaultdict
from tinygrad.helpers import DEBUG
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp, Ops
from test.mockgpu.amd.emu import parse_pcode
from test.mockgpu.amd.pcode import parse_expr
from tinygrad.runtime.autogen.amd.rdna3.str_pcode import PCODE
from tinygrad.runtime.autogen.amd.rdna3.enum import VOP1Op, VOP2Op, SOP2Op, DSOp
def _srcs():
"""Create minimal source variables for pcode parsing."""
def u32(v=0): return UOp.const(dtypes.uint32, v)
return {'S0': u32(), 'S1': u32(), 'S2': u32(), 'SCC': u32(), 'VCC': UOp.const(dtypes.uint64, 0), 'laneId': u32()}
class TestBasicParsing(unittest.TestCase):
"""Test basic pcode parsing for common instruction patterns."""
def test_v_add_f32(self):
"""Test parsing V_ADD_F32 pcode."""
_, assigns = parse_pcode(PCODE[VOP2Op.V_ADD_F32_E32], _srcs())
self.assertEqual(len(assigns), 1)
dest, _ = assigns[0]
self.assertTrue(dest.startswith('D0'))
def test_v_lshlrev_b32(self):
"""Test parsing V_LSHLREV_B32 pcode."""
_, assigns = parse_pcode(PCODE[VOP2Op.V_LSHLREV_B32_E32], _srcs())
self.assertEqual(len(assigns), 1)
def test_s_cselect_b32(self):
"""Test parsing S_CSELECT_B32 pcode with ternary."""
_, assigns = parse_pcode(PCODE[SOP2Op.S_CSELECT_B32], _srcs())
self.assertEqual(len(assigns), 1)
def test_v_add_co_ci_u32(self):
"""Test parsing V_ADD_CO_CI_U32 with carry."""
_, assigns = parse_pcode(PCODE[VOP2Op.V_ADD_CO_CI_U32_E32], _srcs())
self.assertGreaterEqual(len(assigns), 1)
class TestWithSources(unittest.TestCase):
"""Test pcode parsing with actual source operand values."""
def test_v_add_f32_with_sources(self):
"""Test V_ADD_F32 with actual float constants."""
s0 = UOp.const(dtypes.uint32, 0x3f800000) # 1.0f
s1 = UOp.const(dtypes.uint32, 0x40000000) # 2.0f
_, assigns = parse_pcode(PCODE[VOP2Op.V_ADD_F32_E32], {'S0': s0, 'S1': s1})
self.assertEqual(len(assigns), 1)
dest, val = assigns[0]
self.assertTrue(dest.startswith('D0'))
# Result should be an ADD operation
self.assertEqual(val.op, Ops.ADD)
def test_v_mul_f32_with_sources(self):
"""Test V_MUL_F32 with actual float constants."""
s0 = UOp.const(dtypes.uint32, 0x40000000) # 2.0f
s1 = UOp.const(dtypes.uint32, 0x40400000) # 3.0f
_, assigns = parse_pcode(PCODE[VOP2Op.V_MUL_F32_E32], {'S0': s0, 'S1': s1})
self.assertEqual(len(assigns), 1)
dest, val = assigns[0]
self.assertEqual(val.op, Ops.MUL)
class TestParseExpr(unittest.TestCase):
"""Test the parse_expr function directly."""
def test_integer_literals(self):
"""Test parsing integer literals."""
self.assertEqual(parse_expr('0', {}).arg, 0)
self.assertEqual(parse_expr('42', {}).arg, 42)
self.assertEqual(parse_expr('42U', {}).arg, 42)
def test_negative_integers(self):
"""Test parsing negative integer literals."""
result = parse_expr('-1', {})
self.assertEqual(result.arg, -1)
self.assertEqual(result.dtype, dtypes.int)
def test_float_literals(self):
"""Test parsing float literals."""
result = parse_expr('1.0F', {})
self.assertEqual(result.arg, 1.0)
self.assertEqual(result.dtype, dtypes.float32)
def test_hex_literals(self):
"""Test parsing hex literals."""
result = parse_expr('0xFF', {})
self.assertEqual(result.arg, 255)
def test_variable_lookup(self):
"""Test variable lookup in parse_expr."""
vrs = {'x': UOp.const(dtypes.uint32, 42)}
result = parse_expr('x', vrs)
self.assertEqual(result.arg, 42)
def test_binary_ops(self):
"""Test parsing binary operations."""
vrs = {'a': UOp.const(dtypes.uint32, 10), 'b': UOp.const(dtypes.uint32, 5)}
# Addition
result = parse_expr('a + b', vrs)
self.assertEqual(result.op, Ops.ADD)
# Subtraction with constant folding
result = parse_expr('10 - 5', {})
self.assertEqual(result.op, Ops.CONST)
self.assertEqual(result.arg, 5)
def test_ternary(self):
"""Test parsing ternary expressions."""
vrs = {'cond': UOp.const(dtypes.bool, True), 'a': UOp.const(dtypes.uint32, 1), 'b': UOp.const(dtypes.uint32, 0)}
result = parse_expr('cond ? a : b', vrs)
self.assertEqual(result.op, Ops.WHERE)
class TestForLoopParsing(unittest.TestCase):
"""Test for loop parsing (CLZ/CTZ patterns)."""
def test_clz_pcode_exists(self):
"""Verify CLZ pcode is available."""
pcode = PCODE.get(VOP1Op.V_CLZ_I32_U32_E32)
self.assertIsNotNone(pcode)
assert pcode is not None
self.assertIn('for', pcode.lower())
def test_clz_parsing(self):
"""Test CLZ pcode parsing produces correct structure."""
pcode = PCODE[VOP1Op.V_CLZ_I32_U32_E32]
S0 = UOp.const(dtypes.uint32, 0xFFFFFFFF) # All ones - CLZ should be 0
_vrs, assigns = parse_pcode(pcode, {'S0': S0})
self.assertEqual(len(assigns), 1)
dest, val = assigns[0]
self.assertTrue(dest.startswith('D0'))
# Result should be a nested WHERE structure
self.assertEqual(val.op, Ops.WHERE)
def test_clz_with_zero(self):
"""Test CLZ with input 0 - should return -1."""
pcode = PCODE[VOP1Op.V_CLZ_I32_U32_E32]
S0 = UOp.const(dtypes.uint32, 0)
_vrs, assigns = parse_pcode(pcode, {'S0': S0})
# Check that the innermost value (default) is -1 (may be wrapped in CAST)
val = assigns[0][1]
# Traverse to innermost WHERE
while val.op == Ops.WHERE:
val = val.src[2] # false branch
# Unwrap CAST if present
while val.op == Ops.CAST:
val = val.src[0]
self.assertEqual(val.arg, -1)
def test_ctz_parsing(self):
"""Test CTZ pcode parsing."""
pcode = PCODE.get(VOP1Op.V_CTZ_I32_B32_E32)
if pcode is None:
self.skipTest("V_CTZ_I32_B32_E32 pcode not available")
S0 = UOp.const(dtypes.uint32, 1) # LSB set - CTZ should be 0
_vrs, assigns = parse_pcode(pcode, {'S0': S0})
self.assertEqual(len(assigns), 1)
class TestDSPcodePatterns(unittest.TestCase):
"""Test DS instruction pcode patterns."""
def test_ds_load_b32_pcode(self):
"""Test DS_LOAD_B32 pcode is parseable."""
pcode = PCODE.get(DSOp.DS_LOAD_B32)
self.assertIsNotNone(pcode)
assert pcode is not None
self.assertIn('RETURN_DATA', pcode)
self.assertIn('MEM[', pcode)
def test_ds_store_b32_pcode(self):
"""Test DS_STORE_B32 pcode is parseable."""
pcode = PCODE.get(DSOp.DS_STORE_B32)
self.assertIsNotNone(pcode)
assert pcode is not None
self.assertIn('MEM[', pcode)
self.assertIn('DATA', pcode)
def test_mem_read_parsing(self):
"""Test MEM[addr].type read expression parsing."""
# Create a mock LDS buffer
lds = UOp(Ops.PARAM, dtypes.uint32.ptr(16384), arg=3)
addr = UOp.const(dtypes.uint32, 0)
vrs = {'_lds': lds, 'ADDR': addr, 'OFFSET': UOp.const(dtypes.uint32, 0)}
result = parse_expr('MEM[ADDR + OFFSET].b32', vrs)
# Should be an INDEX operation into LDS
self.assertIsNotNone(result)
def test_ds_store_2addr_b32_parsing(self):
"""Test DS_STORE_2ADDR_B32 pcode parsing produces MEM writes."""
pcode = PCODE.get(DSOp.DS_STORE_2ADDR_B32)
self.assertIsNotNone(pcode)
assert pcode is not None
srcs = {
'ADDR': UOp.const(dtypes.uint32, 0),
'OFFSET0': UOp.const(dtypes.uint32, 0),
'OFFSET1': UOp.const(dtypes.uint32, 1),
'DATA': UOp.const(dtypes.uint32, 0xAAAAAAAA),
'DATA2': UOp.const(dtypes.uint32, 0xBBBBBBBB),
}
srcs['laneId'] = UOp.const(dtypes.uint32, 0)
_, assigns = parse_pcode(pcode, srcs)
# Should have 2 MEM write assignments
self.assertEqual(len(assigns), 2)
for dest, val in assigns:
self.assertTrue(dest.startswith('MEM['))
# val should be (addr, write_val) tuple
self.assertIsInstance(val, tuple)
self.assertEqual(len(val), 2) # type: ignore[arg-type]
def test_ds_load_2addr_b32_parsing(self):
"""Test DS_LOAD_2ADDR_B32 pcode parsing produces RETURN_DATA assignments."""
pcode = PCODE.get(DSOp.DS_LOAD_2ADDR_B32)
self.assertIsNotNone(pcode)
assert pcode is not None
lds = UOp(Ops.PARAM, dtypes.uint32.ptr(16384), arg=3)
srcs = {
'ADDR': UOp.const(dtypes.uint32, 0),
'OFFSET0': UOp.const(dtypes.uint32, 0),
'OFFSET1': UOp.const(dtypes.uint32, 1),
'_lds': lds,
}
srcs['laneId'] = UOp.const(dtypes.uint32, 0)
_, assigns = parse_pcode(pcode, srcs)
# Should have 2 RETURN_DATA assignments
self.assertEqual(len(assigns), 2)
self.assertEqual(assigns[0][0], 'RETURN_DATA[31:0]')
self.assertEqual(assigns[1][0], 'RETURN_DATA[63:32]')
def test_ds_store_address_calculation(self):
"""Test DS_STORE_2ADDR_B32 calculates correct addresses (offset * 4)."""
pcode = PCODE.get(DSOp.DS_STORE_2ADDR_B32)
assert pcode is not None
srcs = {
'ADDR': UOp.const(dtypes.uint32, 100),
'OFFSET0': UOp.const(dtypes.uint32, 2),
'OFFSET1': UOp.const(dtypes.uint32, 5),
'DATA': UOp.const(dtypes.uint32, 0xAAAAAAAA),
'DATA2': UOp.const(dtypes.uint32, 0xBBBBBBBB),
}
srcs['laneId'] = UOp.const(dtypes.uint32, 0)
_, assigns = parse_pcode(pcode, srcs)
# Check addresses: 100 + 2*4 = 108, 100 + 5*4 = 120
# assigns[i][1] is (addr, val) tuple for MEM writes; mypy sees UOp
self.assertEqual(assigns[0][1][0].simplify().arg, 108) # type: ignore[index]
self.assertEqual(assigns[1][1][0].simplify().arg, 120) # type: ignore[index]
def test_ds_store_data_values(self):
"""Test DS_STORE_2ADDR_B32 uses correct data values."""
pcode = PCODE.get(DSOp.DS_STORE_2ADDR_B32)
assert pcode is not None
srcs = {
'ADDR': UOp.const(dtypes.uint32, 0),
'OFFSET0': UOp.const(dtypes.uint32, 0),
'OFFSET1': UOp.const(dtypes.uint32, 1),
'DATA': UOp.const(dtypes.uint32, 0xAAAAAAAA),
'DATA2': UOp.const(dtypes.uint32, 0xBBBBBBBB),
}
srcs['laneId'] = UOp.const(dtypes.uint32, 0)
_, assigns = parse_pcode(pcode, srcs)
# assigns[i][1] is (addr, val) tuple for MEM writes; mypy sees UOp
# DATA[31:0] should preserve the value
self.assertEqual(assigns[0][1][1].simplify().arg, 0xAAAAAAAA) # type: ignore[index]
self.assertEqual(assigns[1][1][1].simplify().arg, 0xBBBBBBBB) # type: ignore[index]
class TestConditionalParsing(unittest.TestCase):
"""Test conditional (if/elsif/else) pcode parsing."""
def test_ternary_in_assignment(self):
"""Test parsing ternary expression (which becomes WHERE)."""
# S_CSELECT_B32: D0.u32 = SCC ? S0.u32 : S1.u32
pcode = PCODE[SOP2Op.S_CSELECT_B32]
s0 = UOp.const(dtypes.uint32, 10)
s1 = UOp.const(dtypes.uint32, 20)
scc = UOp.const(dtypes.uint32, 1)
_vrs, assigns = parse_pcode(pcode, {'S0': s0, 'S1': s1, 'SCC': scc})
self.assertEqual(len(assigns), 1)
dest, val = assigns[0]
self.assertTrue(dest.startswith('D0'))
# Result should be a WHERE (ternary becomes WHERE)
self.assertEqual(val.op, Ops.WHERE)
class TestAllPcode(unittest.TestCase):
"""Test that all pcode from all architectures can be parsed."""
def _make_srcs(self):
"""Create dummy source variables for pcode parsing."""
u32, u64 = lambda v=0: UOp.const(dtypes.uint32, v), lambda v=0: UOp.const(dtypes.uint64, v)
lds = UOp(Ops.PARAM, dtypes.uint32.ptr(16384), arg=3)
return {'laneId': u32(), 'laneID': u32(), 'S0': u32(), 'S1': u32(), 'S2': u32(), 'S3': u32(), 'SRC0': u32(),
'D0': u32(), 'D1': u32(), 'DST': u32(), 'VDST': u32(), 'SDST': u32(),
'VCC': u64(), 'VCCZ': u32(), 'EXEC': u64(), 'EXEC_LO': u32(), 'EXECZ': u32(), 'SCC': u32(),
'SIMM16': u32(), 'SIMM32': u32(), 'OFFSET': u32(), 'OFFSET0': u32(), 'OFFSET1': u32(), 'offset1': u32(),
'ADDR': u32(), 'ADDR_BASE': u32(), 'TADDR': u32(), 'DATA': u32(), 'DATA0': u32(), 'DATA1': u32(), 'DATA2': u32(),
'VDATA': u32(), 'VDATA0': u32(), 'VDATA1': u32(), 'VDATA2': u32(), 'VDATA3': u32(),
'OPSEL': u32(), 'OPSEL_HI': u32(), 'NEG': u32(), 'NEG_HI': u32(), 'CLAMP': u32(),
'M0': u32(), 'PC': u64(), 'DENORM': u32(1), 'ROUND_MODE': u32(), 'ROUND_TOWARD_ZERO': u32(),
'ROUND_NEAREST_EVEN': u32(), 'WAVE_STATUS': u32(),
'MAX_FLOAT_F32': u32(0x7f7fffff), 'Unsigned': u32(1), 'clampedLOD': u32(),
'_lds': lds, '_vmem': lds, '_active': UOp.const(dtypes.bool, True)}
def _parse_all_pcode(self, pcode_dict, arch: str, min_pct: float):
"""Parse all pcode. RuntimeError = parser limitation (ok), other exceptions = real bugs."""
srcs = self._make_srcs()
passed, skipped, errors = 0, 0, defaultdict(list)
for op, pcode in pcode_dict.items():
try:
parse_pcode(pcode, srcs)
passed += 1
except RuntimeError as e:
skipped += 1
errors[str(e)].append(op.name)
except Exception as e: self.fail(f"[{arch}] {op.name}: {e}\nPcode: {pcode[:200]}")
total = len(pcode_dict)
pct = 100 * passed / total
print(f"{arch}: {passed}/{total} ({pct:.1f}%) parsed, {skipped} skipped")
if DEBUG >= 2:
for err, ops in sorted(errors.items(), key=lambda x: -len(x[1])):
print(f" {err}: {', '.join(ops[:5])}{'...' if len(ops) > 5 else ''} ({len(ops)})")
self.assertGreaterEqual(pct, min_pct, f"[{arch}] {pct:.1f}% < {min_pct}% threshold")
def test_parse_all_cdna_pcode(self):
from tinygrad.runtime.autogen.amd.cdna.str_pcode import PCODE as CDNA_PCODE
self._parse_all_pcode(CDNA_PCODE, "CDNA", min_pct=60)
def test_parse_all_rdna3_pcode(self):
from tinygrad.runtime.autogen.amd.rdna3.str_pcode import PCODE as RDNA3_PCODE
self._parse_all_pcode(RDNA3_PCODE, "RDNA3", min_pct=90)
def test_parse_all_rdna4_pcode(self):
from tinygrad.runtime.autogen.amd.rdna4.str_pcode import PCODE as RDNA4_PCODE
self._parse_all_pcode(RDNA4_PCODE, "RDNA4", min_pct=65)
if __name__ == "__main__":
unittest.main()