diff --git a/extra/assembly/rdna3/asm.py b/extra/assembly/rdna3/asm.py index d8f4abe962..01351914e4 100644 --- a/extra/assembly/rdna3/asm.py +++ b/extra/assembly/rdna3/asm.py @@ -434,7 +434,8 @@ def disasm(inst: Inst) -> str: if op_name == 's_swappc_b64': return f"{op_name} {_fmt_sdst(sdst, 2)}, {_fmt_ssrc(ssrc0, 2)}" if op_name in ('s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'): return f"{op_name} {_fmt_sdst(sdst, 2 if 'b64' in op_name else 1)}, sendmsg({MSG_NAMES.get(ssrc0, str(ssrc0))})" - return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {_fmt_ssrc(ssrc0, src0_cnt)}" + ssrc0_str = fmt_src(ssrc0) if src0_cnt == 1 else _fmt_ssrc(ssrc0, src0_cnt) + return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {ssrc0_str}" if cls_name == 'SOP2': sdst, ssrc0, ssrc1 = [unwrap(inst._values.get(f, 0)) for f in ('sdst', 'ssrc0', 'ssrc1')] return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {_fmt_ssrc(ssrc0, src0_cnt)}, {_fmt_ssrc(ssrc1, src1_cnt)}" @@ -476,7 +477,8 @@ def parse_operand(op: str) -> tuple: v = -int(m.group(1), 16) if op.startswith('-') else int(m.group(1), 16) return (v, neg, abs_, hi_half) if op in SPECIAL_REGS: return (SPECIAL_REGS[op], neg, abs_, hi_half) - if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', op): return (REG_MAP[m.group(1)][int(m.group(2)):int(m.group(3))+1], neg, abs_, hi_half) + if op == 'lit': return (RawImm(255), neg, abs_, hi_half) # literal marker (actual value comes from literal word) + if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', op): return (REG_MAP[m.group(1)][int(m.group(2)):int(m.group(3))], neg, abs_, hi_half) if m := re.match(r'^([svt](?:tmp)?)(\d+)$', op): reg = REG_MAP[m.group(1)][int(m.group(2))] reg.hi = hi_half diff --git a/extra/assembly/rdna3/lib.py b/extra/assembly/rdna3/lib.py index 103326f3a0..f93149eed0 100644 --- a/extra/assembly/rdna3/lib.py +++ b/extra/assembly/rdna3/lib.py @@ -51,7 +51,7 @@ class _RegFactory(Generic[T]): @overload def __getitem__(self, key: slice) -> Reg: ... def __getitem__(self, key: int | slice) -> Reg: - return self._cls(key.start, key.stop - key.start) if isinstance(key, slice) else self._cls(key) + return self._cls(key.start, key.stop - key.start + 1) if isinstance(key, slice) else self._cls(key) def __repr__(self): return f"<{self._name} factory>" class SGPR(Reg): pass @@ -118,8 +118,30 @@ class Inst: def __init__(self, *args, literal: int | None = None, **kwargs): self._values, self._literal = dict(self._defaults), literal - self._values.update(zip([n for n in self._fields if n != 'encoding'], args)) - self._values.update(kwargs) + # Map positional args to field names + field_names = [n for n in self._fields if n != 'encoding'] + orig_args = dict(zip(field_names, args)) + orig_args.update(kwargs) + self._values.update(orig_args) + # Validate register counts for SMEM instructions (before encoding) + if self.__class__.__name__ == 'SMEM': + op_val = orig_args.get(field_names[0]) if args else orig_args.get('op') + if op_val is not None: + if hasattr(op_val, 'value'): op_val = op_val.value + expected_cnt = {0:1, 1:2, 2:4, 3:8, 4:16, 8:1, 9:2, 10:4, 11:8, 12:16}.get(op_val) + sdata_val = orig_args.get('sdata') + if expected_cnt is not None and isinstance(sdata_val, Reg) and sdata_val.count != expected_cnt: + raise ValueError(f"SMEM op {op_val} expects {expected_cnt} registers, got {sdata_val.count}") + # Validate register counts for SOP1 instructions (b32 = 1 reg, b64 = 2 regs) + if self.__class__.__name__ == 'SOP1': + op_val = orig_args.get(field_names[0]) if args else orig_args.get('op') + if op_val is not None and hasattr(op_val, 'name'): + expected = 2 if op_val.name.endswith('_B64') else 1 + sdst_val, ssrc0_val = orig_args.get('sdst'), orig_args.get('ssrc0') + if isinstance(sdst_val, Reg) and sdst_val.count != expected: + raise ValueError(f"SOP1 {op_val.name} expects {expected} destination register(s), got {sdst_val.count}") + if isinstance(ssrc0_val, Reg) and ssrc0_val.count != expected: + raise ValueError(f"SOP1 {op_val.name} expects {expected} source register(s), got {ssrc0_val.count}") # Type check and encode values for name, val in list(self._values.items()): if name == 'encoding': continue @@ -143,6 +165,9 @@ class Inst: # Track literal value if needed (encoded as 255) if encoded == 255 and self._literal is None and isinstance(val, int) and not isinstance(val, IntEnum): self._literal = val + elif encoded == 255 and self._literal is None and isinstance(val, float): + import struct + self._literal = struct.unpack(' str: from extra.assembly.rdna3.asm import disasm return disasm(self) diff --git a/extra/assembly/rdna3/test/test_handwritten.py b/extra/assembly/rdna3/test/test_handwritten.py index 2a0d17c44b..fc0c16a05f 100644 --- a/extra/assembly/rdna3/test/test_handwritten.py +++ b/extra/assembly/rdna3/test/test_handwritten.py @@ -1,7 +1,7 @@ # do not change these tests. we need to fix bugs to make them pass # the Inst constructor should be looking at the types of the fields to correctly set the value -import unittest +import unittest, struct from extra.assembly.rdna3.autogen import * from extra.assembly.rdna3.lib import Inst from extra.assembly.rdna3.asm import asm @@ -24,6 +24,33 @@ class TestIntegration(unittest.TestCase): def test_load_b128(self): self.inst = s_load_b128(s[4:7], s[0:1], NULL, 0) + def test_load_b128_wrong_size(self): + # this should have to be 4 regs on the loaded to + with self.assertRaises(Exception): + self.inst = s_load_b128(s[4:6], s[0:1], NULL, 0) + + def test_mov_b32(self): + self.inst = s_mov_b32(s[80], s[0]) + + def test_mov_b64(self): + self.inst = s_mov_b64(s[80:81], s[0:1]) + + def test_mov_b32_wrong(self): + with self.assertRaises(Exception): + self.inst = s_mov_b32(s[80:81], s[0:1]) + with self.assertRaises(Exception): + self.inst = s_mov_b32(s[80:81], s[0]) + with self.assertRaises(Exception): + self.inst = s_mov_b32(s[80], s[0:1]) + + def test_mov_b64_wrong(self): + with self.assertRaises(Exception): + self.inst = s_mov_b64(s[80], s[0]) + with self.assertRaises(Exception): + self.inst = s_mov_b64(s[80], s[0:1]) + with self.assertRaises(Exception): + self.inst = s_mov_b64(s[80:81], s[0]) + def test_load_b128_no_0(self): self.inst = s_load_b128(s[4:7], s[0:1], NULL) @@ -84,5 +111,68 @@ class TestIntegration(unittest.TestCase): def test_dual_mul(self): self.inst = v_dual_mul_f32(VOPDOp.V_DUAL_MUL_F32, vdstx=v[0], vdsty=v[1], srcx0=v[2], vsrcx1=v[3], srcy0=v[4], vsrcy1=v[5]) + def test_simple_int_to_s(self): + self.inst = s_mov_b32(s[0], 3) + + def test_complex_int_to_s(self): + self.inst = s_mov_b32(s[0], 0x235646) + + def test_simple_float_to_s(self): + self.inst = s_mov_b32(s[0], 1.0) + + def test_complex_float_to_s(self): + self.inst = s_mov_b32(s[0], 1337.0) + int_inst = s_mov_b32(s[0], struct.unpack("I", struct.pack("f", 1337.0))[0]) + self.assertEqual(self.inst, int_inst) + +class TestRegisterSliceSyntax(unittest.TestCase): + """ + Issue: Register slice syntax should use AMD assembly convention (inclusive end). + + In AMD assembly, s[4:7] means registers s4, s5, s6, s7 (4 registers, inclusive). + The DSL should match this convention so that: + - s[4:7] gives 4 registers + - Disassembler output can be copied directly back into DSL code + + Fix: Change _RegFactory.__getitem__ to use inclusive end: + key.stop - key.start + 1 (instead of key.stop - key.start) + """ + def test_register_slice_count(self): + # s[4:7] should give 4 registers: s4, s5, s6, s7 (AMD convention, inclusive) + reg = s[4:7] + self.assertEqual(reg.count, 4, "s[4:7] should give 4 registers (s4, s5, s6, s7)") + + def test_register_slice_roundtrip(self): + # Round-trip: DSL -> disasm -> DSL should preserve register count + reg = s[4:7] # 4 registers in AMD convention + inst = s_load_b128(reg, s[0:1], NULL, 0) + disasm = inst.disasm() + # Disasm shows s[4:7] - user should be able to copy this back + self.assertIn("s[4:7]", disasm) + # And s[4:7] in DSL should give the same 4 registers + reg_from_disasm = s[4:7] + self.assertEqual(reg_from_disasm.count, 4, "s[4:7] from disasm should give 4 registers") + +class TestInstructionEquality(unittest.TestCase): + """ + Issue: No __eq__ method - instruction comparison requires repr() workaround. + + Two identical instructions should compare equal with ==, but currently: + inst1 == inst2 returns False + + The test_handwritten.py works around this with: + self.assertEqual(repr(self.inst), repr(reasm)) + """ + def test_identical_instructions_equal(self): + inst1 = v_mov_b32_e32(v[0], v[1]) + inst2 = v_mov_b32_e32(v[0], v[1]) + self.assertEqual(inst1, inst2, "identical instructions should be equal") + + def test_different_instructions_not_equal(self): + inst1 = v_mov_b32_e32(v[0], v[1]) + inst2 = v_mov_b32_e32(v[0], v[2]) + self.assertNotEqual(inst1, inst2, "different instructions should not be equal") + + if __name__ == "__main__": unittest.main() diff --git a/extra/assembly/rdna3/test/test_integration.py b/extra/assembly/rdna3/test/test_integration.py index 39f9a272df..0d22640ba4 100644 --- a/extra/assembly/rdna3/test/test_integration.py +++ b/extra/assembly/rdna3/test/test_integration.py @@ -189,7 +189,7 @@ class TestAsm(unittest.TestCase): def test_asm_register_range(self): """Test parsing register ranges.""" inst = asm('s_load_b128 s[4:7], s[0:1], null') - self.assertEqual(inst.to_bytes(), s_load_b128(s[4:8], s[0:2], NULL).to_bytes()) + self.assertEqual(inst.to_bytes(), s_load_b128(s[4:7], s[0:1], NULL).to_bytes()) def test_asm_matches_llvm(self): """Test asm() output matches LLVM assembler."""