fix test failure

This commit is contained in:
George Hotz
2026-01-01 13:21:55 -05:00
parent 8f4de73141
commit 729bb04d8c

View File

@@ -990,6 +990,152 @@ class TestLaneInstructions(unittest.TestCase):
for lane in range(4):
self.assertEqual(st.vgpr[lane][1], 10, f"Sum 1+2+3+4 should be 10")
def test_v_writelane_b32_different_vgpr(self):
"""V_WRITELANE_B32 writes to a non-zero VGPR index.
Regression test for bug where vdst_idx was always 0 due to function signature
mismatch (_vars parameter shifted all arguments). This caused all WRITELANE
operations to write to v[0] regardless of the actual destination register.
"""
instructions = [
v_mov_b32_e32(v[0], 0), # Initialize v0 = 0
v_mov_b32_e32(v[5], 0), # Initialize v5 = 0
s_mov_b32(s[0], 0x12345678), # Value to write
v_writelane_b32(v[5], s[0], 1), # Write to lane 1's v5 (NOT v0!)
]
st = run_program(instructions, n_lanes=4)
# v[0] should remain 0 for all lanes (bug would have written here)
for lane in range(4):
self.assertEqual(st.vgpr[lane][0], 0, f"v[0] lane {lane} should be 0 (untouched)")
# v[5] should have the value only in lane 1
for lane in range(4):
if lane == 1:
self.assertEqual(st.vgpr[lane][5], 0x12345678, f"v[5] lane 1 should have 0x12345678")
else:
self.assertEqual(st.vgpr[lane][5], 0, f"v[5] lane {lane} should be 0")
def test_v_writelane_b32_high_vgpr_index(self):
"""V_WRITELANE_B32 writes to a high VGPR index (v[15]).
Tests that the vdst_idx is correctly passed through for larger register indices.
"""
instructions = [
v_mov_b32_e32(v[0], 0), # Initialize v0 = 0
v_mov_b32_e32(v[15], 0), # Initialize v15 = 0
s_mov_b32(s[0], 0xCAFEBABE), # Value to write
v_writelane_b32(v[15], s[0], 0), # Write to lane 0's v15
]
st = run_program(instructions, n_lanes=4)
# v[0] should remain 0 for all lanes
for lane in range(4):
self.assertEqual(st.vgpr[lane][0], 0, f"v[0] lane {lane} should be 0")
# v[15] should have the value only in lane 0
self.assertEqual(st.vgpr[0][15], 0xCAFEBABE, "v[15] lane 0 should have 0xCAFEBABE")
for lane in range(1, 4):
self.assertEqual(st.vgpr[lane][15], 0, f"v[15] lane {lane} should be 0")
def test_v_writelane_b32_multiple_writes_different_vgprs(self):
"""V_WRITELANE_B32 writes to multiple different VGPRs.
This is the pattern used in sparse_categorical_crossentropy where values
are written to different VGPR indices via writelane, then read back.
"""
instructions = [
# Initialize all target VGPRs to 0
v_mov_b32_e32(v[0], 0),
v_mov_b32_e32(v[3], 0),
v_mov_b32_e32(v[7], 0),
v_mov_b32_e32(v[10], 0),
# Write different values to different VGPRs at different lanes
s_mov_b32(s[0], 100),
v_writelane_b32(v[3], s[0], 0), # v[3] lane 0 = 100
s_mov_b32(s[0], 200),
v_writelane_b32(v[7], s[0], 1), # v[7] lane 1 = 200
s_mov_b32(s[0], 300),
v_writelane_b32(v[10], s[0], 2), # v[10] lane 2 = 300
]
st = run_program(instructions, n_lanes=4)
# v[0] should remain 0 everywhere
for lane in range(4):
self.assertEqual(st.vgpr[lane][0], 0, f"v[0] lane {lane} should be 0")
# Check each target VGPR
self.assertEqual(st.vgpr[0][3], 100, "v[3] lane 0 should be 100")
for lane in range(1, 4):
self.assertEqual(st.vgpr[lane][3], 0, f"v[3] lane {lane} should be 0")
self.assertEqual(st.vgpr[1][7], 200, "v[7] lane 1 should be 200")
for lane in [0, 2, 3]:
self.assertEqual(st.vgpr[lane][7], 0, f"v[7] lane {lane} should be 0")
self.assertEqual(st.vgpr[2][10], 300, "v[10] lane 2 should be 300")
for lane in [0, 1, 3]:
self.assertEqual(st.vgpr[lane][10], 0, f"v[10] lane {lane} should be 0")
def test_v_writelane_then_readlane_different_vgpr(self):
"""V_WRITELANE followed by V_READLANE on a non-zero VGPR.
Regression test: the original bug caused writelane to always write to v[0],
so reading back from the intended VGPR would return 0 instead of the written value.
This is the exact pattern that failed in sparse_categorical_crossentropy.
"""
instructions = [
v_mov_b32_e32(v[0], 0), # Initialize v0 = 0
v_mov_b32_e32(v[8], 0), # Initialize v8 = 0
s_mov_b32(s[0], 0xABCD1234),
v_writelane_b32(v[8], s[0], 2), # Write to lane 2's v8
self._readlane(1, v[8], 2), # Read back from lane 2's v8 into s1
v_mov_b32_e32(v[1], s[1]), # Broadcast to all lanes
]
st = run_program(instructions, n_lanes=4)
# The read value should be what we wrote
for lane in range(4):
self.assertEqual(st.vgpr[lane][1], 0xABCD1234,
f"Lane {lane}: readlane should return 0xABCD1234, got 0x{st.vgpr[lane][1]:08x}")
# v[0] should still be 0 (bug would have written here instead of v[8])
for lane in range(4):
self.assertEqual(st.vgpr[lane][0], 0, f"v[0] lane {lane} should be 0 (untouched)")
def test_v_writelane_b32_accumulate_pattern(self):
"""V_WRITELANE_B32 used to accumulate values across lanes into a single VGPR.
This pattern is used in reductions where each lane writes its result to
a different lane of the same VGPR, then the results are read back.
"""
instructions = [
v_mov_b32_e32(v[6], 0), # Initialize accumulator v6 = 0
# Each "iteration" writes to a different lane
s_mov_b32(s[0], 10),
v_writelane_b32(v[6], s[0], 0), # lane 0 gets 10
s_mov_b32(s[0], 20),
v_writelane_b32(v[6], s[0], 1), # lane 1 gets 20
s_mov_b32(s[0], 30),
v_writelane_b32(v[6], s[0], 2), # lane 2 gets 30
s_mov_b32(s[0], 40),
v_writelane_b32(v[6], s[0], 3), # lane 3 gets 40
# Now read them all back and sum
self._readlane(0, v[6], 0), # s0 = 10
self._readlane(1, v[6], 1), # s1 = 20
s_add_u32(s[0], s[0], s[1]), # s0 = 30
self._readlane(1, v[6], 2), # s1 = 30
s_add_u32(s[0], s[0], s[1]), # s0 = 60
self._readlane(1, v[6], 3), # s1 = 40
s_add_u32(s[0], s[0], s[1]), # s0 = 100
v_mov_b32_e32(v[7], s[0]), # Broadcast sum to all lanes
]
st = run_program(instructions, n_lanes=4)
# Check that each lane of v[6] has the correct value
self.assertEqual(st.vgpr[0][6], 10, "v[6] lane 0 should be 10")
self.assertEqual(st.vgpr[1][6], 20, "v[6] lane 1 should be 20")
self.assertEqual(st.vgpr[2][6], 30, "v[6] lane 2 should be 30")
self.assertEqual(st.vgpr[3][6], 40, "v[6] lane 3 should be 40")
# Check the sum
for lane in range(4):
self.assertEqual(st.vgpr[lane][7], 100, f"Sum should be 100, got {st.vgpr[lane][7]}")
class TestTrigonometry(unittest.TestCase):
"""Tests for trigonometric instructions."""