mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
assembly/amd: clean up pcode, jit pcode instead of static (#14001)
* assembly/amd: clean up pcode * regen * lil * jit the pcode * sendmsg * cleanups * inst prefetch lol
This commit is contained in:
File diff suppressed because it is too large
Load Diff
1421
extra/assembly/amd/autogen/cdna/str_pcode.py
Normal file
1421
extra/assembly/amd/autogen/cdna/str_pcode.py
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
1354
extra/assembly/amd/autogen/rdna3/str_pcode.py
Normal file
1354
extra/assembly/amd/autogen/rdna3/str_pcode.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
1229
extra/assembly/amd/autogen/rdna4/str_pcode.py
Normal file
1229
extra/assembly/amd/autogen/rdna4/str_pcode.py
Normal file
File diff suppressed because one or more lines are too long
@@ -5,7 +5,8 @@ import ctypes, functools
|
||||
from tinygrad.runtime.autogen import hsa
|
||||
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
||||
from extra.assembly.amd.asm import detect_format
|
||||
from extra.assembly.amd.autogen.rdna3.gen_pcode import COMPILED_FUNCTIONS
|
||||
from extra.assembly.amd.pcode import compile_pseudocode
|
||||
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
|
||||
SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, SCRATCHOp, VOPDOp)
|
||||
|
||||
@@ -236,9 +237,8 @@ def exec_vopd(st: WaveState, inst, V: list, lane: int) -> None:
|
||||
"""VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes)."""
|
||||
literal, vdstx, vdsty = inst._literal, inst.vdstx, (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1)
|
||||
sx0, sx1, dx, sy0, sy1, dy = st.rsrc(inst.srcx0, lane, literal), V[inst.vsrcx1], V[vdstx], st.rsrc(inst.srcy0, lane, literal), V[inst.vsrcy1], V[vdsty]
|
||||
opx, opy = _VOPD_TO_VOP[inst.opx], _VOPD_TO_VOP[inst.opy]
|
||||
V[vdstx] = COMPILED_FUNCTIONS[type(opx)][opx](sx0, sx1, 0, dx, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0']
|
||||
V[vdsty] = COMPILED_FUNCTIONS[type(opy)][opy](sy0, sy1, 0, dy, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0']
|
||||
V[vdstx] = inst._fnx(sx0, sx1, 0, dx, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0']
|
||||
V[vdsty] = inst._fny(sy0, sy1, 0, dy, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0']
|
||||
|
||||
def exec_flat(st: WaveState, inst, V: list, lane: int) -> None:
|
||||
"""FLAT/GLOBAL/SCRATCH memory ops."""
|
||||
@@ -359,15 +359,14 @@ def decode_program(data: bytes) -> dict[int, Inst]:
|
||||
result: dict[int, Inst] = {}
|
||||
i = 0
|
||||
while i < len(data):
|
||||
try: inst_class = detect_format(data[i:])
|
||||
except ValueError: break # stop at invalid instruction (padding/metadata after code)
|
||||
inst = inst_class.from_bytes(data[i:i+inst_class._size()+8]) # +8 for potential 64-bit literal
|
||||
inst = detect_format(data[i:]).from_bytes(data[i:])
|
||||
inst._words = inst.size() // 4
|
||||
|
||||
# Determine dispatch function and pcode function
|
||||
fn = COMPILED_FUNCTIONS.get(type(inst.op), {}).get(inst.op)
|
||||
if isinstance(inst, SOPP) and inst.op == SOPPOp.S_ENDPGM: inst._dispatch = dispatch_endpgm
|
||||
if isinstance(inst, SOPP) and inst.op == SOPPOp.S_CODE_END: break
|
||||
elif isinstance(inst, SOPP) and inst.op == SOPPOp.S_ENDPGM: inst._dispatch = dispatch_endpgm
|
||||
elif isinstance(inst, SOPP) and inst.op == SOPPOp.S_BARRIER: inst._dispatch = dispatch_barrier
|
||||
elif isinstance(inst, SOPP) and inst.op in (SOPPOp.S_CLAUSE, SOPPOp.S_WAITCNT, SOPPOp.S_WAITCNT_DEPCTR, SOPPOp.S_SENDMSG, SOPPOp.S_SET_INST_PREFETCH_DISTANCE): inst._dispatch = dispatch_nop
|
||||
elif isinstance(inst, (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM)): inst._dispatch = exec_scalar
|
||||
elif isinstance(inst, VOP1) and inst.op == VOP1Op.V_NOP: inst._dispatch = dispatch_nop
|
||||
elif isinstance(inst, VOP3P) and 'WMMA' in inst.op_name: inst._dispatch = dispatch_wmma
|
||||
@@ -378,11 +377,14 @@ def decode_program(data: bytes) -> dict[int, Inst]:
|
||||
elif isinstance(inst, DS): inst._dispatch = dispatch_lane(exec_ds)
|
||||
else: inst._dispatch = dispatch_lane(exec_vop)
|
||||
|
||||
# Validate pcode exists for instructions that need it (scalar/wave-level ops and VOPD don't need pcode)
|
||||
needs_pcode = inst._dispatch not in (dispatch_endpgm, dispatch_barrier, exec_scalar, dispatch_nop, dispatch_wmma,
|
||||
dispatch_writelane, dispatch_readlane, dispatch_lane(exec_vopd))
|
||||
if fn is None and inst.op_name and needs_pcode: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
|
||||
inst._fn = fn if fn else lambda *args, **kwargs: {}
|
||||
# Compile pcode for instructions that use it (not VOPD which has _fnx/_fny, not special dispatches)
|
||||
# VOPD needs separate functions for X and Y ops
|
||||
if isinstance(inst, VOPD):
|
||||
def _compile_vopd_op(op): return compile_pseudocode(type(op).__name__, op.name, PSEUDOCODE_STRINGS[type(op)][op])
|
||||
inst._fnx, inst._fny = _compile_vopd_op(_VOPD_TO_VOP[inst.opx]), _compile_vopd_op(_VOPD_TO_VOP[inst.opy])
|
||||
elif inst._dispatch not in (dispatch_endpgm, dispatch_barrier, dispatch_nop, dispatch_wmma, dispatch_writelane):
|
||||
assert type(inst.op) != int, f"inst op of {inst} is int"
|
||||
inst._fn = compile_pseudocode(type(inst.op).__name__, inst.op.name, PSEUDOCODE_STRINGS[type(inst.op)][inst.op])
|
||||
result[i // 4] = inst
|
||||
i += inst._words * 4
|
||||
return result
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# DSL for RDNA3 pseudocode - makes pseudocode expressions work directly as Python
|
||||
import struct, math
|
||||
import struct, math, re, functools
|
||||
from extra.assembly.amd.dsl import MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# HELPER FUNCTIONS
|
||||
# INTERNAL HELPERS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _div(a, b):
|
||||
@@ -11,143 +11,35 @@ def _div(a, b):
|
||||
except ZeroDivisionError:
|
||||
if a == 0.0 or math.isnan(a): return float("nan")
|
||||
return math.copysign(float("inf"), a * b) if b == 0.0 else float("inf") if a > 0 else float("-inf")
|
||||
def _to_f16_bits(v): return v if isinstance(v, int) else _i16(v)
|
||||
def _isnan(x):
|
||||
try: return math.isnan(float(x))
|
||||
except (TypeError, ValueError): return False
|
||||
def _check_nan_type(x, quiet_bit_expected, default):
|
||||
"""Check NaN type by examining quiet bit. Returns default if can't determine."""
|
||||
try:
|
||||
if not math.isnan(float(x)): return False
|
||||
if hasattr(x, '_reg') and hasattr(x, '_bits'):
|
||||
bits = x._reg._val & ((1 << x._bits) - 1)
|
||||
# NaN format: exponent all 1s, quiet bit, mantissa != 0
|
||||
# f16: exp[14:10]=31, quiet=bit9, mant[8:0] | f32: exp[30:23]=255, quiet=bit22, mant[22:0] | f64: exp[62:52]=2047, quiet=bit51, mant[51:0]
|
||||
exp_bits, quiet_pos, mant_mask = {16: (0x1f, 9, 0x3ff), 32: (0xff, 22, 0x7fffff), 64: (0x7ff, 51, 0xfffffffffffff)}.get(x._bits, (0,0,0))
|
||||
exp_shift = {16: 10, 32: 23, 64: 52}.get(x._bits, 0)
|
||||
if exp_bits and ((bits >> exp_shift) & exp_bits) == exp_bits and (bits & mant_mask) != 0:
|
||||
return ((bits >> quiet_pos) & 1) == quiet_bit_expected
|
||||
return default
|
||||
except (TypeError, ValueError): return False
|
||||
def _isquietnan(x): return _check_nan_type(x, 1, True) # quiet NaN has quiet bit = 1
|
||||
def _issignalnan(x): return _check_nan_type(x, 0, False) # signaling NaN has quiet bit = 0
|
||||
def _gt_neg_zero(a, b): return (a > b) or (a == 0 and b == 0 and not math.copysign(1, a) < 0 and math.copysign(1, b) < 0)
|
||||
def _lt_neg_zero(a, b): return (a < b) or (a == 0 and b == 0 and math.copysign(1, a) < 0 and not math.copysign(1, b) < 0)
|
||||
def _fma(a, b, c):
|
||||
try: return math.fma(a, b, c)
|
||||
except ValueError: return float('nan') # inf * 0 + c is NaN per IEEE 754
|
||||
def _signext(v): return v
|
||||
def _fpop(fn):
|
||||
def wrapper(x):
|
||||
x = float(x)
|
||||
if math.isnan(x) or math.isinf(x): return x
|
||||
result = float(fn(x))
|
||||
# Preserve sign of zero (IEEE 754: ceil(-0.0) = -0.0, ceil(-0.1) = -0.0)
|
||||
if result == 0.0: return math.copysign(0.0, x)
|
||||
return result
|
||||
return math.copysign(0.0, x) if result == 0.0 else result
|
||||
return wrapper
|
||||
trunc, floor, ceil = _fpop(math.trunc), _fpop(math.floor), _fpop(math.ceil)
|
||||
class _SafeFloat(float):
|
||||
"""Float subclass that uses _div for division to handle 0/inf correctly."""
|
||||
def __truediv__(self, o): return _div(float(self), float(o))
|
||||
def __rtruediv__(self, o): return _div(float(o), float(self))
|
||||
def sqrt(x): return _SafeFloat(math.sqrt(x)) if x >= 0 else _SafeFloat(float("nan"))
|
||||
def log2(x): return math.log2(x) if x > 0 else (float("-inf") if x == 0 else float("nan"))
|
||||
i32_to_f32 = u32_to_f32 = i32_to_f64 = u32_to_f64 = f32_to_f64 = f64_to_f32 = float
|
||||
def _f_to_int(f, lo, hi): f = float(f); return 0 if math.isnan(f) else (hi if f >= hi else lo if f <= lo else int(f))
|
||||
def f32_to_i32(f): return _f_to_int(f, -2147483648, 2147483647)
|
||||
def f32_to_u32(f): return _f_to_int(f, 0, 4294967295)
|
||||
f64_to_i32, f64_to_u32 = f32_to_i32, f32_to_u32
|
||||
def f32_to_f16(f):
|
||||
f = float(f)
|
||||
if math.isnan(f): return 0x7e00 # f16 NaN
|
||||
if math.isinf(f): return 0x7c00 if f > 0 else 0xfc00 # f16 ±infinity
|
||||
try: return struct.unpack("<H", struct.pack("<e", f))[0]
|
||||
except OverflowError: return 0x7c00 if f > 0 else 0xfc00 # overflow -> ±infinity
|
||||
def _f16_to_f32_bits(bits): return struct.unpack("<e", struct.pack("<H", int(bits) & 0xffff))[0]
|
||||
def f16_to_f32(v): return v if isinstance(v, float) else _f16_to_f32_bits(v)
|
||||
def i16_to_f16(v): return f32_to_f16(float(_sext(int(v) & 0xffff, 16)))
|
||||
def u16_to_f16(v): return f32_to_f16(float(int(v) & 0xffff))
|
||||
def f16_to_i16(bits): f = _f16_to_f32_bits(bits); return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
|
||||
def f16_to_u16(bits): f = _f16_to_f32_bits(bits); return max(0, min(65535, int(f))) if not math.isnan(f) else 0
|
||||
def u8_to_u32(v): return int(v) & 0xff
|
||||
def u4_to_u32(v): return int(v) & 0xf
|
||||
def _sign(f): return 1 if math.copysign(1.0, f) < 0 else 0
|
||||
def _mantissa_f32(f): return struct.unpack("<I", struct.pack("<f", f))[0] & 0x7fffff if not (math.isinf(f) or math.isnan(f)) else 0
|
||||
def _ldexp(m, e): return math.ldexp(m, e)
|
||||
def isEven(x):
|
||||
x = float(x)
|
||||
if math.isinf(x) or math.isnan(x): return False
|
||||
return int(x) % 2 == 0
|
||||
def fract(x): return x - math.floor(x)
|
||||
PI = math.pi
|
||||
def _trig(fn, x):
|
||||
# V_SIN/COS_F32: hardware does frac on input cycles before computing
|
||||
if math.isinf(x) or math.isnan(x): return float("nan")
|
||||
frac_cycles = fract(x / (2 * math.pi))
|
||||
result = fn(frac_cycles * 2 * math.pi)
|
||||
# Hardware returns exactly 0 for cos(π/2), sin(π), etc. due to lookup table
|
||||
# Round very small results (below f32 precision) to exactly 0
|
||||
if abs(result) < 1e-7: return 0.0
|
||||
return result
|
||||
def sin(x): return _trig(math.sin, x)
|
||||
def cos(x): return _trig(math.cos, x)
|
||||
def pow(a, b):
|
||||
try: return a ** b
|
||||
except OverflowError: return float("inf") if b > 0 else 0.0
|
||||
def _brev(v, bits): return int(bin(v & ((1 << bits) - 1))[2:].zfill(bits)[::-1], 2)
|
||||
def _brev32(v): return _brev(v, 32)
|
||||
def _brev64(v): return _brev(v, 64)
|
||||
def _ctz(v, bits):
|
||||
v, n = int(v) & ((1 << bits) - 1), 0
|
||||
if v == 0: return bits
|
||||
while (v & 1) == 0: v >>= 1; n += 1
|
||||
return n
|
||||
def _ctz32(v): return _ctz(v, 32)
|
||||
def _ctz64(v): return _ctz(v, 64)
|
||||
def _exponent(f):
|
||||
# Handle TypedView (f16/f32/f64) to get correct exponent for that type
|
||||
if hasattr(f, '_bits') and hasattr(f, '_float') and f._float:
|
||||
raw = f._val
|
||||
if f._bits == 16: return (raw >> 10) & 0x1f # f16: 5-bit exponent
|
||||
if f._bits == 32: return (raw >> 23) & 0xff # f32: 8-bit exponent
|
||||
if f._bits == 64: return (raw >> 52) & 0x7ff # f64: 11-bit exponent
|
||||
# Fallback: convert to f32 and get exponent
|
||||
f = float(f)
|
||||
if math.isinf(f) or math.isnan(f): return 255
|
||||
if f == 0.0: return 0
|
||||
try: bits = struct.unpack("<I", struct.pack("<f", f))[0]; return (bits >> 23) & 0xff
|
||||
except: return 0
|
||||
def _is_denorm_f32(f):
|
||||
if not isinstance(f, float): f = _f32(int(f) & 0xffffffff)
|
||||
if math.isinf(f) or math.isnan(f) or f == 0.0: return False
|
||||
bits = struct.unpack("<I", struct.pack("<f", float(f)))[0]
|
||||
return (bits >> 23) & 0xff == 0
|
||||
def _is_denorm_f64(f):
|
||||
if not isinstance(f, float): f = _f64(int(f) & 0xffffffffffffffff)
|
||||
if math.isinf(f) or math.isnan(f) or f == 0.0: return False
|
||||
bits = struct.unpack("<Q", struct.pack("<d", float(f)))[0]
|
||||
return (bits >> 52) & 0x7ff == 0
|
||||
def v_min_f32(a, b): return a if math.isnan(b) else b if math.isnan(a) else (a if _lt_neg_zero(a, b) else b)
|
||||
def v_max_f32(a, b): return a if math.isnan(b) else b if math.isnan(a) else (a if _gt_neg_zero(a, b) else b)
|
||||
v_min_f16, v_max_f16 = v_min_f32, v_max_f32
|
||||
v_min_i32, v_max_i32 = min, max
|
||||
v_min_i16, v_max_i16 = min, max
|
||||
def v_min_u32(a, b): return min(a & MASK32, b & MASK32)
|
||||
def v_max_u32(a, b): return max(a & MASK32, b & MASK32)
|
||||
def v_min_u16(a, b): return min(a & 0xffff, b & 0xffff)
|
||||
def v_max_u16(a, b): return max(a & 0xffff, b & 0xffff)
|
||||
def v_min3_f32(a, b, c): return v_min_f32(v_min_f32(a, b), c)
|
||||
def v_max3_f32(a, b, c): return v_max_f32(v_max_f32(a, b), c)
|
||||
v_min3_f16, v_max3_f16 = v_min3_f32, v_max3_f32
|
||||
v_min3_i32, v_max3_i32, v_min3_i16, v_max3_i16 = min, max, min, max
|
||||
def v_min3_u32(a, b, c): return min(a & MASK32, b & MASK32, c & MASK32)
|
||||
def v_max3_u32(a, b, c): return max(a & MASK32, b & MASK32, c & MASK32)
|
||||
def v_min3_u16(a, b, c): return min(a & 0xffff, b & 0xffff, c & 0xffff)
|
||||
def v_max3_u16(a, b, c): return max(a & 0xffff, b & 0xffff, c & 0xffff)
|
||||
def ABSDIFF(a, b): return abs(int(a) - int(b))
|
||||
|
||||
# BF16 (bfloat16) conversion functions
|
||||
def _bf16(i):
|
||||
"""Convert bf16 bits to float. BF16 is just the top 16 bits of f32."""
|
||||
return struct.unpack("<f", struct.pack("<I", (i & 0xffff) << 16))[0]
|
||||
@@ -157,95 +49,21 @@ def _ibf16(f):
|
||||
if math.isinf(f): return 0x7f80 if f > 0 else 0xff80 # bf16 ±infinity
|
||||
try: return (struct.unpack("<I", struct.pack("<f", float(f)))[0] >> 16) & 0xffff
|
||||
except (OverflowError, struct.error): return 0x7f80 if f > 0 else 0xff80
|
||||
def bf16_to_f32(v): return _bf16(v) if isinstance(v, int) else float(v)
|
||||
def f32_to_bf16(f): return _ibf16(f)
|
||||
def _trig(fn, x):
|
||||
# V_SIN/COS_F32: hardware does frac on input cycles before computing
|
||||
if math.isinf(x) or math.isnan(x): return float("nan")
|
||||
frac_cycles = fract(x / (2 * math.pi))
|
||||
result = fn(frac_cycles * 2 * math.pi)
|
||||
# Hardware returns exactly 0 for cos(π/2), sin(π), etc. due to lookup table
|
||||
# Round very small results (below f32 precision) to exactly 0
|
||||
if abs(result) < 1e-7: return 0.0
|
||||
return result
|
||||
|
||||
# BYTE_PERMUTE for V_PERM_B32 - select bytes from 64-bit data based on selector
|
||||
def BYTE_PERMUTE(data, sel):
|
||||
"""Select a byte from 64-bit data based on selector value.
|
||||
sel 0-7: select byte from data (S1 is bytes 0-3, S0 is bytes 4-7 in {S0,S1})
|
||||
sel 8-11: sign-extend from specific bytes (8->byte1, 9->byte3, 10->byte5, 11->byte7)
|
||||
sel 12: constant 0x00
|
||||
sel >= 13: constant 0xFF"""
|
||||
sel = int(sel) & 0xff
|
||||
if sel <= 7: return (int(data) >> (sel * 8)) & 0xff
|
||||
if sel == 8: return 0xff if ((int(data) >> 15) & 1) else 0x00 # sign of byte 1
|
||||
if sel == 9: return 0xff if ((int(data) >> 31) & 1) else 0x00 # sign of byte 3
|
||||
if sel == 10: return 0xff if ((int(data) >> 47) & 1) else 0x00 # sign of byte 5
|
||||
if sel == 11: return 0xff if ((int(data) >> 63) & 1) else 0x00 # sign of byte 7
|
||||
if sel == 12: return 0x00
|
||||
return 0xff # sel >= 13
|
||||
class _SafeFloat(float):
|
||||
"""Float subclass that uses _div for division to handle 0/inf correctly."""
|
||||
def __truediv__(self, o): return _div(float(self), float(o))
|
||||
def __rtruediv__(self, o): return _div(float(o), float(self))
|
||||
|
||||
# v_sad_u8 helper for V_SAD instructions (sum of absolute differences of 4 bytes)
|
||||
def v_sad_u8(s0, s1, s2):
|
||||
"""V_SAD_U8: Sum of absolute differences of 4 byte pairs plus accumulator."""
|
||||
s0, s1, s2 = int(s0), int(s1), int(s2)
|
||||
result = s2
|
||||
for i in range(4):
|
||||
a = (s0 >> (i * 8)) & 0xff
|
||||
b = (s1 >> (i * 8)) & 0xff
|
||||
result += abs(a - b)
|
||||
return result & 0xffffffff
|
||||
|
||||
# v_msad_u8 helper (masked SAD - skip when reference byte is 0)
|
||||
def v_msad_u8(s0, s1, s2):
|
||||
"""V_MSAD_U8: Masked sum of absolute differences (skip if reference byte is 0)."""
|
||||
s0, s1, s2 = int(s0), int(s1), int(s2)
|
||||
result = s2
|
||||
for i in range(4):
|
||||
a = (s0 >> (i * 8)) & 0xff
|
||||
b = (s1 >> (i * 8)) & 0xff
|
||||
if b != 0: # Only add diff if reference (s1) byte is non-zero
|
||||
result += abs(a - b)
|
||||
return result & 0xffffffff
|
||||
def f16_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
|
||||
def f16_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535))))
|
||||
def f32_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
|
||||
def f32_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535))))
|
||||
def v_cvt_i16_f32(f): return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
|
||||
def v_cvt_u16_f32(f): return max(0, min(65535, int(f))) if not math.isnan(f) else 0
|
||||
def u32_to_u16(u): return int(u) & 0xffff
|
||||
def i32_to_i16(i): return ((int(i) + 32768) & 0xffff) - 32768
|
||||
def SAT8(v): return max(0, min(255, int(v)))
|
||||
def f32_to_u8(f): return max(0, min(255, int(f))) if not math.isnan(f) else 0
|
||||
def mantissa(f):
|
||||
if f == 0.0 or math.isinf(f) or math.isnan(f): return f
|
||||
m, _ = math.frexp(f)
|
||||
return m # AMD V_FREXP_MANT returns mantissa in [0.5, 1.0) range
|
||||
def signext_from_bit(val, bit):
|
||||
bit = int(bit)
|
||||
if bit == 0: return 0
|
||||
mask = (1 << bit) - 1
|
||||
val = int(val) & mask
|
||||
if val & (1 << (bit - 1)): return val - (1 << bit)
|
||||
return val
|
||||
|
||||
# Aliases used in pseudocode
|
||||
s_ff1_i32_b32, s_ff1_i32_b64 = _ctz32, _ctz64
|
||||
GT_NEG_ZERO, LT_NEG_ZERO = _gt_neg_zero, _lt_neg_zero
|
||||
isNAN = _isnan
|
||||
isQuietNAN = _isquietnan
|
||||
isSignalNAN = _issignalnan
|
||||
fma, ldexp, sign, exponent = _fma, _ldexp, _sign, _exponent
|
||||
def F(x):
|
||||
"""32'F(x) or 64'F(x) - interpret x as float. If x is int, treat as bit pattern."""
|
||||
if isinstance(x, int): return _f32(x) # int -> interpret as f32 bits
|
||||
if isinstance(x, TypedView): return x # preserve TypedView for bit-pattern checks
|
||||
return float(x) # already a float or float-like
|
||||
signext = lambda x: int(x) # sign-extend to full width - already handled by Python's arbitrary precision ints
|
||||
pack = lambda hi, lo: ((int(hi) & 0xffff) << 16) | (int(lo) & 0xffff)
|
||||
pack32 = lambda hi, lo: ((int(hi) & 0xffffffff) << 32) | (int(lo) & 0xffffffff)
|
||||
_pack, _pack32 = pack, pack32 # Aliases for internal use
|
||||
WAVE32, WAVE64 = True, False
|
||||
|
||||
# Float overflow/underflow constants
|
||||
OVERFLOW_F32 = float('inf')
|
||||
UNDERFLOW_F32 = 0.0
|
||||
OVERFLOW_F64 = float('inf')
|
||||
UNDERFLOW_F64 = 0.0
|
||||
MAX_FLOAT_F32 = 3.4028235e+38 # Largest finite float32
|
||||
|
||||
# INF object that supports .f16/.f32/.f64 access and comparison with floats
|
||||
class _Inf:
|
||||
f16 = f32 = f64 = float('inf')
|
||||
def __neg__(self): return _NegInf()
|
||||
@@ -260,26 +78,24 @@ class _NegInf:
|
||||
def __float__(self): return float('-inf')
|
||||
def __eq__(self, other): return float(other) == float('-inf') if not isinstance(other, _Inf) else False
|
||||
def __req__(self, other): return self.__eq__(other)
|
||||
INF = _Inf()
|
||||
|
||||
# Rounding mode placeholder
|
||||
class _RoundMode:
|
||||
NEAREST_EVEN = 0
|
||||
ROUND_MODE = _RoundMode()
|
||||
|
||||
# Helper functions for pseudocode
|
||||
def cvtToQuietNAN(x): return float('nan')
|
||||
DST = None # Placeholder, will be set in context
|
||||
|
||||
class _WaveMode:
|
||||
IEEE = False
|
||||
WAVE_MODE = _WaveMode()
|
||||
|
||||
class _DenormChecker:
|
||||
"""Comparator for denormalized floats. x == DENORM.f32 checks if x is denormalized."""
|
||||
def __init__(self, bits): self._bits = bits
|
||||
def _check(self, other):
|
||||
return _is_denorm_f64(float(other)) if self._bits == 64 else _is_denorm_f32(float(other))
|
||||
f = float(other)
|
||||
if math.isinf(f) or math.isnan(f) or f == 0.0: return False
|
||||
if self._bits == 64:
|
||||
bits = struct.unpack("<Q", struct.pack("<d", f))[0]
|
||||
return (bits >> 52) & 0x7ff == 0
|
||||
bits = struct.unpack("<I", struct.pack("<f", f))[0]
|
||||
return (bits >> 23) & 0xff == 0
|
||||
def __eq__(self, other): return self._check(other)
|
||||
def __req__(self, other): return self._check(other)
|
||||
def __ne__(self, other): return not self._check(other)
|
||||
@@ -287,7 +103,9 @@ class _DenormChecker:
|
||||
class _Denorm:
|
||||
f32 = _DenormChecker(32)
|
||||
f64 = _DenormChecker(64)
|
||||
DENORM = _Denorm()
|
||||
|
||||
_pack = lambda hi, lo: ((int(hi) & 0xffff) << 16) | (int(lo) & 0xffff)
|
||||
_pack32 = lambda hi, lo: ((int(hi) & 0xffffffff) << 32) | (int(lo) & 0xffffffff)
|
||||
|
||||
class TypedView:
|
||||
"""View into a Reg with typed access. Used for both full-width (Reg.u32) and slices (Reg[31:16])."""
|
||||
@@ -396,8 +214,6 @@ class TypedView:
|
||||
def __gt__(s, o): return float(s) > float(o) if s._float else int(s) > int(o)
|
||||
def __ge__(s, o): return float(s) >= float(o) if s._float else int(s) >= int(o)
|
||||
|
||||
SliceProxy = TypedView # Alias for compatibility
|
||||
|
||||
class Reg:
|
||||
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works. Supports up to 128 bits for DS_LOAD_B128."""
|
||||
__slots__ = ('_val',)
|
||||
@@ -466,5 +282,484 @@ class Reg:
|
||||
def __eq__(s, o): return s._val == int(o)
|
||||
def __ne__(s, o): return s._val != int(o)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PSEUDOCODE API - Functions and constants from AMD ISA pseudocode
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
# Rounding and float operations
|
||||
trunc, floor, ceil = _fpop(math.trunc), _fpop(math.floor), _fpop(math.ceil)
|
||||
def sqrt(x): return _SafeFloat(math.sqrt(x)) if x >= 0 else _SafeFloat(float("nan"))
|
||||
def log2(x): return math.log2(x) if x > 0 else (float("-inf") if x == 0 else float("nan"))
|
||||
def fract(x): return x - math.floor(x)
|
||||
def sin(x): return _trig(math.sin, x)
|
||||
def cos(x): return _trig(math.cos, x)
|
||||
def pow(a, b):
|
||||
try: return a ** b
|
||||
except OverflowError: return float("inf") if b > 0 else 0.0
|
||||
def isEven(x):
|
||||
x = float(x)
|
||||
if math.isinf(x) or math.isnan(x): return False
|
||||
return int(x) % 2 == 0
|
||||
def mantissa(f):
|
||||
if f == 0.0 or math.isinf(f) or math.isnan(f): return f
|
||||
m, _ = math.frexp(f)
|
||||
return m # AMD V_FREXP_MANT returns mantissa in [0.5, 1.0) range
|
||||
def signext_from_bit(val, bit):
|
||||
bit = int(bit)
|
||||
if bit == 0: return 0
|
||||
mask = (1 << bit) - 1
|
||||
val = int(val) & mask
|
||||
if val & (1 << (bit - 1)): return val - (1 << bit)
|
||||
return val
|
||||
|
||||
# Type conversions
|
||||
i32_to_f32 = u32_to_f32 = i32_to_f64 = u32_to_f64 = f32_to_f64 = f64_to_f32 = float
|
||||
def f32_to_i32(f): return _f_to_int(f, -2147483648, 2147483647)
|
||||
def f32_to_u32(f): return _f_to_int(f, 0, 4294967295)
|
||||
f64_to_i32, f64_to_u32 = f32_to_i32, f32_to_u32
|
||||
def f32_to_f16(f):
|
||||
f = float(f)
|
||||
if math.isnan(f): return 0x7e00 # f16 NaN
|
||||
if math.isinf(f): return 0x7c00 if f > 0 else 0xfc00 # f16 ±infinity
|
||||
try: return struct.unpack("<H", struct.pack("<e", f))[0]
|
||||
except OverflowError: return 0x7c00 if f > 0 else 0xfc00 # overflow -> ±infinity
|
||||
def f16_to_f32(v): return v if isinstance(v, float) else _f16_to_f32_bits(v)
|
||||
def i16_to_f16(v): return f32_to_f16(float(_sext(int(v) & 0xffff, 16)))
|
||||
def u16_to_f16(v): return f32_to_f16(float(int(v) & 0xffff))
|
||||
def f16_to_i16(bits): f = _f16_to_f32_bits(bits); return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
|
||||
def f16_to_u16(bits): f = _f16_to_f32_bits(bits); return max(0, min(65535, int(f))) if not math.isnan(f) else 0
|
||||
def bf16_to_f32(v): return _bf16(v) if isinstance(v, int) else float(v)
|
||||
def f32_to_bf16(f): return _ibf16(f)
|
||||
def u8_to_u32(v): return int(v) & 0xff
|
||||
def u4_to_u32(v): return int(v) & 0xf
|
||||
def u32_to_u16(u): return int(u) & 0xffff
|
||||
def i32_to_i16(i): return ((int(i) + 32768) & 0xffff) - 32768
|
||||
def f16_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
|
||||
def f16_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535))))
|
||||
def f32_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
|
||||
def f32_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535))))
|
||||
def v_cvt_i16_f32(f): return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
|
||||
def v_cvt_u16_f32(f): return max(0, min(65535, int(f))) if not math.isnan(f) else 0
|
||||
def SAT8(v): return max(0, min(255, int(v)))
|
||||
def f32_to_u8(f): return max(0, min(255, int(f))) if not math.isnan(f) else 0
|
||||
|
||||
# Min/max operations
|
||||
def v_min_f32(a, b): return a if math.isnan(b) else b if math.isnan(a) else (a if _lt_neg_zero(a, b) else b)
|
||||
def v_max_f32(a, b): return a if math.isnan(b) else b if math.isnan(a) else (a if _gt_neg_zero(a, b) else b)
|
||||
v_min_f16, v_max_f16 = v_min_f32, v_max_f32
|
||||
v_min_i32, v_max_i32 = min, max
|
||||
v_min_i16, v_max_i16 = min, max
|
||||
def v_min_u32(a, b): return min(a & MASK32, b & MASK32)
|
||||
def v_max_u32(a, b): return max(a & MASK32, b & MASK32)
|
||||
def v_min_u16(a, b): return min(a & 0xffff, b & 0xffff)
|
||||
def v_max_u16(a, b): return max(a & 0xffff, b & 0xffff)
|
||||
def v_min3_f32(a, b, c): return v_min_f32(v_min_f32(a, b), c)
|
||||
def v_max3_f32(a, b, c): return v_max_f32(v_max_f32(a, b), c)
|
||||
v_min3_f16, v_max3_f16 = v_min3_f32, v_max3_f32
|
||||
v_min3_i32, v_max3_i32, v_min3_i16, v_max3_i16 = min, max, min, max
|
||||
def v_min3_u32(a, b, c): return min(a & MASK32, b & MASK32, c & MASK32)
|
||||
def v_max3_u32(a, b, c): return max(a & MASK32, b & MASK32, c & MASK32)
|
||||
def v_min3_u16(a, b, c): return min(a & 0xffff, b & 0xffff, c & 0xffff)
|
||||
def v_max3_u16(a, b, c): return max(a & 0xffff, b & 0xffff, c & 0xffff)
|
||||
|
||||
# SAD/MSAD operations
|
||||
def ABSDIFF(a, b): return abs(int(a) - int(b))
|
||||
def v_sad_u8(s0, s1, s2):
|
||||
"""V_SAD_U8: Sum of absolute differences of 4 byte pairs plus accumulator."""
|
||||
s0, s1, s2 = int(s0), int(s1), int(s2)
|
||||
result = s2
|
||||
for i in range(4):
|
||||
a = (s0 >> (i * 8)) & 0xff
|
||||
b = (s1 >> (i * 8)) & 0xff
|
||||
result += abs(a - b)
|
||||
return result & 0xffffffff
|
||||
def v_msad_u8(s0, s1, s2):
|
||||
"""V_MSAD_U8: Masked sum of absolute differences (skip if reference byte is 0)."""
|
||||
s0, s1, s2 = int(s0), int(s1), int(s2)
|
||||
result = s2
|
||||
for i in range(4):
|
||||
a = (s0 >> (i * 8)) & 0xff
|
||||
b = (s1 >> (i * 8)) & 0xff
|
||||
if b != 0: # Only add diff if reference (s1) byte is non-zero
|
||||
result += abs(a - b)
|
||||
return result & 0xffffffff
|
||||
|
||||
def BYTE_PERMUTE(data, sel):
|
||||
"""Select a byte from 64-bit data based on selector value."""
|
||||
sel = int(sel) & 0xff
|
||||
if sel <= 7: return (int(data) >> (sel * 8)) & 0xff
|
||||
if sel == 8: return 0xff if ((int(data) >> 15) & 1) else 0x00
|
||||
if sel == 9: return 0xff if ((int(data) >> 31) & 1) else 0x00
|
||||
if sel == 10: return 0xff if ((int(data) >> 47) & 1) else 0x00
|
||||
if sel == 11: return 0xff if ((int(data) >> 63) & 1) else 0x00
|
||||
if sel == 12: return 0x00
|
||||
return 0xff
|
||||
|
||||
# Pseudocode functions
|
||||
def s_ff1_i32_b32(v): return _ctz(v, 32)
|
||||
def s_ff1_i32_b64(v): return _ctz(v, 64)
|
||||
GT_NEG_ZERO, LT_NEG_ZERO = _gt_neg_zero, _lt_neg_zero
|
||||
def isNAN(x):
|
||||
try: return math.isnan(float(x))
|
||||
except (TypeError, ValueError): return False
|
||||
def isQuietNAN(x): return _check_nan_type(x, 1, True)
|
||||
def isSignalNAN(x): return _check_nan_type(x, 0, False)
|
||||
def fma(a, b, c):
|
||||
try: return math.fma(a, b, c)
|
||||
except ValueError: return float('nan')
|
||||
def ldexp(m, e): return math.ldexp(m, e)
|
||||
def sign(f): return 1 if math.copysign(1.0, f) < 0 else 0
|
||||
def exponent(f):
|
||||
if hasattr(f, '_bits') and hasattr(f, '_float') and f._float:
|
||||
raw = f._val
|
||||
if f._bits == 16: return (raw >> 10) & 0x1f
|
||||
if f._bits == 32: return (raw >> 23) & 0xff
|
||||
if f._bits == 64: return (raw >> 52) & 0x7ff
|
||||
f = float(f)
|
||||
if math.isinf(f) or math.isnan(f): return 255
|
||||
if f == 0.0: return 0
|
||||
try: bits = struct.unpack("<I", struct.pack("<f", f))[0]; return (bits >> 23) & 0xff
|
||||
except: return 0
|
||||
def signext(x): return int(x)
|
||||
def cvtToQuietNAN(x): return float('nan')
|
||||
|
||||
def F(x):
|
||||
"""32'F(x) or 64'F(x) - interpret x as float. If x is int, treat as bit pattern."""
|
||||
if isinstance(x, int): return _f32(x)
|
||||
if isinstance(x, TypedView): return x
|
||||
return float(x)
|
||||
|
||||
# Constants
|
||||
PI = math.pi
|
||||
WAVE32, WAVE64 = True, False
|
||||
OVERFLOW_F32, UNDERFLOW_F32 = float('inf'), 0.0
|
||||
OVERFLOW_F64, UNDERFLOW_F64 = float('inf'), 0.0
|
||||
MAX_FLOAT_F32 = 3.4028235e+38
|
||||
INF = _Inf()
|
||||
ROUND_MODE = _RoundMode()
|
||||
WAVE_MODE = _WaveMode()
|
||||
DENORM = _Denorm()
|
||||
|
||||
# 2/PI with 1201 bits of precision for V_TRIG_PREOP_F64
|
||||
TWO_OVER_PI_1201 = Reg(0x0145f306dc9c882a53f84eafa3ea69bb81b6c52b3278872083fca2c757bd778ac36e48dc74849ba5c00c925dd413a32439fc3bd63962534e7dd1046bea5d768909d338e04d68befc827323ac7306a673e93908bf177bf250763ff12fffbc0b301fde5e2316b414da3eda6cfd9e4f96136e9e8c7ecd3cbfd45aea4f758fd7cbe2f67a0e73ef14a525d4d7f6bf623f1aba10ac06608df8f6)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# COMPILER: pseudocode -> Python (minimal transforms)
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _compile_pseudocode(pseudocode: str) -> str:
|
||||
"""Compile pseudocode to Python. Transforms are minimal - most syntax just works."""
|
||||
pseudocode = re.sub(r'\bpass\b', 'pass_', pseudocode) # 'pass' is Python keyword
|
||||
raw_lines = pseudocode.strip().split('\n')
|
||||
joined_lines: list[str] = []
|
||||
for line in raw_lines:
|
||||
line = line.strip()
|
||||
if joined_lines and (joined_lines[-1].rstrip().endswith(('||', '&&', '(', ',')) or
|
||||
(joined_lines[-1].count('(') > joined_lines[-1].count(')'))):
|
||||
joined_lines[-1] = joined_lines[-1].rstrip() + ' ' + line
|
||||
else:
|
||||
joined_lines.append(line)
|
||||
|
||||
lines = []
|
||||
indent, need_pass, in_first_match_loop = 0, False, False
|
||||
for line in joined_lines:
|
||||
line = line.split('//')[0].strip() # Strip C-style comments
|
||||
if not line: continue
|
||||
if line.startswith('if '):
|
||||
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line.startswith('elsif '):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
lines.append(' ' * indent + f"elif {_expr(line[6:].rstrip(' then'))}:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line == 'else':
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
lines.append(' ' * indent + "else:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line.startswith('endif'):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
need_pass = False
|
||||
elif line.startswith('endfor'):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
need_pass, in_first_match_loop = False, False
|
||||
elif line.startswith('declare '):
|
||||
pass
|
||||
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
|
||||
start, end = _expr(m[2].strip()), _expr(m[3].strip())
|
||||
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
|
||||
indent += 1
|
||||
need_pass, in_first_match_loop = True, True
|
||||
elif '=' in line and not line.startswith('=='):
|
||||
need_pass = False
|
||||
line = line.rstrip(';')
|
||||
if m := re.match(r'\{\s*D1\.[ui]1\s*,\s*D0\.[ui]64\s*\}\s*=\s*(.+)', line):
|
||||
rhs = _expr(m[1])
|
||||
lines.append(' ' * indent + f"_full = {rhs}")
|
||||
lines.append(' ' * indent + f"D0.u64 = int(_full) & 0xffffffffffffffff")
|
||||
lines.append(' ' * indent + f"D1 = Reg((int(_full) >> 64) & 1)")
|
||||
elif any(op in line for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^=')):
|
||||
for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^='):
|
||||
if op in line:
|
||||
lhs, rhs = line.split(op, 1)
|
||||
lines.append(' ' * indent + f"{lhs.strip()} {op} {_expr(rhs.strip())}")
|
||||
break
|
||||
else:
|
||||
lhs, rhs = line.split('=', 1)
|
||||
lhs_s, rhs_s = _expr(lhs.strip()), rhs.strip()
|
||||
stmt = _assign(lhs_s, _expr(rhs_s))
|
||||
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
|
||||
stmt += "; break"
|
||||
lines.append(' ' * indent + stmt)
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
return '\n'.join(lines)
|
||||
|
||||
def _assign(lhs: str, rhs: str) -> str:
|
||||
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'):
|
||||
return f"{lhs} = Reg({rhs})"
|
||||
return f"{lhs} = {rhs}"
|
||||
|
||||
def _expr(e: str) -> str:
|
||||
e = e.strip()
|
||||
e = e.replace('&&', ' and ').replace('||', ' or ').replace('<>', ' != ')
|
||||
e = re.sub(r'!([^=])', r' not \1', e)
|
||||
e = re.sub(r'\{\s*(\w+\.u32)\s*,\s*(\w+\.u32)\s*\}', r'_pack32(\1, \2)', e)
|
||||
def pack(m):
|
||||
hi, lo = _expr(m[1].strip()), _expr(m[2].strip())
|
||||
return f'_pack({hi}, {lo})'
|
||||
e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e)
|
||||
e = re.sub(r"1201'B\(2\.0\s*/\s*PI\)", "TWO_OVER_PI_1201", e)
|
||||
e = re.sub(r"\d+'([0-9a-fA-Fx]+)[UuFf]*", r'\1', e)
|
||||
e = re.sub(r"\d+'[FIBU]\(", "(", e)
|
||||
e = re.sub(r'\bB\(', '(', e)
|
||||
e = re.sub(r'([0-9a-fA-Fx])ULL\b', r'\1', e)
|
||||
e = re.sub(r'([0-9a-fA-Fx])LL\b', r'\1', e)
|
||||
e = re.sub(r'([0-9a-fA-Fx])U\b', r'\1', e)
|
||||
e = re.sub(r'(\d\.?\d*)F\b', r'\1', e)
|
||||
e = re.sub(r'(\[laneId\])\.[uib]\d+', r'\1', e)
|
||||
e = e.replace('+INF', 'INF').replace('-INF', '(-INF)')
|
||||
e = re.sub(r'NAN\.f\d+', 'float("nan")', e)
|
||||
def convert_verilog_slice(m):
|
||||
start, width = m.group(1).strip(), m.group(2).strip()
|
||||
return f'[({start}) + ({width}) - 1 : ({start})]'
|
||||
e = re.sub(r'\[([^:\[\]]+)\s*\+:\s*([^:\[\]]+)\]', convert_verilog_slice, e)
|
||||
def process_brackets(s):
|
||||
result, i = [], 0
|
||||
while i < len(s):
|
||||
if s[i] == '[':
|
||||
depth, start = 1, i + 1
|
||||
j = start
|
||||
while j < len(s) and depth > 0:
|
||||
if s[j] == '[': depth += 1
|
||||
elif s[j] == ']': depth -= 1
|
||||
j += 1
|
||||
inner = _expr(s[start:j-1])
|
||||
result.append('[' + inner + ']')
|
||||
i = j
|
||||
else:
|
||||
result.append(s[i])
|
||||
i += 1
|
||||
return ''.join(result)
|
||||
e = process_brackets(e)
|
||||
while '?' in e:
|
||||
depth, bracket, q = 0, 0, -1
|
||||
for i, c in enumerate(e):
|
||||
if c == '(': depth += 1
|
||||
elif c == ')': depth -= 1
|
||||
elif c == '[': bracket += 1
|
||||
elif c == ']': bracket -= 1
|
||||
elif c == '?' and depth == 0 and bracket == 0: q = i; break
|
||||
if q < 0: break
|
||||
depth, bracket, col = 0, 0, -1
|
||||
for i in range(q + 1, len(e)):
|
||||
if e[i] == '(': depth += 1
|
||||
elif e[i] == ')': depth -= 1
|
||||
elif e[i] == '[': bracket += 1
|
||||
elif e[i] == ']': bracket -= 1
|
||||
elif e[i] == ':' and depth == 0 and bracket == 0: col = i; break
|
||||
if col < 0: break
|
||||
cond, t, f = e[:q].strip(), e[q+1:col].strip(), e[col+1:].strip()
|
||||
e = f'(({t}) if ({cond}) else ({f}))'
|
||||
return e
|
||||
|
||||
def _apply_pseudocode_fixes(op_name: str, code: str) -> str:
|
||||
"""Apply known fixes for PDF pseudocode bugs."""
|
||||
if op_name == 'V_DIV_FMAS_F32':
|
||||
code = code.replace('D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)',
|
||||
'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)')
|
||||
if op_name == 'V_DIV_FMAS_F64':
|
||||
code = code.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
|
||||
'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)')
|
||||
if op_name == 'V_DIV_SCALE_F32':
|
||||
code = code.replace('D0.f32 = float("nan")', 'VCC = Reg(0x1); D0.f32 = float("nan")')
|
||||
code = code.replace('elif S1.f32 == DENORM.f32:\n D0.f32 = ldexp(S0.f32, 64)', 'elif False:\n pass')
|
||||
code += '\nif S1.f32 == DENORM.f32:\n D0.f32 = float("nan")'
|
||||
code = code.replace('elif exponent(S2.f32) <= 23:\n D0.f32 = ldexp(S0.f32, 64)', 'elif exponent(S2.f32) <= 23:\n VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64)')
|
||||
code = code.replace('elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)\n if S0.f32 == S2.f32:\n D0.f32 = ldexp(S0.f32, 64)', 'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)')
|
||||
if op_name == 'V_DIV_SCALE_F64':
|
||||
code = code.replace('D0.f64 = float("nan")', 'VCC = Reg(0x1); D0.f64 = float("nan")')
|
||||
code = code.replace('elif S1.f64 == DENORM.f64:\n D0.f64 = ldexp(S0.f64, 128)', 'elif False:\n pass')
|
||||
code += '\nif S1.f64 == DENORM.f64:\n D0.f64 = float("nan")'
|
||||
code = code.replace('elif exponent(S2.f64) <= 52:\n D0.f64 = ldexp(S0.f64, 128)', 'elif exponent(S2.f64) <= 52:\n VCC = Reg(0x1); D0.f64 = ldexp(S0.f64, 128)')
|
||||
code = code.replace('elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)\n if S0.f64 == S2.f64:\n D0.f64 = ldexp(S0.f64, 128)', 'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)')
|
||||
if op_name == 'V_DIV_FIXUP_F32':
|
||||
code = code.replace('D0.f32 = ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))',
|
||||
'D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32)) if isNAN(S0.f32) else ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))')
|
||||
if op_name == 'V_DIV_FIXUP_F64':
|
||||
code = code.replace('D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))',
|
||||
'D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))')
|
||||
if op_name == 'V_TRIG_PREOP_F64':
|
||||
code = code.replace('result = F((TWO_OVER_PI_1201[1200 : 0] << shift.u32) & 0x1fffffffffffff)',
|
||||
'result = float(((TWO_OVER_PI_1201[1200 : 0] << int(shift)) >> (1201 - 53)) & 0x1fffffffffffff)')
|
||||
return code
|
||||
|
||||
def _generate_function(cls_name: str, op_name: str, pc: str, code: str) -> str:
|
||||
"""Generate a single compiled pseudocode function.
|
||||
Functions take int parameters and return dict of int values.
|
||||
Reg wrapping happens inside the function, only for registers actually used."""
|
||||
has_d1 = '{ D1' in pc
|
||||
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
|
||||
is_div_scale = 'DIV_SCALE' in op_name
|
||||
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
|
||||
is_ds = cls_name == 'DSOp'
|
||||
is_flat = cls_name in ('FLATOp', 'GLOBALOp', 'SCRATCHOp')
|
||||
is_smem = cls_name == 'SMEMOp'
|
||||
has_s_array = 'S[i]' in pc # FMA_MIX style: S[0], S[1], S[2] array access
|
||||
combined = code + pc
|
||||
|
||||
fn_name = f"_{cls_name}_{op_name}"
|
||||
|
||||
# Detect which registers are used/modified
|
||||
def needs_init(name): return name in combined and not re.search(rf'^\s*{name}\s*=\s*Reg\(', code, re.MULTILINE)
|
||||
modifies_d0 = is_div_scale or bool(re.search(r'\bD0\b[.\[]', combined))
|
||||
modifies_exec = is_cmpx or bool(re.search(r'EXEC\.(u32|u64|b32|b64)\s*=', combined))
|
||||
modifies_vcc = has_sdst or bool(re.search(r'VCC\.(u32|u64|b32|b64)\s*=|VCC\.u64\[laneId\]\s*=', combined))
|
||||
modifies_scc = bool(re.search(r'\bSCC\s*=', combined))
|
||||
modifies_pc = bool(re.search(r'\bPC\s*=', combined))
|
||||
|
||||
# Build function signature and Reg init lines
|
||||
if is_smem:
|
||||
lines = [f"def {fn_name}(MEM, addr):"]
|
||||
reg_inits = ["ADDR=Reg(addr)", "SDATA=Reg(0)"]
|
||||
special_regs = []
|
||||
elif is_ds:
|
||||
lines = [f"def {fn_name}(MEM, addr, data0, data1, offset0, offset1):"]
|
||||
reg_inits = ["ADDR=Reg(addr)", "DATA0=Reg(data0)", "DATA1=Reg(data1)", "OFFSET0=Reg(offset0)", "OFFSET1=Reg(offset1)", "RETURN_DATA=Reg(0)"]
|
||||
special_regs = [('DATA', 'DATA0'), ('DATA2', 'DATA1'), ('OFFSET', 'OFFSET0'), ('ADDR_BASE', 'ADDR')]
|
||||
elif is_flat:
|
||||
lines = [f"def {fn_name}(MEM, addr, vdata, vdst):"]
|
||||
reg_inits = ["ADDR=addr", "VDATA=Reg(vdata)", "VDST=Reg(vdst)", "RETURN_DATA=Reg(0)"]
|
||||
special_regs = [('DATA', 'VDATA')]
|
||||
elif has_s_array:
|
||||
# FMA_MIX style: needs S[i] array, opsel, opsel_hi for source selection (neg/neg_hi applied in emu.py before call)
|
||||
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None, opsel=0, opsel_hi=0):"]
|
||||
reg_inits = ["S0=Reg(s0)", "S1=Reg(s1)", "S2=Reg(s2)", "S=[S0,S1,S2]", "D0=Reg(d0)", "OPSEL=Reg(opsel)", "OPSEL_HI=Reg(opsel_hi)"]
|
||||
special_regs = []
|
||||
# Detect array declarations like "declare in : 32'F[3]" and create them (rename 'in' to 'ins' since 'in' is a keyword)
|
||||
if "in[" in combined:
|
||||
reg_inits.append("ins=[Reg(0),Reg(0),Reg(0)]")
|
||||
code = code.replace("in[", "ins[")
|
||||
else:
|
||||
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None):"]
|
||||
# Only create Regs for registers actually used in the pseudocode
|
||||
reg_inits = []
|
||||
if 'S0' in combined: reg_inits.append("S0=Reg(s0)")
|
||||
if 'S1' in combined: reg_inits.append("S1=Reg(s1)")
|
||||
if 'S2' in combined: reg_inits.append("S2=Reg(s2)")
|
||||
if modifies_d0 or 'D0' in combined: reg_inits.append("D0=Reg(s0)" if is_div_scale else "D0=Reg(d0)")
|
||||
if modifies_scc or 'SCC' in combined: reg_inits.append("SCC=Reg(scc)")
|
||||
if modifies_vcc or 'VCC' in combined: reg_inits.append("VCC=Reg(vcc)")
|
||||
if modifies_exec or 'EXEC' in combined: reg_inits.append("EXEC=Reg(exec_mask)")
|
||||
if modifies_pc or 'PC' in combined: reg_inits.append("PC=Reg(pc) if pc is not None else None")
|
||||
special_regs = [('D1', 'Reg(0)'), ('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
|
||||
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
|
||||
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
|
||||
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
|
||||
|
||||
# Build init code
|
||||
init_parts = reg_inits.copy()
|
||||
for name, init in special_regs:
|
||||
if name in combined: init_parts.append(f"{name}={init}")
|
||||
if 'EXEC_LO' in code: init_parts.append("EXEC_LO=TypedView(EXEC, 31, 0)")
|
||||
if 'EXEC_HI' in code: init_parts.append("EXEC_HI=TypedView(EXEC, 63, 32)")
|
||||
if 'VCCZ' in code and not re.search(r'^\s*VCCZ\s*=', code, re.MULTILINE): init_parts.append("VCCZ=Reg(1 if VCC._val == 0 else 0)")
|
||||
if 'EXECZ' in code and not re.search(r'^\s*EXECZ\s*=', code, re.MULTILINE): init_parts.append("EXECZ=Reg(1 if EXEC._val == 0 else 0)")
|
||||
|
||||
# Add init line and separator
|
||||
if init_parts: lines.append(f" {'; '.join(init_parts)}")
|
||||
|
||||
# Add compiled pseudocode
|
||||
for line in code.split('\n'):
|
||||
if line.strip(): lines.append(f" {line}")
|
||||
|
||||
# Build result dict
|
||||
result_items = []
|
||||
if modifies_d0: result_items.append("'D0': D0._val")
|
||||
if modifies_scc: result_items.append("'SCC': SCC._val")
|
||||
if modifies_vcc: result_items.append("'VCC': VCC._val")
|
||||
if modifies_exec: result_items.append("'EXEC': EXEC._val")
|
||||
if has_d1: result_items.append("'D1': D1._val")
|
||||
if modifies_pc: result_items.append("'PC': PC._val")
|
||||
if is_smem and 'SDATA' in combined and re.search(r'^\s*SDATA[\.\[].*=', code, re.MULTILINE):
|
||||
result_items.append("'SDATA': SDATA._val")
|
||||
if is_ds and 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
|
||||
result_items.append("'RETURN_DATA': RETURN_DATA._val")
|
||||
if is_flat:
|
||||
if 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
|
||||
result_items.append("'RETURN_DATA': RETURN_DATA._val")
|
||||
if re.search(r'^\s*VDATA[\.\[].*=', code, re.MULTILINE):
|
||||
result_items.append("'VDATA': VDATA._val")
|
||||
lines.append(f" return {{{', '.join(result_items)}}}")
|
||||
return '\n'.join(lines)
|
||||
|
||||
# Build the globals dict for exec() - includes all pcode symbols
|
||||
_PCODE_GLOBALS = {
|
||||
'Reg': Reg, 'TypedView': TypedView, '_pack': _pack, '_pack32': _pack32,
|
||||
'ABSDIFF': ABSDIFF, 'BYTE_PERMUTE': BYTE_PERMUTE, 'DENORM': DENORM, 'F': F,
|
||||
'GT_NEG_ZERO': GT_NEG_ZERO, 'LT_NEG_ZERO': LT_NEG_ZERO, 'INF': INF,
|
||||
'MAX_FLOAT_F32': MAX_FLOAT_F32, 'OVERFLOW_F32': OVERFLOW_F32, 'OVERFLOW_F64': OVERFLOW_F64,
|
||||
'UNDERFLOW_F32': UNDERFLOW_F32, 'UNDERFLOW_F64': UNDERFLOW_F64,
|
||||
'PI': PI, 'ROUND_MODE': ROUND_MODE, 'WAVE_MODE': WAVE_MODE,
|
||||
'WAVE32': WAVE32, 'WAVE64': WAVE64, 'TWO_OVER_PI_1201': TWO_OVER_PI_1201,
|
||||
'SAT8': SAT8, 'trunc': trunc, 'floor': floor, 'ceil': ceil, 'sqrt': sqrt,
|
||||
'log2': log2, 'fract': fract, 'sin': sin, 'cos': cos, 'pow': pow,
|
||||
'isEven': isEven, 'mantissa': mantissa, 'signext_from_bit': signext_from_bit,
|
||||
'i32_to_f32': i32_to_f32, 'u32_to_f32': u32_to_f32, 'i32_to_f64': i32_to_f64,
|
||||
'u32_to_f64': u32_to_f64, 'f32_to_f64': f32_to_f64, 'f64_to_f32': f64_to_f32,
|
||||
'f32_to_i32': f32_to_i32, 'f32_to_u32': f32_to_u32, 'f64_to_i32': f64_to_i32,
|
||||
'f64_to_u32': f64_to_u32, 'f32_to_f16': f32_to_f16, 'f16_to_f32': f16_to_f32,
|
||||
'i16_to_f16': i16_to_f16, 'u16_to_f16': u16_to_f16, 'f16_to_i16': f16_to_i16,
|
||||
'f16_to_u16': f16_to_u16, 'bf16_to_f32': bf16_to_f32, 'f32_to_bf16': f32_to_bf16,
|
||||
'u8_to_u32': u8_to_u32, 'u4_to_u32': u4_to_u32, 'u32_to_u16': u32_to_u16,
|
||||
'i32_to_i16': i32_to_i16, 'f16_to_snorm': f16_to_snorm, 'f16_to_unorm': f16_to_unorm,
|
||||
'f32_to_snorm': f32_to_snorm, 'f32_to_unorm': f32_to_unorm,
|
||||
'v_cvt_i16_f32': v_cvt_i16_f32, 'v_cvt_u16_f32': v_cvt_u16_f32, 'f32_to_u8': f32_to_u8,
|
||||
'v_min_f32': v_min_f32, 'v_max_f32': v_max_f32, 'v_min_f16': v_min_f16, 'v_max_f16': v_max_f16,
|
||||
'v_min_i32': v_min_i32, 'v_max_i32': v_max_i32, 'v_min_i16': v_min_i16, 'v_max_i16': v_max_i16,
|
||||
'v_min_u32': v_min_u32, 'v_max_u32': v_max_u32, 'v_min_u16': v_min_u16, 'v_max_u16': v_max_u16,
|
||||
'v_min3_f32': v_min3_f32, 'v_max3_f32': v_max3_f32, 'v_min3_f16': v_min3_f16, 'v_max3_f16': v_max3_f16,
|
||||
'v_min3_i32': v_min3_i32, 'v_max3_i32': v_max3_i32, 'v_min3_i16': v_min3_i16, 'v_max3_i16': v_max3_i16,
|
||||
'v_min3_u32': v_min3_u32, 'v_max3_u32': v_max3_u32, 'v_min3_u16': v_min3_u16, 'v_max3_u16': v_max3_u16,
|
||||
'v_sad_u8': v_sad_u8, 'v_msad_u8': v_msad_u8,
|
||||
's_ff1_i32_b32': s_ff1_i32_b32, 's_ff1_i32_b64': s_ff1_i32_b64,
|
||||
'isNAN': isNAN, 'isQuietNAN': isQuietNAN, 'isSignalNAN': isSignalNAN,
|
||||
'fma': fma, 'ldexp': ldexp, 'sign': sign, 'exponent': exponent,
|
||||
'signext': signext, 'cvtToQuietNAN': cvtToQuietNAN,
|
||||
}
|
||||
|
||||
@functools.cache
|
||||
def compile_pseudocode(cls_name: str, op_name: str, pseudocode: str):
|
||||
"""Compile pseudocode string to executable function. Cached for performance."""
|
||||
code = _compile_pseudocode(pseudocode)
|
||||
code = _apply_pseudocode_fixes(op_name, code)
|
||||
fn_code = _generate_function(cls_name, op_name, pseudocode, code)
|
||||
fn_name = f"_{cls_name}_{op_name}"
|
||||
local_ns = {}
|
||||
exec(fn_code, _PCODE_GLOBALS, local_ns)
|
||||
return local_ns[fn_name]
|
||||
|
||||
@@ -38,161 +38,7 @@ FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'N
|
||||
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
|
||||
INST_PATTERN = re.compile(r'^([SVD]S?_[A-Z0-9_]+|(?:FLAT|GLOBAL|SCRATCH)_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
|
||||
|
||||
# Patterns that can't be handled by the DSL (require special handling in emu.py)
|
||||
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
|
||||
'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
|
||||
'CVT_OFF_TABLE', 'ThreadMask',
|
||||
'S1[i', 'C.i32', 'thread_',
|
||||
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST',
|
||||
'BARRIER_STATE', 'ReallocVgprs',
|
||||
'GPR_IDX', 'VSKIP', 'specified in', 'TTBL',
|
||||
'fp6', 'bf6', 'GS_REGS', 'M0.base', 'DS_DATA', '= 0..', 'sign(src', 'if no LDS', 'gds_base', 'vector mask',
|
||||
'SGPR_ADDR', 'INST_OFFSET', 'laneID'] # FLAT ops with non-standard vars
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# COMPILER: pseudocode -> Python (minimal transforms)
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def compile_pseudocode(pseudocode: str) -> str:
|
||||
"""Compile pseudocode to Python. Transforms are minimal - most syntax just works."""
|
||||
pseudocode = re.sub(r'\bpass\b', 'pass_', pseudocode) # 'pass' is Python keyword
|
||||
raw_lines = pseudocode.strip().split('\n')
|
||||
joined_lines: list[str] = []
|
||||
for line in raw_lines:
|
||||
line = line.strip()
|
||||
if joined_lines and (joined_lines[-1].rstrip().endswith(('||', '&&', '(', ',')) or
|
||||
(joined_lines[-1].count('(') > joined_lines[-1].count(')'))):
|
||||
joined_lines[-1] = joined_lines[-1].rstrip() + ' ' + line
|
||||
else:
|
||||
joined_lines.append(line)
|
||||
|
||||
lines = []
|
||||
indent, need_pass, in_first_match_loop = 0, False, False
|
||||
for line in joined_lines:
|
||||
line = line.split('//')[0].strip() # Strip C-style comments
|
||||
if not line: continue
|
||||
if line.startswith('if '):
|
||||
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line.startswith('elsif '):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
lines.append(' ' * indent + f"elif {_expr(line[6:].rstrip(' then'))}:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line == 'else':
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
lines.append(' ' * indent + "else:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line.startswith('endif'):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
need_pass = False
|
||||
elif line.startswith('endfor'):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
need_pass, in_first_match_loop = False, False
|
||||
elif line.startswith('declare '):
|
||||
pass
|
||||
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
|
||||
start, end = _expr(m[2].strip()), _expr(m[3].strip())
|
||||
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
|
||||
indent += 1
|
||||
need_pass, in_first_match_loop = True, True
|
||||
elif '=' in line and not line.startswith('=='):
|
||||
need_pass = False
|
||||
line = line.rstrip(';')
|
||||
if m := re.match(r'\{\s*D1\.[ui]1\s*,\s*D0\.[ui]64\s*\}\s*=\s*(.+)', line):
|
||||
rhs = _expr(m[1])
|
||||
lines.append(' ' * indent + f"_full = {rhs}")
|
||||
lines.append(' ' * indent + f"D0.u64 = int(_full) & 0xffffffffffffffff")
|
||||
lines.append(' ' * indent + f"D1 = Reg((int(_full) >> 64) & 1)")
|
||||
elif any(op in line for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^=')):
|
||||
for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^='):
|
||||
if op in line:
|
||||
lhs, rhs = line.split(op, 1)
|
||||
lines.append(' ' * indent + f"{lhs.strip()} {op} {_expr(rhs.strip())}")
|
||||
break
|
||||
else:
|
||||
lhs, rhs = line.split('=', 1)
|
||||
lhs_s, rhs_s = _expr(lhs.strip()), rhs.strip()
|
||||
stmt = _assign(lhs_s, _expr(rhs_s))
|
||||
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
|
||||
stmt += "; break"
|
||||
lines.append(' ' * indent + stmt)
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
return '\n'.join(lines)
|
||||
|
||||
def _assign(lhs: str, rhs: str) -> str:
|
||||
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'):
|
||||
return f"{lhs} = Reg({rhs})"
|
||||
return f"{lhs} = {rhs}"
|
||||
|
||||
def _expr(e: str) -> str:
|
||||
e = e.strip()
|
||||
e = e.replace('&&', ' and ').replace('||', ' or ').replace('<>', ' != ')
|
||||
e = re.sub(r'!([^=])', r' not \1', e)
|
||||
e = re.sub(r'\{\s*(\w+\.u32)\s*,\s*(\w+\.u32)\s*\}', r'_pack32(\1, \2)', e)
|
||||
def pack(m):
|
||||
hi, lo = _expr(m[1].strip()), _expr(m[2].strip())
|
||||
return f'_pack({hi}, {lo})'
|
||||
e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e)
|
||||
e = re.sub(r"1201'B\(2\.0\s*/\s*PI\)", "TWO_OVER_PI_1201", e)
|
||||
e = re.sub(r"\d+'([0-9a-fA-Fx]+)[UuFf]*", r'\1', e)
|
||||
e = re.sub(r"\d+'[FIBU]\(", "(", e)
|
||||
e = re.sub(r'\bB\(', '(', e)
|
||||
e = re.sub(r'([0-9a-fA-Fx])ULL\b', r'\1', e)
|
||||
e = re.sub(r'([0-9a-fA-Fx])LL\b', r'\1', e)
|
||||
e = re.sub(r'([0-9a-fA-Fx])U\b', r'\1', e)
|
||||
e = re.sub(r'(\d\.?\d*)F\b', r'\1', e)
|
||||
e = re.sub(r'(\[laneId\])\.[uib]\d+', r'\1', e)
|
||||
e = e.replace('+INF', 'INF').replace('-INF', '(-INF)')
|
||||
e = re.sub(r'NAN\.f\d+', 'float("nan")', e)
|
||||
def convert_verilog_slice(m):
|
||||
start, width = m.group(1).strip(), m.group(2).strip()
|
||||
return f'[({start}) + ({width}) - 1 : ({start})]'
|
||||
e = re.sub(r'\[([^:\[\]]+)\s*\+:\s*([^:\[\]]+)\]', convert_verilog_slice, e)
|
||||
def process_brackets(s):
|
||||
result, i = [], 0
|
||||
while i < len(s):
|
||||
if s[i] == '[':
|
||||
depth, start = 1, i + 1
|
||||
j = start
|
||||
while j < len(s) and depth > 0:
|
||||
if s[j] == '[': depth += 1
|
||||
elif s[j] == ']': depth -= 1
|
||||
j += 1
|
||||
inner = _expr(s[start:j-1])
|
||||
result.append('[' + inner + ']')
|
||||
i = j
|
||||
else:
|
||||
result.append(s[i])
|
||||
i += 1
|
||||
return ''.join(result)
|
||||
e = process_brackets(e)
|
||||
while '?' in e:
|
||||
depth, bracket, q = 0, 0, -1
|
||||
for i, c in enumerate(e):
|
||||
if c == '(': depth += 1
|
||||
elif c == ')': depth -= 1
|
||||
elif c == '[': bracket += 1
|
||||
elif c == ']': bracket -= 1
|
||||
elif c == '?' and depth == 0 and bracket == 0: q = i; break
|
||||
if q < 0: break
|
||||
depth, bracket, col = 0, 0, -1
|
||||
for i in range(q + 1, len(e)):
|
||||
if e[i] == '(': depth += 1
|
||||
elif e[i] == ')': depth -= 1
|
||||
elif e[i] == '[': bracket += 1
|
||||
elif e[i] == ']': bracket -= 1
|
||||
elif e[i] == ':' and depth == 0 and bracket == 0: col = i; break
|
||||
if col < 0: break
|
||||
cond, t, f = e[:q].strip(), e[q+1:col].strip(), e[col+1:].strip()
|
||||
e = f'(({t}) if ({cond}) else ({f}))'
|
||||
return e
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PDF PARSING WITH PAGE CACHING
|
||||
@@ -472,8 +318,8 @@ def _generate_ins_py(formats, enums, src_enum, doc_name) -> str:
|
||||
if "NULL" in src_names: lines.append("OFF = NULL\n")
|
||||
return '\n'.join(lines)
|
||||
|
||||
def _generate_gen_pcode_py(enums, pseudocode, arch) -> str:
|
||||
"""Generate gen_pcode.py content (compiled pseudocode functions)."""
|
||||
def _generate_str_pcode_py(enums, pseudocode, arch) -> str:
|
||||
"""Generate str_pcode.py content (raw pseudocode strings)."""
|
||||
# Get op enums for this arch (import from .ins which re-exports from .enum)
|
||||
import importlib
|
||||
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}.ins")
|
||||
@@ -491,186 +337,35 @@ def _generate_gen_pcode_py(enums, pseudocode, arch) -> str:
|
||||
if key in defined_ops:
|
||||
for enum_cls, enum_val in defined_ops[key]: instructions[enum_cls][enum_val] = pc
|
||||
|
||||
# First pass: generate all function code
|
||||
fn_lines: list[str] = []
|
||||
all_fn_entries: dict = {}
|
||||
for enum_cls in OP_ENUMS:
|
||||
cls_name = enum_cls.__name__
|
||||
if not instructions.get(enum_cls): continue
|
||||
fn_entries = []
|
||||
for op, pc in instructions[enum_cls].items():
|
||||
if any(p in pc for p in UNSUPPORTED): continue
|
||||
try:
|
||||
code = compile_pseudocode(pc)
|
||||
code = _apply_pseudocode_fixes(op, code)
|
||||
fn_name, fn_code = _generate_function(cls_name, op, pc, code)
|
||||
fn_lines.append(fn_code)
|
||||
fn_entries.append((op, fn_name))
|
||||
except Exception as e: print(f" Warning: Failed to compile {op.name}: {e}")
|
||||
if fn_entries:
|
||||
all_fn_entries[enum_cls] = fn_entries
|
||||
fn_lines.append(f'{cls_name}_FUNCTIONS = {{')
|
||||
for op, fn_name in fn_entries: fn_lines.append(f" {cls_name}.{op.name}: {fn_name},")
|
||||
fn_lines.append('}\n')
|
||||
|
||||
fn_lines.append('COMPILED_FUNCTIONS = {')
|
||||
for enum_cls in OP_ENUMS:
|
||||
if all_fn_entries.get(enum_cls): fn_lines.append(f' {enum_cls.__name__}: {enum_cls.__name__}_FUNCTIONS,')
|
||||
fn_lines.append('}')
|
||||
|
||||
# Second pass: scan generated code for pcode imports
|
||||
fn_code_str = '\n'.join(fn_lines)
|
||||
import extra.assembly.amd.pcode as pcode_module
|
||||
pcode_exports = [name for name in dir(pcode_module) if not name.startswith('_') or name.startswith('_') and not name.startswith('__')]
|
||||
used_imports = sorted(name for name in pcode_exports if re.search(rf'\b{re.escape(name)}\b', fn_code_str))
|
||||
|
||||
# Build final output with explicit imports
|
||||
# Build string dictionaries for each enum
|
||||
lines = [f'''# autogenerated by pdf.py - do not edit
|
||||
# to regenerate: python -m extra.assembly.amd.pdf --arch {arch}
|
||||
# ruff: noqa: E501
|
||||
# mypy: ignore-errors
|
||||
from extra.assembly.amd.autogen.{arch}.enum import {", ".join(enum_names)}
|
||||
from extra.assembly.amd.pcode import {", ".join(used_imports)}
|
||||
'''] + fn_lines
|
||||
''']
|
||||
all_dict_entries: dict = {}
|
||||
for enum_cls in OP_ENUMS:
|
||||
cls_name = enum_cls.__name__
|
||||
if not instructions.get(enum_cls): continue
|
||||
dict_entries = [(op, repr(pc)) for op, pc in instructions[enum_cls].items()]
|
||||
if dict_entries:
|
||||
all_dict_entries[enum_cls] = dict_entries
|
||||
lines.append(f'{cls_name}_PCODE = {{')
|
||||
for op, escaped in dict_entries: lines.append(f" {cls_name}.{op.name}: {escaped},")
|
||||
lines.append('}\n')
|
||||
|
||||
lines.append('PSEUDOCODE_STRINGS = {')
|
||||
for enum_cls in OP_ENUMS:
|
||||
if all_dict_entries.get(enum_cls): lines.append(f' {enum_cls.__name__}: {enum_cls.__name__}_PCODE,')
|
||||
lines.append('}')
|
||||
return '\n'.join(lines)
|
||||
|
||||
def _apply_pseudocode_fixes(op, code: str) -> str:
|
||||
"""Apply known fixes for PDF pseudocode bugs."""
|
||||
if op.name == 'V_DIV_FMAS_F32':
|
||||
code = code.replace('D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)',
|
||||
'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)')
|
||||
if op.name == 'V_DIV_FMAS_F64':
|
||||
code = code.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
|
||||
'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)')
|
||||
if op.name == 'V_DIV_SCALE_F32':
|
||||
code = code.replace('D0.f32 = float("nan")', 'VCC = Reg(0x1); D0.f32 = float("nan")')
|
||||
code = code.replace('elif S1.f32 == DENORM.f32:\n D0.f32 = ldexp(S0.f32, 64)', 'elif False:\n pass')
|
||||
code += '\nif S1.f32 == DENORM.f32:\n D0.f32 = float("nan")'
|
||||
code = code.replace('elif exponent(S2.f32) <= 23:\n D0.f32 = ldexp(S0.f32, 64)', 'elif exponent(S2.f32) <= 23:\n VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64)')
|
||||
code = code.replace('elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)\n if S0.f32 == S2.f32:\n D0.f32 = ldexp(S0.f32, 64)', 'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)')
|
||||
if op.name == 'V_DIV_SCALE_F64':
|
||||
code = code.replace('D0.f64 = float("nan")', 'VCC = Reg(0x1); D0.f64 = float("nan")')
|
||||
code = code.replace('elif S1.f64 == DENORM.f64:\n D0.f64 = ldexp(S0.f64, 128)', 'elif False:\n pass')
|
||||
code += '\nif S1.f64 == DENORM.f64:\n D0.f64 = float("nan")'
|
||||
code = code.replace('elif exponent(S2.f64) <= 52:\n D0.f64 = ldexp(S0.f64, 128)', 'elif exponent(S2.f64) <= 52:\n VCC = Reg(0x1); D0.f64 = ldexp(S0.f64, 128)')
|
||||
code = code.replace('elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)\n if S0.f64 == S2.f64:\n D0.f64 = ldexp(S0.f64, 128)', 'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)')
|
||||
if op.name == 'V_DIV_FIXUP_F32':
|
||||
code = code.replace('D0.f32 = ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))',
|
||||
'D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32)) if isNAN(S0.f32) else ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))')
|
||||
if op.name == 'V_DIV_FIXUP_F64':
|
||||
code = code.replace('D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))',
|
||||
'D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))')
|
||||
if op.name == 'V_TRIG_PREOP_F64':
|
||||
code = code.replace('result = F((TWO_OVER_PI_1201[1200 : 0] << shift.u32) & 0x1fffffffffffff)',
|
||||
'result = float(((TWO_OVER_PI_1201[1200 : 0] << int(shift)) >> (1201 - 53)) & 0x1fffffffffffff)')
|
||||
return code
|
||||
|
||||
def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]:
|
||||
"""Generate a single compiled pseudocode function.
|
||||
Functions take int parameters and return dict of int values.
|
||||
Reg wrapping happens inside the function, only for registers actually used."""
|
||||
has_d1 = '{ D1' in pc
|
||||
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
|
||||
is_div_scale = 'DIV_SCALE' in op.name
|
||||
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
|
||||
is_ds = cls_name == 'DSOp'
|
||||
is_flat = cls_name in ('FLATOp', 'GLOBALOp', 'SCRATCHOp')
|
||||
is_smem = cls_name == 'SMEMOp'
|
||||
has_s_array = 'S[i]' in pc # FMA_MIX style: S[0], S[1], S[2] array access
|
||||
combined = code + pc
|
||||
|
||||
fn_name = f"_{cls_name}_{op.name}"
|
||||
|
||||
# Detect which registers are used/modified
|
||||
def needs_init(name): return name in combined and not re.search(rf'^\s*{name}\s*=\s*Reg\(', code, re.MULTILINE)
|
||||
modifies_d0 = is_div_scale or bool(re.search(r'\bD0\b[.\[]', combined))
|
||||
modifies_exec = is_cmpx or bool(re.search(r'EXEC\.(u32|u64|b32|b64)\s*=', combined))
|
||||
modifies_vcc = has_sdst or bool(re.search(r'VCC\.(u32|u64|b32|b64)\s*=|VCC\.u64\[laneId\]\s*=', combined))
|
||||
modifies_scc = bool(re.search(r'\bSCC\s*=', combined))
|
||||
modifies_pc = bool(re.search(r'\bPC\s*=', combined))
|
||||
|
||||
# Build function signature and Reg init lines
|
||||
if is_smem:
|
||||
lines = [f"def {fn_name}(MEM, addr):"]
|
||||
reg_inits = ["ADDR=Reg(addr)", "SDATA=Reg(0)"]
|
||||
special_regs = []
|
||||
elif is_ds:
|
||||
lines = [f"def {fn_name}(MEM, addr, data0, data1, offset0, offset1):"]
|
||||
reg_inits = ["ADDR=Reg(addr)", "DATA0=Reg(data0)", "DATA1=Reg(data1)", "OFFSET0=Reg(offset0)", "OFFSET1=Reg(offset1)", "RETURN_DATA=Reg(0)"]
|
||||
special_regs = [('DATA', 'DATA0'), ('DATA2', 'DATA1'), ('OFFSET', 'OFFSET0'), ('ADDR_BASE', 'ADDR')]
|
||||
elif is_flat:
|
||||
lines = [f"def {fn_name}(MEM, addr, vdata, vdst):"]
|
||||
reg_inits = ["ADDR=addr", "VDATA=Reg(vdata)", "VDST=Reg(vdst)", "RETURN_DATA=Reg(0)"]
|
||||
special_regs = [('DATA', 'VDATA')]
|
||||
elif has_s_array:
|
||||
# FMA_MIX style: needs S[i] array, opsel, opsel_hi for source selection (neg/neg_hi applied in emu.py before call)
|
||||
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None, opsel=0, opsel_hi=0):"]
|
||||
reg_inits = ["S0=Reg(s0)", "S1=Reg(s1)", "S2=Reg(s2)", "S=[S0,S1,S2]", "D0=Reg(d0)", "OPSEL=Reg(opsel)", "OPSEL_HI=Reg(opsel_hi)"]
|
||||
special_regs = []
|
||||
# Detect array declarations like "declare in : 32'F[3]" and create them (rename 'in' to 'ins' since 'in' is a keyword)
|
||||
if "in[" in combined:
|
||||
reg_inits.append("ins=[Reg(0),Reg(0),Reg(0)]")
|
||||
code = code.replace("in[", "ins[")
|
||||
else:
|
||||
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None):"]
|
||||
# Only create Regs for registers actually used in the pseudocode
|
||||
reg_inits = []
|
||||
if 'S0' in combined: reg_inits.append("S0=Reg(s0)")
|
||||
if 'S1' in combined: reg_inits.append("S1=Reg(s1)")
|
||||
if 'S2' in combined: reg_inits.append("S2=Reg(s2)")
|
||||
if modifies_d0 or 'D0' in combined: reg_inits.append("D0=Reg(s0)" if is_div_scale else "D0=Reg(d0)")
|
||||
if modifies_scc or 'SCC' in combined: reg_inits.append("SCC=Reg(scc)")
|
||||
if modifies_vcc or 'VCC' in combined: reg_inits.append("VCC=Reg(vcc)")
|
||||
if modifies_exec or 'EXEC' in combined: reg_inits.append("EXEC=Reg(exec_mask)")
|
||||
if modifies_pc or 'PC' in combined: reg_inits.append("PC=Reg(pc) if pc is not None else None")
|
||||
special_regs = [('D1', 'Reg(0)'), ('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
|
||||
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
|
||||
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
|
||||
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
|
||||
|
||||
# Build init code
|
||||
init_parts = reg_inits.copy()
|
||||
for name, init in special_regs:
|
||||
if name in combined: init_parts.append(f"{name}={init}")
|
||||
if 'EXEC_LO' in code: init_parts.append("EXEC_LO=SliceProxy(EXEC, 31, 0)")
|
||||
if 'EXEC_HI' in code: init_parts.append("EXEC_HI=SliceProxy(EXEC, 63, 32)")
|
||||
if 'VCCZ' in code and not re.search(r'^\s*VCCZ\s*=', code, re.MULTILINE): init_parts.append("VCCZ=Reg(1 if VCC._val == 0 else 0)")
|
||||
if 'EXECZ' in code and not re.search(r'^\s*EXECZ\s*=', code, re.MULTILINE): init_parts.append("EXECZ=Reg(1 if EXEC._val == 0 else 0)")
|
||||
|
||||
# Add init line and separator
|
||||
if init_parts: lines.append(f" {'; '.join(init_parts)}")
|
||||
lines.append(" # --- compiled pseudocode ---")
|
||||
|
||||
# Add compiled pseudocode
|
||||
for line in code.split('\n'):
|
||||
if line.strip(): lines.append(f" {line}")
|
||||
|
||||
# Build result dict
|
||||
result_items = []
|
||||
if modifies_d0: result_items.append("'D0': D0._val")
|
||||
if modifies_scc: result_items.append("'SCC': SCC._val")
|
||||
if modifies_vcc: result_items.append("'VCC': VCC._val")
|
||||
if modifies_exec: result_items.append("'EXEC': EXEC._val")
|
||||
if has_d1: result_items.append("'D1': D1._val")
|
||||
if modifies_pc: result_items.append("'PC': PC._val")
|
||||
if is_smem and 'SDATA' in combined and re.search(r'^\s*SDATA[\.\[].*=', code, re.MULTILINE):
|
||||
result_items.append("'SDATA': SDATA._val")
|
||||
if is_ds and 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
|
||||
result_items.append("'RETURN_DATA': RETURN_DATA._val")
|
||||
if is_flat:
|
||||
if 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
|
||||
result_items.append("'RETURN_DATA': RETURN_DATA._val")
|
||||
if re.search(r'^\s*VDATA[\.\[].*=', code, re.MULTILINE):
|
||||
result_items.append("'VDATA': VDATA._val")
|
||||
lines.append(f" return {{{', '.join(result_items)}}}\n")
|
||||
return fn_name, '\n'.join(lines)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# MAIN GENERATION
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def generate_arch(arch: str) -> dict:
|
||||
"""Generate enum.py, ins.py and gen_pcode.py for a single architecture."""
|
||||
"""Generate enum.py, ins.py and str_pcode.py for a single architecture."""
|
||||
urls = PDF_URLS[arch]
|
||||
if isinstance(urls, str): urls = [urls]
|
||||
|
||||
@@ -696,9 +391,9 @@ def generate_arch(arch: str) -> dict:
|
||||
ins_path.write_text(ins_content)
|
||||
print(f"Generated {ins_path}: {len(merged['formats'])} formats")
|
||||
|
||||
# Write gen_pcode.py (needs enum.py to exist first for imports)
|
||||
pcode_path = base_path / "gen_pcode.py"
|
||||
pcode_content = _generate_gen_pcode_py(merged["enums"], merged["pseudocode"], arch)
|
||||
# Write str_pcode.py (needs enum.py to exist first for imports)
|
||||
pcode_path = base_path / "str_pcode.py"
|
||||
pcode_content = _generate_str_pcode_py(merged["enums"], merged["pseudocode"], arch)
|
||||
pcode_path.write_text(pcode_content)
|
||||
print(f"Generated {pcode_path}: {len(merged['pseudocode'])} instructions")
|
||||
|
||||
|
||||
@@ -30,8 +30,8 @@ def get_llvm_objdump():
|
||||
class ExecContext:
|
||||
"""Context for running compiled pseudocode in tests."""
|
||||
def __init__(self, s0=0, s1=0, s2=0, d0=0, scc=0, vcc=0, lane=0, exec_mask=0xffffffff, literal=0, vgprs=None, src0_idx=0, vdst_idx=0):
|
||||
from extra.assembly.amd.pcode import Reg, MASK32, MASK64, SliceProxy
|
||||
self._Reg, self._MASK64, self._SliceProxy = Reg, MASK64, SliceProxy
|
||||
from extra.assembly.amd.pcode import Reg, MASK32, MASK64, TypedView
|
||||
self._Reg, self._MASK64, self._TypedView = Reg, MASK64, TypedView
|
||||
self.S0, self.S1, self.S2 = Reg(s0), Reg(s1), Reg(s2)
|
||||
self.D0, self.D1 = Reg(d0), Reg(0)
|
||||
self.SCC, self.VCC, self.EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)
|
||||
@@ -51,7 +51,7 @@ class ExecContext:
|
||||
ns.update({
|
||||
'S0': self.S0, 'S1': self.S1, 'S2': self.S2, 'D0': self.D0, 'D1': self.D1,
|
||||
'SCC': self.SCC, 'VCC': self.VCC, 'EXEC': self.EXEC,
|
||||
'EXEC_LO': self._SliceProxy(self.EXEC, 31, 0), 'EXEC_HI': self._SliceProxy(self.EXEC, 63, 32),
|
||||
'EXEC_LO': self._TypedView(self.EXEC, 31, 0), 'EXEC_HI': self._TypedView(self.EXEC, 63, 32),
|
||||
'tmp': self.tmp, 'saveexec': self.saveexec,
|
||||
'lane': self.lane, 'laneId': self.laneId, 'literal': self.literal,
|
||||
'SIMM16': self.SIMM16, 'SIMM32': self.SIMM32, 'VGPR': self.VGPR, 'SRC0': self.SRC0, 'VDST': self.VDST,
|
||||
|
||||
@@ -47,8 +47,7 @@ dev.synchronize()
|
||||
elapsed = time.perf_counter() - st
|
||||
|
||||
self.assertNotEqual(result.returncode, 0, "should have raised")
|
||||
self.assertTrue("NotImplementedError" in result.stderr or "ValueError" in result.stderr,
|
||||
f"expected NotImplementedError or ValueError in stderr")
|
||||
self.assertTrue("Error" in result.stderr, f"expected an error in stderr, got: {result.stderr[:500]}")
|
||||
# Should exit immediately, not wait for the full timeout
|
||||
self.assertLess(elapsed, 9.0, f"should exit immediately on emulator exception, took {elapsed:.1f}s")
|
||||
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tests for the RDNA3 pseudocode DSL."""
|
||||
import unittest
|
||||
from extra.assembly.amd.pcode import (Reg, TypedView, SliceProxy, MASK32, MASK64,
|
||||
_f32, _i32, _f16, _i16, f32_to_f16, _isnan, _bf16, _ibf16, bf16_to_f32, f32_to_bf16,
|
||||
BYTE_PERMUTE, v_sad_u8, v_msad_u8)
|
||||
from extra.assembly.amd.pdf import compile_pseudocode, _expr
|
||||
from extra.assembly.amd.pcode import (Reg, TypedView, TypedView, MASK32, MASK64,
|
||||
_f32, _i32, _f16, _i16, f32_to_f16, isNAN, _bf16, _ibf16, bf16_to_f32, f32_to_bf16,
|
||||
BYTE_PERMUTE, v_sad_u8, v_msad_u8, _compile_pseudocode, _expr, compile_pseudocode)
|
||||
from extra.assembly.amd.test.helpers import ExecContext
|
||||
from extra.assembly.amd.autogen.rdna3.gen_pcode import _VOP3SDOp_V_DIV_SCALE_F32, _VOPCOp_V_CMP_CLASS_F32
|
||||
from extra.assembly.amd.autogen.rdna3.str_pcode import VOP3SDOp_PCODE, VOPCOp_PCODE
|
||||
from extra.assembly.amd.autogen.rdna3.enum import VOP3SDOp, VOPCOp
|
||||
|
||||
# Compile pseudocode functions on demand for regression tests
|
||||
_VOP3SDOp_V_DIV_SCALE_F32 = compile_pseudocode('VOP3SDOp', 'V_DIV_SCALE_F32', VOP3SDOp_PCODE[VOP3SDOp.V_DIV_SCALE_F32])
|
||||
_VOPCOp_V_CMP_CLASS_F32 = compile_pseudocode('VOPCOp', 'V_CMP_CLASS_F32', VOPCOp_PCODE[VOPCOp.V_CMP_CLASS_F32])
|
||||
|
||||
class TestReg(unittest.TestCase):
|
||||
def test_u32_read(self):
|
||||
@@ -42,7 +46,7 @@ class TestReg(unittest.TestCase):
|
||||
class TestTypedView(unittest.TestCase):
|
||||
def test_bit_slice(self):
|
||||
r = Reg(0xDEADBEEF)
|
||||
# Slices return SliceProxy which supports .u32, .u16 etc (matching pseudocode like S1.u32[1:0].u32)
|
||||
# Slices return TypedView which supports .u32, .u16 etc (matching pseudocode like S1.u32[1:0].u32)
|
||||
self.assertEqual(r.u32[7:0].u32, 0xEF)
|
||||
self.assertEqual(r.u32[15:8].u32, 0xBE)
|
||||
self.assertEqual(r.u32[23:16].u32, 0xAD)
|
||||
@@ -67,7 +71,7 @@ class TestTypedView(unittest.TestCase):
|
||||
# S0.u32[S1.u32[4:0]] - access bit at position from another register
|
||||
s0 = Reg(0b11010101)
|
||||
s1 = Reg(3)
|
||||
bit_pos = s1.u32[4:0] # SliceProxy, int value = 3
|
||||
bit_pos = s1.u32[4:0] # TypedView, int value = 3
|
||||
bit_val = s0.u32[int(bit_pos)] # bit 3 of s0 = 0
|
||||
self.assertEqual(int(bit_pos), 3)
|
||||
self.assertEqual(bit_val, 0)
|
||||
@@ -85,7 +89,7 @@ class TestTypedView(unittest.TestCase):
|
||||
self.assertFalse(r1.u32 < r2.u32)
|
||||
self.assertTrue(r1.u32 != r2.u32)
|
||||
|
||||
class TestSliceProxy(unittest.TestCase):
|
||||
class TestTypedView(unittest.TestCase):
|
||||
def test_slice_read(self):
|
||||
r = Reg(0x56781234)
|
||||
self.assertEqual(r[15:0].u16, 0x1234)
|
||||
@@ -154,19 +158,19 @@ class TestExecContext(unittest.TestCase):
|
||||
self.assertEqual(ctx.SCC._val, 0)
|
||||
|
||||
def test_ternary(self):
|
||||
code = compile_pseudocode("D0.u32 = S0.u32 > S1.u32 ? 1'1U : 1'0U")
|
||||
code = _compile_pseudocode("D0.u32 = S0.u32 > S1.u32 ? 1'1U : 1'0U")
|
||||
ctx = ExecContext(s0=5, s1=3)
|
||||
ctx.run(code)
|
||||
self.assertEqual(ctx.D0._val, 1)
|
||||
|
||||
def test_pack(self):
|
||||
code = compile_pseudocode("D0 = { S1[15:0].u16, S0[15:0].u16 }")
|
||||
code = _compile_pseudocode("D0 = { S1[15:0].u16, S0[15:0].u16 }")
|
||||
ctx = ExecContext(s0=0x1234, s1=0x5678)
|
||||
ctx.run(code)
|
||||
self.assertEqual(ctx.D0._val, 0x56781234)
|
||||
|
||||
def test_tmp_with_typed_access(self):
|
||||
code = compile_pseudocode("""tmp = S0.u32 + S1.u32
|
||||
code = _compile_pseudocode("""tmp = S0.u32 + S1.u32
|
||||
D0.u32 = tmp.u32""")
|
||||
ctx = ExecContext(s0=100, s1=200)
|
||||
ctx.run(code)
|
||||
@@ -174,7 +178,7 @@ D0.u32 = tmp.u32""")
|
||||
|
||||
def test_s_add_u32_pattern(self):
|
||||
# Real pseudocode pattern from S_ADD_U32
|
||||
code = compile_pseudocode("""tmp = 64'U(S0.u32) + 64'U(S1.u32)
|
||||
code = _compile_pseudocode("""tmp = 64'U(S0.u32) + 64'U(S1.u32)
|
||||
SCC = tmp >= 0x100000000ULL ? 1'1U : 1'0U
|
||||
D0.u32 = tmp.u32""")
|
||||
# Test overflow case
|
||||
@@ -184,7 +188,7 @@ D0.u32 = tmp.u32""")
|
||||
self.assertEqual(ctx.SCC._val, 1) # Carry set
|
||||
|
||||
def test_s_add_u32_no_overflow(self):
|
||||
code = compile_pseudocode("""tmp = 64'U(S0.u32) + 64'U(S1.u32)
|
||||
code = _compile_pseudocode("""tmp = 64'U(S0.u32) + 64'U(S1.u32)
|
||||
SCC = tmp >= 0x100000000ULL ? 1'1U : 1'0U
|
||||
D0.u32 = tmp.u32""")
|
||||
ctx = ExecContext(s0=100, s1=200)
|
||||
@@ -206,7 +210,7 @@ D0.u32 = tmp.u32""")
|
||||
|
||||
def test_for_loop(self):
|
||||
# CTZ pattern - find first set bit
|
||||
code = compile_pseudocode("""tmp = -1
|
||||
code = _compile_pseudocode("""tmp = -1
|
||||
for i in 0 : 31 do
|
||||
if S0.u32[i] == 1 then
|
||||
tmp = i
|
||||
@@ -261,15 +265,15 @@ class TestPseudocodeRegressions(unittest.TestCase):
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
|
||||
self.assertEqual(result['D0'] & 1, 0, "Signaling NaN should not match quiet NaN mask")
|
||||
|
||||
def test_isnan_with_typed_view(self):
|
||||
"""_isnan must work with TypedView objects, not just Python floats.
|
||||
Bug: _isnan checked isinstance(x, float) which returned False for TypedView."""
|
||||
def testisNAN_with_typed_view(self):
|
||||
"""isNAN must work with TypedView objects, not just Python floats.
|
||||
Bug: isNAN checked isinstance(x, float) which returned False for TypedView."""
|
||||
nan_reg = Reg(0x7fc00000) # quiet NaN
|
||||
normal_reg = Reg(0x3f800000) # 1.0
|
||||
inf_reg = Reg(0x7f800000) # +inf
|
||||
self.assertTrue(_isnan(nan_reg.f32), "_isnan should return True for NaN TypedView")
|
||||
self.assertFalse(_isnan(normal_reg.f32), "_isnan should return False for normal TypedView")
|
||||
self.assertFalse(_isnan(inf_reg.f32), "_isnan should return False for inf TypedView")
|
||||
self.assertTrue(isNAN(nan_reg.f32), "isNAN should return True for NaN TypedView")
|
||||
self.assertFalse(isNAN(normal_reg.f32), "isNAN should return False for normal TypedView")
|
||||
self.assertFalse(isNAN(inf_reg.f32), "isNAN should return False for inf TypedView")
|
||||
|
||||
class TestBF16(unittest.TestCase):
|
||||
"""Tests for BF16 (bfloat16) support."""
|
||||
@@ -308,7 +312,7 @@ class TestBF16(unittest.TestCase):
|
||||
self.assertAlmostEqual(float(r.bf16), 3.0, places=1)
|
||||
|
||||
def test_bf16_slice_property(self):
|
||||
"""Test SliceProxy.bf16 property."""
|
||||
"""Test TypedView.bf16 property."""
|
||||
r = Reg(0x40404040) # Two bf16 3.0 values
|
||||
self.assertAlmostEqual(r[15:0].bf16, 3.0, places=1)
|
||||
self.assertAlmostEqual(r[31:16].bf16, 3.0, places=1)
|
||||
|
||||
Reference in New Issue
Block a user