more high level contiguous tests + scheduler deletions [pr] (#8695)

* delete those

* move the upat too

* rename ops_folding to just sym

* keep that
This commit is contained in:
qazal
2025-01-20 18:52:58 -05:00
committed by GitHub
parent 08eb1f1f56
commit 66ac0087e8
2 changed files with 28 additions and 33 deletions

View File

@@ -16,7 +16,7 @@ from tinygrad.shape.view import View
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same
from tinygrad.codegen.kernel import verify_ast
from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, ops_folding
from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from extra.models.llama import precompute_freqs_cis
@@ -67,7 +67,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
@track_rewrites(named=True)
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext())
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, ScheduleContext())
class TestSchedule(unittest.TestCase):
def test_basic_binop_fusion(self):
@@ -1824,7 +1824,7 @@ def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.
# these pattern matchers should move to engine/schedule.py
sym = symbolic_simple+PatternMatcher([
ops_folding = symbolic_simple+PatternMatcher([
(UPat(Ops.DETACH, name="x"), lambda x:x.src[0]),
])
@@ -1842,8 +1842,8 @@ def run_tensor_ast(r:Tensor):
output = UOp.new_buffer(r.device, r.lazydata.size, r.dtype)
glbl = UOp(Ops.DEFINE_GLOBAL, output.dtype.ptr(size=output.size), (), 0)
sink = UOp(Ops.STORE, src=(glbl, ShapeTracker.from_shape(r.lazydata.base.shape).to_uop(), r.lazydata.base)).sink()
sink = graph_rewrite(sink, remove_movement_ops+sym+load_buffers+view_left, bufs:=[output])
sink = graph_rewrite(sink, remove_movement_ops+sym+view_right)
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+load_buffers+view_left, bufs:=[output])
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+view_right)
si = ScheduleItem(sink, tuple(x.buffer for x in bufs), (), ())
run_schedule([si])
return output.realized.as_buffer().cast(output.dtype.fmt, r.shape).tolist()
@@ -2336,34 +2336,29 @@ class TestBufferUOp(unittest.TestCase):
class TestContiguous(unittest.TestCase):
def test_contiguous_buffer(self):
a = Tensor.empty(4).lazydata
b = a.alu(Ops.CONTIGUOUS)
b = schedule_graph_rewrite(b)
self.assertIs(b, a)
a = Tensor.empty(4)
b = a.contiguous()
check_schedule(b, 0)
def test_contiguous_buffer_view(self):
a = Tensor.empty(4).lazydata
b = a.reshape((2, 2)).alu(Ops.CONTIGUOUS)
b = schedule_graph_rewrite(b)
self.assertIs(b, a.buf_uop.view(unwrap(b.st)))
a = Tensor.empty(4)
b = a.reshape((2, 2)).contiguous()
check_schedule(b, 0)
def test_non_contiguous_buffer_view(self):
a = Tensor.empty(4, 1).lazydata
b = a.expand((4, 4)).alu(Ops.CONTIGUOUS)
b = schedule_graph_rewrite(b)
assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {})
a = Tensor.empty(4, 1)
b = a.expand((4, 4)).contiguous()
check_schedule(b, 1)
def test_size_change_buffer_view(self):
a = Tensor.empty(4).lazydata
b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).alu(Ops.CONTIGUOUS)
b = schedule_graph_rewrite(b)
assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {})
a = Tensor.empty(4)
b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).contiguous()
check_schedule(b, 1)
def test_double_contiguous_realizes_once(self):
a = Tensor.empty(4, 1).lazydata
b = a.expand((4, 4)).alu(Ops.CONTIGUOUS).alu(Ops.CONTIGUOUS)
b = schedule_graph_rewrite(b)
assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {})
a = Tensor.empty(4, 1)
b = a.expand((4, 4)).contiguous().contiguous()
check_schedule(b, 1)
if __name__ == '__main__':
unittest.main(verbosity=2)