mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Merge remote-tracking branch 'origin/master' into asm_ucode
# Conflicts: # tinygrad/dtype.py
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# RDNA3 assembler and disassembler
|
||||
# RDNA3/CDNA assembler and disassembler
|
||||
from __future__ import annotations
|
||||
import re
|
||||
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory
|
||||
@@ -8,6 +8,8 @@ from extra.assembly.amd.autogen.rdna3 import ins
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP,
|
||||
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp)
|
||||
|
||||
def _is_cdna(inst: Inst) -> bool: return 'cdna' in inst.__class__.__module__
|
||||
|
||||
def _matches_encoding(word: int, cls: type[Inst]) -> bool:
|
||||
"""Check if word matches the encoding pattern of an instruction class."""
|
||||
if cls._encoding is None: return False
|
||||
@@ -81,10 +83,11 @@ def _src16(inst, v: int) -> str: return _fmt_v16(v) if v >= 256 else inst.lit(v)
|
||||
def _mods(*pairs) -> str: return " ".join(m for c, m in pairs if c)
|
||||
def _fmt_bits(label: str, val: int, count: int) -> str: return f"{label}:[{','.join(str((val >> i) & 1) for i in range(count))}]"
|
||||
|
||||
def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool, any_hi: bool) -> str:
|
||||
def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool) -> str:
|
||||
"""Format VOP3 source operand with modifiers."""
|
||||
if n > 1: s = _fmt_src(v, n)
|
||||
elif f16 and v >= 256: s = f"v{v - 256}.h" if hi else (f"v{v - 256}.l" if any_hi else inst.lit(v))
|
||||
if v == 255: s = inst.lit(v) # literal constant takes priority
|
||||
elif n > 1: s = _fmt_src(v, n)
|
||||
elif f16 and v >= 256: s = f"v{v - 256}.h" if hi else f"v{v - 256}.l"
|
||||
else: s = inst.lit(v)
|
||||
if abs_: s = f"|{s}|"
|
||||
return f"-{s}" if neg else s
|
||||
@@ -92,9 +95,11 @@ def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool, any
|
||||
def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str:
|
||||
"""Format op_sel modifier string."""
|
||||
if not need: return ""
|
||||
if is16_d and (opsel & 8): return f" op_sel:[1,1,1{',1' if n == 3 else ''}]"
|
||||
if n == 3: return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1},{(opsel >> 3) & 1}]"
|
||||
return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1}]"
|
||||
# For VOP1 (n=1): op_sel:[src0_hi, dst_hi], for VOP2 (n=2): op_sel:[src0_hi, src1_hi, dst_hi], for VOP3 (n=3): op_sel:[src0_hi, src1_hi, src2_hi, dst_hi]
|
||||
dst_hi = (opsel >> 3) & 1
|
||||
if n == 1: return f" op_sel:[{opsel & 1},{dst_hi}]"
|
||||
if n == 2: return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{dst_hi}]"
|
||||
return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1},{dst_hi}]"
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# DISASSEMBLER
|
||||
@@ -108,30 +113,41 @@ def _disasm_vop1(inst: VOP1) -> str:
|
||||
parts = name.split('_')
|
||||
is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name)
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
|
||||
src = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0)
|
||||
src = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0)
|
||||
return f"{name}_e32 {dst}, {src}"
|
||||
|
||||
def _disasm_vop2(inst: VOP2) -> str:
|
||||
name = inst.op_name.lower()
|
||||
suf = "" if inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32"
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
suf = "" if not cdna and inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32"
|
||||
lit = getattr(inst, '_literal', None)
|
||||
is16 = not cdna and inst.is_16bit()
|
||||
# fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1
|
||||
if inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}"
|
||||
if inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}"
|
||||
if inst.is_16bit(): return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
|
||||
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (", vcc_lo" if inst.op == VOP2Op.V_CNDMASK_B32 else "")
|
||||
if 'fmaak' in name or (not cdna and inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16)):
|
||||
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}, 0x{lit:x}"
|
||||
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{lit:x}"
|
||||
if 'fmamk' in name or (not cdna and inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16)):
|
||||
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, 0x{lit:x}, {_fmt_v16(inst.vsrc1, 0, 128)}"
|
||||
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{lit:x}, v{inst.vsrc1}"
|
||||
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
|
||||
vcc = "vcc" if cdna else "vcc_lo"
|
||||
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (f", {vcc}" if name == 'v_cndmask_b32' else "")
|
||||
|
||||
def _disasm_vopc(inst: VOPC) -> str:
|
||||
name = inst.op_name.lower()
|
||||
s0 = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0)
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
if cdna:
|
||||
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0))
|
||||
return f"{name}_e32 {s0}, v{inst.vsrc1}" if inst.op.value >= 128 else f"{name}_e32 vcc, {s0}, v{inst.vsrc1}"
|
||||
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0)
|
||||
s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else _fmt_v16(inst.vsrc1, 0, 128) if inst.is_16bit() else f"v{inst.vsrc1}"
|
||||
return f"{name}_e32 {s0}, {s1}" if inst.op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}"
|
||||
|
||||
NO_ARG_SOPP = {SOPPOp.S_ENDPGM, SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV,
|
||||
SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE}
|
||||
NO_ARG_SOPP = {SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV,
|
||||
SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE, SOPPOp.S_TTRACEDATA}
|
||||
|
||||
def _disasm_sopp(inst: SOPP) -> str:
|
||||
name = inst.op_name.lower()
|
||||
if inst.op in NO_ARG_SOPP: return name
|
||||
if inst.op == SOPPOp.S_ENDPGM: return name if inst.simm16 == 0 else f"{name} {inst.simm16}"
|
||||
if inst.op == SOPPOp.S_WAITCNT:
|
||||
vm, exp, lgkm = (inst.simm16 >> 10) & 0x3f, inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x3f
|
||||
p = [f"vmcnt({vm})" if vm != 0x3f else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
|
||||
@@ -154,12 +170,13 @@ def _disasm_smem(inst: SMEM) -> str:
|
||||
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc"))
|
||||
|
||||
def _disasm_flat(inst: FLAT) -> str:
|
||||
name = inst.op_name.lower()
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
|
||||
instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}"
|
||||
off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192)
|
||||
w = inst.dst_regs() * (2 if 'cmpswap' in name else 1)
|
||||
mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
|
||||
if cdna: mods = f"{f' offset:{off_val}' if off_val else ''}{' sc0' if inst.sc0 else ''}{' nt' if inst.nt else ''}{' sc1' if inst.sc1 else ''}"
|
||||
else: mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
|
||||
# saddr
|
||||
if seg == 'flat' or inst.saddr == 0x7F: saddr_s = ""
|
||||
elif inst.saddr == 124: saddr_s = ", off"
|
||||
@@ -172,8 +189,9 @@ def _disasm_flat(inst: FLAT) -> str:
|
||||
# addr width
|
||||
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2)
|
||||
data_s, vdst_s = _vreg(inst.data, w), _vreg(inst.vdst, w // 2 if 'cmpswap' in name else w)
|
||||
glc_or_sc0 = inst.sc0 if cdna else inst.glc
|
||||
if 'atomic' in name:
|
||||
return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if inst.glc else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}"
|
||||
return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if glc_or_sc0 else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}"
|
||||
if 'store' in name: return f"{instr} {addr_s}, {data_s}{saddr_s}{mods}"
|
||||
return f"{instr} {_vreg(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
|
||||
|
||||
@@ -211,7 +229,9 @@ def _disasm_vop3(inst: VOP3) -> str:
|
||||
# VOP3SD (shared encoding)
|
||||
if isinstance(op, VOP3SDOp):
|
||||
sdst = (inst.clmp << 7) | (inst.opsel << 3) | inst.abs
|
||||
def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s
|
||||
def src(v, neg, n):
|
||||
s = inst.lit(v) if v == 255 else (_fmt_src(v, n) if n > 1 else inst.lit(v))
|
||||
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
|
||||
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2))
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
|
||||
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
|
||||
@@ -225,18 +245,18 @@ def _disasm_vop3(inst: VOP3) -> str:
|
||||
is16_s2 = is16_s
|
||||
elif re.match(r'v_mad_[iu]32_[iu]16', name): is16_s = True
|
||||
elif 'pack_b32' in name: is16_s = is16_s2 = True
|
||||
elif 'sat_pk' in name: is16_d = True # v_sat_pk_* writes to 16-bit dest but takes 32-bit src
|
||||
else: is16_d = is16_s = is16_s2 = inst.is_16bit()
|
||||
|
||||
any_hi = inst.opsel != 0
|
||||
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s, any_hi)
|
||||
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s, any_hi)
|
||||
s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2, any_hi)
|
||||
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s)
|
||||
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s)
|
||||
s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2)
|
||||
|
||||
# Destination
|
||||
dn = inst.dst_regs()
|
||||
if op == VOP3Op.V_READLANE_B32: dst = _fmt_sdst(inst.vdst, 1)
|
||||
elif dn > 1: dst = _vreg(inst.vdst, dn)
|
||||
elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l" if any_hi else f"v{inst.vdst}"
|
||||
elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l"
|
||||
else: dst = f"v{inst.vdst}"
|
||||
|
||||
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
|
||||
@@ -244,7 +264,7 @@ def _disasm_vop3(inst: VOP3) -> str:
|
||||
need_opsel = nonvgpr_opsel or (inst.opsel and not is16_s)
|
||||
|
||||
if inst.op < 256: # VOPC
|
||||
return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}"
|
||||
return f"{name}_e64 {s0}, {s1}{cl}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}{cl}"
|
||||
if inst.op < 384: # VOP2
|
||||
n = inst.num_srcs()
|
||||
os = _opsel_str(inst.opsel, n, need_opsel, is16_d)
|
||||
@@ -258,7 +278,9 @@ def _disasm_vop3(inst: VOP3) -> str:
|
||||
|
||||
def _disasm_vop3sd(inst: VOP3SD) -> str:
|
||||
name = inst.op_name.lower()
|
||||
def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s
|
||||
def src(v, neg, n):
|
||||
s = inst.lit(v) if v == 255 else (_fmt_src(v, n) if n > 1 else inst.lit(v))
|
||||
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
|
||||
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2))
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
|
||||
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
|
||||
@@ -268,16 +290,22 @@ def _disasm_vop3sd(inst: VOP3SD) -> str:
|
||||
def _disasm_vopd(inst: VOPD) -> str:
|
||||
lit = inst._literal or inst.literal
|
||||
vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), VOPDOp(inst.opx).name.lower(), VOPDOp(inst.opy).name.lower()
|
||||
def half(n, vd, s0, vs1): return f"{n} v{vd}, {inst.lit(s0)}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}" if 'mov' in n else f"{n} v{vd}, {inst.lit(s0)}, v{vs1}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}"
|
||||
def half(n, vd, s0, vs1):
|
||||
if 'mov' in n: return f"{n} v{vd}, {inst.lit(s0)}"
|
||||
# fmamk: dst = src0 * K + vsrc1, fmaak: dst = src0 * vsrc1 + K
|
||||
if 'fmamk' in n and lit: return f"{n} v{vd}, {inst.lit(s0)}, 0x{lit:x}, v{vs1}"
|
||||
if 'fmaak' in n and lit: return f"{n} v{vd}, {inst.lit(s0)}, v{vs1}, 0x{lit:x}"
|
||||
return f"{n} v{vd}, {inst.lit(s0)}, v{vs1}"
|
||||
return f"{half(nx, inst.vdstx, inst.srcx0, inst.vsrcx1)} :: {half(ny, vdst_y, inst.srcy0, inst.vsrcy1)}"
|
||||
|
||||
def _disasm_vop3p(inst: VOP3P) -> str:
|
||||
name = inst.op_name.lower()
|
||||
is_wmma, n, is_fma_mix = 'wmma' in name, inst.num_srcs(), 'fma_mix' in name
|
||||
def get_src(v, sc): return inst.lit(v) if v == 255 else _fmt_src(v, sc)
|
||||
if is_wmma:
|
||||
sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8
|
||||
src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, 8), _vreg(inst.vdst, 8)
|
||||
else: src0, src1, src2, dst = _fmt_src(inst.src0, 1), _fmt_src(inst.src1, 1), _fmt_src(inst.src2, 1), f"v{inst.vdst}"
|
||||
src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 8), _vreg(inst.vdst, 8)
|
||||
else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}"
|
||||
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
|
||||
if is_fma_mix:
|
||||
def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s)
|
||||
@@ -289,15 +317,17 @@ def _disasm_vop3p(inst: VOP3P) -> str:
|
||||
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
|
||||
|
||||
def _disasm_buf(inst: MUBUF | MTBUF) -> str:
|
||||
name = inst.op_name.lower()
|
||||
if inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
if cdna and name in ('buffer_wbl2', 'buffer_inv'): return name
|
||||
if not cdna and inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
|
||||
w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \
|
||||
((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \
|
||||
{'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], 1)
|
||||
if inst.tfe: w += 1
|
||||
if hasattr(inst, 'tfe') and inst.tfe: w += 1
|
||||
vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else f"v{inst.vaddr}" if inst.offen or inst.idxen else "off"
|
||||
srsrc = _sreg_or_ttmp(inst.srsrc*4, 4)
|
||||
mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
|
||||
if cdna: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.sc0,"sc0"),(inst.nt,"nt"),(inst.sc1,"sc1")] if c]
|
||||
else: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
|
||||
return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {srsrc}, {decode_src(inst.soffset)}{' ' + ' '.join(mods) if mods else ''}"
|
||||
|
||||
def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int:
|
||||
@@ -348,25 +378,36 @@ def _disasm_mimg(inst: MIMG) -> str:
|
||||
|
||||
def _disasm_sop1(inst: SOP1) -> str:
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
|
||||
if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {_fmt_src(inst.ssrc0, 2)}"
|
||||
if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {_fmt_src(inst.ssrc0, 2)}"
|
||||
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
|
||||
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.src_regs(0) == 1 else _fmt_src(inst.ssrc0, inst.src_regs(0))}"
|
||||
src = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))
|
||||
if not _is_cdna(inst):
|
||||
if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
|
||||
if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {src}"
|
||||
if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {src}"
|
||||
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
|
||||
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {src}"
|
||||
|
||||
def _disasm_sop2(inst: SOP2) -> str:
|
||||
return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))}"
|
||||
|
||||
def _disasm_sopc(inst: SOPC) -> str:
|
||||
return f"{inst.op_name.lower()} {_fmt_src(inst.ssrc0, inst.src_regs(0))}, {_fmt_src(inst.ssrc1, inst.src_regs(1))}"
|
||||
s0 = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))
|
||||
s1 = inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))
|
||||
return f"{inst.op_name.lower()} {s0}, {s1}"
|
||||
|
||||
def _disasm_sopk(inst: SOPK) -> str:
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
if op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}"
|
||||
if op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32):
|
||||
op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst)
|
||||
# s_setreg_imm32_b32 has a 32-bit literal value
|
||||
if name == 's_setreg_imm32_b32' or (not cdna and op == SOPKOp.S_SETREG_IMM32_B32):
|
||||
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
|
||||
hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})"
|
||||
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if op == SOPKOp.S_SETREG_B32 else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}"
|
||||
return f"{name} {hs}, 0x{inst._literal:x}"
|
||||
if not cdna and op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}"
|
||||
if (not cdna and op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32)) or (cdna and name in ('s_setreg_b32', 's_getreg_b32')):
|
||||
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
|
||||
hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})"
|
||||
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}"
|
||||
if not cdna and op in (SOPKOp.S_SUBVECTOR_LOOP_BEGIN, SOPKOp.S_SUBVECTOR_LOOP_END):
|
||||
return f"{name} {_fmt_sdst(inst.sdst, 1)}, 0x{inst.simm16:x}"
|
||||
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, 0x{inst.simm16:x}"
|
||||
|
||||
def _disasm_vinterp(inst: VINTERP) -> str:
|
||||
@@ -388,7 +429,8 @@ SPEC_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'vcc': RawImm(106), '
|
||||
FLOATS = {str(k): k for k in FLOAT_ENC} # Valid float literal strings: '0.5', '-0.5', '1.0', etc.
|
||||
REG_MAP: dict[str, _RegFactory] = {'s': s, 'v': v, 't': ttmp, 'ttmp': ttmp}
|
||||
SMEM_OPS = {'s_load_b32', 's_load_b64', 's_load_b128', 's_load_b256', 's_load_b512',
|
||||
's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512'}
|
||||
's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512',
|
||||
's_atc_probe', 's_atc_probe_buffer'}
|
||||
SPEC_DSL = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC_LO', 'null': 'NULL', 'off': 'OFF', 'm0': 'M0',
|
||||
'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC_LO', 'scc': 'SCC', 'src_scc': 'SCC'}
|
||||
|
||||
@@ -579,3 +621,69 @@ def asm(text: str) -> Inst:
|
||||
except NameError:
|
||||
if m := re.match(r'^(v_\w+)(\(.*\))$', dsl): return eval(f"{m.group(1)}_e32{m.group(2)}", ns)
|
||||
raise
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# CDNA DISASSEMBLER SUPPORT
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
try:
|
||||
from extra.assembly.amd.autogen.cdna.ins import (VOP1 as CDNA_VOP1, VOP2 as CDNA_VOP2, VOPC as CDNA_VOPC, VOP3A, VOP3B, VOP3P as CDNA_VOP3P,
|
||||
SOP1 as CDNA_SOP1, SOP2 as CDNA_SOP2, SOPC as CDNA_SOPC, SOPK as CDNA_SOPK, SOPP as CDNA_SOPP, SMEM as CDNA_SMEM, DS as CDNA_DS,
|
||||
FLAT as CDNA_FLAT, MUBUF as CDNA_MUBUF, MTBUF as CDNA_MTBUF, SDWA, DPP, VOP1Op as CDNA_VOP1Op)
|
||||
|
||||
def _cdna_src(inst, v, neg, abs_=0, n=1):
|
||||
s = inst.lit(v) if v == 255 else _fmt_src(v, n)
|
||||
if abs_: s = f"|{s}|"
|
||||
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
|
||||
|
||||
def _disasm_vop3a(inst) -> str:
|
||||
name, n, cl, om = inst.op_name.lower(), inst.num_srcs(), " clamp" if inst.clmp else "", _omod(inst.omod)
|
||||
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.src_regs(0)), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.src_regs(1)), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.src_regs(2))
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
|
||||
if inst.op.value < 256: return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}"
|
||||
suf = "_e64" if inst.op.value < 512 else ""
|
||||
return f"{name}{suf} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else (f"{name}{suf}" if name == 'v_nop' else f"{name}{suf} {dst}, {s0}, {s1}{cl}{om}" if n == 2 else f"{name}{suf} {dst}, {s0}{cl}{om}")
|
||||
|
||||
def _disasm_vop3b(inst) -> str:
|
||||
name, n = inst.op_name.lower(), inst.num_srcs()
|
||||
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1), _cdna_src(inst, inst.src1, inst.neg&2), _cdna_src(inst, inst.src2, inst.neg&4)
|
||||
dst, suf = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}", "_e64" if 'co_' in name else ""
|
||||
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
|
||||
return f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}{cl}{om}"
|
||||
|
||||
def _disasm_cdna_vop3p(inst) -> str:
|
||||
name, n, is_mfma = inst.op_name.lower(), inst.num_srcs(), 'mfma' in inst.op_name.lower() or 'smfmac' in inst.op_name.lower()
|
||||
get_src = lambda v, sc: inst.lit(v) if v == 255 else _fmt_src(v, sc)
|
||||
if is_mfma: sc = 2 if 'iu4' in name else 4 if 'iu8' in name or 'i4' in name else 8 if 'f16' in name or 'bf16' in name else 4; src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 16), _vreg(inst.vdst, 16)
|
||||
else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}"
|
||||
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
|
||||
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if n == 3 else 3) else []) + \
|
||||
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
|
||||
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
|
||||
|
||||
_SEL = {0: 'BYTE_0', 1: 'BYTE_1', 2: 'BYTE_2', 3: 'BYTE_3', 4: 'WORD_0', 5: 'WORD_1', 6: 'DWORD'}
|
||||
_UNUSED = {0: 'UNUSED_PAD', 1: 'UNUSED_SEXT', 2: 'UNUSED_PRESERVE'}
|
||||
_DPP = {0x130: "wave_shl:1", 0x134: "wave_rol:1", 0x138: "wave_shr:1", 0x13c: "wave_ror:1", 0x140: "row_mirror", 0x141: "row_half_mirror", 0x142: "row_bcast:15", 0x143: "row_bcast:31"}
|
||||
|
||||
def _disasm_sdwa(inst) -> str:
|
||||
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
|
||||
except ValueError: name = f"vop1_op_{inst.vop_op}"
|
||||
src = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0)
|
||||
mods = [f"dst_sel:{_SEL[inst.dst_sel]}" for _ in [1] if inst.dst_sel != 6] + [f"dst_unused:{_UNUSED[inst.dst_u]}" for _ in [1] if inst.dst_u] + [f"src0_sel:{_SEL[inst.src0_sel]}" for _ in [1] if inst.src0_sel != 6]
|
||||
return f"{name}_sdwa v{inst.vdst}, {src}" + (" " + " ".join(mods) if mods else "")
|
||||
|
||||
def _disasm_dpp(inst) -> str:
|
||||
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
|
||||
except ValueError: name = f"vop1_op_{inst.vop_op}"
|
||||
src, ctrl = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0), inst.dpp_ctrl
|
||||
dpp = f"quad_perm:[{ctrl&3},{(ctrl>>2)&3},{(ctrl>>4)&3},{(ctrl>>6)&3}]" if ctrl < 0x100 else f"row_shl:{ctrl&0xf}" if ctrl < 0x110 else f"row_shr:{ctrl&0xf}" if ctrl < 0x120 else f"row_ror:{ctrl&0xf}" if ctrl < 0x130 else _DPP.get(ctrl, f"dpp_ctrl:0x{ctrl:x}")
|
||||
mods = [dpp] + [f"row_mask:0x{inst.row_mask:x}" for _ in [1] if inst.row_mask != 0xf] + [f"bank_mask:0x{inst.bank_mask:x}" for _ in [1] if inst.bank_mask != 0xf] + ["bound_ctrl:1" for _ in [1] if inst.bound_ctrl]
|
||||
return f"{name}_dpp v{inst.vdst}, {src} " + " ".join(mods)
|
||||
|
||||
# Register CDNA handlers - shared formats use merged disassemblers, CDNA-only formats use dedicated ones
|
||||
DISASM_HANDLERS.update({CDNA_VOP1: _disasm_vop1, CDNA_VOP2: _disasm_vop2, CDNA_VOPC: _disasm_vopc,
|
||||
CDNA_SOP1: _disasm_sop1, CDNA_SOP2: _disasm_sop2, CDNA_SOPC: _disasm_sopc, CDNA_SOPK: _disasm_sopk, CDNA_SOPP: _disasm_sopp,
|
||||
CDNA_SMEM: _disasm_smem, CDNA_DS: _disasm_ds, CDNA_FLAT: _disasm_flat, CDNA_MUBUF: _disasm_buf, CDNA_MTBUF: _disasm_buf,
|
||||
VOP3A: _disasm_vop3a, VOP3B: _disasm_vop3b, CDNA_VOP3P: _disasm_cdna_vop3p, SDWA: _disasm_sdwa, DPP: _disasm_dpp})
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -7,20 +7,18 @@ import functools
|
||||
|
||||
# instruction formats
|
||||
class DPP(Inst64):
|
||||
encoding = bits[31:26] == 0b110110
|
||||
src1_sel = bits[58:56]
|
||||
src1_sext = bits[59]
|
||||
src1_neg = bits[60]
|
||||
src1_abs = bits[61]
|
||||
s1 = bits[63]
|
||||
offset0 = bits[7:0]
|
||||
offset1 = bits[15:8]
|
||||
op = bits[24:17]
|
||||
acc = bits[25]
|
||||
addr:VGPRField = bits[39:32]
|
||||
data0:VGPRField = bits[47:40]
|
||||
data1:VGPRField = bits[55:48]
|
||||
vdst:VGPRField = bits[63:56]
|
||||
encoding = bits[8:0] == 0b11111010
|
||||
vop_op = bits[16:9]
|
||||
vdst:VGPRField = bits[24:17]
|
||||
vop2_op = bits[31:25]
|
||||
src0:Src = bits[39:32]
|
||||
dpp_ctrl = bits[48:40]
|
||||
bound_ctrl = bits[51]
|
||||
src0_neg = bits[52]
|
||||
src0_abs = bits[53]
|
||||
src1_neg = bits[54]
|
||||
src1_abs = bits[55]
|
||||
bank_mask = bits[59:56]
|
||||
row_mask = bits[63:60]
|
||||
|
||||
class DS(Inst64):
|
||||
@@ -82,6 +80,10 @@ class MUBUF(Inst64):
|
||||
acc = bits[55]
|
||||
|
||||
class SDWA(Inst64):
|
||||
encoding = bits[8:0] == 0b11111001
|
||||
vop_op = bits[16:9]
|
||||
vdst:VGPRField = bits[24:17]
|
||||
vop2_op = bits[31:25]
|
||||
src0:Src = bits[39:32]
|
||||
dst_sel = bits[42:40]
|
||||
dst_u = bits[44:43]
|
||||
@@ -97,9 +99,6 @@ class SDWA(Inst64):
|
||||
src1_neg = bits[60]
|
||||
src1_abs = bits[61]
|
||||
s1 = bits[63]
|
||||
sdst:SGPRField = bits[46:40]
|
||||
sd = bits[47]
|
||||
row_mask = bits[63:60]
|
||||
|
||||
class SDWAB(Inst64):
|
||||
src0:Src = bits[39:32]
|
||||
|
||||
@@ -488,6 +488,8 @@ class SMEMOp(IntEnum):
|
||||
S_BUFFER_LOAD_B512 = 12
|
||||
S_GL1_INV = 32
|
||||
S_DCACHE_INV = 33
|
||||
S_ATC_PROBE = 34
|
||||
S_ATC_PROBE_BUFFER = 35
|
||||
|
||||
class SOP1Op(IntEnum):
|
||||
S_MOV_B32 = 0
|
||||
@@ -710,6 +712,8 @@ class SOPKOp(IntEnum):
|
||||
S_SETREG_B32 = 18
|
||||
S_SETREG_IMM32_B32 = 19
|
||||
S_CALL_B64 = 20
|
||||
S_SUBVECTOR_LOOP_BEGIN = 22
|
||||
S_SUBVECTOR_LOOP_END = 23
|
||||
S_WAITCNT_VSCNT = 24
|
||||
S_WAITCNT_VMCNT = 25
|
||||
S_WAITCNT_EXPCNT = 26
|
||||
@@ -751,6 +755,8 @@ class SOPPOp(IntEnum):
|
||||
S_SENDMSGHALT = 55
|
||||
S_INCPERFLEVEL = 56
|
||||
S_DECPERFLEVEL = 57
|
||||
S_TTRACEDATA = 58
|
||||
S_TTRACEDATA_IMM = 59
|
||||
S_ICACHE_INV = 60
|
||||
S_BARRIER = 61
|
||||
|
||||
|
||||
@@ -692,6 +692,8 @@ s_buffer_load_b256 = functools.partial(SMEM, SMEMOp.S_BUFFER_LOAD_B256)
|
||||
s_buffer_load_b512 = functools.partial(SMEM, SMEMOp.S_BUFFER_LOAD_B512)
|
||||
s_gl1_inv = functools.partial(SMEM, SMEMOp.S_GL1_INV)
|
||||
s_dcache_inv = functools.partial(SMEM, SMEMOp.S_DCACHE_INV)
|
||||
s_atc_probe = functools.partial(SMEM, SMEMOp.S_ATC_PROBE)
|
||||
s_atc_probe_buffer = functools.partial(SMEM, SMEMOp.S_ATC_PROBE_BUFFER)
|
||||
s_mov_b32 = functools.partial(SOP1, SOP1Op.S_MOV_B32)
|
||||
s_mov_b64 = functools.partial(SOP1, SOP1Op.S_MOV_B64)
|
||||
s_cmov_b32 = functools.partial(SOP1, SOP1Op.S_CMOV_B32)
|
||||
@@ -906,6 +908,8 @@ s_getreg_b32 = functools.partial(SOPK, SOPKOp.S_GETREG_B32)
|
||||
s_setreg_b32 = functools.partial(SOPK, SOPKOp.S_SETREG_B32)
|
||||
s_setreg_imm32_b32 = functools.partial(SOPK, SOPKOp.S_SETREG_IMM32_B32)
|
||||
s_call_b64 = functools.partial(SOPK, SOPKOp.S_CALL_B64)
|
||||
s_subvector_loop_begin = functools.partial(SOPK, SOPKOp.S_SUBVECTOR_LOOP_BEGIN)
|
||||
s_subvector_loop_end = functools.partial(SOPK, SOPKOp.S_SUBVECTOR_LOOP_END)
|
||||
s_waitcnt_vscnt = functools.partial(SOPK, SOPKOp.S_WAITCNT_VSCNT)
|
||||
s_waitcnt_vmcnt = functools.partial(SOPK, SOPKOp.S_WAITCNT_VMCNT)
|
||||
s_waitcnt_expcnt = functools.partial(SOPK, SOPKOp.S_WAITCNT_EXPCNT)
|
||||
@@ -945,6 +949,8 @@ s_sendmsg = functools.partial(SOPP, SOPPOp.S_SENDMSG)
|
||||
s_sendmsghalt = functools.partial(SOPP, SOPPOp.S_SENDMSGHALT)
|
||||
s_incperflevel = functools.partial(SOPP, SOPPOp.S_INCPERFLEVEL)
|
||||
s_decperflevel = functools.partial(SOPP, SOPPOp.S_DECPERFLEVEL)
|
||||
s_ttracedata = functools.partial(SOPP, SOPPOp.S_TTRACEDATA)
|
||||
s_ttracedata_imm = functools.partial(SOPP, SOPPOp.S_TTRACEDATA_IMM)
|
||||
s_icache_inv = functools.partial(SOPP, SOPPOp.S_ICACHE_INV)
|
||||
s_barrier = functools.partial(SOPP, SOPPOp.S_BARRIER)
|
||||
v_interp_p10_f32 = functools.partial(VINTERP, VINTERPOp.V_INTERP_P10_F32)
|
||||
|
||||
@@ -458,10 +458,16 @@ class Inst:
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
import typing
|
||||
inst = cls.from_int(int.from_bytes(data[:cls._size()], 'little'))
|
||||
op_val = inst._values.get('op', 0)
|
||||
has_literal = cls.__name__ == 'VOP2' and op_val in (44, 45, 55, 56)
|
||||
has_literal = has_literal or (cls.__name__ == 'SOP2' and op_val in (69, 70))
|
||||
# Check for instructions that always have a literal constant (FMAMK/FMAAK/MADMK/MADAK, SETREG_IMM32)
|
||||
op_name = ''
|
||||
if cls.__name__ in ('VOP2', 'SOP2', 'SOPK') and 'op' in (hints := typing.get_type_hints(cls, include_extras=True)):
|
||||
if typing.get_origin(hints['op']) is typing.Annotated:
|
||||
try: op_name = typing.get_args(hints['op'])[1](op_val).name
|
||||
except (ValueError, TypeError): pass
|
||||
has_literal = any(x in op_name for x in ('FMAMK', 'FMAAK', 'MADMK', 'MADAK', 'SETREG_IMM32'))
|
||||
# VOPD fmaak/fmamk always have a literal (opx/opy value 1 or 2)
|
||||
opx, opy = inst._values.get('opx', 0), inst._values.get('opy', 0)
|
||||
has_literal = has_literal or (cls.__name__ == 'VOPD' and (opx in (1, 2) or opy in (1, 2)))
|
||||
@@ -475,7 +481,7 @@ class Inst:
|
||||
lit32 = int.from_bytes(data[cls._size():cls._size()+4], 'little')
|
||||
# Find which source has literal (255) and check its register count
|
||||
lit_src_is_64 = False
|
||||
for n, idx in [('src0', 0), ('src1', 1), ('src2', 2)]:
|
||||
for n, idx in [('src0', 0), ('src1', 1), ('src2', 2), ('ssrc0', 0), ('ssrc1', 1)]:
|
||||
if n in inst._values and isinstance(inst._values[n], RawImm) and inst._values[n].val == 255:
|
||||
lit_src_is_64 = inst.src_regs(idx) == 2
|
||||
break
|
||||
@@ -495,7 +501,12 @@ class Inst:
|
||||
return unwrap(self._values.get(name, 0))
|
||||
|
||||
def lit(self, v: int, neg: bool = False) -> str:
|
||||
s = f"0x{self._literal:x}" if v == 255 and self._literal else decode_src(v)
|
||||
if v == 255 and self._literal is not None:
|
||||
# For 64-bit sources, literal is stored shifted - extract the 32-bit value
|
||||
lit32 = (self._literal >> 32) if self._literal > 0xffffffff else self._literal
|
||||
s = f"0x{lit32:x}"
|
||||
else:
|
||||
s = decode_src(v)
|
||||
return f"-{s}" if neg else s
|
||||
|
||||
def __eq__(self, other):
|
||||
@@ -528,6 +539,10 @@ class Inst:
|
||||
elif val in self._VOP3SD_OPS: self.op = VOP3SDOp(val)
|
||||
else: self.op = VOP3Op(val)
|
||||
except ValueError: self.op = val
|
||||
# Prefer BitField marker (class-specific enum) over _enum_map (generic RDNA3 enums)
|
||||
elif 'op' in self._fields and (marker := self._fields['op'].marker) and issubclass(marker, IntEnum):
|
||||
try: self.op = marker(val)
|
||||
except ValueError: self.op = val
|
||||
elif cls_name in self._enum_map:
|
||||
try: self.op = self._enum_map[cls_name](val)
|
||||
except ValueError: self.op = val
|
||||
|
||||
@@ -194,13 +194,30 @@ def _parse_single_pdf(url: str):
|
||||
if fmt_name in formats:
|
||||
formats[fmt_name] = [(n, h, 14 if n == 'OP' else l, e, t) for n, h, l, e, t in formats[fmt_name]]
|
||||
if doc_name in ('RDNA3', 'RDNA3.5'):
|
||||
if 'SOPPOp' in enums: assert 8 not in enums['SOPPOp']; enums['SOPPOp'][8] = 'S_WAITCNT_DEPCTR'
|
||||
if 'SOPPOp' in enums:
|
||||
for k, v in {8: 'S_WAITCNT_DEPCTR', 58: 'S_TTRACEDATA', 59: 'S_TTRACEDATA_IMM'}.items():
|
||||
assert k not in enums['SOPPOp']; enums['SOPPOp'][k] = v
|
||||
if 'SOPKOp' in enums:
|
||||
for k, v in {22: 'S_SUBVECTOR_LOOP_BEGIN', 23: 'S_SUBVECTOR_LOOP_END'}.items():
|
||||
assert k not in enums['SOPKOp']; enums['SOPKOp'][k] = v
|
||||
if 'SMEMOp' in enums:
|
||||
for k, v in {34: 'S_ATC_PROBE', 35: 'S_ATC_PROBE_BUFFER'}.items():
|
||||
assert k not in enums['SMEMOp']; enums['SMEMOp'][k] = v
|
||||
if 'DSOp' in enums:
|
||||
for k, v in {24: 'DS_GWS_SEMA_RELEASE_ALL', 25: 'DS_GWS_INIT', 26: 'DS_GWS_SEMA_V', 27: 'DS_GWS_SEMA_BR', 28: 'DS_GWS_SEMA_P', 29: 'DS_GWS_BARRIER'}.items():
|
||||
assert k not in enums['DSOp']; enums['DSOp'][k] = v
|
||||
if 'FLATOp' in enums:
|
||||
for k, v in {40: 'GLOBAL_LOAD_ADDTID_B32', 41: 'GLOBAL_STORE_ADDTID_B32', 55: 'FLAT_ATOMIC_CSUB_U32'}.items():
|
||||
assert k not in enums['FLATOp']; enums['FLATOp'][k] = v
|
||||
# CDNA SDWA/DPP: PDF only has modifier fields, need VOP1/VOP2 overlay for correct encoding
|
||||
if is_cdna:
|
||||
if 'SDWA' in formats:
|
||||
formats['SDWA'] = [('ENCODING', 8, 0, 0xf9, None), ('VOP_OP', 16, 9, None, None), ('VDST', 24, 17, None, 'VGPRField'), ('VOP2_OP', 31, 25, None, None)] + \
|
||||
[f for f in formats['SDWA'] if f[0] not in ('ENCODING', 'SDST', 'SD', 'ROW_MASK')]
|
||||
if 'DPP' in formats:
|
||||
formats['DPP'] = [('ENCODING', 8, 0, 0xfa, None), ('VOP_OP', 16, 9, None, None), ('VDST', 24, 17, None, 'VGPRField'), ('VOP2_OP', 31, 25, None, None),
|
||||
('SRC0', 39, 32, None, 'Src'), ('DPP_CTRL', 48, 40, None, None), ('BOUND_CTRL', 51, 51, None, None), ('SRC0_NEG', 52, 52, None, None), ('SRC0_ABS', 53, 53, None, None),
|
||||
('SRC1_NEG', 54, 54, None, None), ('SRC1_ABS', 55, 55, None, None), ('BANK_MASK', 59, 56, None, None), ('ROW_MASK', 63, 60, None, None)]
|
||||
|
||||
# Extract pseudocode for instructions
|
||||
all_text = '\n'.join(pdf.text(i) for i in range(instr_start, instr_end))
|
||||
|
||||
@@ -131,28 +131,19 @@ def _make_asm_test(name):
|
||||
|
||||
def _make_disasm_test(name):
|
||||
def test(self):
|
||||
_, fmt_cls, op_enum = LLVM_TEST_FILES[name]
|
||||
_, base_fmt_cls, base_op_enum = LLVM_TEST_FILES[name]
|
||||
# VOP3SD opcodes that share encoding with VOP3 (only for vop3sd test, not vopc promotions)
|
||||
vop3sd_opcodes = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
|
||||
is_vopc_promotion = name in ('vop3_from_vopc', 'vop3_from_vopcx')
|
||||
undocumented = {'smem': {34, 35}, 'sopk': {22, 23}, 'sopp': {8, 58, 59}}
|
||||
|
||||
# First pass: decode all instructions and collect disasm strings
|
||||
to_test: list[tuple[str, bytes, str | None, str | None]] = [] # (asm_text, data, disasm_str, error)
|
||||
skipped = 0
|
||||
for asm_text, data in self.tests.get(name, []):
|
||||
if len(data) > fmt_cls._size(): continue
|
||||
temp_inst = fmt_cls.from_bytes(data)
|
||||
temp_op = temp_inst._values.get('op', 0)
|
||||
temp_op = temp_op.val if hasattr(temp_op, 'val') else temp_op
|
||||
if temp_op in undocumented.get(name, set()): skipped += 1; continue
|
||||
if name == 'sopp':
|
||||
simm16 = temp_inst._values.get('simm16', 0)
|
||||
simm16 = simm16.val if hasattr(simm16, 'val') else simm16
|
||||
sopp_no_imm = {48, 54, 53, 55, 60, 61, 62}
|
||||
if temp_op in sopp_no_imm and simm16 != 0: skipped += 1; continue
|
||||
# Detect VOP3 promotions in VOP1/VOP2/VOPC tests: VOP3 has bits [31:26]=0b110101 in first dword
|
||||
is_vop3_enc = name in ('vop1', 'vop2', 'vopc', 'vopcx') and len(data) >= 4 and (data[3] >> 2) == 0x35
|
||||
fmt_cls, op_enum = (VOP3, VOP3Op) if is_vop3_enc else (base_fmt_cls, base_op_enum)
|
||||
try:
|
||||
if fmt_cls.__name__ in ('VOP3', 'VOP3SD'):
|
||||
if base_fmt_cls.__name__ in ('VOP3', 'VOP3SD'):
|
||||
temp = VOP3.from_bytes(data)
|
||||
op_val = temp._values.get('op', 0)
|
||||
op_val = op_val.val if hasattr(op_val, 'val') else op_val
|
||||
@@ -188,7 +179,7 @@ def _make_disasm_test(name):
|
||||
if llvm_bytes is not None and llvm_bytes == data: passed += 1
|
||||
elif llvm_bytes is not None: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}")
|
||||
|
||||
print(f"{name.upper()} disasm: {passed} passed, {failed} failed" + (f", {skipped} skipped" if skipped else ""))
|
||||
print(f"{name.upper()} disasm: {passed} passed, {failed} failed")
|
||||
if failures[:10]: print(" " + "\n ".join(failures[:10]))
|
||||
self.assertEqual(failed, 0)
|
||||
return test
|
||||
|
||||
144
extra/assembly/amd/test/test_llvm_cdna.py
Normal file
144
extra/assembly/amd/test/test_llvm_cdna.py
Normal file
@@ -0,0 +1,144 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test CDNA assembler/disassembler against LLVM test vectors."""
|
||||
import unittest, re, subprocess
|
||||
from tinygrad.helpers import fetch
|
||||
from extra.assembly.amd.autogen.cdna.ins import *
|
||||
from extra.assembly.amd.asm import disasm
|
||||
from extra.assembly.amd.test.helpers import get_llvm_mc
|
||||
|
||||
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU"
|
||||
|
||||
def parse_llvm_tests(text: str, mnemonic_filter: str = None, size_filter: int = None) -> list[tuple[str, bytes]]:
|
||||
"""Parse LLVM test format into (asm, expected_bytes) pairs."""
|
||||
tests, lines = [], text.split('\n')
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
if not line or line.startswith(('//', '.', ';')): continue
|
||||
asm_text = line.split('//')[0].strip()
|
||||
if not asm_text or (mnemonic_filter and not asm_text.startswith(mnemonic_filter)): continue
|
||||
for j in list(range(max(0, i - 3), i)) + list(range(i, min(i + 3, len(lines)))):
|
||||
if m := re.search(r'(?:VI9|GFX9|CHECK)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]):
|
||||
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
||||
elif m := re.search(r'CHECK[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]):
|
||||
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
||||
else: continue
|
||||
try:
|
||||
data = bytes.fromhex(hex_bytes)
|
||||
if size_filter is None or len(data) == size_filter: tests.append((asm_text, data))
|
||||
except ValueError: pass
|
||||
break
|
||||
return tests
|
||||
|
||||
# Use gfx9 tests for compatible scalar/vector formats and gfx90a/gfx942 tests for CDNA-specific instructions
|
||||
# Format: (filename, format_class, op_enum, mcpu, mnemonic_filter, size_filter)
|
||||
CDNA_TEST_FILES = {
|
||||
# Scalar ALU - encoding is stable across GFX9/CDNA
|
||||
'sop1': ('gfx9_asm_sop1.s', SOP1, SOP1Op, 'gfx940', None, None),
|
||||
'sop2': ('gfx9_asm_sop2.s', SOP2, SOP2Op, 'gfx940', None, None),
|
||||
'sopp': ('gfx9_asm_sopp.s', SOPP, SOPPOp, 'gfx940', None, None),
|
||||
'sopp_gfx9': ('sopp-gfx9.s', SOPP, SOPPOp, 'gfx940', None, None),
|
||||
'sopk': ('gfx9_asm_sopk.s', SOPK, SOPKOp, 'gfx940', None, None),
|
||||
'sopc': ('gfx9_asm_sopc.s', SOPC, SOPCOp, 'gfx940', None, None),
|
||||
# Vector ALU - encoding is mostly stable
|
||||
'vop1': ('gfx9_asm_vop1.s', VOP1, VOP1Op, 'gfx940', None, None),
|
||||
'vop1_gfx9': ('vop1-gfx9.s', VOP1, VOP1Op, 'gfx940', None, None),
|
||||
'vop2': ('gfx9_asm_vop2.s', VOP2, VOP2Op, 'gfx940', None, None),
|
||||
'vopc': ('gfx9_asm_vopc.s', VOPC, VOPCOp, 'gfx940', None, None),
|
||||
'vop3p': ('gfx9_asm_vop3p.s', VOP3P, VOP3POp, 'gfx940', None, None),
|
||||
'vop3_gfx9': ('vop3-gfx9.s', VOP3A, VOP3AOp, 'gfx940', None, 8), # Only 64-bit VOP3 instructions
|
||||
# Memory instructions
|
||||
'ds': ('gfx9_asm_ds.s', DS, DSOp, 'gfx940', None, None),
|
||||
'ds_gfx9': ('ds-gfx9.s', DS, DSOp, 'gfx940', None, None),
|
||||
# CDNA memory instructions (gfx90a has correct FLAT/MUBUF encodings with acc registers)
|
||||
'flat_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'flat_', None),
|
||||
'global_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'global_', None),
|
||||
'mubuf_gfx90a': ('gfx90a_ldst_acc.s', MUBUF, MUBUFOp, 'gfx90a', 'buffer_', None),
|
||||
'mubuf_gfx9': ('mubuf-gfx9.s', MUBUF, MUBUFOp, 'gfx940', None, None),
|
||||
'scratch_gfx942': ('flat-scratch-gfx942.s', FLAT, FLATOp, 'gfx942', 'scratch_', None),
|
||||
# CDNA-specific: MFMA/MAI instructions
|
||||
'mai': ('mai-gfx942.s', VOP3P, VOP3POp, 'gfx942', None, None),
|
||||
# SDWA and DPP format tests for VOP1 (VOP2 has different bit layout, tested separately)
|
||||
'sdwa_vop1': ('gfx9_asm_vop1.s', SDWA, VOP1Op, 'gfx940', None, None),
|
||||
'dpp_vop1': ('gfx9_asm_vop1.s', DPP, VOP1Op, 'gfx940', None, None),
|
||||
}
|
||||
|
||||
class TestLLVMCDNA(unittest.TestCase):
|
||||
"""Test CDNA instruction format decode/encode roundtrip and disassembly."""
|
||||
tests: dict[str, list[tuple[str, bytes]]] = {}
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
for name, (filename, _, _, _, mnemonic_filter, size_filter) in CDNA_TEST_FILES.items():
|
||||
try:
|
||||
data = fetch(f"{LLVM_BASE}/{filename}").read_bytes()
|
||||
cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore'), mnemonic_filter, size_filter)
|
||||
except Exception as e:
|
||||
print(f"Warning: couldn't fetch {filename}: {e}")
|
||||
cls.tests[name] = []
|
||||
|
||||
def _get_val(v): return v.val if hasattr(v, 'val') else v
|
||||
|
||||
def _filter_and_decode(tests, fmt_cls, op_enum):
|
||||
"""Filter tests and decode instructions, yielding (asm_text, data, decoded, error)."""
|
||||
fn, is_sdwa, is_dpp = fmt_cls.__name__, fmt_cls.__name__ == 'SDWA', fmt_cls.__name__ == 'DPP'
|
||||
for asm_text, data in tests:
|
||||
has_lit = False
|
||||
# SDWA/DPP format tests: only accept matching 8-byte instructions
|
||||
if is_sdwa:
|
||||
if len(data) != 8 or data[0] != 0xf9: continue
|
||||
elif is_dpp:
|
||||
if len(data) != 8 or data[0] != 0xfa: continue
|
||||
elif fmt_cls._size() == 4 and len(data) == 8:
|
||||
if data[0] in (0xf9, 0xfa): continue # Skip SDWA/DPP (tested separately)
|
||||
has_lit = data[0] == 255 or (len(data) >= 2 and data[1] == 255 and fn in ('SOP2', 'SOPC'))
|
||||
if fn == 'SOPK': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 23) & 0x1f) == 20
|
||||
if fn == 'VOP2': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 25) & 0x3f) in (23, 24, 36, 37)
|
||||
if not has_lit: continue
|
||||
if len(data) > fmt_cls._size() + (4 if has_lit else 0): continue
|
||||
try:
|
||||
decoded = fmt_cls.from_bytes(data)
|
||||
# For SDWA/DPP, opcode location depends on VOP1 vs VOP2
|
||||
if is_sdwa or is_dpp:
|
||||
vop2_op = _get_val(decoded._values.get('vop2_op', 0))
|
||||
op_val = _get_val(decoded._values.get('vop_op', 0)) if vop2_op == 0x3f else vop2_op
|
||||
else:
|
||||
op_val = _get_val(decoded._values.get('op', 0))
|
||||
try: op_enum(op_val)
|
||||
except ValueError: continue
|
||||
yield asm_text, data, decoded, None
|
||||
except Exception as e:
|
||||
yield asm_text, data, None, str(e)
|
||||
|
||||
def _make_roundtrip_test(name):
|
||||
def test(self):
|
||||
_, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name]
|
||||
passed, failed, failures = 0, 0, []
|
||||
for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum):
|
||||
if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue
|
||||
if decoded.to_bytes()[:len(data)] == data: passed += 1
|
||||
else: failed += 1; failures.append(f"'{asm_text}': orig={data.hex()} reenc={decoded.to_bytes()[:len(data)].hex()}")
|
||||
print(f"CDNA {name.upper()} roundtrip: {passed} passed, {failed} failed")
|
||||
if failures[:5]: print(" " + "\n ".join(failures[:5]))
|
||||
self.assertEqual(failed, 0)
|
||||
return test
|
||||
|
||||
def _make_disasm_test(name):
|
||||
def test(self):
|
||||
_, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name]
|
||||
passed, failed, failures = 0, 0, []
|
||||
for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum):
|
||||
if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue
|
||||
if decoded.to_bytes()[:len(data)] != data: failed += 1; failures.append(f"'{asm_text}': roundtrip failed"); continue
|
||||
if not (disasm_text := disasm(decoded)) or not disasm_text.strip(): failed += 1; failures.append(f"'{asm_text}': empty disassembly"); continue
|
||||
passed += 1
|
||||
print(f"CDNA {name.upper()} disasm: {passed} passed, {failed} failed")
|
||||
if failures[:5]: print(" " + "\n ".join(failures[:5]))
|
||||
self.assertEqual(failed, 0)
|
||||
return test
|
||||
|
||||
for name in CDNA_TEST_FILES:
|
||||
setattr(TestLLVMCDNA, f'test_{name}_roundtrip', _make_roundtrip_test(name))
|
||||
setattr(TestLLVMCDNA, f'test_{name}_disasm', _make_disasm_test(name))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -131,7 +131,7 @@ class TestDType(unittest.TestCase):
|
||||
def test_finfo(self):
|
||||
if self.DTYPE not in [dtypes.float16, dtypes.float32, dtypes.float64]: return
|
||||
info = np.finfo(_to_np_dtype(self.DTYPE))
|
||||
self.assertEqual(info.bits, self.DTYPE.itemsize*8)
|
||||
self.assertEqual(info.bits, self.DTYPE.bitsize)
|
||||
self.assertEqual((info.nexp, info.nmant), dtypes.finfo(self.DTYPE))
|
||||
|
||||
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
|
||||
|
||||
@@ -452,5 +452,23 @@ class TestImageSimplification(unittest.TestCase):
|
||||
load = get_load_image_uop((32, 1024, 4), valid, (alu0, alu1))
|
||||
self.check(load, "(lidx1<7)", "((gidx0*2+lidx1*512+(lidx0*8192+r0*4096)+-11711)//4%1024)", "(lidx0*2+r0+-3)")
|
||||
|
||||
class TestUnfoldableImageChannelSelection(unittest.TestCase):
|
||||
def _count_nans(self, load):
|
||||
with Context(NOOPT=1, SPEC=0):
|
||||
result = full_rewrite_to_sink(load.sink()).src[0]
|
||||
return sum(1 for u in result.toposort() if u.op is Ops.CONST and u.arg != u.arg)
|
||||
|
||||
def test_bounded_channel_no_nan(self):
|
||||
# unfoldable image load with bounded idx % 4 range [0,1] -> no NAN fallback needed
|
||||
lidx = Special("lidx", 2)
|
||||
load = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((10, 10, 4)), arg=0).index(lidx, ptr=True), UOp.const(dtypes.float, 0)))
|
||||
self.assertEqual(self._count_nans(load), 0)
|
||||
|
||||
def test_unbounded_channel_has_nan(self):
|
||||
# variable with negative range -> x % 4 can be negative -> needs NAN fallback
|
||||
x = Variable("x", -10, 10)
|
||||
load = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((10, 10, 4)), arg=0).index(x, ptr=True), UOp.const(dtypes.float, 0)))
|
||||
self.assertEqual(self._count_nans(load), 1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -197,7 +197,12 @@ def image_fixup(ls:UOp):
|
||||
oidx = UOp(Ops.VECTORIZE, dtypes.index.vec(2), ((x // 4) % image_dtype.shape[1], (x // (4*image_dtype.shape[1]))))
|
||||
idx = idx.replace(src=(idx.src[0], oidx.valid(valid)))
|
||||
vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:])
|
||||
return functools.reduce(lambda ret, i: (x % 4).ne(i).where(ret, vec_load.gep(i)), range(4), ls.const_like(float('nan')))
|
||||
# image pixels have 4 channels (.xyzw), select channel based on x % 4
|
||||
x_mod_4 = x % 4
|
||||
def sel(ret, i): return x_mod_4.ne(i).where(ret, vec_load.gep(i))
|
||||
# if x is non-negative, x % 4 is in [0, 3] and we can skip NAN fallback
|
||||
if x_mod_4.vmin >= 0: return functools.reduce(sel, range(x_mod_4.vmin+1, x_mod_4.vmax+1), vec_load.gep(x_mod_4.vmin))
|
||||
return functools.reduce(sel, range(4), ls.const_like(float('nan')))
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -38,16 +38,18 @@ class AddrSpace(Enum):
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class DType(metaclass=DTypeMetaClass):
|
||||
priority: int # this determines when things get upcasted
|
||||
itemsize: int
|
||||
bitsize: int
|
||||
name: str
|
||||
fmt: FmtStr|None
|
||||
count: int
|
||||
_scalar: DType|None
|
||||
@property
|
||||
def itemsize(self) -> int: return (self.bitsize + 7) // 8
|
||||
@staticmethod
|
||||
def new(priority:int, itemsize:int, name:str, fmt:FmtStr|None): return DType(priority, itemsize, name, fmt, 1, None)
|
||||
def new(priority:int, bitsize:int, name:str, fmt:FmtStr|None): return DType(priority, bitsize, name, fmt, 1, None)
|
||||
def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self))
|
||||
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count != 1 else "")
|
||||
def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)
|
||||
def __lt__(self, o:DType): return (self.priority, self.bitsize, self.name, self.fmt, self.count) < (o.priority, o.bitsize, o.name, o.fmt, o.count)
|
||||
@property
|
||||
def base(self): return self
|
||||
@property
|
||||
@@ -56,9 +58,9 @@ class DType(metaclass=DTypeMetaClass):
|
||||
def vec(self, sz:int) -> DType:
|
||||
assert self.count == 1, f"can't vectorize {self} with size {sz}"
|
||||
if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
|
||||
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
|
||||
return DType(self.priority, self.bitsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
|
||||
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType:
|
||||
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size)
|
||||
return PtrDType(self.priority, self.bitsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size)
|
||||
def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
|
||||
def nbytes(self) -> int: raise RuntimeError("only ptr types have nbytes")
|
||||
@property
|
||||
@@ -79,8 +81,8 @@ class PtrDType(DType):
|
||||
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
|
||||
if sz == 1: return self # sz=1 is a scalar
|
||||
if isinstance(self, ImageDType):
|
||||
return ImageDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size, self.shape)
|
||||
return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size)
|
||||
return ImageDType(self.priority, self.bitsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size, self.shape)
|
||||
return type(self)(self.priority, self.bitsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size)
|
||||
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType: raise RuntimeError("can't make a pointer from a pointer")
|
||||
def nbytes(self) -> int:
|
||||
if self.size == -1: raise RuntimeError("can't get nbytes of a pointer with unlimited size")
|
||||
@@ -142,12 +144,12 @@ class dtypes:
|
||||
@staticmethod
|
||||
@functools.cache
|
||||
def min(dtype:DType):
|
||||
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.scalar().itemsize*8-1)
|
||||
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.scalar().bitsize-1)
|
||||
return -float("inf") if dtypes.is_float(dtype) else False
|
||||
@staticmethod
|
||||
@functools.cache
|
||||
def max(dtype:DType):
|
||||
if dtypes.is_int(dtype): return 2**(dtype.scalar().itemsize*8)-1+dtypes.min(dtype)
|
||||
if dtypes.is_int(dtype): return 2**(dtype.scalar().bitsize)-1+dtypes.min(dtype)
|
||||
return float("inf") if dtypes.is_float(dtype) else True
|
||||
@staticmethod
|
||||
def finfo(dtype:DType) -> tuple[int, int]:
|
||||
@@ -158,25 +160,23 @@ class dtypes:
|
||||
@staticmethod
|
||||
def fields() -> dict[str, DType]: return DTYPES_DICT
|
||||
void: Final[DType] = DType.new(-1, 0, "void", None)
|
||||
index: Final[DType] = DType.new(-1,100, "index", None)
|
||||
index: Final[DType] = DType.new(-1, 800, "index", None)
|
||||
bool: Final[DType] = DType.new(0, 1, "bool", '?')
|
||||
int8: Final[DType] = DType.new(1, 1, "signed char", 'b')
|
||||
uint8: Final[DType] = DType.new(2, 1, "unsigned char", 'B')
|
||||
int16: Final[DType] = DType.new(3, 2, "short", 'h')
|
||||
uint16: Final[DType] = DType.new(4, 2, "unsigned short", 'H')
|
||||
int24: Final[DType] = DType.new(5, 3, "int24", None)
|
||||
uint24: Final[DType] = DType.new(6, 3, "uint24", None)
|
||||
int32: Final[DType] = DType.new(7, 4, "int", 'i')
|
||||
uint32: Final[DType] = DType.new(8, 4, "unsigned int", 'I')
|
||||
int64: Final[DType] = DType.new(9, 8, "long", 'q')
|
||||
uint64: Final[DType] = DType.new(10, 8, "unsigned long", 'Q')
|
||||
fp8e4m3: Final[DType] = DType.new(11, 1, "float8_e4m3", None)
|
||||
fp8e5m2: Final[DType] = DType.new(12, 1, "float8_e5m2", None)
|
||||
float16: Final[DType] = DType.new(13, 2, "half", 'e')
|
||||
int8: Final[DType] = DType.new(1, 8, "signed char", 'b')
|
||||
uint8: Final[DType] = DType.new(2, 8, "unsigned char", 'B')
|
||||
int16: Final[DType] = DType.new(3, 16, "short", 'h')
|
||||
uint16: Final[DType] = DType.new(4, 16, "unsigned short", 'H')
|
||||
int32: Final[DType] = DType.new(5, 32, "int", 'i')
|
||||
uint32: Final[DType] = DType.new(6, 32, "unsigned int", 'I')
|
||||
int64: Final[DType] = DType.new(7, 64, "long", 'q')
|
||||
uint64: Final[DType] = DType.new(8, 64, "unsigned long", 'Q')
|
||||
fp8e4m3: Final[DType] = DType.new(9, 8, "float8_e4m3", None)
|
||||
fp8e5m2: Final[DType] = DType.new(10, 8, "float8_e5m2", None)
|
||||
float16: Final[DType] = DType.new(11, 16, "half", 'e')
|
||||
# bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
|
||||
bfloat16: Final[DType] = DType.new(14, 2, "__bf16", None)
|
||||
float32: Final[DType] = DType.new(15, 4, "float", 'f')
|
||||
float64: Final[DType] = DType.new(16, 8, "double", 'd')
|
||||
bfloat16: Final[DType] = DType.new(12, 16, "__bf16", None)
|
||||
float32: Final[DType] = DType.new(13, 32, "float", 'f')
|
||||
float64: Final[DType] = DType.new(14, 64, "double", 'd')
|
||||
|
||||
# dtype aliases
|
||||
half = float16; float = float32; double = float64 # noqa: E702
|
||||
@@ -185,9 +185,9 @@ class dtypes:
|
||||
|
||||
# NOTE: these are image dtypes
|
||||
@staticmethod
|
||||
def imageh(shp, pitch=-1): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch)
|
||||
def imageh(shp, pitch=-1): return ImageDType(100, 16, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch)
|
||||
@staticmethod
|
||||
def imagef(shp, pitch=-1): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch)
|
||||
def imagef(shp, pitch=-1): return ImageDType(100, 32, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch)
|
||||
|
||||
default_float: ClassVar[DType] = float32
|
||||
default_int: ClassVar[DType] = int32
|
||||
|
||||
@@ -518,7 +518,7 @@ class AMDHIPRenderer(CStyleLanguage):
|
||||
prefix.append("typedef long unsigned int size_t;")
|
||||
ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
|
||||
ocml_ops = {Ops.EXP2: ("exp2", "pure"), Ops.LOG2: ("log2", "pure"), Ops.SQRT: ("sqrt", "const"), Ops.SIN: ("sin", ""), Ops.TRUNC: ("trunc", "")}
|
||||
ocml = [(f"__ocml_{ocml_ops[op][0]}_f{dt.itemsize * 8}", dt.name, dt.name, ocml_ops[op][1])
|
||||
ocml = [(f"__ocml_{ocml_ops[op][0]}_f{dt.bitsize}", dt.name, dt.name, ocml_ops[op][1])
|
||||
for op, dt in dedup((u.op, u.dtype.scalar()) for u in uops) if op in ocml_ops and dt in (dtypes.half, dtypes.float, dtypes.double)]
|
||||
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("typedef unsigned short hip_bfloat16;")
|
||||
if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#define half _Float16")
|
||||
|
||||
@@ -12,7 +12,7 @@ def nsrc(d:mesa.nir_def) -> mesa.nir_src: return mesa.nir_src(ssa=ctypes.pointer
|
||||
|
||||
def glsl_type(t:DType): return mesa.glsl_array_type(glsl_type(t.base), t.size, 0).contents if isinstance(t, PtrDType) else {
|
||||
**{getattr(dtypes,k):g(f"glsl_type_builtin_{v}") for k,v in [('double','double'),('float','float'),('float16','float16_t'),('bool','uint8_t')]},
|
||||
**{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.itemsize*8)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t]
|
||||
**{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.bitsize)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t]
|
||||
|
||||
# alu ops, aop[<dtype>][<op>]
|
||||
u_aop = { Ops.ADD: "iadd", Ops.MUL: "imul", Ops.IDIV: "udiv", Ops.MOD: "umod", Ops.CMPLT: "ult", Ops.CMPNE: "ine", Ops.CMPEQ: "ieq", Ops.OR: "ior",
|
||||
@@ -26,7 +26,7 @@ def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else (
|
||||
def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def:
|
||||
if isinstance(it, PtrDType) and ot == dtypes.long: return src
|
||||
if ot == dtypes.bool: return nalu(b, c(it, False)+'ne'+('u' if c(it) == 'f' else ''), src, nimm(b, 0, it))
|
||||
return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.itemsize*8}", src)
|
||||
return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.bitsize}", src)
|
||||
|
||||
def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable):
|
||||
nif = mesa.nir_push_if(b, cond)
|
||||
@@ -71,12 +71,12 @@ def nimm_set(imm:mesa.nir_def, x, dtype:DType):
|
||||
instr = ctypes.cast(imm.parent_instr, ctypes.POINTER(mesa.nir_load_const_instr))
|
||||
struct.pack_into(unwrap(dtype.fmt), (ctypes.c_ubyte * dtype.itemsize).from_address(ctypes.addressof(instr.contents.value)), 0, x)
|
||||
|
||||
@nir_instr(nc=1, bs=lambda dtype: 1 if dtype == dtypes.bool else dtype.itemsize * 8)
|
||||
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize)
|
||||
def nimm(b:mesa.nir_builder, x, dtype:DType) -> mesa.nir_def:
|
||||
nimm_set(getattr((instr:=mesa.nir_load_const_instr_create(b.shader, 1, 1 if dtype==dtypes.bool else dtype.itemsize * 8)).contents, "def"), x, dtype)
|
||||
nimm_set(getattr((instr:=mesa.nir_load_const_instr_create(b.shader, 1, dtype.bitsize)).contents, "def"), x, dtype)
|
||||
return instr
|
||||
@nir_instr(nc=1, bs=lambda dtype: 1 if dtype == dtypes.bool else dtype.itemsize * 8)
|
||||
def nundef(b, dtype): return mesa.nir_undef_instr_create(b.shader, 1, 1 if dtype == dtypes.bool else dtype.itemsize * 8)
|
||||
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize)
|
||||
def nundef(b, dtype): return mesa.nir_undef_instr_create(b.shader, 1, dtype.bitsize)
|
||||
|
||||
deref_var = nir_instr(nc=1, bs=32, modes=lambda var:var.data.mode, type=lambda var:var.type, var=lambda var:ctypes.pointer(var))( # pylint: disable=W0108
|
||||
lambda b, var: mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_var))
|
||||
@@ -86,7 +86,7 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
|
||||
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1, **iointr(space)},
|
||||
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
|
||||
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
|
||||
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.itemsize*8//dtype.count, num_components=lambda dtype:dtype.count,
|
||||
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.bitsize//dtype.count, num_components=lambda dtype:dtype.count,
|
||||
intrins=lambda space:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}), **iointr(space)}, srcs=lambda addr: [nsrc(addr)])(
|
||||
lambda b, space, addr, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
|
||||
|
||||
|
||||
@@ -173,7 +173,7 @@ class AMDev(PCIDevImplBase):
|
||||
# Init hw for IP blocks where it is needed
|
||||
if not self.partial_boot:
|
||||
if self.psp.is_sos_alive() and self.smu.is_smu_alive():
|
||||
if self.gmc.xgmi_seg_sz > 0:
|
||||
if self.is_hive():
|
||||
if reset_mode: return # in reset mode, do not raise
|
||||
raise RuntimeError("Malformed state. Use extra/amdpci/hive_reset.py to reset the hive")
|
||||
self.smu.mode1_reset()
|
||||
@@ -221,6 +221,8 @@ class AMDev(PCIDevImplBase):
|
||||
self.smu.set_clocks(level=0)
|
||||
self.ih.interrupt_handler()
|
||||
|
||||
def is_hive(self) -> bool: return self.gmc.xgmi_seg_sz > 0
|
||||
|
||||
def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr
|
||||
def paddr2xgmi(self, paddr:int) -> int: return self.gmc.paddr_base + paddr
|
||||
def xgmi2paddr(self, xgmi_paddr:int) -> int: return xgmi_paddr - self.gmc.paddr_base
|
||||
|
||||
@@ -57,8 +57,8 @@ class AM_GMC(AM_IP):
|
||||
|
||||
self.trans_futher = self.adev.ip_ver[am.GC_HWIP] < (10, 0, 0)
|
||||
|
||||
# GFX11/GFX12 has 44-bit address space
|
||||
self.address_space_mask = (1 << 44) - 1
|
||||
# mi3xx has 48-bit, others have 44-bit address space
|
||||
self.address_space_mask = (1 << (48 if self.adev.ip_ver[am.GC_HWIP][:2] == (9,4) else 44)) - 1
|
||||
|
||||
self.memscratch_xgmi_paddr = self.adev.paddr2xgmi(self.adev.mm.palloc(0x1000, zero=False, boot=True))
|
||||
self.dummy_page_xgmi_paddr = self.adev.paddr2xgmi(self.adev.mm.palloc(0x1000, zero=False, boot=True))
|
||||
@@ -183,7 +183,8 @@ class AM_SMU(AM_IP):
|
||||
if self.adev.ip_ver[am.MP0_HWIP] >= (14,0,0): self._send_msg(__DEBUGSMC_MSG_Mode1Reset:=2, 0, debug=True)
|
||||
elif self.adev.ip_ver[am.MP0_HWIP] in {(13,0,6), (13,0,12)}: self._send_msg(self.smu_mod.PPSMC_MSG_GfxDriverReset, 1)
|
||||
else: self._send_msg(self.smu_mod.PPSMC_MSG_Mode1Reset, 0)
|
||||
time.sleep(0.5) # 500ms
|
||||
|
||||
if not self.adev.is_hive(): time.sleep(0.5) # 500ms
|
||||
|
||||
def read_table(self, table_t, arg):
|
||||
if self.adev.ip_ver[am.MP0_HWIP] in {(13,0,6),(13,0,12)}: self._send_msg(self.smu_mod.PPSMC_MSG_GetMetricsTable, arg)
|
||||
|
||||
@@ -604,7 +604,7 @@ class Tensor(OpMixin):
|
||||
bits = bits.bitcast(uint_dtype)
|
||||
# only randomize the mantissa bits and set the exponent to 1
|
||||
one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype)
|
||||
bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one)
|
||||
bits = bits.rshift(dtype.bitsize - nmant).bitwise_or(one)
|
||||
# bitcast back to the original dtype and reshape
|
||||
out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape).requires_grad_(kwargs.get("requires_grad"))
|
||||
return out.contiguous() if contiguous else out
|
||||
|
||||
@@ -186,7 +186,6 @@ commutative = PatternMatcher([
|
||||
|
||||
symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
# ** boolean algebra **
|
||||
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
|
||||
# TODO: make a more general or folder like simplify_valid
|
||||
(UPat.var("x", dtype=dtypes.bool) | UPat.var("x").logical_not(), lambda x: x.const_like(True)), # x|!x -> True
|
||||
# ** combine terms **
|
||||
|
||||
Reference in New Issue
Block a user