more const folding infra from the delete_lazy branch [pr] (#7976)

* more const folding infra from the delete_lazy branch [pr]

* sink base

* limit
This commit is contained in:
qazal
2024-12-01 10:20:30 -05:00
committed by GitHub
parent 509c4a573f
commit aa2e7b11f8
2 changed files with 14 additions and 8 deletions

View File

@@ -5,7 +5,7 @@ from typing import FrozenSet, Set, Tuple, List, Dict, Optional, DefaultDict
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.dtype import ConstType, ImageDType, dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape
from tinygrad.engine.lazy import LazyBuffer
@@ -331,12 +331,16 @@ class UPatScheduled(UPat):
def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"),
UPat(*args, **{**kwargs,"name":"to_store"})))
# ** this folds ops that don't need a BUFFER
# ** this is schedule level const folding
def _as_const(u:UOp, val:ConstType) -> UOp:
assert is_scheduled(u), f"must be scheduled to fold {u}"
st = (base:=ShapeTracker.from_shape(())).reshape((1,)*len(u.shape)).expand(u.shape)
return UOp(Ops.VIEW, u.dtype, (u.buf_uop, UOp.const(u.dtype, val)), base).view(st)
ops_folding = PatternMatcher([
# op with size 0 is just zero
(UPatScheduled(), lambda ctx,b,to_store,base: UOp(Ops.VIEW, base.dtype, (b, UOp.const(base.dtype, 0)), base.st)
if base.st.size == 0 and to_store is not UOp.const(base.dtype, 0) else None),
# op with size 0 is zero
(UPatScheduled(), lambda ctx,b,to_store,base: _as_const(base, 0) if base.size == 0 else None),
])
# ** this decides which ops get realized
@@ -362,7 +366,7 @@ def fold_img_cast(ctx:Dict[UOp, UOp], xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kw
return to_cast.view(unwrap(view.st))
def init_big_graph(ctx:ScheduleContext, sink:UOp) -> Optional[UOp]:
new_src = tuple(x for x in sink.src if is_scheduled(x) and uval(x).op is not Ops.CONST)
new_src = tuple(x.base for x in sink.src if is_scheduled(x.base) and uval(x.base).op is not Ops.CONST)
return None if new_src == sink.src else UOp(Ops.NOOP) if len(new_src) == 0 else UOp.sink(*new_src)
do_realize = PatternMatcher([

View File

@@ -274,6 +274,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def full_shape(self) -> Tuple[sint, ...]:
return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
@property
def shape(self) -> Tuple[sint, ...]: return unwrap(self.st).shape
@property
def size(self) -> int: return self.arg[1][1] if self.op is Ops.BUFFER else unwrap(self.st).size
# *** uop evaluation ***
@@ -383,8 +387,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
case Ops.BUFFER: return self.arg[1][0]
case _: return self.src[0].device
@property
def size(self) -> int: return self.buf_uop.arg[1][1]
@property
def buf_uop(self) -> UOp:
if self.op is Ops.BUFFER: return self
assert self.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}"