early mark uops as realized [pr] (#7821)

* early mark uops as realized [pr]

* merge with metadata

* aesthetics
This commit is contained in:
qazal
2024-11-21 05:02:59 -05:00
committed by GitHub
parent 4542c0f000
commit cdc431803f

View File

@@ -43,17 +43,17 @@ def is_scheduled(u:UOp): return u.op is Ops.VIEW and len(u.src) == 2
@dataclass(frozen=True)
class ScheduleContext:
ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps BUFFER uops to Metadata
lazybufs: Dict[UOp, LazyBuffer] = field(default_factory=dict) # this maps BUFFER uops to Metadata
var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value
assigns: Set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule
realizes: Dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
allbufs: Dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
children: DefaultDict[UOp, Dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], lazybufs:Dict[Buffer, LazyBuffer], cache:Dict[LazyBuffer, UOp]) -> UOp:
def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache:Dict[LazyBuffer, UOp]) -> UOp:
if (r:=cache.get(buf)) is not None: return r
if buf is not buf.base:
cache[buf] = ret = to_uop(buf.base, ctx, buffers, lazybufs, cache).view(buf.st)
cache[buf] = ret = to_uop(buf.base, ctx, buffers, cache).view(buf.st)
return ret
# make things that can't be images not images
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
@@ -68,18 +68,17 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], lazyb
buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer
op = None
elif buf.op is Ops.ASSIGN:
target, new_val = [to_uop(x, ctx, buffers, lazybufs, cache) for x in buf.srcs]
target, new_val = [to_uop(x, ctx, buffers, cache) for x in buf.srcs]
ctx.assigns.add(ubuf:=target.buf_uop)
op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg)
else:
buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer
op = UOp(cast(Ops, buf.op), dtype, tuple(to_uop(x, ctx, buffers, lazybufs, cache) for x in buf.srcs),
op = UOp(cast(Ops, buf.op), dtype, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs),
None if buf.op in {Ops.CAST, Ops.BITCAST} else buf.arg)
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
if op is not None:
lazybufs[buf.buffer] = buf
ctx.lazybufs[ubuf] = buf
ctx.allbufs[ubuf] = ret
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
for x in op.src:
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[ubuf] = None
return ret
@@ -160,9 +159,9 @@ view_right = merge_views+PatternMatcher([
@dataclass(frozen=True)
class ScheduleItemContext:
var_vals: Dict[Variable, int]
lazybufs: Dict[UOp, LazyBuffer]
assigned: Set[UOp]
ubuf_metadata: Dict[UOp, Metadata]
var_vals: Dict[Variable, int]
sinked: Dict[UOp, UOp]
sts: Set[ShapeTracker] = field(default_factory=set)
bufs: List[UOp] = field(default_factory=list)
@@ -194,7 +193,7 @@ to_si = PatternMatcher([
# ** fusion
def fuse_src(ctx:ScheduleItemContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
if (metadata:=ctx.ubuf_metadata.get(b)) is not None: ctx.metadata.add(metadata)
if (metadata:=ctx.lazybufs[b].metadata) is not None: ctx.metadata.add(metadata)
return to_store
lazy = PatternMatcher([
@@ -206,8 +205,8 @@ lazy = PatternMatcher([
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.sinked.get(b)),])
def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemContext]:
si_ctx = ScheduleItemContext(ctx.var_vals, ctx.assigns, ctx.ubuf_metadata, {x.buf_uop:x.src[2] for x in pre.src},
metadata={mx for x in pre.src if (mx:=ctx.ubuf_metadata.get(x.buf_uop))})
si_ctx = ScheduleItemContext(ctx.lazybufs, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src},
metadata={mx for x in pre.src if (mx:=ctx.lazybufs[x.buf_uop].metadata) is not None})
# fuse and fold store -> loads
sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, si_ctx)
# assert cyclic dependency
@@ -386,8 +385,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
ctx = ScheduleContext()
cache: Dict[LazyBuffer, UOp] = {}
buffers: Dict[UOp, Buffer] = {}
lazybufs: Dict[Buffer, LazyBuffer] = {}
big_graph = UOp.sink(*(to_uop(x, ctx, buffers, lazybufs, cache) for x in outs))
big_graph = UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs))
# get realizes
graph_rewrite(big_graph, do_realize, ctx.realizes)
store_groups = group_realizes(ctx, ctx.realizes)
@@ -399,6 +397,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
ast, ast_ctx = full_ast_rewrite(UOp.sink(*(ctx.realizes[u] for u in store_uops)), ctx)
prescheduled.append(ScheduleItem(ast, tuple(buffers[u] for u in ast_ctx.bufs if u.size != 0),
tuple(ast_ctx.metadata), frozenset(x.buf_uop for x in ast_ctx.assign_preloads)))
for u in ast_ctx.sinked: del ast_ctx.lazybufs[u].srcs # can only schedule once
# do BFS
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)
@@ -418,7 +417,6 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
schedule: List[ScheduleItem] = []
while queue:
schedule.append(si:=queue.popleft())
for b in si.outputs: del lazybufs[b].srcs # can only schedule once
for x in graph[si]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)