mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
framework
This commit is contained in:
@@ -431,12 +431,20 @@ TRANS_ISSUE_CYCLES = 4
|
||||
DP_ISSUE_CYCLES = 32
|
||||
SALU_ISSUE_CYCLES = 1
|
||||
|
||||
# ALU latencies (cycles from dispatch to result ready)
|
||||
# ALU latencies (cycles from dispatch to result ready / ALUEXEC)
|
||||
VALU_LATENCY = 6
|
||||
SALU_LATENCY = 2
|
||||
TRANS_LATENCY = 9
|
||||
DP_LATENCY = 38
|
||||
|
||||
# Pipeline delay from last ALU dispatch to first s_nop IMMEDIATE
|
||||
SNOP_PIPELINE_DELAY = 3
|
||||
|
||||
# Forwarding latencies (cycles until result available for dependent instruction)
|
||||
VALU_FORWARD_LATENCY = 5 # result available 5 cycles after dispatch (writeback at 6)
|
||||
TRANS_FORWARD_LATENCY = 13 # result available 13 cycles after dispatch
|
||||
SALU_FORWARD_LATENCY = 1 # result available 1 cycle after dispatch (writeback at 2)
|
||||
|
||||
# Transcendental ops (use TRANS unit)
|
||||
_TRANS_OPS = {'V_RCP_F32', 'V_RCP_F64', 'V_RSQ_F32', 'V_RSQ_F64', 'V_SQRT_F32', 'V_SQRT_F64',
|
||||
'V_LOG_F32', 'V_EXP_F32', 'V_SIN_F32', 'V_COS_F32', 'V_RCP_F16', 'V_RSQ_F16', 'V_SQRT_F16'}
|
||||
@@ -450,18 +458,11 @@ _DP_OPS = {'V_ADD_F64', 'V_MUL_F64', 'V_FMA_F64', 'V_DIV_F64', 'V_MIN_F64', 'V_M
|
||||
|
||||
class SQTTState:
|
||||
"""SQTT tracing state - emits packets when instructions dispatch."""
|
||||
__slots__ = ('packets', 'wave_id', 'simd', 'cu', 'cycle', 'vgpr_ready', 'sgpr_ready',
|
||||
'pending_completions', 'trans_busy_until', 'dp_busy_until')
|
||||
|
||||
def __init__(self, wave_id: int = 0, simd: int = 0, cu: int = 0):
|
||||
self.packets = []
|
||||
self.wave_id, self.simd, self.cu = wave_id, simd, cu
|
||||
self.cycle = 0
|
||||
self.vgpr_ready = {} # vgpr -> cycle when result ready
|
||||
self.sgpr_ready = {} # sgpr -> cycle when result ready
|
||||
self.pending_completions = [] # list of (cycle, AluSrc)
|
||||
self.trans_busy_until = 0 # when TRANS unit is free
|
||||
self.dp_busy_until = 0 # when DP unit is free
|
||||
|
||||
def emit(self, pkt_class, **kwargs):
|
||||
self.packets.append(pkt_class(_time=self.cycle, **kwargs))
|
||||
@@ -469,53 +470,24 @@ class SQTTState:
|
||||
def emit_wavestart(self):
|
||||
from extra.assembly.amd.sqtt import WAVESTART
|
||||
self.emit(WAVESTART, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3)
|
||||
self.cycle = WAVESTART_TO_INST_CYCLES
|
||||
for _ in range(WAVESTART_TO_INST_CYCLES):
|
||||
self.tick()
|
||||
|
||||
def emit_waveend(self):
|
||||
from extra.assembly.amd.sqtt import WAVEEND
|
||||
self.emit(WAVEEND, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3)
|
||||
|
||||
def emit_dispatch_inst(self, inst: Inst):
|
||||
"""Emit SQTT packet for instruction dispatch with proper timing."""
|
||||
from extra.assembly.amd.sqtt import INST, VALUINST, IMMEDIATE, InstOp
|
||||
def tick(self):
|
||||
"""Process one cycle: emit any completing ALUEXECs, then advance cycle."""
|
||||
self.cycle += 1
|
||||
|
||||
if isinstance(inst, (SOP1, SOP2, SOPC, SOPK)):
|
||||
# SALU: 1-cycle issue
|
||||
self.emit(INST, wave=self.wave_id, op=InstOp.SALU)
|
||||
self.cycle += SALU_ISSUE_CYCLES
|
||||
|
||||
elif isinstance(inst, SOPP):
|
||||
if inst.op == SOPPOp.S_NOP:
|
||||
# s_nop emits IMMEDIATE and delays
|
||||
self.emit(IMMEDIATE, wave=self.wave_id)
|
||||
self.cycle += inst.simm16 + 1
|
||||
# Other SOPP (s_endpgm, etc.) - no packet
|
||||
|
||||
elif isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD)):
|
||||
op_name = inst.op_name if hasattr(inst, 'op_name') else ''
|
||||
|
||||
if any(t in op_name for t in _TRANS_OPS):
|
||||
# TRANS: wait for unit, 4-cycle issue interval
|
||||
self.cycle = max(self.cycle, self.trans_busy_until)
|
||||
self.emit(INST, wave=self.wave_id, op=InstOp.VALU_TRANS)
|
||||
self.trans_busy_until = self.cycle + TRANS_ISSUE_CYCLES
|
||||
self.cycle += 1 # next instruction can issue next cycle (different unit)
|
||||
|
||||
elif any(t in op_name for t in _DP_OPS):
|
||||
# DP: wait for unit, 32-cycle issue interval
|
||||
self.cycle = max(self.cycle, self.dp_busy_until)
|
||||
self.emit(INST, wave=self.wave_id, op=InstOp.VALU_64)
|
||||
self.dp_busy_until = self.cycle + DP_ISSUE_CYCLES
|
||||
self.cycle += 1
|
||||
|
||||
else:
|
||||
# Regular VALU: 1-cycle issue
|
||||
self.emit(VALUINST, wave=self.wave_id)
|
||||
self.cycle += VALU_ISSUE_CYCLES
|
||||
def process_instruction(self, inst: Inst):
|
||||
"""Simulate cycles until instruction dispatches, emitting SQTT packets."""
|
||||
pass
|
||||
|
||||
def finalize(self):
|
||||
"""Emit pending ALUEXECs and WAVEEND."""
|
||||
# TODO: emit any pending ALUEXECs before WAVEEND
|
||||
# Emit any remaining ALUEXECs
|
||||
self.emit_waveend()
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
@@ -529,10 +501,9 @@ def step_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int, trace:
|
||||
|
||||
# TODO: add ALUEXEC emits if anything completed
|
||||
|
||||
# Emit SQTT packet for this instruction dispatch
|
||||
# TODO: see if we ahve to block the dispatch
|
||||
# Emit SQTT packets for this instruction
|
||||
if trace is not None:
|
||||
trace.emit_dispatch_inst(inst)
|
||||
trace.process_instruction(inst)
|
||||
|
||||
if isinstance(inst, (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM)):
|
||||
delta = exec_scalar(st, inst)
|
||||
|
||||
Reference in New Issue
Block a user