mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
more high level contiguous tests + scheduler deletions [pr] (#8695)
* delete those * move the upat too * rename ops_folding to just sym * keep that
This commit is contained in:
@@ -16,7 +16,7 @@ from tinygrad.shape.view import View
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same
|
||||
from tinygrad.codegen.kernel import verify_ast
|
||||
from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, ops_folding
|
||||
from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
||||
@@ -67,7 +67,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
||||
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext())
|
||||
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, ScheduleContext())
|
||||
|
||||
class TestSchedule(unittest.TestCase):
|
||||
def test_basic_binop_fusion(self):
|
||||
@@ -1824,7 +1824,7 @@ def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.
|
||||
|
||||
# these pattern matchers should move to engine/schedule.py
|
||||
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
ops_folding = symbolic_simple+PatternMatcher([
|
||||
(UPat(Ops.DETACH, name="x"), lambda x:x.src[0]),
|
||||
])
|
||||
|
||||
@@ -1842,8 +1842,8 @@ def run_tensor_ast(r:Tensor):
|
||||
output = UOp.new_buffer(r.device, r.lazydata.size, r.dtype)
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, output.dtype.ptr(size=output.size), (), 0)
|
||||
sink = UOp(Ops.STORE, src=(glbl, ShapeTracker.from_shape(r.lazydata.base.shape).to_uop(), r.lazydata.base)).sink()
|
||||
sink = graph_rewrite(sink, remove_movement_ops+sym+load_buffers+view_left, bufs:=[output])
|
||||
sink = graph_rewrite(sink, remove_movement_ops+sym+view_right)
|
||||
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+load_buffers+view_left, bufs:=[output])
|
||||
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+view_right)
|
||||
si = ScheduleItem(sink, tuple(x.buffer for x in bufs), (), ())
|
||||
run_schedule([si])
|
||||
return output.realized.as_buffer().cast(output.dtype.fmt, r.shape).tolist()
|
||||
@@ -2336,34 +2336,29 @@ class TestBufferUOp(unittest.TestCase):
|
||||
|
||||
class TestContiguous(unittest.TestCase):
|
||||
def test_contiguous_buffer(self):
|
||||
a = Tensor.empty(4).lazydata
|
||||
b = a.alu(Ops.CONTIGUOUS)
|
||||
b = schedule_graph_rewrite(b)
|
||||
self.assertIs(b, a)
|
||||
a = Tensor.empty(4)
|
||||
b = a.contiguous()
|
||||
check_schedule(b, 0)
|
||||
|
||||
def test_contiguous_buffer_view(self):
|
||||
a = Tensor.empty(4).lazydata
|
||||
b = a.reshape((2, 2)).alu(Ops.CONTIGUOUS)
|
||||
b = schedule_graph_rewrite(b)
|
||||
self.assertIs(b, a.buf_uop.view(unwrap(b.st)))
|
||||
a = Tensor.empty(4)
|
||||
b = a.reshape((2, 2)).contiguous()
|
||||
check_schedule(b, 0)
|
||||
|
||||
def test_non_contiguous_buffer_view(self):
|
||||
a = Tensor.empty(4, 1).lazydata
|
||||
b = a.expand((4, 4)).alu(Ops.CONTIGUOUS)
|
||||
b = schedule_graph_rewrite(b)
|
||||
assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {})
|
||||
a = Tensor.empty(4, 1)
|
||||
b = a.expand((4, 4)).contiguous()
|
||||
check_schedule(b, 1)
|
||||
|
||||
def test_size_change_buffer_view(self):
|
||||
a = Tensor.empty(4).lazydata
|
||||
b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).alu(Ops.CONTIGUOUS)
|
||||
b = schedule_graph_rewrite(b)
|
||||
assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {})
|
||||
a = Tensor.empty(4)
|
||||
b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).contiguous()
|
||||
check_schedule(b, 1)
|
||||
|
||||
def test_double_contiguous_realizes_once(self):
|
||||
a = Tensor.empty(4, 1).lazydata
|
||||
b = a.expand((4, 4)).alu(Ops.CONTIGUOUS).alu(Ops.CONTIGUOUS)
|
||||
b = schedule_graph_rewrite(b)
|
||||
assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {})
|
||||
a = Tensor.empty(4, 1)
|
||||
b = a.expand((4, 4)).contiguous().contiguous()
|
||||
check_schedule(b, 1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user