test schedule with multiple AFTER (#14449)

This commit is contained in:
chenyu
2026-01-30 15:59:00 -05:00
committed by GitHub
parent 486d53d646
commit cfcd1debb5

View File

@@ -247,5 +247,39 @@ class TestCustomKernel(unittest.TestCase):
err = (O_custom - O_ref).square().max()
self.assertLess(err.item(), 1e-6)
def test_multi_after_schedule_order(self):
"""Test correct scheduling order when custom_kernel has multiple outputs.
custom_kernel with 4 arguments creates 4 AFTERs from the same kernel.
The custom_kernel depends on both A2 and B2, so it must be scheduled after both.
E only depends on A2, so E can run before custom_kernel finishes waiting for B2.
Expected schedule order: [A2, B2, E, custom_addmul, final_sum]
The custom_addmul kernel should be at index 3.
"""
from tinygrad.engine.schedule import create_schedule
from tinygrad.schedule.rangeify import get_rangeify_map
A, B = Tensor.empty(4, 4), Tensor.empty(4, 4)
A2 = (A + 1).contiguous() # kernel 0: depends on A
B2 = (B * 2).contiguous() # kernel 1: depends on B
C, D = Tensor.empty(4, 4), Tensor.empty(4, 4)
C, D, _, _ = Tensor.custom_kernel(C, D, A2, B2, fxn=custom_elementwise_addmul_kernel) # depends on A2 AND B2
E = (A2 * 3).contiguous() # kernel 2: depends only on A2
result = (C + D + E).sum() # kernel 3: custom_addmul, then kernel 4: sum
big_sink = result.uop.sink()
tensor_map = get_rangeify_map(big_sink)
sched_sink = big_sink.substitute(tensor_map)
schedule, _ = create_schedule(sched_sink)
# Find the custom_addmul kernel position
custom_idx = next((i for i, item in enumerate(schedule)
if hasattr(item.ast, "arg") and hasattr(item.ast.arg, "name")
and "custom_addmul" in item.ast.arg.name), None)
self.assertIsNotNone(custom_idx, "custom_addmul kernel not found in schedule")
self.assertEqual(custom_idx, 3, f"custom_addmul should be at index 3, got {custom_idx}")
if __name__ == '__main__':
unittest.main()