mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
64 nops
This commit is contained in:
@@ -103,16 +103,20 @@ _TRANS_OPS = {'V_RCP_F32', 'V_RCP_F64', 'V_RSQ_F32', 'V_RSQ_F64', 'V_SQRT_F32',
|
||||
# Startup: WAVESTART -> first instruction (~32 cycles with warm cache, no REG packet)
|
||||
# Cold cache: WAVESTART -> REG (~137 cycles) -> first instruction (~150 cycles after REG)
|
||||
# SALU: issues every cycle, result ready 2 cycles after issue, ALUEXEC at ready time
|
||||
# VALU: issues every cycle, ALUEXEC at issue+8 for each inst, serialized with +1 intervals
|
||||
# VALU: issues every cycle, ALUEXEC at issue+6 for each inst, serialized with +1 intervals
|
||||
# TRANS: issues every 4 cycles, ALUEXEC at last_issue+1, then +8 intervals
|
||||
# For dependent instructions, ALUEXEC is at source_ready + 10 (first dep) or + 9 (chained)
|
||||
# VALU queue depth ~7: after 7 in-flight, ALUEXEC interleaves with VALUINST
|
||||
WAVESTART_TO_INST_CYCLES = 32 # cycles from WAVESTART to first instruction (warm cache)
|
||||
SALU_LATENCY = 2 # cycles from issue to result ready
|
||||
VALU_EXEC_LATENCY = 8 # cycles from issue to ALUEXEC for each instruction
|
||||
VALU_EXEC_LATENCY = 6 # cycles from first issue to first ALUEXEC
|
||||
TRANS_ISSUE_CYCLES = 4 # cycles between transcendental instruction issues
|
||||
TRANS_LATENCY = 9 # cycles from trans issue to ALUEXEC
|
||||
|
||||
class SQTTState:
|
||||
"""SQTT tracing state - emits packets matching real hardware (warm cache model)."""
|
||||
__slots__ = ('cycle', 'packets', 'pending_exec', 'wave_id', 'simd', 'cu', 'sgpr_ready', 'vgpr_ready',
|
||||
'last_salu_exec', 'last_valu_exec', 'valu_issue_cycles')
|
||||
'last_salu_exec', 'last_valu_exec', 'trans_count', 'last_trans_issue', 'last_trans_exec', 'first_valu_issue', 'valu_count', 'last_inst_type')
|
||||
|
||||
def __init__(self, wave_id: int = 0, simd: int = 0, cu: int = 0):
|
||||
self.cycle = 0
|
||||
@@ -120,10 +124,15 @@ class SQTTState:
|
||||
self.pending_exec: list[tuple[int, int]] = [] # (completion_cycle, src) for ALUEXEC
|
||||
self.wave_id, self.simd, self.cu = wave_id, simd, cu
|
||||
self.sgpr_ready: dict[int, int] = {} # sgpr -> cycle when result ready
|
||||
self.last_inst_type: str = '' # track last instruction type for gap calculation
|
||||
self.vgpr_ready: dict[int, int] = {} # vgpr -> cycle when result ready
|
||||
self.last_salu_exec = 0 # last SALU ALUEXEC time (for +1 spacing)
|
||||
self.last_valu_exec = 0 # last VALU ALUEXEC time (for +1 spacing)
|
||||
self.valu_issue_cycles: list[int] = [] # issue cycles for pending independent VALU
|
||||
self.trans_count = 0 # number of trans instructions issued
|
||||
self.last_trans_issue = 0 # cycle when last trans was issued
|
||||
self.last_trans_exec = 0 # cycle when last trans ALUEXEC is scheduled
|
||||
self.first_valu_issue = 0 # cycle when first VALU was issued
|
||||
self.valu_count = 0 # number of VALU instructions issued
|
||||
|
||||
def emit_wavestart(self):
|
||||
from extra.assembly.amd.sqtt import WAVESTART
|
||||
@@ -146,6 +155,10 @@ class SQTTState:
|
||||
from extra.assembly.amd.sqtt import ALUEXEC
|
||||
self.packets.append(ALUEXEC(_time=self.cycle, src=src))
|
||||
|
||||
def emit_immediate(self):
|
||||
from extra.assembly.amd.sqtt import IMMEDIATE
|
||||
self.packets.append(IMMEDIATE(_time=self.cycle, wave=self.wave_id))
|
||||
|
||||
def _get_src_regs(self, inst: Inst) -> list[tuple[str, int]]:
|
||||
"""Extract source register references from instruction."""
|
||||
srcs = []
|
||||
@@ -218,86 +231,173 @@ class SQTTState:
|
||||
self.pending_exec.append((exec_cycle, AluSrc.SALU))
|
||||
self.last_salu_exec = exec_cycle
|
||||
self._record_dst_ready(inst, ready_cycle)
|
||||
self.last_inst_type = 'SALU'
|
||||
|
||||
elif isinstance(inst, SOPP):
|
||||
pass # nop, waitcnt, etc don't emit packets
|
||||
# s_nop and other SOPP emit IMMEDIATE packets
|
||||
if self.last_inst_type == 'SALU':
|
||||
# SALU→SOPP: emit ALUEXECs with max gap of 2 once, then IMMEDIATEs
|
||||
# First ALUEXEC at last INST cycle, then allow one +2 gap, rest interleave
|
||||
# cycle is currently 1 past last INST
|
||||
last_inst_cycle = self.cycle - 1
|
||||
first = True
|
||||
had_gap = False # track if we've had a +2 gap
|
||||
while self.pending_exec:
|
||||
exec_cycle, src = self.pending_exec[0]
|
||||
if first:
|
||||
self.pending_exec.pop(0)
|
||||
self.cycle = last_inst_cycle # First ALUEXEC at same cycle as last INST
|
||||
first = False
|
||||
elif exec_cycle == self.cycle + 1: # consecutive (+1)
|
||||
self.pending_exec.pop(0)
|
||||
self.cycle = exec_cycle
|
||||
elif exec_cycle == self.cycle + 2 and not had_gap: # allow one +2 gap
|
||||
self.pending_exec.pop(0)
|
||||
self.cycle = exec_cycle
|
||||
had_gap = True
|
||||
else:
|
||||
break # too large a gap or second +2 gap - rest interleave with IMMEDIATEs
|
||||
self.emit_aluexec(src)
|
||||
self.cycle += 1 # IMMEDIATE starts +1 after last ALUEXEC
|
||||
elif self.last_inst_type == 'TRANS':
|
||||
# TRANS→SOPP: emit first ALUEXEC at its scheduled time, then IMMEDIATE +2 after
|
||||
# Sort pending execs and emit first one if it's before IMMEDIATE time
|
||||
self.pending_exec.sort(key=lambda x: x[0])
|
||||
if self.pending_exec:
|
||||
first_exec_cycle = self.pending_exec[0][0]
|
||||
# First IMMEDIATE at first_exec + 2
|
||||
imm_cycle = first_exec_cycle + 2
|
||||
# Emit ALUEXECs that come before IMMEDIATE
|
||||
while self.pending_exec and self.pending_exec[0][0] < imm_cycle:
|
||||
exec_cycle, src = self.pending_exec.pop(0)
|
||||
self.cycle = exec_cycle
|
||||
self.emit_aluexec(src)
|
||||
self.cycle = imm_cycle
|
||||
elif self.last_inst_type == 'VALU':
|
||||
# VALU→SOPP: 2-cycle gap, emit any ALUEXECs that would come before
|
||||
imm_cycle = self.cycle + 2 # When IMMEDIATE would be emitted
|
||||
# Emit ALUEXECs that come before the IMMEDIATE
|
||||
while self.pending_exec and self.pending_exec[0][0] < imm_cycle:
|
||||
exec_cycle, src = self.pending_exec.pop(0)
|
||||
self.cycle = max(self.cycle, exec_cycle)
|
||||
self.emit_aluexec(src)
|
||||
# Jump to IMMEDIATE cycle (either imm_cycle or 1 after last ALUEXEC, whichever is later)
|
||||
self.cycle = max(imm_cycle, self.cycle + 1)
|
||||
# Emit IMMEDIATE first, then any pending ALUEXECs at same cycle (HW order)
|
||||
self.emit_immediate()
|
||||
while self.pending_exec and self.pending_exec[0][0] <= self.cycle:
|
||||
exec_cycle, src = self.pending_exec.pop(0)
|
||||
old_cycle = self.cycle
|
||||
self.cycle = exec_cycle
|
||||
self.emit_aluexec(src)
|
||||
self.cycle = old_cycle
|
||||
self.last_inst_type = 'SOPP'
|
||||
|
||||
elif isinstance(inst, SMEM):
|
||||
pass # skip for ALU focus
|
||||
|
||||
elif isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC)):
|
||||
# VALU: issue now, track for ALUEXEC timing in finalize()
|
||||
# All ALUEXEC start at last_issue + 8, then serialize based on dependencies
|
||||
# VALU: issue now, emit ALUEXEC for completed instructions if queue is full
|
||||
from extra.assembly.amd.sqtt import AluSrc
|
||||
|
||||
op_name = inst.op_name if hasattr(inst, 'op_name') else ''
|
||||
if any(t in op_name for t in _TRANS_OPS):
|
||||
is_trans = any(t in op_name for t in _TRANS_OPS)
|
||||
|
||||
if is_trans:
|
||||
# Transcendental: emit INST, 4-cycle issue, add ALUEXEC to pending
|
||||
self.emit_inst(InstOp.VALU_TRANS)
|
||||
self.trans_count += 1
|
||||
# Check for dependency on VGPR sources
|
||||
src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0)
|
||||
if src_ready > self.cycle:
|
||||
# Dependent trans: ALUEXEC at source_ready + 6 (for VALU source) or +10 (for trans source)
|
||||
# Check if source is from a trans instruction (its ready time would be > issue + 6)
|
||||
src_is_trans = any(self.vgpr_ready.get(r, 0) > self.cycle + 6 for _, r in self._get_src_regs(inst) if _ == 'v')
|
||||
exec_cycle = src_ready + (10 if src_is_trans else 6)
|
||||
else:
|
||||
# Independent trans: ALUEXEC at issue + 9
|
||||
exec_cycle = self.cycle + TRANS_LATENCY
|
||||
self.pending_exec.append((exec_cycle, AluSrc.VALU))
|
||||
self._record_dst_ready(inst, exec_cycle) # Record when this trans result is ready
|
||||
self.last_trans_exec = exec_cycle
|
||||
self.last_trans_issue = self.cycle
|
||||
# Trans instructions take 4 cycles to issue
|
||||
self.cycle += TRANS_ISSUE_CYCLES - 1 # -1 because we add 1 at end of trace_inst
|
||||
self.last_inst_type = 'TRANS'
|
||||
else:
|
||||
# Regular VALU: emit VALUINST, may interleave ALUEXEC
|
||||
self.emit_valuinst()
|
||||
# Check for dependency
|
||||
src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0)
|
||||
has_dep = src_ready > self.cycle
|
||||
# Record issue info: (issue_cycle, has_dep, src_ready)
|
||||
self.valu_issue_cycles.append((self.cycle, has_dep, src_ready))
|
||||
# For dependency tracking, estimate result ready
|
||||
# Note: we don't know last_issue yet, so use a placeholder that will be corrected in finalize
|
||||
# The key insight is dependent instructions use src_ready + 10
|
||||
if has_dep:
|
||||
est_exec = src_ready + 10
|
||||
else:
|
||||
# For independent, result will be ready at (final last_issue + 8 + position)
|
||||
# We approximate with current cycle + 8, which is close enough for dependency detection
|
||||
est_exec = self.cycle + VALU_EXEC_LATENCY
|
||||
self._record_dst_ready(inst, est_exec)
|
||||
|
||||
# Track first VALU issue for latency calculation
|
||||
if self.first_valu_issue == 0:
|
||||
self.first_valu_issue = self.cycle
|
||||
|
||||
# Emit any pending ALUEXECs that have completed by now (after VALUINST)
|
||||
while self.pending_exec and self.pending_exec[0][0] <= self.cycle:
|
||||
exec_cycle, src = self.pending_exec.pop(0)
|
||||
old_cycle = self.cycle
|
||||
self.cycle = exec_cycle
|
||||
self.emit_aluexec(src)
|
||||
self.cycle = old_cycle
|
||||
|
||||
# Check for dependency
|
||||
src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0)
|
||||
has_dep = src_ready > self.cycle
|
||||
|
||||
if has_dep:
|
||||
# Dependent: first dependent is +6 from source, subsequent chained deps are +5
|
||||
# "Chained" means the source was also dependent (not the first independent VALU)
|
||||
chained = src_ready > self.first_valu_issue + VALU_EXEC_LATENCY
|
||||
exec_cycle = src_ready + (5 if chained else 6)
|
||||
else:
|
||||
# Independent VALU timing: 6-cycle pipeline latency
|
||||
# exec[i] = first_issue + 6 + i, with +1 spacing between consecutive execs
|
||||
exec_cycle = self.first_valu_issue + VALU_EXEC_LATENCY + self.valu_count
|
||||
exec_cycle = max(exec_cycle, self.last_valu_exec + 1) # +1 spacing
|
||||
|
||||
# Add to pending exec queue (sorted by time)
|
||||
self.pending_exec.append((exec_cycle, AluSrc.VALU))
|
||||
self.pending_exec.sort(key=lambda x: x[0])
|
||||
self.last_valu_exec = exec_cycle
|
||||
self.valu_count += 1
|
||||
self._record_dst_ready(inst, exec_cycle)
|
||||
self.last_inst_type = 'VALU'
|
||||
|
||||
elif isinstance(inst, VOPD):
|
||||
from extra.assembly.amd.sqtt import AluSrc
|
||||
|
||||
# First emit VALUINST (HW emits issue before completion at same cycle)
|
||||
self.emit_valuinst()
|
||||
|
||||
# Emit any pending ALUEXECs that have completed by now
|
||||
while self.pending_exec and self.pending_exec[0][0] <= self.cycle:
|
||||
exec_cycle, src = self.pending_exec.pop(0)
|
||||
old_cycle = self.cycle
|
||||
self.cycle = exec_cycle
|
||||
self.emit_aluexec(src)
|
||||
self.cycle = old_cycle
|
||||
src_ready = max((self.vgpr_ready.get(r, 0) for _, r in self._get_src_regs(inst) if _ == 'v'), default=0)
|
||||
has_dep = src_ready > self.cycle
|
||||
self.valu_issue_cycles.append((self.cycle, has_dep, src_ready))
|
||||
|
||||
if has_dep:
|
||||
est_exec = src_ready + 10
|
||||
exec_cycle = src_ready + 10
|
||||
else:
|
||||
est_exec = self.cycle + VALU_EXEC_LATENCY
|
||||
self._record_dst_ready(inst, est_exec)
|
||||
exec_cycle = self.cycle + VALU_EXEC_LATENCY
|
||||
if self.last_valu_exec > 0:
|
||||
exec_cycle = max(exec_cycle, self.last_valu_exec + 1)
|
||||
|
||||
self.pending_exec.append((exec_cycle, AluSrc.VALU))
|
||||
self.pending_exec.sort(key=lambda x: x[0])
|
||||
self.last_valu_exec = exec_cycle
|
||||
self._record_dst_ready(inst, exec_cycle)
|
||||
|
||||
self.cycle += 1
|
||||
|
||||
def finalize(self):
|
||||
"""Emit all pending ALUEXEC packets and WAVEEND."""
|
||||
"""Emit all remaining pending ALUEXEC packets and WAVEEND."""
|
||||
from extra.assembly.amd.sqtt import AluSrc
|
||||
|
||||
# Process VALU instructions
|
||||
# First pass: compute actual exec times for all instructions
|
||||
# - Independent: at last_issue + 8 + position, serialized at +1 intervals
|
||||
# - Dependent: at src_exec + 10 (first dep) or + 9 (chained dep)
|
||||
if self.valu_issue_cycles:
|
||||
last_issue = self.valu_issue_cycles[-1][0]
|
||||
base_exec = last_issue + VALU_EXEC_LATENCY
|
||||
|
||||
# Build exec_times list with actual completion times
|
||||
exec_times = []
|
||||
last_exec = 0
|
||||
last_was_dep = False
|
||||
|
||||
for i, (issue_cycle, has_dep, src_ready_idx) in enumerate(self.valu_issue_cycles):
|
||||
if has_dep:
|
||||
# src_ready_idx is the vgpr_ready value at trace time, which was an estimate
|
||||
# We need to find the actual exec time of the instruction that wrote this value
|
||||
# For now, use a simpler model: dependent instructions get +10 from previous exec (first dep) or +9 (chained)
|
||||
exec_cycle = last_exec + (10 if not last_was_dep else 9)
|
||||
last_was_dep = True
|
||||
else:
|
||||
# Independent: at base_exec + position, serialized at +1 intervals
|
||||
exec_cycle = max(base_exec + i, last_exec + 1)
|
||||
last_was_dep = False
|
||||
exec_times.append(exec_cycle)
|
||||
last_exec = exec_cycle
|
||||
|
||||
for exec_cycle in exec_times:
|
||||
self.pending_exec.append((exec_cycle, AluSrc.VALU))
|
||||
self.valu_issue_cycles.clear()
|
||||
|
||||
# Sort and emit all pending ALUEXEC
|
||||
# Emit any remaining pending ALUEXECs
|
||||
self.pending_exec.sort(key=lambda x: x[0])
|
||||
last_src = None
|
||||
for exec_cycle, src in self.pending_exec:
|
||||
@@ -305,11 +405,22 @@ class SQTTState:
|
||||
self.emit_aluexec(src)
|
||||
last_src = src
|
||||
self.pending_exec.clear()
|
||||
# WAVEEND timing: 14 cycles after last instruction if no ALU, 13/15 after last ALUEXEC
|
||||
if last_src is None:
|
||||
|
||||
# WAVEEND timing depends on what comes last
|
||||
# If last instruction was SOPP: +1 normally, but +11 if last ALUEXEC was recent (VALU drain)
|
||||
# Otherwise: 14 cycles for no ALU, 20 for trans, 14/15 for SALU/VALU
|
||||
if self.last_inst_type == 'SOPP':
|
||||
# Check if last ALUEXEC was recent (within ~5 cycles of current)
|
||||
if self.valu_count > 0 and self.last_valu_exec >= self.cycle - 5:
|
||||
self.cycle += 11 # VALU drain time
|
||||
else:
|
||||
self.cycle += 1
|
||||
elif last_src is None:
|
||||
self.cycle += 14 # empty program or no ALU ops
|
||||
elif self.trans_count > 0:
|
||||
self.cycle += 20 # trans has longer WAVEEND delay
|
||||
else:
|
||||
self.cycle += 15 if last_src == AluSrc.VALU else 13
|
||||
self.cycle += 15 if last_src == AluSrc.VALU else 14
|
||||
self.emit_waveend()
|
||||
|
||||
# VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup)
|
||||
|
||||
@@ -61,16 +61,13 @@ class TestSQTTCodec(unittest.TestCase):
|
||||
|
||||
from extra.assembly.amd.emu import SQTTState, decode_program
|
||||
from extra.assembly.amd.sqtt import VALUINST, ALUEXEC
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v_mov_b32_e32, v_add_f32_e32, s_mov_b32, s_add_u32, s_endpgm
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v_mov_b32_e32, v_add_f32_e32, v_rcp_f32_e32, v_sqrt_f32_e32, v_exp_f32_e32, s_mov_b32, s_add_u32
|
||||
from extra.assembly.amd.dsl import v, s
|
||||
from extra.assembly.amd.test.test_sqtt_hw import compile_asm_sqtt, run_prg_sqtt, run_prg_sqtt_batch, get_wave_packets, format_packet
|
||||
|
||||
def assemble(instructions: list) -> bytes:
|
||||
return b''.join(inst.to_bytes() for inst in instructions)
|
||||
from extra.assembly.amd.test.test_sqtt_hw import compile_asm_sqtt, run_prg_sqtt, run_prg_sqtt_batch, get_wave_packets, format_packet, assemble, wrap_with_nops
|
||||
|
||||
def run_emulator_sqtt(instructions: list) -> list[PacketType]:
|
||||
"""Run instructions through emulator and return SQTT packets."""
|
||||
code = assemble(instructions + [s_endpgm()])
|
||||
code = assemble(wrap_with_nops(instructions))
|
||||
program = decode_program(code)
|
||||
|
||||
sqtt = SQTTState(wave_id=0, simd=0, cu=0)
|
||||
@@ -103,7 +100,7 @@ def get_timing_deltas(packets: list) -> list[tuple[str, int]]:
|
||||
class TestEmulatorSQTT(unittest.TestCase):
|
||||
"""Tests comparing emulator SQTT to hardware SQTT."""
|
||||
|
||||
def _run_and_compare(self, instructions: list, name: str = "", n_runs: int = 100, min_identical: int = 30, max_attempts: int = 5):
|
||||
def _run_and_compare(self, instructions: list, name: str = "", n_runs: int = 100, min_identical: int = 25, max_attempts: int = 10):
|
||||
"""Run instructions on both hardware and emulator, compare SQTT structure."""
|
||||
from collections import Counter
|
||||
|
||||
@@ -187,14 +184,12 @@ class TestEmulatorSQTT(unittest.TestCase):
|
||||
t = p._time - emu_t0 if hasattr(p, '_time') else 0
|
||||
print(f" {t:8d}: {format_packet(p)}")
|
||||
|
||||
# Compare packet structure (types only, ignoring timing jitter)
|
||||
# Extract just packet types from emulator
|
||||
emu_types = tuple(t for t, _ in emu_deltas)
|
||||
# Find HW patterns with matching structure
|
||||
matching_structures = [p for p in pattern_counts if tuple(t for t, _ in p) == emu_types]
|
||||
self.assertGreater(len(matching_structures), 0,
|
||||
f"{name}: emulator packet structure {emu_types} not found in any HW traces.\n"
|
||||
f"HW structures: {set(tuple(t for t, _ in p) for p in pattern_counts)}")
|
||||
# Assert emulator pattern matches most common HW pattern exactly
|
||||
emu_pattern = tuple(emu_deltas)
|
||||
self.assertIn(emu_pattern, pattern_counts,
|
||||
f"{name}: emulator pattern not found in HW traces.\n"
|
||||
f"Emulator: {emu_deltas}\n"
|
||||
f"HW patterns: {[list(p) for p in pattern_counts.most_common(3)]}")
|
||||
|
||||
def test_salu_independent(self):
|
||||
"""SALU instructions with no dependencies."""
|
||||
@@ -216,13 +211,21 @@ class TestEmulatorSQTT(unittest.TestCase):
|
||||
"""Empty program - just s_endpgm."""
|
||||
self._run_and_compare([], "empty")
|
||||
|
||||
def test_valu_independent(self):
|
||||
def _test_valu_independent_n(self, n: int):
|
||||
"""VALU instructions with no dependencies."""
|
||||
self._run_and_compare([
|
||||
v_mov_b32_e32(v[0], 1.0),
|
||||
v_mov_b32_e32(v[1], 2.0),
|
||||
v_mov_b32_e32(v[2], 3.0),
|
||||
], "3 VALU independent")
|
||||
v_mov_b32_e32(v[i], float(i)) for i in range(n)
|
||||
], f"{n} VALU independent")
|
||||
|
||||
def test_valu_independent_1(self): self._test_valu_independent_n(1)
|
||||
def test_valu_independent_2(self): self._test_valu_independent_n(2)
|
||||
def test_valu_independent_3(self): self._test_valu_independent_n(3)
|
||||
def test_valu_independent_4(self): self._test_valu_independent_n(4)
|
||||
def test_valu_independent_5(self): self._test_valu_independent_n(5)
|
||||
def test_valu_independent_6(self): self._test_valu_independent_n(6)
|
||||
def test_valu_independent_7(self): self._test_valu_independent_n(7)
|
||||
def test_valu_independent_8(self): self._test_valu_independent_n(8)
|
||||
def test_valu_independent_16(self): self._test_valu_independent_n(16)
|
||||
|
||||
def test_valu_chain(self):
|
||||
"""VALU instructions with chain dependencies."""
|
||||
@@ -232,6 +235,22 @@ class TestEmulatorSQTT(unittest.TestCase):
|
||||
v_add_f32_e32(v[2], v[1], v[1]),
|
||||
], "3 VALU chain")
|
||||
|
||||
def test_trans_independent(self):
|
||||
"""Transcendental instructions with no dependencies."""
|
||||
self._run_and_compare([
|
||||
v_rcp_f32_e32(v[0], v[0]),
|
||||
v_sqrt_f32_e32(v[1], v[1]),
|
||||
v_exp_f32_e32(v[2], v[2]),
|
||||
], "3 TRANS independent")
|
||||
|
||||
def test_trans_chain(self):
|
||||
"""Transcendental instructions with chain dependencies."""
|
||||
self._run_and_compare([
|
||||
v_mov_b32_e32(v[0], 1.0),
|
||||
v_rcp_f32_e32(v[1], v[0]),
|
||||
v_sqrt_f32_e32(v[2], v[1]),
|
||||
], "3 TRANS chain")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -132,6 +132,10 @@ def print_all_packets(packets: list) -> None:
|
||||
def assemble(instructions: list) -> bytes:
|
||||
return b''.join(inst.to_bytes() for inst in instructions)
|
||||
|
||||
def wrap_with_nops(instructions: list) -> list:
|
||||
"""Add trailing NOPs and s_endpgm for clean SQTT timing."""
|
||||
return instructions + [s_nop(0)]*64 + [s_endpgm()]
|
||||
|
||||
def compile_asm_sqtt(instructions: list, alu_only: bool = False) -> AMDProgram:
|
||||
"""Compile instructions to an AMDProgram for SQTT tracing.
|
||||
|
||||
@@ -142,8 +146,8 @@ def compile_asm_sqtt(instructions: list, alu_only: bool = False) -> AMDProgram:
|
||||
Compiled AMDProgram ready to run
|
||||
"""
|
||||
compiler = HIPCompiler(dev.arch)
|
||||
instructions = instructions + [s_endpgm()]
|
||||
code = assemble(instructions)
|
||||
# Add NOPs before s_endpgm to flush pipeline and get clean timing
|
||||
code = assemble(wrap_with_nops(instructions))
|
||||
byte_str = ', '.join(f'0x{b:02x}' for b in code)
|
||||
|
||||
if alu_only:
|
||||
@@ -157,7 +161,7 @@ test:
|
||||
.rodata
|
||||
.p2align 6
|
||||
.amdhsa_kernel test
|
||||
.amdhsa_next_free_vgpr 8
|
||||
.amdhsa_next_free_vgpr 64
|
||||
.amdhsa_next_free_sgpr 8
|
||||
.amdhsa_wavefront_size32 1
|
||||
.amdhsa_group_segment_fixed_size 0
|
||||
@@ -178,7 +182,7 @@ amdhsa.kernels:
|
||||
.kernarg_segment_align: 8
|
||||
.wavefront_size: 32
|
||||
.sgpr_count: 8
|
||||
.vgpr_count: 8
|
||||
.vgpr_count: 64
|
||||
.max_flat_workgroup_size: 1024
|
||||
...
|
||||
.end_amdgpu_metadata
|
||||
|
||||
Reference in New Issue
Block a user