This commit is contained in:
George Hotz
2026-01-02 00:38:27 -05:00
parent 29f3fb7af3
commit 21ffa1a86b
3 changed files with 221 additions and 87 deletions

View File

@@ -103,16 +103,20 @@ _TRANS_OPS = {'V_RCP_F32', 'V_RCP_F64', 'V_RSQ_F32', 'V_RSQ_F64', 'V_SQRT_F32',
# Startup: WAVESTART -> first instruction (~32 cycles with warm cache, no REG packet)
# Cold cache: WAVESTART -> REG (~137 cycles) -> first instruction (~150 cycles after REG)
# SALU: issues every cycle, result ready 2 cycles after issue, ALUEXEC at ready time
# VALU: issues every cycle, ALUEXEC at issue+8 for each inst, serialized with +1 intervals
# VALU: issues every cycle, ALUEXEC at issue+6 for each inst, serialized with +1 intervals
# TRANS: issues every 4 cycles, ALUEXEC at last_issue+1, then +8 intervals
# For dependent instructions, ALUEXEC is at source_ready + 10 (first dep) or + 9 (chained)
# VALU queue depth ~7: after 7 in-flight, ALUEXEC interleaves with VALUINST
WAVESTART_TO_INST_CYCLES = 32 # cycles from WAVESTART to first instruction (warm cache)
SALU_LATENCY = 2 # cycles from issue to result ready
VALU_EXEC_LATENCY = 8 # cycles from issue to ALUEXEC for each instruction
VALU_EXEC_LATENCY = 6 # cycles from first issue to first ALUEXEC
TRANS_ISSUE_CYCLES = 4 # cycles between transcendental instruction issues
TRANS_LATENCY = 9 # cycles from trans issue to ALUEXEC
class SQTTState:
"""SQTT tracing state - emits packets matching real hardware (warm cache model)."""
__slots__ = ('cycle', 'packets', 'pending_exec', 'wave_id', 'simd', 'cu', 'sgpr_ready', 'vgpr_ready',
'last_salu_exec', 'last_valu_exec', 'valu_issue_cycles')
'last_salu_exec', 'last_valu_exec', 'trans_count', 'last_trans_issue', 'last_trans_exec', 'first_valu_issue', 'valu_count', 'last_inst_type')
def __init__(self, wave_id: int = 0, simd: int = 0, cu: int = 0):
self.cycle = 0
@@ -120,10 +124,15 @@ class SQTTState:
self.pending_exec: list[tuple[int, int]] = [] # (completion_cycle, src) for ALUEXEC
self.wave_id, self.simd, self.cu = wave_id, simd, cu
self.sgpr_ready: dict[int, int] = {} # sgpr -> cycle when result ready
self.last_inst_type: str = '' # track last instruction type for gap calculation
self.vgpr_ready: dict[int, int] = {} # vgpr -> cycle when result ready
self.last_salu_exec = 0 # last SALU ALUEXEC time (for +1 spacing)
self.last_valu_exec = 0 # last VALU ALUEXEC time (for +1 spacing)
self.valu_issue_cycles: list[int] = [] # issue cycles for pending independent VALU
self.trans_count = 0 # number of trans instructions issued
self.last_trans_issue = 0 # cycle when last trans was issued
self.last_trans_exec = 0 # cycle when last trans ALUEXEC is scheduled
self.first_valu_issue = 0 # cycle when first VALU was issued
self.valu_count = 0 # number of VALU instructions issued
def emit_wavestart(self):
from extra.assembly.amd.sqtt import WAVESTART
@@ -146,6 +155,10 @@ class SQTTState:
from extra.assembly.amd.sqtt import ALUEXEC
self.packets.append(ALUEXEC(_time=self.cycle, src=src))
def emit_immediate(self):
from extra.assembly.amd.sqtt import IMMEDIATE
self.packets.append(IMMEDIATE(_time=self.cycle, wave=self.wave_id))
def _get_src_regs(self, inst: Inst) -> list[tuple[str, int]]:
"""Extract source register references from instruction."""
srcs = []
@@ -218,86 +231,173 @@ class SQTTState:
self.pending_exec.append((exec_cycle, AluSrc.SALU))
self.last_salu_exec = exec_cycle
self._record_dst_ready(inst, ready_cycle)
self.last_inst_type = 'SALU'
elif isinstance(inst, SOPP):
pass # nop, waitcnt, etc don't emit packets
# s_nop and other SOPP emit IMMEDIATE packets
if self.last_inst_type == 'SALU':
# SALU→SOPP: emit ALUEXECs with max gap of 2 once, then IMMEDIATEs
# First ALUEXEC at last INST cycle, then allow one +2 gap, rest interleave
# cycle is currently 1 past last INST
last_inst_cycle = self.cycle - 1
first = True
had_gap = False # track if we've had a +2 gap
while self.pending_exec:
exec_cycle, src = self.pending_exec[0]
if first:
self.pending_exec.pop(0)
self.cycle = last_inst_cycle # First ALUEXEC at same cycle as last INST
first = False
elif exec_cycle == self.cycle + 1: # consecutive (+1)
self.pending_exec.pop(0)
self.cycle = exec_cycle
elif exec_cycle == self.cycle + 2 and not had_gap: # allow one +2 gap
self.pending_exec.pop(0)
self.cycle = exec_cycle
had_gap = True
else:
break # too large a gap or second +2 gap - rest interleave with IMMEDIATEs
self.emit_aluexec(src)
self.cycle += 1 # IMMEDIATE starts +1 after last ALUEXEC
elif self.last_inst_type == 'TRANS':
# TRANS→SOPP: emit first ALUEXEC at its scheduled time, then IMMEDIATE +2 after
# Sort pending execs and emit first one if it's before IMMEDIATE time
self.pending_exec.sort(key=lambda x: x[0])
if self.pending_exec:
first_exec_cycle = self.pending_exec[0][0]
# First IMMEDIATE at first_exec + 2
imm_cycle = first_exec_cycle + 2
# Emit ALUEXECs that come before IMMEDIATE
while self.pending_exec and self.pending_exec[0][0] < imm_cycle:
exec_cycle, src = self.pending_exec.pop(0)
self.cycle = exec_cycle
self.emit_aluexec(src)
self.cycle = imm_cycle
elif self.last_inst_type == 'VALU':
# VALU→SOPP: 2-cycle gap, emit any ALUEXECs that would come before
imm_cycle = self.cycle + 2 # When IMMEDIATE would be emitted
# Emit ALUEXECs that come before the IMMEDIATE
while self.pending_exec and self.pending_exec[0][0] < imm_cycle:
exec_cycle, src = self.pending_exec.pop(0)
self.cycle = max(self.cycle, exec_cycle)
self.emit_aluexec(src)
# Jump to IMMEDIATE cycle (either imm_cycle or 1 after last ALUEXEC, whichever is later)
self.cycle = max(imm_cycle, self.cycle + 1)
# Emit IMMEDIATE first, then any pending ALUEXECs at same cycle (HW order)
self.emit_immediate()
while self.pending_exec and self.pending_exec[0][0] <= self.cycle:
exec_cycle, src = self.pending_exec.pop(0)
old_cycle = self.cycle
self.cycle = exec_cycle
self.emit_aluexec(src)
self.cycle = old_cycle
self.last_inst_type = 'SOPP'
elif isinstance(inst, SMEM):
pass # skip for ALU focus
elif isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC)):
# VALU: issue now, track for ALUEXEC timing in finalize()
# All ALUEXEC start at last_issue + 8, then serialize based on dependencies
# VALU: issue now, emit ALUEXEC for completed instructions if queue is full
from extra.assembly.amd.sqtt import AluSrc
op_name = inst.op_name if hasattr(inst, 'op_name') else ''
if any(t in op_name for t in _TRANS_OPS):
is_trans = any(t in op_name for t in _TRANS_OPS)
if is_trans:
# Transcendental: emit INST, 4-cycle issue, add ALUEXEC to pending
self.emit_inst(InstOp.VALU_TRANS)
self.trans_count += 1
# Check for dependency on VGPR sources
src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0)
if src_ready > self.cycle:
# Dependent trans: ALUEXEC at source_ready + 6 (for VALU source) or +10 (for trans source)
# Check if source is from a trans instruction (its ready time would be > issue + 6)
src_is_trans = any(self.vgpr_ready.get(r, 0) > self.cycle + 6 for _, r in self._get_src_regs(inst) if _ == 'v')
exec_cycle = src_ready + (10 if src_is_trans else 6)
else:
# Independent trans: ALUEXEC at issue + 9
exec_cycle = self.cycle + TRANS_LATENCY
self.pending_exec.append((exec_cycle, AluSrc.VALU))
self._record_dst_ready(inst, exec_cycle) # Record when this trans result is ready
self.last_trans_exec = exec_cycle
self.last_trans_issue = self.cycle
# Trans instructions take 4 cycles to issue
self.cycle += TRANS_ISSUE_CYCLES - 1 # -1 because we add 1 at end of trace_inst
self.last_inst_type = 'TRANS'
else:
# Regular VALU: emit VALUINST, may interleave ALUEXEC
self.emit_valuinst()
# Check for dependency
src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0)
has_dep = src_ready > self.cycle
# Record issue info: (issue_cycle, has_dep, src_ready)
self.valu_issue_cycles.append((self.cycle, has_dep, src_ready))
# For dependency tracking, estimate result ready
# Note: we don't know last_issue yet, so use a placeholder that will be corrected in finalize
# The key insight is dependent instructions use src_ready + 10
if has_dep:
est_exec = src_ready + 10
else:
# For independent, result will be ready at (final last_issue + 8 + position)
# We approximate with current cycle + 8, which is close enough for dependency detection
est_exec = self.cycle + VALU_EXEC_LATENCY
self._record_dst_ready(inst, est_exec)
# Track first VALU issue for latency calculation
if self.first_valu_issue == 0:
self.first_valu_issue = self.cycle
# Emit any pending ALUEXECs that have completed by now (after VALUINST)
while self.pending_exec and self.pending_exec[0][0] <= self.cycle:
exec_cycle, src = self.pending_exec.pop(0)
old_cycle = self.cycle
self.cycle = exec_cycle
self.emit_aluexec(src)
self.cycle = old_cycle
# Check for dependency
src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0)
has_dep = src_ready > self.cycle
if has_dep:
# Dependent: first dependent is +6 from source, subsequent chained deps are +5
# "Chained" means the source was also dependent (not the first independent VALU)
chained = src_ready > self.first_valu_issue + VALU_EXEC_LATENCY
exec_cycle = src_ready + (5 if chained else 6)
else:
# Independent VALU timing: 6-cycle pipeline latency
# exec[i] = first_issue + 6 + i, with +1 spacing between consecutive execs
exec_cycle = self.first_valu_issue + VALU_EXEC_LATENCY + self.valu_count
exec_cycle = max(exec_cycle, self.last_valu_exec + 1) # +1 spacing
# Add to pending exec queue (sorted by time)
self.pending_exec.append((exec_cycle, AluSrc.VALU))
self.pending_exec.sort(key=lambda x: x[0])
self.last_valu_exec = exec_cycle
self.valu_count += 1
self._record_dst_ready(inst, exec_cycle)
self.last_inst_type = 'VALU'
elif isinstance(inst, VOPD):
from extra.assembly.amd.sqtt import AluSrc
# First emit VALUINST (HW emits issue before completion at same cycle)
self.emit_valuinst()
# Emit any pending ALUEXECs that have completed by now
while self.pending_exec and self.pending_exec[0][0] <= self.cycle:
exec_cycle, src = self.pending_exec.pop(0)
old_cycle = self.cycle
self.cycle = exec_cycle
self.emit_aluexec(src)
self.cycle = old_cycle
src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0)
has_dep = src_ready > self.cycle
self.valu_issue_cycles.append((self.cycle, has_dep, src_ready))
if has_dep:
est_exec = src_ready + 10
exec_cycle = src_ready + 10
else:
est_exec = self.cycle + VALU_EXEC_LATENCY
self._record_dst_ready(inst, est_exec)
exec_cycle = self.cycle + VALU_EXEC_LATENCY
if self.last_valu_exec > 0:
exec_cycle = max(exec_cycle, self.last_valu_exec + 1)
self.pending_exec.append((exec_cycle, AluSrc.VALU))
self.pending_exec.sort(key=lambda x: x[0])
self.last_valu_exec = exec_cycle
self._record_dst_ready(inst, exec_cycle)
self.cycle += 1
def finalize(self):
"""Emit all pending ALUEXEC packets and WAVEEND."""
"""Emit all remaining pending ALUEXEC packets and WAVEEND."""
from extra.assembly.amd.sqtt import AluSrc
# Process VALU instructions
# First pass: compute actual exec times for all instructions
# - Independent: at last_issue + 8 + position, serialized at +1 intervals
# - Dependent: at src_exec + 10 (first dep) or + 9 (chained dep)
if self.valu_issue_cycles:
last_issue = self.valu_issue_cycles[-1][0]
base_exec = last_issue + VALU_EXEC_LATENCY
# Build exec_times list with actual completion times
exec_times = []
last_exec = 0
last_was_dep = False
for i, (issue_cycle, has_dep, src_ready_idx) in enumerate(self.valu_issue_cycles):
if has_dep:
# src_ready_idx is the vgpr_ready value at trace time, which was an estimate
# We need to find the actual exec time of the instruction that wrote this value
# For now, use a simpler model: dependent instructions get +10 from previous exec (first dep) or +9 (chained)
exec_cycle = last_exec + (10 if not last_was_dep else 9)
last_was_dep = True
else:
# Independent: at base_exec + position, serialized at +1 intervals
exec_cycle = max(base_exec + i, last_exec + 1)
last_was_dep = False
exec_times.append(exec_cycle)
last_exec = exec_cycle
for exec_cycle in exec_times:
self.pending_exec.append((exec_cycle, AluSrc.VALU))
self.valu_issue_cycles.clear()
# Sort and emit all pending ALUEXEC
# Emit any remaining pending ALUEXECs
self.pending_exec.sort(key=lambda x: x[0])
last_src = None
for exec_cycle, src in self.pending_exec:
@@ -305,11 +405,22 @@ class SQTTState:
self.emit_aluexec(src)
last_src = src
self.pending_exec.clear()
# WAVEEND timing: 14 cycles after last instruction if no ALU, 13/15 after last ALUEXEC
if last_src is None:
# WAVEEND timing depends on what comes last
# If last instruction was SOPP: +1 normally, but +11 if last ALUEXEC was recent (VALU drain)
# Otherwise: 14 cycles for no ALU, 20 for trans, 14/15 for SALU/VALU
if self.last_inst_type == 'SOPP':
# Check if last ALUEXEC was recent (within ~5 cycles of current)
if self.valu_count > 0 and self.last_valu_exec >= self.cycle - 5:
self.cycle += 11 # VALU drain time
else:
self.cycle += 1
elif last_src is None:
self.cycle += 14 # empty program or no ALU ops
elif self.trans_count > 0:
self.cycle += 20 # trans has longer WAVEEND delay
else:
self.cycle += 15 if last_src == AluSrc.VALU else 13
self.cycle += 15 if last_src == AluSrc.VALU else 14
self.emit_waveend()
# VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup)

View File

@@ -61,16 +61,13 @@ class TestSQTTCodec(unittest.TestCase):
from extra.assembly.amd.emu import SQTTState, decode_program
from extra.assembly.amd.sqtt import VALUINST, ALUEXEC
from extra.assembly.amd.autogen.rdna3.ins import v_mov_b32_e32, v_add_f32_e32, s_mov_b32, s_add_u32, s_endpgm
from extra.assembly.amd.autogen.rdna3.ins import v_mov_b32_e32, v_add_f32_e32, v_rcp_f32_e32, v_sqrt_f32_e32, v_exp_f32_e32, s_mov_b32, s_add_u32
from extra.assembly.amd.dsl import v, s
from extra.assembly.amd.test.test_sqtt_hw import compile_asm_sqtt, run_prg_sqtt, run_prg_sqtt_batch, get_wave_packets, format_packet
def assemble(instructions: list) -> bytes:
return b''.join(inst.to_bytes() for inst in instructions)
from extra.assembly.amd.test.test_sqtt_hw import compile_asm_sqtt, run_prg_sqtt, run_prg_sqtt_batch, get_wave_packets, format_packet, assemble, wrap_with_nops
def run_emulator_sqtt(instructions: list) -> list[PacketType]:
"""Run instructions through emulator and return SQTT packets."""
code = assemble(instructions + [s_endpgm()])
code = assemble(wrap_with_nops(instructions))
program = decode_program(code)
sqtt = SQTTState(wave_id=0, simd=0, cu=0)
@@ -103,7 +100,7 @@ def get_timing_deltas(packets: list) -> list[tuple[str, int]]:
class TestEmulatorSQTT(unittest.TestCase):
"""Tests comparing emulator SQTT to hardware SQTT."""
def _run_and_compare(self, instructions: list, name: str = "", n_runs: int = 100, min_identical: int = 30, max_attempts: int = 5):
def _run_and_compare(self, instructions: list, name: str = "", n_runs: int = 100, min_identical: int = 25, max_attempts: int = 10):
"""Run instructions on both hardware and emulator, compare SQTT structure."""
from collections import Counter
@@ -187,14 +184,12 @@ class TestEmulatorSQTT(unittest.TestCase):
t = p._time - emu_t0 if hasattr(p, '_time') else 0
print(f" {t:8d}: {format_packet(p)}")
# Compare packet structure (types only, ignoring timing jitter)
# Extract just packet types from emulator
emu_types = tuple(t for t, _ in emu_deltas)
# Find HW patterns with matching structure
matching_structures = [p for p in pattern_counts if tuple(t for t, _ in p) == emu_types]
self.assertGreater(len(matching_structures), 0,
f"{name}: emulator packet structure {emu_types} not found in any HW traces.\n"
f"HW structures: {set(tuple(t for t, _ in p) for p in pattern_counts)}")
# Assert emulator pattern matches most common HW pattern exactly
emu_pattern = tuple(emu_deltas)
self.assertIn(emu_pattern, pattern_counts,
f"{name}: emulator pattern not found in HW traces.\n"
f"Emulator: {emu_deltas}\n"
f"HW patterns: {[list(p) for p in pattern_counts.most_common(3)]}")
def test_salu_independent(self):
"""SALU instructions with no dependencies."""
@@ -216,13 +211,21 @@ class TestEmulatorSQTT(unittest.TestCase):
"""Empty program - just s_endpgm."""
self._run_and_compare([], "empty")
def test_valu_independent(self):
def _test_valu_independent_n(self, n: int):
"""VALU instructions with no dependencies."""
self._run_and_compare([
v_mov_b32_e32(v[0], 1.0),
v_mov_b32_e32(v[1], 2.0),
v_mov_b32_e32(v[2], 3.0),
], "3 VALU independent")
v_mov_b32_e32(v[i], float(i)) for i in range(n)
], f"{n} VALU independent")
def test_valu_independent_1(self): self._test_valu_independent_n(1)
def test_valu_independent_2(self): self._test_valu_independent_n(2)
def test_valu_independent_3(self): self._test_valu_independent_n(3)
def test_valu_independent_4(self): self._test_valu_independent_n(4)
def test_valu_independent_5(self): self._test_valu_independent_n(5)
def test_valu_independent_6(self): self._test_valu_independent_n(6)
def test_valu_independent_7(self): self._test_valu_independent_n(7)
def test_valu_independent_8(self): self._test_valu_independent_n(8)
def test_valu_independent_16(self): self._test_valu_independent_n(16)
def test_valu_chain(self):
"""VALU instructions with chain dependencies."""
@@ -232,6 +235,22 @@ class TestEmulatorSQTT(unittest.TestCase):
v_add_f32_e32(v[2], v[1], v[1]),
], "3 VALU chain")
def test_trans_independent(self):
"""Transcendental instructions with no dependencies."""
self._run_and_compare([
v_rcp_f32_e32(v[0], v[0]),
v_sqrt_f32_e32(v[1], v[1]),
v_exp_f32_e32(v[2], v[2]),
], "3 TRANS independent")
def test_trans_chain(self):
"""Transcendental instructions with chain dependencies."""
self._run_and_compare([
v_mov_b32_e32(v[0], 1.0),
v_rcp_f32_e32(v[1], v[0]),
v_sqrt_f32_e32(v[2], v[1]),
], "3 TRANS chain")
if __name__ == "__main__":
unittest.main()

View File

@@ -132,6 +132,10 @@ def print_all_packets(packets: list) -> None:
def assemble(instructions: list) -> bytes:
return b''.join(inst.to_bytes() for inst in instructions)
def wrap_with_nops(instructions: list) -> list:
"""Add trailing NOPs and s_endpgm for clean SQTT timing."""
return instructions + [s_nop(0)]*64 + [s_endpgm()]
def compile_asm_sqtt(instructions: list, alu_only: bool = False) -> AMDProgram:
"""Compile instructions to an AMDProgram for SQTT tracing.
@@ -142,8 +146,8 @@ def compile_asm_sqtt(instructions: list, alu_only: bool = False) -> AMDProgram:
Compiled AMDProgram ready to run
"""
compiler = HIPCompiler(dev.arch)
instructions = instructions + [s_endpgm()]
code = assemble(instructions)
# Add NOPs before s_endpgm to flush pipeline and get clean timing
code = assemble(wrap_with_nops(instructions))
byte_str = ', '.join(f'0x{b:02x}' for b in code)
if alu_only:
@@ -157,7 +161,7 @@ test:
.rodata
.p2align 6
.amdhsa_kernel test
.amdhsa_next_free_vgpr 8
.amdhsa_next_free_vgpr 64
.amdhsa_next_free_sgpr 8
.amdhsa_wavefront_size32 1
.amdhsa_group_segment_fixed_size 0
@@ -178,7 +182,7 @@ amdhsa.kernels:
.kernarg_segment_align: 8
.wavefront_size: 32
.sgpr_count: 8
.vgpr_count: 8
.vgpr_count: 64
.max_flat_workgroup_size: 1024
...
.end_amdgpu_metadata