diff --git a/test/test_custom_kernel.py b/test/test_custom_kernel.py index 7defac9c20..a3a24b958e 100644 --- a/test/test_custom_kernel.py +++ b/test/test_custom_kernel.py @@ -247,5 +247,39 @@ class TestCustomKernel(unittest.TestCase): err = (O_custom - O_ref).square().max() self.assertLess(err.item(), 1e-6) + def test_multi_after_schedule_order(self): + """Test correct scheduling order when custom_kernel has multiple outputs. + + custom_kernel with 4 arguments creates 4 AFTERs from the same kernel. + The custom_kernel depends on both A2 and B2, so it must be scheduled after both. + E only depends on A2, so E can run before custom_kernel finishes waiting for B2. + + Expected schedule order: [A2, B2, E, custom_addmul, final_sum] + The custom_addmul kernel should be at index 3. + """ + from tinygrad.engine.schedule import create_schedule + from tinygrad.schedule.rangeify import get_rangeify_map + + A, B = Tensor.empty(4, 4), Tensor.empty(4, 4) + A2 = (A + 1).contiguous() # kernel 0: depends on A + B2 = (B * 2).contiguous() # kernel 1: depends on B + C, D = Tensor.empty(4, 4), Tensor.empty(4, 4) + C, D, _, _ = Tensor.custom_kernel(C, D, A2, B2, fxn=custom_elementwise_addmul_kernel) # depends on A2 AND B2 + E = (A2 * 3).contiguous() # kernel 2: depends only on A2 + result = (C + D + E).sum() # kernel 3: custom_addmul, then kernel 4: sum + + big_sink = result.uop.sink() + tensor_map = get_rangeify_map(big_sink) + sched_sink = big_sink.substitute(tensor_map) + schedule, _ = create_schedule(sched_sink) + + # Find the custom_addmul kernel position + custom_idx = next((i for i, item in enumerate(schedule) + if hasattr(item.ast, "arg") and hasattr(item.ast.arg, "name") + and "custom_addmul" in item.ast.arg.name), None) + + self.assertIsNotNone(custom_idx, "custom_addmul kernel not found in schedule") + self.assertEqual(custom_idx, 3, f"custom_addmul should be at index 3, got {custom_idx}") + if __name__ == '__main__': unittest.main()