assembly/amd: continue refactors (#14386)

* simpler

* merge

* flat

* no ctx

* use the correct apis

* dup code

* write clean code

* remove bad helpers

* bits junk remove

* junk remove

* smem test

* fix tests

* correct fix + tests

* Fmt matters it seems

* wmma refactor

* a lil more

* kimi cleanups

* line
This commit is contained in:
George Hotz
2026-01-28 17:33:03 +08:00
committed by GitHub
parent 5bffa17f82
commit 202b74b369
7 changed files with 676 additions and 352 deletions

View File

@@ -253,6 +253,15 @@ def _get_variant(cls, suffix: str):
module = sys.modules.get(cls.__module__)
return getattr(module, f"{cls.__name__}{suffix}", None) if module else None
def _canonical_name(name: str) -> str | None:
"""Map operand name to canonical name."""
if name in ('src0', 'vsrc0', 'ssrc0'): return 's0'
if name in ('src1', 'vsrc1', 'ssrc1'): return 's1'
if name == 'src2': return 's2'
if name in ('vdst', 'sdst', 'sdata'): return 'd'
if name in ('data', 'vdata', 'data0', 'vsrc'): return 'data'
return None
class Inst:
_fields: list[tuple[str, BitField]]
_base_size: int
@@ -368,12 +377,17 @@ class Inst:
"""Get bit widths with canonical names: {'s0', 's1', 's2', 'd', 'data'}."""
bits = {'d': 32, 's0': 32, 's1': 32, 's2': 32, 'data': 32}
for name, val in self.op_bits.items():
if name in ('src0', 'vsrc0', 'ssrc0'): bits['s0'] = val
elif name in ('src1', 'vsrc1', 'ssrc1'): bits['s1'] = val
elif name == 'src2': bits['s2'] = val
elif name in ('vdst', 'sdst', 'sdata'): bits['d'] = val
elif name in ('data', 'vdata', 'data0', 'vsrc'): bits['data'] = val
if (cn := _canonical_name(name)): bits[cn] = val
return bits
@functools.cached_property
def canonical_operands(self) -> dict:
"""Get operands with canonical names: {'s0', 's1', 's2', 'd', 'data'}."""
result = {}
for name, val in self.operands.items():
if (cn := _canonical_name(name)): result[cn] = val
return result
@property
def canonical_op_regs(self) -> dict[str, int]:
"""Get register counts with canonical names: {'s0', 's1', 's2', 'd', 'data'}."""

View File

@@ -53,7 +53,7 @@ from extra.assembly.amd.autogen.rdna3.str_pcode import PCODE
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP1_SDST, VOP2, VOP3, VOP3_SDST, VOP3SD, VOP3P, VOPC,
DS, FLAT, GLOBAL, SCRATCH, VOPD, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOPDOp)
from extra.assembly.amd.dsl import VCC_LO, EXEC_LO, SCC
from extra.assembly.amd.autogen.common import OpType
from extra.assembly.amd.autogen.common import Fmt, OpType
from extra.assembly.amd.pcode import parse_block, _FUNCS
MASK32 = 0xFFFFFFFF
@@ -70,14 +70,14 @@ def _split64(val: UOp) -> tuple[UOp, UOp]:
return v64.cast(dtypes.uint32), (v64 >> UOp.const(dtypes.uint64, 32)).cast(dtypes.uint32)
_SRC_MOD_TYPES = {16: (dtypes.uint16, dtypes.half, 0x7FFF), 64: (dtypes.uint64, dtypes.float64, 0x7FFFFFFFFFFFFFFF), 32: (dtypes.uint32, dtypes.float32, 0x7FFFFFFF)}
def _apply_src_mods(val: UOp, mod_bit: int, abs_bits: int, neg_bits: int, is_16bit: bool = False, is_64bit: bool = False) -> UOp:
"""Apply abs/neg modifiers to source value based on operation type."""
def _apply_src_mods(val: UOp, mod_bit: int, abs_bits: int, neg_bits: int, bits: int = 32) -> UOp:
"""Apply abs/neg modifiers to source value based on bit width (16, 32, or 64)."""
if not (abs_bits & (1 << mod_bit)) and not (neg_bits & (1 << mod_bit)): return val
ut, ft, mask = _SRC_MOD_TYPES[16 if is_16bit else 64 if is_64bit else 32]
fv = val.cast(ut).bitcast(ft) if is_16bit else val.bitcast(ft) if val.dtype == ut else val
ut, ft, mask = _SRC_MOD_TYPES[bits]
fv = val.cast(ut).bitcast(ft) if bits == 16 else val.bitcast(ft) if val.dtype == ut else val
if abs_bits & (1 << mod_bit): fv = (fv.bitcast(ut) & UOp.const(ut, mask)).bitcast(ft)
if neg_bits & (1 << mod_bit): fv = fv.neg()
return fv.bitcast(ut).cast(dtypes.uint32) if is_16bit else fv.bitcast(ut)
return fv.bitcast(ut).cast(dtypes.uint32) if bits == 16 else fv.bitcast(ut)
# Map VOPD ops to VOP2 ops for pcode lookup
VOPD_TO_VOP2 = {
@@ -95,11 +95,10 @@ PC_LO_IDX, PC_HI_IDX, SCRATCH_STRIDE_IDX = 256, 257, 259
# SGPR buffer: 0-127 = SGPRs, 128-255 = inline constants, 256-259 = special registers
SGPR_COUNT, VGPR_SIZE = 260, 256 * 32
def _is_16bit_op(op_name: str) -> bool: return any(x in op_name for x in ('B16', 'F16', 'I16', 'U16'))
def _op_name(inst) -> str:
if hasattr(inst, 'opx'): return f"{inst.opx.name}_{inst.opy.name}" # VOPD has opx/opy not op
return inst.op.name if hasattr(inst.op, 'name') else str(inst.op)
def _is_64bit_dest(dest: str) -> bool: return any(dest.endswith(x) for x in ('.b64', '.u64', '.i64', '.f64'))
def _to_u32(val: UOp) -> UOp:
if val.dtype == dtypes.uint32: return val
if val.dtype.itemsize == 4: return val.bitcast(dtypes.uint32) # same size: bitcast (float32->uint32)
@@ -192,9 +191,9 @@ def _write_64bit(val: UOp, wfn, reg_or_addr, is_mem: bool, *args) -> list[UOp]:
incr = 4 if is_mem else 1 # 4 bytes for memory addresses, 1 for register indices
return [wfn(reg_or_addr, lo, *args), wfn(reg_or_addr + (UOp.const(reg_or_addr.dtype, incr) if isinstance(reg_or_addr, UOp) else incr), hi, *args)]
def _write_val(dest: str, val: UOp, wfn, reg_or_addr, *args, is_mem: bool = False) -> list[UOp]:
"""Write value, splitting 64-bit if needed based on dest type suffix."""
return _write_64bit(val, wfn, reg_or_addr, is_mem, *args) if _is_64bit_dest(dest) else [wfn(reg_or_addr, _to_u32(val), *args)]
def _write_val(bits: int, val: UOp, wfn, reg_or_addr, *args, is_mem: bool = False) -> list[UOp]:
"""Write value, splitting 64-bit if needed. bits=64 for 64-bit writes, otherwise 32-bit."""
return _write_64bit(val, wfn, reg_or_addr, is_mem, *args) if bits == 64 else [wfn(reg_or_addr, _to_u32(val), *args)]
def _mem_store(mem: UOp, addr: UOp, val: UOp, active: UOp, addr_bits: int = 32, data_bits: int = 32) -> list[UOp]:
"""Conditional memory store with sub-word support. Returns list of store UOps."""
@@ -337,31 +336,38 @@ class _Ctx:
offset = reg.cast(dtypes.int) * _c(32, dtypes.int) + lane.cast(dtypes.int)
return buf.index(offset, _lane_active(exec_mask, lane)).store(val.cast(dtypes.uint32))
def rsrc_dyn(self, off: UOp, lane: UOp, bits: int = 32, literal: UOp | None = None) -> UOp:
"""Read source operand with dynamic offset. Handles SGPR/inline constants (<256), VGPR (>=256)."""
is_vgpr, vgpr_reg = off >= _c(256), off - _c(256)
def rsrc_dyn(self, off: UOp, lane: UOp | None, bits: int = 32, literal: UOp | None = None, is_f64: bool = False) -> UOp:
"""Read source operand with dynamic offset. Handles SGPR/inline constants (<256), VGPR (>=256).
If lane is None, only scalar access is supported (off must be < 256).
is_f64: True for F64 operations where 64-bit literals go in high 32 bits."""
is_float_const = (off >= _c(240)) & (off <= _c(248))
sgpr_lo = self.rsgpr_dyn(off)
vgpr_lo = self.rvgpr_dyn(vgpr_reg, lane, is_vgpr)
if lane is not None:
is_vgpr, vgpr_reg = off >= _c(256), off - _c(256)
vgpr_lo = self.rvgpr_dyn(vgpr_reg, lane, is_vgpr)
vgpr_val = _u64(vgpr_lo, self.rvgpr_dyn(vgpr_reg + _c(1), lane, is_vgpr)) if bits == 64 else vgpr_lo
if bits == 64:
vgpr_val = _u64(vgpr_lo, self.rvgpr_dyn(vgpr_reg + _c(1), lane, is_vgpr))
sgpr_val = _u64(sgpr_lo, self.rsgpr_dyn(off + _c(1)))
# Float constants: cast F32 to F64; integer inline: duplicate lo
inline = is_float_const.where(sgpr_lo.bitcast(dtypes.float32).cast(dtypes.float64).bitcast(dtypes.uint64), _u64(sgpr_lo, sgpr_lo))
if literal is not None: inline = off.eq(_c(255)).where(literal.cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32), inline)
# Integer inline constants: sign-extend 32-bit value from buffer to 64-bit
# Float constants: cast F32 to F64
int_inline = sgpr_lo.cast(dtypes.int32).cast(dtypes.int64)
float_inline = sgpr_lo.bitcast(dtypes.float32).cast(dtypes.float64)
# compute inline
inline = is_float_const.where(float_inline.bitcast(dtypes.uint64), int_inline.bitcast(dtypes.uint64))
# Literal handling: F64 VOP puts literal in high 32 bits; B64/I64/U64 VOP and SOP zero-extend
if literal is not None:
lit_val = literal.cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32) if is_f64 else literal.cast(dtypes.uint64)
inline = off.eq(_c(255)).where(lit_val, inline)
scalar_val = (off < _c(128)).where(sgpr_val, inline)
else:
vgpr_val = vgpr_lo
scalar_val = sgpr_lo
if literal is not None: scalar_val = off.eq(_c(255)).where(literal, scalar_val)
if bits == 16: # Float constants: cast F32 to F16
scalar_val = is_float_const.where(scalar_val.bitcast(dtypes.float32).cast(dtypes.half).bitcast(dtypes.uint16).cast(dtypes.uint32), scalar_val)
return is_vgpr.where(vgpr_val, scalar_val)
def rsrc_dyn_sized(self, off: UOp, lane: UOp, sizes: dict, key: str, f16: bool = False, literal: UOp | None = None) -> UOp:
return self.rsrc_dyn(off, lane, 64, literal) if sizes.get(key, 1) == 2 else self.rsrc_dyn(off, lane, 16 if f16 else 32, literal)
return is_vgpr.where(vgpr_val, scalar_val) if lane is not None else scalar_val
def rpc(self) -> UOp:
"""Read PC as 64-bit byte address."""
@@ -509,36 +515,9 @@ def _compile_smem(inst: SMEM, ctx: _Ctx) -> UOp:
return UOp.sink(*stores, *ctx.inc_pc())
def _compile_sop(inst: SOP1 | SOP2 | SOPC | SOPK, ctx: _Ctx) -> UOp:
sizes = getattr(inst, 'op_regs', {})
bits = inst.canonical_op_bits
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
# Read source operands dynamically
def rsrc_dyn_scalar(off: UOp, is_64bit: bool) -> UOp:
"""Read scalar source with dynamic offset (SGPR or inline constant).
For SOP, off is always 0-255 (SGPR or inline constant, never VGPR).
SGPR buffer has 260 entries: 0-127=SGPRs, 128-255=inline constants, 256-259=special."""
is_sgpr = off < _c(128)
# For 64-bit: read SGPR pair if off < 128, else compute inline constant as 64-bit
# (can't just read from buffer since buffer has 32-bit values)
if is_64bit:
sgpr_val = _u64(ctx.rsgpr_dyn(off), ctx.rsgpr_dyn(off + _c(1)))
# Build inline constant: 128-192 = 0-64, 193-208 = -1 to -16
inline_val = (off - _c(128)).cast(dtypes.uint64) # positive inline 0-64
neg_val = (_c(192) - off).cast(dtypes.int64).cast(dtypes.uint64) # negative -1 to -16
lit_val = literal.cast(dtypes.uint64) if literal is not None else UOp.const(dtypes.uint64, 0)
# Select between sgpr, positive inline, negative inline, or literal
is_neg_inline = (off >= _c(193)) & (off < _c(209))
is_literal = off.eq(_c(255)) if literal is not None else UOp.const(dtypes.bool, False)
val = is_sgpr.where(sgpr_val, is_neg_inline.where(neg_val, is_literal.where(lit_val, inline_val)))
return val
# 32-bit: read from SGPR buffer (inline constants 128-255 are pre-populated)
# off is always 0-255 for SOP, all valid SGPR indices
sgpr_val = ctx.rsgpr_dyn(off)
# Handle literal (255) - literal value overrides the pre-populated 0
if literal is not None:
sgpr_val = off.eq(_c(255)).where(literal, sgpr_val)
return sgpr_val
if isinstance(inst, SOPK):
sdst_off = ctx.inst_field(SOPK.sdst)
simm16 = ctx.inst_field(SOPK.simm16)
@@ -549,21 +528,21 @@ def _compile_sop(inst: SOP1 | SOP2 | SOPC | SOPK, ctx: _Ctx) -> UOp:
elif isinstance(inst, SOP1):
sdst_off = ctx.inst_field(SOP1.sdst)
ssrc0_off = ctx.inst_field(SOP1.ssrc0)
srcs = {'S0': rsrc_dyn_scalar(ssrc0_off, sizes.get('ssrc0', 1) == 2)}
dst_off, dst_size = sdst_off, sizes.get('sdst', 1)
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal)}
dst_off, dst_size = sdst_off, bits['d'] // 32
elif isinstance(inst, SOP2):
sdst_off = ctx.inst_field(SOP2.sdst)
ssrc0_off = ctx.inst_field(SOP2.ssrc0)
ssrc1_off = ctx.inst_field(SOP2.ssrc1)
srcs = {'S0': rsrc_dyn_scalar(ssrc0_off, sizes.get('ssrc0', 1) == 2),
'S1': rsrc_dyn_scalar(ssrc1_off, sizes.get('ssrc1', 1) == 2)}
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal),
'S1': ctx.rsrc_dyn(ssrc1_off, None, bits['s1'], literal)}
if literal is not None: srcs['SIMM32'] = literal
dst_off, dst_size = sdst_off, sizes.get('sdst', 1)
dst_off, dst_size = sdst_off, bits['d'] // 32
elif isinstance(inst, SOPC):
ssrc0_off = ctx.inst_field(SOPC.ssrc0)
ssrc1_off = ctx.inst_field(SOPC.ssrc1)
srcs = {'S0': rsrc_dyn_scalar(ssrc0_off, sizes.get('ssrc0', 1) == 2),
'S1': rsrc_dyn_scalar(ssrc1_off, sizes.get('ssrc1', 1) == 2)}
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal),
'S1': ctx.rsrc_dyn(ssrc1_off, None, bits['s1'], literal)}
dst_off, dst_size = _c(0), 0 # SOPC writes to SCC, not sdst
else:
raise RuntimeError(f"unknown SOP type: {type(inst).__name__}")
@@ -573,18 +552,17 @@ def _compile_sop(inst: SOP1 | SOP2 | SOPC | SOPK, ctx: _Ctx) -> UOp:
def _compile_vop12(inst: VOP1 | VOP1_SDST | VOP2, ctx: _Ctx) -> UOp:
op_name = _op_name(inst)
if op_name == 'V_READFIRSTLANE_B32_E32': return ctx.compile_lane_pcode(inst.op, inst)
lane, exec_mask, sizes = ctx.range(), ctx.rsgpr_dyn(_c(EXEC_LO.offset)), getattr(inst, 'op_regs', {})
is_16bit = _is_16bit_op(op_name)
lane, exec_mask, bits = ctx.range(), ctx.rsgpr_dyn(_c(EXEC_LO.offset)), inst.canonical_op_bits
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
vdst_reg = ctx.inst_field(VOP1.vdst)
write_hi_half = is_16bit and (vdst_reg >= _c(128))
write_hi_half = bits['d'] == 16 and (vdst_reg >= _c(128))
if isinstance(write_hi_half, UOp): vdst_reg = write_hi_half.where(vdst_reg - _c(128), vdst_reg)
elif write_hi_half: vdst_reg -= 128
if isinstance(inst, VOP1):
# Handle VOP1 hi-half source operand (src0 >= v[128] for 16-bit ops)
src0_off = ctx.inst_field(VOP1.src0)
s0 = ctx.rsrc_dyn_sized(src0_off, lane, sizes, 'src0', f16=is_16bit, literal=literal)
if is_16bit:
s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal)
if bits['s0'] == 16:
src0_hi = src0_off >= _c(384)
# Only compute hi-half when src0_off >= 384, use guarded index to prevent OOB access
src0_reg = src0_hi.where(src0_off - _c(384), _c(0))
@@ -592,14 +570,14 @@ def _compile_vop12(inst: VOP1 | VOP1_SDST | VOP2, ctx: _Ctx) -> UOp:
srcs = {'S0': s0}
else:
vsrc1_reg = ctx.inst_field(VOP2.vsrc1)
vsrc1_hi = is_16bit and (vsrc1_reg >= _c(128))
vsrc1_hi = bits['s0'] == 16 and (vsrc1_reg >= _c(128))
vsrc1_actual = _cond(vsrc1_hi, vsrc1_reg - _c(128), vsrc1_reg)
s1 = _cond_hi16(vsrc1_hi, ctx.rvgpr_dyn(vsrc1_actual, lane))
d0 = _cond_hi16(write_hi_half, ctx.rvgpr_dyn(vdst_reg, lane)) # FMAC/FMAMK hi-half dest needs hi-half accumulator
# Handle VOP2 hi-half src0 operand (src0 >= v[128] for 16-bit ops)
src0_off = ctx.inst_field(VOP2.src0)
s0 = ctx.rsrc_dyn(src0_off, lane, bits=16 if is_16bit else 32, literal=literal)
if is_16bit:
s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal)
if bits['s0'] == 16:
src0_hi = src0_off >= _c(384)
# Only compute hi-half when src0_off >= 384, use guarded index to prevent OOB access
src0_reg = src0_hi.where(src0_off - _c(384), _c(0))
@@ -611,41 +589,36 @@ def _compile_vop12(inst: VOP1 | VOP1_SDST | VOP2, ctx: _Ctx) -> UOp:
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=write_hi_half)
def _compile_vopc(inst: VOPC | VOP3, ctx: _Ctx, opsel: int = 0, abs_bits: int = 0, neg_bits: int = 0) -> UOp:
exec_mask, op_name = ctx.rsgpr_dyn(_c(EXEC_LO.offset)), _op_name(inst)
is_cmpx, is_16bit, is_64bit = 'CMPX' in op_name, _is_16bit_op(op_name), 'F64' in op_name
is_vopc = hasattr(inst, 'vsrc1') # VOPC (e32) vs VOP3 (e64) format
exec_mask, op_name, bits = ctx.rsgpr_dyn(_c(EXEC_LO.offset)), _op_name(inst), inst.canonical_op_bits
is_cmpx, is_vopc = 'CMPX' in op_name, hasattr(inst, 'vsrc1') # is_vopc: e32 vs e64
# Handle both VOPC (vsrc1) and VOP3 (src1) instruction formats - read operands dynamically
if is_vopc:
src0_off = ctx.inst_field(VOPC.src0)
vsrc1_off = ctx.inst_field(VOPC.vsrc1)
# For 16-bit ops, vsrc1 >= 128 means hi-half of v[vsrc1-128]
if is_16bit:
if bits['s0'] == 16:
vsrc1_hi = vsrc1_off >= _c(128)
src1_off = _c(256) + vsrc1_hi.where(vsrc1_off - _c(128), vsrc1_off)
else:
vsrc1_hi = False
src1_off = _c(256) + vsrc1_off
src0_bits, src1_bits = (64, 64) if is_64bit else (32, 32)
else:
src0_off = ctx.inst_field(VOP3.src0)
src1_off = ctx.inst_field(VOP3.src1)
dst_off = ctx.inst_field(VOP3.vdst)
vsrc1_hi = False
_, src0_bits, _ = inst.operands.get('src0', (None, 32, None))
_, src1_bits, _ = inst.operands.get('src1', (None, 32, None))
is_16bit = src0_bits == 16 or src1_bits == 16
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
is_float, pcode = any(x in op_name for x in ('_F32', '_F64', '_F16')), get_pcode(inst.op)
is_float, is_f64, pcode = any(x in op_name for x in ('_F32', '_F64', '_F16')), '_F64' in op_name, get_pcode(inst.op)
def get_cmp_bit(lane) -> UOp:
lc = lane.cast(dtypes.int) if isinstance(lane, UOp) else _c(lane, dtypes.int)
s0 = ctx.rsrc_dyn(src0_off, lc, src0_bits, literal)
s1 = _cond_hi16(vsrc1_hi, ctx.rsrc_dyn(src1_off, lc, src1_bits, literal)) if is_16bit else ctx.rsrc_dyn(src1_off, lc, src1_bits, literal)
if is_16bit and opsel: s0, s1 = _apply_opsel(s0, 0, opsel), _apply_opsel(s1, 1, opsel)
if is_float and (abs_bits or neg_bits):
s0 = _apply_src_mods(s0, 0, abs_bits, neg_bits, is_16bit, src0_bits == 64)
s1 = _apply_src_mods(s1, 1, abs_bits, neg_bits, is_16bit, src1_bits == 64)
s0 = ctx.rsrc_dyn(src0_off, lc, bits['s0'], literal, is_f64)
s1 = _cond_hi16(vsrc1_hi, ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64)) if bits['s0'] == 16 else ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64)
if bits['s0'] == 16 and opsel: s0, s1 = _apply_opsel(s0, 0, opsel), _apply_opsel(s1, 1, opsel)
if is_float:
s0 = _apply_src_mods(s0, 0, abs_bits, neg_bits, bits['s0'])
s1 = _apply_src_mods(s1, 1, abs_bits, neg_bits, bits['s1'])
for dest, val in parse_pcode(pcode, {'S0': s0, 'S1': s1, 'laneId': lc})[1]:
if '[laneId]' in dest and ('D0' in dest or 'EXEC' in dest): return val.cast(dtypes.uint32)
return _c(0)
@@ -664,7 +637,7 @@ def _compile_vopc(inst: VOPC | VOP3, ctx: _Ctx, opsel: int = 0, abs_bits: int =
def _compile_vop3(inst: VOP3, ctx: _Ctx) -> UOp:
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
sizes = getattr(inst, 'op_regs', {})
bits = inst.canonical_op_bits
opsel, op_name = getattr(inst, 'opsel', 0) or 0, _op_name(inst)
# Lane operations
@@ -677,53 +650,48 @@ def _compile_vop3(inst: VOP3, ctx: _Ctx) -> UOp:
# Regular VOP3 - read operands dynamically
lane = ctx.range()
is_f16_op = 'F16' in op_name
vdst_reg = ctx.inst_field(VOP3.vdst)
src0_off = ctx.inst_field(VOP3.src0)
src1_off = ctx.inst_field(VOP3.src1)
src2_off = ctx.inst_field(VOP3.src2) if inst.src2 is not None else None
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
src0 = ctx.rsrc_dyn_sized(src0_off, lane, sizes, 'src0', f16=is_f16_op, literal=literal)
src1 = ctx.rsrc_dyn_sized(src1_off, lane, sizes, 'src1', f16=is_f16_op, literal=literal)
src2 = ctx.rsrc_dyn_sized(src2_off, lane, sizes, 'src2', f16=is_f16_op, literal=literal) if src2_off is not None else None
if _is_16bit_op(op_name):
src0, src1 = _apply_opsel(src0, 0, opsel), _apply_opsel(src1, 1, opsel)
if src2 is not None: src2 = _apply_opsel(src2, 2, opsel)
ops = inst.canonical_operands
src0 = ctx.rsrc_dyn(ctx.inst_field(VOP3.src0), lane, bits['s0'], literal, 's0' in ops and ops['s0'][0] == Fmt.FMT_NUM_F64)
src1 = ctx.rsrc_dyn(ctx.inst_field(VOP3.src1), lane, bits['s1'], literal, 's1' in ops and ops['s1'][0] == Fmt.FMT_NUM_F64)
src2 = ctx.rsrc_dyn(ctx.inst_field(VOP3.src2), lane, bits['s2'], literal, 's2' in ops and ops['s2'][0] == Fmt.FMT_NUM_F64)
if bits['s0'] == 16:
src0 = _apply_opsel(src0, 0, opsel)
src1 = _apply_opsel(src1, 1, opsel)
src2 = _apply_opsel(src2, 2, opsel)
abs_bits, neg_bits = getattr(inst, 'abs', 0) or 0, getattr(inst, 'neg', 0) or 0
is_16bit_op = _is_16bit_op(op_name)
if abs_bits or neg_bits:
src0 = _apply_src_mods(src0, 0, abs_bits, neg_bits, is_16bit_op, sizes.get('src0', 1) == 2)
if src1 is not None: src1 = _apply_src_mods(src1, 1, abs_bits, neg_bits, is_16bit_op, sizes.get('src1', 1) == 2)
if src2 is not None: src2 = _apply_src_mods(src2, 2, abs_bits, neg_bits, is_16bit_op, sizes.get('src2', 1) == 2)
srcs = {'S0': src0, 'S1': src1}
if src2 is not None: srcs['S2'] = src2
src0 = _apply_src_mods(src0, 0, abs_bits, neg_bits, bits['s0'])
src1 = _apply_src_mods(src1, 1, abs_bits, neg_bits, bits['s1'])
src2 = _apply_src_mods(src2, 2, abs_bits, neg_bits, bits['s2'])
srcs = {'S0': src0, 'S1': src1, 'S2': src2}
if inst.op in (VOP3Op.V_CNDMASK_B32_E64, VOP3Op.V_CNDMASK_B16) and src2 is not None: srcs['VCC'] = src2
# FMAC instructions need D0 (accumulator) from destination register
if 'FMAC' in op_name: srcs['D0'] = ctx.rvgpr_dyn(vdst_reg, lane)
opsel_dst_hi = bool(opsel & 0b1000) and _is_16bit_op(op_name)
opsel_dst_hi = bool(opsel & 0b1000) and bits['d'] == 16
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=opsel_dst_hi, clmp=getattr(inst, 'clmp', 0))
def _compile_vop3sd(inst: VOP3SD, ctx: _Ctx) -> UOp:
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
sizes, pcode = getattr(inst, 'op_regs', {}), get_pcode(inst.op)
bits, pcode, ops = inst.canonical_op_bits, get_pcode(inst.op), inst.canonical_operands
# Read operands dynamically from instruction encoding
vdst_reg = ctx.inst_field(VOP3SD.vdst)
sdst_off = ctx.inst_field(VOP3SD.sdst)
src0_off = ctx.inst_field(VOP3SD.src0)
src1_off = ctx.inst_field(VOP3SD.src1)
src2_off = ctx.inst_field(VOP3SD.src2) if inst.src2 is not None else None
vdst_reg, sdst_off = ctx.inst_field(VOP3SD.vdst), ctx.inst_field(VOP3SD.sdst)
src0_off, src1_off, src2_off = ctx.inst_field(VOP3SD.src0), ctx.inst_field(VOP3SD.src1), ctx.inst_field(VOP3SD.src2)
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
has_carry_in = 'src2' in inst.operands and inst.operands['src2'][2] == OpType.OPR_SREG
vcc_in_off = src2_off if has_carry_in and src2_off is not None else sdst_off
has_carry_in = 's2' in ops and ops['s2'][2] == OpType.OPR_SREG
vcc_in_off = src2_off if has_carry_in else sdst_off
def load_srcs(lane_uop):
ret = {'VCC': ctx.rsgpr_dyn(vcc_in_off), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane_uop}
ret['S0'] = ctx.rsrc_dyn(src0_off, lane_uop, bits['s0'], literal, ops['s0'][0] == Fmt.FMT_NUM_F64)
ret['S1'] = ctx.rsrc_dyn(src1_off, lane_uop, bits['s1'], literal, ops['s1'][0] == Fmt.FMT_NUM_F64)
if 's2' in ops: ret['S2'] = ctx.rsrc_dyn(src2_off, lane_uop, bits['s2'], literal, ops['s2'][0] == Fmt.FMT_NUM_F64)
return ret
lane = ctx.range()
src0 = ctx.rsrc_dyn_sized(src0_off, lane, sizes, 'src0', literal=literal)
src1 = ctx.rsrc_dyn_sized(src1_off, lane, sizes, 'src1', literal=literal)
src2 = ctx.rsrc_dyn_sized(src2_off, lane, sizes, 'src2', literal=literal) if src2_off is not None else None
srcs = {'S0': src0, 'S1': src1, 'VCC': ctx.rsgpr_dyn(vcc_in_off), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane}
if src2 is not None: srcs['S2'] = src2
srcs = load_srcs(lane)
_, assigns = parse_pcode(pcode, srcs)
has_per_lane_vcc = any('[laneId]' in dest for dest, _ in assigns if dest.startswith('VCC') or dest.startswith('D0.u64'))
@@ -731,23 +699,15 @@ def _compile_vop3sd(inst: VOP3SD, ctx: _Ctx) -> UOp:
# VCC computation: RANGE+REDUCE gets axis ID first (lower ID = runs first)
# This ensures VCC reads source values BEFORE VGPR stores modify them
def get_vcc_bit(lane_uop) -> UOp:
s0, s1 = ctx.rsrc_dyn_sized(src0_off, lane_uop, sizes, 'src0', literal=literal), ctx.rsrc_dyn_sized(src1_off, lane_uop, sizes, 'src1', literal=literal)
s2 = ctx.rsrc_dyn_sized(src2_off, lane_uop, sizes, 'src2', literal=literal) if src2_off is not None else None
lane_srcs = {'S0': s0, 'S1': s1, 'VCC': ctx.rsgpr_dyn(vcc_in_off), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane_uop}
if s2 is not None: lane_srcs['S2'] = s2
vcc_bit = _c(0)
for dest, val in parse_pcode(pcode, lane_srcs)[1]:
for dest, val in parse_pcode(pcode, load_srcs(lane_uop))[1]:
if dest.startswith('VCC') or (dest.startswith('D0.u64') and '[laneId]' in dest): vcc_bit = val.cast(dtypes.uint32)
return vcc_bit
final_vcc = ctx.unroll_lanes(get_vcc_bit, exec_mask)
# VGPR stores: RANGE gets axis ID second (higher ID = runs after VCC loop)
lane3 = ctx.range()
s0, s1 = ctx.rsrc_dyn_sized(src0_off, lane3, sizes, 'src0', literal=literal), ctx.rsrc_dyn_sized(src1_off, lane3, sizes, 'src1', literal=literal)
s2 = ctx.rsrc_dyn_sized(src2_off, lane3, sizes, 'src2', literal=literal) if src2_off is not None else None
lane_srcs = {'S0': s0, 'S1': s1, 'VCC': ctx.rsgpr_dyn(vcc_in_off), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane3}
if s2 is not None: lane_srcs['S2'] = s2
d0_val = None
for dest, val in parse_pcode(pcode, lane_srcs)[1]:
for dest, val in parse_pcode(pcode, load_srcs(lane3))[1]:
if dest.startswith('D0') and '[laneId]' not in dest: d0_val = val
vgpr_stores = []
if d0_val is not None:
@@ -763,62 +723,52 @@ def _compile_vop3sd(inst: VOP3SD, ctx: _Ctx) -> UOp:
else:
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, sdst_reg=inst.sdst.offset)
def _compile_vop3p(inst: VOP3P, ctx: _Ctx) -> UOp:
lane, exec_mask = ctx.range(), ctx.rsgpr_dyn(_c(EXEC_LO.offset))
# Read register fields dynamically for deduplication
def _compile_wmma(inst: VOP3P, ctx: _Ctx) -> UOp:
op_name = _op_name(inst)
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
vdst_reg = ctx.inst_field(VOP3P.vdst)
src0_off = ctx.inst_field(VOP3P.src0)
src1_off = ctx.inst_field(VOP3P.src1)
src2_off = ctx.inst_field(VOP3P.src2) if hasattr(inst, 'src2') and inst.src2 is not None else None
src0 = ctx.rsrc_dyn(src0_off, lane, 16)
src1 = ctx.rsrc_dyn(src1_off, lane, 16)
src2 = ctx.rsrc_dyn(src2_off, lane, 16) if src2_off is not None else None
src0_r = ctx.inst_field(VOP3P.src0) - _c(256)
src1_r = ctx.inst_field(VOP3P.src1) - _c(256)
src2_r = ctx.inst_field(VOP3P.src2) - _c(256)
is_f16_output = 'F16_16X16X16_F16' in op_name or 'BF16_16X16X16_BF16' in op_name # F16/BF16 output vs F32 output
is_bf16 = 'BF16' in op_name
cvt = _FUNCS['bf16_to_f32'] if is_bf16 else _FUNCS['f16_to_f32']
def read_f16_mat(src):
return [f for l in range(16) for r in range(8) for v in [ctx.rvgpr_dyn(src + _c(r), UOp.const(dtypes.int, l))]
for f in [cvt(v & UOp.const(dtypes.uint32, 0xFFFF)), cvt(v >> UOp.const(dtypes.uint32, 16))]]
mat_a, mat_b = read_f16_mat(src0_r), read_f16_mat(src1_r)
if is_f16_output:
# RDNA3 F16/BF16 output: uses 8 VGPRs (same as F32), f16/bf16 values in lo 16 bits of each VGPR
# Layout: half16 per lane where even indices (0,2,4,...,14) = lo halves of VGPRs 0-7
# Read accumulator: 8 regs × 32 lanes, each VGPR's lo 16 bits holds one f16/bf16
mat_c = [cvt(ctx.rvgpr_dyn(src2_r + _c(i // 32), UOp.const(dtypes.int, i % 32)) & UOp.const(dtypes.uint32, 0xFFFF))
for i in range(256)]
mat_d = [sum(mat_a[row*16+k] * mat_b[col*16+k] for k in range(16)) + mat_c[row*16+col] for row in range(16) for col in range(16)]
# Write f16/bf16 results to lo 16 bits of each VGPR
def f32_to_f16_bits(v: UOp) -> UOp: return v.cast(dtypes.half).bitcast(dtypes.uint16).cast(dtypes.uint32)
def f32_to_bf16_bits(v: UOp) -> UOp: return (v.bitcast(dtypes.uint32) >> UOp.const(dtypes.uint32, 16)) & UOp.const(dtypes.uint32, 0xFFFF)
out_cvt = f32_to_bf16_bits if is_bf16 else f32_to_f16_bits
stores = [ctx.wvgpr_dyn(vdst_reg + _c(i // 32), UOp.const(dtypes.int, i % 32), out_cvt(mat_d[i]), exec_mask) for i in range(256)]
else:
# F32 output: accumulator and output are f32
mat_c = [ctx.rvgpr_dyn(src2_r + _c(i // 32), UOp.const(dtypes.int, i % 32)).bitcast(dtypes.float32) for i in range(256)]
mat_d = [sum(mat_a[row*16+k] * mat_b[col*16+k] for k in range(16)) + mat_c[row*16+col] for row in range(16) for col in range(16)]
stores = [ctx.wvgpr_dyn(vdst_reg + _c(i // 32), UOp.const(dtypes.int, i % 32), mat_d[i].bitcast(dtypes.uint32), exec_mask) for i in range(256)]
return UOp.sink(*stores, *ctx.inc_pc())
def _compile_vop3p(inst: VOP3P, ctx: _Ctx) -> UOp:
op_name = _op_name(inst)
if 'WMMA' in op_name and ('16X16X16_F16' in op_name or '16X16X16_BF16' in op_name): return _compile_wmma(inst, ctx)
lane = ctx.range()
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
vdst_reg = ctx.inst_field(VOP3P.vdst)
src0 = ctx.rsrc_dyn(ctx.inst_field(VOP3P.src0), lane, 16)
src1 = ctx.rsrc_dyn(ctx.inst_field(VOP3P.src1), lane, 16)
src2 = ctx.rsrc_dyn(ctx.inst_field(VOP3P.src2), lane, 16)
opsel, opsel_hi = getattr(inst, 'opsel', 0) or 0, getattr(inst, 'opsel_hi', 3) if getattr(inst, 'opsel_hi', 3) is not None else 3
opsel_hi2 = getattr(inst, 'opsel_hi2', 1) if getattr(inst, 'opsel_hi2', 1) is not None else 1
neg, neg_hi = getattr(inst, 'neg', 0) or 0, getattr(inst, 'neg_hi', 0) or 0
def get_half_bits(val: UOp, use_hi: bool, apply_neg: bool = False) -> UOp:
bits = ((val >> UOp.const(dtypes.uint32, 16)) if use_hi else val) & UOp.const(dtypes.uint32, 0xFFFF)
if apply_neg: bits = bits.cast(dtypes.uint16).bitcast(dtypes.half).neg().bitcast(dtypes.uint16).cast(dtypes.uint32)
return bits
def build_remapped_src(src: UOp, opsel_lo_bit: int, opsel_hi_bit: int, neg_lo_bit: int, neg_hi_bit: int) -> UOp:
return get_half_bits(src, bool(opsel_lo_bit), bool(neg_lo_bit)) | (get_half_bits(src, bool(opsel_hi_bit), bool(neg_hi_bit)) << UOp.const(dtypes.uint32, 16))
s0_new = build_remapped_src(src0, opsel & 1, opsel_hi & 1, neg & 1, neg_hi & 1)
s1_new = build_remapped_src(src1, opsel & 2, opsel_hi & 2, neg & 2, neg_hi & 2)
s2_new = build_remapped_src(src2, opsel & 4, 1 if opsel_hi2 else 0, neg & 4, neg_hi & 4) if src2 is not None else None
op_name = _op_name(inst)
# WMMA: Wave Matrix Multiply-Accumulate
if 'WMMA' in op_name and ('16X16X16_F16' in op_name or '16X16X16_BF16' in op_name):
# Dynamic register fields for deduplication
src0_r = ctx.inst_field(VOP3P.src0) - _c(256)
src1_r = ctx.inst_field(VOP3P.src1) - _c(256)
src2_r = ctx.inst_field(VOP3P.src2) - _c(256)
is_f16_output = 'F16_16X16X16_F16' in op_name or 'BF16_16X16X16_BF16' in op_name # F16/BF16 output vs F32 output
is_bf16 = 'BF16' in op_name
cvt = _FUNCS['bf16_to_f32'] if is_bf16 else _FUNCS['f16_to_f32']
def read_f16_mat(src):
return [f for l in range(16) for r in range(8) for v in [ctx.rvgpr_dyn(src + _c(r), UOp.const(dtypes.int, l))]
for f in [cvt(v & UOp.const(dtypes.uint32, 0xFFFF)), cvt(v >> UOp.const(dtypes.uint32, 16))]]
mat_a, mat_b = read_f16_mat(src0_r), read_f16_mat(src1_r)
if is_f16_output:
# RDNA3 F16/BF16 output: uses 8 VGPRs (same as F32), f16/bf16 values in lo 16 bits of each VGPR
# Layout: half16 per lane where even indices (0,2,4,...,14) = lo halves of VGPRs 0-7
# Read accumulator: 8 regs × 32 lanes, each VGPR's lo 16 bits holds one f16/bf16
mat_c = [cvt(ctx.rvgpr_dyn(src2_r + _c(i // 32), UOp.const(dtypes.int, i % 32)) & UOp.const(dtypes.uint32, 0xFFFF))
for i in range(256)]
mat_d = [sum(mat_a[row*16+k] * mat_b[col*16+k] for k in range(16)) + mat_c[row*16+col] for row in range(16) for col in range(16)]
# Write f16/bf16 results to lo 16 bits of each VGPR
def f32_to_f16_bits(v: UOp) -> UOp: return v.cast(dtypes.half).bitcast(dtypes.uint16).cast(dtypes.uint32)
def f32_to_bf16_bits(v: UOp) -> UOp: return (v.bitcast(dtypes.uint32) >> UOp.const(dtypes.uint32, 16)) & UOp.const(dtypes.uint32, 0xFFFF)
out_cvt = f32_to_bf16_bits if is_bf16 else f32_to_f16_bits
stores = [ctx.wvgpr_dyn(vdst_reg + _c(i // 32), UOp.const(dtypes.int, i % 32),
out_cvt(mat_d[i]), exec_mask) for i in range(256)]
else:
# F32 output: accumulator and output are f32
mat_c = [ctx.rvgpr_dyn(src2_r + _c(i // 32), UOp.const(dtypes.int, i % 32)).bitcast(dtypes.float32) for i in range(256)]
mat_d = [sum(mat_a[row*16+k] * mat_b[col*16+k] for k in range(16)) + mat_c[row*16+col] for row in range(16) for col in range(16)]
stores = [ctx.wvgpr_dyn(vdst_reg + _c(i // 32), UOp.const(dtypes.int, i % 32), mat_d[i].bitcast(dtypes.uint32), exec_mask) for i in range(256)]
return UOp.sink(*stores, *ctx.inc_pc())
if 'FMA_MIX' in op_name:
combined_opsel_hi = (opsel_hi & 0x3) | ((opsel_hi2 & 0x1) << 2)
@@ -836,12 +786,20 @@ def _compile_vop3p(inst: VOP3P, ctx: _Ctx) -> UOp:
return v ^ UOp.const(dtypes.uint32, 0x00008000) # f16 lo neg
s0_mod = apply_neg_mix(apply_abs(src0, 1, 1, 1), 1, 1, 1)
s1_mod = apply_neg_mix(apply_abs(src1, 2, 2, 2), 2, 2, 2)
s2_mod = apply_neg_mix(apply_abs(src2, 4, 4, 4), 4, 4, 4) if src2 is not None else UOp.const(dtypes.uint32, 0)
s2_mod = apply_neg_mix(apply_abs(src2, 4, 4, 4), 4, 4, 4)
srcs = {'S0': s0_mod, 'S1': s1_mod, 'S2': s2_mod,
'OPSEL_HI': UOp.const(dtypes.uint32, combined_opsel_hi), 'OPSEL': UOp.const(dtypes.uint32, opsel)}
else:
srcs = {'S0': s0_new, 'S1': s1_new}
if s2_new is not None: srcs['S2'] = s2_new
def get_half_bits(val: UOp, use_hi: bool, apply_neg: bool = False) -> UOp:
bits = ((val >> UOp.const(dtypes.uint32, 16)) if use_hi else val) & UOp.const(dtypes.uint32, 0xFFFF)
if apply_neg: bits = bits.cast(dtypes.uint16).bitcast(dtypes.half).neg().bitcast(dtypes.uint16).cast(dtypes.uint32)
return bits
def build_remapped_src(src: UOp, opsel_lo_bit: int, opsel_hi_bit: int, neg_lo_bit: int, neg_hi_bit: int) -> UOp:
return get_half_bits(src, bool(opsel_lo_bit), bool(neg_lo_bit)) | (get_half_bits(src, bool(opsel_hi_bit), bool(neg_hi_bit)) << UOp.const(dtypes.uint32, 16))
s0_new = build_remapped_src(src0, opsel & 1, opsel_hi & 1, neg & 1, neg_hi & 1)
s1_new = build_remapped_src(src1, opsel & 2, opsel_hi & 2, neg & 2, neg_hi & 2)
s2_new = build_remapped_src(src2, opsel & 4, 1 if opsel_hi2 else 0, neg & 4, neg_hi & 4)
srcs = {'S0': s0_new, 'S1': s1_new, 'S2': s2_new}
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask)
def _compile_vopd(inst: VOPD, ctx: _Ctx) -> UOp:
@@ -904,9 +862,8 @@ def _compile_mem_op(inst: DS | FLAT | GLOBAL | SCRATCH, ctx: _Ctx) -> UOp:
# Dynamic saddr - read field, NULL (124) or >= 128 means no saddr
saddr_reg = ctx.inst_field(type(inst).saddr) if hasattr(inst, 'saddr') else None
# Data width
ndwords = 4 if '_B128' in op_name or 'B128' in op_name else 3 if '_B96' in op_name or 'B96' in op_name else 2 if '_B64' in op_name or 'B64' in op_name else 1
is_64bit = ndwords >= 2 or '_U64' in op_name or '_I64' in op_name or '_F64' in op_name
# Data width from canonical_op_bits (32/64/96/128), default to 32 for untyped ops
data_bits_mem = inst.canonical_op_bits.get('data', 32)
is_atomic, glc = 'ATOMIC' in op_name, getattr(inst, 'glc', 0)
has_data1 = is_lds and hasattr(inst, 'data1') and inst.data1 is not None
data1_reg = ctx.inst_field(DS.data1) if is_lds else _c(0)
@@ -940,10 +897,13 @@ def _compile_mem_op(inst: DS | FLAT | GLOBAL | SCRATCH, ctx: _Ctx) -> UOp:
def make_srcs(lane: UOp) -> dict:
addr = make_addr(lane)
if is_lds:
if 'B128' in op_name or 'B96' in op_name:
if data_bits_mem == 128:
data = {'DATA': ctx.rvgpr_dyn(vdata_reg, lane), 'DATA1': ctx.rvgpr_dyn(vdata_reg + _c(1), lane),
'DATA2': ctx.rvgpr_dyn(vdata_reg + _c(2), lane), 'DATA3': ctx.rvgpr_dyn(vdata_reg + _c(3), lane)}
elif 'B32' in op_name:
elif data_bits_mem == 96:
data = {'DATA': ctx.rvgpr_dyn(vdata_reg, lane), 'DATA1': ctx.rvgpr_dyn(vdata_reg + _c(1), lane),
'DATA2': ctx.rvgpr_dyn(vdata_reg + _c(2), lane)}
elif data_bits_mem == 32:
data = {'DATA': ctx.rvgpr_dyn(vdata_reg, lane), 'DATA2': ctx.rvgpr_dyn(data1_reg, lane) if has_data1 else UOp.const(dtypes.uint32, 0)}
else:
data = {'DATA': _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)),
@@ -951,26 +911,27 @@ def _compile_mem_op(inst: DS | FLAT | GLOBAL | SCRATCH, ctx: _Ctx) -> UOp:
return {'ADDR': addr, 'ADDR_BASE': addr, 'OFFSET': offset, 'OFFSET0': offset0, 'OFFSET1': offset1, '_lds': mem, 'laneId': lane, **data}
active = _lane_active(exec_mask, lane)
if is_atomic:
return {'ADDR': addr, 'DATA': _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)) if is_64bit else ctx.rvgpr_dyn(vdata_reg, lane),
return {'ADDR': addr, 'DATA': _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)) if data_bits_mem == 64 else ctx.rvgpr_dyn(vdata_reg, lane),
'_vmem': mem, '_active': active, 'laneId': lane}
vdata = ctx.rvgpr_dyn(vdata_reg, lane).cast(dtypes.uint64) if 'STORE' in op_name else ctx.rvgpr_dyn(vdst_reg, lane) if 'D16' in op_name else UOp.const(dtypes.uint32, 0)
if 'STORE' in op_name and ndwords >= 2: vdata = vdata | (ctx.rvgpr_dyn(vdata_reg + _c(1), lane).cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
if 'STORE' in op_name and data_bits_mem >= 64: vdata = vdata | (ctx.rvgpr_dyn(vdata_reg + _c(1), lane).cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
srcs = {'ADDR': addr, 'VDATA': vdata, '_vmem': mem, '_active': active, 'laneId': lane}
for i in range(ndwords): srcs[f'VDATA{i}'] = ctx.rvgpr_dyn(vdata_reg + _c(i), lane) if 'STORE' in op_name else UOp.const(dtypes.uint32, 0)
for i in range(data_bits_mem // 32): srcs[f'VDATA{i}'] = ctx.rvgpr_dyn(vdata_reg + _c(i), lane) if 'STORE' in op_name else UOp.const(dtypes.uint32, 0)
return srcs
def make_stores(dest: str, val: UOp, lane: UOp, active: UOp, writes_return_data: bool) -> list[UOp]:
# Parse bit width from dest format: MEM[...].b32 or RETURN_DATA[63:32].b64
parts = dest.rsplit('.', 1)
data_bits = int(parts[1][1:]) if len(parts) == 2 else 32
if dest.startswith('MEM['):
if is_lds or is_atomic: return _write_val(dest, val[1], wmem, val[0], active, is_mem=True)
data_bits = 8 if '.b8' in dest else 16 if '.b16' in dest else 64 if '.b64' in dest else 32
if is_lds or is_atomic: return _write_val(data_bits, val[1], wmem, val[0], active, is_mem=True)
if is_scratch: return _mem_store_bytes(mem, val[0], val[1], active, data_bits)
return _mem_store(mem, val[0], val[1], active, 64, data_bits)
if dest.startswith('RETURN_DATA') and writes_return_data:
if (m := re.match(r'RETURN_DATA\[(\d+)\s*:\s*(\d+)\]', dest)):
bit_width, dword_idx = int(m.group(1)) - int(m.group(2)) + 1, int(m.group(2)) // 32
is_64 = '.b64' if bit_width == 64 else ''
return _write_val(is_64, val, lambda r, v, l, e: ctx.wvgpr_dyn(r, l, v, e), vdst_reg + _c(dword_idx), lane, exec_mask)
return _write_val(dest, val, lambda r, v, l, e: ctx.wvgpr_dyn(r, l, v, e), vdst_reg, lane, exec_mask)
return _write_val(bit_width, val, lambda r, v, l, e: ctx.wvgpr_dyn(r, l, v, e), vdst_reg + _c(dword_idx), lane, exec_mask)
return _write_val(data_bits, val, lambda r, v, l, e: ctx.wvgpr_dyn(r, l, v, e), vdst_reg, lane, exec_mask)
return []
# DS-specific: check for 2ADDR pattern needing separate ranges

View File

@@ -717,6 +717,22 @@ def _subst_loop_var(line: str, loop_var: str, val: int) -> str:
subst_parts = [str(val) if t.type == 'IDENT' and t.val == loop_var else t.val for t in result_toks if t.type != 'EOF']
return ' '.join(subst_parts)
def _set_bits(old: UOp, val: UOp, width: int, offset: int) -> UOp:
"""Set bits [offset:offset+width) in old to val, masking and shifting appropriately."""
mask = _u32(((1 << width) - 1) << offset)
v = (val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val) & _u32((1 << width) - 1)
return (old & (mask ^ _u32(0xFFFFFFFF))) | (v << _u32(offset))
def _find_paren_end(s: str, start: int = 0, open_ch: str = '(', close_ch: str = ')') -> int:
"""Find index of matching close paren, starting after the open paren at start."""
depth = 0
for j, ch in enumerate(s[start:], start):
if ch == open_ch: depth += 1
elif ch == close_ch:
depth -= 1
if depth == 0: return j
return len(s)
def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: dict | None = None,
assigns: list | None = None) -> tuple[int, dict[str, VarVal], UOp | None]:
"""Parse a block of pcode. Returns (next_line, block_assigns, return_value).
@@ -724,7 +740,6 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
if funcs is None: funcs = _FUNCS
block_assigns: dict[str, VarVal] = {}
i = start
def ctx(): return {**vars, **block_assigns}
while i < len(lines):
line = lines[i]
@@ -738,7 +753,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
# return expr (lambda bodies)
if first == 'return':
rest = line[line.lower().find('return') + 6:].strip()
return i + 1, block_assigns, parse_expr(rest, ctx(), funcs)
return i + 1, block_assigns, parse_expr(rest, vars, funcs)
# for loop
if first == 'for':
@@ -747,21 +762,18 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
p.eat_val('for', 'IDENT')
loop_var = p.eat('IDENT').val
p.eat_val('in', 'IDENT')
if p.at('NUM') and p.peek(1).type == 'QUOTE': p.eat('NUM'); p.eat('QUOTE')
if p.at('NUM'):
start_val = int(p.eat('NUM').val.rstrip('UuLl'))
else:
start_expr = p.parse()
start_val = int(start_expr.arg) if start_expr.op == Ops.CONST else 0
def parse_bound():
if p.at('NUM') and p.peek(1).type == 'QUOTE': p.eat('NUM'); p.eat('QUOTE')
if p.at('NUM'): return int(p.eat('NUM').val.rstrip('UuLl'))
expr = p.parse()
return int(expr.arg) if expr.op == Ops.CONST else 0
start_val = parse_bound()
p.eat('COLON')
if p.at('NUM') and p.peek(1).type == 'QUOTE': p.eat('NUM'); p.eat('QUOTE')
if p.at('NUM'):
end_val = int(p.eat('NUM').val.rstrip('UuLl'))
else:
end_expr = p.parse()
end_val = int(end_expr.arg) if end_expr.op == Ops.CONST else 0
end_val = parse_bound()
# Collect body
i += 1; body_lines, depth = [], 1
i += 1
body_lines: list[str] = []
depth = 1
while i < len(lines) and depth > 0:
btoks = tokenize(lines[i])
if btoks[0].type == 'IDENT':
@@ -775,7 +787,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
if found_var: vars[found_var] = block_assigns[found_var] = _const(dtypes.bool, False)
for loop_i in range(start_val, end_val + 1):
subst_lines = [_subst_loop_var(bl, loop_var, loop_i) for bl in body_lines if not (has_break and bl.strip().lower() == 'break')]
_, iter_assigns, _ = parse_block(subst_lines, 0, {**vars, **block_assigns}, funcs, assigns)
_, iter_assigns, _ = parse_block(subst_lines, 0, vars, funcs, assigns)
if has_break:
assert found_var is not None
found = block_assigns.get(found_var, vars.get(found_var))
@@ -791,7 +803,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
if bl_l.startswith('if ') and bl_l.endswith(' then'):
if any(body_lines[k].strip().lower() == 'break' for k in range(j+1, len(body_lines))):
cond_str = _subst_loop_var(bl.strip()[3:-5].strip(), loop_var, loop_i)
cond = _to_bool(parse_expr(cond_str, {**vars, **block_assigns}, funcs))
cond = _to_bool(parse_expr(cond_str, vars, funcs))
block_assigns[found_var] = vars[found_var] = not_found.where(cond, found)
break
else:
@@ -806,25 +818,17 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
# lambda definition
if first != '{' and '=' in line and 'lambda' in line and any(t.type == 'IDENT' and t.val == 'lambda' for t in toks):
name = toks[0].val
body_start, depth = line[line.find('(', line.find('lambda')):], 0
params_end = 0
for j, ch in enumerate(body_start):
if ch == '(': depth += 1
elif ch == ')':
depth -= 1
if depth == 0: params_end = j + 1; break
body_start = line[line.find('(', line.find('lambda')):]
params_end = _find_paren_end(body_start) + 1
params = [p.strip() for p in body_start[1:params_end-1].split(',') if p.strip()]
rest = body_start[params_end:].strip()
if rest.startswith('('):
depth, body_end = 1, 1
for j, ch in enumerate(rest[1:], 1):
if ch == '(': depth += 1
elif ch == ')':
depth -= 1
if depth == 0: body_end = j; break
body = rest[1:body_end].strip()
if depth > 0:
body_lines_lst = [rest[1:]]
body_end = _find_paren_end(rest)
if body_end < len(rest): # found matching paren on same line
body = rest[1:body_end].strip()
i += 1
else: # multiline body
body_lines_lst, depth = [rest[1:]], 1
i += 1
while i < len(lines) and depth > 0:
for j, ch in enumerate(lines[i]):
@@ -835,21 +839,20 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
else: body_lines_lst.append(lines[i])
i += 1
body = '\n'.join(body_lines_lst).strip()
else: i += 1
vars[name] = ('lambda', params, body)
continue
# MEM assignment: MEM[addr].type (+|-)?= value
if first == 'mem' and toks[1].type == 'LBRACKET':
j, addr_toks = _match_bracket(toks, 1)
addr = parse_tokens(addr_toks, ctx(), funcs)
addr = parse_tokens(addr_toks, vars, funcs)
if j < len(toks) and toks[j].type == 'DOT': j += 1
dt_name = toks[j].val if j < len(toks) and toks[j].type == 'IDENT' else 'u32'
dt, j = DTYPES.get(dt_name, dtypes.uint32), j + 1
compound_op = None
if j < len(toks) and toks[j].type == 'ASSIGN_OP': compound_op = toks[j].val; j += 1
elif j < len(toks) and toks[j].type == 'EQUALS': j += 1
rhs = parse_tokens(toks[j:], ctx(), funcs)
rhs = parse_tokens(toks[j:], vars, funcs)
if compound_op:
mem = vars.get('_vmem') if '_vmem' in vars else vars.get('_lds')
if isinstance(mem, UOp):
@@ -868,7 +871,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
if j < len(toks) and toks[j].type == 'LBRACKET':
j, reg_toks = _match_bracket(toks, j)
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
ln, rg, val = parse_tokens(lane_toks, ctx(), funcs), parse_tokens(reg_toks, ctx(), funcs), parse_tokens(toks[j:], ctx(), funcs)
ln, rg, val = parse_tokens(lane_toks, vars, funcs), parse_tokens(reg_toks, vars, funcs), parse_tokens(toks[j:], vars, funcs)
if assigns is not None: assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}]', (_to_u32(rg) * _u32(32) + _to_u32(ln), val)))
i += 1; continue
@@ -884,7 +887,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
j += 3
if j < len(toks) and toks[j].type == 'RBRACE': j += 1
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
val = parse_tokens(toks[j:], ctx(), funcs)
val = parse_tokens(toks[j:], vars, funcs)
lo_dt, hi_dt = DTYPES.get(lo_type, dtypes.uint64), DTYPES.get(hi_type, dtypes.uint32)
lo_bits = 64 if lo_dt in (dtypes.uint64, dtypes.int64) else 32
lo_val = val.cast(lo_dt) if val.dtype.itemsize * 8 <= lo_bits else (val & _const(val.dtype, (1 << lo_bits) - 1)).cast(lo_dt)
@@ -894,7 +897,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
if assigns is not None: assigns.extend([(f'{lo_var}.{lo_type}', lo_val), (f'{hi_var}.{hi_type}', hi_val)])
i += 1; continue
# Bit slice: var[hi:lo] = value or var.type[hi:lo] = value
# Bit slice/index: var[hi:lo] = value, var.type[hi:lo] = value, or var[expr] = value
if len(toks) >= 5 and toks[0].type == 'IDENT' and (toks[1].type == 'LBRACKET' or (toks[1].type == 'DOT' and toks[3].type == 'LBRACKET')):
bracket_start = 2 if toks[1].type == 'LBRACKET' else 4
j = bracket_start
@@ -902,24 +905,33 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
while j < len(toks) and toks[j].type != 'RBRACKET':
if toks[j].type == 'COLON': colon_pos = j
j += 1
if colon_pos is not None:
var = toks[0].val
if colon_pos is not None: # bit slice: var[hi:lo]
hi_str = ' '.join(t.val for t in toks[bracket_start:colon_pos] if t.type != 'EOF')
lo_str = ' '.join(t.val for t in toks[colon_pos+1:j] if t.type != 'EOF')
try:
hi, lo = max(int(eval(hi_str)), int(eval(lo_str))), min(int(eval(hi_str)), int(eval(lo_str)))
var = toks[0].val
hi_val, lo_val = int(eval(hi_str)), int(eval(lo_str))
hi, lo = max(hi_val, lo_val), min(hi_val, lo_val)
j += 1
if j < len(toks) and toks[j].type == 'DOT': j += 2
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
val = parse_tokens(toks[j:], ctx(), funcs)
val = parse_tokens(toks[j:], vars, funcs)
dt_suffix = toks[2].val if toks[1].type == 'DOT' else None
if assigns is not None: assigns.append((f'{var}[{hi}:{lo}]' + (f'.{dt_suffix}' if dt_suffix else ''), val))
if var not in vars: vars[var] = _const(dtypes.uint64 if hi >= 32 else dtypes.uint32, 0)
old = block_assigns.get(var, vars.get(var))
mask = _u32(((1 << (hi - lo + 1)) - 1) << lo)
block_assigns[var] = vars[var] = (old & (mask ^ _u32(0xFFFFFFFF))) | (_val_to_bits(val) << _u32(lo))
block_assigns[var] = vars[var] = _set_bits(old, _val_to_bits(val), hi - lo + 1, lo)
i += 1; continue
except: pass
elif toks[1].type == 'LBRACKET': # bit index: var[expr] (only for var[...], not var.type[...])
existing = block_assigns.get(var, vars.get(var))
if existing is not None and isinstance(existing, UOp) and not any(f'{var}{k}' in vars or f'{var}{k}' in block_assigns for k in range(8)):
bit_toks = toks[2:j]
j += 1
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
if j < len(toks):
block_assigns[var] = vars[var] = _set_bit(existing, _to_u32(parse_tokens(bit_toks, vars, funcs)), parse_tokens(toks[j+1:], vars, funcs))
i += 1; continue
# Array element: var{idx} = value
if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'LBRACE' and toks[2].type == 'NUM':
@@ -927,7 +939,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
j = 4
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
if j < len(toks):
val = parse_tokens(toks[j+1:], ctx(), funcs)
val = parse_tokens(toks[j+1:], vars, funcs)
existing = block_assigns.get(var, vars.get(var))
if existing is not None and isinstance(existing, UOp):
block_assigns[var] = vars[var] = _set_bit(existing, _u32(idx), val)
@@ -936,121 +948,103 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
i += 1; continue
# Compound assignment: var += or var -=
for j, t in enumerate(toks):
if t.type == 'ASSIGN_OP':
assign_op = next((j for j, t in enumerate(toks) if t.type == 'ASSIGN_OP'), None)
if assign_op is not None:
var = toks[0].val
old = block_assigns.get(var, vars.get(var, _u32(0)))
rhs = parse_tokens(toks[assign_op+1:], vars, funcs)
if rhs.dtype != old.dtype: rhs = rhs.cast(old.dtype)
block_assigns[var] = vars[var] = (old + rhs) if toks[assign_op].val == '+=' else (old - rhs)
i += 1; continue
# Typed element: var.type[idx] = value
if len(toks) >= 7 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET' and toks[4].type == 'NUM':
var, dt_name, idx = toks[0].val, toks[2].val, int(toks[4].val)
dt = DTYPES.get(dt_name, dtypes.uint32)
j = 6
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
if j < len(toks):
val, old = parse_tokens(toks[j+1:], vars, funcs), block_assigns.get(var, vars.get(var, _u32(0)))
bw = dt.itemsize * 8
block_assigns[var] = vars[var] = _set_bits(old, val, bw, idx * bw)
if assigns is not None: assigns.append((f'{var}.{dt_name}[{idx}]', val))
i += 1; continue
# Dynamic bit: var.type[expr_with_brackets] = value
if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET':
j, depth, has_inner = 4, 1, False
while j < len(toks) and depth > 0:
if toks[j].type == 'LBRACKET': depth += 1; has_inner = True
elif toks[j].type == 'RBRACKET': depth -= 1
j += 1
if has_inner:
var = toks[0].val
old = block_assigns.get(var, vars.get(var, _u32(0)))
rhs = parse_tokens(toks[j+1:], ctx(), funcs)
if rhs.dtype != old.dtype: rhs = rhs.cast(old.dtype)
block_assigns[var] = vars[var] = (old + rhs) if t.val == '+=' else (old - rhs)
i += 1; break
else:
# Typed element: var.type[idx] = value
if len(toks) >= 7 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET' and toks[4].type == 'NUM':
var, dt_name, idx = toks[0].val, toks[2].val, int(toks[4].val)
dt = DTYPES.get(dt_name, dtypes.uint32)
j = 6
bit_pos = _to_u32(parse_tokens(toks[4:j-1], vars, funcs))
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
if j < len(toks):
val, old = parse_tokens(toks[j+1:], ctx(), funcs), block_assigns.get(var, vars.get(var, _u32(0)))
bw, lo_bit = dt.itemsize * 8, idx * dt.itemsize * 8
mask = _u32(((1 << bw) - 1) << lo_bit)
block_assigns[var] = vars[var] = (old & (mask ^ _u32(0xFFFFFFFF))) | (((val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val) & _u32((1 << bw) - 1)) << _u32(lo_bit))
if assigns is not None: assigns.append((f'{var}.{dt_name}[{idx}]', val))
val = parse_tokens(toks[j+1:], vars, funcs)
old = block_assigns.get(var, vars.get(var, _u32(0)))
block_assigns[var] = vars[var] = _set_bit(old, bit_pos, val)
i += 1; continue
# Dynamic bit: var.type[expr_with_brackets] = value
if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET':
j, depth, has_inner = 4, 1, False
while j < len(toks) and depth > 0:
if toks[j].type == 'LBRACKET': depth += 1; has_inner = True
elif toks[j].type == 'RBRACKET': depth -= 1
j += 1
if has_inner:
var = toks[0].val
bit_pos = _to_u32(parse_tokens(toks[4:j-1], ctx(), funcs))
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
if j < len(toks):
val = parse_tokens(toks[j+1:], ctx(), funcs)
old, mask = block_assigns.get(var, vars.get(var, _u32(0))), _u32(1) << bit_pos
block_assigns[var] = vars[var] = (old | mask) if val.op == Ops.CONST and val.arg == 1 else \
(old & (mask ^ _u32(0xFFFFFFFF))) if val.op == Ops.CONST and val.arg == 0 else _set_bit(old, bit_pos, val)
i += 1; continue
# Bit index: var[expr] = value (bit assignment to existing scalar)
if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'LBRACKET':
var = toks[0].val
existing = block_assigns.get(var, vars.get(var))
if existing is not None and isinstance(existing, UOp) and not any(f'{var}{k}' in vars or f'{var}{k}' in block_assigns for k in range(8)):
j = 2
while j < len(toks) and toks[j].type != 'RBRACKET': j += 1
bit_toks = toks[2:j]
j += 1
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
if j < len(toks):
block_assigns[var] = vars[var] = _set_bit(existing, _to_u32(parse_tokens(bit_toks, ctx(), funcs)), parse_tokens(toks[j+1:], ctx(), funcs))
i += 1; continue
# If/elsif/else - skip branches with statically false conditions (WAVE32/WAVE64)
if first == 'if':
def parse_cond(s, kw):
ll = s.lower()
return _to_bool(parse_expr(s[ll.find(kw) + len(kw):ll.rfind('then')].strip(), ctx(), funcs))
def not_static_false(c): return c.op != Ops.CONST or c.arg is not False
cond = parse_cond(line, 'if')
conditions: list[tuple[UOp, UOp | dict[str, VarVal] | None]] = [(cond, None)] if not_static_false(cond) else []
else_branch: tuple[UOp | None, dict[str, VarVal]] = (None, {})
vars_snap = dict(vars)
i += 1
i, branch, ret = parse_block(lines, i, vars, funcs, assigns)
if conditions: conditions[0] = (cond, ret if ret is not None else branch)
vars.clear(); vars.update(vars_snap)
while i < len(lines):
ltoks = tokenize(lines[i])
if ltoks[0].type != 'IDENT': break
lf = ltoks[0].val.lower()
if lf == 'elsif':
c = parse_cond(lines[i], 'elsif')
i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns)
if not_static_false(c): conditions.append((c, ret if ret is not None else branch))
vars.clear(); vars.update(vars_snap)
elif lf == 'else':
i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns)
else_branch = (ret, branch)
vars.clear(); vars.update(vars_snap)
elif lf == 'endif': i += 1; break
else: break
# Check if any branch returned a value (lambda-style)
if any(isinstance(br, UOp) for _, br in conditions):
result = else_branch[0]
for c, rv in reversed(conditions):
if isinstance(rv, UOp) and isinstance(result, UOp):
if rv.dtype != result.dtype and rv.dtype.itemsize == result.dtype.itemsize: result = result.cast(rv.dtype)
result = c.where(rv, result)
return i, block_assigns, result
# Main style: merge variable assignments with WHERE
else_assigns = else_branch[1]
all_vars = set().union(*[ba.keys() for _, ba in conditions if isinstance(ba, dict)], else_assigns.keys())
for var in all_vars:
res: Any = else_assigns.get(var, block_assigns.get(var, vars.get(var, _u32(0))))
for cond, ba in reversed(conditions):
if isinstance(ba, dict) and var in ba:
tv = ba[var]
if isinstance(tv, UOp) and isinstance(res, UOp):
res = cond.where(tv, res.cast(tv.dtype) if tv.dtype != res.dtype and tv.dtype.itemsize == res.dtype.itemsize else res)
block_assigns[var] = vars[var] = res
continue
# Regular assignment: var = value
for j, t in enumerate(toks):
if t.type == 'EQUALS':
if any(toks[k].type == 'OP' and toks[k].val in ('<', '>', '!', '=') for k in range(j)): break
base_var = toks[0].val
block_assigns[base_var] = vars[base_var] = parse_tokens(toks[j+1:], ctx(), funcs)
i += 1; break
else: i += 1
# If/elsif/else - skip branches with statically false conditions (WAVE32/WAVE64)
if first == 'if':
def parse_cond(s, kw):
ll = s.lower()
return _to_bool(parse_expr(s[ll.find(kw) + len(kw):ll.rfind('then')].strip(), vars, funcs))
def not_static_false(c): return c.op != Ops.CONST or c.arg is not False
cond = parse_cond(line, 'if')
conditions: list[tuple[UOp, UOp | dict[str, VarVal] | None]] = [(cond, None)] if not_static_false(cond) else []
else_branch: tuple[UOp | None, dict[str, VarVal]] = (None, {})
vars_snap = dict(vars)
i += 1
i, branch, ret = parse_block(lines, i, vars, funcs, assigns)
if conditions: conditions[0] = (cond, ret if ret is not None else branch)
vars.clear(); vars.update(vars_snap)
while i < len(lines):
ltoks = tokenize(lines[i])
if ltoks[0].type != 'IDENT': break
lf = ltoks[0].val.lower()
if lf == 'elsif':
c = parse_cond(lines[i], 'elsif')
i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns)
if not_static_false(c): conditions.append((c, ret if ret is not None else branch))
vars.clear(); vars.update(vars_snap)
elif lf == 'else':
i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns)
else_branch = (ret, branch)
vars.clear(); vars.update(vars_snap)
elif lf == 'endif': i += 1; break
else: break
# Check if any branch returned a value (lambda-style)
if any(isinstance(br, UOp) for _, br in conditions):
result = else_branch[0]
for c, rv in reversed(conditions):
if isinstance(rv, UOp) and isinstance(result, UOp):
if rv.dtype != result.dtype and rv.dtype.itemsize == result.dtype.itemsize: result = result.cast(rv.dtype)
result = c.where(rv, result)
return i, block_assigns, result
# Main style: merge variable assignments with WHERE
else_assigns = else_branch[1]
all_vars = set().union(*[ba.keys() for _, ba in conditions if isinstance(ba, dict)], else_assigns.keys())
for var in all_vars:
res: Any = else_assigns.get(var, block_assigns.get(var, vars.get(var, _u32(0))))
for cond, ba in reversed(conditions):
if isinstance(ba, dict) and var in ba:
tv = ba[var]
if isinstance(tv, UOp) and isinstance(res, UOp):
res = cond.where(tv, res.cast(tv.dtype) if tv.dtype != res.dtype and tv.dtype.itemsize == res.dtype.itemsize else res)
block_assigns[var] = vars[var] = res
continue
continue
# Regular assignment: var = value
for j, t in enumerate(toks):
if t.type == 'EQUALS':
if any(toks[k].type == 'OP' and toks[k].val in ('<', '>', '!', '=') for k in range(j)): break
base_var = toks[0].val
block_assigns[base_var] = vars[base_var] = parse_tokens(toks[j+1:], vars, funcs)
i += 1; break
else: i += 1
return i, block_assigns, None
def parse_expr(expr: str, vars: dict[str, VarVal], funcs: dict | None = None) -> UOp:

View File

@@ -138,6 +138,50 @@ class TestDS2AddrMore(unittest.TestCase):
self.assertEqual(st.vgpr[0][4], 0x12345678, "v4 should be untouched")
class TestDSB96(unittest.TestCase):
"""Tests for DS_STORE_B96 and DS_LOAD_B96 (96-bit / 3 dwords)."""
def test_ds_store_load_b96(self):
"""DS_STORE_B96 stores 3 VGPRs, DS_LOAD_B96 loads them back."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[0], 0x11111111),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[0], 0x22222222),
v_mov_b32_e32(v[1], s[0]),
s_mov_b32(s[0], 0x33333333),
v_mov_b32_e32(v[2], s[0]),
ds_store_b96(addr=v[10], data0=v[0:2]),
s_waitcnt(lgkmcnt=0),
ds_load_b96(addr=v[10], vdst=v[4:6]),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][4], 0x11111111, "v4 should have first dword")
self.assertEqual(st.vgpr[0][5], 0x22222222, "v5 should have second dword")
self.assertEqual(st.vgpr[0][6], 0x33333333, "v6 should have third dword")
def test_ds_store_b96_with_offset(self):
"""DS_STORE_B96 with non-zero offset."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[0], 0xAAAAAAAA),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[0], 0xBBBBBBBB),
v_mov_b32_e32(v[1], s[0]),
s_mov_b32(s[0], 0xCCCCCCCC),
v_mov_b32_e32(v[2], s[0]),
DS(DSOp.DS_STORE_B96, addr=v[10], data0=v[0:2], offset0=12),
s_waitcnt(lgkmcnt=0),
DS(DSOp.DS_LOAD_B96, addr=v[10], vdst=v[4:6], offset0=12),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][4], 0xAAAAAAAA)
self.assertEqual(st.vgpr[0][5], 0xBBBBBBBB)
self.assertEqual(st.vgpr[0][6], 0xCCCCCCCC)
class TestDSB128(unittest.TestCase):
"""Tests for DS_STORE_B128 and DS_LOAD_B128 (128-bit / 4 dwords)."""

View File

@@ -265,6 +265,113 @@ class TestSLoadMultiDword(unittest.TestCase):
self.assertEqual(st.sgpr[5], st.sgpr[9])
class TestSLoadLarge(unittest.TestCase):
"""Tests for large s_load operations (s_load_b256, s_load_b512)."""
def test_s_load_b256_basic(self):
"""s_load_b256 loads 8 consecutive dwords."""
instructions = [
s_load_b64(s[2:3], s[80:81], 0, soffset=NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], 0),
# Store 8 test values
s_mov_b32(s[20], 0x11111111),
v_mov_b32_e32(v[2], s[20]),
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET),
s_mov_b32(s[20], 0x22222222),
v_mov_b32_e32(v[2], s[20]),
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+4),
s_mov_b32(s[20], 0x33333333),
v_mov_b32_e32(v[2], s[20]),
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+8),
s_mov_b32(s[20], 0x44444444),
v_mov_b32_e32(v[2], s[20]),
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+12),
s_mov_b32(s[20], 0x55555555),
v_mov_b32_e32(v[2], s[20]),
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+16),
s_mov_b32(s[20], 0x66666666),
v_mov_b32_e32(v[2], s[20]),
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+20),
s_mov_b32(s[20], 0x77777777),
v_mov_b32_e32(v[2], s[20]),
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+24),
s_mov_b32(s[20], 0x88888888),
v_mov_b32_e32(v[2], s[20]),
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+28),
s_waitcnt(vmcnt=0),
*CACHE_INV,
# Load all 8 dwords with s_load_b256
s_load_b256(s[4:11], s[2:3], NULL, offset=TEST_OFFSET),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 0), s_mov_b32(s[3], 0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[4], 0x11111111)
self.assertEqual(st.sgpr[5], 0x22222222)
self.assertEqual(st.sgpr[6], 0x33333333)
self.assertEqual(st.sgpr[7], 0x44444444)
self.assertEqual(st.sgpr[8], 0x55555555)
self.assertEqual(st.sgpr[9], 0x66666666)
self.assertEqual(st.sgpr[10], 0x77777777)
self.assertEqual(st.sgpr[11], 0x88888888)
def test_s_load_b512_basic(self):
"""s_load_b512 loads 16 consecutive dwords."""
instructions = [
s_load_b64(s[2:3], s[80:81], 0, soffset=NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], 0),
# Store 16 test values (use a pattern: 0x10, 0x20, ..., 0x100)
*[instr for i in range(16) for instr in [
s_mov_b32(s[20], (i + 1) * 0x11111111),
v_mov_b32_e32(v[2], s[20]),
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET + i * 4),
]],
s_waitcnt(vmcnt=0),
*CACHE_INV,
# Load all 16 dwords with s_load_b512
s_load_b512(s[64:79], s[2:3], NULL, offset=TEST_OFFSET),
s_waitcnt(lgkmcnt=0),
# Copy results to lower regs for verification (since st.sgpr only has 16 regs in test)
s_mov_b32(s[4], s[64]),
s_mov_b32(s[5], s[65]),
s_mov_b32(s[6], s[78]),
s_mov_b32(s[7], s[79]),
s_mov_b32(s[2], 0), s_mov_b32(s[3], 0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[4], 0x11111111, "first dword")
self.assertEqual(st.sgpr[5], 0x22222222, "second dword")
self.assertEqual(st.sgpr[6], 0xFFFFFFFF & (15 * 0x11111111), "15th dword")
self.assertEqual(st.sgpr[7], 0xFFFFFFFF & (16 * 0x11111111), "16th dword")
def test_s_load_b256_with_register_offset(self):
"""s_load_b256 with register offset should add reg offset to address."""
instructions = [
s_load_b64(s[2:3], s[80:81], 0, soffset=NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], 0),
# Store pattern at TEST_OFFSET+8: skip first 2 dwords
*[instr for i in range(8) for instr in [
s_mov_b32(s[20], (i + 1) * 0x11111111),
v_mov_b32_e32(v[2], s[20]),
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET + 8 + i * 4),
]],
s_waitcnt(vmcnt=0),
*CACHE_INV,
# Load with register offset 8
s_mov_b32(s[20], 8),
s_load_b256(s[4:11], s[2:3], s[20], offset=TEST_OFFSET),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 0), s_mov_b32(s[3], 0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[4], 0x11111111, "first dword at offset+8")
self.assertEqual(st.sgpr[5], 0x22222222, "second dword at offset+8")
self.assertEqual(st.sgpr[11], 0x88888888, "last dword at offset+8")
class TestSLoadOffset(unittest.TestCase):
"""Tests for s_load with different immediate offsets.

View File

@@ -719,5 +719,172 @@ class TestNullRegister(unittest.TestCase):
self.assertEqual(st.scc, 0)
class Test64BitSOP1InlineConstants(unittest.TestCase):
"""Tests for 64-bit SOP1 instructions with inline constants.
Regression tests for bug where rsrc_dyn didn't properly handle 64-bit
inline constants, incorrectly duplicating lo bits to hi instead of
zero/sign-extending.
"""
def test_s_mov_b64_inline_0(self):
"""S_MOV_B64 with inline constant 0."""
instructions = [
s_mov_b64(s[0:1], 0),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 0)
self.assertEqual(st.vgpr[0][1], 0)
def test_s_mov_b64_inline_16(self):
"""S_MOV_B64 with inline constant 16 should set lo=16, hi=0."""
instructions = [
s_mov_b64(s[0:1], 16),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 16)
self.assertEqual(st.vgpr[0][1], 0)
def test_s_mov_b64_inline_64(self):
"""S_MOV_B64 with inline constant 64 (max positive)."""
instructions = [
s_mov_b64(s[0:1], 64),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 64)
self.assertEqual(st.vgpr[0][1], 0)
def test_s_mov_b64_inline_neg1(self):
"""S_MOV_B64 with inline constant -1 should sign-extend."""
instructions = [
s_mov_b64(s[0:1], -1),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 0xFFFFFFFF)
self.assertEqual(st.vgpr[0][1], 0xFFFFFFFF)
def test_s_mov_b64_inline_neg16(self):
"""S_MOV_B64 with inline constant -16 should sign-extend."""
instructions = [
s_mov_b64(s[0:1], -16),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 0xFFFFFFF0)
self.assertEqual(st.vgpr[0][1], 0xFFFFFFFF)
def test_s_mov_b64_float_const_1_0(self):
"""S_MOV_B64 with float inline constant 1.0 - casts F32 to F64."""
instructions = [
s_mov_b64(s[0:1], 1.0), # inline constant 242 (1.0f)
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
]
st = run_program(instructions, n_lanes=1)
# Hardware casts F32 to F64: 1.0f64 = 0x3FF0000000000000
self.assertEqual(st.vgpr[0][0], 0x00000000) # lo
self.assertEqual(st.vgpr[0][1], 0x3FF00000) # hi
def test_s_or_b64_inline_constant(self):
"""S_OR_B64 with 64-bit inline constant."""
instructions = [
s_mov_b64(s[0:1], 0),
s_or_b64(s[2:3], s[0:1], 16),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 16)
self.assertEqual(st.vgpr[0][1], 0)
def test_s_and_b64_inline_constant(self):
"""S_AND_B64 with 64-bit inline constant."""
instructions = [
s_mov_b32(s[0], 0xFFFFFFFF),
s_mov_b32(s[1], 0xFFFFFFFF),
s_and_b64(s[2:3], s[0:1], 16),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 16)
self.assertEqual(st.vgpr[0][1], 0)
class Test64BitSOPLiterals(unittest.TestCase):
"""Tests for 64-bit SOP instructions with 32-bit literals.
Tests the behavior when a 64-bit SOP instruction uses a 32-bit literal
(offset 255 in instruction encoding). The literal is zero-extended to 64 bits.
"""
def test_s_mov_b64_literal(self):
"""S_MOV_B64 with 32-bit literal value - zero-extended to 64 bits."""
instructions = [
s_mov_b64(s[0:1], 0x12345678), # literal > 64, uses literal encoding
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 0x12345678)
self.assertEqual(st.vgpr[0][1], 0)
def test_s_or_b64_literal(self):
"""S_OR_B64 with 32-bit literal value - zero-extended to 64 bits."""
instructions = [
s_mov_b64(s[0:1], 0),
s_or_b64(s[2:3], s[0:1], 0x12345678), # literal
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 0x12345678)
self.assertEqual(st.vgpr[0][1], 0)
def test_s_and_b64_literal(self):
"""S_AND_B64 with 32-bit literal value - zero-extended to 64 bits."""
instructions = [
s_mov_b32(s[0], 0xFFFFFFFF),
s_mov_b32(s[1], 0xFFFFFFFF),
s_and_b64(s[2:3], s[0:1], 0x12345678), # literal
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 0x12345678)
self.assertEqual(st.vgpr[0][1], 0)
def test_s_mov_b64_literal_negative(self):
"""S_MOV_B64 with 0xFFFFFFFF literal - zero-extended (not sign-extended)."""
instructions = [
s_mov_b64(s[0:1], 0xFFFFFFFF), # -1 as 32-bit, but zero-extended to 64-bit
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 0xFFFFFFFF)
self.assertEqual(st.vgpr[0][1], 0) # zero-extended, not sign-extended
def test_s_mov_b64_literal_high_bit(self):
"""S_MOV_B64 with 0x80000000 literal - zero-extended (not sign-extended)."""
instructions = [
s_mov_b64(s[0:1], 0x80000000), # high bit set, but zero-extended
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 0x80000000)
self.assertEqual(st.vgpr[0][1], 0) # zero-extended, not sign-extended
if __name__ == '__main__':
unittest.main()

View File

@@ -1359,6 +1359,43 @@ class TestF64ToI64Conversion(unittest.TestCase):
self.assertEqual(result, 5000000000)
class TestB64VOPLiteral(unittest.TestCase):
"""Tests for B64 VOP operations with literal encoding.
B64 operations (like V_LSHLREV_B64) should zero-extend the literal to 64 bits,
NOT put it in the high 32 bits like F64 operations do.
"""
def test_v_lshlrev_b64_literal_shift_amount(self):
"""V_LSHLREV_B64 with literal shift amount (src0 is 32-bit)."""
# Shift 1 left by 100 (0x64) - uses literal encoding for src0
# Shift amount is 100 & 63 = 36, so 1 << 36 = 0x1000000000
instructions = [
s_mov_b32(s[0], 1),
s_mov_b32(s[1], 0),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_lshlrev_b64(v[2:3], 100, v[0:1]), # 100 > 64, uses literal encoding
]
st = run_program(instructions, n_lanes=1)
# lo = 0x00000000, hi = 0x00000010 = 1 << (36-32)
self.assertEqual(st.vgpr[0][2], 0x00000000)
self.assertEqual(st.vgpr[0][3], 0x00000010)
def test_v_lshlrev_b64_literal_value(self):
"""V_LSHLREV_B64 with literal as the 64-bit value being shifted (src1).
B64 literals are zero-extended (not shifted to high bits like F64).
0xDEADBEEF << 4 = 0xDEADBEEF0 = lo=0xEADBEEF0, hi=0x0000000D
"""
instructions = [
v_lshlrev_b64(v[0:1], 4, 0xDEADBEEF), # shift literal left by 4
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 0xEADBEEF0) # lo
self.assertEqual(st.vgpr[0][1], 0x0000000D) # hi
class TestWMMAMore(unittest.TestCase):
"""More WMMA tests."""