From d8bcd5d301e44c7d775b6ec1aaf8a7eefb0611ec Mon Sep 17 00:00:00 2001 From: qazal Date: Mon, 4 Nov 2024 17:51:15 +0200 Subject: [PATCH] forced_realize is Ops.CONTIGUOUS --- tinygrad/engine/fuse.py | 4 ++-- tinygrad/engine/lazy.py | 4 +--- tinygrad/engine/schedule.py | 3 +-- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tinygrad/engine/fuse.py b/tinygrad/engine/fuse.py index 8faf012d7f..1b029fd0dd 100644 --- a/tinygrad/engine/fuse.py +++ b/tinygrad/engine/fuse.py @@ -1,7 +1,7 @@ import sys from collections import defaultdict, deque from typing import Set, Tuple, List, Dict, DefaultDict -from tinygrad.ops import GroupOp, MetaOps, ReduceOps, UOp, UnaryOps, resolve +from tinygrad.ops import GroupOp, MetaOps, Ops, ReduceOps, UOp, UnaryOps, resolve from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, prod, dedup, all_int, merge_dicts from tinygrad.dtype import ImageDType from tinygrad.shape.shapetracker import ShapeTracker @@ -147,7 +147,7 @@ def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buff for r in reduce_of_const: group = {tr:None for tr,rop in reduce_for_op.items() if rop is r} - if any(tr.forced_realize for tr in group) or any(x.base in group for x in outs): continue + if any(tr.op is Ops.CONTIGUOUS for tr in group) or any(x.base in group for x in outs): continue kernel_children = {c for tr in group for c in children[tr] if c.op not in {MetaOps.COPY, MetaOps.BUFFER_VIEW}} if len(kernel_children) == 0: continue for tr in group: diff --git a/tinygrad/engine/lazy.py b/tinygrad/engine/lazy.py index 438d71014b..16c7dc312f 100644 --- a/tinygrad/engine/lazy.py +++ b/tinygrad/engine/lazy.py @@ -41,7 +41,6 @@ class LazyBuffer(MathTrait): self.buffer = srcs[0].base.buffer if self.op is MetaOps.ASSIGN else Buffer(device, self.size, self.dtype) self.buffer.ref(1) self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None - self.forced_realize = False else: # properties on view assert base.base == base, "base must be a base itself" @@ -93,8 +92,7 @@ class LazyBuffer(MathTrait): ret = self.alu(MetaOps.BUFFER_VIEW) if allow_buffer_view and self.can_view() else self.alu(MetaOps.CONTIGUOUS) if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti return ret - self.base.forced_realize = True - return self + return self.alu(MetaOps.CONTIGUOUS) def bitcast(self, dtype:DType) -> LazyBuffer: return self.cast(dtype, bitcast=True) def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True) -> LazyBuffer: diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 43702fa4d4..2face24c25 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -77,7 +77,6 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> else: ret = UOp(Ops.ALU, dtype, src, buf.op) cache[buf] = ret = UOp(Ops.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret))) if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata - if buf.forced_realize: ctx.realizes[ubuf] = ubuf return ret # **** AST graph rewrite @@ -245,12 +244,12 @@ break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize @track_rewrites(named=True) def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: if len(outs:=dedup(x.base for x in outs if x.realized is None and x.base.op is not Ops.CONST)) == 0: return [], {} - for out in outs: out.forced_realize = True # create the big graph ctx = ScheduleContext() cache: Dict[LazyBuffer, UOp] = {} big_graph = UOp.sink(*(to_uop(x, ctx, cache) for x in outs)) # get realizes + ctx.realizes.update(((u:=ctx.buf_uops[x.buffer]), u) for x in outs) graph_rewrite(big_graph, do_realize, ctx.realizes) store_groups, lazybufs_to_realize, assigns = get_realizes(outs, ctx) # split realizes into small graphs