late folding size 0 ops (#7940)

* fold st size=0

* fold 0 here

* ops folding

* update realize
This commit is contained in:
qazal
2024-11-30 11:40:02 -05:00
committed by GitHub
parent c068e8c242
commit ca20f281df
2 changed files with 11 additions and 3 deletions

View File

@@ -11,7 +11,6 @@ from weakref import ref, ReferenceType, WeakValueDictionary
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
base:Optional[LazyBuffer]=None, enable_cache=bool(LAZYCACHE)):
if st.size == 0: op, arg, srcs, base = Ops.CONST, 0, (), None
dtype = to_dtype(dtype)
if op is Ops.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, UOp) else arg, True

View File

@@ -331,9 +331,18 @@ class UPatScheduled(UPat):
def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"),
UPat(*args, **{**kwargs,"name":"to_store"})))
# ** this folds ops that don't need a BUFFER
ops_folding = PatternMatcher([
# op with size 0 is just zero
(UPatScheduled(), lambda ctx,b,to_store,base: UOp(Ops.VIEW, base.dtype, (b, UOp.const(base.dtype, 0)), base.st)
if base.st.size == 0 and to_store is not UOp.const(base.dtype, 0) else None),
])
# ** this decides which ops get realized
def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> None: return ctx.update([(b, to_store)])
def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> None:
if to_store.op not in {Ops.CONST, Ops.BIND}: ctx.update([(b, to_store)])
def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> None:
if to_store.op in {Ops.CONST, Ops.BIND}: return None
@@ -399,7 +408,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
ctx = ScheduleContext()
cache: Dict[LazyBuffer, UOp] = {}
buffers: Dict[UOp, Buffer] = {}
big_graph = graph_rewrite(UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs)), do_realize, ctx.realizes)
big_graph = graph_rewrite(UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs)), ops_folding+do_realize, ctx.realizes)
for u in big_graph.src: ctx.realizes[u.buf_uop] = u
# group realizes into kernels
store_groups = group_realizes(ctx)