mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
improve asm dsl syntax (#13864)
* improve asm dsl syntax * improve asm dsl syntax
This commit is contained in:
@@ -434,7 +434,8 @@ def disasm(inst: Inst) -> str:
|
|||||||
if op_name == 's_swappc_b64': return f"{op_name} {_fmt_sdst(sdst, 2)}, {_fmt_ssrc(ssrc0, 2)}"
|
if op_name == 's_swappc_b64': return f"{op_name} {_fmt_sdst(sdst, 2)}, {_fmt_ssrc(ssrc0, 2)}"
|
||||||
if op_name in ('s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'):
|
if op_name in ('s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'):
|
||||||
return f"{op_name} {_fmt_sdst(sdst, 2 if 'b64' in op_name else 1)}, sendmsg({MSG_NAMES.get(ssrc0, str(ssrc0))})"
|
return f"{op_name} {_fmt_sdst(sdst, 2 if 'b64' in op_name else 1)}, sendmsg({MSG_NAMES.get(ssrc0, str(ssrc0))})"
|
||||||
return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {_fmt_ssrc(ssrc0, src0_cnt)}"
|
ssrc0_str = fmt_src(ssrc0) if src0_cnt == 1 else _fmt_ssrc(ssrc0, src0_cnt)
|
||||||
|
return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {ssrc0_str}"
|
||||||
if cls_name == 'SOP2':
|
if cls_name == 'SOP2':
|
||||||
sdst, ssrc0, ssrc1 = [unwrap(inst._values.get(f, 0)) for f in ('sdst', 'ssrc0', 'ssrc1')]
|
sdst, ssrc0, ssrc1 = [unwrap(inst._values.get(f, 0)) for f in ('sdst', 'ssrc0', 'ssrc1')]
|
||||||
return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {_fmt_ssrc(ssrc0, src0_cnt)}, {_fmt_ssrc(ssrc1, src1_cnt)}"
|
return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {_fmt_ssrc(ssrc0, src0_cnt)}, {_fmt_ssrc(ssrc1, src1_cnt)}"
|
||||||
@@ -476,7 +477,8 @@ def parse_operand(op: str) -> tuple:
|
|||||||
v = -int(m.group(1), 16) if op.startswith('-') else int(m.group(1), 16)
|
v = -int(m.group(1), 16) if op.startswith('-') else int(m.group(1), 16)
|
||||||
return (v, neg, abs_, hi_half)
|
return (v, neg, abs_, hi_half)
|
||||||
if op in SPECIAL_REGS: return (SPECIAL_REGS[op], neg, abs_, hi_half)
|
if op in SPECIAL_REGS: return (SPECIAL_REGS[op], neg, abs_, hi_half)
|
||||||
if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', op): return (REG_MAP[m.group(1)][int(m.group(2)):int(m.group(3))+1], neg, abs_, hi_half)
|
if op == 'lit': return (RawImm(255), neg, abs_, hi_half) # literal marker (actual value comes from literal word)
|
||||||
|
if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', op): return (REG_MAP[m.group(1)][int(m.group(2)):int(m.group(3))], neg, abs_, hi_half)
|
||||||
if m := re.match(r'^([svt](?:tmp)?)(\d+)$', op):
|
if m := re.match(r'^([svt](?:tmp)?)(\d+)$', op):
|
||||||
reg = REG_MAP[m.group(1)][int(m.group(2))]
|
reg = REG_MAP[m.group(1)][int(m.group(2))]
|
||||||
reg.hi = hi_half
|
reg.hi = hi_half
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class _RegFactory(Generic[T]):
|
|||||||
@overload
|
@overload
|
||||||
def __getitem__(self, key: slice) -> Reg: ...
|
def __getitem__(self, key: slice) -> Reg: ...
|
||||||
def __getitem__(self, key: int | slice) -> Reg:
|
def __getitem__(self, key: int | slice) -> Reg:
|
||||||
return self._cls(key.start, key.stop - key.start) if isinstance(key, slice) else self._cls(key)
|
return self._cls(key.start, key.stop - key.start + 1) if isinstance(key, slice) else self._cls(key)
|
||||||
def __repr__(self): return f"<{self._name} factory>"
|
def __repr__(self): return f"<{self._name} factory>"
|
||||||
|
|
||||||
class SGPR(Reg): pass
|
class SGPR(Reg): pass
|
||||||
@@ -118,8 +118,30 @@ class Inst:
|
|||||||
|
|
||||||
def __init__(self, *args, literal: int | None = None, **kwargs):
|
def __init__(self, *args, literal: int | None = None, **kwargs):
|
||||||
self._values, self._literal = dict(self._defaults), literal
|
self._values, self._literal = dict(self._defaults), literal
|
||||||
self._values.update(zip([n for n in self._fields if n != 'encoding'], args))
|
# Map positional args to field names
|
||||||
self._values.update(kwargs)
|
field_names = [n for n in self._fields if n != 'encoding']
|
||||||
|
orig_args = dict(zip(field_names, args))
|
||||||
|
orig_args.update(kwargs)
|
||||||
|
self._values.update(orig_args)
|
||||||
|
# Validate register counts for SMEM instructions (before encoding)
|
||||||
|
if self.__class__.__name__ == 'SMEM':
|
||||||
|
op_val = orig_args.get(field_names[0]) if args else orig_args.get('op')
|
||||||
|
if op_val is not None:
|
||||||
|
if hasattr(op_val, 'value'): op_val = op_val.value
|
||||||
|
expected_cnt = {0:1, 1:2, 2:4, 3:8, 4:16, 8:1, 9:2, 10:4, 11:8, 12:16}.get(op_val)
|
||||||
|
sdata_val = orig_args.get('sdata')
|
||||||
|
if expected_cnt is not None and isinstance(sdata_val, Reg) and sdata_val.count != expected_cnt:
|
||||||
|
raise ValueError(f"SMEM op {op_val} expects {expected_cnt} registers, got {sdata_val.count}")
|
||||||
|
# Validate register counts for SOP1 instructions (b32 = 1 reg, b64 = 2 regs)
|
||||||
|
if self.__class__.__name__ == 'SOP1':
|
||||||
|
op_val = orig_args.get(field_names[0]) if args else orig_args.get('op')
|
||||||
|
if op_val is not None and hasattr(op_val, 'name'):
|
||||||
|
expected = 2 if op_val.name.endswith('_B64') else 1
|
||||||
|
sdst_val, ssrc0_val = orig_args.get('sdst'), orig_args.get('ssrc0')
|
||||||
|
if isinstance(sdst_val, Reg) and sdst_val.count != expected:
|
||||||
|
raise ValueError(f"SOP1 {op_val.name} expects {expected} destination register(s), got {sdst_val.count}")
|
||||||
|
if isinstance(ssrc0_val, Reg) and ssrc0_val.count != expected:
|
||||||
|
raise ValueError(f"SOP1 {op_val.name} expects {expected} source register(s), got {ssrc0_val.count}")
|
||||||
# Type check and encode values
|
# Type check and encode values
|
||||||
for name, val in list(self._values.items()):
|
for name, val in list(self._values.items()):
|
||||||
if name == 'encoding': continue
|
if name == 'encoding': continue
|
||||||
@@ -143,6 +165,9 @@ class Inst:
|
|||||||
# Track literal value if needed (encoded as 255)
|
# Track literal value if needed (encoded as 255)
|
||||||
if encoded == 255 and self._literal is None and isinstance(val, int) and not isinstance(val, IntEnum):
|
if encoded == 255 and self._literal is None and isinstance(val, int) and not isinstance(val, IntEnum):
|
||||||
self._literal = val
|
self._literal = val
|
||||||
|
elif encoded == 255 and self._literal is None and isinstance(val, float):
|
||||||
|
import struct
|
||||||
|
self._literal = struct.unpack('<I', struct.pack('<f', val))[0]
|
||||||
# Encode raw register fields for consistent repr
|
# Encode raw register fields for consistent repr
|
||||||
elif name in RAW_FIELDS:
|
elif name in RAW_FIELDS:
|
||||||
if isinstance(val, Reg): self._values[name] = _encode_reg(val)
|
if isinstance(val, Reg): self._values[name] = _encode_reg(val)
|
||||||
@@ -209,6 +234,12 @@ class Inst:
|
|||||||
lit = f", literal={hex(self._literal)}" if self._literal is not None else ""
|
lit = f", literal={hex(self._literal)}" if self._literal is not None else ""
|
||||||
return f"{self.__class__.__name__}({', '.join(f'{k}={v}' for k, v in items)}{lit})"
|
return f"{self.__class__.__name__}({', '.join(f'{k}={v}' for k, v in items)}{lit})"
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, Inst): return NotImplemented
|
||||||
|
return self.__class__ == other.__class__ and self._values == other._values and self._literal == other._literal
|
||||||
|
|
||||||
|
def __hash__(self): return hash((self.__class__.__name__, tuple(sorted((k, repr(v)) for k, v in self._values.items())), self._literal))
|
||||||
|
|
||||||
def disasm(self) -> str:
|
def disasm(self) -> str:
|
||||||
from extra.assembly.rdna3.asm import disasm
|
from extra.assembly.rdna3.asm import disasm
|
||||||
return disasm(self)
|
return disasm(self)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# do not change these tests. we need to fix bugs to make them pass
|
# do not change these tests. we need to fix bugs to make them pass
|
||||||
# the Inst constructor should be looking at the types of the fields to correctly set the value
|
# the Inst constructor should be looking at the types of the fields to correctly set the value
|
||||||
|
|
||||||
import unittest
|
import unittest, struct
|
||||||
from extra.assembly.rdna3.autogen import *
|
from extra.assembly.rdna3.autogen import *
|
||||||
from extra.assembly.rdna3.lib import Inst
|
from extra.assembly.rdna3.lib import Inst
|
||||||
from extra.assembly.rdna3.asm import asm
|
from extra.assembly.rdna3.asm import asm
|
||||||
@@ -24,6 +24,33 @@ class TestIntegration(unittest.TestCase):
|
|||||||
def test_load_b128(self):
|
def test_load_b128(self):
|
||||||
self.inst = s_load_b128(s[4:7], s[0:1], NULL, 0)
|
self.inst = s_load_b128(s[4:7], s[0:1], NULL, 0)
|
||||||
|
|
||||||
|
def test_load_b128_wrong_size(self):
|
||||||
|
# this should have to be 4 regs on the loaded to
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
self.inst = s_load_b128(s[4:6], s[0:1], NULL, 0)
|
||||||
|
|
||||||
|
def test_mov_b32(self):
|
||||||
|
self.inst = s_mov_b32(s[80], s[0])
|
||||||
|
|
||||||
|
def test_mov_b64(self):
|
||||||
|
self.inst = s_mov_b64(s[80:81], s[0:1])
|
||||||
|
|
||||||
|
def test_mov_b32_wrong(self):
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
self.inst = s_mov_b32(s[80:81], s[0:1])
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
self.inst = s_mov_b32(s[80:81], s[0])
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
self.inst = s_mov_b32(s[80], s[0:1])
|
||||||
|
|
||||||
|
def test_mov_b64_wrong(self):
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
self.inst = s_mov_b64(s[80], s[0])
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
self.inst = s_mov_b64(s[80], s[0:1])
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
self.inst = s_mov_b64(s[80:81], s[0])
|
||||||
|
|
||||||
def test_load_b128_no_0(self):
|
def test_load_b128_no_0(self):
|
||||||
self.inst = s_load_b128(s[4:7], s[0:1], NULL)
|
self.inst = s_load_b128(s[4:7], s[0:1], NULL)
|
||||||
|
|
||||||
@@ -84,5 +111,68 @@ class TestIntegration(unittest.TestCase):
|
|||||||
def test_dual_mul(self):
|
def test_dual_mul(self):
|
||||||
self.inst = v_dual_mul_f32(VOPDOp.V_DUAL_MUL_F32, vdstx=v[0], vdsty=v[1], srcx0=v[2], vsrcx1=v[3], srcy0=v[4], vsrcy1=v[5])
|
self.inst = v_dual_mul_f32(VOPDOp.V_DUAL_MUL_F32, vdstx=v[0], vdsty=v[1], srcx0=v[2], vsrcx1=v[3], srcy0=v[4], vsrcy1=v[5])
|
||||||
|
|
||||||
|
def test_simple_int_to_s(self):
|
||||||
|
self.inst = s_mov_b32(s[0], 3)
|
||||||
|
|
||||||
|
def test_complex_int_to_s(self):
|
||||||
|
self.inst = s_mov_b32(s[0], 0x235646)
|
||||||
|
|
||||||
|
def test_simple_float_to_s(self):
|
||||||
|
self.inst = s_mov_b32(s[0], 1.0)
|
||||||
|
|
||||||
|
def test_complex_float_to_s(self):
|
||||||
|
self.inst = s_mov_b32(s[0], 1337.0)
|
||||||
|
int_inst = s_mov_b32(s[0], struct.unpack("I", struct.pack("f", 1337.0))[0])
|
||||||
|
self.assertEqual(self.inst, int_inst)
|
||||||
|
|
||||||
|
class TestRegisterSliceSyntax(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Issue: Register slice syntax should use AMD assembly convention (inclusive end).
|
||||||
|
|
||||||
|
In AMD assembly, s[4:7] means registers s4, s5, s6, s7 (4 registers, inclusive).
|
||||||
|
The DSL should match this convention so that:
|
||||||
|
- s[4:7] gives 4 registers
|
||||||
|
- Disassembler output can be copied directly back into DSL code
|
||||||
|
|
||||||
|
Fix: Change _RegFactory.__getitem__ to use inclusive end:
|
||||||
|
key.stop - key.start + 1 (instead of key.stop - key.start)
|
||||||
|
"""
|
||||||
|
def test_register_slice_count(self):
|
||||||
|
# s[4:7] should give 4 registers: s4, s5, s6, s7 (AMD convention, inclusive)
|
||||||
|
reg = s[4:7]
|
||||||
|
self.assertEqual(reg.count, 4, "s[4:7] should give 4 registers (s4, s5, s6, s7)")
|
||||||
|
|
||||||
|
def test_register_slice_roundtrip(self):
|
||||||
|
# Round-trip: DSL -> disasm -> DSL should preserve register count
|
||||||
|
reg = s[4:7] # 4 registers in AMD convention
|
||||||
|
inst = s_load_b128(reg, s[0:1], NULL, 0)
|
||||||
|
disasm = inst.disasm()
|
||||||
|
# Disasm shows s[4:7] - user should be able to copy this back
|
||||||
|
self.assertIn("s[4:7]", disasm)
|
||||||
|
# And s[4:7] in DSL should give the same 4 registers
|
||||||
|
reg_from_disasm = s[4:7]
|
||||||
|
self.assertEqual(reg_from_disasm.count, 4, "s[4:7] from disasm should give 4 registers")
|
||||||
|
|
||||||
|
class TestInstructionEquality(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Issue: No __eq__ method - instruction comparison requires repr() workaround.
|
||||||
|
|
||||||
|
Two identical instructions should compare equal with ==, but currently:
|
||||||
|
inst1 == inst2 returns False
|
||||||
|
|
||||||
|
The test_handwritten.py works around this with:
|
||||||
|
self.assertEqual(repr(self.inst), repr(reasm))
|
||||||
|
"""
|
||||||
|
def test_identical_instructions_equal(self):
|
||||||
|
inst1 = v_mov_b32_e32(v[0], v[1])
|
||||||
|
inst2 = v_mov_b32_e32(v[0], v[1])
|
||||||
|
self.assertEqual(inst1, inst2, "identical instructions should be equal")
|
||||||
|
|
||||||
|
def test_different_instructions_not_equal(self):
|
||||||
|
inst1 = v_mov_b32_e32(v[0], v[1])
|
||||||
|
inst2 = v_mov_b32_e32(v[0], v[2])
|
||||||
|
self.assertNotEqual(inst1, inst2, "different instructions should not be equal")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -189,7 +189,7 @@ class TestAsm(unittest.TestCase):
|
|||||||
def test_asm_register_range(self):
|
def test_asm_register_range(self):
|
||||||
"""Test parsing register ranges."""
|
"""Test parsing register ranges."""
|
||||||
inst = asm('s_load_b128 s[4:7], s[0:1], null')
|
inst = asm('s_load_b128 s[4:7], s[0:1], null')
|
||||||
self.assertEqual(inst.to_bytes(), s_load_b128(s[4:8], s[0:2], NULL).to_bytes())
|
self.assertEqual(inst.to_bytes(), s_load_b128(s[4:7], s[0:1], NULL).to_bytes())
|
||||||
|
|
||||||
def test_asm_matches_llvm(self):
|
def test_asm_matches_llvm(self):
|
||||||
"""Test asm() output matches LLVM assembler."""
|
"""Test asm() output matches LLVM assembler."""
|
||||||
|
|||||||
Reference in New Issue
Block a user