Files
tinygrad/extra/assembly/amd/dsl.py
George Hotz 7ebda28692 assembly/amd: add CDNA support to asm (#13982)
* add CDNA support

* more cdna tests

* something

* fix more stuff

* more work

* simpler

* simplier

* cdna

* disasm

* less skip

* fixes

* simpler
2026-01-04 08:53:56 -08:00

572 lines
30 KiB
Python

# library for RDNA3 assembly DSL
# mypy: ignore-errors
from __future__ import annotations
import struct, math, re
from enum import IntEnum
from functools import cache
from typing import overload, Annotated, TypeVar, Generic
from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, SOP1Op, SOP2Op,
SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp)
# Common masks and bit conversion functions
MASK32, MASK64, MASK128 = 0xffffffff, 0xffffffffffffffff, (1 << 128) - 1
_struct_f, _struct_I = struct.Struct("<f"), struct.Struct("<I")
_struct_e, _struct_H = struct.Struct("<e"), struct.Struct("<H")
_struct_d, _struct_Q = struct.Struct("<d"), struct.Struct("<Q")
def _f32(i):
i = i & MASK32
# RDNA3 default mode: flush f32 denormals to zero (FTZ)
# Denormal: exponent=0 (bits 23-30) and mantissa!=0 (bits 0-22)
if (i & 0x7f800000) == 0 and (i & 0x007fffff) != 0: return 0.0
return _struct_f.unpack(_struct_I.pack(i))[0]
def _i32(f):
if isinstance(f, int): f = float(f)
if math.isnan(f): return 0xffc00000 if math.copysign(1.0, f) < 0 else 0x7fc00000
if math.isinf(f): return 0x7f800000 if f > 0 else 0xff800000
try:
bits = _struct_I.unpack(_struct_f.pack(f))[0]
# RDNA3 default mode: flush f32 denormals to zero (FTZ)
if (bits & 0x7f800000) == 0 and (bits & 0x007fffff) != 0: return 0x80000000 if bits & 0x80000000 else 0
return bits
except (OverflowError, struct.error): return 0x7f800000 if f > 0 else 0xff800000
def _sext(v, b): return v - (1 << b) if v & (1 << (b - 1)) else v
def _f16(i): return _struct_e.unpack(_struct_H.pack(i & 0xffff))[0]
def _i16(f):
if math.isnan(f): return 0x7e00
if math.isinf(f): return 0x7c00 if f > 0 else 0xfc00
try: return _struct_H.unpack(_struct_e.pack(f))[0]
except (OverflowError, struct.error): return 0x7c00 if f > 0 else 0xfc00
def _f64(i): return _struct_d.unpack(_struct_Q.pack(i & MASK64))[0]
def _i64(f):
if math.isnan(f): return 0x7ff8000000000000
if math.isinf(f): return 0x7ff0000000000000 if f > 0 else 0xfff0000000000000
try: return _struct_Q.unpack(_struct_d.pack(f))[0]
except (OverflowError, struct.error): return 0x7ff0000000000000 if f > 0 else 0xfff0000000000000
# Instruction spec - register counts and dtypes derived from instruction names
_REGS = {'B32': 1, 'B64': 2, 'B96': 3, 'B128': 4, 'B256': 8, 'B512': 16,
'F32': 1, 'I32': 1, 'U32': 1, 'F64': 2, 'I64': 2, 'U64': 2,
'F16': 1, 'I16': 1, 'U16': 1, 'B16': 1, 'I8': 1, 'U8': 1, 'B8': 1}
_CVT_RE = re.compile(r'CVT_([FIUB]\d+)_([FIUB]\d+)$')
_MAD_MUL_RE = re.compile(r'(?:MAD|MUL)_([IU]\d+)_([IU]\d+)$')
_PACK_RE = re.compile(r'PACK_([FIUB]\d+)_([FIUB]\d+)$')
_DST_SRC_RE = re.compile(r'_([FIUB]\d+)_([FIUB]\d+)$')
_SINGLE_RE = re.compile(r'_([FIUB](?:32|64|16|8|96|128|256|512))$')
@cache
def _suffix(name: str) -> tuple[str | None, str | None]:
name = name.upper()
if m := _CVT_RE.search(name): return m.group(1), m.group(2)
if m := _MAD_MUL_RE.search(name): return m.group(1), m.group(2)
if m := _PACK_RE.search(name): return m.group(1), m.group(2)
if m := _DST_SRC_RE.search(name): return m.group(1), m.group(2)
if m := _SINGLE_RE.search(name): return m.group(1), m.group(1)
return None, None
_SPECIAL_REGS = {
'V_LSHLREV_B64': (2, 1, 2, 1), 'V_LSHRREV_B64': (2, 1, 2, 1), 'V_ASHRREV_I64': (2, 1, 2, 1),
'S_LSHL_B64': (2, 2, 1, 1), 'S_LSHR_B64': (2, 2, 1, 1), 'S_ASHR_I64': (2, 2, 1, 1),
'S_BFE_U64': (2, 2, 1, 1), 'S_BFE_I64': (2, 2, 1, 1), 'S_BFM_B64': (2, 1, 1, 1),
'S_BITSET0_B64': (2, 1, 1, 1), 'S_BITSET1_B64': (2, 1, 1, 1),
'S_BITCMP0_B64': (1, 2, 1, 1), 'S_BITCMP1_B64': (1, 2, 1, 1),
'V_LDEXP_F64': (2, 2, 1, 1), 'V_TRIG_PREOP_F64': (2, 2, 1, 1),
'V_CMP_CLASS_F64': (1, 2, 1, 1), 'V_CMPX_CLASS_F64': (1, 2, 1, 1),
'V_CMP_CLASS_F32': (1, 1, 1, 1), 'V_CMPX_CLASS_F32': (1, 1, 1, 1),
'V_CMP_CLASS_F16': (1, 1, 1, 1), 'V_CMPX_CLASS_F16': (1, 1, 1, 1),
'V_MAD_U64_U32': (2, 1, 1, 2), 'V_MAD_I64_I32': (2, 1, 1, 2),
'V_QSAD_PK_U16_U8': (2, 2, 1, 2), 'V_MQSAD_PK_U16_U8': (2, 2, 1, 2), 'V_MQSAD_U32_U8': (4, 2, 1, 4),
}
_SPECIAL_DTYPE = {
'V_LSHLREV_B64': ('B64', 'U32', 'B64', None), 'V_LSHRREV_B64': ('B64', 'U32', 'B64', None), 'V_ASHRREV_I64': ('I64', 'U32', 'I64', None),
'S_LSHL_B64': ('B64', 'B64', 'U32', None), 'S_LSHR_B64': ('B64', 'B64', 'U32', None), 'S_ASHR_I64': ('I64', 'I64', 'U32', None),
'S_BFE_U64': ('U64', 'U64', 'U32', None), 'S_BFE_I64': ('I64', 'I64', 'U32', None),
'S_BFM_B64': ('B64', 'U32', 'U32', None), 'S_BITSET0_B64': ('B64', 'U32', None, None), 'S_BITSET1_B64': ('B64', 'U32', None, None),
'S_BITCMP0_B64': ('SCC', 'B64', 'U32', None), 'S_BITCMP1_B64': ('SCC', 'B64', 'U32', None),
'V_LDEXP_F64': ('F64', 'F64', 'I32', None), 'V_TRIG_PREOP_F64': ('F64', 'F64', 'U32', None),
'V_CMP_CLASS_F64': ('VCC', 'F64', 'U32', None), 'V_CMPX_CLASS_F64': ('EXEC', 'F64', 'U32', None),
'V_CMP_CLASS_F32': ('VCC', 'F32', 'U32', None), 'V_CMPX_CLASS_F32': ('EXEC', 'F32', 'U32', None),
'V_CMP_CLASS_F16': ('VCC', 'F16', 'U32', None), 'V_CMPX_CLASS_F16': ('EXEC', 'F16', 'U32', None),
'V_MAD_U64_U32': ('U64', 'U32', 'U32', 'U64'), 'V_MAD_I64_I32': ('I64', 'I32', 'I32', 'I64'),
'V_QSAD_PK_U16_U8': ('B64', 'B64', 'B64', 'B64'), 'V_MQSAD_PK_U16_U8': ('B64', 'B64', 'B64', 'B64'),
'V_MQSAD_U32_U8': ('B128', 'B64', 'B64', 'B128'),
}
@cache
def spec_regs(name: str) -> tuple[int, int, int, int]:
uname = name.upper()
if uname in _SPECIAL_REGS: return _SPECIAL_REGS[uname]
if 'SAD' in uname and 'U8' in uname and 'QSAD' not in uname and 'MQSAD' not in uname: return 1, 1, 1, 1
dst_suf, src_suf = _suffix(name)
return _REGS.get(dst_suf, 1), _REGS.get(src_suf, 1), _REGS.get(src_suf, 1), _REGS.get(src_suf, 1)
@cache
def spec_dtype(name: str) -> tuple[str | None, str | None, str | None, str | None]:
uname = name.upper()
if uname in _SPECIAL_DTYPE: return _SPECIAL_DTYPE[uname]
if 'SAD' in uname and ('U8' in uname or 'U16' in uname) and 'QSAD' not in uname and 'MQSAD' not in uname: return 'U32', 'U32', 'U32', 'U32'
if '_CMP_' in uname or '_CMPX_' in uname:
dst_suf, src_suf = _suffix(name)
return 'EXEC' if '_CMPX_' in uname else 'VCC', src_suf, src_suf, None
dst_suf, src_suf = _suffix(name)
return dst_suf, src_suf, src_suf, src_suf
_F16_RE = re.compile(r'_[FIUB]16(?:_|$)')
_F64_RE = re.compile(r'_[FIUB]64(?:_|$)')
@cache
def spec_is_16bit(name: str) -> bool:
uname = name.upper()
if 'SAD' in uname or 'PACK' in uname or '_PK_' in uname or 'SAT_PK' in uname or 'DOT2' in uname: return False
if '_F32' in uname or '_I32' in uname or '_U32' in uname or '_B32' in uname: return False
return bool(_F16_RE.search(uname))
@cache
def spec_is_64bit(name: str) -> bool: return bool(_F64_RE.search(name.upper()))
_3SRC = {'FMA', 'MAD', 'MIN3', 'MAX3', 'MED3', 'DIV_FIX', 'DIV_FMAS', 'DIV_SCALE', 'SAD', 'LERP', 'ALIGN', 'CUBE', 'BFE', 'BFI',
'PERM_B32', 'PERMLANE', 'CNDMASK', 'XOR3', 'OR3', 'ADD3', 'LSHL_OR', 'AND_OR', 'LSHL_ADD', 'ADD_LSHL', 'XAD', 'MAXMIN',
'MINMAX', 'DOT2', 'DOT4', 'DOT8', 'WMMA', 'CVT_PK_U8', 'MULLIT', 'CO_CI'}
_2SRC = {'FMAC'} # FMAC uses dst as implicit accumulator, so only 2 explicit sources
def spec_num_srcs(name: str) -> int:
name = name.upper()
if any(k in name for k in _2SRC): return 2
return 3 if any(k in name for k in _3SRC) else 2
def is_dtype_16(dt: str | None) -> bool: return dt is not None and '16' in dt
def is_dtype_64(dt: str | None) -> bool: return dt is not None and '64' in dt
# Bit field DSL
class BitField:
def __init__(self, hi: int, lo: int, name: str | None = None): self.hi, self.lo, self.name, self._marker = hi, lo, name, None
def __set_name__(self, owner, name):
import typing
self.name, self._owner = name, owner
# Cache marker at class definition time
hints = typing.get_type_hints(owner, include_extras=True)
if name in hints:
hint = hints[name]
if typing.get_origin(hint) is Annotated:
args = typing.get_args(hint)
self._marker = args[1] if len(args) > 1 else None
def __eq__(self, val: int) -> tuple[BitField, int]: return (self, val) # type: ignore
def mask(self) -> int: return (1 << (self.hi - self.lo + 1)) - 1
@property
def marker(self) -> type | None: return self._marker
@overload
def __get__(self, obj: None, objtype: type) -> BitField: ...
@overload
def __get__(self, obj: object, objtype: type | None = None) -> int: ...
def __get__(self, obj, objtype=None):
if obj is None: return self
val = unwrap(obj._values.get(self.name, 0))
# Convert to IntEnum if marker is an IntEnum subclass
if self.marker and isinstance(self.marker, type) and issubclass(self.marker, IntEnum):
# VOP3 with VOPC opcodes (0-255) -> VOPCOp, VOP3SD opcodes -> VOP3SDOp
if self.marker is VOP3Op:
if val < 256: return VOPCOp(val)
if val in Inst._VOP3SD_OPS: return VOP3SDOp(val)
try: return self.marker(val)
except ValueError: pass
return val
class _Bits:
def __getitem__(self, key) -> BitField: return BitField(key.start, key.stop) if isinstance(key, slice) else BitField(key, key)
bits = _Bits()
# Source operand with modifiers - base class for anything that can be a src with neg/abs
class SrcMod:
__slots__ = ('val', 'neg', 'abs_')
def __init__(self, val: int, neg: bool = False, abs_: bool = False): self.val, self.neg, self.abs_ = val, neg, abs_
def __repr__(self): return f"{'-' if self.neg else ''}{'|' if self.abs_ else ''}{self.val}{'|' if self.abs_ else ''}"
def __neg__(self): return SrcMod(self.val, not self.neg, self.abs_)
def __abs__(self): return SrcMod(self.val, self.neg, True)
# Register types
class Reg(SrcMod):
__slots__ = ('idx', 'count', 'hi')
def __init__(self, idx: int, count: int = 1, hi: bool = False, neg: bool = False, abs_: bool = False):
self.idx, self.count, self.hi = idx, count, hi
super().__init__(idx, neg, abs_)
def __repr__(self): return f"{self.__class__.__name__.lower()[0]}[{self.idx}]" if self.count == 1 else f"{self.__class__.__name__.lower()[0]}[{self.idx}:{self.idx + self.count}]"
def __neg__(self): return self.__class__(self.idx, self.count, self.hi, not self.neg, self.abs_)
def __abs__(self): return self.__class__(self.idx, self.count, self.hi, self.neg, True)
@property
def l(self): return self.__class__(self.idx, self.count, False, self.neg, self.abs_)
@property
def h(self): return self.__class__(self.idx, self.count, True, self.neg, self.abs_)
T = TypeVar('T', bound=Reg)
class _RegFactory(Generic[T]):
def __init__(self, cls: type[T], name: str): self._cls, self._name = cls, name
@overload
def __getitem__(self, key: int) -> Reg: ...
@overload
def __getitem__(self, key: slice) -> Reg: ...
def __getitem__(self, key: int | slice) -> Reg:
return self._cls(key.start, key.stop - key.start + 1) if isinstance(key, slice) else self._cls(key)
def __repr__(self): return f"<{self._name} factory>"
class SGPR(Reg): pass
class VGPR(Reg): pass
class TTMP(Reg): pass
s: _RegFactory[SGPR] = _RegFactory(SGPR, "SGPR")
v: _RegFactory[VGPR] = _RegFactory(VGPR, "VGPR")
ttmp: _RegFactory[TTMP] = _RegFactory(TTMP, "TTMP")
# Special registers as SrcMod objects (support -VCC_LO, abs(EXEC_LO), etc.)
VCC_LO, VCC_HI, VCC = SrcMod(106), SrcMod(107), SrcMod(106)
EXEC_LO, EXEC_HI, EXEC = SrcMod(126), SrcMod(127), SrcMod(126)
SCC, M0, NULL, OFF = SrcMod(253), SrcMod(125), SrcMod(124), SrcMod(124)
# Field type markers (runtime classes for validation)
class _SSrc: pass
class _Src: pass
class _Imm: pass
class _SImm: pass
class _VDSTYEnc: pass # VOPD vdsty: encoded = actual >> 1, actual = (encoded << 1) | ((vdstx & 1) ^ 1)
class _SGPRField: pass
class _VGPRField: pass
# Type aliases for annotations - tells mypy it's a BitField while preserving marker info
SSrc = Annotated[BitField, _SSrc]
Src = Annotated[BitField, _Src]
Imm = Annotated[BitField, _Imm]
SImm = Annotated[BitField, _SImm]
VDSTYEnc = Annotated[BitField, _VDSTYEnc]
SGPRField = Annotated[BitField, _SGPRField]
VGPRField = Annotated[BitField, _VGPRField]
class RawImm:
def __init__(self, val: int): self.val = val
def __repr__(self): return f"RawImm({self.val})"
def __eq__(self, other): return isinstance(other, RawImm) and self.val == other.val
def unwrap(val) -> int:
if isinstance(val, RawImm): return val.val
if isinstance(val, SrcMod) and not isinstance(val, Reg): return val.val # Special registers like VCC_LO, NULL
if hasattr(val, 'value'): return val.value # IntEnum
if hasattr(val, 'idx'): return val.idx # Reg
return val
# Encoding/decoding constants
FLOAT_ENC = {0.5: 240, -0.5: 241, 1.0: 242, -1.0: 243, 2.0: 244, -2.0: 245, 4.0: 246, -4.0: 247}
FLOAT_DEC = {v: str(k) for k, v in FLOAT_ENC.items()}
SPECIAL_GPRS = {106: "vcc_lo", 107: "vcc_hi", 124: "null", 125: "m0", 126: "exec_lo", 127: "exec_hi", 253: "scc"}
SPECIAL_PAIRS = {106: "vcc", 126: "exec"}
SRC_FIELDS = {'src0', 'src1', 'src2', 'ssrc0', 'ssrc1', 'soffset', 'srcx0', 'srcy0'}
RAW_FIELDS = {'vdata', 'vdst', 'vaddr', 'addr', 'data', 'data0', 'data1', 'sdst', 'sdata', 'vsrc1'}
def _encode_reg(val: Reg) -> int: return (108 if isinstance(val, TTMP) else 0) + val.idx
def _is_inline_const(v: int) -> bool: return 0 <= v <= 127 or 128 <= v <= 208 or 240 <= v <= 255
def encode_src(val) -> int:
if isinstance(val, VGPR): return 256 + _encode_reg(val)
if isinstance(val, Reg): return _encode_reg(val)
if isinstance(val, SrcMod) and not isinstance(val, Reg): return val.val if _is_inline_const(val.val) else 255
if hasattr(val, 'value'): return val.value # IntEnum
if isinstance(val, float): return 128 if val == 0.0 else FLOAT_ENC.get(val, 255)
if isinstance(val, int): return 128 + val if 0 <= val <= 64 else 192 - val if -16 <= val <= -1 else 255
return 255
def decode_src(val: int) -> str:
if val <= 105: return f"s{val}"
if val in SPECIAL_GPRS: return SPECIAL_GPRS[val]
if val in FLOAT_DEC: return FLOAT_DEC[val]
if 108 <= val <= 123: return f"ttmp{val - 108}"
if 128 <= val <= 192: return str(val - 128)
if 193 <= val <= 208: return str(-(val - 192))
if 256 <= val <= 511: return f"v{val - 256}"
return "lit" if val == 255 else f"?{val}"
# Instruction base class
class Inst:
_fields: dict[str, BitField]
_encoding: tuple[BitField, int] | None = None
_defaults: dict[str, int] = {}
_values: dict[str, int | RawImm]
_words: int # size in 32-bit words, set by decode_program
_literal: int | None
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._fields = {n: v[0] if isinstance(v, tuple) else v for n, v in cls.__dict__.items() if isinstance(v, BitField) or (isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], BitField))}
if 'encoding' in cls._fields and isinstance(cls.__dict__.get('encoding'), tuple): cls._encoding = cls.__dict__['encoding']
def _or_field(self, name: str, bit: int):
cur = self._values.get(name, 0)
self._values[name] = (cur.val if isinstance(cur, RawImm) else cur) | bit
def _encode_src(self, name: str, val):
"""Encode a source field, handling modifiers and literals."""
encoded = encode_src(val)
has_opsel = 'opsel' in self._fields
if isinstance(val, Reg) and val.hi and not has_opsel: encoded |= 0x80 # hi bit in src for VOP1/2/C
self._values[name] = RawImm(encoded)
# Handle neg/abs/opsel modifiers
if isinstance(val, SrcMod):
mod_bit = {'src0': 1, 'src1': 2, 'src2': 4}.get(name, 0)
if val.neg and 'neg' in self._fields: self._or_field('neg', mod_bit)
if val.abs_ and 'abs' in self._fields: self._or_field('abs', mod_bit)
if isinstance(val, Reg) and val.hi and has_opsel:
self._or_field('opsel', {'src0': 1, 'src1': 2, 'src2': 4}.get(name, 0))
# Track literal value if needed
if encoded == 255 and self._literal is None:
import struct
# Check if THIS source uses 64-bit encoding (not just src0)
src_idx = {'src0': 0, 'src1': 1, 'src2': 2, 'ssrc0': 0, 'ssrc1': 1}.get(name, 0)
src_regs = self.src_regs(src_idx)
is_64 = src_regs == 2
if isinstance(val, SrcMod) and not isinstance(val, Reg): lit32 = val.val & MASK32
elif isinstance(val, int) and not isinstance(val, IntEnum): lit32 = val & MASK32
elif isinstance(val, float): lit32 = (_i64(val) >> 32) if is_64 else _i32(val) # f64: high 32 bits of f64 repr
else: return
self._literal = (lit32 << 32) if is_64 else lit32
def _encode_raw(self, name: str, val):
"""Encode a raw register field (vdst, vdata, etc.)."""
if isinstance(val, Reg):
encoded = _encode_reg(val)
if val.hi and 'opsel' not in self._fields: encoded |= 0x80
self._values[name] = encoded
if name == 'vdst' and val.hi and 'opsel' in self._fields: self._or_field('opsel', 8)
elif hasattr(val, 'value'): self._values[name] = val.value
def _validate(self, orig_args: dict):
"""Format-specific validation. Override in subclass or check by class name."""
cls_name, op = self.__class__.__name__, orig_args.get('op')
if hasattr(op, 'value'): op = op.value
# SMEM: register count must match opcode
if cls_name == 'SMEM' and op is not None:
expected = {0:1, 1:2, 2:4, 3:8, 4:16, 8:1, 9:2, 10:4, 11:8, 12:16}.get(op)
sdata = orig_args.get('sdata')
if expected and isinstance(sdata, Reg) and sdata.count != expected:
raise ValueError(f"SMEM op {op} expects {expected} registers, got {sdata.count}")
# SOP1: b32=1 reg, b64=2 regs
if cls_name == 'SOP1' and hasattr(orig_args.get('op'), 'name'):
expected = 2 if orig_args['op'].name.endswith('_B64') else 1
for fld in ('sdst', 'ssrc0'):
if isinstance(orig_args.get(fld), Reg) and orig_args[fld].count != expected:
raise ValueError(f"SOP1 {orig_args['op'].name} expects {expected} register(s) for {fld}, got {orig_args[fld].count}")
def __init__(self, *args, literal: int | None = None, **kwargs):
self._values, self._literal = dict(self._defaults), None
field_names = [n for n in self._fields if n != 'encoding']
# Map Python-friendly names to actual field names (abs_ -> abs for Python reserved word)
if 'abs_' in kwargs: kwargs['abs'] = kwargs.pop('abs_')
orig_args = dict(zip(field_names, args)) | kwargs
self._values.update(orig_args)
self._precompute()
self._validate(orig_args)
# Pre-shift literal for 64-bit sources (literal param is always raw 32-bit value from user)
if literal is not None:
# Find which source uses the literal (255) and check its register count
for n, idx in [('src0', 0), ('src1', 1), ('src2', 2), ('ssrc0', 0), ('ssrc1', 1)]:
v = orig_args.get(n)
if (isinstance(v, RawImm) and v.val == 255) or (isinstance(v, int) and v == 255):
self._literal = (literal << 32) if self.src_regs(idx) == 2 else literal
break
else:
self._literal = literal # fallback if no literal source found
cls_name = self.__class__.__name__
# Format-specific setup
if cls_name == 'FLAT' and 'sve' in self._fields:
seg = self._values.get('seg', 0)
if (seg.val if isinstance(seg, RawImm) else seg) == 1 and isinstance(orig_args.get('addr'), VGPR): self._values['sve'] = 1
if cls_name == 'VOP3P':
op = orig_args.get('op')
if hasattr(op, 'value'): op = op.value
if op in (32, 33, 34) and 'opsel_hi' not in orig_args: self._values['opsel_hi'] = self._values['opsel_hi2'] = 0
# Encode all fields
for name, val in list(self._values.items()):
if name == 'encoding': continue
if isinstance(val, RawImm):
if name in RAW_FIELDS: self._values[name] = val.val
continue
field = self._fields.get(name)
marker = field.marker if field else None
# Type validation
if marker is _SGPRField and isinstance(val, VGPR): raise TypeError(f"field '{name}' requires SGPR, got VGPR")
if marker is _VGPRField and not isinstance(val, VGPR): raise TypeError(f"field '{name}' requires VGPR, got {type(val).__name__}")
if marker is _SSrc and isinstance(val, VGPR): raise TypeError(f"field '{name}' requires scalar source, got VGPR")
# Encode by field type
if name in SRC_FIELDS: self._encode_src(name, val)
elif name in RAW_FIELDS: self._encode_raw(name, val)
elif name == 'sbase': self._values[name] = (val.idx if isinstance(val, Reg) else val.val if isinstance(val, SrcMod) else val * 2) // 2
elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg): self._values[name] = val.idx // 4
elif marker is _VDSTYEnc and isinstance(val, VGPR): self._values[name] = val.idx >> 1
self._precompute_fields()
def _encode_field(self, name: str, val) -> int:
if isinstance(val, RawImm): return val.val
if isinstance(val, SrcMod) and not isinstance(val, Reg): return val.val # Special regs like VCC_LO
if name in {'srsrc', 'ssamp'}: return val.idx // 4 if isinstance(val, Reg) else val
if name == 'sbase': return val.idx // 2 if isinstance(val, Reg) else val.val // 2 if isinstance(val, SrcMod) else val
if name in RAW_FIELDS: return _encode_reg(val) if isinstance(val, Reg) else val
if isinstance(val, Reg) or name in SRC_FIELDS: return encode_src(val)
return val.value if hasattr(val, 'value') else val
def to_int(self) -> int:
word = (self._encoding[1] & self._encoding[0].mask()) << self._encoding[0].lo if self._encoding else 0
for n, bf in self._fields.items():
if n != 'encoding' and n in self._values: word |= (self._encode_field(n, self._values[n]) & bf.mask()) << bf.lo
return word
def _get_literal(self) -> int | None:
for n in SRC_FIELDS:
if n in self._values and not isinstance(v := self._values[n], RawImm) and isinstance(v, int) and not isinstance(v, IntEnum) and not (0 <= v <= 64 or -16 <= v <= -1): return v
return None
def _is_64bit_op(self) -> bool:
"""Check if this instruction uses 64-bit operands (and thus 64-bit literals)."""
op = self._values.get('op')
if op is None: return False
op_name = op.name if hasattr(op, 'name') else None
# Look up op name from int if needed (happens in from_bytes path)
if op_name is None and self.__class__.__name__ == 'VOP3':
try: op_name = VOP3Op(op).name
except ValueError: pass
if op_name is None and self.__class__.__name__ == 'VOPC':
try: op_name = VOPCOp(op).name
except ValueError: pass
if op_name is None: return False
# V_LDEXP_F64 has 32-bit integer src1, so literal is 32-bit
return op_name != 'V_LDEXP_F64' and op_name.endswith(('_F64', '_B64', '_I64', '_U64'))
def to_bytes(self) -> bytes:
result = self.to_int().to_bytes(self._size(), 'little')
lit = self._get_literal() or getattr(self, '_literal', None)
if lit is None: return result
# For 64-bit sources, literal is stored in high 32 bits internally, but encoded as 4 bytes
# Find which source uses the literal (255) and check its register count
lit_src_is_64 = False
for n, idx in [('src0', 0), ('src1', 1), ('src2', 2), ('ssrc0', 0), ('ssrc1', 1)]:
if n not in self._values: continue
v = self._values[n]
if (isinstance(v, RawImm) and v.val == 255) or (isinstance(v, int) and v == 255):
lit_src_is_64 = self.is_src_64(idx)
break
lit32 = (lit >> 32) if lit_src_is_64 else lit
return result + (lit32 & MASK32).to_bytes(4, 'little')
@classmethod
def _size(cls) -> int: return 4 if issubclass(cls, Inst32) else 12 if issubclass(cls, Inst96) else 8
def size(self) -> int:
# Literal is always 4 bytes in the binary (for 64-bit ops, it's in high 32 bits)
return self._size() + (4 if self._literal is not None else 0)
@classmethod
def from_int(cls, word: int):
inst = object.__new__(cls)
inst._values = {n: RawImm(v) if n in SRC_FIELDS else v for n, bf in cls._fields.items() if n != 'encoding' for v in [(word >> bf.lo) & bf.mask()]}
inst._literal = None
inst._precompute()
inst._precompute_fields()
return inst
@classmethod
def from_bytes(cls, data: bytes):
import typing
inst = cls.from_int(int.from_bytes(data[:cls._size()], 'little'))
op_val = inst._values.get('op', 0)
# Check for instructions that always have a literal constant (FMAMK/FMAAK/MADMK/MADAK, SETREG_IMM32)
op_name = ''
if cls.__name__ in ('VOP2', 'SOP2', 'SOPK') and 'op' in (hints := typing.get_type_hints(cls, include_extras=True)):
if typing.get_origin(hints['op']) is typing.Annotated:
try: op_name = typing.get_args(hints['op'])[1](op_val).name
except (ValueError, TypeError): pass
has_literal = any(x in op_name for x in ('FMAMK', 'FMAAK', 'MADMK', 'MADAK', 'SETREG_IMM32'))
# VOPD fmaak/fmamk always have a literal (opx/opy value 1 or 2)
opx, opy = inst._values.get('opx', 0), inst._values.get('opy', 0)
has_literal = has_literal or (cls.__name__ == 'VOPD' and (opx in (1, 2) or opy in (1, 2)))
for n in SRC_FIELDS:
if n in inst._values and isinstance(inst._values[n], RawImm) and inst._values[n].val == 255: has_literal = True
if has_literal:
# For 64-bit ops, the literal is 32 bits placed in the HIGH 32 bits of the 64-bit value
# (low 32 bits are zero). This is how AMD hardware interprets 32-bit literals for 64-bit ops.
# Check which source uses the literal and whether THAT source is 64-bit
if len(data) >= cls._size() + 4:
lit32 = int.from_bytes(data[cls._size():cls._size()+4], 'little')
# Find which source has literal (255) and check its register count
lit_src_is_64 = False
for n, idx in [('src0', 0), ('src1', 1), ('src2', 2), ('ssrc0', 0), ('ssrc1', 1)]:
if n in inst._values and isinstance(inst._values[n], RawImm) and inst._values[n].val == 255:
lit_src_is_64 = inst.src_regs(idx) == 2
break
inst._literal = (lit32 << 32) if lit_src_is_64 else lit32
return inst
def __repr__(self):
# Use _fields order and exclude fields that are 0/default (for consistent repr after roundtrip)
def is_zero(v): return (isinstance(v, int) and v == 0) or (isinstance(v, VGPR) and v.idx == 0 and v.count == 1)
items = [(k, self._values[k]) for k in self._fields if k in self._values and k != 'encoding'
and not (is_zero(self._values[k]) and k not in {'op'})]
lit = f", literal={hex(self._literal)}" if self._literal is not None else ""
return f"{self.__class__.__name__}({', '.join(f'{k}={v}' for k, v in items)}{lit})"
def __getattr__(self, name: str):
if name.startswith('_'): raise AttributeError(name)
return unwrap(self._values.get(name, 0))
def lit(self, v: int, neg: bool = False) -> str:
if v == 255 and self._literal is not None:
# For 64-bit sources, literal is stored shifted - extract the 32-bit value
lit32 = (self._literal >> 32) if self._literal > 0xffffffff else self._literal
s = f"0x{lit32:x}"
else:
s = decode_src(v)
return f"-{s}" if neg else s
def __eq__(self, other):
if not isinstance(other, Inst): return NotImplemented
return self.__class__ == other.__class__ and self._values == other._values and self._literal == other._literal
def __hash__(self): return hash((self.__class__.__name__, tuple(sorted((k, repr(v)) for k, v in self._values.items())), self._literal))
def disasm(self) -> str:
from extra.assembly.amd.asm import disasm
return disasm(self)
_enum_map = {'VOP1': VOP1Op, 'VOP2': VOP2Op, 'VOP3': VOP3Op, 'VOP3SD': VOP3SDOp, 'VOP3P': VOP3POp, 'VOPC': VOPCOp,
'SOP1': SOP1Op, 'SOP2': SOP2Op, 'SOPC': SOPCOp, 'SOPK': SOPKOp, 'SOPP': SOPPOp,
'SMEM': SMEMOp, 'DS': DSOp, 'FLAT': FLATOp, 'MUBUF': MUBUFOp, 'MTBUF': MTBUFOp, 'MIMG': MIMGOp,
'VOPD': VOPDOp, 'VINTERP': VINTERPOp}
_VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
def _precompute(self):
"""Precompute op, op_name, _spec_regs, _spec_dtype for fast access."""
val = self._values.get('op')
if val is None: self.op = None
elif hasattr(val, 'name'): self.op = val
else:
cls_name = self.__class__.__name__
# VOP3 with VOPC opcodes (0-255) -> VOPCOp, VOP3SD opcodes -> VOP3SDOp
if cls_name == 'VOP3':
try:
if val < 256: self.op = VOPCOp(val)
elif val in self._VOP3SD_OPS: self.op = VOP3SDOp(val)
else: self.op = VOP3Op(val)
except ValueError: self.op = val
# Prefer BitField marker (class-specific enum) over _enum_map (generic RDNA3 enums)
elif 'op' in self._fields and (marker := self._fields['op'].marker) and issubclass(marker, IntEnum):
try: self.op = marker(val)
except ValueError: self.op = val
elif cls_name in self._enum_map:
try: self.op = self._enum_map[cls_name](val)
except ValueError: self.op = val
else: self.op = val
self.op_name = self.op.name if hasattr(self.op, 'name') else ''
self._spec_regs = spec_regs(self.op_name)
self._spec_dtype = spec_dtype(self.op_name)
def _precompute_fields(self):
"""Unwrap all field values as direct attributes for fast access."""
for name, val in self._values.items():
if name != 'op': setattr(self, name, unwrap(val))
def dst_regs(self) -> int: return self._spec_regs[0]
def src_regs(self, n: int) -> int: return self._spec_regs[n + 1]
def num_srcs(self) -> int: return spec_num_srcs(self.op_name)
def dst_dtype(self) -> str | None: return self._spec_dtype[0]
def src_dtype(self, n: int) -> str | None: return self._spec_dtype[n + 1]
def is_src_16(self, n: int) -> bool: return self._spec_regs[n + 1] == 1 and is_dtype_16(self._spec_dtype[n + 1])
def is_src_64(self, n: int) -> bool: return self._spec_regs[n + 1] == 2
def is_16bit(self) -> bool: return spec_is_16bit(self.op_name)
def is_64bit(self) -> bool: return spec_is_64bit(self.op_name)
def is_dst_16(self) -> bool: return self._spec_regs[0] == 1 and is_dtype_16(self._spec_dtype[0])
class Inst32(Inst): pass
class Inst64(Inst): pass
class Inst96(Inst): pass