mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
short
This commit is contained in:
@@ -423,65 +423,17 @@ def exec_wmma(st: WaveState, inst, op: VOP3POp) -> None:
|
||||
# SQTT TRACING
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
WAVESTART_TO_INST_CYCLES = 32 # cycles from WAVESTART to first instruction
|
||||
|
||||
# Issue intervals (fixed, independent of lane count)
|
||||
VALU_ISSUE_CYCLES = 1
|
||||
TRANS_ISSUE_CYCLES = 4
|
||||
DP_ISSUE_CYCLES = 32
|
||||
SALU_ISSUE_CYCLES = 1
|
||||
|
||||
# ALU latencies (cycles from dispatch to result ready / ALUEXEC)
|
||||
WAVESTART_TO_INST_CYCLES = 32
|
||||
VALU_LATENCY = 6
|
||||
SALU_LATENCY = 2
|
||||
TRANS_LATENCY = 9
|
||||
DP_LATENCY = 38
|
||||
|
||||
# Pipeline delay from last ALU dispatch to first s_nop IMMEDIATE
|
||||
SNOP_PIPELINE_DELAY = 3
|
||||
|
||||
# s_nop(N) delays the next instruction's issue by:
|
||||
# issue_delay = N + 1 + SNOP_ISSUE_OVERHEAD + bypass_penalty + extra_stall
|
||||
#
|
||||
# where:
|
||||
# - SNOP_ISSUE_OVERHEAD = 3 (pipeline overhead)
|
||||
# - bypass_penalty = 4 if N >= 4 and pending ALUEXEC (register cache bypass timeout)
|
||||
# - extra_stall = 4 if N in 11-22 and pending ALUEXEC (additional pipeline hazard)
|
||||
#
|
||||
# For s_nop IMMEDIATE packet timing (without pending ALUEXEC):
|
||||
# - N in 7-18 has +4 extra delay
|
||||
SNOP_ISSUE_OVERHEAD = 3
|
||||
SNOP_EXTRA_DELAY_MIN = 7
|
||||
SNOP_EXTRA_DELAY_MAX = 18
|
||||
SNOP_EXTRA_DELAY_MIN_PENDING = 11
|
||||
SNOP_EXTRA_DELAY_MAX_PENDING = 22
|
||||
SNOP_EXTRA_DELAY_MIN, SNOP_EXTRA_DELAY_MAX = 7, 18 # extra +4 delay range (no pending)
|
||||
SNOP_EXTRA_DELAY_MIN_PENDING, SNOP_EXTRA_DELAY_MAX_PENDING = 11, 22 # extra +4 delay range (pending)
|
||||
SNOP_EXTRA_DELAY_CYCLES = 4
|
||||
|
||||
# Forwarding latencies (cycles until result available for dependent instruction)
|
||||
VALU_FORWARD_LATENCY = 5 # result available 5 cycles after dispatch (writeback at 6)
|
||||
TRANS_FORWARD_LATENCY = 13 # result available 13 cycles after dispatch
|
||||
SALU_FORWARD_LATENCY = 1 # result available 1 cycle after dispatch (writeback at 2)
|
||||
|
||||
# Forwarding depth limit: after this many dependent ops in a chain, latency increases
|
||||
FORWARD_DEPTH_LIMIT = 4
|
||||
FORWARD_DEEP_LATENCY = 9 # latency for deep dependency chains (beyond depth limit)
|
||||
|
||||
# Register cache bypass timeout: s_nop(N) with N >= 4 causes VALU results to go through
|
||||
# the full register file instead of bypass path, adding +4 cycles to ALUEXEC
|
||||
REGCACHE_BYPASS_TIMEOUT = 4
|
||||
FORWARD_DEPTH_LIMIT = 4 # chain depth where forwarding exhaustion starts
|
||||
FORWARD_DEEP_LATENCY = 9 # latency for exhausted forwarding
|
||||
REGCACHE_BYPASS_TIMEOUT = 4 # s_nop(N>=4) triggers bypass penalty
|
||||
REGCACHE_BYPASS_PENALTY = 4
|
||||
|
||||
# Transcendental ops (use TRANS unit)
|
||||
_TRANS_OPS = {'V_RCP_F32', 'V_RCP_F64', 'V_RSQ_F32', 'V_RSQ_F64', 'V_SQRT_F32', 'V_SQRT_F64',
|
||||
'V_LOG_F32', 'V_EXP_F32', 'V_SIN_F32', 'V_COS_F32', 'V_RCP_F16', 'V_RSQ_F16', 'V_SQRT_F16'}
|
||||
|
||||
# Double precision ops (use DP unit)
|
||||
_DP_OPS = {'V_ADD_F64', 'V_MUL_F64', 'V_FMA_F64', 'V_DIV_F64', 'V_MIN_F64', 'V_MAX_F64',
|
||||
'V_LDEXP_F64', 'V_FREXP_MANT_F64', 'V_FREXP_EXP_I32_F64', 'V_FRACT_F64',
|
||||
'V_TRUNC_F64', 'V_CEIL_F64', 'V_RNDNE_F64', 'V_FLOOR_F64', 'V_DIV_SCALE_F64',
|
||||
'V_DIV_FMAS_F64', 'V_DIV_FIXUP_F64', 'V_CVT_F64_I32', 'V_CVT_F64_U32',
|
||||
'V_CVT_I32_F64', 'V_CVT_U32_F64', 'V_CVT_F32_F64', 'V_CVT_F64_F32'}
|
||||
|
||||
from extra.assembly.amd.sqtt import WAVESTART, WAVEEND, IMMEDIATE, VALUINST, ALUEXEC, AluSrc
|
||||
|
||||
class SQTTState:
|
||||
@@ -542,24 +494,12 @@ class SQTTState:
|
||||
if dst_reg is not None:
|
||||
self.vgpr[dst_reg] = (completion_cycle, self.vgpr.get(dst_reg, (0, 0))[1])
|
||||
|
||||
def _get_valu_src_regs(self, inst: Inst) -> list[int]:
|
||||
def _get_src_vgprs(self, inst: Inst) -> list[int]:
|
||||
"""Extract source VGPR indices from a VALU instruction."""
|
||||
src_vgprs = []
|
||||
if isinstance(inst, VOP1):
|
||||
if inst.src0 >= 256: src_vgprs.append(inst.src0 - 256)
|
||||
elif isinstance(inst, VOP2):
|
||||
if inst.src0 >= 256: src_vgprs.append(inst.src0 - 256)
|
||||
src_vgprs.append(inst.vsrc1) # vsrc1 is always a VGPR index
|
||||
elif isinstance(inst, VOP3):
|
||||
for src in [inst.src0, inst.src1, getattr(inst, 'src2', None)]:
|
||||
if src is not None and src >= 256: src_vgprs.append(src - 256)
|
||||
return src_vgprs
|
||||
|
||||
def _get_valu_dst_reg(self, inst: Inst) -> int | None:
|
||||
"""Extract destination VGPR index from a VALU instruction."""
|
||||
if isinstance(inst, (VOP1, VOP2, VOP3)):
|
||||
return inst.vdst
|
||||
return None
|
||||
if isinstance(inst, VOP1): return [inst.src0 - 256] if inst.src0 >= 256 else []
|
||||
if isinstance(inst, VOP2): return ([inst.src0 - 256] if inst.src0 >= 256 else []) + [inst.vsrc1]
|
||||
if isinstance(inst, VOP3): return [s - 256 for s in [inst.src0, inst.src1, getattr(inst, 'src2', None)] if s is not None and s >= 256]
|
||||
return []
|
||||
|
||||
def process_instruction(self, inst: Inst):
|
||||
if inst.op == SOPPOp.S_NOP: self._process_snop(inst.simm16)
|
||||
@@ -589,10 +529,10 @@ class SQTTState:
|
||||
|
||||
def _process_valu(self, inst: Inst):
|
||||
"""Process VALU instruction - emit VALUINST and schedule ALUEXEC."""
|
||||
dispatch, dst = self.cycle, self._get_valu_dst_reg(inst)
|
||||
dispatch, dst = self.cycle, inst.vdst
|
||||
|
||||
# Find critical dependency: source VGPR with latest ready time
|
||||
deps = [(r, self.vgpr[r]) for r in self._get_valu_src_regs(inst) if r in self.vgpr]
|
||||
deps = [(r, self.vgpr[r]) for r in self._get_src_vgprs(inst) if r in self.vgpr]
|
||||
src_vgpr, (source_ready, src_depth) = max(deps, key=lambda x: x[1][0]) if deps else (None, (0, 0))
|
||||
depth = src_depth + 1 if deps else 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user