diff --git a/extra/assembly/amd/asm.py b/extra/assembly/amd/asm.py index ed92f2fcfb..de2e2da6c8 100644 --- a/extra/assembly/amd/asm.py +++ b/extra/assembly/amd/asm.py @@ -1,4 +1,4 @@ -# RDNA3 assembler and disassembler +# RDNA3/CDNA assembler and disassembler from __future__ import annotations import re from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory @@ -8,6 +8,8 @@ from extra.assembly.amd.autogen.rdna3 import ins from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp) +def _is_cdna(inst: Inst) -> bool: return 'cdna' in inst.__class__.__module__ + def _matches_encoding(word: int, cls: type[Inst]) -> bool: """Check if word matches the encoding pattern of an instruction class.""" if cls._encoding is None: return False @@ -81,10 +83,11 @@ def _src16(inst, v: int) -> str: return _fmt_v16(v) if v >= 256 else inst.lit(v) def _mods(*pairs) -> str: return " ".join(m for c, m in pairs if c) def _fmt_bits(label: str, val: int, count: int) -> str: return f"{label}:[{','.join(str((val >> i) & 1) for i in range(count))}]" -def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool, any_hi: bool) -> str: +def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool) -> str: """Format VOP3 source operand with modifiers.""" - if n > 1: s = _fmt_src(v, n) - elif f16 and v >= 256: s = f"v{v - 256}.h" if hi else (f"v{v - 256}.l" if any_hi else inst.lit(v)) + if v == 255: s = inst.lit(v) # literal constant takes priority + elif n > 1: s = _fmt_src(v, n) + elif f16 and v >= 256: s = f"v{v - 256}.h" if hi else f"v{v - 256}.l" else: s = inst.lit(v) if abs_: s = f"|{s}|" return f"-{s}" if neg else s @@ -92,9 +95,11 @@ def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool, any def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str: """Format op_sel modifier string.""" if not need: return "" - if is16_d and (opsel & 8): return f" op_sel:[1,1,1{',1' if n == 3 else ''}]" - if n == 3: return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1},{(opsel >> 3) & 1}]" - return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1}]" + # For VOP1 (n=1): op_sel:[src0_hi, dst_hi], for VOP2 (n=2): op_sel:[src0_hi, src1_hi, dst_hi], for VOP3 (n=3): op_sel:[src0_hi, src1_hi, src2_hi, dst_hi] + dst_hi = (opsel >> 3) & 1 + if n == 1: return f" op_sel:[{opsel & 1},{dst_hi}]" + if n == 2: return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{dst_hi}]" + return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1},{dst_hi}]" # ═══════════════════════════════════════════════════════════════════════════════ # DISASSEMBLER @@ -108,30 +113,41 @@ def _disasm_vop1(inst: VOP1) -> str: parts = name.split('_') is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name) dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}" - src = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0) + src = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0) return f"{name}_e32 {dst}, {src}" def _disasm_vop2(inst: VOP2) -> str: - name = inst.op_name.lower() - suf = "" if inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32" + name, cdna = inst.op_name.lower(), _is_cdna(inst) + suf = "" if not cdna and inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32" + lit = getattr(inst, '_literal', None) + is16 = not cdna and inst.is_16bit() # fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1 - if inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}" - if inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}" - if inst.is_16bit(): return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}" - return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (", vcc_lo" if inst.op == VOP2Op.V_CNDMASK_B32 else "") + if 'fmaak' in name or (not cdna and inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16)): + if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}, 0x{lit:x}" + return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{lit:x}" + if 'fmamk' in name or (not cdna and inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16)): + if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, 0x{lit:x}, {_fmt_v16(inst.vsrc1, 0, 128)}" + return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{lit:x}, v{inst.vsrc1}" + if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}" + vcc = "vcc" if cdna else "vcc_lo" + return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (f", {vcc}" if name == 'v_cndmask_b32' else "") def _disasm_vopc(inst: VOPC) -> str: - name = inst.op_name.lower() - s0 = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0) + name, cdna = inst.op_name.lower(), _is_cdna(inst) + if cdna: + s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) + return f"{name}_e32 {s0}, v{inst.vsrc1}" if inst.op.value >= 128 else f"{name}_e32 vcc, {s0}, v{inst.vsrc1}" + s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0) s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else _fmt_v16(inst.vsrc1, 0, 128) if inst.is_16bit() else f"v{inst.vsrc1}" return f"{name}_e32 {s0}, {s1}" if inst.op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}" -NO_ARG_SOPP = {SOPPOp.S_ENDPGM, SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV, - SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE} +NO_ARG_SOPP = {SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV, + SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE, SOPPOp.S_TTRACEDATA} def _disasm_sopp(inst: SOPP) -> str: name = inst.op_name.lower() if inst.op in NO_ARG_SOPP: return name + if inst.op == SOPPOp.S_ENDPGM: return name if inst.simm16 == 0 else f"{name} {inst.simm16}" if inst.op == SOPPOp.S_WAITCNT: vm, exp, lgkm = (inst.simm16 >> 10) & 0x3f, inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x3f p = [f"vmcnt({vm})" if vm != 0x3f else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""] @@ -154,12 +170,13 @@ def _disasm_smem(inst: SMEM) -> str: return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc")) def _disasm_flat(inst: FLAT) -> str: - name = inst.op_name.lower() + name, cdna = inst.op_name.lower(), _is_cdna(inst) seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat' instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}" off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192) w = inst.dst_regs() * (2 if 'cmpswap' in name else 1) - mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}" + if cdna: mods = f"{f' offset:{off_val}' if off_val else ''}{' sc0' if inst.sc0 else ''}{' nt' if inst.nt else ''}{' sc1' if inst.sc1 else ''}" + else: mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}" # saddr if seg == 'flat' or inst.saddr == 0x7F: saddr_s = "" elif inst.saddr == 124: saddr_s = ", off" @@ -172,8 +189,9 @@ def _disasm_flat(inst: FLAT) -> str: # addr width addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2) data_s, vdst_s = _vreg(inst.data, w), _vreg(inst.vdst, w // 2 if 'cmpswap' in name else w) + glc_or_sc0 = inst.sc0 if cdna else inst.glc if 'atomic' in name: - return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if inst.glc else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" + return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if glc_or_sc0 else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if 'store' in name: return f"{instr} {addr_s}, {data_s}{saddr_s}{mods}" return f"{instr} {_vreg(inst.vdst, w)}, {addr_s}{saddr_s}{mods}" @@ -211,7 +229,9 @@ def _disasm_vop3(inst: VOP3) -> str: # VOP3SD (shared encoding) if isinstance(op, VOP3SDOp): sdst = (inst.clmp << 7) | (inst.opsel << 3) | inst.abs - def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s + def src(v, neg, n): + s = inst.lit(v) if v == 255 else (_fmt_src(v, n) if n > 1 else inst.lit(v)) + return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s) s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2)) dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}" srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}" @@ -225,18 +245,18 @@ def _disasm_vop3(inst: VOP3) -> str: is16_s2 = is16_s elif re.match(r'v_mad_[iu]32_[iu]16', name): is16_s = True elif 'pack_b32' in name: is16_s = is16_s2 = True + elif 'sat_pk' in name: is16_d = True # v_sat_pk_* writes to 16-bit dest but takes 32-bit src else: is16_d = is16_s = is16_s2 = inst.is_16bit() - any_hi = inst.opsel != 0 - s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s, any_hi) - s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s, any_hi) - s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2, any_hi) + s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s) + s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s) + s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2) # Destination dn = inst.dst_regs() if op == VOP3Op.V_READLANE_B32: dst = _fmt_sdst(inst.vdst, 1) elif dn > 1: dst = _vreg(inst.vdst, dn) - elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l" if any_hi else f"v{inst.vdst}" + elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l" else: dst = f"v{inst.vdst}" cl, om = " clamp" if inst.clmp else "", _omod(inst.omod) @@ -244,7 +264,7 @@ def _disasm_vop3(inst: VOP3) -> str: need_opsel = nonvgpr_opsel or (inst.opsel and not is16_s) if inst.op < 256: # VOPC - return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}" + return f"{name}_e64 {s0}, {s1}{cl}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}{cl}" if inst.op < 384: # VOP2 n = inst.num_srcs() os = _opsel_str(inst.opsel, n, need_opsel, is16_d) @@ -258,7 +278,9 @@ def _disasm_vop3(inst: VOP3) -> str: def _disasm_vop3sd(inst: VOP3SD) -> str: name = inst.op_name.lower() - def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s + def src(v, neg, n): + s = inst.lit(v) if v == 255 else (_fmt_src(v, n) if n > 1 else inst.lit(v)) + return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s) s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2)) dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}" srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}" @@ -268,16 +290,22 @@ def _disasm_vop3sd(inst: VOP3SD) -> str: def _disasm_vopd(inst: VOPD) -> str: lit = inst._literal or inst.literal vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), VOPDOp(inst.opx).name.lower(), VOPDOp(inst.opy).name.lower() - def half(n, vd, s0, vs1): return f"{n} v{vd}, {inst.lit(s0)}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}" if 'mov' in n else f"{n} v{vd}, {inst.lit(s0)}, v{vs1}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}" + def half(n, vd, s0, vs1): + if 'mov' in n: return f"{n} v{vd}, {inst.lit(s0)}" + # fmamk: dst = src0 * K + vsrc1, fmaak: dst = src0 * vsrc1 + K + if 'fmamk' in n and lit: return f"{n} v{vd}, {inst.lit(s0)}, 0x{lit:x}, v{vs1}" + if 'fmaak' in n and lit: return f"{n} v{vd}, {inst.lit(s0)}, v{vs1}, 0x{lit:x}" + return f"{n} v{vd}, {inst.lit(s0)}, v{vs1}" return f"{half(nx, inst.vdstx, inst.srcx0, inst.vsrcx1)} :: {half(ny, vdst_y, inst.srcy0, inst.vsrcy1)}" def _disasm_vop3p(inst: VOP3P) -> str: name = inst.op_name.lower() is_wmma, n, is_fma_mix = 'wmma' in name, inst.num_srcs(), 'fma_mix' in name + def get_src(v, sc): return inst.lit(v) if v == 255 else _fmt_src(v, sc) if is_wmma: sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8 - src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, 8), _vreg(inst.vdst, 8) - else: src0, src1, src2, dst = _fmt_src(inst.src0, 1), _fmt_src(inst.src1, 1), _fmt_src(inst.src2, 1), f"v{inst.vdst}" + src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 8), _vreg(inst.vdst, 8) + else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}" opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2) if is_fma_mix: def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s) @@ -289,15 +317,17 @@ def _disasm_vop3p(inst: VOP3P) -> str: return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}" def _disasm_buf(inst: MUBUF | MTBUF) -> str: - name = inst.op_name.lower() - if inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name + name, cdna = inst.op_name.lower(), _is_cdna(inst) + if cdna and name in ('buffer_wbl2', 'buffer_inv'): return name + if not cdna and inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \ ((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \ {'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], 1) - if inst.tfe: w += 1 + if hasattr(inst, 'tfe') and inst.tfe: w += 1 vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else f"v{inst.vaddr}" if inst.offen or inst.idxen else "off" srsrc = _sreg_or_ttmp(inst.srsrc*4, 4) - mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c] + if cdna: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.sc0,"sc0"),(inst.nt,"nt"),(inst.sc1,"sc1")] if c] + else: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c] return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {srsrc}, {decode_src(inst.soffset)}{' ' + ' '.join(mods) if mods else ''}" def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int: @@ -348,25 +378,36 @@ def _disasm_mimg(inst: MIMG) -> str: def _disasm_sop1(inst: SOP1) -> str: op, name = inst.op, inst.op_name.lower() - if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}" - if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {_fmt_src(inst.ssrc0, 2)}" - if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {_fmt_src(inst.ssrc0, 2)}" - if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})" - return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.src_regs(0) == 1 else _fmt_src(inst.ssrc0, inst.src_regs(0))}" + src = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0)) + if not _is_cdna(inst): + if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}" + if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {src}" + if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {src}" + if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})" + return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {src}" def _disasm_sop2(inst: SOP2) -> str: return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))}" def _disasm_sopc(inst: SOPC) -> str: - return f"{inst.op_name.lower()} {_fmt_src(inst.ssrc0, inst.src_regs(0))}, {_fmt_src(inst.ssrc1, inst.src_regs(1))}" + s0 = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0)) + s1 = inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1)) + return f"{inst.op_name.lower()} {s0}, {s1}" def _disasm_sopk(inst: SOPK) -> str: - op, name = inst.op, inst.op_name.lower() - if op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}" - if op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32): + op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst) + # s_setreg_imm32_b32 has a 32-bit literal value + if name == 's_setreg_imm32_b32' or (not cdna and op == SOPKOp.S_SETREG_IMM32_B32): hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1 hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})" - return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if op == SOPKOp.S_SETREG_B32 else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}" + return f"{name} {hs}, 0x{inst._literal:x}" + if not cdna and op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}" + if (not cdna and op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32)) or (cdna and name in ('s_setreg_b32', 's_getreg_b32')): + hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1 + hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})" + return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}" + if not cdna and op in (SOPKOp.S_SUBVECTOR_LOOP_BEGIN, SOPKOp.S_SUBVECTOR_LOOP_END): + return f"{name} {_fmt_sdst(inst.sdst, 1)}, 0x{inst.simm16:x}" return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, 0x{inst.simm16:x}" def _disasm_vinterp(inst: VINTERP) -> str: @@ -388,7 +429,8 @@ SPEC_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'vcc': RawImm(106), ' FLOATS = {str(k): k for k in FLOAT_ENC} # Valid float literal strings: '0.5', '-0.5', '1.0', etc. REG_MAP: dict[str, _RegFactory] = {'s': s, 'v': v, 't': ttmp, 'ttmp': ttmp} SMEM_OPS = {'s_load_b32', 's_load_b64', 's_load_b128', 's_load_b256', 's_load_b512', - 's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512'} + 's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512', + 's_atc_probe', 's_atc_probe_buffer'} SPEC_DSL = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC_LO', 'null': 'NULL', 'off': 'OFF', 'm0': 'M0', 'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC_LO', 'scc': 'SCC', 'src_scc': 'SCC'} @@ -579,3 +621,69 @@ def asm(text: str) -> Inst: except NameError: if m := re.match(r'^(v_\w+)(\(.*\))$', dsl): return eval(f"{m.group(1)}_e32{m.group(2)}", ns) raise + +# ═══════════════════════════════════════════════════════════════════════════════ +# CDNA DISASSEMBLER SUPPORT +# ═══════════════════════════════════════════════════════════════════════════════ + +try: + from extra.assembly.amd.autogen.cdna.ins import (VOP1 as CDNA_VOP1, VOP2 as CDNA_VOP2, VOPC as CDNA_VOPC, VOP3A, VOP3B, VOP3P as CDNA_VOP3P, + SOP1 as CDNA_SOP1, SOP2 as CDNA_SOP2, SOPC as CDNA_SOPC, SOPK as CDNA_SOPK, SOPP as CDNA_SOPP, SMEM as CDNA_SMEM, DS as CDNA_DS, + FLAT as CDNA_FLAT, MUBUF as CDNA_MUBUF, MTBUF as CDNA_MTBUF, SDWA, DPP, VOP1Op as CDNA_VOP1Op) + + def _cdna_src(inst, v, neg, abs_=0, n=1): + s = inst.lit(v) if v == 255 else _fmt_src(v, n) + if abs_: s = f"|{s}|" + return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s) + + def _disasm_vop3a(inst) -> str: + name, n, cl, om = inst.op_name.lower(), inst.num_srcs(), " clamp" if inst.clmp else "", _omod(inst.omod) + s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.src_regs(0)), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.src_regs(1)), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.src_regs(2)) + dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}" + if inst.op.value < 256: return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}" + suf = "_e64" if inst.op.value < 512 else "" + return f"{name}{suf} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else (f"{name}{suf}" if name == 'v_nop' else f"{name}{suf} {dst}, {s0}, {s1}{cl}{om}" if n == 2 else f"{name}{suf} {dst}, {s0}{cl}{om}") + + def _disasm_vop3b(inst) -> str: + name, n = inst.op_name.lower(), inst.num_srcs() + s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1), _cdna_src(inst, inst.src1, inst.neg&2), _cdna_src(inst, inst.src2, inst.neg&4) + dst, suf = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}", "_e64" if 'co_' in name else "" + cl, om = " clamp" if inst.clmp else "", _omod(inst.omod) + return f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}{cl}{om}" + + def _disasm_cdna_vop3p(inst) -> str: + name, n, is_mfma = inst.op_name.lower(), inst.num_srcs(), 'mfma' in inst.op_name.lower() or 'smfmac' in inst.op_name.lower() + get_src = lambda v, sc: inst.lit(v) if v == 255 else _fmt_src(v, sc) + if is_mfma: sc = 2 if 'iu4' in name else 4 if 'iu8' in name or 'i4' in name else 8 if 'f16' in name or 'bf16' in name else 4; src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 16), _vreg(inst.vdst, 16) + else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}" + opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2) + mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if n == 3 else 3) else []) + \ + ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else []) + return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}" + + _SEL = {0: 'BYTE_0', 1: 'BYTE_1', 2: 'BYTE_2', 3: 'BYTE_3', 4: 'WORD_0', 5: 'WORD_1', 6: 'DWORD'} + _UNUSED = {0: 'UNUSED_PAD', 1: 'UNUSED_SEXT', 2: 'UNUSED_PRESERVE'} + _DPP = {0x130: "wave_shl:1", 0x134: "wave_rol:1", 0x138: "wave_shr:1", 0x13c: "wave_ror:1", 0x140: "row_mirror", 0x141: "row_half_mirror", 0x142: "row_bcast:15", 0x143: "row_bcast:31"} + + def _disasm_sdwa(inst) -> str: + try: name = CDNA_VOP1Op(inst.vop_op).name.lower() + except ValueError: name = f"vop1_op_{inst.vop_op}" + src = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0) + mods = [f"dst_sel:{_SEL[inst.dst_sel]}" for _ in [1] if inst.dst_sel != 6] + [f"dst_unused:{_UNUSED[inst.dst_u]}" for _ in [1] if inst.dst_u] + [f"src0_sel:{_SEL[inst.src0_sel]}" for _ in [1] if inst.src0_sel != 6] + return f"{name}_sdwa v{inst.vdst}, {src}" + (" " + " ".join(mods) if mods else "") + + def _disasm_dpp(inst) -> str: + try: name = CDNA_VOP1Op(inst.vop_op).name.lower() + except ValueError: name = f"vop1_op_{inst.vop_op}" + src, ctrl = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0), inst.dpp_ctrl + dpp = f"quad_perm:[{ctrl&3},{(ctrl>>2)&3},{(ctrl>>4)&3},{(ctrl>>6)&3}]" if ctrl < 0x100 else f"row_shl:{ctrl&0xf}" if ctrl < 0x110 else f"row_shr:{ctrl&0xf}" if ctrl < 0x120 else f"row_ror:{ctrl&0xf}" if ctrl < 0x130 else _DPP.get(ctrl, f"dpp_ctrl:0x{ctrl:x}") + mods = [dpp] + [f"row_mask:0x{inst.row_mask:x}" for _ in [1] if inst.row_mask != 0xf] + [f"bank_mask:0x{inst.bank_mask:x}" for _ in [1] if inst.bank_mask != 0xf] + ["bound_ctrl:1" for _ in [1] if inst.bound_ctrl] + return f"{name}_dpp v{inst.vdst}, {src} " + " ".join(mods) + + # Register CDNA handlers - shared formats use merged disassemblers, CDNA-only formats use dedicated ones + DISASM_HANDLERS.update({CDNA_VOP1: _disasm_vop1, CDNA_VOP2: _disasm_vop2, CDNA_VOPC: _disasm_vopc, + CDNA_SOP1: _disasm_sop1, CDNA_SOP2: _disasm_sop2, CDNA_SOPC: _disasm_sopc, CDNA_SOPK: _disasm_sopk, CDNA_SOPP: _disasm_sopp, + CDNA_SMEM: _disasm_smem, CDNA_DS: _disasm_ds, CDNA_FLAT: _disasm_flat, CDNA_MUBUF: _disasm_buf, CDNA_MTBUF: _disasm_buf, + VOP3A: _disasm_vop3a, VOP3B: _disasm_vop3b, CDNA_VOP3P: _disasm_cdna_vop3p, SDWA: _disasm_sdwa, DPP: _disasm_dpp}) +except ImportError: + pass diff --git a/extra/assembly/amd/autogen/cdna/ins.py b/extra/assembly/amd/autogen/cdna/ins.py index 671a4a81ae..0451136caf 100644 --- a/extra/assembly/amd/autogen/cdna/ins.py +++ b/extra/assembly/amd/autogen/cdna/ins.py @@ -7,20 +7,18 @@ import functools # instruction formats class DPP(Inst64): - encoding = bits[31:26] == 0b110110 - src1_sel = bits[58:56] - src1_sext = bits[59] - src1_neg = bits[60] - src1_abs = bits[61] - s1 = bits[63] - offset0 = bits[7:0] - offset1 = bits[15:8] - op = bits[24:17] - acc = bits[25] - addr:VGPRField = bits[39:32] - data0:VGPRField = bits[47:40] - data1:VGPRField = bits[55:48] - vdst:VGPRField = bits[63:56] + encoding = bits[8:0] == 0b11111010 + vop_op = bits[16:9] + vdst:VGPRField = bits[24:17] + vop2_op = bits[31:25] + src0:Src = bits[39:32] + dpp_ctrl = bits[48:40] + bound_ctrl = bits[51] + src0_neg = bits[52] + src0_abs = bits[53] + src1_neg = bits[54] + src1_abs = bits[55] + bank_mask = bits[59:56] row_mask = bits[63:60] class DS(Inst64): @@ -82,6 +80,10 @@ class MUBUF(Inst64): acc = bits[55] class SDWA(Inst64): + encoding = bits[8:0] == 0b11111001 + vop_op = bits[16:9] + vdst:VGPRField = bits[24:17] + vop2_op = bits[31:25] src0:Src = bits[39:32] dst_sel = bits[42:40] dst_u = bits[44:43] @@ -97,9 +99,6 @@ class SDWA(Inst64): src1_neg = bits[60] src1_abs = bits[61] s1 = bits[63] - sdst:SGPRField = bits[46:40] - sd = bits[47] - row_mask = bits[63:60] class SDWAB(Inst64): src0:Src = bits[39:32] diff --git a/extra/assembly/amd/autogen/rdna3/enum.py b/extra/assembly/amd/autogen/rdna3/enum.py index fc56dccda7..90bb9dff1c 100644 --- a/extra/assembly/amd/autogen/rdna3/enum.py +++ b/extra/assembly/amd/autogen/rdna3/enum.py @@ -488,6 +488,8 @@ class SMEMOp(IntEnum): S_BUFFER_LOAD_B512 = 12 S_GL1_INV = 32 S_DCACHE_INV = 33 + S_ATC_PROBE = 34 + S_ATC_PROBE_BUFFER = 35 class SOP1Op(IntEnum): S_MOV_B32 = 0 @@ -710,6 +712,8 @@ class SOPKOp(IntEnum): S_SETREG_B32 = 18 S_SETREG_IMM32_B32 = 19 S_CALL_B64 = 20 + S_SUBVECTOR_LOOP_BEGIN = 22 + S_SUBVECTOR_LOOP_END = 23 S_WAITCNT_VSCNT = 24 S_WAITCNT_VMCNT = 25 S_WAITCNT_EXPCNT = 26 @@ -751,6 +755,8 @@ class SOPPOp(IntEnum): S_SENDMSGHALT = 55 S_INCPERFLEVEL = 56 S_DECPERFLEVEL = 57 + S_TTRACEDATA = 58 + S_TTRACEDATA_IMM = 59 S_ICACHE_INV = 60 S_BARRIER = 61 diff --git a/extra/assembly/amd/autogen/rdna3/ins.py b/extra/assembly/amd/autogen/rdna3/ins.py index 1dd7c9f893..0c486bb4c7 100644 --- a/extra/assembly/amd/autogen/rdna3/ins.py +++ b/extra/assembly/amd/autogen/rdna3/ins.py @@ -692,6 +692,8 @@ s_buffer_load_b256 = functools.partial(SMEM, SMEMOp.S_BUFFER_LOAD_B256) s_buffer_load_b512 = functools.partial(SMEM, SMEMOp.S_BUFFER_LOAD_B512) s_gl1_inv = functools.partial(SMEM, SMEMOp.S_GL1_INV) s_dcache_inv = functools.partial(SMEM, SMEMOp.S_DCACHE_INV) +s_atc_probe = functools.partial(SMEM, SMEMOp.S_ATC_PROBE) +s_atc_probe_buffer = functools.partial(SMEM, SMEMOp.S_ATC_PROBE_BUFFER) s_mov_b32 = functools.partial(SOP1, SOP1Op.S_MOV_B32) s_mov_b64 = functools.partial(SOP1, SOP1Op.S_MOV_B64) s_cmov_b32 = functools.partial(SOP1, SOP1Op.S_CMOV_B32) @@ -906,6 +908,8 @@ s_getreg_b32 = functools.partial(SOPK, SOPKOp.S_GETREG_B32) s_setreg_b32 = functools.partial(SOPK, SOPKOp.S_SETREG_B32) s_setreg_imm32_b32 = functools.partial(SOPK, SOPKOp.S_SETREG_IMM32_B32) s_call_b64 = functools.partial(SOPK, SOPKOp.S_CALL_B64) +s_subvector_loop_begin = functools.partial(SOPK, SOPKOp.S_SUBVECTOR_LOOP_BEGIN) +s_subvector_loop_end = functools.partial(SOPK, SOPKOp.S_SUBVECTOR_LOOP_END) s_waitcnt_vscnt = functools.partial(SOPK, SOPKOp.S_WAITCNT_VSCNT) s_waitcnt_vmcnt = functools.partial(SOPK, SOPKOp.S_WAITCNT_VMCNT) s_waitcnt_expcnt = functools.partial(SOPK, SOPKOp.S_WAITCNT_EXPCNT) @@ -945,6 +949,8 @@ s_sendmsg = functools.partial(SOPP, SOPPOp.S_SENDMSG) s_sendmsghalt = functools.partial(SOPP, SOPPOp.S_SENDMSGHALT) s_incperflevel = functools.partial(SOPP, SOPPOp.S_INCPERFLEVEL) s_decperflevel = functools.partial(SOPP, SOPPOp.S_DECPERFLEVEL) +s_ttracedata = functools.partial(SOPP, SOPPOp.S_TTRACEDATA) +s_ttracedata_imm = functools.partial(SOPP, SOPPOp.S_TTRACEDATA_IMM) s_icache_inv = functools.partial(SOPP, SOPPOp.S_ICACHE_INV) s_barrier = functools.partial(SOPP, SOPPOp.S_BARRIER) v_interp_p10_f32 = functools.partial(VINTERP, VINTERPOp.V_INTERP_P10_F32) diff --git a/extra/assembly/amd/dsl.py b/extra/assembly/amd/dsl.py index 0e34374fd8..f374f7116d 100644 --- a/extra/assembly/amd/dsl.py +++ b/extra/assembly/amd/dsl.py @@ -458,10 +458,16 @@ class Inst: @classmethod def from_bytes(cls, data: bytes): + import typing inst = cls.from_int(int.from_bytes(data[:cls._size()], 'little')) op_val = inst._values.get('op', 0) - has_literal = cls.__name__ == 'VOP2' and op_val in (44, 45, 55, 56) - has_literal = has_literal or (cls.__name__ == 'SOP2' and op_val in (69, 70)) + # Check for instructions that always have a literal constant (FMAMK/FMAAK/MADMK/MADAK, SETREG_IMM32) + op_name = '' + if cls.__name__ in ('VOP2', 'SOP2', 'SOPK') and 'op' in (hints := typing.get_type_hints(cls, include_extras=True)): + if typing.get_origin(hints['op']) is typing.Annotated: + try: op_name = typing.get_args(hints['op'])[1](op_val).name + except (ValueError, TypeError): pass + has_literal = any(x in op_name for x in ('FMAMK', 'FMAAK', 'MADMK', 'MADAK', 'SETREG_IMM32')) # VOPD fmaak/fmamk always have a literal (opx/opy value 1 or 2) opx, opy = inst._values.get('opx', 0), inst._values.get('opy', 0) has_literal = has_literal or (cls.__name__ == 'VOPD' and (opx in (1, 2) or opy in (1, 2))) @@ -475,7 +481,7 @@ class Inst: lit32 = int.from_bytes(data[cls._size():cls._size()+4], 'little') # Find which source has literal (255) and check its register count lit_src_is_64 = False - for n, idx in [('src0', 0), ('src1', 1), ('src2', 2)]: + for n, idx in [('src0', 0), ('src1', 1), ('src2', 2), ('ssrc0', 0), ('ssrc1', 1)]: if n in inst._values and isinstance(inst._values[n], RawImm) and inst._values[n].val == 255: lit_src_is_64 = inst.src_regs(idx) == 2 break @@ -495,7 +501,12 @@ class Inst: return unwrap(self._values.get(name, 0)) def lit(self, v: int, neg: bool = False) -> str: - s = f"0x{self._literal:x}" if v == 255 and self._literal else decode_src(v) + if v == 255 and self._literal is not None: + # For 64-bit sources, literal is stored shifted - extract the 32-bit value + lit32 = (self._literal >> 32) if self._literal > 0xffffffff else self._literal + s = f"0x{lit32:x}" + else: + s = decode_src(v) return f"-{s}" if neg else s def __eq__(self, other): @@ -528,6 +539,10 @@ class Inst: elif val in self._VOP3SD_OPS: self.op = VOP3SDOp(val) else: self.op = VOP3Op(val) except ValueError: self.op = val + # Prefer BitField marker (class-specific enum) over _enum_map (generic RDNA3 enums) + elif 'op' in self._fields and (marker := self._fields['op'].marker) and issubclass(marker, IntEnum): + try: self.op = marker(val) + except ValueError: self.op = val elif cls_name in self._enum_map: try: self.op = self._enum_map[cls_name](val) except ValueError: self.op = val diff --git a/extra/assembly/amd/pdf.py b/extra/assembly/amd/pdf.py index 7079c2b07b..1728ce578b 100644 --- a/extra/assembly/amd/pdf.py +++ b/extra/assembly/amd/pdf.py @@ -194,13 +194,30 @@ def _parse_single_pdf(url: str): if fmt_name in formats: formats[fmt_name] = [(n, h, 14 if n == 'OP' else l, e, t) for n, h, l, e, t in formats[fmt_name]] if doc_name in ('RDNA3', 'RDNA3.5'): - if 'SOPPOp' in enums: assert 8 not in enums['SOPPOp']; enums['SOPPOp'][8] = 'S_WAITCNT_DEPCTR' + if 'SOPPOp' in enums: + for k, v in {8: 'S_WAITCNT_DEPCTR', 58: 'S_TTRACEDATA', 59: 'S_TTRACEDATA_IMM'}.items(): + assert k not in enums['SOPPOp']; enums['SOPPOp'][k] = v + if 'SOPKOp' in enums: + for k, v in {22: 'S_SUBVECTOR_LOOP_BEGIN', 23: 'S_SUBVECTOR_LOOP_END'}.items(): + assert k not in enums['SOPKOp']; enums['SOPKOp'][k] = v + if 'SMEMOp' in enums: + for k, v in {34: 'S_ATC_PROBE', 35: 'S_ATC_PROBE_BUFFER'}.items(): + assert k not in enums['SMEMOp']; enums['SMEMOp'][k] = v if 'DSOp' in enums: for k, v in {24: 'DS_GWS_SEMA_RELEASE_ALL', 25: 'DS_GWS_INIT', 26: 'DS_GWS_SEMA_V', 27: 'DS_GWS_SEMA_BR', 28: 'DS_GWS_SEMA_P', 29: 'DS_GWS_BARRIER'}.items(): assert k not in enums['DSOp']; enums['DSOp'][k] = v if 'FLATOp' in enums: for k, v in {40: 'GLOBAL_LOAD_ADDTID_B32', 41: 'GLOBAL_STORE_ADDTID_B32', 55: 'FLAT_ATOMIC_CSUB_U32'}.items(): assert k not in enums['FLATOp']; enums['FLATOp'][k] = v + # CDNA SDWA/DPP: PDF only has modifier fields, need VOP1/VOP2 overlay for correct encoding + if is_cdna: + if 'SDWA' in formats: + formats['SDWA'] = [('ENCODING', 8, 0, 0xf9, None), ('VOP_OP', 16, 9, None, None), ('VDST', 24, 17, None, 'VGPRField'), ('VOP2_OP', 31, 25, None, None)] + \ + [f for f in formats['SDWA'] if f[0] not in ('ENCODING', 'SDST', 'SD', 'ROW_MASK')] + if 'DPP' in formats: + formats['DPP'] = [('ENCODING', 8, 0, 0xfa, None), ('VOP_OP', 16, 9, None, None), ('VDST', 24, 17, None, 'VGPRField'), ('VOP2_OP', 31, 25, None, None), + ('SRC0', 39, 32, None, 'Src'), ('DPP_CTRL', 48, 40, None, None), ('BOUND_CTRL', 51, 51, None, None), ('SRC0_NEG', 52, 52, None, None), ('SRC0_ABS', 53, 53, None, None), + ('SRC1_NEG', 54, 54, None, None), ('SRC1_ABS', 55, 55, None, None), ('BANK_MASK', 59, 56, None, None), ('ROW_MASK', 63, 60, None, None)] # Extract pseudocode for instructions all_text = '\n'.join(pdf.text(i) for i in range(instr_start, instr_end)) diff --git a/extra/assembly/amd/test/test_llvm.py b/extra/assembly/amd/test/test_llvm.py index ca6010fdcf..5bd7779c9c 100644 --- a/extra/assembly/amd/test/test_llvm.py +++ b/extra/assembly/amd/test/test_llvm.py @@ -131,28 +131,19 @@ def _make_asm_test(name): def _make_disasm_test(name): def test(self): - _, fmt_cls, op_enum = LLVM_TEST_FILES[name] + _, base_fmt_cls, base_op_enum = LLVM_TEST_FILES[name] # VOP3SD opcodes that share encoding with VOP3 (only for vop3sd test, not vopc promotions) vop3sd_opcodes = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770} is_vopc_promotion = name in ('vop3_from_vopc', 'vop3_from_vopcx') - undocumented = {'smem': {34, 35}, 'sopk': {22, 23}, 'sopp': {8, 58, 59}} # First pass: decode all instructions and collect disasm strings to_test: list[tuple[str, bytes, str | None, str | None]] = [] # (asm_text, data, disasm_str, error) - skipped = 0 for asm_text, data in self.tests.get(name, []): - if len(data) > fmt_cls._size(): continue - temp_inst = fmt_cls.from_bytes(data) - temp_op = temp_inst._values.get('op', 0) - temp_op = temp_op.val if hasattr(temp_op, 'val') else temp_op - if temp_op in undocumented.get(name, set()): skipped += 1; continue - if name == 'sopp': - simm16 = temp_inst._values.get('simm16', 0) - simm16 = simm16.val if hasattr(simm16, 'val') else simm16 - sopp_no_imm = {48, 54, 53, 55, 60, 61, 62} - if temp_op in sopp_no_imm and simm16 != 0: skipped += 1; continue + # Detect VOP3 promotions in VOP1/VOP2/VOPC tests: VOP3 has bits [31:26]=0b110101 in first dword + is_vop3_enc = name in ('vop1', 'vop2', 'vopc', 'vopcx') and len(data) >= 4 and (data[3] >> 2) == 0x35 + fmt_cls, op_enum = (VOP3, VOP3Op) if is_vop3_enc else (base_fmt_cls, base_op_enum) try: - if fmt_cls.__name__ in ('VOP3', 'VOP3SD'): + if base_fmt_cls.__name__ in ('VOP3', 'VOP3SD'): temp = VOP3.from_bytes(data) op_val = temp._values.get('op', 0) op_val = op_val.val if hasattr(op_val, 'val') else op_val @@ -188,7 +179,7 @@ def _make_disasm_test(name): if llvm_bytes is not None and llvm_bytes == data: passed += 1 elif llvm_bytes is not None: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}") - print(f"{name.upper()} disasm: {passed} passed, {failed} failed" + (f", {skipped} skipped" if skipped else "")) + print(f"{name.upper()} disasm: {passed} passed, {failed} failed") if failures[:10]: print(" " + "\n ".join(failures[:10])) self.assertEqual(failed, 0) return test diff --git a/extra/assembly/amd/test/test_llvm_cdna.py b/extra/assembly/amd/test/test_llvm_cdna.py new file mode 100644 index 0000000000..95bef12e13 --- /dev/null +++ b/extra/assembly/amd/test/test_llvm_cdna.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +"""Test CDNA assembler/disassembler against LLVM test vectors.""" +import unittest, re, subprocess +from tinygrad.helpers import fetch +from extra.assembly.amd.autogen.cdna.ins import * +from extra.assembly.amd.asm import disasm +from extra.assembly.amd.test.helpers import get_llvm_mc + +LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU" + +def parse_llvm_tests(text: str, mnemonic_filter: str = None, size_filter: int = None) -> list[tuple[str, bytes]]: + """Parse LLVM test format into (asm, expected_bytes) pairs.""" + tests, lines = [], text.split('\n') + for i, line in enumerate(lines): + line = line.strip() + if not line or line.startswith(('//', '.', ';')): continue + asm_text = line.split('//')[0].strip() + if not asm_text or (mnemonic_filter and not asm_text.startswith(mnemonic_filter)): continue + for j in list(range(max(0, i - 3), i)) + list(range(i, min(i + 3, len(lines)))): + if m := re.search(r'(?:VI9|GFX9|CHECK)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]): + hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '') + elif m := re.search(r'CHECK[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]): + hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '') + else: continue + try: + data = bytes.fromhex(hex_bytes) + if size_filter is None or len(data) == size_filter: tests.append((asm_text, data)) + except ValueError: pass + break + return tests + +# Use gfx9 tests for compatible scalar/vector formats and gfx90a/gfx942 tests for CDNA-specific instructions +# Format: (filename, format_class, op_enum, mcpu, mnemonic_filter, size_filter) +CDNA_TEST_FILES = { + # Scalar ALU - encoding is stable across GFX9/CDNA + 'sop1': ('gfx9_asm_sop1.s', SOP1, SOP1Op, 'gfx940', None, None), + 'sop2': ('gfx9_asm_sop2.s', SOP2, SOP2Op, 'gfx940', None, None), + 'sopp': ('gfx9_asm_sopp.s', SOPP, SOPPOp, 'gfx940', None, None), + 'sopp_gfx9': ('sopp-gfx9.s', SOPP, SOPPOp, 'gfx940', None, None), + 'sopk': ('gfx9_asm_sopk.s', SOPK, SOPKOp, 'gfx940', None, None), + 'sopc': ('gfx9_asm_sopc.s', SOPC, SOPCOp, 'gfx940', None, None), + # Vector ALU - encoding is mostly stable + 'vop1': ('gfx9_asm_vop1.s', VOP1, VOP1Op, 'gfx940', None, None), + 'vop1_gfx9': ('vop1-gfx9.s', VOP1, VOP1Op, 'gfx940', None, None), + 'vop2': ('gfx9_asm_vop2.s', VOP2, VOP2Op, 'gfx940', None, None), + 'vopc': ('gfx9_asm_vopc.s', VOPC, VOPCOp, 'gfx940', None, None), + 'vop3p': ('gfx9_asm_vop3p.s', VOP3P, VOP3POp, 'gfx940', None, None), + 'vop3_gfx9': ('vop3-gfx9.s', VOP3A, VOP3AOp, 'gfx940', None, 8), # Only 64-bit VOP3 instructions + # Memory instructions + 'ds': ('gfx9_asm_ds.s', DS, DSOp, 'gfx940', None, None), + 'ds_gfx9': ('ds-gfx9.s', DS, DSOp, 'gfx940', None, None), + # CDNA memory instructions (gfx90a has correct FLAT/MUBUF encodings with acc registers) + 'flat_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'flat_', None), + 'global_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'global_', None), + 'mubuf_gfx90a': ('gfx90a_ldst_acc.s', MUBUF, MUBUFOp, 'gfx90a', 'buffer_', None), + 'mubuf_gfx9': ('mubuf-gfx9.s', MUBUF, MUBUFOp, 'gfx940', None, None), + 'scratch_gfx942': ('flat-scratch-gfx942.s', FLAT, FLATOp, 'gfx942', 'scratch_', None), + # CDNA-specific: MFMA/MAI instructions + 'mai': ('mai-gfx942.s', VOP3P, VOP3POp, 'gfx942', None, None), + # SDWA and DPP format tests for VOP1 (VOP2 has different bit layout, tested separately) + 'sdwa_vop1': ('gfx9_asm_vop1.s', SDWA, VOP1Op, 'gfx940', None, None), + 'dpp_vop1': ('gfx9_asm_vop1.s', DPP, VOP1Op, 'gfx940', None, None), +} + +class TestLLVMCDNA(unittest.TestCase): + """Test CDNA instruction format decode/encode roundtrip and disassembly.""" + tests: dict[str, list[tuple[str, bytes]]] = {} + + @classmethod + def setUpClass(cls): + for name, (filename, _, _, _, mnemonic_filter, size_filter) in CDNA_TEST_FILES.items(): + try: + data = fetch(f"{LLVM_BASE}/{filename}").read_bytes() + cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore'), mnemonic_filter, size_filter) + except Exception as e: + print(f"Warning: couldn't fetch {filename}: {e}") + cls.tests[name] = [] + +def _get_val(v): return v.val if hasattr(v, 'val') else v + +def _filter_and_decode(tests, fmt_cls, op_enum): + """Filter tests and decode instructions, yielding (asm_text, data, decoded, error).""" + fn, is_sdwa, is_dpp = fmt_cls.__name__, fmt_cls.__name__ == 'SDWA', fmt_cls.__name__ == 'DPP' + for asm_text, data in tests: + has_lit = False + # SDWA/DPP format tests: only accept matching 8-byte instructions + if is_sdwa: + if len(data) != 8 or data[0] != 0xf9: continue + elif is_dpp: + if len(data) != 8 or data[0] != 0xfa: continue + elif fmt_cls._size() == 4 and len(data) == 8: + if data[0] in (0xf9, 0xfa): continue # Skip SDWA/DPP (tested separately) + has_lit = data[0] == 255 or (len(data) >= 2 and data[1] == 255 and fn in ('SOP2', 'SOPC')) + if fn == 'SOPK': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 23) & 0x1f) == 20 + if fn == 'VOP2': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 25) & 0x3f) in (23, 24, 36, 37) + if not has_lit: continue + if len(data) > fmt_cls._size() + (4 if has_lit else 0): continue + try: + decoded = fmt_cls.from_bytes(data) + # For SDWA/DPP, opcode location depends on VOP1 vs VOP2 + if is_sdwa or is_dpp: + vop2_op = _get_val(decoded._values.get('vop2_op', 0)) + op_val = _get_val(decoded._values.get('vop_op', 0)) if vop2_op == 0x3f else vop2_op + else: + op_val = _get_val(decoded._values.get('op', 0)) + try: op_enum(op_val) + except ValueError: continue + yield asm_text, data, decoded, None + except Exception as e: + yield asm_text, data, None, str(e) + +def _make_roundtrip_test(name): + def test(self): + _, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name] + passed, failed, failures = 0, 0, [] + for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum): + if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue + if decoded.to_bytes()[:len(data)] == data: passed += 1 + else: failed += 1; failures.append(f"'{asm_text}': orig={data.hex()} reenc={decoded.to_bytes()[:len(data)].hex()}") + print(f"CDNA {name.upper()} roundtrip: {passed} passed, {failed} failed") + if failures[:5]: print(" " + "\n ".join(failures[:5])) + self.assertEqual(failed, 0) + return test + +def _make_disasm_test(name): + def test(self): + _, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name] + passed, failed, failures = 0, 0, [] + for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum): + if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue + if decoded.to_bytes()[:len(data)] != data: failed += 1; failures.append(f"'{asm_text}': roundtrip failed"); continue + if not (disasm_text := disasm(decoded)) or not disasm_text.strip(): failed += 1; failures.append(f"'{asm_text}': empty disassembly"); continue + passed += 1 + print(f"CDNA {name.upper()} disasm: {passed} passed, {failed} failed") + if failures[:5]: print(" " + "\n ".join(failures[:5])) + self.assertEqual(failed, 0) + return test + +for name in CDNA_TEST_FILES: + setattr(TestLLVMCDNA, f'test_{name}_roundtrip', _make_roundtrip_test(name)) + setattr(TestLLVMCDNA, f'test_{name}_disasm', _make_disasm_test(name)) + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_dtype.py b/test/test_dtype.py index b1eb663fa5..2104b67112 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -131,7 +131,7 @@ class TestDType(unittest.TestCase): def test_finfo(self): if self.DTYPE not in [dtypes.float16, dtypes.float32, dtypes.float64]: return info = np.finfo(_to_np_dtype(self.DTYPE)) - self.assertEqual(info.bits, self.DTYPE.itemsize*8) + self.assertEqual(info.bits, self.DTYPE.bitsize) self.assertEqual((info.nexp, info.nmant), dtypes.finfo(self.DTYPE)) def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None): diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index 1b020e2a57..5e016bbcef 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -452,5 +452,23 @@ class TestImageSimplification(unittest.TestCase): load = get_load_image_uop((32, 1024, 4), valid, (alu0, alu1)) self.check(load, "(lidx1<7)", "((gidx0*2+lidx1*512+(lidx0*8192+r0*4096)+-11711)//4%1024)", "(lidx0*2+r0+-3)") +class TestUnfoldableImageChannelSelection(unittest.TestCase): + def _count_nans(self, load): + with Context(NOOPT=1, SPEC=0): + result = full_rewrite_to_sink(load.sink()).src[0] + return sum(1 for u in result.toposort() if u.op is Ops.CONST and u.arg != u.arg) + + def test_bounded_channel_no_nan(self): + # unfoldable image load with bounded idx % 4 range [0,1] -> no NAN fallback needed + lidx = Special("lidx", 2) + load = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((10, 10, 4)), arg=0).index(lidx, ptr=True), UOp.const(dtypes.float, 0))) + self.assertEqual(self._count_nans(load), 0) + + def test_unbounded_channel_has_nan(self): + # variable with negative range -> x % 4 can be negative -> needs NAN fallback + x = Variable("x", -10, 10) + load = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((10, 10, 4)), arg=0).index(x, ptr=True), UOp.const(dtypes.float, 0))) + self.assertEqual(self._count_nans(load), 1) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 5e6895fa1e..7d7c8ebc7c 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -197,7 +197,12 @@ def image_fixup(ls:UOp): oidx = UOp(Ops.VECTORIZE, dtypes.index.vec(2), ((x // 4) % image_dtype.shape[1], (x // (4*image_dtype.shape[1])))) idx = idx.replace(src=(idx.src[0], oidx.valid(valid))) vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:]) - return functools.reduce(lambda ret, i: (x % 4).ne(i).where(ret, vec_load.gep(i)), range(4), ls.const_like(float('nan'))) + # image pixels have 4 channels (.xyzw), select channel based on x % 4 + x_mod_4 = x % 4 + def sel(ret, i): return x_mod_4.ne(i).where(ret, vec_load.gep(i)) + # if x is non-negative, x % 4 is in [0, 3] and we can skip NAN fallback + if x_mod_4.vmin >= 0: return functools.reduce(sel, range(x_mod_4.vmin+1, x_mod_4.vmax+1), vec_load.gep(x_mod_4.vmin)) + return functools.reduce(sel, range(4), ls.const_like(float('nan'))) return None diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index fdefab243d..9a79696c10 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -38,16 +38,18 @@ class AddrSpace(Enum): @dataclass(frozen=True, eq=False) class DType(metaclass=DTypeMetaClass): priority: int # this determines when things get upcasted - itemsize: int + bitsize: int name: str fmt: FmtStr|None count: int _scalar: DType|None + @property + def itemsize(self) -> int: return (self.bitsize + 7) // 8 @staticmethod - def new(priority:int, itemsize:int, name:str, fmt:FmtStr|None): return DType(priority, itemsize, name, fmt, 1, None) + def new(priority:int, bitsize:int, name:str, fmt:FmtStr|None): return DType(priority, bitsize, name, fmt, 1, None) def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self)) def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count != 1 else "") - def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count) + def __lt__(self, o:DType): return (self.priority, self.bitsize, self.name, self.fmt, self.count) < (o.priority, o.bitsize, o.name, o.fmt, o.count) @property def base(self): return self @property @@ -56,9 +58,9 @@ class DType(metaclass=DTypeMetaClass): def vec(self, sz:int) -> DType: assert self.count == 1, f"can't vectorize {self} with size {sz}" if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar - return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self) + return DType(self.priority, self.bitsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self) def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType: - return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size) + return PtrDType(self.priority, self.bitsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size) def scalar(self) -> DType: return self._scalar if self._scalar is not None else self def nbytes(self) -> int: raise RuntimeError("only ptr types have nbytes") @property @@ -79,8 +81,8 @@ class PtrDType(DType): assert self.v == 1, f"can't vectorize ptr {self} with size {sz}" if sz == 1: return self # sz=1 is a scalar if isinstance(self, ImageDType): - return ImageDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size, self.shape) - return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size) + return ImageDType(self.priority, self.bitsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size, self.shape) + return type(self)(self.priority, self.bitsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size) def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType: raise RuntimeError("can't make a pointer from a pointer") def nbytes(self) -> int: if self.size == -1: raise RuntimeError("can't get nbytes of a pointer with unlimited size") @@ -142,12 +144,12 @@ class dtypes: @staticmethod @functools.cache def min(dtype:DType): - if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.scalar().itemsize*8-1) + if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.scalar().bitsize-1) return -float("inf") if dtypes.is_float(dtype) else False @staticmethod @functools.cache def max(dtype:DType): - if dtypes.is_int(dtype): return 2**(dtype.scalar().itemsize*8)-1+dtypes.min(dtype) + if dtypes.is_int(dtype): return 2**(dtype.scalar().bitsize)-1+dtypes.min(dtype) return float("inf") if dtypes.is_float(dtype) else True @staticmethod def finfo(dtype:DType) -> tuple[int, int]: @@ -158,25 +160,23 @@ class dtypes: @staticmethod def fields() -> dict[str, DType]: return DTYPES_DICT void: Final[DType] = DType.new(-1, 0, "void", None) - index: Final[DType] = DType.new(-1,100, "index", None) + index: Final[DType] = DType.new(-1, 800, "index", None) bool: Final[DType] = DType.new(0, 1, "bool", '?') - int8: Final[DType] = DType.new(1, 1, "signed char", 'b') - uint8: Final[DType] = DType.new(2, 1, "unsigned char", 'B') - int16: Final[DType] = DType.new(3, 2, "short", 'h') - uint16: Final[DType] = DType.new(4, 2, "unsigned short", 'H') - int24: Final[DType] = DType.new(5, 3, "int24", None) - uint24: Final[DType] = DType.new(6, 3, "uint24", None) - int32: Final[DType] = DType.new(7, 4, "int", 'i') - uint32: Final[DType] = DType.new(8, 4, "unsigned int", 'I') - int64: Final[DType] = DType.new(9, 8, "long", 'q') - uint64: Final[DType] = DType.new(10, 8, "unsigned long", 'Q') - fp8e4m3: Final[DType] = DType.new(11, 1, "float8_e4m3", None) - fp8e5m2: Final[DType] = DType.new(12, 1, "float8_e5m2", None) - float16: Final[DType] = DType.new(13, 2, "half", 'e') + int8: Final[DType] = DType.new(1, 8, "signed char", 'b') + uint8: Final[DType] = DType.new(2, 8, "unsigned char", 'B') + int16: Final[DType] = DType.new(3, 16, "short", 'h') + uint16: Final[DType] = DType.new(4, 16, "unsigned short", 'H') + int32: Final[DType] = DType.new(5, 32, "int", 'i') + uint32: Final[DType] = DType.new(6, 32, "unsigned int", 'I') + int64: Final[DType] = DType.new(7, 64, "long", 'q') + uint64: Final[DType] = DType.new(8, 64, "unsigned long", 'Q') + fp8e4m3: Final[DType] = DType.new(9, 8, "float8_e4m3", None) + fp8e5m2: Final[DType] = DType.new(10, 8, "float8_e5m2", None) + float16: Final[DType] = DType.new(11, 16, "half", 'e') # bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16 - bfloat16: Final[DType] = DType.new(14, 2, "__bf16", None) - float32: Final[DType] = DType.new(15, 4, "float", 'f') - float64: Final[DType] = DType.new(16, 8, "double", 'd') + bfloat16: Final[DType] = DType.new(12, 16, "__bf16", None) + float32: Final[DType] = DType.new(13, 32, "float", 'f') + float64: Final[DType] = DType.new(14, 64, "double", 'd') # dtype aliases half = float16; float = float32; double = float64 # noqa: E702 @@ -185,9 +185,9 @@ class dtypes: # NOTE: these are image dtypes @staticmethod - def imageh(shp, pitch=-1): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch) + def imageh(shp, pitch=-1): return ImageDType(100, 16, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch) @staticmethod - def imagef(shp, pitch=-1): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch) + def imagef(shp, pitch=-1): return ImageDType(100, 32, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch) default_float: ClassVar[DType] = float32 default_int: ClassVar[DType] = int32 diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 6106efbf15..e704fdacce 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -518,7 +518,7 @@ class AMDHIPRenderer(CStyleLanguage): prefix.append("typedef long unsigned int size_t;") ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]] ocml_ops = {Ops.EXP2: ("exp2", "pure"), Ops.LOG2: ("log2", "pure"), Ops.SQRT: ("sqrt", "const"), Ops.SIN: ("sin", ""), Ops.TRUNC: ("trunc", "")} - ocml = [(f"__ocml_{ocml_ops[op][0]}_f{dt.itemsize * 8}", dt.name, dt.name, ocml_ops[op][1]) + ocml = [(f"__ocml_{ocml_ops[op][0]}_f{dt.bitsize}", dt.name, dt.name, ocml_ops[op][1]) for op, dt in dedup((u.op, u.dtype.scalar()) for u in uops) if op in ocml_ops and dt in (dtypes.half, dtypes.float, dtypes.double)] if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("typedef unsigned short hip_bfloat16;") if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#define half _Float16") diff --git a/tinygrad/renderer/nir.py b/tinygrad/renderer/nir.py index 3729a0ef6d..cba6a1dd25 100644 --- a/tinygrad/renderer/nir.py +++ b/tinygrad/renderer/nir.py @@ -12,7 +12,7 @@ def nsrc(d:mesa.nir_def) -> mesa.nir_src: return mesa.nir_src(ssa=ctypes.pointer def glsl_type(t:DType): return mesa.glsl_array_type(glsl_type(t.base), t.size, 0).contents if isinstance(t, PtrDType) else { **{getattr(dtypes,k):g(f"glsl_type_builtin_{v}") for k,v in [('double','double'),('float','float'),('float16','float16_t'),('bool','uint8_t')]}, - **{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.itemsize*8)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t] + **{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.bitsize)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t] # alu ops, aop[][] u_aop = { Ops.ADD: "iadd", Ops.MUL: "imul", Ops.IDIV: "udiv", Ops.MOD: "umod", Ops.CMPLT: "ult", Ops.CMPNE: "ine", Ops.CMPEQ: "ieq", Ops.OR: "ior", @@ -26,7 +26,7 @@ def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else ( def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def: if isinstance(it, PtrDType) and ot == dtypes.long: return src if ot == dtypes.bool: return nalu(b, c(it, False)+'ne'+('u' if c(it) == 'f' else ''), src, nimm(b, 0, it)) - return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.itemsize*8}", src) + return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.bitsize}", src) def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable): nif = mesa.nir_push_if(b, cond) @@ -71,12 +71,12 @@ def nimm_set(imm:mesa.nir_def, x, dtype:DType): instr = ctypes.cast(imm.parent_instr, ctypes.POINTER(mesa.nir_load_const_instr)) struct.pack_into(unwrap(dtype.fmt), (ctypes.c_ubyte * dtype.itemsize).from_address(ctypes.addressof(instr.contents.value)), 0, x) -@nir_instr(nc=1, bs=lambda dtype: 1 if dtype == dtypes.bool else dtype.itemsize * 8) +@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize) def nimm(b:mesa.nir_builder, x, dtype:DType) -> mesa.nir_def: - nimm_set(getattr((instr:=mesa.nir_load_const_instr_create(b.shader, 1, 1 if dtype==dtypes.bool else dtype.itemsize * 8)).contents, "def"), x, dtype) + nimm_set(getattr((instr:=mesa.nir_load_const_instr_create(b.shader, 1, dtype.bitsize)).contents, "def"), x, dtype) return instr -@nir_instr(nc=1, bs=lambda dtype: 1 if dtype == dtypes.bool else dtype.itemsize * 8) -def nundef(b, dtype): return mesa.nir_undef_instr_create(b.shader, 1, 1 if dtype == dtypes.bool else dtype.itemsize * 8) +@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize) +def nundef(b, dtype): return mesa.nir_undef_instr_create(b.shader, 1, dtype.bitsize) deref_var = nir_instr(nc=1, bs=32, modes=lambda var:var.data.mode, type=lambda var:var.type, var=lambda var:ctypes.pointer(var))( # pylint: disable=W0108 lambda b, var: mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_var)) @@ -86,7 +86,7 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1< 0: + if self.is_hive(): if reset_mode: return # in reset mode, do not raise raise RuntimeError("Malformed state. Use extra/amdpci/hive_reset.py to reset the hive") self.smu.mode1_reset() @@ -221,6 +221,8 @@ class AMDev(PCIDevImplBase): self.smu.set_clocks(level=0) self.ih.interrupt_handler() + def is_hive(self) -> bool: return self.gmc.xgmi_seg_sz > 0 + def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr def paddr2xgmi(self, paddr:int) -> int: return self.gmc.paddr_base + paddr def xgmi2paddr(self, xgmi_paddr:int) -> int: return xgmi_paddr - self.gmc.paddr_base diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index 2eb8765049..1a719ce9b0 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -57,8 +57,8 @@ class AM_GMC(AM_IP): self.trans_futher = self.adev.ip_ver[am.GC_HWIP] < (10, 0, 0) - # GFX11/GFX12 has 44-bit address space - self.address_space_mask = (1 << 44) - 1 + # mi3xx has 48-bit, others have 44-bit address space + self.address_space_mask = (1 << (48 if self.adev.ip_ver[am.GC_HWIP][:2] == (9,4) else 44)) - 1 self.memscratch_xgmi_paddr = self.adev.paddr2xgmi(self.adev.mm.palloc(0x1000, zero=False, boot=True)) self.dummy_page_xgmi_paddr = self.adev.paddr2xgmi(self.adev.mm.palloc(0x1000, zero=False, boot=True)) @@ -183,7 +183,8 @@ class AM_SMU(AM_IP): if self.adev.ip_ver[am.MP0_HWIP] >= (14,0,0): self._send_msg(__DEBUGSMC_MSG_Mode1Reset:=2, 0, debug=True) elif self.adev.ip_ver[am.MP0_HWIP] in {(13,0,6), (13,0,12)}: self._send_msg(self.smu_mod.PPSMC_MSG_GfxDriverReset, 1) else: self._send_msg(self.smu_mod.PPSMC_MSG_Mode1Reset, 0) - time.sleep(0.5) # 500ms + + if not self.adev.is_hive(): time.sleep(0.5) # 500ms def read_table(self, table_t, arg): if self.adev.ip_ver[am.MP0_HWIP] in {(13,0,6),(13,0,12)}: self._send_msg(self.smu_mod.PPSMC_MSG_GetMetricsTable, arg) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 8af6036d16..f00b598d6a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -604,7 +604,7 @@ class Tensor(OpMixin): bits = bits.bitcast(uint_dtype) # only randomize the mantissa bits and set the exponent to 1 one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype) - bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one) + bits = bits.rshift(dtype.bitsize - nmant).bitwise_or(one) # bitcast back to the original dtype and reshape out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape).requires_grad_(kwargs.get("requires_grad")) return out.contiguous() if contiguous else out diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 8a9310278b..d21347d9a3 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -186,7 +186,6 @@ commutative = PatternMatcher([ symbolic = symbolic_simple+commutative+PatternMatcher([ # ** boolean algebra ** - (UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x # TODO: make a more general or folder like simplify_valid (UPat.var("x", dtype=dtypes.bool) | UPat.var("x").logical_not(), lambda x: x.const_like(True)), # x|!x -> True # ** combine terms **