mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
assembly/amd: clean up pcode, jit pcode instead of static (#14001)
* assembly/amd: clean up pcode * regen * lil * jit the pcode * sendmsg * cleanups * inst prefetch lol
This commit is contained in:
File diff suppressed because it is too large
Load Diff
1421
extra/assembly/amd/autogen/cdna/str_pcode.py
Normal file
1421
extra/assembly/amd/autogen/cdna/str_pcode.py
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
1354
extra/assembly/amd/autogen/rdna3/str_pcode.py
Normal file
1354
extra/assembly/amd/autogen/rdna3/str_pcode.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
1229
extra/assembly/amd/autogen/rdna4/str_pcode.py
Normal file
1229
extra/assembly/amd/autogen/rdna4/str_pcode.py
Normal file
File diff suppressed because one or more lines are too long
@@ -5,7 +5,8 @@ import ctypes, functools
|
|||||||
from tinygrad.runtime.autogen import hsa
|
from tinygrad.runtime.autogen import hsa
|
||||||
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
||||||
from extra.assembly.amd.asm import detect_format
|
from extra.assembly.amd.asm import detect_format
|
||||||
from extra.assembly.amd.autogen.rdna3.gen_pcode import COMPILED_FUNCTIONS
|
from extra.assembly.amd.pcode import compile_pseudocode
|
||||||
|
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS
|
||||||
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
|
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
|
||||||
SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, SCRATCHOp, VOPDOp)
|
SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, SCRATCHOp, VOPDOp)
|
||||||
|
|
||||||
@@ -236,9 +237,8 @@ def exec_vopd(st: WaveState, inst, V: list, lane: int) -> None:
|
|||||||
"""VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes)."""
|
"""VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes)."""
|
||||||
literal, vdstx, vdsty = inst._literal, inst.vdstx, (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1)
|
literal, vdstx, vdsty = inst._literal, inst.vdstx, (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1)
|
||||||
sx0, sx1, dx, sy0, sy1, dy = st.rsrc(inst.srcx0, lane, literal), V[inst.vsrcx1], V[vdstx], st.rsrc(inst.srcy0, lane, literal), V[inst.vsrcy1], V[vdsty]
|
sx0, sx1, dx, sy0, sy1, dy = st.rsrc(inst.srcx0, lane, literal), V[inst.vsrcx1], V[vdstx], st.rsrc(inst.srcy0, lane, literal), V[inst.vsrcy1], V[vdsty]
|
||||||
opx, opy = _VOPD_TO_VOP[inst.opx], _VOPD_TO_VOP[inst.opy]
|
V[vdstx] = inst._fnx(sx0, sx1, 0, dx, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0']
|
||||||
V[vdstx] = COMPILED_FUNCTIONS[type(opx)][opx](sx0, sx1, 0, dx, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0']
|
V[vdsty] = inst._fny(sy0, sy1, 0, dy, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0']
|
||||||
V[vdsty] = COMPILED_FUNCTIONS[type(opy)][opy](sy0, sy1, 0, dy, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0']
|
|
||||||
|
|
||||||
def exec_flat(st: WaveState, inst, V: list, lane: int) -> None:
|
def exec_flat(st: WaveState, inst, V: list, lane: int) -> None:
|
||||||
"""FLAT/GLOBAL/SCRATCH memory ops."""
|
"""FLAT/GLOBAL/SCRATCH memory ops."""
|
||||||
@@ -359,15 +359,14 @@ def decode_program(data: bytes) -> dict[int, Inst]:
|
|||||||
result: dict[int, Inst] = {}
|
result: dict[int, Inst] = {}
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(data):
|
while i < len(data):
|
||||||
try: inst_class = detect_format(data[i:])
|
inst = detect_format(data[i:]).from_bytes(data[i:])
|
||||||
except ValueError: break # stop at invalid instruction (padding/metadata after code)
|
|
||||||
inst = inst_class.from_bytes(data[i:i+inst_class._size()+8]) # +8 for potential 64-bit literal
|
|
||||||
inst._words = inst.size() // 4
|
inst._words = inst.size() // 4
|
||||||
|
|
||||||
# Determine dispatch function and pcode function
|
# Determine dispatch function and pcode function
|
||||||
fn = COMPILED_FUNCTIONS.get(type(inst.op), {}).get(inst.op)
|
if isinstance(inst, SOPP) and inst.op == SOPPOp.S_CODE_END: break
|
||||||
if isinstance(inst, SOPP) and inst.op == SOPPOp.S_ENDPGM: inst._dispatch = dispatch_endpgm
|
elif isinstance(inst, SOPP) and inst.op == SOPPOp.S_ENDPGM: inst._dispatch = dispatch_endpgm
|
||||||
elif isinstance(inst, SOPP) and inst.op == SOPPOp.S_BARRIER: inst._dispatch = dispatch_barrier
|
elif isinstance(inst, SOPP) and inst.op == SOPPOp.S_BARRIER: inst._dispatch = dispatch_barrier
|
||||||
|
elif isinstance(inst, SOPP) and inst.op in (SOPPOp.S_CLAUSE, SOPPOp.S_WAITCNT, SOPPOp.S_WAITCNT_DEPCTR, SOPPOp.S_SENDMSG, SOPPOp.S_SET_INST_PREFETCH_DISTANCE): inst._dispatch = dispatch_nop
|
||||||
elif isinstance(inst, (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM)): inst._dispatch = exec_scalar
|
elif isinstance(inst, (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM)): inst._dispatch = exec_scalar
|
||||||
elif isinstance(inst, VOP1) and inst.op == VOP1Op.V_NOP: inst._dispatch = dispatch_nop
|
elif isinstance(inst, VOP1) and inst.op == VOP1Op.V_NOP: inst._dispatch = dispatch_nop
|
||||||
elif isinstance(inst, VOP3P) and 'WMMA' in inst.op_name: inst._dispatch = dispatch_wmma
|
elif isinstance(inst, VOP3P) and 'WMMA' in inst.op_name: inst._dispatch = dispatch_wmma
|
||||||
@@ -378,11 +377,14 @@ def decode_program(data: bytes) -> dict[int, Inst]:
|
|||||||
elif isinstance(inst, DS): inst._dispatch = dispatch_lane(exec_ds)
|
elif isinstance(inst, DS): inst._dispatch = dispatch_lane(exec_ds)
|
||||||
else: inst._dispatch = dispatch_lane(exec_vop)
|
else: inst._dispatch = dispatch_lane(exec_vop)
|
||||||
|
|
||||||
# Validate pcode exists for instructions that need it (scalar/wave-level ops and VOPD don't need pcode)
|
# Compile pcode for instructions that use it (not VOPD which has _fnx/_fny, not special dispatches)
|
||||||
needs_pcode = inst._dispatch not in (dispatch_endpgm, dispatch_barrier, exec_scalar, dispatch_nop, dispatch_wmma,
|
# VOPD needs separate functions for X and Y ops
|
||||||
dispatch_writelane, dispatch_readlane, dispatch_lane(exec_vopd))
|
if isinstance(inst, VOPD):
|
||||||
if fn is None and inst.op_name and needs_pcode: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
|
def _compile_vopd_op(op): return compile_pseudocode(type(op).__name__, op.name, PSEUDOCODE_STRINGS[type(op)][op])
|
||||||
inst._fn = fn if fn else lambda *args, **kwargs: {}
|
inst._fnx, inst._fny = _compile_vopd_op(_VOPD_TO_VOP[inst.opx]), _compile_vopd_op(_VOPD_TO_VOP[inst.opy])
|
||||||
|
elif inst._dispatch not in (dispatch_endpgm, dispatch_barrier, dispatch_nop, dispatch_wmma, dispatch_writelane):
|
||||||
|
assert type(inst.op) != int, f"inst op of {inst} is int"
|
||||||
|
inst._fn = compile_pseudocode(type(inst.op).__name__, inst.op.name, PSEUDOCODE_STRINGS[type(inst.op)][inst.op])
|
||||||
result[i // 4] = inst
|
result[i // 4] = inst
|
||||||
i += inst._words * 4
|
i += inst._words * 4
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
# DSL for RDNA3 pseudocode - makes pseudocode expressions work directly as Python
|
# DSL for RDNA3 pseudocode - makes pseudocode expressions work directly as Python
|
||||||
import struct, math
|
import struct, math, re, functools
|
||||||
from extra.assembly.amd.dsl import MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
from extra.assembly.amd.dsl import MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════════════
|
||||||
# HELPER FUNCTIONS
|
# INTERNAL HELPERS
|
||||||
# ═══════════════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
def _div(a, b):
|
def _div(a, b):
|
||||||
@@ -11,143 +11,35 @@ def _div(a, b):
|
|||||||
except ZeroDivisionError:
|
except ZeroDivisionError:
|
||||||
if a == 0.0 or math.isnan(a): return float("nan")
|
if a == 0.0 or math.isnan(a): return float("nan")
|
||||||
return math.copysign(float("inf"), a * b) if b == 0.0 else float("inf") if a > 0 else float("-inf")
|
return math.copysign(float("inf"), a * b) if b == 0.0 else float("inf") if a > 0 else float("-inf")
|
||||||
def _to_f16_bits(v): return v if isinstance(v, int) else _i16(v)
|
|
||||||
def _isnan(x):
|
|
||||||
try: return math.isnan(float(x))
|
|
||||||
except (TypeError, ValueError): return False
|
|
||||||
def _check_nan_type(x, quiet_bit_expected, default):
|
def _check_nan_type(x, quiet_bit_expected, default):
|
||||||
"""Check NaN type by examining quiet bit. Returns default if can't determine."""
|
|
||||||
try:
|
try:
|
||||||
if not math.isnan(float(x)): return False
|
if not math.isnan(float(x)): return False
|
||||||
if hasattr(x, '_reg') and hasattr(x, '_bits'):
|
if hasattr(x, '_reg') and hasattr(x, '_bits'):
|
||||||
bits = x._reg._val & ((1 << x._bits) - 1)
|
bits = x._reg._val & ((1 << x._bits) - 1)
|
||||||
# NaN format: exponent all 1s, quiet bit, mantissa != 0
|
|
||||||
# f16: exp[14:10]=31, quiet=bit9, mant[8:0] | f32: exp[30:23]=255, quiet=bit22, mant[22:0] | f64: exp[62:52]=2047, quiet=bit51, mant[51:0]
|
|
||||||
exp_bits, quiet_pos, mant_mask = {16: (0x1f, 9, 0x3ff), 32: (0xff, 22, 0x7fffff), 64: (0x7ff, 51, 0xfffffffffffff)}.get(x._bits, (0,0,0))
|
exp_bits, quiet_pos, mant_mask = {16: (0x1f, 9, 0x3ff), 32: (0xff, 22, 0x7fffff), 64: (0x7ff, 51, 0xfffffffffffff)}.get(x._bits, (0,0,0))
|
||||||
exp_shift = {16: 10, 32: 23, 64: 52}.get(x._bits, 0)
|
exp_shift = {16: 10, 32: 23, 64: 52}.get(x._bits, 0)
|
||||||
if exp_bits and ((bits >> exp_shift) & exp_bits) == exp_bits and (bits & mant_mask) != 0:
|
if exp_bits and ((bits >> exp_shift) & exp_bits) == exp_bits and (bits & mant_mask) != 0:
|
||||||
return ((bits >> quiet_pos) & 1) == quiet_bit_expected
|
return ((bits >> quiet_pos) & 1) == quiet_bit_expected
|
||||||
return default
|
return default
|
||||||
except (TypeError, ValueError): return False
|
except (TypeError, ValueError): return False
|
||||||
def _isquietnan(x): return _check_nan_type(x, 1, True) # quiet NaN has quiet bit = 1
|
|
||||||
def _issignalnan(x): return _check_nan_type(x, 0, False) # signaling NaN has quiet bit = 0
|
|
||||||
def _gt_neg_zero(a, b): return (a > b) or (a == 0 and b == 0 and not math.copysign(1, a) < 0 and math.copysign(1, b) < 0)
|
def _gt_neg_zero(a, b): return (a > b) or (a == 0 and b == 0 and not math.copysign(1, a) < 0 and math.copysign(1, b) < 0)
|
||||||
def _lt_neg_zero(a, b): return (a < b) or (a == 0 and b == 0 and math.copysign(1, a) < 0 and not math.copysign(1, b) < 0)
|
def _lt_neg_zero(a, b): return (a < b) or (a == 0 and b == 0 and math.copysign(1, a) < 0 and not math.copysign(1, b) < 0)
|
||||||
def _fma(a, b, c):
|
|
||||||
try: return math.fma(a, b, c)
|
|
||||||
except ValueError: return float('nan') # inf * 0 + c is NaN per IEEE 754
|
|
||||||
def _signext(v): return v
|
|
||||||
def _fpop(fn):
|
def _fpop(fn):
|
||||||
def wrapper(x):
|
def wrapper(x):
|
||||||
x = float(x)
|
x = float(x)
|
||||||
if math.isnan(x) or math.isinf(x): return x
|
if math.isnan(x) or math.isinf(x): return x
|
||||||
result = float(fn(x))
|
result = float(fn(x))
|
||||||
# Preserve sign of zero (IEEE 754: ceil(-0.0) = -0.0, ceil(-0.1) = -0.0)
|
return math.copysign(0.0, x) if result == 0.0 else result
|
||||||
if result == 0.0: return math.copysign(0.0, x)
|
|
||||||
return result
|
|
||||||
return wrapper
|
return wrapper
|
||||||
trunc, floor, ceil = _fpop(math.trunc), _fpop(math.floor), _fpop(math.ceil)
|
|
||||||
class _SafeFloat(float):
|
|
||||||
"""Float subclass that uses _div for division to handle 0/inf correctly."""
|
|
||||||
def __truediv__(self, o): return _div(float(self), float(o))
|
|
||||||
def __rtruediv__(self, o): return _div(float(o), float(self))
|
|
||||||
def sqrt(x): return _SafeFloat(math.sqrt(x)) if x >= 0 else _SafeFloat(float("nan"))
|
|
||||||
def log2(x): return math.log2(x) if x > 0 else (float("-inf") if x == 0 else float("nan"))
|
|
||||||
i32_to_f32 = u32_to_f32 = i32_to_f64 = u32_to_f64 = f32_to_f64 = f64_to_f32 = float
|
|
||||||
def _f_to_int(f, lo, hi): f = float(f); return 0 if math.isnan(f) else (hi if f >= hi else lo if f <= lo else int(f))
|
def _f_to_int(f, lo, hi): f = float(f); return 0 if math.isnan(f) else (hi if f >= hi else lo if f <= lo else int(f))
|
||||||
def f32_to_i32(f): return _f_to_int(f, -2147483648, 2147483647)
|
|
||||||
def f32_to_u32(f): return _f_to_int(f, 0, 4294967295)
|
|
||||||
f64_to_i32, f64_to_u32 = f32_to_i32, f32_to_u32
|
|
||||||
def f32_to_f16(f):
|
|
||||||
f = float(f)
|
|
||||||
if math.isnan(f): return 0x7e00 # f16 NaN
|
|
||||||
if math.isinf(f): return 0x7c00 if f > 0 else 0xfc00 # f16 ±infinity
|
|
||||||
try: return struct.unpack("<H", struct.pack("<e", f))[0]
|
|
||||||
except OverflowError: return 0x7c00 if f > 0 else 0xfc00 # overflow -> ±infinity
|
|
||||||
def _f16_to_f32_bits(bits): return struct.unpack("<e", struct.pack("<H", int(bits) & 0xffff))[0]
|
def _f16_to_f32_bits(bits): return struct.unpack("<e", struct.pack("<H", int(bits) & 0xffff))[0]
|
||||||
def f16_to_f32(v): return v if isinstance(v, float) else _f16_to_f32_bits(v)
|
|
||||||
def i16_to_f16(v): return f32_to_f16(float(_sext(int(v) & 0xffff, 16)))
|
|
||||||
def u16_to_f16(v): return f32_to_f16(float(int(v) & 0xffff))
|
|
||||||
def f16_to_i16(bits): f = _f16_to_f32_bits(bits); return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
|
|
||||||
def f16_to_u16(bits): f = _f16_to_f32_bits(bits); return max(0, min(65535, int(f))) if not math.isnan(f) else 0
|
|
||||||
def u8_to_u32(v): return int(v) & 0xff
|
|
||||||
def u4_to_u32(v): return int(v) & 0xf
|
|
||||||
def _sign(f): return 1 if math.copysign(1.0, f) < 0 else 0
|
|
||||||
def _mantissa_f32(f): return struct.unpack("<I", struct.pack("<f", f))[0] & 0x7fffff if not (math.isinf(f) or math.isnan(f)) else 0
|
|
||||||
def _ldexp(m, e): return math.ldexp(m, e)
|
|
||||||
def isEven(x):
|
|
||||||
x = float(x)
|
|
||||||
if math.isinf(x) or math.isnan(x): return False
|
|
||||||
return int(x) % 2 == 0
|
|
||||||
def fract(x): return x - math.floor(x)
|
|
||||||
PI = math.pi
|
|
||||||
def _trig(fn, x):
|
|
||||||
# V_SIN/COS_F32: hardware does frac on input cycles before computing
|
|
||||||
if math.isinf(x) or math.isnan(x): return float("nan")
|
|
||||||
frac_cycles = fract(x / (2 * math.pi))
|
|
||||||
result = fn(frac_cycles * 2 * math.pi)
|
|
||||||
# Hardware returns exactly 0 for cos(π/2), sin(π), etc. due to lookup table
|
|
||||||
# Round very small results (below f32 precision) to exactly 0
|
|
||||||
if abs(result) < 1e-7: return 0.0
|
|
||||||
return result
|
|
||||||
def sin(x): return _trig(math.sin, x)
|
|
||||||
def cos(x): return _trig(math.cos, x)
|
|
||||||
def pow(a, b):
|
|
||||||
try: return a ** b
|
|
||||||
except OverflowError: return float("inf") if b > 0 else 0.0
|
|
||||||
def _brev(v, bits): return int(bin(v & ((1 << bits) - 1))[2:].zfill(bits)[::-1], 2)
|
def _brev(v, bits): return int(bin(v & ((1 << bits) - 1))[2:].zfill(bits)[::-1], 2)
|
||||||
def _brev32(v): return _brev(v, 32)
|
|
||||||
def _brev64(v): return _brev(v, 64)
|
|
||||||
def _ctz(v, bits):
|
def _ctz(v, bits):
|
||||||
v, n = int(v) & ((1 << bits) - 1), 0
|
v, n = int(v) & ((1 << bits) - 1), 0
|
||||||
if v == 0: return bits
|
if v == 0: return bits
|
||||||
while (v & 1) == 0: v >>= 1; n += 1
|
while (v & 1) == 0: v >>= 1; n += 1
|
||||||
return n
|
return n
|
||||||
def _ctz32(v): return _ctz(v, 32)
|
|
||||||
def _ctz64(v): return _ctz(v, 64)
|
|
||||||
def _exponent(f):
|
|
||||||
# Handle TypedView (f16/f32/f64) to get correct exponent for that type
|
|
||||||
if hasattr(f, '_bits') and hasattr(f, '_float') and f._float:
|
|
||||||
raw = f._val
|
|
||||||
if f._bits == 16: return (raw >> 10) & 0x1f # f16: 5-bit exponent
|
|
||||||
if f._bits == 32: return (raw >> 23) & 0xff # f32: 8-bit exponent
|
|
||||||
if f._bits == 64: return (raw >> 52) & 0x7ff # f64: 11-bit exponent
|
|
||||||
# Fallback: convert to f32 and get exponent
|
|
||||||
f = float(f)
|
|
||||||
if math.isinf(f) or math.isnan(f): return 255
|
|
||||||
if f == 0.0: return 0
|
|
||||||
try: bits = struct.unpack("<I", struct.pack("<f", f))[0]; return (bits >> 23) & 0xff
|
|
||||||
except: return 0
|
|
||||||
def _is_denorm_f32(f):
|
|
||||||
if not isinstance(f, float): f = _f32(int(f) & 0xffffffff)
|
|
||||||
if math.isinf(f) or math.isnan(f) or f == 0.0: return False
|
|
||||||
bits = struct.unpack("<I", struct.pack("<f", float(f)))[0]
|
|
||||||
return (bits >> 23) & 0xff == 0
|
|
||||||
def _is_denorm_f64(f):
|
|
||||||
if not isinstance(f, float): f = _f64(int(f) & 0xffffffffffffffff)
|
|
||||||
if math.isinf(f) or math.isnan(f) or f == 0.0: return False
|
|
||||||
bits = struct.unpack("<Q", struct.pack("<d", float(f)))[0]
|
|
||||||
return (bits >> 52) & 0x7ff == 0
|
|
||||||
def v_min_f32(a, b): return a if math.isnan(b) else b if math.isnan(a) else (a if _lt_neg_zero(a, b) else b)
|
|
||||||
def v_max_f32(a, b): return a if math.isnan(b) else b if math.isnan(a) else (a if _gt_neg_zero(a, b) else b)
|
|
||||||
v_min_f16, v_max_f16 = v_min_f32, v_max_f32
|
|
||||||
v_min_i32, v_max_i32 = min, max
|
|
||||||
v_min_i16, v_max_i16 = min, max
|
|
||||||
def v_min_u32(a, b): return min(a & MASK32, b & MASK32)
|
|
||||||
def v_max_u32(a, b): return max(a & MASK32, b & MASK32)
|
|
||||||
def v_min_u16(a, b): return min(a & 0xffff, b & 0xffff)
|
|
||||||
def v_max_u16(a, b): return max(a & 0xffff, b & 0xffff)
|
|
||||||
def v_min3_f32(a, b, c): return v_min_f32(v_min_f32(a, b), c)
|
|
||||||
def v_max3_f32(a, b, c): return v_max_f32(v_max_f32(a, b), c)
|
|
||||||
v_min3_f16, v_max3_f16 = v_min3_f32, v_max3_f32
|
|
||||||
v_min3_i32, v_max3_i32, v_min3_i16, v_max3_i16 = min, max, min, max
|
|
||||||
def v_min3_u32(a, b, c): return min(a & MASK32, b & MASK32, c & MASK32)
|
|
||||||
def v_max3_u32(a, b, c): return max(a & MASK32, b & MASK32, c & MASK32)
|
|
||||||
def v_min3_u16(a, b, c): return min(a & 0xffff, b & 0xffff, c & 0xffff)
|
|
||||||
def v_max3_u16(a, b, c): return max(a & 0xffff, b & 0xffff, c & 0xffff)
|
|
||||||
def ABSDIFF(a, b): return abs(int(a) - int(b))
|
|
||||||
|
|
||||||
# BF16 (bfloat16) conversion functions
|
|
||||||
def _bf16(i):
|
def _bf16(i):
|
||||||
"""Convert bf16 bits to float. BF16 is just the top 16 bits of f32."""
|
"""Convert bf16 bits to float. BF16 is just the top 16 bits of f32."""
|
||||||
return struct.unpack("<f", struct.pack("<I", (i & 0xffff) << 16))[0]
|
return struct.unpack("<f", struct.pack("<I", (i & 0xffff) << 16))[0]
|
||||||
@@ -157,95 +49,21 @@ def _ibf16(f):
|
|||||||
if math.isinf(f): return 0x7f80 if f > 0 else 0xff80 # bf16 ±infinity
|
if math.isinf(f): return 0x7f80 if f > 0 else 0xff80 # bf16 ±infinity
|
||||||
try: return (struct.unpack("<I", struct.pack("<f", float(f)))[0] >> 16) & 0xffff
|
try: return (struct.unpack("<I", struct.pack("<f", float(f)))[0] >> 16) & 0xffff
|
||||||
except (OverflowError, struct.error): return 0x7f80 if f > 0 else 0xff80
|
except (OverflowError, struct.error): return 0x7f80 if f > 0 else 0xff80
|
||||||
def bf16_to_f32(v): return _bf16(v) if isinstance(v, int) else float(v)
|
def _trig(fn, x):
|
||||||
def f32_to_bf16(f): return _ibf16(f)
|
# V_SIN/COS_F32: hardware does frac on input cycles before computing
|
||||||
|
if math.isinf(x) or math.isnan(x): return float("nan")
|
||||||
|
frac_cycles = fract(x / (2 * math.pi))
|
||||||
|
result = fn(frac_cycles * 2 * math.pi)
|
||||||
|
# Hardware returns exactly 0 for cos(π/2), sin(π), etc. due to lookup table
|
||||||
|
# Round very small results (below f32 precision) to exactly 0
|
||||||
|
if abs(result) < 1e-7: return 0.0
|
||||||
|
return result
|
||||||
|
|
||||||
# BYTE_PERMUTE for V_PERM_B32 - select bytes from 64-bit data based on selector
|
class _SafeFloat(float):
|
||||||
def BYTE_PERMUTE(data, sel):
|
"""Float subclass that uses _div for division to handle 0/inf correctly."""
|
||||||
"""Select a byte from 64-bit data based on selector value.
|
def __truediv__(self, o): return _div(float(self), float(o))
|
||||||
sel 0-7: select byte from data (S1 is bytes 0-3, S0 is bytes 4-7 in {S0,S1})
|
def __rtruediv__(self, o): return _div(float(o), float(self))
|
||||||
sel 8-11: sign-extend from specific bytes (8->byte1, 9->byte3, 10->byte5, 11->byte7)
|
|
||||||
sel 12: constant 0x00
|
|
||||||
sel >= 13: constant 0xFF"""
|
|
||||||
sel = int(sel) & 0xff
|
|
||||||
if sel <= 7: return (int(data) >> (sel * 8)) & 0xff
|
|
||||||
if sel == 8: return 0xff if ((int(data) >> 15) & 1) else 0x00 # sign of byte 1
|
|
||||||
if sel == 9: return 0xff if ((int(data) >> 31) & 1) else 0x00 # sign of byte 3
|
|
||||||
if sel == 10: return 0xff if ((int(data) >> 47) & 1) else 0x00 # sign of byte 5
|
|
||||||
if sel == 11: return 0xff if ((int(data) >> 63) & 1) else 0x00 # sign of byte 7
|
|
||||||
if sel == 12: return 0x00
|
|
||||||
return 0xff # sel >= 13
|
|
||||||
|
|
||||||
# v_sad_u8 helper for V_SAD instructions (sum of absolute differences of 4 bytes)
|
|
||||||
def v_sad_u8(s0, s1, s2):
|
|
||||||
"""V_SAD_U8: Sum of absolute differences of 4 byte pairs plus accumulator."""
|
|
||||||
s0, s1, s2 = int(s0), int(s1), int(s2)
|
|
||||||
result = s2
|
|
||||||
for i in range(4):
|
|
||||||
a = (s0 >> (i * 8)) & 0xff
|
|
||||||
b = (s1 >> (i * 8)) & 0xff
|
|
||||||
result += abs(a - b)
|
|
||||||
return result & 0xffffffff
|
|
||||||
|
|
||||||
# v_msad_u8 helper (masked SAD - skip when reference byte is 0)
|
|
||||||
def v_msad_u8(s0, s1, s2):
|
|
||||||
"""V_MSAD_U8: Masked sum of absolute differences (skip if reference byte is 0)."""
|
|
||||||
s0, s1, s2 = int(s0), int(s1), int(s2)
|
|
||||||
result = s2
|
|
||||||
for i in range(4):
|
|
||||||
a = (s0 >> (i * 8)) & 0xff
|
|
||||||
b = (s1 >> (i * 8)) & 0xff
|
|
||||||
if b != 0: # Only add diff if reference (s1) byte is non-zero
|
|
||||||
result += abs(a - b)
|
|
||||||
return result & 0xffffffff
|
|
||||||
def f16_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
|
|
||||||
def f16_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535))))
|
|
||||||
def f32_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
|
|
||||||
def f32_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535))))
|
|
||||||
def v_cvt_i16_f32(f): return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
|
|
||||||
def v_cvt_u16_f32(f): return max(0, min(65535, int(f))) if not math.isnan(f) else 0
|
|
||||||
def u32_to_u16(u): return int(u) & 0xffff
|
|
||||||
def i32_to_i16(i): return ((int(i) + 32768) & 0xffff) - 32768
|
|
||||||
def SAT8(v): return max(0, min(255, int(v)))
|
|
||||||
def f32_to_u8(f): return max(0, min(255, int(f))) if not math.isnan(f) else 0
|
|
||||||
def mantissa(f):
|
|
||||||
if f == 0.0 or math.isinf(f) or math.isnan(f): return f
|
|
||||||
m, _ = math.frexp(f)
|
|
||||||
return m # AMD V_FREXP_MANT returns mantissa in [0.5, 1.0) range
|
|
||||||
def signext_from_bit(val, bit):
|
|
||||||
bit = int(bit)
|
|
||||||
if bit == 0: return 0
|
|
||||||
mask = (1 << bit) - 1
|
|
||||||
val = int(val) & mask
|
|
||||||
if val & (1 << (bit - 1)): return val - (1 << bit)
|
|
||||||
return val
|
|
||||||
|
|
||||||
# Aliases used in pseudocode
|
|
||||||
s_ff1_i32_b32, s_ff1_i32_b64 = _ctz32, _ctz64
|
|
||||||
GT_NEG_ZERO, LT_NEG_ZERO = _gt_neg_zero, _lt_neg_zero
|
|
||||||
isNAN = _isnan
|
|
||||||
isQuietNAN = _isquietnan
|
|
||||||
isSignalNAN = _issignalnan
|
|
||||||
fma, ldexp, sign, exponent = _fma, _ldexp, _sign, _exponent
|
|
||||||
def F(x):
|
|
||||||
"""32'F(x) or 64'F(x) - interpret x as float. If x is int, treat as bit pattern."""
|
|
||||||
if isinstance(x, int): return _f32(x) # int -> interpret as f32 bits
|
|
||||||
if isinstance(x, TypedView): return x # preserve TypedView for bit-pattern checks
|
|
||||||
return float(x) # already a float or float-like
|
|
||||||
signext = lambda x: int(x) # sign-extend to full width - already handled by Python's arbitrary precision ints
|
|
||||||
pack = lambda hi, lo: ((int(hi) & 0xffff) << 16) | (int(lo) & 0xffff)
|
|
||||||
pack32 = lambda hi, lo: ((int(hi) & 0xffffffff) << 32) | (int(lo) & 0xffffffff)
|
|
||||||
_pack, _pack32 = pack, pack32 # Aliases for internal use
|
|
||||||
WAVE32, WAVE64 = True, False
|
|
||||||
|
|
||||||
# Float overflow/underflow constants
|
|
||||||
OVERFLOW_F32 = float('inf')
|
|
||||||
UNDERFLOW_F32 = 0.0
|
|
||||||
OVERFLOW_F64 = float('inf')
|
|
||||||
UNDERFLOW_F64 = 0.0
|
|
||||||
MAX_FLOAT_F32 = 3.4028235e+38 # Largest finite float32
|
|
||||||
|
|
||||||
# INF object that supports .f16/.f32/.f64 access and comparison with floats
|
|
||||||
class _Inf:
|
class _Inf:
|
||||||
f16 = f32 = f64 = float('inf')
|
f16 = f32 = f64 = float('inf')
|
||||||
def __neg__(self): return _NegInf()
|
def __neg__(self): return _NegInf()
|
||||||
@@ -260,26 +78,24 @@ class _NegInf:
|
|||||||
def __float__(self): return float('-inf')
|
def __float__(self): return float('-inf')
|
||||||
def __eq__(self, other): return float(other) == float('-inf') if not isinstance(other, _Inf) else False
|
def __eq__(self, other): return float(other) == float('-inf') if not isinstance(other, _Inf) else False
|
||||||
def __req__(self, other): return self.__eq__(other)
|
def __req__(self, other): return self.__eq__(other)
|
||||||
INF = _Inf()
|
|
||||||
|
|
||||||
# Rounding mode placeholder
|
|
||||||
class _RoundMode:
|
class _RoundMode:
|
||||||
NEAREST_EVEN = 0
|
NEAREST_EVEN = 0
|
||||||
ROUND_MODE = _RoundMode()
|
|
||||||
|
|
||||||
# Helper functions for pseudocode
|
|
||||||
def cvtToQuietNAN(x): return float('nan')
|
|
||||||
DST = None # Placeholder, will be set in context
|
|
||||||
|
|
||||||
class _WaveMode:
|
class _WaveMode:
|
||||||
IEEE = False
|
IEEE = False
|
||||||
WAVE_MODE = _WaveMode()
|
|
||||||
|
|
||||||
class _DenormChecker:
|
class _DenormChecker:
|
||||||
"""Comparator for denormalized floats. x == DENORM.f32 checks if x is denormalized."""
|
"""Comparator for denormalized floats. x == DENORM.f32 checks if x is denormalized."""
|
||||||
def __init__(self, bits): self._bits = bits
|
def __init__(self, bits): self._bits = bits
|
||||||
def _check(self, other):
|
def _check(self, other):
|
||||||
return _is_denorm_f64(float(other)) if self._bits == 64 else _is_denorm_f32(float(other))
|
f = float(other)
|
||||||
|
if math.isinf(f) or math.isnan(f) or f == 0.0: return False
|
||||||
|
if self._bits == 64:
|
||||||
|
bits = struct.unpack("<Q", struct.pack("<d", f))[0]
|
||||||
|
return (bits >> 52) & 0x7ff == 0
|
||||||
|
bits = struct.unpack("<I", struct.pack("<f", f))[0]
|
||||||
|
return (bits >> 23) & 0xff == 0
|
||||||
def __eq__(self, other): return self._check(other)
|
def __eq__(self, other): return self._check(other)
|
||||||
def __req__(self, other): return self._check(other)
|
def __req__(self, other): return self._check(other)
|
||||||
def __ne__(self, other): return not self._check(other)
|
def __ne__(self, other): return not self._check(other)
|
||||||
@@ -287,7 +103,9 @@ class _DenormChecker:
|
|||||||
class _Denorm:
|
class _Denorm:
|
||||||
f32 = _DenormChecker(32)
|
f32 = _DenormChecker(32)
|
||||||
f64 = _DenormChecker(64)
|
f64 = _DenormChecker(64)
|
||||||
DENORM = _Denorm()
|
|
||||||
|
_pack = lambda hi, lo: ((int(hi) & 0xffff) << 16) | (int(lo) & 0xffff)
|
||||||
|
_pack32 = lambda hi, lo: ((int(hi) & 0xffffffff) << 32) | (int(lo) & 0xffffffff)
|
||||||
|
|
||||||
class TypedView:
|
class TypedView:
|
||||||
"""View into a Reg with typed access. Used for both full-width (Reg.u32) and slices (Reg[31:16])."""
|
"""View into a Reg with typed access. Used for both full-width (Reg.u32) and slices (Reg[31:16])."""
|
||||||
@@ -396,8 +214,6 @@ class TypedView:
|
|||||||
def __gt__(s, o): return float(s) > float(o) if s._float else int(s) > int(o)
|
def __gt__(s, o): return float(s) > float(o) if s._float else int(s) > int(o)
|
||||||
def __ge__(s, o): return float(s) >= float(o) if s._float else int(s) >= int(o)
|
def __ge__(s, o): return float(s) >= float(o) if s._float else int(s) >= int(o)
|
||||||
|
|
||||||
SliceProxy = TypedView # Alias for compatibility
|
|
||||||
|
|
||||||
class Reg:
|
class Reg:
|
||||||
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works. Supports up to 128 bits for DS_LOAD_B128."""
|
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works. Supports up to 128 bits for DS_LOAD_B128."""
|
||||||
__slots__ = ('_val',)
|
__slots__ = ('_val',)
|
||||||
@@ -466,5 +282,484 @@ class Reg:
|
|||||||
def __eq__(s, o): return s._val == int(o)
|
def __eq__(s, o): return s._val == int(o)
|
||||||
def __ne__(s, o): return s._val != int(o)
|
def __ne__(s, o): return s._val != int(o)
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# PSEUDOCODE API - Functions and constants from AMD ISA pseudocode
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
# Rounding and float operations
|
||||||
|
trunc, floor, ceil = _fpop(math.trunc), _fpop(math.floor), _fpop(math.ceil)
|
||||||
|
def sqrt(x): return _SafeFloat(math.sqrt(x)) if x >= 0 else _SafeFloat(float("nan"))
|
||||||
|
def log2(x): return math.log2(x) if x > 0 else (float("-inf") if x == 0 else float("nan"))
|
||||||
|
def fract(x): return x - math.floor(x)
|
||||||
|
def sin(x): return _trig(math.sin, x)
|
||||||
|
def cos(x): return _trig(math.cos, x)
|
||||||
|
def pow(a, b):
|
||||||
|
try: return a ** b
|
||||||
|
except OverflowError: return float("inf") if b > 0 else 0.0
|
||||||
|
def isEven(x):
|
||||||
|
x = float(x)
|
||||||
|
if math.isinf(x) or math.isnan(x): return False
|
||||||
|
return int(x) % 2 == 0
|
||||||
|
def mantissa(f):
|
||||||
|
if f == 0.0 or math.isinf(f) or math.isnan(f): return f
|
||||||
|
m, _ = math.frexp(f)
|
||||||
|
return m # AMD V_FREXP_MANT returns mantissa in [0.5, 1.0) range
|
||||||
|
def signext_from_bit(val, bit):
|
||||||
|
bit = int(bit)
|
||||||
|
if bit == 0: return 0
|
||||||
|
mask = (1 << bit) - 1
|
||||||
|
val = int(val) & mask
|
||||||
|
if val & (1 << (bit - 1)): return val - (1 << bit)
|
||||||
|
return val
|
||||||
|
|
||||||
|
# Type conversions
|
||||||
|
i32_to_f32 = u32_to_f32 = i32_to_f64 = u32_to_f64 = f32_to_f64 = f64_to_f32 = float
|
||||||
|
def f32_to_i32(f): return _f_to_int(f, -2147483648, 2147483647)
|
||||||
|
def f32_to_u32(f): return _f_to_int(f, 0, 4294967295)
|
||||||
|
f64_to_i32, f64_to_u32 = f32_to_i32, f32_to_u32
|
||||||
|
def f32_to_f16(f):
|
||||||
|
f = float(f)
|
||||||
|
if math.isnan(f): return 0x7e00 # f16 NaN
|
||||||
|
if math.isinf(f): return 0x7c00 if f > 0 else 0xfc00 # f16 ±infinity
|
||||||
|
try: return struct.unpack("<H", struct.pack("<e", f))[0]
|
||||||
|
except OverflowError: return 0x7c00 if f > 0 else 0xfc00 # overflow -> ±infinity
|
||||||
|
def f16_to_f32(v): return v if isinstance(v, float) else _f16_to_f32_bits(v)
|
||||||
|
def i16_to_f16(v): return f32_to_f16(float(_sext(int(v) & 0xffff, 16)))
|
||||||
|
def u16_to_f16(v): return f32_to_f16(float(int(v) & 0xffff))
|
||||||
|
def f16_to_i16(bits): f = _f16_to_f32_bits(bits); return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
|
||||||
|
def f16_to_u16(bits): f = _f16_to_f32_bits(bits); return max(0, min(65535, int(f))) if not math.isnan(f) else 0
|
||||||
|
def bf16_to_f32(v): return _bf16(v) if isinstance(v, int) else float(v)
|
||||||
|
def f32_to_bf16(f): return _ibf16(f)
|
||||||
|
def u8_to_u32(v): return int(v) & 0xff
|
||||||
|
def u4_to_u32(v): return int(v) & 0xf
|
||||||
|
def u32_to_u16(u): return int(u) & 0xffff
|
||||||
|
def i32_to_i16(i): return ((int(i) + 32768) & 0xffff) - 32768
|
||||||
|
def f16_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
|
||||||
|
def f16_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535))))
|
||||||
|
def f32_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
|
||||||
|
def f32_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535))))
|
||||||
|
def v_cvt_i16_f32(f): return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
|
||||||
|
def v_cvt_u16_f32(f): return max(0, min(65535, int(f))) if not math.isnan(f) else 0
|
||||||
|
def SAT8(v): return max(0, min(255, int(v)))
|
||||||
|
def f32_to_u8(f): return max(0, min(255, int(f))) if not math.isnan(f) else 0
|
||||||
|
|
||||||
|
# Min/max operations
|
||||||
|
def v_min_f32(a, b): return a if math.isnan(b) else b if math.isnan(a) else (a if _lt_neg_zero(a, b) else b)
|
||||||
|
def v_max_f32(a, b): return a if math.isnan(b) else b if math.isnan(a) else (a if _gt_neg_zero(a, b) else b)
|
||||||
|
v_min_f16, v_max_f16 = v_min_f32, v_max_f32
|
||||||
|
v_min_i32, v_max_i32 = min, max
|
||||||
|
v_min_i16, v_max_i16 = min, max
|
||||||
|
def v_min_u32(a, b): return min(a & MASK32, b & MASK32)
|
||||||
|
def v_max_u32(a, b): return max(a & MASK32, b & MASK32)
|
||||||
|
def v_min_u16(a, b): return min(a & 0xffff, b & 0xffff)
|
||||||
|
def v_max_u16(a, b): return max(a & 0xffff, b & 0xffff)
|
||||||
|
def v_min3_f32(a, b, c): return v_min_f32(v_min_f32(a, b), c)
|
||||||
|
def v_max3_f32(a, b, c): return v_max_f32(v_max_f32(a, b), c)
|
||||||
|
v_min3_f16, v_max3_f16 = v_min3_f32, v_max3_f32
|
||||||
|
v_min3_i32, v_max3_i32, v_min3_i16, v_max3_i16 = min, max, min, max
|
||||||
|
def v_min3_u32(a, b, c): return min(a & MASK32, b & MASK32, c & MASK32)
|
||||||
|
def v_max3_u32(a, b, c): return max(a & MASK32, b & MASK32, c & MASK32)
|
||||||
|
def v_min3_u16(a, b, c): return min(a & 0xffff, b & 0xffff, c & 0xffff)
|
||||||
|
def v_max3_u16(a, b, c): return max(a & 0xffff, b & 0xffff, c & 0xffff)
|
||||||
|
|
||||||
|
# SAD/MSAD operations
|
||||||
|
def ABSDIFF(a, b): return abs(int(a) - int(b))
|
||||||
|
def v_sad_u8(s0, s1, s2):
|
||||||
|
"""V_SAD_U8: Sum of absolute differences of 4 byte pairs plus accumulator."""
|
||||||
|
s0, s1, s2 = int(s0), int(s1), int(s2)
|
||||||
|
result = s2
|
||||||
|
for i in range(4):
|
||||||
|
a = (s0 >> (i * 8)) & 0xff
|
||||||
|
b = (s1 >> (i * 8)) & 0xff
|
||||||
|
result += abs(a - b)
|
||||||
|
return result & 0xffffffff
|
||||||
|
def v_msad_u8(s0, s1, s2):
|
||||||
|
"""V_MSAD_U8: Masked sum of absolute differences (skip if reference byte is 0)."""
|
||||||
|
s0, s1, s2 = int(s0), int(s1), int(s2)
|
||||||
|
result = s2
|
||||||
|
for i in range(4):
|
||||||
|
a = (s0 >> (i * 8)) & 0xff
|
||||||
|
b = (s1 >> (i * 8)) & 0xff
|
||||||
|
if b != 0: # Only add diff if reference (s1) byte is non-zero
|
||||||
|
result += abs(a - b)
|
||||||
|
return result & 0xffffffff
|
||||||
|
|
||||||
|
def BYTE_PERMUTE(data, sel):
|
||||||
|
"""Select a byte from 64-bit data based on selector value."""
|
||||||
|
sel = int(sel) & 0xff
|
||||||
|
if sel <= 7: return (int(data) >> (sel * 8)) & 0xff
|
||||||
|
if sel == 8: return 0xff if ((int(data) >> 15) & 1) else 0x00
|
||||||
|
if sel == 9: return 0xff if ((int(data) >> 31) & 1) else 0x00
|
||||||
|
if sel == 10: return 0xff if ((int(data) >> 47) & 1) else 0x00
|
||||||
|
if sel == 11: return 0xff if ((int(data) >> 63) & 1) else 0x00
|
||||||
|
if sel == 12: return 0x00
|
||||||
|
return 0xff
|
||||||
|
|
||||||
|
# Pseudocode functions
|
||||||
|
def s_ff1_i32_b32(v): return _ctz(v, 32)
|
||||||
|
def s_ff1_i32_b64(v): return _ctz(v, 64)
|
||||||
|
GT_NEG_ZERO, LT_NEG_ZERO = _gt_neg_zero, _lt_neg_zero
|
||||||
|
def isNAN(x):
|
||||||
|
try: return math.isnan(float(x))
|
||||||
|
except (TypeError, ValueError): return False
|
||||||
|
def isQuietNAN(x): return _check_nan_type(x, 1, True)
|
||||||
|
def isSignalNAN(x): return _check_nan_type(x, 0, False)
|
||||||
|
def fma(a, b, c):
|
||||||
|
try: return math.fma(a, b, c)
|
||||||
|
except ValueError: return float('nan')
|
||||||
|
def ldexp(m, e): return math.ldexp(m, e)
|
||||||
|
def sign(f): return 1 if math.copysign(1.0, f) < 0 else 0
|
||||||
|
def exponent(f):
|
||||||
|
if hasattr(f, '_bits') and hasattr(f, '_float') and f._float:
|
||||||
|
raw = f._val
|
||||||
|
if f._bits == 16: return (raw >> 10) & 0x1f
|
||||||
|
if f._bits == 32: return (raw >> 23) & 0xff
|
||||||
|
if f._bits == 64: return (raw >> 52) & 0x7ff
|
||||||
|
f = float(f)
|
||||||
|
if math.isinf(f) or math.isnan(f): return 255
|
||||||
|
if f == 0.0: return 0
|
||||||
|
try: bits = struct.unpack("<I", struct.pack("<f", f))[0]; return (bits >> 23) & 0xff
|
||||||
|
except: return 0
|
||||||
|
def signext(x): return int(x)
|
||||||
|
def cvtToQuietNAN(x): return float('nan')
|
||||||
|
|
||||||
|
def F(x):
|
||||||
|
"""32'F(x) or 64'F(x) - interpret x as float. If x is int, treat as bit pattern."""
|
||||||
|
if isinstance(x, int): return _f32(x)
|
||||||
|
if isinstance(x, TypedView): return x
|
||||||
|
return float(x)
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
PI = math.pi
|
||||||
|
WAVE32, WAVE64 = True, False
|
||||||
|
OVERFLOW_F32, UNDERFLOW_F32 = float('inf'), 0.0
|
||||||
|
OVERFLOW_F64, UNDERFLOW_F64 = float('inf'), 0.0
|
||||||
|
MAX_FLOAT_F32 = 3.4028235e+38
|
||||||
|
INF = _Inf()
|
||||||
|
ROUND_MODE = _RoundMode()
|
||||||
|
WAVE_MODE = _WaveMode()
|
||||||
|
DENORM = _Denorm()
|
||||||
|
|
||||||
# 2/PI with 1201 bits of precision for V_TRIG_PREOP_F64
|
# 2/PI with 1201 bits of precision for V_TRIG_PREOP_F64
|
||||||
TWO_OVER_PI_1201 = Reg(0x0145f306dc9c882a53f84eafa3ea69bb81b6c52b3278872083fca2c757bd778ac36e48dc74849ba5c00c925dd413a32439fc3bd63962534e7dd1046bea5d768909d338e04d68befc827323ac7306a673e93908bf177bf250763ff12fffbc0b301fde5e2316b414da3eda6cfd9e4f96136e9e8c7ecd3cbfd45aea4f758fd7cbe2f67a0e73ef14a525d4d7f6bf623f1aba10ac06608df8f6)
|
TWO_OVER_PI_1201 = Reg(0x0145f306dc9c882a53f84eafa3ea69bb81b6c52b3278872083fca2c757bd778ac36e48dc74849ba5c00c925dd413a32439fc3bd63962534e7dd1046bea5d768909d338e04d68befc827323ac7306a673e93908bf177bf250763ff12fffbc0b301fde5e2316b414da3eda6cfd9e4f96136e9e8c7ecd3cbfd45aea4f758fd7cbe2f67a0e73ef14a525d4d7f6bf623f1aba10ac06608df8f6)
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# COMPILER: pseudocode -> Python (minimal transforms)
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
def _compile_pseudocode(pseudocode: str) -> str:
|
||||||
|
"""Compile pseudocode to Python. Transforms are minimal - most syntax just works."""
|
||||||
|
pseudocode = re.sub(r'\bpass\b', 'pass_', pseudocode) # 'pass' is Python keyword
|
||||||
|
raw_lines = pseudocode.strip().split('\n')
|
||||||
|
joined_lines: list[str] = []
|
||||||
|
for line in raw_lines:
|
||||||
|
line = line.strip()
|
||||||
|
if joined_lines and (joined_lines[-1].rstrip().endswith(('||', '&&', '(', ',')) or
|
||||||
|
(joined_lines[-1].count('(') > joined_lines[-1].count(')'))):
|
||||||
|
joined_lines[-1] = joined_lines[-1].rstrip() + ' ' + line
|
||||||
|
else:
|
||||||
|
joined_lines.append(line)
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
indent, need_pass, in_first_match_loop = 0, False, False
|
||||||
|
for line in joined_lines:
|
||||||
|
line = line.split('//')[0].strip() # Strip C-style comments
|
||||||
|
if not line: continue
|
||||||
|
if line.startswith('if '):
|
||||||
|
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:")
|
||||||
|
indent += 1
|
||||||
|
need_pass = True
|
||||||
|
elif line.startswith('elsif '):
|
||||||
|
if need_pass: lines.append(' ' * indent + "pass")
|
||||||
|
indent -= 1
|
||||||
|
lines.append(' ' * indent + f"elif {_expr(line[6:].rstrip(' then'))}:")
|
||||||
|
indent += 1
|
||||||
|
need_pass = True
|
||||||
|
elif line == 'else':
|
||||||
|
if need_pass: lines.append(' ' * indent + "pass")
|
||||||
|
indent -= 1
|
||||||
|
lines.append(' ' * indent + "else:")
|
||||||
|
indent += 1
|
||||||
|
need_pass = True
|
||||||
|
elif line.startswith('endif'):
|
||||||
|
if need_pass: lines.append(' ' * indent + "pass")
|
||||||
|
indent -= 1
|
||||||
|
need_pass = False
|
||||||
|
elif line.startswith('endfor'):
|
||||||
|
if need_pass: lines.append(' ' * indent + "pass")
|
||||||
|
indent -= 1
|
||||||
|
need_pass, in_first_match_loop = False, False
|
||||||
|
elif line.startswith('declare '):
|
||||||
|
pass
|
||||||
|
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
|
||||||
|
start, end = _expr(m[2].strip()), _expr(m[3].strip())
|
||||||
|
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
|
||||||
|
indent += 1
|
||||||
|
need_pass, in_first_match_loop = True, True
|
||||||
|
elif '=' in line and not line.startswith('=='):
|
||||||
|
need_pass = False
|
||||||
|
line = line.rstrip(';')
|
||||||
|
if m := re.match(r'\{\s*D1\.[ui]1\s*,\s*D0\.[ui]64\s*\}\s*=\s*(.+)', line):
|
||||||
|
rhs = _expr(m[1])
|
||||||
|
lines.append(' ' * indent + f"_full = {rhs}")
|
||||||
|
lines.append(' ' * indent + f"D0.u64 = int(_full) & 0xffffffffffffffff")
|
||||||
|
lines.append(' ' * indent + f"D1 = Reg((int(_full) >> 64) & 1)")
|
||||||
|
elif any(op in line for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^=')):
|
||||||
|
for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^='):
|
||||||
|
if op in line:
|
||||||
|
lhs, rhs = line.split(op, 1)
|
||||||
|
lines.append(' ' * indent + f"{lhs.strip()} {op} {_expr(rhs.strip())}")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
lhs, rhs = line.split('=', 1)
|
||||||
|
lhs_s, rhs_s = _expr(lhs.strip()), rhs.strip()
|
||||||
|
stmt = _assign(lhs_s, _expr(rhs_s))
|
||||||
|
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
|
||||||
|
stmt += "; break"
|
||||||
|
lines.append(' ' * indent + stmt)
|
||||||
|
if need_pass: lines.append(' ' * indent + "pass")
|
||||||
|
return '\n'.join(lines)
|
||||||
|
|
||||||
|
def _assign(lhs: str, rhs: str) -> str:
|
||||||
|
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'):
|
||||||
|
return f"{lhs} = Reg({rhs})"
|
||||||
|
return f"{lhs} = {rhs}"
|
||||||
|
|
||||||
|
def _expr(e: str) -> str:
|
||||||
|
e = e.strip()
|
||||||
|
e = e.replace('&&', ' and ').replace('||', ' or ').replace('<>', ' != ')
|
||||||
|
e = re.sub(r'!([^=])', r' not \1', e)
|
||||||
|
e = re.sub(r'\{\s*(\w+\.u32)\s*,\s*(\w+\.u32)\s*\}', r'_pack32(\1, \2)', e)
|
||||||
|
def pack(m):
|
||||||
|
hi, lo = _expr(m[1].strip()), _expr(m[2].strip())
|
||||||
|
return f'_pack({hi}, {lo})'
|
||||||
|
e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e)
|
||||||
|
e = re.sub(r"1201'B\(2\.0\s*/\s*PI\)", "TWO_OVER_PI_1201", e)
|
||||||
|
e = re.sub(r"\d+'([0-9a-fA-Fx]+)[UuFf]*", r'\1', e)
|
||||||
|
e = re.sub(r"\d+'[FIBU]\(", "(", e)
|
||||||
|
e = re.sub(r'\bB\(', '(', e)
|
||||||
|
e = re.sub(r'([0-9a-fA-Fx])ULL\b', r'\1', e)
|
||||||
|
e = re.sub(r'([0-9a-fA-Fx])LL\b', r'\1', e)
|
||||||
|
e = re.sub(r'([0-9a-fA-Fx])U\b', r'\1', e)
|
||||||
|
e = re.sub(r'(\d\.?\d*)F\b', r'\1', e)
|
||||||
|
e = re.sub(r'(\[laneId\])\.[uib]\d+', r'\1', e)
|
||||||
|
e = e.replace('+INF', 'INF').replace('-INF', '(-INF)')
|
||||||
|
e = re.sub(r'NAN\.f\d+', 'float("nan")', e)
|
||||||
|
def convert_verilog_slice(m):
|
||||||
|
start, width = m.group(1).strip(), m.group(2).strip()
|
||||||
|
return f'[({start}) + ({width}) - 1 : ({start})]'
|
||||||
|
e = re.sub(r'\[([^:\[\]]+)\s*\+:\s*([^:\[\]]+)\]', convert_verilog_slice, e)
|
||||||
|
def process_brackets(s):
|
||||||
|
result, i = [], 0
|
||||||
|
while i < len(s):
|
||||||
|
if s[i] == '[':
|
||||||
|
depth, start = 1, i + 1
|
||||||
|
j = start
|
||||||
|
while j < len(s) and depth > 0:
|
||||||
|
if s[j] == '[': depth += 1
|
||||||
|
elif s[j] == ']': depth -= 1
|
||||||
|
j += 1
|
||||||
|
inner = _expr(s[start:j-1])
|
||||||
|
result.append('[' + inner + ']')
|
||||||
|
i = j
|
||||||
|
else:
|
||||||
|
result.append(s[i])
|
||||||
|
i += 1
|
||||||
|
return ''.join(result)
|
||||||
|
e = process_brackets(e)
|
||||||
|
while '?' in e:
|
||||||
|
depth, bracket, q = 0, 0, -1
|
||||||
|
for i, c in enumerate(e):
|
||||||
|
if c == '(': depth += 1
|
||||||
|
elif c == ')': depth -= 1
|
||||||
|
elif c == '[': bracket += 1
|
||||||
|
elif c == ']': bracket -= 1
|
||||||
|
elif c == '?' and depth == 0 and bracket == 0: q = i; break
|
||||||
|
if q < 0: break
|
||||||
|
depth, bracket, col = 0, 0, -1
|
||||||
|
for i in range(q + 1, len(e)):
|
||||||
|
if e[i] == '(': depth += 1
|
||||||
|
elif e[i] == ')': depth -= 1
|
||||||
|
elif e[i] == '[': bracket += 1
|
||||||
|
elif e[i] == ']': bracket -= 1
|
||||||
|
elif e[i] == ':' and depth == 0 and bracket == 0: col = i; break
|
||||||
|
if col < 0: break
|
||||||
|
cond, t, f = e[:q].strip(), e[q+1:col].strip(), e[col+1:].strip()
|
||||||
|
e = f'(({t}) if ({cond}) else ({f}))'
|
||||||
|
return e
|
||||||
|
|
||||||
|
def _apply_pseudocode_fixes(op_name: str, code: str) -> str:
|
||||||
|
"""Apply known fixes for PDF pseudocode bugs."""
|
||||||
|
if op_name == 'V_DIV_FMAS_F32':
|
||||||
|
code = code.replace('D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)',
|
||||||
|
'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)')
|
||||||
|
if op_name == 'V_DIV_FMAS_F64':
|
||||||
|
code = code.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
|
||||||
|
'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)')
|
||||||
|
if op_name == 'V_DIV_SCALE_F32':
|
||||||
|
code = code.replace('D0.f32 = float("nan")', 'VCC = Reg(0x1); D0.f32 = float("nan")')
|
||||||
|
code = code.replace('elif S1.f32 == DENORM.f32:\n D0.f32 = ldexp(S0.f32, 64)', 'elif False:\n pass')
|
||||||
|
code += '\nif S1.f32 == DENORM.f32:\n D0.f32 = float("nan")'
|
||||||
|
code = code.replace('elif exponent(S2.f32) <= 23:\n D0.f32 = ldexp(S0.f32, 64)', 'elif exponent(S2.f32) <= 23:\n VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64)')
|
||||||
|
code = code.replace('elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)\n if S0.f32 == S2.f32:\n D0.f32 = ldexp(S0.f32, 64)', 'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)')
|
||||||
|
if op_name == 'V_DIV_SCALE_F64':
|
||||||
|
code = code.replace('D0.f64 = float("nan")', 'VCC = Reg(0x1); D0.f64 = float("nan")')
|
||||||
|
code = code.replace('elif S1.f64 == DENORM.f64:\n D0.f64 = ldexp(S0.f64, 128)', 'elif False:\n pass')
|
||||||
|
code += '\nif S1.f64 == DENORM.f64:\n D0.f64 = float("nan")'
|
||||||
|
code = code.replace('elif exponent(S2.f64) <= 52:\n D0.f64 = ldexp(S0.f64, 128)', 'elif exponent(S2.f64) <= 52:\n VCC = Reg(0x1); D0.f64 = ldexp(S0.f64, 128)')
|
||||||
|
code = code.replace('elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)\n if S0.f64 == S2.f64:\n D0.f64 = ldexp(S0.f64, 128)', 'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)')
|
||||||
|
if op_name == 'V_DIV_FIXUP_F32':
|
||||||
|
code = code.replace('D0.f32 = ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))',
|
||||||
|
'D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32)) if isNAN(S0.f32) else ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))')
|
||||||
|
if op_name == 'V_DIV_FIXUP_F64':
|
||||||
|
code = code.replace('D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))',
|
||||||
|
'D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))')
|
||||||
|
if op_name == 'V_TRIG_PREOP_F64':
|
||||||
|
code = code.replace('result = F((TWO_OVER_PI_1201[1200 : 0] << shift.u32) & 0x1fffffffffffff)',
|
||||||
|
'result = float(((TWO_OVER_PI_1201[1200 : 0] << int(shift)) >> (1201 - 53)) & 0x1fffffffffffff)')
|
||||||
|
return code
|
||||||
|
|
||||||
|
def _generate_function(cls_name: str, op_name: str, pc: str, code: str) -> str:
|
||||||
|
"""Generate a single compiled pseudocode function.
|
||||||
|
Functions take int parameters and return dict of int values.
|
||||||
|
Reg wrapping happens inside the function, only for registers actually used."""
|
||||||
|
has_d1 = '{ D1' in pc
|
||||||
|
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
|
||||||
|
is_div_scale = 'DIV_SCALE' in op_name
|
||||||
|
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
|
||||||
|
is_ds = cls_name == 'DSOp'
|
||||||
|
is_flat = cls_name in ('FLATOp', 'GLOBALOp', 'SCRATCHOp')
|
||||||
|
is_smem = cls_name == 'SMEMOp'
|
||||||
|
has_s_array = 'S[i]' in pc # FMA_MIX style: S[0], S[1], S[2] array access
|
||||||
|
combined = code + pc
|
||||||
|
|
||||||
|
fn_name = f"_{cls_name}_{op_name}"
|
||||||
|
|
||||||
|
# Detect which registers are used/modified
|
||||||
|
def needs_init(name): return name in combined and not re.search(rf'^\s*{name}\s*=\s*Reg\(', code, re.MULTILINE)
|
||||||
|
modifies_d0 = is_div_scale or bool(re.search(r'\bD0\b[.\[]', combined))
|
||||||
|
modifies_exec = is_cmpx or bool(re.search(r'EXEC\.(u32|u64|b32|b64)\s*=', combined))
|
||||||
|
modifies_vcc = has_sdst or bool(re.search(r'VCC\.(u32|u64|b32|b64)\s*=|VCC\.u64\[laneId\]\s*=', combined))
|
||||||
|
modifies_scc = bool(re.search(r'\bSCC\s*=', combined))
|
||||||
|
modifies_pc = bool(re.search(r'\bPC\s*=', combined))
|
||||||
|
|
||||||
|
# Build function signature and Reg init lines
|
||||||
|
if is_smem:
|
||||||
|
lines = [f"def {fn_name}(MEM, addr):"]
|
||||||
|
reg_inits = ["ADDR=Reg(addr)", "SDATA=Reg(0)"]
|
||||||
|
special_regs = []
|
||||||
|
elif is_ds:
|
||||||
|
lines = [f"def {fn_name}(MEM, addr, data0, data1, offset0, offset1):"]
|
||||||
|
reg_inits = ["ADDR=Reg(addr)", "DATA0=Reg(data0)", "DATA1=Reg(data1)", "OFFSET0=Reg(offset0)", "OFFSET1=Reg(offset1)", "RETURN_DATA=Reg(0)"]
|
||||||
|
special_regs = [('DATA', 'DATA0'), ('DATA2', 'DATA1'), ('OFFSET', 'OFFSET0'), ('ADDR_BASE', 'ADDR')]
|
||||||
|
elif is_flat:
|
||||||
|
lines = [f"def {fn_name}(MEM, addr, vdata, vdst):"]
|
||||||
|
reg_inits = ["ADDR=addr", "VDATA=Reg(vdata)", "VDST=Reg(vdst)", "RETURN_DATA=Reg(0)"]
|
||||||
|
special_regs = [('DATA', 'VDATA')]
|
||||||
|
elif has_s_array:
|
||||||
|
# FMA_MIX style: needs S[i] array, opsel, opsel_hi for source selection (neg/neg_hi applied in emu.py before call)
|
||||||
|
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None, opsel=0, opsel_hi=0):"]
|
||||||
|
reg_inits = ["S0=Reg(s0)", "S1=Reg(s1)", "S2=Reg(s2)", "S=[S0,S1,S2]", "D0=Reg(d0)", "OPSEL=Reg(opsel)", "OPSEL_HI=Reg(opsel_hi)"]
|
||||||
|
special_regs = []
|
||||||
|
# Detect array declarations like "declare in : 32'F[3]" and create them (rename 'in' to 'ins' since 'in' is a keyword)
|
||||||
|
if "in[" in combined:
|
||||||
|
reg_inits.append("ins=[Reg(0),Reg(0),Reg(0)]")
|
||||||
|
code = code.replace("in[", "ins[")
|
||||||
|
else:
|
||||||
|
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None):"]
|
||||||
|
# Only create Regs for registers actually used in the pseudocode
|
||||||
|
reg_inits = []
|
||||||
|
if 'S0' in combined: reg_inits.append("S0=Reg(s0)")
|
||||||
|
if 'S1' in combined: reg_inits.append("S1=Reg(s1)")
|
||||||
|
if 'S2' in combined: reg_inits.append("S2=Reg(s2)")
|
||||||
|
if modifies_d0 or 'D0' in combined: reg_inits.append("D0=Reg(s0)" if is_div_scale else "D0=Reg(d0)")
|
||||||
|
if modifies_scc or 'SCC' in combined: reg_inits.append("SCC=Reg(scc)")
|
||||||
|
if modifies_vcc or 'VCC' in combined: reg_inits.append("VCC=Reg(vcc)")
|
||||||
|
if modifies_exec or 'EXEC' in combined: reg_inits.append("EXEC=Reg(exec_mask)")
|
||||||
|
if modifies_pc or 'PC' in combined: reg_inits.append("PC=Reg(pc) if pc is not None else None")
|
||||||
|
special_regs = [('D1', 'Reg(0)'), ('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
|
||||||
|
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
|
||||||
|
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
|
||||||
|
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
|
||||||
|
|
||||||
|
# Build init code
|
||||||
|
init_parts = reg_inits.copy()
|
||||||
|
for name, init in special_regs:
|
||||||
|
if name in combined: init_parts.append(f"{name}={init}")
|
||||||
|
if 'EXEC_LO' in code: init_parts.append("EXEC_LO=TypedView(EXEC, 31, 0)")
|
||||||
|
if 'EXEC_HI' in code: init_parts.append("EXEC_HI=TypedView(EXEC, 63, 32)")
|
||||||
|
if 'VCCZ' in code and not re.search(r'^\s*VCCZ\s*=', code, re.MULTILINE): init_parts.append("VCCZ=Reg(1 if VCC._val == 0 else 0)")
|
||||||
|
if 'EXECZ' in code and not re.search(r'^\s*EXECZ\s*=', code, re.MULTILINE): init_parts.append("EXECZ=Reg(1 if EXEC._val == 0 else 0)")
|
||||||
|
|
||||||
|
# Add init line and separator
|
||||||
|
if init_parts: lines.append(f" {'; '.join(init_parts)}")
|
||||||
|
|
||||||
|
# Add compiled pseudocode
|
||||||
|
for line in code.split('\n'):
|
||||||
|
if line.strip(): lines.append(f" {line}")
|
||||||
|
|
||||||
|
# Build result dict
|
||||||
|
result_items = []
|
||||||
|
if modifies_d0: result_items.append("'D0': D0._val")
|
||||||
|
if modifies_scc: result_items.append("'SCC': SCC._val")
|
||||||
|
if modifies_vcc: result_items.append("'VCC': VCC._val")
|
||||||
|
if modifies_exec: result_items.append("'EXEC': EXEC._val")
|
||||||
|
if has_d1: result_items.append("'D1': D1._val")
|
||||||
|
if modifies_pc: result_items.append("'PC': PC._val")
|
||||||
|
if is_smem and 'SDATA' in combined and re.search(r'^\s*SDATA[\.\[].*=', code, re.MULTILINE):
|
||||||
|
result_items.append("'SDATA': SDATA._val")
|
||||||
|
if is_ds and 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
|
||||||
|
result_items.append("'RETURN_DATA': RETURN_DATA._val")
|
||||||
|
if is_flat:
|
||||||
|
if 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
|
||||||
|
result_items.append("'RETURN_DATA': RETURN_DATA._val")
|
||||||
|
if re.search(r'^\s*VDATA[\.\[].*=', code, re.MULTILINE):
|
||||||
|
result_items.append("'VDATA': VDATA._val")
|
||||||
|
lines.append(f" return {{{', '.join(result_items)}}}")
|
||||||
|
return '\n'.join(lines)
|
||||||
|
|
||||||
|
# Build the globals dict for exec() - includes all pcode symbols
|
||||||
|
_PCODE_GLOBALS = {
|
||||||
|
'Reg': Reg, 'TypedView': TypedView, '_pack': _pack, '_pack32': _pack32,
|
||||||
|
'ABSDIFF': ABSDIFF, 'BYTE_PERMUTE': BYTE_PERMUTE, 'DENORM': DENORM, 'F': F,
|
||||||
|
'GT_NEG_ZERO': GT_NEG_ZERO, 'LT_NEG_ZERO': LT_NEG_ZERO, 'INF': INF,
|
||||||
|
'MAX_FLOAT_F32': MAX_FLOAT_F32, 'OVERFLOW_F32': OVERFLOW_F32, 'OVERFLOW_F64': OVERFLOW_F64,
|
||||||
|
'UNDERFLOW_F32': UNDERFLOW_F32, 'UNDERFLOW_F64': UNDERFLOW_F64,
|
||||||
|
'PI': PI, 'ROUND_MODE': ROUND_MODE, 'WAVE_MODE': WAVE_MODE,
|
||||||
|
'WAVE32': WAVE32, 'WAVE64': WAVE64, 'TWO_OVER_PI_1201': TWO_OVER_PI_1201,
|
||||||
|
'SAT8': SAT8, 'trunc': trunc, 'floor': floor, 'ceil': ceil, 'sqrt': sqrt,
|
||||||
|
'log2': log2, 'fract': fract, 'sin': sin, 'cos': cos, 'pow': pow,
|
||||||
|
'isEven': isEven, 'mantissa': mantissa, 'signext_from_bit': signext_from_bit,
|
||||||
|
'i32_to_f32': i32_to_f32, 'u32_to_f32': u32_to_f32, 'i32_to_f64': i32_to_f64,
|
||||||
|
'u32_to_f64': u32_to_f64, 'f32_to_f64': f32_to_f64, 'f64_to_f32': f64_to_f32,
|
||||||
|
'f32_to_i32': f32_to_i32, 'f32_to_u32': f32_to_u32, 'f64_to_i32': f64_to_i32,
|
||||||
|
'f64_to_u32': f64_to_u32, 'f32_to_f16': f32_to_f16, 'f16_to_f32': f16_to_f32,
|
||||||
|
'i16_to_f16': i16_to_f16, 'u16_to_f16': u16_to_f16, 'f16_to_i16': f16_to_i16,
|
||||||
|
'f16_to_u16': f16_to_u16, 'bf16_to_f32': bf16_to_f32, 'f32_to_bf16': f32_to_bf16,
|
||||||
|
'u8_to_u32': u8_to_u32, 'u4_to_u32': u4_to_u32, 'u32_to_u16': u32_to_u16,
|
||||||
|
'i32_to_i16': i32_to_i16, 'f16_to_snorm': f16_to_snorm, 'f16_to_unorm': f16_to_unorm,
|
||||||
|
'f32_to_snorm': f32_to_snorm, 'f32_to_unorm': f32_to_unorm,
|
||||||
|
'v_cvt_i16_f32': v_cvt_i16_f32, 'v_cvt_u16_f32': v_cvt_u16_f32, 'f32_to_u8': f32_to_u8,
|
||||||
|
'v_min_f32': v_min_f32, 'v_max_f32': v_max_f32, 'v_min_f16': v_min_f16, 'v_max_f16': v_max_f16,
|
||||||
|
'v_min_i32': v_min_i32, 'v_max_i32': v_max_i32, 'v_min_i16': v_min_i16, 'v_max_i16': v_max_i16,
|
||||||
|
'v_min_u32': v_min_u32, 'v_max_u32': v_max_u32, 'v_min_u16': v_min_u16, 'v_max_u16': v_max_u16,
|
||||||
|
'v_min3_f32': v_min3_f32, 'v_max3_f32': v_max3_f32, 'v_min3_f16': v_min3_f16, 'v_max3_f16': v_max3_f16,
|
||||||
|
'v_min3_i32': v_min3_i32, 'v_max3_i32': v_max3_i32, 'v_min3_i16': v_min3_i16, 'v_max3_i16': v_max3_i16,
|
||||||
|
'v_min3_u32': v_min3_u32, 'v_max3_u32': v_max3_u32, 'v_min3_u16': v_min3_u16, 'v_max3_u16': v_max3_u16,
|
||||||
|
'v_sad_u8': v_sad_u8, 'v_msad_u8': v_msad_u8,
|
||||||
|
's_ff1_i32_b32': s_ff1_i32_b32, 's_ff1_i32_b64': s_ff1_i32_b64,
|
||||||
|
'isNAN': isNAN, 'isQuietNAN': isQuietNAN, 'isSignalNAN': isSignalNAN,
|
||||||
|
'fma': fma, 'ldexp': ldexp, 'sign': sign, 'exponent': exponent,
|
||||||
|
'signext': signext, 'cvtToQuietNAN': cvtToQuietNAN,
|
||||||
|
}
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def compile_pseudocode(cls_name: str, op_name: str, pseudocode: str):
|
||||||
|
"""Compile pseudocode string to executable function. Cached for performance."""
|
||||||
|
code = _compile_pseudocode(pseudocode)
|
||||||
|
code = _apply_pseudocode_fixes(op_name, code)
|
||||||
|
fn_code = _generate_function(cls_name, op_name, pseudocode, code)
|
||||||
|
fn_name = f"_{cls_name}_{op_name}"
|
||||||
|
local_ns = {}
|
||||||
|
exec(fn_code, _PCODE_GLOBALS, local_ns)
|
||||||
|
return local_ns[fn_name]
|
||||||
|
|||||||
@@ -38,161 +38,7 @@ FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'N
|
|||||||
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
|
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
|
||||||
INST_PATTERN = re.compile(r'^([SVD]S?_[A-Z0-9_]+|(?:FLAT|GLOBAL|SCRATCH)_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
|
INST_PATTERN = re.compile(r'^([SVD]S?_[A-Z0-9_]+|(?:FLAT|GLOBAL|SCRATCH)_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
|
||||||
|
|
||||||
# Patterns that can't be handled by the DSL (require special handling in emu.py)
|
|
||||||
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
|
|
||||||
'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
|
|
||||||
'CVT_OFF_TABLE', 'ThreadMask',
|
|
||||||
'S1[i', 'C.i32', 'thread_',
|
|
||||||
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST',
|
|
||||||
'BARRIER_STATE', 'ReallocVgprs',
|
|
||||||
'GPR_IDX', 'VSKIP', 'specified in', 'TTBL',
|
|
||||||
'fp6', 'bf6', 'GS_REGS', 'M0.base', 'DS_DATA', '= 0..', 'sign(src', 'if no LDS', 'gds_base', 'vector mask',
|
|
||||||
'SGPR_ADDR', 'INST_OFFSET', 'laneID'] # FLAT ops with non-standard vars
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════════════
|
|
||||||
# COMPILER: pseudocode -> Python (minimal transforms)
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════════════
|
|
||||||
|
|
||||||
def compile_pseudocode(pseudocode: str) -> str:
|
|
||||||
"""Compile pseudocode to Python. Transforms are minimal - most syntax just works."""
|
|
||||||
pseudocode = re.sub(r'\bpass\b', 'pass_', pseudocode) # 'pass' is Python keyword
|
|
||||||
raw_lines = pseudocode.strip().split('\n')
|
|
||||||
joined_lines: list[str] = []
|
|
||||||
for line in raw_lines:
|
|
||||||
line = line.strip()
|
|
||||||
if joined_lines and (joined_lines[-1].rstrip().endswith(('||', '&&', '(', ',')) or
|
|
||||||
(joined_lines[-1].count('(') > joined_lines[-1].count(')'))):
|
|
||||||
joined_lines[-1] = joined_lines[-1].rstrip() + ' ' + line
|
|
||||||
else:
|
|
||||||
joined_lines.append(line)
|
|
||||||
|
|
||||||
lines = []
|
|
||||||
indent, need_pass, in_first_match_loop = 0, False, False
|
|
||||||
for line in joined_lines:
|
|
||||||
line = line.split('//')[0].strip() # Strip C-style comments
|
|
||||||
if not line: continue
|
|
||||||
if line.startswith('if '):
|
|
||||||
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:")
|
|
||||||
indent += 1
|
|
||||||
need_pass = True
|
|
||||||
elif line.startswith('elsif '):
|
|
||||||
if need_pass: lines.append(' ' * indent + "pass")
|
|
||||||
indent -= 1
|
|
||||||
lines.append(' ' * indent + f"elif {_expr(line[6:].rstrip(' then'))}:")
|
|
||||||
indent += 1
|
|
||||||
need_pass = True
|
|
||||||
elif line == 'else':
|
|
||||||
if need_pass: lines.append(' ' * indent + "pass")
|
|
||||||
indent -= 1
|
|
||||||
lines.append(' ' * indent + "else:")
|
|
||||||
indent += 1
|
|
||||||
need_pass = True
|
|
||||||
elif line.startswith('endif'):
|
|
||||||
if need_pass: lines.append(' ' * indent + "pass")
|
|
||||||
indent -= 1
|
|
||||||
need_pass = False
|
|
||||||
elif line.startswith('endfor'):
|
|
||||||
if need_pass: lines.append(' ' * indent + "pass")
|
|
||||||
indent -= 1
|
|
||||||
need_pass, in_first_match_loop = False, False
|
|
||||||
elif line.startswith('declare '):
|
|
||||||
pass
|
|
||||||
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
|
|
||||||
start, end = _expr(m[2].strip()), _expr(m[3].strip())
|
|
||||||
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
|
|
||||||
indent += 1
|
|
||||||
need_pass, in_first_match_loop = True, True
|
|
||||||
elif '=' in line and not line.startswith('=='):
|
|
||||||
need_pass = False
|
|
||||||
line = line.rstrip(';')
|
|
||||||
if m := re.match(r'\{\s*D1\.[ui]1\s*,\s*D0\.[ui]64\s*\}\s*=\s*(.+)', line):
|
|
||||||
rhs = _expr(m[1])
|
|
||||||
lines.append(' ' * indent + f"_full = {rhs}")
|
|
||||||
lines.append(' ' * indent + f"D0.u64 = int(_full) & 0xffffffffffffffff")
|
|
||||||
lines.append(' ' * indent + f"D1 = Reg((int(_full) >> 64) & 1)")
|
|
||||||
elif any(op in line for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^=')):
|
|
||||||
for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^='):
|
|
||||||
if op in line:
|
|
||||||
lhs, rhs = line.split(op, 1)
|
|
||||||
lines.append(' ' * indent + f"{lhs.strip()} {op} {_expr(rhs.strip())}")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
lhs, rhs = line.split('=', 1)
|
|
||||||
lhs_s, rhs_s = _expr(lhs.strip()), rhs.strip()
|
|
||||||
stmt = _assign(lhs_s, _expr(rhs_s))
|
|
||||||
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
|
|
||||||
stmt += "; break"
|
|
||||||
lines.append(' ' * indent + stmt)
|
|
||||||
if need_pass: lines.append(' ' * indent + "pass")
|
|
||||||
return '\n'.join(lines)
|
|
||||||
|
|
||||||
def _assign(lhs: str, rhs: str) -> str:
|
|
||||||
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'):
|
|
||||||
return f"{lhs} = Reg({rhs})"
|
|
||||||
return f"{lhs} = {rhs}"
|
|
||||||
|
|
||||||
def _expr(e: str) -> str:
|
|
||||||
e = e.strip()
|
|
||||||
e = e.replace('&&', ' and ').replace('||', ' or ').replace('<>', ' != ')
|
|
||||||
e = re.sub(r'!([^=])', r' not \1', e)
|
|
||||||
e = re.sub(r'\{\s*(\w+\.u32)\s*,\s*(\w+\.u32)\s*\}', r'_pack32(\1, \2)', e)
|
|
||||||
def pack(m):
|
|
||||||
hi, lo = _expr(m[1].strip()), _expr(m[2].strip())
|
|
||||||
return f'_pack({hi}, {lo})'
|
|
||||||
e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e)
|
|
||||||
e = re.sub(r"1201'B\(2\.0\s*/\s*PI\)", "TWO_OVER_PI_1201", e)
|
|
||||||
e = re.sub(r"\d+'([0-9a-fA-Fx]+)[UuFf]*", r'\1', e)
|
|
||||||
e = re.sub(r"\d+'[FIBU]\(", "(", e)
|
|
||||||
e = re.sub(r'\bB\(', '(', e)
|
|
||||||
e = re.sub(r'([0-9a-fA-Fx])ULL\b', r'\1', e)
|
|
||||||
e = re.sub(r'([0-9a-fA-Fx])LL\b', r'\1', e)
|
|
||||||
e = re.sub(r'([0-9a-fA-Fx])U\b', r'\1', e)
|
|
||||||
e = re.sub(r'(\d\.?\d*)F\b', r'\1', e)
|
|
||||||
e = re.sub(r'(\[laneId\])\.[uib]\d+', r'\1', e)
|
|
||||||
e = e.replace('+INF', 'INF').replace('-INF', '(-INF)')
|
|
||||||
e = re.sub(r'NAN\.f\d+', 'float("nan")', e)
|
|
||||||
def convert_verilog_slice(m):
|
|
||||||
start, width = m.group(1).strip(), m.group(2).strip()
|
|
||||||
return f'[({start}) + ({width}) - 1 : ({start})]'
|
|
||||||
e = re.sub(r'\[([^:\[\]]+)\s*\+:\s*([^:\[\]]+)\]', convert_verilog_slice, e)
|
|
||||||
def process_brackets(s):
|
|
||||||
result, i = [], 0
|
|
||||||
while i < len(s):
|
|
||||||
if s[i] == '[':
|
|
||||||
depth, start = 1, i + 1
|
|
||||||
j = start
|
|
||||||
while j < len(s) and depth > 0:
|
|
||||||
if s[j] == '[': depth += 1
|
|
||||||
elif s[j] == ']': depth -= 1
|
|
||||||
j += 1
|
|
||||||
inner = _expr(s[start:j-1])
|
|
||||||
result.append('[' + inner + ']')
|
|
||||||
i = j
|
|
||||||
else:
|
|
||||||
result.append(s[i])
|
|
||||||
i += 1
|
|
||||||
return ''.join(result)
|
|
||||||
e = process_brackets(e)
|
|
||||||
while '?' in e:
|
|
||||||
depth, bracket, q = 0, 0, -1
|
|
||||||
for i, c in enumerate(e):
|
|
||||||
if c == '(': depth += 1
|
|
||||||
elif c == ')': depth -= 1
|
|
||||||
elif c == '[': bracket += 1
|
|
||||||
elif c == ']': bracket -= 1
|
|
||||||
elif c == '?' and depth == 0 and bracket == 0: q = i; break
|
|
||||||
if q < 0: break
|
|
||||||
depth, bracket, col = 0, 0, -1
|
|
||||||
for i in range(q + 1, len(e)):
|
|
||||||
if e[i] == '(': depth += 1
|
|
||||||
elif e[i] == ')': depth -= 1
|
|
||||||
elif e[i] == '[': bracket += 1
|
|
||||||
elif e[i] == ']': bracket -= 1
|
|
||||||
elif e[i] == ':' and depth == 0 and bracket == 0: col = i; break
|
|
||||||
if col < 0: break
|
|
||||||
cond, t, f = e[:q].strip(), e[q+1:col].strip(), e[col+1:].strip()
|
|
||||||
e = f'(({t}) if ({cond}) else ({f}))'
|
|
||||||
return e
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════════════
|
||||||
# PDF PARSING WITH PAGE CACHING
|
# PDF PARSING WITH PAGE CACHING
|
||||||
@@ -472,8 +318,8 @@ def _generate_ins_py(formats, enums, src_enum, doc_name) -> str:
|
|||||||
if "NULL" in src_names: lines.append("OFF = NULL\n")
|
if "NULL" in src_names: lines.append("OFF = NULL\n")
|
||||||
return '\n'.join(lines)
|
return '\n'.join(lines)
|
||||||
|
|
||||||
def _generate_gen_pcode_py(enums, pseudocode, arch) -> str:
|
def _generate_str_pcode_py(enums, pseudocode, arch) -> str:
|
||||||
"""Generate gen_pcode.py content (compiled pseudocode functions)."""
|
"""Generate str_pcode.py content (raw pseudocode strings)."""
|
||||||
# Get op enums for this arch (import from .ins which re-exports from .enum)
|
# Get op enums for this arch (import from .ins which re-exports from .enum)
|
||||||
import importlib
|
import importlib
|
||||||
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}.ins")
|
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}.ins")
|
||||||
@@ -491,186 +337,35 @@ def _generate_gen_pcode_py(enums, pseudocode, arch) -> str:
|
|||||||
if key in defined_ops:
|
if key in defined_ops:
|
||||||
for enum_cls, enum_val in defined_ops[key]: instructions[enum_cls][enum_val] = pc
|
for enum_cls, enum_val in defined_ops[key]: instructions[enum_cls][enum_val] = pc
|
||||||
|
|
||||||
# First pass: generate all function code
|
# Build string dictionaries for each enum
|
||||||
fn_lines: list[str] = []
|
|
||||||
all_fn_entries: dict = {}
|
|
||||||
for enum_cls in OP_ENUMS:
|
|
||||||
cls_name = enum_cls.__name__
|
|
||||||
if not instructions.get(enum_cls): continue
|
|
||||||
fn_entries = []
|
|
||||||
for op, pc in instructions[enum_cls].items():
|
|
||||||
if any(p in pc for p in UNSUPPORTED): continue
|
|
||||||
try:
|
|
||||||
code = compile_pseudocode(pc)
|
|
||||||
code = _apply_pseudocode_fixes(op, code)
|
|
||||||
fn_name, fn_code = _generate_function(cls_name, op, pc, code)
|
|
||||||
fn_lines.append(fn_code)
|
|
||||||
fn_entries.append((op, fn_name))
|
|
||||||
except Exception as e: print(f" Warning: Failed to compile {op.name}: {e}")
|
|
||||||
if fn_entries:
|
|
||||||
all_fn_entries[enum_cls] = fn_entries
|
|
||||||
fn_lines.append(f'{cls_name}_FUNCTIONS = {{')
|
|
||||||
for op, fn_name in fn_entries: fn_lines.append(f" {cls_name}.{op.name}: {fn_name},")
|
|
||||||
fn_lines.append('}\n')
|
|
||||||
|
|
||||||
fn_lines.append('COMPILED_FUNCTIONS = {')
|
|
||||||
for enum_cls in OP_ENUMS:
|
|
||||||
if all_fn_entries.get(enum_cls): fn_lines.append(f' {enum_cls.__name__}: {enum_cls.__name__}_FUNCTIONS,')
|
|
||||||
fn_lines.append('}')
|
|
||||||
|
|
||||||
# Second pass: scan generated code for pcode imports
|
|
||||||
fn_code_str = '\n'.join(fn_lines)
|
|
||||||
import extra.assembly.amd.pcode as pcode_module
|
|
||||||
pcode_exports = [name for name in dir(pcode_module) if not name.startswith('_') or name.startswith('_') and not name.startswith('__')]
|
|
||||||
used_imports = sorted(name for name in pcode_exports if re.search(rf'\b{re.escape(name)}\b', fn_code_str))
|
|
||||||
|
|
||||||
# Build final output with explicit imports
|
|
||||||
lines = [f'''# autogenerated by pdf.py - do not edit
|
lines = [f'''# autogenerated by pdf.py - do not edit
|
||||||
# to regenerate: python -m extra.assembly.amd.pdf --arch {arch}
|
# to regenerate: python -m extra.assembly.amd.pdf --arch {arch}
|
||||||
# ruff: noqa: E501
|
# ruff: noqa: E501
|
||||||
# mypy: ignore-errors
|
|
||||||
from extra.assembly.amd.autogen.{arch}.enum import {", ".join(enum_names)}
|
from extra.assembly.amd.autogen.{arch}.enum import {", ".join(enum_names)}
|
||||||
from extra.assembly.amd.pcode import {", ".join(used_imports)}
|
''']
|
||||||
'''] + fn_lines
|
all_dict_entries: dict = {}
|
||||||
|
for enum_cls in OP_ENUMS:
|
||||||
|
cls_name = enum_cls.__name__
|
||||||
|
if not instructions.get(enum_cls): continue
|
||||||
|
dict_entries = [(op, repr(pc)) for op, pc in instructions[enum_cls].items()]
|
||||||
|
if dict_entries:
|
||||||
|
all_dict_entries[enum_cls] = dict_entries
|
||||||
|
lines.append(f'{cls_name}_PCODE = {{')
|
||||||
|
for op, escaped in dict_entries: lines.append(f" {cls_name}.{op.name}: {escaped},")
|
||||||
|
lines.append('}\n')
|
||||||
|
|
||||||
|
lines.append('PSEUDOCODE_STRINGS = {')
|
||||||
|
for enum_cls in OP_ENUMS:
|
||||||
|
if all_dict_entries.get(enum_cls): lines.append(f' {enum_cls.__name__}: {enum_cls.__name__}_PCODE,')
|
||||||
|
lines.append('}')
|
||||||
return '\n'.join(lines)
|
return '\n'.join(lines)
|
||||||
|
|
||||||
def _apply_pseudocode_fixes(op, code: str) -> str:
|
|
||||||
"""Apply known fixes for PDF pseudocode bugs."""
|
|
||||||
if op.name == 'V_DIV_FMAS_F32':
|
|
||||||
code = code.replace('D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)',
|
|
||||||
'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)')
|
|
||||||
if op.name == 'V_DIV_FMAS_F64':
|
|
||||||
code = code.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
|
|
||||||
'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)')
|
|
||||||
if op.name == 'V_DIV_SCALE_F32':
|
|
||||||
code = code.replace('D0.f32 = float("nan")', 'VCC = Reg(0x1); D0.f32 = float("nan")')
|
|
||||||
code = code.replace('elif S1.f32 == DENORM.f32:\n D0.f32 = ldexp(S0.f32, 64)', 'elif False:\n pass')
|
|
||||||
code += '\nif S1.f32 == DENORM.f32:\n D0.f32 = float("nan")'
|
|
||||||
code = code.replace('elif exponent(S2.f32) <= 23:\n D0.f32 = ldexp(S0.f32, 64)', 'elif exponent(S2.f32) <= 23:\n VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64)')
|
|
||||||
code = code.replace('elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)\n if S0.f32 == S2.f32:\n D0.f32 = ldexp(S0.f32, 64)', 'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)')
|
|
||||||
if op.name == 'V_DIV_SCALE_F64':
|
|
||||||
code = code.replace('D0.f64 = float("nan")', 'VCC = Reg(0x1); D0.f64 = float("nan")')
|
|
||||||
code = code.replace('elif S1.f64 == DENORM.f64:\n D0.f64 = ldexp(S0.f64, 128)', 'elif False:\n pass')
|
|
||||||
code += '\nif S1.f64 == DENORM.f64:\n D0.f64 = float("nan")'
|
|
||||||
code = code.replace('elif exponent(S2.f64) <= 52:\n D0.f64 = ldexp(S0.f64, 128)', 'elif exponent(S2.f64) <= 52:\n VCC = Reg(0x1); D0.f64 = ldexp(S0.f64, 128)')
|
|
||||||
code = code.replace('elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)\n if S0.f64 == S2.f64:\n D0.f64 = ldexp(S0.f64, 128)', 'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)')
|
|
||||||
if op.name == 'V_DIV_FIXUP_F32':
|
|
||||||
code = code.replace('D0.f32 = ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))',
|
|
||||||
'D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32)) if isNAN(S0.f32) else ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))')
|
|
||||||
if op.name == 'V_DIV_FIXUP_F64':
|
|
||||||
code = code.replace('D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))',
|
|
||||||
'D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))')
|
|
||||||
if op.name == 'V_TRIG_PREOP_F64':
|
|
||||||
code = code.replace('result = F((TWO_OVER_PI_1201[1200 : 0] << shift.u32) & 0x1fffffffffffff)',
|
|
||||||
'result = float(((TWO_OVER_PI_1201[1200 : 0] << int(shift)) >> (1201 - 53)) & 0x1fffffffffffff)')
|
|
||||||
return code
|
|
||||||
|
|
||||||
def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]:
|
|
||||||
"""Generate a single compiled pseudocode function.
|
|
||||||
Functions take int parameters and return dict of int values.
|
|
||||||
Reg wrapping happens inside the function, only for registers actually used."""
|
|
||||||
has_d1 = '{ D1' in pc
|
|
||||||
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
|
|
||||||
is_div_scale = 'DIV_SCALE' in op.name
|
|
||||||
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
|
|
||||||
is_ds = cls_name == 'DSOp'
|
|
||||||
is_flat = cls_name in ('FLATOp', 'GLOBALOp', 'SCRATCHOp')
|
|
||||||
is_smem = cls_name == 'SMEMOp'
|
|
||||||
has_s_array = 'S[i]' in pc # FMA_MIX style: S[0], S[1], S[2] array access
|
|
||||||
combined = code + pc
|
|
||||||
|
|
||||||
fn_name = f"_{cls_name}_{op.name}"
|
|
||||||
|
|
||||||
# Detect which registers are used/modified
|
|
||||||
def needs_init(name): return name in combined and not re.search(rf'^\s*{name}\s*=\s*Reg\(', code, re.MULTILINE)
|
|
||||||
modifies_d0 = is_div_scale or bool(re.search(r'\bD0\b[.\[]', combined))
|
|
||||||
modifies_exec = is_cmpx or bool(re.search(r'EXEC\.(u32|u64|b32|b64)\s*=', combined))
|
|
||||||
modifies_vcc = has_sdst or bool(re.search(r'VCC\.(u32|u64|b32|b64)\s*=|VCC\.u64\[laneId\]\s*=', combined))
|
|
||||||
modifies_scc = bool(re.search(r'\bSCC\s*=', combined))
|
|
||||||
modifies_pc = bool(re.search(r'\bPC\s*=', combined))
|
|
||||||
|
|
||||||
# Build function signature and Reg init lines
|
|
||||||
if is_smem:
|
|
||||||
lines = [f"def {fn_name}(MEM, addr):"]
|
|
||||||
reg_inits = ["ADDR=Reg(addr)", "SDATA=Reg(0)"]
|
|
||||||
special_regs = []
|
|
||||||
elif is_ds:
|
|
||||||
lines = [f"def {fn_name}(MEM, addr, data0, data1, offset0, offset1):"]
|
|
||||||
reg_inits = ["ADDR=Reg(addr)", "DATA0=Reg(data0)", "DATA1=Reg(data1)", "OFFSET0=Reg(offset0)", "OFFSET1=Reg(offset1)", "RETURN_DATA=Reg(0)"]
|
|
||||||
special_regs = [('DATA', 'DATA0'), ('DATA2', 'DATA1'), ('OFFSET', 'OFFSET0'), ('ADDR_BASE', 'ADDR')]
|
|
||||||
elif is_flat:
|
|
||||||
lines = [f"def {fn_name}(MEM, addr, vdata, vdst):"]
|
|
||||||
reg_inits = ["ADDR=addr", "VDATA=Reg(vdata)", "VDST=Reg(vdst)", "RETURN_DATA=Reg(0)"]
|
|
||||||
special_regs = [('DATA', 'VDATA')]
|
|
||||||
elif has_s_array:
|
|
||||||
# FMA_MIX style: needs S[i] array, opsel, opsel_hi for source selection (neg/neg_hi applied in emu.py before call)
|
|
||||||
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None, opsel=0, opsel_hi=0):"]
|
|
||||||
reg_inits = ["S0=Reg(s0)", "S1=Reg(s1)", "S2=Reg(s2)", "S=[S0,S1,S2]", "D0=Reg(d0)", "OPSEL=Reg(opsel)", "OPSEL_HI=Reg(opsel_hi)"]
|
|
||||||
special_regs = []
|
|
||||||
# Detect array declarations like "declare in : 32'F[3]" and create them (rename 'in' to 'ins' since 'in' is a keyword)
|
|
||||||
if "in[" in combined:
|
|
||||||
reg_inits.append("ins=[Reg(0),Reg(0),Reg(0)]")
|
|
||||||
code = code.replace("in[", "ins[")
|
|
||||||
else:
|
|
||||||
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None):"]
|
|
||||||
# Only create Regs for registers actually used in the pseudocode
|
|
||||||
reg_inits = []
|
|
||||||
if 'S0' in combined: reg_inits.append("S0=Reg(s0)")
|
|
||||||
if 'S1' in combined: reg_inits.append("S1=Reg(s1)")
|
|
||||||
if 'S2' in combined: reg_inits.append("S2=Reg(s2)")
|
|
||||||
if modifies_d0 or 'D0' in combined: reg_inits.append("D0=Reg(s0)" if is_div_scale else "D0=Reg(d0)")
|
|
||||||
if modifies_scc or 'SCC' in combined: reg_inits.append("SCC=Reg(scc)")
|
|
||||||
if modifies_vcc or 'VCC' in combined: reg_inits.append("VCC=Reg(vcc)")
|
|
||||||
if modifies_exec or 'EXEC' in combined: reg_inits.append("EXEC=Reg(exec_mask)")
|
|
||||||
if modifies_pc or 'PC' in combined: reg_inits.append("PC=Reg(pc) if pc is not None else None")
|
|
||||||
special_regs = [('D1', 'Reg(0)'), ('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
|
|
||||||
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
|
|
||||||
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
|
|
||||||
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
|
|
||||||
|
|
||||||
# Build init code
|
|
||||||
init_parts = reg_inits.copy()
|
|
||||||
for name, init in special_regs:
|
|
||||||
if name in combined: init_parts.append(f"{name}={init}")
|
|
||||||
if 'EXEC_LO' in code: init_parts.append("EXEC_LO=SliceProxy(EXEC, 31, 0)")
|
|
||||||
if 'EXEC_HI' in code: init_parts.append("EXEC_HI=SliceProxy(EXEC, 63, 32)")
|
|
||||||
if 'VCCZ' in code and not re.search(r'^\s*VCCZ\s*=', code, re.MULTILINE): init_parts.append("VCCZ=Reg(1 if VCC._val == 0 else 0)")
|
|
||||||
if 'EXECZ' in code and not re.search(r'^\s*EXECZ\s*=', code, re.MULTILINE): init_parts.append("EXECZ=Reg(1 if EXEC._val == 0 else 0)")
|
|
||||||
|
|
||||||
# Add init line and separator
|
|
||||||
if init_parts: lines.append(f" {'; '.join(init_parts)}")
|
|
||||||
lines.append(" # --- compiled pseudocode ---")
|
|
||||||
|
|
||||||
# Add compiled pseudocode
|
|
||||||
for line in code.split('\n'):
|
|
||||||
if line.strip(): lines.append(f" {line}")
|
|
||||||
|
|
||||||
# Build result dict
|
|
||||||
result_items = []
|
|
||||||
if modifies_d0: result_items.append("'D0': D0._val")
|
|
||||||
if modifies_scc: result_items.append("'SCC': SCC._val")
|
|
||||||
if modifies_vcc: result_items.append("'VCC': VCC._val")
|
|
||||||
if modifies_exec: result_items.append("'EXEC': EXEC._val")
|
|
||||||
if has_d1: result_items.append("'D1': D1._val")
|
|
||||||
if modifies_pc: result_items.append("'PC': PC._val")
|
|
||||||
if is_smem and 'SDATA' in combined and re.search(r'^\s*SDATA[\.\[].*=', code, re.MULTILINE):
|
|
||||||
result_items.append("'SDATA': SDATA._val")
|
|
||||||
if is_ds and 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
|
|
||||||
result_items.append("'RETURN_DATA': RETURN_DATA._val")
|
|
||||||
if is_flat:
|
|
||||||
if 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
|
|
||||||
result_items.append("'RETURN_DATA': RETURN_DATA._val")
|
|
||||||
if re.search(r'^\s*VDATA[\.\[].*=', code, re.MULTILINE):
|
|
||||||
result_items.append("'VDATA': VDATA._val")
|
|
||||||
lines.append(f" return {{{', '.join(result_items)}}}\n")
|
|
||||||
return fn_name, '\n'.join(lines)
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════════════
|
||||||
# MAIN GENERATION
|
# MAIN GENERATION
|
||||||
# ═══════════════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
def generate_arch(arch: str) -> dict:
|
def generate_arch(arch: str) -> dict:
|
||||||
"""Generate enum.py, ins.py and gen_pcode.py for a single architecture."""
|
"""Generate enum.py, ins.py and str_pcode.py for a single architecture."""
|
||||||
urls = PDF_URLS[arch]
|
urls = PDF_URLS[arch]
|
||||||
if isinstance(urls, str): urls = [urls]
|
if isinstance(urls, str): urls = [urls]
|
||||||
|
|
||||||
@@ -696,9 +391,9 @@ def generate_arch(arch: str) -> dict:
|
|||||||
ins_path.write_text(ins_content)
|
ins_path.write_text(ins_content)
|
||||||
print(f"Generated {ins_path}: {len(merged['formats'])} formats")
|
print(f"Generated {ins_path}: {len(merged['formats'])} formats")
|
||||||
|
|
||||||
# Write gen_pcode.py (needs enum.py to exist first for imports)
|
# Write str_pcode.py (needs enum.py to exist first for imports)
|
||||||
pcode_path = base_path / "gen_pcode.py"
|
pcode_path = base_path / "str_pcode.py"
|
||||||
pcode_content = _generate_gen_pcode_py(merged["enums"], merged["pseudocode"], arch)
|
pcode_content = _generate_str_pcode_py(merged["enums"], merged["pseudocode"], arch)
|
||||||
pcode_path.write_text(pcode_content)
|
pcode_path.write_text(pcode_content)
|
||||||
print(f"Generated {pcode_path}: {len(merged['pseudocode'])} instructions")
|
print(f"Generated {pcode_path}: {len(merged['pseudocode'])} instructions")
|
||||||
|
|
||||||
|
|||||||
@@ -30,8 +30,8 @@ def get_llvm_objdump():
|
|||||||
class ExecContext:
|
class ExecContext:
|
||||||
"""Context for running compiled pseudocode in tests."""
|
"""Context for running compiled pseudocode in tests."""
|
||||||
def __init__(self, s0=0, s1=0, s2=0, d0=0, scc=0, vcc=0, lane=0, exec_mask=0xffffffff, literal=0, vgprs=None, src0_idx=0, vdst_idx=0):
|
def __init__(self, s0=0, s1=0, s2=0, d0=0, scc=0, vcc=0, lane=0, exec_mask=0xffffffff, literal=0, vgprs=None, src0_idx=0, vdst_idx=0):
|
||||||
from extra.assembly.amd.pcode import Reg, MASK32, MASK64, SliceProxy
|
from extra.assembly.amd.pcode import Reg, MASK32, MASK64, TypedView
|
||||||
self._Reg, self._MASK64, self._SliceProxy = Reg, MASK64, SliceProxy
|
self._Reg, self._MASK64, self._TypedView = Reg, MASK64, TypedView
|
||||||
self.S0, self.S1, self.S2 = Reg(s0), Reg(s1), Reg(s2)
|
self.S0, self.S1, self.S2 = Reg(s0), Reg(s1), Reg(s2)
|
||||||
self.D0, self.D1 = Reg(d0), Reg(0)
|
self.D0, self.D1 = Reg(d0), Reg(0)
|
||||||
self.SCC, self.VCC, self.EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)
|
self.SCC, self.VCC, self.EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)
|
||||||
@@ -51,7 +51,7 @@ class ExecContext:
|
|||||||
ns.update({
|
ns.update({
|
||||||
'S0': self.S0, 'S1': self.S1, 'S2': self.S2, 'D0': self.D0, 'D1': self.D1,
|
'S0': self.S0, 'S1': self.S1, 'S2': self.S2, 'D0': self.D0, 'D1': self.D1,
|
||||||
'SCC': self.SCC, 'VCC': self.VCC, 'EXEC': self.EXEC,
|
'SCC': self.SCC, 'VCC': self.VCC, 'EXEC': self.EXEC,
|
||||||
'EXEC_LO': self._SliceProxy(self.EXEC, 31, 0), 'EXEC_HI': self._SliceProxy(self.EXEC, 63, 32),
|
'EXEC_LO': self._TypedView(self.EXEC, 31, 0), 'EXEC_HI': self._TypedView(self.EXEC, 63, 32),
|
||||||
'tmp': self.tmp, 'saveexec': self.saveexec,
|
'tmp': self.tmp, 'saveexec': self.saveexec,
|
||||||
'lane': self.lane, 'laneId': self.laneId, 'literal': self.literal,
|
'lane': self.lane, 'laneId': self.laneId, 'literal': self.literal,
|
||||||
'SIMM16': self.SIMM16, 'SIMM32': self.SIMM32, 'VGPR': self.VGPR, 'SRC0': self.SRC0, 'VDST': self.VDST,
|
'SIMM16': self.SIMM16, 'SIMM32': self.SIMM32, 'VGPR': self.VGPR, 'SRC0': self.SRC0, 'VDST': self.VDST,
|
||||||
|
|||||||
@@ -47,8 +47,7 @@ dev.synchronize()
|
|||||||
elapsed = time.perf_counter() - st
|
elapsed = time.perf_counter() - st
|
||||||
|
|
||||||
self.assertNotEqual(result.returncode, 0, "should have raised")
|
self.assertNotEqual(result.returncode, 0, "should have raised")
|
||||||
self.assertTrue("NotImplementedError" in result.stderr or "ValueError" in result.stderr,
|
self.assertTrue("Error" in result.stderr, f"expected an error in stderr, got: {result.stderr[:500]}")
|
||||||
f"expected NotImplementedError or ValueError in stderr")
|
|
||||||
# Should exit immediately, not wait for the full timeout
|
# Should exit immediately, not wait for the full timeout
|
||||||
self.assertLess(elapsed, 9.0, f"should exit immediately on emulator exception, took {elapsed:.1f}s")
|
self.assertLess(elapsed, 9.0, f"should exit immediately on emulator exception, took {elapsed:.1f}s")
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Tests for the RDNA3 pseudocode DSL."""
|
"""Tests for the RDNA3 pseudocode DSL."""
|
||||||
import unittest
|
import unittest
|
||||||
from extra.assembly.amd.pcode import (Reg, TypedView, SliceProxy, MASK32, MASK64,
|
from extra.assembly.amd.pcode import (Reg, TypedView, TypedView, MASK32, MASK64,
|
||||||
_f32, _i32, _f16, _i16, f32_to_f16, _isnan, _bf16, _ibf16, bf16_to_f32, f32_to_bf16,
|
_f32, _i32, _f16, _i16, f32_to_f16, isNAN, _bf16, _ibf16, bf16_to_f32, f32_to_bf16,
|
||||||
BYTE_PERMUTE, v_sad_u8, v_msad_u8)
|
BYTE_PERMUTE, v_sad_u8, v_msad_u8, _compile_pseudocode, _expr, compile_pseudocode)
|
||||||
from extra.assembly.amd.pdf import compile_pseudocode, _expr
|
|
||||||
from extra.assembly.amd.test.helpers import ExecContext
|
from extra.assembly.amd.test.helpers import ExecContext
|
||||||
from extra.assembly.amd.autogen.rdna3.gen_pcode import _VOP3SDOp_V_DIV_SCALE_F32, _VOPCOp_V_CMP_CLASS_F32
|
from extra.assembly.amd.autogen.rdna3.str_pcode import VOP3SDOp_PCODE, VOPCOp_PCODE
|
||||||
|
from extra.assembly.amd.autogen.rdna3.enum import VOP3SDOp, VOPCOp
|
||||||
|
|
||||||
|
# Compile pseudocode functions on demand for regression tests
|
||||||
|
_VOP3SDOp_V_DIV_SCALE_F32 = compile_pseudocode('VOP3SDOp', 'V_DIV_SCALE_F32', VOP3SDOp_PCODE[VOP3SDOp.V_DIV_SCALE_F32])
|
||||||
|
_VOPCOp_V_CMP_CLASS_F32 = compile_pseudocode('VOPCOp', 'V_CMP_CLASS_F32', VOPCOp_PCODE[VOPCOp.V_CMP_CLASS_F32])
|
||||||
|
|
||||||
class TestReg(unittest.TestCase):
|
class TestReg(unittest.TestCase):
|
||||||
def test_u32_read(self):
|
def test_u32_read(self):
|
||||||
@@ -42,7 +46,7 @@ class TestReg(unittest.TestCase):
|
|||||||
class TestTypedView(unittest.TestCase):
|
class TestTypedView(unittest.TestCase):
|
||||||
def test_bit_slice(self):
|
def test_bit_slice(self):
|
||||||
r = Reg(0xDEADBEEF)
|
r = Reg(0xDEADBEEF)
|
||||||
# Slices return SliceProxy which supports .u32, .u16 etc (matching pseudocode like S1.u32[1:0].u32)
|
# Slices return TypedView which supports .u32, .u16 etc (matching pseudocode like S1.u32[1:0].u32)
|
||||||
self.assertEqual(r.u32[7:0].u32, 0xEF)
|
self.assertEqual(r.u32[7:0].u32, 0xEF)
|
||||||
self.assertEqual(r.u32[15:8].u32, 0xBE)
|
self.assertEqual(r.u32[15:8].u32, 0xBE)
|
||||||
self.assertEqual(r.u32[23:16].u32, 0xAD)
|
self.assertEqual(r.u32[23:16].u32, 0xAD)
|
||||||
@@ -67,7 +71,7 @@ class TestTypedView(unittest.TestCase):
|
|||||||
# S0.u32[S1.u32[4:0]] - access bit at position from another register
|
# S0.u32[S1.u32[4:0]] - access bit at position from another register
|
||||||
s0 = Reg(0b11010101)
|
s0 = Reg(0b11010101)
|
||||||
s1 = Reg(3)
|
s1 = Reg(3)
|
||||||
bit_pos = s1.u32[4:0] # SliceProxy, int value = 3
|
bit_pos = s1.u32[4:0] # TypedView, int value = 3
|
||||||
bit_val = s0.u32[int(bit_pos)] # bit 3 of s0 = 0
|
bit_val = s0.u32[int(bit_pos)] # bit 3 of s0 = 0
|
||||||
self.assertEqual(int(bit_pos), 3)
|
self.assertEqual(int(bit_pos), 3)
|
||||||
self.assertEqual(bit_val, 0)
|
self.assertEqual(bit_val, 0)
|
||||||
@@ -85,7 +89,7 @@ class TestTypedView(unittest.TestCase):
|
|||||||
self.assertFalse(r1.u32 < r2.u32)
|
self.assertFalse(r1.u32 < r2.u32)
|
||||||
self.assertTrue(r1.u32 != r2.u32)
|
self.assertTrue(r1.u32 != r2.u32)
|
||||||
|
|
||||||
class TestSliceProxy(unittest.TestCase):
|
class TestTypedView(unittest.TestCase):
|
||||||
def test_slice_read(self):
|
def test_slice_read(self):
|
||||||
r = Reg(0x56781234)
|
r = Reg(0x56781234)
|
||||||
self.assertEqual(r[15:0].u16, 0x1234)
|
self.assertEqual(r[15:0].u16, 0x1234)
|
||||||
@@ -154,19 +158,19 @@ class TestExecContext(unittest.TestCase):
|
|||||||
self.assertEqual(ctx.SCC._val, 0)
|
self.assertEqual(ctx.SCC._val, 0)
|
||||||
|
|
||||||
def test_ternary(self):
|
def test_ternary(self):
|
||||||
code = compile_pseudocode("D0.u32 = S0.u32 > S1.u32 ? 1'1U : 1'0U")
|
code = _compile_pseudocode("D0.u32 = S0.u32 > S1.u32 ? 1'1U : 1'0U")
|
||||||
ctx = ExecContext(s0=5, s1=3)
|
ctx = ExecContext(s0=5, s1=3)
|
||||||
ctx.run(code)
|
ctx.run(code)
|
||||||
self.assertEqual(ctx.D0._val, 1)
|
self.assertEqual(ctx.D0._val, 1)
|
||||||
|
|
||||||
def test_pack(self):
|
def test_pack(self):
|
||||||
code = compile_pseudocode("D0 = { S1[15:0].u16, S0[15:0].u16 }")
|
code = _compile_pseudocode("D0 = { S1[15:0].u16, S0[15:0].u16 }")
|
||||||
ctx = ExecContext(s0=0x1234, s1=0x5678)
|
ctx = ExecContext(s0=0x1234, s1=0x5678)
|
||||||
ctx.run(code)
|
ctx.run(code)
|
||||||
self.assertEqual(ctx.D0._val, 0x56781234)
|
self.assertEqual(ctx.D0._val, 0x56781234)
|
||||||
|
|
||||||
def test_tmp_with_typed_access(self):
|
def test_tmp_with_typed_access(self):
|
||||||
code = compile_pseudocode("""tmp = S0.u32 + S1.u32
|
code = _compile_pseudocode("""tmp = S0.u32 + S1.u32
|
||||||
D0.u32 = tmp.u32""")
|
D0.u32 = tmp.u32""")
|
||||||
ctx = ExecContext(s0=100, s1=200)
|
ctx = ExecContext(s0=100, s1=200)
|
||||||
ctx.run(code)
|
ctx.run(code)
|
||||||
@@ -174,7 +178,7 @@ D0.u32 = tmp.u32""")
|
|||||||
|
|
||||||
def test_s_add_u32_pattern(self):
|
def test_s_add_u32_pattern(self):
|
||||||
# Real pseudocode pattern from S_ADD_U32
|
# Real pseudocode pattern from S_ADD_U32
|
||||||
code = compile_pseudocode("""tmp = 64'U(S0.u32) + 64'U(S1.u32)
|
code = _compile_pseudocode("""tmp = 64'U(S0.u32) + 64'U(S1.u32)
|
||||||
SCC = tmp >= 0x100000000ULL ? 1'1U : 1'0U
|
SCC = tmp >= 0x100000000ULL ? 1'1U : 1'0U
|
||||||
D0.u32 = tmp.u32""")
|
D0.u32 = tmp.u32""")
|
||||||
# Test overflow case
|
# Test overflow case
|
||||||
@@ -184,7 +188,7 @@ D0.u32 = tmp.u32""")
|
|||||||
self.assertEqual(ctx.SCC._val, 1) # Carry set
|
self.assertEqual(ctx.SCC._val, 1) # Carry set
|
||||||
|
|
||||||
def test_s_add_u32_no_overflow(self):
|
def test_s_add_u32_no_overflow(self):
|
||||||
code = compile_pseudocode("""tmp = 64'U(S0.u32) + 64'U(S1.u32)
|
code = _compile_pseudocode("""tmp = 64'U(S0.u32) + 64'U(S1.u32)
|
||||||
SCC = tmp >= 0x100000000ULL ? 1'1U : 1'0U
|
SCC = tmp >= 0x100000000ULL ? 1'1U : 1'0U
|
||||||
D0.u32 = tmp.u32""")
|
D0.u32 = tmp.u32""")
|
||||||
ctx = ExecContext(s0=100, s1=200)
|
ctx = ExecContext(s0=100, s1=200)
|
||||||
@@ -206,7 +210,7 @@ D0.u32 = tmp.u32""")
|
|||||||
|
|
||||||
def test_for_loop(self):
|
def test_for_loop(self):
|
||||||
# CTZ pattern - find first set bit
|
# CTZ pattern - find first set bit
|
||||||
code = compile_pseudocode("""tmp = -1
|
code = _compile_pseudocode("""tmp = -1
|
||||||
for i in 0 : 31 do
|
for i in 0 : 31 do
|
||||||
if S0.u32[i] == 1 then
|
if S0.u32[i] == 1 then
|
||||||
tmp = i
|
tmp = i
|
||||||
@@ -261,15 +265,15 @@ class TestPseudocodeRegressions(unittest.TestCase):
|
|||||||
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
|
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
|
||||||
self.assertEqual(result['D0'] & 1, 0, "Signaling NaN should not match quiet NaN mask")
|
self.assertEqual(result['D0'] & 1, 0, "Signaling NaN should not match quiet NaN mask")
|
||||||
|
|
||||||
def test_isnan_with_typed_view(self):
|
def testisNAN_with_typed_view(self):
|
||||||
"""_isnan must work with TypedView objects, not just Python floats.
|
"""isNAN must work with TypedView objects, not just Python floats.
|
||||||
Bug: _isnan checked isinstance(x, float) which returned False for TypedView."""
|
Bug: isNAN checked isinstance(x, float) which returned False for TypedView."""
|
||||||
nan_reg = Reg(0x7fc00000) # quiet NaN
|
nan_reg = Reg(0x7fc00000) # quiet NaN
|
||||||
normal_reg = Reg(0x3f800000) # 1.0
|
normal_reg = Reg(0x3f800000) # 1.0
|
||||||
inf_reg = Reg(0x7f800000) # +inf
|
inf_reg = Reg(0x7f800000) # +inf
|
||||||
self.assertTrue(_isnan(nan_reg.f32), "_isnan should return True for NaN TypedView")
|
self.assertTrue(isNAN(nan_reg.f32), "isNAN should return True for NaN TypedView")
|
||||||
self.assertFalse(_isnan(normal_reg.f32), "_isnan should return False for normal TypedView")
|
self.assertFalse(isNAN(normal_reg.f32), "isNAN should return False for normal TypedView")
|
||||||
self.assertFalse(_isnan(inf_reg.f32), "_isnan should return False for inf TypedView")
|
self.assertFalse(isNAN(inf_reg.f32), "isNAN should return False for inf TypedView")
|
||||||
|
|
||||||
class TestBF16(unittest.TestCase):
|
class TestBF16(unittest.TestCase):
|
||||||
"""Tests for BF16 (bfloat16) support."""
|
"""Tests for BF16 (bfloat16) support."""
|
||||||
@@ -308,7 +312,7 @@ class TestBF16(unittest.TestCase):
|
|||||||
self.assertAlmostEqual(float(r.bf16), 3.0, places=1)
|
self.assertAlmostEqual(float(r.bf16), 3.0, places=1)
|
||||||
|
|
||||||
def test_bf16_slice_property(self):
|
def test_bf16_slice_property(self):
|
||||||
"""Test SliceProxy.bf16 property."""
|
"""Test TypedView.bf16 property."""
|
||||||
r = Reg(0x40404040) # Two bf16 3.0 values
|
r = Reg(0x40404040) # Two bf16 3.0 values
|
||||||
self.assertAlmostEqual(r[15:0].bf16, 3.0, places=1)
|
self.assertAlmostEqual(r[15:0].bf16, 3.0, places=1)
|
||||||
self.assertAlmostEqual(r[31:16].bf16, 3.0, places=1)
|
self.assertAlmostEqual(r[31:16].bf16, 3.0, places=1)
|
||||||
|
|||||||
Reference in New Issue
Block a user