mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
341 lines
14 KiB
Python
341 lines
14 KiB
Python
"""Tests for the pcode parser."""
|
|
import unittest
|
|
from collections import defaultdict
|
|
from tinygrad.helpers import DEBUG
|
|
from tinygrad.dtype import dtypes
|
|
from tinygrad.uop.ops import UOp, Ops
|
|
from test.mockgpu.amd.emu import parse_pcode
|
|
from test.mockgpu.amd.pcode import parse_expr
|
|
from tinygrad.runtime.autogen.amd.rdna3.str_pcode import PCODE
|
|
from tinygrad.runtime.autogen.amd.rdna3.enum import VOP1Op, VOP2Op, SOP2Op, DSOp
|
|
|
|
def _srcs():
|
|
"""Create minimal source variables for pcode parsing."""
|
|
def u32(v=0): return UOp.const(dtypes.uint32, v)
|
|
return {'S0': u32(), 'S1': u32(), 'S2': u32(), 'SCC': u32(), 'VCC': UOp.const(dtypes.uint64, 0), 'laneId': u32()}
|
|
|
|
class TestBasicParsing(unittest.TestCase):
|
|
"""Test basic pcode parsing for common instruction patterns."""
|
|
|
|
def test_v_add_f32(self):
|
|
"""Test parsing V_ADD_F32 pcode."""
|
|
_, assigns = parse_pcode(PCODE[VOP2Op.V_ADD_F32_E32], _srcs())
|
|
self.assertEqual(len(assigns), 1)
|
|
dest, _ = assigns[0]
|
|
self.assertTrue(dest.startswith('D0'))
|
|
|
|
def test_v_lshlrev_b32(self):
|
|
"""Test parsing V_LSHLREV_B32 pcode."""
|
|
_, assigns = parse_pcode(PCODE[VOP2Op.V_LSHLREV_B32_E32], _srcs())
|
|
self.assertEqual(len(assigns), 1)
|
|
|
|
def test_s_cselect_b32(self):
|
|
"""Test parsing S_CSELECT_B32 pcode with ternary."""
|
|
_, assigns = parse_pcode(PCODE[SOP2Op.S_CSELECT_B32], _srcs())
|
|
self.assertEqual(len(assigns), 1)
|
|
|
|
def test_v_add_co_ci_u32(self):
|
|
"""Test parsing V_ADD_CO_CI_U32 with carry."""
|
|
_, assigns = parse_pcode(PCODE[VOP2Op.V_ADD_CO_CI_U32_E32], _srcs())
|
|
self.assertGreaterEqual(len(assigns), 1)
|
|
|
|
class TestWithSources(unittest.TestCase):
|
|
"""Test pcode parsing with actual source operand values."""
|
|
|
|
def test_v_add_f32_with_sources(self):
|
|
"""Test V_ADD_F32 with actual float constants."""
|
|
s0 = UOp.const(dtypes.uint32, 0x3f800000) # 1.0f
|
|
s1 = UOp.const(dtypes.uint32, 0x40000000) # 2.0f
|
|
_, assigns = parse_pcode(PCODE[VOP2Op.V_ADD_F32_E32], {'S0': s0, 'S1': s1})
|
|
self.assertEqual(len(assigns), 1)
|
|
dest, val = assigns[0]
|
|
self.assertTrue(dest.startswith('D0'))
|
|
# Result should be an ADD operation
|
|
self.assertEqual(val.op, Ops.ADD)
|
|
|
|
def test_v_mul_f32_with_sources(self):
|
|
"""Test V_MUL_F32 with actual float constants."""
|
|
s0 = UOp.const(dtypes.uint32, 0x40000000) # 2.0f
|
|
s1 = UOp.const(dtypes.uint32, 0x40400000) # 3.0f
|
|
_, assigns = parse_pcode(PCODE[VOP2Op.V_MUL_F32_E32], {'S0': s0, 'S1': s1})
|
|
self.assertEqual(len(assigns), 1)
|
|
dest, val = assigns[0]
|
|
self.assertEqual(val.op, Ops.MUL)
|
|
|
|
class TestParseExpr(unittest.TestCase):
|
|
"""Test the parse_expr function directly."""
|
|
|
|
def test_integer_literals(self):
|
|
"""Test parsing integer literals."""
|
|
self.assertEqual(parse_expr('0', {}).arg, 0)
|
|
self.assertEqual(parse_expr('42', {}).arg, 42)
|
|
self.assertEqual(parse_expr('42U', {}).arg, 42)
|
|
|
|
def test_negative_integers(self):
|
|
"""Test parsing negative integer literals."""
|
|
result = parse_expr('-1', {})
|
|
self.assertEqual(result.arg, -1)
|
|
self.assertEqual(result.dtype, dtypes.int)
|
|
|
|
def test_float_literals(self):
|
|
"""Test parsing float literals."""
|
|
result = parse_expr('1.0F', {})
|
|
self.assertEqual(result.arg, 1.0)
|
|
self.assertEqual(result.dtype, dtypes.float32)
|
|
|
|
def test_hex_literals(self):
|
|
"""Test parsing hex literals."""
|
|
result = parse_expr('0xFF', {})
|
|
self.assertEqual(result.arg, 255)
|
|
|
|
def test_variable_lookup(self):
|
|
"""Test variable lookup in parse_expr."""
|
|
vrs = {'x': UOp.const(dtypes.uint32, 42)}
|
|
result = parse_expr('x', vrs)
|
|
self.assertEqual(result.arg, 42)
|
|
|
|
def test_binary_ops(self):
|
|
"""Test parsing binary operations."""
|
|
vrs = {'a': UOp.const(dtypes.uint32, 10), 'b': UOp.const(dtypes.uint32, 5)}
|
|
|
|
# Addition
|
|
result = parse_expr('a + b', vrs)
|
|
self.assertEqual(result.op, Ops.ADD)
|
|
|
|
# Subtraction with constant folding
|
|
result = parse_expr('10 - 5', {})
|
|
self.assertEqual(result.op, Ops.CONST)
|
|
self.assertEqual(result.arg, 5)
|
|
|
|
def test_ternary(self):
|
|
"""Test parsing ternary expressions."""
|
|
vrs = {'cond': UOp.const(dtypes.bool, True), 'a': UOp.const(dtypes.uint32, 1), 'b': UOp.const(dtypes.uint32, 0)}
|
|
result = parse_expr('cond ? a : b', vrs)
|
|
self.assertEqual(result.op, Ops.WHERE)
|
|
|
|
class TestForLoopParsing(unittest.TestCase):
|
|
"""Test for loop parsing (CLZ/CTZ patterns)."""
|
|
|
|
def test_clz_pcode_exists(self):
|
|
"""Verify CLZ pcode is available."""
|
|
pcode = PCODE.get(VOP1Op.V_CLZ_I32_U32_E32)
|
|
self.assertIsNotNone(pcode)
|
|
assert pcode is not None
|
|
self.assertIn('for', pcode.lower())
|
|
|
|
def test_clz_parsing(self):
|
|
"""Test CLZ pcode parsing produces correct structure."""
|
|
pcode = PCODE[VOP1Op.V_CLZ_I32_U32_E32]
|
|
S0 = UOp.const(dtypes.uint32, 0xFFFFFFFF) # All ones - CLZ should be 0
|
|
_vrs, assigns = parse_pcode(pcode, {'S0': S0})
|
|
|
|
self.assertEqual(len(assigns), 1)
|
|
dest, val = assigns[0]
|
|
self.assertTrue(dest.startswith('D0'))
|
|
# Result should be a nested WHERE structure
|
|
self.assertEqual(val.op, Ops.WHERE)
|
|
|
|
def test_clz_with_zero(self):
|
|
"""Test CLZ with input 0 - should return -1."""
|
|
pcode = PCODE[VOP1Op.V_CLZ_I32_U32_E32]
|
|
S0 = UOp.const(dtypes.uint32, 0)
|
|
_vrs, assigns = parse_pcode(pcode, {'S0': S0})
|
|
|
|
# Check that the innermost value (default) is -1 (may be wrapped in CAST)
|
|
val = assigns[0][1]
|
|
# Traverse to innermost WHERE
|
|
while val.op == Ops.WHERE:
|
|
val = val.src[2] # false branch
|
|
# Unwrap CAST if present
|
|
while val.op == Ops.CAST:
|
|
val = val.src[0]
|
|
self.assertEqual(val.arg, -1)
|
|
|
|
def test_ctz_parsing(self):
|
|
"""Test CTZ pcode parsing."""
|
|
pcode = PCODE.get(VOP1Op.V_CTZ_I32_B32_E32)
|
|
if pcode is None:
|
|
self.skipTest("V_CTZ_I32_B32_E32 pcode not available")
|
|
|
|
S0 = UOp.const(dtypes.uint32, 1) # LSB set - CTZ should be 0
|
|
_vrs, assigns = parse_pcode(pcode, {'S0': S0})
|
|
self.assertEqual(len(assigns), 1)
|
|
|
|
class TestDSPcodePatterns(unittest.TestCase):
|
|
"""Test DS instruction pcode patterns."""
|
|
|
|
def test_ds_load_b32_pcode(self):
|
|
"""Test DS_LOAD_B32 pcode is parseable."""
|
|
pcode = PCODE.get(DSOp.DS_LOAD_B32)
|
|
self.assertIsNotNone(pcode)
|
|
assert pcode is not None
|
|
self.assertIn('RETURN_DATA', pcode)
|
|
self.assertIn('MEM[', pcode)
|
|
|
|
def test_ds_store_b32_pcode(self):
|
|
"""Test DS_STORE_B32 pcode is parseable."""
|
|
pcode = PCODE.get(DSOp.DS_STORE_B32)
|
|
self.assertIsNotNone(pcode)
|
|
assert pcode is not None
|
|
self.assertIn('MEM[', pcode)
|
|
self.assertIn('DATA', pcode)
|
|
|
|
def test_mem_read_parsing(self):
|
|
"""Test MEM[addr].type read expression parsing."""
|
|
# Create a mock LDS buffer
|
|
lds = UOp(Ops.PARAM, dtypes.uint32.ptr(16384), arg=3)
|
|
addr = UOp.const(dtypes.uint32, 0)
|
|
vrs = {'_lds': lds, 'ADDR': addr, 'OFFSET': UOp.const(dtypes.uint32, 0)}
|
|
|
|
result = parse_expr('MEM[ADDR + OFFSET].b32', vrs)
|
|
# Should be an INDEX operation into LDS
|
|
self.assertIsNotNone(result)
|
|
|
|
def test_ds_store_2addr_b32_parsing(self):
|
|
"""Test DS_STORE_2ADDR_B32 pcode parsing produces MEM writes."""
|
|
pcode = PCODE.get(DSOp.DS_STORE_2ADDR_B32)
|
|
self.assertIsNotNone(pcode)
|
|
assert pcode is not None
|
|
srcs = {
|
|
'ADDR': UOp.const(dtypes.uint32, 0),
|
|
'OFFSET0': UOp.const(dtypes.uint32, 0),
|
|
'OFFSET1': UOp.const(dtypes.uint32, 1),
|
|
'DATA': UOp.const(dtypes.uint32, 0xAAAAAAAA),
|
|
'DATA2': UOp.const(dtypes.uint32, 0xBBBBBBBB),
|
|
}
|
|
srcs['laneId'] = UOp.const(dtypes.uint32, 0)
|
|
_, assigns = parse_pcode(pcode, srcs)
|
|
# Should have 2 MEM write assignments
|
|
self.assertEqual(len(assigns), 2)
|
|
for dest, val in assigns:
|
|
self.assertTrue(dest.startswith('MEM['))
|
|
# val should be (addr, write_val) tuple
|
|
self.assertIsInstance(val, tuple)
|
|
self.assertEqual(len(val), 2) # type: ignore[arg-type]
|
|
|
|
def test_ds_load_2addr_b32_parsing(self):
|
|
"""Test DS_LOAD_2ADDR_B32 pcode parsing produces RETURN_DATA assignments."""
|
|
pcode = PCODE.get(DSOp.DS_LOAD_2ADDR_B32)
|
|
self.assertIsNotNone(pcode)
|
|
assert pcode is not None
|
|
lds = UOp(Ops.PARAM, dtypes.uint32.ptr(16384), arg=3)
|
|
srcs = {
|
|
'ADDR': UOp.const(dtypes.uint32, 0),
|
|
'OFFSET0': UOp.const(dtypes.uint32, 0),
|
|
'OFFSET1': UOp.const(dtypes.uint32, 1),
|
|
'_lds': lds,
|
|
}
|
|
srcs['laneId'] = UOp.const(dtypes.uint32, 0)
|
|
_, assigns = parse_pcode(pcode, srcs)
|
|
# Should have 2 RETURN_DATA assignments
|
|
self.assertEqual(len(assigns), 2)
|
|
self.assertEqual(assigns[0][0], 'RETURN_DATA[31:0]')
|
|
self.assertEqual(assigns[1][0], 'RETURN_DATA[63:32]')
|
|
|
|
def test_ds_store_address_calculation(self):
|
|
"""Test DS_STORE_2ADDR_B32 calculates correct addresses (offset * 4)."""
|
|
pcode = PCODE.get(DSOp.DS_STORE_2ADDR_B32)
|
|
assert pcode is not None
|
|
srcs = {
|
|
'ADDR': UOp.const(dtypes.uint32, 100),
|
|
'OFFSET0': UOp.const(dtypes.uint32, 2),
|
|
'OFFSET1': UOp.const(dtypes.uint32, 5),
|
|
'DATA': UOp.const(dtypes.uint32, 0xAAAAAAAA),
|
|
'DATA2': UOp.const(dtypes.uint32, 0xBBBBBBBB),
|
|
}
|
|
srcs['laneId'] = UOp.const(dtypes.uint32, 0)
|
|
_, assigns = parse_pcode(pcode, srcs)
|
|
# Check addresses: 100 + 2*4 = 108, 100 + 5*4 = 120
|
|
# assigns[i][1] is (addr, val) tuple for MEM writes; mypy sees UOp
|
|
self.assertEqual(assigns[0][1][0].simplify().arg, 108) # type: ignore[index]
|
|
self.assertEqual(assigns[1][1][0].simplify().arg, 120) # type: ignore[index]
|
|
|
|
def test_ds_store_data_values(self):
|
|
"""Test DS_STORE_2ADDR_B32 uses correct data values."""
|
|
pcode = PCODE.get(DSOp.DS_STORE_2ADDR_B32)
|
|
assert pcode is not None
|
|
srcs = {
|
|
'ADDR': UOp.const(dtypes.uint32, 0),
|
|
'OFFSET0': UOp.const(dtypes.uint32, 0),
|
|
'OFFSET1': UOp.const(dtypes.uint32, 1),
|
|
'DATA': UOp.const(dtypes.uint32, 0xAAAAAAAA),
|
|
'DATA2': UOp.const(dtypes.uint32, 0xBBBBBBBB),
|
|
}
|
|
srcs['laneId'] = UOp.const(dtypes.uint32, 0)
|
|
_, assigns = parse_pcode(pcode, srcs)
|
|
# assigns[i][1] is (addr, val) tuple for MEM writes; mypy sees UOp
|
|
# DATA[31:0] should preserve the value
|
|
self.assertEqual(assigns[0][1][1].simplify().arg, 0xAAAAAAAA) # type: ignore[index]
|
|
self.assertEqual(assigns[1][1][1].simplify().arg, 0xBBBBBBBB) # type: ignore[index]
|
|
|
|
class TestConditionalParsing(unittest.TestCase):
|
|
"""Test conditional (if/elsif/else) pcode parsing."""
|
|
|
|
def test_ternary_in_assignment(self):
|
|
"""Test parsing ternary expression (which becomes WHERE)."""
|
|
# S_CSELECT_B32: D0.u32 = SCC ? S0.u32 : S1.u32
|
|
pcode = PCODE[SOP2Op.S_CSELECT_B32]
|
|
s0 = UOp.const(dtypes.uint32, 10)
|
|
s1 = UOp.const(dtypes.uint32, 20)
|
|
scc = UOp.const(dtypes.uint32, 1)
|
|
_vrs, assigns = parse_pcode(pcode, {'S0': s0, 'S1': s1, 'SCC': scc})
|
|
self.assertEqual(len(assigns), 1)
|
|
dest, val = assigns[0]
|
|
self.assertTrue(dest.startswith('D0'))
|
|
# Result should be a WHERE (ternary becomes WHERE)
|
|
self.assertEqual(val.op, Ops.WHERE)
|
|
|
|
class TestAllPcode(unittest.TestCase):
|
|
"""Test that all pcode from all architectures can be parsed."""
|
|
|
|
def _make_srcs(self):
|
|
"""Create dummy source variables for pcode parsing."""
|
|
u32, u64 = lambda v=0: UOp.const(dtypes.uint32, v), lambda v=0: UOp.const(dtypes.uint64, v)
|
|
lds = UOp(Ops.PARAM, dtypes.uint32.ptr(16384), arg=3)
|
|
return {'laneId': u32(), 'laneID': u32(), 'S0': u32(), 'S1': u32(), 'S2': u32(), 'S3': u32(), 'SRC0': u32(),
|
|
'D0': u32(), 'D1': u32(), 'DST': u32(), 'VDST': u32(), 'SDST': u32(),
|
|
'VCC': u64(), 'VCCZ': u32(), 'EXEC': u64(), 'EXEC_LO': u32(), 'EXECZ': u32(), 'SCC': u32(),
|
|
'SIMM16': u32(), 'SIMM32': u32(), 'OFFSET': u32(), 'OFFSET0': u32(), 'OFFSET1': u32(), 'offset1': u32(),
|
|
'ADDR': u32(), 'ADDR_BASE': u32(), 'TADDR': u32(), 'DATA': u32(), 'DATA0': u32(), 'DATA1': u32(), 'DATA2': u32(),
|
|
'VDATA': u32(), 'VDATA0': u32(), 'VDATA1': u32(), 'VDATA2': u32(), 'VDATA3': u32(),
|
|
'OPSEL': u32(), 'OPSEL_HI': u32(), 'NEG': u32(), 'NEG_HI': u32(), 'CLAMP': u32(),
|
|
'M0': u32(), 'PC': u64(), 'DENORM': u32(1), 'ROUND_MODE': u32(), 'ROUND_TOWARD_ZERO': u32(),
|
|
'ROUND_NEAREST_EVEN': u32(), 'WAVE_STATUS': u32(),
|
|
'MAX_FLOAT_F32': u32(0x7f7fffff), 'Unsigned': u32(1), 'clampedLOD': u32(),
|
|
'_lds': lds, '_vmem': lds, '_active': UOp.const(dtypes.bool, True)}
|
|
|
|
def _parse_all_pcode(self, pcode_dict, arch: str, min_pct: float):
|
|
"""Parse all pcode. RuntimeError = parser limitation (ok), other exceptions = real bugs."""
|
|
srcs = self._make_srcs()
|
|
passed, skipped, errors = 0, 0, defaultdict(list)
|
|
for op, pcode in pcode_dict.items():
|
|
try:
|
|
parse_pcode(pcode, srcs)
|
|
passed += 1
|
|
except RuntimeError as e:
|
|
skipped += 1
|
|
errors[str(e)].append(op.name)
|
|
except Exception as e: self.fail(f"[{arch}] {op.name}: {e}\nPcode: {pcode[:200]}")
|
|
total = len(pcode_dict)
|
|
pct = 100 * passed / total
|
|
print(f"{arch}: {passed}/{total} ({pct:.1f}%) parsed, {skipped} skipped")
|
|
if DEBUG >= 2:
|
|
for err, ops in sorted(errors.items(), key=lambda x: -len(x[1])):
|
|
print(f" {err}: {', '.join(ops[:5])}{'...' if len(ops) > 5 else ''} ({len(ops)})")
|
|
self.assertGreaterEqual(pct, min_pct, f"[{arch}] {pct:.1f}% < {min_pct}% threshold")
|
|
|
|
def test_parse_all_cdna_pcode(self):
|
|
from tinygrad.runtime.autogen.amd.cdna.str_pcode import PCODE as CDNA_PCODE
|
|
self._parse_all_pcode(CDNA_PCODE, "CDNA", min_pct=60)
|
|
|
|
def test_parse_all_rdna3_pcode(self):
|
|
from tinygrad.runtime.autogen.amd.rdna3.str_pcode import PCODE as RDNA3_PCODE
|
|
self._parse_all_pcode(RDNA3_PCODE, "RDNA3", min_pct=90)
|
|
|
|
def test_parse_all_rdna4_pcode(self):
|
|
from tinygrad.runtime.autogen.amd.rdna4.str_pcode import PCODE as RDNA4_PCODE
|
|
self._parse_all_pcode(RDNA4_PCODE, "RDNA4", min_pct=65)
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|