assembly/amd: make the emu.py code shine (#13996)

* assembly/amd: make the code shine

* lil clean

* reg back in pcode

* cleanups

* gen fma_mix

* no writelane hacks

* fn cleanup

* dead vgpr_write

* readable

* smem

* cleanup bench_emu

* speedups

* simpler and faster

* direct inst._fn

* split fxn

* Revert "simpler and faster"

This reverts commit e85f6594b3.

* move lds to wavestate

* dispatcher

* pc in dispatch

* literal isn't wavestate

* cleanups + program

* one readlane

* exec_vop3sd in exec_vop

* cleaner exec_vopd

* fully merge VOP3P

* no special paths

* no SliceProxy

* low=0

* no bigint

* failing tests

* fma on python 3.13
This commit is contained in:
George Hotz
2026-01-03 23:33:09 -05:00
committed by GitHub
parent bdb421f13e
commit 8328511808
15 changed files with 14949 additions and 8550 deletions

View File

@@ -668,6 +668,7 @@ jobs:
key: rdna3-emu key: rdna3-emu
deps: testing_minimal deps: testing_minimal
amd: 'true' amd: 'true'
python-version: '3.13'
- name: Install LLVM 21 - name: Install LLVM 21
run: | run: |
wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import struct, math, re import struct, math, re
from enum import IntEnum from enum import IntEnum
from functools import cache, cached_property from functools import cache
from typing import overload, Annotated, TypeVar, Generic from typing import overload, Annotated, TypeVar, Generic
from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, SOP1Op, SOP2Op, from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, SOP1Op, SOP2Op,
SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp) SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp)
@@ -346,6 +346,7 @@ class Inst:
if 'abs_' in kwargs: kwargs['abs'] = kwargs.pop('abs_') if 'abs_' in kwargs: kwargs['abs'] = kwargs.pop('abs_')
orig_args = dict(zip(field_names, args)) | kwargs orig_args = dict(zip(field_names, args)) | kwargs
self._values.update(orig_args) self._values.update(orig_args)
self._precompute()
self._validate(orig_args) self._validate(orig_args)
# Pre-shift literal for 64-bit sources (literal param is always raw 32-bit value from user) # Pre-shift literal for 64-bit sources (literal param is always raw 32-bit value from user)
if literal is not None: if literal is not None:
@@ -386,6 +387,7 @@ class Inst:
elif name == 'sbase': self._values[name] = (val.idx if isinstance(val, Reg) else val.val if isinstance(val, SrcMod) else val * 2) // 2 elif name == 'sbase': self._values[name] = (val.idx if isinstance(val, Reg) else val.val if isinstance(val, SrcMod) else val * 2) // 2
elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg): self._values[name] = val.idx // 4 elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg): self._values[name] = val.idx // 4
elif marker is _VDSTYEnc and isinstance(val, VGPR): self._values[name] = val.idx >> 1 elif marker is _VDSTYEnc and isinstance(val, VGPR): self._values[name] = val.idx >> 1
self._precompute_fields()
def _encode_field(self, name: str, val) -> int: def _encode_field(self, name: str, val) -> int:
if isinstance(val, RawImm): return val.val if isinstance(val, RawImm): return val.val
@@ -450,6 +452,8 @@ class Inst:
inst = object.__new__(cls) inst = object.__new__(cls)
inst._values = {n: RawImm(v) if n in SRC_FIELDS else v for n, bf in cls._fields.items() if n != 'encoding' for v in [(word >> bf.lo) & bf.mask()]} inst._values = {n: RawImm(v) if n in SRC_FIELDS else v for n, bf in cls._fields.items() if n != 'encoding' for v in [(word >> bf.lo) & bf.mask()]}
inst._literal = None inst._literal = None
inst._precompute()
inst._precompute_fields()
return inst return inst
@classmethod @classmethod
@@ -510,25 +514,32 @@ class Inst:
'VOPD': VOPDOp, 'VINTERP': VINTERPOp} 'VOPD': VOPDOp, 'VINTERP': VINTERPOp}
_VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770} _VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
@property def _precompute(self):
def op(self): """Precompute op, op_name, _spec_regs, _spec_dtype for fast access."""
"""Return the op as an enum (e.g., VOP1Op.V_MOV_B32). VOP3 returns VOPCOp/VOP3SDOp for those op ranges."""
val = self._values.get('op') val = self._values.get('op')
if val is None: return None if val is None: self.op = None
if hasattr(val, 'name'): return val # already an enum elif hasattr(val, 'name'): self.op = val
cls_name = self.__class__.__name__ else:
assert cls_name in self._enum_map, f"no enum map for {cls_name}" cls_name = self.__class__.__name__
return self._enum_map[cls_name](val) # VOP3 with VOPC opcodes (0-255) -> VOPCOp, VOP3SD opcodes -> VOP3SDOp
if cls_name == 'VOP3':
try:
if val < 256: self.op = VOPCOp(val)
elif val in self._VOP3SD_OPS: self.op = VOP3SDOp(val)
else: self.op = VOP3Op(val)
except ValueError: self.op = val
elif cls_name in self._enum_map:
try: self.op = self._enum_map[cls_name](val)
except ValueError: self.op = val
else: self.op = val
self.op_name = self.op.name if hasattr(self.op, 'name') else ''
self._spec_regs = spec_regs(self.op_name)
self._spec_dtype = spec_dtype(self.op_name)
@cached_property def _precompute_fields(self):
def op_name(self) -> str: """Unwrap all field values as direct attributes for fast access."""
op = self.op for name, val in self._values.items():
return op.name if hasattr(op, 'name') else '' if name != 'op': setattr(self, name, unwrap(val))
@cached_property
def _spec_regs(self) -> tuple[int, int, int, int]: return spec_regs(self.op_name)
@cached_property
def _spec_dtype(self) -> tuple[str | None, str | None, str | None, str | None]: return spec_dtype(self.op_name)
def dst_regs(self) -> int: return self._spec_regs[0] def dst_regs(self) -> int: return self._spec_regs[0]
def src_regs(self, n: int) -> int: return self._spec_regs[n + 1] def src_regs(self, n: int) -> int: return self._spec_regs[n + 1]
def num_srcs(self) -> int: return spec_num_srcs(self.op_name) def num_srcs(self) -> int: return spec_num_srcs(self.op_name)

View File

@@ -1,15 +1,14 @@
# RDNA3 emulator - executes compiled pseudocode from AMD ISA PDF # RDNA3 emulator - executes compiled pseudocode from AMD ISA PDF
# mypy: ignore-errors # mypy: ignore-errors
from __future__ import annotations from __future__ import annotations
import ctypes import ctypes, functools
from tinygrad.runtime.autogen import hsa
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64 from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
from extra.assembly.amd.pcode import Reg
from extra.assembly.amd.asm import detect_format from extra.assembly.amd.asm import detect_format
from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions from extra.assembly.amd.autogen.rdna3.gen_pcode import COMPILED_FUNCTIONS
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD, from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, SCRATCHOp, VOPDOp) SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, SCRATCHOp, VOPDOp)
Program = dict[int, Inst]
WAVE_SIZE, SGPR_COUNT, VGPR_COUNT = 32, 128, 256 WAVE_SIZE, SGPR_COUNT, VGPR_COUNT = 32, 128, 256
VCC_LO, VCC_HI, NULL, EXEC_LO, EXEC_HI, SCC = SrcEnum.VCC_LO, SrcEnum.VCC_HI, SrcEnum.NULL, SrcEnum.EXEC_LO, SrcEnum.EXEC_HI, SrcEnum.SCC VCC_LO, VCC_HI, NULL, EXEC_LO, EXEC_HI, SCC = SrcEnum.VCC_LO, SrcEnum.VCC_HI, SrcEnum.NULL, SrcEnum.EXEC_LO, SrcEnum.EXEC_HI, SrcEnum.SCC
@@ -29,6 +28,41 @@ def _dst16(cur: int, val: int, is_hi: bool) -> int: return (cur & 0x0000ffff) |
def _vgpr_hi(src: int) -> bool: return src >= 256 and ((src - 256) & 0x80) != 0 def _vgpr_hi(src: int) -> bool: return src >= 256 and ((src - 256) & 0x80) != 0
def _vgpr_masked(src: int) -> int: return ((src - 256) & 0x7f) + 256 if src >= 256 else src def _vgpr_masked(src: int) -> int: return ((src - 256) & 0x7f) + 256 if src >= 256 else src
# VOP3 source modifier: apply abs/neg to value
def _mod_src(val: int, idx: int, neg: int, abs_: int, is64: bool = False) -> int:
to_f, to_i = (_f64, _i64) if is64 else (_f32, _i32)
if (abs_ >> idx) & 1: val = to_i(abs(to_f(val)))
if (neg >> idx) & 1: val = to_i(-to_f(val))
return val
# Read source operand with VOP3 modifiers
def _read_src(st, inst, src, idx: int, lane: int, neg: int, abs_: int, opsel: int) -> int:
if src is None: return 0
literal, regs, is_src_16 = inst._literal, inst.src_regs(idx), inst.is_src_16(idx)
if regs == 2: return _mod_src(st.rsrc64(src, lane, literal), idx, neg, abs_, is64=True)
if isinstance(inst, VOP3P):
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
if 'FMA_MIX' in inst.op_name:
raw = st.rsrc(src, lane, literal)
sign_bit = (15 if not (opsel & (1 << idx)) else 31) if (opsel_hi >> idx) & 1 else 31
if inst.neg_hi & (1 << idx): raw &= ~(1 << sign_bit)
if neg & (1 << idx): raw ^= (1 << sign_bit)
return raw
raw = st.rsrc_f16(src, lane, literal)
hi = _src16(raw, opsel_hi & (1 << idx)) ^ (0x8000 if inst.neg_hi & (1 << idx) else 0)
lo = _src16(raw, opsel & (1 << idx)) ^ (0x8000 if neg & (1 << idx) else 0)
return (hi << 16) | lo
if is_src_16 and isinstance(inst, VOP3):
raw = st.rsrc_f16(src, lane, literal) if 128 <= src < 255 else st.rsrc(src, lane, literal)
val = _src16(raw, bool(opsel & (1 << idx)))
if abs_ & (1 << idx): val &= 0x7fff
if neg & (1 << idx): val ^= 0x8000
return val
if is_src_16 and isinstance(inst, (VOP1, VOP2, VOPC)):
if src >= 256: return _src16(_mod_src(st.rsrc(_vgpr_masked(src), lane, literal), idx, neg, abs_), _vgpr_hi(src))
return _mod_src(st.rsrc_f16(src, lane, literal), idx, neg, abs_) & 0xffff
return _mod_src(st.rsrc(src, lane, literal), idx, neg, abs_)
# Helper: get number of dwords from memory op name # Helper: get number of dwords from memory op name
def _op_ndwords(name: str) -> int: def _op_ndwords(name: str) -> int:
if '_B128' in name: return 4 if '_B128' in name: return 4
@@ -36,8 +70,8 @@ def _op_ndwords(name: str) -> int:
if any(s in name for s in ('_B64', '_U64', '_I64', '_F64')): return 2 if any(s in name for s in ('_B64', '_U64', '_I64', '_F64')): return 2
return 1 return 1
# Helper: build multi-dword Reg from consecutive VGPRs # Helper: build multi-dword int from consecutive VGPRs
def _vgpr_read(V: list, base: int, ndwords: int) -> Reg: return Reg(sum(V[base + i] << (32 * i) for i in range(ndwords))) def _vgpr_read(V: list, base: int, ndwords: int) -> int: return sum(V[base + i] << (32 * i) for i in range(ndwords))
# Helper: write multi-dword value to consecutive VGPRs # Helper: write multi-dword value to consecutive VGPRs
def _vgpr_write(V: list, base: int, val: int, ndwords: int): def _vgpr_write(V: list, base: int, val: int, ndwords: int):
@@ -88,7 +122,8 @@ class LDSMem:
if addr + size <= len(self._lds): self._lds[addr:addr+size] = (int(val) & ((1 << (size*8)) - 1)).to_bytes(size, 'little') if addr + size <= len(self._lds): self._lds[addr:addr+size] = (int(val) & ((1 << (size*8)) - 1)).to_bytes(size, 'little')
def __getitem__(self, addr): return _make_mem_accessor(self._read, self._write)(addr) def __getitem__(self, addr): return _make_mem_accessor(self._read, self._write)(addr)
SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16} # SMEM dst register count (for writing result back to SGPRs)
SMEM_DST_COUNT = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
# VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup) # VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup)
_VOPD_TO_VOP = { _VOPD_TO_VOP = {
@@ -100,19 +135,12 @@ _VOPD_TO_VOP = {
VOPDOp.V_DUAL_ADD_NC_U32: VOP3Op.V_ADD_NC_U32, VOPDOp.V_DUAL_LSHLREV_B32: VOP3Op.V_LSHLREV_B32, VOPDOp.V_DUAL_AND_B32: VOP3Op.V_AND_B32, VOPDOp.V_DUAL_ADD_NC_U32: VOP3Op.V_ADD_NC_U32, VOPDOp.V_DUAL_LSHLREV_B32: VOP3Op.V_LSHLREV_B32, VOPDOp.V_DUAL_AND_B32: VOP3Op.V_AND_B32,
} }
# Compiled pseudocode functions (lazy loaded)
_COMPILED: dict | None = None
def _get_compiled() -> dict:
global _COMPILED
if _COMPILED is None: _COMPILED = get_compiled_functions()
return _COMPILED
class WaveState: class WaveState:
__slots__ = ('sgpr', 'vgpr', 'scc', 'pc', 'literal', '_pend_sgpr') __slots__ = ('sgpr', 'vgpr', 'scc', 'pc', '_pend_sgpr', 'lds', 'n_lanes')
def __init__(self): def __init__(self, lds: LDSMem | None = None, n_lanes: int = WAVE_SIZE):
self.sgpr, self.vgpr = [0] * SGPR_COUNT, [[0] * VGPR_COUNT for _ in range(WAVE_SIZE)] self.sgpr, self.vgpr = [0] * SGPR_COUNT, [[0] * VGPR_COUNT for _ in range(WAVE_SIZE)]
self.sgpr[EXEC_LO], self.scc, self.pc, self.literal, self._pend_sgpr = 0xffffffff, 0, 0, 0, {} self.sgpr[EXEC_LO], self.scc, self.pc, self._pend_sgpr, self.lds, self.n_lanes = 0xffffffff, 0, 0, {}, lds, n_lanes
@property @property
def vcc(self) -> int: return self.sgpr[VCC_LO] | (self.sgpr[VCC_HI] << 32) def vcc(self) -> int: return self.sgpr[VCC_LO] | (self.sgpr[VCC_HI] << 32)
@@ -129,18 +157,18 @@ class WaveState:
def rsgpr64(self, i: int) -> int: return self.rsgpr(i) | (self.rsgpr(i+1) << 32) def rsgpr64(self, i: int) -> int: return self.rsgpr(i) | (self.rsgpr(i+1) << 32)
def wsgpr64(self, i: int, v: int): self.wsgpr(i, v & MASK32); self.wsgpr(i+1, (v >> 32) & MASK32) def wsgpr64(self, i: int, v: int): self.wsgpr(i, v & MASK32); self.wsgpr(i+1, (v >> 32) & MASK32)
def _rsrc_base(self, v: int, lane: int, consts): def _rsrc_base(self, v: int, lane: int, consts, literal: int):
if v < SGPR_COUNT: return self.sgpr[v] if v < SGPR_COUNT: return self.sgpr[v]
if v == SCC: return self.scc if v == SCC: return self.scc
if v < 255: return consts[v - 128] if v < 255: return consts[v - 128]
if v == 255: return self.literal if v == 255: return literal
return self.vgpr[lane][v - 256] if v <= 511 else 0 return self.vgpr[lane][v - 256] if v <= 511 else 0
def rsrc(self, v: int, lane: int) -> int: return self._rsrc_base(v, lane, _INLINE_CONSTS) def rsrc(self, v: int, lane: int, literal: int = 0) -> int: return self._rsrc_base(v, lane, _INLINE_CONSTS, literal)
def rsrc_f16(self, v: int, lane: int) -> int: return self._rsrc_base(v, lane, _INLINE_CONSTS_F16) def rsrc_f16(self, v: int, lane: int, literal: int = 0) -> int: return self._rsrc_base(v, lane, _INLINE_CONSTS_F16, literal)
def rsrc64(self, v: int, lane: int) -> int: def rsrc64(self, v: int, lane: int, literal: int = 0) -> int:
if 128 <= v < 255: return _INLINE_CONSTS_F64[v - 128] if 128 <= v < 255: return _INLINE_CONSTS_F64[v - 128]
if v == 255: return self.literal # literal is already shifted in from_bytes for 64-bit ops if v == 255: return literal # literal is already shifted in from_bytes for 64-bit ops
return self.rsrc(v, lane) | ((self.rsrc(v+1, lane) if v < VCC_LO or 256 <= v <= 511 else 0) << 32) return self.rsrc(v, lane, literal) | ((self.rsrc(v+1, lane, literal) if v < VCC_LO or 256 <= v <= 511 else 0) << 32)
def pend_sgpr_lane(self, reg: int, lane: int, val: int): def pend_sgpr_lane(self, reg: int, lane: int, val: int):
if reg not in self._pend_sgpr: self._pend_sgpr[reg] = 0 if reg not in self._pend_sgpr: self._pend_sgpr[reg] = 0
@@ -150,251 +178,130 @@ class WaveState:
self._pend_sgpr.clear() self._pend_sgpr.clear()
def decode_program(data: bytes) -> Program:
result: Program = {}
i = 0
while i < len(data):
try: inst_class = detect_format(data[i:])
except ValueError: break # stop at invalid instruction (padding/metadata after code)
if inst_class is None: i += 4; continue
base_size = inst_class._size()
# Pass enough data for potential 64-bit literal (base + 8 bytes max)
inst = inst_class.from_bytes(data[i:i+base_size+8])
for name, val in inst._values.items():
if name != 'op': setattr(inst, name, unwrap(val)) # skip op to preserve property access
inst._words = inst.size() // 4
result[i // 4] = inst
i += inst._words * 4
return result
# ═══════════════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════════════
# EXECUTION - All ALU ops use pseudocode from PDF # EXECUTION - All ops use pseudocode from PDF
# ═══════════════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════════════
def exec_scalar(st: WaveState, inst: Inst) -> int: def exec_scalar(st: WaveState, inst: Inst):
"""Execute scalar instruction. Returns PC delta or negative for special cases.""" """Execute scalar instruction. Returns 0 to continue execution."""
compiled = _get_compiled()
# SOPP: special cases for control flow that has no pseudocode
if isinstance(inst, SOPP):
if inst.op == SOPPOp.S_ENDPGM: return -1
if inst.op == SOPPOp.S_BARRIER: return -2
# SMEM: memory loads (not ALU)
if isinstance(inst, SMEM):
addr = st.rsgpr64(inst.sbase * 2) + _sext(inst.offset, 21)
if inst.soffset not in (NULL, 0x7f): addr += st.rsrc(inst.soffset, 0)
if (cnt := SMEM_LOAD.get(inst.op)) is None: raise NotImplementedError(f"SMEM op {inst.op}")
for i in range(cnt): st.wsgpr(inst.sdata + i, mem_read((addr + i * 4) & MASK64, 4))
return 0
# Get op enum and lookup compiled function # Get op enum and lookup compiled function
if isinstance(inst, SOP1): ssrc0, sdst = inst.ssrc0, inst.sdst if isinstance(inst, SMEM): ssrc0, sdst = None, None
elif isinstance(inst, SOP1): ssrc0, sdst = inst.ssrc0, inst.sdst
elif isinstance(inst, SOP2): ssrc0, sdst = inst.ssrc0, inst.sdst elif isinstance(inst, SOP2): ssrc0, sdst = inst.ssrc0, inst.sdst
elif isinstance(inst, SOPC): ssrc0, sdst = inst.ssrc0, None elif isinstance(inst, SOPC): ssrc0, sdst = inst.ssrc0, None
elif isinstance(inst, SOPK): ssrc0, sdst = inst.sdst, inst.sdst # sdst is both src and dst elif isinstance(inst, SOPK): ssrc0, sdst = inst.sdst, inst.sdst # sdst is both src and dst
elif isinstance(inst, SOPP): ssrc0, sdst = None, None elif isinstance(inst, SOPP): ssrc0, sdst = None, None
else: raise NotImplementedError(f"Unknown scalar type {type(inst)}") else: raise NotImplementedError(f"Unknown scalar type {type(inst)}")
# SOPP has gaps in the opcode enum - treat unknown opcodes as no-ops # SMEM: memory loads
try: op = inst.op if isinstance(inst, SMEM):
except ValueError: addr = st.rsgpr64(inst.sbase * 2) + _sext(inst.offset, 21)
if isinstance(inst, SOPP): return 0 if inst.soffset not in (NULL, 0x7f): addr += st.rsrc(inst.soffset, 0, inst._literal)
raise result = inst._fn(GlobalMem, addr & MASK64)
fn = compiled.get(type(op), {}).get(op) if 'SDATA' in result:
if fn is None: sdata = result['SDATA']
# SOPP instructions without pseudocode (waits, hints, nops) are no-ops for i in range(SMEM_DST_COUNT.get(inst.op, 1)): st.wsgpr(inst.sdata + i, (sdata >> (i * 32)) & MASK32)
if isinstance(inst, SOPP): return 0 st.pc += inst._words
raise NotImplementedError(f"{op.name} not in pseudocode") return 0
# Build context - use inst methods to determine operand sizes # Build context - use inst methods to determine operand sizes
s0 = st.rsrc64(ssrc0, 0) if inst.is_src_64(0) else (st.rsrc(ssrc0, 0) if not isinstance(inst, (SOPK, SOPP)) else (st.rsgpr(inst.sdst) if isinstance(inst, SOPK) else 0)) literal = inst._literal
s1 = st.rsrc64(inst.ssrc1, 0) if inst.is_src_64(1) else (st.rsrc(inst.ssrc1, 0) if isinstance(inst, (SOP2, SOPC)) else inst.simm16 if isinstance(inst, SOPK) else 0) s0 = st.rsrc64(ssrc0, 0, literal) if inst.is_src_64(0) else (st.rsrc(ssrc0, 0, literal) if not isinstance(inst, (SOPK, SOPP)) else (st.rsgpr(inst.sdst) if isinstance(inst, SOPK) else 0))
s1 = st.rsrc64(inst.ssrc1, 0, literal) if inst.is_src_64(1) else (st.rsrc(inst.ssrc1, 0, literal) if isinstance(inst, (SOP2, SOPC)) else inst.simm16 if isinstance(inst, SOPK) else 0)
d0 = st.rsgpr64(sdst) if inst.dst_regs() == 2 and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0) d0 = st.rsgpr64(sdst) if inst.dst_regs() == 2 and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0)
literal = inst.simm16 if isinstance(inst, (SOPK, SOPP)) else st.literal literal = inst.simm16 if isinstance(inst, (SOPK, SOPP)) else inst._literal
# Create Reg objects for compiled function - mask VCC/EXEC to 32 bits for wave32 # Call compiled function with int parameters
result = fn(Reg(s0), Reg(s1), None, Reg(d0), Reg(st.scc), Reg(st.vcc & MASK32), 0, Reg(st.exec_mask & MASK32), literal, None, PC=Reg(st.pc * 4)) result = inst._fn(s0, s1, 0, d0, st.scc, st.vcc & MASK32, 0, st.exec_mask & MASK32, literal, None, pc=st.pc * 4)
# Apply results - extract values from returned Reg objects # Apply results (already int values)
if sdst is not None and 'D0' in result: if sdst is not None and 'D0' in result:
(st.wsgpr64 if inst.dst_regs() == 2 else st.wsgpr)(sdst, result['D0']._val) (st.wsgpr64 if inst.dst_regs() == 2 else st.wsgpr)(sdst, result['D0'])
if 'SCC' in result: st.scc = result['SCC']._val & 1 if 'SCC' in result: st.scc = result['SCC'] & 1
if 'EXEC' in result: st.exec_mask = result['EXEC']._val if 'EXEC' in result: st.exec_mask = result['EXEC']
if 'PC' in result: if 'PC' in result:
# Convert absolute byte address to word delta # Convert absolute byte address to word offset
pc_val = result['PC']._val pc_val = result['PC']
new_pc = pc_val if pc_val < 0x8000000000000000 else pc_val - 0x10000000000000000 new_pc = pc_val if pc_val < 0x8000000000000000 else pc_val - 0x10000000000000000
new_pc_words = new_pc // 4 st.pc = new_pc // 4
return new_pc_words - st.pc - 1 # -1 because emulator adds inst_words (1 for scalar) else:
st.pc += inst._words
return 0 return 0
def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None) -> None: # ═══════════════════════════════════════════════════════════════════════════════
"""Execute vector instruction for one lane.""" # VECTOR INSTRUCTIONS
compiled = _get_compiled() # ═══════════════════════════════════════════════════════════════════════════════
V = st.vgpr[lane]
# Memory ops (FLAT/GLOBAL/SCRATCH and DS) - use generated pcode def exec_vopd(st: WaveState, inst, V: list, lane: int) -> None:
if isinstance(inst, (FLAT, DS)): """VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes)."""
op, vdst, op_name = inst.op, inst.vdst, inst.op.name literal, vdstx, vdsty = inst._literal, inst.vdstx, (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1)
fn, ndwords = compiled[type(op)][op], _op_ndwords(op_name) sx0, sx1, dx, sy0, sy1, dy = st.rsrc(inst.srcx0, lane, literal), V[inst.vsrcx1], V[vdstx], st.rsrc(inst.srcy0, lane, literal), V[inst.vsrcy1], V[vdsty]
if isinstance(inst, FLAT): opx, opy = _VOPD_TO_VOP[inst.opx], _VOPD_TO_VOP[inst.opy]
addr = V[inst.addr] | (V[inst.addr + 1] << 32) V[vdstx] = COMPILED_FUNCTIONS[type(opx)][opx](sx0, sx1, 0, dx, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0']
ADDR = (st.rsgpr64(inst.saddr) + V[inst.addr] + _sext(inst.offset, 13)) & MASK64 if inst.saddr not in (NULL, 0x7f) else (addr + _sext(inst.offset, 13)) & MASK64 V[vdsty] = COMPILED_FUNCTIONS[type(opy)][opy](sy0, sy1, 0, dy, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0']
# For loads, VDATA comes from vdst (preserves unwritten bits); for stores, from inst.data
vdata_src = vdst if 'LOAD' in op_name else inst.data
result = fn(GlobalMem, ADDR, _vgpr_read(V, vdata_src, ndwords), Reg(V[vdst]), Reg(0))
if 'VDATA' in result: _vgpr_write(V, vdst, result['VDATA']._val, ndwords)
if 'RETURN_DATA' in result: _vgpr_write(V, vdst, result['RETURN_DATA']._val, ndwords)
else: # DS
DATA0, DATA1 = _vgpr_read(V, inst.data0, ndwords), _vgpr_read(V, inst.data1, ndwords) if inst.data1 is not None else Reg(0)
result = fn(lds, Reg(V[inst.addr]), DATA0, DATA1, Reg(inst.offset0), Reg(inst.offset1), Reg(0))
if 'RETURN_DATA' in result and ('_RTN' in op_name or '_LOAD' in op_name):
_vgpr_write(V, vdst, result['RETURN_DATA']._val, ndwords * 2 if '_2ADDR_' in op_name else ndwords)
return
# VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes) def exec_flat(st: WaveState, inst, V: list, lane: int) -> None:
if isinstance(inst, VOPD): """FLAT/GLOBAL/SCRATCH memory ops."""
vdsty = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1) ndwords = _op_ndwords(inst.op_name)
inputs = [(inst.opx, st.rsrc(inst.srcx0, lane), V[inst.vsrcx1], V[inst.vdstx], inst.vdstx), addr = V[inst.addr] | (V[inst.addr + 1] << 32)
(inst.opy, st.rsrc(inst.srcy0, lane), V[inst.vsrcy1], V[vdsty], vdsty)] ADDR = (st.rsgpr64(inst.saddr) + V[inst.addr] + _sext(inst.offset, 13)) & MASK64 if inst.saddr not in (NULL, 0x7f) else (addr + _sext(inst.offset, 13)) & MASK64
def exec_vopd(vopd_op, s0, s1, d0): vdata_src = inst.vdst if 'LOAD' in inst.op_name else inst.data
op = _VOPD_TO_VOP[vopd_op] result = inst._fn(GlobalMem, ADDR, _vgpr_read(V, vdata_src, ndwords), V[inst.vdst])
return compiled[type(op)][op](Reg(s0), Reg(s1), None, Reg(d0), Reg(st.scc), Reg(st.vcc), lane, Reg(st.exec_mask), st.literal, None)['D0']._val if 'VDATA' in result: _vgpr_write(V, inst.vdst, result['VDATA'], ndwords)
for vopd_op, s0, s1, d0, dst in inputs: V[dst] = exec_vopd(vopd_op, s0, s1, d0) if 'RETURN_DATA' in result: _vgpr_write(V, inst.vdst, result['RETURN_DATA'], ndwords)
return
# VOP3SD: has extra scalar dest for carry output def exec_ds(st: WaveState, inst, V: list, lane: int) -> None:
if isinstance(inst, VOP3SD): """DS (LDS) memory ops."""
fn = compiled[VOP3SDOp][inst.op] ndwords = _op_ndwords(inst.op_name)
# Read sources based on register counts from inst properties data0, data1 = _vgpr_read(V, inst.data0, ndwords), _vgpr_read(V, inst.data1, ndwords) if inst.data1 is not None else 0
def rsrc_n(src, regs): return st.rsrc64(src, lane) if regs == 2 else st.rsrc(src, lane) result = inst._fn(st.lds, V[inst.addr], data0, data1, inst.offset0, inst.offset1)
s0, s1, s2 = rsrc_n(inst.src0, inst.src_regs(0)), rsrc_n(inst.src1, inst.src_regs(1)), rsrc_n(inst.src2, inst.src_regs(2)) if 'RETURN_DATA' in result and ('_RTN' in inst.op_name or '_LOAD' in inst.op_name):
# Carry-in ops use src2 as carry bitmask instead of VCC _vgpr_write(V, inst.vdst, result['RETURN_DATA'], ndwords * 2 if '_2ADDR_' in inst.op_name else ndwords)
vcc = st.rsgpr64(inst.src2) if 'CO_CI' in inst.op_name else st.vcc
result = fn(Reg(s0), Reg(s1), Reg(s2), Reg(V[inst.vdst]), Reg(st.scc), Reg(vcc), lane, Reg(st.exec_mask), st.literal, None)
d0_val = result['D0']._val
V[inst.vdst] = d0_val & MASK32
if inst.dst_regs() == 2: V[inst.vdst + 1] = (d0_val >> 32) & MASK32
if 'VCC' in result: st.pend_sgpr_lane(inst.sdst, lane, (result['VCC']._val >> lane) & 1)
return
# Get op enum and sources (None means "no source" for that operand) def exec_vop(st: WaveState, inst: Inst, V: list, lane: int) -> None:
# dst_hi: for VOP1/VOP2 16-bit dst ops, bit 7 of vdst indicates .h (high 16-bit) destination """VOP1/VOP2/VOP3/VOP3SD/VOP3P/VOPC: standard ALU ops."""
dst_hi = False if isinstance(inst, VOP3P):
if isinstance(inst, VOP1): src0, src1, src2, vdst, dst_hi = inst.src0, inst.src1, inst.src2, inst.vdst, False
if inst.op == VOP1Op.V_NOP: return neg, abs_, opsel = inst.neg, 0, inst.opsel
src0, src1, src2 = inst.src0, None, None elif isinstance(inst, VOP1):
dst_hi = (inst.vdst & 0x80) != 0 and inst.is_dst_16() src0, src1, src2, vdst = inst.src0, None, None, inst.vdst & 0x7f if inst.is_dst_16() else inst.vdst
vdst = inst.vdst & 0x7f if inst.is_dst_16() else inst.vdst neg, abs_, opsel, dst_hi = 0, 0, 0, (inst.vdst & 0x80) != 0 and inst.is_dst_16()
elif isinstance(inst, VOP2): elif isinstance(inst, VOP2):
src0, src1, src2 = inst.src0, inst.vsrc1 + 256, None src0, src1, src2, vdst = inst.src0, inst.vsrc1 + 256, None, inst.vdst & 0x7f if inst.is_dst_16() else inst.vdst
dst_hi = (inst.vdst & 0x80) != 0 and inst.is_dst_16() neg, abs_, opsel, dst_hi = 0, 0, 0, (inst.vdst & 0x80) != 0 and inst.is_dst_16()
vdst = inst.vdst & 0x7f if inst.is_dst_16() else inst.vdst elif isinstance(inst, (VOP3, VOP3SD)):
elif isinstance(inst, VOP3): src0, src1, src2, vdst = inst.src0, inst.src1, (None if isinstance(inst, VOP3) and inst.op.value < 256 else inst.src2), inst.vdst
# VOP3 ops 0-255 are VOPC comparisons encoded as VOP3 - inst.op returns VOPCOp for these neg, abs_, opsel, dst_hi = (inst.neg, inst.abs, inst.opsel, False) if isinstance(inst, VOP3) else (0, 0, 0, False)
src0, src1, src2, vdst = inst.src0, inst.src1, (None if inst.op.value < 256 else inst.src2), inst.vdst
elif isinstance(inst, VOPC): elif isinstance(inst, VOPC):
# For 16-bit VOPC, vsrc1 uses same encoding as VOP2 16-bit: bit 7 selects hi(1) or lo(0) half src0, src1, src2, vdst, neg, abs_, opsel, dst_hi = inst.src0, inst.vsrc1 + 256, None, VCC_LO, 0, 0, 0, False
# vsrc1 field is 8 bits: [6:0] = VGPR index, [7] = hi flag else:
src0, src1, src2, vdst = inst.src0, inst.vsrc1 + 256, None, VCC_LO raise NotImplementedError(f"exec_vop: unhandled instruction type {type(inst).__name__}")
elif isinstance(inst, VOP3P):
# VOP3P: Packed 16-bit operations using compiled functions
# WMMA: wave-level matrix multiply-accumulate (special handling - needs cross-lane access)
if 'WMMA' in inst.op_name:
if lane == 0: # Only execute once per wave, write results for all lanes
exec_wmma(st, inst, inst.op)
return
# V_FMA_MIX: Mixed precision FMA - opsel_hi controls f32(0) vs f16(1), opsel selects which f16 half
if 'FMA_MIX' in inst.op_name:
opsel, opsel_hi, opsel_hi2 = getattr(inst, 'opsel', 0), getattr(inst, 'opsel_hi', 0), getattr(inst, 'opsel_hi2', 0)
neg, abs_ = getattr(inst, 'neg', 0), getattr(inst, 'neg_hi', 0) # neg_hi reused as abs
raws = [st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane) if inst.src2 is not None else 0]
is_f16 = [opsel_hi & 1, opsel_hi & 2, opsel_hi2]
srcs = [_f16(_src16(raws[i], bool(opsel & (1<<i)))) if is_f16[i] else _f32(raws[i]) for i in range(3)]
for i in range(3):
if abs_ & (1<<i): srcs[i] = abs(srcs[i])
if neg & (1<<i): srcs[i] = -srcs[i]
result = srcs[0] * srcs[1] + srcs[2]
st.vgpr[lane][inst.vdst] = _i32(result) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result), inst.op == VOP3POp.V_FMA_MIXHI_F16)
return
# VOP3P packed ops: opsel selects halves for lo, opsel_hi for hi; neg toggles f16 sign
raws = [st.rsrc_f16(inst.src0, lane), st.rsrc_f16(inst.src1, lane), st.rsrc_f16(inst.src2, lane) if inst.src2 is not None else 0]
opsel, opsel_hi, opsel_hi2 = getattr(inst, 'opsel', 0), getattr(inst, 'opsel_hi', 3), getattr(inst, 'opsel_hi2', 1)
neg, neg_hi = getattr(inst, 'neg', 0), getattr(inst, 'neg_hi', 0)
hi_sels = [opsel_hi & 1, opsel_hi & 2, opsel_hi2]
srcs = [((_src16(raws[i], hi_sels[i]) ^ (0x8000 if neg_hi & (1<<i) else 0)) << 16) |
(_src16(raws[i], opsel & (1<<i)) ^ (0x8000 if neg & (1<<i) else 0)) for i in range(3)]
result = compiled[VOP3POp][inst.op](Reg(srcs[0]), Reg(srcs[1]), Reg(srcs[2]), Reg(0), Reg(st.scc), Reg(st.vcc), lane, Reg(st.exec_mask), st.literal, None)
st.vgpr[lane][inst.vdst] = result['D0']._val & MASK32
return
else: raise NotImplementedError(f"Unknown vector type {type(inst)}")
op_cls = type(inst.op) s0 = _read_src(st, inst, src0, 0, lane, neg, abs_, opsel)
if (fn := compiled.get(op_cls, {}).get(inst.op)) is None: raise NotImplementedError(f"{inst.op_name} not in pseudocode") s1 = _read_src(st, inst, src1, 1, lane, neg, abs_, opsel)
s2 = _read_src(st, inst, src2, 2, lane, neg, abs_, opsel)
if isinstance(inst, VOP2) and inst.is_16bit(): d0 = _src16(V[vdst], dst_hi)
elif inst.dst_regs() == 2: d0 = V[vdst] | (V[vdst + 1] << 32)
else: d0 = V[vdst]
# Read sources (with VOP3 modifiers if applicable) if isinstance(inst, VOP3SD) and 'CO_CI' in inst.op_name: vcc_for_fn = st.rsgpr64(inst.src2)
neg, abs_ = (getattr(inst, 'neg', 0), getattr(inst, 'abs', 0)) if isinstance(inst, VOP3) else (0, 0) elif isinstance(inst, VOP3) and inst.op in (VOP3Op.V_CNDMASK_B32, VOP3Op.V_CNDMASK_B16) and src2 is not None and src2 < 256: vcc_for_fn = st.rsgpr64(src2)
opsel = getattr(inst, 'opsel', 0) if isinstance(inst, VOP3) else 0 else: vcc_for_fn = st.vcc
def mod_src(val: int, idx: int, is64=False) -> int:
to_f, to_i = (_f64, _i64) if is64 else (_f32, _i32)
if (abs_ >> idx) & 1: val = to_i(abs(to_f(val)))
if (neg >> idx) & 1: val = to_i(-to_f(val))
return val
# Use inst methods to determine operand sizes (inst.is_src_16, inst.is_src_64, etc.)
is_vop2_16bit = isinstance(inst, VOP2) and inst.is_16bit()
# Read sources based on register counts and dtypes from inst properties
def read_src(src, idx, regs, is_src_16):
if src is None: return 0
if regs == 2: return mod_src(st.rsrc64(src, lane), idx, is64=True)
if is_src_16 and isinstance(inst, VOP3):
raw = st.rsrc_f16(src, lane) if 128 <= src < 255 else st.rsrc(src, lane)
val = _src16(raw, bool(opsel & (1 << idx)))
if abs_ & (1 << idx): val &= 0x7fff
if neg & (1 << idx): val ^= 0x8000
return val
if is_src_16 and isinstance(inst, (VOP1, VOP2, VOPC)):
if src >= 256: return _src16(mod_src(st.rsrc(_vgpr_masked(src), lane), idx), _vgpr_hi(src))
return mod_src(st.rsrc_f16(src, lane), idx) & 0xffff
return mod_src(st.rsrc(src, lane), idx)
s0 = read_src(src0, 0, inst.src_regs(0), inst.is_src_16(0))
s1 = read_src(src1, 1, inst.src_regs(1), inst.is_src_16(1)) if src1 is not None else 0
s2 = read_src(src2, 2, inst.src_regs(2), inst.is_src_16(2)) if src2 is not None else 0
# Read destination (accumulator for VOP2 f16, 64-bit for 64-bit ops)
d0 = _src16(V[vdst], dst_hi) if is_vop2_16bit else (V[vdst] | (V[vdst + 1] << 32)) if inst.dst_regs() == 2 else V[vdst]
# V_CNDMASK_B32/B16: VOP3 encoding uses src2 as mask (not VCC); VOP2 uses VCC implicitly
# Pass the correct mask as vcc to the function so pseudocode VCC.u64[laneId] works correctly
vcc_for_fn = st.rsgpr64(src2) if inst.op in (VOP3Op.V_CNDMASK_B32, VOP3Op.V_CNDMASK_B16) and isinstance(inst, VOP3) and src2 is not None and src2 < 256 else st.vcc
# Execute compiled function - pass src0_idx and vdst_idx for lane instructions
# For VGPR access: src0 index is the VGPR number (src0 - 256 if VGPR, else src0 for SGPR)
src0_idx = (src0 - 256) if src0 is not None and src0 >= 256 else (src0 if src0 is not None else 0) src0_idx = (src0 - 256) if src0 is not None and src0 >= 256 else (src0 if src0 is not None else 0)
result = fn(Reg(s0), Reg(s1), Reg(s2), Reg(d0), Reg(st.scc), Reg(vcc_for_fn), lane, Reg(st.exec_mask), st.literal, st.vgpr, src0_idx, vdst) extra_kwargs = {'opsel': opsel, 'opsel_hi': inst.opsel_hi | (inst.opsel_hi2 << 2)} if isinstance(inst, VOP3P) and 'FMA_MIX' in inst.op_name else {}
result = inst._fn(s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, inst._literal, st.vgpr, src0_idx, vdst, **extra_kwargs)
# Apply results - extract values from returned Reg objects
if 'vgpr_write' in result:
# Lane instruction wrote to VGPR: (lane, vgpr_idx, value)
wr_lane, wr_idx, wr_val = result['vgpr_write']
st.vgpr[wr_lane][wr_idx] = wr_val
if 'VCC' in result: if 'VCC' in result:
# VOP2 carry ops write to VCC implicitly; VOPC/VOP3 write to vdst if isinstance(inst, VOP3SD): st.pend_sgpr_lane(inst.sdst, lane, (result['VCC'] >> lane) & 1)
st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, (result['VCC']._val >> lane) & 1) else: st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, (result['VCC'] >> lane) & 1)
if 'EXEC' in result: if 'EXEC' in result:
# V_CMPX instructions write to EXEC per-lane (not to vdst) st.pend_sgpr_lane(EXEC_LO, lane, (result['EXEC'] >> lane) & 1)
st.pend_sgpr_lane(EXEC_LO, lane, (result['EXEC']._val >> lane) & 1) elif isinstance(inst.op, VOPCOp):
elif op_cls is VOPCOp: st.pend_sgpr_lane(vdst, lane, (result['D0'] >> lane) & 1)
# VOPC comparison result stored in D0 bitmask, extract lane bit (non-CMPX only) if not isinstance(inst.op, VOPCOp):
st.pend_sgpr_lane(vdst, lane, (result['D0']._val >> lane) & 1) d0_val = result['D0']
if op_cls is not VOPCOp and 'vgpr_write' not in result: if inst.dst_regs() == 2: V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
writes_to_sgpr = 'READFIRSTLANE' in inst.op_name or 'READLANE' in inst.op_name elif not isinstance(inst, VOP3P) and inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi)
d0_val = result['D0']._val
if writes_to_sgpr: st.wsgpr(vdst, d0_val & MASK32)
elif inst.dst_regs() == 2: V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
elif inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi)
else: V[vdst] = d0_val & MASK32 else: V[vdst] = d0_val & MASK32
# ═══════════════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════════════
@@ -419,64 +326,102 @@ def exec_wmma(st: WaveState, inst, op: VOP3POp) -> None:
else: else:
for i in range(256): st.vgpr[i % 32][vdst + i//32] = _i32(mat_d[i]) for i in range(256): st.vgpr[i % 32][vdst + i//32] = _i32(mat_d[i])
# ═══════════════════════════════════════════════════════════════════════════════
# PROGRAM DECODE
# ═══════════════════════════════════════════════════════════════════════════════
# Wave-level dispatch functions: (st, inst) -> return_code (0 = continue, -1 = end, -2 = barrier)
def dispatch_endpgm(st, inst): return -1
def dispatch_barrier(st, inst): st.pc += inst._words; return -2
def dispatch_nop(st, inst): st.pc += inst._words; return 0
def dispatch_wmma(st, inst): exec_wmma(st, inst, inst.op); st.pc += inst._words; return 0
def dispatch_writelane(st, inst): st.vgpr[st.rsrc(inst.src1, 0, inst._literal) & 0x1f][inst.vdst] = st.rsrc(inst.src0, 0, inst._literal) & MASK32; st.pc += inst._words; return 0
def dispatch_readlane(st, inst):
src0_idx = (inst.src0 - 256) if inst.src0 >= 256 else inst.src0
s1 = st.rsrc(inst.src1, 0, inst._literal) if getattr(inst, 'src1', None) is not None else 0
result = inst._fn(0, s1, 0, 0, st.scc, st.vcc, 0, st.exec_mask, inst._literal, st.vgpr, src0_idx, inst.vdst)
st.wsgpr(inst.vdst, result['D0'])
st.pc += inst._words; return 0
# Per-lane dispatch wrapper: wraps per-lane exec functions into wave-level dispatch
@functools.cache
def dispatch_lane(exec_fn):
def dispatch(st, inst):
exec_mask, vgpr, n_lanes = st.exec_mask, st.vgpr, st.n_lanes
for lane in range(n_lanes):
if exec_mask >> lane & 1: exec_fn(st, inst, vgpr[lane], lane)
st.commit_pends()
st.pc += inst._words
return 0
return dispatch
def decode_program(data: bytes) -> dict[int, Inst]:
result: dict[int, Inst] = {}
i = 0
while i < len(data):
try: inst_class = detect_format(data[i:])
except ValueError: break # stop at invalid instruction (padding/metadata after code)
inst = inst_class.from_bytes(data[i:i+inst_class._size()+8]) # +8 for potential 64-bit literal
inst._words = inst.size() // 4
# Determine dispatch function and pcode function
fn = COMPILED_FUNCTIONS.get(type(inst.op), {}).get(inst.op)
if isinstance(inst, SOPP) and inst.op == SOPPOp.S_ENDPGM: inst._dispatch = dispatch_endpgm
elif isinstance(inst, SOPP) and inst.op == SOPPOp.S_BARRIER: inst._dispatch = dispatch_barrier
elif isinstance(inst, (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM)): inst._dispatch = exec_scalar
elif isinstance(inst, VOP1) and inst.op == VOP1Op.V_NOP: inst._dispatch = dispatch_nop
elif isinstance(inst, VOP3P) and 'WMMA' in inst.op_name: inst._dispatch = dispatch_wmma
elif isinstance(inst, VOP3) and inst.op == VOP3Op.V_WRITELANE_B32: inst._dispatch = dispatch_writelane
elif isinstance(inst, (VOP1, VOP3)) and inst.op in (VOP1Op.V_READFIRSTLANE_B32, VOP3Op.V_READFIRSTLANE_B32, VOP3Op.V_READLANE_B32): inst._dispatch = dispatch_readlane
elif isinstance(inst, VOPD): inst._dispatch = dispatch_lane(exec_vopd)
elif isinstance(inst, FLAT): inst._dispatch = dispatch_lane(exec_flat)
elif isinstance(inst, DS): inst._dispatch = dispatch_lane(exec_ds)
else: inst._dispatch = dispatch_lane(exec_vop)
# Validate pcode exists for instructions that need it (scalar/wave-level ops and VOPD don't need pcode)
needs_pcode = inst._dispatch not in (dispatch_endpgm, dispatch_barrier, exec_scalar, dispatch_nop, dispatch_wmma,
dispatch_writelane, dispatch_readlane, dispatch_lane(exec_vopd))
if fn is None and inst.op_name and needs_pcode: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
inst._fn = fn if fn else lambda *args, **kwargs: {}
result[i // 4] = inst
i += inst._words * 4
return result
# ═══════════════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════════════
# MAIN EXECUTION LOOP # MAIN EXECUTION LOOP
# ═══════════════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════════════
def step_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int: def exec_wave(program: dict[int, Inst], st: WaveState) -> int:
inst = program.get(st.pc) while (inst := program.get(st.pc)) and (result := inst._dispatch(st, inst)) == 0: pass
if inst is None: return 1 return result
inst_words, st.literal = inst._words, getattr(inst, '_literal', None) or 0
if isinstance(inst, (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM)): def exec_workgroup(program: dict[int, Inst], workgroup_id: tuple[int, int, int], local_size: tuple[int, int, int], args_ptr: int, rsrc2: int) -> None:
delta = exec_scalar(st, inst)
if delta == -1: return -1 # endpgm
if delta == -2: st.pc += inst_words; return -2 # barrier
st.pc += inst_words + delta
else:
# V_READFIRSTLANE/V_READLANE write to SGPR, execute once; others execute per-lane with exec_mask
is_readlane = isinstance(inst, (VOP1, VOP3)) and ('READFIRSTLANE' in inst.op_name or 'READLANE' in inst.op_name)
exec_mask = 1 if is_readlane else st.exec_mask
for lane in range(1 if is_readlane else n_lanes):
if exec_mask & (1 << lane): exec_vector(st, inst, lane, lds)
st.commit_pends()
st.pc += inst_words
return 0
def exec_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int:
while st.pc in program:
result = step_wave(program, st, lds, n_lanes)
if result == -1: return 0
if result == -2: return -2
return 0
def exec_workgroup(program: Program, workgroup_id: tuple[int, int, int], local_size: tuple[int, int, int], args_ptr: int,
wg_id_sgpr_base: int, wg_id_enables: tuple[bool, bool, bool]) -> None:
lx, ly, lz = local_size lx, ly, lz = local_size
total_threads, lds = lx * ly * lz, LDSMem(bytearray(65536)) total_threads = lx * ly * lz
waves: list[tuple[WaveState, int, int]] = [] # GRANULATED_LDS_SIZE is in 512-byte units (see ops_amd.py: lds_size = ((group_segment_size + 511) // 512))
lds_size = ((rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE_SHIFT) * 512
lds = LDSMem(bytearray(lds_size)) if lds_size else None
waves: list[WaveState] = []
for wave_start in range(0, total_threads, WAVE_SIZE): for wave_start in range(0, total_threads, WAVE_SIZE):
n_lanes, st = min(WAVE_SIZE, total_threads - wave_start), WaveState() n_lanes = min(WAVE_SIZE, total_threads - wave_start)
st = WaveState(lds, n_lanes)
st.exec_mask = (1 << n_lanes) - 1 st.exec_mask = (1 << n_lanes) - 1
st.wsgpr64(0, args_ptr) st.wsgpr64(0, args_ptr) # s[0:1] = kernel arguments pointer
# Set workgroup IDs in SGPRs based on USER_SGPR_COUNT and enable flags from COMPUTE_PGM_RSRC2 # COMPUTE_PGM_RSRC2: USER_SGPR_COUNT is where workgroup IDs start, ENABLE_SGPR_WORKGROUP_ID_X/Y/Z control which are passed
sgpr_idx = wg_id_sgpr_base sgpr_idx = (rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT_SHIFT
for wg_id, enabled in zip(workgroup_id, wg_id_enables): if rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_X: st.sgpr[sgpr_idx] = workgroup_id[0]; sgpr_idx += 1
if enabled: st.sgpr[sgpr_idx] = wg_id; sgpr_idx += 1 if rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Y: st.sgpr[sgpr_idx] = workgroup_id[1]; sgpr_idx += 1
# Set workitem IDs in VGPR0 using packed method: v0 = (Z << 20) | (Y << 10) | X if rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Z: st.sgpr[sgpr_idx] = workgroup_id[2]
for i in range(n_lanes): # VGPR0 = packed workitem IDs: (Z << 20) | (Y << 10) | X
tid = wave_start + i for tid in range(wave_start, wave_start + n_lanes):
st.vgpr[i][0] = ((tid // (lx * ly)) << 20) | (((tid // lx) % ly) << 10) | (tid % lx) st.vgpr[tid - wave_start][0] = ((tid // (lx * ly)) << 20) | (((tid // lx) % ly) << 10) | (tid % lx)
waves.append((st, n_lanes, wave_start)) waves.append(st)
has_barrier = any(isinstance(inst, SOPP) and inst.op == SOPPOp.S_BARRIER for inst in program.values()) while waves:
for _ in range(2 if has_barrier else 1): waves = [st for st in waves if exec_wave(program, st) != -1]
for st, n_lanes, _ in waves: exec_wave(program, st, lds, n_lanes)
def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int = 0x19c) -> int: def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int = 0x19c) -> int:
program = decode_program((ctypes.c_char * lib_sz).from_address(lib).raw) program = decode_program((ctypes.c_char * lib_sz).from_address(lib).raw)
if not program: return -1
wg_id_enables = tuple(bool((rsrc2 >> (7+i)) & 1) for i in range(3))
for gidz in range(gz): for gidz in range(gz):
for gidy in range(gy): for gidy in range(gy):
for gidx in range(gx): exec_workgroup(program, (gidx, gidy, gidz), (lx, ly, lz), args_ptr, (rsrc2 >> 1) & 0x1f, wg_id_enables) for gidx in range(gx): exec_workgroup(program, (gidx, gidy, gidz), (lx, ly, lz), args_ptr, rsrc2)
return 0 return 0

View File

@@ -1,6 +1,6 @@
# DSL for RDNA3 pseudocode - makes pseudocode expressions work directly as Python # DSL for RDNA3 pseudocode - makes pseudocode expressions work directly as Python
import struct, math import struct, math
from extra.assembly.amd.dsl import MASK32, MASK64, MASK128, _f32, _i32, _sext, _f16, _i16, _f64, _i64 from extra.assembly.amd.dsl import MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
# ═══════════════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════════════
# HELPER FUNCTIONS # HELPER FUNCTIONS
@@ -33,7 +33,9 @@ def _isquietnan(x): return _check_nan_type(x, 1, True) # quiet NaN has quiet bi
def _issignalnan(x): return _check_nan_type(x, 0, False) # signaling NaN has quiet bit = 0 def _issignalnan(x): return _check_nan_type(x, 0, False) # signaling NaN has quiet bit = 0
def _gt_neg_zero(a, b): return (a > b) or (a == 0 and b == 0 and not math.copysign(1, a) < 0 and math.copysign(1, b) < 0) def _gt_neg_zero(a, b): return (a > b) or (a == 0 and b == 0 and not math.copysign(1, a) < 0 and math.copysign(1, b) < 0)
def _lt_neg_zero(a, b): return (a < b) or (a == 0 and b == 0 and math.copysign(1, a) < 0 and not math.copysign(1, b) < 0) def _lt_neg_zero(a, b): return (a < b) or (a == 0 and b == 0 and math.copysign(1, a) < 0 and not math.copysign(1, b) < 0)
def _fma(a, b, c): return a * b + c def _fma(a, b, c):
try: return math.fma(a, b, c)
except ValueError: return float('nan') # inf * 0 + c is NaN per IEEE 754
def _signext(v): return v def _signext(v): return v
def _fpop(fn): def _fpop(fn):
def wrapper(x): def wrapper(x):
@@ -269,31 +271,6 @@ ROUND_MODE = _RoundMode()
def cvtToQuietNAN(x): return float('nan') def cvtToQuietNAN(x): return float('nan')
DST = None # Placeholder, will be set in context DST = None # Placeholder, will be set in context
# 2/PI with 1201 bits of precision for V_TRIG_PREOP_F64
# Computed as: int((2/pi) * 2^1201) - this is the fractional part of 2/pi scaled to integer
# The MSB (bit 1200) corresponds to 2^0 position in the fraction 0.b1200 b1199 ... b1 b0
_TWO_OVER_PI_1201_RAW = 0x0145f306dc9c882a53f84eafa3ea69bb81b6c52b3278872083fca2c757bd778ac36e48dc74849ba5c00c925dd413a32439fc3bd63962534e7dd1046bea5d768909d338e04d68befc827323ac7306a673e93908bf177bf250763ff12fffbc0b301fde5e2316b414da3eda6cfd9e4f96136e9e8c7ecd3cbfd45aea4f758fd7cbe2f67a0e73ef14a525d4d7f6bf623f1aba10ac06608df8f6
class _BigInt:
"""Wrapper for large integers that supports bit slicing [high:low]."""
__slots__ = ('_val',)
def __init__(self, val): self._val = val
def __getitem__(self, key):
if isinstance(key, slice):
high, low = key.start, key.stop
if high < low: high, low = low, high # Handle reversed slice
mask = (1 << (high - low + 1)) - 1
return (self._val >> low) & mask
return (self._val >> key) & 1
def __int__(self): return self._val
def __index__(self): return self._val
def __lshift__(self, n): return self._val << int(n)
def __rshift__(self, n): return self._val >> int(n)
def __and__(self, n): return self._val & int(n)
def __or__(self, n): return self._val | int(n)
TWO_OVER_PI_1201 = _BigInt(_TWO_OVER_PI_1201_RAW)
class _WaveMode: class _WaveMode:
IEEE = False IEEE = False
WAVE_MODE = _WaveMode() WAVE_MODE = _WaveMode()
@@ -312,14 +289,16 @@ class _Denorm:
f64 = _DenormChecker(64) f64 = _DenormChecker(64)
DENORM = _Denorm() DENORM = _Denorm()
class SliceProxy: class TypedView:
"""Proxy for D0[31:16] that supports .f16/.u16 etc getters and setters.""" """View into a Reg with typed access. Used for both full-width (Reg.u32) and slices (Reg[31:16])."""
__slots__ = ('_reg', '_high', '_low', '_reversed') __slots__ = ('_reg', '_high', '_low', '_signed', '_float', '_bf16', '_reversed')
def __init__(self, reg, high, low): def __init__(self, reg, high, low=0, signed=False, is_float=False, is_bf16=False):
self._reg = reg
# Handle reversed slices like [0:31] which means bit-reverse # Handle reversed slices like [0:31] which means bit-reverse
if high < low: self._high, self._low, self._reversed = low, high, True if high < low: high, low, reversed = low, high, True
else: self._high, self._low, self._reversed = high, low, False else: reversed = False
self._reg, self._high, self._low, self._reversed = reg, high, low, reversed
self._signed, self._float, self._bf16 = signed, is_float, is_bf16
def _nbits(self): return self._high - self._low + 1 def _nbits(self): return self._high - self._low + 1
def _mask(self): return (1 << self._nbits()) - 1 def _mask(self): return (1 << self._nbits()) - 1
def _get(self): def _get(self):
@@ -330,6 +309,12 @@ class SliceProxy:
if self._reversed: v = _brev(v, self._nbits()) if self._reversed: v = _brev(v, self._nbits())
self._reg._val = (self._reg._val & ~(self._mask() << self._low)) | ((v & self._mask()) << self._low) self._reg._val = (self._reg._val & ~(self._mask() << self._low)) | ((v & self._mask()) << self._low)
@property
def _val(self): return self._get()
@property
def _bits(self): return self._nbits()
# Type accessors for slices (e.g., D0[31:16].f16)
u8 = property(lambda s: s._get() & 0xff) u8 = property(lambda s: s._get() & 0xff)
u16 = property(lambda s: s._get() & 0xffff, lambda s, v: s._set(v)) u16 = property(lambda s: s._get() & 0xffff, lambda s, v: s._set(v))
u32 = property(lambda s: s._get() & MASK32, lambda s, v: s._set(v)) u32 = property(lambda s: s._get() & MASK32, lambda s, v: s._set(v))
@@ -340,33 +325,17 @@ class SliceProxy:
bf16 = property(lambda s: _bf16(s._get()), lambda s, v: s._set(v if isinstance(v, int) else _ibf16(float(v)))) bf16 = property(lambda s: _bf16(s._get()), lambda s, v: s._set(v if isinstance(v, int) else _ibf16(float(v))))
b16, b32 = u16, u32 b16, b32 = u16, u32
def __int__(self): return self._get() # Chained type access (e.g., jump_addr.i64 when jump_addr is already TypedView)
def __index__(self): return self._get()
# Comparison operators (compare as integers)
def __eq__(s, o): return s._get() == int(o)
def __ne__(s, o): return s._get() != int(o)
def __lt__(s, o): return s._get() < int(o)
def __le__(s, o): return s._get() <= int(o)
def __gt__(s, o): return s._get() > int(o)
def __ge__(s, o): return s._get() >= int(o)
class TypedView:
"""View for S0.u32 that supports [4:0] slicing and [bit] access."""
__slots__ = ('_reg', '_bits', '_signed', '_float', '_bf16')
def __init__(self, reg, bits, signed=False, is_float=False, is_bf16=False):
self._reg, self._bits, self._signed, self._float, self._bf16 = reg, bits, signed, is_float, is_bf16
@property @property
def _val(self): def i64(s): return s if s._nbits() == 64 and s._signed else int(s)
mask = MASK64 if self._bits == 64 else MASK32 if self._bits == 32 else (1 << self._bits) - 1 @property
return self._reg._val & mask def u64(s): return s if s._nbits() == 64 and not s._signed else int(s) & MASK64
def __getitem__(self, key): def __getitem__(self, key):
if isinstance(key, slice): if isinstance(key, slice):
high, low = int(key.start), int(key.stop) high, low = int(key.start), int(key.stop)
return SliceProxy(self._reg, high, low) return TypedView(self._reg, high, low)
return (self._val >> int(key)) & 1 return (self._get() >> int(key)) & 1
def __setitem__(self, key, value): def __setitem__(self, key, value):
if isinstance(key, slice): if isinstance(key, slice):
@@ -377,14 +346,16 @@ class TypedView:
elif value: self._reg._val |= (1 << int(key)) elif value: self._reg._val |= (1 << int(key))
else: self._reg._val &= ~(1 << int(key)) else: self._reg._val &= ~(1 << int(key))
def __int__(self): return _sext(self._val, self._bits) if self._signed else self._val def __int__(self): return _sext(self._get(), self._nbits()) if self._signed else self._get()
def __index__(self): return int(self) def __index__(self): return int(self)
def __trunc__(self): return int(float(self)) if self._float else int(self) def __trunc__(self): return int(float(self)) if self._float else int(self)
def __float__(self): def __float__(self):
if self._float: if self._float:
if self._bf16: return _bf16(self._val) # bf16 uses different conversion if self._bf16: return _bf16(self._get())
return _f16(self._val) if self._bits == 16 else _f32(self._val) if self._bits == 32 else _f64(self._val) bits = self._nbits()
return _f16(self._get()) if bits == 16 else _f32(self._get()) if bits == 32 else _f64(self._get())
return float(int(self)) return float(int(self))
def __bool__(s): return bool(int(s))
# Arithmetic - floats use float(), ints use int() # Arithmetic - floats use float(), ints use int()
def __add__(s, o): return float(s) + float(o) if s._float else int(s) + int(o) def __add__(s, o): return float(s) + float(o) if s._float else int(s) + int(o)
@@ -405,8 +376,8 @@ class TypedView:
def __or__(s, o): return int(s) | int(o) def __or__(s, o): return int(s) | int(o)
def __xor__(s, o): return int(s) ^ int(o) def __xor__(s, o): return int(s) ^ int(o)
def __invert__(s): return ~int(s) def __invert__(s): return ~int(s)
def __lshift__(s, o): n = int(o); return int(s) << n if 0 <= n < 64 else 0 def __lshift__(s, o): n = int(o); return int(s) << n if 0 <= n < 64 or s._nbits() > 64 else 0
def __rshift__(s, o): n = int(o); return int(s) >> n if 0 <= n < 64 else 0 def __rshift__(s, o): n = int(o); return int(s) >> n if 0 <= n < 64 or s._nbits() > 64 else 0
def __rand__(s, o): return int(o) & int(s) def __rand__(s, o): return int(o) & int(s)
def __ror__(s, o): return int(o) | int(s) def __ror__(s, o): return int(o) | int(s)
def __rxor__(s, o): return int(o) ^ int(s) def __rxor__(s, o): return int(o) ^ int(s)
@@ -425,51 +396,42 @@ class TypedView:
def __gt__(s, o): return float(s) > float(o) if s._float else int(s) > int(o) def __gt__(s, o): return float(s) > float(o) if s._float else int(s) > int(o)
def __ge__(s, o): return float(s) >= float(o) if s._float else int(s) >= int(o) def __ge__(s, o): return float(s) >= float(o) if s._float else int(s) >= int(o)
def __bool__(s): return bool(int(s)) SliceProxy = TypedView # Alias for compatibility
# Allow chained type access like jump_addr.i64 when jump_addr is already a TypedView
# These just return self or convert appropriately
@property
def i64(s): return s if s._bits == 64 and s._signed else int(s)
@property
def u64(s): return s if s._bits == 64 and not s._signed else int(s) & MASK64
@property
def i32(s): return s if s._bits == 32 and s._signed else _sext(int(s) & MASK32, 32)
@property
def u32(s): return s if s._bits == 32 and not s._signed else int(s) & MASK32
class Reg: class Reg:
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works. Supports up to 128 bits for DS_LOAD_B128.""" """GPU register: D0.f32 = S0.f32 + S1.f32 just works. Supports up to 128 bits for DS_LOAD_B128."""
__slots__ = ('_val',) __slots__ = ('_val',)
def __init__(self, val=0): self._val = int(val) & MASK128 def __init__(self, val=0): self._val = int(val)
# Typed views # Typed views - TypedView(reg, high, signed, is_float, is_bf16)
u64 = property(lambda s: TypedView(s, 64), lambda s, v: setattr(s, '_val', int(v) & MASK64)) u64 = property(lambda s: TypedView(s, 63), lambda s, v: setattr(s, '_val', int(v) & MASK64))
i64 = property(lambda s: TypedView(s, 64, signed=True), lambda s, v: setattr(s, '_val', int(v) & MASK64)) i64 = property(lambda s: TypedView(s, 63, signed=True), lambda s, v: setattr(s, '_val', int(v) & MASK64))
b64 = property(lambda s: TypedView(s, 64), lambda s, v: setattr(s, '_val', int(v) & MASK64)) b64 = property(lambda s: TypedView(s, 63), lambda s, v: setattr(s, '_val', int(v) & MASK64))
f64 = property(lambda s: TypedView(s, 64, is_float=True), lambda s, v: setattr(s, '_val', v if isinstance(v, int) else _i64(float(v)))) f64 = property(lambda s: TypedView(s, 63, is_float=True), lambda s, v: setattr(s, '_val', v if isinstance(v, int) else _i64(float(v))))
u32 = property(lambda s: TypedView(s, 32), lambda s, v: setattr(s, '_val', int(v) & MASK32)) u32 = property(lambda s: TypedView(s, 31), lambda s, v: setattr(s, '_val', int(v) & MASK32))
i32 = property(lambda s: TypedView(s, 32, signed=True), lambda s, v: setattr(s, '_val', int(v) & MASK32)) i32 = property(lambda s: TypedView(s, 31, signed=True), lambda s, v: setattr(s, '_val', int(v) & MASK32))
b32 = property(lambda s: TypedView(s, 32), lambda s, v: setattr(s, '_val', int(v) & MASK32)) b32 = property(lambda s: TypedView(s, 31), lambda s, v: setattr(s, '_val', int(v) & MASK32))
f32 = property(lambda s: TypedView(s, 32, is_float=True), lambda s, v: setattr(s, '_val', _i32(float(v)))) f32 = property(lambda s: TypedView(s, 31, is_float=True), lambda s, v: setattr(s, '_val', _i32(float(v))))
u24 = property(lambda s: TypedView(s, 24)) u24 = property(lambda s: TypedView(s, 23))
i24 = property(lambda s: TypedView(s, 24, signed=True)) i24 = property(lambda s: TypedView(s, 23, signed=True))
u16 = property(lambda s: TypedView(s, 16), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff))) u16 = property(lambda s: TypedView(s, 15), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
i16 = property(lambda s: TypedView(s, 16, signed=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff))) i16 = property(lambda s: TypedView(s, 15, signed=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
b16 = property(lambda s: TypedView(s, 16), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff))) b16 = property(lambda s: TypedView(s, 15), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
f16 = property(lambda s: TypedView(s, 16, is_float=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _i16(float(v))) & 0xffff))) f16 = property(lambda s: TypedView(s, 15, is_float=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _i16(float(v))) & 0xffff)))
bf16 = property(lambda s: TypedView(s, 16, is_float=True, is_bf16=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _ibf16(float(v))) & 0xffff))) bf16 = property(lambda s: TypedView(s, 15, is_float=True, is_bf16=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _ibf16(float(v))) & 0xffff)))
u8 = property(lambda s: TypedView(s, 8)) u8 = property(lambda s: TypedView(s, 7))
i8 = property(lambda s: TypedView(s, 8, signed=True)) i8 = property(lambda s: TypedView(s, 7, signed=True))
u1 = property(lambda s: TypedView(s, 1)) # single bit u3 = property(lambda s: TypedView(s, 2)) # 3-bit for opsel fields
u1 = property(lambda s: TypedView(s, 0)) # single bit
def __getitem__(s, key): def __getitem__(s, key):
if isinstance(key, slice): return SliceProxy(s, int(key.start), int(key.stop)) if isinstance(key, slice): return TypedView(s, int(key.start), int(key.stop))
return (s._val >> int(key)) & 1 return (s._val >> int(key)) & 1
def __setitem__(s, key, value): def __setitem__(s, key, value):
if isinstance(key, slice): if isinstance(key, slice):
high, low = int(key.start), int(key.stop) high, low = int(key.start), int(key.stop)
if high < low: high, low = low, high
mask = (1 << (high - low + 1)) - 1 mask = (1 << (high - low + 1)) - 1
s._val = (s._val & ~(mask << low)) | ((int(value) & mask) << low) s._val = (s._val & ~(mask << low)) | ((int(value) & mask) << low)
elif value: s._val |= (1 << int(key)) elif value: s._val |= (1 << int(key))
@@ -504,4 +466,5 @@ class Reg:
def __eq__(s, o): return s._val == int(o) def __eq__(s, o): return s._val == int(o)
def __ne__(s, o): return s._val != int(o) def __ne__(s, o): return s._val != int(o)
# 2/PI with 1201 bits of precision for V_TRIG_PREOP_F64
TWO_OVER_PI_1201 = Reg(0x0145f306dc9c882a53f84eafa3ea69bb81b6c52b3278872083fca2c757bd778ac36e48dc74849ba5c00c925dd413a32439fc3bd63962534e7dd1046bea5d768909d338e04d68befc827323ac7306a673e93908bf177bf250763ff12fffbc0b301fde5e2316b414da3eda6cfd9e4f96136e9e8c7ecd3cbfd45aea4f758fd7cbe2f67a0e73ef14a525d4d7f6bf623f1aba10ac06608df8f6)

View File

@@ -42,7 +42,7 @@ INST_PATTERN = re.compile(r'^([SVD]S?_[A-Z0-9_]+|(?:FLAT|GLOBAL|SCRATCH)_[A-Z0-9
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS', UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt', 'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
'CVT_OFF_TABLE', 'ThreadMask', 'CVT_OFF_TABLE', 'ThreadMask',
'S1[i', 'C.i32', 'S[i]', 'in[', 'S1[i', 'C.i32', 'thread_',
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST', 'if n.', 'DST.u32', 'addrd = DST', 'addr = DST',
'BARRIER_STATE', 'ReallocVgprs', 'BARRIER_STATE', 'ReallocVgprs',
'GPR_IDX', 'VSKIP', 'specified in', 'TTBL', 'GPR_IDX', 'VSKIP', 'specified in', 'TTBL',
@@ -477,7 +477,7 @@ def _generate_gen_pcode_py(enums, pseudocode, arch) -> str:
# Get op enums for this arch (import from .ins which re-exports from .enum) # Get op enums for this arch (import from .ins which re-exports from .enum)
import importlib import importlib
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}.ins") autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}.ins")
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp', 'DSOp', 'FLATOp', 'GLOBALOp', 'SCRATCHOp'] if hasattr(autogen, name)] OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'SMEMOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp', 'DSOp', 'FLATOp', 'GLOBALOp', 'SCRATCHOp'] if hasattr(autogen, name)]
# Build defined ops mapping # Build defined ops mapping
defined_ops: dict[tuple, list] = {} defined_ops: dict[tuple, list] = {}
@@ -513,20 +513,10 @@ def _generate_gen_pcode_py(enums, pseudocode, arch) -> str:
for op, fn_name in fn_entries: fn_lines.append(f" {cls_name}.{op.name}: {fn_name},") for op, fn_name in fn_entries: fn_lines.append(f" {cls_name}.{op.name}: {fn_name},")
fn_lines.append('}\n') fn_lines.append('}\n')
# Add V_WRITELANE_B32 if VOP3Op exists
if 'VOP3Op' in enum_names:
fn_lines.append('''
# V_WRITELANE_B32: Write scalar to specific lane's VGPR (not in PDF pseudocode)
def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):
wr_lane = s1 & 0x1f
return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)}
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32
''')
fn_lines.append('COMPILED_FUNCTIONS = {') fn_lines.append('COMPILED_FUNCTIONS = {')
for enum_cls in OP_ENUMS: for enum_cls in OP_ENUMS:
if all_fn_entries.get(enum_cls): fn_lines.append(f' {enum_cls.__name__}: {enum_cls.__name__}_FUNCTIONS,') if all_fn_entries.get(enum_cls): fn_lines.append(f' {enum_cls.__name__}: {enum_cls.__name__}_FUNCTIONS,')
fn_lines.append('}\n\ndef get_compiled_functions(): return COMPILED_FUNCTIONS') fn_lines.append('}')
# Second pass: scan generated code for pcode imports # Second pass: scan generated code for pcode imports
fn_code_str = '\n'.join(fn_lines) fn_code_str = '\n'.join(fn_lines)
@@ -576,83 +566,102 @@ def _apply_pseudocode_fixes(op, code: str) -> str:
return code return code
def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]: def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]:
"""Generate a single compiled pseudocode function.""" """Generate a single compiled pseudocode function.
Functions take int parameters and return dict of int values.
Reg wrapping happens inside the function, only for registers actually used."""
has_d1 = '{ D1' in pc has_d1 = '{ D1' in pc
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
is_div_scale = 'DIV_SCALE' in op.name is_div_scale = 'DIV_SCALE' in op.name
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale) has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
is_ds = cls_name == 'DSOp' is_ds = cls_name == 'DSOp'
is_flat = cls_name in ('FLATOp', 'GLOBALOp', 'SCRATCHOp') is_flat = cls_name in ('FLATOp', 'GLOBALOp', 'SCRATCHOp')
is_smem = cls_name == 'SMEMOp'
has_s_array = 'S[i]' in pc # FMA_MIX style: S[0], S[1], S[2] array access
combined = code + pc combined = code + pc
fn_name = f"_{cls_name}_{op.name}" fn_name = f"_{cls_name}_{op.name}"
# Function accepts Reg objects directly (uppercase names), laneId is passed directly as int
# DSOp functions get additional MEM and offset parameters
# FLAT/GLOBAL ops get MEM, vaddr, vdata, saddr, offset parameters
if is_ds:
lines = [f"def {fn_name}(MEM, ADDR, DATA0, DATA1, OFFSET0, OFFSET1, RETURN_DATA):"]
elif is_flat:
lines = [f"def {fn_name}(MEM, ADDR, VDATA, VDST, RETURN_DATA):"]
else:
lines = [f"def {fn_name}(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):"]
# Registers that need special handling (aliases or init) # Detect which registers are used/modified
def needs_init(name): return name in combined and not re.search(rf'^\s*{name}\s*=\s*Reg\(', code, re.MULTILINE) def needs_init(name): return name in combined and not re.search(rf'^\s*{name}\s*=\s*Reg\(', code, re.MULTILINE)
special_regs = []
if is_ds: special_regs = [('DATA', 'DATA0'), ('DATA2', 'DATA1'), ('OFFSET', 'OFFSET0'), ('ADDR_BASE', 'ADDR')]
elif is_flat: special_regs = [('DATA', 'VDATA')]
else:
special_regs = [('D1', 'Reg(0)'), ('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
used = {name for name, _ in special_regs if name in combined}
# Detect which registers are modified (not just read) - look for assignments
modifies_d0 = is_div_scale or bool(re.search(r'\bD0\b[.\[]', combined)) modifies_d0 = is_div_scale or bool(re.search(r'\bD0\b[.\[]', combined))
modifies_exec = is_cmpx or bool(re.search(r'EXEC\.(u32|u64|b32|b64)\s*=', combined)) modifies_exec = is_cmpx or bool(re.search(r'EXEC\.(u32|u64|b32|b64)\s*=', combined))
modifies_vcc = has_sdst or bool(re.search(r'VCC\.(u32|u64|b32|b64)\s*=|VCC\.u64\[laneId\]\s*=', combined)) modifies_vcc = has_sdst or bool(re.search(r'VCC\.(u32|u64|b32|b64)\s*=|VCC\.u64\[laneId\]\s*=', combined))
modifies_scc = bool(re.search(r'\bSCC\s*=', combined)) modifies_scc = bool(re.search(r'\bSCC\s*=', combined))
modifies_pc = bool(re.search(r'\bPC\s*=', combined)) modifies_pc = bool(re.search(r'\bPC\s*=', combined))
# DS/FLAT ops: detect memory writes (MEM[...] = ...)
modifies_mem = (is_ds or is_flat) and bool(re.search(r'MEM\[.*\]\.[a-z0-9]+\s*=', combined))
# FLAT ops: detect VDST writes
modifies_vdst = is_flat and bool(re.search(r'VDST[\.\[].*=', combined))
# Build init code for special registers # Build function signature and Reg init lines
init_lines = [] if is_smem:
if is_div_scale: init_lines.append(" D0 = Reg(S0._val)") lines = [f"def {fn_name}(MEM, addr):"]
reg_inits = ["ADDR=Reg(addr)", "SDATA=Reg(0)"]
special_regs = []
elif is_ds:
lines = [f"def {fn_name}(MEM, addr, data0, data1, offset0, offset1):"]
reg_inits = ["ADDR=Reg(addr)", "DATA0=Reg(data0)", "DATA1=Reg(data1)", "OFFSET0=Reg(offset0)", "OFFSET1=Reg(offset1)", "RETURN_DATA=Reg(0)"]
special_regs = [('DATA', 'DATA0'), ('DATA2', 'DATA1'), ('OFFSET', 'OFFSET0'), ('ADDR_BASE', 'ADDR')]
elif is_flat:
lines = [f"def {fn_name}(MEM, addr, vdata, vdst):"]
reg_inits = ["ADDR=addr", "VDATA=Reg(vdata)", "VDST=Reg(vdst)", "RETURN_DATA=Reg(0)"]
special_regs = [('DATA', 'VDATA')]
elif has_s_array:
# FMA_MIX style: needs S[i] array, opsel, opsel_hi for source selection (neg/neg_hi applied in emu.py before call)
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None, opsel=0, opsel_hi=0):"]
reg_inits = ["S0=Reg(s0)", "S1=Reg(s1)", "S2=Reg(s2)", "S=[S0,S1,S2]", "D0=Reg(d0)", "OPSEL=Reg(opsel)", "OPSEL_HI=Reg(opsel_hi)"]
special_regs = []
# Detect array declarations like "declare in : 32'F[3]" and create them (rename 'in' to 'ins' since 'in' is a keyword)
if "in[" in combined:
reg_inits.append("ins=[Reg(0),Reg(0),Reg(0)]")
code = code.replace("in[", "ins[")
else:
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None):"]
# Only create Regs for registers actually used in the pseudocode
reg_inits = []
if 'S0' in combined: reg_inits.append("S0=Reg(s0)")
if 'S1' in combined: reg_inits.append("S1=Reg(s1)")
if 'S2' in combined: reg_inits.append("S2=Reg(s2)")
if modifies_d0 or 'D0' in combined: reg_inits.append("D0=Reg(s0)" if is_div_scale else "D0=Reg(d0)")
if modifies_scc or 'SCC' in combined: reg_inits.append("SCC=Reg(scc)")
if modifies_vcc or 'VCC' in combined: reg_inits.append("VCC=Reg(vcc)")
if modifies_exec or 'EXEC' in combined: reg_inits.append("EXEC=Reg(exec_mask)")
if modifies_pc or 'PC' in combined: reg_inits.append("PC=Reg(pc) if pc is not None else None")
special_regs = [('D1', 'Reg(0)'), ('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
# Build init code
init_parts = reg_inits.copy()
for name, init in special_regs: for name, init in special_regs:
if name in used: init_lines.append(f" {name} = {init}") if name in combined: init_parts.append(f"{name}={init}")
if 'EXEC_LO' in code: init_lines.append(" EXEC_LO = SliceProxy(EXEC, 31, 0)") if 'EXEC_LO' in code: init_parts.append("EXEC_LO=SliceProxy(EXEC, 31, 0)")
if 'EXEC_HI' in code: init_lines.append(" EXEC_HI = SliceProxy(EXEC, 63, 32)") if 'EXEC_HI' in code: init_parts.append("EXEC_HI=SliceProxy(EXEC, 63, 32)")
if 'VCCZ' in code and not re.search(r'^\s*VCCZ\s*=', code, re.MULTILINE): init_lines.append(" VCCZ = Reg(1 if VCC._val == 0 else 0)") if 'VCCZ' in code and not re.search(r'^\s*VCCZ\s*=', code, re.MULTILINE): init_parts.append("VCCZ=Reg(1 if VCC._val == 0 else 0)")
if 'EXECZ' in code and not re.search(r'^\s*EXECZ\s*=', code, re.MULTILINE): init_lines.append(" EXECZ = Reg(1 if EXEC._val == 0 else 0)") if 'EXECZ' in code and not re.search(r'^\s*EXECZ\s*=', code, re.MULTILINE): init_parts.append("EXECZ=Reg(1 if EXEC._val == 0 else 0)")
code_lines = [line for line in code.split('\n') if line.strip()]
if init_lines:
lines.extend(init_lines)
if code_lines: lines.append(" # --- compiled pseudocode ---")
for line in code_lines:
lines.append(f" {line}")
# Build result dict - only include registers that are modified # Add init line and separator
if init_parts: lines.append(f" {'; '.join(init_parts)}")
lines.append(" # --- compiled pseudocode ---")
# Add compiled pseudocode
for line in code.split('\n'):
if line.strip(): lines.append(f" {line}")
# Build result dict
result_items = [] result_items = []
if modifies_d0: result_items.append("'D0': D0") if modifies_d0: result_items.append("'D0': D0._val")
if modifies_scc: result_items.append("'SCC': SCC") if modifies_scc: result_items.append("'SCC': SCC._val")
if modifies_vcc: result_items.append("'VCC': VCC") if modifies_vcc: result_items.append("'VCC': VCC._val")
if modifies_exec: result_items.append("'EXEC': EXEC") if modifies_exec: result_items.append("'EXEC': EXEC._val")
if has_d1: result_items.append("'D1': D1") if has_d1: result_items.append("'D1': D1._val")
if modifies_pc: result_items.append("'PC': PC") if modifies_pc: result_items.append("'PC': PC._val")
# DS ops: return RETURN_DATA if it was written (left side of assignment) if is_smem and 'SDATA' in combined and re.search(r'^\s*SDATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'SDATA': SDATA._val")
if is_ds and 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE): if is_ds and 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'RETURN_DATA': RETURN_DATA") result_items.append("'RETURN_DATA': RETURN_DATA._val")
# FLAT ops: return RETURN_DATA for atomics, VDATA for loads (only if written to)
if is_flat: if is_flat:
if 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE): if 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'RETURN_DATA': RETURN_DATA") result_items.append("'RETURN_DATA': RETURN_DATA._val")
if re.search(r'^\s*VDATA[\.\[].*=', code, re.MULTILINE): if re.search(r'^\s*VDATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'VDATA': VDATA") result_items.append("'VDATA': VDATA._val")
lines.append(f" return {{{', '.join(result_items)}}}\n") lines.append(f" return {{{', '.join(result_items)}}}\n")
return fn_name, '\n'.join(lines) return fn_name, '\n'.join(lines)

View File

@@ -1,13 +1,12 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Benchmark comparing Python vs Rust RDNA3 emulators on synthetic and real tinygrad kernels.""" """Benchmark comparing Python vs Rust RDNA3 emulators on real tinygrad kernels."""
import ctypes, time, os, struct, cProfile, pstats, io import ctypes, time, os
from pathlib import Path from pathlib import Path
from typing import Callable
# Set AMD=1 before importing tinygrad # Set AMD=1 before importing tinygrad
os.environ["AMD"] = "1" os.environ["AMD"] = "1"
from extra.assembly.amd.emu import run_asm as python_run_asm, set_valid_mem_ranges, decode_program, step_wave, WaveState, WAVE_SIZE from extra.assembly.amd.emu import run_asm as python_run_asm, set_valid_mem_ranges, decode_program
REMU_PATH = Path(__file__).parents[3] / "remu/target/release/libremu.so" REMU_PATH = Path(__file__).parents[3] / "remu/target/release/libremu.so"
if not REMU_PATH.exists(): if not REMU_PATH.exists():
@@ -42,7 +41,7 @@ def setup_buffers(buf_sizes: list[int], init_data: dict[int, bytes] | None = Non
ranges.add((args_ptr, ctypes.sizeof(args))) ranges.add((args_ptr, ctypes.sizeof(args)))
return buffers, args, args_ptr, ranges return buffers, args, args_ptr, ranges
def benchmark_emulator(name: str, run_fn, kernel: bytes, global_size, local_size, args_ptr, iterations: int = 5): def benchmark_emulator(name: str, run_fn, kernel: bytes, global_size, local_size, args_ptr, rsrc2: int, iterations: int = 5):
"""Benchmark an emulator and return average time.""" """Benchmark an emulator and return average time."""
gx, gy, gz = global_size gx, gy, gz = global_size
lx, ly, lz = local_size lx, ly, lz = local_size
@@ -50,13 +49,13 @@ def benchmark_emulator(name: str, run_fn, kernel: bytes, global_size, local_size
lib_ptr = ctypes.addressof(kernel_buf) lib_ptr = ctypes.addressof(kernel_buf)
# Warmup # Warmup
run_fn(lib_ptr, len(kernel), gx, gy, gz, lx, ly, lz, args_ptr) run_fn(lib_ptr, len(kernel), gx, gy, gz, lx, ly, lz, args_ptr, rsrc2)
# Timed runs # Timed runs
times = [] times = []
for _ in range(iterations): for _ in range(iterations):
start = time.perf_counter() start = time.perf_counter()
result = run_fn(lib_ptr, len(kernel), gx, gy, gz, lx, ly, lz, args_ptr) result = run_fn(lib_ptr, len(kernel), gx, gy, gz, lx, ly, lz, args_ptr, rsrc2)
end = time.perf_counter() end = time.perf_counter()
if result != 0: if result != 0:
print(f" {name} returned error: {result}") print(f" {name} returned error: {result}")
@@ -65,27 +64,12 @@ def benchmark_emulator(name: str, run_fn, kernel: bytes, global_size, local_size
return sum(times) / len(times) return sum(times) / len(times)
def create_synthetic_kernel(n_ops: int) -> bytes: def get_tinygrad_kernel(op_name: str) -> tuple[bytes, tuple, tuple, list[int], dict[int, bytes], int] | None:
"""Create a synthetic kernel with n_ops vector operations.""" """Get a real tinygrad kernel by operation name. Returns (code, global_size, local_size, buf_sizes, buf_data, rsrc2)."""
instructions = []
# VOP2 instructions: v_add_f32, v_mul_f32, v_max_f32, v_min_f32
ops = [
(0b0000011 << 25) | (1 << 17) | (0 << 9) | 256, # v_add_f32 v0, v0, v1
(0b0001000 << 25) | (1 << 17) | (0 << 9) | 256, # v_mul_f32 v0, v0, v1
(0b0010000 << 25) | (1 << 17) | (0 << 9) | 256, # v_max_f32 v0, v0, v1
(0b0001111 << 25) | (1 << 17) | (0 << 9) | 256, # v_min_f32 v0, v0, v1
]
for i in range(n_ops):
instructions.append(ops[i % len(ops)])
# S_ENDPGM
instructions.append((0b101111111 << 23) | (48 << 16) | 0)
return b''.join(struct.pack('<I', inst) for inst in instructions)
def get_tinygrad_kernel(op_name: str) -> tuple[bytes, tuple, tuple, list[int], dict[int, bytes]] | None:
"""Get a real tinygrad kernel by operation name. Returns (code, global_size, local_size, buf_sizes, buf_data)."""
try: try:
from tinygrad import Tensor from tinygrad import Tensor
from tinygrad.runtime.support.elf import elf_loader from tinygrad.runtime.support.elf import elf_loader
from tinygrad.runtime.autogen import hsa
import numpy as np import numpy as np
np.random.seed(42) np.random.seed(42)
@@ -112,7 +96,9 @@ def get_tinygrad_kernel(op_name: str) -> tuple[bytes, tuple, tuple, list[int], d
lowered = ei.lower() lowered = ei.lower()
if ei.ast.op.name == 'SINK' and lowered.prg and lowered.prg.p.lib: if ei.ast.op.name == 'SINK' and lowered.prg and lowered.prg.p.lib:
lib = bytes(lowered.prg.p.lib) lib = bytes(lowered.prg.p.lib)
image = memoryview(bytearray(lib))
_, sections, _ = elf_loader(lib) _, sections, _ = elf_loader(lib)
rodata_entry = next((sh.header.sh_addr for sh in sections if sh.name == ".rodata"), -1)
for sec in sections: for sec in sections:
if sec.name == '.text': if sec.name == '.text':
buf_sizes = [b.nbytes for b in lowered.bufs] buf_sizes = [b.nbytes for b in lowered.bufs]
@@ -122,67 +108,22 @@ def get_tinygrad_kernel(op_name: str) -> tuple[bytes, tuple, tuple, list[int], d
if hasattr(buf, 'base') and buf.base is not None and hasattr(buf.base, '_buf'): if hasattr(buf, 'base') and buf.base is not None and hasattr(buf.base, '_buf'):
try: buf_data[i] = bytes(buf.base._buf) try: buf_data[i] = bytes(buf.base._buf)
except: pass except: pass
return (bytes(sec.content), tuple(lowered.prg.p.global_size), tuple(lowered.prg.p.local_size), buf_sizes, buf_data) # Extract rsrc2 from ELF (same as ops_amd.py)
group_segment_size = image[rodata_entry:rodata_entry+4].cast("I")[0]
lds_size = ((group_segment_size + 511) // 512) & 0x1FF
code = hsa.amd_kernel_code_t.from_buffer_copy(bytes(image[rodata_entry:rodata_entry+256]) + b'\x00'*256)
rsrc2 = code.compute_pgm_rsrc2 | (lds_size << 15)
return (bytes(sec.content), tuple(lowered.prg.p.global_size), tuple(lowered.prg.p.local_size), buf_sizes, buf_data, rsrc2)
return None return None
except Exception as e: except Exception as e:
print(f" Error getting kernel: {e}") print(f" Error getting kernel: {e}")
return None return None
def profile_python_emu(kernel: bytes, global_size, local_size, args_ptr, n_runs: int = 1):
"""Profile the Python emulator to find bottlenecks."""
gx, gy, gz = global_size
lx, ly, lz = local_size
kernel_buf = (ctypes.c_char * len(kernel)).from_buffer_copy(kernel)
lib_ptr = ctypes.addressof(kernel_buf)
pr = cProfile.Profile()
pr.enable()
for _ in range(n_runs):
python_run_asm(lib_ptr, len(kernel), gx, gy, gz, lx, ly, lz, args_ptr)
pr.disable()
s = io.StringIO()
ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
ps.print_stats(20)
return s.getvalue()
def measure_step_rate(kernel: bytes, n_steps: int = 10000) -> float:
"""Measure raw step_wave() performance (steps per second)."""
program = decode_program(kernel)
if not program: return 0.0
st = WaveState()
st.exec_mask = 0xffffffff
lds = bytearray(65536)
n_lanes = 32
# Reset PC for each measurement
start = time.perf_counter()
for _ in range(n_steps):
st.pc = 0
while st.pc in program:
result = step_wave(program, st, lds, n_lanes)
if result == -1: break
elapsed = time.perf_counter() - start
return n_steps / elapsed if elapsed > 0 else 0
# Test configurations
SYNTHETIC_TESTS = [
("synthetic_10ops", 10, (1, 1, 1), (32, 1, 1)),
("synthetic_100ops", 100, (1, 1, 1), (32, 1, 1)),
("synthetic_500ops", 500, (1, 1, 1), (32, 1, 1)),
("synthetic_100ops_4wg", 100, (4, 1, 1), (32, 1, 1)),
("synthetic_100ops_16wg", 100, (16, 1, 1), (32, 1, 1)),
]
TINYGRAD_TESTS = ["add", "mul", "reduce_sum", "softmax", "exp", "gelu", "matmul_small"] TINYGRAD_TESTS = ["add", "mul", "reduce_sum", "softmax", "exp", "gelu", "matmul_small"]
def main(): def main():
import argparse import argparse
parser = argparse.ArgumentParser(description="Benchmark RDNA3 emulators") parser = argparse.ArgumentParser(description="Benchmark RDNA3 emulators")
parser.add_argument("--profile", action="store_true", help="Profile Python emulator")
parser.add_argument("--synthetic-only", action="store_true", help="Only run synthetic tests")
parser.add_argument("--tinygrad-only", action="store_true", help="Only run tinygrad tests")
parser.add_argument("--iterations", type=int, default=3, help="Number of iterations per benchmark") parser.add_argument("--iterations", type=int, default=3, help="Number of iterations per benchmark")
args = parser.parse_args() args = parser.parse_args()
@@ -197,98 +138,55 @@ def main():
results = [] results = []
# Synthetic workloads print("\n[TINYGRAD KERNELS]")
if not args.tinygrad_only: print("-" * 90)
print("\n[SYNTHETIC WORKLOADS]")
print("-" * 90)
for name, n_ops, global_size, local_size in SYNTHETIC_TESTS: for op_name in TINYGRAD_TESTS:
kernel = create_synthetic_kernel(n_ops) print(f"\n{op_name}:", end=" ", flush=True)
n_insts = count_instructions(kernel) kernel_info = get_tinygrad_kernel(op_name)
n_workgroups = global_size[0] * global_size[1] * global_size[2] if kernel_info is None:
n_threads = local_size[0] * local_size[1] * local_size[2] print("failed to compile")
total_work = n_insts * n_workgroups * n_threads continue
print(f"\n{name}: {n_insts} insts × {n_workgroups} WGs × {n_threads} threads = {total_work:,} ops") kernel, global_size, local_size, buf_sizes, buf_data, rsrc2 = kernel_info
n_insts = count_instructions(kernel)
n_workgroups = global_size[0] * global_size[1] * global_size[2]
n_threads = local_size[0] * local_size[1] * local_size[2]
total_work = n_insts * n_workgroups * n_threads
buf_sizes = [4096] print(f"{n_insts} insts × {n_workgroups} WGs × {n_threads} threads = {total_work:,} ops")
buffers, args_arr, args_ptr, ranges = setup_buffers(buf_sizes)
set_valid_mem_ranges(ranges)
# Benchmark buffers, args_arr, args_ptr, ranges = setup_buffers(buf_sizes, buf_data)
py_time = benchmark_emulator("Python", python_run_asm, kernel, global_size, local_size, args_ptr, args.iterations) set_valid_mem_ranges(ranges)
rust_time = benchmark_emulator("Rust", rust_remu.run_asm, kernel, global_size, local_size, args_ptr, args.iterations) if rust_remu else None
if py_time: py_time = benchmark_emulator("Python", python_run_asm, kernel, global_size, local_size, args_ptr, rsrc2, args.iterations)
py_rate = total_work / py_time / 1e6 rust_time = benchmark_emulator("Rust", rust_remu.run_asm, kernel, global_size, local_size, args_ptr, rsrc2, args.iterations) if rust_remu else None
print(f" Python: {py_time*1000:8.3f} ms ({py_rate:7.2f} M ops/s)")
if rust_time:
rust_rate = total_work / rust_time / 1e6
speedup = py_time / rust_time if py_time else 0
print(f" Rust: {rust_time*1000:8.3f} ms ({rust_rate:7.2f} M ops/s) [{speedup:.1f}x faster]")
results.append(("synthetic", name, n_insts, n_workgroups, py_time, rust_time)) if py_time:
py_rate = total_work / py_time / 1e6
print(f" Python: {py_time*1000:8.3f} ms ({py_rate:7.2f} M ops/s)")
if rust_time:
rust_rate = total_work / rust_time / 1e6
speedup = py_time / rust_time if py_time else 0
print(f" Rust: {rust_time*1000:8.3f} ms ({rust_rate:7.2f} M ops/s) [{speedup:.1f}x faster]")
# Tinygrad kernels results.append((op_name, n_insts, n_workgroups, py_time, rust_time))
if not args.synthetic_only:
print("\n[TINYGRAD KERNELS]")
print("-" * 90)
for op_name in TINYGRAD_TESTS:
print(f"\n{op_name}:", end=" ", flush=True)
kernel_info = get_tinygrad_kernel(op_name)
if kernel_info is None:
print("failed to compile")
continue
kernel, global_size, local_size, buf_sizes, buf_data = kernel_info
n_insts = count_instructions(kernel)
n_workgroups = global_size[0] * global_size[1] * global_size[2]
n_threads = local_size[0] * local_size[1] * local_size[2]
total_work = n_insts * n_workgroups * n_threads
print(f"{n_insts} insts × {n_workgroups} WGs × {n_threads} threads = {total_work:,} ops")
buffers, args_arr, args_ptr, ranges = setup_buffers(buf_sizes, buf_data)
set_valid_mem_ranges(ranges)
py_time = benchmark_emulator("Python", python_run_asm, kernel, global_size, local_size, args_ptr, args.iterations)
rust_time = benchmark_emulator("Rust", rust_remu.run_asm, kernel, global_size, local_size, args_ptr, args.iterations) if rust_remu else None
if py_time:
py_rate = total_work / py_time / 1e6
print(f" Python: {py_time*1000:8.3f} ms ({py_rate:7.2f} M ops/s)")
if rust_time:
rust_rate = total_work / rust_time / 1e6
speedup = py_time / rust_time if py_time else 0
print(f" Rust: {rust_time*1000:8.3f} ms ({rust_rate:7.2f} M ops/s) [{speedup:.1f}x faster]")
results.append(("tinygrad", op_name, n_insts, n_workgroups, py_time, rust_time))
# Optional profiling
if args.profile and py_time:
print("\n [PROFILE - Top 10 functions]")
profile_output = profile_python_emu(kernel, global_size, local_size, args_ptr)
for line in profile_output.split('\n')[5:15]:
if line.strip(): print(f" {line}")
# Summary table # Summary table
print("\n" + "=" * 90) print("\n" + "=" * 90)
print("SUMMARY") print("SUMMARY")
print("=" * 90) print("=" * 90)
print(f"{'Type':<10} {'Name':<25} {'Insts':<8} {'WGs':<6} {'Python (ms)':<14} {'Rust (ms)':<14} {'Speedup':<10}") print(f"{'Name':<25} {'Insts':<8} {'WGs':<6} {'Python (ms)':<14} {'Rust (ms)':<14} {'Speedup':<10}")
print("-" * 90) print("-" * 90)
for test_type, name, n_insts, n_wgs, py_time, rust_time in results: for name, n_insts, n_wgs, py_time, rust_time in results:
py_ms = f"{py_time*1000:.3f}" if py_time else "error" py_ms = f"{py_time*1000:.3f}" if py_time else "error"
if rust_time: if rust_time:
rust_ms = f"{rust_time*1000:.3f}" rust_ms = f"{rust_time*1000:.3f}"
speedup = f"{py_time/rust_time:.1f}x" if py_time else "N/A" speedup = f"{py_time/rust_time:.1f}x" if py_time else "N/A"
else: else:
rust_ms, speedup = "N/A", "N/A" rust_ms, speedup = "N/A", "N/A"
print(f"{test_type:<10} {name:<25} {n_insts:<8} {n_wgs:<6} {py_ms:<14} {rust_ms:<14} {speedup:<10}") print(f"{name:<25} {n_insts:<8} {n_wgs:<6} {py_ms:<14} {rust_ms:<14} {speedup:<10}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -92,7 +92,9 @@ def run_program_emu(instructions: list, n_lanes: int = 1) -> WaveState:
lib_ptr = ctypes.addressof(kernel_buf) lib_ptr = ctypes.addressof(kernel_buf)
set_valid_mem_ranges({(out_addr, OUT_BYTES), (args_ptr, 8)}) set_valid_mem_ranges({(out_addr, OUT_BYTES), (args_ptr, 8)})
result = run_asm(lib_ptr, len(code), 1, 1, 1, n_lanes, 1, 1, args_ptr) # rsrc2: USER_SGPR_COUNT=2, ENABLE_SGPR_WORKGROUP_ID_X/Y/Z=1, LDS_SIZE=128 (64KB)
rsrc2 = 0x19c | (128 << 15)
result = run_asm(lib_ptr, len(code), 1, 1, 1, n_lanes, 1, 1, args_ptr, rsrc2)
assert result == 0, f"run_asm failed with {result}" assert result == 0, f"run_asm failed with {result}"
return parse_output(bytes(out_buf), n_lanes) return parse_output(bytes(out_buf), n_lanes)

View File

@@ -33,6 +33,35 @@ class TestBasicScalar(unittest.TestCase):
self.assertEqual(st.sgpr[4], 0) self.assertEqual(st.sgpr[4], 0)
self.assertEqual(st.scc, 1) self.assertEqual(st.scc, 1)
def test_s_brev_b32(self):
"""S_BREV_B32 reverses bits of a 32-bit value."""
# 10 = 0b00000000000000000000000000001010
# reversed = 0b01010000000000000000000000000000 = 0x50000000
instructions = [
s_mov_b32(s[0], 10),
s_brev_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[1], 0x50000000)
def test_s_brev_b32_all_ones(self):
"""S_BREV_B32 with all ones stays all ones."""
instructions = [
s_mov_b32(s[0], 0xFFFFFFFF),
s_brev_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[1], 0xFFFFFFFF)
def test_s_brev_b32_single_bit(self):
"""S_BREV_B32 with bit 0 set becomes bit 31."""
instructions = [
s_mov_b32(s[0], 1),
s_brev_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[1], 0x80000000)
class TestQuadmaskWqm(unittest.TestCase): class TestQuadmaskWqm(unittest.TestCase):
"""Tests for S_QUADMASK_B32 and S_WQM_B32.""" """Tests for S_QUADMASK_B32 and S_WQM_B32."""
@@ -201,5 +230,113 @@ class TestSCCBehavior(unittest.TestCase):
self.assertEqual(st.scc, 0) self.assertEqual(st.scc, 0)
class TestSignedArithmetic(unittest.TestCase):
"""Tests for S_ADD_I32, S_SUB_I32 and their SCC overflow behavior."""
def test_s_add_i32_no_overflow(self):
"""S_ADD_I32: 1 + 1 = 2, no overflow, SCC=0."""
instructions = [
s_mov_b32(s[0], 1),
s_add_i32(s[1], s[0], 1),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[1], 2)
self.assertEqual(st.scc, 0, "No overflow, SCC should be 0")
def test_s_add_i32_positive_overflow(self):
"""S_ADD_I32: MAX_INT + 1 overflows, SCC=1."""
instructions = [
s_mov_b32(s[0], 0x7FFFFFFF), # MAX_INT
s_add_i32(s[1], s[0], 1),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[1], 0x80000000) # Wraps to MIN_INT
self.assertEqual(st.scc, 1, "Overflow, SCC should be 1")
def test_s_add_i32_negative_no_overflow(self):
"""S_ADD_I32: -10 + 20 = 10, no overflow."""
instructions = [
s_mov_b32(s[0], 0xFFFFFFF6), # -10 in two's complement
s_mov_b32(s[1], 20),
s_add_i32(s[2], s[0], s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[2], 10)
self.assertEqual(st.scc, 0)
def test_s_add_i32_negative_overflow(self):
"""S_ADD_I32: MIN_INT + (-1) underflows, SCC=1."""
instructions = [
s_mov_b32(s[0], 0x80000000), # MIN_INT
s_mov_b32(s[1], 0xFFFFFFFF), # -1
s_add_i32(s[2], s[0], s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[2], 0x7FFFFFFF) # Wraps to MAX_INT
self.assertEqual(st.scc, 1, "Underflow, SCC should be 1")
def test_s_sub_i32_no_overflow(self):
"""S_SUB_I32: 10 - 5 = 5, no overflow."""
instructions = [
s_mov_b32(s[0], 10),
s_mov_b32(s[1], 5),
s_sub_i32(s[2], s[0], s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[2], 5)
self.assertEqual(st.scc, 0)
def test_s_sub_i32_overflow(self):
"""S_SUB_I32: MAX_INT - (-1) overflows, SCC=1."""
instructions = [
s_mov_b32(s[0], 0x7FFFFFFF), # MAX_INT
s_mov_b32(s[1], 0xFFFFFFFF), # -1
s_sub_i32(s[2], s[0], s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[2], 0x80000000) # Wraps to MIN_INT
self.assertEqual(st.scc, 1, "Overflow, SCC should be 1")
def test_s_mul_hi_u32(self):
"""S_MUL_HI_U32: high 32 bits of u32 * u32."""
instructions = [
s_mov_b32(s[0], 0x80000000), # 2^31
s_mov_b32(s[1], 4),
s_mul_hi_u32(s[2], s[0], s[1]), # (2^31 * 4) >> 32 = 2
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[2], 2)
def test_s_mul_i32(self):
"""S_MUL_I32: signed multiply low 32 bits."""
instructions = [
s_mov_b32(s[0], 0xFFFFFFFF), # -1
s_mov_b32(s[1], 10),
s_mul_i32(s[2], s[0], s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[2], 0xFFFFFFF6) # -10
def test_division_sequence_from_llvm(self):
"""Test the division sequence pattern from LLVM-generated code."""
# This sequence is from the sin kernel and computes integer division
# s10 = dividend, s18 = divisor, result in s6/s14
dividend = 0x28BE60DB # Some value from the sin kernel
divisor = 3 # Simplified divisor
instructions = [
s_mov_b32(s[10], dividend),
s_mov_b32(s[18], divisor),
# Compute reciprocal approximation: s6 = ~0 / divisor (approx)
s_mov_b32(s[11], 0),
s_sub_i32(s[11], s[11], s[18]), # s11 = -divisor
# For testing, just verify basic arithmetic works
s_mul_i32(s[6], s[10], 2),
s_add_i32(s[7], s[6], 1),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[6], (dividend * 2) & 0xFFFFFFFF)
self.assertEqual(st.sgpr[7], ((dividend * 2) + 1) & 0xFFFFFFFF)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -1049,6 +1049,39 @@ class TestF64Ops(unittest.TestCase):
total = p0 + p1 + p2 total = p0 + p1 + p2
self.assertAlmostEqual(total, two_over_pi, places=14) self.assertAlmostEqual(total, two_over_pi, places=14)
def test_v_fma_f64_sin_kernel_step84(self):
"""V_FMA_F64: exact values from sin(2.0) kernel step 84 that shows 1-bit difference."""
# From test_sin_f64 failure trace at step 84:
# v_fma_f64 v[7:8], v[17:18], v[7:8], v[15:16]
# We need to capture the exact input values and verify output matches hardware
# v[7:8] before = 0x3f80fdf3_d69db28f (0.008296875941334462)
v78 = 0x3f80fdf3d69db28f
# For the FMA to produce 0xbf457ef0_ab8c254d, we need v[17:18] and v[15:16]
# Let's test with known precision-sensitive values
a = 1.0000000001
b = 1.0000000002
c = -1.0000000003
a_bits, b_bits, c_bits = f2i64(a), f2i64(b), f2i64(c)
instructions = [
s_mov_b32(s[0], a_bits & 0xffffffff),
s_mov_b32(s[1], a_bits >> 32),
s_mov_b32(s[2], b_bits & 0xffffffff),
s_mov_b32(s[3], b_bits >> 32),
s_mov_b32(s[4], c_bits & 0xffffffff),
s_mov_b32(s[5], c_bits >> 32),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], s[2]),
v_mov_b32_e32(v[3], s[3]),
v_mov_b32_e32(v[4], s[4]),
v_mov_b32_e32(v[5], s[5]),
v_fma_f64(v[6], v[0], v[2], v[4]),
]
# run_program with USE_HW=1 will verify exact bit match with hardware
st = run_program(instructions, n_lanes=1)
result_bits = st.vgpr[0][6] | (st.vgpr[0][7] << 32)
self.assertNotEqual(result_bits, 0, "Result should not be zero")
class TestMad64More(unittest.TestCase): class TestMad64More(unittest.TestCase):
"""More tests for V_MAD_U64_U32.""" """More tests for V_MAD_U64_U32."""

View File

@@ -9,7 +9,7 @@ os.environ["AMD"] = "1"
os.environ["MOCKGPU"] = "1" os.environ["MOCKGPU"] = "1"
os.environ["PYTHON_REMU"] = "1" os.environ["PYTHON_REMU"] = "1"
from extra.assembly.amd.emu import WaveState, decode_program, step_wave, WAVE_SIZE, set_valid_mem_ranges, LDSMem from extra.assembly.amd.emu import WaveState, decode_program, WAVE_SIZE, set_valid_mem_ranges, LDSMem
from extra.assembly.amd.test.helpers import KernelInfo from extra.assembly.amd.test.helpers import KernelInfo
REMU_PATH = Path(__file__).parents[3] / "remu/target/release/libremu.so" REMU_PATH = Path(__file__).parents[3] / "remu/target/release/libremu.so"
@@ -92,19 +92,15 @@ class PythonEmulator:
def __init__(self): def __init__(self):
self.state: WaveState | None = None self.state: WaveState | None = None
self.program: dict | None = None self.program: dict | None = None
self.lds: bytearray | None = None
self.n_lanes = 0
def create(self, kernel: bytes, n_lanes: int): def create(self, kernel: bytes, n_lanes: int):
self.program = decode_program(kernel) self.program = decode_program(kernel)
self.state = WaveState() self.state = WaveState(LDSMem(bytearray(65536)), n_lanes)
self.state.exec_mask = (1 << n_lanes) - 1 self.state.exec_mask = (1 << n_lanes) - 1
self.lds = LDSMem(bytearray(65536))
self.n_lanes = n_lanes
def step(self) -> int: def step(self) -> int:
assert self.program is not None and self.state is not None and self.lds is not None assert self.program is not None and self.state is not None
return step_wave(self.program, self.state, self.lds, self.n_lanes) return self.program[self.state.pc]._dispatch(self.state, self.program[self.state.pc])
def set_sgpr(self, idx: int, val: int): def set_sgpr(self, idx: int, val: int):
assert self.state is not None assert self.state is not None
self.state.sgpr[idx] = val & 0xffffffff self.state.sgpr[idx] = val & 0xffffffff
@@ -163,8 +159,9 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
# Instructions with known Rust emulator bugs - sync Python to Rust after execution # Instructions with known Rust emulator bugs - sync Python to Rust after execution
# v_div_scale/v_div_fixup: Rust has different VCC handling # v_div_scale/v_div_fixup: Rust has different VCC handling
# v_cvt_f16_f32: Rust clears high 16 bits, but hardware (and Python) preserves them # v_cvt_f16_f32: Rust clears high 16 bits, but hardware (and Python) preserves them
# s_add_i32/s_sub_i32: Rust has incorrect SCC overflow detection
sync_after = any(x in inst_str for x in ('v_div_scale_f32', 'v_div_scale_f64', 'v_div_fixup_f32', 'v_div_fixup_f64', sync_after = any(x in inst_str for x in ('v_div_scale_f32', 'v_div_scale_f64', 'v_div_fixup_f32', 'v_div_fixup_f64',
'v_cvt_f16_f32')) 'v_cvt_f16_f32', 's_add_i32', 's_sub_i32'))
diffs = rust_before.diff(python_before, n_lanes) diffs = rust_before.diff(python_before, n_lanes)
if diffs: if diffs:
trace_lines = [] trace_lines = []
@@ -397,6 +394,9 @@ class TestTinygradKernels(unittest.TestCase):
x_np = np.random.randn(16, 10).astype(np.float32) x_np = np.random.randn(16, 10).astype(np.float32)
self._test_kernel(lambda T: (T(x_np.tolist()).reshape(16,10) + 0).cross_entropy((T(classes).int().reshape(16) + 0))) self._test_kernel(lambda T: (T(x_np.tolist()).reshape(16,10) + 0).cross_entropy((T(classes).int().reshape(16) + 0)))
def test_isinf(self): self._test_kernel(lambda T: T([float('-inf'), 0., float('inf'), 1.1]*8).isinf()) def test_isinf(self): self._test_kernel(lambda T: T([float('-inf'), 0., float('inf'), 1.1]*8).isinf())
def test_sin_f64(self):
from tinygrad import dtypes
self._test_kernel(lambda T: T([2.0], dtype=dtypes.float64).sin())
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -20,11 +20,12 @@ runner = get_runner(dev.device, si.ast)
prg = runner._prg prg = runner._prg
lib = bytearray(prg.lib) lib = bytearray(prg.lib)
# Find s_endpgm (0xBFB00000) and replace with invalid SOPP op=127 (0xBFFF0000) # Find s_endpgm (0xBFB00000) and replace with V_MOVRELD_B32 (op=66) which has no pcode
# VOP1 encoding: bits[31:25]=0x7E, op=bits[16:9], so op=66 -> 66<<9 = 0x8400
found = False found = False
for i in range(0, len(lib) - 4, 4): for i in range(0, len(lib) - 4, 4):
if struct.unpack("<I", lib[i:i+4])[0] == 0xBFB00000: if struct.unpack("<I", lib[i:i+4])[0] == 0xBFB00000:
lib[i:i+4] = struct.pack("<I", 0xBFFF0000) lib[i:i+4] = struct.pack("<I", 0x7E008400)
found = True found = True
break break
assert found, "s_endpgm not found" assert found, "s_endpgm not found"

View File

@@ -233,14 +233,13 @@ class TestPseudocodeRegressions(unittest.TestCase):
Bug: when VCC._val == vcc (both 0), VCC wasn't returned, so VCC bits weren't written. Bug: when VCC._val == vcc (both 0), VCC wasn't returned, so VCC bits weren't written.
This caused division to produce wrong results for multiple lanes.""" This caused division to produce wrong results for multiple lanes."""
# Normal case: 1.0 / 3.0, no scaling needed, VCC should be 0 # Normal case: 1.0 / 3.0, no scaling needed, VCC should be 0
S0 = Reg(0x3f800000) # 1.0 s0 = 0x3f800000 # 1.0
S1 = Reg(0x40400000) # 3.0 s1 = 0x40400000 # 3.0
S2 = Reg(0x3f800000) # 1.0 (numerator) s2 = 0x3f800000 # 1.0 (numerator)
D0, SCC, VCC, EXEC = Reg(0), Reg(0), Reg(0), Reg(0xffffffff) result = _VOP3SDOp_V_DIV_SCALE_F32(s0, s1, s2, 0, 0, 0, 0, 0xffffffff, 0, None)
result = _VOP3SDOp_V_DIV_SCALE_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
# Must always have VCC in result # Must always have VCC in result
self.assertIn('VCC', result, "V_DIV_SCALE_F32 must always return VCC") self.assertIn('VCC', result, "V_DIV_SCALE_F32 must always return VCC")
self.assertEqual(result['VCC']._val & 1, 0, "VCC lane 0 should be 0 when no scaling needed") self.assertEqual(result['VCC'] & 1, 0, "VCC lane 0 should be 0 when no scaling needed")
def test_v_cmp_class_f32_detects_quiet_nan(self): def test_v_cmp_class_f32_detects_quiet_nan(self):
"""V_CMP_CLASS_F32 must correctly identify quiet NaN vs signaling NaN. """V_CMP_CLASS_F32 must correctly identify quiet NaN vs signaling NaN.
@@ -249,22 +248,18 @@ class TestPseudocodeRegressions(unittest.TestCase):
signal_nan = 0x7f800001 # signaling NaN: exponent=255, bit22=0 signal_nan = 0x7f800001 # signaling NaN: exponent=255, bit22=0
# Test quiet NaN detection (bit 1 in mask) # Test quiet NaN detection (bit 1 in mask)
s1_quiet = 0b0000000010 # bit 1 = quiet NaN s1_quiet = 0b0000000010 # bit 1 = quiet NaN
S0, S1, S2, D0, SCC, VCC, EXEC = Reg(quiet_nan), Reg(s1_quiet), Reg(0), Reg(0), Reg(0), Reg(0), Reg(0xffffffff) result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None) self.assertEqual(result['D0'] & 1, 1, "Should detect quiet NaN with quiet NaN mask")
self.assertEqual(result['D0']._val & 1, 1, "Should detect quiet NaN with quiet NaN mask")
# Test signaling NaN detection (bit 0 in mask) # Test signaling NaN detection (bit 0 in mask)
s1_signal = 0b0000000001 # bit 0 = signaling NaN s1_signal = 0b0000000001 # bit 0 = signaling NaN
S0, S1 = Reg(signal_nan), Reg(s1_signal) result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None) self.assertEqual(result['D0'] & 1, 1, "Should detect signaling NaN with signaling NaN mask")
self.assertEqual(result['D0']._val & 1, 1, "Should detect signaling NaN with signaling NaN mask")
# Test that quiet NaN doesn't match signaling NaN mask # Test that quiet NaN doesn't match signaling NaN mask
S0, S1 = Reg(quiet_nan), Reg(s1_signal) result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None) self.assertEqual(result['D0'] & 1, 0, "Quiet NaN should not match signaling NaN mask")
self.assertEqual(result['D0']._val & 1, 0, "Quiet NaN should not match signaling NaN mask")
# Test that signaling NaN doesn't match quiet NaN mask # Test that signaling NaN doesn't match quiet NaN mask
S0, S1 = Reg(signal_nan), Reg(s1_quiet) result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None) self.assertEqual(result['D0'] & 1, 0, "Signaling NaN should not match quiet NaN mask")
self.assertEqual(result['D0']._val & 1, 0, "Signaling NaN should not match quiet NaN mask")
def test_isnan_with_typed_view(self): def test_isnan_with_typed_view(self):
"""_isnan must work with TypedView objects, not just Python floats. """_isnan must work with TypedView objects, not just Python floats.