From 08eb1f1f56cc71279deed76abcbba9ed239fafbd Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 20 Jan 2025 16:42:42 -0500 Subject: [PATCH] simplify tensors before scheduling [pr] (#8580) * delete forced_realize * put that back * work * remove forced_realize * expectedFailures * contiguous(buffer) * multi * expectedFailures * cleaner create_subbuffer * more comments * remove that * note * realizes * work * one upat and image is back * remove * cleaner * fix test_complex_backward for now --------- Co-authored-by: George Hotz --- test/test_image_dtype.py | 3 +- test/test_schedule.py | 6 +-- tinygrad/engine/schedule.py | 77 ++++++++++++------------------------- 3 files changed, 30 insertions(+), 56 deletions(-) diff --git a/test/test_image_dtype.py b/test/test_image_dtype.py index 88a4c929c4..62fcb4a443 100644 --- a/test/test_image_dtype.py +++ b/test/test_image_dtype.py @@ -113,7 +113,8 @@ class TestImageDType(unittest.TestCase): assert it.lazydata.base.realized._buf != b1 # issue caused by: don't realize image to image casts. this is part of a larger problem - @unittest.expectedFailure + #@unittest.expectedFailure + # update: passing after tensor_map def test_lil_model(self): with Context(IMAGE=2): x = Tensor.zeros(1, 1) diff --git a/test/test_schedule.py b/test/test_schedule.py index 79fc2516ac..b0426630de 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -220,7 +220,7 @@ class TestSchedule(unittest.TestCase): GlobalCounters.reset() expr = (a*b)/b expr.realize() - self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertEqual(GlobalCounters.kernel_count, 0) # the scheduler can fold divs now! self.assertEqual(GlobalCounters.global_ops, 0) np.testing.assert_allclose(expr.numpy(), np.full((4,), 4.0)) @@ -229,7 +229,7 @@ class TestSchedule(unittest.TestCase): GlobalCounters.reset() expr = a/a expr.realize() - self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertEqual(GlobalCounters.kernel_count, 0) self.assertEqual(GlobalCounters.global_ops, 0) np.testing.assert_allclose(expr.numpy(), np.full((4,), 1.0)) @@ -2204,7 +2204,7 @@ class TestConst(unittest.TestCase): sched = add.schedule() self.assertEqual(len(sched), 0) # b+0 and b share the same underlying device memory - self.assertIs(add.lazydata.realized, b.lazydata.realized) + self.assertIs(add.lazydata.buffer, b.lazydata.buffer) self.assertListEqual(add.tolist(), [2, 2, 2, 2]) def test_src_masked_const_folding(self): diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 9db897446f..d5acb3cbe3 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -2,7 +2,7 @@ import sys, atexit, functools, pickle from collections import defaultdict, deque from dataclasses import dataclass, field from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views -from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify +from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify, graph_rewrite_map from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, ContextVar from tinygrad.dtype import DType, ImageDType, dtypes @@ -88,15 +88,15 @@ class ScheduleContext: # wrap tensor uops around a VIEW(BUFFER, ) # this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it. -def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp: +def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r # SINK is passthrough - if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, ctx, cache) for x in buf.src)) + if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src)) # skip creating buffers for CONST/BIND/DEVICE/BUFFER if buf.base.is_realized or buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf # VIEW is passthrough if buf is not buf.base: - cache[buf] = ret = add_buffers(buf.base, ctx, cache).view(unwrap(buf.st)) + cache[buf] = ret = add_buffers(buf.base, tensor_map, ctx, cache).view(unwrap(buf.st)) return ret # make things that can't be images not images dtype = buf.dtype @@ -105,9 +105,9 @@ def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp: dtype = buf.dtype.base # ASSIGN already has a target buffer, otherwise we create a new one buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype) - op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, ctx, cache) for x in buf.src)) + op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src)) # track the underlying tensor uop for this buffer - ctx.tensor_uops[buf_uop] = [buf] + ctx.tensor_uops[buf_uop] = tensor_map[buf] # (early) bufferize cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st) return ret @@ -358,10 +358,8 @@ def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None: case _: return None return reduce.const_like(ret) -def found_contiguous(ctx:ScheduleContext, contig:UOp, base:UOp, b:UOp): - if contig.src[0].op is Ops.VIEW and len(contig.src[0].src): - old_base = contig.src[0].src[0] - if old_base.op is Ops.VIEW and (sti:=unwrap(contig.src[0].st).invert(old_base.shape)) is not None: ctx.contiguous[old_base] = base.view(sti) +def found_contiguous(ctx:ScheduleContext, contig:UOp, src:UOp): + if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx.contiguous[src.base] = contig.view(sti) def replace_contiguous(ctx:ScheduleContext, alu:UOp): new_src = list(alu.src) for i,s in enumerate(alu.src): @@ -372,8 +370,6 @@ ops_folding = symbolic_simple+PatternMatcher([ # op with size 0 is zero (UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \ and not (root.base.op is Ops.CONST and root.base.arg == 0) else None), - # if the uop folded to a CONST we can delete the BUFFER - (UPatScheduled(Ops.CONST, name="const"), lambda b,base,const: base.const_like(const.const_arg)), # DETACH is a NOOP here (UPat(Ops.DETACH, name="detach"), lambda detach: detach.src[0]), # reduce of size 0 is the identity element @@ -386,13 +382,16 @@ ops_folding = symbolic_simple+PatternMatcher([ # no COPY to same device, except clone (arg is True) (UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"), lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None), + # remove cast to image when it's already a contiguous image + (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)), + lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None), # remove contiguous if we can just view the buffer (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)), lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None), # double contiguous is one contiguous (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.CONTIGUOUS),)), lambda root: root.src[0]), # support for using a contiguous permuted view instead of the parent view if one exists - (UPatScheduled(Ops.CONTIGUOUS, name="contig"), found_contiguous), + (UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous), (UPat(GroupOp.ALU, name="alu"), replace_contiguous), # remove CONST/BIND/BUFFER/VIEW from SINK (UPat(Ops.SINK, name="root"), @@ -400,34 +399,6 @@ ops_folding = symbolic_simple+PatternMatcher([ if (new_src:=tuple(x.base for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None), ]) -# ** buffer merging - -def merge(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp) -> UOp: - assert v1.st is not None and v2.st is not None and v1.st == v2.st, f"implicit movementop {v1.st} {v2.st}" - # if b2 is realized also realize b1 - if b2 in ctx.realizes: - ctx.realizes[b1] = b1 - del ctx.realizes[b2] - # ops referring to b2 now ref to b1 - ctx.tensor_uops[b1] += ctx.tensor_uops[b2] - del ctx.tensor_uops[b2] - # merge - return v1 - -def merge_realized(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp): - # early become - for luop in ctx.tensor_uops.get(b1, [])+ctx.tensor_uops.get(b2, []): ctx.becomes_map[luop] = b1.view(unwrap(luop.st)) - return v1 - -merge_bufs = PatternMatcher([ - # merge base - (UPat(Ops.VIEW, name="v2", src=(UPat(Ops.BUFFER, name="b2"), UPat(Ops.VIEW, name="v1", src=(UPat(Ops.BUFFER, name="b1"), UPat())))), merge), - (UPat(Ops.VIEW, name="v2", src=(UPat(Ops.BUFFER, name="b2"), UPat(Ops.VIEW, name="v1", src=(UPat(Ops.BUFFER, name="b1"),)))), merge_realized), - # collapse view - (UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat())).view(name="mv"))), lambda mv:mv), - (UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).view(name="mv"))), lambda mv:mv), -]) - # ** this decides which ops get realized def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store @@ -481,7 +452,7 @@ def load_realized(ctx:ScheduleContext, b:UOp, st:UOp): return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop())) def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp): - if (m:=ctx.tensor_uops[b][0].metadata) is not None: ctx.ops_metadata[x] = m + if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m if b not in ctx.realizes: return x # collapse BUFFER ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x) return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop())) @@ -523,15 +494,13 @@ remove_movement_ops = PatternMatcher([ @track_rewrites(named=True) def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec) - # if using VIZ, do a graph rewrite to vizualize the Tensor graph - if getenv("VIZ"): graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext()) + tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+ops_folding, ctx:=ScheduleContext()) + rev_tensor_map: dict[UOp, list[UOp]] = {} + for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k) # add BUFFER uops - sink = add_buffers(big_sink, ctx:=ScheduleContext(), cache={}) - # const folding and fusion - sink = graph_rewrite(sink, remove_movement_ops+ops_folding+do_realize, ctx) - sink = graph_rewrite(sink, merge_bufs, ctx) - # create the scheduler context - graph_rewrite(sink, create_ctx, ctx) + sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx, cache={}) + # add realizes + sink = graph_rewrite(sink, do_realize+create_ctx, ctx) # group realizes into kernels store_groups = group_realizes(ctx) graph_rewrite(sink, break_sched, ctx) @@ -539,13 +508,17 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu prescheduled: list[ScheduleItem] = [] for store_uops in store_groups: small_sink = UOp.sink(*[ctx.realizes[u] for u in store_uops]) - # TODO: this still exists because symbolic folding is happening after bufferization - if not all(x.op is Ops.STORE for x in small_sink.src): continue + if not all(x.op is Ops.STORE for x in small_sink.src): raise RuntimeError(f"expected all realized BUFFERs to get a STORE {sink}") prescheduled.append(schedule_uop(small_sink, ctx)) # can only schedule once for buf_uop in store_uops: for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st)) + # tensors can become an existing buffer, no ScheduleItem needed + for k,v in tensor_map.items(): + # NOTE: we only add base tensors to becomes_map + if k is not v and v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st)) + # add kernel children schedule_targets = {out:si for si in prescheduled for out in si.outputs} graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list)