mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 07:35:16 -05:00
create_schedule([x.lazydata]) -> x.schedule() in tests (#8449)
This commit is contained in:
@@ -16,7 +16,7 @@ from tinygrad.shape.view import View
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, view_supported_devices, symbolic
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule, view_right, view_left, remove_movement_ops
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops
|
||||
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
||||
@@ -28,7 +28,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz
|
||||
elif isinstance(t, List) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t)
|
||||
else:
|
||||
assert isinstance(t, UOp), f"can't schedule {t}"
|
||||
sched = create_schedule([t])
|
||||
sched, _ = create_schedule_with_vars([t])
|
||||
if filter_sink: sched = [s for s in sched if s.ast.op is Ops.SINK]
|
||||
if len(sched) != allowed:
|
||||
print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||
@@ -55,7 +55,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
||||
w = Tensor.uniform(16, CIN, 3, 3, requires_grad=True).realize()
|
||||
ret = Tensor.conv2d(img, w).relu().mean().backward()
|
||||
dtypes.default_float = old_default_float
|
||||
with Context(**kwargs): s = create_schedule([ret.lazydata, img.grad.lazydata, w.grad.lazydata])
|
||||
with Context(**kwargs): s = Tensor.schedule(ret, img.grad, w.grad)
|
||||
run_schedule(s.copy())
|
||||
cnt = len([si for si in s if si.ast.op is Ops.SINK])
|
||||
assert cnt == allowed, f"expected {allowed} kernels, got {cnt}"
|
||||
@@ -1394,11 +1394,11 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
def test_const_schedule(self):
|
||||
constv = Tensor.empty(2, 2).lazydata.const_like(10)
|
||||
self.assertEqual(len(create_schedule([constv])), 0)
|
||||
check_schedule(constv, 0)
|
||||
|
||||
def test_const_schedule_contig(self):
|
||||
constv = Tensor.empty(2, 2).lazydata.const_like(10).contiguous()
|
||||
self.assertEqual(len(create_schedule([constv])), 1)
|
||||
check_schedule(constv, 1)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "GPU", "image only supported on GPU")
|
||||
def test_image_matmul(self):
|
||||
|
||||
Reference in New Issue
Block a user