mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
write python emulator from RDNA3 psuedocode in pdf (#13841)
* write python emulator from RDNA3 psuedocode in pdf * emu2 * more emu * working * more psueod * progress * cleanups * delete junk * delete stale files * just emu * work * emu compare * bemu * cleanups and more failures * revert bench emu * fix emu cmp * four tests fail * bugfixes * dsl * ext * refactor * dsl * div scale fix * test_emu * fix emu tests * pcode * test pcode * top imports * fix test_emu to use run_asm * emu tests on real hardware * more tests * more emu tests * more * work * work * bug fix * bugfixes * fix fp16 gemm * all ops tests pass in emulator * fix llvm tests * fix a few more tests * fix mockgpu timeout
This commit is contained in:
10
CLAUDE.md
10
CLAUDE.md
@@ -79,7 +79,7 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
|
||||
|
||||
## Common Environment Variables
|
||||
|
||||
- `DEBUG=1-4` - Increasing verbosity
|
||||
- `DEBUG=1-7` - Increasing verbosity (7 shows assembly output)
|
||||
- `VIZ=1` - Enable graph visualization
|
||||
- `SPEC=1` - Enable UOp spec verification
|
||||
- `NOOPT=1` - Disable optimizations
|
||||
@@ -100,6 +100,14 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
|
||||
- Run tests before proposing commits
|
||||
- Test with `SPEC=2` when modifying UOp-related code
|
||||
|
||||
## Auto-generated Files (DO NOT EDIT)
|
||||
|
||||
The following files are auto-generated and should never be edited manually:
|
||||
- `extra/assembly/rdna3/autogen/gen_pcode.py` - Generated by `python -m extra.assembly.rdna3.pcode`
|
||||
- `extra/assembly/rdna3/autogen/__init__.py` - Generated from AMD ISA definitions
|
||||
|
||||
To add missing instruction implementations, add them to `extra/assembly/rdna3/emu.py` instead.
|
||||
|
||||
## Style Notes
|
||||
|
||||
- 2-space indentation, 150 char line limit
|
||||
|
||||
@@ -1,254 +0,0 @@
|
||||
# Pure combinational ALU functions for RDNA3 emulation
|
||||
from __future__ import annotations
|
||||
import struct, math
|
||||
from typing import Callable
|
||||
from extra.assembly.rdna3.autogen import SOP1Op, SOP2Op, SOPCOp, SOPKOp, VOP1Op, VOP2Op, VOP3Op
|
||||
|
||||
# Format base offsets for unified opcode space
|
||||
SOP2_BASE, SOP1_BASE, SOPC_BASE, SOPK_BASE = 0x000, 0x100, 0x200, 0x300
|
||||
VOP2_BASE, VOP1_BASE = 0x100, 0x180
|
||||
|
||||
# Float conversion helpers
|
||||
_I, _f, _H, _e = struct.Struct('<I'), struct.Struct('<f'), struct.Struct('<H'), struct.Struct('<e')
|
||||
def f32(i: int) -> float: return _f.unpack(_I.pack(i & 0xffffffff))[0]
|
||||
def i32(f: float) -> int:
|
||||
if math.isinf(f): return 0x7f800000 if f > 0 else 0xff800000
|
||||
try: return _I.unpack(_f.pack(f))[0]
|
||||
except (OverflowError, struct.error): return 0x7f800000 if f > 0 else 0xff800000
|
||||
def f16(i: int) -> float: return _e.unpack(_H.pack(i & 0xffff))[0]
|
||||
def i16(f: float) -> int:
|
||||
if math.isinf(f): return 0x7c00 if f > 0 else 0xfc00
|
||||
try: return _H.unpack(_e.pack(f))[0]
|
||||
except (OverflowError, struct.error): return 0x7c00 if f > 0 else 0xfc00
|
||||
def sext(v: int, b: int) -> int: return v - (1 << b) if v & (1 << (b-1)) else v
|
||||
def clz(x: int) -> int: return 32 - x.bit_length() if x else 32
|
||||
def cls(x: int) -> int: x &= 0xffffffff; return 31 if x in (0, 0xffffffff) else clz(~x & 0xffffffff if x >> 31 else x) - 1
|
||||
def _cvt_i32_f32(v): return (0x7fffffff if v > 0 else 0x80000000) if math.isinf(v) else (0 if math.isnan(v) else max(-0x80000000, min(0x7fffffff, int(v))) & 0xffffffff)
|
||||
def _cvt_u32_f32(v): return (0xffffffff if v > 0 else 0) if math.isinf(v) else (0 if math.isnan(v) or v < 0 else min(0xffffffff, int(v)))
|
||||
|
||||
# SALU: op -> fn(s0, s1, scc_in) -> (result, scc_out)
|
||||
SALU: dict[int, Callable] = {
|
||||
# SOP2
|
||||
SOP2_BASE + SOP2Op.S_ADD_U32: lambda a, b, scc: ((a + b) & 0xffffffff, int((a + b) >= 0x100000000)),
|
||||
SOP2_BASE + SOP2Op.S_SUB_U32: lambda a, b, scc: ((a - b) & 0xffffffff, int(b > a)),
|
||||
SOP2_BASE + SOP2Op.S_ADDC_U32: lambda a, b, scc: ((r := a + b + scc) & 0xffffffff, int(r >= 0x100000000)),
|
||||
SOP2_BASE + SOP2Op.S_SUBB_U32: lambda a, b, scc: ((a - b - scc) & 0xffffffff, int((b + scc) > a)),
|
||||
SOP2_BASE + SOP2Op.S_ADD_I32: lambda a, b, scc: ((r := sext(a, 32) + sext(b, 32)) & 0xffffffff, int(((a >> 31) == (b >> 31)) and ((a >> 31) != ((r >> 31) & 1)))),
|
||||
SOP2_BASE + SOP2Op.S_SUB_I32: lambda a, b, scc: ((r := sext(a, 32) - sext(b, 32)) & 0xffffffff, int(((a >> 31) != (b >> 31)) and ((a >> 31) != ((r >> 31) & 1)))),
|
||||
SOP2_BASE + SOP2Op.S_AND_B32: lambda a, b, scc: ((r := a & b), int(r != 0)),
|
||||
SOP2_BASE + SOP2Op.S_OR_B32: lambda a, b, scc: ((r := a | b), int(r != 0)),
|
||||
SOP2_BASE + SOP2Op.S_XOR_B32: lambda a, b, scc: ((r := a ^ b), int(r != 0)),
|
||||
SOP2_BASE + SOP2Op.S_AND_NOT1_B32: lambda a, b, scc: ((r := a & (~b & 0xffffffff)), int(r != 0)),
|
||||
SOP2_BASE + SOP2Op.S_OR_NOT1_B32: lambda a, b, scc: ((r := a | (~b & 0xffffffff)), int(r != 0)),
|
||||
SOP2_BASE + SOP2Op.S_LSHL_B32: lambda a, b, scc: ((r := (a << (b & 0x1f)) & 0xffffffff), int(r != 0)),
|
||||
SOP2_BASE + SOP2Op.S_LSHR_B32: lambda a, b, scc: ((r := a >> (b & 0x1f)), int(r != 0)),
|
||||
SOP2_BASE + SOP2Op.S_ASHR_I32: lambda a, b, scc: ((r := sext(a, 32) >> (b & 0x1f)) & 0xffffffff, int(r != 0)),
|
||||
SOP2_BASE + SOP2Op.S_MUL_I32: lambda a, b, scc: ((sext(a, 32) * sext(b, 32)) & 0xffffffff, scc),
|
||||
SOP2_BASE + SOP2Op.S_MUL_HI_U32: lambda a, b, scc: (((a * b) >> 32) & 0xffffffff, scc),
|
||||
SOP2_BASE + SOP2Op.S_MUL_HI_I32: lambda a, b, scc: (((sext(a, 32) * sext(b, 32)) >> 32) & 0xffffffff, scc),
|
||||
SOP2_BASE + SOP2Op.S_MIN_I32: lambda a, b, scc: (a, 1) if sext(a, 32) < sext(b, 32) else (b, 0),
|
||||
SOP2_BASE + SOP2Op.S_MIN_U32: lambda a, b, scc: (a, 1) if a < b else (b, 0),
|
||||
SOP2_BASE + SOP2Op.S_MAX_I32: lambda a, b, scc: (a, 1) if sext(a, 32) > sext(b, 32) else (b, 0),
|
||||
SOP2_BASE + SOP2Op.S_MAX_U32: lambda a, b, scc: (a, 1) if a > b else (b, 0),
|
||||
SOP2_BASE + SOP2Op.S_CSELECT_B32: lambda a, b, scc: (a if scc else b, scc),
|
||||
SOP2_BASE + SOP2Op.S_BFE_U32: lambda a, b, scc: ((r := ((a >> (b & 0x1f)) & ((1 << ((b >> 16) & 0x7f)) - 1)) if (b >> 16) & 0x7f else 0), int(r != 0)),
|
||||
SOP2_BASE + SOP2Op.S_BFE_I32: lambda a, b, scc: ((r := sext((a >> (b & 0x1f)) & ((1 << w) - 1), w) & 0xffffffff if (w := (b >> 16) & 0x7f) else 0), int(r != 0)),
|
||||
SOP2_BASE + SOP2Op.S_PACK_LL_B32_B16: lambda a, b, scc: ((a & 0xffff) | ((b & 0xffff) << 16), scc),
|
||||
SOP2_BASE + SOP2Op.S_PACK_LH_B32_B16: lambda a, b, scc: ((a & 0xffff) | (b & 0xffff0000), scc),
|
||||
SOP2_BASE + SOP2Op.S_PACK_HH_B32_B16: lambda a, b, scc: (((a >> 16) & 0xffff) | (b & 0xffff0000), scc),
|
||||
SOP2_BASE + SOP2Op.S_PACK_HL_B32_B16: lambda a, b, scc: (((a >> 16) & 0xffff) | ((b & 0xffff) << 16), scc),
|
||||
SOP2_BASE + SOP2Op.S_ADD_F32: lambda a, b, scc: (i32(f32(a) + f32(b)), scc),
|
||||
SOP2_BASE + SOP2Op.S_SUB_F32: lambda a, b, scc: (i32(f32(a) - f32(b)), scc),
|
||||
SOP2_BASE + SOP2Op.S_MUL_F32: lambda a, b, scc: (i32(f32(a) * f32(b)), scc),
|
||||
# SOP1
|
||||
SOP1_BASE + SOP1Op.S_MOV_B32: lambda a, b, scc: (a, scc),
|
||||
SOP1_BASE + SOP1Op.S_NOT_B32: lambda a, b, scc: ((r := (~a) & 0xffffffff), int(r != 0)),
|
||||
SOP1_BASE + SOP1Op.S_BREV_B32: lambda a, b, scc: (int(f'{a & 0xffffffff:032b}'[::-1], 2), scc),
|
||||
SOP1_BASE + SOP1Op.S_CLZ_I32_U32: lambda a, b, scc: (clz(a), scc),
|
||||
SOP1_BASE + SOP1Op.S_CLS_I32: lambda a, b, scc: (cls(a), scc),
|
||||
SOP1_BASE + SOP1Op.S_SEXT_I32_I8: lambda a, b, scc: (sext(a & 0xff, 8) & 0xffffffff, scc),
|
||||
SOP1_BASE + SOP1Op.S_SEXT_I32_I16: lambda a, b, scc: (sext(a & 0xffff, 16) & 0xffffffff, scc),
|
||||
SOP1_BASE + SOP1Op.S_ABS_I32: lambda a, b, scc: ((r := abs(sext(a, 32)) & 0xffffffff), int(r != 0)),
|
||||
SOP1_BASE + SOP1Op.S_CVT_F32_I32: lambda a, b, scc: (i32(float(sext(a, 32))), scc),
|
||||
SOP1_BASE + SOP1Op.S_CVT_F32_U32: lambda a, b, scc: (i32(float(a)), scc),
|
||||
SOP1_BASE + SOP1Op.S_CVT_I32_F32: lambda a, b, scc: (_cvt_i32_f32(f32(a)), scc),
|
||||
SOP1_BASE + SOP1Op.S_CVT_U32_F32: lambda a, b, scc: (_cvt_u32_f32(f32(a)), scc),
|
||||
SOP1_BASE + SOP1Op.S_CEIL_F32: lambda a, b, scc: (i32(math.ceil(f32(a))), scc),
|
||||
SOP1_BASE + SOP1Op.S_FLOOR_F32: lambda a, b, scc: (i32(math.floor(f32(a))), scc),
|
||||
SOP1_BASE + SOP1Op.S_TRUNC_F32: lambda a, b, scc: (i32(math.trunc(f32(a))), scc),
|
||||
SOP1_BASE + SOP1Op.S_RNDNE_F32: lambda a, b, scc: (i32(round(f32(a))), scc),
|
||||
SOP1_BASE + SOP1Op.S_CVT_F16_F32: lambda a, b, scc: (i16(f32(a)), scc),
|
||||
SOP1_BASE + SOP1Op.S_CVT_F32_F16: lambda a, b, scc: (i32(f16(a)), scc),
|
||||
# SOPC
|
||||
SOPC_BASE + SOPCOp.S_CMP_EQ_I32: lambda a, b, scc: (0, int(sext(a, 32) == sext(b, 32))),
|
||||
SOPC_BASE + SOPCOp.S_CMP_LG_I32: lambda a, b, scc: (0, int(sext(a, 32) != sext(b, 32))),
|
||||
SOPC_BASE + SOPCOp.S_CMP_GT_I32: lambda a, b, scc: (0, int(sext(a, 32) > sext(b, 32))),
|
||||
SOPC_BASE + SOPCOp.S_CMP_GE_I32: lambda a, b, scc: (0, int(sext(a, 32) >= sext(b, 32))),
|
||||
SOPC_BASE + SOPCOp.S_CMP_LT_I32: lambda a, b, scc: (0, int(sext(a, 32) < sext(b, 32))),
|
||||
SOPC_BASE + SOPCOp.S_CMP_LE_I32: lambda a, b, scc: (0, int(sext(a, 32) <= sext(b, 32))),
|
||||
SOPC_BASE + SOPCOp.S_CMP_EQ_U32: lambda a, b, scc: (0, int(a == b)),
|
||||
SOPC_BASE + SOPCOp.S_CMP_LG_U32: lambda a, b, scc: (0, int(a != b)),
|
||||
SOPC_BASE + SOPCOp.S_CMP_GT_U32: lambda a, b, scc: (0, int(a > b)),
|
||||
SOPC_BASE + SOPCOp.S_CMP_GE_U32: lambda a, b, scc: (0, int(a >= b)),
|
||||
SOPC_BASE + SOPCOp.S_CMP_LT_U32: lambda a, b, scc: (0, int(a < b)),
|
||||
SOPC_BASE + SOPCOp.S_CMP_LE_U32: lambda a, b, scc: (0, int(a <= b)),
|
||||
SOPC_BASE + SOPCOp.S_BITCMP0_B32: lambda a, b, scc: (0, int((a & (1 << (b & 0x1f))) == 0)),
|
||||
SOPC_BASE + SOPCOp.S_BITCMP1_B32: lambda a, b, scc: (0, int((a & (1 << (b & 0x1f))) != 0)),
|
||||
# SOPK
|
||||
SOPK_BASE + SOPKOp.S_MOVK_I32: lambda a, b, scc: (sext(b, 16) & 0xffffffff, scc),
|
||||
SOPK_BASE + SOPKOp.S_CMOVK_I32: lambda a, b, scc: ((sext(b, 16) & 0xffffffff) if scc else a, scc),
|
||||
SOPK_BASE + SOPKOp.S_ADDK_I32: lambda a, b, scc: ((r := sext(a, 32) + sext(b, 16)) & 0xffffffff, int(((a >> 31) == ((b >> 15) & 1)) and ((a >> 31) != ((r >> 31) & 1)))),
|
||||
SOPK_BASE + SOPKOp.S_MULK_I32: lambda a, b, scc: ((sext(a, 32) * sext(b, 16)) & 0xffffffff, scc),
|
||||
SOPK_BASE + SOPKOp.S_CMPK_EQ_I32: lambda a, b, scc: (0, int(sext(a, 32) == sext(b, 16))),
|
||||
SOPK_BASE + SOPKOp.S_CMPK_LG_I32: lambda a, b, scc: (0, int(sext(a, 32) != sext(b, 16))),
|
||||
SOPK_BASE + SOPKOp.S_CMPK_GT_I32: lambda a, b, scc: (0, int(sext(a, 32) > sext(b, 16))),
|
||||
SOPK_BASE + SOPKOp.S_CMPK_GE_I32: lambda a, b, scc: (0, int(sext(a, 32) >= sext(b, 16))),
|
||||
SOPK_BASE + SOPKOp.S_CMPK_LT_I32: lambda a, b, scc: (0, int(sext(a, 32) < sext(b, 16))),
|
||||
SOPK_BASE + SOPKOp.S_CMPK_LE_I32: lambda a, b, scc: (0, int(sext(a, 32) <= sext(b, 16))),
|
||||
SOPK_BASE + SOPKOp.S_CMPK_EQ_U32: lambda a, b, scc: (0, int(a == (b & 0xffff))),
|
||||
SOPK_BASE + SOPKOp.S_CMPK_LG_U32: lambda a, b, scc: (0, int(a != (b & 0xffff))),
|
||||
SOPK_BASE + SOPKOp.S_CMPK_GT_U32: lambda a, b, scc: (0, int(a > (b & 0xffff))),
|
||||
SOPK_BASE + SOPKOp.S_CMPK_GE_U32: lambda a, b, scc: (0, int(a >= (b & 0xffff))),
|
||||
SOPK_BASE + SOPKOp.S_CMPK_LT_U32: lambda a, b, scc: (0, int(a < (b & 0xffff))),
|
||||
SOPK_BASE + SOPKOp.S_CMPK_LE_U32: lambda a, b, scc: (0, int(a <= (b & 0xffff))),
|
||||
}
|
||||
|
||||
# VALU: op -> fn(s0, s1, s2) -> result
|
||||
VALU: dict[int, Callable] = {
|
||||
# VOP2
|
||||
VOP2_BASE + VOP2Op.V_ADD_F32: lambda a, b, c: i32(f32(a) + f32(b)),
|
||||
VOP2_BASE + VOP2Op.V_SUB_F32: lambda a, b, c: i32(f32(a) - f32(b)),
|
||||
VOP2_BASE + VOP2Op.V_SUBREV_F32: lambda a, b, c: i32(f32(b) - f32(a)),
|
||||
VOP2_BASE + VOP2Op.V_MUL_F32: lambda a, b, c: i32(f32(a) * f32(b)),
|
||||
VOP2_BASE + VOP2Op.V_MIN_F32: lambda a, b, c: i32(min(f32(a), f32(b))),
|
||||
VOP2_BASE + VOP2Op.V_MAX_F32: lambda a, b, c: i32(max(f32(a), f32(b))),
|
||||
VOP2_BASE + VOP2Op.V_ADD_NC_U32: lambda a, b, c: (a + b) & 0xffffffff,
|
||||
VOP2_BASE + VOP2Op.V_SUB_NC_U32: lambda a, b, c: (a - b) & 0xffffffff,
|
||||
VOP2_BASE + VOP2Op.V_SUBREV_NC_U32: lambda a, b, c: (b - a) & 0xffffffff,
|
||||
VOP2_BASE + VOP2Op.V_AND_B32: lambda a, b, c: a & b,
|
||||
VOP2_BASE + VOP2Op.V_OR_B32: lambda a, b, c: a | b,
|
||||
VOP2_BASE + VOP2Op.V_XOR_B32: lambda a, b, c: a ^ b,
|
||||
VOP2_BASE + VOP2Op.V_XNOR_B32: lambda a, b, c: (~(a ^ b)) & 0xffffffff,
|
||||
VOP2_BASE + VOP2Op.V_LSHLREV_B32: lambda a, b, c: (b << (a & 0x1f)) & 0xffffffff,
|
||||
VOP2_BASE + VOP2Op.V_LSHRREV_B32: lambda a, b, c: b >> (a & 0x1f),
|
||||
VOP2_BASE + VOP2Op.V_ASHRREV_I32: lambda a, b, c: (sext(b, 32) >> (a & 0x1f)) & 0xffffffff,
|
||||
VOP2_BASE + VOP2Op.V_MIN_I32: lambda a, b, c: a if sext(a, 32) < sext(b, 32) else b,
|
||||
VOP2_BASE + VOP2Op.V_MAX_I32: lambda a, b, c: a if sext(a, 32) > sext(b, 32) else b,
|
||||
VOP2_BASE + VOP2Op.V_MIN_U32: lambda a, b, c: min(a, b),
|
||||
VOP2_BASE + VOP2Op.V_MAX_U32: lambda a, b, c: max(a, b),
|
||||
VOP2_BASE + VOP2Op.V_MUL_I32_I24: lambda a, b, c: (sext(a & 0xffffff, 24) * sext(b & 0xffffff, 24)) & 0xffffffff,
|
||||
VOP2_BASE + VOP2Op.V_MUL_HI_I32_I24: lambda a, b, c: ((sext(a & 0xffffff, 24) * sext(b & 0xffffff, 24)) >> 32) & 0xffffffff,
|
||||
VOP2_BASE + VOP2Op.V_MUL_U32_U24: lambda a, b, c: ((a & 0xffffff) * (b & 0xffffff)) & 0xffffffff,
|
||||
VOP2_BASE + VOP2Op.V_MUL_HI_U32_U24: lambda a, b, c: (((a & 0xffffff) * (b & 0xffffff)) >> 32) & 0xffffffff,
|
||||
VOP2_BASE + VOP2Op.V_CVT_PK_RTZ_F16_F32: lambda a, b, c: i16(f32(a)) | (i16(f32(b)) << 16),
|
||||
VOP2_BASE + VOP2Op.V_LDEXP_F16: lambda a, b, c: i16(math.ldexp(f16(a), sext(b, 32))),
|
||||
VOP2_BASE + VOP2Op.V_ADD_F16: lambda a, b, c: i16(f16(a) + f16(b)),
|
||||
VOP2_BASE + VOP2Op.V_SUB_F16: lambda a, b, c: i16(f16(a) - f16(b)),
|
||||
VOP2_BASE + VOP2Op.V_MUL_F16: lambda a, b, c: i16(f16(a) * f16(b)),
|
||||
VOP2_BASE + VOP2Op.V_MIN_F16: lambda a, b, c: i16(min(f16(a), f16(b))),
|
||||
VOP2_BASE + VOP2Op.V_MAX_F16: lambda a, b, c: i16(max(f16(a), f16(b))),
|
||||
# VOP1
|
||||
VOP1_BASE + VOP1Op.V_MOV_B32: lambda a, b, c: a,
|
||||
VOP1_BASE + VOP1Op.V_NOT_B32: lambda a, b, c: (~a) & 0xffffffff,
|
||||
VOP1_BASE + VOP1Op.V_BFREV_B32: lambda a, b, c: int(f'{a & 0xffffffff:032b}'[::-1], 2),
|
||||
VOP1_BASE + VOP1Op.V_CLZ_I32_U32: lambda a, b, c: clz(a),
|
||||
VOP1_BASE + VOP1Op.V_CLS_I32: lambda a, b, c: cls(a),
|
||||
VOP1_BASE + VOP1Op.V_CVT_F32_I32: lambda a, b, c: i32(float(sext(a, 32))),
|
||||
VOP1_BASE + VOP1Op.V_CVT_F32_U32: lambda a, b, c: i32(float(a)),
|
||||
VOP1_BASE + VOP1Op.V_CVT_I32_F32: lambda a, b, c: _cvt_i32_f32(f32(a)),
|
||||
VOP1_BASE + VOP1Op.V_CVT_U32_F32: lambda a, b, c: _cvt_u32_f32(f32(a)),
|
||||
VOP1_BASE + VOP1Op.V_CVT_F16_F32: lambda a, b, c: i16(f32(a)),
|
||||
VOP1_BASE + VOP1Op.V_CVT_F32_F16: lambda a, b, c: i32(f16(a)),
|
||||
VOP1_BASE + VOP1Op.V_RCP_F32: lambda a, b, c: i32(1.0 / f32(a) if f32(a) != 0 else math.copysign(float('inf'), f32(a))),
|
||||
VOP1_BASE + VOP1Op.V_RCP_IFLAG_F32: lambda a, b, c: i32(1.0 / f32(a) if f32(a) != 0 else math.copysign(float('inf'), f32(a))),
|
||||
VOP1_BASE + VOP1Op.V_RSQ_F32: lambda a, b, c: i32(1.0 / math.sqrt(f32(a)) if f32(a) > 0 else (float('nan') if f32(a) < 0 else float('inf'))),
|
||||
VOP1_BASE + VOP1Op.V_SQRT_F32: lambda a, b, c: i32(math.sqrt(f32(a)) if f32(a) >= 0 else float('nan')),
|
||||
VOP1_BASE + VOP1Op.V_LOG_F32: lambda a, b, c: i32(math.log2(f32(a)) if f32(a) > 0 else (float('-inf') if f32(a) == 0 else float('nan'))),
|
||||
VOP1_BASE + VOP1Op.V_EXP_F32: lambda a, b, c: i32(float('inf') if f32(a) > 128 else (0.0 if f32(a) < -150 else math.pow(2.0, f32(a)))),
|
||||
VOP1_BASE + VOP1Op.V_SIN_F32: lambda a, b, c: i32(math.sin(f32(a) * 2 * math.pi)),
|
||||
VOP1_BASE + VOP1Op.V_COS_F32: lambda a, b, c: i32(math.cos(f32(a) * 2 * math.pi)),
|
||||
VOP1_BASE + VOP1Op.V_FLOOR_F32: lambda a, b, c: i32(math.floor(f32(a))),
|
||||
VOP1_BASE + VOP1Op.V_CEIL_F32: lambda a, b, c: i32(math.ceil(f32(a))),
|
||||
VOP1_BASE + VOP1Op.V_TRUNC_F32: lambda a, b, c: i32(math.trunc(f32(a))),
|
||||
VOP1_BASE + VOP1Op.V_RNDNE_F32: lambda a, b, c: i32(round(f32(a))),
|
||||
VOP1_BASE + VOP1Op.V_FRACT_F32: lambda a, b, c: i32((v := f32(a)) - math.floor(v)),
|
||||
VOP1_BASE + VOP1Op.V_CVT_F32_UBYTE0: lambda a, b, c: i32(float(a & 0xff)),
|
||||
VOP1_BASE + VOP1Op.V_CVT_F32_UBYTE1: lambda a, b, c: i32(float((a >> 8) & 0xff)),
|
||||
VOP1_BASE + VOP1Op.V_CVT_F32_UBYTE2: lambda a, b, c: i32(float((a >> 16) & 0xff)),
|
||||
VOP1_BASE + VOP1Op.V_CVT_F32_UBYTE3: lambda a, b, c: i32(float((a >> 24) & 0xff)),
|
||||
VOP1_BASE + VOP1Op.V_FREXP_MANT_F32: lambda a, b, c: i32(math.frexp(v)[0] if (v := f32(a)) != 0 else 0.0),
|
||||
VOP1_BASE + VOP1Op.V_FREXP_EXP_I32_F32: lambda a, b, c: (math.frexp(v)[1] if (v := f32(a)) != 0 else 0) & 0xffffffff,
|
||||
# VOP3
|
||||
VOP3Op.V_FMA_F32: lambda a, b, c: i32(f32(a) * f32(b) + f32(c)),
|
||||
VOP3Op.V_DIV_FMAS_F32: lambda a, b, c: i32(f32(a) * f32(b) + f32(c)),
|
||||
VOP3Op.V_ADD3_U32: lambda a, b, c: (a + b + c) & 0xffffffff,
|
||||
VOP3Op.V_LSHL_ADD_U32: lambda a, b, c: ((a << (b & 0x1f)) + c) & 0xffffffff,
|
||||
VOP3Op.V_ADD_LSHL_U32: lambda a, b, c: ((a + b) << (c & 0x1f)) & 0xffffffff,
|
||||
VOP3Op.V_XOR3_B32: lambda a, b, c: a ^ b ^ c,
|
||||
VOP3Op.V_OR3_B32: lambda a, b, c: a | b | c,
|
||||
VOP3Op.V_AND_OR_B32: lambda a, b, c: (a & b) | c,
|
||||
VOP3Op.V_LSHL_OR_B32: lambda a, b, c: ((a << (b & 0x1f)) | c) & 0xffffffff,
|
||||
VOP3Op.V_XAD_U32: lambda a, b, c: ((a ^ b) + c) & 0xffffffff,
|
||||
VOP3Op.V_MAD_U32_U24: lambda a, b, c: ((a & 0xffffff) * (b & 0xffffff) + c) & 0xffffffff,
|
||||
VOP3Op.V_MAD_I32_I24: lambda a, b, c: (sext(a & 0xffffff, 24) * sext(b & 0xffffff, 24) + sext(c, 32)) & 0xffffffff,
|
||||
VOP3Op.V_BFE_U32: lambda a, b, c: (a >> (b & 0x1f)) & ((1 << (c & 0x1f)) - 1) if c & 0x1f else 0,
|
||||
VOP3Op.V_BFE_I32: lambda a, b, c: sext((a >> (b & 0x1f)) & ((1 << w) - 1), w) & 0xffffffff if (w := c & 0x1f) else 0,
|
||||
VOP3Op.V_ALIGNBIT_B32: lambda a, b, c: (((a << 32) | b) >> (c & 0x1f)) & 0xffffffff,
|
||||
VOP3Op.V_MUL_LO_U32: lambda a, b, c: (a * b) & 0xffffffff,
|
||||
VOP3Op.V_MUL_HI_U32: lambda a, b, c: ((a * b) >> 32) & 0xffffffff,
|
||||
VOP3Op.V_MUL_HI_I32: lambda a, b, c: ((sext(a, 32) * sext(b, 32)) >> 32) & 0xffffffff,
|
||||
VOP3Op.V_LDEXP_F32: lambda a, b, c: i32(math.ldexp(f32(a), sext(b, 32))),
|
||||
VOP3Op.V_DIV_FIXUP_F32: lambda a, b, c: i32(math.copysign(float('inf'), f32(c)) if f32(b) == 0.0 else f32(c) / f32(b)),
|
||||
VOP3Op.V_PACK_B32_F16: lambda a, b, c: (a & 0xffff) | ((b & 0xffff) << 16),
|
||||
VOP3Op.V_CVT_PK_RTZ_F16_F32: lambda a, b, c: i16(f32(a)) | (i16(f32(b)) << 16),
|
||||
VOP3Op.V_LSHLREV_B16: lambda a, b, c: ((b & 0xffff) << (a & 0xf)) & 0xffff,
|
||||
VOP3Op.V_LSHRREV_B16: lambda a, b, c: (b & 0xffff) >> (a & 0xf),
|
||||
VOP3Op.V_ASHRREV_I16: lambda a, b, c: (sext(b & 0xffff, 16) >> (a & 0xf)) & 0xffff,
|
||||
VOP3Op.V_ADD_NC_U16: lambda a, b, c: ((a & 0xffff) + (b & 0xffff)) & 0xffff,
|
||||
VOP3Op.V_SUB_NC_U16: lambda a, b, c: ((a & 0xffff) - (b & 0xffff)) & 0xffff,
|
||||
VOP3Op.V_MUL_LO_U16: lambda a, b, c: ((a & 0xffff) * (b & 0xffff)) & 0xffff,
|
||||
VOP3Op.V_MIN_U16: lambda a, b, c: min(a & 0xffff, b & 0xffff),
|
||||
VOP3Op.V_MAX_U16: lambda a, b, c: max(a & 0xffff, b & 0xffff),
|
||||
VOP3Op.V_MIN_I16: lambda a, b, c: (a & 0xffff) if sext(a & 0xffff, 16) < sext(b & 0xffff, 16) else (b & 0xffff),
|
||||
VOP3Op.V_MAX_I16: lambda a, b, c: (a & 0xffff) if sext(a & 0xffff, 16) > sext(b & 0xffff, 16) else (b & 0xffff),
|
||||
VOP3Op.V_MAD_U16: lambda a, b, c: ((a & 0xffff) * (b & 0xffff) + (c & 0xffff)) & 0xffff,
|
||||
VOP3Op.V_MAD_I16: lambda a, b, c: (sext(a & 0xffff, 16) * sext(b & 0xffff, 16) + sext(c & 0xffff, 16)) & 0xffff,
|
||||
VOP3Op.V_FMA_F16: lambda a, b, c: i16(f16(a) * f16(b) + f16(c)),
|
||||
VOP3Op.V_MIN3_I32: lambda a, b, c: sorted([sext(a, 32), sext(b, 32), sext(c, 32)])[0] & 0xffffffff,
|
||||
VOP3Op.V_MAX3_I32: lambda a, b, c: sorted([sext(a, 32), sext(b, 32), sext(c, 32)])[2] & 0xffffffff,
|
||||
VOP3Op.V_MED3_I32: lambda a, b, c: sorted([sext(a, 32), sext(b, 32), sext(c, 32)])[1] & 0xffffffff,
|
||||
VOP3Op.V_MIN3_F16: lambda a, b, c: i16(min(f16(a), f16(b), f16(c))),
|
||||
VOP3Op.V_MAX3_F16: lambda a, b, c: i16(max(f16(a), f16(b), f16(c))),
|
||||
VOP3Op.V_MED3_F16: lambda a, b, c: i16(sorted([f16(a), f16(b), f16(c)])[1]),
|
||||
VOP3Op.V_MIN3_U16: lambda a, b, c: min(a & 0xffff, b & 0xffff, c & 0xffff),
|
||||
VOP3Op.V_MAX3_U16: lambda a, b, c: max(a & 0xffff, b & 0xffff, c & 0xffff),
|
||||
VOP3Op.V_MED3_U16: lambda a, b, c: sorted([a & 0xffff, b & 0xffff, c & 0xffff])[1],
|
||||
VOP3Op.V_MIN3_I16: lambda a, b, c: sorted([sext(a & 0xffff, 16), sext(b & 0xffff, 16), sext(c & 0xffff, 16)])[0] & 0xffff,
|
||||
VOP3Op.V_MAX3_I16: lambda a, b, c: sorted([sext(a & 0xffff, 16), sext(b & 0xffff, 16), sext(c & 0xffff, 16)])[2] & 0xffff,
|
||||
VOP3Op.V_MED3_I16: lambda a, b, c: sorted([sext(a & 0xffff, 16), sext(b & 0xffff, 16), sext(c & 0xffff, 16)])[1] & 0xffff,
|
||||
}
|
||||
|
||||
def _cmp8(a, b): return [False, a < b, a == b, a <= b, a > b, a != b, a >= b, True]
|
||||
def _cmp6(a, b): return [a < b, a == b, a <= b, a > b, a != b, a >= b]
|
||||
|
||||
def vopc(op: int, s0: int, s1: int, s0_hi: int = 0, s1_hi: int = 0) -> int:
|
||||
base = op & 0x7f
|
||||
if 16 <= base <= 31: # F32
|
||||
f0, f1, cmp, nan = f32(s0), f32(s1), base - 16, math.isnan(f32(s0)) or math.isnan(f32(s1))
|
||||
return int([False, f0<f1, f0==f1, f0<=f1, f0>f1, f0!=f1, f0>=f1, not nan, nan, f0<f1 or nan, f0==f1 or nan, f0<=f1 or nan, f0>f1 or nan, f0!=f1 or nan, f0>=f1 or nan, True][cmp])
|
||||
if 49 <= base <= 54: return int(_cmp6(sext(s0 & 0xffff, 16), sext(s1 & 0xffff, 16))[base - 49]) # I16
|
||||
if 57 <= base <= 62: return int(_cmp6(s0 & 0xffff, s1 & 0xffff)[base - 57]) # U16
|
||||
if 64 <= base <= 79: # I32/U32
|
||||
cmp = (base - 64) % 8
|
||||
return int(_cmp8(sext(s0, 32), sext(s1, 32))[cmp] if base < 72 else _cmp8(s0, s1)[cmp])
|
||||
if 80 <= base <= 95: # I64/U64
|
||||
s0_64, s1_64 = s0 | (s0_hi << 32), s1 | (s1_hi << 32)
|
||||
return int(_cmp8(sext(s0_64, 64), sext(s1_64, 64))[(base - 80) % 8] if base < 88 else _cmp8(s0_64, s1_64)[(base - 80) % 8])
|
||||
if base == 126: # CLASS_F32
|
||||
f, mask = f32(s0), s1
|
||||
if math.isnan(f): return int(bool(mask & 0x3))
|
||||
if math.isinf(f): return int(bool(mask & (0x4 if f < 0 else 0x200)))
|
||||
if f == 0.0: return int(bool(mask & (0x20 if (s0 >> 31) & 1 else 0x40)))
|
||||
exp, sign = (s0 >> 23) & 0xff, (s0 >> 31) & 1
|
||||
return int(bool(mask & ((0x10 if sign else 0x80) if exp == 0 else (0x8 if sign else 0x100))))
|
||||
raise NotImplementedError(f"VOPC op {op} (base {base})")
|
||||
@@ -2638,16 +2638,16 @@ v_add_nc_u32_e32 = functools.partial(VOP2, VOP2Op.V_ADD_NC_U32)
|
||||
v_sub_nc_u32_e32 = functools.partial(VOP2, VOP2Op.V_SUB_NC_U32)
|
||||
v_subrev_nc_u32_e32 = functools.partial(VOP2, VOP2Op.V_SUBREV_NC_U32)
|
||||
v_fmac_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAC_F32)
|
||||
v_fmamk_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAMK_F32)
|
||||
v_fmaak_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAAK_F32)
|
||||
def v_fmamk_f32_e32(vdst, src0, K, vsrc1): return VOP2(VOP2Op.V_FMAMK_F32, vdst, src0, vsrc1, literal=K)
|
||||
def v_fmaak_f32_e32(vdst, src0, vsrc1, K): return VOP2(VOP2Op.V_FMAAK_F32, vdst, src0, vsrc1, literal=K)
|
||||
v_cvt_pk_rtz_f16_f32_e32 = functools.partial(VOP2, VOP2Op.V_CVT_PK_RTZ_F16_F32)
|
||||
v_add_f16_e32 = functools.partial(VOP2, VOP2Op.V_ADD_F16)
|
||||
v_sub_f16_e32 = functools.partial(VOP2, VOP2Op.V_SUB_F16)
|
||||
v_subrev_f16_e32 = functools.partial(VOP2, VOP2Op.V_SUBREV_F16)
|
||||
v_mul_f16_e32 = functools.partial(VOP2, VOP2Op.V_MUL_F16)
|
||||
v_fmac_f16_e32 = functools.partial(VOP2, VOP2Op.V_FMAC_F16)
|
||||
v_fmamk_f16_e32 = functools.partial(VOP2, VOP2Op.V_FMAMK_F16)
|
||||
v_fmaak_f16_e32 = functools.partial(VOP2, VOP2Op.V_FMAAK_F16)
|
||||
def v_fmamk_f16_e32(vdst, src0, K, vsrc1): return VOP2(VOP2Op.V_FMAMK_F16, vdst, src0, vsrc1, literal=K)
|
||||
def v_fmaak_f16_e32(vdst, src0, vsrc1, K): return VOP2(VOP2Op.V_FMAAK_F16, vdst, src0, vsrc1, literal=K)
|
||||
v_max_f16_e32 = functools.partial(VOP2, VOP2Op.V_MAX_F16)
|
||||
v_min_f16_e32 = functools.partial(VOP2, VOP2Op.V_MIN_F16)
|
||||
v_ldexp_f16_e32 = functools.partial(VOP2, VOP2Op.V_LDEXP_F16)
|
||||
|
||||
16706
extra/assembly/rdna3/autogen/gen_pcode.py
Normal file
16706
extra/assembly/rdna3/autogen/gen_pcode.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -178,7 +178,15 @@ def generate(output_path: pathlib.Path|str|None = None) -> dict:
|
||||
suffix = "_e64"
|
||||
else:
|
||||
suffix = ""
|
||||
lines.append(f"{name.lower()}{suffix} = functools.partial({tgt}.{name}{seg})")
|
||||
# FMAMK/FMAAK have a literal constant K that must be passed via literal= kwarg
|
||||
# FMAMK: D = S0.f * K + S1.f (K is 3rd operand in assembly syntax)
|
||||
# FMAAK: D = S0.f * S1.f + K (K is 4th operand in assembly syntax)
|
||||
if name in ('V_FMAMK_F32', 'V_FMAMK_F16'):
|
||||
lines.append(f"def {name.lower()}{suffix}(vdst, src0, K, vsrc1): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
|
||||
elif name in ('V_FMAAK_F32', 'V_FMAAK_F16'):
|
||||
lines.append(f"def {name.lower()}{suffix}(vdst, src0, vsrc1, K): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
|
||||
else:
|
||||
lines.append(f"{name.lower()}{suffix} = functools.partial({tgt}.{name}{seg})")
|
||||
# export SrcEnum values, but skip DPP8/DPP16 which conflict with class names
|
||||
skip_exports = {'DPP8', 'DPP16'}
|
||||
lines += [""] + [f"{name} = SrcEnum.{name}" for _, name in sorted(src_enum.items()) if name not in skip_exports] + ["OFF = NULL\n"]
|
||||
|
||||
@@ -169,11 +169,13 @@ class Inst:
|
||||
cur_neg = self._values.get('neg', 0)
|
||||
self._values['neg'] = (cur_neg.val if isinstance(cur_neg, RawImm) else cur_neg) | neg_bit
|
||||
# Track literal value if needed (encoded as 255)
|
||||
# For 64-bit ops, store literal in high 32 bits (to match from_bytes decoding and to_bytes encoding)
|
||||
if encoded == 255 and self._literal is None and isinstance(val, int) and not isinstance(val, IntEnum):
|
||||
self._literal = val
|
||||
self._literal = (val << 32) if self._is_64bit_op() else val
|
||||
elif encoded == 255 and self._literal is None and isinstance(val, float):
|
||||
import struct
|
||||
self._literal = struct.unpack('<I', struct.pack('<f', val))[0]
|
||||
lit32 = struct.unpack('<I', struct.pack('<f', val))[0]
|
||||
self._literal = (lit32 << 32) if self._is_64bit_op() else lit32
|
||||
# Encode raw register fields for consistent repr
|
||||
elif name in RAW_FIELDS:
|
||||
if isinstance(val, Reg): self._values[name] = _encode_reg(val)
|
||||
@@ -206,13 +208,35 @@ class Inst:
|
||||
if n in self._values and not isinstance(v := self._values[n], RawImm) and isinstance(v, int) and not isinstance(v, IntEnum) and not (0 <= v <= 64 or -16 <= v <= -1): return v
|
||||
return None
|
||||
|
||||
def _is_64bit_op(self) -> bool:
|
||||
"""Check if this instruction uses 64-bit operands (and thus 64-bit literals).
|
||||
Exception: V_LDEXP_F64 has 32-bit integer src1, so its literal is 32-bit."""
|
||||
op = self._values.get('op')
|
||||
if op is None: return False
|
||||
# op may be an enum (from __init__) or an int (from from_int)
|
||||
op_name = op.name if hasattr(op, 'name') else None
|
||||
if op_name is None and self.__class__.__name__ == 'VOP3':
|
||||
from extra.assembly.rdna3.autogen import VOP3Op
|
||||
try: op_name = VOP3Op(op).name
|
||||
except ValueError: pass
|
||||
if op_name is None: return False
|
||||
# V_LDEXP_F64 has 32-bit integer exponent in src1, so literal is 32-bit
|
||||
if op_name == 'V_LDEXP_F64': return False
|
||||
return op_name.endswith(('_F64', '_B64', '_I64', '_U64'))
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
result = self.to_int().to_bytes(self._size(), 'little')
|
||||
return result + (lit & 0xffffffff).to_bytes(4, 'little') if (lit := self._get_literal() or getattr(self, '_literal', None)) else result
|
||||
lit = self._get_literal() or getattr(self, '_literal', None)
|
||||
if lit is None: return result
|
||||
# For 64-bit ops, literal is stored in high 32 bits internally, but encoded as 4 bytes
|
||||
lit32 = (lit >> 32) if self._is_64bit_op() else lit
|
||||
return result + (lit32 & 0xffffffff).to_bytes(4, 'little')
|
||||
|
||||
@classmethod
|
||||
def _size(cls) -> int: return 4 if issubclass(cls, Inst32) else 8
|
||||
def size(self) -> int: return self._size() + (4 if self._literal is not None else 0)
|
||||
def size(self) -> int:
|
||||
# Literal is always 4 bytes in the binary (for 64-bit ops, it's in high 32 bits)
|
||||
return self._size() + (4 if self._literal is not None else 0)
|
||||
|
||||
@classmethod
|
||||
def from_int(cls, word: int):
|
||||
@@ -229,7 +253,12 @@ class Inst:
|
||||
has_literal = has_literal or (cls.__name__ == 'SOP2' and op_val in (69, 70))
|
||||
for n in SRC_FIELDS:
|
||||
if n in inst._values and isinstance(inst._values[n], RawImm) and inst._values[n].val == 255: has_literal = True
|
||||
if has_literal and len(data) >= cls._size() + 4: inst._literal = int.from_bytes(data[cls._size():cls._size()+4], 'little')
|
||||
if has_literal:
|
||||
# For 64-bit ops, the literal is 32 bits placed in the HIGH 32 bits of the 64-bit value
|
||||
# (low 32 bits are zero). This is how AMD hardware interprets 32-bit literals for 64-bit ops.
|
||||
if len(data) >= cls._size() + 4:
|
||||
lit32 = int.from_bytes(data[cls._size():cls._size()+4], 'little')
|
||||
inst._literal = (lit32 << 32) if inst._is_64bit_op() else lit32
|
||||
return inst
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
910
extra/assembly/rdna3/pcode.py
Normal file
910
extra/assembly/rdna3/pcode.py
Normal file
@@ -0,0 +1,910 @@
|
||||
# DSL for RDNA3 pseudocode - makes pseudocode expressions work directly as Python
|
||||
import struct, math, re
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# HELPER FUNCTIONS (previously in helpers.py)
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _f32(i): return struct.unpack("<f", struct.pack("<I", i & 0xffffffff))[0]
|
||||
def _i32(f):
|
||||
if isinstance(f, int): f = float(f)
|
||||
if math.isnan(f): return 0xffc00000 if math.copysign(1.0, f) < 0 else 0x7fc00000
|
||||
if math.isinf(f): return 0x7f800000 if f > 0 else 0xff800000
|
||||
try: return struct.unpack("<I", struct.pack("<f", f))[0]
|
||||
except (OverflowError, struct.error): return 0x7f800000 if f > 0 else 0xff800000
|
||||
def _div(a, b):
|
||||
try: return a / b
|
||||
except ZeroDivisionError:
|
||||
if a == 0.0 or math.isnan(a): return float("nan")
|
||||
return math.copysign(float("inf"), a * b) if b == 0.0 else float("inf") if a > 0 else float("-inf")
|
||||
def _sext(v, b): return v - (1 << b) if v & (1 << (b - 1)) else v
|
||||
def _f16(i): return struct.unpack("<e", struct.pack("<H", i & 0xffff))[0]
|
||||
def _i16(f):
|
||||
if math.isnan(f): return 0x7e00
|
||||
if math.isinf(f): return 0x7c00 if f > 0 else 0xfc00
|
||||
try: return struct.unpack("<H", struct.pack("<e", f))[0]
|
||||
except (OverflowError, struct.error): return 0x7c00 if f > 0 else 0xfc00
|
||||
def _to_f16_bits(v): return v if isinstance(v, int) else _i16(v)
|
||||
def _f64(i): return struct.unpack("<d", struct.pack("<Q", i & 0xffffffffffffffff))[0]
|
||||
def _i64(f):
|
||||
if math.isnan(f): return 0x7ff8000000000000
|
||||
if math.isinf(f): return 0x7ff0000000000000 if f > 0 else 0xfff0000000000000
|
||||
try: return struct.unpack("<Q", struct.pack("<d", f))[0]
|
||||
except (OverflowError, struct.error): return 0x7ff0000000000000 if f > 0 else 0xfff0000000000000
|
||||
def _isnan(x):
|
||||
try: return math.isnan(float(x))
|
||||
except (TypeError, ValueError): return False
|
||||
def _isquietnan(x):
|
||||
"""Check if x is a quiet NaN. For f32: exponent=255, bit22=1, mantissa!=0"""
|
||||
try:
|
||||
if not math.isnan(float(x)): return False
|
||||
# Get raw bits from TypedView or similar object with _reg attribute
|
||||
if hasattr(x, '_reg') and hasattr(x, '_bits'):
|
||||
bits = x._reg._val & ((1 << x._bits) - 1)
|
||||
if x._bits == 32:
|
||||
return ((bits >> 23) & 0xff) == 255 and ((bits >> 22) & 1) == 1 and (bits & 0x7fffff) != 0
|
||||
if x._bits == 64:
|
||||
return ((bits >> 52) & 0x7ff) == 0x7ff and ((bits >> 51) & 1) == 1 and (bits & 0xfffffffffffff) != 0
|
||||
return True # Default to quiet NaN if we can't determine bit pattern
|
||||
except (TypeError, ValueError): return False
|
||||
def _issignalnan(x):
|
||||
"""Check if x is a signaling NaN. For f32: exponent=255, bit22=0, mantissa!=0"""
|
||||
try:
|
||||
if not math.isnan(float(x)): return False
|
||||
# Get raw bits from TypedView or similar object with _reg attribute
|
||||
if hasattr(x, '_reg') and hasattr(x, '_bits'):
|
||||
bits = x._reg._val & ((1 << x._bits) - 1)
|
||||
if x._bits == 32:
|
||||
return ((bits >> 23) & 0xff) == 255 and ((bits >> 22) & 1) == 0 and (bits & 0x7fffff) != 0
|
||||
if x._bits == 64:
|
||||
return ((bits >> 52) & 0x7ff) == 0x7ff and ((bits >> 51) & 1) == 0 and (bits & 0xfffffffffffff) != 0
|
||||
return False # Default to not signaling if we can't determine bit pattern
|
||||
except (TypeError, ValueError): return False
|
||||
def _gt_neg_zero(a, b): return (a > b) or (a == 0 and b == 0 and not math.copysign(1, a) < 0 and math.copysign(1, b) < 0)
|
||||
def _lt_neg_zero(a, b): return (a < b) or (a == 0 and b == 0 and math.copysign(1, a) < 0 and not math.copysign(1, b) < 0)
|
||||
def _fma(a, b, c): return a * b + c
|
||||
def _signext(v): return v
|
||||
def trunc(x):
|
||||
x = float(x)
|
||||
return x if math.isnan(x) or math.isinf(x) else float(math.trunc(x))
|
||||
def floor(x):
|
||||
x = float(x)
|
||||
return x if math.isnan(x) or math.isinf(x) else float(math.floor(x))
|
||||
def ceil(x):
|
||||
x = float(x)
|
||||
return x if math.isnan(x) or math.isinf(x) else float(math.ceil(x))
|
||||
def sqrt(x): return math.sqrt(x) if x >= 0 else float("nan")
|
||||
def log2(x): return math.log2(x) if x > 0 else (float("-inf") if x == 0 else float("nan"))
|
||||
i32_to_f32 = u32_to_f32 = i32_to_f64 = u32_to_f64 = f32_to_f64 = f64_to_f32 = float
|
||||
def f32_to_i32(f):
|
||||
f = float(f)
|
||||
if math.isnan(f): return 0
|
||||
if f >= 2147483647: return 2147483647
|
||||
if f <= -2147483648: return -2147483648
|
||||
return int(f)
|
||||
def f32_to_u32(f):
|
||||
f = float(f)
|
||||
if math.isnan(f): return 0
|
||||
if f >= 4294967295: return 4294967295
|
||||
if f <= 0: return 0
|
||||
return int(f)
|
||||
f64_to_i32 = f32_to_i32
|
||||
f64_to_u32 = f32_to_u32
|
||||
def f32_to_f16(f):
|
||||
f = float(f)
|
||||
if math.isnan(f): return 0x7e00 # f16 NaN
|
||||
if math.isinf(f): return 0x7c00 if f > 0 else 0xfc00 # f16 ±infinity
|
||||
try: return struct.unpack("<H", struct.pack("<e", f))[0]
|
||||
except OverflowError: return 0x7c00 if f > 0 else 0xfc00 # overflow -> ±infinity
|
||||
def _f16_to_f32_bits(bits): return struct.unpack("<e", struct.pack("<H", int(bits) & 0xffff))[0]
|
||||
def f16_to_f32(v): return v if isinstance(v, float) else _f16_to_f32_bits(v)
|
||||
def i16_to_f16(v): return f32_to_f16(float(_sext(int(v) & 0xffff, 16)))
|
||||
def u16_to_f16(v): return f32_to_f16(float(int(v) & 0xffff))
|
||||
def f16_to_i16(bits): f = _f16_to_f32_bits(bits); return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
|
||||
def f16_to_u16(bits): f = _f16_to_f32_bits(bits); return max(0, min(65535, int(f))) if not math.isnan(f) else 0
|
||||
def _sign(f): return 1 if math.copysign(1.0, f) < 0 else 0
|
||||
def _mantissa_f32(f): return struct.unpack("<I", struct.pack("<f", f))[0] & 0x7fffff if not (math.isinf(f) or math.isnan(f)) else 0
|
||||
def _ldexp(m, e): return math.ldexp(m, e)
|
||||
def isEven(x): return int(x) % 2 == 0
|
||||
def fract(x): return x - math.floor(x)
|
||||
PI = math.pi
|
||||
def sin(x):
|
||||
# V_SIN_F32: pseudocode does sin(input * 2π), but hardware does frac on the input first
|
||||
# So sin(1.0 * 2π) should be sin(frac(1.0) * 2π) = sin(0) = 0
|
||||
if math.isinf(x) or math.isnan(x): return float("nan")
|
||||
# The input x is already multiplied by 2π in the pseudocode, so we need to
|
||||
# extract the fractional cycle: frac(x / 2π) * 2π
|
||||
cycles = x / (2 * math.pi)
|
||||
frac_cycles = cycles - math.floor(cycles)
|
||||
return math.sin(frac_cycles * 2 * math.pi)
|
||||
def cos(x):
|
||||
# V_COS_F32: same as sin, hardware does frac on input cycles
|
||||
if math.isinf(x) or math.isnan(x): return float("nan")
|
||||
cycles = x / (2 * math.pi)
|
||||
frac_cycles = cycles - math.floor(cycles)
|
||||
return math.cos(frac_cycles * 2 * math.pi)
|
||||
def pow(a, b):
|
||||
try: return a ** b
|
||||
except OverflowError: return float("inf") if b > 0 else 0.0
|
||||
def _brev32(v): return int(bin(v & 0xffffffff)[2:].zfill(32)[::-1], 2)
|
||||
def _brev64(v): return int(bin(v & 0xffffffffffffffff)[2:].zfill(64)[::-1], 2)
|
||||
def _ctz32(v):
|
||||
v = int(v) & 0xffffffff
|
||||
if v == 0: return 32
|
||||
n = 0
|
||||
while (v & 1) == 0: v >>= 1; n += 1
|
||||
return n
|
||||
def _ctz64(v):
|
||||
v = int(v) & 0xffffffffffffffff
|
||||
if v == 0: return 64
|
||||
n = 0
|
||||
while (v & 1) == 0: v >>= 1; n += 1
|
||||
return n
|
||||
def _exponent(f):
|
||||
if math.isinf(f) or math.isnan(f): return 255
|
||||
if f == 0.0: return 0
|
||||
try: bits = struct.unpack("<I", struct.pack("<f", float(f)))[0]; return (bits >> 23) & 0xff
|
||||
except: return 0
|
||||
def _is_denorm_f32(f):
|
||||
if not isinstance(f, float): f = _f32(int(f) & 0xffffffff)
|
||||
if math.isinf(f) or math.isnan(f) or f == 0.0: return False
|
||||
bits = struct.unpack("<I", struct.pack("<f", float(f)))[0]
|
||||
return (bits >> 23) & 0xff == 0
|
||||
def _is_denorm_f64(f):
|
||||
if not isinstance(f, float): f = _f64(int(f) & 0xffffffffffffffff)
|
||||
if math.isinf(f) or math.isnan(f) or f == 0.0: return False
|
||||
bits = struct.unpack("<Q", struct.pack("<d", float(f)))[0]
|
||||
return (bits >> 52) & 0x7ff == 0
|
||||
def v_min_f32(a, b):
|
||||
if math.isnan(b): return a
|
||||
if math.isnan(a): return b
|
||||
return a if _lt_neg_zero(a, b) else b
|
||||
def v_max_f32(a, b):
|
||||
if math.isnan(b): return a
|
||||
if math.isnan(a): return b
|
||||
return a if _gt_neg_zero(a, b) else b
|
||||
def v_min_i32(a, b): return min(a, b)
|
||||
def v_max_i32(a, b): return max(a, b)
|
||||
def v_min_u32(a, b): return min(a & 0xffffffff, b & 0xffffffff)
|
||||
def v_max_u32(a, b): return max(a & 0xffffffff, b & 0xffffffff)
|
||||
v_min_f16 = v_min_f32
|
||||
v_max_f16 = v_max_f32
|
||||
v_min_i16 = v_min_i32
|
||||
v_max_i16 = v_max_i32
|
||||
def v_min_u16(a, b): return min(a & 0xffff, b & 0xffff)
|
||||
def v_max_u16(a, b): return max(a & 0xffff, b & 0xffff)
|
||||
def v_min3_f32(a, b, c): return v_min_f32(v_min_f32(a, b), c)
|
||||
def v_max3_f32(a, b, c): return v_max_f32(v_max_f32(a, b), c)
|
||||
def v_min3_i32(a, b, c): return min(a, b, c)
|
||||
def v_max3_i32(a, b, c): return max(a, b, c)
|
||||
def v_min3_u32(a, b, c): return min(a & 0xffffffff, b & 0xffffffff, c & 0xffffffff)
|
||||
def v_max3_u32(a, b, c): return max(a & 0xffffffff, b & 0xffffffff, c & 0xffffffff)
|
||||
v_min3_f16 = v_min3_f32
|
||||
v_max3_f16 = v_max3_f32
|
||||
v_min3_i16 = v_min3_i32
|
||||
v_max3_i16 = v_max3_i32
|
||||
def v_min3_u16(a, b, c): return min(a & 0xffff, b & 0xffff, c & 0xffff)
|
||||
def v_max3_u16(a, b, c): return max(a & 0xffff, b & 0xffff, c & 0xffff)
|
||||
def ABSDIFF(a, b): return abs(a - b)
|
||||
def f16_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
|
||||
def f16_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535))))
|
||||
def f32_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
|
||||
def f32_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535))))
|
||||
def v_cvt_i16_f32(f): return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
|
||||
def v_cvt_u16_f32(f): return max(0, min(65535, int(f))) if not math.isnan(f) else 0
|
||||
def u32_to_u16(u): return int(u) & 0xffff
|
||||
def i32_to_i16(i): return ((int(i) + 32768) & 0xffff) - 32768
|
||||
def SAT8(v): return max(0, min(255, int(v)))
|
||||
def f32_to_u8(f): return max(0, min(255, int(f))) if not math.isnan(f) else 0
|
||||
def mantissa(f):
|
||||
if f == 0.0 or math.isinf(f) or math.isnan(f): return f
|
||||
m, _ = math.frexp(f)
|
||||
return math.copysign(m * 2.0, f)
|
||||
def signext_from_bit(val, bit):
|
||||
bit = int(bit)
|
||||
if bit == 0: return 0
|
||||
mask = (1 << bit) - 1
|
||||
val = int(val) & mask
|
||||
if val & (1 << (bit - 1)): return val - (1 << bit)
|
||||
return val
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# DSL EXPORTS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
__all__ = [
|
||||
# Classes
|
||||
'Reg', 'SliceProxy', 'TypedView', 'ExecContext', 'compile_pseudocode',
|
||||
# Pack functions
|
||||
'_pack', '_pack32', 'pack', 'pack32',
|
||||
# Constants
|
||||
'WAVE32', 'WAVE64', 'MASK32', 'MASK64', 'WAVE_MODE', 'DENORM', 'OVERFLOW_F32', 'UNDERFLOW_F32',
|
||||
'OVERFLOW_F64', 'UNDERFLOW_F64', 'MAX_FLOAT_F32', 'ROUND_MODE', 'cvtToQuietNAN', 'DST', 'INF', 'PI',
|
||||
# Aliases for pseudocode
|
||||
's_ff1_i32_b32', 's_ff1_i32_b64', 'GT_NEG_ZERO', 'LT_NEG_ZERO',
|
||||
'isNAN', 'isQuietNAN', 'isSignalNAN', 'fma', 'ldexp', 'sign', 'exponent', 'F', 'signext',
|
||||
# Conversion functions
|
||||
'_f32', '_i32', '_f16', '_i16', '_f64', '_i64', '_sext', '_to_f16_bits', '_f16_to_f32_bits',
|
||||
'i32_to_f32', 'u32_to_f32', 'i32_to_f64', 'u32_to_f64', 'f32_to_f64', 'f64_to_f32',
|
||||
'f32_to_i32', 'f32_to_u32', 'f64_to_i32', 'f64_to_u32', 'f32_to_f16', 'f16_to_f32',
|
||||
'i16_to_f16', 'u16_to_f16', 'f16_to_i16', 'f16_to_u16', 'u32_to_u16', 'i32_to_i16',
|
||||
'f16_to_snorm', 'f16_to_unorm', 'f32_to_snorm', 'f32_to_unorm', 'v_cvt_i16_f32', 'v_cvt_u16_f32',
|
||||
'SAT8', 'f32_to_u8',
|
||||
# Math functions
|
||||
'trunc', 'floor', 'ceil', 'sqrt', 'log2', 'sin', 'cos', 'pow', 'fract', 'isEven', 'mantissa',
|
||||
# Min/max functions
|
||||
'v_min_f32', 'v_max_f32', 'v_min_i32', 'v_max_i32', 'v_min_u32', 'v_max_u32',
|
||||
'v_min_f16', 'v_max_f16', 'v_min_i16', 'v_max_i16', 'v_min_u16', 'v_max_u16',
|
||||
'v_min3_f32', 'v_max3_f32', 'v_min3_i32', 'v_max3_i32', 'v_min3_u32', 'v_max3_u32',
|
||||
'v_min3_f16', 'v_max3_f16', 'v_min3_i16', 'v_max3_i16', 'v_min3_u16', 'v_max3_u16',
|
||||
'ABSDIFF',
|
||||
# Bit manipulation
|
||||
'_brev32', '_brev64', '_ctz32', '_ctz64', '_exponent', '_is_denorm_f32', '_is_denorm_f64',
|
||||
'_sign', '_mantissa_f32', '_div', '_isnan', '_isquietnan', '_issignalnan', '_gt_neg_zero', '_lt_neg_zero', '_fma', '_ldexp', '_signext',
|
||||
'signext_from_bit',
|
||||
]
|
||||
|
||||
# Aliases used in pseudocode
|
||||
s_ff1_i32_b32, s_ff1_i32_b64 = _ctz32, _ctz64
|
||||
GT_NEG_ZERO, LT_NEG_ZERO = _gt_neg_zero, _lt_neg_zero
|
||||
isNAN = _isnan
|
||||
isQuietNAN = _isquietnan
|
||||
isSignalNAN = _issignalnan
|
||||
fma, ldexp, sign, exponent = _fma, _ldexp, _sign, _exponent
|
||||
def F(x):
|
||||
"""32'F(x) or 64'F(x) - interpret x as float. If x is int, treat as bit pattern."""
|
||||
if isinstance(x, int): return _f32(x) # int -> interpret as f32 bits
|
||||
if isinstance(x, TypedView): return x # preserve TypedView for bit-pattern checks
|
||||
return float(x) # already a float or float-like
|
||||
signext = lambda x: x
|
||||
pack = lambda hi, lo: ((int(hi) & 0xffff) << 16) | (int(lo) & 0xffff)
|
||||
pack32 = lambda hi, lo: ((int(hi) & 0xffffffff) << 32) | (int(lo) & 0xffffffff)
|
||||
_pack, _pack32 = pack, pack32 # Aliases for internal use
|
||||
WAVE32, WAVE64 = True, False
|
||||
|
||||
# Float overflow/underflow constants
|
||||
OVERFLOW_F32 = float('inf')
|
||||
UNDERFLOW_F32 = 0.0
|
||||
OVERFLOW_F64 = float('inf')
|
||||
UNDERFLOW_F64 = 0.0
|
||||
MAX_FLOAT_F32 = 3.4028235e+38 # Largest finite float32
|
||||
|
||||
# INF object that supports .f16/.f32/.f64 access and comparison with floats
|
||||
class _Inf:
|
||||
f16 = f32 = f64 = float('inf')
|
||||
def __neg__(self): return _NegInf()
|
||||
def __pos__(self): return self
|
||||
def __eq__(self, other): return float(other) == float('inf') if not isinstance(other, _NegInf) else False
|
||||
def __req__(self, other): return self.__eq__(other)
|
||||
class _NegInf:
|
||||
f16 = f32 = f64 = float('-inf')
|
||||
def __neg__(self): return _Inf()
|
||||
def __pos__(self): return self
|
||||
def __eq__(self, other): return float(other) == float('-inf') if not isinstance(other, _Inf) else False
|
||||
def __req__(self, other): return self.__eq__(other)
|
||||
INF = _Inf()
|
||||
|
||||
# Rounding mode placeholder
|
||||
class _RoundMode:
|
||||
NEAREST_EVEN = 0
|
||||
ROUND_MODE = _RoundMode()
|
||||
|
||||
# Helper functions for pseudocode
|
||||
def cvtToQuietNAN(x): return float('nan')
|
||||
DST = None # Placeholder, will be set in context
|
||||
|
||||
MASK32, MASK64 = 0xffffffff, 0xffffffffffffffff
|
||||
|
||||
class _WaveMode:
|
||||
IEEE = False
|
||||
WAVE_MODE = _WaveMode()
|
||||
|
||||
class _DenormChecker:
|
||||
"""Comparator for denormalized floats. x == DENORM.f32 checks if x is denormalized."""
|
||||
def __init__(self, bits): self._bits = bits
|
||||
def _check(self, other):
|
||||
return _is_denorm_f64(float(other)) if self._bits == 64 else _is_denorm_f32(float(other))
|
||||
def __eq__(self, other): return self._check(other)
|
||||
def __req__(self, other): return self._check(other)
|
||||
def __ne__(self, other): return not self._check(other)
|
||||
|
||||
class _Denorm:
|
||||
f32 = _DenormChecker(32)
|
||||
f64 = _DenormChecker(64)
|
||||
DENORM = _Denorm()
|
||||
|
||||
def _brev(v, bits):
|
||||
"""Bit-reverse a value."""
|
||||
result = 0
|
||||
for i in range(bits): result |= ((v >> i) & 1) << (bits - 1 - i)
|
||||
return result
|
||||
|
||||
class SliceProxy:
|
||||
"""Proxy for D0[31:16] that supports .f16/.u16 etc getters and setters."""
|
||||
__slots__ = ('_reg', '_high', '_low', '_reversed')
|
||||
def __init__(self, reg, high, low):
|
||||
self._reg = reg
|
||||
# Handle reversed slices like [0:31] which means bit-reverse
|
||||
if high < low: self._high, self._low, self._reversed = low, high, True
|
||||
else: self._high, self._low, self._reversed = high, low, False
|
||||
def _nbits(self): return self._high - self._low + 1
|
||||
def _mask(self): return (1 << self._nbits()) - 1
|
||||
def _get(self):
|
||||
v = (self._reg._val >> self._low) & self._mask()
|
||||
return _brev(v, self._nbits()) if self._reversed else v
|
||||
def _set(self, v):
|
||||
v = int(v)
|
||||
if self._reversed: v = _brev(v, self._nbits())
|
||||
self._reg._val = (self._reg._val & ~(self._mask() << self._low)) | ((v & self._mask()) << self._low)
|
||||
|
||||
u8 = property(lambda s: s._get() & 0xff)
|
||||
u16 = property(lambda s: s._get() & 0xffff, lambda s, v: s._set(v))
|
||||
u32 = property(lambda s: s._get() & MASK32, lambda s, v: s._set(v))
|
||||
i16 = property(lambda s: _sext(s._get() & 0xffff, 16), lambda s, v: s._set(v))
|
||||
i32 = property(lambda s: _sext(s._get() & MASK32, 32), lambda s, v: s._set(v))
|
||||
f16 = property(lambda s: _f16(s._get()), lambda s, v: s._set(v if isinstance(v, int) else _i16(float(v))))
|
||||
f32 = property(lambda s: _f32(s._get()), lambda s, v: s._set(_i32(float(v))))
|
||||
b16, b32 = u16, u32
|
||||
|
||||
def __int__(self): return self._get()
|
||||
def __index__(self): return self._get()
|
||||
|
||||
class TypedView:
|
||||
"""View for S0.u32 that supports [4:0] slicing and [bit] access."""
|
||||
__slots__ = ('_reg', '_bits', '_signed', '_float')
|
||||
def __init__(self, reg, bits, signed=False, is_float=False):
|
||||
self._reg, self._bits, self._signed, self._float = reg, bits, signed, is_float
|
||||
|
||||
@property
|
||||
def _val(self):
|
||||
mask = MASK64 if self._bits == 64 else MASK32 if self._bits == 32 else (1 << self._bits) - 1
|
||||
return self._reg._val & mask
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, slice):
|
||||
high, low = int(key.start), int(key.stop)
|
||||
return SliceProxy(self._reg, high, low)
|
||||
return (self._val >> int(key)) & 1
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if isinstance(key, slice):
|
||||
high, low = int(key.start), int(key.stop)
|
||||
if high < low: high, low, value = low, high, _brev(int(value), low - high + 1)
|
||||
mask = (1 << (high - low + 1)) - 1
|
||||
self._reg._val = (self._reg._val & ~(mask << low)) | ((int(value) & mask) << low)
|
||||
elif value: self._reg._val |= (1 << int(key))
|
||||
else: self._reg._val &= ~(1 << int(key))
|
||||
|
||||
def __int__(self): return _sext(self._val, self._bits) if self._signed else self._val
|
||||
def __index__(self): return int(self)
|
||||
def __trunc__(self): return int(float(self)) if self._float else int(self)
|
||||
def __float__(self):
|
||||
if self._float:
|
||||
return _f16(self._val) if self._bits == 16 else _f32(self._val) if self._bits == 32 else _f64(self._val)
|
||||
return float(int(self))
|
||||
|
||||
# Arithmetic - floats use float(), ints use int()
|
||||
def __add__(s, o): return float(s) + float(o) if s._float else int(s) + int(o)
|
||||
def __radd__(s, o): return float(o) + float(s) if s._float else int(o) + int(s)
|
||||
def __sub__(s, o): return float(s) - float(o) if s._float else int(s) - int(o)
|
||||
def __rsub__(s, o): return float(o) - float(s) if s._float else int(o) - int(s)
|
||||
def __mul__(s, o): return float(s) * float(o) if s._float else int(s) * int(o)
|
||||
def __rmul__(s, o): return float(o) * float(s) if s._float else int(o) * int(s)
|
||||
def __truediv__(s, o): return _div(float(s), float(o)) if s._float else _div(int(s), int(o))
|
||||
def __rtruediv__(s, o): return _div(float(o), float(s)) if s._float else _div(int(o), int(s))
|
||||
def __pow__(s, o): return float(s) ** float(o) if s._float else int(s) ** int(o)
|
||||
def __rpow__(s, o): return float(o) ** float(s) if s._float else int(o) ** int(s)
|
||||
def __neg__(s): return -float(s) if s._float else -int(s)
|
||||
def __abs__(s): return abs(float(s)) if s._float else abs(int(s))
|
||||
|
||||
# Bitwise - GPU shifts mask the shift amount to valid range
|
||||
def __and__(s, o): return int(s) & int(o)
|
||||
def __or__(s, o): return int(s) | int(o)
|
||||
def __xor__(s, o): return int(s) ^ int(o)
|
||||
def __invert__(s): return ~int(s)
|
||||
def __lshift__(s, o): n = int(o); return int(s) << n if 0 <= n < 64 else 0
|
||||
def __rshift__(s, o): n = int(o); return int(s) >> n if 0 <= n < 64 else 0
|
||||
def __rand__(s, o): return int(o) & int(s)
|
||||
def __ror__(s, o): return int(o) | int(s)
|
||||
def __rxor__(s, o): return int(o) ^ int(s)
|
||||
def __rlshift__(s, o): n = int(s); return int(o) << n if 0 <= n < 64 else 0
|
||||
def __rrshift__(s, o): n = int(s); return int(o) >> n if 0 <= n < 64 else 0
|
||||
|
||||
# Comparison - handle _DenormChecker specially
|
||||
def __eq__(s, o):
|
||||
if isinstance(o, _DenormChecker): return o._check(s)
|
||||
return float(s) == float(o) if s._float else int(s) == int(o)
|
||||
def __ne__(s, o):
|
||||
if isinstance(o, _DenormChecker): return not o._check(s)
|
||||
return float(s) != float(o) if s._float else int(s) != int(o)
|
||||
def __lt__(s, o): return float(s) < float(o) if s._float else int(s) < int(o)
|
||||
def __le__(s, o): return float(s) <= float(o) if s._float else int(s) <= int(o)
|
||||
def __gt__(s, o): return float(s) > float(o) if s._float else int(s) > int(o)
|
||||
def __ge__(s, o): return float(s) >= float(o) if s._float else int(s) >= int(o)
|
||||
|
||||
def __bool__(s): return bool(int(s))
|
||||
|
||||
class Reg:
|
||||
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works."""
|
||||
__slots__ = ('_val',)
|
||||
def __init__(self, val=0): self._val = int(val) & MASK64
|
||||
|
||||
# Typed views
|
||||
u64 = property(lambda s: TypedView(s, 64), lambda s, v: setattr(s, '_val', int(v) & MASK64))
|
||||
i64 = property(lambda s: TypedView(s, 64, signed=True), lambda s, v: setattr(s, '_val', int(v) & MASK64))
|
||||
b64 = property(lambda s: TypedView(s, 64), lambda s, v: setattr(s, '_val', int(v) & MASK64))
|
||||
f64 = property(lambda s: TypedView(s, 64, is_float=True), lambda s, v: setattr(s, '_val', v if isinstance(v, int) else _i64(float(v))))
|
||||
u32 = property(lambda s: TypedView(s, 32), lambda s, v: setattr(s, '_val', int(v) & MASK32))
|
||||
i32 = property(lambda s: TypedView(s, 32, signed=True), lambda s, v: setattr(s, '_val', int(v) & MASK32))
|
||||
b32 = property(lambda s: TypedView(s, 32), lambda s, v: setattr(s, '_val', int(v) & MASK32))
|
||||
f32 = property(lambda s: TypedView(s, 32, is_float=True), lambda s, v: setattr(s, '_val', _i32(float(v))))
|
||||
u24 = property(lambda s: TypedView(s, 24))
|
||||
i24 = property(lambda s: TypedView(s, 24, signed=True))
|
||||
u16 = property(lambda s: TypedView(s, 16), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
|
||||
i16 = property(lambda s: TypedView(s, 16, signed=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
|
||||
b16 = property(lambda s: TypedView(s, 16), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
|
||||
f16 = property(lambda s: TypedView(s, 16, is_float=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _i16(float(v))) & 0xffff)))
|
||||
u8 = property(lambda s: TypedView(s, 8))
|
||||
i8 = property(lambda s: TypedView(s, 8, signed=True))
|
||||
|
||||
def __getitem__(s, key):
|
||||
if isinstance(key, slice): return SliceProxy(s, int(key.start), int(key.stop))
|
||||
return (s._val >> int(key)) & 1
|
||||
|
||||
def __setitem__(s, key, value):
|
||||
if isinstance(key, slice):
|
||||
high, low = int(key.start), int(key.stop)
|
||||
mask = (1 << (high - low + 1)) - 1
|
||||
s._val = (s._val & ~(mask << low)) | ((int(value) & mask) << low)
|
||||
elif value: s._val |= (1 << int(key))
|
||||
else: s._val &= ~(1 << int(key))
|
||||
|
||||
def __int__(s): return s._val
|
||||
def __index__(s): return s._val
|
||||
def __bool__(s): return bool(s._val)
|
||||
|
||||
# Arithmetic (for tmp = tmp + 1 patterns). Float operands trigger f32 interpretation.
|
||||
def __add__(s, o): return (_f32(s._val) + float(o)) if isinstance(o, float) else s._val + int(o)
|
||||
def __radd__(s, o): return (float(o) + _f32(s._val)) if isinstance(o, float) else int(o) + s._val
|
||||
def __sub__(s, o): return (_f32(s._val) - float(o)) if isinstance(o, float) else s._val - int(o)
|
||||
def __rsub__(s, o): return (float(o) - _f32(s._val)) if isinstance(o, float) else int(o) - s._val
|
||||
def __mul__(s, o): return (_f32(s._val) * float(o)) if isinstance(o, float) else s._val * int(o)
|
||||
def __rmul__(s, o): return (float(o) * _f32(s._val)) if isinstance(o, float) else int(o) * s._val
|
||||
def __and__(s, o): return s._val & int(o)
|
||||
def __rand__(s, o): return int(o) & s._val
|
||||
def __or__(s, o): return s._val | int(o)
|
||||
def __ror__(s, o): return int(o) | s._val
|
||||
def __xor__(s, o): return s._val ^ int(o)
|
||||
def __rxor__(s, o): return int(o) ^ s._val
|
||||
def __lshift__(s, o): n = int(o); return s._val << n if 0 <= n < 64 else 0
|
||||
def __rshift__(s, o): n = int(o); return s._val >> n if 0 <= n < 64 else 0
|
||||
def __invert__(s): return ~s._val
|
||||
|
||||
# Comparison (for tmp >= 0x100000000 patterns)
|
||||
def __lt__(s, o): return s._val < int(o)
|
||||
def __le__(s, o): return s._val <= int(o)
|
||||
def __gt__(s, o): return s._val > int(o)
|
||||
def __ge__(s, o): return s._val >= int(o)
|
||||
def __eq__(s, o): return s._val == int(o)
|
||||
def __ne__(s, o): return s._val != int(o)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# COMPILER: pseudocode -> Python (minimal transforms)
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def compile_pseudocode(pseudocode: str) -> str:
|
||||
"""Compile pseudocode to Python. Transforms are minimal - most syntax just works."""
|
||||
# Join continuation lines (lines ending with || or && or open paren)
|
||||
raw_lines = pseudocode.strip().split('\n')
|
||||
joined_lines: list[str] = []
|
||||
for line in raw_lines:
|
||||
line = line.strip()
|
||||
if joined_lines and (joined_lines[-1].rstrip().endswith(('||', '&&', '(', ',')) or
|
||||
(joined_lines[-1].count('(') > joined_lines[-1].count(')'))):
|
||||
joined_lines[-1] = joined_lines[-1].rstrip() + ' ' + line
|
||||
else:
|
||||
joined_lines.append(line)
|
||||
|
||||
lines = []
|
||||
indent, need_pass = 0, False
|
||||
for line in joined_lines:
|
||||
line = line.strip()
|
||||
if not line or line.startswith('//'): continue
|
||||
|
||||
# Control flow - only need pass before outdent (endif/endfor/else/elsif)
|
||||
if line.startswith('if '):
|
||||
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line.startswith('elsif '):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
lines.append(' ' * indent + f"elif {_expr(line[6:].rstrip(' then'))}:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line == 'else':
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
lines.append(' ' * indent + "else:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line.startswith('endif'):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
need_pass = False
|
||||
elif line.startswith('endfor'):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
need_pass = False
|
||||
elif line.startswith('declare '):
|
||||
pass
|
||||
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
|
||||
start, end = _expr(m[2].strip()), _expr(m[3].strip())
|
||||
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif '=' in line and not line.startswith('=='):
|
||||
need_pass = False
|
||||
line = line.rstrip(';')
|
||||
# Handle tuple unpacking: { D1.u1, D0.u64 } = expr
|
||||
if m := re.match(r'\{\s*D1\.[ui]1\s*,\s*D0\.[ui]64\s*\}\s*=\s*(.+)', line):
|
||||
rhs = _expr(m[1])
|
||||
lines.append(' ' * indent + f"_full = {rhs}")
|
||||
lines.append(' ' * indent + f"D0.u64 = int(_full) & 0xffffffffffffffff")
|
||||
lines.append(' ' * indent + f"D1 = Reg((int(_full) >> 64) & 1)")
|
||||
# Compound assignment
|
||||
elif any(op in line for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^=')):
|
||||
for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^='):
|
||||
if op in line:
|
||||
lhs, rhs = line.split(op, 1)
|
||||
lines.append(' ' * indent + f"{lhs.strip()} {op} {_expr(rhs.strip())}")
|
||||
break
|
||||
else:
|
||||
lhs, rhs = line.split('=', 1)
|
||||
lines.append(' ' * indent + _assign(lhs.strip(), _expr(rhs.strip())))
|
||||
# If we ended with a control statement that needs a body, add pass
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
return '\n'.join(lines)
|
||||
|
||||
def _assign(lhs: str, rhs: str) -> str:
|
||||
"""Generate assignment. Bare tmp/SCC/etc get wrapped in Reg()."""
|
||||
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec'):
|
||||
return f"{lhs} = Reg({rhs})"
|
||||
return f"{lhs} = {rhs}"
|
||||
|
||||
def _expr(e: str) -> str:
|
||||
"""Expression transform: minimal - just fix syntax differences."""
|
||||
e = e.strip()
|
||||
e = e.replace('&&', ' and ').replace('||', ' or ').replace('<>', ' != ')
|
||||
e = re.sub(r'!([^=])', r' not \1', e)
|
||||
|
||||
# Pack: { hi, lo } -> _pack(hi, lo)
|
||||
e = re.sub(r'\{\s*(\w+\.u32)\s*,\s*(\w+\.u32)\s*\}', r'_pack32(\1, \2)', e)
|
||||
def pack(m):
|
||||
hi, lo = _expr(m[1].strip()), _expr(m[2].strip())
|
||||
return f'_pack({hi}, {lo})'
|
||||
e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e)
|
||||
|
||||
# Literals: 1'0U -> 0, 32'I(x) -> (x), B(x) -> (x)
|
||||
e = re.sub(r"\d+'([0-9a-fA-Fx]+)[UuFf]*", r'\1', e)
|
||||
e = re.sub(r"\d+'[FIBU]\(", "(", e)
|
||||
e = re.sub(r'\bB\(', '(', e) # Bare B( without digit prefix
|
||||
e = re.sub(r'([0-9a-fA-Fx])ULL\b', r'\1', e)
|
||||
e = re.sub(r'([0-9a-fA-Fx])LL\b', r'\1', e)
|
||||
e = re.sub(r'([0-9a-fA-Fx])U\b', r'\1', e)
|
||||
e = re.sub(r'(\d\.?\d*)F\b', r'\1', e)
|
||||
# Remove redundant type suffix after lane access: VCC.u64[laneId].u64 -> VCC.u64[laneId]
|
||||
e = re.sub(r'(\[laneId\])\.[uib]\d+', r'\1', e)
|
||||
|
||||
# Constants - INF is defined as an object supporting .f32/.f64 access
|
||||
e = e.replace('+INF', 'INF').replace('-INF', '(-INF)')
|
||||
e = re.sub(r'NAN\.f\d+', 'float("nan")', e)
|
||||
|
||||
# Recursively process bracket contents to handle nested ternaries like S1.u32[x ? a : b]
|
||||
def process_brackets(s):
|
||||
result, i = [], 0
|
||||
while i < len(s):
|
||||
if s[i] == '[':
|
||||
# Find matching ]
|
||||
depth, start = 1, i + 1
|
||||
j = start
|
||||
while j < len(s) and depth > 0:
|
||||
if s[j] == '[': depth += 1
|
||||
elif s[j] == ']': depth -= 1
|
||||
j += 1
|
||||
inner = _expr(s[start:j-1]) # Recursively process bracket content
|
||||
result.append('[' + inner + ']')
|
||||
i = j
|
||||
else:
|
||||
result.append(s[i])
|
||||
i += 1
|
||||
return ''.join(result)
|
||||
e = process_brackets(e)
|
||||
|
||||
# Ternary: a ? b : c -> (b if a else c)
|
||||
while '?' in e:
|
||||
depth, bracket, q = 0, 0, -1
|
||||
for i, c in enumerate(e):
|
||||
if c == '(': depth += 1
|
||||
elif c == ')': depth -= 1
|
||||
elif c == '[': bracket += 1
|
||||
elif c == ']': bracket -= 1
|
||||
elif c == '?' and depth == 0 and bracket == 0: q = i; break
|
||||
if q < 0: break
|
||||
depth, bracket, col = 0, 0, -1
|
||||
for i in range(q + 1, len(e)):
|
||||
if e[i] == '(': depth += 1
|
||||
elif e[i] == ')': depth -= 1
|
||||
elif e[i] == '[': bracket += 1
|
||||
elif e[i] == ']': bracket -= 1
|
||||
elif e[i] == ':' and depth == 0 and bracket == 0: col = i; break
|
||||
if col < 0: break
|
||||
cond, t, f = e[:q].strip(), e[q+1:col].strip(), e[col+1:].strip()
|
||||
e = f'(({t}) if ({cond}) else ({f}))'
|
||||
return e
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# EXECUTION CONTEXT
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class ExecContext:
|
||||
"""Context for running compiled pseudocode."""
|
||||
def __init__(self, s0=0, s1=0, s2=0, d0=0, scc=0, vcc=0, lane=0, exec_mask=MASK32, literal=0, vgprs=None, src0_idx=0, vdst_idx=0):
|
||||
self.S0, self.S1, self.S2 = Reg(s0), Reg(s1), Reg(s2)
|
||||
self.D0, self.D1 = Reg(d0), Reg(0)
|
||||
self.SCC, self.VCC, self.EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)
|
||||
self.tmp, self.saveexec = Reg(0), Reg(exec_mask)
|
||||
self.lane, self.laneId, self.literal = lane, lane, literal
|
||||
self.SIMM16, self.SIMM32 = Reg(literal), Reg(literal)
|
||||
self.VGPR = vgprs if vgprs is not None else {}
|
||||
self.SRC0, self.VDST = Reg(src0_idx), Reg(vdst_idx)
|
||||
|
||||
def run(self, code: str):
|
||||
"""Execute compiled code."""
|
||||
# Start with module globals (helpers, aliases), then add instance-specific bindings
|
||||
ns = dict(globals())
|
||||
ns.update({
|
||||
'S0': self.S0, 'S1': self.S1, 'S2': self.S2, 'D0': self.D0, 'D1': self.D1,
|
||||
'SCC': self.SCC, 'VCC': self.VCC, 'EXEC': self.EXEC,
|
||||
'EXEC_LO': SliceProxy(self.EXEC, 31, 0), 'EXEC_HI': SliceProxy(self.EXEC, 63, 32),
|
||||
'tmp': self.tmp, 'saveexec': self.saveexec,
|
||||
'lane': self.lane, 'laneId': self.laneId, 'literal': self.literal,
|
||||
'SIMM16': self.SIMM16, 'SIMM32': self.SIMM32,
|
||||
'VGPR': self.VGPR, 'SRC0': self.SRC0, 'VDST': self.VDST,
|
||||
})
|
||||
exec(code, ns)
|
||||
# Sync rebinds: if register was reassigned to new Reg or value, copy it back
|
||||
def _sync(ctx_reg, ns_val):
|
||||
if isinstance(ns_val, Reg): ctx_reg._val = ns_val._val
|
||||
else: ctx_reg._val = int(ns_val) & MASK64
|
||||
if ns.get('SCC') is not self.SCC: _sync(self.SCC, ns['SCC'])
|
||||
if ns.get('VCC') is not self.VCC: _sync(self.VCC, ns['VCC'])
|
||||
if ns.get('EXEC') is not self.EXEC: _sync(self.EXEC, ns['EXEC'])
|
||||
if ns.get('D0') is not self.D0: _sync(self.D0, ns['D0'])
|
||||
if ns.get('D1') is not self.D1: _sync(self.D1, ns['D1'])
|
||||
if ns.get('tmp') is not self.tmp: _sync(self.tmp, ns['tmp'])
|
||||
if ns.get('saveexec') is not self.saveexec: _sync(self.saveexec, ns['saveexec'])
|
||||
|
||||
def result(self) -> dict:
|
||||
return {"d0": self.D0._val, "scc": self.SCC._val & 1}
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PDF EXTRACTION AND CODE GENERATION
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
PDF_URL = "https://docs.amd.com/api/khub/documents/UVVZM22UN7tMUeiW_4ShTQ/content"
|
||||
INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
|
||||
|
||||
# Patterns that can't be handled by the DSL (require special handling in emu.py)
|
||||
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'BYTE_PERMUTE', 'FATAL_HALT', 'HW_REGISTERS',
|
||||
'PC =', 'PC=', 'PC+', '= PC', 'v_sad', '+:', 'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
|
||||
'CVT_OFF_TABLE', '.bf16', 'ThreadMask', 'u8_to_u32', 'u4_to_u32',
|
||||
'S1[i', 'C.i32', 'v_msad_u8', 'S[i]', 'in[', '2.0 / PI',
|
||||
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST'] # Malformed pseudocode from PDF
|
||||
|
||||
def extract_pseudocode(text: str) -> str | None:
|
||||
"""Extract pseudocode from an instruction description snippet."""
|
||||
lines, result, depth = text.split('\n'), [], 0
|
||||
for line in lines:
|
||||
s = line.strip()
|
||||
if not s: continue
|
||||
if re.match(r'^\d+ of \d+$', s): continue
|
||||
if re.match(r'^\d+\.\d+\..*Instructions', s): continue
|
||||
if s.startswith('"RDNA') or s.startswith('AMD '): continue
|
||||
if s.startswith('Notes') or s.startswith('Functional examples'): break
|
||||
if s.startswith('if '): depth += 1
|
||||
elif s.startswith('endif'): depth = max(0, depth - 1)
|
||||
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
|
||||
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
|
||||
is_code = (
|
||||
any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =']) or
|
||||
any(p in s for p in ['D0[', 'D1[', 'S0[', 'S1[', 'S2[']) or
|
||||
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
|
||||
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s)
|
||||
)
|
||||
if is_code: result.append(s)
|
||||
return '\n'.join(result) if result else None
|
||||
|
||||
def parse_pseudocode_from_pdf(pdf_path: str | None = None) -> dict:
|
||||
"""Parse pseudocode from PDF for all ops. Returns {enum_cls: {op: pseudocode}}."""
|
||||
import pdfplumber
|
||||
from tinygrad.helpers import fetch
|
||||
from extra.assembly.rdna3.autogen import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp
|
||||
|
||||
OP_ENUMS = [SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp]
|
||||
defined_ops = {}
|
||||
for enum_cls in OP_ENUMS:
|
||||
for op in enum_cls:
|
||||
if op.name.startswith(('S_', 'V_')): defined_ops[(op.name, op.value)] = (enum_cls, op)
|
||||
|
||||
pdf = pdfplumber.open(fetch(PDF_URL) if pdf_path is None else pdf_path)
|
||||
all_text = '\n'.join(pdf.pages[i].extract_text() or '' for i in range(195, 560))
|
||||
matches = list(INST_PATTERN.finditer(all_text))
|
||||
instructions: dict = {cls: {} for cls in OP_ENUMS}
|
||||
|
||||
for i, match in enumerate(matches):
|
||||
name, opcode = match.group(1), int(match.group(2))
|
||||
key = (name, opcode)
|
||||
if key not in defined_ops: continue
|
||||
enum_cls, enum_val = defined_ops[key]
|
||||
start = match.end()
|
||||
end = matches[i + 1].start() if i + 1 < len(matches) else start + 2000
|
||||
snippet = all_text[start:end].strip()
|
||||
if (pseudocode := extract_pseudocode(snippet)): instructions[enum_cls][enum_val] = pseudocode
|
||||
|
||||
return instructions
|
||||
|
||||
def generate_gen_pcode(output_path: str = "extra/assembly/rdna3/autogen/gen_pcode.py"):
|
||||
"""Generate gen_pcode.py - compiled pseudocode functions for the emulator."""
|
||||
from pathlib import Path
|
||||
from extra.assembly.rdna3.autogen import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp
|
||||
|
||||
OP_ENUMS = [SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp]
|
||||
|
||||
print("Parsing pseudocode from PDF...")
|
||||
by_cls = parse_pseudocode_from_pdf()
|
||||
|
||||
total_found, total_ops = 0, 0
|
||||
for enum_cls in OP_ENUMS:
|
||||
total = sum(1 for op in enum_cls if op.name.startswith(('S_', 'V_')))
|
||||
found = len(by_cls.get(enum_cls, {}))
|
||||
total_found += found
|
||||
total_ops += total
|
||||
print(f"{enum_cls.__name__}: {found}/{total} ({100*found//total if total else 0}%)")
|
||||
print(f"Total: {total_found}/{total_ops} ({100*total_found//total_ops}%)")
|
||||
|
||||
print("\nCompiling to pseudocode functions...")
|
||||
lines = ['''# autogenerated by pcode.py - do not edit
|
||||
# to regenerate: python -m extra.assembly.rdna3.pcode
|
||||
# ruff: noqa: E501,F405,F403
|
||||
from extra.assembly.rdna3.autogen import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp
|
||||
from extra.assembly.rdna3.pcode import *
|
||||
''']
|
||||
|
||||
compiled_count, skipped_count = 0, 0
|
||||
|
||||
for enum_cls in OP_ENUMS:
|
||||
cls_name = enum_cls.__name__
|
||||
pseudocode_dict = by_cls.get(enum_cls, {})
|
||||
if not pseudocode_dict: continue
|
||||
|
||||
fn_entries = []
|
||||
for op, pc in pseudocode_dict.items():
|
||||
if any(p in pc for p in UNSUPPORTED):
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
code = compile_pseudocode(pc)
|
||||
# CLZ/CTZ: The PDF pseudocode searches for the first 1 bit but doesn't break.
|
||||
# Hardware stops at first match, so we need to add break after D0.i32 = i
|
||||
if 'CLZ' in op.name or 'CTZ' in op.name:
|
||||
code = code.replace('D0.i32 = i', 'D0.i32 = i; break # Stop at first 1 bit found')
|
||||
# Detect flags for result handling
|
||||
is_64 = any(p in pc for p in ['D0.u64', 'D0.b64', 'D0.f64', 'D0.i64', 'D1.u64', 'D1.b64', 'D1.f64', 'D1.i64'])
|
||||
has_d1 = '{ D1' in pc
|
||||
if has_d1: is_64 = True
|
||||
is_cmp = cls_name == 'VOPCOp' and 'D0.u64[laneId]' in pc
|
||||
is_cmpx = cls_name == 'VOPCOp' and 'EXEC.u64[laneId]' in pc # V_CMPX writes to EXEC per-lane
|
||||
# V_DIV_SCALE passes through S0 if no branch taken
|
||||
is_div_scale = 'DIV_SCALE' in op.name
|
||||
# VOP3SD instructions that write VCC per-lane (either via VCC.u64[laneId] or by setting VCC = 0/1)
|
||||
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
|
||||
|
||||
# Generate function with indented body
|
||||
fn_name = f"_{cls_name}_{op.name}"
|
||||
lines.append(f"def {fn_name}(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):")
|
||||
# Add original pseudocode as comment
|
||||
for pc_line in pc.split('\n'):
|
||||
lines.append(f" # {pc_line}")
|
||||
# V_DIV_SCALE: D0 defaults to S0 if no branch taken
|
||||
if is_div_scale:
|
||||
lines.append(" S0, S1, S2, D0, D1 = Reg(s0), Reg(s1), Reg(s2), Reg(s0), Reg(0)")
|
||||
else:
|
||||
lines.append(" S0, S1, S2, D0, D1 = Reg(s0), Reg(s1), Reg(s2), Reg(d0), Reg(0)")
|
||||
lines.append(" SCC, VCC, EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)")
|
||||
lines.append(" EXEC_LO, EXEC_HI = SliceProxy(EXEC, 31, 0), SliceProxy(EXEC, 63, 32)")
|
||||
lines.append(" tmp, saveexec = Reg(0), Reg(exec_mask)")
|
||||
lines.append(" laneId = lane")
|
||||
lines.append(" SIMM16, SIMM32 = Reg(literal), Reg(literal)")
|
||||
lines.append(" SRC0, VDST = Reg(src0_idx), Reg(vdst_idx)")
|
||||
# Add compiled pseudocode with markers
|
||||
lines.append(" # --- compiled pseudocode ---")
|
||||
for line in code.split('\n'):
|
||||
lines.append(f" {line}")
|
||||
lines.append(" # --- end pseudocode ---")
|
||||
# Generate result dict
|
||||
lines.append(" result = {'d0': D0._val, 'scc': SCC._val & 1}")
|
||||
if has_sdst:
|
||||
lines.append(" result['vcc_lane'] = (VCC._val >> lane) & 1")
|
||||
else:
|
||||
lines.append(" if VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1")
|
||||
if is_cmpx:
|
||||
lines.append(" result['exec_lane'] = (EXEC._val >> lane) & 1")
|
||||
else:
|
||||
lines.append(" if EXEC._val != exec_mask: result['exec'] = EXEC._val")
|
||||
if is_cmp:
|
||||
lines.append(" result['vcc_lane'] = (D0._val >> lane) & 1")
|
||||
if is_64:
|
||||
lines.append(" result['d0_64'] = True")
|
||||
if has_d1:
|
||||
lines.append(" result['d1'] = D1._val & 1")
|
||||
lines.append(" return result")
|
||||
lines.append("")
|
||||
|
||||
fn_entries.append((op, fn_name))
|
||||
compiled_count += 1
|
||||
except Exception as e:
|
||||
print(f" Warning: Failed to compile {op.name}: {e}")
|
||||
skipped_count += 1
|
||||
|
||||
if fn_entries:
|
||||
lines.append(f'{cls_name}_FUNCTIONS = {{')
|
||||
for op, fn_name in fn_entries:
|
||||
lines.append(f" {cls_name}.{op.name}: {fn_name},")
|
||||
lines.append('}')
|
||||
lines.append('')
|
||||
|
||||
# Add manually implemented lane instructions
|
||||
lines.append('''
|
||||
# Manually implemented lane instructions (require special vgpr_write handling)
|
||||
def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
|
||||
# VGPR[lane][VDST] = S0.b32 - writes s0 to specified lane's VGPR
|
||||
wr_lane = s1 & 0x1f # lane select (5 bits for wave32)
|
||||
return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)}
|
||||
|
||||
def _VOP3Op_V_READLANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
|
||||
# D0 = VGPR[lane][SRC0] - reads from specified lane's VGPR
|
||||
rd_lane = s1 & 0x1f # lane select (5 bits for wave32)
|
||||
val = VGPR[rd_lane][src0_idx] if VGPR is not None and rd_lane < len(VGPR) and src0_idx < len(VGPR[rd_lane]) else s0
|
||||
return {'d0': val & 0xffffffff, 'scc': scc}
|
||||
|
||||
def _VOP1Op_V_READFIRSTLANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
|
||||
# D0 = VGPR[first_active_lane][SRC0] - reads from first active lane
|
||||
first_lane = 0
|
||||
for i in range(32):
|
||||
if exec_mask & (1 << i):
|
||||
first_lane = i
|
||||
break
|
||||
val = VGPR[first_lane][src0_idx] if VGPR is not None and first_lane < len(VGPR) and src0_idx < len(VGPR[first_lane]) else s0
|
||||
return {'d0': val & 0xffffffff, 'scc': scc}
|
||||
''')
|
||||
|
||||
lines.append('COMPILED_FUNCTIONS = {')
|
||||
for enum_cls in OP_ENUMS:
|
||||
cls_name = enum_cls.__name__
|
||||
if by_cls.get(enum_cls): lines.append(f' {cls_name}: {cls_name}_FUNCTIONS,')
|
||||
lines.append('}')
|
||||
lines.append('')
|
||||
lines.append("# Add lane instructions to their respective dicts")
|
||||
lines.append("VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32")
|
||||
lines.append("VOP3Op_FUNCTIONS[VOP3Op.V_READLANE_B32] = _VOP3Op_V_READLANE_B32")
|
||||
lines.append("VOP1Op_FUNCTIONS[VOP1Op.V_READFIRSTLANE_B32] = _VOP1Op_V_READFIRSTLANE_B32")
|
||||
lines.append('')
|
||||
lines.append('def get_compiled_functions(): return COMPILED_FUNCTIONS')
|
||||
|
||||
Path(output_path).write_text('\n'.join(lines))
|
||||
print(f"\nGenerated {output_path}: {compiled_count} compiled, {skipped_count} skipped")
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_gen_pcode()
|
||||
196
extra/assembly/rdna3/test/external_test_usability.py
Normal file
196
extra/assembly/rdna3/test/external_test_usability.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# Usability tests for the RDNA3 ASM DSL
|
||||
# These tests demonstrate how the DSL *should* work for a good user experience
|
||||
# Currently many of these tests fail - they document desired behavior
|
||||
|
||||
import unittest
|
||||
from extra.assembly.rdna3.autogen import *
|
||||
from extra.assembly.rdna3.lib import Inst, RawImm, SGPR, VGPR
|
||||
|
||||
class TestRegisterSliceSyntax(unittest.TestCase):
|
||||
"""
|
||||
Issue: Register slice syntax should use AMD assembly convention (inclusive end).
|
||||
|
||||
In AMD assembly, s[4:7] means registers s4, s5, s6, s7 (4 registers, inclusive).
|
||||
The DSL should match this convention so that:
|
||||
- s[4:7] gives 4 registers
|
||||
- Disassembler output can be copied directly back into DSL code
|
||||
|
||||
Fix: Change _RegFactory.__getitem__ to use inclusive end:
|
||||
key.stop - key.start + 1 (instead of key.stop - key.start)
|
||||
"""
|
||||
def test_register_slice_count(self):
|
||||
# s[4:7] should give 4 registers: s4, s5, s6, s7 (AMD convention, inclusive)
|
||||
reg = s[4:7]
|
||||
self.assertEqual(reg.count, 4, "s[4:7] should give 4 registers (s4, s5, s6, s7)")
|
||||
|
||||
def test_register_slice_roundtrip(self):
|
||||
# Round-trip: DSL -> disasm -> DSL should preserve register count
|
||||
reg = s[4:7] # 4 registers in AMD convention
|
||||
inst = s_load_b128(reg, s[0:1], NULL, 0)
|
||||
disasm = inst.disasm()
|
||||
# Disasm shows s[4:7] - user should be able to copy this back
|
||||
self.assertIn("s[4:7]", disasm)
|
||||
# And s[4:7] in DSL should give the same 4 registers
|
||||
reg_from_disasm = s[4:7]
|
||||
self.assertEqual(reg_from_disasm.count, 4, "s[4:7] from disasm should give 4 registers")
|
||||
|
||||
|
||||
class TestReprReadability(unittest.TestCase):
|
||||
"""
|
||||
Issue: repr() leaks internal RawImm type and omits zero-valued fields.
|
||||
|
||||
When you create v_mov_b32_e32(v[0], v[1]), the repr shows:
|
||||
VOP1(op=1, src0=RawImm(257))
|
||||
|
||||
Problems:
|
||||
1. vdst=v[0] is omitted because 0 is treated as "default"
|
||||
2. src0 shows RawImm(257) instead of v[1]
|
||||
3. User sees encoded values (257 = 256 + 1) instead of register names
|
||||
|
||||
Expected repr: VOP1(op=1, vdst=v[0], src0=v[1])
|
||||
"""
|
||||
def test_repr_shows_registers_not_raw_imm(self):
|
||||
inst = v_mov_b32_e32(v[0], v[1])
|
||||
# Should show v[1], not RawImm(257)
|
||||
self.assertNotIn("RawImm", repr(inst), "repr should not expose RawImm internal type")
|
||||
self.assertIn("v[1]", repr(inst), "repr should show register name")
|
||||
|
||||
def test_repr_includes_zero_dst(self):
|
||||
inst = v_mov_b32_e32(v[0], v[1])
|
||||
# v[0] is a valid destination register, should be shown
|
||||
self.assertIn("vdst", repr(inst), "repr should include vdst even when 0")
|
||||
|
||||
def test_repr_roundtrip(self):
|
||||
# repr should produce something that can be eval'd back
|
||||
inst = v_mov_b32_e32(v[0], v[1])
|
||||
# This would require repr to output valid Python, e.g.:
|
||||
# "VOP1(op=VOP1Op.V_MOV_B32, vdst=v[0], src0=v[1])"
|
||||
r = repr(inst)
|
||||
# At minimum, it should be human-readable
|
||||
self.assertIn("v[", r, "repr should show register syntax")
|
||||
|
||||
|
||||
class TestInstructionEquality(unittest.TestCase):
|
||||
"""
|
||||
Issue: No __eq__ method - instruction comparison requires repr() workaround.
|
||||
|
||||
Two identical instructions should compare equal with ==, but currently:
|
||||
inst1 == inst2 returns False
|
||||
|
||||
The test_handwritten.py works around this with:
|
||||
self.assertEqual(repr(self.inst), repr(reasm))
|
||||
"""
|
||||
def test_identical_instructions_equal(self):
|
||||
inst1 = v_mov_b32_e32(v[0], v[1])
|
||||
inst2 = v_mov_b32_e32(v[0], v[1])
|
||||
self.assertEqual(inst1, inst2, "identical instructions should be equal")
|
||||
|
||||
def test_different_instructions_not_equal(self):
|
||||
inst1 = v_mov_b32_e32(v[0], v[1])
|
||||
inst2 = v_mov_b32_e32(v[0], v[2])
|
||||
self.assertNotEqual(inst1, inst2, "different instructions should not be equal")
|
||||
|
||||
|
||||
class TestVOPDHelperSignature(unittest.TestCase):
|
||||
"""
|
||||
Issue: VOPD helper functions have confusing semantics.
|
||||
|
||||
v_dual_mul_f32 is defined as:
|
||||
v_dual_mul_f32 = functools.partial(VOPD, VOPDOp.V_DUAL_MUL_F32)
|
||||
|
||||
This binds VOPDOp.V_DUAL_MUL_F32 to the FIRST positional arg of VOPD.__init__,
|
||||
which is 'opx'. So v_dual_mul_f32 sets the X operation.
|
||||
|
||||
But then test_dual_mul in test_handwritten.py does:
|
||||
v_dual_mul_f32(VOPDOp.V_DUAL_MUL_F32, vdstx=v[0], ...)
|
||||
|
||||
This passes V_DUAL_MUL_F32 as the SECOND positional arg (opy), making both
|
||||
X and Y operations the same. This is confusing because:
|
||||
1. The function name suggests it handles the X operation
|
||||
2. But you still pass an opcode as the first arg (which becomes opy)
|
||||
|
||||
Expected: Either make the helper fully specify both ops, or make the
|
||||
signature clearer about what the positional arg means.
|
||||
"""
|
||||
def test_vopd_helper_opy_should_be_required(self):
|
||||
# Using only keyword args "works" but opy silently defaults to 0
|
||||
inst = v_dual_mul_f32(vdstx=v[0], vdsty=v[1], srcx0=v[2], vsrcx1=v[3], srcy0=v[4], vsrcy1=v[5])
|
||||
self.assertEqual(inst.opx, VOPDOp.V_DUAL_MUL_F32)
|
||||
# Bug: opy defaults to 0 (V_DUAL_FMAC_F32) silently - should require explicit opy
|
||||
# This test documents the bug - it should fail once fixed
|
||||
self.assertNotEqual(inst.opy, VOPDOp.V_DUAL_FMAC_F32, "opy should not silently default to FMAC")
|
||||
|
||||
def test_vopd_helper_positional_arg_is_opy(self):
|
||||
# The first positional arg after the partial becomes opy, not a second opx
|
||||
inst = v_dual_mul_f32(VOPDOp.V_DUAL_MOV_B32, vdstx=v[0], vdsty=v[1], srcx0=v[2], vsrcx1=v[3], srcy0=v[4], vsrcy1=v[5])
|
||||
self.assertEqual(inst.opx, VOPDOp.V_DUAL_MUL_F32) # From partial
|
||||
self.assertEqual(inst.opy, VOPDOp.V_DUAL_MOV_B32) # From first positional arg
|
||||
|
||||
|
||||
class TestFieldAccessPreservesType(unittest.TestCase):
|
||||
"""
|
||||
Issue: Field access loses type information.
|
||||
|
||||
After creating an instruction, accessing fields returns encoded int values:
|
||||
inst = v_mov_b32_e32(v[0], v[1])
|
||||
inst.vdst # returns 0, not VGPR(0)
|
||||
|
||||
This makes it impossible to round-trip register types through field access.
|
||||
"""
|
||||
def test_vdst_returns_register(self):
|
||||
inst = v_mov_b32_e32(v[5], v[1])
|
||||
vdst = inst.vdst
|
||||
# Should return a VGPR, not an int
|
||||
self.assertIsInstance(vdst, (VGPR, int), "vdst should return VGPR or at least be usable")
|
||||
# Ideally: self.assertIsInstance(vdst, VGPR)
|
||||
|
||||
def test_src_returns_register_for_vgpr_source(self):
|
||||
inst = v_mov_b32_e32(v[0], v[1])
|
||||
# src0 is encoded as 257 (256 + 1 for v1)
|
||||
# Ideally it should decode back to v[1]
|
||||
src0_raw = inst._values.get('src0')
|
||||
# Currently returns RawImm(257), should return VGPR(1) or similar
|
||||
self.assertNotIsInstance(src0_raw, RawImm, "source should not be RawImm internally")
|
||||
|
||||
|
||||
class TestArgumentDiscoverability(unittest.TestCase):
|
||||
"""
|
||||
Issue: No clear signature for positional arguments.
|
||||
|
||||
inspect.signature(s_load_b128) shows: (*args, literal=None, **kwargs)
|
||||
|
||||
Users have no way to know the argument order without reading source code.
|
||||
The order is implicitly defined by the class field definition order.
|
||||
|
||||
Possible fixes:
|
||||
1. Add explicit parameter names to functools.partial
|
||||
2. Generate type stubs with proper signatures
|
||||
3. Add docstrings listing the expected arguments
|
||||
"""
|
||||
def test_signature_has_named_params(self):
|
||||
import inspect
|
||||
sig = inspect.signature(s_load_b128)
|
||||
params = list(sig.parameters.keys())
|
||||
# Currently: ['args', 'literal', 'kwargs'] (from *args, literal=None, **kwargs)
|
||||
# Expected: something like ['sdata', 'sbase', 'soffset', 'offset', 'literal']
|
||||
self.assertIn('sdata', params, "signature should show field names")
|
||||
|
||||
|
||||
class TestSpecialConstants(unittest.TestCase):
|
||||
"""
|
||||
Issue: NULL and other constants are IntEnum values that might be confusing.
|
||||
|
||||
NULL = SrcEnum.NULL = 124, but users might expect NULL to be a special object
|
||||
that clearly represents "no register" rather than a magic number.
|
||||
"""
|
||||
def test_null_has_clear_repr(self):
|
||||
# NULL should have a clear string representation
|
||||
self.assertIn("NULL", str(NULL) or repr(NULL), "NULL should be clearly identifiable")
|
||||
|
||||
def test_null_is_distinguishable_from_int(self):
|
||||
# NULL should be distinguishable from the raw integer 124
|
||||
self.assertNotEqual(type(NULL), int, "NULL should not be plain int")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -20,6 +20,15 @@ class KernelInfo:
|
||||
buf_idxs: list[int] # indices into shared buffer pool
|
||||
buf_sizes: list[int] # sizes for each buffer index
|
||||
|
||||
def _is_f32_nan(bits: int) -> bool:
|
||||
"""Check if 32-bit value is a NaN (exponent all 1s, mantissa non-zero)."""
|
||||
return (bits & 0x7f800000) == 0x7f800000 and (bits & 0x007fffff) != 0
|
||||
|
||||
def _vals_equal(a: int, b: int) -> bool:
|
||||
"""Compare two 32-bit values, treating all NaN bit patterns as equal."""
|
||||
if a == b: return True
|
||||
return _is_f32_nan(a) and _is_f32_nan(b)
|
||||
|
||||
@dataclass
|
||||
class StateSnapshot:
|
||||
pc: int
|
||||
@@ -29,20 +38,20 @@ class StateSnapshot:
|
||||
sgpr: list[int]
|
||||
vgpr: list[list[int]]
|
||||
|
||||
def diff(self, other: 'StateSnapshot', n_lanes: int) -> list[str]:
|
||||
def diff(self, other: 'StateSnapshot', n_lanes: int, arrow: str = " vs ") -> list[str]:
|
||||
"""Return list of differences between two states."""
|
||||
diffs = []
|
||||
if self.pc != other.pc: diffs.append(f"pc: {self.pc} vs {other.pc}")
|
||||
if self.scc != other.scc: diffs.append(f"scc: {self.scc} vs {other.scc}")
|
||||
if self.vcc != other.vcc: diffs.append(f"vcc: 0x{self.vcc:08x} vs 0x{other.vcc:08x}")
|
||||
if self.exec_mask != other.exec_mask: diffs.append(f"exec: 0x{self.exec_mask:08x} vs 0x{other.exec_mask:08x}")
|
||||
if self.pc != other.pc: diffs.append(f"pc: {self.pc}{arrow}{other.pc}")
|
||||
if self.scc != other.scc: diffs.append(f"scc: {self.scc}{arrow}{other.scc}")
|
||||
if self.vcc != other.vcc: diffs.append(f"vcc: 0x{self.vcc:08x}{arrow}0x{other.vcc:08x}")
|
||||
if self.exec_mask != other.exec_mask: diffs.append(f"exec: 0x{self.exec_mask:08x}{arrow}0x{other.exec_mask:08x}")
|
||||
for i, (a, b) in enumerate(zip(self.sgpr, other.sgpr)):
|
||||
# Skip VCC_LO/HI (106/107) and EXEC_LO/HI (126/127) as they alias vcc/exec_mask which are compared separately
|
||||
if i in (106, 107, 126, 127): continue
|
||||
if a != b: diffs.append(f"sgpr[{i}]: 0x{a:08x} vs 0x{b:08x}")
|
||||
if not _vals_equal(a, b): diffs.append(f"sgpr[{i}]: 0x{a:08x}{arrow}0x{b:08x}")
|
||||
for lane in range(n_lanes):
|
||||
for i, (a, b) in enumerate(zip(self.vgpr[lane], other.vgpr[lane])):
|
||||
if a != b: diffs.append(f"vgpr[{lane}][{i}]: 0x{a:08x} vs 0x{b:08x}")
|
||||
if not _vals_equal(a, b): diffs.append(f"vgpr[{lane}][{i}]: 0x{a:08x}{arrow}0x{b:08x}")
|
||||
return diffs
|
||||
|
||||
class CStateSnapshot(ctypes.Structure):
|
||||
@@ -157,17 +166,32 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
|
||||
|
||||
if debug: print(f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step}: PC={python_before.pc}, inst={inst_str}")
|
||||
|
||||
# Instructions with known Rust emulator bugs - sync Python to Rust after execution
|
||||
# v_div_scale/v_div_fixup: Rust has different VCC handling
|
||||
# v_cvt_f16_f32: Rust clears high 16 bits, but hardware (and Python) preserves them
|
||||
sync_after = any(x in inst_str for x in ('v_div_scale_f32', 'v_div_scale_f64', 'v_div_fixup_f32', 'v_div_fixup_f64',
|
||||
'v_cvt_f16_f32'))
|
||||
diffs = rust_before.diff(python_before, n_lanes)
|
||||
if diffs:
|
||||
trace_lines = []
|
||||
for s, pc, d, rb, pb in trace[:-1]:
|
||||
for idx, (s, pc, d, rb, pb) in enumerate(trace):
|
||||
trace_lines.append(f" step {s}: PC={pc:3d} {d}")
|
||||
if trace.index((s, pc, d, rb, pb)) < len(trace) - 2:
|
||||
next_rb, next_pb = trace[trace.index((s, pc, d, rb, pb)) + 1][3:5]
|
||||
inst_diffs = rb.diff(next_rb, n_lanes)
|
||||
if inst_diffs: trace_lines.append(f" rust changes: {', '.join(inst_diffs[:3])}")
|
||||
if idx < len(trace) - 1:
|
||||
next_rb, next_pb = trace[idx + 1][3:5]
|
||||
rust_diffs = rb.diff(next_rb, n_lanes, "->")
|
||||
python_diffs = pb.diff(next_pb, n_lanes, "->")
|
||||
if rust_diffs: trace_lines.append(f" rust: {', '.join(rust_diffs[:5])}")
|
||||
if python_diffs: trace_lines.append(f" python: {', '.join(python_diffs[:5])}")
|
||||
elif rust_diffs: trace_lines.append(f" python: (no changes)")
|
||||
else:
|
||||
# Last traced instruction - compare with current state
|
||||
rust_diffs = rb.diff(rust_before, n_lanes, "->")
|
||||
python_diffs = pb.diff(python_before, n_lanes, "->")
|
||||
if rust_diffs: trace_lines.append(f" rust: {', '.join(rust_diffs[:5])}")
|
||||
if python_diffs: trace_lines.append(f" python: {', '.join(python_diffs[:5])}")
|
||||
elif rust_diffs: trace_lines.append(f" python: (no changes)")
|
||||
trace_str = "\n".join(trace_lines)
|
||||
return False, f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step} before inst '{inst_str}': states differ:\n " + "\n ".join(diffs[:10]) + f"\n Recent instructions:\n{trace_str}", total_steps
|
||||
return False, f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step} before inst '{inst_str}': states differ (rust vs python):\n " + "\n ".join(diffs[:10]) + f"\n Recent instructions:\n{trace_str}", total_steps
|
||||
|
||||
rust_result = rust.step()
|
||||
python_result = python.step()
|
||||
@@ -176,6 +200,14 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
|
||||
trace_str = "\n".join(f" step {s}: PC={pc:3d} {d}" for s, pc, d, _, _ in trace)
|
||||
return False, f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step}: different return codes: rust={rust_result}, python={python_result}, inst={inst_str}\n Recent instructions:\n{trace_str}", total_steps
|
||||
|
||||
# Sync Python state to Rust after instructions with known Rust emulator differences
|
||||
if sync_after:
|
||||
rust_after = rust.get_snapshot()
|
||||
for i in range(128): python.set_sgpr(i, rust_after.sgpr[i])
|
||||
for lane in range(n_lanes):
|
||||
for i in range(256): python.set_vgpr(lane, i, rust_after.vgpr[lane][i])
|
||||
python.state.pc, python.state.scc, python.state.vcc, python.state.exec_mask = rust_after.pc, rust_after.scc, rust_after.vcc, rust_after.exec_mask
|
||||
|
||||
if rust_result == -1:
|
||||
total_steps += step + 1
|
||||
break
|
||||
@@ -330,9 +362,21 @@ class TestTinygradKernels(unittest.TestCase):
|
||||
def test_exp(self): self._test_kernel(lambda T: T([0.0, 1.0, 2.0]).exp())
|
||||
def test_log(self): self._test_kernel(lambda T: T([1.0, 2.0, 3.0]).log())
|
||||
def test_sin(self): self._test_kernel(lambda T: T([0.0, 1.0, 2.0]).sin())
|
||||
def test_cos(self): self._test_kernel(lambda T: T([0.0, 1.0, 2.0]).cos())
|
||||
def test_sqrt(self): self._test_kernel(lambda T: T([1.0, 4.0, 9.0]).sqrt())
|
||||
def test_recip(self): self._test_kernel(lambda T: T([1.0, 2.0, 4.0]).reciprocal())
|
||||
|
||||
# Sin/cos with various ranges - test polynomial expansion
|
||||
def test_sin_small(self): self._test_kernel(lambda T: T([0.1, 0.2, 0.3, 0.4, 0.5]*7).sin()) # 35 elements, small angles
|
||||
def test_sin_pi(self): self._test_kernel(lambda T: T([3.14159, 1.5708, 0.7854, -1.5708, -3.14159]*7).sin()) # around pi
|
||||
def test_sin_medium(self): self._test_kernel(lambda T: T([10.0, 20.0, 30.0, 50.0, 100.0]*7).sin()) # medium values
|
||||
def test_sin_negative(self): self._test_kernel(lambda T: T([-0.5, -1.0, -2.0, -5.0, -10.0]*7).sin()) # negative values
|
||||
def test_cos_small(self): self._test_kernel(lambda T: T([0.1, 0.2, 0.3, 0.4, 0.5]*7).cos())
|
||||
def test_cos_pi(self): self._test_kernel(lambda T: T([3.14159, 1.5708, 0.7854, -1.5708, -3.14159]*7).cos())
|
||||
def test_cos_medium(self): self._test_kernel(lambda T: T([10.0, 20.0, 30.0, 50.0, 100.0]*7).cos())
|
||||
@unittest.skip("Rust emulator has V_DIV_SCALE_F32 bug - returns 0 instead of src0 for normal cases")
|
||||
def test_tan(self): self._test_kernel(lambda T: T([0.1, 0.2, 0.5, 1.0, -0.5]*7).tan()) # avoid pi/2
|
||||
|
||||
# Binary ops
|
||||
def test_add(self): self._test_kernel(lambda T: T([1.0, 2.0]) + T([3.0, 4.0]))
|
||||
def test_sub(self): self._test_kernel(lambda T: T([5.0, 6.0]) - T([1.0, 2.0]))
|
||||
@@ -445,6 +489,14 @@ class TestTinygradKernels(unittest.TestCase):
|
||||
|
||||
# Pooling operations - regression test for VCC wave32 mode (S_CBRANCH_VCCZ should only check VCC_LO)
|
||||
def test_avg_pool2d(self): self._test_kernel(lambda T: T.empty(1, 1, 8, 8).avg_pool2d(kernel_size=(4,4), stride=2))
|
||||
|
||||
# Trig functions with special values (inf, nan, 0)
|
||||
def test_sin_special(self): self._test_kernel(lambda T: T([0., 0.25, 0.5, 1.0]*8).sin())
|
||||
def test_cos_special(self): self._test_kernel(lambda T: T([0., 0.25, 0.5, 1.0]*8).cos())
|
||||
|
||||
# Sqrt and rsqrt
|
||||
def test_sqrt(self): self._test_kernel(lambda T: T([0., 1., 4., 9.]*8).sqrt())
|
||||
def test_rsqrt(self): self._test_kernel(lambda T: T([1., 4., 9., 16.]*8).rsqrt())
|
||||
@unittest.skip("Rust emulator has S_ADD_I32 SCC bug - uses carry instead of signed overflow")
|
||||
def test_avg_pool3d(self):
|
||||
import numpy as np
|
||||
@@ -462,5 +514,33 @@ class TestTinygradKernels(unittest.TestCase):
|
||||
self._test_kernel(lambda T: T(np.random.randn(2, 4, 9, 9, 9).astype(np.float32).tolist()).conv_transpose2d(
|
||||
T(np.random.randn(4, 4, 3, 3, 3).astype(np.float32).tolist())), max_steps=500000)
|
||||
|
||||
# Tests from test_ops.py failures
|
||||
def test_gelu_extreme(self): self._test_kernel(lambda T: T.empty(45, 65).gelu())
|
||||
def test_gemm_64x64(self): self._test_kernel(lambda T: T.empty(64, 64) @ T.empty(64, 64), max_steps=500000)
|
||||
def test_gemm_fp16(self): self._test_kernel(lambda T: T.empty(64, 64).half() @ T.empty(64, 64).half(), max_steps=500000)
|
||||
def test_global_avg_pool2d(self): self._test_kernel(lambda T: T.empty(32, 2, 111, 28).avg_pool2d(kernel_size=(111, 28)), max_steps=100000)
|
||||
@unittest.skip("Rust emulator has S_ADD_I32 SCC bug - uses carry instead of signed overflow")
|
||||
def test_grouped_conv2d(self): self._test_kernel(lambda T: T.empty(4, 15, 5, 5).conv2d(T.empty(35, 3, 3, 3), groups=5), max_steps=200000)
|
||||
@unittest.skip("Rust emulator has S_ADD_I32 SCC bug - uses carry instead of signed overflow")
|
||||
def test_grouped_conv_transpose2d(self): self._test_kernel(lambda T: T.empty(2, 4, 9, 9).conv_transpose2d(T.empty(4, 4, 3, 3), groups=2), max_steps=200000)
|
||||
def test_hardsigmoid(self): self._test_kernel(lambda T: T.empty(45, 65).hardsigmoid())
|
||||
def test_hardsigmoid_extreme(self): self._test_kernel(lambda T: T.empty(45, 65).sigmoid())
|
||||
def test_matvec(self): self._test_kernel(lambda T: (T.empty(1, 128) @ T.empty(128, 128)).relu(), max_steps=200000)
|
||||
def test_matvecmat(self): self._test_kernel(lambda T: ((T.empty(1, 128) @ T.empty(128, 128)).relu() @ T.empty(128, 128)), max_steps=300000)
|
||||
def test_max_reduce_45x3(self): self._test_kernel(lambda T: T.empty(45, 3).max())
|
||||
def test_max_dont_collapse(self): self._test_kernel(lambda T: T.empty(4, 8).max(axis=1))
|
||||
def test_max_pool2d_simple(self): self._test_kernel(lambda T: T.empty(1, 1, 2, 3).max_pool2d(kernel_size=(2, 2)))
|
||||
def test_max_pool2d_32x2(self): self._test_kernel(lambda T: T.empty(32, 2, 11, 28).max_pool2d(kernel_size=(2, 2)))
|
||||
def test_max_pool2d_asymmetric_padding(self): self._test_kernel(lambda T: T.empty(4, 2, 111, 28).max_pool2d(kernel_size=(5, 5), padding=(0, 1, 0, 1)))
|
||||
def test_max_pool2d_bigger_stride(self): self._test_kernel(lambda T: T.empty(4, 2, 11, 28).max_pool2d(kernel_size=(2, 2), stride=(2, 3)))
|
||||
def test_max_pool2d_unit_stride(self): self._test_kernel(lambda T: T.empty(3, 2, 17, 14).max_pool2d(kernel_size=(5, 5), stride=1))
|
||||
def test_max_pool2d_smaller_stride(self): self._test_kernel(lambda T: T.empty(3, 2, 17, 14).max_pool2d(kernel_size=(5, 5), stride=(2, 3)))
|
||||
def test_max_unpool2d(self): self._test_kernel(lambda T: T.max_unpool2d(*T.empty(8, 3, 50, 50).max_pool2d(kernel_size=(5, 5), stride=(6, 5), return_indices=True), kernel_size=(5, 5), stride=(6, 5)))
|
||||
def test_isinf(self): self._test_kernel(lambda T: T([float('-inf'), 0., float('inf'), 1.1]*8).isinf())
|
||||
def test_isfinite(self): self._test_kernel(lambda T: T([float('-inf'), 0., float('inf'), 1.1]*8).isfinite())
|
||||
|
||||
# WMMA tests - uses wave matrix multiply for larger fp16 matmuls
|
||||
def test_wmma_gemm_fp16(self): self._test_kernel(lambda T: T.empty(64, 64).half() @ T.empty(64, 64).half(), max_steps=1000000)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -46,7 +46,8 @@ dev.synchronize()
|
||||
elapsed = time.perf_counter() - st
|
||||
|
||||
self.assertNotEqual(result.returncode, 0, "should have raised")
|
||||
self.assertIn("NotImplementedError", result.stderr)
|
||||
self.assertTrue("NotImplementedError" in result.stderr or "ValueError" in result.stderr,
|
||||
f"expected NotImplementedError or ValueError in stderr")
|
||||
# Should exit immediately, not wait for the full timeout
|
||||
self.assertLess(elapsed, 5.0, f"should exit immediately on emulator exception, took {elapsed:.1f}s")
|
||||
|
||||
|
||||
269
extra/assembly/rdna3/test/test_pcode.py
Normal file
269
extra/assembly/rdna3/test/test_pcode.py
Normal file
@@ -0,0 +1,269 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tests for the RDNA3 pseudocode DSL."""
|
||||
import unittest
|
||||
from extra.assembly.rdna3.pcode import Reg, TypedView, SliceProxy, ExecContext, compile_pseudocode, _expr, MASK32, MASK64, _f32, _i32, _f16, _i16, f32_to_f16, _isnan
|
||||
from extra.assembly.rdna3.autogen.gen_pcode import _VOP3SDOp_V_DIV_SCALE_F32, _VOPCOp_V_CMP_CLASS_F32
|
||||
|
||||
class TestReg(unittest.TestCase):
|
||||
def test_u32_read(self):
|
||||
r = Reg(0xDEADBEEF)
|
||||
self.assertEqual(int(r.u32), 0xDEADBEEF)
|
||||
|
||||
def test_u32_write(self):
|
||||
r = Reg(0)
|
||||
r.u32 = 0x12345678
|
||||
self.assertEqual(r._val, 0x12345678)
|
||||
|
||||
def test_f32_read(self):
|
||||
r = Reg(0x40400000) # 3.0f
|
||||
self.assertAlmostEqual(float(r.f32), 3.0)
|
||||
|
||||
def test_f32_write(self):
|
||||
r = Reg(0)
|
||||
r.f32 = 3.0
|
||||
self.assertEqual(r._val, 0x40400000)
|
||||
|
||||
def test_i32_signed(self):
|
||||
r = Reg(0xFFFFFFFF) # -1 as signed
|
||||
self.assertEqual(int(r.i32), -1)
|
||||
|
||||
def test_u64(self):
|
||||
r = Reg(0xDEADBEEFCAFEBABE)
|
||||
self.assertEqual(int(r.u64), 0xDEADBEEFCAFEBABE)
|
||||
|
||||
def test_f64(self):
|
||||
r = Reg(0x4008000000000000) # 3.0 as f64
|
||||
self.assertAlmostEqual(float(r.f64), 3.0)
|
||||
|
||||
class TestTypedView(unittest.TestCase):
|
||||
def test_bit_slice(self):
|
||||
r = Reg(0xDEADBEEF)
|
||||
# Slices return SliceProxy which supports .u32, .u16 etc (matching pseudocode like S1.u32[1:0].u32)
|
||||
self.assertEqual(r.u32[7:0].u32, 0xEF)
|
||||
self.assertEqual(r.u32[15:8].u32, 0xBE)
|
||||
self.assertEqual(r.u32[23:16].u32, 0xAD)
|
||||
self.assertEqual(r.u32[31:24].u32, 0xDE)
|
||||
# Also works with int() for arithmetic
|
||||
self.assertEqual(int(r.u32[7:0]), 0xEF)
|
||||
|
||||
def test_single_bit_read(self):
|
||||
r = Reg(0b11010101)
|
||||
self.assertEqual(r.u32[0], 1)
|
||||
self.assertEqual(r.u32[1], 0)
|
||||
self.assertEqual(r.u32[2], 1)
|
||||
self.assertEqual(r.u32[3], 0)
|
||||
|
||||
def test_single_bit_write(self):
|
||||
r = Reg(0)
|
||||
r.u32[5] = 1
|
||||
r.u32[3] = 1
|
||||
self.assertEqual(r._val, 0b00101000)
|
||||
|
||||
def test_nested_bit_access(self):
|
||||
# S0.u32[S1.u32[4:0]] - access bit at position from another register
|
||||
s0 = Reg(0b11010101)
|
||||
s1 = Reg(3)
|
||||
bit_pos = s1.u32[4:0] # SliceProxy, int value = 3
|
||||
bit_val = s0.u32[int(bit_pos)] # bit 3 of s0 = 0
|
||||
self.assertEqual(int(bit_pos), 3)
|
||||
self.assertEqual(bit_val, 0)
|
||||
|
||||
def test_arithmetic(self):
|
||||
r1 = Reg(0x40400000) # 3.0f
|
||||
r2 = Reg(0x40800000) # 4.0f
|
||||
result = r1.f32 + r2.f32
|
||||
self.assertAlmostEqual(result, 7.0)
|
||||
|
||||
def test_comparison(self):
|
||||
r1 = Reg(5)
|
||||
r2 = Reg(3)
|
||||
self.assertTrue(r1.u32 > r2.u32)
|
||||
self.assertFalse(r1.u32 < r2.u32)
|
||||
self.assertTrue(r1.u32 != r2.u32)
|
||||
|
||||
class TestSliceProxy(unittest.TestCase):
|
||||
def test_slice_read(self):
|
||||
r = Reg(0x56781234)
|
||||
self.assertEqual(r[15:0].u16, 0x1234)
|
||||
self.assertEqual(r[31:16].u16, 0x5678)
|
||||
|
||||
def test_slice_write(self):
|
||||
r = Reg(0)
|
||||
r[15:0].u16 = 0x1234
|
||||
r[31:16].u16 = 0x5678
|
||||
self.assertEqual(r._val, 0x56781234)
|
||||
|
||||
def test_slice_f16(self):
|
||||
r = Reg(0)
|
||||
r[15:0].f16 = 3.0
|
||||
self.assertAlmostEqual(_f16(r._val & 0xffff), 3.0, places=2)
|
||||
|
||||
class TestCompiler(unittest.TestCase):
|
||||
def test_ternary(self):
|
||||
result = _expr("a > b ? 1 : 0")
|
||||
self.assertIn("if", result)
|
||||
self.assertIn("else", result)
|
||||
|
||||
def test_type_prefix_strip(self):
|
||||
self.assertEqual(_expr("1'0U"), "0")
|
||||
self.assertEqual(_expr("32'1"), "1")
|
||||
self.assertEqual(_expr("16'0xFFFF"), "0xFFFF")
|
||||
|
||||
def test_suffix_strip(self):
|
||||
self.assertEqual(_expr("0ULL"), "0")
|
||||
self.assertEqual(_expr("1LL"), "1")
|
||||
self.assertEqual(_expr("5U"), "5")
|
||||
self.assertEqual(_expr("3.14F"), "3.14")
|
||||
|
||||
def test_boolean_ops(self):
|
||||
self.assertIn("and", _expr("a && b"))
|
||||
self.assertIn("or", _expr("a || b"))
|
||||
self.assertIn("!=", _expr("a <> b"))
|
||||
|
||||
def test_pack16(self):
|
||||
result = _expr("{ a, b }")
|
||||
self.assertIn("_pack", result)
|
||||
|
||||
def test_type_cast_strip(self):
|
||||
self.assertEqual(_expr("64'U(x)"), "(x)")
|
||||
self.assertEqual(_expr("32'I(y)"), "(y)")
|
||||
|
||||
class TestExecContext(unittest.TestCase):
|
||||
def test_float_add(self):
|
||||
ctx = ExecContext(s0=0x40400000, s1=0x40800000) # 3.0f, 4.0f
|
||||
ctx.D0.f32 = ctx.S0.f32 + ctx.S1.f32
|
||||
self.assertAlmostEqual(_f32(ctx.D0._val), 7.0)
|
||||
|
||||
def test_float_mul(self):
|
||||
ctx = ExecContext(s0=0x40400000, s1=0x40800000) # 3.0f, 4.0f
|
||||
ctx.run("D0.f32 = S0.f32 * S1.f32")
|
||||
self.assertAlmostEqual(_f32(ctx.D0._val), 12.0)
|
||||
|
||||
def test_scc_comparison(self):
|
||||
ctx = ExecContext(s0=42, s1=42)
|
||||
ctx.run("SCC = S0.u32 == S1.u32")
|
||||
self.assertEqual(ctx.SCC._val, 1)
|
||||
|
||||
def test_scc_comparison_false(self):
|
||||
ctx = ExecContext(s0=42, s1=43)
|
||||
ctx.run("SCC = S0.u32 == S1.u32")
|
||||
self.assertEqual(ctx.SCC._val, 0)
|
||||
|
||||
def test_ternary(self):
|
||||
code = compile_pseudocode("D0.u32 = S0.u32 > S1.u32 ? 1'1U : 1'0U")
|
||||
ctx = ExecContext(s0=5, s1=3)
|
||||
ctx.run(code)
|
||||
self.assertEqual(ctx.D0._val, 1)
|
||||
|
||||
def test_pack(self):
|
||||
code = compile_pseudocode("D0 = { S1[15:0].u16, S0[15:0].u16 }")
|
||||
ctx = ExecContext(s0=0x1234, s1=0x5678)
|
||||
ctx.run(code)
|
||||
self.assertEqual(ctx.D0._val, 0x56781234)
|
||||
|
||||
def test_tmp_with_typed_access(self):
|
||||
code = compile_pseudocode("""tmp = S0.u32 + S1.u32
|
||||
D0.u32 = tmp.u32""")
|
||||
ctx = ExecContext(s0=100, s1=200)
|
||||
ctx.run(code)
|
||||
self.assertEqual(ctx.D0._val, 300)
|
||||
|
||||
def test_s_add_u32_pattern(self):
|
||||
# Real pseudocode pattern from S_ADD_U32
|
||||
code = compile_pseudocode("""tmp = 64'U(S0.u32) + 64'U(S1.u32)
|
||||
SCC = tmp >= 0x100000000ULL ? 1'1U : 1'0U
|
||||
D0.u32 = tmp.u32""")
|
||||
# Test overflow case
|
||||
ctx = ExecContext(s0=0xFFFFFFFF, s1=0x00000001)
|
||||
ctx.run(code)
|
||||
self.assertEqual(ctx.D0._val, 0) # Wraps to 0
|
||||
self.assertEqual(ctx.SCC._val, 1) # Carry set
|
||||
|
||||
def test_s_add_u32_no_overflow(self):
|
||||
code = compile_pseudocode("""tmp = 64'U(S0.u32) + 64'U(S1.u32)
|
||||
SCC = tmp >= 0x100000000ULL ? 1'1U : 1'0U
|
||||
D0.u32 = tmp.u32""")
|
||||
ctx = ExecContext(s0=100, s1=200)
|
||||
ctx.run(code)
|
||||
self.assertEqual(ctx.D0._val, 300)
|
||||
self.assertEqual(ctx.SCC._val, 0) # No carry
|
||||
|
||||
def test_vcc_lane_read(self):
|
||||
ctx = ExecContext(vcc=0b1010, lane=1)
|
||||
# Lane 1 is set
|
||||
self.assertEqual(ctx.VCC.u64[1], 1)
|
||||
self.assertEqual(ctx.VCC.u64[2], 0)
|
||||
|
||||
def test_vcc_lane_write(self):
|
||||
ctx = ExecContext(vcc=0, lane=0)
|
||||
ctx.VCC.u64[3] = 1
|
||||
ctx.VCC.u64[1] = 1
|
||||
self.assertEqual(ctx.VCC._val, 0b1010)
|
||||
|
||||
def test_for_loop(self):
|
||||
# CTZ pattern - find first set bit
|
||||
code = compile_pseudocode("""tmp = -1
|
||||
for i in 0 : 31 do
|
||||
if S0.u32[i] == 1 then
|
||||
tmp = i
|
||||
D0.i32 = tmp""")
|
||||
ctx = ExecContext(s0=0b1000) # Bit 3 is set
|
||||
ctx.run(code)
|
||||
self.assertEqual(ctx.D0._val & MASK32, 3)
|
||||
|
||||
def test_result_dict(self):
|
||||
ctx = ExecContext(s0=5, s1=3)
|
||||
ctx.D0.u32 = 42
|
||||
ctx.SCC._val = 1
|
||||
result = ctx.result()
|
||||
self.assertEqual(result['d0'], 42)
|
||||
self.assertEqual(result['scc'], 1)
|
||||
|
||||
class TestPseudocodeRegressions(unittest.TestCase):
|
||||
"""Regression tests for pseudocode instruction emulation bugs."""
|
||||
|
||||
def test_v_div_scale_f32_vcc_always_returned(self):
|
||||
"""V_DIV_SCALE_F32 must always return vcc_lane, even when VCC=0 (no scaling needed).
|
||||
Bug: when VCC._val == vcc (both 0), vcc_lane wasn't returned, so VCC bits weren't written.
|
||||
This caused division to produce wrong results for multiple lanes."""
|
||||
# Normal case: 1.0 / 3.0, no scaling needed, VCC should be 0
|
||||
s0 = 0x3f800000 # 1.0
|
||||
s1 = 0x40400000 # 3.0
|
||||
s2 = 0x3f800000 # 1.0 (numerator)
|
||||
result = _VOP3SDOp_V_DIV_SCALE_F32(s0, s1, s2, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
# Must always have vcc_lane in result
|
||||
self.assertIn('vcc_lane', result, "V_DIV_SCALE_F32 must always return vcc_lane")
|
||||
self.assertEqual(result['vcc_lane'], 0, "vcc_lane should be 0 when no scaling needed")
|
||||
|
||||
def test_v_cmp_class_f32_detects_quiet_nan(self):
|
||||
"""V_CMP_CLASS_F32 must correctly identify quiet NaN vs signaling NaN.
|
||||
Bug: isQuietNAN and isSignalNAN both used math.isnan which can't distinguish them."""
|
||||
quiet_nan = 0x7fc00000 # quiet NaN: exponent=255, bit22=1
|
||||
signal_nan = 0x7f800001 # signaling NaN: exponent=255, bit22=0
|
||||
# Test quiet NaN detection (bit 1 in mask)
|
||||
s1_quiet = 0b0000000010 # bit 1 = quiet NaN
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
self.assertEqual(result['vcc_lane'], 1, "Should detect quiet NaN with quiet NaN mask")
|
||||
# Test signaling NaN detection (bit 0 in mask)
|
||||
s1_signal = 0b0000000001 # bit 0 = signaling NaN
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
self.assertEqual(result['vcc_lane'], 1, "Should detect signaling NaN with signaling NaN mask")
|
||||
# Test that quiet NaN doesn't match signaling NaN mask
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
self.assertEqual(result['vcc_lane'], 0, "Quiet NaN should not match signaling NaN mask")
|
||||
# Test that signaling NaN doesn't match quiet NaN mask
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
self.assertEqual(result['vcc_lane'], 0, "Signaling NaN should not match quiet NaN mask")
|
||||
|
||||
def test_isnan_with_typed_view(self):
|
||||
"""_isnan must work with TypedView objects, not just Python floats.
|
||||
Bug: _isnan checked isinstance(x, float) which returned False for TypedView."""
|
||||
nan_reg = Reg(0x7fc00000) # quiet NaN
|
||||
normal_reg = Reg(0x3f800000) # 1.0
|
||||
inf_reg = Reg(0x7f800000) # +inf
|
||||
self.assertTrue(_isnan(nan_reg.f32), "_isnan should return True for NaN TypedView")
|
||||
self.assertFalse(_isnan(normal_reg.f32), "_isnan should return False for normal TypedView")
|
||||
self.assertFalse(_isnan(inf_reg.f32), "_isnan should return False for inf TypedView")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -7,6 +7,7 @@ import tinygrad.runtime.autogen.amd_gpu as amd_gpu, tinygrad.runtime.autogen.am.
|
||||
SDMA_MAX_COPY_SIZE = 0x400000
|
||||
|
||||
regCOMPUTE_PGM_LO = 0x1bac + amd_gpu.GC_BASE__INST0_SEG0
|
||||
regCOMPUTE_PGM_RSRC2 = 0x1bb3 + amd_gpu.GC_BASE__INST0_SEG0
|
||||
regCOMPUTE_USER_DATA_0 = 0x1be0 + amd_gpu.GC_BASE__INST0_SEG0
|
||||
regCOMPUTE_NUM_THREAD_X = 0x1ba7 + amd_gpu.GC_BASE__INST0_SEG0
|
||||
regGRBM_GFX_INDEX = 0x2200 + amd_gpu.GC_BASE__INST0_SEG1
|
||||
@@ -179,14 +180,16 @@ class PM4Executor(AMDQueue):
|
||||
prg_addr = (self.gpu.regs[regCOMPUTE_PGM_LO] + (self.gpu.regs[regCOMPUTE_PGM_LO + 1] << 32)) << 8
|
||||
args_addr = self.gpu.regs[regCOMPUTE_USER_DATA_0] + (self.gpu.regs[regCOMPUTE_USER_DATA_0 + 1] << 32)
|
||||
lc = [self.gpu.regs[i] for i in range(regCOMPUTE_NUM_THREAD_X, regCOMPUTE_NUM_THREAD_X+3)]
|
||||
rsrc2 = self.gpu.regs[regCOMPUTE_PGM_RSRC2]
|
||||
|
||||
prg_sz = 0
|
||||
for st,sz in self.gpu.mapped_ranges:
|
||||
if st <= prg_addr < st+sz: prg_sz = sz - (prg_addr - st)
|
||||
|
||||
assert prg_sz > 0, "Invalid prg ptr (not found in mapped ranges)"
|
||||
# Pass valid memory ranges to Python emulator for bounds checking
|
||||
# Pass valid memory ranges and rsrc2 to Python emulator for bounds checking and SGPR layout
|
||||
if hasattr(remu, 'valid_mem_ranges'): remu.valid_mem_ranges = self.gpu.mapped_ranges
|
||||
if hasattr(remu, 'rsrc2'): remu.rsrc2 = rsrc2
|
||||
err = remu.run_asm(prg_addr, prg_sz, *gl, *lc, args_addr)
|
||||
if err != 0: raise RuntimeError("remu does not support the new instruction introduced in this kernel")
|
||||
|
||||
|
||||
@@ -18,12 +18,13 @@ def _try_dlopen_gpuocelot():
|
||||
class PythonRemu:
|
||||
"""Python RDNA3 emulator wrapper that matches the libremu.so interface."""
|
||||
valid_mem_ranges: set[tuple[int, int]] = set()
|
||||
rsrc2: int = 0x19c # Default: USER_SGPR_COUNT=14, enable X and Y workgroup IDs
|
||||
|
||||
def run_asm(self, lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int) -> int:
|
||||
from extra.assembly.rdna3.emu import run_asm, set_valid_mem_ranges
|
||||
# Pad ranges to handle GPU loads that may read past small buffers (e.g. s_load_b128 on 12-byte buffer)
|
||||
set_valid_mem_ranges({(start, size + 4096) for start, size in self.valid_mem_ranges})
|
||||
return run_asm(lib, lib_sz, gx, gy, gz, lx, ly, lz, args_ptr)
|
||||
return run_asm(lib, lib_sz, gx, gy, gz, lx, ly, lz, args_ptr, self.rsrc2)
|
||||
|
||||
def _try_dlopen_remu():
|
||||
# Use Python emulator only if PYTHON_REMU=1
|
||||
|
||||
@@ -641,14 +641,14 @@ class AMDQueueDesc:
|
||||
def read_ptr(self): return min(p[0] for p in self.read_ptrs)
|
||||
|
||||
def signal_doorbell(self, dev, doorbell_value:int|None=None):
|
||||
for write_ptr in self.write_ptrs: write_ptr[0] = self.put_value
|
||||
|
||||
# Ensure all prior writes are visible to the GPU.
|
||||
System.memory_barrier()
|
||||
|
||||
# Flush hdp if queue is in dev mem.
|
||||
if dev.is_am() and not dev.is_usb(): dev.iface.dev_impl.gmc.flush_hdp()
|
||||
try:
|
||||
for write_ptr in self.write_ptrs: write_ptr[0] = self.put_value
|
||||
|
||||
# Ensure all prior writes are visible to the GPU.
|
||||
System.memory_barrier()
|
||||
|
||||
# Flush hdp if queue is in dev mem.
|
||||
if dev.is_am() and not dev.is_usb(): dev.iface.dev_impl.gmc.flush_hdp()
|
||||
for doorbell in self.doorbells: doorbell[0] = self.put_value if doorbell_value is None else doorbell_value
|
||||
except Exception as e:
|
||||
dev.error_state = e
|
||||
|
||||
Reference in New Issue
Block a user