From 672008ccabc76ea21597ed0b7516e1ba00020c68 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Fri, 2 Jan 2026 11:31:41 -0800 Subject: [PATCH] framework --- extra/assembly/amd/emu.py | 69 ++++++++++++--------------------------- 1 file changed, 20 insertions(+), 49 deletions(-) diff --git a/extra/assembly/amd/emu.py b/extra/assembly/amd/emu.py index b346d1baf1..954cb3b946 100644 --- a/extra/assembly/amd/emu.py +++ b/extra/assembly/amd/emu.py @@ -431,12 +431,20 @@ TRANS_ISSUE_CYCLES = 4 DP_ISSUE_CYCLES = 32 SALU_ISSUE_CYCLES = 1 -# ALU latencies (cycles from dispatch to result ready) +# ALU latencies (cycles from dispatch to result ready / ALUEXEC) VALU_LATENCY = 6 SALU_LATENCY = 2 TRANS_LATENCY = 9 DP_LATENCY = 38 +# Pipeline delay from last ALU dispatch to first s_nop IMMEDIATE +SNOP_PIPELINE_DELAY = 3 + +# Forwarding latencies (cycles until result available for dependent instruction) +VALU_FORWARD_LATENCY = 5 # result available 5 cycles after dispatch (writeback at 6) +TRANS_FORWARD_LATENCY = 13 # result available 13 cycles after dispatch +SALU_FORWARD_LATENCY = 1 # result available 1 cycle after dispatch (writeback at 2) + # Transcendental ops (use TRANS unit) _TRANS_OPS = {'V_RCP_F32', 'V_RCP_F64', 'V_RSQ_F32', 'V_RSQ_F64', 'V_SQRT_F32', 'V_SQRT_F64', 'V_LOG_F32', 'V_EXP_F32', 'V_SIN_F32', 'V_COS_F32', 'V_RCP_F16', 'V_RSQ_F16', 'V_SQRT_F16'} @@ -450,18 +458,11 @@ _DP_OPS = {'V_ADD_F64', 'V_MUL_F64', 'V_FMA_F64', 'V_DIV_F64', 'V_MIN_F64', 'V_M class SQTTState: """SQTT tracing state - emits packets when instructions dispatch.""" - __slots__ = ('packets', 'wave_id', 'simd', 'cu', 'cycle', 'vgpr_ready', 'sgpr_ready', - 'pending_completions', 'trans_busy_until', 'dp_busy_until') def __init__(self, wave_id: int = 0, simd: int = 0, cu: int = 0): self.packets = [] self.wave_id, self.simd, self.cu = wave_id, simd, cu self.cycle = 0 - self.vgpr_ready = {} # vgpr -> cycle when result ready - self.sgpr_ready = {} # sgpr -> cycle when result ready - self.pending_completions = [] # list of (cycle, AluSrc) - self.trans_busy_until = 0 # when TRANS unit is free - self.dp_busy_until = 0 # when DP unit is free def emit(self, pkt_class, **kwargs): self.packets.append(pkt_class(_time=self.cycle, **kwargs)) @@ -469,53 +470,24 @@ class SQTTState: def emit_wavestart(self): from extra.assembly.amd.sqtt import WAVESTART self.emit(WAVESTART, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3) - self.cycle = WAVESTART_TO_INST_CYCLES + for _ in range(WAVESTART_TO_INST_CYCLES): + self.tick() def emit_waveend(self): from extra.assembly.amd.sqtt import WAVEEND self.emit(WAVEEND, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3) - def emit_dispatch_inst(self, inst: Inst): - """Emit SQTT packet for instruction dispatch with proper timing.""" - from extra.assembly.amd.sqtt import INST, VALUINST, IMMEDIATE, InstOp + def tick(self): + """Process one cycle: emit any completing ALUEXECs, then advance cycle.""" + self.cycle += 1 - if isinstance(inst, (SOP1, SOP2, SOPC, SOPK)): - # SALU: 1-cycle issue - self.emit(INST, wave=self.wave_id, op=InstOp.SALU) - self.cycle += SALU_ISSUE_CYCLES - - elif isinstance(inst, SOPP): - if inst.op == SOPPOp.S_NOP: - # s_nop emits IMMEDIATE and delays - self.emit(IMMEDIATE, wave=self.wave_id) - self.cycle += inst.simm16 + 1 - # Other SOPP (s_endpgm, etc.) - no packet - - elif isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD)): - op_name = inst.op_name if hasattr(inst, 'op_name') else '' - - if any(t in op_name for t in _TRANS_OPS): - # TRANS: wait for unit, 4-cycle issue interval - self.cycle = max(self.cycle, self.trans_busy_until) - self.emit(INST, wave=self.wave_id, op=InstOp.VALU_TRANS) - self.trans_busy_until = self.cycle + TRANS_ISSUE_CYCLES - self.cycle += 1 # next instruction can issue next cycle (different unit) - - elif any(t in op_name for t in _DP_OPS): - # DP: wait for unit, 32-cycle issue interval - self.cycle = max(self.cycle, self.dp_busy_until) - self.emit(INST, wave=self.wave_id, op=InstOp.VALU_64) - self.dp_busy_until = self.cycle + DP_ISSUE_CYCLES - self.cycle += 1 - - else: - # Regular VALU: 1-cycle issue - self.emit(VALUINST, wave=self.wave_id) - self.cycle += VALU_ISSUE_CYCLES + def process_instruction(self, inst: Inst): + """Simulate cycles until instruction dispatches, emitting SQTT packets.""" + pass def finalize(self): """Emit pending ALUEXECs and WAVEEND.""" - # TODO: emit any pending ALUEXECs before WAVEEND + # Emit any remaining ALUEXECs self.emit_waveend() # ═══════════════════════════════════════════════════════════════════════════════ @@ -529,10 +501,9 @@ def step_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int, trace: # TODO: add ALUEXEC emits if anything completed - # Emit SQTT packet for this instruction dispatch - # TODO: see if we ahve to block the dispatch + # Emit SQTT packets for this instruction if trace is not None: - trace.emit_dispatch_inst(inst) + trace.process_instruction(inst) if isinstance(inst, (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM)): delta = exec_scalar(st, inst)