mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
test fixup prereqs for delete_buffer_view [pr] (#8523)
This commit is contained in:
@@ -48,7 +48,7 @@ for si in schedule: print(str(si)[:80])
|
||||
# 4. Lower a schedule.
|
||||
|
||||
from tinygrad.engine.realize import lower_schedule_item, ExecItem
|
||||
lowered: List[ExecItem] = [ExecItem(lower_schedule_item(si).prg, list(si.bufs)) for si in tqdm(schedule)]
|
||||
lowered: List[ExecItem] = [lower_schedule_item(si) for si in tqdm(schedule)]
|
||||
|
||||
# *****
|
||||
# 5. Run the schedule
|
||||
|
||||
@@ -17,7 +17,7 @@ from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_re
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops
|
||||
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
||||
class KernelCountException(Exception): pass
|
||||
@@ -29,7 +29,9 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz
|
||||
else:
|
||||
assert isinstance(t, UOp), f"can't schedule {t}"
|
||||
sched, _, __ = create_schedule_with_vars([t])
|
||||
if filter_sink: sched = [s for s in sched if s.ast.op is Ops.SINK]
|
||||
# test lowering all the ScheduleItems to ExecItems
|
||||
lowered = list(lower_schedule(sched.copy()))
|
||||
if filter_sink: sched = [s for s,ei in zip(sched, lowered) if isinstance(ei.prg, CompiledRunner)]
|
||||
if len(sched) != allowed:
|
||||
print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||
if DEBUG >= 3:
|
||||
@@ -37,10 +39,6 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz
|
||||
print("kernel", i+1)
|
||||
print(s.ast)
|
||||
raise KernelCountException(f"{len(sched)=} != {allowed}")
|
||||
# test the (sink) ops linearize
|
||||
for s in sched:
|
||||
if s.ast.op is not Ops.SINK: continue
|
||||
get_runner(s.bufs[0].device, s.ast)
|
||||
return sched
|
||||
|
||||
def _realize_weights(m):
|
||||
@@ -1477,10 +1475,10 @@ class TestIndexing(unittest.TestCase):
|
||||
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):
|
||||
lst = [xt] if isinstance(xt, Tensor) else xt
|
||||
s = Tensor.schedule(*lst)
|
||||
kernels = [si for si in s if si.ast.op is Ops.SINK]
|
||||
for si in kernels: verify_ast(si.ast)
|
||||
run_schedule(s.copy())
|
||||
lowered = list(lower_schedule(s.copy()))
|
||||
kernels = [ei for ei in list(lowered) if isinstance(ei.prg, CompiledRunner)]
|
||||
if FUSE_ARANGE: self.assertEqual(len(kernels), cnt)
|
||||
for ei in lowered: ei.run(do_update_stats=True)
|
||||
return s
|
||||
|
||||
def test_simple_indexing(self):
|
||||
|
||||
Reference in New Issue
Block a user