This commit is contained in:
George Hotz
2026-01-02 16:45:34 -08:00
parent 92cb8b6776
commit 6ea3586101

View File

@@ -423,65 +423,17 @@ def exec_wmma(st: WaveState, inst, op: VOP3POp) -> None:
# SQTT TRACING
# ═══════════════════════════════════════════════════════════════════════════════
WAVESTART_TO_INST_CYCLES = 32 # cycles from WAVESTART to first instruction
# Issue intervals (fixed, independent of lane count)
VALU_ISSUE_CYCLES = 1
TRANS_ISSUE_CYCLES = 4
DP_ISSUE_CYCLES = 32
SALU_ISSUE_CYCLES = 1
# ALU latencies (cycles from dispatch to result ready / ALUEXEC)
WAVESTART_TO_INST_CYCLES = 32
VALU_LATENCY = 6
SALU_LATENCY = 2
TRANS_LATENCY = 9
DP_LATENCY = 38
# Pipeline delay from last ALU dispatch to first s_nop IMMEDIATE
SNOP_PIPELINE_DELAY = 3
# s_nop(N) delays the next instruction's issue by:
# issue_delay = N + 1 + SNOP_ISSUE_OVERHEAD + bypass_penalty + extra_stall
#
# where:
# - SNOP_ISSUE_OVERHEAD = 3 (pipeline overhead)
# - bypass_penalty = 4 if N >= 4 and pending ALUEXEC (register cache bypass timeout)
# - extra_stall = 4 if N in 11-22 and pending ALUEXEC (additional pipeline hazard)
#
# For s_nop IMMEDIATE packet timing (without pending ALUEXEC):
# - N in 7-18 has +4 extra delay
SNOP_ISSUE_OVERHEAD = 3
SNOP_EXTRA_DELAY_MIN = 7
SNOP_EXTRA_DELAY_MAX = 18
SNOP_EXTRA_DELAY_MIN_PENDING = 11
SNOP_EXTRA_DELAY_MAX_PENDING = 22
SNOP_EXTRA_DELAY_MIN, SNOP_EXTRA_DELAY_MAX = 7, 18 # extra +4 delay range (no pending)
SNOP_EXTRA_DELAY_MIN_PENDING, SNOP_EXTRA_DELAY_MAX_PENDING = 11, 22 # extra +4 delay range (pending)
SNOP_EXTRA_DELAY_CYCLES = 4
# Forwarding latencies (cycles until result available for dependent instruction)
VALU_FORWARD_LATENCY = 5 # result available 5 cycles after dispatch (writeback at 6)
TRANS_FORWARD_LATENCY = 13 # result available 13 cycles after dispatch
SALU_FORWARD_LATENCY = 1 # result available 1 cycle after dispatch (writeback at 2)
# Forwarding depth limit: after this many dependent ops in a chain, latency increases
FORWARD_DEPTH_LIMIT = 4
FORWARD_DEEP_LATENCY = 9 # latency for deep dependency chains (beyond depth limit)
# Register cache bypass timeout: s_nop(N) with N >= 4 causes VALU results to go through
# the full register file instead of bypass path, adding +4 cycles to ALUEXEC
REGCACHE_BYPASS_TIMEOUT = 4
FORWARD_DEPTH_LIMIT = 4 # chain depth where forwarding exhaustion starts
FORWARD_DEEP_LATENCY = 9 # latency for exhausted forwarding
REGCACHE_BYPASS_TIMEOUT = 4 # s_nop(N>=4) triggers bypass penalty
REGCACHE_BYPASS_PENALTY = 4
# Transcendental ops (use TRANS unit)
_TRANS_OPS = {'V_RCP_F32', 'V_RCP_F64', 'V_RSQ_F32', 'V_RSQ_F64', 'V_SQRT_F32', 'V_SQRT_F64',
'V_LOG_F32', 'V_EXP_F32', 'V_SIN_F32', 'V_COS_F32', 'V_RCP_F16', 'V_RSQ_F16', 'V_SQRT_F16'}
# Double precision ops (use DP unit)
_DP_OPS = {'V_ADD_F64', 'V_MUL_F64', 'V_FMA_F64', 'V_DIV_F64', 'V_MIN_F64', 'V_MAX_F64',
'V_LDEXP_F64', 'V_FREXP_MANT_F64', 'V_FREXP_EXP_I32_F64', 'V_FRACT_F64',
'V_TRUNC_F64', 'V_CEIL_F64', 'V_RNDNE_F64', 'V_FLOOR_F64', 'V_DIV_SCALE_F64',
'V_DIV_FMAS_F64', 'V_DIV_FIXUP_F64', 'V_CVT_F64_I32', 'V_CVT_F64_U32',
'V_CVT_I32_F64', 'V_CVT_U32_F64', 'V_CVT_F32_F64', 'V_CVT_F64_F32'}
from extra.assembly.amd.sqtt import WAVESTART, WAVEEND, IMMEDIATE, VALUINST, ALUEXEC, AluSrc
class SQTTState:
@@ -542,24 +494,12 @@ class SQTTState:
if dst_reg is not None:
self.vgpr[dst_reg] = (completion_cycle, self.vgpr.get(dst_reg, (0, 0))[1])
def _get_valu_src_regs(self, inst: Inst) -> list[int]:
def _get_src_vgprs(self, inst: Inst) -> list[int]:
"""Extract source VGPR indices from a VALU instruction."""
src_vgprs = []
if isinstance(inst, VOP1):
if inst.src0 >= 256: src_vgprs.append(inst.src0 - 256)
elif isinstance(inst, VOP2):
if inst.src0 >= 256: src_vgprs.append(inst.src0 - 256)
src_vgprs.append(inst.vsrc1) # vsrc1 is always a VGPR index
elif isinstance(inst, VOP3):
for src in [inst.src0, inst.src1, getattr(inst, 'src2', None)]:
if src is not None and src >= 256: src_vgprs.append(src - 256)
return src_vgprs
def _get_valu_dst_reg(self, inst: Inst) -> int | None:
"""Extract destination VGPR index from a VALU instruction."""
if isinstance(inst, (VOP1, VOP2, VOP3)):
return inst.vdst
return None
if isinstance(inst, VOP1): return [inst.src0 - 256] if inst.src0 >= 256 else []
if isinstance(inst, VOP2): return ([inst.src0 - 256] if inst.src0 >= 256 else []) + [inst.vsrc1]
if isinstance(inst, VOP3): return [s - 256 for s in [inst.src0, inst.src1, getattr(inst, 'src2', None)] if s is not None and s >= 256]
return []
def process_instruction(self, inst: Inst):
if inst.op == SOPPOp.S_NOP: self._process_snop(inst.simm16)
@@ -589,10 +529,10 @@ class SQTTState:
def _process_valu(self, inst: Inst):
"""Process VALU instruction - emit VALUINST and schedule ALUEXEC."""
dispatch, dst = self.cycle, self._get_valu_dst_reg(inst)
dispatch, dst = self.cycle, inst.vdst
# Find critical dependency: source VGPR with latest ready time
deps = [(r, self.vgpr[r]) for r in self._get_valu_src_regs(inst) if r in self.vgpr]
deps = [(r, self.vgpr[r]) for r in self._get_src_vgprs(inst) if r in self.vgpr]
src_vgpr, (source_ready, src_depth) = max(deps, key=lambda x: x[1][0]) if deps else (None, (0, 0))
depth = src_depth + 1 if deps else 0