mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
framework
This commit is contained in:
@@ -431,12 +431,20 @@ TRANS_ISSUE_CYCLES = 4
|
|||||||
DP_ISSUE_CYCLES = 32
|
DP_ISSUE_CYCLES = 32
|
||||||
SALU_ISSUE_CYCLES = 1
|
SALU_ISSUE_CYCLES = 1
|
||||||
|
|
||||||
# ALU latencies (cycles from dispatch to result ready)
|
# ALU latencies (cycles from dispatch to result ready / ALUEXEC)
|
||||||
VALU_LATENCY = 6
|
VALU_LATENCY = 6
|
||||||
SALU_LATENCY = 2
|
SALU_LATENCY = 2
|
||||||
TRANS_LATENCY = 9
|
TRANS_LATENCY = 9
|
||||||
DP_LATENCY = 38
|
DP_LATENCY = 38
|
||||||
|
|
||||||
|
# Pipeline delay from last ALU dispatch to first s_nop IMMEDIATE
|
||||||
|
SNOP_PIPELINE_DELAY = 3
|
||||||
|
|
||||||
|
# Forwarding latencies (cycles until result available for dependent instruction)
|
||||||
|
VALU_FORWARD_LATENCY = 5 # result available 5 cycles after dispatch (writeback at 6)
|
||||||
|
TRANS_FORWARD_LATENCY = 13 # result available 13 cycles after dispatch
|
||||||
|
SALU_FORWARD_LATENCY = 1 # result available 1 cycle after dispatch (writeback at 2)
|
||||||
|
|
||||||
# Transcendental ops (use TRANS unit)
|
# Transcendental ops (use TRANS unit)
|
||||||
_TRANS_OPS = {'V_RCP_F32', 'V_RCP_F64', 'V_RSQ_F32', 'V_RSQ_F64', 'V_SQRT_F32', 'V_SQRT_F64',
|
_TRANS_OPS = {'V_RCP_F32', 'V_RCP_F64', 'V_RSQ_F32', 'V_RSQ_F64', 'V_SQRT_F32', 'V_SQRT_F64',
|
||||||
'V_LOG_F32', 'V_EXP_F32', 'V_SIN_F32', 'V_COS_F32', 'V_RCP_F16', 'V_RSQ_F16', 'V_SQRT_F16'}
|
'V_LOG_F32', 'V_EXP_F32', 'V_SIN_F32', 'V_COS_F32', 'V_RCP_F16', 'V_RSQ_F16', 'V_SQRT_F16'}
|
||||||
@@ -450,18 +458,11 @@ _DP_OPS = {'V_ADD_F64', 'V_MUL_F64', 'V_FMA_F64', 'V_DIV_F64', 'V_MIN_F64', 'V_M
|
|||||||
|
|
||||||
class SQTTState:
|
class SQTTState:
|
||||||
"""SQTT tracing state - emits packets when instructions dispatch."""
|
"""SQTT tracing state - emits packets when instructions dispatch."""
|
||||||
__slots__ = ('packets', 'wave_id', 'simd', 'cu', 'cycle', 'vgpr_ready', 'sgpr_ready',
|
|
||||||
'pending_completions', 'trans_busy_until', 'dp_busy_until')
|
|
||||||
|
|
||||||
def __init__(self, wave_id: int = 0, simd: int = 0, cu: int = 0):
|
def __init__(self, wave_id: int = 0, simd: int = 0, cu: int = 0):
|
||||||
self.packets = []
|
self.packets = []
|
||||||
self.wave_id, self.simd, self.cu = wave_id, simd, cu
|
self.wave_id, self.simd, self.cu = wave_id, simd, cu
|
||||||
self.cycle = 0
|
self.cycle = 0
|
||||||
self.vgpr_ready = {} # vgpr -> cycle when result ready
|
|
||||||
self.sgpr_ready = {} # sgpr -> cycle when result ready
|
|
||||||
self.pending_completions = [] # list of (cycle, AluSrc)
|
|
||||||
self.trans_busy_until = 0 # when TRANS unit is free
|
|
||||||
self.dp_busy_until = 0 # when DP unit is free
|
|
||||||
|
|
||||||
def emit(self, pkt_class, **kwargs):
|
def emit(self, pkt_class, **kwargs):
|
||||||
self.packets.append(pkt_class(_time=self.cycle, **kwargs))
|
self.packets.append(pkt_class(_time=self.cycle, **kwargs))
|
||||||
@@ -469,53 +470,24 @@ class SQTTState:
|
|||||||
def emit_wavestart(self):
|
def emit_wavestart(self):
|
||||||
from extra.assembly.amd.sqtt import WAVESTART
|
from extra.assembly.amd.sqtt import WAVESTART
|
||||||
self.emit(WAVESTART, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3)
|
self.emit(WAVESTART, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3)
|
||||||
self.cycle = WAVESTART_TO_INST_CYCLES
|
for _ in range(WAVESTART_TO_INST_CYCLES):
|
||||||
|
self.tick()
|
||||||
|
|
||||||
def emit_waveend(self):
|
def emit_waveend(self):
|
||||||
from extra.assembly.amd.sqtt import WAVEEND
|
from extra.assembly.amd.sqtt import WAVEEND
|
||||||
self.emit(WAVEEND, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3)
|
self.emit(WAVEEND, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3)
|
||||||
|
|
||||||
def emit_dispatch_inst(self, inst: Inst):
|
def tick(self):
|
||||||
"""Emit SQTT packet for instruction dispatch with proper timing."""
|
"""Process one cycle: emit any completing ALUEXECs, then advance cycle."""
|
||||||
from extra.assembly.amd.sqtt import INST, VALUINST, IMMEDIATE, InstOp
|
self.cycle += 1
|
||||||
|
|
||||||
if isinstance(inst, (SOP1, SOP2, SOPC, SOPK)):
|
def process_instruction(self, inst: Inst):
|
||||||
# SALU: 1-cycle issue
|
"""Simulate cycles until instruction dispatches, emitting SQTT packets."""
|
||||||
self.emit(INST, wave=self.wave_id, op=InstOp.SALU)
|
pass
|
||||||
self.cycle += SALU_ISSUE_CYCLES
|
|
||||||
|
|
||||||
elif isinstance(inst, SOPP):
|
|
||||||
if inst.op == SOPPOp.S_NOP:
|
|
||||||
# s_nop emits IMMEDIATE and delays
|
|
||||||
self.emit(IMMEDIATE, wave=self.wave_id)
|
|
||||||
self.cycle += inst.simm16 + 1
|
|
||||||
# Other SOPP (s_endpgm, etc.) - no packet
|
|
||||||
|
|
||||||
elif isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD)):
|
|
||||||
op_name = inst.op_name if hasattr(inst, 'op_name') else ''
|
|
||||||
|
|
||||||
if any(t in op_name for t in _TRANS_OPS):
|
|
||||||
# TRANS: wait for unit, 4-cycle issue interval
|
|
||||||
self.cycle = max(self.cycle, self.trans_busy_until)
|
|
||||||
self.emit(INST, wave=self.wave_id, op=InstOp.VALU_TRANS)
|
|
||||||
self.trans_busy_until = self.cycle + TRANS_ISSUE_CYCLES
|
|
||||||
self.cycle += 1 # next instruction can issue next cycle (different unit)
|
|
||||||
|
|
||||||
elif any(t in op_name for t in _DP_OPS):
|
|
||||||
# DP: wait for unit, 32-cycle issue interval
|
|
||||||
self.cycle = max(self.cycle, self.dp_busy_until)
|
|
||||||
self.emit(INST, wave=self.wave_id, op=InstOp.VALU_64)
|
|
||||||
self.dp_busy_until = self.cycle + DP_ISSUE_CYCLES
|
|
||||||
self.cycle += 1
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Regular VALU: 1-cycle issue
|
|
||||||
self.emit(VALUINST, wave=self.wave_id)
|
|
||||||
self.cycle += VALU_ISSUE_CYCLES
|
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
"""Emit pending ALUEXECs and WAVEEND."""
|
"""Emit pending ALUEXECs and WAVEEND."""
|
||||||
# TODO: emit any pending ALUEXECs before WAVEEND
|
# Emit any remaining ALUEXECs
|
||||||
self.emit_waveend()
|
self.emit_waveend()
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════════════
|
||||||
@@ -529,10 +501,9 @@ def step_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int, trace:
|
|||||||
|
|
||||||
# TODO: add ALUEXEC emits if anything completed
|
# TODO: add ALUEXEC emits if anything completed
|
||||||
|
|
||||||
# Emit SQTT packet for this instruction dispatch
|
# Emit SQTT packets for this instruction
|
||||||
# TODO: see if we ahve to block the dispatch
|
|
||||||
if trace is not None:
|
if trace is not None:
|
||||||
trace.emit_dispatch_inst(inst)
|
trace.process_instruction(inst)
|
||||||
|
|
||||||
if isinstance(inst, (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM)):
|
if isinstance(inst, (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM)):
|
||||||
delta = exec_scalar(st, inst)
|
delta = exec_scalar(st, inst)
|
||||||
|
|||||||
Reference in New Issue
Block a user