mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
assembly/amd: move Reg out of the psuedocode (#13934)
* assembly/amd: move Reg out of the psuedocode * remove extra * fix pcode tests * simpler pcode * simpler * simpler * cleaner * fix mypy
This commit is contained in:
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,9 @@
|
||||
# RDNA3 emulator - executes compiled pseudocode from AMD ISA PDF
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
import ctypes, struct
|
||||
from extra.assembly.amd.dsl import Inst, RawImm, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
||||
import ctypes
|
||||
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
||||
from extra.assembly.amd.pcode import Reg
|
||||
from extra.assembly.amd.asm import detect_format
|
||||
from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
|
||||
@@ -178,24 +179,21 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
|
||||
s0 = st.rsrc64(ssrc0, 0) if inst.is_src_64(0) else (st.rsrc(ssrc0, 0) if not isinstance(inst, (SOPK, SOPP)) else (st.rsgpr(inst.sdst) if isinstance(inst, SOPK) else 0))
|
||||
s1 = st.rsrc64(inst.ssrc1, 0) if inst.is_src_64(1) else (st.rsrc(inst.ssrc1, 0) if isinstance(inst, (SOP2, SOPC)) else inst.simm16 if isinstance(inst, SOPK) else 0)
|
||||
d0 = st.rsgpr64(sdst) if inst.dst_regs() == 2 and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0)
|
||||
exec_mask = st.exec_mask
|
||||
literal = inst.simm16 if isinstance(inst, (SOPK, SOPP)) else st.literal
|
||||
|
||||
# Execute compiled function - pass PC in bytes for instructions that need it
|
||||
# For wave32, mask VCC and EXEC to 32 bits since only the lower 32 bits are relevant
|
||||
pc_bytes = st.pc * 4
|
||||
vcc32, exec32 = st.vcc & MASK32, exec_mask & MASK32
|
||||
result = fn(s0, s1, 0, d0, st.scc, vcc32, 0, exec32, literal, None, {}, pc=pc_bytes)
|
||||
# Create Reg objects for compiled function - mask VCC/EXEC to 32 bits for wave32
|
||||
result = fn(Reg(s0), Reg(s1), None, Reg(d0), Reg(st.scc), Reg(st.vcc & MASK32), 0, Reg(st.exec_mask & MASK32), literal, None, PC=Reg(st.pc * 4))
|
||||
|
||||
# Apply results
|
||||
if sdst is not None:
|
||||
(st.wsgpr64 if result.get('d0_64') else st.wsgpr)(sdst, result['d0'])
|
||||
if 'scc' in result: st.scc = result['scc']
|
||||
if 'exec' in result: st.exec_mask = result['exec']
|
||||
if 'new_pc' in result:
|
||||
# Apply results - extract values from returned Reg objects
|
||||
if sdst is not None and 'D0' in result:
|
||||
(st.wsgpr64 if inst.dst_regs() == 2 else st.wsgpr)(sdst, result['D0']._val)
|
||||
if 'SCC' in result: st.scc = result['SCC']._val & 1
|
||||
if 'EXEC' in result: st.exec_mask = result['EXEC']._val
|
||||
if 'PC' in result:
|
||||
# Convert absolute byte address to word delta
|
||||
# new_pc is where we want to go, st.pc is current position, inst._words will be added after
|
||||
new_pc_words = result['new_pc'] // 4
|
||||
pc_val = result['PC']._val
|
||||
new_pc = pc_val if pc_val < 0x8000000000000000 else pc_val - 0x10000000000000000
|
||||
new_pc_words = new_pc // 4
|
||||
return new_pc_words - st.pc - 1 # -1 because emulator adds inst_words (1 for scalar)
|
||||
return 0
|
||||
|
||||
@@ -260,24 +258,25 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
vdsty = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1)
|
||||
inputs = [(inst.opx, st.rsrc(inst.srcx0, lane), V[inst.vsrcx1], V[inst.vdstx], inst.vdstx),
|
||||
(inst.opy, st.rsrc(inst.srcy0, lane), V[inst.vsrcy1], V[vdsty], vdsty)]
|
||||
results = [(dst, fn(s0, s1, 0, d0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})['d0'])
|
||||
for vopd_op, s0, s1, d0, dst in inputs if (op := _VOPD_TO_VOP.get(vopd_op)) and (fn := compiled.get(type(op), {}).get(op))]
|
||||
for dst, val in results: V[dst] = val
|
||||
def exec_vopd(vopd_op, s0, s1, d0):
|
||||
op = _VOPD_TO_VOP[vopd_op]
|
||||
return compiled[type(op)][op](Reg(s0), Reg(s1), None, Reg(d0), Reg(st.scc), Reg(st.vcc), lane, Reg(st.exec_mask), st.literal, None)['D0']._val
|
||||
for vopd_op, s0, s1, d0, dst in inputs: V[dst] = exec_vopd(vopd_op, s0, s1, d0)
|
||||
return
|
||||
|
||||
# VOP3SD: has extra scalar dest for carry output
|
||||
if isinstance(inst, VOP3SD):
|
||||
fn = compiled.get(VOP3SDOp, {}).get(inst.op)
|
||||
if fn is None: raise NotImplementedError(f"{inst.op.name} not in pseudocode")
|
||||
fn = compiled[VOP3SDOp][inst.op]
|
||||
# Read sources based on register counts from inst properties
|
||||
def rsrc_n(src, regs): return st.rsrc64(src, lane) if regs == 2 else st.rsrc(src, lane)
|
||||
s0, s1, s2 = rsrc_n(inst.src0, inst.src_regs(0)), rsrc_n(inst.src1, inst.src_regs(1)), rsrc_n(inst.src2, inst.src_regs(2))
|
||||
# Carry-in ops use src2 as carry bitmask instead of VCC
|
||||
vcc = st.rsgpr64(inst.src2) if 'CO_CI' in inst.op_name else st.vcc
|
||||
result = fn(s0, s1, s2, V[inst.vdst], st.scc, vcc, lane, st.exec_mask, st.literal, None, {})
|
||||
V[inst.vdst] = result['d0'] & MASK32
|
||||
if result.get('d0_64'): V[inst.vdst + 1] = (result['d0'] >> 32) & MASK32
|
||||
if result.get('vcc_lane') is not None: st.pend_sgpr_lane(inst.sdst, lane, result['vcc_lane'])
|
||||
result = fn(Reg(s0), Reg(s1), Reg(s2), Reg(V[inst.vdst]), Reg(st.scc), Reg(vcc), lane, Reg(st.exec_mask), st.literal, None)
|
||||
d0_val = result['D0']._val
|
||||
V[inst.vdst] = d0_val & MASK32
|
||||
if inst.dst_regs() == 2: V[inst.vdst + 1] = (d0_val >> 32) & MASK32
|
||||
if 'VCC' in result: st.pend_sgpr_lane(inst.sdst, lane, (result['VCC']._val >> lane) & 1)
|
||||
return
|
||||
|
||||
# Get op enum and sources (None means "no source" for that operand)
|
||||
@@ -317,8 +316,7 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
if abs_ & (1<<i): srcs[i] = abs(srcs[i])
|
||||
if neg & (1<<i): srcs[i] = -srcs[i]
|
||||
result = srcs[0] * srcs[1] + srcs[2]
|
||||
V = st.vgpr[lane]
|
||||
V[inst.vdst] = _i32(result) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result), inst.op == VOP3POp.V_FMA_MIXHI_F16)
|
||||
st.vgpr[lane][inst.vdst] = _i32(result) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result), inst.op == VOP3POp.V_FMA_MIXHI_F16)
|
||||
return
|
||||
# VOP3P packed ops: opsel selects halves for lo, opsel_hi for hi; neg toggles f16 sign
|
||||
raws = [st.rsrc_f16(inst.src0, lane), st.rsrc_f16(inst.src1, lane), st.rsrc_f16(inst.src2, lane) if inst.src2 is not None else 0]
|
||||
@@ -327,15 +325,13 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
hi_sels = [opsel_hi & 1, opsel_hi & 2, opsel_hi2]
|
||||
srcs = [((_src16(raws[i], hi_sels[i]) ^ (0x8000 if neg_hi & (1<<i) else 0)) << 16) |
|
||||
(_src16(raws[i], opsel & (1<<i)) ^ (0x8000 if neg & (1<<i) else 0)) for i in range(3)]
|
||||
fn = compiled.get(VOP3POp, {}).get(inst.op)
|
||||
if fn is None: raise NotImplementedError(f"{inst.op.name} not in pseudocode")
|
||||
st.vgpr[lane][inst.vdst] = fn(srcs[0], srcs[1], srcs[2], 0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})['d0'] & MASK32
|
||||
result = compiled[VOP3POp][inst.op](Reg(srcs[0]), Reg(srcs[1]), Reg(srcs[2]), Reg(0), Reg(st.scc), Reg(st.vcc), lane, Reg(st.exec_mask), st.literal, None)
|
||||
st.vgpr[lane][inst.vdst] = result['D0']._val & MASK32
|
||||
return
|
||||
else: raise NotImplementedError(f"Unknown vector type {type(inst)}")
|
||||
|
||||
op_cls = type(inst.op)
|
||||
fn = compiled.get(op_cls, {}).get(inst.op)
|
||||
if fn is None: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
|
||||
if (fn := compiled.get(op_cls, {}).get(inst.op)) is None: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
|
||||
|
||||
# Read sources (with VOP3 modifiers if applicable)
|
||||
neg, abs_ = (getattr(inst, 'neg', 0), getattr(inst, 'abs', 0)) if isinstance(inst, VOP3) else (0, 0)
|
||||
@@ -377,24 +373,27 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
# Execute compiled function - pass src0_idx and vdst_idx for lane instructions
|
||||
# For VGPR access: src0 index is the VGPR number (src0 - 256 if VGPR, else src0 for SGPR)
|
||||
src0_idx = (src0 - 256) if src0 is not None and src0 >= 256 else (src0 if src0 is not None else 0)
|
||||
result = fn(s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, st.literal, st.vgpr, {}, src0_idx, vdst)
|
||||
result = fn(Reg(s0), Reg(s1), Reg(s2), Reg(d0), Reg(st.scc), Reg(vcc_for_fn), lane, Reg(st.exec_mask), st.literal, st.vgpr, src0_idx, vdst)
|
||||
|
||||
# Apply results
|
||||
# Apply results - extract values from returned Reg objects
|
||||
if 'vgpr_write' in result:
|
||||
# Lane instruction wrote to VGPR: (lane, vgpr_idx, value)
|
||||
wr_lane, wr_idx, wr_val = result['vgpr_write']
|
||||
st.vgpr[wr_lane][wr_idx] = wr_val
|
||||
if 'vcc_lane' in result:
|
||||
if 'VCC' in result:
|
||||
# VOP2 carry ops write to VCC implicitly; VOPC/VOP3 write to vdst
|
||||
st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, result['vcc_lane'])
|
||||
if 'exec_lane' in result:
|
||||
# V_CMPX instructions write to EXEC per-lane
|
||||
st.pend_sgpr_lane(EXEC_LO, lane, result['exec_lane'])
|
||||
if 'd0' in result and op_cls is not VOPCOp and 'vgpr_write' not in result:
|
||||
st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, (result['VCC']._val >> lane) & 1)
|
||||
if 'EXEC' in result:
|
||||
# V_CMPX instructions write to EXEC per-lane (not to vdst)
|
||||
st.pend_sgpr_lane(EXEC_LO, lane, (result['EXEC']._val >> lane) & 1)
|
||||
elif op_cls is VOPCOp:
|
||||
# VOPC comparison result stored in D0 bitmask, extract lane bit (non-CMPX only)
|
||||
st.pend_sgpr_lane(vdst, lane, (result['D0']._val >> lane) & 1)
|
||||
if op_cls is not VOPCOp and 'vgpr_write' not in result:
|
||||
writes_to_sgpr = 'READFIRSTLANE' in inst.op_name or 'READLANE' in inst.op_name
|
||||
d0_val = result['d0']
|
||||
d0_val = result['D0']._val
|
||||
if writes_to_sgpr: st.wsgpr(vdst, d0_val & MASK32)
|
||||
elif result.get('d0_64'): V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
|
||||
elif inst.dst_regs() == 2: V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
|
||||
elif inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi)
|
||||
else: V[vdst] = d0_val & MASK32
|
||||
|
||||
|
||||
@@ -43,7 +43,10 @@ UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
|
||||
'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
|
||||
'CVT_OFF_TABLE', 'ThreadMask',
|
||||
'S1[i', 'C.i32', 'S[i]', 'in[',
|
||||
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST'] # Malformed pseudocode from PDF
|
||||
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST',
|
||||
'BARRIER_STATE', 'ReallocVgprs',
|
||||
'GPR_IDX', 'VSKIP', 'specified in', 'TTBL',
|
||||
'fp6', 'bf6'] # Malformed pseudocode from PDF
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# COMPILER: pseudocode -> Python (minimal transforms)
|
||||
@@ -51,6 +54,7 @@ UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
|
||||
|
||||
def compile_pseudocode(pseudocode: str) -> str:
|
||||
"""Compile pseudocode to Python. Transforms are minimal - most syntax just works."""
|
||||
pseudocode = re.sub(r'\bpass\b', 'pass_', pseudocode) # 'pass' is Python keyword
|
||||
raw_lines = pseudocode.strip().split('\n')
|
||||
joined_lines: list[str] = []
|
||||
for line in raw_lines:
|
||||
@@ -113,7 +117,7 @@ def compile_pseudocode(pseudocode: str) -> str:
|
||||
break
|
||||
else:
|
||||
lhs, rhs = line.split('=', 1)
|
||||
lhs_s, rhs_s = lhs.strip(), rhs.strip()
|
||||
lhs_s, rhs_s = _expr(lhs.strip()), rhs.strip()
|
||||
stmt = _assign(lhs_s, _expr(rhs_s))
|
||||
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
|
||||
stmt += "; break"
|
||||
@@ -533,52 +537,57 @@ def _apply_pseudocode_fixes(op, code: str) -> str:
|
||||
|
||||
def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]:
|
||||
"""Generate a single compiled pseudocode function."""
|
||||
is_64 = any(p in pc for p in ['D0.u64', 'D0.b64', 'D0.f64', 'D0.i64', 'D1.u64', 'D1.b64', 'D1.f64', 'D1.i64'])
|
||||
has_d1 = '{ D1' in pc
|
||||
if has_d1: is_64 = True
|
||||
is_cmp = (cls_name in ('VOPCOp', 'VOP3Op')) and 'D0.u64[laneId]' in pc
|
||||
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
|
||||
is_div_scale = 'DIV_SCALE' in op.name
|
||||
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
|
||||
has_pc = 'PC' in pc
|
||||
combined = code + pc
|
||||
|
||||
fn_name = f"_{cls_name}_{op.name}"
|
||||
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0):"]
|
||||
for pc_line in pc.split('\n'): lines.append(f" # {pc_line}")
|
||||
# Function accepts Reg objects directly (uppercase names), laneId is passed directly as int
|
||||
lines = [f"def {fn_name}(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):"]
|
||||
|
||||
regs = [('S0', 'Reg(s0)'), ('S1', 'Reg(s1)'), ('S2', 'Reg(s2)'),
|
||||
('D0', 'Reg(s0)' if is_div_scale else 'Reg(d0)'), ('D1', 'Reg(0)'),
|
||||
('SCC', 'Reg(scc)'), ('VCC', 'Reg(vcc)'), ('EXEC', 'Reg(exec_mask)'),
|
||||
('tmp', 'Reg(0)'), ('saveexec', 'Reg(exec_mask)'), ('laneId', 'lane'),
|
||||
('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
|
||||
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)'), ('PC', 'Reg(pc)')]
|
||||
used = {name for name, _ in regs if name in combined}
|
||||
if 'EXEC_LO' in combined or 'EXEC_HI' in combined: used.add('EXEC')
|
||||
if 'VCCZ' in combined: used.add('VCC')
|
||||
if 'EXECZ' in combined: used.add('EXEC')
|
||||
for name, init in regs:
|
||||
if name in used: lines.append(f" {name} = {init}")
|
||||
if 'EXEC_LO' in combined: lines.append(" EXEC_LO = SliceProxy(EXEC, 31, 0)")
|
||||
if 'EXEC_HI' in combined: lines.append(" EXEC_HI = SliceProxy(EXEC, 63, 32)")
|
||||
if 'VCCZ' in combined: lines.append(" VCCZ = Reg(1 if VCC._val == 0 else 0)")
|
||||
if 'EXECZ' in combined: lines.append(" EXECZ = Reg(1 if EXEC._val == 0 else 0)")
|
||||
lines.append(" # --- compiled pseudocode ---")
|
||||
for line in code.split('\n'): lines.append(f" {line}")
|
||||
lines.append(" # --- end pseudocode ---")
|
||||
d0_val, scc_val = ("D0._val" if 'D0' in used else "d0"), ("SCC._val & 1" if 'SCC' in used else "scc & 1")
|
||||
lines.append(f" result = {{'d0': {d0_val}, 'scc': {scc_val}}}")
|
||||
if has_sdst: lines.append(" result['vcc_lane'] = (VCC._val >> lane) & 1")
|
||||
elif 'VCC' in used: lines.append(" if VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1")
|
||||
if is_cmpx: lines.append(" result['exec_lane'] = (EXEC._val >> lane) & 1")
|
||||
elif 'EXEC' in used: lines.append(" if EXEC._val != exec_mask: result['exec'] = EXEC._val")
|
||||
if is_cmp: lines.append(" result['vcc_lane'] = (D0._val >> lane) & 1")
|
||||
if is_64: lines.append(" result['d0_64'] = True")
|
||||
if has_d1: lines.append(" result['d1'] = D1._val & 1")
|
||||
if has_pc:
|
||||
lines.append(" _pc = PC._val if PC._val < 0x8000000000000000 else PC._val - 0x10000000000000000")
|
||||
lines.append(" result['new_pc'] = _pc")
|
||||
lines.append(" return result\n")
|
||||
# Registers that need special handling (not passed directly)
|
||||
# Only init if used but not first assigned as `name = Reg(...)` in the compiled code
|
||||
def needs_init(name): return name in combined and not re.search(rf'^\s*{name}\s*=\s*Reg\(', code, re.MULTILINE)
|
||||
special_regs = [('D1', 'Reg(0)'), ('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
|
||||
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
|
||||
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
|
||||
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
|
||||
used = {name for name, _ in special_regs if name in combined}
|
||||
|
||||
# Detect which registers are modified (not just read) - look for assignments
|
||||
modifies_d0 = is_div_scale or bool(re.search(r'\bD0\b[.\[]', combined))
|
||||
modifies_exec = is_cmpx or bool(re.search(r'EXEC\.(u32|u64|b32|b64)\s*=', combined))
|
||||
modifies_vcc = has_sdst or bool(re.search(r'VCC\.(u32|u64|b32|b64)\s*=|VCC\.u64\[laneId\]\s*=', combined))
|
||||
modifies_scc = bool(re.search(r'\bSCC\s*=', combined))
|
||||
modifies_pc = bool(re.search(r'\bPC\s*=', combined))
|
||||
|
||||
# Build init code for special registers
|
||||
init_lines = []
|
||||
if is_div_scale: init_lines.append(" D0 = Reg(S0._val)")
|
||||
for name, init in special_regs:
|
||||
if name in used: init_lines.append(f" {name} = {init}")
|
||||
if 'EXEC_LO' in code: init_lines.append(" EXEC_LO = SliceProxy(EXEC, 31, 0)")
|
||||
if 'EXEC_HI' in code: init_lines.append(" EXEC_HI = SliceProxy(EXEC, 63, 32)")
|
||||
if 'VCCZ' in code and not re.search(r'^\s*VCCZ\s*=', code, re.MULTILINE): init_lines.append(" VCCZ = Reg(1 if VCC._val == 0 else 0)")
|
||||
if 'EXECZ' in code and not re.search(r'^\s*EXECZ\s*=', code, re.MULTILINE): init_lines.append(" EXECZ = Reg(1 if EXEC._val == 0 else 0)")
|
||||
code_lines = [line for line in code.split('\n') if line.strip()]
|
||||
if init_lines:
|
||||
lines.extend(init_lines)
|
||||
if code_lines: lines.append(" # --- compiled pseudocode ---")
|
||||
for line in code_lines:
|
||||
lines.append(f" {line}")
|
||||
|
||||
# Build result dict - only include registers that are modified
|
||||
result_items = []
|
||||
if modifies_d0: result_items.append("'D0': D0")
|
||||
if modifies_scc: result_items.append("'SCC': SCC")
|
||||
if modifies_vcc: result_items.append("'VCC': VCC")
|
||||
if modifies_exec: result_items.append("'EXEC': EXEC")
|
||||
if has_d1: result_items.append("'D1': D1")
|
||||
if modifies_pc: result_items.append("'PC': PC")
|
||||
lines.append(f" return {{{', '.join(result_items)}}}\n")
|
||||
return fn_name, '\n'.join(lines)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
@@ -229,17 +229,18 @@ class TestPseudocodeRegressions(unittest.TestCase):
|
||||
"""Regression tests for pseudocode instruction emulation bugs."""
|
||||
|
||||
def test_v_div_scale_f32_vcc_always_returned(self):
|
||||
"""V_DIV_SCALE_F32 must always return vcc_lane, even when VCC=0 (no scaling needed).
|
||||
Bug: when VCC._val == vcc (both 0), vcc_lane wasn't returned, so VCC bits weren't written.
|
||||
"""V_DIV_SCALE_F32 must always return VCC, even when VCC=0 (no scaling needed).
|
||||
Bug: when VCC._val == vcc (both 0), VCC wasn't returned, so VCC bits weren't written.
|
||||
This caused division to produce wrong results for multiple lanes."""
|
||||
# Normal case: 1.0 / 3.0, no scaling needed, VCC should be 0
|
||||
s0 = 0x3f800000 # 1.0
|
||||
s1 = 0x40400000 # 3.0
|
||||
s2 = 0x3f800000 # 1.0 (numerator)
|
||||
result = _VOP3SDOp_V_DIV_SCALE_F32(s0, s1, s2, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
# Must always have vcc_lane in result
|
||||
self.assertIn('vcc_lane', result, "V_DIV_SCALE_F32 must always return vcc_lane")
|
||||
self.assertEqual(result['vcc_lane'], 0, "vcc_lane should be 0 when no scaling needed")
|
||||
S0 = Reg(0x3f800000) # 1.0
|
||||
S1 = Reg(0x40400000) # 3.0
|
||||
S2 = Reg(0x3f800000) # 1.0 (numerator)
|
||||
D0, SCC, VCC, EXEC = Reg(0), Reg(0), Reg(0), Reg(0xffffffff)
|
||||
result = _VOP3SDOp_V_DIV_SCALE_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
|
||||
# Must always have VCC in result
|
||||
self.assertIn('VCC', result, "V_DIV_SCALE_F32 must always return VCC")
|
||||
self.assertEqual(result['VCC']._val & 1, 0, "VCC lane 0 should be 0 when no scaling needed")
|
||||
|
||||
def test_v_cmp_class_f32_detects_quiet_nan(self):
|
||||
"""V_CMP_CLASS_F32 must correctly identify quiet NaN vs signaling NaN.
|
||||
@@ -248,18 +249,22 @@ class TestPseudocodeRegressions(unittest.TestCase):
|
||||
signal_nan = 0x7f800001 # signaling NaN: exponent=255, bit22=0
|
||||
# Test quiet NaN detection (bit 1 in mask)
|
||||
s1_quiet = 0b0000000010 # bit 1 = quiet NaN
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
self.assertEqual(result['vcc_lane'], 1, "Should detect quiet NaN with quiet NaN mask")
|
||||
S0, S1, S2, D0, SCC, VCC, EXEC = Reg(quiet_nan), Reg(s1_quiet), Reg(0), Reg(0), Reg(0), Reg(0), Reg(0xffffffff)
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
|
||||
self.assertEqual(result['D0']._val & 1, 1, "Should detect quiet NaN with quiet NaN mask")
|
||||
# Test signaling NaN detection (bit 0 in mask)
|
||||
s1_signal = 0b0000000001 # bit 0 = signaling NaN
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
self.assertEqual(result['vcc_lane'], 1, "Should detect signaling NaN with signaling NaN mask")
|
||||
S0, S1 = Reg(signal_nan), Reg(s1_signal)
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
|
||||
self.assertEqual(result['D0']._val & 1, 1, "Should detect signaling NaN with signaling NaN mask")
|
||||
# Test that quiet NaN doesn't match signaling NaN mask
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
self.assertEqual(result['vcc_lane'], 0, "Quiet NaN should not match signaling NaN mask")
|
||||
S0, S1 = Reg(quiet_nan), Reg(s1_signal)
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
|
||||
self.assertEqual(result['D0']._val & 1, 0, "Quiet NaN should not match signaling NaN mask")
|
||||
# Test that signaling NaN doesn't match quiet NaN mask
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
self.assertEqual(result['vcc_lane'], 0, "Signaling NaN should not match quiet NaN mask")
|
||||
S0, S1 = Reg(signal_nan), Reg(s1_quiet)
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
|
||||
self.assertEqual(result['D0']._val & 1, 0, "Signaling NaN should not match quiet NaN mask")
|
||||
|
||||
def test_isnan_with_typed_view(self):
|
||||
"""_isnan must work with TypedView objects, not just Python floats.
|
||||
|
||||
Reference in New Issue
Block a user