From c793076fb6d224256cd9edad1a3c41eca462facf Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 11 Jan 2026 11:28:21 +0900 Subject: [PATCH] add s_delay_alu tests --- extra/assembly/amd/emu.py | 62 ++++++++-------- extra/assembly/amd/test/test_sqtt_correct.py | 76 +++++++++++++++++++- 2 files changed, 109 insertions(+), 29 deletions(-) diff --git a/extra/assembly/amd/emu.py b/extra/assembly/amd/emu.py index f708613ffc..181f504d8d 100644 --- a/extra/assembly/amd/emu.py +++ b/extra/assembly/amd/emu.py @@ -489,42 +489,48 @@ class SQTTState: self.cycle += 1 if DEBUG >= 3: print(f"C{self.cycle}:", end="") - # 1. Clear forward from previous cycle (forward only lasts 1 cycle) - prev_forward = self.forward_vgpr - self.forward_vgpr = None + # Pipeline timing model for 6,5,5 pattern: + # - Base latency (const source): 6 cycles from issue to exec + # - First dependent: 6 cycles from producer exec (no forwarding benefit) + # - Subsequent dependents: 5 cycles from producer exec (forwarding saves 1 cycle) + # + # Forwarding is available from ALU[3] output. When instruction exits ALU[3]->WB, + # a waiting dependent can enter ALU[0] on the SAME cycle, giving delta=5. + # But the FIRST dependent can't benefit because it arrives before producer is in ALU[3]. - # 2. DepFetch -> ALU[0]: check if any instruction can enter ALU - # - Must always wait until ready_cycle (1 cycle in dep_fetch) - # - If forwarding available: deps resolved, just wait for ready_cycle - # - If no forwarding and VGPR cold: extra wait for regfile read - for item in self.dep_fetch[:]: - inst, dest, srcs, ready_cycle = item - if not self._deps_ready(srcs, prev_forward): continue # deps not resolved yet - if self.cycle < ready_cycle: continue # must wait for dep_fetch cycle - self.dep_fetch.remove(item) - self.alu[0] = (inst, dest) - uses_forwarding = prev_forward is not None and prev_forward in srcs - if DEBUG >= 3: print(f" DepFetch->ALU[0] v{dest}{'(fwd)' if uses_forwarding else ''}", end="") - break + # 1. Writeback: emit ALUEXEC + if self.writeback is not None: + inst, dest = self.writeback + self.emit(ALUEXEC, src=AluSrc.VALU) + if DEBUG >= 3: print(f" WB v{dest}", end="") + self.writeback = None + + # 2. ALU[3] -> Writeback, and set forwarding from ALU[3] output + # Forwarding is available THIS cycle for instructions entering ALU[0] + alu3_out = self.alu[3] + self.alu[3] = None + forward_this_cycle = alu3_out[1] if alu3_out is not None else None + if alu3_out is not None: + self.writeback = alu3_out + if DEBUG >= 3: print(f" ALU[3]->WB", end="") # 3. Shift ALU pipeline: [0]->[1]->[2]->[3] - # But first save ALU[3] for writeback - alu3_out = self.alu[3] self.alu[3] = self.alu[2] self.alu[2] = self.alu[1] self.alu[1] = self.alu[0] self.alu[0] = None - # 4. Writeback: emit ALUEXEC, make forward available for NEXT cycle - if self.writeback is not None: - inst, dest = self.writeback - self.emit(ALUEXEC, src=AluSrc.VALU) - self.forward_vgpr = dest # Available for next cycle's dep resolution - if DEBUG >= 3: print(f" WB v{dest}", end="") - - # 5. ALU[3] output -> Writeback for next cycle - self.writeback = alu3_out - if alu3_out is not None and DEBUG >= 3: print(f" ALU[3]->WB", end="") + # 4. DepFetch -> ALU[0]: check if any instruction can enter ALU + # Forwarding from ALU[3] output is available THIS cycle (not next cycle) + for item in self.dep_fetch[:]: + inst, dest, srcs, ready_cycle = item + if not self._deps_ready(srcs, forward_this_cycle): continue # deps not resolved yet + if self.cycle < ready_cycle: continue # must wait for dep_fetch cycle + self.dep_fetch.remove(item) + self.alu[0] = (inst, dest) + uses_forwarding = forward_this_cycle is not None and forward_this_cycle in srcs + if DEBUG >= 3: print(f" DepFetch->ALU[0] v{dest}{'(fwd)' if uses_forwarding else ''}", end="") + break if DEBUG >= 3: print() diff --git a/extra/assembly/amd/test/test_sqtt_correct.py b/extra/assembly/amd/test/test_sqtt_correct.py index 10eab5ded0..40354fa32e 100644 --- a/extra/assembly/amd/test/test_sqtt_correct.py +++ b/extra/assembly/amd/test/test_sqtt_correct.py @@ -20,7 +20,7 @@ if USE_HW: from extra.assembly.amd.emu import SQTTState, decode_program, exec_wave, WaveState, LDSMem from extra.assembly.amd.sqtt import WAVESTART, WAVEEND -from extra.assembly.amd.autogen.rdna3.ins import v_mov_b32_e32, v_add_f32_e32, s_nop, s_endpgm +from extra.assembly.amd.autogen.rdna3.ins import v_mov_b32_e32, v_add_f32_e32, s_nop, s_endpgm, s_delay_alu from extra.assembly.amd.dsl import v def assemble(instructions: list) -> bytes: @@ -426,5 +426,79 @@ class TestVALUExecWithNop(unittest.TestCase): def test_nop4_nop0(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(4), s_nop(0)]), 10) +class TestDelayALU(unittest.TestCase): + """s_delay_alu behavior - helps understand hardware pipeline latencies. + + s_delay_alu(simm16) where simm16 encodes: + instid0[3:0] = dependency on VALU N instructions back (1-4), 0=none + skip[6:4] = skip count for second dependency + instid1[10:7] = second dependency + + Key insight: s_delay_alu tells hardware to wait for a previous VALU to complete. + The hardware determines how many cycles to stall based on pipeline state. + """ + def _exec_delta(self, instrs): + """Return exec delta for last instruction.""" + _, execd = get_deltas(instrs) + return execd[-1] if execd else None + + # Direct dependency (producer -> consumer), instid0=1 means "wait for VALU 1 back" + def test_direct_no_delay(self): + # Without s_delay_alu: 6 cycles + self.assertEqual(self._exec_delta([v_mov_b32_e32(v[0], 1.0), v_add_f32_e32(v[1], v[0], v[0])]), 6) + + def test_direct_delay1(self): + # With s_delay_alu(instid0=1): 7 cycles (+1 from the delay instruction) + self.assertEqual(self._exec_delta([v_mov_b32_e32(v[0], 1.0), s_delay_alu(simm16=1), v_add_f32_e32(v[1], v[0], v[0])]), 7) + + def test_direct_delay2(self): + # instid0=2 doesn't apply (only 1 VALU back), so no extra delay + self.assertEqual(self._exec_delta([v_mov_b32_e32(v[0], 1.0), s_delay_alu(simm16=2), v_add_f32_e32(v[1], v[0], v[0])]), 6) + + def test_direct_delay3(self): + self.assertEqual(self._exec_delta([v_mov_b32_e32(v[0], 1.0), s_delay_alu(simm16=3), v_add_f32_e32(v[1], v[0], v[0])]), 6) + + def test_direct_delay4(self): + self.assertEqual(self._exec_delta([v_mov_b32_e32(v[0], 1.0), s_delay_alu(simm16=4), v_add_f32_e32(v[1], v[0], v[0])]), 6) + + # With 1 independent instruction between producer and consumer + def test_gap1_delay1(self): + # instid0=1 waits for the independent instruction (not the producer) + instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), s_delay_alu(simm16=1), v_add_f32_e32(v[1], v[0], v[0])] + self.assertEqual(self._exec_delta(instrs), 8) + + def test_gap1_delay2(self): + # instid0=2 waits for the producer (2 VALUs back) + instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), s_delay_alu(simm16=2), v_add_f32_e32(v[1], v[0], v[0])] + self.assertEqual(self._exec_delta(instrs), 6) + + def test_gap1_delay3(self): + # instid0=3 doesn't apply (only 2 VALUs back) + instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), s_delay_alu(simm16=3), v_add_f32_e32(v[1], v[0], v[0])] + self.assertEqual(self._exec_delta(instrs), 5) + + # With 2 independent instructions between + def test_gap2_delay1(self): + instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), v_mov_b32_e32(v[6], 6.0), + s_delay_alu(simm16=1), v_add_f32_e32(v[1], v[0], v[0])] + self.assertEqual(self._exec_delta(instrs), 7) + + def test_gap2_delay2(self): + instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), v_mov_b32_e32(v[6], 6.0), + s_delay_alu(simm16=2), v_add_f32_e32(v[1], v[0], v[0])] + self.assertEqual(self._exec_delta(instrs), 7) + + def test_gap2_delay3(self): + # instid0=3 waits for the producer (3 VALUs back) + instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), v_mov_b32_e32(v[6], 6.0), + s_delay_alu(simm16=3), v_add_f32_e32(v[1], v[0], v[0])] + self.assertEqual(self._exec_delta(instrs), 5) + + def test_gap2_delay4(self): + instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), v_mov_b32_e32(v[6], 6.0), + s_delay_alu(simm16=4), v_add_f32_e32(v[1], v[0], v[0])] + self.assertEqual(self._exec_delta(instrs), 4) + + if __name__ == "__main__": unittest.main()