mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
test schedule with multiple AFTER (#14449)
This commit is contained in:
@@ -247,5 +247,39 @@ class TestCustomKernel(unittest.TestCase):
|
||||
err = (O_custom - O_ref).square().max()
|
||||
self.assertLess(err.item(), 1e-6)
|
||||
|
||||
def test_multi_after_schedule_order(self):
|
||||
"""Test correct scheduling order when custom_kernel has multiple outputs.
|
||||
|
||||
custom_kernel with 4 arguments creates 4 AFTERs from the same kernel.
|
||||
The custom_kernel depends on both A2 and B2, so it must be scheduled after both.
|
||||
E only depends on A2, so E can run before custom_kernel finishes waiting for B2.
|
||||
|
||||
Expected schedule order: [A2, B2, E, custom_addmul, final_sum]
|
||||
The custom_addmul kernel should be at index 3.
|
||||
"""
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
|
||||
A, B = Tensor.empty(4, 4), Tensor.empty(4, 4)
|
||||
A2 = (A + 1).contiguous() # kernel 0: depends on A
|
||||
B2 = (B * 2).contiguous() # kernel 1: depends on B
|
||||
C, D = Tensor.empty(4, 4), Tensor.empty(4, 4)
|
||||
C, D, _, _ = Tensor.custom_kernel(C, D, A2, B2, fxn=custom_elementwise_addmul_kernel) # depends on A2 AND B2
|
||||
E = (A2 * 3).contiguous() # kernel 2: depends only on A2
|
||||
result = (C + D + E).sum() # kernel 3: custom_addmul, then kernel 4: sum
|
||||
|
||||
big_sink = result.uop.sink()
|
||||
tensor_map = get_rangeify_map(big_sink)
|
||||
sched_sink = big_sink.substitute(tensor_map)
|
||||
schedule, _ = create_schedule(sched_sink)
|
||||
|
||||
# Find the custom_addmul kernel position
|
||||
custom_idx = next((i for i, item in enumerate(schedule)
|
||||
if hasattr(item.ast, "arg") and hasattr(item.ast.arg, "name")
|
||||
and "custom_addmul" in item.ast.arg.name), None)
|
||||
|
||||
self.assertIsNotNone(custom_idx, "custom_addmul kernel not found in schedule")
|
||||
self.assertEqual(custom_idx, 3, f"custom_addmul should be at index 3, got {custom_idx}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user