s/lazybufs/tensor_uops [pr] (#8207)

This commit is contained in:
qazal
2024-12-13 13:20:02 +02:00
committed by GitHub
parent 6d6c34eb1e
commit 1824cbd72c

View File

@@ -39,7 +39,7 @@ class ScheduleItem:
@dataclass(frozen=True)
class ScheduleContext:
lazybufs: Dict[UOp, List[UOp]] = field(default_factory=dict) # this maps BUFFER uops of this schedule to the underlying lazybuffer
tensor_uops: Dict[UOp, List[UOp]] = field(default_factory=dict) # this maps BUFFER uops of this schedule to the tensor uop
var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value
assigns: Set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule
realizes: Dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
@@ -76,8 +76,8 @@ def to_uop(buf:UOp, ctx:ScheduleContext, cache:Dict[UOp, UOp]) -> UOp:
buf_uop = src[0].base.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
op = UOp(buf.op, dtype.base, src, buf.arg)
ret = UOp(Ops.VIEW, dtype.base, (buf_uop,) if op is None else (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st)
# track the underlying lazydata for this op
if op is not None: ctx.lazybufs[buf_uop] = [buf]
# track the underlying tensor uop for this op
if op is not None: ctx.tensor_uops[buf_uop] = [buf]
cache[buf] = ret
return ret
@@ -137,7 +137,7 @@ view_right = merge_views+PatternMatcher([
@dataclass(frozen=True)
class ScheduleItemContext:
lazybufs: Dict[UOp, List[UOp]]
tensor_uops: Dict[UOp, List[UOp]]
ops_metadata: Dict[UOp, Metadata]
assigns: Set[UOp]
var_vals: Dict[Variable, int]
@@ -182,7 +182,7 @@ multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: c
def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemContext]:
# create the ast context
si_ctx = ScheduleItemContext(ctx.lazybufs, ctx.ops_metadata, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src})
si_ctx = ScheduleItemContext(ctx.tensor_uops, ctx.ops_metadata, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src})
create_ctx = add_metadata if len(si_ctx.assigns) == 0 else add_metadata+add_assign_adjacents
sink = graph_rewrite(pre, create_ctx if len(si_ctx.sinked) == 1 else multioutput+create_ctx, si_ctx)
# do movement ops
@@ -301,7 +301,7 @@ def group_realizes(ctx:ScheduleContext) -> List[List[UOp]]:
# maybe fuse arange with its children
for rbuf in reduce_of_const:
group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf}
if any(luop.forced_realize for tr in group for luop in ctx.lazybufs[tr]): continue
if any(luop.forced_realize for tr in group for luop in ctx.tensor_uops[tr]): continue
kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}}
if len(kernel_children) == 0: continue
for tr in group: del ctx.realizes[tr]
@@ -397,14 +397,14 @@ def merge(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp) -> UOp:
ctx.realizes[b1] = b1
del ctx.realizes[b2]
# ops referring to b2 now ref to b1
ctx.lazybufs[b1] += ctx.lazybufs[b2]
del ctx.lazybufs[b2]
ctx.tensor_uops[b1] += ctx.tensor_uops[b2]
del ctx.tensor_uops[b2]
# merge
return v1
def merge_realized(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp):
# early become
for luop in ctx.lazybufs.get(b1, [])+ctx.lazybufs.get(b2, []): luop.become(b1.view(unwrap(luop.st)))
for luop in ctx.tensor_uops.get(b1, [])+ctx.tensor_uops.get(b2, []): luop.become(b1.view(unwrap(luop.st)))
return v1
merge_bufs = PatternMatcher([
@@ -468,7 +468,7 @@ def append_realize(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp:
# TODO: metadata post merge
if (m:=ctx.lazybufs[b][0].metadata) is not None: ctx.ops_metadata[to_store] = m
if (m:=ctx.tensor_uops[b][0].metadata) is not None: ctx.ops_metadata[to_store] = m
return to_store
break_sched = PatternMatcher([
@@ -515,7 +515,7 @@ def create_schedule_with_vars(outs:List[UOp]) -> Tuple[List[ScheduleItem], Dict[
prescheduled.append(ScheduleItem(ast, tuple(u.buffer for u in ast_ctx.bufs if u.size != 0), tuple(ast_ctx.metadata),
frozenset(ubuf for ubuf,ops in ast_ctx.assign_adj.items() if any(x.op is Ops.PRELOAD for x in ops))))
for buf_uop in ast_ctx.sinked:
for luop in ast_ctx.lazybufs[buf_uop]: luop.become(buf_uop.view(unwrap(luop.st)))
for luop in ast_ctx.tensor_uops[buf_uop]: luop.become(buf_uop.view(unwrap(luop.st)))
# do BFS
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)