assembly/amd: refactor to use op_bits/op_regs (#14156)

* assembly/amd: refactor to use op_bits/op_regs

* remove that skip

* remove another hack

* remove another hack

* precompute mask

* more reg, less hasattr
This commit is contained in:
George Hotz
2026-01-15 11:20:21 +09:00
committed by GitHub
parent add7da268f
commit fd60626ea1
6 changed files with 165 additions and 180 deletions

View File

@@ -16,7 +16,7 @@ def _matches_encoding(word: int, cls: type[Inst]) -> bool:
enc = next(((n, f) for n, f in cls._fields if isinstance(f, FixedBitField) and n == 'encoding'), None)
if enc is None: return False
bf = enc[1]
return ((word >> bf.lo) & bf.mask()) == bf.default
return ((word >> bf.lo) & bf.mask) == bf.default
# Order matters: more specific encodings first, VOP2 last (it's a catch-all for bit31=0)
_RDNA_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP]

View File

@@ -186,10 +186,10 @@ def _disasm_vop1(inst: VOP1) -> str:
src = _vreg(inst.src0) if _unwrap(inst.src0) >= 256 else decode_src(_unwrap(inst.src0), cdna)
vdst_raw = _unwrap(inst.vdst)
return f"{name} {_fmt_sdst(vdst_raw - 256 if vdst_raw >= 256 else vdst_raw, 1, cdna)}, {src}"
# Use operand info for register sizes and 16-bit detection
dregs, sregs = inst.dst_regs(), inst.src_regs(0)
is16_dst = not cdna and inst.is_dst_16()
is16_src = not cdna and inst.is_src_16(0)
# Use get_field_bits for register sizes and 16-bit detection
bits = inst.canonical_op_bits
dregs, sregs = max(1, bits['d'] // 32), max(1, bits['s0'] // 32)
is16_dst, is16_src = not cdna and bits['d'] == 16, not cdna and bits['s0'] == 16
# v_cvt_pk_f32_fp8/bf8: pcode has None dst type but outputs 2 VGPRs
if 'cvt_pk_f32_fp8' in name or 'cvt_pk_f32_bf8' in name: dregs, is16_src = 2, True
# Format dst
@@ -216,8 +216,9 @@ def _disasm_vop2(inst: VOP2) -> str:
if cdna: name = _CDNA_DISASM_ALIASES.get(name, name) # apply CDNA aliases
suf = "" if cdna or name.endswith('_e32') or (not cdna and inst.op == VOP2Op.V_DOT2ACC_F32_F16_E32) else "_e32"
lit = getattr(inst, '_literal', None)
# Use operand info for 16-bit detection
is16 = not cdna and inst.is_dst_16()
# Use get_field_bits for 16-bit detection
bits = inst.canonical_op_bits
is16 = not cdna and bits['d'] == 16
# fmaak/madak: dst = src0 * vsrc1 + K, fmamk/madmk: dst = src0 * K + vsrc1
if 'fmaak' in name or 'madak' in name or (not cdna and inst.op in (VOP2Op.V_FMAAK_F32_E32, VOP2Op.V_FMAAK_F16_E32)):
if lit is None: return f"op_{inst.op.value if hasattr(inst.op, 'value') else inst.op}"
@@ -234,8 +235,9 @@ def _disasm_vop2(inst: VOP2) -> str:
if cdna and name in _VOP2_CARRY_INOUT: return f"{name}{suf} {_vreg(inst.vdst)}, {vcc}, {_lit(inst, inst.src0)}, {_vreg(inst.vsrc1)}, {vcc}"
# RDNA carry-in/out ops: v_add_co_ci_u32, etc.
if not cdna and name in _VOP2_CARRY_INOUT_RDNA: return f"{name}{suf} {_vreg(inst.vdst)}, {vcc}, {_lit(inst, inst.src0)}, {_vreg(inst.vsrc1)}, {vcc}"
# Use pcode types for register sizes - pcode is complete for all VOP2 ops
dn, sn0, sn1 = inst.dst_regs(), inst.src_regs(0), inst.src_regs(1)
# Use get_field_bits for register sizes
regs = inst.canonical_op_regs
dn, sn0, sn1 = regs.get('d', 1), regs.get('s0', 1), regs.get('s1', 1)
if dn > 1 or sn0 > 1 or sn1 > 1:
dst = _vreg(inst.vdst, dn)
src0 = _lit(inst, inst.src0) if _unwrap(inst.src0) == 255 else _fmt_src(inst.src0, sn0, cdna)
@@ -245,9 +247,10 @@ def _disasm_vop2(inst: VOP2) -> str:
def _disasm_vopc(inst: VOPC) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
# Use operand info for register sizes and 16-bit detection
r0, r1 = inst.src_regs(0), inst.src_regs(1)
is16 = inst.is_src_16(0)
# Use get_field_bits for register sizes and 16-bit detection
bits = inst.canonical_op_bits
r0, r1 = max(1, bits['s0'] // 32), max(1, bits['s1'] // 32)
is16 = bits['s0'] == 16
if cdna:
s0 = _lit(inst, inst.src0) if _unwrap(inst.src0) == 255 else _fmt_src(inst.src0, r0, cdna)
s1 = _vreg(inst.vsrc1, r1) if r1 > 1 else _vreg(inst.vsrc1)
@@ -321,8 +324,8 @@ def _disasm_smem(inst: SMEM) -> str:
soff_s = decode_src(inst.soffset, cdna) if inst.soffset != 124 else "null"
if 'pc_rel' in name: return f"{name} {off_s}, {soff_s}, {_unwrap(inst.sdata)}"
return f"{name} {sbase_str}, {off_s}, {soff_s}, {_unwrap(inst.sdata)}"
# Use operand info for register count
dst_n = inst.dst_regs()
# Use get_field_bits for register count
dst_n = inst.canonical_op_regs.get('d', 1)
th, scope = getattr(inst, 'th', 0), getattr(inst, 'scope', 0)
if is_rdna4: # RDNA4 uses th/scope instead of glc/dlc
th_names = ['TH_LOAD_RT', 'TH_LOAD_NT', 'TH_LOAD_HT', 'TH_LOAD_LU']
@@ -343,8 +346,9 @@ def _disasm_flat(inst: FLAT) -> str:
seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}"
off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192)
# Use operand info: data_regs for stores/atomics, dst_regs for loads
w = inst.data_regs() if 'store' in name or 'atomic' in name else inst.dst_regs()
# Use get_field_bits: data for stores/atomics, d for loads
regs = inst.canonical_op_regs
w = regs.get('data', regs.get('d', 1)) if 'store' in name or 'atomic' in name else regs.get('d', 1)
off_s = f" offset:{off_val}" if off_val else ""
if cdna: mods = f"{off_s}{' glc' if inst.sc0 else ''}{' slc' if inst.nt else ''}"
else: mods = f"{off_s}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
@@ -372,8 +376,9 @@ def _disasm_ds(inst: DS) -> str:
gds = " gds" if getattr(inst, 'gds', 0) else ""
off = f" offset:{inst.offset0 | (inst.offset1 << 8)}" if inst.offset0 or inst.offset1 else ""
off2 = (" offset0:" + str(inst.offset0) if inst.offset0 else "") + (" offset1:" + str(inst.offset1) if inst.offset1 else "")
# Use operand info: data_regs for stores/writes/atomics, dst_regs for loads
w = inst.data_regs() if 'store' in name or 'write' in name or ('load' not in name and 'read' not in name) else inst.dst_regs()
# Use get_field_bits: data for stores/writes/atomics, d for loads
regs = inst.canonical_op_regs
w = regs.get('data', regs.get('d', 1)) if 'store' in name or 'write' in name or ('load' not in name and 'read' not in name) else regs.get('d', 1)
d0, d1, dst, addr = reg_fn(inst.data0, w), reg_fn(inst.data1, w), reg_fn(inst.vdst, w), _vreg(inst.addr)
if name == 'ds_nop': return name
@@ -388,11 +393,11 @@ def _disasm_ds(inst: DS) -> str:
if name in ('ds_consume', 'ds_append'): return f"{name} {reg_fn(inst.vdst)}{off}{gds}"
if 'gs_reg' in name: return f"{name} {reg_fn(inst.vdst, 2)}, {reg_fn(inst.data0)}{off}{gds}"
if '2addr' in name:
if 'load' in name: return f"{name} {reg_fn(inst.vdst, inst.dst_regs())}, {addr}{off2}{gds}"
if 'load' in name: return f"{name} {reg_fn(inst.vdst, regs.get('d', 1))}, {addr}{off2}{gds}"
if 'store' in name and 'xchg' not in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}"
return f"{name} {reg_fn(inst.vdst, inst.dst_regs())}, {addr}, {d0}, {d1}{off2}{gds}"
return f"{name} {reg_fn(inst.vdst, regs.get('d', 1))}, {addr}, {d0}, {d1}{off2}{gds}"
if 'write2' in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}"
if 'read2' in name: return f"{name} {reg_fn(inst.vdst, inst.dst_regs())}, {addr}{off2}{gds}"
if 'read2' in name: return f"{name} {reg_fn(inst.vdst, regs.get('d', 1))}, {addr}{off2}{gds}"
if 'load' in name: return f"{name} {reg_fn(inst.vdst)}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}"
if 'store' in name and not _has(name, 'cmp', 'xchg'):
return f"{name} {reg_fn(inst.data0)}{off}{gds}" if 'addtid' in name else f"{name} {addr}, {d0}{off}{gds}"
@@ -406,20 +411,21 @@ def _disasm_ds(inst: DS) -> str:
def _disasm_vop3(inst: VOP3) -> str:
op, name = inst.op, inst.op_name.lower()
n_up = name.upper()
bits = inst.canonical_op_bits
# RDNA4 v_s_* scalar VOP3 instructions - vdst is SGPR (VGPRField adds 256)
if name.startswith('v_s_'):
src = _lit(inst, inst.src0) if _unwrap(inst.src0) == 255 else ("src_scc" if _unwrap(inst.src0) == 253 else _fmt_src(inst.src0, inst.src_regs(0)))
src = _lit(inst, inst.src0) if _unwrap(inst.src0) == 255 else ("src_scc" if _unwrap(inst.src0) == 253 else _fmt_src(inst.src0, max(1, bits['s0'] // 32)))
if inst.neg & 1: src = f"-{src}"
if inst.abs & 1: src = f"|{src}|"
clamp = getattr(inst, 'cm', None) or getattr(inst, 'clmp', 0)
vdst_raw = _unwrap(inst.vdst)
return f"{name} s{vdst_raw - 256 if vdst_raw >= 256 else vdst_raw}, {src}" + (" clamp" if clamp else "") + _omod(inst.omod)
# Use operand info for register sizes and 16-bit detection
r0, r1, r2 = inst.src_regs(0), inst.src_regs(1), inst.src_regs(2)
dn = inst.dst_regs()
is16_d, is16_s, is16_s2 = inst.is_dst_16(), inst.is_src_16(0), inst.is_src_16(2)
# Use get_field_bits for register sizes and 16-bit detection
r0, r1, r2 = max(1, bits['s0'] // 32), max(1, bits['s1'] // 32), max(1, bits['s2'] // 32)
dn = max(1, bits['d'] // 32)
is16_d, is16_s, is16_s2 = bits['d'] == 16, bits['s0'] == 16, bits['s2'] == 16
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, r0, is16_s)
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, r1, is16_s)
@@ -460,20 +466,18 @@ def _disasm_vop3(inst: VOP3) -> str:
return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}"
def _disasm_vop3sd(inst: VOP3SD) -> str:
name = inst.op_name.lower()
# Use pcode types for register sizes
dn, sr0, sr1, sr2 = inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2)
name, regs = inst.op_name.lower(), inst.canonical_op_regs
dn, sr0, sr1, sr2 = regs['d'], regs['s0'], regs['s1'], regs['s2']
def src(v, neg, n):
v = _unwrap(v)
s = _lit(inst, v) if v == 255 else ("src_scc" if v == 253 else (_fmt_src(v, n) if n > 1 else _lit(inst, v)))
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
s0, s1, s2 = src(inst.src0, inst.neg & 1, sr0), src(inst.src1, inst.neg & 2, sr1), src(inst.src2, inst.neg & 4, sr2)
dst = _vreg(inst.vdst, dn)
# VOP3SD: _co_ ops (add/sub) without _ci_ have only 2 sources, all others (mad, div_scale, _co_ci_) have 3 sources
has_only_two_srcs = '_co_' in name and '_ci_' not in name and 'mad' not in name
srcs = f"{s0}, {s1}" if has_only_two_srcs else f"{s0}, {s1}, {s2}"
clamp = getattr(inst, 'cm', None) or getattr(inst, 'clmp', 0)
return f"{name} {dst}, {_fmt_sdst(inst.sdst, 1)}, {srcs}{' clamp' if clamp else ''}{_omod(inst.omod)}"
return f"{name} {_vreg(inst.vdst, dn)}, {_fmt_sdst(inst.sdst, 1)}, {srcs}{' clamp' if clamp else ''}{_omod(inst.omod)}"
def _disasm_vopd(inst: VOPD) -> str:
lit = inst._literal or getattr(inst, 'literal', None)
@@ -494,9 +498,9 @@ def _disasm_vop3p(inst: VOP3P) -> str:
def get_src(v, sc):
uv = _unwrap(v)
return _lit(inst, uv) if uv == 255 else _fmt_src(uv, sc)
# Use operand info for register sizes
dn, s0n, s1n, s2n = inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2)
src0, src1, src2, dst = get_src(inst.src0, s0n), get_src(inst.src1, s1n), get_src(inst.src2, s2n), _vreg(inst.vdst, dn)
# Use get_field_bits for register sizes
regs = inst.canonical_op_regs
src0, src1, src2, dst = get_src(inst.src0, regs['s0']), get_src(inst.src1, regs['s1']), get_src(inst.src2, regs['s2']), _vreg(inst.vdst, regs['d'])
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
clamp = getattr(inst, 'cm', None) or getattr(inst, 'clmp', 0)
if is_fma_mix:
@@ -584,8 +588,9 @@ def _disasm_mimg(inst: MIMG) -> str:
def _disasm_sop1(inst: SOP1) -> str:
op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst)
# Use operand info for register sizes
dst_regs, src_regs = inst.dst_regs(), inst.src_regs(0)
# Use get_field_bits for register sizes
regs = inst.canonical_op_regs
dst_regs, src_regs = regs.get('d', 1), regs.get('s0', 1)
src = _lit(inst, inst.ssrc0) if _unwrap(inst.ssrc0) == 255 else _fmt_src(inst.ssrc0, src_regs, cdna)
if not cdna:
if 'getpc_b64' in name: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
@@ -602,8 +607,9 @@ def _disasm_sop1(inst: SOP1) -> str:
def _disasm_sop2(inst: SOP2) -> str:
cdna, name = _is_cdna(inst), inst.op_name.lower()
lit = getattr(inst, '_literal', None)
# Use operand info for register sizes
dn, s0n, s1n = inst.dst_regs(), inst.src_regs(0), inst.src_regs(1)
# Use get_field_bits for register sizes
regs = inst.canonical_op_regs
dn, s0n, s1n = regs['d'], regs['s0'], regs['s1']
s0 = _lit(inst, inst.ssrc0) if _unwrap(inst.ssrc0) == 255 else _fmt_src(inst.ssrc0, s0n, cdna)
s1 = _lit(inst, inst.ssrc1) if _unwrap(inst.ssrc1) == 255 else _fmt_src(inst.ssrc1, s1n, cdna)
dst = _fmt_sdst(inst.sdst, dn, cdna)
@@ -612,10 +618,9 @@ def _disasm_sop2(inst: SOP2) -> str:
return f"{name} {dst}, {s0}, {s1}"
def _disasm_sopc(inst: SOPC) -> str:
cdna = _is_cdna(inst)
s0_regs, s1_regs = inst.src_regs(0), inst.src_regs(1) # pcode types are complete for all SOPC ops
s0 = _lit(inst, inst.ssrc0) if _unwrap(inst.ssrc0) == 255 else _fmt_src(inst.ssrc0, s0_regs, cdna)
s1 = _lit(inst, inst.ssrc1) if _unwrap(inst.ssrc1) == 255 else _fmt_src(inst.ssrc1, s1_regs, cdna)
cdna, regs = _is_cdna(inst), inst.canonical_op_regs
s0 = _lit(inst, inst.ssrc0) if _unwrap(inst.ssrc0) == 255 else _fmt_src(inst.ssrc0, regs['s0'], cdna)
s1 = _lit(inst, inst.ssrc1) if _unwrap(inst.ssrc1) == 255 else _fmt_src(inst.ssrc1, regs['s1'], cdna)
return f"{inst.op_name.lower()} {s0}, {s1}"
def _disasm_sopk(inst: SOPK) -> str:
@@ -636,7 +641,7 @@ def _disasm_sopk(inst: SOPK) -> str:
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1, cdna)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1, cdna)}, {hs}"
if name in ('s_subvector_loop_begin', 's_subvector_loop_end'):
return f"{name} {_fmt_sdst(inst.sdst, 1)}, 0x{inst.simm16:x}"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, 0x{inst.simm16:x}" # pcode types are complete for SOPK
return f"{name} {_fmt_sdst(inst.sdst, inst.canonical_op_regs['d'], cdna)}, 0x{inst.simm16:x}"
def _disasm_vinterp(inst: VINTERP) -> str:
mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp"))
@@ -655,7 +660,7 @@ def _disasm_vbuffer(inst) -> str:
name = inst.op_name.lower().replace('buffer_', 'buffer_').replace('tbuffer_', 'tbuffer_')
w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \
((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \
{'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], inst.dst_regs())
{'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], inst.canonical_op_regs['d'])
if getattr(inst, 'tfe', 0): w += 1
vdata = _vreg(inst.vdata, w) if w else _vreg(inst.vdata)
vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else (_vreg(inst.vaddr) if inst.offen or inst.idxen else 'off')
@@ -714,7 +719,8 @@ try:
s2 = ""
dst = _vreg(inst.vdst)
else:
dregs, r0, r1, r2 = inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2)
regs = inst.canonical_op_regs
dregs, r0, r1, r2 = regs['d'], regs['s0'], regs['s1'], regs['s2']
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, r0), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, r1), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, r2)
dst = _vreg(inst.vdst, dregs) if dregs > 1 else _vreg(inst.vdst)
if op_val >= 512:
@@ -738,7 +744,8 @@ try:
if hasattr(op_val, 'value'): op_val = op_val.value
name = inst.op_name.lower() or f'vop3b_op_{op_val}'
n = inst.num_srcs() or _num_srcs(inst)
dregs, r0, r1, r2 = inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2)
regs = inst.canonical_op_regs
dregs, r0, r1, r2 = regs['d'], regs['s0'], regs['s1'], regs['s2']
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, n=r0), _cdna_src(inst, inst.src1, inst.neg&2, n=r1), _cdna_src(inst, inst.src2, inst.neg&4, n=r2)
dst = _vreg(inst.vdst, dregs) if dregs > 1 else _vreg(inst.vdst)
sdst = _fmt_sdst(inst.sdst, 2, cdna=True)

View File

@@ -96,13 +96,12 @@ bits = _Bits()
class BitField:
def __init__(self, hi: int, lo: int, default: int = 0):
self.hi, self.lo, self.default, self.name = hi, lo, default, None
self.hi, self.lo, self.default, self.name, self.mask = hi, lo, default, None, (1 << (hi - lo + 1)) - 1
def __set_name__(self, owner, name): self.name = name
def __eq__(self, other) -> 'FixedBitField':
if isinstance(other, int): return FixedBitField(self.hi, self.lo, other)
return NotImplemented
def enum(self, enum_cls) -> 'EnumBitField': return EnumBitField(self.hi, self.lo, enum_cls)
def mask(self) -> int: return (1 << (self.hi - self.lo + 1)) - 1
def encode(self, val) -> int:
assert isinstance(val, int), f"BitField.encode expects int, got {type(val).__name__}"
return val
@@ -110,11 +109,11 @@ class BitField:
def set(self, raw: int, val) -> int:
if val is None: val = self.default
encoded = self.encode(val)
if encoded < 0 or encoded > self.mask(): raise RuntimeError(f"field '{self.name}': value {encoded} doesn't fit in {self.hi - self.lo + 1} bits")
return (raw & ~(self.mask() << self.lo)) | (encoded << self.lo)
if encoded < 0 or encoded > self.mask: raise RuntimeError(f"field '{self.name}': value {encoded} doesn't fit in {self.hi - self.lo + 1} bits")
return (raw & ~(self.mask << self.lo)) | (encoded << self.lo)
def __get__(self, obj, objtype=None):
if obj is None: return self
return self.decode((obj._raw >> self.lo) & self.mask())
return self.decode((obj._raw >> self.lo) & self.mask)
class FixedBitField(BitField):
def set(self, raw: int, val=None) -> int:
@@ -203,6 +202,7 @@ class VDSTYField(BitField):
# Operand info from XML
# ══════════════════════════════════════════════════════════════
import functools
from extra.assembly.amd.autogen.rdna3.operands import OPERANDS as OPERANDS_RDNA3
from extra.assembly.amd.autogen.rdna4.operands import OPERANDS as OPERANDS_RDNA4
from extra.assembly.amd.autogen.cdna.operands import OPERANDS as OPERANDS_CDNA
@@ -259,7 +259,7 @@ class Inst:
if isinstance(field, SrcField) and val is not None and field.encode(val) + field._valid_range[0] == 255 and self._literal is None:
self._literal = _f32(val) if isinstance(val, float) else val & 0xFFFFFFFF
# Validate register sizes against operand info (skip special registers like NULL, VCC, EXEC)
for name, expected in self._get_field_sizes(vals).items():
for name, expected in self.op_regs.items():
if (val := vals.get(name)) is None: continue
if isinstance(val, Reg) and val.sz != expected and not (106 <= val.offset <= 127 or val.offset == 253):
raise TypeError(f"{name} expects {expected} register(s), got {val.sz}")
@@ -269,60 +269,49 @@ class Inst:
@property
def operands(self) -> dict: return OPERANDS.get(self.op, {}) if hasattr(self, 'op') else {}
def _is_cdna(self) -> bool: return 'cdna' in type(self).__module__
def _get_field_sizes(self, vals: dict) -> dict[str, int]:
"""Map field names to expected register sizes based on operand info."""
sizes = {k: (v[1] + 31) // 32 for k, v in self.operands.items()}
if not hasattr(self, 'op'): return sizes
name = self.op_name.lower()
# RDNA (WAVE32): condition masks and carry flags are 32-bit; CDNA (WAVE64) uses 64-bit
@functools.cached_property
def op_bits(self) -> dict[str, int]:
"""Get bit widths for each operand field, with WAVE32 and addr/saddr adjustments."""
if not hasattr(self, 'op'): return {k: v[1] for k, v in self.operands.items()}
bits = {k: v[1] for k, v in self.operands.items()}
# RDNA (WAVE32): condition masks, carry flags, and compare results are 32-bit
if not self._is_cdna():
if 'cndmask' in name and 'src2' in sizes: sizes['src2'] = 1
name = self.op_name.lower()
if 'cndmask' in name and 'src2' in bits: bits['src2'] = 32
if '_co_ci_' in name:
if 'src2' in sizes: sizes['src2'] = 1
if 'sdst' in sizes: sizes['sdst'] = 1
if 'src2' in bits: bits['src2'] = 32
if 'sdst' in bits: bits['sdst'] = 32
if 'cmp' in name and 'vdst' in bits: bits['vdst'] = 32
# GLOBAL/FLAT: addr is 32-bit if saddr is valid SGPR, 64-bit if saddr is NULL
# Check vals for saddr since some ops have the field but not in operand info
if 'addr' in sizes and ('saddr' in sizes or 'saddr' in vals):
saddr_val = vals.get('saddr')
if isinstance(saddr_val, Reg): saddr_val = saddr_val.offset
is_null_saddr = saddr_val in (None, 124, 125) # 124=NULL, 125=M0
sizes['addr'] = 2 if is_null_saddr else 1
# saddr is 2 SGPRs when not NULL, otherwise skip validation (NULL is special single reg)
if is_null_saddr: sizes.pop('saddr', None)
# MUBUF/MTBUF: vaddr is variable (0-2 regs depending on idxen/offen), vdata depends on format
if 'vaddr' in sizes: sizes.pop('vaddr')
if 'vdata' in sizes: sizes.pop('vdata')
# VOPC/VOP3 vdst for compares is wave-size dependent
if 'vdst' in sizes and 'cmp' in name: sizes.pop('vdst')
return sizes
def _field_bits(self, name: str) -> int:
"""Get size in bits for a field from operand info."""
return self.operands.get(name, (None, 0, None))[1]
def is_src_64(self, n: int) -> bool:
for name in (['src0', 'vsrc0', 'ssrc0'] if n == 0 else ['src1', 'vsrc1', 'ssrc1'] if n == 1 else ['src2']):
if name in self.operands: return self.operands[name][1] == 64
return False
def is_src_16(self, n: int) -> bool:
for name in (['src0', 'vsrc0', 'ssrc0'] if n == 0 else ['src1', 'vsrc1', 'ssrc1'] if n == 1 else ['src2']):
if name in self.operands: return self.operands[name][1] == 16
return False
def is_dst_16(self) -> bool:
for name in ['vdst', 'sdst', 'sdata']:
if name in self.operands: return self.operands[name][1] == 16
return False
def dst_regs(self) -> int:
for name in ['vdst', 'sdst', 'sdata']:
if name in self.operands: return max(1, self.operands[name][1] // 32)
return 1
def data_regs(self) -> int:
"""Get data register count for memory ops (stores use 'data' field, loads use 'vdst')."""
for name in ['data', 'vdata', 'data0']:
if name in self.operands: return max(1, self.operands[name][1] // 32)
return self.dst_regs() # fallback to vdst for loads
def src_regs(self, n: int) -> int:
for name in (['src0', 'vsrc0', 'ssrc0'] if n == 0 else ['src1', 'vsrc1', 'ssrc1'] if n == 1 else ['src2']):
if name in self._field_sizes: return self._field_sizes[name]
return 1
if 'addr' in bits and hasattr(self, 'saddr'):
saddr_val = self.saddr.offset if isinstance(self.saddr, Reg) else self.saddr
bits['addr'] = 64 if saddr_val in (None, 124, 125) else 32 # 124=NULL, 125=M0
# MUBUF/MTBUF: vaddr size depends on offen/idxen (1 or 2 regs)
if 'vaddr' in bits and hasattr(self, 'offen') and hasattr(self, 'idxen'):
bits['vaddr'] = max(1, self.offen + self.idxen) * 32
return bits
@property
def op_regs(self) -> dict[str, int]:
"""Get register counts for each operand field."""
return {k: max(1, v // 32) for k, v in self.op_bits.items()}
@functools.cached_property
def canonical_op_bits(self) -> dict[str, int]:
"""Get bit widths with canonical names: {'s0', 's1', 's2', 'd', 'data'}."""
bits = {'d': 32, 's0': 32, 's1': 32, 's2': 32, 'data': 32}
for name, val in self.op_bits.items():
if name in ('src0', 'vsrc0', 'ssrc0'): bits['s0'] = val
elif name in ('src1', 'vsrc1', 'ssrc1'): bits['s1'] = val
elif name == 'src2': bits['s2'] = val
elif name in ('vdst', 'sdst', 'sdata'): bits['d'] = val
elif name in ('data', 'vdata', 'data0'): bits['data'] = val
return bits
@property
def canonical_op_regs(self) -> dict[str, int]:
"""Get register counts with canonical names: {'s0', 's1', 's2', 'd', 'data'}."""
return {k: max(1, v // 32) for k, v in self.canonical_op_bits.items()}
def num_srcs(self) -> int:
"""Get number of source operands from operand info."""
ops = self.operands
@@ -365,9 +354,8 @@ class Inst:
def __hash__(self): return hash((type(self), self._raw, self._literal))
@property
def _field_sizes(self) -> dict[str, int]:
"""Get field sizes for repr - uses current field values."""
vals = {name: getattr(self, name) for name, _ in self._fields}
return self._get_field_sizes(vals)
"""Get field sizes for repr - uses op_regs."""
return self.op_regs
def __repr__(self):
# collect (repr, is_default) pairs, strip trailing defaults so repr roundtrips with eval

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import ctypes, functools
from enum import IntEnum
from tinygrad.runtime.autogen import hsa
from extra.assembly.amd.dsl import Inst, NULL, SCC, VCC_LO, VCC_HI, EXEC_LO, EXEC_HI
from extra.assembly.amd.dsl import Inst, NULL, SCC, VCC_LO, VCC_HI, EXEC_LO, EXEC_HI, v, s
from extra.assembly.amd.pcode import _f32, _i32, _sext, _f16, _i16, _f64, _i64
from extra.assembly.amd.decode import decode_inst
from extra.assembly.amd.pcode import compile_pseudocode
@@ -15,32 +15,26 @@ from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP,
# Constants and helpers defined locally (not imported from dsl.py)
MASK32, MASK64 = 0xFFFFFFFF, 0xFFFFFFFFFFFFFFFF
FLOAT_ENC = {0.5: 240, -0.5: 241, 1.0: 242, -1.0: 243, 2.0: 244, -2.0: 245, 4.0: 246, -4.0: 247}
def unwrap(v): return v.offset if hasattr(v, 'offset') else v
class SGPRArray:
"""SGPR array that accepts Reg or int index. Validates SGPR range (0-127)."""
"""SGPR array indexed by Reg or int."""
__slots__ = ('_data',)
def __init__(self, size: int): self._data = [0] * size
def _idx(self, key) -> int:
i = key.offset if hasattr(key, 'offset') else key
assert 0 <= i < 128, f"SGPR index {i} out of range 0-127"
return i
def __getitem__(self, key): return self._data[self._idx(key)]
def __setitem__(self, key, val): self._data[self._idx(key)] = val
def __getitem__(self, key): return self._data[getattr(key, 'offset', key)]
def __setitem__(self, key, val): self._data[getattr(key, 'offset', key)] = val
def __len__(self): return len(self._data)
def __iter__(self): return iter(self._data)
class VGPRLane:
"""Single lane of VGPRs that accepts Reg or int index. Validates VGPR range (256-511)."""
"""Single lane of VGPRs indexed by Reg (offset 256-511) or int (0-255)."""
__slots__ = ('_data',)
def __init__(self, size: int): self._data = [0] * size
def _idx(self, key) -> int:
i = key.offset if hasattr(key, 'offset') else key
if i >= 256: i -= 256 # convert from src encoding to VGPR index
assert 0 <= i < 256, f"VGPR index {i} out of range 0-255"
return i
def __getitem__(self, key): return self._data[self._idx(key)]
def __setitem__(self, key, val): self._data[self._idx(key)] = val
def __getitem__(self, key):
i = getattr(key, 'offset', key)
return self._data[i - 256 if i >= 256 else i]
def __setitem__(self, key, val):
i = getattr(key, 'offset', key)
self._data[i - 256 if i >= 256 else i] = val
def __len__(self): return len(self._data)
def __iter__(self): return iter(self._data)
@@ -59,12 +53,8 @@ _INLINE_CONSTS_F64 = _build_inline_consts(MASK64, _i64)
# Helper: extract/write 16-bit half from/to 32-bit value
def _src16(raw: int, is_hi: bool) -> int: return ((raw >> 16) & 0xffff) if is_hi else (raw & 0xffff)
def _dst16(cur: int, val: int, is_hi: bool) -> int: return (cur & 0x0000ffff) | ((val & 0xffff) << 16) if is_hi else (cur & 0xffff0000) | (val & 0xffff)
def _vgpr_hi(src) -> bool:
off = src.offset if hasattr(src, 'offset') else src
return off >= 256 and ((off - 256) & 0x80) != 0
def _vgpr_masked(src) -> int:
off = src.offset if hasattr(src, 'offset') else src
return ((off - 256) & 0x7f) + 256 if off >= 256 else off
def _vgpr_hi(src) -> bool: return src.offset >= 256 and ((src.offset - 256) & 0x80) != 0
def _vgpr_masked(src): return v[(src.offset - 256) & 0x7f] if src.offset >= 256 else src
# VOP3 source modifier: apply abs/neg to value
def _mod_src(val: int, idx: int, neg: int, abs_: int, is64: bool = False) -> int:
@@ -76,9 +66,10 @@ def _mod_src(val: int, idx: int, neg: int, abs_: int, is64: bool = False) -> int
# Read source operand with VOP3 modifiers
def _read_src(st, inst, src, idx: int, lane: int, neg: int, abs_: int, opsel: int) -> int:
if src is None: return 0
src_off = src.offset if hasattr(src, 'offset') else src
literal, regs, is_src_16 = inst._literal, inst.src_regs(idx), inst.is_src_16(idx)
if regs == 2: return _mod_src(st.rsrc64(src, lane, literal), idx, neg, abs_, is64=True)
src_off = src.offset
src_bits = inst.canonical_op_bits[f's{idx}']
literal, is_src_64, is_src_16 = inst._literal, src_bits == 64, src_bits == 16
if is_src_64: return _mod_src(st.rsrc64(src, lane, literal), idx, neg, abs_, is64=True)
if isinstance(inst, VOP3P):
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
if 'FMA_MIX' in inst.op_name:
@@ -111,13 +102,11 @@ def _op_ndwords(name: str) -> int:
# Helper: build multi-dword int from consecutive VGPRs
def _vgpr_read(V: VGPRLane, reg, ndwords: int) -> int:
base = reg.offset if hasattr(reg, 'offset') else reg
return sum(V[base + i] << (32 * i) for i in range(ndwords))
return sum(V[reg + i] << (32 * i) for i in range(ndwords))
# Helper: write multi-dword value to consecutive VGPRs
def _vgpr_write(V: VGPRLane, reg, val: int, ndwords: int):
base = reg.offset if hasattr(reg, 'offset') else reg
for i in range(ndwords): V[base + i] = (val >> (32 * i)) & MASK32
for i in range(ndwords): V[reg + i] = (val >> (32 * i)) & MASK32
# Memory access
_valid_mem_ranges: list[tuple[int, int]] = []
@@ -200,26 +189,26 @@ class WaveState:
def wsgpr(self, reg, v: int):
if reg != NULL: self.sgpr[reg] = v & MASK32
def rsgpr64(self, reg) -> int:
off = reg.offset if hasattr(reg, 'offset') else reg
return self.rsgpr(off) | (self.rsgpr(off + 1) << 32)
off = reg.offset
return self.sgpr._data[off] | (self.sgpr._data[off + 1] << 32)
def wsgpr64(self, reg, v: int):
off = reg.offset if hasattr(reg, 'offset') else reg
self.wsgpr(off, v & MASK32); self.wsgpr(off + 1, (v >> 32) & MASK32)
off = reg.offset
self.sgpr._data[off] = v & MASK32; self.sgpr._data[off + 1] = (v >> 32) & MASK32
def _rsrc_base(self, reg, lane: int, consts, literal: int):
v = reg.offset if hasattr(reg, 'offset') else reg
if v < SGPR_COUNT: return self.sgpr[v]
if v == SCC.offset: return self.scc
if v < 255: return consts[v - 128]
if v == 255: return literal
return self.vgpr[lane][v] if v <= 511 else 0
off = reg.offset
if off < SGPR_COUNT: return self.sgpr._data[off]
if off == SCC.offset: return self.scc
if off < 255: return consts[off - 128]
if off == 255: return literal
return self.vgpr[lane]._data[off - 256] if off <= 511 else 0
def rsrc(self, reg, lane: int, literal: int = 0) -> int: return self._rsrc_base(reg, lane, _INLINE_CONSTS, literal)
def rsrc_f16(self, reg, lane: int, literal: int = 0) -> int: return self._rsrc_base(reg, lane, _INLINE_CONSTS_F16, literal)
def rsrc64(self, reg, lane: int, literal: int = 0) -> int:
v = reg.offset if hasattr(reg, 'offset') else reg
if 128 <= v < 255: return _INLINE_CONSTS_F64[v - 128]
if v == 255: return literal << 32 # 32-bit literal forms upper 32 bits of 64-bit value
return self.rsrc(v, lane, literal) | ((self.rsrc(v+1, lane, literal) if v < VCC_LO.offset or 256 <= v <= 511 else 0) << 32)
off = reg.offset
if 128 <= off < 255: return _INLINE_CONSTS_F64[off - 128]
if off == 255: return literal << 32 # 32-bit literal forms upper 32 bits of 64-bit value
return self.rsrc(reg, lane, literal) | ((self.rsrc(reg + 1, lane, literal) if off < VCC_LO.offset or 256 <= off <= 511 else 0) << 32)
def pend_sgpr_lane(self, reg, lane: int, val: int):
if reg not in self._pend_sgpr: self._pend_sgpr[reg] = 0
@@ -251,15 +240,15 @@ def exec_scalar(st: WaveState, inst: Inst):
result = inst._fn(GlobalMem, addr & MASK64)
if 'SDATA' in result:
sdata = result['SDATA']
for i in range(SMEM_DST_COUNT.get(inst.op, 1)): st.wsgpr(inst.sdata.offset + i, (sdata >> (i * 32)) & MASK32)
for i in range(SMEM_DST_COUNT.get(inst.op, 1)): st.wsgpr(inst.sdata + i, (sdata >> (i * 32)) & MASK32)
st.pc += inst._words
return 0
# Build context - use inst methods to determine operand sizes
# Build context - use canonical_op_bits to determine operand sizes
literal = inst._literal
s0 = st.rsrc64(ssrc0, 0, literal) if inst.is_src_64(0) else (st.rsrc(ssrc0, 0, literal) if not isinstance(inst, (SOPK, SOPP)) else (st.rsgpr(inst.sdst) if isinstance(inst, SOPK) else 0))
s1 = st.rsrc64(inst.ssrc1, 0, literal) if inst.is_src_64(1) else (st.rsrc(inst.ssrc1, 0, literal) if isinstance(inst, (SOP2, SOPC)) else inst.simm16 if isinstance(inst, SOPK) else 0)
d0 = st.rsgpr64(sdst) if inst.dst_regs() == 2 and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0)
s0 = st.rsrc64(ssrc0, 0, literal) if inst.canonical_op_bits['s0'] == 64 else (st.rsrc(ssrc0, 0, literal) if not isinstance(inst, (SOPK, SOPP)) else (st.rsgpr(inst.sdst) if isinstance(inst, SOPK) else 0))
s1 = st.rsrc64(inst.ssrc1, 0, literal) if inst.canonical_op_bits['s1'] == 64 else (st.rsrc(inst.ssrc1, 0, literal) if isinstance(inst, (SOP2, SOPC)) else inst.simm16 if isinstance(inst, SOPK) else 0)
d0 = st.rsgpr64(sdst) if inst.canonical_op_bits['d'] == 64 and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0)
literal = inst.simm16 if isinstance(inst, (SOPK, SOPP)) else inst._literal
# Call compiled function with int parameters
@@ -267,7 +256,7 @@ def exec_scalar(st: WaveState, inst: Inst):
# Apply results (already int values)
if sdst is not None and 'D0' in result:
(st.wsgpr64 if inst.dst_regs() == 2 else st.wsgpr)(sdst, result['D0'])
(st.wsgpr64 if inst.canonical_op_bits['d'] == 64 else st.wsgpr)(sdst, result['D0'])
if 'SCC' in result: st.scc = result['SCC'] & 1
if 'EXEC' in result: st.exec_mask = result['EXEC']
if 'PC' in result:
@@ -286,7 +275,7 @@ def exec_scalar(st: WaveState, inst: Inst):
def exec_vopd(st: WaveState, inst, V: VGPRLane, lane: int) -> None:
"""VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes)."""
literal, vdstx = inst._literal, inst.vdstx
vdsty = (inst.vdsty << 1) | ((inst.vdstx.offset & 1) ^ 1) # vdsty is raw int from VDSTYField.decode
vdsty = v[(inst.vdsty << 1) | ((inst.vdstx.offset & 1) ^ 1)] # vdsty is raw int from VDSTYField.decode
sx0, sx1, dx, sy0, sy1, dy = st.rsrc(inst.srcx0, lane, literal), V[inst.vsrcx1], V[vdstx], st.rsrc(inst.srcy0, lane, literal), V[inst.vsrcy1], V[vdsty]
V[vdstx] = inst._fnx(sx0, sx1, 0, dx, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0']
V[vdsty] = inst._fny(sy0, sy1, 0, dy, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0']
@@ -311,17 +300,18 @@ def exec_ds(st: WaveState, inst, V: VGPRLane, lane: int) -> None:
def exec_vop(st: WaveState, inst: Inst, V: VGPRLane, lane: int) -> None:
"""VOP1/VOP2/VOP3/VOP3SD/VOP3P/VOPC: standard ALU ops."""
is_dst_16 = inst.canonical_op_bits['d'] == 16
if isinstance(inst, VOP3P):
src0, src1, src2, vdst, dst_hi = inst.src0, inst.src1, inst.src2, inst.vdst, False
neg, abs_, opsel = inst.neg, 0, inst.opsel
elif isinstance(inst, VOP1):
src0, src1, src2, vdst = inst.src0, None, None, inst.vdst
neg, abs_, opsel, dst_hi = 0, 0, 0, (inst.vdst.offset & 0x80) != 0 and inst.is_dst_16()
if inst.is_dst_16(): vdst = inst.vdst.offset & 0x7f
neg, abs_, opsel, dst_hi = 0, 0, 0, (inst.vdst.offset & 0x80) != 0 and is_dst_16
if is_dst_16: vdst = v[inst.vdst.offset & 0x7f]
elif isinstance(inst, VOP2):
src0, src1, src2, vdst = inst.src0, inst.vsrc1, None, inst.vdst
neg, abs_, opsel, dst_hi = 0, 0, 0, (inst.vdst.offset & 0x80) != 0 and inst.is_dst_16()
if inst.is_dst_16(): vdst = inst.vdst.offset & 0x7f
neg, abs_, opsel, dst_hi = 0, 0, 0, (inst.vdst.offset & 0x80) != 0 and is_dst_16
if is_dst_16: vdst = v[inst.vdst.offset & 0x7f]
elif isinstance(inst, (VOP3, VOP3SD)):
src0, src1, src2, vdst = inst.src0, inst.src1, (None if isinstance(inst, VOP3) and inst.op.value < 256 else inst.src2), inst.vdst
neg, abs_, opsel, dst_hi = (inst.neg, inst.abs, inst.opsel, False) if isinstance(inst, VOP3) else (0, 0, 0, False)
@@ -333,8 +323,8 @@ def exec_vop(st: WaveState, inst: Inst, V: VGPRLane, lane: int) -> None:
s0 = _read_src(st, inst, src0, 0, lane, neg, abs_, opsel)
s1 = _read_src(st, inst, src1, 1, lane, neg, abs_, opsel)
s2 = _read_src(st, inst, src2, 2, lane, neg, abs_, opsel)
if isinstance(inst, VOP2) and inst.is_dst_16(): d0 = _src16(V[vdst], dst_hi)
elif inst.dst_regs() == 2: d0 = V[vdst] | (V[vdst + 1] << 32)
if isinstance(inst, VOP2) and is_dst_16: d0 = _src16(V[vdst], dst_hi)
elif inst.canonical_op_bits['d'] == 64: d0 = V[vdst] | (V[vdst + 1] << 32)
else: d0 = V[vdst]
if isinstance(inst, VOP3SD) and 'CO_CI' in inst.op_name: vcc_for_fn = st.rsgpr64(inst.src2)
@@ -342,7 +332,7 @@ def exec_vop(st: WaveState, inst: Inst, V: VGPRLane, lane: int) -> None:
else: vcc_for_fn = st.vcc
src0_off = src0.offset if src0 is not None else 0
src0_idx = (src0_off - 256) if src0_off >= 256 else src0_off
vdst_off = vdst.offset if hasattr(vdst, 'offset') else vdst
vdst_off = vdst.offset
extra_kwargs = {'opsel': opsel, 'opsel_hi': inst.opsel_hi | (inst.opsel_hi2 << 2)} if isinstance(inst, VOP3P) and 'FMA_MIX' in inst.op_name else {}
result = inst._fn(s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, inst._literal, st.vgpr, src0_idx, vdst_off, **extra_kwargs)
@@ -359,8 +349,8 @@ def exec_vop(st: WaveState, inst: Inst, V: VGPRLane, lane: int) -> None:
st.pend_sgpr_lane(vdst, lane, (result['D0'] >> lane) & 1)
if not is_vopc:
d0_val = result['D0']
if inst.dst_regs() == 2: V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
elif not isinstance(inst, VOP3P) and inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi)
if inst.canonical_op_bits['d'] == 64: V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
elif not isinstance(inst, VOP3P) and is_dst_16: V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi)
else: V[vdst] = d0_val & MASK32
# ═══════════════════════════════════════════════════════════════════════════════
@@ -468,7 +458,7 @@ def exec_workgroup(program: dict[int, Inst], workgroup_id: tuple[int, int, int],
n_lanes = min(WAVE_SIZE, total_threads - wave_start)
st = WaveState(lds, n_lanes)
st.exec_mask = (1 << n_lanes) - 1
st.wsgpr64(0, args_ptr) # s[0:1] = kernel arguments pointer
st.wsgpr64(s[0:1], args_ptr) # s[0:1] = kernel arguments pointer
# COMPUTE_PGM_RSRC2: USER_SGPR_COUNT is where workgroup IDs start, ENABLE_SGPR_WORKGROUP_ID_X/Y/Z control which are passed
sgpr_idx = (rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT_SHIFT
if rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_X: st.sgpr[sgpr_idx] = workgroup_id[0]; sgpr_idx += 1

View File

@@ -292,7 +292,7 @@ def _build_state_table() -> tuple[bytes, dict[int, type[PacketType]]]:
for byte_val in range(256):
for opcode, pkt_cls in enumerate(PACKET_TYPES):
if (byte_val & pkt_cls.encoding.mask()) == pkt_cls.encoding.default:
if (byte_val & pkt_cls.encoding.mask) == pkt_cls.encoding.default:
table[byte_val] = opcode
break
@@ -310,7 +310,7 @@ _DECODE_INFO: dict[int, tuple] = {}
for _opcode, _pkt_cls in OPCODE_TO_CLASS.items():
_delta_field = getattr(_pkt_cls, 'delta', None)
_delta_lo = _delta_field.lo if _delta_field else 0
_delta_mask = _delta_field.mask() if _delta_field else 0
_delta_mask = _delta_field.mask if _delta_field else 0
_special = 1 if _opcode == _TS_DELTA_OR_MARK_OPCODE else (2 if _opcode == _TS_DELTA_SHORT_OPCODE else 0)
_DECODE_INFO[_opcode] = (_pkt_cls, _pkt_cls._size_nibbles, _delta_lo, _delta_mask, _special)

View File

@@ -110,33 +110,33 @@ class TestMIMG(unittest.TestCase):
"""Test MIMG (image) instructions."""
def test_image_load_2d(self):
# image_load v[0:3], v[4:5], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_2D
# image_load v[0:3], v[4:7], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_2D
# GFX11: encoding: [0x04,0x0f,0x00,0xf0,0x04,0x00,0x00,0x00]
inst = image_load(vdata=v[0:3], vaddr=v[4:5], srsrc=s[0:7], dmask=0xf, dim=1) # dim=1 is SQ_RSRC_IMG_2D
inst = image_load(vdata=v[0:3], vaddr=v[4:7], srsrc=s[0:7], dmask=0xf, dim=1) # dim=1 is SQ_RSRC_IMG_2D
self.assertEqual(inst.to_bytes(), bytes([0x04,0x0f,0x00,0xf0,0x04,0x00,0x00,0x00]))
def test_image_store_2d(self):
# image_store v[0:3], v[4:5], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_2D
# image_store v[0:3], v[4:7], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_2D
# GFX11: encoding: [0x04,0x0f,0x18,0xf0,0x04,0x00,0x00,0x00]
inst = image_store(vdata=v[0:3], vaddr=v[4:5], srsrc=s[0:7], dmask=0xf, dim=1)
inst = image_store(vdata=v[0:3], vaddr=v[4:7], srsrc=s[0:7], dmask=0xf, dim=1)
self.assertEqual(inst.to_bytes(), bytes([0x04,0x0f,0x18,0xf0,0x04,0x00,0x00,0x00]))
def test_image_load_1d(self):
# image_load v[0:3], v4, s[0:7] dmask:0xf dim:SQ_RSRC_IMG_1D
# image_load v[0:3], v[4:7], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_1D
# GFX11: encoding: [0x00,0x0f,0x00,0xf0,0x04,0x00,0x00,0x00]
inst = image_load(vdata=v[0:3], vaddr=v[4], srsrc=s[0:7], dmask=0xf, dim=0) # dim=0 is SQ_RSRC_IMG_1D
inst = image_load(vdata=v[0:3], vaddr=v[4:7], srsrc=s[0:7], dmask=0xf, dim=0) # dim=0 is SQ_RSRC_IMG_1D
self.assertEqual(inst.to_bytes(), bytes([0x00,0x0f,0x00,0xf0,0x04,0x00,0x00,0x00]))
def test_image_sample(self):
# image_sample v[0:3], v[4:5], s[0:7], s[8:11] dmask:0xf dim:SQ_RSRC_IMG_2D
# image_sample v[0:3], v[4:6], s[0:7], s[8:11] dmask:0xf dim:SQ_RSRC_IMG_2D
# GFX11: encoding: [0x04,0x0f,0x6c,0xf0,0x04,0x00,0x00,0x08]
inst = image_sample(vdata=v[0:3], vaddr=v[4:5], srsrc=s[0:7], ssamp=s[8:11], dmask=0xf, dim=1)
inst = image_sample(vdata=v[0:3], vaddr=v[4:6], srsrc=s[0:7], ssamp=s[8:11], dmask=0xf, dim=1)
self.assertEqual(inst.to_bytes(), bytes([0x04,0x0f,0x6c,0xf0,0x04,0x00,0x00,0x08]))
def test_image_load_d16(self):
# image_load v[0:1], v[4:5], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_2D d16
# image_load v[0:3], v[4:7], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_2D d16
# GFX11: encoding: [0x04,0x0f,0x02,0xf0,0x04,0x00,0x00,0x00]
inst = image_load(vdata=v[0:1], vaddr=v[4:5], srsrc=s[0:7], dmask=0xf, dim=1, d16=1)
inst = image_load(vdata=v[0:3], vaddr=v[4:7], srsrc=s[0:7], dmask=0xf, dim=1, d16=1)
self.assertEqual(inst.to_bytes(), bytes([0x04,0x0f,0x02,0xf0,0x04,0x00,0x00,0x00]))
@@ -387,7 +387,7 @@ class TestDetectFormat(unittest.TestCase):
self.assertEqual(detect_format(tbuffer_load_format_x(v[0], v[1], s[0:3], s[5], format=22).to_bytes()), MTBUF)
def test_detect_mimg(self):
self.assertEqual(detect_format(image_load(v[0:3], v[4:5], s[0:7], dmask=0xf, dim=1).to_bytes()), MIMG)
self.assertEqual(detect_format(image_load(v[0:3], v[4:7], s[0:7], dmask=0xf, dim=1).to_bytes()), MIMG)
def test_detect_exp(self):
self.assertEqual(detect_format(EXP(en=0xf, target=0, vsrc0=v[0], vsrc1=v[1], vsrc2=v[2], vsrc3=v[3]).to_bytes()), EXP)