mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 15:45:27 -05:00
scheduler: don't trade complexity for speed (#8370)
* scheduler: don't trade complexity for speed * don't need is_scheduled * make those tests real world * graph_rewrite dedup
This commit is contained in:
@@ -16,7 +16,7 @@ from tinygrad.shape.view import View
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, view_supported_devices
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleContext, ScheduleItem, create_schedule, view_right, view_left, do_realize, remove_movement_ops
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule, view_right, view_left, remove_movement_ops
|
||||
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
||||
@@ -247,6 +247,11 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(sched)
|
||||
self.assertIsNot(a.lazydata.realized, b.lazydata.realized)
|
||||
|
||||
def test_dedup_outputs(self):
|
||||
a = Tensor.full((4, 4), 1.).contiguous().realize()
|
||||
b = Tensor.full((4, 4), 1.).contiguous().realize()
|
||||
check_schedule([a+b, a+b], 1)
|
||||
|
||||
def test_fold_double_unary(self):
|
||||
y = Tensor.empty(2)
|
||||
out = y.sum(keepdim=True).sqrt().__neg__()
|
||||
@@ -1972,29 +1977,14 @@ class TestView(unittest.TestCase):
|
||||
run_schedule(sched)
|
||||
np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:])
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def big_graph_rewrite(big_graph:UOp, ctx) -> UOp: return graph_rewrite(big_graph, do_realize, ctx)
|
||||
class TestBigGraph(unittest.TestCase):
|
||||
def test_sink_childless_const(self):
|
||||
x = UOp.const(dtypes.int, 0)
|
||||
big_graph = big_graph_rewrite(x.sink(), ctx:=ScheduleContext())
|
||||
self.assertIs(big_graph, UOp(Ops.NOOP))
|
||||
self.assertEqual(len(ctx.realizes), 0)
|
||||
|
||||
def test_sink_childless_const_alt(self):
|
||||
x = UOp.const(dtypes.int, 0)
|
||||
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
|
||||
big_graph = big_graph_rewrite(UOp.sink(x, y), ctx:=ScheduleContext())
|
||||
self.assertIs(big_graph, UOp(Ops.NOOP))
|
||||
self.assertEqual(len(ctx.realizes), 0)
|
||||
x = Tensor(0)
|
||||
check_schedule(x, 0)
|
||||
|
||||
def test_sink_childless_const_alt_expanded(self):
|
||||
# this is a real STORE of CONST (post expand)
|
||||
y = UOp(Ops.VIEW, dtypes.int, (UOp.new_buffer(Device.DEFAULT, 1, dtypes.int), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
|
||||
out = UOp(Ops.VIEW, dtypes.int, (UOp.new_buffer(Device.DEFAULT, 2, dtypes.int), y.reshape((1,)).expand((2,)).contiguous(),), ShapeTracker.from_shape((2,)))
|
||||
big_graph = big_graph_rewrite(out.sink(), ctx:=ScheduleContext())
|
||||
self.assertIs(big_graph, out.sink())
|
||||
self.assertEqual(len(ctx.realizes), 1)
|
||||
x = Tensor.zeros(4, 4).contiguous()
|
||||
check_schedule(x, 1)
|
||||
|
||||
tensor_const_pm = PatternMatcher([
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.CONST, src=()))), lambda: True),
|
||||
@@ -2091,7 +2081,7 @@ class TestConst(unittest.TestCase):
|
||||
|
||||
# ** part 3: Tensor variable bindings
|
||||
|
||||
@unittest.expectedFailure # TODO: should schedule assert if you try to realize a Variable?
|
||||
#@unittest.expectedFailure # TODO: should schedule assert if you try to realize a Variable?
|
||||
def test_var_schedule(self):
|
||||
vv = UOp.variable("a", 0, 10).bind(1)
|
||||
a = Tensor(vv)
|
||||
|
||||
Reference in New Issue
Block a user