mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
assembly/amd: only reg emu
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -89,49 +89,32 @@ def _get_compiled() -> dict:
|
||||
if _COMPILED is None: _COMPILED = get_compiled_functions()
|
||||
return _COMPILED
|
||||
|
||||
# Flag indices: (is_64, has_d1, is_cmp, is_cmpx, is_div_scale, has_sdst, uses_vcc, uses_exec, used_regs)
|
||||
_F_IS_64, _F_HAS_D1, _F_IS_CMP, _F_IS_CMPX, _F_IS_DIV_SCALE, _F_HAS_SDST, _F_USES_VCC, _F_USES_EXEC, _F_USED_REGS = range(9)
|
||||
def _run_pcode(fn, op_cls, op, s0, s1, s2, d0, scc, vcc, lane, exec_mask, vdst_idx):
|
||||
"""Create Regs, run pseudocode, extract results."""
|
||||
# Determine flags from op_cls and op.name
|
||||
is_div_scale = 'DIV_SCALE' in op.name
|
||||
is_64 = op.name.endswith(('_B64', '_I64', '_U64', '_F64')) or op.name in ('V_MAD_U64_U32', 'V_MAD_I64_I32')
|
||||
is_cmp = op_cls.__name__ == 'VOPCOp' and not op.name.startswith('V_CMPX')
|
||||
is_cmpx = op_cls.__name__ == 'VOPCOp' and op.name.startswith('V_CMPX')
|
||||
has_sdst = op_cls.__name__ == 'VOP3SDOp'
|
||||
|
||||
def _run_pcode(fn, flags, s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, vgprs, src0_idx, vdst_idx):
|
||||
"""Create Regs, run pseudocode, extract results based on flags."""
|
||||
used_regs = flags[_F_USED_REGS]
|
||||
# Create only the Regs that are needed
|
||||
regs = {}
|
||||
if 'S0' in used_regs: regs['S0'] = Reg(s0)
|
||||
if 'S1' in used_regs: regs['S1'] = Reg(s1)
|
||||
if 'S2' in used_regs: regs['S2'] = Reg(s2)
|
||||
if 'D0' in used_regs: regs['D0'] = Reg(s0 if flags[_F_IS_DIV_SCALE] else d0)
|
||||
if 'D1' in used_regs: regs['D1'] = Reg(0)
|
||||
if 'SCC' in used_regs: regs['SCC'] = Reg(scc)
|
||||
if 'VCC' in used_regs: regs['VCC'] = Reg(vcc)
|
||||
if 'EXEC' in used_regs: regs['EXEC'] = Reg(exec_mask)
|
||||
if 'tmp' in used_regs: regs['tmp'] = Reg(0)
|
||||
if 'saveexec' in used_regs: regs['saveexec'] = Reg(exec_mask)
|
||||
if 'laneId' in used_regs: regs['laneId'] = lane
|
||||
if 'SIMM16' in used_regs: regs['SIMM16'] = Reg(literal)
|
||||
if 'SIMM32' in used_regs: regs['SIMM32'] = Reg(literal)
|
||||
if 'SRC0' in used_regs: regs['SRC0'] = Reg(src0_idx)
|
||||
if 'VDST' in used_regs: regs['VDST'] = Reg(vdst_idx)
|
||||
if 'VGPR' in used_regs: regs['VGPR'] = vgprs
|
||||
# Create Regs - D0 gets s0 for DIV_SCALE (passthrough behavior)
|
||||
S0, S1, S2 = Reg(s0), Reg(s1), Reg(s2)
|
||||
D0, D1 = Reg(s0 if is_div_scale else d0), Reg(0)
|
||||
SCC, VCC, EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)
|
||||
tmp = Reg(0)
|
||||
|
||||
# Call pseudocode with only the registers it needs
|
||||
ret = fn(**{r: regs[r] for r in used_regs})
|
||||
# Call pseudocode
|
||||
ret = fn(S0, S1, S2, D0, D1, SCC, VCC, EXEC, tmp, lane)
|
||||
|
||||
# Build result dict based on flags
|
||||
D0 = regs.get('D0')
|
||||
SCC = regs.get('SCC')
|
||||
VCC = regs.get('VCC')
|
||||
EXEC = regs.get('EXEC')
|
||||
D1 = regs.get('D1')
|
||||
|
||||
result = {'d0': D0._val if D0 is not None else d0, 'scc': (SCC._val & 1) if SCC is not None else (scc & 1)}
|
||||
if flags[_F_HAS_SDST]: result['vcc_lane'] = (VCC._val >> lane) & 1
|
||||
elif flags[_F_USES_VCC] and VCC is not None and VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1
|
||||
if flags[_F_IS_CMPX]: result['exec_lane'] = (EXEC._val >> lane) & 1
|
||||
elif flags[_F_USES_EXEC] and EXEC is not None and EXEC._val != exec_mask: result['exec'] = EXEC._val
|
||||
if flags[_F_IS_CMP]: result['vcc_lane'] = (D0._val >> lane) & 1
|
||||
if flags[_F_IS_64]: result['d0_64'] = True
|
||||
if flags[_F_HAS_D1]: result['d1'] = D1._val & 1
|
||||
# Build result
|
||||
result = {'d0': D0._val, 'scc': SCC._val & 1}
|
||||
if has_sdst or VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1
|
||||
if is_cmpx: result['exec_lane'] = (EXEC._val >> lane) & 1
|
||||
elif EXEC._val != exec_mask: result['exec'] = EXEC._val
|
||||
if is_cmp: result['vcc_lane'] = (D0._val >> lane) & 1
|
||||
if is_64: result['d0_64'] = True
|
||||
if D1._val: result['d1'] = D1._val & 1
|
||||
# V_WRITELANE_B32 returns (wr_lane, value) directly
|
||||
if ret is not None: result['vgpr_write'] = (ret[0], vdst_idx, ret[1])
|
||||
return result
|
||||
@@ -294,21 +277,20 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
|
||||
else: raise NotImplementedError(f"Unknown scalar type {inst_type}")
|
||||
|
||||
op = op_cls(inst.op)
|
||||
fn_flags = compiled.get(op_cls, {}).get(op)
|
||||
if fn_flags is None: raise NotImplementedError(f"{op.name} not in pseudocode")
|
||||
fn, flags = fn_flags
|
||||
fn = compiled.get(op_cls, {}).get(op)
|
||||
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
|
||||
|
||||
# Read sources - 64-bit ops need 64-bit source reads
|
||||
is_64bit_s0 = op.name.endswith(('_B64', '_I64', '_U64')) or '_U64_' in op.name or '_I64_' in op.name
|
||||
is_64bit_s0s1 = op_cls is SOPCOp and op in (SOPCOp.S_CMP_EQ_U64, SOPCOp.S_CMP_LG_U64)
|
||||
s0 = st.rsrc64(ssrc0, 0) if is_64bit_s0 or is_64bit_s0s1 else (st.rsrc(ssrc0, 0) if inst_type != SOPK else st.rsgpr(inst.sdst))
|
||||
is_64bit_sop2 = is_64bit_s0 and inst_type is SOP2
|
||||
s1 = st.rsrc64(inst.ssrc1, 0) if (is_64bit_sop2 or is_64bit_s0s1) else (st.rsrc(inst.ssrc1, 0) if inst_type in (SOP2, SOPC) else inst.simm16 if inst_type is SOPK else 0)
|
||||
s1 = st.rsrc64(inst.ssrc1, 0) if (is_64bit_sop2 or is_64bit_s0s1) else (st.rsrc(inst.ssrc1, 0) if inst_type in (SOP2, SOPC) else 0)
|
||||
s2 = inst.simm16 if inst_type is SOPK else 0 # SOPK: 16-bit immediate passed as S2
|
||||
d0 = st.rsgpr64(sdst) if (is_64bit_s0 or is_64bit_s0s1) and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0)
|
||||
literal = inst.simm16 if inst_type is SOPK else st.literal
|
||||
|
||||
# Execute and apply results
|
||||
result = _run_pcode(fn, flags, s0, s1, 0, d0, st.scc, st.vcc, 0, st.exec_mask, literal, None, 0, 0)
|
||||
result = _run_pcode(fn, op_cls, op, s0, s1, s2, d0, st.scc, st.vcc, 0, st.exec_mask, 0)
|
||||
if sdst is not None:
|
||||
if result.get('d0_64'): st.wsgpr64(sdst, result['d0'])
|
||||
else: st.wsgpr(sdst, result['d0'])
|
||||
@@ -360,16 +342,24 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
# Both ops execute simultaneously using pre-instruction values, so read all inputs first
|
||||
if inst_type is VOPD:
|
||||
vdsty = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1)
|
||||
# Read all source operands BEFORE any writes (dual-issue semantics)
|
||||
sx0, sx1 = st.rsrc(inst.srcx0, lane), V[inst.vsrcx1]
|
||||
sy0, sy1 = st.rsrc(inst.srcy0, lane), V[inst.vsrcy1]
|
||||
dx0, dy0 = V[inst.vdstx], V[vdsty]
|
||||
# FMAAK/FMAMK in VOPD use literal as S2
|
||||
literal = getattr(inst, '_literal', None) or 0
|
||||
res_x = res_y = None
|
||||
if (op_x := _VOPD_TO_VOP.get(inst.opx)):
|
||||
if (fn_flags := compiled.get(type(op_x), {}).get(op_x)):
|
||||
res_x = _run_pcode(fn_flags[0], fn_flags[1], sx0, sx1, 0, dx0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, 0, 0)
|
||||
if (fn := compiled.get(type(op_x), {}).get(op_x)):
|
||||
# opx 1=FMAMK, 2=FMAAK use literal
|
||||
sx2 = literal if inst.opx in (VOPDOp.V_DUAL_FMAMK_F32, VOPDOp.V_DUAL_FMAAK_F32) else 0
|
||||
res_x = _run_pcode(fn, type(op_x), op_x, sx0, sx1, sx2, dx0, st.scc, st.vcc, lane, st.exec_mask, 0)
|
||||
if (op_y := _VOPD_TO_VOP.get(inst.opy)):
|
||||
if (fn_flags := compiled.get(type(op_y), {}).get(op_y)):
|
||||
res_y = _run_pcode(fn_flags[0], fn_flags[1], sy0, sy1, 0, dy0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, 0, 0)
|
||||
if (fn := compiled.get(type(op_y), {}).get(op_y)):
|
||||
# opy 1=FMAMK, 2=FMAAK use literal
|
||||
sy2 = literal if inst.opy in (VOPDOp.V_DUAL_FMAMK_F32, VOPDOp.V_DUAL_FMAAK_F32) else 0
|
||||
res_y = _run_pcode(fn, type(op_y), op_y, sy0, sy1, sy2, dy0, st.scc, st.vcc, lane, st.exec_mask, 0)
|
||||
# Write results after both ops complete
|
||||
if res_x: V[inst.vdstx] = res_x['d0']
|
||||
if res_y: V[vdsty] = res_y['d0']
|
||||
return
|
||||
@@ -377,17 +367,19 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
# VOP3SD: has extra scalar dest for carry output
|
||||
if inst_type is VOP3SD:
|
||||
op = VOP3SDOp(inst.op)
|
||||
fn_flags = compiled.get(VOP3SDOp, {}).get(op)
|
||||
if fn_flags is None: raise NotImplementedError(f"{op.name} not in pseudocode")
|
||||
fn, flags = fn_flags
|
||||
fn = compiled.get(VOP3SDOp, {}).get(op)
|
||||
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
|
||||
s0, s1, s2 = st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane)
|
||||
# For 64-bit src2 ops (V_MAD_U64_U32, V_MAD_I64_I32), read from consecutive registers
|
||||
mad64_ops = (VOP3SDOp.V_MAD_U64_U32, VOP3SDOp.V_MAD_I64_I32)
|
||||
if op in mad64_ops:
|
||||
s2 = (V[inst.src2 - 256] | (V[inst.src2 - 256 + 1] << 32)) if inst.src2 >= 256 else st.rsgpr64(inst.src2)
|
||||
d0 = V[inst.vdst]
|
||||
# For carry-in ops (V_*_CO_CI_*), src2 register contains the carry bitmask (not VCC).
|
||||
# The pseudocode uses VCC but in VOP3SD encoding, the actual carry source is inst.src2.
|
||||
carry_ops = (VOP3SDOp.V_ADD_CO_CI_U32, VOP3SDOp.V_SUB_CO_CI_U32, VOP3SDOp.V_SUBREV_CO_CI_U32)
|
||||
vcc_for_exec = st.rsgpr64(inst.src2) if op in carry_ops else st.vcc
|
||||
result = _run_pcode(fn, flags, s0, s1, s2, d0, st.scc, vcc_for_exec, lane, st.exec_mask, st.literal, None, 0, inst.vdst)
|
||||
result = _run_pcode(fn, VOP3SDOp, op, s0, s1, s2, d0, st.scc, vcc_for_exec, lane, st.exec_mask, inst.vdst)
|
||||
if result.get('d0_64'):
|
||||
V[inst.vdst] = result['d0'] & 0xffffffff
|
||||
V[inst.vdst + 1] = (result['d0'] >> 32) & 0xffffffff
|
||||
@@ -402,15 +394,36 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
# Get op enum and sources (None means "no source" for that operand)
|
||||
if inst_type is VOP1:
|
||||
if inst.op == VOP1Op.V_NOP: return
|
||||
# V_READFIRSTLANE_B32: read from first active lane's VGPR -> SGPR (not in pseudocode - needs cross-lane access)
|
||||
if inst.op == VOP1Op.V_READFIRSTLANE_B32:
|
||||
first_lane = (st.exec_mask & -st.exec_mask).bit_length() - 1 if st.exec_mask else 0
|
||||
vgpr_idx = inst.src0 - 256 if inst.src0 >= 256 else inst.src0 # VGPR index
|
||||
st.wsgpr(inst.vdst, st.vgpr[first_lane][vgpr_idx])
|
||||
return
|
||||
op_cls, op, src0, src1, src2, vdst = VOP1Op, VOP1Op(inst.op), inst.src0, None, None, inst.vdst
|
||||
elif inst_type is VOP2:
|
||||
op_cls, op, src0, src1, src2, vdst = VOP2Op, VOP2Op(inst.op), inst.src0, inst.vsrc1 + 256, None, inst.vdst
|
||||
op_cls, op = VOP2Op, VOP2Op(inst.op)
|
||||
# FMAAK/FMAMK use inline literal constant as S2
|
||||
literal = getattr(inst, '_literal', None)
|
||||
src0, src1, src2, vdst = inst.src0, inst.vsrc1 + 256, literal, inst.vdst
|
||||
elif inst_type is VOP3:
|
||||
# VOP3 ops 0-255 are VOPC comparisons encoded as VOP3 (use VOPCOp pseudocode)
|
||||
if inst.op < 256:
|
||||
op_cls, op, src0, src1, src2, vdst = VOPCOp, VOPCOp(inst.op), inst.src0, inst.src1, None, inst.vdst
|
||||
else:
|
||||
op_cls, op, src0, src1, src2, vdst = VOP3Op, VOP3Op(inst.op), inst.src0, inst.src1, inst.src2, inst.vdst
|
||||
# V_READFIRSTLANE_B32 in VOP3 encoding - same as VOP1 but with VOP3 format
|
||||
if op == VOP3Op.V_READFIRSTLANE_B32:
|
||||
first_lane = (st.exec_mask & -st.exec_mask).bit_length() - 1 if st.exec_mask else 0
|
||||
vgpr_idx = inst.src0 - 256 if inst.src0 >= 256 else inst.src0
|
||||
st.wsgpr(inst.vdst, st.vgpr[first_lane][vgpr_idx])
|
||||
return
|
||||
# V_READLANE_B32: read from specific lane's VGPR -> SGPR (lane specified in src1)
|
||||
if op == VOP3Op.V_READLANE_B32:
|
||||
read_lane = st.rsrc(inst.src1, lane) & 0x1f # Lane to read from (5 bits)
|
||||
vgpr_idx = inst.src0 - 256 if inst.src0 >= 256 else inst.src0
|
||||
st.wsgpr(inst.vdst, st.vgpr[read_lane][vgpr_idx])
|
||||
return
|
||||
# V_PERM_B32: byte permutation - not in pseudocode PDF, implement directly
|
||||
# D0[byte_i] = selector[byte_i] < 8 ? {src0, src1}[selector[byte_i]] : (selector[byte_i] >= 0xD ? 0xFF : 0x00)
|
||||
if op == VOP3Op.V_PERM_B32:
|
||||
@@ -504,16 +517,15 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
s1 = (s1_hi << 16) | s1_lo
|
||||
s2 = (s2_hi << 16) | s2_lo
|
||||
vdst = inst.vdst
|
||||
fn_flags = compiled.get(VOP3POp, {}).get(op)
|
||||
if fn_flags is None: raise NotImplementedError(f"{op.name} not in pseudocode")
|
||||
result = _run_pcode(fn_flags[0], fn_flags[1], s0, s1, s2, 0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, 0, vdst)
|
||||
fn = compiled.get(VOP3POp, {}).get(op)
|
||||
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
|
||||
result = _run_pcode(fn, VOP3POp, op, s0, s1, s2, 0, st.scc, st.vcc, lane, st.exec_mask, vdst)
|
||||
st.vgpr[lane][vdst] = result['d0'] & 0xffffffff
|
||||
return
|
||||
else: raise NotImplementedError(f"Unknown vector type {inst_type}")
|
||||
|
||||
fn_flags = compiled.get(op_cls, {}).get(op)
|
||||
if fn_flags is None: raise NotImplementedError(f"{op.name} not in pseudocode")
|
||||
fn, flags = fn_flags
|
||||
fn = compiled.get(op_cls, {}).get(op)
|
||||
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
|
||||
|
||||
# Read sources (with VOP3 modifiers if applicable)
|
||||
neg, abs_ = (getattr(inst, 'neg', 0), getattr(inst, 'abs', 0)) if inst_type is VOP3 else (0, 0)
|
||||
@@ -528,25 +540,29 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
return val
|
||||
|
||||
# Determine if sources are 64-bit based on instruction type
|
||||
# For 64-bit shift ops: src0 is 32-bit (shift amount), src1 is 64-bit (value to shift)
|
||||
# For V_LDEXP_F64: src0 is 64-bit float, src1 is 32-bit integer exponent
|
||||
# For most other _B64/_I64/_U64/_F64 ops: all sources are 64-bit
|
||||
is_64bit_op = op.name.endswith(('_B64', '_I64', '_U64', '_F64'))
|
||||
is_ldexp_64 = op in (VOP3Op.V_LDEXP_F64,)
|
||||
is_shift_64 = op in (VOP3Op.V_LSHLREV_B64, VOP3Op.V_LSHRREV_B64, VOP3Op.V_ASHRREV_I64)
|
||||
is_16bit_src = op_cls is VOP3Op and op in _VOP3_16BIT_OPS and op not in _CVT_32_64_SRC_OPS
|
||||
is_vop2_16bit = op_cls is VOP2Op and op in _VOP2_16BIT_OPS
|
||||
is_vop2_16bit = op_cls is VOP2Op and op in _VOP2_16BIT_OPS # VOP2 16-bit ops use f16 inline constants
|
||||
|
||||
if is_shift_64:
|
||||
s0 = mod_src(st.rsrc(src0, lane), 0)
|
||||
s1 = st.rsrc64(src1, lane) if src1 is not None else 0
|
||||
s0 = mod_src(st.rsrc(src0, lane), 0) # shift amount is 32-bit
|
||||
s1 = st.rsrc64(src1, lane) if src1 is not None else 0 # value to shift is 64-bit
|
||||
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
|
||||
elif is_ldexp_64:
|
||||
s0 = mod_src64(st.rsrc64(src0, lane), 0)
|
||||
s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0
|
||||
s0 = mod_src64(st.rsrc64(src0, lane), 0) # mantissa is 64-bit float
|
||||
s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0 # exponent is 32-bit int
|
||||
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
|
||||
elif is_64bit_op:
|
||||
s0 = mod_src64(st.rsrc64(src0, lane), 0)
|
||||
s1 = mod_src64(st.rsrc64(src1, lane), 1) if src1 is not None else 0
|
||||
s2 = mod_src64(st.rsrc64(src2, lane), 2) if src2 is not None else 0
|
||||
elif is_16bit_src:
|
||||
# For 16-bit source ops, opsel bits select which half to use
|
||||
s0_raw = mod_src(st.rsrc(src0, lane), 0)
|
||||
s1_raw = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0
|
||||
s2_raw = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
|
||||
@@ -560,26 +576,31 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
else:
|
||||
s0 = mod_src(st.rsrc(src0, lane), 0)
|
||||
s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0
|
||||
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
|
||||
# src2 can be a register index OR a raw literal value (for FMAAK/FMAMK)
|
||||
# If src2 > 511, it's a raw literal value, not a register index
|
||||
s2 = src2 if src2 is not None and src2 > 511 else (mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0)
|
||||
d0 = V[vdst] if not is_64bit_op else (V[vdst] | (V[vdst + 1] << 32))
|
||||
|
||||
# V_CNDMASK_B32: VOP3 encoding uses src2 as mask (not VCC)
|
||||
# V_CNDMASK_B32: VOP3 encoding uses src2 as mask (not VCC); VOP2 uses VCC implicitly
|
||||
vcc_for_fn = st.rsgpr64(src2) if op in (VOP3Op.V_CNDMASK_B32,) and inst_type is VOP3 and src2 is not None and src2 < 256 else st.vcc
|
||||
src0_idx = (src0 - 256) if src0 is not None and src0 >= 256 else (src0 if src0 is not None else 0)
|
||||
|
||||
# Execute pseudocode
|
||||
result = _run_pcode(fn, flags, s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, st.literal, st.vgpr, src0_idx, vdst)
|
||||
result = _run_pcode(fn, op_cls, op, s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, vdst)
|
||||
|
||||
# Apply results
|
||||
if 'vgpr_write' in result:
|
||||
# Lane instruction wrote to VGPR: (lane, vgpr_idx, value)
|
||||
wr_lane, wr_idx, wr_val = result['vgpr_write']
|
||||
st.vgpr[wr_lane][wr_idx] = wr_val
|
||||
if 'vcc_lane' in result:
|
||||
# VOP2 carry instructions write carry to VCC implicitly; VOPC writes to vdst
|
||||
vcc_dst = VCC_LO if op_cls is VOP2Op and op in (VOP2Op.V_ADD_CO_CI_U32, VOP2Op.V_SUB_CO_CI_U32, VOP2Op.V_SUBREV_CO_CI_U32) else vdst
|
||||
st.pend_sgpr_lane(vcc_dst, lane, result['vcc_lane'])
|
||||
if 'exec_lane' in result:
|
||||
# V_CMPX instructions write to EXEC per-lane
|
||||
st.pend_sgpr_lane(EXEC_LO, lane, result['exec_lane'])
|
||||
if 'd0' in result and op_cls not in (VOPCOp,) and 'vgpr_write' not in result:
|
||||
# V_READFIRSTLANE_B32 and V_READLANE_B32 write to SGPR, not VGPR
|
||||
writes_to_sgpr = op in (VOP1Op.V_READFIRSTLANE_B32,) or \
|
||||
(op_cls is VOP3Op and op in (VOP3Op.V_READFIRSTLANE_B32, VOP3Op.V_READLANE_B32))
|
||||
is_16bit_dst = op in _VOP3_16BIT_DST_OPS or op in _VOP1_16BIT_DST_OPS
|
||||
@@ -589,6 +610,7 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
V[vdst] = result['d0'] & 0xffffffff
|
||||
V[vdst + 1] = (result['d0'] >> 32) & 0xffffffff
|
||||
elif is_16bit_dst and inst_type is VOP3:
|
||||
# VOP3 16-bit ops: opsel[3] controls hi/lo destination
|
||||
if opsel & 8: V[vdst] = (V[vdst] & 0x0000ffff) | ((result['d0'] & 0xffff) << 16)
|
||||
else: V[vdst] = (V[vdst] & 0xffff0000) | (result['d0'] & 0xffff)
|
||||
else:
|
||||
|
||||
@@ -657,9 +657,14 @@ def compile_pseudocode(pseudocode: str) -> str:
|
||||
return '\n'.join(lines)
|
||||
|
||||
def _assign(lhs: str, rhs: str) -> str:
|
||||
"""Generate assignment. Bare tmp/SCC/etc modify existing Reg._val."""
|
||||
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec'):
|
||||
"""Generate assignment. Outputs modify Reg in-place via ._val."""
|
||||
# Output registers and tmp: modify in-place so caller sees changes
|
||||
if lhs in ('SCC', 'VCC', 'EXEC', 'D0', 'D1', 'tmp'):
|
||||
return f"{lhs}._val = int({rhs})"
|
||||
# saveexec needs to be a new Reg for typed accessor access
|
||||
if lhs == 'saveexec':
|
||||
return f"{lhs} = Reg(int({rhs}))"
|
||||
# Other locals: natural style
|
||||
return f"{lhs} = {rhs}"
|
||||
|
||||
def _expr(e: str) -> str:
|
||||
@@ -982,30 +987,15 @@ from extra.assembly.amd.pcode import *
|
||||
code = code.replace(
|
||||
'D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))',
|
||||
'D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))')
|
||||
# Detect flags for result handling (stored in metadata, not generated code)
|
||||
is_64 = any(p in pc for p in ['D0.u64', 'D0.b64', 'D0.f64', 'D0.i64', 'D1.u64', 'D1.b64', 'D1.f64', 'D1.i64'])
|
||||
has_d1 = '{ D1' in pc
|
||||
if has_d1: is_64 = True
|
||||
is_cmp = cls_name == 'VOPCOp' and 'D0.u64[laneId]' in pc
|
||||
is_cmpx = cls_name == 'VOPCOp' and 'EXEC.u64[laneId]' in pc # V_CMPX writes to EXEC per-lane
|
||||
is_div_scale = 'DIV_SCALE' in op.name
|
||||
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
|
||||
combined = code + pc
|
||||
uses_vcc = 'VCC' in combined
|
||||
uses_exec = 'EXEC' in combined or 'EXEC_LO' in combined or 'EXEC_HI' in combined
|
||||
|
||||
# Determine which registers are used
|
||||
all_regs = ['S0', 'S1', 'S2', 'D0', 'D1', 'SCC', 'VCC', 'EXEC', 'tmp', 'saveexec', 'laneId', 'SIMM16', 'SIMM32', 'SRC0', 'VDST', 'VGPR']
|
||||
used_regs = [r for r in all_regs if r in combined]
|
||||
if 'EXEC_LO' in combined or 'EXEC_HI' in combined:
|
||||
if 'EXEC' not in used_regs: used_regs.append('EXEC')
|
||||
|
||||
# Generate pure pseudocode function - Regs passed directly as arguments
|
||||
# SIMM16/SIMM32 (inline literal constants) are passed as S2
|
||||
code = code.replace('SIMM16', 'S2').replace('SIMM32', 'S2')
|
||||
# Generate function with standard signature
|
||||
fn_name = f"_{cls_name}_{op.name}"
|
||||
lines.append(f"def {fn_name}({', '.join(used_regs)}):")
|
||||
lines.append(f"def {fn_name}(S0, S1, S2, D0, D1, SCC, VCC, EXEC, tmp, laneId):")
|
||||
for pc_line in pc.split('\n'):
|
||||
lines.append(f" # {pc_line}")
|
||||
# Add EXEC_LO/EXEC_HI if needed
|
||||
combined = code + pc
|
||||
if 'EXEC_LO' in combined: lines.append(" EXEC_LO = SliceProxy(EXEC, 31, 0)")
|
||||
if 'EXEC_HI' in combined: lines.append(" EXEC_HI = SliceProxy(EXEC, 63, 32)")
|
||||
code_lines = [line for line in code.split('\n') if line.strip()]
|
||||
@@ -1016,9 +1006,7 @@ from extra.assembly.amd.pcode import *
|
||||
lines.append(" pass")
|
||||
lines.append("")
|
||||
|
||||
# Build flags tuple: (is_64, has_d1, is_cmp, is_cmpx, is_div_scale, has_sdst, uses_vcc, uses_exec, used_regs)
|
||||
flags = (is_64, has_d1, is_cmp, is_cmpx, is_div_scale, has_sdst, uses_vcc, uses_exec, tuple(used_regs))
|
||||
fn_entries.append((op, fn_name, flags))
|
||||
fn_entries.append((op, fn_name))
|
||||
compiled_count += 1
|
||||
except Exception as e:
|
||||
print(f" Warning: Failed to compile {op.name}: {e}")
|
||||
@@ -1026,8 +1014,8 @@ from extra.assembly.amd.pcode import *
|
||||
|
||||
if fn_entries:
|
||||
lines.append(f'{cls_name}_FUNCTIONS = {{')
|
||||
for op, fn_name, flags in fn_entries:
|
||||
lines.append(f" {cls_name}.{op.name}: ({fn_name}, {flags}),")
|
||||
for op, fn_name in fn_entries:
|
||||
lines.append(f" {cls_name}.{op.name}: {fn_name},")
|
||||
lines.append('}')
|
||||
lines.append('')
|
||||
|
||||
@@ -1036,10 +1024,9 @@ from extra.assembly.amd.pcode import *
|
||||
if 'VOP3Op' in enum_names:
|
||||
lines.append('''
|
||||
# V_WRITELANE_B32: Write scalar to specific lane's VGPR (not in PDF pseudocode)
|
||||
def _VOP3Op_V_WRITELANE_B32(S0, S1):
|
||||
def _VOP3Op_V_WRITELANE_B32(S0, S1, S2, D0, D1, SCC, VCC, EXEC, tmp, laneId):
|
||||
return (int(S1) & 0x1f, int(S0) & 0xffffffff) # (wr_lane, value)
|
||||
# flags: (is_64, has_d1, is_cmp, is_cmpx, is_div_scale, has_sdst, uses_vcc, uses_exec, used_regs)
|
||||
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = (_VOP3Op_V_WRITELANE_B32, (False, False, False, False, False, False, False, False, ('S0', 'S1')))
|
||||
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32
|
||||
''')
|
||||
|
||||
lines.append('COMPILED_FUNCTIONS = {')
|
||||
|
||||
@@ -234,8 +234,8 @@ class TestPseudocodeRegressions(unittest.TestCase):
|
||||
s0 = 0x3f800000 # 1.0
|
||||
s1 = 0x40400000 # 3.0
|
||||
s2 = 0x3f800000 # 1.0 (numerator)
|
||||
fn, flags = VOP3SDOp_FUNCTIONS[VOP3SDOp.V_DIV_SCALE_F32]
|
||||
result = _run_pcode(fn, flags, s0, s1, s2, 0, 0, 0, 0, 0xffffffff, 0, None, 0, 0)
|
||||
fn = VOP3SDOp_FUNCTIONS[VOP3SDOp.V_DIV_SCALE_F32]
|
||||
result = _run_pcode(fn, VOP3SDOp, VOP3SDOp.V_DIV_SCALE_F32, s0, s1, s2, 0, 0, 0, 0, 0xffffffff, 0)
|
||||
# Must always have vcc_lane in result
|
||||
self.assertIn('vcc_lane', result, "V_DIV_SCALE_F32 must always return vcc_lane")
|
||||
self.assertEqual(result['vcc_lane'], 0, "vcc_lane should be 0 when no scaling needed")
|
||||
@@ -245,20 +245,20 @@ class TestPseudocodeRegressions(unittest.TestCase):
|
||||
Bug: isQuietNAN and isSignalNAN both used math.isnan which can't distinguish them."""
|
||||
quiet_nan = 0x7fc00000 # quiet NaN: exponent=255, bit22=1
|
||||
signal_nan = 0x7f800001 # signaling NaN: exponent=255, bit22=0
|
||||
fn, flags = VOPCOp_FUNCTIONS[VOPCOp.V_CMP_CLASS_F32]
|
||||
fn = VOPCOp_FUNCTIONS[VOPCOp.V_CMP_CLASS_F32]
|
||||
# Test quiet NaN detection (bit 1 in mask)
|
||||
s1_quiet = 0b0000000010 # bit 1 = quiet NaN
|
||||
result = _run_pcode(fn, flags, quiet_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None, 0, 0)
|
||||
result = _run_pcode(fn, VOPCOp, VOPCOp.V_CMP_CLASS_F32, quiet_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0)
|
||||
self.assertEqual(result['vcc_lane'], 1, "Should detect quiet NaN with quiet NaN mask")
|
||||
# Test signaling NaN detection (bit 0 in mask)
|
||||
s1_signal = 0b0000000001 # bit 0 = signaling NaN
|
||||
result = _run_pcode(fn, flags, signal_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None, 0, 0)
|
||||
result = _run_pcode(fn, VOPCOp, VOPCOp.V_CMP_CLASS_F32, signal_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0)
|
||||
self.assertEqual(result['vcc_lane'], 1, "Should detect signaling NaN with signaling NaN mask")
|
||||
# Test that quiet NaN doesn't match signaling NaN mask
|
||||
result = _run_pcode(fn, flags, quiet_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None, 0, 0)
|
||||
result = _run_pcode(fn, VOPCOp, VOPCOp.V_CMP_CLASS_F32, quiet_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0)
|
||||
self.assertEqual(result['vcc_lane'], 0, "Quiet NaN should not match signaling NaN mask")
|
||||
# Test that signaling NaN doesn't match quiet NaN mask
|
||||
result = _run_pcode(fn, flags, signal_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None, 0, 0)
|
||||
result = _run_pcode(fn, VOPCOp, VOPCOp.V_CMP_CLASS_F32, signal_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0)
|
||||
self.assertEqual(result['vcc_lane'], 0, "Signaling NaN should not match quiet NaN mask")
|
||||
|
||||
def test_isnan_with_typed_view(self):
|
||||
|
||||
Reference in New Issue
Block a user