From 29508504eae2c1c95ae84bd6674215feb35db800 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 13 Nov 2024 15:32:34 +0200 Subject: [PATCH] uop style prefer small dtype + cleanups [pr] (#7671) * just this * space * typing 2 --- tinygrad/engine/fuse.py | 6 +++--- tinygrad/engine/schedule.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tinygrad/engine/fuse.py b/tinygrad/engine/fuse.py index 04dd366407..d140b77ebb 100644 --- a/tinygrad/engine/fuse.py +++ b/tinygrad/engine/fuse.py @@ -24,8 +24,8 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa if len(st_childs:=dedup(s.st for s in tr_next.srcs if s.base == tr)) > 1: return group.setdefault(r) _recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache) -def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, UOp], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],\ - realizes:Dict[LazyBuffer, None], group:Dict[LazyBuffer, None]) -> Dict[LazyBuffer, None]: +def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, UOp], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], + realizes:Dict[LazyBuffer, None], group:Dict[LazyBuffer, None]) -> Dict[LazyBuffer, None]: rc_parents, cache = deque(group), set() while rc_parents: if (p:=rc_parents.pop()) in cache: continue @@ -80,7 +80,7 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break tr = tr_next # don't cast to higher size before store (tr cannot be realized if forced_realize) - if tr.op is Ops.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize: + if tr.op is Ops.CAST and tr.dtype.base.itemsize > tr.srcs[0].dtype.base.itemsize: tr = tr.srcs[0].base group = {tr: None} realizes[tr] = None diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 2df75af70d..ef09d824a2 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -45,7 +45,8 @@ class ScheduleContext: assigns: Set[UOp] = field(default_factory=set) # this holds all the UOps.BUFFERs we ASSIGN to in this schedule lazybufs: Dict[Buffer, LazyBuffer] = field(default_factory=dict) # this is a lookup for the LazyBuffers we need to mark as realized -def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children, allbufs, double_reduces, cache:Dict[LazyBuffer, UOp]) -> UOp: +def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbufs:Dict[LazyBuffer, None], + double_reduces:Dict[LazyBuffer, None], cache:Dict[LazyBuffer, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r if buf is not buf.base: cache[buf] = ret = to_uop(buf.base, ctx, children, allbufs, double_reduces, cache).view(buf.st)