scheduler: don't trade complexity for speed (#8370)

* scheduler: don't trade complexity for speed

* don't need is_scheduled

* make those tests real world

* graph_rewrite dedup
This commit is contained in:
qazal
2024-12-22 03:30:51 +02:00
committed by GitHub
parent 991b91d4d6
commit 88bc51385c
2 changed files with 21 additions and 32 deletions

View File

@@ -16,7 +16,7 @@ from tinygrad.shape.view import View
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, view_supported_devices
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleContext, ScheduleItem, create_schedule, view_right, view_left, do_realize, remove_movement_ops
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule, view_right, view_left, remove_movement_ops
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
from extra.models.llama import precompute_freqs_cis
@@ -247,6 +247,11 @@ class TestSchedule(unittest.TestCase):
run_schedule(sched)
self.assertIsNot(a.lazydata.realized, b.lazydata.realized)
def test_dedup_outputs(self):
a = Tensor.full((4, 4), 1.).contiguous().realize()
b = Tensor.full((4, 4), 1.).contiguous().realize()
check_schedule([a+b, a+b], 1)
def test_fold_double_unary(self):
y = Tensor.empty(2)
out = y.sum(keepdim=True).sqrt().__neg__()
@@ -1972,29 +1977,14 @@ class TestView(unittest.TestCase):
run_schedule(sched)
np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:])
@track_rewrites(named=True)
def big_graph_rewrite(big_graph:UOp, ctx) -> UOp: return graph_rewrite(big_graph, do_realize, ctx)
class TestBigGraph(unittest.TestCase):
def test_sink_childless_const(self):
x = UOp.const(dtypes.int, 0)
big_graph = big_graph_rewrite(x.sink(), ctx:=ScheduleContext())
self.assertIs(big_graph, UOp(Ops.NOOP))
self.assertEqual(len(ctx.realizes), 0)
def test_sink_childless_const_alt(self):
x = UOp.const(dtypes.int, 0)
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
big_graph = big_graph_rewrite(UOp.sink(x, y), ctx:=ScheduleContext())
self.assertIs(big_graph, UOp(Ops.NOOP))
self.assertEqual(len(ctx.realizes), 0)
x = Tensor(0)
check_schedule(x, 0)
def test_sink_childless_const_alt_expanded(self):
# this is a real STORE of CONST (post expand)
y = UOp(Ops.VIEW, dtypes.int, (UOp.new_buffer(Device.DEFAULT, 1, dtypes.int), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
out = UOp(Ops.VIEW, dtypes.int, (UOp.new_buffer(Device.DEFAULT, 2, dtypes.int), y.reshape((1,)).expand((2,)).contiguous(),), ShapeTracker.from_shape((2,)))
big_graph = big_graph_rewrite(out.sink(), ctx:=ScheduleContext())
self.assertIs(big_graph, out.sink())
self.assertEqual(len(ctx.realizes), 1)
x = Tensor.zeros(4, 4).contiguous()
check_schedule(x, 1)
tensor_const_pm = PatternMatcher([
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.CONST, src=()))), lambda: True),
@@ -2091,7 +2081,7 @@ class TestConst(unittest.TestCase):
# ** part 3: Tensor variable bindings
@unittest.expectedFailure # TODO: should schedule assert if you try to realize a Variable?
#@unittest.expectedFailure # TODO: should schedule assert if you try to realize a Variable?
def test_var_schedule(self):
vv = UOp.variable("a", 0, 10).bind(1)
a = Tensor(vv)