create_schedule([x.lazydata]) -> x.schedule() in tests (#8449)

This commit is contained in:
qazal
2024-12-30 21:15:52 +02:00
committed by GitHub
parent 0addbad36d
commit 866dfa1f23
23 changed files with 85 additions and 105 deletions

View File

@@ -16,7 +16,7 @@ from tinygrad.shape.view import View
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, view_supported_devices, symbolic
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule, view_right, view_left, remove_movement_ops
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
from extra.models.llama import precompute_freqs_cis
@@ -28,7 +28,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz
elif isinstance(t, List) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t)
else:
assert isinstance(t, UOp), f"can't schedule {t}"
sched = create_schedule([t])
sched, _ = create_schedule_with_vars([t])
if filter_sink: sched = [s for s in sched if s.ast.op is Ops.SINK]
if len(sched) != allowed:
print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
@@ -55,7 +55,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
w = Tensor.uniform(16, CIN, 3, 3, requires_grad=True).realize()
ret = Tensor.conv2d(img, w).relu().mean().backward()
dtypes.default_float = old_default_float
with Context(**kwargs): s = create_schedule([ret.lazydata, img.grad.lazydata, w.grad.lazydata])
with Context(**kwargs): s = Tensor.schedule(ret, img.grad, w.grad)
run_schedule(s.copy())
cnt = len([si for si in s if si.ast.op is Ops.SINK])
assert cnt == allowed, f"expected {allowed} kernels, got {cnt}"
@@ -1394,11 +1394,11 @@ class TestSchedule(unittest.TestCase):
def test_const_schedule(self):
constv = Tensor.empty(2, 2).lazydata.const_like(10)
self.assertEqual(len(create_schedule([constv])), 0)
check_schedule(constv, 0)
def test_const_schedule_contig(self):
constv = Tensor.empty(2, 2).lazydata.const_like(10).contiguous()
self.assertEqual(len(create_schedule([constv])), 1)
check_schedule(constv, 1)
@unittest.skipIf(Device.DEFAULT != "GPU", "image only supported on GPU")
def test_image_matmul(self):