From aa2e7b11f871edffbfdc581efd85b8cbcaf6ca2b Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 1 Dec 2024 10:20:30 -0500 Subject: [PATCH] more const folding infra from the delete_lazy branch [pr] (#7976) * more const folding infra from the delete_lazy branch [pr] * sink base * limit --- tinygrad/engine/schedule.py | 16 ++++++++++------ tinygrad/ops.py | 6 ++++-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 1106dacbd3..7a279e030e 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -5,7 +5,7 @@ from typing import FrozenSet, Set, Tuple, List, Dict, Optional, DefaultDict from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG -from tinygrad.dtype import ImageDType, dtypes +from tinygrad.dtype import ConstType, ImageDType, dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View, strides_for_shape from tinygrad.engine.lazy import LazyBuffer @@ -331,12 +331,16 @@ class UPatScheduled(UPat): def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{**kwargs,"name":"to_store"}))) -# ** this folds ops that don't need a BUFFER +# ** this is schedule level const folding + +def _as_const(u:UOp, val:ConstType) -> UOp: + assert is_scheduled(u), f"must be scheduled to fold {u}" + st = (base:=ShapeTracker.from_shape(())).reshape((1,)*len(u.shape)).expand(u.shape) + return UOp(Ops.VIEW, u.dtype, (u.buf_uop, UOp.const(u.dtype, val)), base).view(st) ops_folding = PatternMatcher([ - # op with size 0 is just zero - (UPatScheduled(), lambda ctx,b,to_store,base: UOp(Ops.VIEW, base.dtype, (b, UOp.const(base.dtype, 0)), base.st) - if base.st.size == 0 and to_store is not UOp.const(base.dtype, 0) else None), + # op with size 0 is zero + (UPatScheduled(), lambda ctx,b,to_store,base: _as_const(base, 0) if base.size == 0 else None), ]) # ** this decides which ops get realized @@ -362,7 +366,7 @@ def fold_img_cast(ctx:Dict[UOp, UOp], xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kw return to_cast.view(unwrap(view.st)) def init_big_graph(ctx:ScheduleContext, sink:UOp) -> Optional[UOp]: - new_src = tuple(x for x in sink.src if is_scheduled(x) and uval(x).op is not Ops.CONST) + new_src = tuple(x.base for x in sink.src if is_scheduled(x.base) and uval(x.base).op is not Ops.CONST) return None if new_src == sink.src else UOp(Ops.NOOP) if len(new_src) == 0 else UOp.sink(*new_src) do_realize = PatternMatcher([ diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 757ac1128a..b67003ff5d 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -274,6 +274,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @functools.cached_property def full_shape(self) -> Tuple[sint, ...]: return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st])) + @property + def shape(self) -> Tuple[sint, ...]: return unwrap(self.st).shape + @property + def size(self) -> int: return self.arg[1][1] if self.op is Ops.BUFFER else unwrap(self.st).size # *** uop evaluation *** @@ -383,8 +387,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): case Ops.BUFFER: return self.arg[1][0] case _: return self.src[0].device @property - def size(self) -> int: return self.buf_uop.arg[1][1] - @property def buf_uop(self) -> UOp: if self.op is Ops.BUFFER: return self assert self.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}"