uop style prefer small dtype + cleanups [pr] (#7671)

* just this

* space

* typing 2
This commit is contained in:
qazal
2024-11-13 15:32:34 +02:00
committed by GitHub
parent e84d089ef1
commit 29508504ea
2 changed files with 5 additions and 4 deletions

View File

@@ -24,8 +24,8 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa
if len(st_childs:=dedup(s.st for s in tr_next.srcs if s.base == tr)) > 1: return group.setdefault(r)
_recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache)
def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, UOp], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],\
realizes:Dict[LazyBuffer, None], group:Dict[LazyBuffer, None]) -> Dict[LazyBuffer, None]:
def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, UOp], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
realizes:Dict[LazyBuffer, None], group:Dict[LazyBuffer, None]) -> Dict[LazyBuffer, None]:
rc_parents, cache = deque(group), set()
while rc_parents:
if (p:=rc_parents.pop()) in cache: continue
@@ -80,7 +80,7 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu
if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break
tr = tr_next
# don't cast to higher size before store (tr cannot be realized if forced_realize)
if tr.op is Ops.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:
if tr.op is Ops.CAST and tr.dtype.base.itemsize > tr.srcs[0].dtype.base.itemsize:
tr = tr.srcs[0].base
group = {tr: None}
realizes[tr] = None

View File

@@ -45,7 +45,8 @@ class ScheduleContext:
assigns: Set[UOp] = field(default_factory=set) # this holds all the UOps.BUFFERs we ASSIGN to in this schedule
lazybufs: Dict[Buffer, LazyBuffer] = field(default_factory=dict) # this is a lookup for the LazyBuffers we need to mark as realized
def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children, allbufs, double_reduces, cache:Dict[LazyBuffer, UOp]) -> UOp:
def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbufs:Dict[LazyBuffer, None],
double_reduces:Dict[LazyBuffer, None], cache:Dict[LazyBuffer, UOp]) -> UOp:
if (r:=cache.get(buf)) is not None: return r
if buf is not buf.base:
cache[buf] = ret = to_uop(buf.base, ctx, children, allbufs, double_reduces, cache).view(buf.st)