mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
assembly/amd: continue refactors (#14386)
* simpler * merge * flat * no ctx * use the correct apis * dup code * write clean code * remove bad helpers * bits junk remove * junk remove * smem test * fix tests * correct fix + tests * Fmt matters it seems * wmma refactor * a lil more * kimi cleanups * line
This commit is contained in:
@@ -253,6 +253,15 @@ def _get_variant(cls, suffix: str):
|
||||
module = sys.modules.get(cls.__module__)
|
||||
return getattr(module, f"{cls.__name__}{suffix}", None) if module else None
|
||||
|
||||
def _canonical_name(name: str) -> str | None:
|
||||
"""Map operand name to canonical name."""
|
||||
if name in ('src0', 'vsrc0', 'ssrc0'): return 's0'
|
||||
if name in ('src1', 'vsrc1', 'ssrc1'): return 's1'
|
||||
if name == 'src2': return 's2'
|
||||
if name in ('vdst', 'sdst', 'sdata'): return 'd'
|
||||
if name in ('data', 'vdata', 'data0', 'vsrc'): return 'data'
|
||||
return None
|
||||
|
||||
class Inst:
|
||||
_fields: list[tuple[str, BitField]]
|
||||
_base_size: int
|
||||
@@ -368,12 +377,17 @@ class Inst:
|
||||
"""Get bit widths with canonical names: {'s0', 's1', 's2', 'd', 'data'}."""
|
||||
bits = {'d': 32, 's0': 32, 's1': 32, 's2': 32, 'data': 32}
|
||||
for name, val in self.op_bits.items():
|
||||
if name in ('src0', 'vsrc0', 'ssrc0'): bits['s0'] = val
|
||||
elif name in ('src1', 'vsrc1', 'ssrc1'): bits['s1'] = val
|
||||
elif name == 'src2': bits['s2'] = val
|
||||
elif name in ('vdst', 'sdst', 'sdata'): bits['d'] = val
|
||||
elif name in ('data', 'vdata', 'data0', 'vsrc'): bits['data'] = val
|
||||
if (cn := _canonical_name(name)): bits[cn] = val
|
||||
return bits
|
||||
|
||||
@functools.cached_property
|
||||
def canonical_operands(self) -> dict:
|
||||
"""Get operands with canonical names: {'s0', 's1', 's2', 'd', 'data'}."""
|
||||
result = {}
|
||||
for name, val in self.operands.items():
|
||||
if (cn := _canonical_name(name)): result[cn] = val
|
||||
return result
|
||||
|
||||
@property
|
||||
def canonical_op_regs(self) -> dict[str, int]:
|
||||
"""Get register counts with canonical names: {'s0', 's1', 's2', 'd', 'data'}."""
|
||||
|
||||
@@ -53,7 +53,7 @@ from extra.assembly.amd.autogen.rdna3.str_pcode import PCODE
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP1_SDST, VOP2, VOP3, VOP3_SDST, VOP3SD, VOP3P, VOPC,
|
||||
DS, FLAT, GLOBAL, SCRATCH, VOPD, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOPDOp)
|
||||
from extra.assembly.amd.dsl import VCC_LO, EXEC_LO, SCC
|
||||
from extra.assembly.amd.autogen.common import OpType
|
||||
from extra.assembly.amd.autogen.common import Fmt, OpType
|
||||
from extra.assembly.amd.pcode import parse_block, _FUNCS
|
||||
|
||||
MASK32 = 0xFFFFFFFF
|
||||
@@ -70,14 +70,14 @@ def _split64(val: UOp) -> tuple[UOp, UOp]:
|
||||
return v64.cast(dtypes.uint32), (v64 >> UOp.const(dtypes.uint64, 32)).cast(dtypes.uint32)
|
||||
|
||||
_SRC_MOD_TYPES = {16: (dtypes.uint16, dtypes.half, 0x7FFF), 64: (dtypes.uint64, dtypes.float64, 0x7FFFFFFFFFFFFFFF), 32: (dtypes.uint32, dtypes.float32, 0x7FFFFFFF)}
|
||||
def _apply_src_mods(val: UOp, mod_bit: int, abs_bits: int, neg_bits: int, is_16bit: bool = False, is_64bit: bool = False) -> UOp:
|
||||
"""Apply abs/neg modifiers to source value based on operation type."""
|
||||
def _apply_src_mods(val: UOp, mod_bit: int, abs_bits: int, neg_bits: int, bits: int = 32) -> UOp:
|
||||
"""Apply abs/neg modifiers to source value based on bit width (16, 32, or 64)."""
|
||||
if not (abs_bits & (1 << mod_bit)) and not (neg_bits & (1 << mod_bit)): return val
|
||||
ut, ft, mask = _SRC_MOD_TYPES[16 if is_16bit else 64 if is_64bit else 32]
|
||||
fv = val.cast(ut).bitcast(ft) if is_16bit else val.bitcast(ft) if val.dtype == ut else val
|
||||
ut, ft, mask = _SRC_MOD_TYPES[bits]
|
||||
fv = val.cast(ut).bitcast(ft) if bits == 16 else val.bitcast(ft) if val.dtype == ut else val
|
||||
if abs_bits & (1 << mod_bit): fv = (fv.bitcast(ut) & UOp.const(ut, mask)).bitcast(ft)
|
||||
if neg_bits & (1 << mod_bit): fv = fv.neg()
|
||||
return fv.bitcast(ut).cast(dtypes.uint32) if is_16bit else fv.bitcast(ut)
|
||||
return fv.bitcast(ut).cast(dtypes.uint32) if bits == 16 else fv.bitcast(ut)
|
||||
|
||||
# Map VOPD ops to VOP2 ops for pcode lookup
|
||||
VOPD_TO_VOP2 = {
|
||||
@@ -95,11 +95,10 @@ PC_LO_IDX, PC_HI_IDX, SCRATCH_STRIDE_IDX = 256, 257, 259
|
||||
# SGPR buffer: 0-127 = SGPRs, 128-255 = inline constants, 256-259 = special registers
|
||||
SGPR_COUNT, VGPR_SIZE = 260, 256 * 32
|
||||
|
||||
def _is_16bit_op(op_name: str) -> bool: return any(x in op_name for x in ('B16', 'F16', 'I16', 'U16'))
|
||||
def _op_name(inst) -> str:
|
||||
if hasattr(inst, 'opx'): return f"{inst.opx.name}_{inst.opy.name}" # VOPD has opx/opy not op
|
||||
return inst.op.name if hasattr(inst.op, 'name') else str(inst.op)
|
||||
def _is_64bit_dest(dest: str) -> bool: return any(dest.endswith(x) for x in ('.b64', '.u64', '.i64', '.f64'))
|
||||
|
||||
def _to_u32(val: UOp) -> UOp:
|
||||
if val.dtype == dtypes.uint32: return val
|
||||
if val.dtype.itemsize == 4: return val.bitcast(dtypes.uint32) # same size: bitcast (float32->uint32)
|
||||
@@ -192,9 +191,9 @@ def _write_64bit(val: UOp, wfn, reg_or_addr, is_mem: bool, *args) -> list[UOp]:
|
||||
incr = 4 if is_mem else 1 # 4 bytes for memory addresses, 1 for register indices
|
||||
return [wfn(reg_or_addr, lo, *args), wfn(reg_or_addr + (UOp.const(reg_or_addr.dtype, incr) if isinstance(reg_or_addr, UOp) else incr), hi, *args)]
|
||||
|
||||
def _write_val(dest: str, val: UOp, wfn, reg_or_addr, *args, is_mem: bool = False) -> list[UOp]:
|
||||
"""Write value, splitting 64-bit if needed based on dest type suffix."""
|
||||
return _write_64bit(val, wfn, reg_or_addr, is_mem, *args) if _is_64bit_dest(dest) else [wfn(reg_or_addr, _to_u32(val), *args)]
|
||||
def _write_val(bits: int, val: UOp, wfn, reg_or_addr, *args, is_mem: bool = False) -> list[UOp]:
|
||||
"""Write value, splitting 64-bit if needed. bits=64 for 64-bit writes, otherwise 32-bit."""
|
||||
return _write_64bit(val, wfn, reg_or_addr, is_mem, *args) if bits == 64 else [wfn(reg_or_addr, _to_u32(val), *args)]
|
||||
|
||||
def _mem_store(mem: UOp, addr: UOp, val: UOp, active: UOp, addr_bits: int = 32, data_bits: int = 32) -> list[UOp]:
|
||||
"""Conditional memory store with sub-word support. Returns list of store UOps."""
|
||||
@@ -337,31 +336,38 @@ class _Ctx:
|
||||
offset = reg.cast(dtypes.int) * _c(32, dtypes.int) + lane.cast(dtypes.int)
|
||||
return buf.index(offset, _lane_active(exec_mask, lane)).store(val.cast(dtypes.uint32))
|
||||
|
||||
def rsrc_dyn(self, off: UOp, lane: UOp, bits: int = 32, literal: UOp | None = None) -> UOp:
|
||||
"""Read source operand with dynamic offset. Handles SGPR/inline constants (<256), VGPR (>=256)."""
|
||||
is_vgpr, vgpr_reg = off >= _c(256), off - _c(256)
|
||||
def rsrc_dyn(self, off: UOp, lane: UOp | None, bits: int = 32, literal: UOp | None = None, is_f64: bool = False) -> UOp:
|
||||
"""Read source operand with dynamic offset. Handles SGPR/inline constants (<256), VGPR (>=256).
|
||||
If lane is None, only scalar access is supported (off must be < 256).
|
||||
is_f64: True for F64 operations where 64-bit literals go in high 32 bits."""
|
||||
is_float_const = (off >= _c(240)) & (off <= _c(248))
|
||||
sgpr_lo = self.rsgpr_dyn(off)
|
||||
vgpr_lo = self.rvgpr_dyn(vgpr_reg, lane, is_vgpr)
|
||||
|
||||
if lane is not None:
|
||||
is_vgpr, vgpr_reg = off >= _c(256), off - _c(256)
|
||||
vgpr_lo = self.rvgpr_dyn(vgpr_reg, lane, is_vgpr)
|
||||
vgpr_val = _u64(vgpr_lo, self.rvgpr_dyn(vgpr_reg + _c(1), lane, is_vgpr)) if bits == 64 else vgpr_lo
|
||||
|
||||
if bits == 64:
|
||||
vgpr_val = _u64(vgpr_lo, self.rvgpr_dyn(vgpr_reg + _c(1), lane, is_vgpr))
|
||||
sgpr_val = _u64(sgpr_lo, self.rsgpr_dyn(off + _c(1)))
|
||||
# Float constants: cast F32 to F64; integer inline: duplicate lo
|
||||
inline = is_float_const.where(sgpr_lo.bitcast(dtypes.float32).cast(dtypes.float64).bitcast(dtypes.uint64), _u64(sgpr_lo, sgpr_lo))
|
||||
if literal is not None: inline = off.eq(_c(255)).where(literal.cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32), inline)
|
||||
# Integer inline constants: sign-extend 32-bit value from buffer to 64-bit
|
||||
# Float constants: cast F32 to F64
|
||||
int_inline = sgpr_lo.cast(dtypes.int32).cast(dtypes.int64)
|
||||
float_inline = sgpr_lo.bitcast(dtypes.float32).cast(dtypes.float64)
|
||||
# compute inline
|
||||
inline = is_float_const.where(float_inline.bitcast(dtypes.uint64), int_inline.bitcast(dtypes.uint64))
|
||||
# Literal handling: F64 VOP puts literal in high 32 bits; B64/I64/U64 VOP and SOP zero-extend
|
||||
if literal is not None:
|
||||
lit_val = literal.cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32) if is_f64 else literal.cast(dtypes.uint64)
|
||||
inline = off.eq(_c(255)).where(lit_val, inline)
|
||||
scalar_val = (off < _c(128)).where(sgpr_val, inline)
|
||||
else:
|
||||
vgpr_val = vgpr_lo
|
||||
scalar_val = sgpr_lo
|
||||
if literal is not None: scalar_val = off.eq(_c(255)).where(literal, scalar_val)
|
||||
if bits == 16: # Float constants: cast F32 to F16
|
||||
scalar_val = is_float_const.where(scalar_val.bitcast(dtypes.float32).cast(dtypes.half).bitcast(dtypes.uint16).cast(dtypes.uint32), scalar_val)
|
||||
|
||||
return is_vgpr.where(vgpr_val, scalar_val)
|
||||
|
||||
def rsrc_dyn_sized(self, off: UOp, lane: UOp, sizes: dict, key: str, f16: bool = False, literal: UOp | None = None) -> UOp:
|
||||
return self.rsrc_dyn(off, lane, 64, literal) if sizes.get(key, 1) == 2 else self.rsrc_dyn(off, lane, 16 if f16 else 32, literal)
|
||||
return is_vgpr.where(vgpr_val, scalar_val) if lane is not None else scalar_val
|
||||
|
||||
def rpc(self) -> UOp:
|
||||
"""Read PC as 64-bit byte address."""
|
||||
@@ -509,36 +515,9 @@ def _compile_smem(inst: SMEM, ctx: _Ctx) -> UOp:
|
||||
return UOp.sink(*stores, *ctx.inc_pc())
|
||||
|
||||
def _compile_sop(inst: SOP1 | SOP2 | SOPC | SOPK, ctx: _Ctx) -> UOp:
|
||||
sizes = getattr(inst, 'op_regs', {})
|
||||
bits = inst.canonical_op_bits
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
|
||||
|
||||
# Read source operands dynamically
|
||||
def rsrc_dyn_scalar(off: UOp, is_64bit: bool) -> UOp:
|
||||
"""Read scalar source with dynamic offset (SGPR or inline constant).
|
||||
For SOP, off is always 0-255 (SGPR or inline constant, never VGPR).
|
||||
SGPR buffer has 260 entries: 0-127=SGPRs, 128-255=inline constants, 256-259=special."""
|
||||
is_sgpr = off < _c(128)
|
||||
# For 64-bit: read SGPR pair if off < 128, else compute inline constant as 64-bit
|
||||
# (can't just read from buffer since buffer has 32-bit values)
|
||||
if is_64bit:
|
||||
sgpr_val = _u64(ctx.rsgpr_dyn(off), ctx.rsgpr_dyn(off + _c(1)))
|
||||
# Build inline constant: 128-192 = 0-64, 193-208 = -1 to -16
|
||||
inline_val = (off - _c(128)).cast(dtypes.uint64) # positive inline 0-64
|
||||
neg_val = (_c(192) - off).cast(dtypes.int64).cast(dtypes.uint64) # negative -1 to -16
|
||||
lit_val = literal.cast(dtypes.uint64) if literal is not None else UOp.const(dtypes.uint64, 0)
|
||||
# Select between sgpr, positive inline, negative inline, or literal
|
||||
is_neg_inline = (off >= _c(193)) & (off < _c(209))
|
||||
is_literal = off.eq(_c(255)) if literal is not None else UOp.const(dtypes.bool, False)
|
||||
val = is_sgpr.where(sgpr_val, is_neg_inline.where(neg_val, is_literal.where(lit_val, inline_val)))
|
||||
return val
|
||||
# 32-bit: read from SGPR buffer (inline constants 128-255 are pre-populated)
|
||||
# off is always 0-255 for SOP, all valid SGPR indices
|
||||
sgpr_val = ctx.rsgpr_dyn(off)
|
||||
# Handle literal (255) - literal value overrides the pre-populated 0
|
||||
if literal is not None:
|
||||
sgpr_val = off.eq(_c(255)).where(literal, sgpr_val)
|
||||
return sgpr_val
|
||||
|
||||
if isinstance(inst, SOPK):
|
||||
sdst_off = ctx.inst_field(SOPK.sdst)
|
||||
simm16 = ctx.inst_field(SOPK.simm16)
|
||||
@@ -549,21 +528,21 @@ def _compile_sop(inst: SOP1 | SOP2 | SOPC | SOPK, ctx: _Ctx) -> UOp:
|
||||
elif isinstance(inst, SOP1):
|
||||
sdst_off = ctx.inst_field(SOP1.sdst)
|
||||
ssrc0_off = ctx.inst_field(SOP1.ssrc0)
|
||||
srcs = {'S0': rsrc_dyn_scalar(ssrc0_off, sizes.get('ssrc0', 1) == 2)}
|
||||
dst_off, dst_size = sdst_off, sizes.get('sdst', 1)
|
||||
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal)}
|
||||
dst_off, dst_size = sdst_off, bits['d'] // 32
|
||||
elif isinstance(inst, SOP2):
|
||||
sdst_off = ctx.inst_field(SOP2.sdst)
|
||||
ssrc0_off = ctx.inst_field(SOP2.ssrc0)
|
||||
ssrc1_off = ctx.inst_field(SOP2.ssrc1)
|
||||
srcs = {'S0': rsrc_dyn_scalar(ssrc0_off, sizes.get('ssrc0', 1) == 2),
|
||||
'S1': rsrc_dyn_scalar(ssrc1_off, sizes.get('ssrc1', 1) == 2)}
|
||||
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal),
|
||||
'S1': ctx.rsrc_dyn(ssrc1_off, None, bits['s1'], literal)}
|
||||
if literal is not None: srcs['SIMM32'] = literal
|
||||
dst_off, dst_size = sdst_off, sizes.get('sdst', 1)
|
||||
dst_off, dst_size = sdst_off, bits['d'] // 32
|
||||
elif isinstance(inst, SOPC):
|
||||
ssrc0_off = ctx.inst_field(SOPC.ssrc0)
|
||||
ssrc1_off = ctx.inst_field(SOPC.ssrc1)
|
||||
srcs = {'S0': rsrc_dyn_scalar(ssrc0_off, sizes.get('ssrc0', 1) == 2),
|
||||
'S1': rsrc_dyn_scalar(ssrc1_off, sizes.get('ssrc1', 1) == 2)}
|
||||
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal),
|
||||
'S1': ctx.rsrc_dyn(ssrc1_off, None, bits['s1'], literal)}
|
||||
dst_off, dst_size = _c(0), 0 # SOPC writes to SCC, not sdst
|
||||
else:
|
||||
raise RuntimeError(f"unknown SOP type: {type(inst).__name__}")
|
||||
@@ -573,18 +552,17 @@ def _compile_sop(inst: SOP1 | SOP2 | SOPC | SOPK, ctx: _Ctx) -> UOp:
|
||||
def _compile_vop12(inst: VOP1 | VOP1_SDST | VOP2, ctx: _Ctx) -> UOp:
|
||||
op_name = _op_name(inst)
|
||||
if op_name == 'V_READFIRSTLANE_B32_E32': return ctx.compile_lane_pcode(inst.op, inst)
|
||||
lane, exec_mask, sizes = ctx.range(), ctx.rsgpr_dyn(_c(EXEC_LO.offset)), getattr(inst, 'op_regs', {})
|
||||
is_16bit = _is_16bit_op(op_name)
|
||||
lane, exec_mask, bits = ctx.range(), ctx.rsgpr_dyn(_c(EXEC_LO.offset)), inst.canonical_op_bits
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
|
||||
vdst_reg = ctx.inst_field(VOP1.vdst)
|
||||
write_hi_half = is_16bit and (vdst_reg >= _c(128))
|
||||
write_hi_half = bits['d'] == 16 and (vdst_reg >= _c(128))
|
||||
if isinstance(write_hi_half, UOp): vdst_reg = write_hi_half.where(vdst_reg - _c(128), vdst_reg)
|
||||
elif write_hi_half: vdst_reg -= 128
|
||||
if isinstance(inst, VOP1):
|
||||
# Handle VOP1 hi-half source operand (src0 >= v[128] for 16-bit ops)
|
||||
src0_off = ctx.inst_field(VOP1.src0)
|
||||
s0 = ctx.rsrc_dyn_sized(src0_off, lane, sizes, 'src0', f16=is_16bit, literal=literal)
|
||||
if is_16bit:
|
||||
s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal)
|
||||
if bits['s0'] == 16:
|
||||
src0_hi = src0_off >= _c(384)
|
||||
# Only compute hi-half when src0_off >= 384, use guarded index to prevent OOB access
|
||||
src0_reg = src0_hi.where(src0_off - _c(384), _c(0))
|
||||
@@ -592,14 +570,14 @@ def _compile_vop12(inst: VOP1 | VOP1_SDST | VOP2, ctx: _Ctx) -> UOp:
|
||||
srcs = {'S0': s0}
|
||||
else:
|
||||
vsrc1_reg = ctx.inst_field(VOP2.vsrc1)
|
||||
vsrc1_hi = is_16bit and (vsrc1_reg >= _c(128))
|
||||
vsrc1_hi = bits['s0'] == 16 and (vsrc1_reg >= _c(128))
|
||||
vsrc1_actual = _cond(vsrc1_hi, vsrc1_reg - _c(128), vsrc1_reg)
|
||||
s1 = _cond_hi16(vsrc1_hi, ctx.rvgpr_dyn(vsrc1_actual, lane))
|
||||
d0 = _cond_hi16(write_hi_half, ctx.rvgpr_dyn(vdst_reg, lane)) # FMAC/FMAMK hi-half dest needs hi-half accumulator
|
||||
# Handle VOP2 hi-half src0 operand (src0 >= v[128] for 16-bit ops)
|
||||
src0_off = ctx.inst_field(VOP2.src0)
|
||||
s0 = ctx.rsrc_dyn(src0_off, lane, bits=16 if is_16bit else 32, literal=literal)
|
||||
if is_16bit:
|
||||
s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal)
|
||||
if bits['s0'] == 16:
|
||||
src0_hi = src0_off >= _c(384)
|
||||
# Only compute hi-half when src0_off >= 384, use guarded index to prevent OOB access
|
||||
src0_reg = src0_hi.where(src0_off - _c(384), _c(0))
|
||||
@@ -611,41 +589,36 @@ def _compile_vop12(inst: VOP1 | VOP1_SDST | VOP2, ctx: _Ctx) -> UOp:
|
||||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=write_hi_half)
|
||||
|
||||
def _compile_vopc(inst: VOPC | VOP3, ctx: _Ctx, opsel: int = 0, abs_bits: int = 0, neg_bits: int = 0) -> UOp:
|
||||
exec_mask, op_name = ctx.rsgpr_dyn(_c(EXEC_LO.offset)), _op_name(inst)
|
||||
is_cmpx, is_16bit, is_64bit = 'CMPX' in op_name, _is_16bit_op(op_name), 'F64' in op_name
|
||||
is_vopc = hasattr(inst, 'vsrc1') # VOPC (e32) vs VOP3 (e64) format
|
||||
exec_mask, op_name, bits = ctx.rsgpr_dyn(_c(EXEC_LO.offset)), _op_name(inst), inst.canonical_op_bits
|
||||
is_cmpx, is_vopc = 'CMPX' in op_name, hasattr(inst, 'vsrc1') # is_vopc: e32 vs e64
|
||||
|
||||
# Handle both VOPC (vsrc1) and VOP3 (src1) instruction formats - read operands dynamically
|
||||
if is_vopc:
|
||||
src0_off = ctx.inst_field(VOPC.src0)
|
||||
vsrc1_off = ctx.inst_field(VOPC.vsrc1)
|
||||
# For 16-bit ops, vsrc1 >= 128 means hi-half of v[vsrc1-128]
|
||||
if is_16bit:
|
||||
if bits['s0'] == 16:
|
||||
vsrc1_hi = vsrc1_off >= _c(128)
|
||||
src1_off = _c(256) + vsrc1_hi.where(vsrc1_off - _c(128), vsrc1_off)
|
||||
else:
|
||||
vsrc1_hi = False
|
||||
src1_off = _c(256) + vsrc1_off
|
||||
src0_bits, src1_bits = (64, 64) if is_64bit else (32, 32)
|
||||
else:
|
||||
src0_off = ctx.inst_field(VOP3.src0)
|
||||
src1_off = ctx.inst_field(VOP3.src1)
|
||||
dst_off = ctx.inst_field(VOP3.vdst)
|
||||
vsrc1_hi = False
|
||||
_, src0_bits, _ = inst.operands.get('src0', (None, 32, None))
|
||||
_, src1_bits, _ = inst.operands.get('src1', (None, 32, None))
|
||||
is_16bit = src0_bits == 16 or src1_bits == 16
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
|
||||
|
||||
is_float, pcode = any(x in op_name for x in ('_F32', '_F64', '_F16')), get_pcode(inst.op)
|
||||
is_float, is_f64, pcode = any(x in op_name for x in ('_F32', '_F64', '_F16')), '_F64' in op_name, get_pcode(inst.op)
|
||||
def get_cmp_bit(lane) -> UOp:
|
||||
lc = lane.cast(dtypes.int) if isinstance(lane, UOp) else _c(lane, dtypes.int)
|
||||
s0 = ctx.rsrc_dyn(src0_off, lc, src0_bits, literal)
|
||||
s1 = _cond_hi16(vsrc1_hi, ctx.rsrc_dyn(src1_off, lc, src1_bits, literal)) if is_16bit else ctx.rsrc_dyn(src1_off, lc, src1_bits, literal)
|
||||
if is_16bit and opsel: s0, s1 = _apply_opsel(s0, 0, opsel), _apply_opsel(s1, 1, opsel)
|
||||
if is_float and (abs_bits or neg_bits):
|
||||
s0 = _apply_src_mods(s0, 0, abs_bits, neg_bits, is_16bit, src0_bits == 64)
|
||||
s1 = _apply_src_mods(s1, 1, abs_bits, neg_bits, is_16bit, src1_bits == 64)
|
||||
s0 = ctx.rsrc_dyn(src0_off, lc, bits['s0'], literal, is_f64)
|
||||
s1 = _cond_hi16(vsrc1_hi, ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64)) if bits['s0'] == 16 else ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64)
|
||||
if bits['s0'] == 16 and opsel: s0, s1 = _apply_opsel(s0, 0, opsel), _apply_opsel(s1, 1, opsel)
|
||||
if is_float:
|
||||
s0 = _apply_src_mods(s0, 0, abs_bits, neg_bits, bits['s0'])
|
||||
s1 = _apply_src_mods(s1, 1, abs_bits, neg_bits, bits['s1'])
|
||||
for dest, val in parse_pcode(pcode, {'S0': s0, 'S1': s1, 'laneId': lc})[1]:
|
||||
if '[laneId]' in dest and ('D0' in dest or 'EXEC' in dest): return val.cast(dtypes.uint32)
|
||||
return _c(0)
|
||||
@@ -664,7 +637,7 @@ def _compile_vopc(inst: VOPC | VOP3, ctx: _Ctx, opsel: int = 0, abs_bits: int =
|
||||
|
||||
def _compile_vop3(inst: VOP3, ctx: _Ctx) -> UOp:
|
||||
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
|
||||
sizes = getattr(inst, 'op_regs', {})
|
||||
bits = inst.canonical_op_bits
|
||||
opsel, op_name = getattr(inst, 'opsel', 0) or 0, _op_name(inst)
|
||||
|
||||
# Lane operations
|
||||
@@ -677,53 +650,48 @@ def _compile_vop3(inst: VOP3, ctx: _Ctx) -> UOp:
|
||||
|
||||
# Regular VOP3 - read operands dynamically
|
||||
lane = ctx.range()
|
||||
is_f16_op = 'F16' in op_name
|
||||
vdst_reg = ctx.inst_field(VOP3.vdst)
|
||||
src0_off = ctx.inst_field(VOP3.src0)
|
||||
src1_off = ctx.inst_field(VOP3.src1)
|
||||
src2_off = ctx.inst_field(VOP3.src2) if inst.src2 is not None else None
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
|
||||
src0 = ctx.rsrc_dyn_sized(src0_off, lane, sizes, 'src0', f16=is_f16_op, literal=literal)
|
||||
src1 = ctx.rsrc_dyn_sized(src1_off, lane, sizes, 'src1', f16=is_f16_op, literal=literal)
|
||||
src2 = ctx.rsrc_dyn_sized(src2_off, lane, sizes, 'src2', f16=is_f16_op, literal=literal) if src2_off is not None else None
|
||||
if _is_16bit_op(op_name):
|
||||
src0, src1 = _apply_opsel(src0, 0, opsel), _apply_opsel(src1, 1, opsel)
|
||||
if src2 is not None: src2 = _apply_opsel(src2, 2, opsel)
|
||||
ops = inst.canonical_operands
|
||||
src0 = ctx.rsrc_dyn(ctx.inst_field(VOP3.src0), lane, bits['s0'], literal, 's0' in ops and ops['s0'][0] == Fmt.FMT_NUM_F64)
|
||||
src1 = ctx.rsrc_dyn(ctx.inst_field(VOP3.src1), lane, bits['s1'], literal, 's1' in ops and ops['s1'][0] == Fmt.FMT_NUM_F64)
|
||||
src2 = ctx.rsrc_dyn(ctx.inst_field(VOP3.src2), lane, bits['s2'], literal, 's2' in ops and ops['s2'][0] == Fmt.FMT_NUM_F64)
|
||||
if bits['s0'] == 16:
|
||||
src0 = _apply_opsel(src0, 0, opsel)
|
||||
src1 = _apply_opsel(src1, 1, opsel)
|
||||
src2 = _apply_opsel(src2, 2, opsel)
|
||||
abs_bits, neg_bits = getattr(inst, 'abs', 0) or 0, getattr(inst, 'neg', 0) or 0
|
||||
is_16bit_op = _is_16bit_op(op_name)
|
||||
if abs_bits or neg_bits:
|
||||
src0 = _apply_src_mods(src0, 0, abs_bits, neg_bits, is_16bit_op, sizes.get('src0', 1) == 2)
|
||||
if src1 is not None: src1 = _apply_src_mods(src1, 1, abs_bits, neg_bits, is_16bit_op, sizes.get('src1', 1) == 2)
|
||||
if src2 is not None: src2 = _apply_src_mods(src2, 2, abs_bits, neg_bits, is_16bit_op, sizes.get('src2', 1) == 2)
|
||||
srcs = {'S0': src0, 'S1': src1}
|
||||
if src2 is not None: srcs['S2'] = src2
|
||||
src0 = _apply_src_mods(src0, 0, abs_bits, neg_bits, bits['s0'])
|
||||
src1 = _apply_src_mods(src1, 1, abs_bits, neg_bits, bits['s1'])
|
||||
src2 = _apply_src_mods(src2, 2, abs_bits, neg_bits, bits['s2'])
|
||||
srcs = {'S0': src0, 'S1': src1, 'S2': src2}
|
||||
if inst.op in (VOP3Op.V_CNDMASK_B32_E64, VOP3Op.V_CNDMASK_B16) and src2 is not None: srcs['VCC'] = src2
|
||||
# FMAC instructions need D0 (accumulator) from destination register
|
||||
if 'FMAC' in op_name: srcs['D0'] = ctx.rvgpr_dyn(vdst_reg, lane)
|
||||
opsel_dst_hi = bool(opsel & 0b1000) and _is_16bit_op(op_name)
|
||||
opsel_dst_hi = bool(opsel & 0b1000) and bits['d'] == 16
|
||||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=opsel_dst_hi, clmp=getattr(inst, 'clmp', 0))
|
||||
|
||||
def _compile_vop3sd(inst: VOP3SD, ctx: _Ctx) -> UOp:
|
||||
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
|
||||
sizes, pcode = getattr(inst, 'op_regs', {}), get_pcode(inst.op)
|
||||
bits, pcode, ops = inst.canonical_op_bits, get_pcode(inst.op), inst.canonical_operands
|
||||
|
||||
# Read operands dynamically from instruction encoding
|
||||
vdst_reg = ctx.inst_field(VOP3SD.vdst)
|
||||
sdst_off = ctx.inst_field(VOP3SD.sdst)
|
||||
src0_off = ctx.inst_field(VOP3SD.src0)
|
||||
src1_off = ctx.inst_field(VOP3SD.src1)
|
||||
src2_off = ctx.inst_field(VOP3SD.src2) if inst.src2 is not None else None
|
||||
vdst_reg, sdst_off = ctx.inst_field(VOP3SD.vdst), ctx.inst_field(VOP3SD.sdst)
|
||||
src0_off, src1_off, src2_off = ctx.inst_field(VOP3SD.src0), ctx.inst_field(VOP3SD.src1), ctx.inst_field(VOP3SD.src2)
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
|
||||
|
||||
has_carry_in = 'src2' in inst.operands and inst.operands['src2'][2] == OpType.OPR_SREG
|
||||
vcc_in_off = src2_off if has_carry_in and src2_off is not None else sdst_off
|
||||
has_carry_in = 's2' in ops and ops['s2'][2] == OpType.OPR_SREG
|
||||
vcc_in_off = src2_off if has_carry_in else sdst_off
|
||||
|
||||
def load_srcs(lane_uop):
|
||||
ret = {'VCC': ctx.rsgpr_dyn(vcc_in_off), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane_uop}
|
||||
ret['S0'] = ctx.rsrc_dyn(src0_off, lane_uop, bits['s0'], literal, ops['s0'][0] == Fmt.FMT_NUM_F64)
|
||||
ret['S1'] = ctx.rsrc_dyn(src1_off, lane_uop, bits['s1'], literal, ops['s1'][0] == Fmt.FMT_NUM_F64)
|
||||
if 's2' in ops: ret['S2'] = ctx.rsrc_dyn(src2_off, lane_uop, bits['s2'], literal, ops['s2'][0] == Fmt.FMT_NUM_F64)
|
||||
return ret
|
||||
|
||||
lane = ctx.range()
|
||||
src0 = ctx.rsrc_dyn_sized(src0_off, lane, sizes, 'src0', literal=literal)
|
||||
src1 = ctx.rsrc_dyn_sized(src1_off, lane, sizes, 'src1', literal=literal)
|
||||
src2 = ctx.rsrc_dyn_sized(src2_off, lane, sizes, 'src2', literal=literal) if src2_off is not None else None
|
||||
srcs = {'S0': src0, 'S1': src1, 'VCC': ctx.rsgpr_dyn(vcc_in_off), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane}
|
||||
if src2 is not None: srcs['S2'] = src2
|
||||
srcs = load_srcs(lane)
|
||||
_, assigns = parse_pcode(pcode, srcs)
|
||||
|
||||
has_per_lane_vcc = any('[laneId]' in dest for dest, _ in assigns if dest.startswith('VCC') or dest.startswith('D0.u64'))
|
||||
@@ -731,23 +699,15 @@ def _compile_vop3sd(inst: VOP3SD, ctx: _Ctx) -> UOp:
|
||||
# VCC computation: RANGE+REDUCE gets axis ID first (lower ID = runs first)
|
||||
# This ensures VCC reads source values BEFORE VGPR stores modify them
|
||||
def get_vcc_bit(lane_uop) -> UOp:
|
||||
s0, s1 = ctx.rsrc_dyn_sized(src0_off, lane_uop, sizes, 'src0', literal=literal), ctx.rsrc_dyn_sized(src1_off, lane_uop, sizes, 'src1', literal=literal)
|
||||
s2 = ctx.rsrc_dyn_sized(src2_off, lane_uop, sizes, 'src2', literal=literal) if src2_off is not None else None
|
||||
lane_srcs = {'S0': s0, 'S1': s1, 'VCC': ctx.rsgpr_dyn(vcc_in_off), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane_uop}
|
||||
if s2 is not None: lane_srcs['S2'] = s2
|
||||
vcc_bit = _c(0)
|
||||
for dest, val in parse_pcode(pcode, lane_srcs)[1]:
|
||||
for dest, val in parse_pcode(pcode, load_srcs(lane_uop))[1]:
|
||||
if dest.startswith('VCC') or (dest.startswith('D0.u64') and '[laneId]' in dest): vcc_bit = val.cast(dtypes.uint32)
|
||||
return vcc_bit
|
||||
final_vcc = ctx.unroll_lanes(get_vcc_bit, exec_mask)
|
||||
# VGPR stores: RANGE gets axis ID second (higher ID = runs after VCC loop)
|
||||
lane3 = ctx.range()
|
||||
s0, s1 = ctx.rsrc_dyn_sized(src0_off, lane3, sizes, 'src0', literal=literal), ctx.rsrc_dyn_sized(src1_off, lane3, sizes, 'src1', literal=literal)
|
||||
s2 = ctx.rsrc_dyn_sized(src2_off, lane3, sizes, 'src2', literal=literal) if src2_off is not None else None
|
||||
lane_srcs = {'S0': s0, 'S1': s1, 'VCC': ctx.rsgpr_dyn(vcc_in_off), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane3}
|
||||
if s2 is not None: lane_srcs['S2'] = s2
|
||||
d0_val = None
|
||||
for dest, val in parse_pcode(pcode, lane_srcs)[1]:
|
||||
for dest, val in parse_pcode(pcode, load_srcs(lane3))[1]:
|
||||
if dest.startswith('D0') and '[laneId]' not in dest: d0_val = val
|
||||
vgpr_stores = []
|
||||
if d0_val is not None:
|
||||
@@ -763,62 +723,52 @@ def _compile_vop3sd(inst: VOP3SD, ctx: _Ctx) -> UOp:
|
||||
else:
|
||||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, sdst_reg=inst.sdst.offset)
|
||||
|
||||
def _compile_vop3p(inst: VOP3P, ctx: _Ctx) -> UOp:
|
||||
lane, exec_mask = ctx.range(), ctx.rsgpr_dyn(_c(EXEC_LO.offset))
|
||||
# Read register fields dynamically for deduplication
|
||||
def _compile_wmma(inst: VOP3P, ctx: _Ctx) -> UOp:
|
||||
op_name = _op_name(inst)
|
||||
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
|
||||
vdst_reg = ctx.inst_field(VOP3P.vdst)
|
||||
src0_off = ctx.inst_field(VOP3P.src0)
|
||||
src1_off = ctx.inst_field(VOP3P.src1)
|
||||
src2_off = ctx.inst_field(VOP3P.src2) if hasattr(inst, 'src2') and inst.src2 is not None else None
|
||||
src0 = ctx.rsrc_dyn(src0_off, lane, 16)
|
||||
src1 = ctx.rsrc_dyn(src1_off, lane, 16)
|
||||
src2 = ctx.rsrc_dyn(src2_off, lane, 16) if src2_off is not None else None
|
||||
src0_r = ctx.inst_field(VOP3P.src0) - _c(256)
|
||||
src1_r = ctx.inst_field(VOP3P.src1) - _c(256)
|
||||
src2_r = ctx.inst_field(VOP3P.src2) - _c(256)
|
||||
is_f16_output = 'F16_16X16X16_F16' in op_name or 'BF16_16X16X16_BF16' in op_name # F16/BF16 output vs F32 output
|
||||
is_bf16 = 'BF16' in op_name
|
||||
cvt = _FUNCS['bf16_to_f32'] if is_bf16 else _FUNCS['f16_to_f32']
|
||||
def read_f16_mat(src):
|
||||
return [f for l in range(16) for r in range(8) for v in [ctx.rvgpr_dyn(src + _c(r), UOp.const(dtypes.int, l))]
|
||||
for f in [cvt(v & UOp.const(dtypes.uint32, 0xFFFF)), cvt(v >> UOp.const(dtypes.uint32, 16))]]
|
||||
mat_a, mat_b = read_f16_mat(src0_r), read_f16_mat(src1_r)
|
||||
if is_f16_output:
|
||||
# RDNA3 F16/BF16 output: uses 8 VGPRs (same as F32), f16/bf16 values in lo 16 bits of each VGPR
|
||||
# Layout: half16 per lane where even indices (0,2,4,...,14) = lo halves of VGPRs 0-7
|
||||
# Read accumulator: 8 regs × 32 lanes, each VGPR's lo 16 bits holds one f16/bf16
|
||||
mat_c = [cvt(ctx.rvgpr_dyn(src2_r + _c(i // 32), UOp.const(dtypes.int, i % 32)) & UOp.const(dtypes.uint32, 0xFFFF))
|
||||
for i in range(256)]
|
||||
mat_d = [sum(mat_a[row*16+k] * mat_b[col*16+k] for k in range(16)) + mat_c[row*16+col] for row in range(16) for col in range(16)]
|
||||
# Write f16/bf16 results to lo 16 bits of each VGPR
|
||||
def f32_to_f16_bits(v: UOp) -> UOp: return v.cast(dtypes.half).bitcast(dtypes.uint16).cast(dtypes.uint32)
|
||||
def f32_to_bf16_bits(v: UOp) -> UOp: return (v.bitcast(dtypes.uint32) >> UOp.const(dtypes.uint32, 16)) & UOp.const(dtypes.uint32, 0xFFFF)
|
||||
out_cvt = f32_to_bf16_bits if is_bf16 else f32_to_f16_bits
|
||||
stores = [ctx.wvgpr_dyn(vdst_reg + _c(i // 32), UOp.const(dtypes.int, i % 32), out_cvt(mat_d[i]), exec_mask) for i in range(256)]
|
||||
else:
|
||||
# F32 output: accumulator and output are f32
|
||||
mat_c = [ctx.rvgpr_dyn(src2_r + _c(i // 32), UOp.const(dtypes.int, i % 32)).bitcast(dtypes.float32) for i in range(256)]
|
||||
mat_d = [sum(mat_a[row*16+k] * mat_b[col*16+k] for k in range(16)) + mat_c[row*16+col] for row in range(16) for col in range(16)]
|
||||
stores = [ctx.wvgpr_dyn(vdst_reg + _c(i // 32), UOp.const(dtypes.int, i % 32), mat_d[i].bitcast(dtypes.uint32), exec_mask) for i in range(256)]
|
||||
return UOp.sink(*stores, *ctx.inc_pc())
|
||||
|
||||
def _compile_vop3p(inst: VOP3P, ctx: _Ctx) -> UOp:
|
||||
op_name = _op_name(inst)
|
||||
if 'WMMA' in op_name and ('16X16X16_F16' in op_name or '16X16X16_BF16' in op_name): return _compile_wmma(inst, ctx)
|
||||
|
||||
lane = ctx.range()
|
||||
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
|
||||
vdst_reg = ctx.inst_field(VOP3P.vdst)
|
||||
src0 = ctx.rsrc_dyn(ctx.inst_field(VOP3P.src0), lane, 16)
|
||||
src1 = ctx.rsrc_dyn(ctx.inst_field(VOP3P.src1), lane, 16)
|
||||
src2 = ctx.rsrc_dyn(ctx.inst_field(VOP3P.src2), lane, 16)
|
||||
opsel, opsel_hi = getattr(inst, 'opsel', 0) or 0, getattr(inst, 'opsel_hi', 3) if getattr(inst, 'opsel_hi', 3) is not None else 3
|
||||
opsel_hi2 = getattr(inst, 'opsel_hi2', 1) if getattr(inst, 'opsel_hi2', 1) is not None else 1
|
||||
neg, neg_hi = getattr(inst, 'neg', 0) or 0, getattr(inst, 'neg_hi', 0) or 0
|
||||
def get_half_bits(val: UOp, use_hi: bool, apply_neg: bool = False) -> UOp:
|
||||
bits = ((val >> UOp.const(dtypes.uint32, 16)) if use_hi else val) & UOp.const(dtypes.uint32, 0xFFFF)
|
||||
if apply_neg: bits = bits.cast(dtypes.uint16).bitcast(dtypes.half).neg().bitcast(dtypes.uint16).cast(dtypes.uint32)
|
||||
return bits
|
||||
def build_remapped_src(src: UOp, opsel_lo_bit: int, opsel_hi_bit: int, neg_lo_bit: int, neg_hi_bit: int) -> UOp:
|
||||
return get_half_bits(src, bool(opsel_lo_bit), bool(neg_lo_bit)) | (get_half_bits(src, bool(opsel_hi_bit), bool(neg_hi_bit)) << UOp.const(dtypes.uint32, 16))
|
||||
s0_new = build_remapped_src(src0, opsel & 1, opsel_hi & 1, neg & 1, neg_hi & 1)
|
||||
s1_new = build_remapped_src(src1, opsel & 2, opsel_hi & 2, neg & 2, neg_hi & 2)
|
||||
s2_new = build_remapped_src(src2, opsel & 4, 1 if opsel_hi2 else 0, neg & 4, neg_hi & 4) if src2 is not None else None
|
||||
op_name = _op_name(inst)
|
||||
|
||||
# WMMA: Wave Matrix Multiply-Accumulate
|
||||
if 'WMMA' in op_name and ('16X16X16_F16' in op_name or '16X16X16_BF16' in op_name):
|
||||
# Dynamic register fields for deduplication
|
||||
src0_r = ctx.inst_field(VOP3P.src0) - _c(256)
|
||||
src1_r = ctx.inst_field(VOP3P.src1) - _c(256)
|
||||
src2_r = ctx.inst_field(VOP3P.src2) - _c(256)
|
||||
is_f16_output = 'F16_16X16X16_F16' in op_name or 'BF16_16X16X16_BF16' in op_name # F16/BF16 output vs F32 output
|
||||
is_bf16 = 'BF16' in op_name
|
||||
cvt = _FUNCS['bf16_to_f32'] if is_bf16 else _FUNCS['f16_to_f32']
|
||||
def read_f16_mat(src):
|
||||
return [f for l in range(16) for r in range(8) for v in [ctx.rvgpr_dyn(src + _c(r), UOp.const(dtypes.int, l))]
|
||||
for f in [cvt(v & UOp.const(dtypes.uint32, 0xFFFF)), cvt(v >> UOp.const(dtypes.uint32, 16))]]
|
||||
mat_a, mat_b = read_f16_mat(src0_r), read_f16_mat(src1_r)
|
||||
if is_f16_output:
|
||||
# RDNA3 F16/BF16 output: uses 8 VGPRs (same as F32), f16/bf16 values in lo 16 bits of each VGPR
|
||||
# Layout: half16 per lane where even indices (0,2,4,...,14) = lo halves of VGPRs 0-7
|
||||
# Read accumulator: 8 regs × 32 lanes, each VGPR's lo 16 bits holds one f16/bf16
|
||||
mat_c = [cvt(ctx.rvgpr_dyn(src2_r + _c(i // 32), UOp.const(dtypes.int, i % 32)) & UOp.const(dtypes.uint32, 0xFFFF))
|
||||
for i in range(256)]
|
||||
mat_d = [sum(mat_a[row*16+k] * mat_b[col*16+k] for k in range(16)) + mat_c[row*16+col] for row in range(16) for col in range(16)]
|
||||
# Write f16/bf16 results to lo 16 bits of each VGPR
|
||||
def f32_to_f16_bits(v: UOp) -> UOp: return v.cast(dtypes.half).bitcast(dtypes.uint16).cast(dtypes.uint32)
|
||||
def f32_to_bf16_bits(v: UOp) -> UOp: return (v.bitcast(dtypes.uint32) >> UOp.const(dtypes.uint32, 16)) & UOp.const(dtypes.uint32, 0xFFFF)
|
||||
out_cvt = f32_to_bf16_bits if is_bf16 else f32_to_f16_bits
|
||||
stores = [ctx.wvgpr_dyn(vdst_reg + _c(i // 32), UOp.const(dtypes.int, i % 32),
|
||||
out_cvt(mat_d[i]), exec_mask) for i in range(256)]
|
||||
else:
|
||||
# F32 output: accumulator and output are f32
|
||||
mat_c = [ctx.rvgpr_dyn(src2_r + _c(i // 32), UOp.const(dtypes.int, i % 32)).bitcast(dtypes.float32) for i in range(256)]
|
||||
mat_d = [sum(mat_a[row*16+k] * mat_b[col*16+k] for k in range(16)) + mat_c[row*16+col] for row in range(16) for col in range(16)]
|
||||
stores = [ctx.wvgpr_dyn(vdst_reg + _c(i // 32), UOp.const(dtypes.int, i % 32), mat_d[i].bitcast(dtypes.uint32), exec_mask) for i in range(256)]
|
||||
return UOp.sink(*stores, *ctx.inc_pc())
|
||||
|
||||
if 'FMA_MIX' in op_name:
|
||||
combined_opsel_hi = (opsel_hi & 0x3) | ((opsel_hi2 & 0x1) << 2)
|
||||
@@ -836,12 +786,20 @@ def _compile_vop3p(inst: VOP3P, ctx: _Ctx) -> UOp:
|
||||
return v ^ UOp.const(dtypes.uint32, 0x00008000) # f16 lo neg
|
||||
s0_mod = apply_neg_mix(apply_abs(src0, 1, 1, 1), 1, 1, 1)
|
||||
s1_mod = apply_neg_mix(apply_abs(src1, 2, 2, 2), 2, 2, 2)
|
||||
s2_mod = apply_neg_mix(apply_abs(src2, 4, 4, 4), 4, 4, 4) if src2 is not None else UOp.const(dtypes.uint32, 0)
|
||||
s2_mod = apply_neg_mix(apply_abs(src2, 4, 4, 4), 4, 4, 4)
|
||||
srcs = {'S0': s0_mod, 'S1': s1_mod, 'S2': s2_mod,
|
||||
'OPSEL_HI': UOp.const(dtypes.uint32, combined_opsel_hi), 'OPSEL': UOp.const(dtypes.uint32, opsel)}
|
||||
else:
|
||||
srcs = {'S0': s0_new, 'S1': s1_new}
|
||||
if s2_new is not None: srcs['S2'] = s2_new
|
||||
def get_half_bits(val: UOp, use_hi: bool, apply_neg: bool = False) -> UOp:
|
||||
bits = ((val >> UOp.const(dtypes.uint32, 16)) if use_hi else val) & UOp.const(dtypes.uint32, 0xFFFF)
|
||||
if apply_neg: bits = bits.cast(dtypes.uint16).bitcast(dtypes.half).neg().bitcast(dtypes.uint16).cast(dtypes.uint32)
|
||||
return bits
|
||||
def build_remapped_src(src: UOp, opsel_lo_bit: int, opsel_hi_bit: int, neg_lo_bit: int, neg_hi_bit: int) -> UOp:
|
||||
return get_half_bits(src, bool(opsel_lo_bit), bool(neg_lo_bit)) | (get_half_bits(src, bool(opsel_hi_bit), bool(neg_hi_bit)) << UOp.const(dtypes.uint32, 16))
|
||||
s0_new = build_remapped_src(src0, opsel & 1, opsel_hi & 1, neg & 1, neg_hi & 1)
|
||||
s1_new = build_remapped_src(src1, opsel & 2, opsel_hi & 2, neg & 2, neg_hi & 2)
|
||||
s2_new = build_remapped_src(src2, opsel & 4, 1 if opsel_hi2 else 0, neg & 4, neg_hi & 4)
|
||||
srcs = {'S0': s0_new, 'S1': s1_new, 'S2': s2_new}
|
||||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask)
|
||||
|
||||
def _compile_vopd(inst: VOPD, ctx: _Ctx) -> UOp:
|
||||
@@ -904,9 +862,8 @@ def _compile_mem_op(inst: DS | FLAT | GLOBAL | SCRATCH, ctx: _Ctx) -> UOp:
|
||||
# Dynamic saddr - read field, NULL (124) or >= 128 means no saddr
|
||||
saddr_reg = ctx.inst_field(type(inst).saddr) if hasattr(inst, 'saddr') else None
|
||||
|
||||
# Data width
|
||||
ndwords = 4 if '_B128' in op_name or 'B128' in op_name else 3 if '_B96' in op_name or 'B96' in op_name else 2 if '_B64' in op_name or 'B64' in op_name else 1
|
||||
is_64bit = ndwords >= 2 or '_U64' in op_name or '_I64' in op_name or '_F64' in op_name
|
||||
# Data width from canonical_op_bits (32/64/96/128), default to 32 for untyped ops
|
||||
data_bits_mem = inst.canonical_op_bits.get('data', 32)
|
||||
is_atomic, glc = 'ATOMIC' in op_name, getattr(inst, 'glc', 0)
|
||||
has_data1 = is_lds and hasattr(inst, 'data1') and inst.data1 is not None
|
||||
data1_reg = ctx.inst_field(DS.data1) if is_lds else _c(0)
|
||||
@@ -940,10 +897,13 @@ def _compile_mem_op(inst: DS | FLAT | GLOBAL | SCRATCH, ctx: _Ctx) -> UOp:
|
||||
def make_srcs(lane: UOp) -> dict:
|
||||
addr = make_addr(lane)
|
||||
if is_lds:
|
||||
if 'B128' in op_name or 'B96' in op_name:
|
||||
if data_bits_mem == 128:
|
||||
data = {'DATA': ctx.rvgpr_dyn(vdata_reg, lane), 'DATA1': ctx.rvgpr_dyn(vdata_reg + _c(1), lane),
|
||||
'DATA2': ctx.rvgpr_dyn(vdata_reg + _c(2), lane), 'DATA3': ctx.rvgpr_dyn(vdata_reg + _c(3), lane)}
|
||||
elif 'B32' in op_name:
|
||||
elif data_bits_mem == 96:
|
||||
data = {'DATA': ctx.rvgpr_dyn(vdata_reg, lane), 'DATA1': ctx.rvgpr_dyn(vdata_reg + _c(1), lane),
|
||||
'DATA2': ctx.rvgpr_dyn(vdata_reg + _c(2), lane)}
|
||||
elif data_bits_mem == 32:
|
||||
data = {'DATA': ctx.rvgpr_dyn(vdata_reg, lane), 'DATA2': ctx.rvgpr_dyn(data1_reg, lane) if has_data1 else UOp.const(dtypes.uint32, 0)}
|
||||
else:
|
||||
data = {'DATA': _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)),
|
||||
@@ -951,26 +911,27 @@ def _compile_mem_op(inst: DS | FLAT | GLOBAL | SCRATCH, ctx: _Ctx) -> UOp:
|
||||
return {'ADDR': addr, 'ADDR_BASE': addr, 'OFFSET': offset, 'OFFSET0': offset0, 'OFFSET1': offset1, '_lds': mem, 'laneId': lane, **data}
|
||||
active = _lane_active(exec_mask, lane)
|
||||
if is_atomic:
|
||||
return {'ADDR': addr, 'DATA': _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)) if is_64bit else ctx.rvgpr_dyn(vdata_reg, lane),
|
||||
return {'ADDR': addr, 'DATA': _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)) if data_bits_mem == 64 else ctx.rvgpr_dyn(vdata_reg, lane),
|
||||
'_vmem': mem, '_active': active, 'laneId': lane}
|
||||
vdata = ctx.rvgpr_dyn(vdata_reg, lane).cast(dtypes.uint64) if 'STORE' in op_name else ctx.rvgpr_dyn(vdst_reg, lane) if 'D16' in op_name else UOp.const(dtypes.uint32, 0)
|
||||
if 'STORE' in op_name and ndwords >= 2: vdata = vdata | (ctx.rvgpr_dyn(vdata_reg + _c(1), lane).cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
|
||||
if 'STORE' in op_name and data_bits_mem >= 64: vdata = vdata | (ctx.rvgpr_dyn(vdata_reg + _c(1), lane).cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
|
||||
srcs = {'ADDR': addr, 'VDATA': vdata, '_vmem': mem, '_active': active, 'laneId': lane}
|
||||
for i in range(ndwords): srcs[f'VDATA{i}'] = ctx.rvgpr_dyn(vdata_reg + _c(i), lane) if 'STORE' in op_name else UOp.const(dtypes.uint32, 0)
|
||||
for i in range(data_bits_mem // 32): srcs[f'VDATA{i}'] = ctx.rvgpr_dyn(vdata_reg + _c(i), lane) if 'STORE' in op_name else UOp.const(dtypes.uint32, 0)
|
||||
return srcs
|
||||
|
||||
def make_stores(dest: str, val: UOp, lane: UOp, active: UOp, writes_return_data: bool) -> list[UOp]:
|
||||
# Parse bit width from dest format: MEM[...].b32 or RETURN_DATA[63:32].b64
|
||||
parts = dest.rsplit('.', 1)
|
||||
data_bits = int(parts[1][1:]) if len(parts) == 2 else 32
|
||||
if dest.startswith('MEM['):
|
||||
if is_lds or is_atomic: return _write_val(dest, val[1], wmem, val[0], active, is_mem=True)
|
||||
data_bits = 8 if '.b8' in dest else 16 if '.b16' in dest else 64 if '.b64' in dest else 32
|
||||
if is_lds or is_atomic: return _write_val(data_bits, val[1], wmem, val[0], active, is_mem=True)
|
||||
if is_scratch: return _mem_store_bytes(mem, val[0], val[1], active, data_bits)
|
||||
return _mem_store(mem, val[0], val[1], active, 64, data_bits)
|
||||
if dest.startswith('RETURN_DATA') and writes_return_data:
|
||||
if (m := re.match(r'RETURN_DATA\[(\d+)\s*:\s*(\d+)\]', dest)):
|
||||
bit_width, dword_idx = int(m.group(1)) - int(m.group(2)) + 1, int(m.group(2)) // 32
|
||||
is_64 = '.b64' if bit_width == 64 else ''
|
||||
return _write_val(is_64, val, lambda r, v, l, e: ctx.wvgpr_dyn(r, l, v, e), vdst_reg + _c(dword_idx), lane, exec_mask)
|
||||
return _write_val(dest, val, lambda r, v, l, e: ctx.wvgpr_dyn(r, l, v, e), vdst_reg, lane, exec_mask)
|
||||
return _write_val(bit_width, val, lambda r, v, l, e: ctx.wvgpr_dyn(r, l, v, e), vdst_reg + _c(dword_idx), lane, exec_mask)
|
||||
return _write_val(data_bits, val, lambda r, v, l, e: ctx.wvgpr_dyn(r, l, v, e), vdst_reg, lane, exec_mask)
|
||||
return []
|
||||
|
||||
# DS-specific: check for 2ADDR pattern needing separate ranges
|
||||
|
||||
@@ -717,6 +717,22 @@ def _subst_loop_var(line: str, loop_var: str, val: int) -> str:
|
||||
subst_parts = [str(val) if t.type == 'IDENT' and t.val == loop_var else t.val for t in result_toks if t.type != 'EOF']
|
||||
return ' '.join(subst_parts)
|
||||
|
||||
def _set_bits(old: UOp, val: UOp, width: int, offset: int) -> UOp:
|
||||
"""Set bits [offset:offset+width) in old to val, masking and shifting appropriately."""
|
||||
mask = _u32(((1 << width) - 1) << offset)
|
||||
v = (val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val) & _u32((1 << width) - 1)
|
||||
return (old & (mask ^ _u32(0xFFFFFFFF))) | (v << _u32(offset))
|
||||
|
||||
def _find_paren_end(s: str, start: int = 0, open_ch: str = '(', close_ch: str = ')') -> int:
|
||||
"""Find index of matching close paren, starting after the open paren at start."""
|
||||
depth = 0
|
||||
for j, ch in enumerate(s[start:], start):
|
||||
if ch == open_ch: depth += 1
|
||||
elif ch == close_ch:
|
||||
depth -= 1
|
||||
if depth == 0: return j
|
||||
return len(s)
|
||||
|
||||
def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: dict | None = None,
|
||||
assigns: list | None = None) -> tuple[int, dict[str, VarVal], UOp | None]:
|
||||
"""Parse a block of pcode. Returns (next_line, block_assigns, return_value).
|
||||
@@ -724,7 +740,6 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
||||
if funcs is None: funcs = _FUNCS
|
||||
block_assigns: dict[str, VarVal] = {}
|
||||
i = start
|
||||
def ctx(): return {**vars, **block_assigns}
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
@@ -738,7 +753,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
||||
# return expr (lambda bodies)
|
||||
if first == 'return':
|
||||
rest = line[line.lower().find('return') + 6:].strip()
|
||||
return i + 1, block_assigns, parse_expr(rest, ctx(), funcs)
|
||||
return i + 1, block_assigns, parse_expr(rest, vars, funcs)
|
||||
|
||||
# for loop
|
||||
if first == 'for':
|
||||
@@ -747,21 +762,18 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
||||
p.eat_val('for', 'IDENT')
|
||||
loop_var = p.eat('IDENT').val
|
||||
p.eat_val('in', 'IDENT')
|
||||
if p.at('NUM') and p.peek(1).type == 'QUOTE': p.eat('NUM'); p.eat('QUOTE')
|
||||
if p.at('NUM'):
|
||||
start_val = int(p.eat('NUM').val.rstrip('UuLl'))
|
||||
else:
|
||||
start_expr = p.parse()
|
||||
start_val = int(start_expr.arg) if start_expr.op == Ops.CONST else 0
|
||||
def parse_bound():
|
||||
if p.at('NUM') and p.peek(1).type == 'QUOTE': p.eat('NUM'); p.eat('QUOTE')
|
||||
if p.at('NUM'): return int(p.eat('NUM').val.rstrip('UuLl'))
|
||||
expr = p.parse()
|
||||
return int(expr.arg) if expr.op == Ops.CONST else 0
|
||||
start_val = parse_bound()
|
||||
p.eat('COLON')
|
||||
if p.at('NUM') and p.peek(1).type == 'QUOTE': p.eat('NUM'); p.eat('QUOTE')
|
||||
if p.at('NUM'):
|
||||
end_val = int(p.eat('NUM').val.rstrip('UuLl'))
|
||||
else:
|
||||
end_expr = p.parse()
|
||||
end_val = int(end_expr.arg) if end_expr.op == Ops.CONST else 0
|
||||
end_val = parse_bound()
|
||||
# Collect body
|
||||
i += 1; body_lines, depth = [], 1
|
||||
i += 1
|
||||
body_lines: list[str] = []
|
||||
depth = 1
|
||||
while i < len(lines) and depth > 0:
|
||||
btoks = tokenize(lines[i])
|
||||
if btoks[0].type == 'IDENT':
|
||||
@@ -775,7 +787,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
||||
if found_var: vars[found_var] = block_assigns[found_var] = _const(dtypes.bool, False)
|
||||
for loop_i in range(start_val, end_val + 1):
|
||||
subst_lines = [_subst_loop_var(bl, loop_var, loop_i) for bl in body_lines if not (has_break and bl.strip().lower() == 'break')]
|
||||
_, iter_assigns, _ = parse_block(subst_lines, 0, {**vars, **block_assigns}, funcs, assigns)
|
||||
_, iter_assigns, _ = parse_block(subst_lines, 0, vars, funcs, assigns)
|
||||
if has_break:
|
||||
assert found_var is not None
|
||||
found = block_assigns.get(found_var, vars.get(found_var))
|
||||
@@ -791,7 +803,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
||||
if bl_l.startswith('if ') and bl_l.endswith(' then'):
|
||||
if any(body_lines[k].strip().lower() == 'break' for k in range(j+1, len(body_lines))):
|
||||
cond_str = _subst_loop_var(bl.strip()[3:-5].strip(), loop_var, loop_i)
|
||||
cond = _to_bool(parse_expr(cond_str, {**vars, **block_assigns}, funcs))
|
||||
cond = _to_bool(parse_expr(cond_str, vars, funcs))
|
||||
block_assigns[found_var] = vars[found_var] = not_found.where(cond, found)
|
||||
break
|
||||
else:
|
||||
@@ -806,25 +818,17 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
||||
# lambda definition
|
||||
if first != '{' and '=' in line and 'lambda' in line and any(t.type == 'IDENT' and t.val == 'lambda' for t in toks):
|
||||
name = toks[0].val
|
||||
body_start, depth = line[line.find('(', line.find('lambda')):], 0
|
||||
params_end = 0
|
||||
for j, ch in enumerate(body_start):
|
||||
if ch == '(': depth += 1
|
||||
elif ch == ')':
|
||||
depth -= 1
|
||||
if depth == 0: params_end = j + 1; break
|
||||
body_start = line[line.find('(', line.find('lambda')):]
|
||||
params_end = _find_paren_end(body_start) + 1
|
||||
params = [p.strip() for p in body_start[1:params_end-1].split(',') if p.strip()]
|
||||
rest = body_start[params_end:].strip()
|
||||
if rest.startswith('('):
|
||||
depth, body_end = 1, 1
|
||||
for j, ch in enumerate(rest[1:], 1):
|
||||
if ch == '(': depth += 1
|
||||
elif ch == ')':
|
||||
depth -= 1
|
||||
if depth == 0: body_end = j; break
|
||||
body = rest[1:body_end].strip()
|
||||
if depth > 0:
|
||||
body_lines_lst = [rest[1:]]
|
||||
body_end = _find_paren_end(rest)
|
||||
if body_end < len(rest): # found matching paren on same line
|
||||
body = rest[1:body_end].strip()
|
||||
i += 1
|
||||
else: # multiline body
|
||||
body_lines_lst, depth = [rest[1:]], 1
|
||||
i += 1
|
||||
while i < len(lines) and depth > 0:
|
||||
for j, ch in enumerate(lines[i]):
|
||||
@@ -835,21 +839,20 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
||||
else: body_lines_lst.append(lines[i])
|
||||
i += 1
|
||||
body = '\n'.join(body_lines_lst).strip()
|
||||
else: i += 1
|
||||
vars[name] = ('lambda', params, body)
|
||||
continue
|
||||
|
||||
# MEM assignment: MEM[addr].type (+|-)?= value
|
||||
if first == 'mem' and toks[1].type == 'LBRACKET':
|
||||
j, addr_toks = _match_bracket(toks, 1)
|
||||
addr = parse_tokens(addr_toks, ctx(), funcs)
|
||||
addr = parse_tokens(addr_toks, vars, funcs)
|
||||
if j < len(toks) and toks[j].type == 'DOT': j += 1
|
||||
dt_name = toks[j].val if j < len(toks) and toks[j].type == 'IDENT' else 'u32'
|
||||
dt, j = DTYPES.get(dt_name, dtypes.uint32), j + 1
|
||||
compound_op = None
|
||||
if j < len(toks) and toks[j].type == 'ASSIGN_OP': compound_op = toks[j].val; j += 1
|
||||
elif j < len(toks) and toks[j].type == 'EQUALS': j += 1
|
||||
rhs = parse_tokens(toks[j:], ctx(), funcs)
|
||||
rhs = parse_tokens(toks[j:], vars, funcs)
|
||||
if compound_op:
|
||||
mem = vars.get('_vmem') if '_vmem' in vars else vars.get('_lds')
|
||||
if isinstance(mem, UOp):
|
||||
@@ -868,7 +871,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
||||
if j < len(toks) and toks[j].type == 'LBRACKET':
|
||||
j, reg_toks = _match_bracket(toks, j)
|
||||
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
|
||||
ln, rg, val = parse_tokens(lane_toks, ctx(), funcs), parse_tokens(reg_toks, ctx(), funcs), parse_tokens(toks[j:], ctx(), funcs)
|
||||
ln, rg, val = parse_tokens(lane_toks, vars, funcs), parse_tokens(reg_toks, vars, funcs), parse_tokens(toks[j:], vars, funcs)
|
||||
if assigns is not None: assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}]', (_to_u32(rg) * _u32(32) + _to_u32(ln), val)))
|
||||
i += 1; continue
|
||||
|
||||
@@ -884,7 +887,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
||||
j += 3
|
||||
if j < len(toks) and toks[j].type == 'RBRACE': j += 1
|
||||
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
|
||||
val = parse_tokens(toks[j:], ctx(), funcs)
|
||||
val = parse_tokens(toks[j:], vars, funcs)
|
||||
lo_dt, hi_dt = DTYPES.get(lo_type, dtypes.uint64), DTYPES.get(hi_type, dtypes.uint32)
|
||||
lo_bits = 64 if lo_dt in (dtypes.uint64, dtypes.int64) else 32
|
||||
lo_val = val.cast(lo_dt) if val.dtype.itemsize * 8 <= lo_bits else (val & _const(val.dtype, (1 << lo_bits) - 1)).cast(lo_dt)
|
||||
@@ -894,7 +897,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
||||
if assigns is not None: assigns.extend([(f'{lo_var}.{lo_type}', lo_val), (f'{hi_var}.{hi_type}', hi_val)])
|
||||
i += 1; continue
|
||||
|
||||
# Bit slice: var[hi:lo] = value or var.type[hi:lo] = value
|
||||
# Bit slice/index: var[hi:lo] = value, var.type[hi:lo] = value, or var[expr] = value
|
||||
if len(toks) >= 5 and toks[0].type == 'IDENT' and (toks[1].type == 'LBRACKET' or (toks[1].type == 'DOT' and toks[3].type == 'LBRACKET')):
|
||||
bracket_start = 2 if toks[1].type == 'LBRACKET' else 4
|
||||
j = bracket_start
|
||||
@@ -902,24 +905,33 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
||||
while j < len(toks) and toks[j].type != 'RBRACKET':
|
||||
if toks[j].type == 'COLON': colon_pos = j
|
||||
j += 1
|
||||
if colon_pos is not None:
|
||||
var = toks[0].val
|
||||
if colon_pos is not None: # bit slice: var[hi:lo]
|
||||
hi_str = ' '.join(t.val for t in toks[bracket_start:colon_pos] if t.type != 'EOF')
|
||||
lo_str = ' '.join(t.val for t in toks[colon_pos+1:j] if t.type != 'EOF')
|
||||
try:
|
||||
hi, lo = max(int(eval(hi_str)), int(eval(lo_str))), min(int(eval(hi_str)), int(eval(lo_str)))
|
||||
var = toks[0].val
|
||||
hi_val, lo_val = int(eval(hi_str)), int(eval(lo_str))
|
||||
hi, lo = max(hi_val, lo_val), min(hi_val, lo_val)
|
||||
j += 1
|
||||
if j < len(toks) and toks[j].type == 'DOT': j += 2
|
||||
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
|
||||
val = parse_tokens(toks[j:], ctx(), funcs)
|
||||
val = parse_tokens(toks[j:], vars, funcs)
|
||||
dt_suffix = toks[2].val if toks[1].type == 'DOT' else None
|
||||
if assigns is not None: assigns.append((f'{var}[{hi}:{lo}]' + (f'.{dt_suffix}' if dt_suffix else ''), val))
|
||||
if var not in vars: vars[var] = _const(dtypes.uint64 if hi >= 32 else dtypes.uint32, 0)
|
||||
old = block_assigns.get(var, vars.get(var))
|
||||
mask = _u32(((1 << (hi - lo + 1)) - 1) << lo)
|
||||
block_assigns[var] = vars[var] = (old & (mask ^ _u32(0xFFFFFFFF))) | (_val_to_bits(val) << _u32(lo))
|
||||
block_assigns[var] = vars[var] = _set_bits(old, _val_to_bits(val), hi - lo + 1, lo)
|
||||
i += 1; continue
|
||||
except: pass
|
||||
elif toks[1].type == 'LBRACKET': # bit index: var[expr] (only for var[...], not var.type[...])
|
||||
existing = block_assigns.get(var, vars.get(var))
|
||||
if existing is not None and isinstance(existing, UOp) and not any(f'{var}{k}' in vars or f'{var}{k}' in block_assigns for k in range(8)):
|
||||
bit_toks = toks[2:j]
|
||||
j += 1
|
||||
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
|
||||
if j < len(toks):
|
||||
block_assigns[var] = vars[var] = _set_bit(existing, _to_u32(parse_tokens(bit_toks, vars, funcs)), parse_tokens(toks[j+1:], vars, funcs))
|
||||
i += 1; continue
|
||||
|
||||
# Array element: var{idx} = value
|
||||
if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'LBRACE' and toks[2].type == 'NUM':
|
||||
@@ -927,7 +939,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
||||
j = 4
|
||||
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
|
||||
if j < len(toks):
|
||||
val = parse_tokens(toks[j+1:], ctx(), funcs)
|
||||
val = parse_tokens(toks[j+1:], vars, funcs)
|
||||
existing = block_assigns.get(var, vars.get(var))
|
||||
if existing is not None and isinstance(existing, UOp):
|
||||
block_assigns[var] = vars[var] = _set_bit(existing, _u32(idx), val)
|
||||
@@ -936,121 +948,103 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
||||
i += 1; continue
|
||||
|
||||
# Compound assignment: var += or var -=
|
||||
for j, t in enumerate(toks):
|
||||
if t.type == 'ASSIGN_OP':
|
||||
assign_op = next((j for j, t in enumerate(toks) if t.type == 'ASSIGN_OP'), None)
|
||||
if assign_op is not None:
|
||||
var = toks[0].val
|
||||
old = block_assigns.get(var, vars.get(var, _u32(0)))
|
||||
rhs = parse_tokens(toks[assign_op+1:], vars, funcs)
|
||||
if rhs.dtype != old.dtype: rhs = rhs.cast(old.dtype)
|
||||
block_assigns[var] = vars[var] = (old + rhs) if toks[assign_op].val == '+=' else (old - rhs)
|
||||
i += 1; continue
|
||||
|
||||
# Typed element: var.type[idx] = value
|
||||
if len(toks) >= 7 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET' and toks[4].type == 'NUM':
|
||||
var, dt_name, idx = toks[0].val, toks[2].val, int(toks[4].val)
|
||||
dt = DTYPES.get(dt_name, dtypes.uint32)
|
||||
j = 6
|
||||
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
|
||||
if j < len(toks):
|
||||
val, old = parse_tokens(toks[j+1:], vars, funcs), block_assigns.get(var, vars.get(var, _u32(0)))
|
||||
bw = dt.itemsize * 8
|
||||
block_assigns[var] = vars[var] = _set_bits(old, val, bw, idx * bw)
|
||||
if assigns is not None: assigns.append((f'{var}.{dt_name}[{idx}]', val))
|
||||
i += 1; continue
|
||||
|
||||
# Dynamic bit: var.type[expr_with_brackets] = value
|
||||
if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET':
|
||||
j, depth, has_inner = 4, 1, False
|
||||
while j < len(toks) and depth > 0:
|
||||
if toks[j].type == 'LBRACKET': depth += 1; has_inner = True
|
||||
elif toks[j].type == 'RBRACKET': depth -= 1
|
||||
j += 1
|
||||
if has_inner:
|
||||
var = toks[0].val
|
||||
old = block_assigns.get(var, vars.get(var, _u32(0)))
|
||||
rhs = parse_tokens(toks[j+1:], ctx(), funcs)
|
||||
if rhs.dtype != old.dtype: rhs = rhs.cast(old.dtype)
|
||||
block_assigns[var] = vars[var] = (old + rhs) if t.val == '+=' else (old - rhs)
|
||||
i += 1; break
|
||||
else:
|
||||
# Typed element: var.type[idx] = value
|
||||
if len(toks) >= 7 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET' and toks[4].type == 'NUM':
|
||||
var, dt_name, idx = toks[0].val, toks[2].val, int(toks[4].val)
|
||||
dt = DTYPES.get(dt_name, dtypes.uint32)
|
||||
j = 6
|
||||
bit_pos = _to_u32(parse_tokens(toks[4:j-1], vars, funcs))
|
||||
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
|
||||
if j < len(toks):
|
||||
val, old = parse_tokens(toks[j+1:], ctx(), funcs), block_assigns.get(var, vars.get(var, _u32(0)))
|
||||
bw, lo_bit = dt.itemsize * 8, idx * dt.itemsize * 8
|
||||
mask = _u32(((1 << bw) - 1) << lo_bit)
|
||||
block_assigns[var] = vars[var] = (old & (mask ^ _u32(0xFFFFFFFF))) | (((val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val) & _u32((1 << bw) - 1)) << _u32(lo_bit))
|
||||
if assigns is not None: assigns.append((f'{var}.{dt_name}[{idx}]', val))
|
||||
val = parse_tokens(toks[j+1:], vars, funcs)
|
||||
old = block_assigns.get(var, vars.get(var, _u32(0)))
|
||||
block_assigns[var] = vars[var] = _set_bit(old, bit_pos, val)
|
||||
i += 1; continue
|
||||
|
||||
# Dynamic bit: var.type[expr_with_brackets] = value
|
||||
if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET':
|
||||
j, depth, has_inner = 4, 1, False
|
||||
while j < len(toks) and depth > 0:
|
||||
if toks[j].type == 'LBRACKET': depth += 1; has_inner = True
|
||||
elif toks[j].type == 'RBRACKET': depth -= 1
|
||||
j += 1
|
||||
if has_inner:
|
||||
var = toks[0].val
|
||||
bit_pos = _to_u32(parse_tokens(toks[4:j-1], ctx(), funcs))
|
||||
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
|
||||
if j < len(toks):
|
||||
val = parse_tokens(toks[j+1:], ctx(), funcs)
|
||||
old, mask = block_assigns.get(var, vars.get(var, _u32(0))), _u32(1) << bit_pos
|
||||
block_assigns[var] = vars[var] = (old | mask) if val.op == Ops.CONST and val.arg == 1 else \
|
||||
(old & (mask ^ _u32(0xFFFFFFFF))) if val.op == Ops.CONST and val.arg == 0 else _set_bit(old, bit_pos, val)
|
||||
i += 1; continue
|
||||
|
||||
# Bit index: var[expr] = value (bit assignment to existing scalar)
|
||||
if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'LBRACKET':
|
||||
var = toks[0].val
|
||||
existing = block_assigns.get(var, vars.get(var))
|
||||
if existing is not None and isinstance(existing, UOp) and not any(f'{var}{k}' in vars or f'{var}{k}' in block_assigns for k in range(8)):
|
||||
j = 2
|
||||
while j < len(toks) and toks[j].type != 'RBRACKET': j += 1
|
||||
bit_toks = toks[2:j]
|
||||
j += 1
|
||||
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
|
||||
if j < len(toks):
|
||||
block_assigns[var] = vars[var] = _set_bit(existing, _to_u32(parse_tokens(bit_toks, ctx(), funcs)), parse_tokens(toks[j+1:], ctx(), funcs))
|
||||
i += 1; continue
|
||||
|
||||
# If/elsif/else - skip branches with statically false conditions (WAVE32/WAVE64)
|
||||
if first == 'if':
|
||||
def parse_cond(s, kw):
|
||||
ll = s.lower()
|
||||
return _to_bool(parse_expr(s[ll.find(kw) + len(kw):ll.rfind('then')].strip(), ctx(), funcs))
|
||||
def not_static_false(c): return c.op != Ops.CONST or c.arg is not False
|
||||
cond = parse_cond(line, 'if')
|
||||
conditions: list[tuple[UOp, UOp | dict[str, VarVal] | None]] = [(cond, None)] if not_static_false(cond) else []
|
||||
else_branch: tuple[UOp | None, dict[str, VarVal]] = (None, {})
|
||||
vars_snap = dict(vars)
|
||||
i += 1
|
||||
i, branch, ret = parse_block(lines, i, vars, funcs, assigns)
|
||||
if conditions: conditions[0] = (cond, ret if ret is not None else branch)
|
||||
vars.clear(); vars.update(vars_snap)
|
||||
while i < len(lines):
|
||||
ltoks = tokenize(lines[i])
|
||||
if ltoks[0].type != 'IDENT': break
|
||||
lf = ltoks[0].val.lower()
|
||||
if lf == 'elsif':
|
||||
c = parse_cond(lines[i], 'elsif')
|
||||
i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns)
|
||||
if not_static_false(c): conditions.append((c, ret if ret is not None else branch))
|
||||
vars.clear(); vars.update(vars_snap)
|
||||
elif lf == 'else':
|
||||
i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns)
|
||||
else_branch = (ret, branch)
|
||||
vars.clear(); vars.update(vars_snap)
|
||||
elif lf == 'endif': i += 1; break
|
||||
else: break
|
||||
# Check if any branch returned a value (lambda-style)
|
||||
if any(isinstance(br, UOp) for _, br in conditions):
|
||||
result = else_branch[0]
|
||||
for c, rv in reversed(conditions):
|
||||
if isinstance(rv, UOp) and isinstance(result, UOp):
|
||||
if rv.dtype != result.dtype and rv.dtype.itemsize == result.dtype.itemsize: result = result.cast(rv.dtype)
|
||||
result = c.where(rv, result)
|
||||
return i, block_assigns, result
|
||||
# Main style: merge variable assignments with WHERE
|
||||
else_assigns = else_branch[1]
|
||||
all_vars = set().union(*[ba.keys() for _, ba in conditions if isinstance(ba, dict)], else_assigns.keys())
|
||||
for var in all_vars:
|
||||
res: Any = else_assigns.get(var, block_assigns.get(var, vars.get(var, _u32(0))))
|
||||
for cond, ba in reversed(conditions):
|
||||
if isinstance(ba, dict) and var in ba:
|
||||
tv = ba[var]
|
||||
if isinstance(tv, UOp) and isinstance(res, UOp):
|
||||
res = cond.where(tv, res.cast(tv.dtype) if tv.dtype != res.dtype and tv.dtype.itemsize == res.dtype.itemsize else res)
|
||||
block_assigns[var] = vars[var] = res
|
||||
continue
|
||||
|
||||
# Regular assignment: var = value
|
||||
for j, t in enumerate(toks):
|
||||
if t.type == 'EQUALS':
|
||||
if any(toks[k].type == 'OP' and toks[k].val in ('<', '>', '!', '=') for k in range(j)): break
|
||||
base_var = toks[0].val
|
||||
block_assigns[base_var] = vars[base_var] = parse_tokens(toks[j+1:], ctx(), funcs)
|
||||
i += 1; break
|
||||
else: i += 1
|
||||
# If/elsif/else - skip branches with statically false conditions (WAVE32/WAVE64)
|
||||
if first == 'if':
|
||||
def parse_cond(s, kw):
|
||||
ll = s.lower()
|
||||
return _to_bool(parse_expr(s[ll.find(kw) + len(kw):ll.rfind('then')].strip(), vars, funcs))
|
||||
def not_static_false(c): return c.op != Ops.CONST or c.arg is not False
|
||||
cond = parse_cond(line, 'if')
|
||||
conditions: list[tuple[UOp, UOp | dict[str, VarVal] | None]] = [(cond, None)] if not_static_false(cond) else []
|
||||
else_branch: tuple[UOp | None, dict[str, VarVal]] = (None, {})
|
||||
vars_snap = dict(vars)
|
||||
i += 1
|
||||
i, branch, ret = parse_block(lines, i, vars, funcs, assigns)
|
||||
if conditions: conditions[0] = (cond, ret if ret is not None else branch)
|
||||
vars.clear(); vars.update(vars_snap)
|
||||
while i < len(lines):
|
||||
ltoks = tokenize(lines[i])
|
||||
if ltoks[0].type != 'IDENT': break
|
||||
lf = ltoks[0].val.lower()
|
||||
if lf == 'elsif':
|
||||
c = parse_cond(lines[i], 'elsif')
|
||||
i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns)
|
||||
if not_static_false(c): conditions.append((c, ret if ret is not None else branch))
|
||||
vars.clear(); vars.update(vars_snap)
|
||||
elif lf == 'else':
|
||||
i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns)
|
||||
else_branch = (ret, branch)
|
||||
vars.clear(); vars.update(vars_snap)
|
||||
elif lf == 'endif': i += 1; break
|
||||
else: break
|
||||
# Check if any branch returned a value (lambda-style)
|
||||
if any(isinstance(br, UOp) for _, br in conditions):
|
||||
result = else_branch[0]
|
||||
for c, rv in reversed(conditions):
|
||||
if isinstance(rv, UOp) and isinstance(result, UOp):
|
||||
if rv.dtype != result.dtype and rv.dtype.itemsize == result.dtype.itemsize: result = result.cast(rv.dtype)
|
||||
result = c.where(rv, result)
|
||||
return i, block_assigns, result
|
||||
# Main style: merge variable assignments with WHERE
|
||||
else_assigns = else_branch[1]
|
||||
all_vars = set().union(*[ba.keys() for _, ba in conditions if isinstance(ba, dict)], else_assigns.keys())
|
||||
for var in all_vars:
|
||||
res: Any = else_assigns.get(var, block_assigns.get(var, vars.get(var, _u32(0))))
|
||||
for cond, ba in reversed(conditions):
|
||||
if isinstance(ba, dict) and var in ba:
|
||||
tv = ba[var]
|
||||
if isinstance(tv, UOp) and isinstance(res, UOp):
|
||||
res = cond.where(tv, res.cast(tv.dtype) if tv.dtype != res.dtype and tv.dtype.itemsize == res.dtype.itemsize else res)
|
||||
block_assigns[var] = vars[var] = res
|
||||
continue
|
||||
continue
|
||||
|
||||
# Regular assignment: var = value
|
||||
for j, t in enumerate(toks):
|
||||
if t.type == 'EQUALS':
|
||||
if any(toks[k].type == 'OP' and toks[k].val in ('<', '>', '!', '=') for k in range(j)): break
|
||||
base_var = toks[0].val
|
||||
block_assigns[base_var] = vars[base_var] = parse_tokens(toks[j+1:], vars, funcs)
|
||||
i += 1; break
|
||||
else: i += 1
|
||||
return i, block_assigns, None
|
||||
|
||||
def parse_expr(expr: str, vars: dict[str, VarVal], funcs: dict | None = None) -> UOp:
|
||||
|
||||
@@ -138,6 +138,50 @@ class TestDS2AddrMore(unittest.TestCase):
|
||||
self.assertEqual(st.vgpr[0][4], 0x12345678, "v4 should be untouched")
|
||||
|
||||
|
||||
class TestDSB96(unittest.TestCase):
|
||||
"""Tests for DS_STORE_B96 and DS_LOAD_B96 (96-bit / 3 dwords)."""
|
||||
|
||||
def test_ds_store_load_b96(self):
|
||||
"""DS_STORE_B96 stores 3 VGPRs, DS_LOAD_B96 loads them back."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], 0),
|
||||
s_mov_b32(s[0], 0x11111111),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
s_mov_b32(s[0], 0x22222222),
|
||||
v_mov_b32_e32(v[1], s[0]),
|
||||
s_mov_b32(s[0], 0x33333333),
|
||||
v_mov_b32_e32(v[2], s[0]),
|
||||
ds_store_b96(addr=v[10], data0=v[0:2]),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
ds_load_b96(addr=v[10], vdst=v[4:6]),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][4], 0x11111111, "v4 should have first dword")
|
||||
self.assertEqual(st.vgpr[0][5], 0x22222222, "v5 should have second dword")
|
||||
self.assertEqual(st.vgpr[0][6], 0x33333333, "v6 should have third dword")
|
||||
|
||||
def test_ds_store_b96_with_offset(self):
|
||||
"""DS_STORE_B96 with non-zero offset."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], 0),
|
||||
s_mov_b32(s[0], 0xAAAAAAAA),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
s_mov_b32(s[0], 0xBBBBBBBB),
|
||||
v_mov_b32_e32(v[1], s[0]),
|
||||
s_mov_b32(s[0], 0xCCCCCCCC),
|
||||
v_mov_b32_e32(v[2], s[0]),
|
||||
DS(DSOp.DS_STORE_B96, addr=v[10], data0=v[0:2], offset0=12),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
DS(DSOp.DS_LOAD_B96, addr=v[10], vdst=v[4:6], offset0=12),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][4], 0xAAAAAAAA)
|
||||
self.assertEqual(st.vgpr[0][5], 0xBBBBBBBB)
|
||||
self.assertEqual(st.vgpr[0][6], 0xCCCCCCCC)
|
||||
|
||||
|
||||
class TestDSB128(unittest.TestCase):
|
||||
"""Tests for DS_STORE_B128 and DS_LOAD_B128 (128-bit / 4 dwords)."""
|
||||
|
||||
|
||||
@@ -265,6 +265,113 @@ class TestSLoadMultiDword(unittest.TestCase):
|
||||
self.assertEqual(st.sgpr[5], st.sgpr[9])
|
||||
|
||||
|
||||
class TestSLoadLarge(unittest.TestCase):
|
||||
"""Tests for large s_load operations (s_load_b256, s_load_b512)."""
|
||||
|
||||
def test_s_load_b256_basic(self):
|
||||
"""s_load_b256 loads 8 consecutive dwords."""
|
||||
instructions = [
|
||||
s_load_b64(s[2:3], s[80:81], 0, soffset=NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
# Store 8 test values
|
||||
s_mov_b32(s[20], 0x11111111),
|
||||
v_mov_b32_e32(v[2], s[20]),
|
||||
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET),
|
||||
s_mov_b32(s[20], 0x22222222),
|
||||
v_mov_b32_e32(v[2], s[20]),
|
||||
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+4),
|
||||
s_mov_b32(s[20], 0x33333333),
|
||||
v_mov_b32_e32(v[2], s[20]),
|
||||
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+8),
|
||||
s_mov_b32(s[20], 0x44444444),
|
||||
v_mov_b32_e32(v[2], s[20]),
|
||||
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+12),
|
||||
s_mov_b32(s[20], 0x55555555),
|
||||
v_mov_b32_e32(v[2], s[20]),
|
||||
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+16),
|
||||
s_mov_b32(s[20], 0x66666666),
|
||||
v_mov_b32_e32(v[2], s[20]),
|
||||
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+20),
|
||||
s_mov_b32(s[20], 0x77777777),
|
||||
v_mov_b32_e32(v[2], s[20]),
|
||||
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+24),
|
||||
s_mov_b32(s[20], 0x88888888),
|
||||
v_mov_b32_e32(v[2], s[20]),
|
||||
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET+28),
|
||||
s_waitcnt(vmcnt=0),
|
||||
*CACHE_INV,
|
||||
# Load all 8 dwords with s_load_b256
|
||||
s_load_b256(s[4:11], s[2:3], NULL, offset=TEST_OFFSET),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
s_mov_b32(s[2], 0), s_mov_b32(s[3], 0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.sgpr[4], 0x11111111)
|
||||
self.assertEqual(st.sgpr[5], 0x22222222)
|
||||
self.assertEqual(st.sgpr[6], 0x33333333)
|
||||
self.assertEqual(st.sgpr[7], 0x44444444)
|
||||
self.assertEqual(st.sgpr[8], 0x55555555)
|
||||
self.assertEqual(st.sgpr[9], 0x66666666)
|
||||
self.assertEqual(st.sgpr[10], 0x77777777)
|
||||
self.assertEqual(st.sgpr[11], 0x88888888)
|
||||
|
||||
def test_s_load_b512_basic(self):
|
||||
"""s_load_b512 loads 16 consecutive dwords."""
|
||||
instructions = [
|
||||
s_load_b64(s[2:3], s[80:81], 0, soffset=NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
# Store 16 test values (use a pattern: 0x10, 0x20, ..., 0x100)
|
||||
*[instr for i in range(16) for instr in [
|
||||
s_mov_b32(s[20], (i + 1) * 0x11111111),
|
||||
v_mov_b32_e32(v[2], s[20]),
|
||||
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET + i * 4),
|
||||
]],
|
||||
s_waitcnt(vmcnt=0),
|
||||
*CACHE_INV,
|
||||
# Load all 16 dwords with s_load_b512
|
||||
s_load_b512(s[64:79], s[2:3], NULL, offset=TEST_OFFSET),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
# Copy results to lower regs for verification (since st.sgpr only has 16 regs in test)
|
||||
s_mov_b32(s[4], s[64]),
|
||||
s_mov_b32(s[5], s[65]),
|
||||
s_mov_b32(s[6], s[78]),
|
||||
s_mov_b32(s[7], s[79]),
|
||||
s_mov_b32(s[2], 0), s_mov_b32(s[3], 0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.sgpr[4], 0x11111111, "first dword")
|
||||
self.assertEqual(st.sgpr[5], 0x22222222, "second dword")
|
||||
self.assertEqual(st.sgpr[6], 0xFFFFFFFF & (15 * 0x11111111), "15th dword")
|
||||
self.assertEqual(st.sgpr[7], 0xFFFFFFFF & (16 * 0x11111111), "16th dword")
|
||||
|
||||
def test_s_load_b256_with_register_offset(self):
|
||||
"""s_load_b256 with register offset should add reg offset to address."""
|
||||
instructions = [
|
||||
s_load_b64(s[2:3], s[80:81], 0, soffset=NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
# Store pattern at TEST_OFFSET+8: skip first 2 dwords
|
||||
*[instr for i in range(8) for instr in [
|
||||
s_mov_b32(s[20], (i + 1) * 0x11111111),
|
||||
v_mov_b32_e32(v[2], s[20]),
|
||||
global_store_b32(addr=v[0], data=v[2], saddr=s[2:3], offset=TEST_OFFSET + 8 + i * 4),
|
||||
]],
|
||||
s_waitcnt(vmcnt=0),
|
||||
*CACHE_INV,
|
||||
# Load with register offset 8
|
||||
s_mov_b32(s[20], 8),
|
||||
s_load_b256(s[4:11], s[2:3], s[20], offset=TEST_OFFSET),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
s_mov_b32(s[2], 0), s_mov_b32(s[3], 0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.sgpr[4], 0x11111111, "first dword at offset+8")
|
||||
self.assertEqual(st.sgpr[5], 0x22222222, "second dword at offset+8")
|
||||
self.assertEqual(st.sgpr[11], 0x88888888, "last dword at offset+8")
|
||||
|
||||
|
||||
class TestSLoadOffset(unittest.TestCase):
|
||||
"""Tests for s_load with different immediate offsets.
|
||||
|
||||
|
||||
@@ -719,5 +719,172 @@ class TestNullRegister(unittest.TestCase):
|
||||
self.assertEqual(st.scc, 0)
|
||||
|
||||
|
||||
class Test64BitSOP1InlineConstants(unittest.TestCase):
|
||||
"""Tests for 64-bit SOP1 instructions with inline constants.
|
||||
|
||||
Regression tests for bug where rsrc_dyn didn't properly handle 64-bit
|
||||
inline constants, incorrectly duplicating lo bits to hi instead of
|
||||
zero/sign-extending.
|
||||
"""
|
||||
|
||||
def test_s_mov_b64_inline_0(self):
|
||||
"""S_MOV_B64 with inline constant 0."""
|
||||
instructions = [
|
||||
s_mov_b64(s[0:1], 0),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][0], 0)
|
||||
self.assertEqual(st.vgpr[0][1], 0)
|
||||
|
||||
def test_s_mov_b64_inline_16(self):
|
||||
"""S_MOV_B64 with inline constant 16 should set lo=16, hi=0."""
|
||||
instructions = [
|
||||
s_mov_b64(s[0:1], 16),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][0], 16)
|
||||
self.assertEqual(st.vgpr[0][1], 0)
|
||||
|
||||
def test_s_mov_b64_inline_64(self):
|
||||
"""S_MOV_B64 with inline constant 64 (max positive)."""
|
||||
instructions = [
|
||||
s_mov_b64(s[0:1], 64),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][0], 64)
|
||||
self.assertEqual(st.vgpr[0][1], 0)
|
||||
|
||||
def test_s_mov_b64_inline_neg1(self):
|
||||
"""S_MOV_B64 with inline constant -1 should sign-extend."""
|
||||
instructions = [
|
||||
s_mov_b64(s[0:1], -1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][0], 0xFFFFFFFF)
|
||||
self.assertEqual(st.vgpr[0][1], 0xFFFFFFFF)
|
||||
|
||||
def test_s_mov_b64_inline_neg16(self):
|
||||
"""S_MOV_B64 with inline constant -16 should sign-extend."""
|
||||
instructions = [
|
||||
s_mov_b64(s[0:1], -16),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][0], 0xFFFFFFF0)
|
||||
self.assertEqual(st.vgpr[0][1], 0xFFFFFFFF)
|
||||
|
||||
def test_s_mov_b64_float_const_1_0(self):
|
||||
"""S_MOV_B64 with float inline constant 1.0 - casts F32 to F64."""
|
||||
instructions = [
|
||||
s_mov_b64(s[0:1], 1.0), # inline constant 242 (1.0f)
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
# Hardware casts F32 to F64: 1.0f64 = 0x3FF0000000000000
|
||||
self.assertEqual(st.vgpr[0][0], 0x00000000) # lo
|
||||
self.assertEqual(st.vgpr[0][1], 0x3FF00000) # hi
|
||||
|
||||
def test_s_or_b64_inline_constant(self):
|
||||
"""S_OR_B64 with 64-bit inline constant."""
|
||||
instructions = [
|
||||
s_mov_b64(s[0:1], 0),
|
||||
s_or_b64(s[2:3], s[0:1], 16),
|
||||
v_mov_b32_e32(v[0], s[2]),
|
||||
v_mov_b32_e32(v[1], s[3]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][0], 16)
|
||||
self.assertEqual(st.vgpr[0][1], 0)
|
||||
|
||||
def test_s_and_b64_inline_constant(self):
|
||||
"""S_AND_B64 with 64-bit inline constant."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0xFFFFFFFF),
|
||||
s_mov_b32(s[1], 0xFFFFFFFF),
|
||||
s_and_b64(s[2:3], s[0:1], 16),
|
||||
v_mov_b32_e32(v[0], s[2]),
|
||||
v_mov_b32_e32(v[1], s[3]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][0], 16)
|
||||
self.assertEqual(st.vgpr[0][1], 0)
|
||||
|
||||
|
||||
class Test64BitSOPLiterals(unittest.TestCase):
|
||||
"""Tests for 64-bit SOP instructions with 32-bit literals.
|
||||
|
||||
Tests the behavior when a 64-bit SOP instruction uses a 32-bit literal
|
||||
(offset 255 in instruction encoding). The literal is zero-extended to 64 bits.
|
||||
"""
|
||||
|
||||
def test_s_mov_b64_literal(self):
|
||||
"""S_MOV_B64 with 32-bit literal value - zero-extended to 64 bits."""
|
||||
instructions = [
|
||||
s_mov_b64(s[0:1], 0x12345678), # literal > 64, uses literal encoding
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][0], 0x12345678)
|
||||
self.assertEqual(st.vgpr[0][1], 0)
|
||||
|
||||
def test_s_or_b64_literal(self):
|
||||
"""S_OR_B64 with 32-bit literal value - zero-extended to 64 bits."""
|
||||
instructions = [
|
||||
s_mov_b64(s[0:1], 0),
|
||||
s_or_b64(s[2:3], s[0:1], 0x12345678), # literal
|
||||
v_mov_b32_e32(v[0], s[2]),
|
||||
v_mov_b32_e32(v[1], s[3]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][0], 0x12345678)
|
||||
self.assertEqual(st.vgpr[0][1], 0)
|
||||
|
||||
def test_s_and_b64_literal(self):
|
||||
"""S_AND_B64 with 32-bit literal value - zero-extended to 64 bits."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0xFFFFFFFF),
|
||||
s_mov_b32(s[1], 0xFFFFFFFF),
|
||||
s_and_b64(s[2:3], s[0:1], 0x12345678), # literal
|
||||
v_mov_b32_e32(v[0], s[2]),
|
||||
v_mov_b32_e32(v[1], s[3]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][0], 0x12345678)
|
||||
self.assertEqual(st.vgpr[0][1], 0)
|
||||
|
||||
def test_s_mov_b64_literal_negative(self):
|
||||
"""S_MOV_B64 with 0xFFFFFFFF literal - zero-extended (not sign-extended)."""
|
||||
instructions = [
|
||||
s_mov_b64(s[0:1], 0xFFFFFFFF), # -1 as 32-bit, but zero-extended to 64-bit
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][0], 0xFFFFFFFF)
|
||||
self.assertEqual(st.vgpr[0][1], 0) # zero-extended, not sign-extended
|
||||
|
||||
def test_s_mov_b64_literal_high_bit(self):
|
||||
"""S_MOV_B64 with 0x80000000 literal - zero-extended (not sign-extended)."""
|
||||
instructions = [
|
||||
s_mov_b64(s[0:1], 0x80000000), # high bit set, but zero-extended
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][0], 0x80000000)
|
||||
self.assertEqual(st.vgpr[0][1], 0) # zero-extended, not sign-extended
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1359,6 +1359,43 @@ class TestF64ToI64Conversion(unittest.TestCase):
|
||||
self.assertEqual(result, 5000000000)
|
||||
|
||||
|
||||
class TestB64VOPLiteral(unittest.TestCase):
|
||||
"""Tests for B64 VOP operations with literal encoding.
|
||||
|
||||
B64 operations (like V_LSHLREV_B64) should zero-extend the literal to 64 bits,
|
||||
NOT put it in the high 32 bits like F64 operations do.
|
||||
"""
|
||||
|
||||
def test_v_lshlrev_b64_literal_shift_amount(self):
|
||||
"""V_LSHLREV_B64 with literal shift amount (src0 is 32-bit)."""
|
||||
# Shift 1 left by 100 (0x64) - uses literal encoding for src0
|
||||
# Shift amount is 100 & 63 = 36, so 1 << 36 = 0x1000000000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 1),
|
||||
s_mov_b32(s[1], 0),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_lshlrev_b64(v[2:3], 100, v[0:1]), # 100 > 64, uses literal encoding
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
# lo = 0x00000000, hi = 0x00000010 = 1 << (36-32)
|
||||
self.assertEqual(st.vgpr[0][2], 0x00000000)
|
||||
self.assertEqual(st.vgpr[0][3], 0x00000010)
|
||||
|
||||
def test_v_lshlrev_b64_literal_value(self):
|
||||
"""V_LSHLREV_B64 with literal as the 64-bit value being shifted (src1).
|
||||
|
||||
B64 literals are zero-extended (not shifted to high bits like F64).
|
||||
0xDEADBEEF << 4 = 0xDEADBEEF0 = lo=0xEADBEEF0, hi=0x0000000D
|
||||
"""
|
||||
instructions = [
|
||||
v_lshlrev_b64(v[0:1], 4, 0xDEADBEEF), # shift literal left by 4
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][0], 0xEADBEEF0) # lo
|
||||
self.assertEqual(st.vgpr[0][1], 0x0000000D) # hi
|
||||
|
||||
|
||||
class TestWMMAMore(unittest.TestCase):
|
||||
"""More WMMA tests."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user