assembly/amd: make the emu.py code shine (#13996)

* assembly/amd: make the code shine

* lil clean

* reg back in pcode

* cleanups

* gen fma_mix

* no writelane hacks

* fn cleanup

* dead vgpr_write

* readable

* smem

* cleanup bench_emu

* speedups

* simpler and faster

* direct inst._fn

* split fxn

* Revert "simpler and faster"

This reverts commit e85f6594b3.

* move lds to wavestate

* dispatcher

* pc in dispatch

* literal isn't wavestate

* cleanups + program

* one readlane

* exec_vop3sd in exec_vop

* cleaner exec_vopd

* fully merge VOP3P

* no special paths

* no SliceProxy

* low=0

* no bigint

* failing tests

* fma on python 3.13
This commit is contained in:
George Hotz
2026-01-03 23:33:09 -05:00
committed by GitHub
parent bdb421f13e
commit 8328511808
15 changed files with 14949 additions and 8550 deletions

View File

@@ -668,6 +668,7 @@ jobs:
key: rdna3-emu
deps: testing_minimal
amd: 'true'
python-version: '3.13'
- name: Install LLVM 21
run: |
wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
import struct, math, re
from enum import IntEnum
from functools import cache, cached_property
from functools import cache
from typing import overload, Annotated, TypeVar, Generic
from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, SOP1Op, SOP2Op,
SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp)
@@ -346,6 +346,7 @@ class Inst:
if 'abs_' in kwargs: kwargs['abs'] = kwargs.pop('abs_')
orig_args = dict(zip(field_names, args)) | kwargs
self._values.update(orig_args)
self._precompute()
self._validate(orig_args)
# Pre-shift literal for 64-bit sources (literal param is always raw 32-bit value from user)
if literal is not None:
@@ -386,6 +387,7 @@ class Inst:
elif name == 'sbase': self._values[name] = (val.idx if isinstance(val, Reg) else val.val if isinstance(val, SrcMod) else val * 2) // 2
elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg): self._values[name] = val.idx // 4
elif marker is _VDSTYEnc and isinstance(val, VGPR): self._values[name] = val.idx >> 1
self._precompute_fields()
def _encode_field(self, name: str, val) -> int:
if isinstance(val, RawImm): return val.val
@@ -450,6 +452,8 @@ class Inst:
inst = object.__new__(cls)
inst._values = {n: RawImm(v) if n in SRC_FIELDS else v for n, bf in cls._fields.items() if n != 'encoding' for v in [(word >> bf.lo) & bf.mask()]}
inst._literal = None
inst._precompute()
inst._precompute_fields()
return inst
@classmethod
@@ -510,25 +514,32 @@ class Inst:
'VOPD': VOPDOp, 'VINTERP': VINTERPOp}
_VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
@property
def op(self):
"""Return the op as an enum (e.g., VOP1Op.V_MOV_B32). VOP3 returns VOPCOp/VOP3SDOp for those op ranges."""
def _precompute(self):
"""Precompute op, op_name, _spec_regs, _spec_dtype for fast access."""
val = self._values.get('op')
if val is None: return None
if hasattr(val, 'name'): return val # already an enum
cls_name = self.__class__.__name__
assert cls_name in self._enum_map, f"no enum map for {cls_name}"
return self._enum_map[cls_name](val)
if val is None: self.op = None
elif hasattr(val, 'name'): self.op = val
else:
cls_name = self.__class__.__name__
# VOP3 with VOPC opcodes (0-255) -> VOPCOp, VOP3SD opcodes -> VOP3SDOp
if cls_name == 'VOP3':
try:
if val < 256: self.op = VOPCOp(val)
elif val in self._VOP3SD_OPS: self.op = VOP3SDOp(val)
else: self.op = VOP3Op(val)
except ValueError: self.op = val
elif cls_name in self._enum_map:
try: self.op = self._enum_map[cls_name](val)
except ValueError: self.op = val
else: self.op = val
self.op_name = self.op.name if hasattr(self.op, 'name') else ''
self._spec_regs = spec_regs(self.op_name)
self._spec_dtype = spec_dtype(self.op_name)
@cached_property
def op_name(self) -> str:
op = self.op
return op.name if hasattr(op, 'name') else ''
@cached_property
def _spec_regs(self) -> tuple[int, int, int, int]: return spec_regs(self.op_name)
@cached_property
def _spec_dtype(self) -> tuple[str | None, str | None, str | None, str | None]: return spec_dtype(self.op_name)
def _precompute_fields(self):
"""Unwrap all field values as direct attributes for fast access."""
for name, val in self._values.items():
if name != 'op': setattr(self, name, unwrap(val))
def dst_regs(self) -> int: return self._spec_regs[0]
def src_regs(self, n: int) -> int: return self._spec_regs[n + 1]
def num_srcs(self) -> int: return spec_num_srcs(self.op_name)

View File

@@ -1,15 +1,14 @@
# RDNA3 emulator - executes compiled pseudocode from AMD ISA PDF
# mypy: ignore-errors
from __future__ import annotations
import ctypes
import ctypes, functools
from tinygrad.runtime.autogen import hsa
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
from extra.assembly.amd.pcode import Reg
from extra.assembly.amd.asm import detect_format
from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions
from extra.assembly.amd.autogen.rdna3.gen_pcode import COMPILED_FUNCTIONS
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, SCRATCHOp, VOPDOp)
Program = dict[int, Inst]
WAVE_SIZE, SGPR_COUNT, VGPR_COUNT = 32, 128, 256
VCC_LO, VCC_HI, NULL, EXEC_LO, EXEC_HI, SCC = SrcEnum.VCC_LO, SrcEnum.VCC_HI, SrcEnum.NULL, SrcEnum.EXEC_LO, SrcEnum.EXEC_HI, SrcEnum.SCC
@@ -29,6 +28,41 @@ def _dst16(cur: int, val: int, is_hi: bool) -> int: return (cur & 0x0000ffff) |
def _vgpr_hi(src: int) -> bool: return src >= 256 and ((src - 256) & 0x80) != 0
def _vgpr_masked(src: int) -> int: return ((src - 256) & 0x7f) + 256 if src >= 256 else src
# VOP3 source modifier: apply abs/neg to value
def _mod_src(val: int, idx: int, neg: int, abs_: int, is64: bool = False) -> int:
to_f, to_i = (_f64, _i64) if is64 else (_f32, _i32)
if (abs_ >> idx) & 1: val = to_i(abs(to_f(val)))
if (neg >> idx) & 1: val = to_i(-to_f(val))
return val
# Read source operand with VOP3 modifiers
def _read_src(st, inst, src, idx: int, lane: int, neg: int, abs_: int, opsel: int) -> int:
if src is None: return 0
literal, regs, is_src_16 = inst._literal, inst.src_regs(idx), inst.is_src_16(idx)
if regs == 2: return _mod_src(st.rsrc64(src, lane, literal), idx, neg, abs_, is64=True)
if isinstance(inst, VOP3P):
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
if 'FMA_MIX' in inst.op_name:
raw = st.rsrc(src, lane, literal)
sign_bit = (15 if not (opsel & (1 << idx)) else 31) if (opsel_hi >> idx) & 1 else 31
if inst.neg_hi & (1 << idx): raw &= ~(1 << sign_bit)
if neg & (1 << idx): raw ^= (1 << sign_bit)
return raw
raw = st.rsrc_f16(src, lane, literal)
hi = _src16(raw, opsel_hi & (1 << idx)) ^ (0x8000 if inst.neg_hi & (1 << idx) else 0)
lo = _src16(raw, opsel & (1 << idx)) ^ (0x8000 if neg & (1 << idx) else 0)
return (hi << 16) | lo
if is_src_16 and isinstance(inst, VOP3):
raw = st.rsrc_f16(src, lane, literal) if 128 <= src < 255 else st.rsrc(src, lane, literal)
val = _src16(raw, bool(opsel & (1 << idx)))
if abs_ & (1 << idx): val &= 0x7fff
if neg & (1 << idx): val ^= 0x8000
return val
if is_src_16 and isinstance(inst, (VOP1, VOP2, VOPC)):
if src >= 256: return _src16(_mod_src(st.rsrc(_vgpr_masked(src), lane, literal), idx, neg, abs_), _vgpr_hi(src))
return _mod_src(st.rsrc_f16(src, lane, literal), idx, neg, abs_) & 0xffff
return _mod_src(st.rsrc(src, lane, literal), idx, neg, abs_)
# Helper: get number of dwords from memory op name
def _op_ndwords(name: str) -> int:
if '_B128' in name: return 4
@@ -36,8 +70,8 @@ def _op_ndwords(name: str) -> int:
if any(s in name for s in ('_B64', '_U64', '_I64', '_F64')): return 2
return 1
# Helper: build multi-dword Reg from consecutive VGPRs
def _vgpr_read(V: list, base: int, ndwords: int) -> Reg: return Reg(sum(V[base + i] << (32 * i) for i in range(ndwords)))
# Helper: build multi-dword int from consecutive VGPRs
def _vgpr_read(V: list, base: int, ndwords: int) -> int: return sum(V[base + i] << (32 * i) for i in range(ndwords))
# Helper: write multi-dword value to consecutive VGPRs
def _vgpr_write(V: list, base: int, val: int, ndwords: int):
@@ -88,7 +122,8 @@ class LDSMem:
if addr + size <= len(self._lds): self._lds[addr:addr+size] = (int(val) & ((1 << (size*8)) - 1)).to_bytes(size, 'little')
def __getitem__(self, addr): return _make_mem_accessor(self._read, self._write)(addr)
SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
# SMEM dst register count (for writing result back to SGPRs)
SMEM_DST_COUNT = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
# VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup)
_VOPD_TO_VOP = {
@@ -100,19 +135,12 @@ _VOPD_TO_VOP = {
VOPDOp.V_DUAL_ADD_NC_U32: VOP3Op.V_ADD_NC_U32, VOPDOp.V_DUAL_LSHLREV_B32: VOP3Op.V_LSHLREV_B32, VOPDOp.V_DUAL_AND_B32: VOP3Op.V_AND_B32,
}
# Compiled pseudocode functions (lazy loaded)
_COMPILED: dict | None = None
def _get_compiled() -> dict:
global _COMPILED
if _COMPILED is None: _COMPILED = get_compiled_functions()
return _COMPILED
class WaveState:
__slots__ = ('sgpr', 'vgpr', 'scc', 'pc', 'literal', '_pend_sgpr')
def __init__(self):
__slots__ = ('sgpr', 'vgpr', 'scc', 'pc', '_pend_sgpr', 'lds', 'n_lanes')
def __init__(self, lds: LDSMem | None = None, n_lanes: int = WAVE_SIZE):
self.sgpr, self.vgpr = [0] * SGPR_COUNT, [[0] * VGPR_COUNT for _ in range(WAVE_SIZE)]
self.sgpr[EXEC_LO], self.scc, self.pc, self.literal, self._pend_sgpr = 0xffffffff, 0, 0, 0, {}
self.sgpr[EXEC_LO], self.scc, self.pc, self._pend_sgpr, self.lds, self.n_lanes = 0xffffffff, 0, 0, {}, lds, n_lanes
@property
def vcc(self) -> int: return self.sgpr[VCC_LO] | (self.sgpr[VCC_HI] << 32)
@@ -129,18 +157,18 @@ class WaveState:
def rsgpr64(self, i: int) -> int: return self.rsgpr(i) | (self.rsgpr(i+1) << 32)
def wsgpr64(self, i: int, v: int): self.wsgpr(i, v & MASK32); self.wsgpr(i+1, (v >> 32) & MASK32)
def _rsrc_base(self, v: int, lane: int, consts):
def _rsrc_base(self, v: int, lane: int, consts, literal: int):
if v < SGPR_COUNT: return self.sgpr[v]
if v == SCC: return self.scc
if v < 255: return consts[v - 128]
if v == 255: return self.literal
if v == 255: return literal
return self.vgpr[lane][v - 256] if v <= 511 else 0
def rsrc(self, v: int, lane: int) -> int: return self._rsrc_base(v, lane, _INLINE_CONSTS)
def rsrc_f16(self, v: int, lane: int) -> int: return self._rsrc_base(v, lane, _INLINE_CONSTS_F16)
def rsrc64(self, v: int, lane: int) -> int:
def rsrc(self, v: int, lane: int, literal: int = 0) -> int: return self._rsrc_base(v, lane, _INLINE_CONSTS, literal)
def rsrc_f16(self, v: int, lane: int, literal: int = 0) -> int: return self._rsrc_base(v, lane, _INLINE_CONSTS_F16, literal)
def rsrc64(self, v: int, lane: int, literal: int = 0) -> int:
if 128 <= v < 255: return _INLINE_CONSTS_F64[v - 128]
if v == 255: return self.literal # literal is already shifted in from_bytes for 64-bit ops
return self.rsrc(v, lane) | ((self.rsrc(v+1, lane) if v < VCC_LO or 256 <= v <= 511 else 0) << 32)
if v == 255: return literal # literal is already shifted in from_bytes for 64-bit ops
return self.rsrc(v, lane, literal) | ((self.rsrc(v+1, lane, literal) if v < VCC_LO or 256 <= v <= 511 else 0) << 32)
def pend_sgpr_lane(self, reg: int, lane: int, val: int):
if reg not in self._pend_sgpr: self._pend_sgpr[reg] = 0
@@ -150,251 +178,130 @@ class WaveState:
self._pend_sgpr.clear()
def decode_program(data: bytes) -> Program:
result: Program = {}
i = 0
while i < len(data):
try: inst_class = detect_format(data[i:])
except ValueError: break # stop at invalid instruction (padding/metadata after code)
if inst_class is None: i += 4; continue
base_size = inst_class._size()
# Pass enough data for potential 64-bit literal (base + 8 bytes max)
inst = inst_class.from_bytes(data[i:i+base_size+8])
for name, val in inst._values.items():
if name != 'op': setattr(inst, name, unwrap(val)) # skip op to preserve property access
inst._words = inst.size() // 4
result[i // 4] = inst
i += inst._words * 4
return result
# ═══════════════════════════════════════════════════════════════════════════════
# EXECUTION - All ALU ops use pseudocode from PDF
# EXECUTION - All ops use pseudocode from PDF
# ═══════════════════════════════════════════════════════════════════════════════
def exec_scalar(st: WaveState, inst: Inst) -> int:
"""Execute scalar instruction. Returns PC delta or negative for special cases."""
compiled = _get_compiled()
# SOPP: special cases for control flow that has no pseudocode
if isinstance(inst, SOPP):
if inst.op == SOPPOp.S_ENDPGM: return -1
if inst.op == SOPPOp.S_BARRIER: return -2
# SMEM: memory loads (not ALU)
if isinstance(inst, SMEM):
addr = st.rsgpr64(inst.sbase * 2) + _sext(inst.offset, 21)
if inst.soffset not in (NULL, 0x7f): addr += st.rsrc(inst.soffset, 0)
if (cnt := SMEM_LOAD.get(inst.op)) is None: raise NotImplementedError(f"SMEM op {inst.op}")
for i in range(cnt): st.wsgpr(inst.sdata + i, mem_read((addr + i * 4) & MASK64, 4))
return 0
def exec_scalar(st: WaveState, inst: Inst):
"""Execute scalar instruction. Returns 0 to continue execution."""
# Get op enum and lookup compiled function
if isinstance(inst, SOP1): ssrc0, sdst = inst.ssrc0, inst.sdst
if isinstance(inst, SMEM): ssrc0, sdst = None, None
elif isinstance(inst, SOP1): ssrc0, sdst = inst.ssrc0, inst.sdst
elif isinstance(inst, SOP2): ssrc0, sdst = inst.ssrc0, inst.sdst
elif isinstance(inst, SOPC): ssrc0, sdst = inst.ssrc0, None
elif isinstance(inst, SOPK): ssrc0, sdst = inst.sdst, inst.sdst # sdst is both src and dst
elif isinstance(inst, SOPP): ssrc0, sdst = None, None
else: raise NotImplementedError(f"Unknown scalar type {type(inst)}")
# SOPP has gaps in the opcode enum - treat unknown opcodes as no-ops
try: op = inst.op
except ValueError:
if isinstance(inst, SOPP): return 0
raise
fn = compiled.get(type(op), {}).get(op)
if fn is None:
# SOPP instructions without pseudocode (waits, hints, nops) are no-ops
if isinstance(inst, SOPP): return 0
raise NotImplementedError(f"{op.name} not in pseudocode")
# SMEM: memory loads
if isinstance(inst, SMEM):
addr = st.rsgpr64(inst.sbase * 2) + _sext(inst.offset, 21)
if inst.soffset not in (NULL, 0x7f): addr += st.rsrc(inst.soffset, 0, inst._literal)
result = inst._fn(GlobalMem, addr & MASK64)
if 'SDATA' in result:
sdata = result['SDATA']
for i in range(SMEM_DST_COUNT.get(inst.op, 1)): st.wsgpr(inst.sdata + i, (sdata >> (i * 32)) & MASK32)
st.pc += inst._words
return 0
# Build context - use inst methods to determine operand sizes
s0 = st.rsrc64(ssrc0, 0) if inst.is_src_64(0) else (st.rsrc(ssrc0, 0) if not isinstance(inst, (SOPK, SOPP)) else (st.rsgpr(inst.sdst) if isinstance(inst, SOPK) else 0))
s1 = st.rsrc64(inst.ssrc1, 0) if inst.is_src_64(1) else (st.rsrc(inst.ssrc1, 0) if isinstance(inst, (SOP2, SOPC)) else inst.simm16 if isinstance(inst, SOPK) else 0)
literal = inst._literal
s0 = st.rsrc64(ssrc0, 0, literal) if inst.is_src_64(0) else (st.rsrc(ssrc0, 0, literal) if not isinstance(inst, (SOPK, SOPP)) else (st.rsgpr(inst.sdst) if isinstance(inst, SOPK) else 0))
s1 = st.rsrc64(inst.ssrc1, 0, literal) if inst.is_src_64(1) else (st.rsrc(inst.ssrc1, 0, literal) if isinstance(inst, (SOP2, SOPC)) else inst.simm16 if isinstance(inst, SOPK) else 0)
d0 = st.rsgpr64(sdst) if inst.dst_regs() == 2 and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0)
literal = inst.simm16 if isinstance(inst, (SOPK, SOPP)) else st.literal
literal = inst.simm16 if isinstance(inst, (SOPK, SOPP)) else inst._literal
# Create Reg objects for compiled function - mask VCC/EXEC to 32 bits for wave32
result = fn(Reg(s0), Reg(s1), None, Reg(d0), Reg(st.scc), Reg(st.vcc & MASK32), 0, Reg(st.exec_mask & MASK32), literal, None, PC=Reg(st.pc * 4))
# Call compiled function with int parameters
result = inst._fn(s0, s1, 0, d0, st.scc, st.vcc & MASK32, 0, st.exec_mask & MASK32, literal, None, pc=st.pc * 4)
# Apply results - extract values from returned Reg objects
# Apply results (already int values)
if sdst is not None and 'D0' in result:
(st.wsgpr64 if inst.dst_regs() == 2 else st.wsgpr)(sdst, result['D0']._val)
if 'SCC' in result: st.scc = result['SCC']._val & 1
if 'EXEC' in result: st.exec_mask = result['EXEC']._val
(st.wsgpr64 if inst.dst_regs() == 2 else st.wsgpr)(sdst, result['D0'])
if 'SCC' in result: st.scc = result['SCC'] & 1
if 'EXEC' in result: st.exec_mask = result['EXEC']
if 'PC' in result:
# Convert absolute byte address to word delta
pc_val = result['PC']._val
# Convert absolute byte address to word offset
pc_val = result['PC']
new_pc = pc_val if pc_val < 0x8000000000000000 else pc_val - 0x10000000000000000
new_pc_words = new_pc // 4
return new_pc_words - st.pc - 1 # -1 because emulator adds inst_words (1 for scalar)
st.pc = new_pc // 4
else:
st.pc += inst._words
return 0
def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None) -> None:
"""Execute vector instruction for one lane."""
compiled = _get_compiled()
V = st.vgpr[lane]
# ═══════════════════════════════════════════════════════════════════════════════
# VECTOR INSTRUCTIONS
# ═══════════════════════════════════════════════════════════════════════════════
# Memory ops (FLAT/GLOBAL/SCRATCH and DS) - use generated pcode
if isinstance(inst, (FLAT, DS)):
op, vdst, op_name = inst.op, inst.vdst, inst.op.name
fn, ndwords = compiled[type(op)][op], _op_ndwords(op_name)
if isinstance(inst, FLAT):
addr = V[inst.addr] | (V[inst.addr + 1] << 32)
ADDR = (st.rsgpr64(inst.saddr) + V[inst.addr] + _sext(inst.offset, 13)) & MASK64 if inst.saddr not in (NULL, 0x7f) else (addr + _sext(inst.offset, 13)) & MASK64
# For loads, VDATA comes from vdst (preserves unwritten bits); for stores, from inst.data
vdata_src = vdst if 'LOAD' in op_name else inst.data
result = fn(GlobalMem, ADDR, _vgpr_read(V, vdata_src, ndwords), Reg(V[vdst]), Reg(0))
if 'VDATA' in result: _vgpr_write(V, vdst, result['VDATA']._val, ndwords)
if 'RETURN_DATA' in result: _vgpr_write(V, vdst, result['RETURN_DATA']._val, ndwords)
else: # DS
DATA0, DATA1 = _vgpr_read(V, inst.data0, ndwords), _vgpr_read(V, inst.data1, ndwords) if inst.data1 is not None else Reg(0)
result = fn(lds, Reg(V[inst.addr]), DATA0, DATA1, Reg(inst.offset0), Reg(inst.offset1), Reg(0))
if 'RETURN_DATA' in result and ('_RTN' in op_name or '_LOAD' in op_name):
_vgpr_write(V, vdst, result['RETURN_DATA']._val, ndwords * 2 if '_2ADDR_' in op_name else ndwords)
return
def exec_vopd(st: WaveState, inst, V: list, lane: int) -> None:
"""VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes)."""
literal, vdstx, vdsty = inst._literal, inst.vdstx, (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1)
sx0, sx1, dx, sy0, sy1, dy = st.rsrc(inst.srcx0, lane, literal), V[inst.vsrcx1], V[vdstx], st.rsrc(inst.srcy0, lane, literal), V[inst.vsrcy1], V[vdsty]
opx, opy = _VOPD_TO_VOP[inst.opx], _VOPD_TO_VOP[inst.opy]
V[vdstx] = COMPILED_FUNCTIONS[type(opx)][opx](sx0, sx1, 0, dx, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0']
V[vdsty] = COMPILED_FUNCTIONS[type(opy)][opy](sy0, sy1, 0, dy, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0']
# VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes)
if isinstance(inst, VOPD):
vdsty = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1)
inputs = [(inst.opx, st.rsrc(inst.srcx0, lane), V[inst.vsrcx1], V[inst.vdstx], inst.vdstx),
(inst.opy, st.rsrc(inst.srcy0, lane), V[inst.vsrcy1], V[vdsty], vdsty)]
def exec_vopd(vopd_op, s0, s1, d0):
op = _VOPD_TO_VOP[vopd_op]
return compiled[type(op)][op](Reg(s0), Reg(s1), None, Reg(d0), Reg(st.scc), Reg(st.vcc), lane, Reg(st.exec_mask), st.literal, None)['D0']._val
for vopd_op, s0, s1, d0, dst in inputs: V[dst] = exec_vopd(vopd_op, s0, s1, d0)
return
def exec_flat(st: WaveState, inst, V: list, lane: int) -> None:
"""FLAT/GLOBAL/SCRATCH memory ops."""
ndwords = _op_ndwords(inst.op_name)
addr = V[inst.addr] | (V[inst.addr + 1] << 32)
ADDR = (st.rsgpr64(inst.saddr) + V[inst.addr] + _sext(inst.offset, 13)) & MASK64 if inst.saddr not in (NULL, 0x7f) else (addr + _sext(inst.offset, 13)) & MASK64
vdata_src = inst.vdst if 'LOAD' in inst.op_name else inst.data
result = inst._fn(GlobalMem, ADDR, _vgpr_read(V, vdata_src, ndwords), V[inst.vdst])
if 'VDATA' in result: _vgpr_write(V, inst.vdst, result['VDATA'], ndwords)
if 'RETURN_DATA' in result: _vgpr_write(V, inst.vdst, result['RETURN_DATA'], ndwords)
# VOP3SD: has extra scalar dest for carry output
if isinstance(inst, VOP3SD):
fn = compiled[VOP3SDOp][inst.op]
# Read sources based on register counts from inst properties
def rsrc_n(src, regs): return st.rsrc64(src, lane) if regs == 2 else st.rsrc(src, lane)
s0, s1, s2 = rsrc_n(inst.src0, inst.src_regs(0)), rsrc_n(inst.src1, inst.src_regs(1)), rsrc_n(inst.src2, inst.src_regs(2))
# Carry-in ops use src2 as carry bitmask instead of VCC
vcc = st.rsgpr64(inst.src2) if 'CO_CI' in inst.op_name else st.vcc
result = fn(Reg(s0), Reg(s1), Reg(s2), Reg(V[inst.vdst]), Reg(st.scc), Reg(vcc), lane, Reg(st.exec_mask), st.literal, None)
d0_val = result['D0']._val
V[inst.vdst] = d0_val & MASK32
if inst.dst_regs() == 2: V[inst.vdst + 1] = (d0_val >> 32) & MASK32
if 'VCC' in result: st.pend_sgpr_lane(inst.sdst, lane, (result['VCC']._val >> lane) & 1)
return
def exec_ds(st: WaveState, inst, V: list, lane: int) -> None:
"""DS (LDS) memory ops."""
ndwords = _op_ndwords(inst.op_name)
data0, data1 = _vgpr_read(V, inst.data0, ndwords), _vgpr_read(V, inst.data1, ndwords) if inst.data1 is not None else 0
result = inst._fn(st.lds, V[inst.addr], data0, data1, inst.offset0, inst.offset1)
if 'RETURN_DATA' in result and ('_RTN' in inst.op_name or '_LOAD' in inst.op_name):
_vgpr_write(V, inst.vdst, result['RETURN_DATA'], ndwords * 2 if '_2ADDR_' in inst.op_name else ndwords)
# Get op enum and sources (None means "no source" for that operand)
# dst_hi: for VOP1/VOP2 16-bit dst ops, bit 7 of vdst indicates .h (high 16-bit) destination
dst_hi = False
if isinstance(inst, VOP1):
if inst.op == VOP1Op.V_NOP: return
src0, src1, src2 = inst.src0, None, None
dst_hi = (inst.vdst & 0x80) != 0 and inst.is_dst_16()
vdst = inst.vdst & 0x7f if inst.is_dst_16() else inst.vdst
def exec_vop(st: WaveState, inst: Inst, V: list, lane: int) -> None:
"""VOP1/VOP2/VOP3/VOP3SD/VOP3P/VOPC: standard ALU ops."""
if isinstance(inst, VOP3P):
src0, src1, src2, vdst, dst_hi = inst.src0, inst.src1, inst.src2, inst.vdst, False
neg, abs_, opsel = inst.neg, 0, inst.opsel
elif isinstance(inst, VOP1):
src0, src1, src2, vdst = inst.src0, None, None, inst.vdst & 0x7f if inst.is_dst_16() else inst.vdst
neg, abs_, opsel, dst_hi = 0, 0, 0, (inst.vdst & 0x80) != 0 and inst.is_dst_16()
elif isinstance(inst, VOP2):
src0, src1, src2 = inst.src0, inst.vsrc1 + 256, None
dst_hi = (inst.vdst & 0x80) != 0 and inst.is_dst_16()
vdst = inst.vdst & 0x7f if inst.is_dst_16() else inst.vdst
elif isinstance(inst, VOP3):
# VOP3 ops 0-255 are VOPC comparisons encoded as VOP3 - inst.op returns VOPCOp for these
src0, src1, src2, vdst = inst.src0, inst.src1, (None if inst.op.value < 256 else inst.src2), inst.vdst
src0, src1, src2, vdst = inst.src0, inst.vsrc1 + 256, None, inst.vdst & 0x7f if inst.is_dst_16() else inst.vdst
neg, abs_, opsel, dst_hi = 0, 0, 0, (inst.vdst & 0x80) != 0 and inst.is_dst_16()
elif isinstance(inst, (VOP3, VOP3SD)):
src0, src1, src2, vdst = inst.src0, inst.src1, (None if isinstance(inst, VOP3) and inst.op.value < 256 else inst.src2), inst.vdst
neg, abs_, opsel, dst_hi = (inst.neg, inst.abs, inst.opsel, False) if isinstance(inst, VOP3) else (0, 0, 0, False)
elif isinstance(inst, VOPC):
# For 16-bit VOPC, vsrc1 uses same encoding as VOP2 16-bit: bit 7 selects hi(1) or lo(0) half
# vsrc1 field is 8 bits: [6:0] = VGPR index, [7] = hi flag
src0, src1, src2, vdst = inst.src0, inst.vsrc1 + 256, None, VCC_LO
elif isinstance(inst, VOP3P):
# VOP3P: Packed 16-bit operations using compiled functions
# WMMA: wave-level matrix multiply-accumulate (special handling - needs cross-lane access)
if 'WMMA' in inst.op_name:
if lane == 0: # Only execute once per wave, write results for all lanes
exec_wmma(st, inst, inst.op)
return
# V_FMA_MIX: Mixed precision FMA - opsel_hi controls f32(0) vs f16(1), opsel selects which f16 half
if 'FMA_MIX' in inst.op_name:
opsel, opsel_hi, opsel_hi2 = getattr(inst, 'opsel', 0), getattr(inst, 'opsel_hi', 0), getattr(inst, 'opsel_hi2', 0)
neg, abs_ = getattr(inst, 'neg', 0), getattr(inst, 'neg_hi', 0) # neg_hi reused as abs
raws = [st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane) if inst.src2 is not None else 0]
is_f16 = [opsel_hi & 1, opsel_hi & 2, opsel_hi2]
srcs = [_f16(_src16(raws[i], bool(opsel & (1<<i)))) if is_f16[i] else _f32(raws[i]) for i in range(3)]
for i in range(3):
if abs_ & (1<<i): srcs[i] = abs(srcs[i])
if neg & (1<<i): srcs[i] = -srcs[i]
result = srcs[0] * srcs[1] + srcs[2]
st.vgpr[lane][inst.vdst] = _i32(result) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result), inst.op == VOP3POp.V_FMA_MIXHI_F16)
return
# VOP3P packed ops: opsel selects halves for lo, opsel_hi for hi; neg toggles f16 sign
raws = [st.rsrc_f16(inst.src0, lane), st.rsrc_f16(inst.src1, lane), st.rsrc_f16(inst.src2, lane) if inst.src2 is not None else 0]
opsel, opsel_hi, opsel_hi2 = getattr(inst, 'opsel', 0), getattr(inst, 'opsel_hi', 3), getattr(inst, 'opsel_hi2', 1)
neg, neg_hi = getattr(inst, 'neg', 0), getattr(inst, 'neg_hi', 0)
hi_sels = [opsel_hi & 1, opsel_hi & 2, opsel_hi2]
srcs = [((_src16(raws[i], hi_sels[i]) ^ (0x8000 if neg_hi & (1<<i) else 0)) << 16) |
(_src16(raws[i], opsel & (1<<i)) ^ (0x8000 if neg & (1<<i) else 0)) for i in range(3)]
result = compiled[VOP3POp][inst.op](Reg(srcs[0]), Reg(srcs[1]), Reg(srcs[2]), Reg(0), Reg(st.scc), Reg(st.vcc), lane, Reg(st.exec_mask), st.literal, None)
st.vgpr[lane][inst.vdst] = result['D0']._val & MASK32
return
else: raise NotImplementedError(f"Unknown vector type {type(inst)}")
src0, src1, src2, vdst, neg, abs_, opsel, dst_hi = inst.src0, inst.vsrc1 + 256, None, VCC_LO, 0, 0, 0, False
else:
raise NotImplementedError(f"exec_vop: unhandled instruction type {type(inst).__name__}")
op_cls = type(inst.op)
if (fn := compiled.get(op_cls, {}).get(inst.op)) is None: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
s0 = _read_src(st, inst, src0, 0, lane, neg, abs_, opsel)
s1 = _read_src(st, inst, src1, 1, lane, neg, abs_, opsel)
s2 = _read_src(st, inst, src2, 2, lane, neg, abs_, opsel)
if isinstance(inst, VOP2) and inst.is_16bit(): d0 = _src16(V[vdst], dst_hi)
elif inst.dst_regs() == 2: d0 = V[vdst] | (V[vdst + 1] << 32)
else: d0 = V[vdst]
# Read sources (with VOP3 modifiers if applicable)
neg, abs_ = (getattr(inst, 'neg', 0), getattr(inst, 'abs', 0)) if isinstance(inst, VOP3) else (0, 0)
opsel = getattr(inst, 'opsel', 0) if isinstance(inst, VOP3) else 0
def mod_src(val: int, idx: int, is64=False) -> int:
to_f, to_i = (_f64, _i64) if is64 else (_f32, _i32)
if (abs_ >> idx) & 1: val = to_i(abs(to_f(val)))
if (neg >> idx) & 1: val = to_i(-to_f(val))
return val
# Use inst methods to determine operand sizes (inst.is_src_16, inst.is_src_64, etc.)
is_vop2_16bit = isinstance(inst, VOP2) and inst.is_16bit()
# Read sources based on register counts and dtypes from inst properties
def read_src(src, idx, regs, is_src_16):
if src is None: return 0
if regs == 2: return mod_src(st.rsrc64(src, lane), idx, is64=True)
if is_src_16 and isinstance(inst, VOP3):
raw = st.rsrc_f16(src, lane) if 128 <= src < 255 else st.rsrc(src, lane)
val = _src16(raw, bool(opsel & (1 << idx)))
if abs_ & (1 << idx): val &= 0x7fff
if neg & (1 << idx): val ^= 0x8000
return val
if is_src_16 and isinstance(inst, (VOP1, VOP2, VOPC)):
if src >= 256: return _src16(mod_src(st.rsrc(_vgpr_masked(src), lane), idx), _vgpr_hi(src))
return mod_src(st.rsrc_f16(src, lane), idx) & 0xffff
return mod_src(st.rsrc(src, lane), idx)
s0 = read_src(src0, 0, inst.src_regs(0), inst.is_src_16(0))
s1 = read_src(src1, 1, inst.src_regs(1), inst.is_src_16(1)) if src1 is not None else 0
s2 = read_src(src2, 2, inst.src_regs(2), inst.is_src_16(2)) if src2 is not None else 0
# Read destination (accumulator for VOP2 f16, 64-bit for 64-bit ops)
d0 = _src16(V[vdst], dst_hi) if is_vop2_16bit else (V[vdst] | (V[vdst + 1] << 32)) if inst.dst_regs() == 2 else V[vdst]
# V_CNDMASK_B32/B16: VOP3 encoding uses src2 as mask (not VCC); VOP2 uses VCC implicitly
# Pass the correct mask as vcc to the function so pseudocode VCC.u64[laneId] works correctly
vcc_for_fn = st.rsgpr64(src2) if inst.op in (VOP3Op.V_CNDMASK_B32, VOP3Op.V_CNDMASK_B16) and isinstance(inst, VOP3) and src2 is not None and src2 < 256 else st.vcc
# Execute compiled function - pass src0_idx and vdst_idx for lane instructions
# For VGPR access: src0 index is the VGPR number (src0 - 256 if VGPR, else src0 for SGPR)
if isinstance(inst, VOP3SD) and 'CO_CI' in inst.op_name: vcc_for_fn = st.rsgpr64(inst.src2)
elif isinstance(inst, VOP3) and inst.op in (VOP3Op.V_CNDMASK_B32, VOP3Op.V_CNDMASK_B16) and src2 is not None and src2 < 256: vcc_for_fn = st.rsgpr64(src2)
else: vcc_for_fn = st.vcc
src0_idx = (src0 - 256) if src0 is not None and src0 >= 256 else (src0 if src0 is not None else 0)
result = fn(Reg(s0), Reg(s1), Reg(s2), Reg(d0), Reg(st.scc), Reg(vcc_for_fn), lane, Reg(st.exec_mask), st.literal, st.vgpr, src0_idx, vdst)
extra_kwargs = {'opsel': opsel, 'opsel_hi': inst.opsel_hi | (inst.opsel_hi2 << 2)} if isinstance(inst, VOP3P) and 'FMA_MIX' in inst.op_name else {}
result = inst._fn(s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, inst._literal, st.vgpr, src0_idx, vdst, **extra_kwargs)
# Apply results - extract values from returned Reg objects
if 'vgpr_write' in result:
# Lane instruction wrote to VGPR: (lane, vgpr_idx, value)
wr_lane, wr_idx, wr_val = result['vgpr_write']
st.vgpr[wr_lane][wr_idx] = wr_val
if 'VCC' in result:
# VOP2 carry ops write to VCC implicitly; VOPC/VOP3 write to vdst
st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, (result['VCC']._val >> lane) & 1)
if isinstance(inst, VOP3SD): st.pend_sgpr_lane(inst.sdst, lane, (result['VCC'] >> lane) & 1)
else: st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, (result['VCC'] >> lane) & 1)
if 'EXEC' in result:
# V_CMPX instructions write to EXEC per-lane (not to vdst)
st.pend_sgpr_lane(EXEC_LO, lane, (result['EXEC']._val >> lane) & 1)
elif op_cls is VOPCOp:
# VOPC comparison result stored in D0 bitmask, extract lane bit (non-CMPX only)
st.pend_sgpr_lane(vdst, lane, (result['D0']._val >> lane) & 1)
if op_cls is not VOPCOp and 'vgpr_write' not in result:
writes_to_sgpr = 'READFIRSTLANE' in inst.op_name or 'READLANE' in inst.op_name
d0_val = result['D0']._val
if writes_to_sgpr: st.wsgpr(vdst, d0_val & MASK32)
elif inst.dst_regs() == 2: V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
elif inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi)
st.pend_sgpr_lane(EXEC_LO, lane, (result['EXEC'] >> lane) & 1)
elif isinstance(inst.op, VOPCOp):
st.pend_sgpr_lane(vdst, lane, (result['D0'] >> lane) & 1)
if not isinstance(inst.op, VOPCOp):
d0_val = result['D0']
if inst.dst_regs() == 2: V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
elif not isinstance(inst, VOP3P) and inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi)
else: V[vdst] = d0_val & MASK32
# ═══════════════════════════════════════════════════════════════════════════════
@@ -419,64 +326,102 @@ def exec_wmma(st: WaveState, inst, op: VOP3POp) -> None:
else:
for i in range(256): st.vgpr[i % 32][vdst + i//32] = _i32(mat_d[i])
# ═══════════════════════════════════════════════════════════════════════════════
# PROGRAM DECODE
# ═══════════════════════════════════════════════════════════════════════════════
# Wave-level dispatch functions: (st, inst) -> return_code (0 = continue, -1 = end, -2 = barrier)
def dispatch_endpgm(st, inst): return -1
def dispatch_barrier(st, inst): st.pc += inst._words; return -2
def dispatch_nop(st, inst): st.pc += inst._words; return 0
def dispatch_wmma(st, inst): exec_wmma(st, inst, inst.op); st.pc += inst._words; return 0
def dispatch_writelane(st, inst): st.vgpr[st.rsrc(inst.src1, 0, inst._literal) & 0x1f][inst.vdst] = st.rsrc(inst.src0, 0, inst._literal) & MASK32; st.pc += inst._words; return 0
def dispatch_readlane(st, inst):
src0_idx = (inst.src0 - 256) if inst.src0 >= 256 else inst.src0
s1 = st.rsrc(inst.src1, 0, inst._literal) if getattr(inst, 'src1', None) is not None else 0
result = inst._fn(0, s1, 0, 0, st.scc, st.vcc, 0, st.exec_mask, inst._literal, st.vgpr, src0_idx, inst.vdst)
st.wsgpr(inst.vdst, result['D0'])
st.pc += inst._words; return 0
# Per-lane dispatch wrapper: wraps per-lane exec functions into wave-level dispatch
@functools.cache
def dispatch_lane(exec_fn):
def dispatch(st, inst):
exec_mask, vgpr, n_lanes = st.exec_mask, st.vgpr, st.n_lanes
for lane in range(n_lanes):
if exec_mask >> lane & 1: exec_fn(st, inst, vgpr[lane], lane)
st.commit_pends()
st.pc += inst._words
return 0
return dispatch
def decode_program(data: bytes) -> dict[int, Inst]:
result: dict[int, Inst] = {}
i = 0
while i < len(data):
try: inst_class = detect_format(data[i:])
except ValueError: break # stop at invalid instruction (padding/metadata after code)
inst = inst_class.from_bytes(data[i:i+inst_class._size()+8]) # +8 for potential 64-bit literal
inst._words = inst.size() // 4
# Determine dispatch function and pcode function
fn = COMPILED_FUNCTIONS.get(type(inst.op), {}).get(inst.op)
if isinstance(inst, SOPP) and inst.op == SOPPOp.S_ENDPGM: inst._dispatch = dispatch_endpgm
elif isinstance(inst, SOPP) and inst.op == SOPPOp.S_BARRIER: inst._dispatch = dispatch_barrier
elif isinstance(inst, (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM)): inst._dispatch = exec_scalar
elif isinstance(inst, VOP1) and inst.op == VOP1Op.V_NOP: inst._dispatch = dispatch_nop
elif isinstance(inst, VOP3P) and 'WMMA' in inst.op_name: inst._dispatch = dispatch_wmma
elif isinstance(inst, VOP3) and inst.op == VOP3Op.V_WRITELANE_B32: inst._dispatch = dispatch_writelane
elif isinstance(inst, (VOP1, VOP3)) and inst.op in (VOP1Op.V_READFIRSTLANE_B32, VOP3Op.V_READFIRSTLANE_B32, VOP3Op.V_READLANE_B32): inst._dispatch = dispatch_readlane
elif isinstance(inst, VOPD): inst._dispatch = dispatch_lane(exec_vopd)
elif isinstance(inst, FLAT): inst._dispatch = dispatch_lane(exec_flat)
elif isinstance(inst, DS): inst._dispatch = dispatch_lane(exec_ds)
else: inst._dispatch = dispatch_lane(exec_vop)
# Validate pcode exists for instructions that need it (scalar/wave-level ops and VOPD don't need pcode)
needs_pcode = inst._dispatch not in (dispatch_endpgm, dispatch_barrier, exec_scalar, dispatch_nop, dispatch_wmma,
dispatch_writelane, dispatch_readlane, dispatch_lane(exec_vopd))
if fn is None and inst.op_name and needs_pcode: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
inst._fn = fn if fn else lambda *args, **kwargs: {}
result[i // 4] = inst
i += inst._words * 4
return result
# ═══════════════════════════════════════════════════════════════════════════════
# MAIN EXECUTION LOOP
# ═══════════════════════════════════════════════════════════════════════════════
def step_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int:
inst = program.get(st.pc)
if inst is None: return 1
inst_words, st.literal = inst._words, getattr(inst, '_literal', None) or 0
def exec_wave(program: dict[int, Inst], st: WaveState) -> int:
while (inst := program.get(st.pc)) and (result := inst._dispatch(st, inst)) == 0: pass
return result
if isinstance(inst, (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM)):
delta = exec_scalar(st, inst)
if delta == -1: return -1 # endpgm
if delta == -2: st.pc += inst_words; return -2 # barrier
st.pc += inst_words + delta
else:
# V_READFIRSTLANE/V_READLANE write to SGPR, execute once; others execute per-lane with exec_mask
is_readlane = isinstance(inst, (VOP1, VOP3)) and ('READFIRSTLANE' in inst.op_name or 'READLANE' in inst.op_name)
exec_mask = 1 if is_readlane else st.exec_mask
for lane in range(1 if is_readlane else n_lanes):
if exec_mask & (1 << lane): exec_vector(st, inst, lane, lds)
st.commit_pends()
st.pc += inst_words
return 0
def exec_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int:
while st.pc in program:
result = step_wave(program, st, lds, n_lanes)
if result == -1: return 0
if result == -2: return -2
return 0
def exec_workgroup(program: Program, workgroup_id: tuple[int, int, int], local_size: tuple[int, int, int], args_ptr: int,
wg_id_sgpr_base: int, wg_id_enables: tuple[bool, bool, bool]) -> None:
def exec_workgroup(program: dict[int, Inst], workgroup_id: tuple[int, int, int], local_size: tuple[int, int, int], args_ptr: int, rsrc2: int) -> None:
lx, ly, lz = local_size
total_threads, lds = lx * ly * lz, LDSMem(bytearray(65536))
waves: list[tuple[WaveState, int, int]] = []
total_threads = lx * ly * lz
# GRANULATED_LDS_SIZE is in 512-byte units (see ops_amd.py: lds_size = ((group_segment_size + 511) // 512))
lds_size = ((rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE_SHIFT) * 512
lds = LDSMem(bytearray(lds_size)) if lds_size else None
waves: list[WaveState] = []
for wave_start in range(0, total_threads, WAVE_SIZE):
n_lanes, st = min(WAVE_SIZE, total_threads - wave_start), WaveState()
n_lanes = min(WAVE_SIZE, total_threads - wave_start)
st = WaveState(lds, n_lanes)
st.exec_mask = (1 << n_lanes) - 1
st.wsgpr64(0, args_ptr)
# Set workgroup IDs in SGPRs based on USER_SGPR_COUNT and enable flags from COMPUTE_PGM_RSRC2
sgpr_idx = wg_id_sgpr_base
for wg_id, enabled in zip(workgroup_id, wg_id_enables):
if enabled: st.sgpr[sgpr_idx] = wg_id; sgpr_idx += 1
# Set workitem IDs in VGPR0 using packed method: v0 = (Z << 20) | (Y << 10) | X
for i in range(n_lanes):
tid = wave_start + i
st.vgpr[i][0] = ((tid // (lx * ly)) << 20) | (((tid // lx) % ly) << 10) | (tid % lx)
waves.append((st, n_lanes, wave_start))
has_barrier = any(isinstance(inst, SOPP) and inst.op == SOPPOp.S_BARRIER for inst in program.values())
for _ in range(2 if has_barrier else 1):
for st, n_lanes, _ in waves: exec_wave(program, st, lds, n_lanes)
st.wsgpr64(0, args_ptr) # s[0:1] = kernel arguments pointer
# COMPUTE_PGM_RSRC2: USER_SGPR_COUNT is where workgroup IDs start, ENABLE_SGPR_WORKGROUP_ID_X/Y/Z control which are passed
sgpr_idx = (rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT_SHIFT
if rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_X: st.sgpr[sgpr_idx] = workgroup_id[0]; sgpr_idx += 1
if rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Y: st.sgpr[sgpr_idx] = workgroup_id[1]; sgpr_idx += 1
if rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Z: st.sgpr[sgpr_idx] = workgroup_id[2]
# VGPR0 = packed workitem IDs: (Z << 20) | (Y << 10) | X
for tid in range(wave_start, wave_start + n_lanes):
st.vgpr[tid - wave_start][0] = ((tid // (lx * ly)) << 20) | (((tid // lx) % ly) << 10) | (tid % lx)
waves.append(st)
while waves:
waves = [st for st in waves if exec_wave(program, st) != -1]
def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int = 0x19c) -> int:
program = decode_program((ctypes.c_char * lib_sz).from_address(lib).raw)
if not program: return -1
wg_id_enables = tuple(bool((rsrc2 >> (7+i)) & 1) for i in range(3))
for gidz in range(gz):
for gidy in range(gy):
for gidx in range(gx): exec_workgroup(program, (gidx, gidy, gidz), (lx, ly, lz), args_ptr, (rsrc2 >> 1) & 0x1f, wg_id_enables)
for gidx in range(gx): exec_workgroup(program, (gidx, gidy, gidz), (lx, ly, lz), args_ptr, rsrc2)
return 0

View File

@@ -1,6 +1,6 @@
# DSL for RDNA3 pseudocode - makes pseudocode expressions work directly as Python
import struct, math
from extra.assembly.amd.dsl import MASK32, MASK64, MASK128, _f32, _i32, _sext, _f16, _i16, _f64, _i64
from extra.assembly.amd.dsl import MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
# ═══════════════════════════════════════════════════════════════════════════════
# HELPER FUNCTIONS
@@ -33,7 +33,9 @@ def _isquietnan(x): return _check_nan_type(x, 1, True) # quiet NaN has quiet bi
def _issignalnan(x): return _check_nan_type(x, 0, False) # signaling NaN has quiet bit = 0
def _gt_neg_zero(a, b): return (a > b) or (a == 0 and b == 0 and not math.copysign(1, a) < 0 and math.copysign(1, b) < 0)
def _lt_neg_zero(a, b): return (a < b) or (a == 0 and b == 0 and math.copysign(1, a) < 0 and not math.copysign(1, b) < 0)
def _fma(a, b, c): return a * b + c
def _fma(a, b, c):
try: return math.fma(a, b, c)
except ValueError: return float('nan') # inf * 0 + c is NaN per IEEE 754
def _signext(v): return v
def _fpop(fn):
def wrapper(x):
@@ -269,31 +271,6 @@ ROUND_MODE = _RoundMode()
def cvtToQuietNAN(x): return float('nan')
DST = None # Placeholder, will be set in context
# 2/PI with 1201 bits of precision for V_TRIG_PREOP_F64
# Computed as: int((2/pi) * 2^1201) - this is the fractional part of 2/pi scaled to integer
# The MSB (bit 1200) corresponds to 2^0 position in the fraction 0.b1200 b1199 ... b1 b0
_TWO_OVER_PI_1201_RAW = 0x0145f306dc9c882a53f84eafa3ea69bb81b6c52b3278872083fca2c757bd778ac36e48dc74849ba5c00c925dd413a32439fc3bd63962534e7dd1046bea5d768909d338e04d68befc827323ac7306a673e93908bf177bf250763ff12fffbc0b301fde5e2316b414da3eda6cfd9e4f96136e9e8c7ecd3cbfd45aea4f758fd7cbe2f67a0e73ef14a525d4d7f6bf623f1aba10ac06608df8f6
class _BigInt:
"""Wrapper for large integers that supports bit slicing [high:low]."""
__slots__ = ('_val',)
def __init__(self, val): self._val = val
def __getitem__(self, key):
if isinstance(key, slice):
high, low = key.start, key.stop
if high < low: high, low = low, high # Handle reversed slice
mask = (1 << (high - low + 1)) - 1
return (self._val >> low) & mask
return (self._val >> key) & 1
def __int__(self): return self._val
def __index__(self): return self._val
def __lshift__(self, n): return self._val << int(n)
def __rshift__(self, n): return self._val >> int(n)
def __and__(self, n): return self._val & int(n)
def __or__(self, n): return self._val | int(n)
TWO_OVER_PI_1201 = _BigInt(_TWO_OVER_PI_1201_RAW)
class _WaveMode:
IEEE = False
WAVE_MODE = _WaveMode()
@@ -312,14 +289,16 @@ class _Denorm:
f64 = _DenormChecker(64)
DENORM = _Denorm()
class SliceProxy:
"""Proxy for D0[31:16] that supports .f16/.u16 etc getters and setters."""
__slots__ = ('_reg', '_high', '_low', '_reversed')
def __init__(self, reg, high, low):
self._reg = reg
class TypedView:
"""View into a Reg with typed access. Used for both full-width (Reg.u32) and slices (Reg[31:16])."""
__slots__ = ('_reg', '_high', '_low', '_signed', '_float', '_bf16', '_reversed')
def __init__(self, reg, high, low=0, signed=False, is_float=False, is_bf16=False):
# Handle reversed slices like [0:31] which means bit-reverse
if high < low: self._high, self._low, self._reversed = low, high, True
else: self._high, self._low, self._reversed = high, low, False
if high < low: high, low, reversed = low, high, True
else: reversed = False
self._reg, self._high, self._low, self._reversed = reg, high, low, reversed
self._signed, self._float, self._bf16 = signed, is_float, is_bf16
def _nbits(self): return self._high - self._low + 1
def _mask(self): return (1 << self._nbits()) - 1
def _get(self):
@@ -330,6 +309,12 @@ class SliceProxy:
if self._reversed: v = _brev(v, self._nbits())
self._reg._val = (self._reg._val & ~(self._mask() << self._low)) | ((v & self._mask()) << self._low)
@property
def _val(self): return self._get()
@property
def _bits(self): return self._nbits()
# Type accessors for slices (e.g., D0[31:16].f16)
u8 = property(lambda s: s._get() & 0xff)
u16 = property(lambda s: s._get() & 0xffff, lambda s, v: s._set(v))
u32 = property(lambda s: s._get() & MASK32, lambda s, v: s._set(v))
@@ -340,33 +325,17 @@ class SliceProxy:
bf16 = property(lambda s: _bf16(s._get()), lambda s, v: s._set(v if isinstance(v, int) else _ibf16(float(v))))
b16, b32 = u16, u32
def __int__(self): return self._get()
def __index__(self): return self._get()
# Comparison operators (compare as integers)
def __eq__(s, o): return s._get() == int(o)
def __ne__(s, o): return s._get() != int(o)
def __lt__(s, o): return s._get() < int(o)
def __le__(s, o): return s._get() <= int(o)
def __gt__(s, o): return s._get() > int(o)
def __ge__(s, o): return s._get() >= int(o)
class TypedView:
"""View for S0.u32 that supports [4:0] slicing and [bit] access."""
__slots__ = ('_reg', '_bits', '_signed', '_float', '_bf16')
def __init__(self, reg, bits, signed=False, is_float=False, is_bf16=False):
self._reg, self._bits, self._signed, self._float, self._bf16 = reg, bits, signed, is_float, is_bf16
# Chained type access (e.g., jump_addr.i64 when jump_addr is already TypedView)
@property
def _val(self):
mask = MASK64 if self._bits == 64 else MASK32 if self._bits == 32 else (1 << self._bits) - 1
return self._reg._val & mask
def i64(s): return s if s._nbits() == 64 and s._signed else int(s)
@property
def u64(s): return s if s._nbits() == 64 and not s._signed else int(s) & MASK64
def __getitem__(self, key):
if isinstance(key, slice):
high, low = int(key.start), int(key.stop)
return SliceProxy(self._reg, high, low)
return (self._val >> int(key)) & 1
return TypedView(self._reg, high, low)
return (self._get() >> int(key)) & 1
def __setitem__(self, key, value):
if isinstance(key, slice):
@@ -377,14 +346,16 @@ class TypedView:
elif value: self._reg._val |= (1 << int(key))
else: self._reg._val &= ~(1 << int(key))
def __int__(self): return _sext(self._val, self._bits) if self._signed else self._val
def __int__(self): return _sext(self._get(), self._nbits()) if self._signed else self._get()
def __index__(self): return int(self)
def __trunc__(self): return int(float(self)) if self._float else int(self)
def __float__(self):
if self._float:
if self._bf16: return _bf16(self._val) # bf16 uses different conversion
return _f16(self._val) if self._bits == 16 else _f32(self._val) if self._bits == 32 else _f64(self._val)
if self._bf16: return _bf16(self._get())
bits = self._nbits()
return _f16(self._get()) if bits == 16 else _f32(self._get()) if bits == 32 else _f64(self._get())
return float(int(self))
def __bool__(s): return bool(int(s))
# Arithmetic - floats use float(), ints use int()
def __add__(s, o): return float(s) + float(o) if s._float else int(s) + int(o)
@@ -405,8 +376,8 @@ class TypedView:
def __or__(s, o): return int(s) | int(o)
def __xor__(s, o): return int(s) ^ int(o)
def __invert__(s): return ~int(s)
def __lshift__(s, o): n = int(o); return int(s) << n if 0 <= n < 64 else 0
def __rshift__(s, o): n = int(o); return int(s) >> n if 0 <= n < 64 else 0
def __lshift__(s, o): n = int(o); return int(s) << n if 0 <= n < 64 or s._nbits() > 64 else 0
def __rshift__(s, o): n = int(o); return int(s) >> n if 0 <= n < 64 or s._nbits() > 64 else 0
def __rand__(s, o): return int(o) & int(s)
def __ror__(s, o): return int(o) | int(s)
def __rxor__(s, o): return int(o) ^ int(s)
@@ -425,51 +396,42 @@ class TypedView:
def __gt__(s, o): return float(s) > float(o) if s._float else int(s) > int(o)
def __ge__(s, o): return float(s) >= float(o) if s._float else int(s) >= int(o)
def __bool__(s): return bool(int(s))
# Allow chained type access like jump_addr.i64 when jump_addr is already a TypedView
# These just return self or convert appropriately
@property
def i64(s): return s if s._bits == 64 and s._signed else int(s)
@property
def u64(s): return s if s._bits == 64 and not s._signed else int(s) & MASK64
@property
def i32(s): return s if s._bits == 32 and s._signed else _sext(int(s) & MASK32, 32)
@property
def u32(s): return s if s._bits == 32 and not s._signed else int(s) & MASK32
SliceProxy = TypedView # Alias for compatibility
class Reg:
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works. Supports up to 128 bits for DS_LOAD_B128."""
__slots__ = ('_val',)
def __init__(self, val=0): self._val = int(val) & MASK128
def __init__(self, val=0): self._val = int(val)
# Typed views
u64 = property(lambda s: TypedView(s, 64), lambda s, v: setattr(s, '_val', int(v) & MASK64))
i64 = property(lambda s: TypedView(s, 64, signed=True), lambda s, v: setattr(s, '_val', int(v) & MASK64))
b64 = property(lambda s: TypedView(s, 64), lambda s, v: setattr(s, '_val', int(v) & MASK64))
f64 = property(lambda s: TypedView(s, 64, is_float=True), lambda s, v: setattr(s, '_val', v if isinstance(v, int) else _i64(float(v))))
u32 = property(lambda s: TypedView(s, 32), lambda s, v: setattr(s, '_val', int(v) & MASK32))
i32 = property(lambda s: TypedView(s, 32, signed=True), lambda s, v: setattr(s, '_val', int(v) & MASK32))
b32 = property(lambda s: TypedView(s, 32), lambda s, v: setattr(s, '_val', int(v) & MASK32))
f32 = property(lambda s: TypedView(s, 32, is_float=True), lambda s, v: setattr(s, '_val', _i32(float(v))))
u24 = property(lambda s: TypedView(s, 24))
i24 = property(lambda s: TypedView(s, 24, signed=True))
u16 = property(lambda s: TypedView(s, 16), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
i16 = property(lambda s: TypedView(s, 16, signed=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
b16 = property(lambda s: TypedView(s, 16), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
f16 = property(lambda s: TypedView(s, 16, is_float=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _i16(float(v))) & 0xffff)))
bf16 = property(lambda s: TypedView(s, 16, is_float=True, is_bf16=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _ibf16(float(v))) & 0xffff)))
u8 = property(lambda s: TypedView(s, 8))
i8 = property(lambda s: TypedView(s, 8, signed=True))
u1 = property(lambda s: TypedView(s, 1)) # single bit
# Typed views - TypedView(reg, high, signed, is_float, is_bf16)
u64 = property(lambda s: TypedView(s, 63), lambda s, v: setattr(s, '_val', int(v) & MASK64))
i64 = property(lambda s: TypedView(s, 63, signed=True), lambda s, v: setattr(s, '_val', int(v) & MASK64))
b64 = property(lambda s: TypedView(s, 63), lambda s, v: setattr(s, '_val', int(v) & MASK64))
f64 = property(lambda s: TypedView(s, 63, is_float=True), lambda s, v: setattr(s, '_val', v if isinstance(v, int) else _i64(float(v))))
u32 = property(lambda s: TypedView(s, 31), lambda s, v: setattr(s, '_val', int(v) & MASK32))
i32 = property(lambda s: TypedView(s, 31, signed=True), lambda s, v: setattr(s, '_val', int(v) & MASK32))
b32 = property(lambda s: TypedView(s, 31), lambda s, v: setattr(s, '_val', int(v) & MASK32))
f32 = property(lambda s: TypedView(s, 31, is_float=True), lambda s, v: setattr(s, '_val', _i32(float(v))))
u24 = property(lambda s: TypedView(s, 23))
i24 = property(lambda s: TypedView(s, 23, signed=True))
u16 = property(lambda s: TypedView(s, 15), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
i16 = property(lambda s: TypedView(s, 15, signed=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
b16 = property(lambda s: TypedView(s, 15), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
f16 = property(lambda s: TypedView(s, 15, is_float=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _i16(float(v))) & 0xffff)))
bf16 = property(lambda s: TypedView(s, 15, is_float=True, is_bf16=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _ibf16(float(v))) & 0xffff)))
u8 = property(lambda s: TypedView(s, 7))
i8 = property(lambda s: TypedView(s, 7, signed=True))
u3 = property(lambda s: TypedView(s, 2)) # 3-bit for opsel fields
u1 = property(lambda s: TypedView(s, 0)) # single bit
def __getitem__(s, key):
if isinstance(key, slice): return SliceProxy(s, int(key.start), int(key.stop))
if isinstance(key, slice): return TypedView(s, int(key.start), int(key.stop))
return (s._val >> int(key)) & 1
def __setitem__(s, key, value):
if isinstance(key, slice):
high, low = int(key.start), int(key.stop)
if high < low: high, low = low, high
mask = (1 << (high - low + 1)) - 1
s._val = (s._val & ~(mask << low)) | ((int(value) & mask) << low)
elif value: s._val |= (1 << int(key))
@@ -504,4 +466,5 @@ class Reg:
def __eq__(s, o): return s._val == int(o)
def __ne__(s, o): return s._val != int(o)
# 2/PI with 1201 bits of precision for V_TRIG_PREOP_F64
TWO_OVER_PI_1201 = Reg(0x0145f306dc9c882a53f84eafa3ea69bb81b6c52b3278872083fca2c757bd778ac36e48dc74849ba5c00c925dd413a32439fc3bd63962534e7dd1046bea5d768909d338e04d68befc827323ac7306a673e93908bf177bf250763ff12fffbc0b301fde5e2316b414da3eda6cfd9e4f96136e9e8c7ecd3cbfd45aea4f758fd7cbe2f67a0e73ef14a525d4d7f6bf623f1aba10ac06608df8f6)

View File

@@ -42,7 +42,7 @@ INST_PATTERN = re.compile(r'^([SVD]S?_[A-Z0-9_]+|(?:FLAT|GLOBAL|SCRATCH)_[A-Z0-9
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
'CVT_OFF_TABLE', 'ThreadMask',
'S1[i', 'C.i32', 'S[i]', 'in[',
'S1[i', 'C.i32', 'thread_',
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST',
'BARRIER_STATE', 'ReallocVgprs',
'GPR_IDX', 'VSKIP', 'specified in', 'TTBL',
@@ -477,7 +477,7 @@ def _generate_gen_pcode_py(enums, pseudocode, arch) -> str:
# Get op enums for this arch (import from .ins which re-exports from .enum)
import importlib
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}.ins")
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp', 'DSOp', 'FLATOp', 'GLOBALOp', 'SCRATCHOp'] if hasattr(autogen, name)]
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'SMEMOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp', 'DSOp', 'FLATOp', 'GLOBALOp', 'SCRATCHOp'] if hasattr(autogen, name)]
# Build defined ops mapping
defined_ops: dict[tuple, list] = {}
@@ -513,20 +513,10 @@ def _generate_gen_pcode_py(enums, pseudocode, arch) -> str:
for op, fn_name in fn_entries: fn_lines.append(f" {cls_name}.{op.name}: {fn_name},")
fn_lines.append('}\n')
# Add V_WRITELANE_B32 if VOP3Op exists
if 'VOP3Op' in enum_names:
fn_lines.append('''
# V_WRITELANE_B32: Write scalar to specific lane's VGPR (not in PDF pseudocode)
def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):
wr_lane = s1 & 0x1f
return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)}
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32
''')
fn_lines.append('COMPILED_FUNCTIONS = {')
for enum_cls in OP_ENUMS:
if all_fn_entries.get(enum_cls): fn_lines.append(f' {enum_cls.__name__}: {enum_cls.__name__}_FUNCTIONS,')
fn_lines.append('}\n\ndef get_compiled_functions(): return COMPILED_FUNCTIONS')
fn_lines.append('}')
# Second pass: scan generated code for pcode imports
fn_code_str = '\n'.join(fn_lines)
@@ -576,83 +566,102 @@ def _apply_pseudocode_fixes(op, code: str) -> str:
return code
def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]:
"""Generate a single compiled pseudocode function."""
"""Generate a single compiled pseudocode function.
Functions take int parameters and return dict of int values.
Reg wrapping happens inside the function, only for registers actually used."""
has_d1 = '{ D1' in pc
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
is_div_scale = 'DIV_SCALE' in op.name
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
is_ds = cls_name == 'DSOp'
is_flat = cls_name in ('FLATOp', 'GLOBALOp', 'SCRATCHOp')
is_smem = cls_name == 'SMEMOp'
has_s_array = 'S[i]' in pc # FMA_MIX style: S[0], S[1], S[2] array access
combined = code + pc
fn_name = f"_{cls_name}_{op.name}"
# Function accepts Reg objects directly (uppercase names), laneId is passed directly as int
# DSOp functions get additional MEM and offset parameters
# FLAT/GLOBAL ops get MEM, vaddr, vdata, saddr, offset parameters
if is_ds:
lines = [f"def {fn_name}(MEM, ADDR, DATA0, DATA1, OFFSET0, OFFSET1, RETURN_DATA):"]
elif is_flat:
lines = [f"def {fn_name}(MEM, ADDR, VDATA, VDST, RETURN_DATA):"]
else:
lines = [f"def {fn_name}(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):"]
# Registers that need special handling (aliases or init)
# Detect which registers are used/modified
def needs_init(name): return name in combined and not re.search(rf'^\s*{name}\s*=\s*Reg\(', code, re.MULTILINE)
special_regs = []
if is_ds: special_regs = [('DATA', 'DATA0'), ('DATA2', 'DATA1'), ('OFFSET', 'OFFSET0'), ('ADDR_BASE', 'ADDR')]
elif is_flat: special_regs = [('DATA', 'VDATA')]
else:
special_regs = [('D1', 'Reg(0)'), ('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
used = {name for name, _ in special_regs if name in combined}
# Detect which registers are modified (not just read) - look for assignments
modifies_d0 = is_div_scale or bool(re.search(r'\bD0\b[.\[]', combined))
modifies_exec = is_cmpx or bool(re.search(r'EXEC\.(u32|u64|b32|b64)\s*=', combined))
modifies_vcc = has_sdst or bool(re.search(r'VCC\.(u32|u64|b32|b64)\s*=|VCC\.u64\[laneId\]\s*=', combined))
modifies_scc = bool(re.search(r'\bSCC\s*=', combined))
modifies_pc = bool(re.search(r'\bPC\s*=', combined))
# DS/FLAT ops: detect memory writes (MEM[...] = ...)
modifies_mem = (is_ds or is_flat) and bool(re.search(r'MEM\[.*\]\.[a-z0-9]+\s*=', combined))
# FLAT ops: detect VDST writes
modifies_vdst = is_flat and bool(re.search(r'VDST[\.\[].*=', combined))
# Build init code for special registers
init_lines = []
if is_div_scale: init_lines.append(" D0 = Reg(S0._val)")
# Build function signature and Reg init lines
if is_smem:
lines = [f"def {fn_name}(MEM, addr):"]
reg_inits = ["ADDR=Reg(addr)", "SDATA=Reg(0)"]
special_regs = []
elif is_ds:
lines = [f"def {fn_name}(MEM, addr, data0, data1, offset0, offset1):"]
reg_inits = ["ADDR=Reg(addr)", "DATA0=Reg(data0)", "DATA1=Reg(data1)", "OFFSET0=Reg(offset0)", "OFFSET1=Reg(offset1)", "RETURN_DATA=Reg(0)"]
special_regs = [('DATA', 'DATA0'), ('DATA2', 'DATA1'), ('OFFSET', 'OFFSET0'), ('ADDR_BASE', 'ADDR')]
elif is_flat:
lines = [f"def {fn_name}(MEM, addr, vdata, vdst):"]
reg_inits = ["ADDR=addr", "VDATA=Reg(vdata)", "VDST=Reg(vdst)", "RETURN_DATA=Reg(0)"]
special_regs = [('DATA', 'VDATA')]
elif has_s_array:
# FMA_MIX style: needs S[i] array, opsel, opsel_hi for source selection (neg/neg_hi applied in emu.py before call)
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None, opsel=0, opsel_hi=0):"]
reg_inits = ["S0=Reg(s0)", "S1=Reg(s1)", "S2=Reg(s2)", "S=[S0,S1,S2]", "D0=Reg(d0)", "OPSEL=Reg(opsel)", "OPSEL_HI=Reg(opsel_hi)"]
special_regs = []
# Detect array declarations like "declare in : 32'F[3]" and create them (rename 'in' to 'ins' since 'in' is a keyword)
if "in[" in combined:
reg_inits.append("ins=[Reg(0),Reg(0),Reg(0)]")
code = code.replace("in[", "ins[")
else:
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None):"]
# Only create Regs for registers actually used in the pseudocode
reg_inits = []
if 'S0' in combined: reg_inits.append("S0=Reg(s0)")
if 'S1' in combined: reg_inits.append("S1=Reg(s1)")
if 'S2' in combined: reg_inits.append("S2=Reg(s2)")
if modifies_d0 or 'D0' in combined: reg_inits.append("D0=Reg(s0)" if is_div_scale else "D0=Reg(d0)")
if modifies_scc or 'SCC' in combined: reg_inits.append("SCC=Reg(scc)")
if modifies_vcc or 'VCC' in combined: reg_inits.append("VCC=Reg(vcc)")
if modifies_exec or 'EXEC' in combined: reg_inits.append("EXEC=Reg(exec_mask)")
if modifies_pc or 'PC' in combined: reg_inits.append("PC=Reg(pc) if pc is not None else None")
special_regs = [('D1', 'Reg(0)'), ('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
# Build init code
init_parts = reg_inits.copy()
for name, init in special_regs:
if name in used: init_lines.append(f" {name} = {init}")
if 'EXEC_LO' in code: init_lines.append(" EXEC_LO = SliceProxy(EXEC, 31, 0)")
if 'EXEC_HI' in code: init_lines.append(" EXEC_HI = SliceProxy(EXEC, 63, 32)")
if 'VCCZ' in code and not re.search(r'^\s*VCCZ\s*=', code, re.MULTILINE): init_lines.append(" VCCZ = Reg(1 if VCC._val == 0 else 0)")
if 'EXECZ' in code and not re.search(r'^\s*EXECZ\s*=', code, re.MULTILINE): init_lines.append(" EXECZ = Reg(1 if EXEC._val == 0 else 0)")
code_lines = [line for line in code.split('\n') if line.strip()]
if init_lines:
lines.extend(init_lines)
if code_lines: lines.append(" # --- compiled pseudocode ---")
for line in code_lines:
lines.append(f" {line}")
if name in combined: init_parts.append(f"{name}={init}")
if 'EXEC_LO' in code: init_parts.append("EXEC_LO=SliceProxy(EXEC, 31, 0)")
if 'EXEC_HI' in code: init_parts.append("EXEC_HI=SliceProxy(EXEC, 63, 32)")
if 'VCCZ' in code and not re.search(r'^\s*VCCZ\s*=', code, re.MULTILINE): init_parts.append("VCCZ=Reg(1 if VCC._val == 0 else 0)")
if 'EXECZ' in code and not re.search(r'^\s*EXECZ\s*=', code, re.MULTILINE): init_parts.append("EXECZ=Reg(1 if EXEC._val == 0 else 0)")
# Build result dict - only include registers that are modified
# Add init line and separator
if init_parts: lines.append(f" {'; '.join(init_parts)}")
lines.append(" # --- compiled pseudocode ---")
# Add compiled pseudocode
for line in code.split('\n'):
if line.strip(): lines.append(f" {line}")
# Build result dict
result_items = []
if modifies_d0: result_items.append("'D0': D0")
if modifies_scc: result_items.append("'SCC': SCC")
if modifies_vcc: result_items.append("'VCC': VCC")
if modifies_exec: result_items.append("'EXEC': EXEC")
if has_d1: result_items.append("'D1': D1")
if modifies_pc: result_items.append("'PC': PC")
# DS ops: return RETURN_DATA if it was written (left side of assignment)
if modifies_d0: result_items.append("'D0': D0._val")
if modifies_scc: result_items.append("'SCC': SCC._val")
if modifies_vcc: result_items.append("'VCC': VCC._val")
if modifies_exec: result_items.append("'EXEC': EXEC._val")
if has_d1: result_items.append("'D1': D1._val")
if modifies_pc: result_items.append("'PC': PC._val")
if is_smem and 'SDATA' in combined and re.search(r'^\s*SDATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'SDATA': SDATA._val")
if is_ds and 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'RETURN_DATA': RETURN_DATA")
# FLAT ops: return RETURN_DATA for atomics, VDATA for loads (only if written to)
result_items.append("'RETURN_DATA': RETURN_DATA._val")
if is_flat:
if 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'RETURN_DATA': RETURN_DATA")
result_items.append("'RETURN_DATA': RETURN_DATA._val")
if re.search(r'^\s*VDATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'VDATA': VDATA")
result_items.append("'VDATA': VDATA._val")
lines.append(f" return {{{', '.join(result_items)}}}\n")
return fn_name, '\n'.join(lines)

View File

@@ -1,13 +1,12 @@
#!/usr/bin/env python3
"""Benchmark comparing Python vs Rust RDNA3 emulators on synthetic and real tinygrad kernels."""
import ctypes, time, os, struct, cProfile, pstats, io
"""Benchmark comparing Python vs Rust RDNA3 emulators on real tinygrad kernels."""
import ctypes, time, os
from pathlib import Path
from typing import Callable
# Set AMD=1 before importing tinygrad
os.environ["AMD"] = "1"
from extra.assembly.amd.emu import run_asm as python_run_asm, set_valid_mem_ranges, decode_program, step_wave, WaveState, WAVE_SIZE
from extra.assembly.amd.emu import run_asm as python_run_asm, set_valid_mem_ranges, decode_program
REMU_PATH = Path(__file__).parents[3] / "remu/target/release/libremu.so"
if not REMU_PATH.exists():
@@ -42,7 +41,7 @@ def setup_buffers(buf_sizes: list[int], init_data: dict[int, bytes] | None = Non
ranges.add((args_ptr, ctypes.sizeof(args)))
return buffers, args, args_ptr, ranges
def benchmark_emulator(name: str, run_fn, kernel: bytes, global_size, local_size, args_ptr, iterations: int = 5):
def benchmark_emulator(name: str, run_fn, kernel: bytes, global_size, local_size, args_ptr, rsrc2: int, iterations: int = 5):
"""Benchmark an emulator and return average time."""
gx, gy, gz = global_size
lx, ly, lz = local_size
@@ -50,13 +49,13 @@ def benchmark_emulator(name: str, run_fn, kernel: bytes, global_size, local_size
lib_ptr = ctypes.addressof(kernel_buf)
# Warmup
run_fn(lib_ptr, len(kernel), gx, gy, gz, lx, ly, lz, args_ptr)
run_fn(lib_ptr, len(kernel), gx, gy, gz, lx, ly, lz, args_ptr, rsrc2)
# Timed runs
times = []
for _ in range(iterations):
start = time.perf_counter()
result = run_fn(lib_ptr, len(kernel), gx, gy, gz, lx, ly, lz, args_ptr)
result = run_fn(lib_ptr, len(kernel), gx, gy, gz, lx, ly, lz, args_ptr, rsrc2)
end = time.perf_counter()
if result != 0:
print(f" {name} returned error: {result}")
@@ -65,27 +64,12 @@ def benchmark_emulator(name: str, run_fn, kernel: bytes, global_size, local_size
return sum(times) / len(times)
def create_synthetic_kernel(n_ops: int) -> bytes:
"""Create a synthetic kernel with n_ops vector operations."""
instructions = []
# VOP2 instructions: v_add_f32, v_mul_f32, v_max_f32, v_min_f32
ops = [
(0b0000011 << 25) | (1 << 17) | (0 << 9) | 256, # v_add_f32 v0, v0, v1
(0b0001000 << 25) | (1 << 17) | (0 << 9) | 256, # v_mul_f32 v0, v0, v1
(0b0010000 << 25) | (1 << 17) | (0 << 9) | 256, # v_max_f32 v0, v0, v1
(0b0001111 << 25) | (1 << 17) | (0 << 9) | 256, # v_min_f32 v0, v0, v1
]
for i in range(n_ops):
instructions.append(ops[i % len(ops)])
# S_ENDPGM
instructions.append((0b101111111 << 23) | (48 << 16) | 0)
return b''.join(struct.pack('<I', inst) for inst in instructions)
def get_tinygrad_kernel(op_name: str) -> tuple[bytes, tuple, tuple, list[int], dict[int, bytes]] | None:
"""Get a real tinygrad kernel by operation name. Returns (code, global_size, local_size, buf_sizes, buf_data)."""
def get_tinygrad_kernel(op_name: str) -> tuple[bytes, tuple, tuple, list[int], dict[int, bytes], int] | None:
"""Get a real tinygrad kernel by operation name. Returns (code, global_size, local_size, buf_sizes, buf_data, rsrc2)."""
try:
from tinygrad import Tensor
from tinygrad.runtime.support.elf import elf_loader
from tinygrad.runtime.autogen import hsa
import numpy as np
np.random.seed(42)
@@ -112,7 +96,9 @@ def get_tinygrad_kernel(op_name: str) -> tuple[bytes, tuple, tuple, list[int], d
lowered = ei.lower()
if ei.ast.op.name == 'SINK' and lowered.prg and lowered.prg.p.lib:
lib = bytes(lowered.prg.p.lib)
image = memoryview(bytearray(lib))
_, sections, _ = elf_loader(lib)
rodata_entry = next((sh.header.sh_addr for sh in sections if sh.name == ".rodata"), -1)
for sec in sections:
if sec.name == '.text':
buf_sizes = [b.nbytes for b in lowered.bufs]
@@ -122,67 +108,22 @@ def get_tinygrad_kernel(op_name: str) -> tuple[bytes, tuple, tuple, list[int], d
if hasattr(buf, 'base') and buf.base is not None and hasattr(buf.base, '_buf'):
try: buf_data[i] = bytes(buf.base._buf)
except: pass
return (bytes(sec.content), tuple(lowered.prg.p.global_size), tuple(lowered.prg.p.local_size), buf_sizes, buf_data)
# Extract rsrc2 from ELF (same as ops_amd.py)
group_segment_size = image[rodata_entry:rodata_entry+4].cast("I")[0]
lds_size = ((group_segment_size + 511) // 512) & 0x1FF
code = hsa.amd_kernel_code_t.from_buffer_copy(bytes(image[rodata_entry:rodata_entry+256]) + b'\x00'*256)
rsrc2 = code.compute_pgm_rsrc2 | (lds_size << 15)
return (bytes(sec.content), tuple(lowered.prg.p.global_size), tuple(lowered.prg.p.local_size), buf_sizes, buf_data, rsrc2)
return None
except Exception as e:
print(f" Error getting kernel: {e}")
return None
def profile_python_emu(kernel: bytes, global_size, local_size, args_ptr, n_runs: int = 1):
"""Profile the Python emulator to find bottlenecks."""
gx, gy, gz = global_size
lx, ly, lz = local_size
kernel_buf = (ctypes.c_char * len(kernel)).from_buffer_copy(kernel)
lib_ptr = ctypes.addressof(kernel_buf)
pr = cProfile.Profile()
pr.enable()
for _ in range(n_runs):
python_run_asm(lib_ptr, len(kernel), gx, gy, gz, lx, ly, lz, args_ptr)
pr.disable()
s = io.StringIO()
ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
ps.print_stats(20)
return s.getvalue()
def measure_step_rate(kernel: bytes, n_steps: int = 10000) -> float:
"""Measure raw step_wave() performance (steps per second)."""
program = decode_program(kernel)
if not program: return 0.0
st = WaveState()
st.exec_mask = 0xffffffff
lds = bytearray(65536)
n_lanes = 32
# Reset PC for each measurement
start = time.perf_counter()
for _ in range(n_steps):
st.pc = 0
while st.pc in program:
result = step_wave(program, st, lds, n_lanes)
if result == -1: break
elapsed = time.perf_counter() - start
return n_steps / elapsed if elapsed > 0 else 0
# Test configurations
SYNTHETIC_TESTS = [
("synthetic_10ops", 10, (1, 1, 1), (32, 1, 1)),
("synthetic_100ops", 100, (1, 1, 1), (32, 1, 1)),
("synthetic_500ops", 500, (1, 1, 1), (32, 1, 1)),
("synthetic_100ops_4wg", 100, (4, 1, 1), (32, 1, 1)),
("synthetic_100ops_16wg", 100, (16, 1, 1), (32, 1, 1)),
]
TINYGRAD_TESTS = ["add", "mul", "reduce_sum", "softmax", "exp", "gelu", "matmul_small"]
def main():
import argparse
parser = argparse.ArgumentParser(description="Benchmark RDNA3 emulators")
parser.add_argument("--profile", action="store_true", help="Profile Python emulator")
parser.add_argument("--synthetic-only", action="store_true", help="Only run synthetic tests")
parser.add_argument("--tinygrad-only", action="store_true", help="Only run tinygrad tests")
parser.add_argument("--iterations", type=int, default=3, help="Number of iterations per benchmark")
args = parser.parse_args()
@@ -197,98 +138,55 @@ def main():
results = []
# Synthetic workloads
if not args.tinygrad_only:
print("\n[SYNTHETIC WORKLOADS]")
print("-" * 90)
print("\n[TINYGRAD KERNELS]")
print("-" * 90)
for name, n_ops, global_size, local_size in SYNTHETIC_TESTS:
kernel = create_synthetic_kernel(n_ops)
n_insts = count_instructions(kernel)
n_workgroups = global_size[0] * global_size[1] * global_size[2]
n_threads = local_size[0] * local_size[1] * local_size[2]
total_work = n_insts * n_workgroups * n_threads
for op_name in TINYGRAD_TESTS:
print(f"\n{op_name}:", end=" ", flush=True)
kernel_info = get_tinygrad_kernel(op_name)
if kernel_info is None:
print("failed to compile")
continue
print(f"\n{name}: {n_insts} insts × {n_workgroups} WGs × {n_threads} threads = {total_work:,} ops")
kernel, global_size, local_size, buf_sizes, buf_data, rsrc2 = kernel_info
n_insts = count_instructions(kernel)
n_workgroups = global_size[0] * global_size[1] * global_size[2]
n_threads = local_size[0] * local_size[1] * local_size[2]
total_work = n_insts * n_workgroups * n_threads
buf_sizes = [4096]
buffers, args_arr, args_ptr, ranges = setup_buffers(buf_sizes)
set_valid_mem_ranges(ranges)
print(f"{n_insts} insts × {n_workgroups} WGs × {n_threads} threads = {total_work:,} ops")
# Benchmark
py_time = benchmark_emulator("Python", python_run_asm, kernel, global_size, local_size, args_ptr, args.iterations)
rust_time = benchmark_emulator("Rust", rust_remu.run_asm, kernel, global_size, local_size, args_ptr, args.iterations) if rust_remu else None
buffers, args_arr, args_ptr, ranges = setup_buffers(buf_sizes, buf_data)
set_valid_mem_ranges(ranges)
if py_time:
py_rate = total_work / py_time / 1e6
print(f" Python: {py_time*1000:8.3f} ms ({py_rate:7.2f} M ops/s)")
if rust_time:
rust_rate = total_work / rust_time / 1e6
speedup = py_time / rust_time if py_time else 0
print(f" Rust: {rust_time*1000:8.3f} ms ({rust_rate:7.2f} M ops/s) [{speedup:.1f}x faster]")
py_time = benchmark_emulator("Python", python_run_asm, kernel, global_size, local_size, args_ptr, rsrc2, args.iterations)
rust_time = benchmark_emulator("Rust", rust_remu.run_asm, kernel, global_size, local_size, args_ptr, rsrc2, args.iterations) if rust_remu else None
results.append(("synthetic", name, n_insts, n_workgroups, py_time, rust_time))
if py_time:
py_rate = total_work / py_time / 1e6
print(f" Python: {py_time*1000:8.3f} ms ({py_rate:7.2f} M ops/s)")
if rust_time:
rust_rate = total_work / rust_time / 1e6
speedup = py_time / rust_time if py_time else 0
print(f" Rust: {rust_time*1000:8.3f} ms ({rust_rate:7.2f} M ops/s) [{speedup:.1f}x faster]")
# Tinygrad kernels
if not args.synthetic_only:
print("\n[TINYGRAD KERNELS]")
print("-" * 90)
for op_name in TINYGRAD_TESTS:
print(f"\n{op_name}:", end=" ", flush=True)
kernel_info = get_tinygrad_kernel(op_name)
if kernel_info is None:
print("failed to compile")
continue
kernel, global_size, local_size, buf_sizes, buf_data = kernel_info
n_insts = count_instructions(kernel)
n_workgroups = global_size[0] * global_size[1] * global_size[2]
n_threads = local_size[0] * local_size[1] * local_size[2]
total_work = n_insts * n_workgroups * n_threads
print(f"{n_insts} insts × {n_workgroups} WGs × {n_threads} threads = {total_work:,} ops")
buffers, args_arr, args_ptr, ranges = setup_buffers(buf_sizes, buf_data)
set_valid_mem_ranges(ranges)
py_time = benchmark_emulator("Python", python_run_asm, kernel, global_size, local_size, args_ptr, args.iterations)
rust_time = benchmark_emulator("Rust", rust_remu.run_asm, kernel, global_size, local_size, args_ptr, args.iterations) if rust_remu else None
if py_time:
py_rate = total_work / py_time / 1e6
print(f" Python: {py_time*1000:8.3f} ms ({py_rate:7.2f} M ops/s)")
if rust_time:
rust_rate = total_work / rust_time / 1e6
speedup = py_time / rust_time if py_time else 0
print(f" Rust: {rust_time*1000:8.3f} ms ({rust_rate:7.2f} M ops/s) [{speedup:.1f}x faster]")
results.append(("tinygrad", op_name, n_insts, n_workgroups, py_time, rust_time))
# Optional profiling
if args.profile and py_time:
print("\n [PROFILE - Top 10 functions]")
profile_output = profile_python_emu(kernel, global_size, local_size, args_ptr)
for line in profile_output.split('\n')[5:15]:
if line.strip(): print(f" {line}")
results.append((op_name, n_insts, n_workgroups, py_time, rust_time))
# Summary table
print("\n" + "=" * 90)
print("SUMMARY")
print("=" * 90)
print(f"{'Type':<10} {'Name':<25} {'Insts':<8} {'WGs':<6} {'Python (ms)':<14} {'Rust (ms)':<14} {'Speedup':<10}")
print(f"{'Name':<25} {'Insts':<8} {'WGs':<6} {'Python (ms)':<14} {'Rust (ms)':<14} {'Speedup':<10}")
print("-" * 90)
for test_type, name, n_insts, n_wgs, py_time, rust_time in results:
for name, n_insts, n_wgs, py_time, rust_time in results:
py_ms = f"{py_time*1000:.3f}" if py_time else "error"
if rust_time:
rust_ms = f"{rust_time*1000:.3f}"
speedup = f"{py_time/rust_time:.1f}x" if py_time else "N/A"
else:
rust_ms, speedup = "N/A", "N/A"
print(f"{test_type:<10} {name:<25} {n_insts:<8} {n_wgs:<6} {py_ms:<14} {rust_ms:<14} {speedup:<10}")
print(f"{name:<25} {n_insts:<8} {n_wgs:<6} {py_ms:<14} {rust_ms:<14} {speedup:<10}")
if __name__ == "__main__":
main()

View File

@@ -92,7 +92,9 @@ def run_program_emu(instructions: list, n_lanes: int = 1) -> WaveState:
lib_ptr = ctypes.addressof(kernel_buf)
set_valid_mem_ranges({(out_addr, OUT_BYTES), (args_ptr, 8)})
result = run_asm(lib_ptr, len(code), 1, 1, 1, n_lanes, 1, 1, args_ptr)
# rsrc2: USER_SGPR_COUNT=2, ENABLE_SGPR_WORKGROUP_ID_X/Y/Z=1, LDS_SIZE=128 (64KB)
rsrc2 = 0x19c | (128 << 15)
result = run_asm(lib_ptr, len(code), 1, 1, 1, n_lanes, 1, 1, args_ptr, rsrc2)
assert result == 0, f"run_asm failed with {result}"
return parse_output(bytes(out_buf), n_lanes)

View File

@@ -33,6 +33,35 @@ class TestBasicScalar(unittest.TestCase):
self.assertEqual(st.sgpr[4], 0)
self.assertEqual(st.scc, 1)
def test_s_brev_b32(self):
"""S_BREV_B32 reverses bits of a 32-bit value."""
# 10 = 0b00000000000000000000000000001010
# reversed = 0b01010000000000000000000000000000 = 0x50000000
instructions = [
s_mov_b32(s[0], 10),
s_brev_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[1], 0x50000000)
def test_s_brev_b32_all_ones(self):
"""S_BREV_B32 with all ones stays all ones."""
instructions = [
s_mov_b32(s[0], 0xFFFFFFFF),
s_brev_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[1], 0xFFFFFFFF)
def test_s_brev_b32_single_bit(self):
"""S_BREV_B32 with bit 0 set becomes bit 31."""
instructions = [
s_mov_b32(s[0], 1),
s_brev_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[1], 0x80000000)
class TestQuadmaskWqm(unittest.TestCase):
"""Tests for S_QUADMASK_B32 and S_WQM_B32."""
@@ -201,5 +230,113 @@ class TestSCCBehavior(unittest.TestCase):
self.assertEqual(st.scc, 0)
class TestSignedArithmetic(unittest.TestCase):
"""Tests for S_ADD_I32, S_SUB_I32 and their SCC overflow behavior."""
def test_s_add_i32_no_overflow(self):
"""S_ADD_I32: 1 + 1 = 2, no overflow, SCC=0."""
instructions = [
s_mov_b32(s[0], 1),
s_add_i32(s[1], s[0], 1),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[1], 2)
self.assertEqual(st.scc, 0, "No overflow, SCC should be 0")
def test_s_add_i32_positive_overflow(self):
"""S_ADD_I32: MAX_INT + 1 overflows, SCC=1."""
instructions = [
s_mov_b32(s[0], 0x7FFFFFFF), # MAX_INT
s_add_i32(s[1], s[0], 1),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[1], 0x80000000) # Wraps to MIN_INT
self.assertEqual(st.scc, 1, "Overflow, SCC should be 1")
def test_s_add_i32_negative_no_overflow(self):
"""S_ADD_I32: -10 + 20 = 10, no overflow."""
instructions = [
s_mov_b32(s[0], 0xFFFFFFF6), # -10 in two's complement
s_mov_b32(s[1], 20),
s_add_i32(s[2], s[0], s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[2], 10)
self.assertEqual(st.scc, 0)
def test_s_add_i32_negative_overflow(self):
"""S_ADD_I32: MIN_INT + (-1) underflows, SCC=1."""
instructions = [
s_mov_b32(s[0], 0x80000000), # MIN_INT
s_mov_b32(s[1], 0xFFFFFFFF), # -1
s_add_i32(s[2], s[0], s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[2], 0x7FFFFFFF) # Wraps to MAX_INT
self.assertEqual(st.scc, 1, "Underflow, SCC should be 1")
def test_s_sub_i32_no_overflow(self):
"""S_SUB_I32: 10 - 5 = 5, no overflow."""
instructions = [
s_mov_b32(s[0], 10),
s_mov_b32(s[1], 5),
s_sub_i32(s[2], s[0], s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[2], 5)
self.assertEqual(st.scc, 0)
def test_s_sub_i32_overflow(self):
"""S_SUB_I32: MAX_INT - (-1) overflows, SCC=1."""
instructions = [
s_mov_b32(s[0], 0x7FFFFFFF), # MAX_INT
s_mov_b32(s[1], 0xFFFFFFFF), # -1
s_sub_i32(s[2], s[0], s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[2], 0x80000000) # Wraps to MIN_INT
self.assertEqual(st.scc, 1, "Overflow, SCC should be 1")
def test_s_mul_hi_u32(self):
"""S_MUL_HI_U32: high 32 bits of u32 * u32."""
instructions = [
s_mov_b32(s[0], 0x80000000), # 2^31
s_mov_b32(s[1], 4),
s_mul_hi_u32(s[2], s[0], s[1]), # (2^31 * 4) >> 32 = 2
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[2], 2)
def test_s_mul_i32(self):
"""S_MUL_I32: signed multiply low 32 bits."""
instructions = [
s_mov_b32(s[0], 0xFFFFFFFF), # -1
s_mov_b32(s[1], 10),
s_mul_i32(s[2], s[0], s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[2], 0xFFFFFFF6) # -10
def test_division_sequence_from_llvm(self):
"""Test the division sequence pattern from LLVM-generated code."""
# This sequence is from the sin kernel and computes integer division
# s10 = dividend, s18 = divisor, result in s6/s14
dividend = 0x28BE60DB # Some value from the sin kernel
divisor = 3 # Simplified divisor
instructions = [
s_mov_b32(s[10], dividend),
s_mov_b32(s[18], divisor),
# Compute reciprocal approximation: s6 = ~0 / divisor (approx)
s_mov_b32(s[11], 0),
s_sub_i32(s[11], s[11], s[18]), # s11 = -divisor
# For testing, just verify basic arithmetic works
s_mul_i32(s[6], s[10], 2),
s_add_i32(s[7], s[6], 1),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[6], (dividend * 2) & 0xFFFFFFFF)
self.assertEqual(st.sgpr[7], ((dividend * 2) + 1) & 0xFFFFFFFF)
if __name__ == '__main__':
unittest.main()

View File

@@ -1049,6 +1049,39 @@ class TestF64Ops(unittest.TestCase):
total = p0 + p1 + p2
self.assertAlmostEqual(total, two_over_pi, places=14)
def test_v_fma_f64_sin_kernel_step84(self):
"""V_FMA_F64: exact values from sin(2.0) kernel step 84 that shows 1-bit difference."""
# From test_sin_f64 failure trace at step 84:
# v_fma_f64 v[7:8], v[17:18], v[7:8], v[15:16]
# We need to capture the exact input values and verify output matches hardware
# v[7:8] before = 0x3f80fdf3_d69db28f (0.008296875941334462)
v78 = 0x3f80fdf3d69db28f
# For the FMA to produce 0xbf457ef0_ab8c254d, we need v[17:18] and v[15:16]
# Let's test with known precision-sensitive values
a = 1.0000000001
b = 1.0000000002
c = -1.0000000003
a_bits, b_bits, c_bits = f2i64(a), f2i64(b), f2i64(c)
instructions = [
s_mov_b32(s[0], a_bits & 0xffffffff),
s_mov_b32(s[1], a_bits >> 32),
s_mov_b32(s[2], b_bits & 0xffffffff),
s_mov_b32(s[3], b_bits >> 32),
s_mov_b32(s[4], c_bits & 0xffffffff),
s_mov_b32(s[5], c_bits >> 32),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], s[2]),
v_mov_b32_e32(v[3], s[3]),
v_mov_b32_e32(v[4], s[4]),
v_mov_b32_e32(v[5], s[5]),
v_fma_f64(v[6], v[0], v[2], v[4]),
]
# run_program with USE_HW=1 will verify exact bit match with hardware
st = run_program(instructions, n_lanes=1)
result_bits = st.vgpr[0][6] | (st.vgpr[0][7] << 32)
self.assertNotEqual(result_bits, 0, "Result should not be zero")
class TestMad64More(unittest.TestCase):
"""More tests for V_MAD_U64_U32."""

View File

@@ -9,7 +9,7 @@ os.environ["AMD"] = "1"
os.environ["MOCKGPU"] = "1"
os.environ["PYTHON_REMU"] = "1"
from extra.assembly.amd.emu import WaveState, decode_program, step_wave, WAVE_SIZE, set_valid_mem_ranges, LDSMem
from extra.assembly.amd.emu import WaveState, decode_program, WAVE_SIZE, set_valid_mem_ranges, LDSMem
from extra.assembly.amd.test.helpers import KernelInfo
REMU_PATH = Path(__file__).parents[3] / "remu/target/release/libremu.so"
@@ -92,19 +92,15 @@ class PythonEmulator:
def __init__(self):
self.state: WaveState | None = None
self.program: dict | None = None
self.lds: bytearray | None = None
self.n_lanes = 0
def create(self, kernel: bytes, n_lanes: int):
self.program = decode_program(kernel)
self.state = WaveState()
self.state = WaveState(LDSMem(bytearray(65536)), n_lanes)
self.state.exec_mask = (1 << n_lanes) - 1
self.lds = LDSMem(bytearray(65536))
self.n_lanes = n_lanes
def step(self) -> int:
assert self.program is not None and self.state is not None and self.lds is not None
return step_wave(self.program, self.state, self.lds, self.n_lanes)
assert self.program is not None and self.state is not None
return self.program[self.state.pc]._dispatch(self.state, self.program[self.state.pc])
def set_sgpr(self, idx: int, val: int):
assert self.state is not None
self.state.sgpr[idx] = val & 0xffffffff
@@ -163,8 +159,9 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
# Instructions with known Rust emulator bugs - sync Python to Rust after execution
# v_div_scale/v_div_fixup: Rust has different VCC handling
# v_cvt_f16_f32: Rust clears high 16 bits, but hardware (and Python) preserves them
# s_add_i32/s_sub_i32: Rust has incorrect SCC overflow detection
sync_after = any(x in inst_str for x in ('v_div_scale_f32', 'v_div_scale_f64', 'v_div_fixup_f32', 'v_div_fixup_f64',
'v_cvt_f16_f32'))
'v_cvt_f16_f32', 's_add_i32', 's_sub_i32'))
diffs = rust_before.diff(python_before, n_lanes)
if diffs:
trace_lines = []
@@ -397,6 +394,9 @@ class TestTinygradKernels(unittest.TestCase):
x_np = np.random.randn(16, 10).astype(np.float32)
self._test_kernel(lambda T: (T(x_np.tolist()).reshape(16,10) + 0).cross_entropy((T(classes).int().reshape(16) + 0)))
def test_isinf(self): self._test_kernel(lambda T: T([float('-inf'), 0., float('inf'), 1.1]*8).isinf())
def test_sin_f64(self):
from tinygrad import dtypes
self._test_kernel(lambda T: T([2.0], dtype=dtypes.float64).sin())
if __name__ == "__main__":
unittest.main()

View File

@@ -20,11 +20,12 @@ runner = get_runner(dev.device, si.ast)
prg = runner._prg
lib = bytearray(prg.lib)
# Find s_endpgm (0xBFB00000) and replace with invalid SOPP op=127 (0xBFFF0000)
# Find s_endpgm (0xBFB00000) and replace with V_MOVRELD_B32 (op=66) which has no pcode
# VOP1 encoding: bits[31:25]=0x7E, op=bits[16:9], so op=66 -> 66<<9 = 0x8400
found = False
for i in range(0, len(lib) - 4, 4):
if struct.unpack("<I", lib[i:i+4])[0] == 0xBFB00000:
lib[i:i+4] = struct.pack("<I", 0xBFFF0000)
lib[i:i+4] = struct.pack("<I", 0x7E008400)
found = True
break
assert found, "s_endpgm not found"

View File

@@ -233,14 +233,13 @@ class TestPseudocodeRegressions(unittest.TestCase):
Bug: when VCC._val == vcc (both 0), VCC wasn't returned, so VCC bits weren't written.
This caused division to produce wrong results for multiple lanes."""
# Normal case: 1.0 / 3.0, no scaling needed, VCC should be 0
S0 = Reg(0x3f800000) # 1.0
S1 = Reg(0x40400000) # 3.0
S2 = Reg(0x3f800000) # 1.0 (numerator)
D0, SCC, VCC, EXEC = Reg(0), Reg(0), Reg(0), Reg(0xffffffff)
result = _VOP3SDOp_V_DIV_SCALE_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
s0 = 0x3f800000 # 1.0
s1 = 0x40400000 # 3.0
s2 = 0x3f800000 # 1.0 (numerator)
result = _VOP3SDOp_V_DIV_SCALE_F32(s0, s1, s2, 0, 0, 0, 0, 0xffffffff, 0, None)
# Must always have VCC in result
self.assertIn('VCC', result, "V_DIV_SCALE_F32 must always return VCC")
self.assertEqual(result['VCC']._val & 1, 0, "VCC lane 0 should be 0 when no scaling needed")
self.assertEqual(result['VCC'] & 1, 0, "VCC lane 0 should be 0 when no scaling needed")
def test_v_cmp_class_f32_detects_quiet_nan(self):
"""V_CMP_CLASS_F32 must correctly identify quiet NaN vs signaling NaN.
@@ -249,22 +248,18 @@ class TestPseudocodeRegressions(unittest.TestCase):
signal_nan = 0x7f800001 # signaling NaN: exponent=255, bit22=0
# Test quiet NaN detection (bit 1 in mask)
s1_quiet = 0b0000000010 # bit 1 = quiet NaN
S0, S1, S2, D0, SCC, VCC, EXEC = Reg(quiet_nan), Reg(s1_quiet), Reg(0), Reg(0), Reg(0), Reg(0), Reg(0xffffffff)
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
self.assertEqual(result['D0']._val & 1, 1, "Should detect quiet NaN with quiet NaN mask")
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
self.assertEqual(result['D0'] & 1, 1, "Should detect quiet NaN with quiet NaN mask")
# Test signaling NaN detection (bit 0 in mask)
s1_signal = 0b0000000001 # bit 0 = signaling NaN
S0, S1 = Reg(signal_nan), Reg(s1_signal)
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
self.assertEqual(result['D0']._val & 1, 1, "Should detect signaling NaN with signaling NaN mask")
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
self.assertEqual(result['D0'] & 1, 1, "Should detect signaling NaN with signaling NaN mask")
# Test that quiet NaN doesn't match signaling NaN mask
S0, S1 = Reg(quiet_nan), Reg(s1_signal)
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
self.assertEqual(result['D0']._val & 1, 0, "Quiet NaN should not match signaling NaN mask")
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
self.assertEqual(result['D0'] & 1, 0, "Quiet NaN should not match signaling NaN mask")
# Test that signaling NaN doesn't match quiet NaN mask
S0, S1 = Reg(signal_nan), Reg(s1_quiet)
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
self.assertEqual(result['D0']._val & 1, 0, "Signaling NaN should not match quiet NaN mask")
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
self.assertEqual(result['D0'] & 1, 0, "Signaling NaN should not match quiet NaN mask")
def test_isnan_with_typed_view(self):
"""_isnan must work with TypedView objects, not just Python floats.