diff --git a/docs/abstractions3.py b/docs/abstractions3.py index e5b995b6f2..b69905f490 100644 --- a/docs/abstractions3.py +++ b/docs/abstractions3.py @@ -48,7 +48,7 @@ for si in schedule: print(str(si)[:80]) # 4. Lower a schedule. from tinygrad.engine.realize import lower_schedule_item, ExecItem -lowered: List[ExecItem] = [ExecItem(lower_schedule_item(si).prg, list(si.bufs)) for si in tqdm(schedule)] +lowered: List[ExecItem] = [lower_schedule_item(si) for si in tqdm(schedule)] # ***** # 5. Run the schedule diff --git a/test/test_schedule.py b/test/test_schedule.py index 800c0aac3d..1368a9cdd5 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -17,7 +17,7 @@ from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_re from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, getenv, SPLIT_REDUCEOP, unwrap, prod, Context from tinygrad.codegen.kernel import Kernel, verify_ast from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops -from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule +from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule from extra.models.llama import precompute_freqs_cis class KernelCountException(Exception): pass @@ -29,7 +29,9 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz else: assert isinstance(t, UOp), f"can't schedule {t}" sched, _, __ = create_schedule_with_vars([t]) - if filter_sink: sched = [s for s in sched if s.ast.op is Ops.SINK] + # test lowering all the ScheduleItems to ExecItems + lowered = list(lower_schedule(sched.copy())) + if filter_sink: sched = [s for s,ei in zip(sched, lowered) if isinstance(ei.prg, CompiledRunner)] if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}") if DEBUG >= 3: @@ -37,10 +39,6 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz print("kernel", i+1) print(s.ast) raise KernelCountException(f"{len(sched)=} != {allowed}") - # test the (sink) ops linearize - for s in sched: - if s.ast.op is not Ops.SINK: continue - get_runner(s.bufs[0].device, s.ast) return sched def _realize_weights(m): @@ -1477,10 +1475,10 @@ class TestIndexing(unittest.TestCase): with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)): lst = [xt] if isinstance(xt, Tensor) else xt s = Tensor.schedule(*lst) - kernels = [si for si in s if si.ast.op is Ops.SINK] - for si in kernels: verify_ast(si.ast) - run_schedule(s.copy()) + lowered = list(lower_schedule(s.copy())) + kernels = [ei for ei in list(lowered) if isinstance(ei.prg, CompiledRunner)] if FUSE_ARANGE: self.assertEqual(len(kernels), cnt) + for ei in lowered: ei.run(do_update_stats=True) return s def test_simple_indexing(self):