mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
refactor to generic UPat for sourcing unrealized bufs [pr] (#7731)
* base check * use is_scheduled * fixup lazy * update metadata * match is too slow
This commit is contained in:
@@ -37,6 +37,10 @@ class ScheduleItem:
|
||||
|
||||
# **** small wrapper for LazyBuffer -> UOp
|
||||
|
||||
def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), to_store, name="store"), name="load")
|
||||
@functools.lru_cache(None)
|
||||
def is_scheduled(u:UOp): return u.op is Ops.LOAD and len(u.src) == 3
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleContext:
|
||||
buf_uops: Dict[Buffer, UOp] = field(default_factory=dict) # this maps Buffers to BUFFER uops
|
||||
@@ -84,7 +88,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children:DefaultDict[UOp, Dict[U
|
||||
allbufs[ubuf] = ret
|
||||
if buf.op is Ops.REDUCE_AXIS and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[ubuf] = None
|
||||
for x in src:
|
||||
if x.base.op is Ops.LOAD: children[x.base.buf_uop][ubuf] = None
|
||||
if is_scheduled(x.base): children[x.base.buf_uop][ubuf] = None
|
||||
return ret
|
||||
|
||||
# **** AST graph rewrite
|
||||
@@ -194,7 +198,7 @@ to_si = PatternMatcher([
|
||||
# ** fusion
|
||||
|
||||
lazy = PatternMatcher([
|
||||
(UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), UPat.var("v"))), lambda ctx,b,v: v),
|
||||
(UPatLoadStore(UPat.var("v")), lambda ctx,v,**kwargs: v),
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
|
||||
])
|
||||
|
||||
@@ -229,7 +233,7 @@ if getenv("RUN_PROCESS_REPLAY"):
|
||||
# **** Schedule grouping
|
||||
|
||||
def uval(u:UOp) -> UOp:
|
||||
assert u.op is Ops.LOAD and len(u.src) == 3 and u.src[2].op is Ops.STORE, f"must be a LOAD of STORE {u}"
|
||||
assert is_scheduled(u), f"must be a scheduled op {u}"
|
||||
return to_store.src[0] if (to_store:=u.src[2].src[2]).is_contiguous_base else to_store
|
||||
|
||||
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp], realizes:Dict[UOp, UOp],
|
||||
@@ -247,7 +251,7 @@ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:DefaultDict[UOp, Di
|
||||
# max one reduceop per kernel
|
||||
if (tr_next_uop:=uval(allbufs[tr_next]).base).op is Ops.REDUCE_AXIS: return group.setdefault(r)
|
||||
# can only fuse contiguous
|
||||
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if x.base.op is Ops.LOAD and x.base.buf_uop == tr)) > 1: return group.setdefault(r)
|
||||
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r)
|
||||
recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache)
|
||||
|
||||
def get_isolated_children(r:UOp, reduce_for_op:Dict[UOp, UOp], children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp],
|
||||
@@ -258,7 +262,7 @@ def get_isolated_children(r:UOp, reduce_for_op:Dict[UOp, UOp], children:DefaultD
|
||||
cache.add(p)
|
||||
# max one reduceop per kernel
|
||||
if p.op is Ops.REDUCE_AXIS: return {}
|
||||
rc_parents.extend(x.base.buf_uop for x in p.src if x.base.op is Ops.LOAD and x.base.buf_uop is not r)
|
||||
rc_parents.extend(x.base.buf_uop for x in p.src if is_scheduled(x.base) and x.base.buf_uop is not r)
|
||||
# search descendants of the reduceop that can cleanly group
|
||||
descendants: Dict[UOp, None] = {}
|
||||
for tr in group: recursive_group(tr, unwrap(allbufs[tr].st), tr, children, allbufs, realizes, reduce_for_op, descendants, cache={})
|
||||
@@ -295,7 +299,7 @@ def group_realizes(children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp,
|
||||
st = unwrap(r_uop.st)
|
||||
while len(children[tr]) == 1:
|
||||
tr_next_uop = uval(allbufs[(tr_next:=next(iter(children[tr])))])
|
||||
st_childs = dedup([unwrap(x.st) for x in tr_next_uop.src if x.base.op is Ops.LOAD and x.base.buf_uop is tr])
|
||||
st_childs = dedup([unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop is tr])
|
||||
if len(st_childs) > 1: break
|
||||
if st.size != st_childs[0].size: break
|
||||
st = st + st_childs[0]
|
||||
@@ -343,7 +347,6 @@ def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, **kwargs) -> Optional[U
|
||||
# otherwise safety check pads
|
||||
return None if (all(v.mask is None for v in st.views) or can_pad(base)) else realize(ctx, **kwargs).view(st)
|
||||
|
||||
def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), to_store, name="store"), name="load")
|
||||
do_realize = PatternMatcher([
|
||||
# always realize meta ops
|
||||
(UPatLoadStore(UPat((Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta))), realize),
|
||||
@@ -381,7 +384,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
bufs = list(ctx.buf_uops)
|
||||
prescheduled: List[ScheduleItem] = []
|
||||
for sink in sinks:
|
||||
metadata = tuple({mx for x in sink.sparents if x.op in GroupOp.Buffer and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.buf_uop))})
|
||||
metadata = tuple({mx for x in sink.sparents if (x.op is Ops.STORE or is_scheduled(x)) and (mx:=ctx.ubuf_metadata.get(x.buf_uop))})
|
||||
ast, ast_ctx = full_ast_rewrite(sink, ctx.var_vals, ctx.assigns)
|
||||
prescheduled.append(ScheduleItem(ast, tuple(b for u in ast_ctx.bufs if (b:=bufs[u.arg[0]]).size != 0), metadata, tuple(ast_ctx.assign_preloads)))
|
||||
# do BFS
|
||||
|
||||
@@ -342,7 +342,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
# *** uop movement ops ***
|
||||
|
||||
@property
|
||||
def base(self): return self.src[0] if self.op is Ops.VIEW and len(self.src) != 0 else self
|
||||
def base(self): return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 else self
|
||||
def view(self, st:ShapeTracker):
|
||||
assert self.op is not Ops.STORE, "VIEW of STORE is invalid, STORE is always base"
|
||||
return self if self.st is None or self.st == st else UOp(Ops.VIEW, self.dtype, (self,), st)
|
||||
|
||||
Reference in New Issue
Block a user