mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
improve asm dsl syntax (#13864)
* improve asm dsl syntax * improve asm dsl syntax
This commit is contained in:
@@ -434,7 +434,8 @@ def disasm(inst: Inst) -> str:
|
||||
if op_name == 's_swappc_b64': return f"{op_name} {_fmt_sdst(sdst, 2)}, {_fmt_ssrc(ssrc0, 2)}"
|
||||
if op_name in ('s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'):
|
||||
return f"{op_name} {_fmt_sdst(sdst, 2 if 'b64' in op_name else 1)}, sendmsg({MSG_NAMES.get(ssrc0, str(ssrc0))})"
|
||||
return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {_fmt_ssrc(ssrc0, src0_cnt)}"
|
||||
ssrc0_str = fmt_src(ssrc0) if src0_cnt == 1 else _fmt_ssrc(ssrc0, src0_cnt)
|
||||
return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {ssrc0_str}"
|
||||
if cls_name == 'SOP2':
|
||||
sdst, ssrc0, ssrc1 = [unwrap(inst._values.get(f, 0)) for f in ('sdst', 'ssrc0', 'ssrc1')]
|
||||
return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {_fmt_ssrc(ssrc0, src0_cnt)}, {_fmt_ssrc(ssrc1, src1_cnt)}"
|
||||
@@ -476,7 +477,8 @@ def parse_operand(op: str) -> tuple:
|
||||
v = -int(m.group(1), 16) if op.startswith('-') else int(m.group(1), 16)
|
||||
return (v, neg, abs_, hi_half)
|
||||
if op in SPECIAL_REGS: return (SPECIAL_REGS[op], neg, abs_, hi_half)
|
||||
if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', op): return (REG_MAP[m.group(1)][int(m.group(2)):int(m.group(3))+1], neg, abs_, hi_half)
|
||||
if op == 'lit': return (RawImm(255), neg, abs_, hi_half) # literal marker (actual value comes from literal word)
|
||||
if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', op): return (REG_MAP[m.group(1)][int(m.group(2)):int(m.group(3))], neg, abs_, hi_half)
|
||||
if m := re.match(r'^([svt](?:tmp)?)(\d+)$', op):
|
||||
reg = REG_MAP[m.group(1)][int(m.group(2))]
|
||||
reg.hi = hi_half
|
||||
|
||||
@@ -51,7 +51,7 @@ class _RegFactory(Generic[T]):
|
||||
@overload
|
||||
def __getitem__(self, key: slice) -> Reg: ...
|
||||
def __getitem__(self, key: int | slice) -> Reg:
|
||||
return self._cls(key.start, key.stop - key.start) if isinstance(key, slice) else self._cls(key)
|
||||
return self._cls(key.start, key.stop - key.start + 1) if isinstance(key, slice) else self._cls(key)
|
||||
def __repr__(self): return f"<{self._name} factory>"
|
||||
|
||||
class SGPR(Reg): pass
|
||||
@@ -118,8 +118,30 @@ class Inst:
|
||||
|
||||
def __init__(self, *args, literal: int | None = None, **kwargs):
|
||||
self._values, self._literal = dict(self._defaults), literal
|
||||
self._values.update(zip([n for n in self._fields if n != 'encoding'], args))
|
||||
self._values.update(kwargs)
|
||||
# Map positional args to field names
|
||||
field_names = [n for n in self._fields if n != 'encoding']
|
||||
orig_args = dict(zip(field_names, args))
|
||||
orig_args.update(kwargs)
|
||||
self._values.update(orig_args)
|
||||
# Validate register counts for SMEM instructions (before encoding)
|
||||
if self.__class__.__name__ == 'SMEM':
|
||||
op_val = orig_args.get(field_names[0]) if args else orig_args.get('op')
|
||||
if op_val is not None:
|
||||
if hasattr(op_val, 'value'): op_val = op_val.value
|
||||
expected_cnt = {0:1, 1:2, 2:4, 3:8, 4:16, 8:1, 9:2, 10:4, 11:8, 12:16}.get(op_val)
|
||||
sdata_val = orig_args.get('sdata')
|
||||
if expected_cnt is not None and isinstance(sdata_val, Reg) and sdata_val.count != expected_cnt:
|
||||
raise ValueError(f"SMEM op {op_val} expects {expected_cnt} registers, got {sdata_val.count}")
|
||||
# Validate register counts for SOP1 instructions (b32 = 1 reg, b64 = 2 regs)
|
||||
if self.__class__.__name__ == 'SOP1':
|
||||
op_val = orig_args.get(field_names[0]) if args else orig_args.get('op')
|
||||
if op_val is not None and hasattr(op_val, 'name'):
|
||||
expected = 2 if op_val.name.endswith('_B64') else 1
|
||||
sdst_val, ssrc0_val = orig_args.get('sdst'), orig_args.get('ssrc0')
|
||||
if isinstance(sdst_val, Reg) and sdst_val.count != expected:
|
||||
raise ValueError(f"SOP1 {op_val.name} expects {expected} destination register(s), got {sdst_val.count}")
|
||||
if isinstance(ssrc0_val, Reg) and ssrc0_val.count != expected:
|
||||
raise ValueError(f"SOP1 {op_val.name} expects {expected} source register(s), got {ssrc0_val.count}")
|
||||
# Type check and encode values
|
||||
for name, val in list(self._values.items()):
|
||||
if name == 'encoding': continue
|
||||
@@ -143,6 +165,9 @@ class Inst:
|
||||
# Track literal value if needed (encoded as 255)
|
||||
if encoded == 255 and self._literal is None and isinstance(val, int) and not isinstance(val, IntEnum):
|
||||
self._literal = val
|
||||
elif encoded == 255 and self._literal is None and isinstance(val, float):
|
||||
import struct
|
||||
self._literal = struct.unpack('<I', struct.pack('<f', val))[0]
|
||||
# Encode raw register fields for consistent repr
|
||||
elif name in RAW_FIELDS:
|
||||
if isinstance(val, Reg): self._values[name] = _encode_reg(val)
|
||||
@@ -209,6 +234,12 @@ class Inst:
|
||||
lit = f", literal={hex(self._literal)}" if self._literal is not None else ""
|
||||
return f"{self.__class__.__name__}({', '.join(f'{k}={v}' for k, v in items)}{lit})"
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Inst): return NotImplemented
|
||||
return self.__class__ == other.__class__ and self._values == other._values and self._literal == other._literal
|
||||
|
||||
def __hash__(self): return hash((self.__class__.__name__, tuple(sorted((k, repr(v)) for k, v in self._values.items())), self._literal))
|
||||
|
||||
def disasm(self) -> str:
|
||||
from extra.assembly.rdna3.asm import disasm
|
||||
return disasm(self)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# do not change these tests. we need to fix bugs to make them pass
|
||||
# the Inst constructor should be looking at the types of the fields to correctly set the value
|
||||
|
||||
import unittest
|
||||
import unittest, struct
|
||||
from extra.assembly.rdna3.autogen import *
|
||||
from extra.assembly.rdna3.lib import Inst
|
||||
from extra.assembly.rdna3.asm import asm
|
||||
@@ -24,6 +24,33 @@ class TestIntegration(unittest.TestCase):
|
||||
def test_load_b128(self):
|
||||
self.inst = s_load_b128(s[4:7], s[0:1], NULL, 0)
|
||||
|
||||
def test_load_b128_wrong_size(self):
|
||||
# this should have to be 4 regs on the loaded to
|
||||
with self.assertRaises(Exception):
|
||||
self.inst = s_load_b128(s[4:6], s[0:1], NULL, 0)
|
||||
|
||||
def test_mov_b32(self):
|
||||
self.inst = s_mov_b32(s[80], s[0])
|
||||
|
||||
def test_mov_b64(self):
|
||||
self.inst = s_mov_b64(s[80:81], s[0:1])
|
||||
|
||||
def test_mov_b32_wrong(self):
|
||||
with self.assertRaises(Exception):
|
||||
self.inst = s_mov_b32(s[80:81], s[0:1])
|
||||
with self.assertRaises(Exception):
|
||||
self.inst = s_mov_b32(s[80:81], s[0])
|
||||
with self.assertRaises(Exception):
|
||||
self.inst = s_mov_b32(s[80], s[0:1])
|
||||
|
||||
def test_mov_b64_wrong(self):
|
||||
with self.assertRaises(Exception):
|
||||
self.inst = s_mov_b64(s[80], s[0])
|
||||
with self.assertRaises(Exception):
|
||||
self.inst = s_mov_b64(s[80], s[0:1])
|
||||
with self.assertRaises(Exception):
|
||||
self.inst = s_mov_b64(s[80:81], s[0])
|
||||
|
||||
def test_load_b128_no_0(self):
|
||||
self.inst = s_load_b128(s[4:7], s[0:1], NULL)
|
||||
|
||||
@@ -84,5 +111,68 @@ class TestIntegration(unittest.TestCase):
|
||||
def test_dual_mul(self):
|
||||
self.inst = v_dual_mul_f32(VOPDOp.V_DUAL_MUL_F32, vdstx=v[0], vdsty=v[1], srcx0=v[2], vsrcx1=v[3], srcy0=v[4], vsrcy1=v[5])
|
||||
|
||||
def test_simple_int_to_s(self):
|
||||
self.inst = s_mov_b32(s[0], 3)
|
||||
|
||||
def test_complex_int_to_s(self):
|
||||
self.inst = s_mov_b32(s[0], 0x235646)
|
||||
|
||||
def test_simple_float_to_s(self):
|
||||
self.inst = s_mov_b32(s[0], 1.0)
|
||||
|
||||
def test_complex_float_to_s(self):
|
||||
self.inst = s_mov_b32(s[0], 1337.0)
|
||||
int_inst = s_mov_b32(s[0], struct.unpack("I", struct.pack("f", 1337.0))[0])
|
||||
self.assertEqual(self.inst, int_inst)
|
||||
|
||||
class TestRegisterSliceSyntax(unittest.TestCase):
|
||||
"""
|
||||
Issue: Register slice syntax should use AMD assembly convention (inclusive end).
|
||||
|
||||
In AMD assembly, s[4:7] means registers s4, s5, s6, s7 (4 registers, inclusive).
|
||||
The DSL should match this convention so that:
|
||||
- s[4:7] gives 4 registers
|
||||
- Disassembler output can be copied directly back into DSL code
|
||||
|
||||
Fix: Change _RegFactory.__getitem__ to use inclusive end:
|
||||
key.stop - key.start + 1 (instead of key.stop - key.start)
|
||||
"""
|
||||
def test_register_slice_count(self):
|
||||
# s[4:7] should give 4 registers: s4, s5, s6, s7 (AMD convention, inclusive)
|
||||
reg = s[4:7]
|
||||
self.assertEqual(reg.count, 4, "s[4:7] should give 4 registers (s4, s5, s6, s7)")
|
||||
|
||||
def test_register_slice_roundtrip(self):
|
||||
# Round-trip: DSL -> disasm -> DSL should preserve register count
|
||||
reg = s[4:7] # 4 registers in AMD convention
|
||||
inst = s_load_b128(reg, s[0:1], NULL, 0)
|
||||
disasm = inst.disasm()
|
||||
# Disasm shows s[4:7] - user should be able to copy this back
|
||||
self.assertIn("s[4:7]", disasm)
|
||||
# And s[4:7] in DSL should give the same 4 registers
|
||||
reg_from_disasm = s[4:7]
|
||||
self.assertEqual(reg_from_disasm.count, 4, "s[4:7] from disasm should give 4 registers")
|
||||
|
||||
class TestInstructionEquality(unittest.TestCase):
|
||||
"""
|
||||
Issue: No __eq__ method - instruction comparison requires repr() workaround.
|
||||
|
||||
Two identical instructions should compare equal with ==, but currently:
|
||||
inst1 == inst2 returns False
|
||||
|
||||
The test_handwritten.py works around this with:
|
||||
self.assertEqual(repr(self.inst), repr(reasm))
|
||||
"""
|
||||
def test_identical_instructions_equal(self):
|
||||
inst1 = v_mov_b32_e32(v[0], v[1])
|
||||
inst2 = v_mov_b32_e32(v[0], v[1])
|
||||
self.assertEqual(inst1, inst2, "identical instructions should be equal")
|
||||
|
||||
def test_different_instructions_not_equal(self):
|
||||
inst1 = v_mov_b32_e32(v[0], v[1])
|
||||
inst2 = v_mov_b32_e32(v[0], v[2])
|
||||
self.assertNotEqual(inst1, inst2, "different instructions should not be equal")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -189,7 +189,7 @@ class TestAsm(unittest.TestCase):
|
||||
def test_asm_register_range(self):
|
||||
"""Test parsing register ranges."""
|
||||
inst = asm('s_load_b128 s[4:7], s[0:1], null')
|
||||
self.assertEqual(inst.to_bytes(), s_load_b128(s[4:8], s[0:2], NULL).to_bytes())
|
||||
self.assertEqual(inst.to_bytes(), s_load_b128(s[4:7], s[0:1], NULL).to_bytes())
|
||||
|
||||
def test_asm_matches_llvm(self):
|
||||
"""Test asm() output matches LLVM assembler."""
|
||||
|
||||
Reference in New Issue
Block a user