diff --git a/extra/assembly/amd/emu.py b/extra/assembly/amd/emu.py index 8b167761b4..b24e904264 100644 --- a/extra/assembly/amd/emu.py +++ b/extra/assembly/amd/emu.py @@ -1,14 +1,12 @@ # RDNA3 emulator - executes compiled pseudocode from AMD ISA PDF # mypy: ignore-errors from __future__ import annotations -import ctypes -from enum import IntEnum -from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64 -from extra.assembly.amd.pcode import Reg +import ctypes, struct +from extra.assembly.amd.dsl import Inst, RawImm, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64 from extra.assembly.amd.asm import detect_format from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD, - SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, SCRATCHOp, VOPDOp) + SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, VOPDOp) Program = dict[int, Inst] WAVE_SIZE, SGPR_COUNT, VGPR_COUNT = 32, 128, 256 @@ -30,402 +28,35 @@ def _dst16(cur: int, val: int, is_hi: bool) -> int: return (cur & 0x0000ffff) | def _vgpr_hi(src: int) -> bool: return src >= 256 and ((src - 256) & 0x80) != 0 def _vgpr_masked(src: int) -> int: return ((src - 256) & 0x7f) + 256 if src >= 256 else src -# Helper: get number of dwords from memory op name -def _op_ndwords(name: str) -> int: - if '_B128' in name: return 4 - if '_B96' in name: return 3 - if any(s in name for s in ('_B64', '_U64', '_I64', '_F64')): return 2 - return 1 - -# Helper: build multi-dword Reg from consecutive VGPRs -def _vgpr_read(V: list, base: int, ndwords: int) -> Reg: return Reg(sum(V[base + i] << (32 * i) for i in range(ndwords))) - -# Helper: write multi-dword value to consecutive VGPRs -def _vgpr_write(V: list, base: int, val: int, ndwords: int): - for i in range(ndwords): V[base + i] = (val >> (32 * i)) & MASK32 - # Memory access _valid_mem_ranges: list[tuple[int, int]] = [] def set_valid_mem_ranges(ranges: set[tuple[int, int]]) -> None: _valid_mem_ranges.clear(); _valid_mem_ranges.extend(ranges) def _mem_valid(addr: int, size: int) -> bool: return not _valid_mem_ranges or any(s <= addr and addr + size <= s + z for s, z in _valid_mem_ranges) -def _ctypes_at(addr: int, size: int): return (ctypes.c_uint8 if size == 1 else ctypes.c_uint16 if size == 2 else ctypes.c_uint64 if size == 8 else ctypes.c_uint32).from_address(addr) +def _ctypes_at(addr: int, size: int): return (ctypes.c_uint8 if size == 1 else ctypes.c_uint16 if size == 2 else ctypes.c_uint32).from_address(addr) def mem_read(addr: int, size: int) -> int: return _ctypes_at(addr, size).value if _mem_valid(addr, size) else 0 def mem_write(addr: int, size: int, val: int) -> None: if _mem_valid(addr, size): _ctypes_at(addr, size).value = val -def _make_mem_accessor(read_fn, write_fn): - """Create a memory accessor class with the given read/write functions.""" - class _MemAccessor: - __slots__ = ('_addr',) - def __init__(self, addr: int): self._addr = int(addr) - u8 = property(lambda s: read_fn(s._addr, 1), lambda s, v: write_fn(s._addr, 1, int(v))) - u16 = property(lambda s: read_fn(s._addr, 2), lambda s, v: write_fn(s._addr, 2, int(v))) - u32 = property(lambda s: read_fn(s._addr, 4), lambda s, v: write_fn(s._addr, 4, int(v))) - u64 = property(lambda s: read_fn(s._addr, 8), lambda s, v: write_fn(s._addr, 8, int(v))) - i8 = property(lambda s: _sext(read_fn(s._addr, 1), 8), lambda s, v: write_fn(s._addr, 1, int(v))) - i16 = property(lambda s: _sext(read_fn(s._addr, 2), 16), lambda s, v: write_fn(s._addr, 2, int(v))) - i32 = property(lambda s: _sext(read_fn(s._addr, 4), 32), lambda s, v: write_fn(s._addr, 4, int(v))) - i64 = property(lambda s: _sext(read_fn(s._addr, 8), 64), lambda s, v: write_fn(s._addr, 8, int(v))) - b8, b16, b32, b64 = u8, u16, u32, u64 - return _MemAccessor - -_GlobalMemAccessor = _make_mem_accessor(mem_read, mem_write) - -class _GlobalMem: - """Global memory wrapper that supports MEM[addr].u32 style access.""" - def __getitem__(self, addr) -> _GlobalMemAccessor: return _GlobalMemAccessor(addr) -GlobalMem = _GlobalMem() - -class LDSMem: - """LDS memory wrapper that supports MEM[addr].u32 style access.""" - __slots__ = ('_lds',) - def __init__(self, lds: bytearray): self._lds = lds - def _read(self, addr: int, size: int) -> int: - addr = addr & 0xffff - return int.from_bytes(self._lds[addr:addr+size], 'little') if addr + size <= len(self._lds) else 0 - def _write(self, addr: int, size: int, val: int): - addr = addr & 0xffff - if addr + size <= len(self._lds): self._lds[addr:addr+size] = (int(val) & ((1 << (size*8)) - 1)).to_bytes(size, 'little') - def __getitem__(self, addr): return _make_mem_accessor(self._read, self._write)(addr) - +# Memory op tables (not pseudocode - these are format descriptions) +def _mem_ops(ops, suffix_map): + return {getattr(e, f"{p}_{s}"): v for e in ops for s, v in suffix_map.items() for p in [e.__name__.replace("Op", "")]} +_LOAD_MAP = {'LOAD_B32': (1,4,0), 'LOAD_B64': (2,4,0), 'LOAD_B96': (3,4,0), 'LOAD_B128': (4,4,0), 'LOAD_U8': (1,1,0), 'LOAD_I8': (1,1,1), 'LOAD_U16': (1,2,0), 'LOAD_I16': (1,2,1)} +_STORE_MAP = {'STORE_B32': (1,4), 'STORE_B64': (2,4), 'STORE_B96': (3,4), 'STORE_B128': (4,4), 'STORE_B8': (1,1), 'STORE_B16': (1,2)} +FLAT_LOAD, FLAT_STORE = _mem_ops([GLOBALOp, FLATOp], _LOAD_MAP), _mem_ops([GLOBALOp, FLATOp], _STORE_MAP) +# D16 ops: load/store 16-bit to lower or upper half of VGPR. Format: (size, sign, hi) where hi=1 means upper 16 bits +_D16_LOAD_MAP = {'LOAD_D16_U8': (1,0,0), 'LOAD_D16_I8': (1,1,0), 'LOAD_D16_B16': (2,0,0), + 'LOAD_D16_HI_U8': (1,0,1), 'LOAD_D16_HI_I8': (1,1,1), 'LOAD_D16_HI_B16': (2,0,1)} +_D16_STORE_MAP = {'STORE_D16_HI_B8': (1,1), 'STORE_D16_HI_B16': (2,1)} # (size, hi) +FLAT_D16_LOAD = _mem_ops([GLOBALOp, FLATOp], _D16_LOAD_MAP) +FLAT_D16_STORE = _mem_ops([GLOBALOp, FLATOp], _D16_STORE_MAP) +DS_LOAD = {DSOp.DS_LOAD_B32: (1,4,0), DSOp.DS_LOAD_B64: (2,4,0), DSOp.DS_LOAD_B128: (4,4,0), DSOp.DS_LOAD_U8: (1,1,0), DSOp.DS_LOAD_I8: (1,1,1), DSOp.DS_LOAD_U16: (1,2,0), DSOp.DS_LOAD_I16: (1,2,1)} +DS_STORE = {DSOp.DS_STORE_B32: (1,4), DSOp.DS_STORE_B64: (2,4), DSOp.DS_STORE_B128: (4,4), DSOp.DS_STORE_B8: (1,1), DSOp.DS_STORE_B16: (1,2)} +# 2ADDR ops: load/store two values using offset0 and offset1 +DS_LOAD_2ADDR = {DSOp.DS_LOAD_2ADDR_B32: 4, DSOp.DS_LOAD_2ADDR_B64: 8} +DS_STORE_2ADDR = {DSOp.DS_STORE_2ADDR_B32: 4, DSOp.DS_STORE_2ADDR_B64: 8} SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16} -# ═══════════════════════════════════════════════════════════════════════════════ -# SQTT TRACING - Emit packets matching real hardware output -# ═══════════════════════════════════════════════════════════════════════════════ - -# Transcendental ops that produce INST packets with op=VALU_TRANS (0x0b) -_TRANS_OPS = {'V_RCP_F32', 'V_RCP_F64', 'V_RSQ_F32', 'V_RSQ_F64', 'V_SQRT_F32', 'V_SQRT_F64', - 'V_LOG_F32', 'V_EXP_F32', 'V_SIN_F32', 'V_COS_F32', 'V_RCP_F16', 'V_RSQ_F16', 'V_SQRT_F16'} - -# Latency model from hardware measurements (warm instruction cache): -# Startup: WAVESTART -> first instruction (~32 cycles with warm cache, no REG packet) -# Cold cache: WAVESTART -> REG (~137 cycles) -> first instruction (~150 cycles after REG) -# SALU: issues every cycle, result ready 2 cycles after issue, ALUEXEC at ready time -# VALU: issues every cycle, ALUEXEC at issue+6 for each inst, serialized with +1 intervals -# TRANS: issues every 4 cycles, ALUEXEC at last_issue+1, then +8 intervals -# For dependent instructions, ALUEXEC is at source_ready + 10 (first dep) or + 9 (chained) -# VALU queue depth ~7: after 7 in-flight, ALUEXEC interleaves with VALUINST -WAVESTART_TO_INST_CYCLES = 32 # cycles from WAVESTART to first instruction (warm cache) -SALU_LATENCY = 2 # cycles from issue to result ready -VALU_EXEC_LATENCY = 6 # cycles from first issue to first ALUEXEC -TRANS_ISSUE_CYCLES = 4 # cycles between transcendental instruction issues -TRANS_LATENCY = 9 # cycles from trans issue to ALUEXEC - -class SQTTState: - """SQTT tracing state - emits packets matching real hardware (warm cache model).""" - __slots__ = ('cycle', 'packets', 'pending_exec', 'wave_id', 'simd', 'cu', 'sgpr_ready', 'vgpr_ready', - 'last_salu_exec', 'last_valu_exec', 'trans_count', 'last_trans_issue', 'last_trans_exec', 'first_valu_issue', 'valu_count', 'last_inst_type') - - def __init__(self, wave_id: int = 0, simd: int = 0, cu: int = 0): - self.cycle = 0 - self.packets: list = [] - self.pending_exec: list[tuple[int, int]] = [] # (completion_cycle, src) for ALUEXEC - self.wave_id, self.simd, self.cu = wave_id, simd, cu - self.sgpr_ready: dict[int, int] = {} # sgpr -> cycle when result ready - self.last_inst_type: str = '' # track last instruction type for gap calculation - self.vgpr_ready: dict[int, int] = {} # vgpr -> cycle when result ready - self.last_salu_exec = 0 # last SALU ALUEXEC time (for +1 spacing) - self.last_valu_exec = 0 # last VALU ALUEXEC time (for +1 spacing) - self.trans_count = 0 # number of trans instructions issued - self.last_trans_issue = 0 # cycle when last trans was issued - self.last_trans_exec = 0 # cycle when last trans ALUEXEC is scheduled - self.first_valu_issue = 0 # cycle when first VALU was issued - self.valu_count = 0 # number of VALU instructions issued - - def emit_wavestart(self): - from extra.assembly.amd.sqtt import WAVESTART - self.packets.append(WAVESTART(_time=self.cycle, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3)) - self.cycle += WAVESTART_TO_INST_CYCLES - - def emit_waveend(self): - from extra.assembly.amd.sqtt import WAVEEND - self.packets.append(WAVEEND(_time=self.cycle, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3)) - - def emit_inst(self, inst_op: int): - from extra.assembly.amd.sqtt import INST - self.packets.append(INST(_time=self.cycle, wave=self.wave_id, op=inst_op)) - - def emit_valuinst(self): - from extra.assembly.amd.sqtt import VALUINST - self.packets.append(VALUINST(_time=self.cycle, wave=self.wave_id)) - - def emit_aluexec(self, src: int): - from extra.assembly.amd.sqtt import ALUEXEC - self.packets.append(ALUEXEC(_time=self.cycle, src=src)) - - def emit_immediate(self): - from extra.assembly.amd.sqtt import IMMEDIATE - self.packets.append(IMMEDIATE(_time=self.cycle, wave=self.wave_id)) - - def _get_src_regs(self, inst: Inst) -> list[tuple[str, int]]: - """Extract source register references from instruction.""" - srcs = [] - if isinstance(inst, (SOP1, SOP2, SOPC)): - if hasattr(inst, 'ssrc0') and inst.ssrc0 < SGPR_COUNT: srcs.append(('s', inst.ssrc0)) - if hasattr(inst, 'ssrc1') and inst.ssrc1 < SGPR_COUNT: srcs.append(('s', inst.ssrc1)) - elif isinstance(inst, SOPK): - pass # immediate only - elif isinstance(inst, VOP1): - # VOP1: only src0 - if inst.src0 < SGPR_COUNT: srcs.append(('s', inst.src0)) - elif inst.src0 >= 256: srcs.append(('v', inst.src0 - 256)) - elif isinstance(inst, VOP2): - # VOP2: src0 (can be SGPR/VGPR/const) + vsrc1 (always VGPR) - if inst.src0 < SGPR_COUNT: srcs.append(('s', inst.src0)) - elif inst.src0 >= 256: srcs.append(('v', inst.src0 - 256)) - srcs.append(('v', inst.vsrc1)) - elif isinstance(inst, (VOP3, VOP3SD, VOP3P, VOPC)): - for attr in ('src0', 'src1', 'src2'): - if hasattr(inst, attr): - src = getattr(inst, attr) - if src < SGPR_COUNT: srcs.append(('s', src)) - elif src >= 256: srcs.append(('v', src - 256)) - return srcs - - def _get_dst_regs(self, inst: Inst) -> list[tuple[str, int]]: - """Extract destination register references from instruction.""" - dsts = [] - if isinstance(inst, (SOP1, SOP2, SOPK)): - if hasattr(inst, 'sdst') and inst.sdst < SGPR_COUNT: - dsts.append(('s', inst.sdst)) - if inst.dst_regs() == 2: dsts.append(('s', inst.sdst + 1)) - elif isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD)): - if hasattr(inst, 'vdst'): - dsts.append(('v', inst.vdst)) - if inst.dst_regs() == 2: dsts.append(('v', inst.vdst + 1)) - return dsts - - def _check_dep_stall(self, inst: Inst) -> int: - """Calculate stall cycles due to RAW dependencies - wait until all sources ready.""" - max_ready = 0 - for typ, reg in self._get_src_regs(inst): - ready = (self.sgpr_ready if typ == 's' else self.vgpr_ready).get(reg, 0) - max_ready = max(max_ready, ready) - return max(0, max_ready - self.cycle) - - def _record_dst_ready(self, inst: Inst, ready_cycle: int): - """Record when destination registers will be ready.""" - for typ, reg in self._get_dst_regs(inst): - (self.sgpr_ready if typ == 's' else self.vgpr_ready)[reg] = ready_cycle - - def trace_inst(self, inst: Inst): - """Emit appropriate SQTT packets for an instruction. - - Key insight from hardware traces: - - Instructions issue every cycle (INST/VALUINST packets at +1 intervals) - - ALUEXEC shows when instruction completes (result ready) - - For independent ops: ALUEXEC at issue+latency, then +1 each (pipeline throughput) - - For dependent ops: ALUEXEC delayed until source ready + latency - """ - from extra.assembly.amd.sqtt import InstOp, AluSrc - - if isinstance(inst, (SOP1, SOP2, SOPC, SOPK)): - # SALU: issue now, complete after latency (or wait for deps) - self.emit_inst(InstOp.SALU) - # Completion time: max of (issue + latency) and (source_ready + latency) and (last_exec + 1) - src_ready = max((self.sgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst)), default=0) - ready_cycle = max(self.cycle, src_ready) + SALU_LATENCY - exec_cycle = max(ready_cycle, self.last_salu_exec + 1) - self.pending_exec.append((exec_cycle, AluSrc.SALU)) - self.last_salu_exec = exec_cycle - self._record_dst_ready(inst, ready_cycle) - self.last_inst_type = 'SALU' - - elif isinstance(inst, SOPP): - # s_nop and other SOPP emit IMMEDIATE packets - if self.last_inst_type == 'SALU': - # SALU→SOPP: emit ALUEXECs with max gap of 2 once, then IMMEDIATEs - # First ALUEXEC at last INST cycle, then allow one +2 gap, rest interleave - # cycle is currently 1 past last INST - last_inst_cycle = self.cycle - 1 - first = True - had_gap = False # track if we've had a +2 gap - while self.pending_exec: - exec_cycle, src = self.pending_exec[0] - if first: - self.pending_exec.pop(0) - self.cycle = last_inst_cycle # First ALUEXEC at same cycle as last INST - first = False - elif exec_cycle == self.cycle + 1: # consecutive (+1) - self.pending_exec.pop(0) - self.cycle = exec_cycle - elif exec_cycle == self.cycle + 2 and not had_gap: # allow one +2 gap - self.pending_exec.pop(0) - self.cycle = exec_cycle - had_gap = True - else: - break # too large a gap or second +2 gap - rest interleave with IMMEDIATEs - self.emit_aluexec(src) - self.cycle += 1 # IMMEDIATE starts +1 after last ALUEXEC - elif self.last_inst_type == 'TRANS': - # TRANS→SOPP: emit first ALUEXEC at its scheduled time, then IMMEDIATE +2 after - # Sort pending execs and emit first one if it's before IMMEDIATE time - self.pending_exec.sort(key=lambda x: x[0]) - if self.pending_exec: - first_exec_cycle = self.pending_exec[0][0] - # First IMMEDIATE at first_exec + 2 - imm_cycle = first_exec_cycle + 2 - # Emit ALUEXECs that come before IMMEDIATE - while self.pending_exec and self.pending_exec[0][0] < imm_cycle: - exec_cycle, src = self.pending_exec.pop(0) - self.cycle = exec_cycle - self.emit_aluexec(src) - self.cycle = imm_cycle - elif self.last_inst_type == 'VALU': - # VALU→SOPP: 2-cycle gap, emit any ALUEXECs that would come before - imm_cycle = self.cycle + 2 # When IMMEDIATE would be emitted - # Emit ALUEXECs that come before the IMMEDIATE - while self.pending_exec and self.pending_exec[0][0] < imm_cycle: - exec_cycle, src = self.pending_exec.pop(0) - self.cycle = max(self.cycle, exec_cycle) - self.emit_aluexec(src) - # Jump to IMMEDIATE cycle (either imm_cycle or 1 after last ALUEXEC, whichever is later) - self.cycle = max(imm_cycle, self.cycle + 1) - # Emit IMMEDIATE first, then any pending ALUEXECs at same cycle (HW order) - self.emit_immediate() - while self.pending_exec and self.pending_exec[0][0] <= self.cycle: - exec_cycle, src = self.pending_exec.pop(0) - old_cycle = self.cycle - self.cycle = exec_cycle - self.emit_aluexec(src) - self.cycle = old_cycle - # s_nop(N) waits N+1 cycles - apply delay for N > 0 (simm16 field) - if inst.op == SOPPOp.S_NOP and inst.simm16 > 0: - self.cycle += inst.simm16 # add delay before next instruction - self.last_inst_type = 'SOPP' - - elif isinstance(inst, SMEM): - pass # skip for ALU focus - - elif isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC)): - # VALU: issue now, emit ALUEXEC for completed instructions if queue is full - from extra.assembly.amd.sqtt import AluSrc - - op_name = inst.op_name if hasattr(inst, 'op_name') else '' - is_trans = any(t in op_name for t in _TRANS_OPS) - - if is_trans: - # Transcendental: emit INST, 4-cycle issue, add ALUEXEC to pending - self.emit_inst(InstOp.VALU_TRANS) - self.trans_count += 1 - # Check for dependency on VGPR sources - src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0) - if src_ready > self.cycle: - # Dependent trans: ALUEXEC at source_ready + 6 (for VALU source) or +10 (for trans source) - # Check if source is from a trans instruction (its ready time would be > issue + 6) - src_is_trans = any(self.vgpr_ready.get(r, 0) > self.cycle + 6 for _, r in self._get_src_regs(inst) if _ == 'v') - exec_cycle = src_ready + (10 if src_is_trans else 6) - else: - # Independent trans: ALUEXEC at issue + 9 - exec_cycle = self.cycle + TRANS_LATENCY - self.pending_exec.append((exec_cycle, AluSrc.VALU)) - self._record_dst_ready(inst, exec_cycle) # Record when this trans result is ready - self.last_trans_exec = exec_cycle - self.last_trans_issue = self.cycle - # Trans instructions take 4 cycles to issue - self.cycle += TRANS_ISSUE_CYCLES - 1 # -1 because we add 1 at end of trace_inst - self.last_inst_type = 'TRANS' - else: - # Regular VALU: emit VALUINST, may interleave ALUEXEC - self.emit_valuinst() - - # Track first VALU issue for latency calculation - if self.first_valu_issue == 0: - self.first_valu_issue = self.cycle - - # Emit any pending ALUEXECs that have completed by now (after VALUINST) - while self.pending_exec and self.pending_exec[0][0] <= self.cycle: - exec_cycle, src = self.pending_exec.pop(0) - old_cycle = self.cycle - self.cycle = exec_cycle - self.emit_aluexec(src) - self.cycle = old_cycle - - # Check for dependency - src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0) - has_dep = src_ready > self.cycle - - if has_dep: - # Dependent: first dependent is +6 from source, subsequent chained deps are +5 - # "Chained" means the source was also dependent (not the first independent VALU) - chained = src_ready > self.first_valu_issue + VALU_EXEC_LATENCY - exec_cycle = src_ready + (5 if chained else 6) - else: - # Independent VALU timing: 6-cycle pipeline latency - # exec[i] = first_issue + 6 + i, with +1 spacing between consecutive execs - exec_cycle = self.first_valu_issue + VALU_EXEC_LATENCY + self.valu_count - exec_cycle = max(exec_cycle, self.last_valu_exec + 1) # +1 spacing - - # Add to pending exec queue (sorted by time) - self.pending_exec.append((exec_cycle, AluSrc.VALU)) - self.pending_exec.sort(key=lambda x: x[0]) - self.last_valu_exec = exec_cycle - self.valu_count += 1 - self._record_dst_ready(inst, exec_cycle) - self.last_inst_type = 'VALU' - - elif isinstance(inst, VOPD): - from extra.assembly.amd.sqtt import AluSrc - - # First emit VALUINST (HW emits issue before completion at same cycle) - self.emit_valuinst() - - # Emit any pending ALUEXECs that have completed by now - while self.pending_exec and self.pending_exec[0][0] <= self.cycle: - exec_cycle, src = self.pending_exec.pop(0) - old_cycle = self.cycle - self.cycle = exec_cycle - self.emit_aluexec(src) - self.cycle = old_cycle - src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0) - has_dep = src_ready > self.cycle - - if has_dep: - exec_cycle = src_ready + 10 - else: - exec_cycle = self.cycle + VALU_EXEC_LATENCY - if self.last_valu_exec > 0: - exec_cycle = max(exec_cycle, self.last_valu_exec + 1) - - self.pending_exec.append((exec_cycle, AluSrc.VALU)) - self.pending_exec.sort(key=lambda x: x[0]) - self.last_valu_exec = exec_cycle - self._record_dst_ready(inst, exec_cycle) - - self.cycle += 1 - - def finalize(self): - """Emit all remaining pending ALUEXEC packets and WAVEEND.""" - from extra.assembly.amd.sqtt import AluSrc - - # Emit any remaining pending ALUEXECs - self.pending_exec.sort(key=lambda x: x[0]) - last_src = None - for exec_cycle, src in self.pending_exec: - self.cycle = exec_cycle - self.emit_aluexec(src) - last_src = src - self.pending_exec.clear() - - # WAVEEND timing depends on what comes last - # If last instruction was SOPP: +1 normally, but +11 if last ALUEXEC was recent (VALU drain) - # Otherwise: 14 cycles for no ALU, 20 for trans, 14/15 for SALU/VALU - if self.last_inst_type == 'SOPP': - # Check if last ALUEXEC was recent (within ~5 cycles of current) - if self.valu_count > 0 and self.last_valu_exec >= self.cycle - 5: - self.cycle += 11 # VALU drain time - else: - self.cycle += 1 - elif last_src is None: - self.cycle += 14 # empty program or no ALU ops - elif self.trans_count > 0: - self.cycle += 20 # trans has longer WAVEEND delay - else: - self.cycle += 15 if last_src == AluSrc.VALU else 14 - self.emit_waveend() - # VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup) _VOPD_TO_VOP = { VOPDOp.V_DUAL_FMAC_F32: VOP3Op.V_FMAC_F32, VOPDOp.V_DUAL_FMAAK_F32: VOP2Op.V_FMAAK_F32, VOPDOp.V_DUAL_FMAMK_F32: VOP2Op.V_FMAMK_F32, @@ -547,46 +178,81 @@ def exec_scalar(st: WaveState, inst: Inst) -> int: s0 = st.rsrc64(ssrc0, 0) if inst.is_src_64(0) else (st.rsrc(ssrc0, 0) if not isinstance(inst, (SOPK, SOPP)) else (st.rsgpr(inst.sdst) if isinstance(inst, SOPK) else 0)) s1 = st.rsrc64(inst.ssrc1, 0) if inst.is_src_64(1) else (st.rsrc(inst.ssrc1, 0) if isinstance(inst, (SOP2, SOPC)) else inst.simm16 if isinstance(inst, SOPK) else 0) d0 = st.rsgpr64(sdst) if inst.dst_regs() == 2 and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0) + exec_mask = st.exec_mask literal = inst.simm16 if isinstance(inst, (SOPK, SOPP)) else st.literal - # Create Reg objects for compiled function - mask VCC/EXEC to 32 bits for wave32 - result = fn(Reg(s0), Reg(s1), None, Reg(d0), Reg(st.scc), Reg(st.vcc & MASK32), 0, Reg(st.exec_mask & MASK32), literal, None, PC=Reg(st.pc * 4)) + # Execute compiled function - pass PC in bytes for instructions that need it + # For wave32, mask VCC and EXEC to 32 bits since only the lower 32 bits are relevant + pc_bytes = st.pc * 4 + vcc32, exec32 = st.vcc & MASK32, exec_mask & MASK32 + result = fn(s0, s1, 0, d0, st.scc, vcc32, 0, exec32, literal, None, {}, pc=pc_bytes) - # Apply results - extract values from returned Reg objects - if sdst is not None and 'D0' in result: - (st.wsgpr64 if inst.dst_regs() == 2 else st.wsgpr)(sdst, result['D0']._val) - if 'SCC' in result: st.scc = result['SCC']._val & 1 - if 'EXEC' in result: st.exec_mask = result['EXEC']._val - if 'PC' in result: + # Apply results + if sdst is not None: + (st.wsgpr64 if result.get('d0_64') else st.wsgpr)(sdst, result['d0']) + if 'scc' in result: st.scc = result['scc'] + if 'exec' in result: st.exec_mask = result['exec'] + if 'new_pc' in result: # Convert absolute byte address to word delta - pc_val = result['PC']._val - new_pc = pc_val if pc_val < 0x8000000000000000 else pc_val - 0x10000000000000000 - new_pc_words = new_pc // 4 + # new_pc is where we want to go, st.pc is current position, inst._words will be added after + new_pc_words = result['new_pc'] // 4 return new_pc_words - st.pc - 1 # -1 because emulator adds inst_words (1 for scalar) return 0 -def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None) -> None: +def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = None) -> None: """Execute vector instruction for one lane.""" compiled = _get_compiled() V = st.vgpr[lane] - # Memory ops (FLAT/GLOBAL/SCRATCH and DS) - use generated pcode - if isinstance(inst, (FLAT, DS)): - op, vdst, op_name = inst.op, inst.vdst, inst.op.name - fn, ndwords = compiled[type(op)][op], _op_ndwords(op_name) - if isinstance(inst, FLAT): - addr = V[inst.addr] | (V[inst.addr + 1] << 32) - ADDR = (st.rsgpr64(inst.saddr) + V[inst.addr] + _sext(inst.offset, 13)) & MASK64 if inst.saddr not in (NULL, 0x7f) else (addr + _sext(inst.offset, 13)) & MASK64 - # For loads, VDATA comes from vdst (preserves unwritten bits); for stores, from inst.data - vdata_src = vdst if 'LOAD' in op_name else inst.data - result = fn(GlobalMem, ADDR, _vgpr_read(V, vdata_src, ndwords), Reg(V[vdst]), Reg(0)) - if 'VDATA' in result: _vgpr_write(V, vdst, result['VDATA']._val, ndwords) - if 'RETURN_DATA' in result: _vgpr_write(V, vdst, result['RETURN_DATA']._val, ndwords) - else: # DS - DATA0, DATA1 = _vgpr_read(V, inst.data0, ndwords), _vgpr_read(V, inst.data1, ndwords) if inst.data1 is not None else Reg(0) - result = fn(lds, Reg(V[inst.addr]), DATA0, DATA1, Reg(inst.offset0), Reg(inst.offset1), Reg(0)) - if 'RETURN_DATA' in result and ('_RTN' in op_name or '_LOAD' in op_name): - _vgpr_write(V, vdst, result['RETURN_DATA']._val, ndwords * 2 if '_2ADDR_' in op_name else ndwords) + # Memory ops (not ALU pseudocode) + if isinstance(inst, FLAT): + op, addr_reg, data_reg, vdst, offset, saddr = inst.op, inst.addr, inst.data, inst.vdst, _sext(inst.offset, 13), inst.saddr + addr = V[addr_reg] | (V[addr_reg+1] << 32) + addr = (st.rsgpr64(saddr) + V[addr_reg] + offset) & MASK64 if saddr not in (NULL, 0x7f) else (addr + offset) & MASK64 + if op in FLAT_LOAD: + cnt, sz, sign = FLAT_LOAD[op] + for i in range(cnt): val = mem_read(addr + i * sz, sz); V[vdst + i] = _sext(val, sz * 8) & MASK32 if sign else val + elif op in FLAT_STORE: + cnt, sz = FLAT_STORE[op] + for i in range(cnt): mem_write(addr + i * sz, sz, V[data_reg + i] & ((1 << (sz * 8)) - 1)) + elif op in FLAT_D16_LOAD: + sz, sign, hi = FLAT_D16_LOAD[op] + val = mem_read(addr, sz) + if sign: val = _sext(val, sz * 8) & 0xffff + V[vdst] = _dst16(V[vdst], val, hi) + elif op in FLAT_D16_STORE: + sz, hi = FLAT_D16_STORE[op] + mem_write(addr, sz, _src16(V[data_reg], hi) & ((1 << (sz * 8)) - 1)) + else: raise NotImplementedError(f"FLAT op {op}") + return + + if isinstance(inst, DS): + op, addr0, vdst = inst.op, (V[inst.addr] + inst.offset0) & 0xffff, inst.vdst + if op in DS_LOAD: + cnt, sz, sign = DS_LOAD[op] + for i in range(cnt): val = int.from_bytes(lds[addr0+i*sz:addr0+i*sz+sz], 'little'); V[vdst + i] = _sext(val, sz * 8) & MASK32 if sign else val + elif op in DS_STORE: + cnt, sz = DS_STORE[op] + for i in range(cnt): lds[addr0+i*sz:addr0+i*sz+sz] = (V[inst.data0 + i] & ((1 << (sz * 8)) - 1)).to_bytes(sz, 'little') + elif op in DS_LOAD_2ADDR: + # Load two values from addr+offset0*sz and addr+offset1*sz into vdst (B32: 1 dword each, B64: 2 dwords each) + # Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA + sz = DS_LOAD_2ADDR[op] + addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff + addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff + cnt = sz // 4 # 1 for B32, 2 for B64 + for i in range(cnt): V[vdst + i] = int.from_bytes(lds[addr0+i*4:addr0+i*4+4], 'little') + for i in range(cnt): V[vdst + cnt + i] = int.from_bytes(lds[addr1+i*4:addr1+i*4+4], 'little') + elif op in DS_STORE_2ADDR: + # Store two values from data0 and data1 to addr+offset0*sz and addr+offset1*sz + # Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA + sz = DS_STORE_2ADDR[op] + addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff + addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff + cnt = sz // 4 + for i in range(cnt): lds[addr0+i*4:addr0+i*4+4] = (V[inst.data0 + i] & MASK32).to_bytes(4, 'little') + for i in range(cnt): lds[addr1+i*4:addr1+i*4+4] = (V[inst.data1 + i] & MASK32).to_bytes(4, 'little') + else: raise NotImplementedError(f"DS op {op}") return # VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes) @@ -594,25 +260,24 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None) vdsty = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1) inputs = [(inst.opx, st.rsrc(inst.srcx0, lane), V[inst.vsrcx1], V[inst.vdstx], inst.vdstx), (inst.opy, st.rsrc(inst.srcy0, lane), V[inst.vsrcy1], V[vdsty], vdsty)] - def exec_vopd(vopd_op, s0, s1, d0): - op = _VOPD_TO_VOP[vopd_op] - return compiled[type(op)][op](Reg(s0), Reg(s1), None, Reg(d0), Reg(st.scc), Reg(st.vcc), lane, Reg(st.exec_mask), st.literal, None)['D0']._val - for vopd_op, s0, s1, d0, dst in inputs: V[dst] = exec_vopd(vopd_op, s0, s1, d0) + results = [(dst, fn(s0, s1, 0, d0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})['d0']) + for vopd_op, s0, s1, d0, dst in inputs if (op := _VOPD_TO_VOP.get(vopd_op)) and (fn := compiled.get(type(op), {}).get(op))] + for dst, val in results: V[dst] = val return # VOP3SD: has extra scalar dest for carry output if isinstance(inst, VOP3SD): - fn = compiled[VOP3SDOp][inst.op] + fn = compiled.get(VOP3SDOp, {}).get(inst.op) + if fn is None: raise NotImplementedError(f"{inst.op.name} not in pseudocode") # Read sources based on register counts from inst properties def rsrc_n(src, regs): return st.rsrc64(src, lane) if regs == 2 else st.rsrc(src, lane) s0, s1, s2 = rsrc_n(inst.src0, inst.src_regs(0)), rsrc_n(inst.src1, inst.src_regs(1)), rsrc_n(inst.src2, inst.src_regs(2)) # Carry-in ops use src2 as carry bitmask instead of VCC vcc = st.rsgpr64(inst.src2) if 'CO_CI' in inst.op_name else st.vcc - result = fn(Reg(s0), Reg(s1), Reg(s2), Reg(V[inst.vdst]), Reg(st.scc), Reg(vcc), lane, Reg(st.exec_mask), st.literal, None) - d0_val = result['D0']._val - V[inst.vdst] = d0_val & MASK32 - if inst.dst_regs() == 2: V[inst.vdst + 1] = (d0_val >> 32) & MASK32 - if 'VCC' in result: st.pend_sgpr_lane(inst.sdst, lane, (result['VCC']._val >> lane) & 1) + result = fn(s0, s1, s2, V[inst.vdst], st.scc, vcc, lane, st.exec_mask, st.literal, None, {}) + V[inst.vdst] = result['d0'] & MASK32 + if result.get('d0_64'): V[inst.vdst + 1] = (result['d0'] >> 32) & MASK32 + if result.get('vcc_lane') is not None: st.pend_sgpr_lane(inst.sdst, lane, result['vcc_lane']) return # Get op enum and sources (None means "no source" for that operand) @@ -652,7 +317,8 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None) if abs_ & (1<= 256 else (src0 if src0 is not None else 0) - result = fn(Reg(s0), Reg(s1), Reg(s2), Reg(d0), Reg(st.scc), Reg(vcc_for_fn), lane, Reg(st.exec_mask), st.literal, st.vgpr, src0_idx, vdst) + result = fn(s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, st.literal, st.vgpr, {}, src0_idx, vdst) - # Apply results - extract values from returned Reg objects + # Apply results if 'vgpr_write' in result: # Lane instruction wrote to VGPR: (lane, vgpr_idx, value) wr_lane, wr_idx, wr_val = result['vgpr_write'] st.vgpr[wr_lane][wr_idx] = wr_val - if 'VCC' in result: + if 'vcc_lane' in result: # VOP2 carry ops write to VCC implicitly; VOPC/VOP3 write to vdst - st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, (result['VCC']._val >> lane) & 1) - if 'EXEC' in result: - # V_CMPX instructions write to EXEC per-lane (not to vdst) - st.pend_sgpr_lane(EXEC_LO, lane, (result['EXEC']._val >> lane) & 1) - elif op_cls is VOPCOp: - # VOPC comparison result stored in D0 bitmask, extract lane bit (non-CMPX only) - st.pend_sgpr_lane(vdst, lane, (result['D0']._val >> lane) & 1) - if op_cls is not VOPCOp and 'vgpr_write' not in result: + st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, result['vcc_lane']) + if 'exec_lane' in result: + # V_CMPX instructions write to EXEC per-lane + st.pend_sgpr_lane(EXEC_LO, lane, result['exec_lane']) + if 'd0' in result and op_cls is not VOPCOp and 'vgpr_write' not in result: writes_to_sgpr = 'READFIRSTLANE' in inst.op_name or 'READLANE' in inst.op_name - d0_val = result['D0']._val + d0_val = result['d0'] if writes_to_sgpr: st.wsgpr(vdst, d0_val & MASK32) - elif inst.dst_regs() == 2: V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32 + elif result.get('d0_64'): V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32 elif inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi) else: V[vdst] = d0_val & MASK32 @@ -759,7 +424,7 @@ def exec_wmma(st: WaveState, inst, op: VOP3POp) -> None: # MAIN EXECUTION LOOP # ═══════════════════════════════════════════════════════════════════════════════ -def step_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int: +def step_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) -> int: inst = program.get(st.pc) if inst is None: return 1 inst_words, st.literal = inst._words, getattr(inst, '_literal', None) or 0 @@ -779,7 +444,7 @@ def step_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int st.pc += inst_words return 0 -def exec_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int: +def exec_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) -> int: while st.pc in program: result = step_wave(program, st, lds, n_lanes) if result == -1: return 0 @@ -789,7 +454,7 @@ def exec_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int def exec_workgroup(program: Program, workgroup_id: tuple[int, int, int], local_size: tuple[int, int, int], args_ptr: int, wg_id_sgpr_base: int, wg_id_enables: tuple[bool, bool, bool]) -> None: lx, ly, lz = local_size - total_threads, lds = lx * ly * lz, LDSMem(bytearray(65536)) + total_threads, lds = lx * ly * lz, bytearray(65536) waves: list[tuple[WaveState, int, int]] = [] for wave_start in range(0, total_threads, WAVE_SIZE): n_lanes, st = min(WAVE_SIZE, total_threads - wave_start), WaveState() diff --git a/extra/assembly/amd/test/test_sqtt_compare.py b/extra/assembly/amd/test/test_sqtt_compare.py index ca11688d8d..631ca3782e 100644 --- a/extra/assembly/amd/test/test_sqtt_compare.py +++ b/extra/assembly/amd/test/test_sqtt_compare.py @@ -8,7 +8,7 @@ import os os.environ["SQTT"] = "1" os.environ["PROFILE"] = "1" os.environ["SQTT_LIMIT_SE"] = "2" -os.environ["SQTT_TOKEN_EXCLUDE"] = "3786" # exclude WAVERDY, REG, EVENT, UTILCTR, WAVEALLOC, PERF, ALUEXEC +os.environ["SQTT_TOKEN_EXCLUDE"] = "3784" # Exclude WAVERDY, REG, EVENT, UTILCTR, WAVEALLOC, PERF import unittest from tinygrad.device import Device @@ -242,7 +242,7 @@ class TestEmulatorSQTT(unittest.TestCase): def test_valu_independent_8(self): self._test_valu_independent_n(8) def test_valu_independent_16(self): self._test_valu_independent_n(16) - def test_trans_independent_16(self): self._test_valu_independent_n(16, True) + def test_trans_independent_16(self): self._test_valu_independent_n(16, trans=True) def test_valu_chain(self): """VALU instructions with chain dependencies."""