From fd60626ea1d2eda8ea6d6083412771430f9cfdd2 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 15 Jan 2026 11:20:21 +0900 Subject: [PATCH] assembly/amd: refactor to use op_bits/op_regs (#14156) * assembly/amd: refactor to use op_bits/op_regs * remove that skip * remove another hack * remove another hack * precompute mask * more reg, less hasattr --- extra/assembly/amd/decode.py | 2 +- extra/assembly/amd/disasm.py | 97 ++++++++++---------- extra/assembly/amd/dsl.py | 108 ++++++++++------------- extra/assembly/amd/emu.py | 112 +++++++++++------------- extra/assembly/amd/sqtt.py | 4 +- extra/assembly/amd/test/test_formats.py | 22 ++--- 6 files changed, 165 insertions(+), 180 deletions(-) diff --git a/extra/assembly/amd/decode.py b/extra/assembly/amd/decode.py index fe92c55659..0a9c9b468a 100644 --- a/extra/assembly/amd/decode.py +++ b/extra/assembly/amd/decode.py @@ -16,7 +16,7 @@ def _matches_encoding(word: int, cls: type[Inst]) -> bool: enc = next(((n, f) for n, f in cls._fields if isinstance(f, FixedBitField) and n == 'encoding'), None) if enc is None: return False bf = enc[1] - return ((word >> bf.lo) & bf.mask()) == bf.default + return ((word >> bf.lo) & bf.mask) == bf.default # Order matters: more specific encodings first, VOP2 last (it's a catch-all for bit31=0) _RDNA_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP] diff --git a/extra/assembly/amd/disasm.py b/extra/assembly/amd/disasm.py index 5d71961a4a..1afba3c950 100644 --- a/extra/assembly/amd/disasm.py +++ b/extra/assembly/amd/disasm.py @@ -186,10 +186,10 @@ def _disasm_vop1(inst: VOP1) -> str: src = _vreg(inst.src0) if _unwrap(inst.src0) >= 256 else decode_src(_unwrap(inst.src0), cdna) vdst_raw = _unwrap(inst.vdst) return f"{name} {_fmt_sdst(vdst_raw - 256 if vdst_raw >= 256 else vdst_raw, 1, cdna)}, {src}" - # Use operand info for register sizes and 16-bit detection - dregs, sregs = inst.dst_regs(), inst.src_regs(0) - is16_dst = not cdna and inst.is_dst_16() - is16_src = not cdna and inst.is_src_16(0) + # Use get_field_bits for register sizes and 16-bit detection + bits = inst.canonical_op_bits + dregs, sregs = max(1, bits['d'] // 32), max(1, bits['s0'] // 32) + is16_dst, is16_src = not cdna and bits['d'] == 16, not cdna and bits['s0'] == 16 # v_cvt_pk_f32_fp8/bf8: pcode has None dst type but outputs 2 VGPRs if 'cvt_pk_f32_fp8' in name or 'cvt_pk_f32_bf8' in name: dregs, is16_src = 2, True # Format dst @@ -216,8 +216,9 @@ def _disasm_vop2(inst: VOP2) -> str: if cdna: name = _CDNA_DISASM_ALIASES.get(name, name) # apply CDNA aliases suf = "" if cdna or name.endswith('_e32') or (not cdna and inst.op == VOP2Op.V_DOT2ACC_F32_F16_E32) else "_e32" lit = getattr(inst, '_literal', None) - # Use operand info for 16-bit detection - is16 = not cdna and inst.is_dst_16() + # Use get_field_bits for 16-bit detection + bits = inst.canonical_op_bits + is16 = not cdna and bits['d'] == 16 # fmaak/madak: dst = src0 * vsrc1 + K, fmamk/madmk: dst = src0 * K + vsrc1 if 'fmaak' in name or 'madak' in name or (not cdna and inst.op in (VOP2Op.V_FMAAK_F32_E32, VOP2Op.V_FMAAK_F16_E32)): if lit is None: return f"op_{inst.op.value if hasattr(inst.op, 'value') else inst.op}" @@ -234,8 +235,9 @@ def _disasm_vop2(inst: VOP2) -> str: if cdna and name in _VOP2_CARRY_INOUT: return f"{name}{suf} {_vreg(inst.vdst)}, {vcc}, {_lit(inst, inst.src0)}, {_vreg(inst.vsrc1)}, {vcc}" # RDNA carry-in/out ops: v_add_co_ci_u32, etc. if not cdna and name in _VOP2_CARRY_INOUT_RDNA: return f"{name}{suf} {_vreg(inst.vdst)}, {vcc}, {_lit(inst, inst.src0)}, {_vreg(inst.vsrc1)}, {vcc}" - # Use pcode types for register sizes - pcode is complete for all VOP2 ops - dn, sn0, sn1 = inst.dst_regs(), inst.src_regs(0), inst.src_regs(1) + # Use get_field_bits for register sizes + regs = inst.canonical_op_regs + dn, sn0, sn1 = regs.get('d', 1), regs.get('s0', 1), regs.get('s1', 1) if dn > 1 or sn0 > 1 or sn1 > 1: dst = _vreg(inst.vdst, dn) src0 = _lit(inst, inst.src0) if _unwrap(inst.src0) == 255 else _fmt_src(inst.src0, sn0, cdna) @@ -245,9 +247,10 @@ def _disasm_vop2(inst: VOP2) -> str: def _disasm_vopc(inst: VOPC) -> str: name, cdna = inst.op_name.lower(), _is_cdna(inst) - # Use operand info for register sizes and 16-bit detection - r0, r1 = inst.src_regs(0), inst.src_regs(1) - is16 = inst.is_src_16(0) + # Use get_field_bits for register sizes and 16-bit detection + bits = inst.canonical_op_bits + r0, r1 = max(1, bits['s0'] // 32), max(1, bits['s1'] // 32) + is16 = bits['s0'] == 16 if cdna: s0 = _lit(inst, inst.src0) if _unwrap(inst.src0) == 255 else _fmt_src(inst.src0, r0, cdna) s1 = _vreg(inst.vsrc1, r1) if r1 > 1 else _vreg(inst.vsrc1) @@ -321,8 +324,8 @@ def _disasm_smem(inst: SMEM) -> str: soff_s = decode_src(inst.soffset, cdna) if inst.soffset != 124 else "null" if 'pc_rel' in name: return f"{name} {off_s}, {soff_s}, {_unwrap(inst.sdata)}" return f"{name} {sbase_str}, {off_s}, {soff_s}, {_unwrap(inst.sdata)}" - # Use operand info for register count - dst_n = inst.dst_regs() + # Use get_field_bits for register count + dst_n = inst.canonical_op_regs.get('d', 1) th, scope = getattr(inst, 'th', 0), getattr(inst, 'scope', 0) if is_rdna4: # RDNA4 uses th/scope instead of glc/dlc th_names = ['TH_LOAD_RT', 'TH_LOAD_NT', 'TH_LOAD_HT', 'TH_LOAD_LU'] @@ -343,8 +346,9 @@ def _disasm_flat(inst: FLAT) -> str: seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat' instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}" off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192) - # Use operand info: data_regs for stores/atomics, dst_regs for loads - w = inst.data_regs() if 'store' in name or 'atomic' in name else inst.dst_regs() + # Use get_field_bits: data for stores/atomics, d for loads + regs = inst.canonical_op_regs + w = regs.get('data', regs.get('d', 1)) if 'store' in name or 'atomic' in name else regs.get('d', 1) off_s = f" offset:{off_val}" if off_val else "" if cdna: mods = f"{off_s}{' glc' if inst.sc0 else ''}{' slc' if inst.nt else ''}" else: mods = f"{off_s}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}" @@ -372,8 +376,9 @@ def _disasm_ds(inst: DS) -> str: gds = " gds" if getattr(inst, 'gds', 0) else "" off = f" offset:{inst.offset0 | (inst.offset1 << 8)}" if inst.offset0 or inst.offset1 else "" off2 = (" offset0:" + str(inst.offset0) if inst.offset0 else "") + (" offset1:" + str(inst.offset1) if inst.offset1 else "") - # Use operand info: data_regs for stores/writes/atomics, dst_regs for loads - w = inst.data_regs() if 'store' in name or 'write' in name or ('load' not in name and 'read' not in name) else inst.dst_regs() + # Use get_field_bits: data for stores/writes/atomics, d for loads + regs = inst.canonical_op_regs + w = regs.get('data', regs.get('d', 1)) if 'store' in name or 'write' in name or ('load' not in name and 'read' not in name) else regs.get('d', 1) d0, d1, dst, addr = reg_fn(inst.data0, w), reg_fn(inst.data1, w), reg_fn(inst.vdst, w), _vreg(inst.addr) if name == 'ds_nop': return name @@ -388,11 +393,11 @@ def _disasm_ds(inst: DS) -> str: if name in ('ds_consume', 'ds_append'): return f"{name} {reg_fn(inst.vdst)}{off}{gds}" if 'gs_reg' in name: return f"{name} {reg_fn(inst.vdst, 2)}, {reg_fn(inst.data0)}{off}{gds}" if '2addr' in name: - if 'load' in name: return f"{name} {reg_fn(inst.vdst, inst.dst_regs())}, {addr}{off2}{gds}" + if 'load' in name: return f"{name} {reg_fn(inst.vdst, regs.get('d', 1))}, {addr}{off2}{gds}" if 'store' in name and 'xchg' not in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}" - return f"{name} {reg_fn(inst.vdst, inst.dst_regs())}, {addr}, {d0}, {d1}{off2}{gds}" + return f"{name} {reg_fn(inst.vdst, regs.get('d', 1))}, {addr}, {d0}, {d1}{off2}{gds}" if 'write2' in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}" - if 'read2' in name: return f"{name} {reg_fn(inst.vdst, inst.dst_regs())}, {addr}{off2}{gds}" + if 'read2' in name: return f"{name} {reg_fn(inst.vdst, regs.get('d', 1))}, {addr}{off2}{gds}" if 'load' in name: return f"{name} {reg_fn(inst.vdst)}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}" if 'store' in name and not _has(name, 'cmp', 'xchg'): return f"{name} {reg_fn(inst.data0)}{off}{gds}" if 'addtid' in name else f"{name} {addr}, {d0}{off}{gds}" @@ -406,20 +411,21 @@ def _disasm_ds(inst: DS) -> str: def _disasm_vop3(inst: VOP3) -> str: op, name = inst.op, inst.op_name.lower() n_up = name.upper() + bits = inst.canonical_op_bits # RDNA4 v_s_* scalar VOP3 instructions - vdst is SGPR (VGPRField adds 256) if name.startswith('v_s_'): - src = _lit(inst, inst.src0) if _unwrap(inst.src0) == 255 else ("src_scc" if _unwrap(inst.src0) == 253 else _fmt_src(inst.src0, inst.src_regs(0))) + src = _lit(inst, inst.src0) if _unwrap(inst.src0) == 255 else ("src_scc" if _unwrap(inst.src0) == 253 else _fmt_src(inst.src0, max(1, bits['s0'] // 32))) if inst.neg & 1: src = f"-{src}" if inst.abs & 1: src = f"|{src}|" clamp = getattr(inst, 'cm', None) or getattr(inst, 'clmp', 0) vdst_raw = _unwrap(inst.vdst) return f"{name} s{vdst_raw - 256 if vdst_raw >= 256 else vdst_raw}, {src}" + (" clamp" if clamp else "") + _omod(inst.omod) - # Use operand info for register sizes and 16-bit detection - r0, r1, r2 = inst.src_regs(0), inst.src_regs(1), inst.src_regs(2) - dn = inst.dst_regs() - is16_d, is16_s, is16_s2 = inst.is_dst_16(), inst.is_src_16(0), inst.is_src_16(2) + # Use get_field_bits for register sizes and 16-bit detection + r0, r1, r2 = max(1, bits['s0'] // 32), max(1, bits['s1'] // 32), max(1, bits['s2'] // 32) + dn = max(1, bits['d'] // 32) + is16_d, is16_s, is16_s2 = bits['d'] == 16, bits['s0'] == 16, bits['s2'] == 16 s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, r0, is16_s) s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, r1, is16_s) @@ -460,20 +466,18 @@ def _disasm_vop3(inst: VOP3) -> str: return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}" def _disasm_vop3sd(inst: VOP3SD) -> str: - name = inst.op_name.lower() - # Use pcode types for register sizes - dn, sr0, sr1, sr2 = inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2) + name, regs = inst.op_name.lower(), inst.canonical_op_regs + dn, sr0, sr1, sr2 = regs['d'], regs['s0'], regs['s1'], regs['s2'] def src(v, neg, n): v = _unwrap(v) s = _lit(inst, v) if v == 255 else ("src_scc" if v == 253 else (_fmt_src(v, n) if n > 1 else _lit(inst, v))) return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s) s0, s1, s2 = src(inst.src0, inst.neg & 1, sr0), src(inst.src1, inst.neg & 2, sr1), src(inst.src2, inst.neg & 4, sr2) - dst = _vreg(inst.vdst, dn) # VOP3SD: _co_ ops (add/sub) without _ci_ have only 2 sources, all others (mad, div_scale, _co_ci_) have 3 sources has_only_two_srcs = '_co_' in name and '_ci_' not in name and 'mad' not in name srcs = f"{s0}, {s1}" if has_only_two_srcs else f"{s0}, {s1}, {s2}" clamp = getattr(inst, 'cm', None) or getattr(inst, 'clmp', 0) - return f"{name} {dst}, {_fmt_sdst(inst.sdst, 1)}, {srcs}{' clamp' if clamp else ''}{_omod(inst.omod)}" + return f"{name} {_vreg(inst.vdst, dn)}, {_fmt_sdst(inst.sdst, 1)}, {srcs}{' clamp' if clamp else ''}{_omod(inst.omod)}" def _disasm_vopd(inst: VOPD) -> str: lit = inst._literal or getattr(inst, 'literal', None) @@ -494,9 +498,9 @@ def _disasm_vop3p(inst: VOP3P) -> str: def get_src(v, sc): uv = _unwrap(v) return _lit(inst, uv) if uv == 255 else _fmt_src(uv, sc) - # Use operand info for register sizes - dn, s0n, s1n, s2n = inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2) - src0, src1, src2, dst = get_src(inst.src0, s0n), get_src(inst.src1, s1n), get_src(inst.src2, s2n), _vreg(inst.vdst, dn) + # Use get_field_bits for register sizes + regs = inst.canonical_op_regs + src0, src1, src2, dst = get_src(inst.src0, regs['s0']), get_src(inst.src1, regs['s1']), get_src(inst.src2, regs['s2']), _vreg(inst.vdst, regs['d']) opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2) clamp = getattr(inst, 'cm', None) or getattr(inst, 'clmp', 0) if is_fma_mix: @@ -584,8 +588,9 @@ def _disasm_mimg(inst: MIMG) -> str: def _disasm_sop1(inst: SOP1) -> str: op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst) - # Use operand info for register sizes - dst_regs, src_regs = inst.dst_regs(), inst.src_regs(0) + # Use get_field_bits for register sizes + regs = inst.canonical_op_regs + dst_regs, src_regs = regs.get('d', 1), regs.get('s0', 1) src = _lit(inst, inst.ssrc0) if _unwrap(inst.ssrc0) == 255 else _fmt_src(inst.ssrc0, src_regs, cdna) if not cdna: if 'getpc_b64' in name: return f"{name} {_fmt_sdst(inst.sdst, 2)}" @@ -602,8 +607,9 @@ def _disasm_sop1(inst: SOP1) -> str: def _disasm_sop2(inst: SOP2) -> str: cdna, name = _is_cdna(inst), inst.op_name.lower() lit = getattr(inst, '_literal', None) - # Use operand info for register sizes - dn, s0n, s1n = inst.dst_regs(), inst.src_regs(0), inst.src_regs(1) + # Use get_field_bits for register sizes + regs = inst.canonical_op_regs + dn, s0n, s1n = regs['d'], regs['s0'], regs['s1'] s0 = _lit(inst, inst.ssrc0) if _unwrap(inst.ssrc0) == 255 else _fmt_src(inst.ssrc0, s0n, cdna) s1 = _lit(inst, inst.ssrc1) if _unwrap(inst.ssrc1) == 255 else _fmt_src(inst.ssrc1, s1n, cdna) dst = _fmt_sdst(inst.sdst, dn, cdna) @@ -612,10 +618,9 @@ def _disasm_sop2(inst: SOP2) -> str: return f"{name} {dst}, {s0}, {s1}" def _disasm_sopc(inst: SOPC) -> str: - cdna = _is_cdna(inst) - s0_regs, s1_regs = inst.src_regs(0), inst.src_regs(1) # pcode types are complete for all SOPC ops - s0 = _lit(inst, inst.ssrc0) if _unwrap(inst.ssrc0) == 255 else _fmt_src(inst.ssrc0, s0_regs, cdna) - s1 = _lit(inst, inst.ssrc1) if _unwrap(inst.ssrc1) == 255 else _fmt_src(inst.ssrc1, s1_regs, cdna) + cdna, regs = _is_cdna(inst), inst.canonical_op_regs + s0 = _lit(inst, inst.ssrc0) if _unwrap(inst.ssrc0) == 255 else _fmt_src(inst.ssrc0, regs['s0'], cdna) + s1 = _lit(inst, inst.ssrc1) if _unwrap(inst.ssrc1) == 255 else _fmt_src(inst.ssrc1, regs['s1'], cdna) return f"{inst.op_name.lower()} {s0}, {s1}" def _disasm_sopk(inst: SOPK) -> str: @@ -636,7 +641,7 @@ def _disasm_sopk(inst: SOPK) -> str: return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1, cdna)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1, cdna)}, {hs}" if name in ('s_subvector_loop_begin', 's_subvector_loop_end'): return f"{name} {_fmt_sdst(inst.sdst, 1)}, 0x{inst.simm16:x}" - return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, 0x{inst.simm16:x}" # pcode types are complete for SOPK + return f"{name} {_fmt_sdst(inst.sdst, inst.canonical_op_regs['d'], cdna)}, 0x{inst.simm16:x}" def _disasm_vinterp(inst: VINTERP) -> str: mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp")) @@ -655,7 +660,7 @@ def _disasm_vbuffer(inst) -> str: name = inst.op_name.lower().replace('buffer_', 'buffer_').replace('tbuffer_', 'tbuffer_') w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \ ((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \ - {'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], inst.dst_regs()) + {'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], inst.canonical_op_regs['d']) if getattr(inst, 'tfe', 0): w += 1 vdata = _vreg(inst.vdata, w) if w else _vreg(inst.vdata) vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else (_vreg(inst.vaddr) if inst.offen or inst.idxen else 'off') @@ -714,7 +719,8 @@ try: s2 = "" dst = _vreg(inst.vdst) else: - dregs, r0, r1, r2 = inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2) + regs = inst.canonical_op_regs + dregs, r0, r1, r2 = regs['d'], regs['s0'], regs['s1'], regs['s2'] s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, r0), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, r1), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, r2) dst = _vreg(inst.vdst, dregs) if dregs > 1 else _vreg(inst.vdst) if op_val >= 512: @@ -738,7 +744,8 @@ try: if hasattr(op_val, 'value'): op_val = op_val.value name = inst.op_name.lower() or f'vop3b_op_{op_val}' n = inst.num_srcs() or _num_srcs(inst) - dregs, r0, r1, r2 = inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2) + regs = inst.canonical_op_regs + dregs, r0, r1, r2 = regs['d'], regs['s0'], regs['s1'], regs['s2'] s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, n=r0), _cdna_src(inst, inst.src1, inst.neg&2, n=r1), _cdna_src(inst, inst.src2, inst.neg&4, n=r2) dst = _vreg(inst.vdst, dregs) if dregs > 1 else _vreg(inst.vdst) sdst = _fmt_sdst(inst.sdst, 2, cdna=True) diff --git a/extra/assembly/amd/dsl.py b/extra/assembly/amd/dsl.py index 344db2ef0b..13902bb6af 100644 --- a/extra/assembly/amd/dsl.py +++ b/extra/assembly/amd/dsl.py @@ -96,13 +96,12 @@ bits = _Bits() class BitField: def __init__(self, hi: int, lo: int, default: int = 0): - self.hi, self.lo, self.default, self.name = hi, lo, default, None + self.hi, self.lo, self.default, self.name, self.mask = hi, lo, default, None, (1 << (hi - lo + 1)) - 1 def __set_name__(self, owner, name): self.name = name def __eq__(self, other) -> 'FixedBitField': if isinstance(other, int): return FixedBitField(self.hi, self.lo, other) return NotImplemented def enum(self, enum_cls) -> 'EnumBitField': return EnumBitField(self.hi, self.lo, enum_cls) - def mask(self) -> int: return (1 << (self.hi - self.lo + 1)) - 1 def encode(self, val) -> int: assert isinstance(val, int), f"BitField.encode expects int, got {type(val).__name__}" return val @@ -110,11 +109,11 @@ class BitField: def set(self, raw: int, val) -> int: if val is None: val = self.default encoded = self.encode(val) - if encoded < 0 or encoded > self.mask(): raise RuntimeError(f"field '{self.name}': value {encoded} doesn't fit in {self.hi - self.lo + 1} bits") - return (raw & ~(self.mask() << self.lo)) | (encoded << self.lo) + if encoded < 0 or encoded > self.mask: raise RuntimeError(f"field '{self.name}': value {encoded} doesn't fit in {self.hi - self.lo + 1} bits") + return (raw & ~(self.mask << self.lo)) | (encoded << self.lo) def __get__(self, obj, objtype=None): if obj is None: return self - return self.decode((obj._raw >> self.lo) & self.mask()) + return self.decode((obj._raw >> self.lo) & self.mask) class FixedBitField(BitField): def set(self, raw: int, val=None) -> int: @@ -203,6 +202,7 @@ class VDSTYField(BitField): # Operand info from XML # ══════════════════════════════════════════════════════════════ +import functools from extra.assembly.amd.autogen.rdna3.operands import OPERANDS as OPERANDS_RDNA3 from extra.assembly.amd.autogen.rdna4.operands import OPERANDS as OPERANDS_RDNA4 from extra.assembly.amd.autogen.cdna.operands import OPERANDS as OPERANDS_CDNA @@ -259,7 +259,7 @@ class Inst: if isinstance(field, SrcField) and val is not None and field.encode(val) + field._valid_range[0] == 255 and self._literal is None: self._literal = _f32(val) if isinstance(val, float) else val & 0xFFFFFFFF # Validate register sizes against operand info (skip special registers like NULL, VCC, EXEC) - for name, expected in self._get_field_sizes(vals).items(): + for name, expected in self.op_regs.items(): if (val := vals.get(name)) is None: continue if isinstance(val, Reg) and val.sz != expected and not (106 <= val.offset <= 127 or val.offset == 253): raise TypeError(f"{name} expects {expected} register(s), got {val.sz}") @@ -269,60 +269,49 @@ class Inst: @property def operands(self) -> dict: return OPERANDS.get(self.op, {}) if hasattr(self, 'op') else {} def _is_cdna(self) -> bool: return 'cdna' in type(self).__module__ - def _get_field_sizes(self, vals: dict) -> dict[str, int]: - """Map field names to expected register sizes based on operand info.""" - sizes = {k: (v[1] + 31) // 32 for k, v in self.operands.items()} - if not hasattr(self, 'op'): return sizes - name = self.op_name.lower() - # RDNA (WAVE32): condition masks and carry flags are 32-bit; CDNA (WAVE64) uses 64-bit + + @functools.cached_property + def op_bits(self) -> dict[str, int]: + """Get bit widths for each operand field, with WAVE32 and addr/saddr adjustments.""" + if not hasattr(self, 'op'): return {k: v[1] for k, v in self.operands.items()} + bits = {k: v[1] for k, v in self.operands.items()} + # RDNA (WAVE32): condition masks, carry flags, and compare results are 32-bit if not self._is_cdna(): - if 'cndmask' in name and 'src2' in sizes: sizes['src2'] = 1 + name = self.op_name.lower() + if 'cndmask' in name and 'src2' in bits: bits['src2'] = 32 if '_co_ci_' in name: - if 'src2' in sizes: sizes['src2'] = 1 - if 'sdst' in sizes: sizes['sdst'] = 1 + if 'src2' in bits: bits['src2'] = 32 + if 'sdst' in bits: bits['sdst'] = 32 + if 'cmp' in name and 'vdst' in bits: bits['vdst'] = 32 # GLOBAL/FLAT: addr is 32-bit if saddr is valid SGPR, 64-bit if saddr is NULL - # Check vals for saddr since some ops have the field but not in operand info - if 'addr' in sizes and ('saddr' in sizes or 'saddr' in vals): - saddr_val = vals.get('saddr') - if isinstance(saddr_val, Reg): saddr_val = saddr_val.offset - is_null_saddr = saddr_val in (None, 124, 125) # 124=NULL, 125=M0 - sizes['addr'] = 2 if is_null_saddr else 1 - # saddr is 2 SGPRs when not NULL, otherwise skip validation (NULL is special single reg) - if is_null_saddr: sizes.pop('saddr', None) - # MUBUF/MTBUF: vaddr is variable (0-2 regs depending on idxen/offen), vdata depends on format - if 'vaddr' in sizes: sizes.pop('vaddr') - if 'vdata' in sizes: sizes.pop('vdata') - # VOPC/VOP3 vdst for compares is wave-size dependent - if 'vdst' in sizes and 'cmp' in name: sizes.pop('vdst') - return sizes - def _field_bits(self, name: str) -> int: - """Get size in bits for a field from operand info.""" - return self.operands.get(name, (None, 0, None))[1] - def is_src_64(self, n: int) -> bool: - for name in (['src0', 'vsrc0', 'ssrc0'] if n == 0 else ['src1', 'vsrc1', 'ssrc1'] if n == 1 else ['src2']): - if name in self.operands: return self.operands[name][1] == 64 - return False - def is_src_16(self, n: int) -> bool: - for name in (['src0', 'vsrc0', 'ssrc0'] if n == 0 else ['src1', 'vsrc1', 'ssrc1'] if n == 1 else ['src2']): - if name in self.operands: return self.operands[name][1] == 16 - return False - def is_dst_16(self) -> bool: - for name in ['vdst', 'sdst', 'sdata']: - if name in self.operands: return self.operands[name][1] == 16 - return False - def dst_regs(self) -> int: - for name in ['vdst', 'sdst', 'sdata']: - if name in self.operands: return max(1, self.operands[name][1] // 32) - return 1 - def data_regs(self) -> int: - """Get data register count for memory ops (stores use 'data' field, loads use 'vdst').""" - for name in ['data', 'vdata', 'data0']: - if name in self.operands: return max(1, self.operands[name][1] // 32) - return self.dst_regs() # fallback to vdst for loads - def src_regs(self, n: int) -> int: - for name in (['src0', 'vsrc0', 'ssrc0'] if n == 0 else ['src1', 'vsrc1', 'ssrc1'] if n == 1 else ['src2']): - if name in self._field_sizes: return self._field_sizes[name] - return 1 + if 'addr' in bits and hasattr(self, 'saddr'): + saddr_val = self.saddr.offset if isinstance(self.saddr, Reg) else self.saddr + bits['addr'] = 64 if saddr_val in (None, 124, 125) else 32 # 124=NULL, 125=M0 + # MUBUF/MTBUF: vaddr size depends on offen/idxen (1 or 2 regs) + if 'vaddr' in bits and hasattr(self, 'offen') and hasattr(self, 'idxen'): + bits['vaddr'] = max(1, self.offen + self.idxen) * 32 + return bits + @property + def op_regs(self) -> dict[str, int]: + """Get register counts for each operand field.""" + return {k: max(1, v // 32) for k, v in self.op_bits.items()} + + @functools.cached_property + def canonical_op_bits(self) -> dict[str, int]: + """Get bit widths with canonical names: {'s0', 's1', 's2', 'd', 'data'}.""" + bits = {'d': 32, 's0': 32, 's1': 32, 's2': 32, 'data': 32} + for name, val in self.op_bits.items(): + if name in ('src0', 'vsrc0', 'ssrc0'): bits['s0'] = val + elif name in ('src1', 'vsrc1', 'ssrc1'): bits['s1'] = val + elif name == 'src2': bits['s2'] = val + elif name in ('vdst', 'sdst', 'sdata'): bits['d'] = val + elif name in ('data', 'vdata', 'data0'): bits['data'] = val + return bits + @property + def canonical_op_regs(self) -> dict[str, int]: + """Get register counts with canonical names: {'s0', 's1', 's2', 'd', 'data'}.""" + return {k: max(1, v // 32) for k, v in self.canonical_op_bits.items()} + def num_srcs(self) -> int: """Get number of source operands from operand info.""" ops = self.operands @@ -365,9 +354,8 @@ class Inst: def __hash__(self): return hash((type(self), self._raw, self._literal)) @property def _field_sizes(self) -> dict[str, int]: - """Get field sizes for repr - uses current field values.""" - vals = {name: getattr(self, name) for name, _ in self._fields} - return self._get_field_sizes(vals) + """Get field sizes for repr - uses op_regs.""" + return self.op_regs def __repr__(self): # collect (repr, is_default) pairs, strip trailing defaults so repr roundtrips with eval diff --git a/extra/assembly/amd/emu.py b/extra/assembly/amd/emu.py index b04c1661d7..61274b2d67 100644 --- a/extra/assembly/amd/emu.py +++ b/extra/assembly/amd/emu.py @@ -4,7 +4,7 @@ from __future__ import annotations import ctypes, functools from enum import IntEnum from tinygrad.runtime.autogen import hsa -from extra.assembly.amd.dsl import Inst, NULL, SCC, VCC_LO, VCC_HI, EXEC_LO, EXEC_HI +from extra.assembly.amd.dsl import Inst, NULL, SCC, VCC_LO, VCC_HI, EXEC_LO, EXEC_HI, v, s from extra.assembly.amd.pcode import _f32, _i32, _sext, _f16, _i16, _f64, _i64 from extra.assembly.amd.decode import decode_inst from extra.assembly.amd.pcode import compile_pseudocode @@ -15,32 +15,26 @@ from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, # Constants and helpers defined locally (not imported from dsl.py) MASK32, MASK64 = 0xFFFFFFFF, 0xFFFFFFFFFFFFFFFF FLOAT_ENC = {0.5: 240, -0.5: 241, 1.0: 242, -1.0: 243, 2.0: 244, -2.0: 245, 4.0: 246, -4.0: 247} -def unwrap(v): return v.offset if hasattr(v, 'offset') else v class SGPRArray: - """SGPR array that accepts Reg or int index. Validates SGPR range (0-127).""" + """SGPR array indexed by Reg or int.""" __slots__ = ('_data',) def __init__(self, size: int): self._data = [0] * size - def _idx(self, key) -> int: - i = key.offset if hasattr(key, 'offset') else key - assert 0 <= i < 128, f"SGPR index {i} out of range 0-127" - return i - def __getitem__(self, key): return self._data[self._idx(key)] - def __setitem__(self, key, val): self._data[self._idx(key)] = val + def __getitem__(self, key): return self._data[getattr(key, 'offset', key)] + def __setitem__(self, key, val): self._data[getattr(key, 'offset', key)] = val def __len__(self): return len(self._data) def __iter__(self): return iter(self._data) class VGPRLane: - """Single lane of VGPRs that accepts Reg or int index. Validates VGPR range (256-511).""" + """Single lane of VGPRs indexed by Reg (offset 256-511) or int (0-255).""" __slots__ = ('_data',) def __init__(self, size: int): self._data = [0] * size - def _idx(self, key) -> int: - i = key.offset if hasattr(key, 'offset') else key - if i >= 256: i -= 256 # convert from src encoding to VGPR index - assert 0 <= i < 256, f"VGPR index {i} out of range 0-255" - return i - def __getitem__(self, key): return self._data[self._idx(key)] - def __setitem__(self, key, val): self._data[self._idx(key)] = val + def __getitem__(self, key): + i = getattr(key, 'offset', key) + return self._data[i - 256 if i >= 256 else i] + def __setitem__(self, key, val): + i = getattr(key, 'offset', key) + self._data[i - 256 if i >= 256 else i] = val def __len__(self): return len(self._data) def __iter__(self): return iter(self._data) @@ -59,12 +53,8 @@ _INLINE_CONSTS_F64 = _build_inline_consts(MASK64, _i64) # Helper: extract/write 16-bit half from/to 32-bit value def _src16(raw: int, is_hi: bool) -> int: return ((raw >> 16) & 0xffff) if is_hi else (raw & 0xffff) def _dst16(cur: int, val: int, is_hi: bool) -> int: return (cur & 0x0000ffff) | ((val & 0xffff) << 16) if is_hi else (cur & 0xffff0000) | (val & 0xffff) -def _vgpr_hi(src) -> bool: - off = src.offset if hasattr(src, 'offset') else src - return off >= 256 and ((off - 256) & 0x80) != 0 -def _vgpr_masked(src) -> int: - off = src.offset if hasattr(src, 'offset') else src - return ((off - 256) & 0x7f) + 256 if off >= 256 else off +def _vgpr_hi(src) -> bool: return src.offset >= 256 and ((src.offset - 256) & 0x80) != 0 +def _vgpr_masked(src): return v[(src.offset - 256) & 0x7f] if src.offset >= 256 else src # VOP3 source modifier: apply abs/neg to value def _mod_src(val: int, idx: int, neg: int, abs_: int, is64: bool = False) -> int: @@ -76,9 +66,10 @@ def _mod_src(val: int, idx: int, neg: int, abs_: int, is64: bool = False) -> int # Read source operand with VOP3 modifiers def _read_src(st, inst, src, idx: int, lane: int, neg: int, abs_: int, opsel: int) -> int: if src is None: return 0 - src_off = src.offset if hasattr(src, 'offset') else src - literal, regs, is_src_16 = inst._literal, inst.src_regs(idx), inst.is_src_16(idx) - if regs == 2: return _mod_src(st.rsrc64(src, lane, literal), idx, neg, abs_, is64=True) + src_off = src.offset + src_bits = inst.canonical_op_bits[f's{idx}'] + literal, is_src_64, is_src_16 = inst._literal, src_bits == 64, src_bits == 16 + if is_src_64: return _mod_src(st.rsrc64(src, lane, literal), idx, neg, abs_, is64=True) if isinstance(inst, VOP3P): opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2) if 'FMA_MIX' in inst.op_name: @@ -111,13 +102,11 @@ def _op_ndwords(name: str) -> int: # Helper: build multi-dword int from consecutive VGPRs def _vgpr_read(V: VGPRLane, reg, ndwords: int) -> int: - base = reg.offset if hasattr(reg, 'offset') else reg - return sum(V[base + i] << (32 * i) for i in range(ndwords)) + return sum(V[reg + i] << (32 * i) for i in range(ndwords)) # Helper: write multi-dword value to consecutive VGPRs def _vgpr_write(V: VGPRLane, reg, val: int, ndwords: int): - base = reg.offset if hasattr(reg, 'offset') else reg - for i in range(ndwords): V[base + i] = (val >> (32 * i)) & MASK32 + for i in range(ndwords): V[reg + i] = (val >> (32 * i)) & MASK32 # Memory access _valid_mem_ranges: list[tuple[int, int]] = [] @@ -200,26 +189,26 @@ class WaveState: def wsgpr(self, reg, v: int): if reg != NULL: self.sgpr[reg] = v & MASK32 def rsgpr64(self, reg) -> int: - off = reg.offset if hasattr(reg, 'offset') else reg - return self.rsgpr(off) | (self.rsgpr(off + 1) << 32) + off = reg.offset + return self.sgpr._data[off] | (self.sgpr._data[off + 1] << 32) def wsgpr64(self, reg, v: int): - off = reg.offset if hasattr(reg, 'offset') else reg - self.wsgpr(off, v & MASK32); self.wsgpr(off + 1, (v >> 32) & MASK32) + off = reg.offset + self.sgpr._data[off] = v & MASK32; self.sgpr._data[off + 1] = (v >> 32) & MASK32 def _rsrc_base(self, reg, lane: int, consts, literal: int): - v = reg.offset if hasattr(reg, 'offset') else reg - if v < SGPR_COUNT: return self.sgpr[v] - if v == SCC.offset: return self.scc - if v < 255: return consts[v - 128] - if v == 255: return literal - return self.vgpr[lane][v] if v <= 511 else 0 + off = reg.offset + if off < SGPR_COUNT: return self.sgpr._data[off] + if off == SCC.offset: return self.scc + if off < 255: return consts[off - 128] + if off == 255: return literal + return self.vgpr[lane]._data[off - 256] if off <= 511 else 0 def rsrc(self, reg, lane: int, literal: int = 0) -> int: return self._rsrc_base(reg, lane, _INLINE_CONSTS, literal) def rsrc_f16(self, reg, lane: int, literal: int = 0) -> int: return self._rsrc_base(reg, lane, _INLINE_CONSTS_F16, literal) def rsrc64(self, reg, lane: int, literal: int = 0) -> int: - v = reg.offset if hasattr(reg, 'offset') else reg - if 128 <= v < 255: return _INLINE_CONSTS_F64[v - 128] - if v == 255: return literal << 32 # 32-bit literal forms upper 32 bits of 64-bit value - return self.rsrc(v, lane, literal) | ((self.rsrc(v+1, lane, literal) if v < VCC_LO.offset or 256 <= v <= 511 else 0) << 32) + off = reg.offset + if 128 <= off < 255: return _INLINE_CONSTS_F64[off - 128] + if off == 255: return literal << 32 # 32-bit literal forms upper 32 bits of 64-bit value + return self.rsrc(reg, lane, literal) | ((self.rsrc(reg + 1, lane, literal) if off < VCC_LO.offset or 256 <= off <= 511 else 0) << 32) def pend_sgpr_lane(self, reg, lane: int, val: int): if reg not in self._pend_sgpr: self._pend_sgpr[reg] = 0 @@ -251,15 +240,15 @@ def exec_scalar(st: WaveState, inst: Inst): result = inst._fn(GlobalMem, addr & MASK64) if 'SDATA' in result: sdata = result['SDATA'] - for i in range(SMEM_DST_COUNT.get(inst.op, 1)): st.wsgpr(inst.sdata.offset + i, (sdata >> (i * 32)) & MASK32) + for i in range(SMEM_DST_COUNT.get(inst.op, 1)): st.wsgpr(inst.sdata + i, (sdata >> (i * 32)) & MASK32) st.pc += inst._words return 0 - # Build context - use inst methods to determine operand sizes + # Build context - use canonical_op_bits to determine operand sizes literal = inst._literal - s0 = st.rsrc64(ssrc0, 0, literal) if inst.is_src_64(0) else (st.rsrc(ssrc0, 0, literal) if not isinstance(inst, (SOPK, SOPP)) else (st.rsgpr(inst.sdst) if isinstance(inst, SOPK) else 0)) - s1 = st.rsrc64(inst.ssrc1, 0, literal) if inst.is_src_64(1) else (st.rsrc(inst.ssrc1, 0, literal) if isinstance(inst, (SOP2, SOPC)) else inst.simm16 if isinstance(inst, SOPK) else 0) - d0 = st.rsgpr64(sdst) if inst.dst_regs() == 2 and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0) + s0 = st.rsrc64(ssrc0, 0, literal) if inst.canonical_op_bits['s0'] == 64 else (st.rsrc(ssrc0, 0, literal) if not isinstance(inst, (SOPK, SOPP)) else (st.rsgpr(inst.sdst) if isinstance(inst, SOPK) else 0)) + s1 = st.rsrc64(inst.ssrc1, 0, literal) if inst.canonical_op_bits['s1'] == 64 else (st.rsrc(inst.ssrc1, 0, literal) if isinstance(inst, (SOP2, SOPC)) else inst.simm16 if isinstance(inst, SOPK) else 0) + d0 = st.rsgpr64(sdst) if inst.canonical_op_bits['d'] == 64 and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0) literal = inst.simm16 if isinstance(inst, (SOPK, SOPP)) else inst._literal # Call compiled function with int parameters @@ -267,7 +256,7 @@ def exec_scalar(st: WaveState, inst: Inst): # Apply results (already int values) if sdst is not None and 'D0' in result: - (st.wsgpr64 if inst.dst_regs() == 2 else st.wsgpr)(sdst, result['D0']) + (st.wsgpr64 if inst.canonical_op_bits['d'] == 64 else st.wsgpr)(sdst, result['D0']) if 'SCC' in result: st.scc = result['SCC'] & 1 if 'EXEC' in result: st.exec_mask = result['EXEC'] if 'PC' in result: @@ -286,7 +275,7 @@ def exec_scalar(st: WaveState, inst: Inst): def exec_vopd(st: WaveState, inst, V: VGPRLane, lane: int) -> None: """VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes).""" literal, vdstx = inst._literal, inst.vdstx - vdsty = (inst.vdsty << 1) | ((inst.vdstx.offset & 1) ^ 1) # vdsty is raw int from VDSTYField.decode + vdsty = v[(inst.vdsty << 1) | ((inst.vdstx.offset & 1) ^ 1)] # vdsty is raw int from VDSTYField.decode sx0, sx1, dx, sy0, sy1, dy = st.rsrc(inst.srcx0, lane, literal), V[inst.vsrcx1], V[vdstx], st.rsrc(inst.srcy0, lane, literal), V[inst.vsrcy1], V[vdsty] V[vdstx] = inst._fnx(sx0, sx1, 0, dx, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0'] V[vdsty] = inst._fny(sy0, sy1, 0, dy, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0'] @@ -311,17 +300,18 @@ def exec_ds(st: WaveState, inst, V: VGPRLane, lane: int) -> None: def exec_vop(st: WaveState, inst: Inst, V: VGPRLane, lane: int) -> None: """VOP1/VOP2/VOP3/VOP3SD/VOP3P/VOPC: standard ALU ops.""" + is_dst_16 = inst.canonical_op_bits['d'] == 16 if isinstance(inst, VOP3P): src0, src1, src2, vdst, dst_hi = inst.src0, inst.src1, inst.src2, inst.vdst, False neg, abs_, opsel = inst.neg, 0, inst.opsel elif isinstance(inst, VOP1): src0, src1, src2, vdst = inst.src0, None, None, inst.vdst - neg, abs_, opsel, dst_hi = 0, 0, 0, (inst.vdst.offset & 0x80) != 0 and inst.is_dst_16() - if inst.is_dst_16(): vdst = inst.vdst.offset & 0x7f + neg, abs_, opsel, dst_hi = 0, 0, 0, (inst.vdst.offset & 0x80) != 0 and is_dst_16 + if is_dst_16: vdst = v[inst.vdst.offset & 0x7f] elif isinstance(inst, VOP2): src0, src1, src2, vdst = inst.src0, inst.vsrc1, None, inst.vdst - neg, abs_, opsel, dst_hi = 0, 0, 0, (inst.vdst.offset & 0x80) != 0 and inst.is_dst_16() - if inst.is_dst_16(): vdst = inst.vdst.offset & 0x7f + neg, abs_, opsel, dst_hi = 0, 0, 0, (inst.vdst.offset & 0x80) != 0 and is_dst_16 + if is_dst_16: vdst = v[inst.vdst.offset & 0x7f] elif isinstance(inst, (VOP3, VOP3SD)): src0, src1, src2, vdst = inst.src0, inst.src1, (None if isinstance(inst, VOP3) and inst.op.value < 256 else inst.src2), inst.vdst neg, abs_, opsel, dst_hi = (inst.neg, inst.abs, inst.opsel, False) if isinstance(inst, VOP3) else (0, 0, 0, False) @@ -333,8 +323,8 @@ def exec_vop(st: WaveState, inst: Inst, V: VGPRLane, lane: int) -> None: s0 = _read_src(st, inst, src0, 0, lane, neg, abs_, opsel) s1 = _read_src(st, inst, src1, 1, lane, neg, abs_, opsel) s2 = _read_src(st, inst, src2, 2, lane, neg, abs_, opsel) - if isinstance(inst, VOP2) and inst.is_dst_16(): d0 = _src16(V[vdst], dst_hi) - elif inst.dst_regs() == 2: d0 = V[vdst] | (V[vdst + 1] << 32) + if isinstance(inst, VOP2) and is_dst_16: d0 = _src16(V[vdst], dst_hi) + elif inst.canonical_op_bits['d'] == 64: d0 = V[vdst] | (V[vdst + 1] << 32) else: d0 = V[vdst] if isinstance(inst, VOP3SD) and 'CO_CI' in inst.op_name: vcc_for_fn = st.rsgpr64(inst.src2) @@ -342,7 +332,7 @@ def exec_vop(st: WaveState, inst: Inst, V: VGPRLane, lane: int) -> None: else: vcc_for_fn = st.vcc src0_off = src0.offset if src0 is not None else 0 src0_idx = (src0_off - 256) if src0_off >= 256 else src0_off - vdst_off = vdst.offset if hasattr(vdst, 'offset') else vdst + vdst_off = vdst.offset extra_kwargs = {'opsel': opsel, 'opsel_hi': inst.opsel_hi | (inst.opsel_hi2 << 2)} if isinstance(inst, VOP3P) and 'FMA_MIX' in inst.op_name else {} result = inst._fn(s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, inst._literal, st.vgpr, src0_idx, vdst_off, **extra_kwargs) @@ -359,8 +349,8 @@ def exec_vop(st: WaveState, inst: Inst, V: VGPRLane, lane: int) -> None: st.pend_sgpr_lane(vdst, lane, (result['D0'] >> lane) & 1) if not is_vopc: d0_val = result['D0'] - if inst.dst_regs() == 2: V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32 - elif not isinstance(inst, VOP3P) and inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi) + if inst.canonical_op_bits['d'] == 64: V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32 + elif not isinstance(inst, VOP3P) and is_dst_16: V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi) else: V[vdst] = d0_val & MASK32 # ═══════════════════════════════════════════════════════════════════════════════ @@ -468,7 +458,7 @@ def exec_workgroup(program: dict[int, Inst], workgroup_id: tuple[int, int, int], n_lanes = min(WAVE_SIZE, total_threads - wave_start) st = WaveState(lds, n_lanes) st.exec_mask = (1 << n_lanes) - 1 - st.wsgpr64(0, args_ptr) # s[0:1] = kernel arguments pointer + st.wsgpr64(s[0:1], args_ptr) # s[0:1] = kernel arguments pointer # COMPUTE_PGM_RSRC2: USER_SGPR_COUNT is where workgroup IDs start, ENABLE_SGPR_WORKGROUP_ID_X/Y/Z control which are passed sgpr_idx = (rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT_SHIFT if rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_X: st.sgpr[sgpr_idx] = workgroup_id[0]; sgpr_idx += 1 diff --git a/extra/assembly/amd/sqtt.py b/extra/assembly/amd/sqtt.py index 33c5106e51..667440d643 100644 --- a/extra/assembly/amd/sqtt.py +++ b/extra/assembly/amd/sqtt.py @@ -292,7 +292,7 @@ def _build_state_table() -> tuple[bytes, dict[int, type[PacketType]]]: for byte_val in range(256): for opcode, pkt_cls in enumerate(PACKET_TYPES): - if (byte_val & pkt_cls.encoding.mask()) == pkt_cls.encoding.default: + if (byte_val & pkt_cls.encoding.mask) == pkt_cls.encoding.default: table[byte_val] = opcode break @@ -310,7 +310,7 @@ _DECODE_INFO: dict[int, tuple] = {} for _opcode, _pkt_cls in OPCODE_TO_CLASS.items(): _delta_field = getattr(_pkt_cls, 'delta', None) _delta_lo = _delta_field.lo if _delta_field else 0 - _delta_mask = _delta_field.mask() if _delta_field else 0 + _delta_mask = _delta_field.mask if _delta_field else 0 _special = 1 if _opcode == _TS_DELTA_OR_MARK_OPCODE else (2 if _opcode == _TS_DELTA_SHORT_OPCODE else 0) _DECODE_INFO[_opcode] = (_pkt_cls, _pkt_cls._size_nibbles, _delta_lo, _delta_mask, _special) diff --git a/extra/assembly/amd/test/test_formats.py b/extra/assembly/amd/test/test_formats.py index 1b0b075c99..334f045f85 100644 --- a/extra/assembly/amd/test/test_formats.py +++ b/extra/assembly/amd/test/test_formats.py @@ -110,33 +110,33 @@ class TestMIMG(unittest.TestCase): """Test MIMG (image) instructions.""" def test_image_load_2d(self): - # image_load v[0:3], v[4:5], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_2D + # image_load v[0:3], v[4:7], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_2D # GFX11: encoding: [0x04,0x0f,0x00,0xf0,0x04,0x00,0x00,0x00] - inst = image_load(vdata=v[0:3], vaddr=v[4:5], srsrc=s[0:7], dmask=0xf, dim=1) # dim=1 is SQ_RSRC_IMG_2D + inst = image_load(vdata=v[0:3], vaddr=v[4:7], srsrc=s[0:7], dmask=0xf, dim=1) # dim=1 is SQ_RSRC_IMG_2D self.assertEqual(inst.to_bytes(), bytes([0x04,0x0f,0x00,0xf0,0x04,0x00,0x00,0x00])) def test_image_store_2d(self): - # image_store v[0:3], v[4:5], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_2D + # image_store v[0:3], v[4:7], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_2D # GFX11: encoding: [0x04,0x0f,0x18,0xf0,0x04,0x00,0x00,0x00] - inst = image_store(vdata=v[0:3], vaddr=v[4:5], srsrc=s[0:7], dmask=0xf, dim=1) + inst = image_store(vdata=v[0:3], vaddr=v[4:7], srsrc=s[0:7], dmask=0xf, dim=1) self.assertEqual(inst.to_bytes(), bytes([0x04,0x0f,0x18,0xf0,0x04,0x00,0x00,0x00])) def test_image_load_1d(self): - # image_load v[0:3], v4, s[0:7] dmask:0xf dim:SQ_RSRC_IMG_1D + # image_load v[0:3], v[4:7], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_1D # GFX11: encoding: [0x00,0x0f,0x00,0xf0,0x04,0x00,0x00,0x00] - inst = image_load(vdata=v[0:3], vaddr=v[4], srsrc=s[0:7], dmask=0xf, dim=0) # dim=0 is SQ_RSRC_IMG_1D + inst = image_load(vdata=v[0:3], vaddr=v[4:7], srsrc=s[0:7], dmask=0xf, dim=0) # dim=0 is SQ_RSRC_IMG_1D self.assertEqual(inst.to_bytes(), bytes([0x00,0x0f,0x00,0xf0,0x04,0x00,0x00,0x00])) def test_image_sample(self): - # image_sample v[0:3], v[4:5], s[0:7], s[8:11] dmask:0xf dim:SQ_RSRC_IMG_2D + # image_sample v[0:3], v[4:6], s[0:7], s[8:11] dmask:0xf dim:SQ_RSRC_IMG_2D # GFX11: encoding: [0x04,0x0f,0x6c,0xf0,0x04,0x00,0x00,0x08] - inst = image_sample(vdata=v[0:3], vaddr=v[4:5], srsrc=s[0:7], ssamp=s[8:11], dmask=0xf, dim=1) + inst = image_sample(vdata=v[0:3], vaddr=v[4:6], srsrc=s[0:7], ssamp=s[8:11], dmask=0xf, dim=1) self.assertEqual(inst.to_bytes(), bytes([0x04,0x0f,0x6c,0xf0,0x04,0x00,0x00,0x08])) def test_image_load_d16(self): - # image_load v[0:1], v[4:5], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_2D d16 + # image_load v[0:3], v[4:7], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_2D d16 # GFX11: encoding: [0x04,0x0f,0x02,0xf0,0x04,0x00,0x00,0x00] - inst = image_load(vdata=v[0:1], vaddr=v[4:5], srsrc=s[0:7], dmask=0xf, dim=1, d16=1) + inst = image_load(vdata=v[0:3], vaddr=v[4:7], srsrc=s[0:7], dmask=0xf, dim=1, d16=1) self.assertEqual(inst.to_bytes(), bytes([0x04,0x0f,0x02,0xf0,0x04,0x00,0x00,0x00])) @@ -387,7 +387,7 @@ class TestDetectFormat(unittest.TestCase): self.assertEqual(detect_format(tbuffer_load_format_x(v[0], v[1], s[0:3], s[5], format=22).to_bytes()), MTBUF) def test_detect_mimg(self): - self.assertEqual(detect_format(image_load(v[0:3], v[4:5], s[0:7], dmask=0xf, dim=1).to_bytes()), MIMG) + self.assertEqual(detect_format(image_load(v[0:3], v[4:7], s[0:7], dmask=0xf, dim=1).to_bytes()), MIMG) def test_detect_exp(self): self.assertEqual(detect_format(EXP(en=0xf, target=0, vsrc0=v[0], vsrc1=v[1], vsrc2=v[2], vsrc3=v[3]).to_bytes()), EXP)