rever emu to master

This commit is contained in:
George Hotz
2026-01-02 10:46:46 -08:00
parent 2b56c264d5
commit df20197bfb
2 changed files with 116 additions and 451 deletions

View File

@@ -1,14 +1,12 @@
# RDNA3 emulator - executes compiled pseudocode from AMD ISA PDF
# mypy: ignore-errors
from __future__ import annotations
import ctypes
from enum import IntEnum
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
from extra.assembly.amd.pcode import Reg
import ctypes, struct
from extra.assembly.amd.dsl import Inst, RawImm, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
from extra.assembly.amd.asm import detect_format
from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, SCRATCHOp, VOPDOp)
SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, VOPDOp)
Program = dict[int, Inst]
WAVE_SIZE, SGPR_COUNT, VGPR_COUNT = 32, 128, 256
@@ -30,402 +28,35 @@ def _dst16(cur: int, val: int, is_hi: bool) -> int: return (cur & 0x0000ffff) |
def _vgpr_hi(src: int) -> bool: return src >= 256 and ((src - 256) & 0x80) != 0
def _vgpr_masked(src: int) -> int: return ((src - 256) & 0x7f) + 256 if src >= 256 else src
# Helper: get number of dwords from memory op name
def _op_ndwords(name: str) -> int:
if '_B128' in name: return 4
if '_B96' in name: return 3
if any(s in name for s in ('_B64', '_U64', '_I64', '_F64')): return 2
return 1
# Helper: build multi-dword Reg from consecutive VGPRs
def _vgpr_read(V: list, base: int, ndwords: int) -> Reg: return Reg(sum(V[base + i] << (32 * i) for i in range(ndwords)))
# Helper: write multi-dword value to consecutive VGPRs
def _vgpr_write(V: list, base: int, val: int, ndwords: int):
for i in range(ndwords): V[base + i] = (val >> (32 * i)) & MASK32
# Memory access
_valid_mem_ranges: list[tuple[int, int]] = []
def set_valid_mem_ranges(ranges: set[tuple[int, int]]) -> None: _valid_mem_ranges.clear(); _valid_mem_ranges.extend(ranges)
def _mem_valid(addr: int, size: int) -> bool:
return not _valid_mem_ranges or any(s <= addr and addr + size <= s + z for s, z in _valid_mem_ranges)
def _ctypes_at(addr: int, size: int): return (ctypes.c_uint8 if size == 1 else ctypes.c_uint16 if size == 2 else ctypes.c_uint64 if size == 8 else ctypes.c_uint32).from_address(addr)
def _ctypes_at(addr: int, size: int): return (ctypes.c_uint8 if size == 1 else ctypes.c_uint16 if size == 2 else ctypes.c_uint32).from_address(addr)
def mem_read(addr: int, size: int) -> int: return _ctypes_at(addr, size).value if _mem_valid(addr, size) else 0
def mem_write(addr: int, size: int, val: int) -> None:
if _mem_valid(addr, size): _ctypes_at(addr, size).value = val
def _make_mem_accessor(read_fn, write_fn):
"""Create a memory accessor class with the given read/write functions."""
class _MemAccessor:
__slots__ = ('_addr',)
def __init__(self, addr: int): self._addr = int(addr)
u8 = property(lambda s: read_fn(s._addr, 1), lambda s, v: write_fn(s._addr, 1, int(v)))
u16 = property(lambda s: read_fn(s._addr, 2), lambda s, v: write_fn(s._addr, 2, int(v)))
u32 = property(lambda s: read_fn(s._addr, 4), lambda s, v: write_fn(s._addr, 4, int(v)))
u64 = property(lambda s: read_fn(s._addr, 8), lambda s, v: write_fn(s._addr, 8, int(v)))
i8 = property(lambda s: _sext(read_fn(s._addr, 1), 8), lambda s, v: write_fn(s._addr, 1, int(v)))
i16 = property(lambda s: _sext(read_fn(s._addr, 2), 16), lambda s, v: write_fn(s._addr, 2, int(v)))
i32 = property(lambda s: _sext(read_fn(s._addr, 4), 32), lambda s, v: write_fn(s._addr, 4, int(v)))
i64 = property(lambda s: _sext(read_fn(s._addr, 8), 64), lambda s, v: write_fn(s._addr, 8, int(v)))
b8, b16, b32, b64 = u8, u16, u32, u64
return _MemAccessor
_GlobalMemAccessor = _make_mem_accessor(mem_read, mem_write)
class _GlobalMem:
"""Global memory wrapper that supports MEM[addr].u32 style access."""
def __getitem__(self, addr) -> _GlobalMemAccessor: return _GlobalMemAccessor(addr)
GlobalMem = _GlobalMem()
class LDSMem:
"""LDS memory wrapper that supports MEM[addr].u32 style access."""
__slots__ = ('_lds',)
def __init__(self, lds: bytearray): self._lds = lds
def _read(self, addr: int, size: int) -> int:
addr = addr & 0xffff
return int.from_bytes(self._lds[addr:addr+size], 'little') if addr + size <= len(self._lds) else 0
def _write(self, addr: int, size: int, val: int):
addr = addr & 0xffff
if addr + size <= len(self._lds): self._lds[addr:addr+size] = (int(val) & ((1 << (size*8)) - 1)).to_bytes(size, 'little')
def __getitem__(self, addr): return _make_mem_accessor(self._read, self._write)(addr)
# Memory op tables (not pseudocode - these are format descriptions)
def _mem_ops(ops, suffix_map):
return {getattr(e, f"{p}_{s}"): v for e in ops for s, v in suffix_map.items() for p in [e.__name__.replace("Op", "")]}
_LOAD_MAP = {'LOAD_B32': (1,4,0), 'LOAD_B64': (2,4,0), 'LOAD_B96': (3,4,0), 'LOAD_B128': (4,4,0), 'LOAD_U8': (1,1,0), 'LOAD_I8': (1,1,1), 'LOAD_U16': (1,2,0), 'LOAD_I16': (1,2,1)}
_STORE_MAP = {'STORE_B32': (1,4), 'STORE_B64': (2,4), 'STORE_B96': (3,4), 'STORE_B128': (4,4), 'STORE_B8': (1,1), 'STORE_B16': (1,2)}
FLAT_LOAD, FLAT_STORE = _mem_ops([GLOBALOp, FLATOp], _LOAD_MAP), _mem_ops([GLOBALOp, FLATOp], _STORE_MAP)
# D16 ops: load/store 16-bit to lower or upper half of VGPR. Format: (size, sign, hi) where hi=1 means upper 16 bits
_D16_LOAD_MAP = {'LOAD_D16_U8': (1,0,0), 'LOAD_D16_I8': (1,1,0), 'LOAD_D16_B16': (2,0,0),
'LOAD_D16_HI_U8': (1,0,1), 'LOAD_D16_HI_I8': (1,1,1), 'LOAD_D16_HI_B16': (2,0,1)}
_D16_STORE_MAP = {'STORE_D16_HI_B8': (1,1), 'STORE_D16_HI_B16': (2,1)} # (size, hi)
FLAT_D16_LOAD = _mem_ops([GLOBALOp, FLATOp], _D16_LOAD_MAP)
FLAT_D16_STORE = _mem_ops([GLOBALOp, FLATOp], _D16_STORE_MAP)
DS_LOAD = {DSOp.DS_LOAD_B32: (1,4,0), DSOp.DS_LOAD_B64: (2,4,0), DSOp.DS_LOAD_B128: (4,4,0), DSOp.DS_LOAD_U8: (1,1,0), DSOp.DS_LOAD_I8: (1,1,1), DSOp.DS_LOAD_U16: (1,2,0), DSOp.DS_LOAD_I16: (1,2,1)}
DS_STORE = {DSOp.DS_STORE_B32: (1,4), DSOp.DS_STORE_B64: (2,4), DSOp.DS_STORE_B128: (4,4), DSOp.DS_STORE_B8: (1,1), DSOp.DS_STORE_B16: (1,2)}
# 2ADDR ops: load/store two values using offset0 and offset1
DS_LOAD_2ADDR = {DSOp.DS_LOAD_2ADDR_B32: 4, DSOp.DS_LOAD_2ADDR_B64: 8}
DS_STORE_2ADDR = {DSOp.DS_STORE_2ADDR_B32: 4, DSOp.DS_STORE_2ADDR_B64: 8}
SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
# ═══════════════════════════════════════════════════════════════════════════════
# SQTT TRACING - Emit packets matching real hardware output
# ═══════════════════════════════════════════════════════════════════════════════
# Transcendental ops that produce INST packets with op=VALU_TRANS (0x0b)
_TRANS_OPS = {'V_RCP_F32', 'V_RCP_F64', 'V_RSQ_F32', 'V_RSQ_F64', 'V_SQRT_F32', 'V_SQRT_F64',
'V_LOG_F32', 'V_EXP_F32', 'V_SIN_F32', 'V_COS_F32', 'V_RCP_F16', 'V_RSQ_F16', 'V_SQRT_F16'}
# Latency model from hardware measurements (warm instruction cache):
# Startup: WAVESTART -> first instruction (~32 cycles with warm cache, no REG packet)
# Cold cache: WAVESTART -> REG (~137 cycles) -> first instruction (~150 cycles after REG)
# SALU: issues every cycle, result ready 2 cycles after issue, ALUEXEC at ready time
# VALU: issues every cycle, ALUEXEC at issue+6 for each inst, serialized with +1 intervals
# TRANS: issues every 4 cycles, ALUEXEC at last_issue+1, then +8 intervals
# For dependent instructions, ALUEXEC is at source_ready + 10 (first dep) or + 9 (chained)
# VALU queue depth ~7: after 7 in-flight, ALUEXEC interleaves with VALUINST
WAVESTART_TO_INST_CYCLES = 32 # cycles from WAVESTART to first instruction (warm cache)
SALU_LATENCY = 2 # cycles from issue to result ready
VALU_EXEC_LATENCY = 6 # cycles from first issue to first ALUEXEC
TRANS_ISSUE_CYCLES = 4 # cycles between transcendental instruction issues
TRANS_LATENCY = 9 # cycles from trans issue to ALUEXEC
class SQTTState:
"""SQTT tracing state - emits packets matching real hardware (warm cache model)."""
__slots__ = ('cycle', 'packets', 'pending_exec', 'wave_id', 'simd', 'cu', 'sgpr_ready', 'vgpr_ready',
'last_salu_exec', 'last_valu_exec', 'trans_count', 'last_trans_issue', 'last_trans_exec', 'first_valu_issue', 'valu_count', 'last_inst_type')
def __init__(self, wave_id: int = 0, simd: int = 0, cu: int = 0):
self.cycle = 0
self.packets: list = []
self.pending_exec: list[tuple[int, int]] = [] # (completion_cycle, src) for ALUEXEC
self.wave_id, self.simd, self.cu = wave_id, simd, cu
self.sgpr_ready: dict[int, int] = {} # sgpr -> cycle when result ready
self.last_inst_type: str = '' # track last instruction type for gap calculation
self.vgpr_ready: dict[int, int] = {} # vgpr -> cycle when result ready
self.last_salu_exec = 0 # last SALU ALUEXEC time (for +1 spacing)
self.last_valu_exec = 0 # last VALU ALUEXEC time (for +1 spacing)
self.trans_count = 0 # number of trans instructions issued
self.last_trans_issue = 0 # cycle when last trans was issued
self.last_trans_exec = 0 # cycle when last trans ALUEXEC is scheduled
self.first_valu_issue = 0 # cycle when first VALU was issued
self.valu_count = 0 # number of VALU instructions issued
def emit_wavestart(self):
from extra.assembly.amd.sqtt import WAVESTART
self.packets.append(WAVESTART(_time=self.cycle, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3))
self.cycle += WAVESTART_TO_INST_CYCLES
def emit_waveend(self):
from extra.assembly.amd.sqtt import WAVEEND
self.packets.append(WAVEEND(_time=self.cycle, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3))
def emit_inst(self, inst_op: int):
from extra.assembly.amd.sqtt import INST
self.packets.append(INST(_time=self.cycle, wave=self.wave_id, op=inst_op))
def emit_valuinst(self):
from extra.assembly.amd.sqtt import VALUINST
self.packets.append(VALUINST(_time=self.cycle, wave=self.wave_id))
def emit_aluexec(self, src: int):
from extra.assembly.amd.sqtt import ALUEXEC
self.packets.append(ALUEXEC(_time=self.cycle, src=src))
def emit_immediate(self):
from extra.assembly.amd.sqtt import IMMEDIATE
self.packets.append(IMMEDIATE(_time=self.cycle, wave=self.wave_id))
def _get_src_regs(self, inst: Inst) -> list[tuple[str, int]]:
"""Extract source register references from instruction."""
srcs = []
if isinstance(inst, (SOP1, SOP2, SOPC)):
if hasattr(inst, 'ssrc0') and inst.ssrc0 < SGPR_COUNT: srcs.append(('s', inst.ssrc0))
if hasattr(inst, 'ssrc1') and inst.ssrc1 < SGPR_COUNT: srcs.append(('s', inst.ssrc1))
elif isinstance(inst, SOPK):
pass # immediate only
elif isinstance(inst, VOP1):
# VOP1: only src0
if inst.src0 < SGPR_COUNT: srcs.append(('s', inst.src0))
elif inst.src0 >= 256: srcs.append(('v', inst.src0 - 256))
elif isinstance(inst, VOP2):
# VOP2: src0 (can be SGPR/VGPR/const) + vsrc1 (always VGPR)
if inst.src0 < SGPR_COUNT: srcs.append(('s', inst.src0))
elif inst.src0 >= 256: srcs.append(('v', inst.src0 - 256))
srcs.append(('v', inst.vsrc1))
elif isinstance(inst, (VOP3, VOP3SD, VOP3P, VOPC)):
for attr in ('src0', 'src1', 'src2'):
if hasattr(inst, attr):
src = getattr(inst, attr)
if src < SGPR_COUNT: srcs.append(('s', src))
elif src >= 256: srcs.append(('v', src - 256))
return srcs
def _get_dst_regs(self, inst: Inst) -> list[tuple[str, int]]:
"""Extract destination register references from instruction."""
dsts = []
if isinstance(inst, (SOP1, SOP2, SOPK)):
if hasattr(inst, 'sdst') and inst.sdst < SGPR_COUNT:
dsts.append(('s', inst.sdst))
if inst.dst_regs() == 2: dsts.append(('s', inst.sdst + 1))
elif isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD)):
if hasattr(inst, 'vdst'):
dsts.append(('v', inst.vdst))
if inst.dst_regs() == 2: dsts.append(('v', inst.vdst + 1))
return dsts
def _check_dep_stall(self, inst: Inst) -> int:
"""Calculate stall cycles due to RAW dependencies - wait until all sources ready."""
max_ready = 0
for typ, reg in self._get_src_regs(inst):
ready = (self.sgpr_ready if typ == 's' else self.vgpr_ready).get(reg, 0)
max_ready = max(max_ready, ready)
return max(0, max_ready - self.cycle)
def _record_dst_ready(self, inst: Inst, ready_cycle: int):
"""Record when destination registers will be ready."""
for typ, reg in self._get_dst_regs(inst):
(self.sgpr_ready if typ == 's' else self.vgpr_ready)[reg] = ready_cycle
def trace_inst(self, inst: Inst):
"""Emit appropriate SQTT packets for an instruction.
Key insight from hardware traces:
- Instructions issue every cycle (INST/VALUINST packets at +1 intervals)
- ALUEXEC shows when instruction completes (result ready)
- For independent ops: ALUEXEC at issue+latency, then +1 each (pipeline throughput)
- For dependent ops: ALUEXEC delayed until source ready + latency
"""
from extra.assembly.amd.sqtt import InstOp, AluSrc
if isinstance(inst, (SOP1, SOP2, SOPC, SOPK)):
# SALU: issue now, complete after latency (or wait for deps)
self.emit_inst(InstOp.SALU)
# Completion time: max of (issue + latency) and (source_ready + latency) and (last_exec + 1)
src_ready = max((self.sgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst)), default=0)
ready_cycle = max(self.cycle, src_ready) + SALU_LATENCY
exec_cycle = max(ready_cycle, self.last_salu_exec + 1)
self.pending_exec.append((exec_cycle, AluSrc.SALU))
self.last_salu_exec = exec_cycle
self._record_dst_ready(inst, ready_cycle)
self.last_inst_type = 'SALU'
elif isinstance(inst, SOPP):
# s_nop and other SOPP emit IMMEDIATE packets
if self.last_inst_type == 'SALU':
# SALU→SOPP: emit ALUEXECs with max gap of 2 once, then IMMEDIATEs
# First ALUEXEC at last INST cycle, then allow one +2 gap, rest interleave
# cycle is currently 1 past last INST
last_inst_cycle = self.cycle - 1
first = True
had_gap = False # track if we've had a +2 gap
while self.pending_exec:
exec_cycle, src = self.pending_exec[0]
if first:
self.pending_exec.pop(0)
self.cycle = last_inst_cycle # First ALUEXEC at same cycle as last INST
first = False
elif exec_cycle == self.cycle + 1: # consecutive (+1)
self.pending_exec.pop(0)
self.cycle = exec_cycle
elif exec_cycle == self.cycle + 2 and not had_gap: # allow one +2 gap
self.pending_exec.pop(0)
self.cycle = exec_cycle
had_gap = True
else:
break # too large a gap or second +2 gap - rest interleave with IMMEDIATEs
self.emit_aluexec(src)
self.cycle += 1 # IMMEDIATE starts +1 after last ALUEXEC
elif self.last_inst_type == 'TRANS':
# TRANS→SOPP: emit first ALUEXEC at its scheduled time, then IMMEDIATE +2 after
# Sort pending execs and emit first one if it's before IMMEDIATE time
self.pending_exec.sort(key=lambda x: x[0])
if self.pending_exec:
first_exec_cycle = self.pending_exec[0][0]
# First IMMEDIATE at first_exec + 2
imm_cycle = first_exec_cycle + 2
# Emit ALUEXECs that come before IMMEDIATE
while self.pending_exec and self.pending_exec[0][0] < imm_cycle:
exec_cycle, src = self.pending_exec.pop(0)
self.cycle = exec_cycle
self.emit_aluexec(src)
self.cycle = imm_cycle
elif self.last_inst_type == 'VALU':
# VALU→SOPP: 2-cycle gap, emit any ALUEXECs that would come before
imm_cycle = self.cycle + 2 # When IMMEDIATE would be emitted
# Emit ALUEXECs that come before the IMMEDIATE
while self.pending_exec and self.pending_exec[0][0] < imm_cycle:
exec_cycle, src = self.pending_exec.pop(0)
self.cycle = max(self.cycle, exec_cycle)
self.emit_aluexec(src)
# Jump to IMMEDIATE cycle (either imm_cycle or 1 after last ALUEXEC, whichever is later)
self.cycle = max(imm_cycle, self.cycle + 1)
# Emit IMMEDIATE first, then any pending ALUEXECs at same cycle (HW order)
self.emit_immediate()
while self.pending_exec and self.pending_exec[0][0] <= self.cycle:
exec_cycle, src = self.pending_exec.pop(0)
old_cycle = self.cycle
self.cycle = exec_cycle
self.emit_aluexec(src)
self.cycle = old_cycle
# s_nop(N) waits N+1 cycles - apply delay for N > 0 (simm16 field)
if inst.op == SOPPOp.S_NOP and inst.simm16 > 0:
self.cycle += inst.simm16 # add delay before next instruction
self.last_inst_type = 'SOPP'
elif isinstance(inst, SMEM):
pass # skip for ALU focus
elif isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC)):
# VALU: issue now, emit ALUEXEC for completed instructions if queue is full
from extra.assembly.amd.sqtt import AluSrc
op_name = inst.op_name if hasattr(inst, 'op_name') else ''
is_trans = any(t in op_name for t in _TRANS_OPS)
if is_trans:
# Transcendental: emit INST, 4-cycle issue, add ALUEXEC to pending
self.emit_inst(InstOp.VALU_TRANS)
self.trans_count += 1
# Check for dependency on VGPR sources
src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0)
if src_ready > self.cycle:
# Dependent trans: ALUEXEC at source_ready + 6 (for VALU source) or +10 (for trans source)
# Check if source is from a trans instruction (its ready time would be > issue + 6)
src_is_trans = any(self.vgpr_ready.get(r, 0) > self.cycle + 6 for _, r in self._get_src_regs(inst) if _ == 'v')
exec_cycle = src_ready + (10 if src_is_trans else 6)
else:
# Independent trans: ALUEXEC at issue + 9
exec_cycle = self.cycle + TRANS_LATENCY
self.pending_exec.append((exec_cycle, AluSrc.VALU))
self._record_dst_ready(inst, exec_cycle) # Record when this trans result is ready
self.last_trans_exec = exec_cycle
self.last_trans_issue = self.cycle
# Trans instructions take 4 cycles to issue
self.cycle += TRANS_ISSUE_CYCLES - 1 # -1 because we add 1 at end of trace_inst
self.last_inst_type = 'TRANS'
else:
# Regular VALU: emit VALUINST, may interleave ALUEXEC
self.emit_valuinst()
# Track first VALU issue for latency calculation
if self.first_valu_issue == 0:
self.first_valu_issue = self.cycle
# Emit any pending ALUEXECs that have completed by now (after VALUINST)
while self.pending_exec and self.pending_exec[0][0] <= self.cycle:
exec_cycle, src = self.pending_exec.pop(0)
old_cycle = self.cycle
self.cycle = exec_cycle
self.emit_aluexec(src)
self.cycle = old_cycle
# Check for dependency
src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0)
has_dep = src_ready > self.cycle
if has_dep:
# Dependent: first dependent is +6 from source, subsequent chained deps are +5
# "Chained" means the source was also dependent (not the first independent VALU)
chained = src_ready > self.first_valu_issue + VALU_EXEC_LATENCY
exec_cycle = src_ready + (5 if chained else 6)
else:
# Independent VALU timing: 6-cycle pipeline latency
# exec[i] = first_issue + 6 + i, with +1 spacing between consecutive execs
exec_cycle = self.first_valu_issue + VALU_EXEC_LATENCY + self.valu_count
exec_cycle = max(exec_cycle, self.last_valu_exec + 1) # +1 spacing
# Add to pending exec queue (sorted by time)
self.pending_exec.append((exec_cycle, AluSrc.VALU))
self.pending_exec.sort(key=lambda x: x[0])
self.last_valu_exec = exec_cycle
self.valu_count += 1
self._record_dst_ready(inst, exec_cycle)
self.last_inst_type = 'VALU'
elif isinstance(inst, VOPD):
from extra.assembly.amd.sqtt import AluSrc
# First emit VALUINST (HW emits issue before completion at same cycle)
self.emit_valuinst()
# Emit any pending ALUEXECs that have completed by now
while self.pending_exec and self.pending_exec[0][0] <= self.cycle:
exec_cycle, src = self.pending_exec.pop(0)
old_cycle = self.cycle
self.cycle = exec_cycle
self.emit_aluexec(src)
self.cycle = old_cycle
src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0)
has_dep = src_ready > self.cycle
if has_dep:
exec_cycle = src_ready + 10
else:
exec_cycle = self.cycle + VALU_EXEC_LATENCY
if self.last_valu_exec > 0:
exec_cycle = max(exec_cycle, self.last_valu_exec + 1)
self.pending_exec.append((exec_cycle, AluSrc.VALU))
self.pending_exec.sort(key=lambda x: x[0])
self.last_valu_exec = exec_cycle
self._record_dst_ready(inst, exec_cycle)
self.cycle += 1
def finalize(self):
"""Emit all remaining pending ALUEXEC packets and WAVEEND."""
from extra.assembly.amd.sqtt import AluSrc
# Emit any remaining pending ALUEXECs
self.pending_exec.sort(key=lambda x: x[0])
last_src = None
for exec_cycle, src in self.pending_exec:
self.cycle = exec_cycle
self.emit_aluexec(src)
last_src = src
self.pending_exec.clear()
# WAVEEND timing depends on what comes last
# If last instruction was SOPP: +1 normally, but +11 if last ALUEXEC was recent (VALU drain)
# Otherwise: 14 cycles for no ALU, 20 for trans, 14/15 for SALU/VALU
if self.last_inst_type == 'SOPP':
# Check if last ALUEXEC was recent (within ~5 cycles of current)
if self.valu_count > 0 and self.last_valu_exec >= self.cycle - 5:
self.cycle += 11 # VALU drain time
else:
self.cycle += 1
elif last_src is None:
self.cycle += 14 # empty program or no ALU ops
elif self.trans_count > 0:
self.cycle += 20 # trans has longer WAVEEND delay
else:
self.cycle += 15 if last_src == AluSrc.VALU else 14
self.emit_waveend()
# VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup)
_VOPD_TO_VOP = {
VOPDOp.V_DUAL_FMAC_F32: VOP3Op.V_FMAC_F32, VOPDOp.V_DUAL_FMAAK_F32: VOP2Op.V_FMAAK_F32, VOPDOp.V_DUAL_FMAMK_F32: VOP2Op.V_FMAMK_F32,
@@ -547,46 +178,81 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
s0 = st.rsrc64(ssrc0, 0) if inst.is_src_64(0) else (st.rsrc(ssrc0, 0) if not isinstance(inst, (SOPK, SOPP)) else (st.rsgpr(inst.sdst) if isinstance(inst, SOPK) else 0))
s1 = st.rsrc64(inst.ssrc1, 0) if inst.is_src_64(1) else (st.rsrc(inst.ssrc1, 0) if isinstance(inst, (SOP2, SOPC)) else inst.simm16 if isinstance(inst, SOPK) else 0)
d0 = st.rsgpr64(sdst) if inst.dst_regs() == 2 and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0)
exec_mask = st.exec_mask
literal = inst.simm16 if isinstance(inst, (SOPK, SOPP)) else st.literal
# Create Reg objects for compiled function - mask VCC/EXEC to 32 bits for wave32
result = fn(Reg(s0), Reg(s1), None, Reg(d0), Reg(st.scc), Reg(st.vcc & MASK32), 0, Reg(st.exec_mask & MASK32), literal, None, PC=Reg(st.pc * 4))
# Execute compiled function - pass PC in bytes for instructions that need it
# For wave32, mask VCC and EXEC to 32 bits since only the lower 32 bits are relevant
pc_bytes = st.pc * 4
vcc32, exec32 = st.vcc & MASK32, exec_mask & MASK32
result = fn(s0, s1, 0, d0, st.scc, vcc32, 0, exec32, literal, None, {}, pc=pc_bytes)
# Apply results - extract values from returned Reg objects
if sdst is not None and 'D0' in result:
(st.wsgpr64 if inst.dst_regs() == 2 else st.wsgpr)(sdst, result['D0']._val)
if 'SCC' in result: st.scc = result['SCC']._val & 1
if 'EXEC' in result: st.exec_mask = result['EXEC']._val
if 'PC' in result:
# Apply results
if sdst is not None:
(st.wsgpr64 if result.get('d0_64') else st.wsgpr)(sdst, result['d0'])
if 'scc' in result: st.scc = result['scc']
if 'exec' in result: st.exec_mask = result['exec']
if 'new_pc' in result:
# Convert absolute byte address to word delta
pc_val = result['PC']._val
new_pc = pc_val if pc_val < 0x8000000000000000 else pc_val - 0x10000000000000000
new_pc_words = new_pc // 4
# new_pc is where we want to go, st.pc is current position, inst._words will be added after
new_pc_words = result['new_pc'] // 4
return new_pc_words - st.pc - 1 # -1 because emulator adds inst_words (1 for scalar)
return 0
def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None) -> None:
def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = None) -> None:
"""Execute vector instruction for one lane."""
compiled = _get_compiled()
V = st.vgpr[lane]
# Memory ops (FLAT/GLOBAL/SCRATCH and DS) - use generated pcode
if isinstance(inst, (FLAT, DS)):
op, vdst, op_name = inst.op, inst.vdst, inst.op.name
fn, ndwords = compiled[type(op)][op], _op_ndwords(op_name)
# Memory ops (not ALU pseudocode)
if isinstance(inst, FLAT):
addr = V[inst.addr] | (V[inst.addr + 1] << 32)
ADDR = (st.rsgpr64(inst.saddr) + V[inst.addr] + _sext(inst.offset, 13)) & MASK64 if inst.saddr not in (NULL, 0x7f) else (addr + _sext(inst.offset, 13)) & MASK64
# For loads, VDATA comes from vdst (preserves unwritten bits); for stores, from inst.data
vdata_src = vdst if 'LOAD' in op_name else inst.data
result = fn(GlobalMem, ADDR, _vgpr_read(V, vdata_src, ndwords), Reg(V[vdst]), Reg(0))
if 'VDATA' in result: _vgpr_write(V, vdst, result['VDATA']._val, ndwords)
if 'RETURN_DATA' in result: _vgpr_write(V, vdst, result['RETURN_DATA']._val, ndwords)
else: # DS
DATA0, DATA1 = _vgpr_read(V, inst.data0, ndwords), _vgpr_read(V, inst.data1, ndwords) if inst.data1 is not None else Reg(0)
result = fn(lds, Reg(V[inst.addr]), DATA0, DATA1, Reg(inst.offset0), Reg(inst.offset1), Reg(0))
if 'RETURN_DATA' in result and ('_RTN' in op_name or '_LOAD' in op_name):
_vgpr_write(V, vdst, result['RETURN_DATA']._val, ndwords * 2 if '_2ADDR_' in op_name else ndwords)
op, addr_reg, data_reg, vdst, offset, saddr = inst.op, inst.addr, inst.data, inst.vdst, _sext(inst.offset, 13), inst.saddr
addr = V[addr_reg] | (V[addr_reg+1] << 32)
addr = (st.rsgpr64(saddr) + V[addr_reg] + offset) & MASK64 if saddr not in (NULL, 0x7f) else (addr + offset) & MASK64
if op in FLAT_LOAD:
cnt, sz, sign = FLAT_LOAD[op]
for i in range(cnt): val = mem_read(addr + i * sz, sz); V[vdst + i] = _sext(val, sz * 8) & MASK32 if sign else val
elif op in FLAT_STORE:
cnt, sz = FLAT_STORE[op]
for i in range(cnt): mem_write(addr + i * sz, sz, V[data_reg + i] & ((1 << (sz * 8)) - 1))
elif op in FLAT_D16_LOAD:
sz, sign, hi = FLAT_D16_LOAD[op]
val = mem_read(addr, sz)
if sign: val = _sext(val, sz * 8) & 0xffff
V[vdst] = _dst16(V[vdst], val, hi)
elif op in FLAT_D16_STORE:
sz, hi = FLAT_D16_STORE[op]
mem_write(addr, sz, _src16(V[data_reg], hi) & ((1 << (sz * 8)) - 1))
else: raise NotImplementedError(f"FLAT op {op}")
return
if isinstance(inst, DS):
op, addr0, vdst = inst.op, (V[inst.addr] + inst.offset0) & 0xffff, inst.vdst
if op in DS_LOAD:
cnt, sz, sign = DS_LOAD[op]
for i in range(cnt): val = int.from_bytes(lds[addr0+i*sz:addr0+i*sz+sz], 'little'); V[vdst + i] = _sext(val, sz * 8) & MASK32 if sign else val
elif op in DS_STORE:
cnt, sz = DS_STORE[op]
for i in range(cnt): lds[addr0+i*sz:addr0+i*sz+sz] = (V[inst.data0 + i] & ((1 << (sz * 8)) - 1)).to_bytes(sz, 'little')
elif op in DS_LOAD_2ADDR:
# Load two values from addr+offset0*sz and addr+offset1*sz into vdst (B32: 1 dword each, B64: 2 dwords each)
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
sz = DS_LOAD_2ADDR[op]
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
cnt = sz // 4 # 1 for B32, 2 for B64
for i in range(cnt): V[vdst + i] = int.from_bytes(lds[addr0+i*4:addr0+i*4+4], 'little')
for i in range(cnt): V[vdst + cnt + i] = int.from_bytes(lds[addr1+i*4:addr1+i*4+4], 'little')
elif op in DS_STORE_2ADDR:
# Store two values from data0 and data1 to addr+offset0*sz and addr+offset1*sz
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
sz = DS_STORE_2ADDR[op]
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
cnt = sz // 4
for i in range(cnt): lds[addr0+i*4:addr0+i*4+4] = (V[inst.data0 + i] & MASK32).to_bytes(4, 'little')
for i in range(cnt): lds[addr1+i*4:addr1+i*4+4] = (V[inst.data1 + i] & MASK32).to_bytes(4, 'little')
else: raise NotImplementedError(f"DS op {op}")
return
# VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes)
@@ -594,25 +260,24 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None)
vdsty = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1)
inputs = [(inst.opx, st.rsrc(inst.srcx0, lane), V[inst.vsrcx1], V[inst.vdstx], inst.vdstx),
(inst.opy, st.rsrc(inst.srcy0, lane), V[inst.vsrcy1], V[vdsty], vdsty)]
def exec_vopd(vopd_op, s0, s1, d0):
op = _VOPD_TO_VOP[vopd_op]
return compiled[type(op)][op](Reg(s0), Reg(s1), None, Reg(d0), Reg(st.scc), Reg(st.vcc), lane, Reg(st.exec_mask), st.literal, None)['D0']._val
for vopd_op, s0, s1, d0, dst in inputs: V[dst] = exec_vopd(vopd_op, s0, s1, d0)
results = [(dst, fn(s0, s1, 0, d0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})['d0'])
for vopd_op, s0, s1, d0, dst in inputs if (op := _VOPD_TO_VOP.get(vopd_op)) and (fn := compiled.get(type(op), {}).get(op))]
for dst, val in results: V[dst] = val
return
# VOP3SD: has extra scalar dest for carry output
if isinstance(inst, VOP3SD):
fn = compiled[VOP3SDOp][inst.op]
fn = compiled.get(VOP3SDOp, {}).get(inst.op)
if fn is None: raise NotImplementedError(f"{inst.op.name} not in pseudocode")
# Read sources based on register counts from inst properties
def rsrc_n(src, regs): return st.rsrc64(src, lane) if regs == 2 else st.rsrc(src, lane)
s0, s1, s2 = rsrc_n(inst.src0, inst.src_regs(0)), rsrc_n(inst.src1, inst.src_regs(1)), rsrc_n(inst.src2, inst.src_regs(2))
# Carry-in ops use src2 as carry bitmask instead of VCC
vcc = st.rsgpr64(inst.src2) if 'CO_CI' in inst.op_name else st.vcc
result = fn(Reg(s0), Reg(s1), Reg(s2), Reg(V[inst.vdst]), Reg(st.scc), Reg(vcc), lane, Reg(st.exec_mask), st.literal, None)
d0_val = result['D0']._val
V[inst.vdst] = d0_val & MASK32
if inst.dst_regs() == 2: V[inst.vdst + 1] = (d0_val >> 32) & MASK32
if 'VCC' in result: st.pend_sgpr_lane(inst.sdst, lane, (result['VCC']._val >> lane) & 1)
result = fn(s0, s1, s2, V[inst.vdst], st.scc, vcc, lane, st.exec_mask, st.literal, None, {})
V[inst.vdst] = result['d0'] & MASK32
if result.get('d0_64'): V[inst.vdst + 1] = (result['d0'] >> 32) & MASK32
if result.get('vcc_lane') is not None: st.pend_sgpr_lane(inst.sdst, lane, result['vcc_lane'])
return
# Get op enum and sources (None means "no source" for that operand)
@@ -652,7 +317,8 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None)
if abs_ & (1<<i): srcs[i] = abs(srcs[i])
if neg & (1<<i): srcs[i] = -srcs[i]
result = srcs[0] * srcs[1] + srcs[2]
st.vgpr[lane][inst.vdst] = _i32(result) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result), inst.op == VOP3POp.V_FMA_MIXHI_F16)
V = st.vgpr[lane]
V[inst.vdst] = _i32(result) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result), inst.op == VOP3POp.V_FMA_MIXHI_F16)
return
# VOP3P packed ops: opsel selects halves for lo, opsel_hi for hi; neg toggles f16 sign
raws = [st.rsrc_f16(inst.src0, lane), st.rsrc_f16(inst.src1, lane), st.rsrc_f16(inst.src2, lane) if inst.src2 is not None else 0]
@@ -661,13 +327,15 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None)
hi_sels = [opsel_hi & 1, opsel_hi & 2, opsel_hi2]
srcs = [((_src16(raws[i], hi_sels[i]) ^ (0x8000 if neg_hi & (1<<i) else 0)) << 16) |
(_src16(raws[i], opsel & (1<<i)) ^ (0x8000 if neg & (1<<i) else 0)) for i in range(3)]
result = compiled[VOP3POp][inst.op](Reg(srcs[0]), Reg(srcs[1]), Reg(srcs[2]), Reg(0), Reg(st.scc), Reg(st.vcc), lane, Reg(st.exec_mask), st.literal, None)
st.vgpr[lane][inst.vdst] = result['D0']._val & MASK32
fn = compiled.get(VOP3POp, {}).get(inst.op)
if fn is None: raise NotImplementedError(f"{inst.op.name} not in pseudocode")
st.vgpr[lane][inst.vdst] = fn(srcs[0], srcs[1], srcs[2], 0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})['d0'] & MASK32
return
else: raise NotImplementedError(f"Unknown vector type {type(inst)}")
op_cls = type(inst.op)
if (fn := compiled.get(op_cls, {}).get(inst.op)) is None: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
fn = compiled.get(op_cls, {}).get(inst.op)
if fn is None: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
# Read sources (with VOP3 modifiers if applicable)
neg, abs_ = (getattr(inst, 'neg', 0), getattr(inst, 'abs', 0)) if isinstance(inst, VOP3) else (0, 0)
@@ -709,27 +377,24 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None)
# Execute compiled function - pass src0_idx and vdst_idx for lane instructions
# For VGPR access: src0 index is the VGPR number (src0 - 256 if VGPR, else src0 for SGPR)
src0_idx = (src0 - 256) if src0 is not None and src0 >= 256 else (src0 if src0 is not None else 0)
result = fn(Reg(s0), Reg(s1), Reg(s2), Reg(d0), Reg(st.scc), Reg(vcc_for_fn), lane, Reg(st.exec_mask), st.literal, st.vgpr, src0_idx, vdst)
result = fn(s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, st.literal, st.vgpr, {}, src0_idx, vdst)
# Apply results - extract values from returned Reg objects
# Apply results
if 'vgpr_write' in result:
# Lane instruction wrote to VGPR: (lane, vgpr_idx, value)
wr_lane, wr_idx, wr_val = result['vgpr_write']
st.vgpr[wr_lane][wr_idx] = wr_val
if 'VCC' in result:
if 'vcc_lane' in result:
# VOP2 carry ops write to VCC implicitly; VOPC/VOP3 write to vdst
st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, (result['VCC']._val >> lane) & 1)
if 'EXEC' in result:
# V_CMPX instructions write to EXEC per-lane (not to vdst)
st.pend_sgpr_lane(EXEC_LO, lane, (result['EXEC']._val >> lane) & 1)
elif op_cls is VOPCOp:
# VOPC comparison result stored in D0 bitmask, extract lane bit (non-CMPX only)
st.pend_sgpr_lane(vdst, lane, (result['D0']._val >> lane) & 1)
if op_cls is not VOPCOp and 'vgpr_write' not in result:
st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, result['vcc_lane'])
if 'exec_lane' in result:
# V_CMPX instructions write to EXEC per-lane
st.pend_sgpr_lane(EXEC_LO, lane, result['exec_lane'])
if 'd0' in result and op_cls is not VOPCOp and 'vgpr_write' not in result:
writes_to_sgpr = 'READFIRSTLANE' in inst.op_name or 'READLANE' in inst.op_name
d0_val = result['D0']._val
d0_val = result['d0']
if writes_to_sgpr: st.wsgpr(vdst, d0_val & MASK32)
elif inst.dst_regs() == 2: V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
elif result.get('d0_64'): V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
elif inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi)
else: V[vdst] = d0_val & MASK32
@@ -759,7 +424,7 @@ def exec_wmma(st: WaveState, inst, op: VOP3POp) -> None:
# MAIN EXECUTION LOOP
# ═══════════════════════════════════════════════════════════════════════════════
def step_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int:
def step_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) -> int:
inst = program.get(st.pc)
if inst is None: return 1
inst_words, st.literal = inst._words, getattr(inst, '_literal', None) or 0
@@ -779,7 +444,7 @@ def step_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int
st.pc += inst_words
return 0
def exec_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int:
def exec_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) -> int:
while st.pc in program:
result = step_wave(program, st, lds, n_lanes)
if result == -1: return 0
@@ -789,7 +454,7 @@ def exec_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int
def exec_workgroup(program: Program, workgroup_id: tuple[int, int, int], local_size: tuple[int, int, int], args_ptr: int,
wg_id_sgpr_base: int, wg_id_enables: tuple[bool, bool, bool]) -> None:
lx, ly, lz = local_size
total_threads, lds = lx * ly * lz, LDSMem(bytearray(65536))
total_threads, lds = lx * ly * lz, bytearray(65536)
waves: list[tuple[WaveState, int, int]] = []
for wave_start in range(0, total_threads, WAVE_SIZE):
n_lanes, st = min(WAVE_SIZE, total_threads - wave_start), WaveState()

View File

@@ -8,7 +8,7 @@ import os
os.environ["SQTT"] = "1"
os.environ["PROFILE"] = "1"
os.environ["SQTT_LIMIT_SE"] = "2"
os.environ["SQTT_TOKEN_EXCLUDE"] = "3786" # exclude WAVERDY, REG, EVENT, UTILCTR, WAVEALLOC, PERF, ALUEXEC
os.environ["SQTT_TOKEN_EXCLUDE"] = "3784" # Exclude WAVERDY, REG, EVENT, UTILCTR, WAVEALLOC, PERF
import unittest
from tinygrad.device import Device
@@ -242,7 +242,7 @@ class TestEmulatorSQTT(unittest.TestCase):
def test_valu_independent_8(self): self._test_valu_independent_n(8)
def test_valu_independent_16(self): self._test_valu_independent_n(16)
def test_trans_independent_16(self): self._test_valu_independent_n(16, True)
def test_trans_independent_16(self): self._test_valu_independent_n(16, trans=True)
def test_valu_chain(self):
"""VALU instructions with chain dependencies."""