forced_realize is Ops.CONTIGUOUS

This commit is contained in:
qazal
2024-11-04 17:51:15 +02:00
parent 36488a2a43
commit d8bcd5d301
3 changed files with 4 additions and 7 deletions

View File

@@ -1,7 +1,7 @@
import sys
from collections import defaultdict, deque
from typing import Set, Tuple, List, Dict, DefaultDict
from tinygrad.ops import GroupOp, MetaOps, ReduceOps, UOp, UnaryOps, resolve
from tinygrad.ops import GroupOp, MetaOps, Ops, ReduceOps, UOp, UnaryOps, resolve
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, prod, dedup, all_int, merge_dicts
from tinygrad.dtype import ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
@@ -147,7 +147,7 @@ def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buff
for r in reduce_of_const:
group = {tr:None for tr,rop in reduce_for_op.items() if rop is r}
if any(tr.forced_realize for tr in group) or any(x.base in group for x in outs): continue
if any(tr.op is Ops.CONTIGUOUS for tr in group) or any(x.base in group for x in outs): continue
kernel_children = {c for tr in group for c in children[tr] if c.op not in {MetaOps.COPY, MetaOps.BUFFER_VIEW}}
if len(kernel_children) == 0: continue
for tr in group:

View File

@@ -41,7 +41,6 @@ class LazyBuffer(MathTrait):
self.buffer = srcs[0].base.buffer if self.op is MetaOps.ASSIGN else Buffer(device, self.size, self.dtype)
self.buffer.ref(1)
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
self.forced_realize = False
else:
# properties on view
assert base.base == base, "base must be a base itself"
@@ -93,8 +92,7 @@ class LazyBuffer(MathTrait):
ret = self.alu(MetaOps.BUFFER_VIEW) if allow_buffer_view and self.can_view() else self.alu(MetaOps.CONTIGUOUS)
if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
return ret
self.base.forced_realize = True
return self
return self.alu(MetaOps.CONTIGUOUS)
def bitcast(self, dtype:DType) -> LazyBuffer: return self.cast(dtype, bitcast=True)
def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True) -> LazyBuffer:

View File

@@ -77,7 +77,6 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) ->
else: ret = UOp(Ops.ALU, dtype, src, buf.op)
cache[buf] = ret = UOp(Ops.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret)))
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
if buf.forced_realize: ctx.realizes[ubuf] = ubuf
return ret
# **** AST graph rewrite
@@ -245,12 +244,12 @@ break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize
@track_rewrites(named=True)
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
if len(outs:=dedup(x.base for x in outs if x.realized is None and x.base.op is not Ops.CONST)) == 0: return [], {}
for out in outs: out.forced_realize = True
# create the big graph
ctx = ScheduleContext()
cache: Dict[LazyBuffer, UOp] = {}
big_graph = UOp.sink(*(to_uop(x, ctx, cache) for x in outs))
# get realizes
ctx.realizes.update(((u:=ctx.buf_uops[x.buffer]), u) for x in outs)
graph_rewrite(big_graph, do_realize, ctx.realizes)
store_groups, lazybufs_to_realize, assigns = get_realizes(outs, ctx)
# split realizes into small graphs