improve asm dsl syntax (#13864)

* improve asm dsl syntax

* improve asm dsl syntax
This commit is contained in:
George Hotz
2025-12-28 20:04:59 -05:00
committed by GitHub
parent f5090192c8
commit d9603c1bee
4 changed files with 130 additions and 7 deletions

View File

@@ -434,7 +434,8 @@ def disasm(inst: Inst) -> str:
if op_name == 's_swappc_b64': return f"{op_name} {_fmt_sdst(sdst, 2)}, {_fmt_ssrc(ssrc0, 2)}"
if op_name in ('s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'):
return f"{op_name} {_fmt_sdst(sdst, 2 if 'b64' in op_name else 1)}, sendmsg({MSG_NAMES.get(ssrc0, str(ssrc0))})"
return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {_fmt_ssrc(ssrc0, src0_cnt)}"
ssrc0_str = fmt_src(ssrc0) if src0_cnt == 1 else _fmt_ssrc(ssrc0, src0_cnt)
return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {ssrc0_str}"
if cls_name == 'SOP2':
sdst, ssrc0, ssrc1 = [unwrap(inst._values.get(f, 0)) for f in ('sdst', 'ssrc0', 'ssrc1')]
return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {_fmt_ssrc(ssrc0, src0_cnt)}, {_fmt_ssrc(ssrc1, src1_cnt)}"
@@ -476,7 +477,8 @@ def parse_operand(op: str) -> tuple:
v = -int(m.group(1), 16) if op.startswith('-') else int(m.group(1), 16)
return (v, neg, abs_, hi_half)
if op in SPECIAL_REGS: return (SPECIAL_REGS[op], neg, abs_, hi_half)
if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', op): return (REG_MAP[m.group(1)][int(m.group(2)):int(m.group(3))+1], neg, abs_, hi_half)
if op == 'lit': return (RawImm(255), neg, abs_, hi_half) # literal marker (actual value comes from literal word)
if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', op): return (REG_MAP[m.group(1)][int(m.group(2)):int(m.group(3))], neg, abs_, hi_half)
if m := re.match(r'^([svt](?:tmp)?)(\d+)$', op):
reg = REG_MAP[m.group(1)][int(m.group(2))]
reg.hi = hi_half

View File

@@ -51,7 +51,7 @@ class _RegFactory(Generic[T]):
@overload
def __getitem__(self, key: slice) -> Reg: ...
def __getitem__(self, key: int | slice) -> Reg:
return self._cls(key.start, key.stop - key.start) if isinstance(key, slice) else self._cls(key)
return self._cls(key.start, key.stop - key.start + 1) if isinstance(key, slice) else self._cls(key)
def __repr__(self): return f"<{self._name} factory>"
class SGPR(Reg): pass
@@ -118,8 +118,30 @@ class Inst:
def __init__(self, *args, literal: int | None = None, **kwargs):
self._values, self._literal = dict(self._defaults), literal
self._values.update(zip([n for n in self._fields if n != 'encoding'], args))
self._values.update(kwargs)
# Map positional args to field names
field_names = [n for n in self._fields if n != 'encoding']
orig_args = dict(zip(field_names, args))
orig_args.update(kwargs)
self._values.update(orig_args)
# Validate register counts for SMEM instructions (before encoding)
if self.__class__.__name__ == 'SMEM':
op_val = orig_args.get(field_names[0]) if args else orig_args.get('op')
if op_val is not None:
if hasattr(op_val, 'value'): op_val = op_val.value
expected_cnt = {0:1, 1:2, 2:4, 3:8, 4:16, 8:1, 9:2, 10:4, 11:8, 12:16}.get(op_val)
sdata_val = orig_args.get('sdata')
if expected_cnt is not None and isinstance(sdata_val, Reg) and sdata_val.count != expected_cnt:
raise ValueError(f"SMEM op {op_val} expects {expected_cnt} registers, got {sdata_val.count}")
# Validate register counts for SOP1 instructions (b32 = 1 reg, b64 = 2 regs)
if self.__class__.__name__ == 'SOP1':
op_val = orig_args.get(field_names[0]) if args else orig_args.get('op')
if op_val is not None and hasattr(op_val, 'name'):
expected = 2 if op_val.name.endswith('_B64') else 1
sdst_val, ssrc0_val = orig_args.get('sdst'), orig_args.get('ssrc0')
if isinstance(sdst_val, Reg) and sdst_val.count != expected:
raise ValueError(f"SOP1 {op_val.name} expects {expected} destination register(s), got {sdst_val.count}")
if isinstance(ssrc0_val, Reg) and ssrc0_val.count != expected:
raise ValueError(f"SOP1 {op_val.name} expects {expected} source register(s), got {ssrc0_val.count}")
# Type check and encode values
for name, val in list(self._values.items()):
if name == 'encoding': continue
@@ -143,6 +165,9 @@ class Inst:
# Track literal value if needed (encoded as 255)
if encoded == 255 and self._literal is None and isinstance(val, int) and not isinstance(val, IntEnum):
self._literal = val
elif encoded == 255 and self._literal is None and isinstance(val, float):
import struct
self._literal = struct.unpack('<I', struct.pack('<f', val))[0]
# Encode raw register fields for consistent repr
elif name in RAW_FIELDS:
if isinstance(val, Reg): self._values[name] = _encode_reg(val)
@@ -209,6 +234,12 @@ class Inst:
lit = f", literal={hex(self._literal)}" if self._literal is not None else ""
return f"{self.__class__.__name__}({', '.join(f'{k}={v}' for k, v in items)}{lit})"
def __eq__(self, other):
if not isinstance(other, Inst): return NotImplemented
return self.__class__ == other.__class__ and self._values == other._values and self._literal == other._literal
def __hash__(self): return hash((self.__class__.__name__, tuple(sorted((k, repr(v)) for k, v in self._values.items())), self._literal))
def disasm(self) -> str:
from extra.assembly.rdna3.asm import disasm
return disasm(self)

View File

@@ -1,7 +1,7 @@
# do not change these tests. we need to fix bugs to make them pass
# the Inst constructor should be looking at the types of the fields to correctly set the value
import unittest
import unittest, struct
from extra.assembly.rdna3.autogen import *
from extra.assembly.rdna3.lib import Inst
from extra.assembly.rdna3.asm import asm
@@ -24,6 +24,33 @@ class TestIntegration(unittest.TestCase):
def test_load_b128(self):
self.inst = s_load_b128(s[4:7], s[0:1], NULL, 0)
def test_load_b128_wrong_size(self):
# this should have to be 4 regs on the loaded to
with self.assertRaises(Exception):
self.inst = s_load_b128(s[4:6], s[0:1], NULL, 0)
def test_mov_b32(self):
self.inst = s_mov_b32(s[80], s[0])
def test_mov_b64(self):
self.inst = s_mov_b64(s[80:81], s[0:1])
def test_mov_b32_wrong(self):
with self.assertRaises(Exception):
self.inst = s_mov_b32(s[80:81], s[0:1])
with self.assertRaises(Exception):
self.inst = s_mov_b32(s[80:81], s[0])
with self.assertRaises(Exception):
self.inst = s_mov_b32(s[80], s[0:1])
def test_mov_b64_wrong(self):
with self.assertRaises(Exception):
self.inst = s_mov_b64(s[80], s[0])
with self.assertRaises(Exception):
self.inst = s_mov_b64(s[80], s[0:1])
with self.assertRaises(Exception):
self.inst = s_mov_b64(s[80:81], s[0])
def test_load_b128_no_0(self):
self.inst = s_load_b128(s[4:7], s[0:1], NULL)
@@ -84,5 +111,68 @@ class TestIntegration(unittest.TestCase):
def test_dual_mul(self):
self.inst = v_dual_mul_f32(VOPDOp.V_DUAL_MUL_F32, vdstx=v[0], vdsty=v[1], srcx0=v[2], vsrcx1=v[3], srcy0=v[4], vsrcy1=v[5])
def test_simple_int_to_s(self):
self.inst = s_mov_b32(s[0], 3)
def test_complex_int_to_s(self):
self.inst = s_mov_b32(s[0], 0x235646)
def test_simple_float_to_s(self):
self.inst = s_mov_b32(s[0], 1.0)
def test_complex_float_to_s(self):
self.inst = s_mov_b32(s[0], 1337.0)
int_inst = s_mov_b32(s[0], struct.unpack("I", struct.pack("f", 1337.0))[0])
self.assertEqual(self.inst, int_inst)
class TestRegisterSliceSyntax(unittest.TestCase):
"""
Issue: Register slice syntax should use AMD assembly convention (inclusive end).
In AMD assembly, s[4:7] means registers s4, s5, s6, s7 (4 registers, inclusive).
The DSL should match this convention so that:
- s[4:7] gives 4 registers
- Disassembler output can be copied directly back into DSL code
Fix: Change _RegFactory.__getitem__ to use inclusive end:
key.stop - key.start + 1 (instead of key.stop - key.start)
"""
def test_register_slice_count(self):
# s[4:7] should give 4 registers: s4, s5, s6, s7 (AMD convention, inclusive)
reg = s[4:7]
self.assertEqual(reg.count, 4, "s[4:7] should give 4 registers (s4, s5, s6, s7)")
def test_register_slice_roundtrip(self):
# Round-trip: DSL -> disasm -> DSL should preserve register count
reg = s[4:7] # 4 registers in AMD convention
inst = s_load_b128(reg, s[0:1], NULL, 0)
disasm = inst.disasm()
# Disasm shows s[4:7] - user should be able to copy this back
self.assertIn("s[4:7]", disasm)
# And s[4:7] in DSL should give the same 4 registers
reg_from_disasm = s[4:7]
self.assertEqual(reg_from_disasm.count, 4, "s[4:7] from disasm should give 4 registers")
class TestInstructionEquality(unittest.TestCase):
"""
Issue: No __eq__ method - instruction comparison requires repr() workaround.
Two identical instructions should compare equal with ==, but currently:
inst1 == inst2 returns False
The test_handwritten.py works around this with:
self.assertEqual(repr(self.inst), repr(reasm))
"""
def test_identical_instructions_equal(self):
inst1 = v_mov_b32_e32(v[0], v[1])
inst2 = v_mov_b32_e32(v[0], v[1])
self.assertEqual(inst1, inst2, "identical instructions should be equal")
def test_different_instructions_not_equal(self):
inst1 = v_mov_b32_e32(v[0], v[1])
inst2 = v_mov_b32_e32(v[0], v[2])
self.assertNotEqual(inst1, inst2, "different instructions should not be equal")
if __name__ == "__main__":
unittest.main()

View File

@@ -189,7 +189,7 @@ class TestAsm(unittest.TestCase):
def test_asm_register_range(self):
"""Test parsing register ranges."""
inst = asm('s_load_b128 s[4:7], s[0:1], null')
self.assertEqual(inst.to_bytes(), s_load_b128(s[4:8], s[0:2], NULL).to_bytes())
self.assertEqual(inst.to_bytes(), s_load_b128(s[4:7], s[0:1], NULL).to_bytes())
def test_asm_matches_llvm(self):
"""Test asm() output matches LLVM assembler."""