mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 07:18:40 -05:00
more const folding infra from the delete_lazy branch [pr] (#7976)
* more const folding infra from the delete_lazy branch [pr] * sink base * limit
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import FrozenSet, Set, Tuple, List, Dict, Optional, DefaultDict
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
|
||||
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
@@ -331,12 +331,16 @@ class UPatScheduled(UPat):
|
||||
def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"),
|
||||
UPat(*args, **{**kwargs,"name":"to_store"})))
|
||||
|
||||
# ** this folds ops that don't need a BUFFER
|
||||
# ** this is schedule level const folding
|
||||
|
||||
def _as_const(u:UOp, val:ConstType) -> UOp:
|
||||
assert is_scheduled(u), f"must be scheduled to fold {u}"
|
||||
st = (base:=ShapeTracker.from_shape(())).reshape((1,)*len(u.shape)).expand(u.shape)
|
||||
return UOp(Ops.VIEW, u.dtype, (u.buf_uop, UOp.const(u.dtype, val)), base).view(st)
|
||||
|
||||
ops_folding = PatternMatcher([
|
||||
# op with size 0 is just zero
|
||||
(UPatScheduled(), lambda ctx,b,to_store,base: UOp(Ops.VIEW, base.dtype, (b, UOp.const(base.dtype, 0)), base.st)
|
||||
if base.st.size == 0 and to_store is not UOp.const(base.dtype, 0) else None),
|
||||
# op with size 0 is zero
|
||||
(UPatScheduled(), lambda ctx,b,to_store,base: _as_const(base, 0) if base.size == 0 else None),
|
||||
])
|
||||
|
||||
# ** this decides which ops get realized
|
||||
@@ -362,7 +366,7 @@ def fold_img_cast(ctx:Dict[UOp, UOp], xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kw
|
||||
return to_cast.view(unwrap(view.st))
|
||||
|
||||
def init_big_graph(ctx:ScheduleContext, sink:UOp) -> Optional[UOp]:
|
||||
new_src = tuple(x for x in sink.src if is_scheduled(x) and uval(x).op is not Ops.CONST)
|
||||
new_src = tuple(x.base for x in sink.src if is_scheduled(x.base) and uval(x.base).op is not Ops.CONST)
|
||||
return None if new_src == sink.src else UOp(Ops.NOOP) if len(new_src) == 0 else UOp.sink(*new_src)
|
||||
|
||||
do_realize = PatternMatcher([
|
||||
|
||||
@@ -274,6 +274,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@functools.cached_property
|
||||
def full_shape(self) -> Tuple[sint, ...]:
|
||||
return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
||||
@property
|
||||
def shape(self) -> Tuple[sint, ...]: return unwrap(self.st).shape
|
||||
@property
|
||||
def size(self) -> int: return self.arg[1][1] if self.op is Ops.BUFFER else unwrap(self.st).size
|
||||
|
||||
# *** uop evaluation ***
|
||||
|
||||
@@ -383,8 +387,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
case Ops.BUFFER: return self.arg[1][0]
|
||||
case _: return self.src[0].device
|
||||
@property
|
||||
def size(self) -> int: return self.buf_uop.arg[1][1]
|
||||
@property
|
||||
def buf_uop(self) -> UOp:
|
||||
if self.op is Ops.BUFFER: return self
|
||||
assert self.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}"
|
||||
|
||||
Reference in New Issue
Block a user