mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 07:18:40 -05:00
late folding size 0 ops (#7940)
* fold st size=0 * fold 0 here * ops folding * update realize
This commit is contained in:
@@ -11,7 +11,6 @@ from weakref import ref, ReferenceType, WeakValueDictionary
|
||||
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
base:Optional[LazyBuffer]=None, enable_cache=bool(LAZYCACHE)):
|
||||
if st.size == 0: op, arg, srcs, base = Ops.CONST, 0, (), None
|
||||
dtype = to_dtype(dtype)
|
||||
if op is Ops.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, UOp) else arg, True
|
||||
|
||||
|
||||
@@ -331,9 +331,18 @@ class UPatScheduled(UPat):
|
||||
def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"),
|
||||
UPat(*args, **{**kwargs,"name":"to_store"})))
|
||||
|
||||
# ** this folds ops that don't need a BUFFER
|
||||
|
||||
ops_folding = PatternMatcher([
|
||||
# op with size 0 is just zero
|
||||
(UPatScheduled(), lambda ctx,b,to_store,base: UOp(Ops.VIEW, base.dtype, (b, UOp.const(base.dtype, 0)), base.st)
|
||||
if base.st.size == 0 and to_store is not UOp.const(base.dtype, 0) else None),
|
||||
])
|
||||
|
||||
# ** this decides which ops get realized
|
||||
|
||||
def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> None: return ctx.update([(b, to_store)])
|
||||
def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> None:
|
||||
if to_store.op not in {Ops.CONST, Ops.BIND}: ctx.update([(b, to_store)])
|
||||
|
||||
def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> None:
|
||||
if to_store.op in {Ops.CONST, Ops.BIND}: return None
|
||||
@@ -399,7 +408,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
ctx = ScheduleContext()
|
||||
cache: Dict[LazyBuffer, UOp] = {}
|
||||
buffers: Dict[UOp, Buffer] = {}
|
||||
big_graph = graph_rewrite(UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs)), do_realize, ctx.realizes)
|
||||
big_graph = graph_rewrite(UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs)), ops_folding+do_realize, ctx.realizes)
|
||||
for u in big_graph.src: ctx.realizes[u.buf_uop] = u
|
||||
# group realizes into kernels
|
||||
store_groups = group_realizes(ctx)
|
||||
|
||||
Reference in New Issue
Block a user