test fixup prereqs for delete_buffer_view [pr] (#8523)

This commit is contained in:
qazal
2025-01-07 11:52:18 +02:00
committed by GitHub
parent 85a4397f27
commit 0e97f807e0
2 changed files with 8 additions and 10 deletions

View File

@@ -48,7 +48,7 @@ for si in schedule: print(str(si)[:80])
# 4. Lower a schedule.
from tinygrad.engine.realize import lower_schedule_item, ExecItem
lowered: List[ExecItem] = [ExecItem(lower_schedule_item(si).prg, list(si.bufs)) for si in tqdm(schedule)]
lowered: List[ExecItem] = [lower_schedule_item(si) for si in tqdm(schedule)]
# *****
# 5. Run the schedule

View File

@@ -17,7 +17,7 @@ from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_re
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from extra.models.llama import precompute_freqs_cis
class KernelCountException(Exception): pass
@@ -29,7 +29,9 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz
else:
assert isinstance(t, UOp), f"can't schedule {t}"
sched, _, __ = create_schedule_with_vars([t])
if filter_sink: sched = [s for s in sched if s.ast.op is Ops.SINK]
# test lowering all the ScheduleItems to ExecItems
lowered = list(lower_schedule(sched.copy()))
if filter_sink: sched = [s for s,ei in zip(sched, lowered) if isinstance(ei.prg, CompiledRunner)]
if len(sched) != allowed:
print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if DEBUG >= 3:
@@ -37,10 +39,6 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz
print("kernel", i+1)
print(s.ast)
raise KernelCountException(f"{len(sched)=} != {allowed}")
# test the (sink) ops linearize
for s in sched:
if s.ast.op is not Ops.SINK: continue
get_runner(s.bufs[0].device, s.ast)
return sched
def _realize_weights(m):
@@ -1477,10 +1475,10 @@ class TestIndexing(unittest.TestCase):
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):
lst = [xt] if isinstance(xt, Tensor) else xt
s = Tensor.schedule(*lst)
kernels = [si for si in s if si.ast.op is Ops.SINK]
for si in kernels: verify_ast(si.ast)
run_schedule(s.copy())
lowered = list(lower_schedule(s.copy()))
kernels = [ei for ei in list(lowered) if isinstance(ei.prg, CompiledRunner)]
if FUSE_ARANGE: self.assertEqual(len(kernels), cnt)
for ei in lowered: ei.run(do_update_stats=True)
return s
def test_simple_indexing(self):