mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
rever emu to master
This commit is contained in:
@@ -1,14 +1,12 @@
|
||||
# RDNA3 emulator - executes compiled pseudocode from AMD ISA PDF
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
import ctypes
|
||||
from enum import IntEnum
|
||||
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
||||
from extra.assembly.amd.pcode import Reg
|
||||
import ctypes, struct
|
||||
from extra.assembly.amd.dsl import Inst, RawImm, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
||||
from extra.assembly.amd.asm import detect_format
|
||||
from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
|
||||
SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, SCRATCHOp, VOPDOp)
|
||||
SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, VOPDOp)
|
||||
|
||||
Program = dict[int, Inst]
|
||||
WAVE_SIZE, SGPR_COUNT, VGPR_COUNT = 32, 128, 256
|
||||
@@ -30,402 +28,35 @@ def _dst16(cur: int, val: int, is_hi: bool) -> int: return (cur & 0x0000ffff) |
|
||||
def _vgpr_hi(src: int) -> bool: return src >= 256 and ((src - 256) & 0x80) != 0
|
||||
def _vgpr_masked(src: int) -> int: return ((src - 256) & 0x7f) + 256 if src >= 256 else src
|
||||
|
||||
# Helper: get number of dwords from memory op name
|
||||
def _op_ndwords(name: str) -> int:
|
||||
if '_B128' in name: return 4
|
||||
if '_B96' in name: return 3
|
||||
if any(s in name for s in ('_B64', '_U64', '_I64', '_F64')): return 2
|
||||
return 1
|
||||
|
||||
# Helper: build multi-dword Reg from consecutive VGPRs
|
||||
def _vgpr_read(V: list, base: int, ndwords: int) -> Reg: return Reg(sum(V[base + i] << (32 * i) for i in range(ndwords)))
|
||||
|
||||
# Helper: write multi-dword value to consecutive VGPRs
|
||||
def _vgpr_write(V: list, base: int, val: int, ndwords: int):
|
||||
for i in range(ndwords): V[base + i] = (val >> (32 * i)) & MASK32
|
||||
|
||||
# Memory access
|
||||
_valid_mem_ranges: list[tuple[int, int]] = []
|
||||
def set_valid_mem_ranges(ranges: set[tuple[int, int]]) -> None: _valid_mem_ranges.clear(); _valid_mem_ranges.extend(ranges)
|
||||
def _mem_valid(addr: int, size: int) -> bool:
|
||||
return not _valid_mem_ranges or any(s <= addr and addr + size <= s + z for s, z in _valid_mem_ranges)
|
||||
def _ctypes_at(addr: int, size: int): return (ctypes.c_uint8 if size == 1 else ctypes.c_uint16 if size == 2 else ctypes.c_uint64 if size == 8 else ctypes.c_uint32).from_address(addr)
|
||||
def _ctypes_at(addr: int, size: int): return (ctypes.c_uint8 if size == 1 else ctypes.c_uint16 if size == 2 else ctypes.c_uint32).from_address(addr)
|
||||
def mem_read(addr: int, size: int) -> int: return _ctypes_at(addr, size).value if _mem_valid(addr, size) else 0
|
||||
def mem_write(addr: int, size: int, val: int) -> None:
|
||||
if _mem_valid(addr, size): _ctypes_at(addr, size).value = val
|
||||
|
||||
def _make_mem_accessor(read_fn, write_fn):
|
||||
"""Create a memory accessor class with the given read/write functions."""
|
||||
class _MemAccessor:
|
||||
__slots__ = ('_addr',)
|
||||
def __init__(self, addr: int): self._addr = int(addr)
|
||||
u8 = property(lambda s: read_fn(s._addr, 1), lambda s, v: write_fn(s._addr, 1, int(v)))
|
||||
u16 = property(lambda s: read_fn(s._addr, 2), lambda s, v: write_fn(s._addr, 2, int(v)))
|
||||
u32 = property(lambda s: read_fn(s._addr, 4), lambda s, v: write_fn(s._addr, 4, int(v)))
|
||||
u64 = property(lambda s: read_fn(s._addr, 8), lambda s, v: write_fn(s._addr, 8, int(v)))
|
||||
i8 = property(lambda s: _sext(read_fn(s._addr, 1), 8), lambda s, v: write_fn(s._addr, 1, int(v)))
|
||||
i16 = property(lambda s: _sext(read_fn(s._addr, 2), 16), lambda s, v: write_fn(s._addr, 2, int(v)))
|
||||
i32 = property(lambda s: _sext(read_fn(s._addr, 4), 32), lambda s, v: write_fn(s._addr, 4, int(v)))
|
||||
i64 = property(lambda s: _sext(read_fn(s._addr, 8), 64), lambda s, v: write_fn(s._addr, 8, int(v)))
|
||||
b8, b16, b32, b64 = u8, u16, u32, u64
|
||||
return _MemAccessor
|
||||
|
||||
_GlobalMemAccessor = _make_mem_accessor(mem_read, mem_write)
|
||||
|
||||
class _GlobalMem:
|
||||
"""Global memory wrapper that supports MEM[addr].u32 style access."""
|
||||
def __getitem__(self, addr) -> _GlobalMemAccessor: return _GlobalMemAccessor(addr)
|
||||
GlobalMem = _GlobalMem()
|
||||
|
||||
class LDSMem:
|
||||
"""LDS memory wrapper that supports MEM[addr].u32 style access."""
|
||||
__slots__ = ('_lds',)
|
||||
def __init__(self, lds: bytearray): self._lds = lds
|
||||
def _read(self, addr: int, size: int) -> int:
|
||||
addr = addr & 0xffff
|
||||
return int.from_bytes(self._lds[addr:addr+size], 'little') if addr + size <= len(self._lds) else 0
|
||||
def _write(self, addr: int, size: int, val: int):
|
||||
addr = addr & 0xffff
|
||||
if addr + size <= len(self._lds): self._lds[addr:addr+size] = (int(val) & ((1 << (size*8)) - 1)).to_bytes(size, 'little')
|
||||
def __getitem__(self, addr): return _make_mem_accessor(self._read, self._write)(addr)
|
||||
|
||||
# Memory op tables (not pseudocode - these are format descriptions)
|
||||
def _mem_ops(ops, suffix_map):
|
||||
return {getattr(e, f"{p}_{s}"): v for e in ops for s, v in suffix_map.items() for p in [e.__name__.replace("Op", "")]}
|
||||
_LOAD_MAP = {'LOAD_B32': (1,4,0), 'LOAD_B64': (2,4,0), 'LOAD_B96': (3,4,0), 'LOAD_B128': (4,4,0), 'LOAD_U8': (1,1,0), 'LOAD_I8': (1,1,1), 'LOAD_U16': (1,2,0), 'LOAD_I16': (1,2,1)}
|
||||
_STORE_MAP = {'STORE_B32': (1,4), 'STORE_B64': (2,4), 'STORE_B96': (3,4), 'STORE_B128': (4,4), 'STORE_B8': (1,1), 'STORE_B16': (1,2)}
|
||||
FLAT_LOAD, FLAT_STORE = _mem_ops([GLOBALOp, FLATOp], _LOAD_MAP), _mem_ops([GLOBALOp, FLATOp], _STORE_MAP)
|
||||
# D16 ops: load/store 16-bit to lower or upper half of VGPR. Format: (size, sign, hi) where hi=1 means upper 16 bits
|
||||
_D16_LOAD_MAP = {'LOAD_D16_U8': (1,0,0), 'LOAD_D16_I8': (1,1,0), 'LOAD_D16_B16': (2,0,0),
|
||||
'LOAD_D16_HI_U8': (1,0,1), 'LOAD_D16_HI_I8': (1,1,1), 'LOAD_D16_HI_B16': (2,0,1)}
|
||||
_D16_STORE_MAP = {'STORE_D16_HI_B8': (1,1), 'STORE_D16_HI_B16': (2,1)} # (size, hi)
|
||||
FLAT_D16_LOAD = _mem_ops([GLOBALOp, FLATOp], _D16_LOAD_MAP)
|
||||
FLAT_D16_STORE = _mem_ops([GLOBALOp, FLATOp], _D16_STORE_MAP)
|
||||
DS_LOAD = {DSOp.DS_LOAD_B32: (1,4,0), DSOp.DS_LOAD_B64: (2,4,0), DSOp.DS_LOAD_B128: (4,4,0), DSOp.DS_LOAD_U8: (1,1,0), DSOp.DS_LOAD_I8: (1,1,1), DSOp.DS_LOAD_U16: (1,2,0), DSOp.DS_LOAD_I16: (1,2,1)}
|
||||
DS_STORE = {DSOp.DS_STORE_B32: (1,4), DSOp.DS_STORE_B64: (2,4), DSOp.DS_STORE_B128: (4,4), DSOp.DS_STORE_B8: (1,1), DSOp.DS_STORE_B16: (1,2)}
|
||||
# 2ADDR ops: load/store two values using offset0 and offset1
|
||||
DS_LOAD_2ADDR = {DSOp.DS_LOAD_2ADDR_B32: 4, DSOp.DS_LOAD_2ADDR_B64: 8}
|
||||
DS_STORE_2ADDR = {DSOp.DS_STORE_2ADDR_B32: 4, DSOp.DS_STORE_2ADDR_B64: 8}
|
||||
SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# SQTT TRACING - Emit packets matching real hardware output
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
# Transcendental ops that produce INST packets with op=VALU_TRANS (0x0b)
|
||||
_TRANS_OPS = {'V_RCP_F32', 'V_RCP_F64', 'V_RSQ_F32', 'V_RSQ_F64', 'V_SQRT_F32', 'V_SQRT_F64',
|
||||
'V_LOG_F32', 'V_EXP_F32', 'V_SIN_F32', 'V_COS_F32', 'V_RCP_F16', 'V_RSQ_F16', 'V_SQRT_F16'}
|
||||
|
||||
# Latency model from hardware measurements (warm instruction cache):
|
||||
# Startup: WAVESTART -> first instruction (~32 cycles with warm cache, no REG packet)
|
||||
# Cold cache: WAVESTART -> REG (~137 cycles) -> first instruction (~150 cycles after REG)
|
||||
# SALU: issues every cycle, result ready 2 cycles after issue, ALUEXEC at ready time
|
||||
# VALU: issues every cycle, ALUEXEC at issue+6 for each inst, serialized with +1 intervals
|
||||
# TRANS: issues every 4 cycles, ALUEXEC at last_issue+1, then +8 intervals
|
||||
# For dependent instructions, ALUEXEC is at source_ready + 10 (first dep) or + 9 (chained)
|
||||
# VALU queue depth ~7: after 7 in-flight, ALUEXEC interleaves with VALUINST
|
||||
WAVESTART_TO_INST_CYCLES = 32 # cycles from WAVESTART to first instruction (warm cache)
|
||||
SALU_LATENCY = 2 # cycles from issue to result ready
|
||||
VALU_EXEC_LATENCY = 6 # cycles from first issue to first ALUEXEC
|
||||
TRANS_ISSUE_CYCLES = 4 # cycles between transcendental instruction issues
|
||||
TRANS_LATENCY = 9 # cycles from trans issue to ALUEXEC
|
||||
|
||||
class SQTTState:
|
||||
"""SQTT tracing state - emits packets matching real hardware (warm cache model)."""
|
||||
__slots__ = ('cycle', 'packets', 'pending_exec', 'wave_id', 'simd', 'cu', 'sgpr_ready', 'vgpr_ready',
|
||||
'last_salu_exec', 'last_valu_exec', 'trans_count', 'last_trans_issue', 'last_trans_exec', 'first_valu_issue', 'valu_count', 'last_inst_type')
|
||||
|
||||
def __init__(self, wave_id: int = 0, simd: int = 0, cu: int = 0):
|
||||
self.cycle = 0
|
||||
self.packets: list = []
|
||||
self.pending_exec: list[tuple[int, int]] = [] # (completion_cycle, src) for ALUEXEC
|
||||
self.wave_id, self.simd, self.cu = wave_id, simd, cu
|
||||
self.sgpr_ready: dict[int, int] = {} # sgpr -> cycle when result ready
|
||||
self.last_inst_type: str = '' # track last instruction type for gap calculation
|
||||
self.vgpr_ready: dict[int, int] = {} # vgpr -> cycle when result ready
|
||||
self.last_salu_exec = 0 # last SALU ALUEXEC time (for +1 spacing)
|
||||
self.last_valu_exec = 0 # last VALU ALUEXEC time (for +1 spacing)
|
||||
self.trans_count = 0 # number of trans instructions issued
|
||||
self.last_trans_issue = 0 # cycle when last trans was issued
|
||||
self.last_trans_exec = 0 # cycle when last trans ALUEXEC is scheduled
|
||||
self.first_valu_issue = 0 # cycle when first VALU was issued
|
||||
self.valu_count = 0 # number of VALU instructions issued
|
||||
|
||||
def emit_wavestart(self):
|
||||
from extra.assembly.amd.sqtt import WAVESTART
|
||||
self.packets.append(WAVESTART(_time=self.cycle, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3))
|
||||
self.cycle += WAVESTART_TO_INST_CYCLES
|
||||
|
||||
def emit_waveend(self):
|
||||
from extra.assembly.amd.sqtt import WAVEEND
|
||||
self.packets.append(WAVEEND(_time=self.cycle, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3))
|
||||
|
||||
def emit_inst(self, inst_op: int):
|
||||
from extra.assembly.amd.sqtt import INST
|
||||
self.packets.append(INST(_time=self.cycle, wave=self.wave_id, op=inst_op))
|
||||
|
||||
def emit_valuinst(self):
|
||||
from extra.assembly.amd.sqtt import VALUINST
|
||||
self.packets.append(VALUINST(_time=self.cycle, wave=self.wave_id))
|
||||
|
||||
def emit_aluexec(self, src: int):
|
||||
from extra.assembly.amd.sqtt import ALUEXEC
|
||||
self.packets.append(ALUEXEC(_time=self.cycle, src=src))
|
||||
|
||||
def emit_immediate(self):
|
||||
from extra.assembly.amd.sqtt import IMMEDIATE
|
||||
self.packets.append(IMMEDIATE(_time=self.cycle, wave=self.wave_id))
|
||||
|
||||
def _get_src_regs(self, inst: Inst) -> list[tuple[str, int]]:
|
||||
"""Extract source register references from instruction."""
|
||||
srcs = []
|
||||
if isinstance(inst, (SOP1, SOP2, SOPC)):
|
||||
if hasattr(inst, 'ssrc0') and inst.ssrc0 < SGPR_COUNT: srcs.append(('s', inst.ssrc0))
|
||||
if hasattr(inst, 'ssrc1') and inst.ssrc1 < SGPR_COUNT: srcs.append(('s', inst.ssrc1))
|
||||
elif isinstance(inst, SOPK):
|
||||
pass # immediate only
|
||||
elif isinstance(inst, VOP1):
|
||||
# VOP1: only src0
|
||||
if inst.src0 < SGPR_COUNT: srcs.append(('s', inst.src0))
|
||||
elif inst.src0 >= 256: srcs.append(('v', inst.src0 - 256))
|
||||
elif isinstance(inst, VOP2):
|
||||
# VOP2: src0 (can be SGPR/VGPR/const) + vsrc1 (always VGPR)
|
||||
if inst.src0 < SGPR_COUNT: srcs.append(('s', inst.src0))
|
||||
elif inst.src0 >= 256: srcs.append(('v', inst.src0 - 256))
|
||||
srcs.append(('v', inst.vsrc1))
|
||||
elif isinstance(inst, (VOP3, VOP3SD, VOP3P, VOPC)):
|
||||
for attr in ('src0', 'src1', 'src2'):
|
||||
if hasattr(inst, attr):
|
||||
src = getattr(inst, attr)
|
||||
if src < SGPR_COUNT: srcs.append(('s', src))
|
||||
elif src >= 256: srcs.append(('v', src - 256))
|
||||
return srcs
|
||||
|
||||
def _get_dst_regs(self, inst: Inst) -> list[tuple[str, int]]:
|
||||
"""Extract destination register references from instruction."""
|
||||
dsts = []
|
||||
if isinstance(inst, (SOP1, SOP2, SOPK)):
|
||||
if hasattr(inst, 'sdst') and inst.sdst < SGPR_COUNT:
|
||||
dsts.append(('s', inst.sdst))
|
||||
if inst.dst_regs() == 2: dsts.append(('s', inst.sdst + 1))
|
||||
elif isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD)):
|
||||
if hasattr(inst, 'vdst'):
|
||||
dsts.append(('v', inst.vdst))
|
||||
if inst.dst_regs() == 2: dsts.append(('v', inst.vdst + 1))
|
||||
return dsts
|
||||
|
||||
def _check_dep_stall(self, inst: Inst) -> int:
|
||||
"""Calculate stall cycles due to RAW dependencies - wait until all sources ready."""
|
||||
max_ready = 0
|
||||
for typ, reg in self._get_src_regs(inst):
|
||||
ready = (self.sgpr_ready if typ == 's' else self.vgpr_ready).get(reg, 0)
|
||||
max_ready = max(max_ready, ready)
|
||||
return max(0, max_ready - self.cycle)
|
||||
|
||||
def _record_dst_ready(self, inst: Inst, ready_cycle: int):
|
||||
"""Record when destination registers will be ready."""
|
||||
for typ, reg in self._get_dst_regs(inst):
|
||||
(self.sgpr_ready if typ == 's' else self.vgpr_ready)[reg] = ready_cycle
|
||||
|
||||
def trace_inst(self, inst: Inst):
|
||||
"""Emit appropriate SQTT packets for an instruction.
|
||||
|
||||
Key insight from hardware traces:
|
||||
- Instructions issue every cycle (INST/VALUINST packets at +1 intervals)
|
||||
- ALUEXEC shows when instruction completes (result ready)
|
||||
- For independent ops: ALUEXEC at issue+latency, then +1 each (pipeline throughput)
|
||||
- For dependent ops: ALUEXEC delayed until source ready + latency
|
||||
"""
|
||||
from extra.assembly.amd.sqtt import InstOp, AluSrc
|
||||
|
||||
if isinstance(inst, (SOP1, SOP2, SOPC, SOPK)):
|
||||
# SALU: issue now, complete after latency (or wait for deps)
|
||||
self.emit_inst(InstOp.SALU)
|
||||
# Completion time: max of (issue + latency) and (source_ready + latency) and (last_exec + 1)
|
||||
src_ready = max((self.sgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst)), default=0)
|
||||
ready_cycle = max(self.cycle, src_ready) + SALU_LATENCY
|
||||
exec_cycle = max(ready_cycle, self.last_salu_exec + 1)
|
||||
self.pending_exec.append((exec_cycle, AluSrc.SALU))
|
||||
self.last_salu_exec = exec_cycle
|
||||
self._record_dst_ready(inst, ready_cycle)
|
||||
self.last_inst_type = 'SALU'
|
||||
|
||||
elif isinstance(inst, SOPP):
|
||||
# s_nop and other SOPP emit IMMEDIATE packets
|
||||
if self.last_inst_type == 'SALU':
|
||||
# SALU→SOPP: emit ALUEXECs with max gap of 2 once, then IMMEDIATEs
|
||||
# First ALUEXEC at last INST cycle, then allow one +2 gap, rest interleave
|
||||
# cycle is currently 1 past last INST
|
||||
last_inst_cycle = self.cycle - 1
|
||||
first = True
|
||||
had_gap = False # track if we've had a +2 gap
|
||||
while self.pending_exec:
|
||||
exec_cycle, src = self.pending_exec[0]
|
||||
if first:
|
||||
self.pending_exec.pop(0)
|
||||
self.cycle = last_inst_cycle # First ALUEXEC at same cycle as last INST
|
||||
first = False
|
||||
elif exec_cycle == self.cycle + 1: # consecutive (+1)
|
||||
self.pending_exec.pop(0)
|
||||
self.cycle = exec_cycle
|
||||
elif exec_cycle == self.cycle + 2 and not had_gap: # allow one +2 gap
|
||||
self.pending_exec.pop(0)
|
||||
self.cycle = exec_cycle
|
||||
had_gap = True
|
||||
else:
|
||||
break # too large a gap or second +2 gap - rest interleave with IMMEDIATEs
|
||||
self.emit_aluexec(src)
|
||||
self.cycle += 1 # IMMEDIATE starts +1 after last ALUEXEC
|
||||
elif self.last_inst_type == 'TRANS':
|
||||
# TRANS→SOPP: emit first ALUEXEC at its scheduled time, then IMMEDIATE +2 after
|
||||
# Sort pending execs and emit first one if it's before IMMEDIATE time
|
||||
self.pending_exec.sort(key=lambda x: x[0])
|
||||
if self.pending_exec:
|
||||
first_exec_cycle = self.pending_exec[0][0]
|
||||
# First IMMEDIATE at first_exec + 2
|
||||
imm_cycle = first_exec_cycle + 2
|
||||
# Emit ALUEXECs that come before IMMEDIATE
|
||||
while self.pending_exec and self.pending_exec[0][0] < imm_cycle:
|
||||
exec_cycle, src = self.pending_exec.pop(0)
|
||||
self.cycle = exec_cycle
|
||||
self.emit_aluexec(src)
|
||||
self.cycle = imm_cycle
|
||||
elif self.last_inst_type == 'VALU':
|
||||
# VALU→SOPP: 2-cycle gap, emit any ALUEXECs that would come before
|
||||
imm_cycle = self.cycle + 2 # When IMMEDIATE would be emitted
|
||||
# Emit ALUEXECs that come before the IMMEDIATE
|
||||
while self.pending_exec and self.pending_exec[0][0] < imm_cycle:
|
||||
exec_cycle, src = self.pending_exec.pop(0)
|
||||
self.cycle = max(self.cycle, exec_cycle)
|
||||
self.emit_aluexec(src)
|
||||
# Jump to IMMEDIATE cycle (either imm_cycle or 1 after last ALUEXEC, whichever is later)
|
||||
self.cycle = max(imm_cycle, self.cycle + 1)
|
||||
# Emit IMMEDIATE first, then any pending ALUEXECs at same cycle (HW order)
|
||||
self.emit_immediate()
|
||||
while self.pending_exec and self.pending_exec[0][0] <= self.cycle:
|
||||
exec_cycle, src = self.pending_exec.pop(0)
|
||||
old_cycle = self.cycle
|
||||
self.cycle = exec_cycle
|
||||
self.emit_aluexec(src)
|
||||
self.cycle = old_cycle
|
||||
# s_nop(N) waits N+1 cycles - apply delay for N > 0 (simm16 field)
|
||||
if inst.op == SOPPOp.S_NOP and inst.simm16 > 0:
|
||||
self.cycle += inst.simm16 # add delay before next instruction
|
||||
self.last_inst_type = 'SOPP'
|
||||
|
||||
elif isinstance(inst, SMEM):
|
||||
pass # skip for ALU focus
|
||||
|
||||
elif isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC)):
|
||||
# VALU: issue now, emit ALUEXEC for completed instructions if queue is full
|
||||
from extra.assembly.amd.sqtt import AluSrc
|
||||
|
||||
op_name = inst.op_name if hasattr(inst, 'op_name') else ''
|
||||
is_trans = any(t in op_name for t in _TRANS_OPS)
|
||||
|
||||
if is_trans:
|
||||
# Transcendental: emit INST, 4-cycle issue, add ALUEXEC to pending
|
||||
self.emit_inst(InstOp.VALU_TRANS)
|
||||
self.trans_count += 1
|
||||
# Check for dependency on VGPR sources
|
||||
src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0)
|
||||
if src_ready > self.cycle:
|
||||
# Dependent trans: ALUEXEC at source_ready + 6 (for VALU source) or +10 (for trans source)
|
||||
# Check if source is from a trans instruction (its ready time would be > issue + 6)
|
||||
src_is_trans = any(self.vgpr_ready.get(r, 0) > self.cycle + 6 for _, r in self._get_src_regs(inst) if _ == 'v')
|
||||
exec_cycle = src_ready + (10 if src_is_trans else 6)
|
||||
else:
|
||||
# Independent trans: ALUEXEC at issue + 9
|
||||
exec_cycle = self.cycle + TRANS_LATENCY
|
||||
self.pending_exec.append((exec_cycle, AluSrc.VALU))
|
||||
self._record_dst_ready(inst, exec_cycle) # Record when this trans result is ready
|
||||
self.last_trans_exec = exec_cycle
|
||||
self.last_trans_issue = self.cycle
|
||||
# Trans instructions take 4 cycles to issue
|
||||
self.cycle += TRANS_ISSUE_CYCLES - 1 # -1 because we add 1 at end of trace_inst
|
||||
self.last_inst_type = 'TRANS'
|
||||
else:
|
||||
# Regular VALU: emit VALUINST, may interleave ALUEXEC
|
||||
self.emit_valuinst()
|
||||
|
||||
# Track first VALU issue for latency calculation
|
||||
if self.first_valu_issue == 0:
|
||||
self.first_valu_issue = self.cycle
|
||||
|
||||
# Emit any pending ALUEXECs that have completed by now (after VALUINST)
|
||||
while self.pending_exec and self.pending_exec[0][0] <= self.cycle:
|
||||
exec_cycle, src = self.pending_exec.pop(0)
|
||||
old_cycle = self.cycle
|
||||
self.cycle = exec_cycle
|
||||
self.emit_aluexec(src)
|
||||
self.cycle = old_cycle
|
||||
|
||||
# Check for dependency
|
||||
src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0)
|
||||
has_dep = src_ready > self.cycle
|
||||
|
||||
if has_dep:
|
||||
# Dependent: first dependent is +6 from source, subsequent chained deps are +5
|
||||
# "Chained" means the source was also dependent (not the first independent VALU)
|
||||
chained = src_ready > self.first_valu_issue + VALU_EXEC_LATENCY
|
||||
exec_cycle = src_ready + (5 if chained else 6)
|
||||
else:
|
||||
# Independent VALU timing: 6-cycle pipeline latency
|
||||
# exec[i] = first_issue + 6 + i, with +1 spacing between consecutive execs
|
||||
exec_cycle = self.first_valu_issue + VALU_EXEC_LATENCY + self.valu_count
|
||||
exec_cycle = max(exec_cycle, self.last_valu_exec + 1) # +1 spacing
|
||||
|
||||
# Add to pending exec queue (sorted by time)
|
||||
self.pending_exec.append((exec_cycle, AluSrc.VALU))
|
||||
self.pending_exec.sort(key=lambda x: x[0])
|
||||
self.last_valu_exec = exec_cycle
|
||||
self.valu_count += 1
|
||||
self._record_dst_ready(inst, exec_cycle)
|
||||
self.last_inst_type = 'VALU'
|
||||
|
||||
elif isinstance(inst, VOPD):
|
||||
from extra.assembly.amd.sqtt import AluSrc
|
||||
|
||||
# First emit VALUINST (HW emits issue before completion at same cycle)
|
||||
self.emit_valuinst()
|
||||
|
||||
# Emit any pending ALUEXECs that have completed by now
|
||||
while self.pending_exec and self.pending_exec[0][0] <= self.cycle:
|
||||
exec_cycle, src = self.pending_exec.pop(0)
|
||||
old_cycle = self.cycle
|
||||
self.cycle = exec_cycle
|
||||
self.emit_aluexec(src)
|
||||
self.cycle = old_cycle
|
||||
src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0)
|
||||
has_dep = src_ready > self.cycle
|
||||
|
||||
if has_dep:
|
||||
exec_cycle = src_ready + 10
|
||||
else:
|
||||
exec_cycle = self.cycle + VALU_EXEC_LATENCY
|
||||
if self.last_valu_exec > 0:
|
||||
exec_cycle = max(exec_cycle, self.last_valu_exec + 1)
|
||||
|
||||
self.pending_exec.append((exec_cycle, AluSrc.VALU))
|
||||
self.pending_exec.sort(key=lambda x: x[0])
|
||||
self.last_valu_exec = exec_cycle
|
||||
self._record_dst_ready(inst, exec_cycle)
|
||||
|
||||
self.cycle += 1
|
||||
|
||||
def finalize(self):
|
||||
"""Emit all remaining pending ALUEXEC packets and WAVEEND."""
|
||||
from extra.assembly.amd.sqtt import AluSrc
|
||||
|
||||
# Emit any remaining pending ALUEXECs
|
||||
self.pending_exec.sort(key=lambda x: x[0])
|
||||
last_src = None
|
||||
for exec_cycle, src in self.pending_exec:
|
||||
self.cycle = exec_cycle
|
||||
self.emit_aluexec(src)
|
||||
last_src = src
|
||||
self.pending_exec.clear()
|
||||
|
||||
# WAVEEND timing depends on what comes last
|
||||
# If last instruction was SOPP: +1 normally, but +11 if last ALUEXEC was recent (VALU drain)
|
||||
# Otherwise: 14 cycles for no ALU, 20 for trans, 14/15 for SALU/VALU
|
||||
if self.last_inst_type == 'SOPP':
|
||||
# Check if last ALUEXEC was recent (within ~5 cycles of current)
|
||||
if self.valu_count > 0 and self.last_valu_exec >= self.cycle - 5:
|
||||
self.cycle += 11 # VALU drain time
|
||||
else:
|
||||
self.cycle += 1
|
||||
elif last_src is None:
|
||||
self.cycle += 14 # empty program or no ALU ops
|
||||
elif self.trans_count > 0:
|
||||
self.cycle += 20 # trans has longer WAVEEND delay
|
||||
else:
|
||||
self.cycle += 15 if last_src == AluSrc.VALU else 14
|
||||
self.emit_waveend()
|
||||
|
||||
# VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup)
|
||||
_VOPD_TO_VOP = {
|
||||
VOPDOp.V_DUAL_FMAC_F32: VOP3Op.V_FMAC_F32, VOPDOp.V_DUAL_FMAAK_F32: VOP2Op.V_FMAAK_F32, VOPDOp.V_DUAL_FMAMK_F32: VOP2Op.V_FMAMK_F32,
|
||||
@@ -547,46 +178,81 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
|
||||
s0 = st.rsrc64(ssrc0, 0) if inst.is_src_64(0) else (st.rsrc(ssrc0, 0) if not isinstance(inst, (SOPK, SOPP)) else (st.rsgpr(inst.sdst) if isinstance(inst, SOPK) else 0))
|
||||
s1 = st.rsrc64(inst.ssrc1, 0) if inst.is_src_64(1) else (st.rsrc(inst.ssrc1, 0) if isinstance(inst, (SOP2, SOPC)) else inst.simm16 if isinstance(inst, SOPK) else 0)
|
||||
d0 = st.rsgpr64(sdst) if inst.dst_regs() == 2 and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0)
|
||||
exec_mask = st.exec_mask
|
||||
literal = inst.simm16 if isinstance(inst, (SOPK, SOPP)) else st.literal
|
||||
|
||||
# Create Reg objects for compiled function - mask VCC/EXEC to 32 bits for wave32
|
||||
result = fn(Reg(s0), Reg(s1), None, Reg(d0), Reg(st.scc), Reg(st.vcc & MASK32), 0, Reg(st.exec_mask & MASK32), literal, None, PC=Reg(st.pc * 4))
|
||||
# Execute compiled function - pass PC in bytes for instructions that need it
|
||||
# For wave32, mask VCC and EXEC to 32 bits since only the lower 32 bits are relevant
|
||||
pc_bytes = st.pc * 4
|
||||
vcc32, exec32 = st.vcc & MASK32, exec_mask & MASK32
|
||||
result = fn(s0, s1, 0, d0, st.scc, vcc32, 0, exec32, literal, None, {}, pc=pc_bytes)
|
||||
|
||||
# Apply results - extract values from returned Reg objects
|
||||
if sdst is not None and 'D0' in result:
|
||||
(st.wsgpr64 if inst.dst_regs() == 2 else st.wsgpr)(sdst, result['D0']._val)
|
||||
if 'SCC' in result: st.scc = result['SCC']._val & 1
|
||||
if 'EXEC' in result: st.exec_mask = result['EXEC']._val
|
||||
if 'PC' in result:
|
||||
# Apply results
|
||||
if sdst is not None:
|
||||
(st.wsgpr64 if result.get('d0_64') else st.wsgpr)(sdst, result['d0'])
|
||||
if 'scc' in result: st.scc = result['scc']
|
||||
if 'exec' in result: st.exec_mask = result['exec']
|
||||
if 'new_pc' in result:
|
||||
# Convert absolute byte address to word delta
|
||||
pc_val = result['PC']._val
|
||||
new_pc = pc_val if pc_val < 0x8000000000000000 else pc_val - 0x10000000000000000
|
||||
new_pc_words = new_pc // 4
|
||||
# new_pc is where we want to go, st.pc is current position, inst._words will be added after
|
||||
new_pc_words = result['new_pc'] // 4
|
||||
return new_pc_words - st.pc - 1 # -1 because emulator adds inst_words (1 for scalar)
|
||||
return 0
|
||||
|
||||
def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None) -> None:
|
||||
def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = None) -> None:
|
||||
"""Execute vector instruction for one lane."""
|
||||
compiled = _get_compiled()
|
||||
V = st.vgpr[lane]
|
||||
|
||||
# Memory ops (FLAT/GLOBAL/SCRATCH and DS) - use generated pcode
|
||||
if isinstance(inst, (FLAT, DS)):
|
||||
op, vdst, op_name = inst.op, inst.vdst, inst.op.name
|
||||
fn, ndwords = compiled[type(op)][op], _op_ndwords(op_name)
|
||||
if isinstance(inst, FLAT):
|
||||
addr = V[inst.addr] | (V[inst.addr + 1] << 32)
|
||||
ADDR = (st.rsgpr64(inst.saddr) + V[inst.addr] + _sext(inst.offset, 13)) & MASK64 if inst.saddr not in (NULL, 0x7f) else (addr + _sext(inst.offset, 13)) & MASK64
|
||||
# For loads, VDATA comes from vdst (preserves unwritten bits); for stores, from inst.data
|
||||
vdata_src = vdst if 'LOAD' in op_name else inst.data
|
||||
result = fn(GlobalMem, ADDR, _vgpr_read(V, vdata_src, ndwords), Reg(V[vdst]), Reg(0))
|
||||
if 'VDATA' in result: _vgpr_write(V, vdst, result['VDATA']._val, ndwords)
|
||||
if 'RETURN_DATA' in result: _vgpr_write(V, vdst, result['RETURN_DATA']._val, ndwords)
|
||||
else: # DS
|
||||
DATA0, DATA1 = _vgpr_read(V, inst.data0, ndwords), _vgpr_read(V, inst.data1, ndwords) if inst.data1 is not None else Reg(0)
|
||||
result = fn(lds, Reg(V[inst.addr]), DATA0, DATA1, Reg(inst.offset0), Reg(inst.offset1), Reg(0))
|
||||
if 'RETURN_DATA' in result and ('_RTN' in op_name or '_LOAD' in op_name):
|
||||
_vgpr_write(V, vdst, result['RETURN_DATA']._val, ndwords * 2 if '_2ADDR_' in op_name else ndwords)
|
||||
# Memory ops (not ALU pseudocode)
|
||||
if isinstance(inst, FLAT):
|
||||
op, addr_reg, data_reg, vdst, offset, saddr = inst.op, inst.addr, inst.data, inst.vdst, _sext(inst.offset, 13), inst.saddr
|
||||
addr = V[addr_reg] | (V[addr_reg+1] << 32)
|
||||
addr = (st.rsgpr64(saddr) + V[addr_reg] + offset) & MASK64 if saddr not in (NULL, 0x7f) else (addr + offset) & MASK64
|
||||
if op in FLAT_LOAD:
|
||||
cnt, sz, sign = FLAT_LOAD[op]
|
||||
for i in range(cnt): val = mem_read(addr + i * sz, sz); V[vdst + i] = _sext(val, sz * 8) & MASK32 if sign else val
|
||||
elif op in FLAT_STORE:
|
||||
cnt, sz = FLAT_STORE[op]
|
||||
for i in range(cnt): mem_write(addr + i * sz, sz, V[data_reg + i] & ((1 << (sz * 8)) - 1))
|
||||
elif op in FLAT_D16_LOAD:
|
||||
sz, sign, hi = FLAT_D16_LOAD[op]
|
||||
val = mem_read(addr, sz)
|
||||
if sign: val = _sext(val, sz * 8) & 0xffff
|
||||
V[vdst] = _dst16(V[vdst], val, hi)
|
||||
elif op in FLAT_D16_STORE:
|
||||
sz, hi = FLAT_D16_STORE[op]
|
||||
mem_write(addr, sz, _src16(V[data_reg], hi) & ((1 << (sz * 8)) - 1))
|
||||
else: raise NotImplementedError(f"FLAT op {op}")
|
||||
return
|
||||
|
||||
if isinstance(inst, DS):
|
||||
op, addr0, vdst = inst.op, (V[inst.addr] + inst.offset0) & 0xffff, inst.vdst
|
||||
if op in DS_LOAD:
|
||||
cnt, sz, sign = DS_LOAD[op]
|
||||
for i in range(cnt): val = int.from_bytes(lds[addr0+i*sz:addr0+i*sz+sz], 'little'); V[vdst + i] = _sext(val, sz * 8) & MASK32 if sign else val
|
||||
elif op in DS_STORE:
|
||||
cnt, sz = DS_STORE[op]
|
||||
for i in range(cnt): lds[addr0+i*sz:addr0+i*sz+sz] = (V[inst.data0 + i] & ((1 << (sz * 8)) - 1)).to_bytes(sz, 'little')
|
||||
elif op in DS_LOAD_2ADDR:
|
||||
# Load two values from addr+offset0*sz and addr+offset1*sz into vdst (B32: 1 dword each, B64: 2 dwords each)
|
||||
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
|
||||
sz = DS_LOAD_2ADDR[op]
|
||||
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
|
||||
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
|
||||
cnt = sz // 4 # 1 for B32, 2 for B64
|
||||
for i in range(cnt): V[vdst + i] = int.from_bytes(lds[addr0+i*4:addr0+i*4+4], 'little')
|
||||
for i in range(cnt): V[vdst + cnt + i] = int.from_bytes(lds[addr1+i*4:addr1+i*4+4], 'little')
|
||||
elif op in DS_STORE_2ADDR:
|
||||
# Store two values from data0 and data1 to addr+offset0*sz and addr+offset1*sz
|
||||
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
|
||||
sz = DS_STORE_2ADDR[op]
|
||||
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
|
||||
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
|
||||
cnt = sz // 4
|
||||
for i in range(cnt): lds[addr0+i*4:addr0+i*4+4] = (V[inst.data0 + i] & MASK32).to_bytes(4, 'little')
|
||||
for i in range(cnt): lds[addr1+i*4:addr1+i*4+4] = (V[inst.data1 + i] & MASK32).to_bytes(4, 'little')
|
||||
else: raise NotImplementedError(f"DS op {op}")
|
||||
return
|
||||
|
||||
# VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes)
|
||||
@@ -594,25 +260,24 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None)
|
||||
vdsty = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1)
|
||||
inputs = [(inst.opx, st.rsrc(inst.srcx0, lane), V[inst.vsrcx1], V[inst.vdstx], inst.vdstx),
|
||||
(inst.opy, st.rsrc(inst.srcy0, lane), V[inst.vsrcy1], V[vdsty], vdsty)]
|
||||
def exec_vopd(vopd_op, s0, s1, d0):
|
||||
op = _VOPD_TO_VOP[vopd_op]
|
||||
return compiled[type(op)][op](Reg(s0), Reg(s1), None, Reg(d0), Reg(st.scc), Reg(st.vcc), lane, Reg(st.exec_mask), st.literal, None)['D0']._val
|
||||
for vopd_op, s0, s1, d0, dst in inputs: V[dst] = exec_vopd(vopd_op, s0, s1, d0)
|
||||
results = [(dst, fn(s0, s1, 0, d0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})['d0'])
|
||||
for vopd_op, s0, s1, d0, dst in inputs if (op := _VOPD_TO_VOP.get(vopd_op)) and (fn := compiled.get(type(op), {}).get(op))]
|
||||
for dst, val in results: V[dst] = val
|
||||
return
|
||||
|
||||
# VOP3SD: has extra scalar dest for carry output
|
||||
if isinstance(inst, VOP3SD):
|
||||
fn = compiled[VOP3SDOp][inst.op]
|
||||
fn = compiled.get(VOP3SDOp, {}).get(inst.op)
|
||||
if fn is None: raise NotImplementedError(f"{inst.op.name} not in pseudocode")
|
||||
# Read sources based on register counts from inst properties
|
||||
def rsrc_n(src, regs): return st.rsrc64(src, lane) if regs == 2 else st.rsrc(src, lane)
|
||||
s0, s1, s2 = rsrc_n(inst.src0, inst.src_regs(0)), rsrc_n(inst.src1, inst.src_regs(1)), rsrc_n(inst.src2, inst.src_regs(2))
|
||||
# Carry-in ops use src2 as carry bitmask instead of VCC
|
||||
vcc = st.rsgpr64(inst.src2) if 'CO_CI' in inst.op_name else st.vcc
|
||||
result = fn(Reg(s0), Reg(s1), Reg(s2), Reg(V[inst.vdst]), Reg(st.scc), Reg(vcc), lane, Reg(st.exec_mask), st.literal, None)
|
||||
d0_val = result['D0']._val
|
||||
V[inst.vdst] = d0_val & MASK32
|
||||
if inst.dst_regs() == 2: V[inst.vdst + 1] = (d0_val >> 32) & MASK32
|
||||
if 'VCC' in result: st.pend_sgpr_lane(inst.sdst, lane, (result['VCC']._val >> lane) & 1)
|
||||
result = fn(s0, s1, s2, V[inst.vdst], st.scc, vcc, lane, st.exec_mask, st.literal, None, {})
|
||||
V[inst.vdst] = result['d0'] & MASK32
|
||||
if result.get('d0_64'): V[inst.vdst + 1] = (result['d0'] >> 32) & MASK32
|
||||
if result.get('vcc_lane') is not None: st.pend_sgpr_lane(inst.sdst, lane, result['vcc_lane'])
|
||||
return
|
||||
|
||||
# Get op enum and sources (None means "no source" for that operand)
|
||||
@@ -652,7 +317,8 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None)
|
||||
if abs_ & (1<<i): srcs[i] = abs(srcs[i])
|
||||
if neg & (1<<i): srcs[i] = -srcs[i]
|
||||
result = srcs[0] * srcs[1] + srcs[2]
|
||||
st.vgpr[lane][inst.vdst] = _i32(result) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result), inst.op == VOP3POp.V_FMA_MIXHI_F16)
|
||||
V = st.vgpr[lane]
|
||||
V[inst.vdst] = _i32(result) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result), inst.op == VOP3POp.V_FMA_MIXHI_F16)
|
||||
return
|
||||
# VOP3P packed ops: opsel selects halves for lo, opsel_hi for hi; neg toggles f16 sign
|
||||
raws = [st.rsrc_f16(inst.src0, lane), st.rsrc_f16(inst.src1, lane), st.rsrc_f16(inst.src2, lane) if inst.src2 is not None else 0]
|
||||
@@ -661,13 +327,15 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None)
|
||||
hi_sels = [opsel_hi & 1, opsel_hi & 2, opsel_hi2]
|
||||
srcs = [((_src16(raws[i], hi_sels[i]) ^ (0x8000 if neg_hi & (1<<i) else 0)) << 16) |
|
||||
(_src16(raws[i], opsel & (1<<i)) ^ (0x8000 if neg & (1<<i) else 0)) for i in range(3)]
|
||||
result = compiled[VOP3POp][inst.op](Reg(srcs[0]), Reg(srcs[1]), Reg(srcs[2]), Reg(0), Reg(st.scc), Reg(st.vcc), lane, Reg(st.exec_mask), st.literal, None)
|
||||
st.vgpr[lane][inst.vdst] = result['D0']._val & MASK32
|
||||
fn = compiled.get(VOP3POp, {}).get(inst.op)
|
||||
if fn is None: raise NotImplementedError(f"{inst.op.name} not in pseudocode")
|
||||
st.vgpr[lane][inst.vdst] = fn(srcs[0], srcs[1], srcs[2], 0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})['d0'] & MASK32
|
||||
return
|
||||
else: raise NotImplementedError(f"Unknown vector type {type(inst)}")
|
||||
|
||||
op_cls = type(inst.op)
|
||||
if (fn := compiled.get(op_cls, {}).get(inst.op)) is None: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
|
||||
fn = compiled.get(op_cls, {}).get(inst.op)
|
||||
if fn is None: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
|
||||
|
||||
# Read sources (with VOP3 modifiers if applicable)
|
||||
neg, abs_ = (getattr(inst, 'neg', 0), getattr(inst, 'abs', 0)) if isinstance(inst, VOP3) else (0, 0)
|
||||
@@ -709,27 +377,24 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None)
|
||||
# Execute compiled function - pass src0_idx and vdst_idx for lane instructions
|
||||
# For VGPR access: src0 index is the VGPR number (src0 - 256 if VGPR, else src0 for SGPR)
|
||||
src0_idx = (src0 - 256) if src0 is not None and src0 >= 256 else (src0 if src0 is not None else 0)
|
||||
result = fn(Reg(s0), Reg(s1), Reg(s2), Reg(d0), Reg(st.scc), Reg(vcc_for_fn), lane, Reg(st.exec_mask), st.literal, st.vgpr, src0_idx, vdst)
|
||||
result = fn(s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, st.literal, st.vgpr, {}, src0_idx, vdst)
|
||||
|
||||
# Apply results - extract values from returned Reg objects
|
||||
# Apply results
|
||||
if 'vgpr_write' in result:
|
||||
# Lane instruction wrote to VGPR: (lane, vgpr_idx, value)
|
||||
wr_lane, wr_idx, wr_val = result['vgpr_write']
|
||||
st.vgpr[wr_lane][wr_idx] = wr_val
|
||||
if 'VCC' in result:
|
||||
if 'vcc_lane' in result:
|
||||
# VOP2 carry ops write to VCC implicitly; VOPC/VOP3 write to vdst
|
||||
st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, (result['VCC']._val >> lane) & 1)
|
||||
if 'EXEC' in result:
|
||||
# V_CMPX instructions write to EXEC per-lane (not to vdst)
|
||||
st.pend_sgpr_lane(EXEC_LO, lane, (result['EXEC']._val >> lane) & 1)
|
||||
elif op_cls is VOPCOp:
|
||||
# VOPC comparison result stored in D0 bitmask, extract lane bit (non-CMPX only)
|
||||
st.pend_sgpr_lane(vdst, lane, (result['D0']._val >> lane) & 1)
|
||||
if op_cls is not VOPCOp and 'vgpr_write' not in result:
|
||||
st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, result['vcc_lane'])
|
||||
if 'exec_lane' in result:
|
||||
# V_CMPX instructions write to EXEC per-lane
|
||||
st.pend_sgpr_lane(EXEC_LO, lane, result['exec_lane'])
|
||||
if 'd0' in result and op_cls is not VOPCOp and 'vgpr_write' not in result:
|
||||
writes_to_sgpr = 'READFIRSTLANE' in inst.op_name or 'READLANE' in inst.op_name
|
||||
d0_val = result['D0']._val
|
||||
d0_val = result['d0']
|
||||
if writes_to_sgpr: st.wsgpr(vdst, d0_val & MASK32)
|
||||
elif inst.dst_regs() == 2: V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
|
||||
elif result.get('d0_64'): V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
|
||||
elif inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi)
|
||||
else: V[vdst] = d0_val & MASK32
|
||||
|
||||
@@ -759,7 +424,7 @@ def exec_wmma(st: WaveState, inst, op: VOP3POp) -> None:
|
||||
# MAIN EXECUTION LOOP
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def step_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int:
|
||||
def step_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) -> int:
|
||||
inst = program.get(st.pc)
|
||||
if inst is None: return 1
|
||||
inst_words, st.literal = inst._words, getattr(inst, '_literal', None) or 0
|
||||
@@ -779,7 +444,7 @@ def step_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int
|
||||
st.pc += inst_words
|
||||
return 0
|
||||
|
||||
def exec_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int:
|
||||
def exec_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) -> int:
|
||||
while st.pc in program:
|
||||
result = step_wave(program, st, lds, n_lanes)
|
||||
if result == -1: return 0
|
||||
@@ -789,7 +454,7 @@ def exec_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int
|
||||
def exec_workgroup(program: Program, workgroup_id: tuple[int, int, int], local_size: tuple[int, int, int], args_ptr: int,
|
||||
wg_id_sgpr_base: int, wg_id_enables: tuple[bool, bool, bool]) -> None:
|
||||
lx, ly, lz = local_size
|
||||
total_threads, lds = lx * ly * lz, LDSMem(bytearray(65536))
|
||||
total_threads, lds = lx * ly * lz, bytearray(65536)
|
||||
waves: list[tuple[WaveState, int, int]] = []
|
||||
for wave_start in range(0, total_threads, WAVE_SIZE):
|
||||
n_lanes, st = min(WAVE_SIZE, total_threads - wave_start), WaveState()
|
||||
|
||||
@@ -8,7 +8,7 @@ import os
|
||||
os.environ["SQTT"] = "1"
|
||||
os.environ["PROFILE"] = "1"
|
||||
os.environ["SQTT_LIMIT_SE"] = "2"
|
||||
os.environ["SQTT_TOKEN_EXCLUDE"] = "3786" # exclude WAVERDY, REG, EVENT, UTILCTR, WAVEALLOC, PERF, ALUEXEC
|
||||
os.environ["SQTT_TOKEN_EXCLUDE"] = "3784" # Exclude WAVERDY, REG, EVENT, UTILCTR, WAVEALLOC, PERF
|
||||
|
||||
import unittest
|
||||
from tinygrad.device import Device
|
||||
@@ -242,7 +242,7 @@ class TestEmulatorSQTT(unittest.TestCase):
|
||||
def test_valu_independent_8(self): self._test_valu_independent_n(8)
|
||||
def test_valu_independent_16(self): self._test_valu_independent_n(16)
|
||||
|
||||
def test_trans_independent_16(self): self._test_valu_independent_n(16, True)
|
||||
def test_trans_independent_16(self): self._test_valu_independent_n(16, trans=True)
|
||||
|
||||
def test_valu_chain(self):
|
||||
"""VALU instructions with chain dependencies."""
|
||||
|
||||
Reference in New Issue
Block a user