mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
good tests
This commit is contained in:
@@ -1,22 +1,17 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tests for SQTT emulator correctness against known hardware patterns.
|
||||
|
||||
These tests define the CORRECT behavior based on hardware traces.
|
||||
When run against hardware (SQTT_HW=1), all tests should pass.
|
||||
When run against the emulator, tests FAIL where the emulator is wrong.
|
||||
NOTE: This file only tests NOP and VALU behavior. For WMMA/DP/trans tests,
|
||||
see test_sqtt_compare.py.
|
||||
|
||||
Run emulator tests (expect failures - shows emulator bugs):
|
||||
PYTHONPATH="." python3 extra/assembly/amd/test/test_sqtt_correct.py
|
||||
|
||||
Run against real hardware (all tests should pass):
|
||||
SQTT_HW=1 PYTHONPATH="." python3 extra/assembly/amd/test/test_sqtt_correct.py
|
||||
Run emulator tests: PYTHONPATH="." python3 extra/assembly/amd/test/test_sqtt_correct.py
|
||||
Run hardware tests: SQTT_HW=1 PYTHONPATH="." python3 extra/assembly/amd/test/test_sqtt_correct.py
|
||||
"""
|
||||
import os
|
||||
import unittest
|
||||
|
||||
USE_HW = os.environ.get("SQTT_HW", "0") == "1"
|
||||
|
||||
# Must set SQTT env vars before importing tinygrad device
|
||||
if USE_HW:
|
||||
os.environ["SQTT"] = "1"
|
||||
os.environ["PROFILE"] = "1"
|
||||
@@ -24,10 +19,9 @@ if USE_HW:
|
||||
os.environ["SQTT_TOKEN_EXCLUDE"] = "3784"
|
||||
|
||||
from extra.assembly.amd.emu import SQTTState, decode_program, exec_wave, WaveState, LDSMem
|
||||
from extra.assembly.amd.sqtt import WAVESTART, WAVEEND, IMMEDIATE, VALUINST, ALUEXEC, PacketType
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (v_mov_b32_e32, v_add_f32_e32, v_rcp_f32_e32, v_sqrt_f32_e32,
|
||||
v_exp_f32_e32, s_mov_b32, s_nop, s_endpgm, s_delay_alu, v_mul_f32_e32)
|
||||
from extra.assembly.amd.dsl import v, s
|
||||
from extra.assembly.amd.sqtt import WAVESTART, WAVEEND
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v_mov_b32_e32, v_add_f32_e32, s_nop, s_endpgm
|
||||
from extra.assembly.amd.dsl import v
|
||||
|
||||
def assemble(instructions: list) -> bytes:
|
||||
return b''.join(inst.to_bytes() for inst in instructions)
|
||||
@@ -36,7 +30,6 @@ def wrap_with_nops(instructions: list, nops=16) -> list:
|
||||
return instructions + [s_nop(0)]*nops + [s_endpgm()]
|
||||
|
||||
def get_wave_packets(packets: list) -> list:
|
||||
"""Extract packets from first WAVESTART to WAVEEND on simd 0."""
|
||||
result, in_wave = [], False
|
||||
for p in packets:
|
||||
if isinstance(p, WAVESTART) and p.simd == 0:
|
||||
@@ -47,20 +40,15 @@ def get_wave_packets(packets: list) -> list:
|
||||
return result
|
||||
|
||||
def get_timing_deltas(packets: list) -> list[tuple[str, int]]:
|
||||
"""Extract (packet_type, delta_from_previous) for non-timing packets."""
|
||||
skip_types = {"NOP", "TS_DELTA_SHORT", "TS_WAVE_STATE", "TS_DELTA_OR_MARK", "TS_DELTA_S5_W2", "TS_DELTA_S5_W3", "TS_DELTA_S8_W3", "REG"}
|
||||
filtered = [p for p in packets if type(p).__name__ not in skip_types]
|
||||
if not filtered: return []
|
||||
result = [(type(filtered[0]).__name__, 0)]
|
||||
for i in range(1, len(filtered)):
|
||||
result.append((type(filtered[i]).__name__, filtered[i]._time - filtered[i-1]._time))
|
||||
if result and result[-1][0] == 'WAVEEND':
|
||||
result[-1] = ('WAVEEND', 0) # normalize WAVEEND timing
|
||||
return result
|
||||
|
||||
def run_emulator(instructions: list) -> list:
|
||||
"""Run instructions through emulator and return wave packets."""
|
||||
instructions = wrap_with_nops(instructions)
|
||||
code = assemble(instructions)
|
||||
program = decode_program(code)
|
||||
st = WaveState()
|
||||
@@ -70,255 +58,237 @@ def run_emulator(instructions: list) -> list:
|
||||
exec_wave(program, st, lds, 32, trace)
|
||||
return get_wave_packets(trace.packets)
|
||||
|
||||
def get_all_waves(packets: list) -> list[list]:
|
||||
"""Extract all WAVESTART..WAVEEND ranges on simd 0."""
|
||||
waves, in_wave, current = [], False, []
|
||||
for p in packets:
|
||||
if isinstance(p, WAVESTART) and p.simd == 0:
|
||||
in_wave, current = True, [p]
|
||||
elif in_wave:
|
||||
current.append(p)
|
||||
if isinstance(p, WAVEEND):
|
||||
waves.append(current)
|
||||
in_wave, current = False, []
|
||||
return waves
|
||||
|
||||
def run_hardware(instructions: list) -> list:
|
||||
"""Run on real hardware, return most common wave packet sequence."""
|
||||
from extra.assembly.amd.test.test_sqtt_hw import compile_asm_sqtt, run_prg_sqtt_batch
|
||||
from extra.assembly.amd.sqtt import decode
|
||||
from collections import Counter
|
||||
|
||||
instructions = wrap_with_nops(instructions)
|
||||
prg = compile_asm_sqtt(instructions, alu_only=True)
|
||||
|
||||
for _ in range(10):
|
||||
blobs = run_prg_sqtt_batch(prg, n_runs=200)
|
||||
traces = [get_wave_packets(decode(blob)) for blob in blobs]
|
||||
traces = [t for t in traces if t] # filter empty
|
||||
if traces:
|
||||
# Return most common pattern
|
||||
delta_sets = [tuple(get_timing_deltas(t)) for t in traces]
|
||||
most_common = Counter(delta_sets).most_common(1)[0][0]
|
||||
# Find a trace matching that pattern
|
||||
for t in traces:
|
||||
if tuple(get_timing_deltas(t)) == most_common:
|
||||
return t
|
||||
# Extract all waves from all blobs
|
||||
traces = []
|
||||
for blob in blobs:
|
||||
traces.extend(get_all_waves(decode(blob)))
|
||||
if not traces:
|
||||
continue
|
||||
# Find most common pattern
|
||||
delta_sets = [tuple(get_timing_deltas(t)) for t in traces]
|
||||
most_common = Counter(delta_sets).most_common(1)[0][0]
|
||||
for t in traces:
|
||||
if tuple(get_timing_deltas(t)) == most_common:
|
||||
return t
|
||||
return []
|
||||
|
||||
def run_sqtt(instructions: list) -> list:
|
||||
"""Run on HW or emulator based on SQTT_HW env var, return wave packets."""
|
||||
instructions = wrap_with_nops(instructions)
|
||||
return run_hardware(instructions) if USE_HW else run_emulator(instructions)
|
||||
|
||||
def run_test(instructions: list) -> list[tuple[str, int]]:
|
||||
"""Run and return timing deltas."""
|
||||
return get_timing_deltas(run_sqtt(instructions))
|
||||
def get_deltas(instructions: list) -> tuple[list[int], list[int]]:
|
||||
"""Run and return (issue deltas, exec deltas).
|
||||
Issue = IMMEDIATE + VALUINST, Exec = ALUEXEC.
|
||||
Deltas are between consecutive packets of same stream."""
|
||||
deltas = get_timing_deltas(run_sqtt(instructions))
|
||||
time = 0
|
||||
issue_times, exec_times = [], []
|
||||
for ptype, delta in deltas:
|
||||
time += delta
|
||||
if ptype in ('IMMEDIATE', 'VALUINST'):
|
||||
issue_times.append(time)
|
||||
elif ptype == 'ALUEXEC':
|
||||
exec_times.append(time)
|
||||
issue = [issue_times[i] - issue_times[i-1] for i in range(1, len(issue_times))]
|
||||
execd = [exec_times[i] - exec_times[i-1] for i in range(1, len(exec_times))]
|
||||
return issue, execd
|
||||
|
||||
# Hardware ALUEXEC delta patterns:
|
||||
# chain: forwarding (6,5,5...) then stalls when exhausted (9,9,9...)
|
||||
# ind: no dependencies, exec follows issue by 1 cycle
|
||||
# snop: n+4 baseline, but +4 extra for 11 <= n <= 22
|
||||
CHAIN_ISSUE = {
|
||||
2: [1],
|
||||
3: [1, 1],
|
||||
4: [1, 1, 1],
|
||||
5: [1, 1, 1, 1],
|
||||
6: [1, 1, 1, 1, 1],
|
||||
7: [1, 1, 1, 1, 1, 1],
|
||||
8: [1, 1, 1, 1, 1, 1, 1],
|
||||
12: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
14: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
15: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3], # issue stalls start here
|
||||
16: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 5],
|
||||
18: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 5, 5, 5],
|
||||
20: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 5, 5, 5, 5, 5],
|
||||
}
|
||||
CHAIN_EXEC = {
|
||||
2: [6],
|
||||
3: [6, 5],
|
||||
4: [6, 5, 5],
|
||||
5: [6, 5, 5, 9],
|
||||
6: [6, 5, 5, 9, 9],
|
||||
7: [6, 5, 5, 5, 9, 9],
|
||||
8: [6, 5, 5, 5, 9, 9, 9],
|
||||
12: [6, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9],
|
||||
14: [6, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9, 9, 9],
|
||||
15: [6, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9, 9, 9],
|
||||
16: [6, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9, 9, 9],
|
||||
18: [6, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9, 9, 9],
|
||||
20: [6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9, 9, 9],
|
||||
}
|
||||
IND_EXEC = {
|
||||
2: [1],
|
||||
3: [1, 1],
|
||||
4: [1, 1, 1],
|
||||
5: [1, 1, 1, 1],
|
||||
6: [1, 1, 1, 1, 1],
|
||||
7: [1, 1, 1, 1, 1, 1],
|
||||
8: [1, 1, 1, 1, 1, 1, 1],
|
||||
}
|
||||
SNOP_EXEC = {
|
||||
0: 4, 1: 5, 2: 6, 3: 7, 4: 8, 5: 9, 6: 10, 7: 11, 8: 12, 9: 13, 10: 14,
|
||||
11: 19, 12: 20, 13: 21, 14: 22, 15: 23, 16: 24, 17: 25, 18: 26, 19: 27, 20: 28, 21: 29, 22: 30, # +4 extra
|
||||
23: 27, 24: 28, 25: 29, 26: 30, 27: 31, 28: 32, 29: 33, 30: 34, 31: 35,
|
||||
32: 36, 33: 37, 34: 38, 35: 39, 36: 40, 37: 41, 38: 42, 39: 43,
|
||||
40: 44, 41: 45, 42: 46, 43: 47, 44: 48, 45: 49, 46: 50, 47: 51,
|
||||
48: 52, 49: 53, 50: 54, 51: 55, 52: 56, 53: 57, 54: 58, 55: 59,
|
||||
56: 60, 57: 61, 58: 62, 59: 63, 60: 64, 61: 65, 62: 66, 63: 67,
|
||||
}
|
||||
|
||||
class TestVALUChains(unittest.TestCase):
|
||||
"""VALU dependency chains."""
|
||||
def _chain(self, n):
|
||||
instrs = [v_mov_b32_e32(v[0], 1.0)] + [v_add_f32_e32(v[i], v[i-1], v[i-1]) for i in range(1, n)]
|
||||
issue, execd = get_deltas(instrs)
|
||||
self.assertEqual(issue[:n-1], CHAIN_ISSUE[n])
|
||||
self.assertEqual(execd, CHAIN_EXEC[n])
|
||||
|
||||
def test_chain_2(self): self._chain(2)
|
||||
def test_chain_3(self): self._chain(3)
|
||||
def test_chain_4(self): self._chain(4)
|
||||
def test_chain_5(self): self._chain(5)
|
||||
def test_chain_6(self): self._chain(6)
|
||||
def test_chain_7(self): self._chain(7)
|
||||
def test_chain_8(self): self._chain(8)
|
||||
def test_chain_12(self): self._chain(12)
|
||||
def test_chain_14(self): self._chain(14)
|
||||
def test_chain_15(self): self._chain(15) # issue stalls start here
|
||||
def test_chain_16(self): self._chain(16)
|
||||
def test_chain_18(self): self._chain(18)
|
||||
def test_chain_20(self): self._chain(20)
|
||||
|
||||
|
||||
class TestTranscendentals(unittest.TestCase):
|
||||
"""Tests for transcendental instruction (v_rcp, v_sqrt, v_exp, v_log) tracing.
|
||||
class TestVALUIndependent(unittest.TestCase):
|
||||
"""Independent VALU instructions."""
|
||||
def _ind(self, n):
|
||||
instrs = [v_mov_b32_e32(v[i], float(i)) for i in range(n)]
|
||||
issue, execd = get_deltas(instrs)
|
||||
self.assertEqual(issue[:n-1], [1]*(n-1))
|
||||
self.assertEqual(execd, IND_EXEC[n])
|
||||
|
||||
Hardware behavior:
|
||||
- Transcendentals emit INST packets (not VALUINST) - they use a separate unit
|
||||
- Issue interval is 4 cycles (vs 1 cycle for regular VALU)
|
||||
- Latency is higher (~8 cycles vs ~6 for VALU)
|
||||
"""
|
||||
|
||||
def test_trans_emits_inst_not_valuinst(self):
|
||||
"""Transcendentals should emit INST packets, not VALUINST.
|
||||
|
||||
Hardware uses a separate transcendental unit that traces as INST.
|
||||
"""
|
||||
deltas = run_test([v_rcp_f32_e32(v[0], v[0])])
|
||||
types = [d[0] for d in deltas]
|
||||
|
||||
self.assertIn('INST', types, "Transcendental should emit INST")
|
||||
self.assertNotIn('VALUINST', types, "Transcendental should not emit VALUINST")
|
||||
|
||||
def test_trans_4_cycle_issue_interval(self):
|
||||
"""Independent transcendentals issue 4 cycles apart (not 1 like VALU).
|
||||
|
||||
The transcendental unit is narrower and takes 4 cycles per instruction.
|
||||
"""
|
||||
deltas = run_test([
|
||||
v_rcp_f32_e32(v[0], v[0]),
|
||||
v_sqrt_f32_e32(v[1], v[1]),
|
||||
v_exp_f32_e32(v[2], v[2]),
|
||||
])
|
||||
|
||||
# Find INST packets (transcendentals)
|
||||
inst_deltas = [d for d in deltas if d[0] == 'INST']
|
||||
self.assertEqual(len(inst_deltas), 3, "Should have 3 INST packets for 3 transcendentals")
|
||||
|
||||
# Hardware: 4-cycle interval between transcendentals
|
||||
for d in inst_deltas[1:]:
|
||||
self.assertEqual(d[1], 4, "Trans should issue 4 cycles apart")
|
||||
def test_ind_2(self): self._ind(2)
|
||||
def test_ind_3(self): self._ind(3)
|
||||
def test_ind_4(self): self._ind(4)
|
||||
def test_ind_5(self): self._ind(5)
|
||||
def test_ind_6(self): self._ind(6)
|
||||
def test_ind_7(self): self._ind(7)
|
||||
def test_ind_8(self): self._ind(8)
|
||||
|
||||
|
||||
class TestSALUTracing(unittest.TestCase):
|
||||
"""Tests for SALU instruction tracing.
|
||||
class TestChainWithNop(unittest.TestCase):
|
||||
"""Dependency chain with s_nop between instructions."""
|
||||
def _test(self, nop_val, expected_issue, expected_exec):
|
||||
issue, execd = get_deltas([v_mov_b32_e32(v[0], 1.0), s_nop(nop_val), v_add_f32_e32(v[1], v[0], v[0])])
|
||||
self.assertEqual(issue[:2], expected_issue)
|
||||
self.assertEqual(execd, expected_exec)
|
||||
|
||||
Hardware behavior:
|
||||
- SALU instructions (s_mov, s_add, etc.) emit INST packets
|
||||
- They execute on the scalar unit, separate from VALU
|
||||
"""
|
||||
|
||||
def test_salu_emits_inst(self):
|
||||
"""SALU instructions should emit INST packets.
|
||||
|
||||
Hardware traces scalar ALU operations with INST packets.
|
||||
"""
|
||||
deltas = run_test([s_mov_b32(s[4], 1)])
|
||||
types = [d[0] for d in deltas]
|
||||
|
||||
self.assertIn('INST', types, "SALU should emit INST")
|
||||
|
||||
def test_mixed_salu_valu_interleaving(self):
|
||||
"""Mixed SALU/VALU sequence should show interleaved INST and VALUINST.
|
||||
|
||||
Hardware can issue SALU and VALU in parallel on different units.
|
||||
The trace shows both instruction types interleaved.
|
||||
"""
|
||||
deltas = run_test([
|
||||
s_mov_b32(s[4], 1),
|
||||
v_mov_b32_e32(v[0], 1.0),
|
||||
s_mov_b32(s[5], 2),
|
||||
v_mov_b32_e32(v[1], 2.0),
|
||||
])
|
||||
types = [d[0] for d in deltas]
|
||||
|
||||
self.assertIn('INST', types, "Should have INST for SALU")
|
||||
self.assertIn('VALUINST', types, "Should have VALUINST for VALU")
|
||||
inst_count = types.count('INST')
|
||||
valuinst_count = types.count('VALUINST')
|
||||
self.assertEqual(inst_count, 2, "Should have 2 INST packets for 2 SALUs")
|
||||
self.assertEqual(valuinst_count, 2, "Should have 2 VALUINST packets for 2 VALUs")
|
||||
def test_nop0(self): self._test(0, [3, 1], [6])
|
||||
def test_nop1(self): self._test(1, [4, 1], [7])
|
||||
def test_nop2(self): self._test(2, [5, 1], [9])
|
||||
def test_nop3(self): self._test(3, [6, 1], [9])
|
||||
def test_nop4(self): self._test(4, [11, 1], [10])
|
||||
def test_nop5(self): self._test(5, [12, 1], [11])
|
||||
|
||||
|
||||
class TestLongDependencyChains(unittest.TestCase):
|
||||
"""Tests for long VALU dependency chains.
|
||||
class TestIndWithNop(unittest.TestCase):
|
||||
"""Independent instructions with s_nop between."""
|
||||
def _test(self, nop_val, expected_issue, expected_exec):
|
||||
issue, execd = get_deltas([v_mov_b32_e32(v[0], 1.0), s_nop(nop_val), v_mov_b32_e32(v[1], 2.0)])
|
||||
self.assertEqual(issue[:2], expected_issue)
|
||||
self.assertEqual(execd, expected_exec)
|
||||
|
||||
Hardware behavior:
|
||||
- Chains exhaust forwarding network after ~4-5 instructions
|
||||
- Once exhausted, instructions must wait for register writeback
|
||||
- This causes ALUEXEC timing to change at the boundary
|
||||
"""
|
||||
|
||||
def test_5_chain_aluexec_timing(self):
|
||||
"""5-instruction dependency chain has specific ALUEXEC timing pattern.
|
||||
|
||||
Hardware: First 4 ALUEXECs use forwarding, 5th waits for writeback.
|
||||
This shows up as 9-cycle delta before the last ALUEXEC.
|
||||
"""
|
||||
deltas = run_test([
|
||||
v_mov_b32_e32(v[0], 1.0),
|
||||
v_add_f32_e32(v[1], v[0], v[0]),
|
||||
v_add_f32_e32(v[2], v[1], v[1]),
|
||||
v_add_f32_e32(v[3], v[2], v[2]),
|
||||
v_add_f32_e32(v[4], v[3], v[3]),
|
||||
])
|
||||
|
||||
aluexec_deltas = [d for d in deltas if d[0] == 'ALUEXEC']
|
||||
self.assertEqual(len(aluexec_deltas), 5, "Should have 5 ALUEXECs")
|
||||
|
||||
# HW: last ALUEXEC has delta=9 (waiting for writeback after forwarding exhausted)
|
||||
self.assertEqual(aluexec_deltas[-1][1], 9, "Last ALUEXEC should wait 9 cycles")
|
||||
def test_nop0(self): self._test(0, [3, 1], [4])
|
||||
def test_nop1(self): self._test(1, [4, 1], [5])
|
||||
def test_nop3(self): self._test(3, [6, 1], [7])
|
||||
def test_nop4(self): self._test(4, [11, 1], [8])
|
||||
def test_nop5(self): self._test(5, [12, 1], [9])
|
||||
|
||||
|
||||
class TestDelayALU(unittest.TestCase):
|
||||
"""Tests for s_delay_alu instruction effects.
|
||||
class TestChain3NopMid(unittest.TestCase):
|
||||
"""3-instruction chain with s_nop in middle."""
|
||||
def _test(self, nop_val, expected_issue, expected_exec):
|
||||
issue, execd = get_deltas([
|
||||
v_mov_b32_e32(v[0], 1.0), v_add_f32_e32(v[1], v[0], v[0]),
|
||||
s_nop(nop_val), v_add_f32_e32(v[2], v[1], v[1])])
|
||||
self.assertEqual(issue[:3], expected_issue)
|
||||
self.assertEqual(execd, expected_exec)
|
||||
|
||||
Hardware behavior:
|
||||
- s_delay_alu tells hardware about upcoming dependencies
|
||||
- It affects when the next instruction can issue
|
||||
- The encoding specifies which previous instruction to wait for
|
||||
"""
|
||||
|
||||
def test_delay_alu_affects_valuinst_timing(self):
|
||||
"""s_delay_alu(1) delays next VALUINST to wait for VALU_DEP_1.
|
||||
|
||||
Should show ~5-cycle gap between VALUINSTs when delay_alu is used.
|
||||
"""
|
||||
deltas = run_test([
|
||||
v_mov_b32_e32(v[0], 1.0),
|
||||
s_delay_alu(1), # VALU_DEP_1
|
||||
v_add_f32_e32(v[1], v[0], v[0]),
|
||||
])
|
||||
|
||||
valuinst_deltas = [d for d in deltas if d[0] == 'VALUINST']
|
||||
self.assertEqual(len(valuinst_deltas), 2, "Should have 2 VALUINSTs")
|
||||
|
||||
# delay_alu should cause ~5 cycle gap
|
||||
self.assertGreaterEqual(valuinst_deltas[1][1], 4, "delay_alu should cause ~5 cycle gap")
|
||||
|
||||
def test_delay_alu_aluexec_position(self):
|
||||
"""ALUEXEC timing after delay_alu - second ALUEXEC preceded by IMMEDIATE.
|
||||
|
||||
Hardware pattern: ..., IMMEDIATE, ALUEXEC (second one)
|
||||
"""
|
||||
deltas = run_test([
|
||||
v_mov_b32_e32(v[0], 1.0),
|
||||
s_delay_alu(1),
|
||||
v_add_f32_e32(v[1], v[0], v[0]),
|
||||
])
|
||||
|
||||
# Find position of second ALUEXEC
|
||||
aluexec_indices = [i for i, d in enumerate(deltas) if d[0] == 'ALUEXEC']
|
||||
self.assertEqual(len(aluexec_indices), 2, "Should have 2 ALUEXECs")
|
||||
|
||||
# On HW, there's an IMMEDIATE before the second ALUEXEC
|
||||
second_aluexec_idx = aluexec_indices[1]
|
||||
prev_type = deltas[second_aluexec_idx - 1][0]
|
||||
self.assertEqual(prev_type, 'IMMEDIATE', "Should have IMMEDIATE before second ALUEXEC")
|
||||
def test_nop0(self): self._test(0, [1, 3, 1], [6, 5])
|
||||
def test_nop1(self): self._test(1, [1, 4, 1], [6, 5])
|
||||
def test_nop2(self): self._test(2, [1, 5, 1], [6, 5])
|
||||
def test_nop3(self): self._test(3, [1, 10, 1], [6, 5])
|
||||
|
||||
|
||||
class TestVALUAfterTrans(unittest.TestCase):
|
||||
"""Tests for VALU instructions following transcendentals.
|
||||
class TestInd3NopMid(unittest.TestCase):
|
||||
"""3 independent instructions with s_nop in middle."""
|
||||
def _test(self, nop_val, expected_issue, expected_exec):
|
||||
issue, execd = get_deltas([
|
||||
v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[1], 2.0),
|
||||
s_nop(nop_val), v_mov_b32_e32(v[2], 3.0)])
|
||||
self.assertEqual(issue[:3], expected_issue)
|
||||
self.assertEqual(execd, expected_exec)
|
||||
|
||||
Hardware behavior:
|
||||
- Transcendental completes later than regular VALU
|
||||
- Following VALUs can issue while trans is in flight
|
||||
- First packet should be INST (trans), followed by VALUINSTs
|
||||
"""
|
||||
|
||||
def test_valu_after_trans_packet_types(self):
|
||||
"""VALUs after trans: first is INST (trans), rest are VALUINST.
|
||||
|
||||
Trans has ~8 cycle latency, VALU has ~6 cycle.
|
||||
Trans issues first as INST, then VALUs as VALUINST.
|
||||
"""
|
||||
deltas = run_test([
|
||||
v_rcp_f32_e32(v[0], v[0]), # trans
|
||||
v_mov_b32_e32(v[1], 1.0), # valu
|
||||
v_mov_b32_e32(v[2], 2.0), # valu
|
||||
])
|
||||
types = [d[0] for d in deltas]
|
||||
|
||||
# First instruction (after WAVESTART) should be INST (trans)
|
||||
self.assertEqual(types[1], 'INST', "First instruction should be INST (trans)")
|
||||
self.assertEqual(types[2], 'VALUINST', "Second should be VALUINST")
|
||||
self.assertEqual(types[3], 'VALUINST', "Third should be VALUINST")
|
||||
def test_nop0(self): self._test(0, [1, 3, 1], [1, 4])
|
||||
def test_nop1(self): self._test(1, [1, 4, 1], [1, 5])
|
||||
def test_nop2(self): self._test(2, [1, 5, 1], [1, 6])
|
||||
def test_nop3(self): self._test(3, [1, 10, 1], [1, 7])
|
||||
|
||||
|
||||
class TestBasicPatterns(unittest.TestCase):
|
||||
"""Tests for basic patterns that work correctly."""
|
||||
class TestSNopDelay(unittest.TestCase):
|
||||
"""Single s_nop delay between two independent v_movs."""
|
||||
def _test(self, n):
|
||||
_, execd = get_deltas([v_mov_b32_e32(v[0], 1.0), s_nop(n), v_mov_b32_e32(v[1], 2.0)])
|
||||
self.assertEqual(execd, [SNOP_EXEC[n]])
|
||||
|
||||
def test_empty_program(self):
|
||||
"""Empty program (just epilogue) produces WAVESTART, IMMEDIATEs, WAVEEND."""
|
||||
deltas = run_test([])
|
||||
types = [d[0] for d in deltas]
|
||||
self.assertEqual(types[0], 'WAVESTART')
|
||||
self.assertEqual(types[-1], 'WAVEEND')
|
||||
self.assertTrue(all(t == 'IMMEDIATE' for t in types[1:-1]))
|
||||
|
||||
def test_single_valu(self):
|
||||
"""Single VALU should produce VALUINST + ALUEXEC."""
|
||||
deltas = run_test([v_mov_b32_e32(v[0], 1.0)])
|
||||
types = [d[0] for d in deltas]
|
||||
self.assertEqual(types.count('VALUINST'), 1)
|
||||
self.assertEqual(types.count('ALUEXEC'), 1)
|
||||
|
||||
def test_independent_valus(self):
|
||||
"""Independent VALUs issue 1 cycle apart."""
|
||||
deltas = run_test([v_mov_b32_e32(v[i], float(i)) for i in range(4)])
|
||||
valuinst_deltas = [d for d in deltas if d[0] == 'VALUINST']
|
||||
for vd in valuinst_deltas[1:]:
|
||||
self.assertEqual(vd[1], 1)
|
||||
|
||||
def test_snop_timing(self):
|
||||
"""s_nop(7) produces larger delay than s_nop(0)."""
|
||||
snop0 = run_test([s_nop(0)])
|
||||
snop7 = run_test([s_nop(7)])
|
||||
imm0 = next(d for d in snop0 if d[0] == 'IMMEDIATE')
|
||||
imm7 = next(d for d in snop7 if d[0] == 'IMMEDIATE')
|
||||
self.assertGreater(imm7[1], imm0[1])
|
||||
def test_snop_0(self): self._test(0)
|
||||
def test_snop_1(self): self._test(1)
|
||||
def test_snop_2(self): self._test(2)
|
||||
def test_snop_3(self): self._test(3)
|
||||
def test_snop_4(self): self._test(4)
|
||||
def test_snop_5(self): self._test(5)
|
||||
def test_snop_6(self): self._test(6)
|
||||
def test_snop_7(self): self._test(7)
|
||||
def test_snop_10(self): self._test(10)
|
||||
def test_snop_11(self): self._test(11) # +4 extra starts here
|
||||
def test_snop_15(self): self._test(15)
|
||||
def test_snop_22(self): self._test(22) # +4 extra ends here
|
||||
def test_snop_23(self): self._test(23)
|
||||
def test_snop_31(self): self._test(31)
|
||||
def test_snop_32(self): self._test(32)
|
||||
def test_snop_63(self): self._test(63)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user