mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
more high level contiguous tests + scheduler deletions [pr] (#8695)
* delete those * move the upat too * rename ops_folding to just sym * keep that
This commit is contained in:
@@ -16,7 +16,7 @@ from tinygrad.shape.view import View
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same
|
||||
from tinygrad.codegen.kernel import verify_ast
|
||||
from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, ops_folding
|
||||
from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
||||
@@ -67,7 +67,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
||||
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext())
|
||||
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, ScheduleContext())
|
||||
|
||||
class TestSchedule(unittest.TestCase):
|
||||
def test_basic_binop_fusion(self):
|
||||
@@ -1824,7 +1824,7 @@ def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.
|
||||
|
||||
# these pattern matchers should move to engine/schedule.py
|
||||
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
ops_folding = symbolic_simple+PatternMatcher([
|
||||
(UPat(Ops.DETACH, name="x"), lambda x:x.src[0]),
|
||||
])
|
||||
|
||||
@@ -1842,8 +1842,8 @@ def run_tensor_ast(r:Tensor):
|
||||
output = UOp.new_buffer(r.device, r.lazydata.size, r.dtype)
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, output.dtype.ptr(size=output.size), (), 0)
|
||||
sink = UOp(Ops.STORE, src=(glbl, ShapeTracker.from_shape(r.lazydata.base.shape).to_uop(), r.lazydata.base)).sink()
|
||||
sink = graph_rewrite(sink, remove_movement_ops+sym+load_buffers+view_left, bufs:=[output])
|
||||
sink = graph_rewrite(sink, remove_movement_ops+sym+view_right)
|
||||
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+load_buffers+view_left, bufs:=[output])
|
||||
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+view_right)
|
||||
si = ScheduleItem(sink, tuple(x.buffer for x in bufs), (), ())
|
||||
run_schedule([si])
|
||||
return output.realized.as_buffer().cast(output.dtype.fmt, r.shape).tolist()
|
||||
@@ -2336,34 +2336,29 @@ class TestBufferUOp(unittest.TestCase):
|
||||
|
||||
class TestContiguous(unittest.TestCase):
|
||||
def test_contiguous_buffer(self):
|
||||
a = Tensor.empty(4).lazydata
|
||||
b = a.alu(Ops.CONTIGUOUS)
|
||||
b = schedule_graph_rewrite(b)
|
||||
self.assertIs(b, a)
|
||||
a = Tensor.empty(4)
|
||||
b = a.contiguous()
|
||||
check_schedule(b, 0)
|
||||
|
||||
def test_contiguous_buffer_view(self):
|
||||
a = Tensor.empty(4).lazydata
|
||||
b = a.reshape((2, 2)).alu(Ops.CONTIGUOUS)
|
||||
b = schedule_graph_rewrite(b)
|
||||
self.assertIs(b, a.buf_uop.view(unwrap(b.st)))
|
||||
a = Tensor.empty(4)
|
||||
b = a.reshape((2, 2)).contiguous()
|
||||
check_schedule(b, 0)
|
||||
|
||||
def test_non_contiguous_buffer_view(self):
|
||||
a = Tensor.empty(4, 1).lazydata
|
||||
b = a.expand((4, 4)).alu(Ops.CONTIGUOUS)
|
||||
b = schedule_graph_rewrite(b)
|
||||
assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {})
|
||||
a = Tensor.empty(4, 1)
|
||||
b = a.expand((4, 4)).contiguous()
|
||||
check_schedule(b, 1)
|
||||
|
||||
def test_size_change_buffer_view(self):
|
||||
a = Tensor.empty(4).lazydata
|
||||
b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).alu(Ops.CONTIGUOUS)
|
||||
b = schedule_graph_rewrite(b)
|
||||
assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {})
|
||||
a = Tensor.empty(4)
|
||||
b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).contiguous()
|
||||
check_schedule(b, 1)
|
||||
|
||||
def test_double_contiguous_realizes_once(self):
|
||||
a = Tensor.empty(4, 1).lazydata
|
||||
b = a.expand((4, 4)).alu(Ops.CONTIGUOUS).alu(Ops.CONTIGUOUS)
|
||||
b = schedule_graph_rewrite(b)
|
||||
assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {})
|
||||
a = Tensor.empty(4, 1)
|
||||
b = a.expand((4, 4)).contiguous().contiguous()
|
||||
check_schedule(b, 1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -242,7 +242,7 @@ if CAPTURE_PROCESS_REPLAY:
|
||||
def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER
|
||||
def uval(u:UOp) -> UOp:
|
||||
assert is_scheduled(u), f"must be a scheduled op {u}"
|
||||
return r.src[0] if (r:=u.src[1]).op is Ops.CONTIGUOUS and not (r.src[0].base.op is Ops.VIEW and len(r.src[0].base.src) == 2) else r
|
||||
return u.src[1]
|
||||
|
||||
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp], realizes:dict[UOp, UOp],
|
||||
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
|
||||
@@ -340,10 +340,6 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]:
|
||||
|
||||
# **** Schedule creation and BFS toposort
|
||||
|
||||
class UPatScheduled(UPat):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
|
||||
|
||||
# ** this is schedule level const folding
|
||||
|
||||
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
|
||||
@@ -366,8 +362,8 @@ def replace_contiguous(ctx:ScheduleContext, alu:UOp):
|
||||
if (replace_src:=ctx.contiguous.get(s, None)) is not None: new_src[i] = replace_src
|
||||
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
|
||||
|
||||
ops_folding = symbolic_simple+PatternMatcher([
|
||||
# op with size 0 is zero
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
# UOp with size 0 is zero
|
||||
(UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
|
||||
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
|
||||
# DETACH is a NOOP here
|
||||
@@ -401,6 +397,10 @@ ops_folding = symbolic_simple+PatternMatcher([
|
||||
|
||||
# ** this decides which ops get realized
|
||||
|
||||
class UPatScheduled(UPat):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
|
||||
|
||||
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store
|
||||
|
||||
def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
|
||||
@@ -494,7 +494,7 @@ remove_movement_ops = PatternMatcher([
|
||||
@track_rewrites(named=True)
|
||||
def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
||||
if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec)
|
||||
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+ops_folding, ctx:=ScheduleContext())
|
||||
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx:=ScheduleContext())
|
||||
rev_tensor_map: dict[UOp, list[UOp]] = {}
|
||||
for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k)
|
||||
# add BUFFER uops
|
||||
|
||||
Reference in New Issue
Block a user