mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
early mark uops as realized [pr] (#7821)
* early mark uops as realized [pr] * merge with metadata * aesthetics
This commit is contained in:
@@ -43,17 +43,17 @@ def is_scheduled(u:UOp): return u.op is Ops.VIEW and len(u.src) == 2
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleContext:
|
||||
ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps BUFFER uops to Metadata
|
||||
lazybufs: Dict[UOp, LazyBuffer] = field(default_factory=dict) # this maps BUFFER uops to Metadata
|
||||
var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value
|
||||
assigns: Set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule
|
||||
realizes: Dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
|
||||
allbufs: Dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
|
||||
children: DefaultDict[UOp, Dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
||||
|
||||
def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], lazybufs:Dict[Buffer, LazyBuffer], cache:Dict[LazyBuffer, UOp]) -> UOp:
|
||||
def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache:Dict[LazyBuffer, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
if buf is not buf.base:
|
||||
cache[buf] = ret = to_uop(buf.base, ctx, buffers, lazybufs, cache).view(buf.st)
|
||||
cache[buf] = ret = to_uop(buf.base, ctx, buffers, cache).view(buf.st)
|
||||
return ret
|
||||
# make things that can't be images not images
|
||||
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
|
||||
@@ -68,18 +68,17 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], lazyb
|
||||
buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer
|
||||
op = None
|
||||
elif buf.op is Ops.ASSIGN:
|
||||
target, new_val = [to_uop(x, ctx, buffers, lazybufs, cache) for x in buf.srcs]
|
||||
target, new_val = [to_uop(x, ctx, buffers, cache) for x in buf.srcs]
|
||||
ctx.assigns.add(ubuf:=target.buf_uop)
|
||||
op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg)
|
||||
else:
|
||||
buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer
|
||||
op = UOp(cast(Ops, buf.op), dtype, tuple(to_uop(x, ctx, buffers, lazybufs, cache) for x in buf.srcs),
|
||||
op = UOp(cast(Ops, buf.op), dtype, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs),
|
||||
None if buf.op in {Ops.CAST, Ops.BITCAST} else buf.arg)
|
||||
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
|
||||
if op is not None:
|
||||
lazybufs[buf.buffer] = buf
|
||||
ctx.lazybufs[ubuf] = buf
|
||||
ctx.allbufs[ubuf] = ret
|
||||
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
|
||||
for x in op.src:
|
||||
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[ubuf] = None
|
||||
return ret
|
||||
@@ -160,9 +159,9 @@ view_right = merge_views+PatternMatcher([
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItemContext:
|
||||
var_vals: Dict[Variable, int]
|
||||
lazybufs: Dict[UOp, LazyBuffer]
|
||||
assigned: Set[UOp]
|
||||
ubuf_metadata: Dict[UOp, Metadata]
|
||||
var_vals: Dict[Variable, int]
|
||||
sinked: Dict[UOp, UOp]
|
||||
sts: Set[ShapeTracker] = field(default_factory=set)
|
||||
bufs: List[UOp] = field(default_factory=list)
|
||||
@@ -194,7 +193,7 @@ to_si = PatternMatcher([
|
||||
# ** fusion
|
||||
|
||||
def fuse_src(ctx:ScheduleItemContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
|
||||
if (metadata:=ctx.ubuf_metadata.get(b)) is not None: ctx.metadata.add(metadata)
|
||||
if (metadata:=ctx.lazybufs[b].metadata) is not None: ctx.metadata.add(metadata)
|
||||
return to_store
|
||||
|
||||
lazy = PatternMatcher([
|
||||
@@ -206,8 +205,8 @@ lazy = PatternMatcher([
|
||||
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.sinked.get(b)),])
|
||||
|
||||
def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemContext]:
|
||||
si_ctx = ScheduleItemContext(ctx.var_vals, ctx.assigns, ctx.ubuf_metadata, {x.buf_uop:x.src[2] for x in pre.src},
|
||||
metadata={mx for x in pre.src if (mx:=ctx.ubuf_metadata.get(x.buf_uop))})
|
||||
si_ctx = ScheduleItemContext(ctx.lazybufs, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src},
|
||||
metadata={mx for x in pre.src if (mx:=ctx.lazybufs[x.buf_uop].metadata) is not None})
|
||||
# fuse and fold store -> loads
|
||||
sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, si_ctx)
|
||||
# assert cyclic dependency
|
||||
@@ -386,8 +385,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
ctx = ScheduleContext()
|
||||
cache: Dict[LazyBuffer, UOp] = {}
|
||||
buffers: Dict[UOp, Buffer] = {}
|
||||
lazybufs: Dict[Buffer, LazyBuffer] = {}
|
||||
big_graph = UOp.sink(*(to_uop(x, ctx, buffers, lazybufs, cache) for x in outs))
|
||||
big_graph = UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs))
|
||||
# get realizes
|
||||
graph_rewrite(big_graph, do_realize, ctx.realizes)
|
||||
store_groups = group_realizes(ctx, ctx.realizes)
|
||||
@@ -399,6 +397,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
ast, ast_ctx = full_ast_rewrite(UOp.sink(*(ctx.realizes[u] for u in store_uops)), ctx)
|
||||
prescheduled.append(ScheduleItem(ast, tuple(buffers[u] for u in ast_ctx.bufs if u.size != 0),
|
||||
tuple(ast_ctx.metadata), frozenset(x.buf_uop for x in ast_ctx.assign_preloads)))
|
||||
for u in ast_ctx.sinked: del ast_ctx.lazybufs[u].srcs # can only schedule once
|
||||
# do BFS
|
||||
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
|
||||
graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)
|
||||
@@ -418,7 +417,6 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
schedule: List[ScheduleItem] = []
|
||||
while queue:
|
||||
schedule.append(si:=queue.popleft())
|
||||
for b in si.outputs: del lazybufs[b].srcs # can only schedule once
|
||||
for x in graph[si]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
Reference in New Issue
Block a user