framework

This commit is contained in:
George Hotz
2026-01-02 11:31:41 -08:00
parent 849af761a4
commit 672008ccab

View File

@@ -431,12 +431,20 @@ TRANS_ISSUE_CYCLES = 4
DP_ISSUE_CYCLES = 32 DP_ISSUE_CYCLES = 32
SALU_ISSUE_CYCLES = 1 SALU_ISSUE_CYCLES = 1
# ALU latencies (cycles from dispatch to result ready) # ALU latencies (cycles from dispatch to result ready / ALUEXEC)
VALU_LATENCY = 6 VALU_LATENCY = 6
SALU_LATENCY = 2 SALU_LATENCY = 2
TRANS_LATENCY = 9 TRANS_LATENCY = 9
DP_LATENCY = 38 DP_LATENCY = 38
# Pipeline delay from last ALU dispatch to first s_nop IMMEDIATE
SNOP_PIPELINE_DELAY = 3
# Forwarding latencies (cycles until result available for dependent instruction)
VALU_FORWARD_LATENCY = 5 # result available 5 cycles after dispatch (writeback at 6)
TRANS_FORWARD_LATENCY = 13 # result available 13 cycles after dispatch
SALU_FORWARD_LATENCY = 1 # result available 1 cycle after dispatch (writeback at 2)
# Transcendental ops (use TRANS unit) # Transcendental ops (use TRANS unit)
_TRANS_OPS = {'V_RCP_F32', 'V_RCP_F64', 'V_RSQ_F32', 'V_RSQ_F64', 'V_SQRT_F32', 'V_SQRT_F64', _TRANS_OPS = {'V_RCP_F32', 'V_RCP_F64', 'V_RSQ_F32', 'V_RSQ_F64', 'V_SQRT_F32', 'V_SQRT_F64',
'V_LOG_F32', 'V_EXP_F32', 'V_SIN_F32', 'V_COS_F32', 'V_RCP_F16', 'V_RSQ_F16', 'V_SQRT_F16'} 'V_LOG_F32', 'V_EXP_F32', 'V_SIN_F32', 'V_COS_F32', 'V_RCP_F16', 'V_RSQ_F16', 'V_SQRT_F16'}
@@ -450,18 +458,11 @@ _DP_OPS = {'V_ADD_F64', 'V_MUL_F64', 'V_FMA_F64', 'V_DIV_F64', 'V_MIN_F64', 'V_M
class SQTTState: class SQTTState:
"""SQTT tracing state - emits packets when instructions dispatch.""" """SQTT tracing state - emits packets when instructions dispatch."""
__slots__ = ('packets', 'wave_id', 'simd', 'cu', 'cycle', 'vgpr_ready', 'sgpr_ready',
'pending_completions', 'trans_busy_until', 'dp_busy_until')
def __init__(self, wave_id: int = 0, simd: int = 0, cu: int = 0): def __init__(self, wave_id: int = 0, simd: int = 0, cu: int = 0):
self.packets = [] self.packets = []
self.wave_id, self.simd, self.cu = wave_id, simd, cu self.wave_id, self.simd, self.cu = wave_id, simd, cu
self.cycle = 0 self.cycle = 0
self.vgpr_ready = {} # vgpr -> cycle when result ready
self.sgpr_ready = {} # sgpr -> cycle when result ready
self.pending_completions = [] # list of (cycle, AluSrc)
self.trans_busy_until = 0 # when TRANS unit is free
self.dp_busy_until = 0 # when DP unit is free
def emit(self, pkt_class, **kwargs): def emit(self, pkt_class, **kwargs):
self.packets.append(pkt_class(_time=self.cycle, **kwargs)) self.packets.append(pkt_class(_time=self.cycle, **kwargs))
@@ -469,53 +470,24 @@ class SQTTState:
def emit_wavestart(self): def emit_wavestart(self):
from extra.assembly.amd.sqtt import WAVESTART from extra.assembly.amd.sqtt import WAVESTART
self.emit(WAVESTART, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3) self.emit(WAVESTART, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3)
self.cycle = WAVESTART_TO_INST_CYCLES for _ in range(WAVESTART_TO_INST_CYCLES):
self.tick()
def emit_waveend(self): def emit_waveend(self):
from extra.assembly.amd.sqtt import WAVEEND from extra.assembly.amd.sqtt import WAVEEND
self.emit(WAVEEND, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3) self.emit(WAVEEND, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3)
def emit_dispatch_inst(self, inst: Inst): def tick(self):
"""Emit SQTT packet for instruction dispatch with proper timing.""" """Process one cycle: emit any completing ALUEXECs, then advance cycle."""
from extra.assembly.amd.sqtt import INST, VALUINST, IMMEDIATE, InstOp self.cycle += 1
if isinstance(inst, (SOP1, SOP2, SOPC, SOPK)): def process_instruction(self, inst: Inst):
# SALU: 1-cycle issue """Simulate cycles until instruction dispatches, emitting SQTT packets."""
self.emit(INST, wave=self.wave_id, op=InstOp.SALU) pass
self.cycle += SALU_ISSUE_CYCLES
elif isinstance(inst, SOPP):
if inst.op == SOPPOp.S_NOP:
# s_nop emits IMMEDIATE and delays
self.emit(IMMEDIATE, wave=self.wave_id)
self.cycle += inst.simm16 + 1
# Other SOPP (s_endpgm, etc.) - no packet
elif isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD)):
op_name = inst.op_name if hasattr(inst, 'op_name') else ''
if any(t in op_name for t in _TRANS_OPS):
# TRANS: wait for unit, 4-cycle issue interval
self.cycle = max(self.cycle, self.trans_busy_until)
self.emit(INST, wave=self.wave_id, op=InstOp.VALU_TRANS)
self.trans_busy_until = self.cycle + TRANS_ISSUE_CYCLES
self.cycle += 1 # next instruction can issue next cycle (different unit)
elif any(t in op_name for t in _DP_OPS):
# DP: wait for unit, 32-cycle issue interval
self.cycle = max(self.cycle, self.dp_busy_until)
self.emit(INST, wave=self.wave_id, op=InstOp.VALU_64)
self.dp_busy_until = self.cycle + DP_ISSUE_CYCLES
self.cycle += 1
else:
# Regular VALU: 1-cycle issue
self.emit(VALUINST, wave=self.wave_id)
self.cycle += VALU_ISSUE_CYCLES
def finalize(self): def finalize(self):
"""Emit pending ALUEXECs and WAVEEND.""" """Emit pending ALUEXECs and WAVEEND."""
# TODO: emit any pending ALUEXECs before WAVEEND # Emit any remaining ALUEXECs
self.emit_waveend() self.emit_waveend()
# ═══════════════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════════════
@@ -529,10 +501,9 @@ def step_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int, trace:
# TODO: add ALUEXEC emits if anything completed # TODO: add ALUEXEC emits if anything completed
# Emit SQTT packet for this instruction dispatch # Emit SQTT packets for this instruction
# TODO: see if we ahve to block the dispatch
if trace is not None: if trace is not None:
trace.emit_dispatch_inst(inst) trace.process_instruction(inst)
if isinstance(inst, (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM)): if isinstance(inst, (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM)):
delta = exec_scalar(st, inst) delta = exec_scalar(st, inst)