mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
forced_realize is Ops.CONTIGUOUS
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import sys
|
||||
from collections import defaultdict, deque
|
||||
from typing import Set, Tuple, List, Dict, DefaultDict
|
||||
from tinygrad.ops import GroupOp, MetaOps, ReduceOps, UOp, UnaryOps, resolve
|
||||
from tinygrad.ops import GroupOp, MetaOps, Ops, ReduceOps, UOp, UnaryOps, resolve
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, prod, dedup, all_int, merge_dicts
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -147,7 +147,7 @@ def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buff
|
||||
|
||||
for r in reduce_of_const:
|
||||
group = {tr:None for tr,rop in reduce_for_op.items() if rop is r}
|
||||
if any(tr.forced_realize for tr in group) or any(x.base in group for x in outs): continue
|
||||
if any(tr.op is Ops.CONTIGUOUS for tr in group) or any(x.base in group for x in outs): continue
|
||||
kernel_children = {c for tr in group for c in children[tr] if c.op not in {MetaOps.COPY, MetaOps.BUFFER_VIEW}}
|
||||
if len(kernel_children) == 0: continue
|
||||
for tr in group:
|
||||
|
||||
@@ -41,7 +41,6 @@ class LazyBuffer(MathTrait):
|
||||
self.buffer = srcs[0].base.buffer if self.op is MetaOps.ASSIGN else Buffer(device, self.size, self.dtype)
|
||||
self.buffer.ref(1)
|
||||
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
|
||||
self.forced_realize = False
|
||||
else:
|
||||
# properties on view
|
||||
assert base.base == base, "base must be a base itself"
|
||||
@@ -93,8 +92,7 @@ class LazyBuffer(MathTrait):
|
||||
ret = self.alu(MetaOps.BUFFER_VIEW) if allow_buffer_view and self.can_view() else self.alu(MetaOps.CONTIGUOUS)
|
||||
if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
|
||||
return ret
|
||||
self.base.forced_realize = True
|
||||
return self
|
||||
return self.alu(MetaOps.CONTIGUOUS)
|
||||
|
||||
def bitcast(self, dtype:DType) -> LazyBuffer: return self.cast(dtype, bitcast=True)
|
||||
def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True) -> LazyBuffer:
|
||||
|
||||
@@ -77,7 +77,6 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) ->
|
||||
else: ret = UOp(Ops.ALU, dtype, src, buf.op)
|
||||
cache[buf] = ret = UOp(Ops.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret)))
|
||||
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
|
||||
if buf.forced_realize: ctx.realizes[ubuf] = ubuf
|
||||
return ret
|
||||
|
||||
# **** AST graph rewrite
|
||||
@@ -245,12 +244,12 @@ break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize
|
||||
@track_rewrites(named=True)
|
||||
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
||||
if len(outs:=dedup(x.base for x in outs if x.realized is None and x.base.op is not Ops.CONST)) == 0: return [], {}
|
||||
for out in outs: out.forced_realize = True
|
||||
# create the big graph
|
||||
ctx = ScheduleContext()
|
||||
cache: Dict[LazyBuffer, UOp] = {}
|
||||
big_graph = UOp.sink(*(to_uop(x, ctx, cache) for x in outs))
|
||||
# get realizes
|
||||
ctx.realizes.update(((u:=ctx.buf_uops[x.buffer]), u) for x in outs)
|
||||
graph_rewrite(big_graph, do_realize, ctx.realizes)
|
||||
store_groups, lazybufs_to_realize, assigns = get_realizes(outs, ctx)
|
||||
# split realizes into small graphs
|
||||
|
||||
Reference in New Issue
Block a user