diff --git a/tinygrad/engine/lazy.py b/tinygrad/engine/lazy.py index 2fad343673..cade971fe6 100644 --- a/tinygrad/engine/lazy.py +++ b/tinygrad/engine/lazy.py @@ -11,7 +11,6 @@ from weakref import ref, ReferenceType, WeakValueDictionary lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary() def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(), base:Optional[LazyBuffer]=None, enable_cache=bool(LAZYCACHE)): - if st.size == 0: op, arg, srcs, base = Ops.CONST, 0, (), None dtype = to_dtype(dtype) if op is Ops.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, UOp) else arg, True diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 89baf5a5dd..1106dacbd3 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -331,9 +331,18 @@ class UPatScheduled(UPat): def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{**kwargs,"name":"to_store"}))) +# ** this folds ops that don't need a BUFFER + +ops_folding = PatternMatcher([ + # op with size 0 is just zero + (UPatScheduled(), lambda ctx,b,to_store,base: UOp(Ops.VIEW, base.dtype, (b, UOp.const(base.dtype, 0)), base.st) + if base.st.size == 0 and to_store is not UOp.const(base.dtype, 0) else None), +]) + # ** this decides which ops get realized -def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> None: return ctx.update([(b, to_store)]) +def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> None: + if to_store.op not in {Ops.CONST, Ops.BIND}: ctx.update([(b, to_store)]) def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> None: if to_store.op in {Ops.CONST, Ops.BIND}: return None @@ -399,7 +408,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] ctx = ScheduleContext() cache: Dict[LazyBuffer, UOp] = {} buffers: Dict[UOp, Buffer] = {} - big_graph = graph_rewrite(UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs)), do_realize, ctx.realizes) + big_graph = graph_rewrite(UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs)), ops_folding+do_realize, ctx.realizes) for u in big_graph.src: ctx.realizes[u.buf_uop] = u # group realizes into kernels store_groups = group_realizes(ctx)