diff --git a/extra/assembly/amd/dsl.py b/extra/assembly/amd/dsl.py index c0e35fba01..1eed49aae5 100644 --- a/extra/assembly/amd/dsl.py +++ b/extra/assembly/amd/dsl.py @@ -253,6 +253,15 @@ def _get_variant(cls, suffix: str): module = sys.modules.get(cls.__module__) return getattr(module, f"{cls.__name__}{suffix}", None) if module else None +def _canonical_name(name: str) -> str | None: + """Map operand name to canonical name.""" + if name in ('src0', 'vsrc0', 'ssrc0'): return 's0' + if name in ('src1', 'vsrc1', 'ssrc1'): return 's1' + if name == 'src2': return 's2' + if name in ('vdst', 'sdst', 'sdata'): return 'd' + if name in ('data', 'vdata', 'data0', 'vsrc'): return 'data' + return None + class Inst: _fields: list[tuple[str, BitField]] _base_size: int @@ -368,12 +377,17 @@ class Inst: """Get bit widths with canonical names: {'s0', 's1', 's2', 'd', 'data'}.""" bits = {'d': 32, 's0': 32, 's1': 32, 's2': 32, 'data': 32} for name, val in self.op_bits.items(): - if name in ('src0', 'vsrc0', 'ssrc0'): bits['s0'] = val - elif name in ('src1', 'vsrc1', 'ssrc1'): bits['s1'] = val - elif name == 'src2': bits['s2'] = val - elif name in ('vdst', 'sdst', 'sdata'): bits['d'] = val - elif name in ('data', 'vdata', 'data0', 'vsrc'): bits['data'] = val + if (cn := _canonical_name(name)): bits[cn] = val return bits + + @functools.cached_property + def canonical_operands(self) -> dict: + """Get operands with canonical names: {'s0', 's1', 's2', 'd', 'data'}.""" + result = {} + for name, val in self.operands.items(): + if (cn := _canonical_name(name)): result[cn] = val + return result + @property def canonical_op_regs(self) -> dict[str, int]: """Get register counts with canonical names: {'s0', 's1', 's2', 'd', 'data'}.""" diff --git a/extra/assembly/amd/emu.py b/extra/assembly/amd/emu.py index 70d7161f2e..e77063fbe2 100644 --- a/extra/assembly/amd/emu.py +++ b/extra/assembly/amd/emu.py @@ -53,7 +53,7 @@ from extra.assembly.amd.autogen.rdna3.str_pcode import PCODE from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP1_SDST, VOP2, VOP3, VOP3_SDST, VOP3SD, VOP3P, VOPC, DS, FLAT, GLOBAL, SCRATCH, VOPD, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOPDOp) from extra.assembly.amd.dsl import VCC_LO, EXEC_LO, SCC -from extra.assembly.amd.autogen.common import OpType +from extra.assembly.amd.autogen.common import Fmt, OpType from extra.assembly.amd.pcode import parse_block, _FUNCS MASK32 = 0xFFFFFFFF @@ -70,14 +70,14 @@ def _split64(val: UOp) -> tuple[UOp, UOp]: return v64.cast(dtypes.uint32), (v64 >> UOp.const(dtypes.uint64, 32)).cast(dtypes.uint32) _SRC_MOD_TYPES = {16: (dtypes.uint16, dtypes.half, 0x7FFF), 64: (dtypes.uint64, dtypes.float64, 0x7FFFFFFFFFFFFFFF), 32: (dtypes.uint32, dtypes.float32, 0x7FFFFFFF)} -def _apply_src_mods(val: UOp, mod_bit: int, abs_bits: int, neg_bits: int, is_16bit: bool = False, is_64bit: bool = False) -> UOp: - """Apply abs/neg modifiers to source value based on operation type.""" +def _apply_src_mods(val: UOp, mod_bit: int, abs_bits: int, neg_bits: int, bits: int = 32) -> UOp: + """Apply abs/neg modifiers to source value based on bit width (16, 32, or 64).""" if not (abs_bits & (1 << mod_bit)) and not (neg_bits & (1 << mod_bit)): return val - ut, ft, mask = _SRC_MOD_TYPES[16 if is_16bit else 64 if is_64bit else 32] - fv = val.cast(ut).bitcast(ft) if is_16bit else val.bitcast(ft) if val.dtype == ut else val + ut, ft, mask = _SRC_MOD_TYPES[bits] + fv = val.cast(ut).bitcast(ft) if bits == 16 else val.bitcast(ft) if val.dtype == ut else val if abs_bits & (1 << mod_bit): fv = (fv.bitcast(ut) & UOp.const(ut, mask)).bitcast(ft) if neg_bits & (1 << mod_bit): fv = fv.neg() - return fv.bitcast(ut).cast(dtypes.uint32) if is_16bit else fv.bitcast(ut) + return fv.bitcast(ut).cast(dtypes.uint32) if bits == 16 else fv.bitcast(ut) # Map VOPD ops to VOP2 ops for pcode lookup VOPD_TO_VOP2 = { @@ -95,11 +95,10 @@ PC_LO_IDX, PC_HI_IDX, SCRATCH_STRIDE_IDX = 256, 257, 259 # SGPR buffer: 0-127 = SGPRs, 128-255 = inline constants, 256-259 = special registers SGPR_COUNT, VGPR_SIZE = 260, 256 * 32 -def _is_16bit_op(op_name: str) -> bool: return any(x in op_name for x in ('B16', 'F16', 'I16', 'U16')) def _op_name(inst) -> str: if hasattr(inst, 'opx'): return f"{inst.opx.name}_{inst.opy.name}" # VOPD has opx/opy not op return inst.op.name if hasattr(inst.op, 'name') else str(inst.op) -def _is_64bit_dest(dest: str) -> bool: return any(dest.endswith(x) for x in ('.b64', '.u64', '.i64', '.f64')) + def _to_u32(val: UOp) -> UOp: if val.dtype == dtypes.uint32: return val if val.dtype.itemsize == 4: return val.bitcast(dtypes.uint32) # same size: bitcast (float32->uint32) @@ -192,9 +191,9 @@ def _write_64bit(val: UOp, wfn, reg_or_addr, is_mem: bool, *args) -> list[UOp]: incr = 4 if is_mem else 1 # 4 bytes for memory addresses, 1 for register indices return [wfn(reg_or_addr, lo, *args), wfn(reg_or_addr + (UOp.const(reg_or_addr.dtype, incr) if isinstance(reg_or_addr, UOp) else incr), hi, *args)] -def _write_val(dest: str, val: UOp, wfn, reg_or_addr, *args, is_mem: bool = False) -> list[UOp]: - """Write value, splitting 64-bit if needed based on dest type suffix.""" - return _write_64bit(val, wfn, reg_or_addr, is_mem, *args) if _is_64bit_dest(dest) else [wfn(reg_or_addr, _to_u32(val), *args)] +def _write_val(bits: int, val: UOp, wfn, reg_or_addr, *args, is_mem: bool = False) -> list[UOp]: + """Write value, splitting 64-bit if needed. bits=64 for 64-bit writes, otherwise 32-bit.""" + return _write_64bit(val, wfn, reg_or_addr, is_mem, *args) if bits == 64 else [wfn(reg_or_addr, _to_u32(val), *args)] def _mem_store(mem: UOp, addr: UOp, val: UOp, active: UOp, addr_bits: int = 32, data_bits: int = 32) -> list[UOp]: """Conditional memory store with sub-word support. Returns list of store UOps.""" @@ -337,31 +336,38 @@ class _Ctx: offset = reg.cast(dtypes.int) * _c(32, dtypes.int) + lane.cast(dtypes.int) return buf.index(offset, _lane_active(exec_mask, lane)).store(val.cast(dtypes.uint32)) - def rsrc_dyn(self, off: UOp, lane: UOp, bits: int = 32, literal: UOp | None = None) -> UOp: - """Read source operand with dynamic offset. Handles SGPR/inline constants (<256), VGPR (>=256).""" - is_vgpr, vgpr_reg = off >= _c(256), off - _c(256) + def rsrc_dyn(self, off: UOp, lane: UOp | None, bits: int = 32, literal: UOp | None = None, is_f64: bool = False) -> UOp: + """Read source operand with dynamic offset. Handles SGPR/inline constants (<256), VGPR (>=256). + If lane is None, only scalar access is supported (off must be < 256). + is_f64: True for F64 operations where 64-bit literals go in high 32 bits.""" is_float_const = (off >= _c(240)) & (off <= _c(248)) sgpr_lo = self.rsgpr_dyn(off) - vgpr_lo = self.rvgpr_dyn(vgpr_reg, lane, is_vgpr) + + if lane is not None: + is_vgpr, vgpr_reg = off >= _c(256), off - _c(256) + vgpr_lo = self.rvgpr_dyn(vgpr_reg, lane, is_vgpr) + vgpr_val = _u64(vgpr_lo, self.rvgpr_dyn(vgpr_reg + _c(1), lane, is_vgpr)) if bits == 64 else vgpr_lo if bits == 64: - vgpr_val = _u64(vgpr_lo, self.rvgpr_dyn(vgpr_reg + _c(1), lane, is_vgpr)) sgpr_val = _u64(sgpr_lo, self.rsgpr_dyn(off + _c(1))) - # Float constants: cast F32 to F64; integer inline: duplicate lo - inline = is_float_const.where(sgpr_lo.bitcast(dtypes.float32).cast(dtypes.float64).bitcast(dtypes.uint64), _u64(sgpr_lo, sgpr_lo)) - if literal is not None: inline = off.eq(_c(255)).where(literal.cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32), inline) + # Integer inline constants: sign-extend 32-bit value from buffer to 64-bit + # Float constants: cast F32 to F64 + int_inline = sgpr_lo.cast(dtypes.int32).cast(dtypes.int64) + float_inline = sgpr_lo.bitcast(dtypes.float32).cast(dtypes.float64) + # compute inline + inline = is_float_const.where(float_inline.bitcast(dtypes.uint64), int_inline.bitcast(dtypes.uint64)) + # Literal handling: F64 VOP puts literal in high 32 bits; B64/I64/U64 VOP and SOP zero-extend + if literal is not None: + lit_val = literal.cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32) if is_f64 else literal.cast(dtypes.uint64) + inline = off.eq(_c(255)).where(lit_val, inline) scalar_val = (off < _c(128)).where(sgpr_val, inline) else: - vgpr_val = vgpr_lo scalar_val = sgpr_lo if literal is not None: scalar_val = off.eq(_c(255)).where(literal, scalar_val) if bits == 16: # Float constants: cast F32 to F16 scalar_val = is_float_const.where(scalar_val.bitcast(dtypes.float32).cast(dtypes.half).bitcast(dtypes.uint16).cast(dtypes.uint32), scalar_val) - return is_vgpr.where(vgpr_val, scalar_val) - - def rsrc_dyn_sized(self, off: UOp, lane: UOp, sizes: dict, key: str, f16: bool = False, literal: UOp | None = None) -> UOp: - return self.rsrc_dyn(off, lane, 64, literal) if sizes.get(key, 1) == 2 else self.rsrc_dyn(off, lane, 16 if f16 else 32, literal) + return is_vgpr.where(vgpr_val, scalar_val) if lane is not None else scalar_val def rpc(self) -> UOp: """Read PC as 64-bit byte address.""" @@ -509,36 +515,9 @@ def _compile_smem(inst: SMEM, ctx: _Ctx) -> UOp: return UOp.sink(*stores, *ctx.inc_pc()) def _compile_sop(inst: SOP1 | SOP2 | SOPC | SOPK, ctx: _Ctx) -> UOp: - sizes = getattr(inst, 'op_regs', {}) + bits = inst.canonical_op_bits literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None - # Read source operands dynamically - def rsrc_dyn_scalar(off: UOp, is_64bit: bool) -> UOp: - """Read scalar source with dynamic offset (SGPR or inline constant). - For SOP, off is always 0-255 (SGPR or inline constant, never VGPR). - SGPR buffer has 260 entries: 0-127=SGPRs, 128-255=inline constants, 256-259=special.""" - is_sgpr = off < _c(128) - # For 64-bit: read SGPR pair if off < 128, else compute inline constant as 64-bit - # (can't just read from buffer since buffer has 32-bit values) - if is_64bit: - sgpr_val = _u64(ctx.rsgpr_dyn(off), ctx.rsgpr_dyn(off + _c(1))) - # Build inline constant: 128-192 = 0-64, 193-208 = -1 to -16 - inline_val = (off - _c(128)).cast(dtypes.uint64) # positive inline 0-64 - neg_val = (_c(192) - off).cast(dtypes.int64).cast(dtypes.uint64) # negative -1 to -16 - lit_val = literal.cast(dtypes.uint64) if literal is not None else UOp.const(dtypes.uint64, 0) - # Select between sgpr, positive inline, negative inline, or literal - is_neg_inline = (off >= _c(193)) & (off < _c(209)) - is_literal = off.eq(_c(255)) if literal is not None else UOp.const(dtypes.bool, False) - val = is_sgpr.where(sgpr_val, is_neg_inline.where(neg_val, is_literal.where(lit_val, inline_val))) - return val - # 32-bit: read from SGPR buffer (inline constants 128-255 are pre-populated) - # off is always 0-255 for SOP, all valid SGPR indices - sgpr_val = ctx.rsgpr_dyn(off) - # Handle literal (255) - literal value overrides the pre-populated 0 - if literal is not None: - sgpr_val = off.eq(_c(255)).where(literal, sgpr_val) - return sgpr_val - if isinstance(inst, SOPK): sdst_off = ctx.inst_field(SOPK.sdst) simm16 = ctx.inst_field(SOPK.simm16) @@ -549,21 +528,21 @@ def _compile_sop(inst: SOP1 | SOP2 | SOPC | SOPK, ctx: _Ctx) -> UOp: elif isinstance(inst, SOP1): sdst_off = ctx.inst_field(SOP1.sdst) ssrc0_off = ctx.inst_field(SOP1.ssrc0) - srcs = {'S0': rsrc_dyn_scalar(ssrc0_off, sizes.get('ssrc0', 1) == 2)} - dst_off, dst_size = sdst_off, sizes.get('sdst', 1) + srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal)} + dst_off, dst_size = sdst_off, bits['d'] // 32 elif isinstance(inst, SOP2): sdst_off = ctx.inst_field(SOP2.sdst) ssrc0_off = ctx.inst_field(SOP2.ssrc0) ssrc1_off = ctx.inst_field(SOP2.ssrc1) - srcs = {'S0': rsrc_dyn_scalar(ssrc0_off, sizes.get('ssrc0', 1) == 2), - 'S1': rsrc_dyn_scalar(ssrc1_off, sizes.get('ssrc1', 1) == 2)} + srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal), + 'S1': ctx.rsrc_dyn(ssrc1_off, None, bits['s1'], literal)} if literal is not None: srcs['SIMM32'] = literal - dst_off, dst_size = sdst_off, sizes.get('sdst', 1) + dst_off, dst_size = sdst_off, bits['d'] // 32 elif isinstance(inst, SOPC): ssrc0_off = ctx.inst_field(SOPC.ssrc0) ssrc1_off = ctx.inst_field(SOPC.ssrc1) - srcs = {'S0': rsrc_dyn_scalar(ssrc0_off, sizes.get('ssrc0', 1) == 2), - 'S1': rsrc_dyn_scalar(ssrc1_off, sizes.get('ssrc1', 1) == 2)} + srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal), + 'S1': ctx.rsrc_dyn(ssrc1_off, None, bits['s1'], literal)} dst_off, dst_size = _c(0), 0 # SOPC writes to SCC, not sdst else: raise RuntimeError(f"unknown SOP type: {type(inst).__name__}") @@ -573,18 +552,17 @@ def _compile_sop(inst: SOP1 | SOP2 | SOPC | SOPK, ctx: _Ctx) -> UOp: def _compile_vop12(inst: VOP1 | VOP1_SDST | VOP2, ctx: _Ctx) -> UOp: op_name = _op_name(inst) if op_name == 'V_READFIRSTLANE_B32_E32': return ctx.compile_lane_pcode(inst.op, inst) - lane, exec_mask, sizes = ctx.range(), ctx.rsgpr_dyn(_c(EXEC_LO.offset)), getattr(inst, 'op_regs', {}) - is_16bit = _is_16bit_op(op_name) + lane, exec_mask, bits = ctx.range(), ctx.rsgpr_dyn(_c(EXEC_LO.offset)), inst.canonical_op_bits literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None vdst_reg = ctx.inst_field(VOP1.vdst) - write_hi_half = is_16bit and (vdst_reg >= _c(128)) + write_hi_half = bits['d'] == 16 and (vdst_reg >= _c(128)) if isinstance(write_hi_half, UOp): vdst_reg = write_hi_half.where(vdst_reg - _c(128), vdst_reg) elif write_hi_half: vdst_reg -= 128 if isinstance(inst, VOP1): # Handle VOP1 hi-half source operand (src0 >= v[128] for 16-bit ops) src0_off = ctx.inst_field(VOP1.src0) - s0 = ctx.rsrc_dyn_sized(src0_off, lane, sizes, 'src0', f16=is_16bit, literal=literal) - if is_16bit: + s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal) + if bits['s0'] == 16: src0_hi = src0_off >= _c(384) # Only compute hi-half when src0_off >= 384, use guarded index to prevent OOB access src0_reg = src0_hi.where(src0_off - _c(384), _c(0)) @@ -592,14 +570,14 @@ def _compile_vop12(inst: VOP1 | VOP1_SDST | VOP2, ctx: _Ctx) -> UOp: srcs = {'S0': s0} else: vsrc1_reg = ctx.inst_field(VOP2.vsrc1) - vsrc1_hi = is_16bit and (vsrc1_reg >= _c(128)) + vsrc1_hi = bits['s0'] == 16 and (vsrc1_reg >= _c(128)) vsrc1_actual = _cond(vsrc1_hi, vsrc1_reg - _c(128), vsrc1_reg) s1 = _cond_hi16(vsrc1_hi, ctx.rvgpr_dyn(vsrc1_actual, lane)) d0 = _cond_hi16(write_hi_half, ctx.rvgpr_dyn(vdst_reg, lane)) # FMAC/FMAMK hi-half dest needs hi-half accumulator # Handle VOP2 hi-half src0 operand (src0 >= v[128] for 16-bit ops) src0_off = ctx.inst_field(VOP2.src0) - s0 = ctx.rsrc_dyn(src0_off, lane, bits=16 if is_16bit else 32, literal=literal) - if is_16bit: + s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal) + if bits['s0'] == 16: src0_hi = src0_off >= _c(384) # Only compute hi-half when src0_off >= 384, use guarded index to prevent OOB access src0_reg = src0_hi.where(src0_off - _c(384), _c(0)) @@ -611,41 +589,36 @@ def _compile_vop12(inst: VOP1 | VOP1_SDST | VOP2, ctx: _Ctx) -> UOp: return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=write_hi_half) def _compile_vopc(inst: VOPC | VOP3, ctx: _Ctx, opsel: int = 0, abs_bits: int = 0, neg_bits: int = 0) -> UOp: - exec_mask, op_name = ctx.rsgpr_dyn(_c(EXEC_LO.offset)), _op_name(inst) - is_cmpx, is_16bit, is_64bit = 'CMPX' in op_name, _is_16bit_op(op_name), 'F64' in op_name - is_vopc = hasattr(inst, 'vsrc1') # VOPC (e32) vs VOP3 (e64) format + exec_mask, op_name, bits = ctx.rsgpr_dyn(_c(EXEC_LO.offset)), _op_name(inst), inst.canonical_op_bits + is_cmpx, is_vopc = 'CMPX' in op_name, hasattr(inst, 'vsrc1') # is_vopc: e32 vs e64 # Handle both VOPC (vsrc1) and VOP3 (src1) instruction formats - read operands dynamically if is_vopc: src0_off = ctx.inst_field(VOPC.src0) vsrc1_off = ctx.inst_field(VOPC.vsrc1) # For 16-bit ops, vsrc1 >= 128 means hi-half of v[vsrc1-128] - if is_16bit: + if bits['s0'] == 16: vsrc1_hi = vsrc1_off >= _c(128) src1_off = _c(256) + vsrc1_hi.where(vsrc1_off - _c(128), vsrc1_off) else: vsrc1_hi = False src1_off = _c(256) + vsrc1_off - src0_bits, src1_bits = (64, 64) if is_64bit else (32, 32) else: src0_off = ctx.inst_field(VOP3.src0) src1_off = ctx.inst_field(VOP3.src1) dst_off = ctx.inst_field(VOP3.vdst) vsrc1_hi = False - _, src0_bits, _ = inst.operands.get('src0', (None, 32, None)) - _, src1_bits, _ = inst.operands.get('src1', (None, 32, None)) - is_16bit = src0_bits == 16 or src1_bits == 16 literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None - is_float, pcode = any(x in op_name for x in ('_F32', '_F64', '_F16')), get_pcode(inst.op) + is_float, is_f64, pcode = any(x in op_name for x in ('_F32', '_F64', '_F16')), '_F64' in op_name, get_pcode(inst.op) def get_cmp_bit(lane) -> UOp: lc = lane.cast(dtypes.int) if isinstance(lane, UOp) else _c(lane, dtypes.int) - s0 = ctx.rsrc_dyn(src0_off, lc, src0_bits, literal) - s1 = _cond_hi16(vsrc1_hi, ctx.rsrc_dyn(src1_off, lc, src1_bits, literal)) if is_16bit else ctx.rsrc_dyn(src1_off, lc, src1_bits, literal) - if is_16bit and opsel: s0, s1 = _apply_opsel(s0, 0, opsel), _apply_opsel(s1, 1, opsel) - if is_float and (abs_bits or neg_bits): - s0 = _apply_src_mods(s0, 0, abs_bits, neg_bits, is_16bit, src0_bits == 64) - s1 = _apply_src_mods(s1, 1, abs_bits, neg_bits, is_16bit, src1_bits == 64) + s0 = ctx.rsrc_dyn(src0_off, lc, bits['s0'], literal, is_f64) + s1 = _cond_hi16(vsrc1_hi, ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64)) if bits['s0'] == 16 else ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64) + if bits['s0'] == 16 and opsel: s0, s1 = _apply_opsel(s0, 0, opsel), _apply_opsel(s1, 1, opsel) + if is_float: + s0 = _apply_src_mods(s0, 0, abs_bits, neg_bits, bits['s0']) + s1 = _apply_src_mods(s1, 1, abs_bits, neg_bits, bits['s1']) for dest, val in parse_pcode(pcode, {'S0': s0, 'S1': s1, 'laneId': lc})[1]: if '[laneId]' in dest and ('D0' in dest or 'EXEC' in dest): return val.cast(dtypes.uint32) return _c(0) @@ -664,7 +637,7 @@ def _compile_vopc(inst: VOPC | VOP3, ctx: _Ctx, opsel: int = 0, abs_bits: int = def _compile_vop3(inst: VOP3, ctx: _Ctx) -> UOp: exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset)) - sizes = getattr(inst, 'op_regs', {}) + bits = inst.canonical_op_bits opsel, op_name = getattr(inst, 'opsel', 0) or 0, _op_name(inst) # Lane operations @@ -677,53 +650,48 @@ def _compile_vop3(inst: VOP3, ctx: _Ctx) -> UOp: # Regular VOP3 - read operands dynamically lane = ctx.range() - is_f16_op = 'F16' in op_name vdst_reg = ctx.inst_field(VOP3.vdst) - src0_off = ctx.inst_field(VOP3.src0) - src1_off = ctx.inst_field(VOP3.src1) - src2_off = ctx.inst_field(VOP3.src2) if inst.src2 is not None else None literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None - src0 = ctx.rsrc_dyn_sized(src0_off, lane, sizes, 'src0', f16=is_f16_op, literal=literal) - src1 = ctx.rsrc_dyn_sized(src1_off, lane, sizes, 'src1', f16=is_f16_op, literal=literal) - src2 = ctx.rsrc_dyn_sized(src2_off, lane, sizes, 'src2', f16=is_f16_op, literal=literal) if src2_off is not None else None - if _is_16bit_op(op_name): - src0, src1 = _apply_opsel(src0, 0, opsel), _apply_opsel(src1, 1, opsel) - if src2 is not None: src2 = _apply_opsel(src2, 2, opsel) + ops = inst.canonical_operands + src0 = ctx.rsrc_dyn(ctx.inst_field(VOP3.src0), lane, bits['s0'], literal, 's0' in ops and ops['s0'][0] == Fmt.FMT_NUM_F64) + src1 = ctx.rsrc_dyn(ctx.inst_field(VOP3.src1), lane, bits['s1'], literal, 's1' in ops and ops['s1'][0] == Fmt.FMT_NUM_F64) + src2 = ctx.rsrc_dyn(ctx.inst_field(VOP3.src2), lane, bits['s2'], literal, 's2' in ops and ops['s2'][0] == Fmt.FMT_NUM_F64) + if bits['s0'] == 16: + src0 = _apply_opsel(src0, 0, opsel) + src1 = _apply_opsel(src1, 1, opsel) + src2 = _apply_opsel(src2, 2, opsel) abs_bits, neg_bits = getattr(inst, 'abs', 0) or 0, getattr(inst, 'neg', 0) or 0 - is_16bit_op = _is_16bit_op(op_name) - if abs_bits or neg_bits: - src0 = _apply_src_mods(src0, 0, abs_bits, neg_bits, is_16bit_op, sizes.get('src0', 1) == 2) - if src1 is not None: src1 = _apply_src_mods(src1, 1, abs_bits, neg_bits, is_16bit_op, sizes.get('src1', 1) == 2) - if src2 is not None: src2 = _apply_src_mods(src2, 2, abs_bits, neg_bits, is_16bit_op, sizes.get('src2', 1) == 2) - srcs = {'S0': src0, 'S1': src1} - if src2 is not None: srcs['S2'] = src2 + src0 = _apply_src_mods(src0, 0, abs_bits, neg_bits, bits['s0']) + src1 = _apply_src_mods(src1, 1, abs_bits, neg_bits, bits['s1']) + src2 = _apply_src_mods(src2, 2, abs_bits, neg_bits, bits['s2']) + srcs = {'S0': src0, 'S1': src1, 'S2': src2} if inst.op in (VOP3Op.V_CNDMASK_B32_E64, VOP3Op.V_CNDMASK_B16) and src2 is not None: srcs['VCC'] = src2 # FMAC instructions need D0 (accumulator) from destination register if 'FMAC' in op_name: srcs['D0'] = ctx.rvgpr_dyn(vdst_reg, lane) - opsel_dst_hi = bool(opsel & 0b1000) and _is_16bit_op(op_name) + opsel_dst_hi = bool(opsel & 0b1000) and bits['d'] == 16 return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=opsel_dst_hi, clmp=getattr(inst, 'clmp', 0)) def _compile_vop3sd(inst: VOP3SD, ctx: _Ctx) -> UOp: exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset)) - sizes, pcode = getattr(inst, 'op_regs', {}), get_pcode(inst.op) + bits, pcode, ops = inst.canonical_op_bits, get_pcode(inst.op), inst.canonical_operands # Read operands dynamically from instruction encoding - vdst_reg = ctx.inst_field(VOP3SD.vdst) - sdst_off = ctx.inst_field(VOP3SD.sdst) - src0_off = ctx.inst_field(VOP3SD.src0) - src1_off = ctx.inst_field(VOP3SD.src1) - src2_off = ctx.inst_field(VOP3SD.src2) if inst.src2 is not None else None + vdst_reg, sdst_off = ctx.inst_field(VOP3SD.vdst), ctx.inst_field(VOP3SD.sdst) + src0_off, src1_off, src2_off = ctx.inst_field(VOP3SD.src0), ctx.inst_field(VOP3SD.src1), ctx.inst_field(VOP3SD.src2) literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None - has_carry_in = 'src2' in inst.operands and inst.operands['src2'][2] == OpType.OPR_SREG - vcc_in_off = src2_off if has_carry_in and src2_off is not None else sdst_off + has_carry_in = 's2' in ops and ops['s2'][2] == OpType.OPR_SREG + vcc_in_off = src2_off if has_carry_in else sdst_off + + def load_srcs(lane_uop): + ret = {'VCC': ctx.rsgpr_dyn(vcc_in_off), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane_uop} + ret['S0'] = ctx.rsrc_dyn(src0_off, lane_uop, bits['s0'], literal, ops['s0'][0] == Fmt.FMT_NUM_F64) + ret['S1'] = ctx.rsrc_dyn(src1_off, lane_uop, bits['s1'], literal, ops['s1'][0] == Fmt.FMT_NUM_F64) + if 's2' in ops: ret['S2'] = ctx.rsrc_dyn(src2_off, lane_uop, bits['s2'], literal, ops['s2'][0] == Fmt.FMT_NUM_F64) + return ret lane = ctx.range() - src0 = ctx.rsrc_dyn_sized(src0_off, lane, sizes, 'src0', literal=literal) - src1 = ctx.rsrc_dyn_sized(src1_off, lane, sizes, 'src1', literal=literal) - src2 = ctx.rsrc_dyn_sized(src2_off, lane, sizes, 'src2', literal=literal) if src2_off is not None else None - srcs = {'S0': src0, 'S1': src1, 'VCC': ctx.rsgpr_dyn(vcc_in_off), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane} - if src2 is not None: srcs['S2'] = src2 + srcs = load_srcs(lane) _, assigns = parse_pcode(pcode, srcs) has_per_lane_vcc = any('[laneId]' in dest for dest, _ in assigns if dest.startswith('VCC') or dest.startswith('D0.u64')) @@ -731,23 +699,15 @@ def _compile_vop3sd(inst: VOP3SD, ctx: _Ctx) -> UOp: # VCC computation: RANGE+REDUCE gets axis ID first (lower ID = runs first) # This ensures VCC reads source values BEFORE VGPR stores modify them def get_vcc_bit(lane_uop) -> UOp: - s0, s1 = ctx.rsrc_dyn_sized(src0_off, lane_uop, sizes, 'src0', literal=literal), ctx.rsrc_dyn_sized(src1_off, lane_uop, sizes, 'src1', literal=literal) - s2 = ctx.rsrc_dyn_sized(src2_off, lane_uop, sizes, 'src2', literal=literal) if src2_off is not None else None - lane_srcs = {'S0': s0, 'S1': s1, 'VCC': ctx.rsgpr_dyn(vcc_in_off), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane_uop} - if s2 is not None: lane_srcs['S2'] = s2 vcc_bit = _c(0) - for dest, val in parse_pcode(pcode, lane_srcs)[1]: + for dest, val in parse_pcode(pcode, load_srcs(lane_uop))[1]: if dest.startswith('VCC') or (dest.startswith('D0.u64') and '[laneId]' in dest): vcc_bit = val.cast(dtypes.uint32) return vcc_bit final_vcc = ctx.unroll_lanes(get_vcc_bit, exec_mask) # VGPR stores: RANGE gets axis ID second (higher ID = runs after VCC loop) lane3 = ctx.range() - s0, s1 = ctx.rsrc_dyn_sized(src0_off, lane3, sizes, 'src0', literal=literal), ctx.rsrc_dyn_sized(src1_off, lane3, sizes, 'src1', literal=literal) - s2 = ctx.rsrc_dyn_sized(src2_off, lane3, sizes, 'src2', literal=literal) if src2_off is not None else None - lane_srcs = {'S0': s0, 'S1': s1, 'VCC': ctx.rsgpr_dyn(vcc_in_off), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane3} - if s2 is not None: lane_srcs['S2'] = s2 d0_val = None - for dest, val in parse_pcode(pcode, lane_srcs)[1]: + for dest, val in parse_pcode(pcode, load_srcs(lane3))[1]: if dest.startswith('D0') and '[laneId]' not in dest: d0_val = val vgpr_stores = [] if d0_val is not None: @@ -763,62 +723,52 @@ def _compile_vop3sd(inst: VOP3SD, ctx: _Ctx) -> UOp: else: return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, sdst_reg=inst.sdst.offset) -def _compile_vop3p(inst: VOP3P, ctx: _Ctx) -> UOp: - lane, exec_mask = ctx.range(), ctx.rsgpr_dyn(_c(EXEC_LO.offset)) - # Read register fields dynamically for deduplication +def _compile_wmma(inst: VOP3P, ctx: _Ctx) -> UOp: + op_name = _op_name(inst) + exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset)) vdst_reg = ctx.inst_field(VOP3P.vdst) - src0_off = ctx.inst_field(VOP3P.src0) - src1_off = ctx.inst_field(VOP3P.src1) - src2_off = ctx.inst_field(VOP3P.src2) if hasattr(inst, 'src2') and inst.src2 is not None else None - src0 = ctx.rsrc_dyn(src0_off, lane, 16) - src1 = ctx.rsrc_dyn(src1_off, lane, 16) - src2 = ctx.rsrc_dyn(src2_off, lane, 16) if src2_off is not None else None + src0_r = ctx.inst_field(VOP3P.src0) - _c(256) + src1_r = ctx.inst_field(VOP3P.src1) - _c(256) + src2_r = ctx.inst_field(VOP3P.src2) - _c(256) + is_f16_output = 'F16_16X16X16_F16' in op_name or 'BF16_16X16X16_BF16' in op_name # F16/BF16 output vs F32 output + is_bf16 = 'BF16' in op_name + cvt = _FUNCS['bf16_to_f32'] if is_bf16 else _FUNCS['f16_to_f32'] + def read_f16_mat(src): + return [f for l in range(16) for r in range(8) for v in [ctx.rvgpr_dyn(src + _c(r), UOp.const(dtypes.int, l))] + for f in [cvt(v & UOp.const(dtypes.uint32, 0xFFFF)), cvt(v >> UOp.const(dtypes.uint32, 16))]] + mat_a, mat_b = read_f16_mat(src0_r), read_f16_mat(src1_r) + if is_f16_output: + # RDNA3 F16/BF16 output: uses 8 VGPRs (same as F32), f16/bf16 values in lo 16 bits of each VGPR + # Layout: half16 per lane where even indices (0,2,4,...,14) = lo halves of VGPRs 0-7 + # Read accumulator: 8 regs × 32 lanes, each VGPR's lo 16 bits holds one f16/bf16 + mat_c = [cvt(ctx.rvgpr_dyn(src2_r + _c(i // 32), UOp.const(dtypes.int, i % 32)) & UOp.const(dtypes.uint32, 0xFFFF)) + for i in range(256)] + mat_d = [sum(mat_a[row*16+k] * mat_b[col*16+k] for k in range(16)) + mat_c[row*16+col] for row in range(16) for col in range(16)] + # Write f16/bf16 results to lo 16 bits of each VGPR + def f32_to_f16_bits(v: UOp) -> UOp: return v.cast(dtypes.half).bitcast(dtypes.uint16).cast(dtypes.uint32) + def f32_to_bf16_bits(v: UOp) -> UOp: return (v.bitcast(dtypes.uint32) >> UOp.const(dtypes.uint32, 16)) & UOp.const(dtypes.uint32, 0xFFFF) + out_cvt = f32_to_bf16_bits if is_bf16 else f32_to_f16_bits + stores = [ctx.wvgpr_dyn(vdst_reg + _c(i // 32), UOp.const(dtypes.int, i % 32), out_cvt(mat_d[i]), exec_mask) for i in range(256)] + else: + # F32 output: accumulator and output are f32 + mat_c = [ctx.rvgpr_dyn(src2_r + _c(i // 32), UOp.const(dtypes.int, i % 32)).bitcast(dtypes.float32) for i in range(256)] + mat_d = [sum(mat_a[row*16+k] * mat_b[col*16+k] for k in range(16)) + mat_c[row*16+col] for row in range(16) for col in range(16)] + stores = [ctx.wvgpr_dyn(vdst_reg + _c(i // 32), UOp.const(dtypes.int, i % 32), mat_d[i].bitcast(dtypes.uint32), exec_mask) for i in range(256)] + return UOp.sink(*stores, *ctx.inc_pc()) + +def _compile_vop3p(inst: VOP3P, ctx: _Ctx) -> UOp: + op_name = _op_name(inst) + if 'WMMA' in op_name and ('16X16X16_F16' in op_name or '16X16X16_BF16' in op_name): return _compile_wmma(inst, ctx) + + lane = ctx.range() + exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset)) + vdst_reg = ctx.inst_field(VOP3P.vdst) + src0 = ctx.rsrc_dyn(ctx.inst_field(VOP3P.src0), lane, 16) + src1 = ctx.rsrc_dyn(ctx.inst_field(VOP3P.src1), lane, 16) + src2 = ctx.rsrc_dyn(ctx.inst_field(VOP3P.src2), lane, 16) opsel, opsel_hi = getattr(inst, 'opsel', 0) or 0, getattr(inst, 'opsel_hi', 3) if getattr(inst, 'opsel_hi', 3) is not None else 3 opsel_hi2 = getattr(inst, 'opsel_hi2', 1) if getattr(inst, 'opsel_hi2', 1) is not None else 1 neg, neg_hi = getattr(inst, 'neg', 0) or 0, getattr(inst, 'neg_hi', 0) or 0 - def get_half_bits(val: UOp, use_hi: bool, apply_neg: bool = False) -> UOp: - bits = ((val >> UOp.const(dtypes.uint32, 16)) if use_hi else val) & UOp.const(dtypes.uint32, 0xFFFF) - if apply_neg: bits = bits.cast(dtypes.uint16).bitcast(dtypes.half).neg().bitcast(dtypes.uint16).cast(dtypes.uint32) - return bits - def build_remapped_src(src: UOp, opsel_lo_bit: int, opsel_hi_bit: int, neg_lo_bit: int, neg_hi_bit: int) -> UOp: - return get_half_bits(src, bool(opsel_lo_bit), bool(neg_lo_bit)) | (get_half_bits(src, bool(opsel_hi_bit), bool(neg_hi_bit)) << UOp.const(dtypes.uint32, 16)) - s0_new = build_remapped_src(src0, opsel & 1, opsel_hi & 1, neg & 1, neg_hi & 1) - s1_new = build_remapped_src(src1, opsel & 2, opsel_hi & 2, neg & 2, neg_hi & 2) - s2_new = build_remapped_src(src2, opsel & 4, 1 if opsel_hi2 else 0, neg & 4, neg_hi & 4) if src2 is not None else None - op_name = _op_name(inst) - - # WMMA: Wave Matrix Multiply-Accumulate - if 'WMMA' in op_name and ('16X16X16_F16' in op_name or '16X16X16_BF16' in op_name): - # Dynamic register fields for deduplication - src0_r = ctx.inst_field(VOP3P.src0) - _c(256) - src1_r = ctx.inst_field(VOP3P.src1) - _c(256) - src2_r = ctx.inst_field(VOP3P.src2) - _c(256) - is_f16_output = 'F16_16X16X16_F16' in op_name or 'BF16_16X16X16_BF16' in op_name # F16/BF16 output vs F32 output - is_bf16 = 'BF16' in op_name - cvt = _FUNCS['bf16_to_f32'] if is_bf16 else _FUNCS['f16_to_f32'] - def read_f16_mat(src): - return [f for l in range(16) for r in range(8) for v in [ctx.rvgpr_dyn(src + _c(r), UOp.const(dtypes.int, l))] - for f in [cvt(v & UOp.const(dtypes.uint32, 0xFFFF)), cvt(v >> UOp.const(dtypes.uint32, 16))]] - mat_a, mat_b = read_f16_mat(src0_r), read_f16_mat(src1_r) - if is_f16_output: - # RDNA3 F16/BF16 output: uses 8 VGPRs (same as F32), f16/bf16 values in lo 16 bits of each VGPR - # Layout: half16 per lane where even indices (0,2,4,...,14) = lo halves of VGPRs 0-7 - # Read accumulator: 8 regs × 32 lanes, each VGPR's lo 16 bits holds one f16/bf16 - mat_c = [cvt(ctx.rvgpr_dyn(src2_r + _c(i // 32), UOp.const(dtypes.int, i % 32)) & UOp.const(dtypes.uint32, 0xFFFF)) - for i in range(256)] - mat_d = [sum(mat_a[row*16+k] * mat_b[col*16+k] for k in range(16)) + mat_c[row*16+col] for row in range(16) for col in range(16)] - # Write f16/bf16 results to lo 16 bits of each VGPR - def f32_to_f16_bits(v: UOp) -> UOp: return v.cast(dtypes.half).bitcast(dtypes.uint16).cast(dtypes.uint32) - def f32_to_bf16_bits(v: UOp) -> UOp: return (v.bitcast(dtypes.uint32) >> UOp.const(dtypes.uint32, 16)) & UOp.const(dtypes.uint32, 0xFFFF) - out_cvt = f32_to_bf16_bits if is_bf16 else f32_to_f16_bits - stores = [ctx.wvgpr_dyn(vdst_reg + _c(i // 32), UOp.const(dtypes.int, i % 32), - out_cvt(mat_d[i]), exec_mask) for i in range(256)] - else: - # F32 output: accumulator and output are f32 - mat_c = [ctx.rvgpr_dyn(src2_r + _c(i // 32), UOp.const(dtypes.int, i % 32)).bitcast(dtypes.float32) for i in range(256)] - mat_d = [sum(mat_a[row*16+k] * mat_b[col*16+k] for k in range(16)) + mat_c[row*16+col] for row in range(16) for col in range(16)] - stores = [ctx.wvgpr_dyn(vdst_reg + _c(i // 32), UOp.const(dtypes.int, i % 32), mat_d[i].bitcast(dtypes.uint32), exec_mask) for i in range(256)] - return UOp.sink(*stores, *ctx.inc_pc()) if 'FMA_MIX' in op_name: combined_opsel_hi = (opsel_hi & 0x3) | ((opsel_hi2 & 0x1) << 2) @@ -836,12 +786,20 @@ def _compile_vop3p(inst: VOP3P, ctx: _Ctx) -> UOp: return v ^ UOp.const(dtypes.uint32, 0x00008000) # f16 lo neg s0_mod = apply_neg_mix(apply_abs(src0, 1, 1, 1), 1, 1, 1) s1_mod = apply_neg_mix(apply_abs(src1, 2, 2, 2), 2, 2, 2) - s2_mod = apply_neg_mix(apply_abs(src2, 4, 4, 4), 4, 4, 4) if src2 is not None else UOp.const(dtypes.uint32, 0) + s2_mod = apply_neg_mix(apply_abs(src2, 4, 4, 4), 4, 4, 4) srcs = {'S0': s0_mod, 'S1': s1_mod, 'S2': s2_mod, 'OPSEL_HI': UOp.const(dtypes.uint32, combined_opsel_hi), 'OPSEL': UOp.const(dtypes.uint32, opsel)} else: - srcs = {'S0': s0_new, 'S1': s1_new} - if s2_new is not None: srcs['S2'] = s2_new + def get_half_bits(val: UOp, use_hi: bool, apply_neg: bool = False) -> UOp: + bits = ((val >> UOp.const(dtypes.uint32, 16)) if use_hi else val) & UOp.const(dtypes.uint32, 0xFFFF) + if apply_neg: bits = bits.cast(dtypes.uint16).bitcast(dtypes.half).neg().bitcast(dtypes.uint16).cast(dtypes.uint32) + return bits + def build_remapped_src(src: UOp, opsel_lo_bit: int, opsel_hi_bit: int, neg_lo_bit: int, neg_hi_bit: int) -> UOp: + return get_half_bits(src, bool(opsel_lo_bit), bool(neg_lo_bit)) | (get_half_bits(src, bool(opsel_hi_bit), bool(neg_hi_bit)) << UOp.const(dtypes.uint32, 16)) + s0_new = build_remapped_src(src0, opsel & 1, opsel_hi & 1, neg & 1, neg_hi & 1) + s1_new = build_remapped_src(src1, opsel & 2, opsel_hi & 2, neg & 2, neg_hi & 2) + s2_new = build_remapped_src(src2, opsel & 4, 1 if opsel_hi2 else 0, neg & 4, neg_hi & 4) + srcs = {'S0': s0_new, 'S1': s1_new, 'S2': s2_new} return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask) def _compile_vopd(inst: VOPD, ctx: _Ctx) -> UOp: @@ -904,9 +862,8 @@ def _compile_mem_op(inst: DS | FLAT | GLOBAL | SCRATCH, ctx: _Ctx) -> UOp: # Dynamic saddr - read field, NULL (124) or >= 128 means no saddr saddr_reg = ctx.inst_field(type(inst).saddr) if hasattr(inst, 'saddr') else None - # Data width - ndwords = 4 if '_B128' in op_name or 'B128' in op_name else 3 if '_B96' in op_name or 'B96' in op_name else 2 if '_B64' in op_name or 'B64' in op_name else 1 - is_64bit = ndwords >= 2 or '_U64' in op_name or '_I64' in op_name or '_F64' in op_name + # Data width from canonical_op_bits (32/64/96/128), default to 32 for untyped ops + data_bits_mem = inst.canonical_op_bits.get('data', 32) is_atomic, glc = 'ATOMIC' in op_name, getattr(inst, 'glc', 0) has_data1 = is_lds and hasattr(inst, 'data1') and inst.data1 is not None data1_reg = ctx.inst_field(DS.data1) if is_lds else _c(0) @@ -940,10 +897,13 @@ def _compile_mem_op(inst: DS | FLAT | GLOBAL | SCRATCH, ctx: _Ctx) -> UOp: def make_srcs(lane: UOp) -> dict: addr = make_addr(lane) if is_lds: - if 'B128' in op_name or 'B96' in op_name: + if data_bits_mem == 128: data = {'DATA': ctx.rvgpr_dyn(vdata_reg, lane), 'DATA1': ctx.rvgpr_dyn(vdata_reg + _c(1), lane), 'DATA2': ctx.rvgpr_dyn(vdata_reg + _c(2), lane), 'DATA3': ctx.rvgpr_dyn(vdata_reg + _c(3), lane)} - elif 'B32' in op_name: + elif data_bits_mem == 96: + data = {'DATA': ctx.rvgpr_dyn(vdata_reg, lane), 'DATA1': ctx.rvgpr_dyn(vdata_reg + _c(1), lane), + 'DATA2': ctx.rvgpr_dyn(vdata_reg + _c(2), lane)} + elif data_bits_mem == 32: data = {'DATA': ctx.rvgpr_dyn(vdata_reg, lane), 'DATA2': ctx.rvgpr_dyn(data1_reg, lane) if has_data1 else UOp.const(dtypes.uint32, 0)} else: data = {'DATA': _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)), @@ -951,26 +911,27 @@ def _compile_mem_op(inst: DS | FLAT | GLOBAL | SCRATCH, ctx: _Ctx) -> UOp: return {'ADDR': addr, 'ADDR_BASE': addr, 'OFFSET': offset, 'OFFSET0': offset0, 'OFFSET1': offset1, '_lds': mem, 'laneId': lane, **data} active = _lane_active(exec_mask, lane) if is_atomic: - return {'ADDR': addr, 'DATA': _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)) if is_64bit else ctx.rvgpr_dyn(vdata_reg, lane), + return {'ADDR': addr, 'DATA': _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)) if data_bits_mem == 64 else ctx.rvgpr_dyn(vdata_reg, lane), '_vmem': mem, '_active': active, 'laneId': lane} vdata = ctx.rvgpr_dyn(vdata_reg, lane).cast(dtypes.uint64) if 'STORE' in op_name else ctx.rvgpr_dyn(vdst_reg, lane) if 'D16' in op_name else UOp.const(dtypes.uint32, 0) - if 'STORE' in op_name and ndwords >= 2: vdata = vdata | (ctx.rvgpr_dyn(vdata_reg + _c(1), lane).cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32)) + if 'STORE' in op_name and data_bits_mem >= 64: vdata = vdata | (ctx.rvgpr_dyn(vdata_reg + _c(1), lane).cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32)) srcs = {'ADDR': addr, 'VDATA': vdata, '_vmem': mem, '_active': active, 'laneId': lane} - for i in range(ndwords): srcs[f'VDATA{i}'] = ctx.rvgpr_dyn(vdata_reg + _c(i), lane) if 'STORE' in op_name else UOp.const(dtypes.uint32, 0) + for i in range(data_bits_mem // 32): srcs[f'VDATA{i}'] = ctx.rvgpr_dyn(vdata_reg + _c(i), lane) if 'STORE' in op_name else UOp.const(dtypes.uint32, 0) return srcs def make_stores(dest: str, val: UOp, lane: UOp, active: UOp, writes_return_data: bool) -> list[UOp]: + # Parse bit width from dest format: MEM[...].b32 or RETURN_DATA[63:32].b64 + parts = dest.rsplit('.', 1) + data_bits = int(parts[1][1:]) if len(parts) == 2 else 32 if dest.startswith('MEM['): - if is_lds or is_atomic: return _write_val(dest, val[1], wmem, val[0], active, is_mem=True) - data_bits = 8 if '.b8' in dest else 16 if '.b16' in dest else 64 if '.b64' in dest else 32 + if is_lds or is_atomic: return _write_val(data_bits, val[1], wmem, val[0], active, is_mem=True) if is_scratch: return _mem_store_bytes(mem, val[0], val[1], active, data_bits) return _mem_store(mem, val[0], val[1], active, 64, data_bits) if dest.startswith('RETURN_DATA') and writes_return_data: if (m := re.match(r'RETURN_DATA\[(\d+)\s*:\s*(\d+)\]', dest)): bit_width, dword_idx = int(m.group(1)) - int(m.group(2)) + 1, int(m.group(2)) // 32 - is_64 = '.b64' if bit_width == 64 else '' - return _write_val(is_64, val, lambda r, v, l, e: ctx.wvgpr_dyn(r, l, v, e), vdst_reg + _c(dword_idx), lane, exec_mask) - return _write_val(dest, val, lambda r, v, l, e: ctx.wvgpr_dyn(r, l, v, e), vdst_reg, lane, exec_mask) + return _write_val(bit_width, val, lambda r, v, l, e: ctx.wvgpr_dyn(r, l, v, e), vdst_reg + _c(dword_idx), lane, exec_mask) + return _write_val(data_bits, val, lambda r, v, l, e: ctx.wvgpr_dyn(r, l, v, e), vdst_reg, lane, exec_mask) return [] # DS-specific: check for 2ADDR pattern needing separate ranges diff --git a/extra/assembly/amd/pcode.py b/extra/assembly/amd/pcode.py index ba3e85413f..7103109d28 100644 --- a/extra/assembly/amd/pcode.py +++ b/extra/assembly/amd/pcode.py @@ -717,6 +717,22 @@ def _subst_loop_var(line: str, loop_var: str, val: int) -> str: subst_parts = [str(val) if t.type == 'IDENT' and t.val == loop_var else t.val for t in result_toks if t.type != 'EOF'] return ' '.join(subst_parts) +def _set_bits(old: UOp, val: UOp, width: int, offset: int) -> UOp: + """Set bits [offset:offset+width) in old to val, masking and shifting appropriately.""" + mask = _u32(((1 << width) - 1) << offset) + v = (val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val) & _u32((1 << width) - 1) + return (old & (mask ^ _u32(0xFFFFFFFF))) | (v << _u32(offset)) + +def _find_paren_end(s: str, start: int = 0, open_ch: str = '(', close_ch: str = ')') -> int: + """Find index of matching close paren, starting after the open paren at start.""" + depth = 0 + for j, ch in enumerate(s[start:], start): + if ch == open_ch: depth += 1 + elif ch == close_ch: + depth -= 1 + if depth == 0: return j + return len(s) + def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: dict | None = None, assigns: list | None = None) -> tuple[int, dict[str, VarVal], UOp | None]: """Parse a block of pcode. Returns (next_line, block_assigns, return_value). @@ -724,7 +740,6 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di if funcs is None: funcs = _FUNCS block_assigns: dict[str, VarVal] = {} i = start - def ctx(): return {**vars, **block_assigns} while i < len(lines): line = lines[i] @@ -738,7 +753,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di # return expr (lambda bodies) if first == 'return': rest = line[line.lower().find('return') + 6:].strip() - return i + 1, block_assigns, parse_expr(rest, ctx(), funcs) + return i + 1, block_assigns, parse_expr(rest, vars, funcs) # for loop if first == 'for': @@ -747,21 +762,18 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di p.eat_val('for', 'IDENT') loop_var = p.eat('IDENT').val p.eat_val('in', 'IDENT') - if p.at('NUM') and p.peek(1).type == 'QUOTE': p.eat('NUM'); p.eat('QUOTE') - if p.at('NUM'): - start_val = int(p.eat('NUM').val.rstrip('UuLl')) - else: - start_expr = p.parse() - start_val = int(start_expr.arg) if start_expr.op == Ops.CONST else 0 + def parse_bound(): + if p.at('NUM') and p.peek(1).type == 'QUOTE': p.eat('NUM'); p.eat('QUOTE') + if p.at('NUM'): return int(p.eat('NUM').val.rstrip('UuLl')) + expr = p.parse() + return int(expr.arg) if expr.op == Ops.CONST else 0 + start_val = parse_bound() p.eat('COLON') - if p.at('NUM') and p.peek(1).type == 'QUOTE': p.eat('NUM'); p.eat('QUOTE') - if p.at('NUM'): - end_val = int(p.eat('NUM').val.rstrip('UuLl')) - else: - end_expr = p.parse() - end_val = int(end_expr.arg) if end_expr.op == Ops.CONST else 0 + end_val = parse_bound() # Collect body - i += 1; body_lines, depth = [], 1 + i += 1 + body_lines: list[str] = [] + depth = 1 while i < len(lines) and depth > 0: btoks = tokenize(lines[i]) if btoks[0].type == 'IDENT': @@ -775,7 +787,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di if found_var: vars[found_var] = block_assigns[found_var] = _const(dtypes.bool, False) for loop_i in range(start_val, end_val + 1): subst_lines = [_subst_loop_var(bl, loop_var, loop_i) for bl in body_lines if not (has_break and bl.strip().lower() == 'break')] - _, iter_assigns, _ = parse_block(subst_lines, 0, {**vars, **block_assigns}, funcs, assigns) + _, iter_assigns, _ = parse_block(subst_lines, 0, vars, funcs, assigns) if has_break: assert found_var is not None found = block_assigns.get(found_var, vars.get(found_var)) @@ -791,7 +803,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di if bl_l.startswith('if ') and bl_l.endswith(' then'): if any(body_lines[k].strip().lower() == 'break' for k in range(j+1, len(body_lines))): cond_str = _subst_loop_var(bl.strip()[3:-5].strip(), loop_var, loop_i) - cond = _to_bool(parse_expr(cond_str, {**vars, **block_assigns}, funcs)) + cond = _to_bool(parse_expr(cond_str, vars, funcs)) block_assigns[found_var] = vars[found_var] = not_found.where(cond, found) break else: @@ -806,25 +818,17 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di # lambda definition if first != '{' and '=' in line and 'lambda' in line and any(t.type == 'IDENT' and t.val == 'lambda' for t in toks): name = toks[0].val - body_start, depth = line[line.find('(', line.find('lambda')):], 0 - params_end = 0 - for j, ch in enumerate(body_start): - if ch == '(': depth += 1 - elif ch == ')': - depth -= 1 - if depth == 0: params_end = j + 1; break + body_start = line[line.find('(', line.find('lambda')):] + params_end = _find_paren_end(body_start) + 1 params = [p.strip() for p in body_start[1:params_end-1].split(',') if p.strip()] rest = body_start[params_end:].strip() if rest.startswith('('): - depth, body_end = 1, 1 - for j, ch in enumerate(rest[1:], 1): - if ch == '(': depth += 1 - elif ch == ')': - depth -= 1 - if depth == 0: body_end = j; break - body = rest[1:body_end].strip() - if depth > 0: - body_lines_lst = [rest[1:]] + body_end = _find_paren_end(rest) + if body_end < len(rest): # found matching paren on same line + body = rest[1:body_end].strip() + i += 1 + else: # multiline body + body_lines_lst, depth = [rest[1:]], 1 i += 1 while i < len(lines) and depth > 0: for j, ch in enumerate(lines[i]): @@ -835,21 +839,20 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di else: body_lines_lst.append(lines[i]) i += 1 body = '\n'.join(body_lines_lst).strip() - else: i += 1 vars[name] = ('lambda', params, body) continue # MEM assignment: MEM[addr].type (+|-)?= value if first == 'mem' and toks[1].type == 'LBRACKET': j, addr_toks = _match_bracket(toks, 1) - addr = parse_tokens(addr_toks, ctx(), funcs) + addr = parse_tokens(addr_toks, vars, funcs) if j < len(toks) and toks[j].type == 'DOT': j += 1 dt_name = toks[j].val if j < len(toks) and toks[j].type == 'IDENT' else 'u32' dt, j = DTYPES.get(dt_name, dtypes.uint32), j + 1 compound_op = None if j < len(toks) and toks[j].type == 'ASSIGN_OP': compound_op = toks[j].val; j += 1 elif j < len(toks) and toks[j].type == 'EQUALS': j += 1 - rhs = parse_tokens(toks[j:], ctx(), funcs) + rhs = parse_tokens(toks[j:], vars, funcs) if compound_op: mem = vars.get('_vmem') if '_vmem' in vars else vars.get('_lds') if isinstance(mem, UOp): @@ -868,7 +871,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di if j < len(toks) and toks[j].type == 'LBRACKET': j, reg_toks = _match_bracket(toks, j) if j < len(toks) and toks[j].type == 'EQUALS': j += 1 - ln, rg, val = parse_tokens(lane_toks, ctx(), funcs), parse_tokens(reg_toks, ctx(), funcs), parse_tokens(toks[j:], ctx(), funcs) + ln, rg, val = parse_tokens(lane_toks, vars, funcs), parse_tokens(reg_toks, vars, funcs), parse_tokens(toks[j:], vars, funcs) if assigns is not None: assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}]', (_to_u32(rg) * _u32(32) + _to_u32(ln), val))) i += 1; continue @@ -884,7 +887,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di j += 3 if j < len(toks) and toks[j].type == 'RBRACE': j += 1 if j < len(toks) and toks[j].type == 'EQUALS': j += 1 - val = parse_tokens(toks[j:], ctx(), funcs) + val = parse_tokens(toks[j:], vars, funcs) lo_dt, hi_dt = DTYPES.get(lo_type, dtypes.uint64), DTYPES.get(hi_type, dtypes.uint32) lo_bits = 64 if lo_dt in (dtypes.uint64, dtypes.int64) else 32 lo_val = val.cast(lo_dt) if val.dtype.itemsize * 8 <= lo_bits else (val & _const(val.dtype, (1 << lo_bits) - 1)).cast(lo_dt) @@ -894,7 +897,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di if assigns is not None: assigns.extend([(f'{lo_var}.{lo_type}', lo_val), (f'{hi_var}.{hi_type}', hi_val)]) i += 1; continue - # Bit slice: var[hi:lo] = value or var.type[hi:lo] = value + # Bit slice/index: var[hi:lo] = value, var.type[hi:lo] = value, or var[expr] = value if len(toks) >= 5 and toks[0].type == 'IDENT' and (toks[1].type == 'LBRACKET' or (toks[1].type == 'DOT' and toks[3].type == 'LBRACKET')): bracket_start = 2 if toks[1].type == 'LBRACKET' else 4 j = bracket_start @@ -902,24 +905,33 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di while j < len(toks) and toks[j].type != 'RBRACKET': if toks[j].type == 'COLON': colon_pos = j j += 1 - if colon_pos is not None: + var = toks[0].val + if colon_pos is not None: # bit slice: var[hi:lo] hi_str = ' '.join(t.val for t in toks[bracket_start:colon_pos] if t.type != 'EOF') lo_str = ' '.join(t.val for t in toks[colon_pos+1:j] if t.type != 'EOF') try: - hi, lo = max(int(eval(hi_str)), int(eval(lo_str))), min(int(eval(hi_str)), int(eval(lo_str))) - var = toks[0].val + hi_val, lo_val = int(eval(hi_str)), int(eval(lo_str)) + hi, lo = max(hi_val, lo_val), min(hi_val, lo_val) j += 1 if j < len(toks) and toks[j].type == 'DOT': j += 2 if j < len(toks) and toks[j].type == 'EQUALS': j += 1 - val = parse_tokens(toks[j:], ctx(), funcs) + val = parse_tokens(toks[j:], vars, funcs) dt_suffix = toks[2].val if toks[1].type == 'DOT' else None if assigns is not None: assigns.append((f'{var}[{hi}:{lo}]' + (f'.{dt_suffix}' if dt_suffix else ''), val)) if var not in vars: vars[var] = _const(dtypes.uint64 if hi >= 32 else dtypes.uint32, 0) old = block_assigns.get(var, vars.get(var)) - mask = _u32(((1 << (hi - lo + 1)) - 1) << lo) - block_assigns[var] = vars[var] = (old & (mask ^ _u32(0xFFFFFFFF))) | (_val_to_bits(val) << _u32(lo)) + block_assigns[var] = vars[var] = _set_bits(old, _val_to_bits(val), hi - lo + 1, lo) i += 1; continue except: pass + elif toks[1].type == 'LBRACKET': # bit index: var[expr] (only for var[...], not var.type[...]) + existing = block_assigns.get(var, vars.get(var)) + if existing is not None and isinstance(existing, UOp) and not any(f'{var}{k}' in vars or f'{var}{k}' in block_assigns for k in range(8)): + bit_toks = toks[2:j] + j += 1 + while j < len(toks) and toks[j].type != 'EQUALS': j += 1 + if j < len(toks): + block_assigns[var] = vars[var] = _set_bit(existing, _to_u32(parse_tokens(bit_toks, vars, funcs)), parse_tokens(toks[j+1:], vars, funcs)) + i += 1; continue # Array element: var{idx} = value if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'LBRACE' and toks[2].type == 'NUM': @@ -927,7 +939,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di j = 4 while j < len(toks) and toks[j].type != 'EQUALS': j += 1 if j < len(toks): - val = parse_tokens(toks[j+1:], ctx(), funcs) + val = parse_tokens(toks[j+1:], vars, funcs) existing = block_assigns.get(var, vars.get(var)) if existing is not None and isinstance(existing, UOp): block_assigns[var] = vars[var] = _set_bit(existing, _u32(idx), val) @@ -936,121 +948,103 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di i += 1; continue # Compound assignment: var += or var -= - for j, t in enumerate(toks): - if t.type == 'ASSIGN_OP': + assign_op = next((j for j, t in enumerate(toks) if t.type == 'ASSIGN_OP'), None) + if assign_op is not None: + var = toks[0].val + old = block_assigns.get(var, vars.get(var, _u32(0))) + rhs = parse_tokens(toks[assign_op+1:], vars, funcs) + if rhs.dtype != old.dtype: rhs = rhs.cast(old.dtype) + block_assigns[var] = vars[var] = (old + rhs) if toks[assign_op].val == '+=' else (old - rhs) + i += 1; continue + + # Typed element: var.type[idx] = value + if len(toks) >= 7 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET' and toks[4].type == 'NUM': + var, dt_name, idx = toks[0].val, toks[2].val, int(toks[4].val) + dt = DTYPES.get(dt_name, dtypes.uint32) + j = 6 + while j < len(toks) and toks[j].type != 'EQUALS': j += 1 + if j < len(toks): + val, old = parse_tokens(toks[j+1:], vars, funcs), block_assigns.get(var, vars.get(var, _u32(0))) + bw = dt.itemsize * 8 + block_assigns[var] = vars[var] = _set_bits(old, val, bw, idx * bw) + if assigns is not None: assigns.append((f'{var}.{dt_name}[{idx}]', val)) + i += 1; continue + + # Dynamic bit: var.type[expr_with_brackets] = value + if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET': + j, depth, has_inner = 4, 1, False + while j < len(toks) and depth > 0: + if toks[j].type == 'LBRACKET': depth += 1; has_inner = True + elif toks[j].type == 'RBRACKET': depth -= 1 + j += 1 + if has_inner: var = toks[0].val - old = block_assigns.get(var, vars.get(var, _u32(0))) - rhs = parse_tokens(toks[j+1:], ctx(), funcs) - if rhs.dtype != old.dtype: rhs = rhs.cast(old.dtype) - block_assigns[var] = vars[var] = (old + rhs) if t.val == '+=' else (old - rhs) - i += 1; break - else: - # Typed element: var.type[idx] = value - if len(toks) >= 7 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET' and toks[4].type == 'NUM': - var, dt_name, idx = toks[0].val, toks[2].val, int(toks[4].val) - dt = DTYPES.get(dt_name, dtypes.uint32) - j = 6 + bit_pos = _to_u32(parse_tokens(toks[4:j-1], vars, funcs)) while j < len(toks) and toks[j].type != 'EQUALS': j += 1 if j < len(toks): - val, old = parse_tokens(toks[j+1:], ctx(), funcs), block_assigns.get(var, vars.get(var, _u32(0))) - bw, lo_bit = dt.itemsize * 8, idx * dt.itemsize * 8 - mask = _u32(((1 << bw) - 1) << lo_bit) - block_assigns[var] = vars[var] = (old & (mask ^ _u32(0xFFFFFFFF))) | (((val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val) & _u32((1 << bw) - 1)) << _u32(lo_bit)) - if assigns is not None: assigns.append((f'{var}.{dt_name}[{idx}]', val)) + val = parse_tokens(toks[j+1:], vars, funcs) + old = block_assigns.get(var, vars.get(var, _u32(0))) + block_assigns[var] = vars[var] = _set_bit(old, bit_pos, val) i += 1; continue - # Dynamic bit: var.type[expr_with_brackets] = value - if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET': - j, depth, has_inner = 4, 1, False - while j < len(toks) and depth > 0: - if toks[j].type == 'LBRACKET': depth += 1; has_inner = True - elif toks[j].type == 'RBRACKET': depth -= 1 - j += 1 - if has_inner: - var = toks[0].val - bit_pos = _to_u32(parse_tokens(toks[4:j-1], ctx(), funcs)) - while j < len(toks) and toks[j].type != 'EQUALS': j += 1 - if j < len(toks): - val = parse_tokens(toks[j+1:], ctx(), funcs) - old, mask = block_assigns.get(var, vars.get(var, _u32(0))), _u32(1) << bit_pos - block_assigns[var] = vars[var] = (old | mask) if val.op == Ops.CONST and val.arg == 1 else \ - (old & (mask ^ _u32(0xFFFFFFFF))) if val.op == Ops.CONST and val.arg == 0 else _set_bit(old, bit_pos, val) - i += 1; continue - - # Bit index: var[expr] = value (bit assignment to existing scalar) - if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'LBRACKET': - var = toks[0].val - existing = block_assigns.get(var, vars.get(var)) - if existing is not None and isinstance(existing, UOp) and not any(f'{var}{k}' in vars or f'{var}{k}' in block_assigns for k in range(8)): - j = 2 - while j < len(toks) and toks[j].type != 'RBRACKET': j += 1 - bit_toks = toks[2:j] - j += 1 - while j < len(toks) and toks[j].type != 'EQUALS': j += 1 - if j < len(toks): - block_assigns[var] = vars[var] = _set_bit(existing, _to_u32(parse_tokens(bit_toks, ctx(), funcs)), parse_tokens(toks[j+1:], ctx(), funcs)) - i += 1; continue - - # If/elsif/else - skip branches with statically false conditions (WAVE32/WAVE64) - if first == 'if': - def parse_cond(s, kw): - ll = s.lower() - return _to_bool(parse_expr(s[ll.find(kw) + len(kw):ll.rfind('then')].strip(), ctx(), funcs)) - def not_static_false(c): return c.op != Ops.CONST or c.arg is not False - cond = parse_cond(line, 'if') - conditions: list[tuple[UOp, UOp | dict[str, VarVal] | None]] = [(cond, None)] if not_static_false(cond) else [] - else_branch: tuple[UOp | None, dict[str, VarVal]] = (None, {}) - vars_snap = dict(vars) - i += 1 - i, branch, ret = parse_block(lines, i, vars, funcs, assigns) - if conditions: conditions[0] = (cond, ret if ret is not None else branch) - vars.clear(); vars.update(vars_snap) - while i < len(lines): - ltoks = tokenize(lines[i]) - if ltoks[0].type != 'IDENT': break - lf = ltoks[0].val.lower() - if lf == 'elsif': - c = parse_cond(lines[i], 'elsif') - i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns) - if not_static_false(c): conditions.append((c, ret if ret is not None else branch)) - vars.clear(); vars.update(vars_snap) - elif lf == 'else': - i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns) - else_branch = (ret, branch) - vars.clear(); vars.update(vars_snap) - elif lf == 'endif': i += 1; break - else: break - # Check if any branch returned a value (lambda-style) - if any(isinstance(br, UOp) for _, br in conditions): - result = else_branch[0] - for c, rv in reversed(conditions): - if isinstance(rv, UOp) and isinstance(result, UOp): - if rv.dtype != result.dtype and rv.dtype.itemsize == result.dtype.itemsize: result = result.cast(rv.dtype) - result = c.where(rv, result) - return i, block_assigns, result - # Main style: merge variable assignments with WHERE - else_assigns = else_branch[1] - all_vars = set().union(*[ba.keys() for _, ba in conditions if isinstance(ba, dict)], else_assigns.keys()) - for var in all_vars: - res: Any = else_assigns.get(var, block_assigns.get(var, vars.get(var, _u32(0)))) - for cond, ba in reversed(conditions): - if isinstance(ba, dict) and var in ba: - tv = ba[var] - if isinstance(tv, UOp) and isinstance(res, UOp): - res = cond.where(tv, res.cast(tv.dtype) if tv.dtype != res.dtype and tv.dtype.itemsize == res.dtype.itemsize else res) - block_assigns[var] = vars[var] = res - continue - - # Regular assignment: var = value - for j, t in enumerate(toks): - if t.type == 'EQUALS': - if any(toks[k].type == 'OP' and toks[k].val in ('<', '>', '!', '=') for k in range(j)): break - base_var = toks[0].val - block_assigns[base_var] = vars[base_var] = parse_tokens(toks[j+1:], ctx(), funcs) - i += 1; break - else: i += 1 + # If/elsif/else - skip branches with statically false conditions (WAVE32/WAVE64) + if first == 'if': + def parse_cond(s, kw): + ll = s.lower() + return _to_bool(parse_expr(s[ll.find(kw) + len(kw):ll.rfind('then')].strip(), vars, funcs)) + def not_static_false(c): return c.op != Ops.CONST or c.arg is not False + cond = parse_cond(line, 'if') + conditions: list[tuple[UOp, UOp | dict[str, VarVal] | None]] = [(cond, None)] if not_static_false(cond) else [] + else_branch: tuple[UOp | None, dict[str, VarVal]] = (None, {}) + vars_snap = dict(vars) + i += 1 + i, branch, ret = parse_block(lines, i, vars, funcs, assigns) + if conditions: conditions[0] = (cond, ret if ret is not None else branch) + vars.clear(); vars.update(vars_snap) + while i < len(lines): + ltoks = tokenize(lines[i]) + if ltoks[0].type != 'IDENT': break + lf = ltoks[0].val.lower() + if lf == 'elsif': + c = parse_cond(lines[i], 'elsif') + i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns) + if not_static_false(c): conditions.append((c, ret if ret is not None else branch)) + vars.clear(); vars.update(vars_snap) + elif lf == 'else': + i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns) + else_branch = (ret, branch) + vars.clear(); vars.update(vars_snap) + elif lf == 'endif': i += 1; break + else: break + # Check if any branch returned a value (lambda-style) + if any(isinstance(br, UOp) for _, br in conditions): + result = else_branch[0] + for c, rv in reversed(conditions): + if isinstance(rv, UOp) and isinstance(result, UOp): + if rv.dtype != result.dtype and rv.dtype.itemsize == result.dtype.itemsize: result = result.cast(rv.dtype) + result = c.where(rv, result) + return i, block_assigns, result + # Main style: merge variable assignments with WHERE + else_assigns = else_branch[1] + all_vars = set().union(*[ba.keys() for _, ba in conditions if isinstance(ba, dict)], else_assigns.keys()) + for var in all_vars: + res: Any = else_assigns.get(var, block_assigns.get(var, vars.get(var, _u32(0)))) + for cond, ba in reversed(conditions): + if isinstance(ba, dict) and var in ba: + tv = ba[var] + if isinstance(tv, UOp) and isinstance(res, UOp): + res = cond.where(tv, res.cast(tv.dtype) if tv.dtype != res.dtype and tv.dtype.itemsize == res.dtype.itemsize else res) + block_assigns[var] = vars[var] = res continue - continue + + # Regular assignment: var = value + for j, t in enumerate(toks): + if t.type == 'EQUALS': + if any(toks[k].type == 'OP' and toks[k].val in ('<', '>', '!', '=') for k in range(j)): break + base_var = toks[0].val + block_assigns[base_var] = vars[base_var] = parse_tokens(toks[j+1:], vars, funcs) + i += 1; break + else: i += 1 return i, block_assigns, None def parse_expr(expr: str, vars: dict[str, VarVal], funcs: dict | None = None) -> UOp: diff --git a/extra/assembly/amd/test/hw/test_ds.py b/extra/assembly/amd/test/hw/test_ds.py index 2783be043c..4f307bcf67 100644 --- a/extra/assembly/amd/test/hw/test_ds.py +++ b/extra/assembly/amd/test/hw/test_ds.py @@ -138,6 +138,50 @@ class TestDS2AddrMore(unittest.TestCase): self.assertEqual(st.vgpr[0][4], 0x12345678, "v4 should be untouched") +class TestDSB96(unittest.TestCase): + """Tests for DS_STORE_B96 and DS_LOAD_B96 (96-bit / 3 dwords).""" + + def test_ds_store_load_b96(self): + """DS_STORE_B96 stores 3 VGPRs, DS_LOAD_B96 loads them back.""" + instructions = [ + v_mov_b32_e32(v[10], 0), + s_mov_b32(s[0], 0x11111111), + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[0], 0x22222222), + v_mov_b32_e32(v[1], s[0]), + s_mov_b32(s[0], 0x33333333), + v_mov_b32_e32(v[2], s[0]), + ds_store_b96(addr=v[10], data0=v[0:2]), + s_waitcnt(lgkmcnt=0), + ds_load_b96(addr=v[10], vdst=v[4:6]), + s_waitcnt(lgkmcnt=0), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vgpr[0][4], 0x11111111, "v4 should have first dword") + self.assertEqual(st.vgpr[0][5], 0x22222222, "v5 should have second dword") + self.assertEqual(st.vgpr[0][6], 0x33333333, "v6 should have third dword") + + def test_ds_store_b96_with_offset(self): + """DS_STORE_B96 with non-zero offset.""" + instructions = [ + v_mov_b32_e32(v[10], 0), + s_mov_b32(s[0], 0xAAAAAAAA), + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[0], 0xBBBBBBBB), + v_mov_b32_e32(v[1], s[0]), + s_mov_b32(s[0], 0xCCCCCCCC), + v_mov_b32_e32(v[2], s[0]), + DS(DSOp.DS_STORE_B96, addr=v[10], data0=v[0:2], offset0=12), + s_waitcnt(lgkmcnt=0), + DS(DSOp.DS_LOAD_B96, addr=v[10], vdst=v[4:6], offset0=12), + s_waitcnt(lgkmcnt=0), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vgpr[0][4], 0xAAAAAAAA) + self.assertEqual(st.vgpr[0][5], 0xBBBBBBBB) + self.assertEqual(st.vgpr[0][6], 0xCCCCCCCC) + + class TestDSB128(unittest.TestCase): """Tests for DS_STORE_B128 and DS_LOAD_B128 (128-bit / 4 dwords).""" diff --git a/extra/assembly/amd/test/hw/test_smem.py b/extra/assembly/amd/test/hw/test_smem.py index d05518019d..d76a7597b3 100644 --- a/extra/assembly/amd/test/hw/test_smem.py +++ b/extra/assembly/amd/test/hw/test_smem.py @@ -265,6 +265,113 @@ class TestSLoadMultiDword(unittest.TestCase): self.assertEqual(st.sgpr[5], st.sgpr[9]) +class TestSLoadLarge(unittest.TestCase): + """Tests for large s_load operations (s_load_b256, s_load_b512).""" + + def test_s_load_b256_basic(self): + """s_load_b256 loads 8 consecutive dwords.""" + instructions = [ + s_load_b64(s[2:3], s[80:81], 0, soffset=NULL), + s_waitcnt(lgkmcnt=0), + v_mov_b32_e32(v[0], 0), + # Store 8 test values + s_mov_b32(s[20], 0x11111111), + v_mov_b32_e32(v[2], s[20]), + global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET), + s_mov_b32(s[20], 0x22222222), + v_mov_b32_e32(v[2], s[20]), + global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+4), + s_mov_b32(s[20], 0x33333333), + v_mov_b32_e32(v[2], s[20]), + global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+8), + s_mov_b32(s[20], 0x44444444), + v_mov_b32_e32(v[2], s[20]), + global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+12), + s_mov_b32(s[20], 0x55555555), + v_mov_b32_e32(v[2], s[20]), + global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+16), + s_mov_b32(s[20], 0x66666666), + v_mov_b32_e32(v[2], s[20]), + global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+20), + s_mov_b32(s[20], 0x77777777), + v_mov_b32_e32(v[2], s[20]), + global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+24), + s_mov_b32(s[20], 0x88888888), + v_mov_b32_e32(v[2], s[20]), + global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+28), + s_waitcnt(vmcnt=0), + *CACHE_INV, + # Load all 8 dwords with s_load_b256 + s_load_b256(s[4:11], s[2:3], NULL, offset=TEST_OFFSET), + s_waitcnt(lgkmcnt=0), + s_mov_b32(s[2], 0), s_mov_b32(s[3], 0), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.sgpr[4], 0x11111111) + self.assertEqual(st.sgpr[5], 0x22222222) + self.assertEqual(st.sgpr[6], 0x33333333) + self.assertEqual(st.sgpr[7], 0x44444444) + self.assertEqual(st.sgpr[8], 0x55555555) + self.assertEqual(st.sgpr[9], 0x66666666) + self.assertEqual(st.sgpr[10], 0x77777777) + self.assertEqual(st.sgpr[11], 0x88888888) + + def test_s_load_b512_basic(self): + """s_load_b512 loads 16 consecutive dwords.""" + instructions = [ + s_load_b64(s[2:3], s[80:81], 0, soffset=NULL), + s_waitcnt(lgkmcnt=0), + v_mov_b32_e32(v[0], 0), + # Store 16 test values (use a pattern: 0x10, 0x20, ..., 0x100) + *[instr for i in range(16) for instr in [ + s_mov_b32(s[20], (i + 1) * 0x11111111), + v_mov_b32_e32(v[2], s[20]), + global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET + i * 4), + ]], + s_waitcnt(vmcnt=0), + *CACHE_INV, + # Load all 16 dwords with s_load_b512 + s_load_b512(s[64:79], s[2:3], NULL, offset=TEST_OFFSET), + s_waitcnt(lgkmcnt=0), + # Copy results to lower regs for verification (since st.sgpr only has 16 regs in test) + s_mov_b32(s[4], s[64]), + s_mov_b32(s[5], s[65]), + s_mov_b32(s[6], s[78]), + s_mov_b32(s[7], s[79]), + s_mov_b32(s[2], 0), s_mov_b32(s[3], 0), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.sgpr[4], 0x11111111, "first dword") + self.assertEqual(st.sgpr[5], 0x22222222, "second dword") + self.assertEqual(st.sgpr[6], 0xFFFFFFFF & (15 * 0x11111111), "15th dword") + self.assertEqual(st.sgpr[7], 0xFFFFFFFF & (16 * 0x11111111), "16th dword") + + def test_s_load_b256_with_register_offset(self): + """s_load_b256 with register offset should add reg offset to address.""" + instructions = [ + s_load_b64(s[2:3], s[80:81], 0, soffset=NULL), + s_waitcnt(lgkmcnt=0), + v_mov_b32_e32(v[0], 0), + # Store pattern at TEST_OFFSET+8: skip first 2 dwords + *[instr for i in range(8) for instr in [ + s_mov_b32(s[20], (i + 1) * 0x11111111), + v_mov_b32_e32(v[2], s[20]), + global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET + 8 + i * 4), + ]], + s_waitcnt(vmcnt=0), + *CACHE_INV, + # Load with register offset 8 + s_mov_b32(s[20], 8), + s_load_b256(s[4:11], s[2:3], s[20], offset=TEST_OFFSET), + s_waitcnt(lgkmcnt=0), + s_mov_b32(s[2], 0), s_mov_b32(s[3], 0), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.sgpr[4], 0x11111111, "first dword at offset+8") + self.assertEqual(st.sgpr[5], 0x22222222, "second dword at offset+8") + self.assertEqual(st.sgpr[11], 0x88888888, "last dword at offset+8") + + class TestSLoadOffset(unittest.TestCase): """Tests for s_load with different immediate offsets. diff --git a/extra/assembly/amd/test/hw/test_sop.py b/extra/assembly/amd/test/hw/test_sop.py index ca6adfd3f8..b8f340ad7b 100644 --- a/extra/assembly/amd/test/hw/test_sop.py +++ b/extra/assembly/amd/test/hw/test_sop.py @@ -719,5 +719,172 @@ class TestNullRegister(unittest.TestCase): self.assertEqual(st.scc, 0) +class Test64BitSOP1InlineConstants(unittest.TestCase): + """Tests for 64-bit SOP1 instructions with inline constants. + + Regression tests for bug where rsrc_dyn didn't properly handle 64-bit + inline constants, incorrectly duplicating lo bits to hi instead of + zero/sign-extending. + """ + + def test_s_mov_b64_inline_0(self): + """S_MOV_B64 with inline constant 0.""" + instructions = [ + s_mov_b64(s[0:1], 0), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vgpr[0][0], 0) + self.assertEqual(st.vgpr[0][1], 0) + + def test_s_mov_b64_inline_16(self): + """S_MOV_B64 with inline constant 16 should set lo=16, hi=0.""" + instructions = [ + s_mov_b64(s[0:1], 16), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vgpr[0][0], 16) + self.assertEqual(st.vgpr[0][1], 0) + + def test_s_mov_b64_inline_64(self): + """S_MOV_B64 with inline constant 64 (max positive).""" + instructions = [ + s_mov_b64(s[0:1], 64), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vgpr[0][0], 64) + self.assertEqual(st.vgpr[0][1], 0) + + def test_s_mov_b64_inline_neg1(self): + """S_MOV_B64 with inline constant -1 should sign-extend.""" + instructions = [ + s_mov_b64(s[0:1], -1), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vgpr[0][0], 0xFFFFFFFF) + self.assertEqual(st.vgpr[0][1], 0xFFFFFFFF) + + def test_s_mov_b64_inline_neg16(self): + """S_MOV_B64 with inline constant -16 should sign-extend.""" + instructions = [ + s_mov_b64(s[0:1], -16), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vgpr[0][0], 0xFFFFFFF0) + self.assertEqual(st.vgpr[0][1], 0xFFFFFFFF) + + def test_s_mov_b64_float_const_1_0(self): + """S_MOV_B64 with float inline constant 1.0 - casts F32 to F64.""" + instructions = [ + s_mov_b64(s[0:1], 1.0), # inline constant 242 (1.0f) + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + ] + st = run_program(instructions, n_lanes=1) + # Hardware casts F32 to F64: 1.0f64 = 0x3FF0000000000000 + self.assertEqual(st.vgpr[0][0], 0x00000000) # lo + self.assertEqual(st.vgpr[0][1], 0x3FF00000) # hi + + def test_s_or_b64_inline_constant(self): + """S_OR_B64 with 64-bit inline constant.""" + instructions = [ + s_mov_b64(s[0:1], 0), + s_or_b64(s[2:3], s[0:1], 16), + v_mov_b32_e32(v[0], s[2]), + v_mov_b32_e32(v[1], s[3]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vgpr[0][0], 16) + self.assertEqual(st.vgpr[0][1], 0) + + def test_s_and_b64_inline_constant(self): + """S_AND_B64 with 64-bit inline constant.""" + instructions = [ + s_mov_b32(s[0], 0xFFFFFFFF), + s_mov_b32(s[1], 0xFFFFFFFF), + s_and_b64(s[2:3], s[0:1], 16), + v_mov_b32_e32(v[0], s[2]), + v_mov_b32_e32(v[1], s[3]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vgpr[0][0], 16) + self.assertEqual(st.vgpr[0][1], 0) + + +class Test64BitSOPLiterals(unittest.TestCase): + """Tests for 64-bit SOP instructions with 32-bit literals. + + Tests the behavior when a 64-bit SOP instruction uses a 32-bit literal + (offset 255 in instruction encoding). The literal is zero-extended to 64 bits. + """ + + def test_s_mov_b64_literal(self): + """S_MOV_B64 with 32-bit literal value - zero-extended to 64 bits.""" + instructions = [ + s_mov_b64(s[0:1], 0x12345678), # literal > 64, uses literal encoding + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vgpr[0][0], 0x12345678) + self.assertEqual(st.vgpr[0][1], 0) + + def test_s_or_b64_literal(self): + """S_OR_B64 with 32-bit literal value - zero-extended to 64 bits.""" + instructions = [ + s_mov_b64(s[0:1], 0), + s_or_b64(s[2:3], s[0:1], 0x12345678), # literal + v_mov_b32_e32(v[0], s[2]), + v_mov_b32_e32(v[1], s[3]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vgpr[0][0], 0x12345678) + self.assertEqual(st.vgpr[0][1], 0) + + def test_s_and_b64_literal(self): + """S_AND_B64 with 32-bit literal value - zero-extended to 64 bits.""" + instructions = [ + s_mov_b32(s[0], 0xFFFFFFFF), + s_mov_b32(s[1], 0xFFFFFFFF), + s_and_b64(s[2:3], s[0:1], 0x12345678), # literal + v_mov_b32_e32(v[0], s[2]), + v_mov_b32_e32(v[1], s[3]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vgpr[0][0], 0x12345678) + self.assertEqual(st.vgpr[0][1], 0) + + def test_s_mov_b64_literal_negative(self): + """S_MOV_B64 with 0xFFFFFFFF literal - zero-extended (not sign-extended).""" + instructions = [ + s_mov_b64(s[0:1], 0xFFFFFFFF), # -1 as 32-bit, but zero-extended to 64-bit + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vgpr[0][0], 0xFFFFFFFF) + self.assertEqual(st.vgpr[0][1], 0) # zero-extended, not sign-extended + + def test_s_mov_b64_literal_high_bit(self): + """S_MOV_B64 with 0x80000000 literal - zero-extended (not sign-extended).""" + instructions = [ + s_mov_b64(s[0:1], 0x80000000), # high bit set, but zero-extended + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vgpr[0][0], 0x80000000) + self.assertEqual(st.vgpr[0][1], 0) # zero-extended, not sign-extended + + if __name__ == '__main__': unittest.main() diff --git a/extra/assembly/amd/test/hw/test_vop3.py b/extra/assembly/amd/test/hw/test_vop3.py index 83bce16df7..be8ecdde6e 100644 --- a/extra/assembly/amd/test/hw/test_vop3.py +++ b/extra/assembly/amd/test/hw/test_vop3.py @@ -1359,6 +1359,43 @@ class TestF64ToI64Conversion(unittest.TestCase): self.assertEqual(result, 5000000000) +class TestB64VOPLiteral(unittest.TestCase): + """Tests for B64 VOP operations with literal encoding. + + B64 operations (like V_LSHLREV_B64) should zero-extend the literal to 64 bits, + NOT put it in the high 32 bits like F64 operations do. + """ + + def test_v_lshlrev_b64_literal_shift_amount(self): + """V_LSHLREV_B64 with literal shift amount (src0 is 32-bit).""" + # Shift 1 left by 100 (0x64) - uses literal encoding for src0 + # Shift amount is 100 & 63 = 36, so 1 << 36 = 0x1000000000 + instructions = [ + s_mov_b32(s[0], 1), + s_mov_b32(s[1], 0), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_lshlrev_b64(v[2:3], 100, v[0:1]), # 100 > 64, uses literal encoding + ] + st = run_program(instructions, n_lanes=1) + # lo = 0x00000000, hi = 0x00000010 = 1 << (36-32) + self.assertEqual(st.vgpr[0][2], 0x00000000) + self.assertEqual(st.vgpr[0][3], 0x00000010) + + def test_v_lshlrev_b64_literal_value(self): + """V_LSHLREV_B64 with literal as the 64-bit value being shifted (src1). + + B64 literals are zero-extended (not shifted to high bits like F64). + 0xDEADBEEF << 4 = 0xDEADBEEF0 = lo=0xEADBEEF0, hi=0x0000000D + """ + instructions = [ + v_lshlrev_b64(v[0:1], 4, 0xDEADBEEF), # shift literal left by 4 + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vgpr[0][0], 0xEADBEEF0) # lo + self.assertEqual(st.vgpr[0][1], 0x0000000D) # hi + + class TestWMMAMore(unittest.TestCase): """More WMMA tests."""