diff --git a/extra/assembly/amd/emu.py b/extra/assembly/amd/emu.py index 720d2037f4..a9599a0b13 100644 --- a/extra/assembly/amd/emu.py +++ b/extra/assembly/amd/emu.py @@ -103,16 +103,20 @@ _TRANS_OPS = {'V_RCP_F32', 'V_RCP_F64', 'V_RSQ_F32', 'V_RSQ_F64', 'V_SQRT_F32', # Startup: WAVESTART -> first instruction (~32 cycles with warm cache, no REG packet) # Cold cache: WAVESTART -> REG (~137 cycles) -> first instruction (~150 cycles after REG) # SALU: issues every cycle, result ready 2 cycles after issue, ALUEXEC at ready time -# VALU: issues every cycle, ALUEXEC at issue+8 for each inst, serialized with +1 intervals +# VALU: issues every cycle, ALUEXEC at issue+6 for each inst, serialized with +1 intervals +# TRANS: issues every 4 cycles, ALUEXEC at last_issue+1, then +8 intervals # For dependent instructions, ALUEXEC is at source_ready + 10 (first dep) or + 9 (chained) +# VALU queue depth ~7: after 7 in-flight, ALUEXEC interleaves with VALUINST WAVESTART_TO_INST_CYCLES = 32 # cycles from WAVESTART to first instruction (warm cache) SALU_LATENCY = 2 # cycles from issue to result ready -VALU_EXEC_LATENCY = 8 # cycles from issue to ALUEXEC for each instruction +VALU_EXEC_LATENCY = 6 # cycles from first issue to first ALUEXEC +TRANS_ISSUE_CYCLES = 4 # cycles between transcendental instruction issues +TRANS_LATENCY = 9 # cycles from trans issue to ALUEXEC class SQTTState: """SQTT tracing state - emits packets matching real hardware (warm cache model).""" __slots__ = ('cycle', 'packets', 'pending_exec', 'wave_id', 'simd', 'cu', 'sgpr_ready', 'vgpr_ready', - 'last_salu_exec', 'last_valu_exec', 'valu_issue_cycles') + 'last_salu_exec', 'last_valu_exec', 'trans_count', 'last_trans_issue', 'last_trans_exec', 'first_valu_issue', 'valu_count', 'last_inst_type') def __init__(self, wave_id: int = 0, simd: int = 0, cu: int = 0): self.cycle = 0 @@ -120,10 +124,15 @@ class SQTTState: self.pending_exec: list[tuple[int, int]] = [] # (completion_cycle, src) for ALUEXEC self.wave_id, self.simd, self.cu = wave_id, simd, cu self.sgpr_ready: dict[int, int] = {} # sgpr -> cycle when result ready + self.last_inst_type: str = '' # track last instruction type for gap calculation self.vgpr_ready: dict[int, int] = {} # vgpr -> cycle when result ready self.last_salu_exec = 0 # last SALU ALUEXEC time (for +1 spacing) self.last_valu_exec = 0 # last VALU ALUEXEC time (for +1 spacing) - self.valu_issue_cycles: list[int] = [] # issue cycles for pending independent VALU + self.trans_count = 0 # number of trans instructions issued + self.last_trans_issue = 0 # cycle when last trans was issued + self.last_trans_exec = 0 # cycle when last trans ALUEXEC is scheduled + self.first_valu_issue = 0 # cycle when first VALU was issued + self.valu_count = 0 # number of VALU instructions issued def emit_wavestart(self): from extra.assembly.amd.sqtt import WAVESTART @@ -146,6 +155,10 @@ class SQTTState: from extra.assembly.amd.sqtt import ALUEXEC self.packets.append(ALUEXEC(_time=self.cycle, src=src)) + def emit_immediate(self): + from extra.assembly.amd.sqtt import IMMEDIATE + self.packets.append(IMMEDIATE(_time=self.cycle, wave=self.wave_id)) + def _get_src_regs(self, inst: Inst) -> list[tuple[str, int]]: """Extract source register references from instruction.""" srcs = [] @@ -218,86 +231,173 @@ class SQTTState: self.pending_exec.append((exec_cycle, AluSrc.SALU)) self.last_salu_exec = exec_cycle self._record_dst_ready(inst, ready_cycle) + self.last_inst_type = 'SALU' elif isinstance(inst, SOPP): - pass # nop, waitcnt, etc don't emit packets + # s_nop and other SOPP emit IMMEDIATE packets + if self.last_inst_type == 'SALU': + # SALU→SOPP: emit ALUEXECs with max gap of 2 once, then IMMEDIATEs + # First ALUEXEC at last INST cycle, then allow one +2 gap, rest interleave + # cycle is currently 1 past last INST + last_inst_cycle = self.cycle - 1 + first = True + had_gap = False # track if we've had a +2 gap + while self.pending_exec: + exec_cycle, src = self.pending_exec[0] + if first: + self.pending_exec.pop(0) + self.cycle = last_inst_cycle # First ALUEXEC at same cycle as last INST + first = False + elif exec_cycle == self.cycle + 1: # consecutive (+1) + self.pending_exec.pop(0) + self.cycle = exec_cycle + elif exec_cycle == self.cycle + 2 and not had_gap: # allow one +2 gap + self.pending_exec.pop(0) + self.cycle = exec_cycle + had_gap = True + else: + break # too large a gap or second +2 gap - rest interleave with IMMEDIATEs + self.emit_aluexec(src) + self.cycle += 1 # IMMEDIATE starts +1 after last ALUEXEC + elif self.last_inst_type == 'TRANS': + # TRANS→SOPP: emit first ALUEXEC at its scheduled time, then IMMEDIATE +2 after + # Sort pending execs and emit first one if it's before IMMEDIATE time + self.pending_exec.sort(key=lambda x: x[0]) + if self.pending_exec: + first_exec_cycle = self.pending_exec[0][0] + # First IMMEDIATE at first_exec + 2 + imm_cycle = first_exec_cycle + 2 + # Emit ALUEXECs that come before IMMEDIATE + while self.pending_exec and self.pending_exec[0][0] < imm_cycle: + exec_cycle, src = self.pending_exec.pop(0) + self.cycle = exec_cycle + self.emit_aluexec(src) + self.cycle = imm_cycle + elif self.last_inst_type == 'VALU': + # VALU→SOPP: 2-cycle gap, emit any ALUEXECs that would come before + imm_cycle = self.cycle + 2 # When IMMEDIATE would be emitted + # Emit ALUEXECs that come before the IMMEDIATE + while self.pending_exec and self.pending_exec[0][0] < imm_cycle: + exec_cycle, src = self.pending_exec.pop(0) + self.cycle = max(self.cycle, exec_cycle) + self.emit_aluexec(src) + # Jump to IMMEDIATE cycle (either imm_cycle or 1 after last ALUEXEC, whichever is later) + self.cycle = max(imm_cycle, self.cycle + 1) + # Emit IMMEDIATE first, then any pending ALUEXECs at same cycle (HW order) + self.emit_immediate() + while self.pending_exec and self.pending_exec[0][0] <= self.cycle: + exec_cycle, src = self.pending_exec.pop(0) + old_cycle = self.cycle + self.cycle = exec_cycle + self.emit_aluexec(src) + self.cycle = old_cycle + self.last_inst_type = 'SOPP' elif isinstance(inst, SMEM): pass # skip for ALU focus elif isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC)): - # VALU: issue now, track for ALUEXEC timing in finalize() - # All ALUEXEC start at last_issue + 8, then serialize based on dependencies + # VALU: issue now, emit ALUEXEC for completed instructions if queue is full + from extra.assembly.amd.sqtt import AluSrc + op_name = inst.op_name if hasattr(inst, 'op_name') else '' - if any(t in op_name for t in _TRANS_OPS): + is_trans = any(t in op_name for t in _TRANS_OPS) + + if is_trans: + # Transcendental: emit INST, 4-cycle issue, add ALUEXEC to pending self.emit_inst(InstOp.VALU_TRANS) + self.trans_count += 1 + # Check for dependency on VGPR sources + src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0) + if src_ready > self.cycle: + # Dependent trans: ALUEXEC at source_ready + 6 (for VALU source) or +10 (for trans source) + # Check if source is from a trans instruction (its ready time would be > issue + 6) + src_is_trans = any(self.vgpr_ready.get(r, 0) > self.cycle + 6 for _, r in self._get_src_regs(inst) if _ == 'v') + exec_cycle = src_ready + (10 if src_is_trans else 6) + else: + # Independent trans: ALUEXEC at issue + 9 + exec_cycle = self.cycle + TRANS_LATENCY + self.pending_exec.append((exec_cycle, AluSrc.VALU)) + self._record_dst_ready(inst, exec_cycle) # Record when this trans result is ready + self.last_trans_exec = exec_cycle + self.last_trans_issue = self.cycle + # Trans instructions take 4 cycles to issue + self.cycle += TRANS_ISSUE_CYCLES - 1 # -1 because we add 1 at end of trace_inst + self.last_inst_type = 'TRANS' else: + # Regular VALU: emit VALUINST, may interleave ALUEXEC self.emit_valuinst() - # Check for dependency - src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0) - has_dep = src_ready > self.cycle - # Record issue info: (issue_cycle, has_dep, src_ready) - self.valu_issue_cycles.append((self.cycle, has_dep, src_ready)) - # For dependency tracking, estimate result ready - # Note: we don't know last_issue yet, so use a placeholder that will be corrected in finalize - # The key insight is dependent instructions use src_ready + 10 - if has_dep: - est_exec = src_ready + 10 - else: - # For independent, result will be ready at (final last_issue + 8 + position) - # We approximate with current cycle + 8, which is close enough for dependency detection - est_exec = self.cycle + VALU_EXEC_LATENCY - self._record_dst_ready(inst, est_exec) + + # Track first VALU issue for latency calculation + if self.first_valu_issue == 0: + self.first_valu_issue = self.cycle + + # Emit any pending ALUEXECs that have completed by now (after VALUINST) + while self.pending_exec and self.pending_exec[0][0] <= self.cycle: + exec_cycle, src = self.pending_exec.pop(0) + old_cycle = self.cycle + self.cycle = exec_cycle + self.emit_aluexec(src) + self.cycle = old_cycle + + # Check for dependency + src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0) + has_dep = src_ready > self.cycle + + if has_dep: + # Dependent: first dependent is +6 from source, subsequent chained deps are +5 + # "Chained" means the source was also dependent (not the first independent VALU) + chained = src_ready > self.first_valu_issue + VALU_EXEC_LATENCY + exec_cycle = src_ready + (5 if chained else 6) + else: + # Independent VALU timing: 6-cycle pipeline latency + # exec[i] = first_issue + 6 + i, with +1 spacing between consecutive execs + exec_cycle = self.first_valu_issue + VALU_EXEC_LATENCY + self.valu_count + exec_cycle = max(exec_cycle, self.last_valu_exec + 1) # +1 spacing + + # Add to pending exec queue (sorted by time) + self.pending_exec.append((exec_cycle, AluSrc.VALU)) + self.pending_exec.sort(key=lambda x: x[0]) + self.last_valu_exec = exec_cycle + self.valu_count += 1 + self._record_dst_ready(inst, exec_cycle) + self.last_inst_type = 'VALU' elif isinstance(inst, VOPD): + from extra.assembly.amd.sqtt import AluSrc + + # First emit VALUINST (HW emits issue before completion at same cycle) self.emit_valuinst() + + # Emit any pending ALUEXECs that have completed by now + while self.pending_exec and self.pending_exec[0][0] <= self.cycle: + exec_cycle, src = self.pending_exec.pop(0) + old_cycle = self.cycle + self.cycle = exec_cycle + self.emit_aluexec(src) + self.cycle = old_cycle src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0) has_dep = src_ready > self.cycle - self.valu_issue_cycles.append((self.cycle, has_dep, src_ready)) + if has_dep: - est_exec = src_ready + 10 + exec_cycle = src_ready + 10 else: - est_exec = self.cycle + VALU_EXEC_LATENCY - self._record_dst_ready(inst, est_exec) + exec_cycle = self.cycle + VALU_EXEC_LATENCY + if self.last_valu_exec > 0: + exec_cycle = max(exec_cycle, self.last_valu_exec + 1) + + self.pending_exec.append((exec_cycle, AluSrc.VALU)) + self.pending_exec.sort(key=lambda x: x[0]) + self.last_valu_exec = exec_cycle + self._record_dst_ready(inst, exec_cycle) self.cycle += 1 def finalize(self): - """Emit all pending ALUEXEC packets and WAVEEND.""" + """Emit all remaining pending ALUEXEC packets and WAVEEND.""" from extra.assembly.amd.sqtt import AluSrc - # Process VALU instructions - # First pass: compute actual exec times for all instructions - # - Independent: at last_issue + 8 + position, serialized at +1 intervals - # - Dependent: at src_exec + 10 (first dep) or + 9 (chained dep) - if self.valu_issue_cycles: - last_issue = self.valu_issue_cycles[-1][0] - base_exec = last_issue + VALU_EXEC_LATENCY - - # Build exec_times list with actual completion times - exec_times = [] - last_exec = 0 - last_was_dep = False - - for i, (issue_cycle, has_dep, src_ready_idx) in enumerate(self.valu_issue_cycles): - if has_dep: - # src_ready_idx is the vgpr_ready value at trace time, which was an estimate - # We need to find the actual exec time of the instruction that wrote this value - # For now, use a simpler model: dependent instructions get +10 from previous exec (first dep) or +9 (chained) - exec_cycle = last_exec + (10 if not last_was_dep else 9) - last_was_dep = True - else: - # Independent: at base_exec + position, serialized at +1 intervals - exec_cycle = max(base_exec + i, last_exec + 1) - last_was_dep = False - exec_times.append(exec_cycle) - last_exec = exec_cycle - - for exec_cycle in exec_times: - self.pending_exec.append((exec_cycle, AluSrc.VALU)) - self.valu_issue_cycles.clear() - - # Sort and emit all pending ALUEXEC + # Emit any remaining pending ALUEXECs self.pending_exec.sort(key=lambda x: x[0]) last_src = None for exec_cycle, src in self.pending_exec: @@ -305,11 +405,22 @@ class SQTTState: self.emit_aluexec(src) last_src = src self.pending_exec.clear() - # WAVEEND timing: 14 cycles after last instruction if no ALU, 13/15 after last ALUEXEC - if last_src is None: + + # WAVEEND timing depends on what comes last + # If last instruction was SOPP: +1 normally, but +11 if last ALUEXEC was recent (VALU drain) + # Otherwise: 14 cycles for no ALU, 20 for trans, 14/15 for SALU/VALU + if self.last_inst_type == 'SOPP': + # Check if last ALUEXEC was recent (within ~5 cycles of current) + if self.valu_count > 0 and self.last_valu_exec >= self.cycle - 5: + self.cycle += 11 # VALU drain time + else: + self.cycle += 1 + elif last_src is None: self.cycle += 14 # empty program or no ALU ops + elif self.trans_count > 0: + self.cycle += 20 # trans has longer WAVEEND delay else: - self.cycle += 15 if last_src == AluSrc.VALU else 13 + self.cycle += 15 if last_src == AluSrc.VALU else 14 self.emit_waveend() # VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup) diff --git a/extra/assembly/amd/test/test_sqtt_compare.py b/extra/assembly/amd/test/test_sqtt_compare.py index 04c91ef605..1c8e29fd30 100644 --- a/extra/assembly/amd/test/test_sqtt_compare.py +++ b/extra/assembly/amd/test/test_sqtt_compare.py @@ -61,16 +61,13 @@ class TestSQTTCodec(unittest.TestCase): from extra.assembly.amd.emu import SQTTState, decode_program from extra.assembly.amd.sqtt import VALUINST, ALUEXEC -from extra.assembly.amd.autogen.rdna3.ins import v_mov_b32_e32, v_add_f32_e32, s_mov_b32, s_add_u32, s_endpgm +from extra.assembly.amd.autogen.rdna3.ins import v_mov_b32_e32, v_add_f32_e32, v_rcp_f32_e32, v_sqrt_f32_e32, v_exp_f32_e32, s_mov_b32, s_add_u32 from extra.assembly.amd.dsl import v, s -from extra.assembly.amd.test.test_sqtt_hw import compile_asm_sqtt, run_prg_sqtt, run_prg_sqtt_batch, get_wave_packets, format_packet - -def assemble(instructions: list) -> bytes: - return b''.join(inst.to_bytes() for inst in instructions) +from extra.assembly.amd.test.test_sqtt_hw import compile_asm_sqtt, run_prg_sqtt, run_prg_sqtt_batch, get_wave_packets, format_packet, assemble, wrap_with_nops def run_emulator_sqtt(instructions: list) -> list[PacketType]: """Run instructions through emulator and return SQTT packets.""" - code = assemble(instructions + [s_endpgm()]) + code = assemble(wrap_with_nops(instructions)) program = decode_program(code) sqtt = SQTTState(wave_id=0, simd=0, cu=0) @@ -103,7 +100,7 @@ def get_timing_deltas(packets: list) -> list[tuple[str, int]]: class TestEmulatorSQTT(unittest.TestCase): """Tests comparing emulator SQTT to hardware SQTT.""" - def _run_and_compare(self, instructions: list, name: str = "", n_runs: int = 100, min_identical: int = 30, max_attempts: int = 5): + def _run_and_compare(self, instructions: list, name: str = "", n_runs: int = 100, min_identical: int = 25, max_attempts: int = 10): """Run instructions on both hardware and emulator, compare SQTT structure.""" from collections import Counter @@ -187,14 +184,12 @@ class TestEmulatorSQTT(unittest.TestCase): t = p._time - emu_t0 if hasattr(p, '_time') else 0 print(f" {t:8d}: {format_packet(p)}") - # Compare packet structure (types only, ignoring timing jitter) - # Extract just packet types from emulator - emu_types = tuple(t for t, _ in emu_deltas) - # Find HW patterns with matching structure - matching_structures = [p for p in pattern_counts if tuple(t for t, _ in p) == emu_types] - self.assertGreater(len(matching_structures), 0, - f"{name}: emulator packet structure {emu_types} not found in any HW traces.\n" - f"HW structures: {set(tuple(t for t, _ in p) for p in pattern_counts)}") + # Assert emulator pattern matches most common HW pattern exactly + emu_pattern = tuple(emu_deltas) + self.assertIn(emu_pattern, pattern_counts, + f"{name}: emulator pattern not found in HW traces.\n" + f"Emulator: {emu_deltas}\n" + f"HW patterns: {[list(p) for p in pattern_counts.most_common(3)]}") def test_salu_independent(self): """SALU instructions with no dependencies.""" @@ -216,13 +211,21 @@ class TestEmulatorSQTT(unittest.TestCase): """Empty program - just s_endpgm.""" self._run_and_compare([], "empty") - def test_valu_independent(self): + def _test_valu_independent_n(self, n: int): """VALU instructions with no dependencies.""" self._run_and_compare([ - v_mov_b32_e32(v[0], 1.0), - v_mov_b32_e32(v[1], 2.0), - v_mov_b32_e32(v[2], 3.0), - ], "3 VALU independent") + v_mov_b32_e32(v[i], float(i)) for i in range(n) + ], f"{n} VALU independent") + + def test_valu_independent_1(self): self._test_valu_independent_n(1) + def test_valu_independent_2(self): self._test_valu_independent_n(2) + def test_valu_independent_3(self): self._test_valu_independent_n(3) + def test_valu_independent_4(self): self._test_valu_independent_n(4) + def test_valu_independent_5(self): self._test_valu_independent_n(5) + def test_valu_independent_6(self): self._test_valu_independent_n(6) + def test_valu_independent_7(self): self._test_valu_independent_n(7) + def test_valu_independent_8(self): self._test_valu_independent_n(8) + def test_valu_independent_16(self): self._test_valu_independent_n(16) def test_valu_chain(self): """VALU instructions with chain dependencies.""" @@ -232,6 +235,22 @@ class TestEmulatorSQTT(unittest.TestCase): v_add_f32_e32(v[2], v[1], v[1]), ], "3 VALU chain") + def test_trans_independent(self): + """Transcendental instructions with no dependencies.""" + self._run_and_compare([ + v_rcp_f32_e32(v[0], v[0]), + v_sqrt_f32_e32(v[1], v[1]), + v_exp_f32_e32(v[2], v[2]), + ], "3 TRANS independent") + + def test_trans_chain(self): + """Transcendental instructions with chain dependencies.""" + self._run_and_compare([ + v_mov_b32_e32(v[0], 1.0), + v_rcp_f32_e32(v[1], v[0]), + v_sqrt_f32_e32(v[2], v[1]), + ], "3 TRANS chain") + if __name__ == "__main__": unittest.main() diff --git a/extra/assembly/amd/test/test_sqtt_hw.py b/extra/assembly/amd/test/test_sqtt_hw.py index 3612a81298..bd2b183a8e 100644 --- a/extra/assembly/amd/test/test_sqtt_hw.py +++ b/extra/assembly/amd/test/test_sqtt_hw.py @@ -132,6 +132,10 @@ def print_all_packets(packets: list) -> None: def assemble(instructions: list) -> bytes: return b''.join(inst.to_bytes() for inst in instructions) +def wrap_with_nops(instructions: list) -> list: + """Add trailing NOPs and s_endpgm for clean SQTT timing.""" + return instructions + [s_nop(0)]*64 + [s_endpgm()] + def compile_asm_sqtt(instructions: list, alu_only: bool = False) -> AMDProgram: """Compile instructions to an AMDProgram for SQTT tracing. @@ -142,8 +146,8 @@ def compile_asm_sqtt(instructions: list, alu_only: bool = False) -> AMDProgram: Compiled AMDProgram ready to run """ compiler = HIPCompiler(dev.arch) - instructions = instructions + [s_endpgm()] - code = assemble(instructions) + # Add NOPs before s_endpgm to flush pipeline and get clean timing + code = assemble(wrap_with_nops(instructions)) byte_str = ', '.join(f'0x{b:02x}' for b in code) if alu_only: @@ -157,7 +161,7 @@ test: .rodata .p2align 6 .amdhsa_kernel test - .amdhsa_next_free_vgpr 8 + .amdhsa_next_free_vgpr 64 .amdhsa_next_free_sgpr 8 .amdhsa_wavefront_size32 1 .amdhsa_group_segment_fixed_size 0 @@ -178,7 +182,7 @@ amdhsa.kernels: .kernarg_segment_align: 8 .wavefront_size: 32 .sgpr_count: 8 - .vgpr_count: 8 + .vgpr_count: 64 .max_flat_workgroup_size: 1024 ... .end_amdgpu_metadata