mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
s/lazybufs/tensor_uops [pr] (#8207)
This commit is contained in:
@@ -39,7 +39,7 @@ class ScheduleItem:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleContext:
|
||||
lazybufs: Dict[UOp, List[UOp]] = field(default_factory=dict) # this maps BUFFER uops of this schedule to the underlying lazybuffer
|
||||
tensor_uops: Dict[UOp, List[UOp]] = field(default_factory=dict) # this maps BUFFER uops of this schedule to the tensor uop
|
||||
var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value
|
||||
assigns: Set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule
|
||||
realizes: Dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
|
||||
@@ -76,8 +76,8 @@ def to_uop(buf:UOp, ctx:ScheduleContext, cache:Dict[UOp, UOp]) -> UOp:
|
||||
buf_uop = src[0].base.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
op = UOp(buf.op, dtype.base, src, buf.arg)
|
||||
ret = UOp(Ops.VIEW, dtype.base, (buf_uop,) if op is None else (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st)
|
||||
# track the underlying lazydata for this op
|
||||
if op is not None: ctx.lazybufs[buf_uop] = [buf]
|
||||
# track the underlying tensor uop for this op
|
||||
if op is not None: ctx.tensor_uops[buf_uop] = [buf]
|
||||
cache[buf] = ret
|
||||
return ret
|
||||
|
||||
@@ -137,7 +137,7 @@ view_right = merge_views+PatternMatcher([
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItemContext:
|
||||
lazybufs: Dict[UOp, List[UOp]]
|
||||
tensor_uops: Dict[UOp, List[UOp]]
|
||||
ops_metadata: Dict[UOp, Metadata]
|
||||
assigns: Set[UOp]
|
||||
var_vals: Dict[Variable, int]
|
||||
@@ -182,7 +182,7 @@ multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: c
|
||||
|
||||
def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemContext]:
|
||||
# create the ast context
|
||||
si_ctx = ScheduleItemContext(ctx.lazybufs, ctx.ops_metadata, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src})
|
||||
si_ctx = ScheduleItemContext(ctx.tensor_uops, ctx.ops_metadata, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src})
|
||||
create_ctx = add_metadata if len(si_ctx.assigns) == 0 else add_metadata+add_assign_adjacents
|
||||
sink = graph_rewrite(pre, create_ctx if len(si_ctx.sinked) == 1 else multioutput+create_ctx, si_ctx)
|
||||
# do movement ops
|
||||
@@ -301,7 +301,7 @@ def group_realizes(ctx:ScheduleContext) -> List[List[UOp]]:
|
||||
# maybe fuse arange with its children
|
||||
for rbuf in reduce_of_const:
|
||||
group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf}
|
||||
if any(luop.forced_realize for tr in group for luop in ctx.lazybufs[tr]): continue
|
||||
if any(luop.forced_realize for tr in group for luop in ctx.tensor_uops[tr]): continue
|
||||
kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}}
|
||||
if len(kernel_children) == 0: continue
|
||||
for tr in group: del ctx.realizes[tr]
|
||||
@@ -397,14 +397,14 @@ def merge(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp) -> UOp:
|
||||
ctx.realizes[b1] = b1
|
||||
del ctx.realizes[b2]
|
||||
# ops referring to b2 now ref to b1
|
||||
ctx.lazybufs[b1] += ctx.lazybufs[b2]
|
||||
del ctx.lazybufs[b2]
|
||||
ctx.tensor_uops[b1] += ctx.tensor_uops[b2]
|
||||
del ctx.tensor_uops[b2]
|
||||
# merge
|
||||
return v1
|
||||
|
||||
def merge_realized(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp):
|
||||
# early become
|
||||
for luop in ctx.lazybufs.get(b1, [])+ctx.lazybufs.get(b2, []): luop.become(b1.view(unwrap(luop.st)))
|
||||
for luop in ctx.tensor_uops.get(b1, [])+ctx.tensor_uops.get(b2, []): luop.become(b1.view(unwrap(luop.st)))
|
||||
return v1
|
||||
|
||||
merge_bufs = PatternMatcher([
|
||||
@@ -468,7 +468,7 @@ def append_realize(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
|
||||
|
||||
def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp:
|
||||
# TODO: metadata post merge
|
||||
if (m:=ctx.lazybufs[b][0].metadata) is not None: ctx.ops_metadata[to_store] = m
|
||||
if (m:=ctx.tensor_uops[b][0].metadata) is not None: ctx.ops_metadata[to_store] = m
|
||||
return to_store
|
||||
|
||||
break_sched = PatternMatcher([
|
||||
@@ -515,7 +515,7 @@ def create_schedule_with_vars(outs:List[UOp]) -> Tuple[List[ScheduleItem], Dict[
|
||||
prescheduled.append(ScheduleItem(ast, tuple(u.buffer for u in ast_ctx.bufs if u.size != 0), tuple(ast_ctx.metadata),
|
||||
frozenset(ubuf for ubuf,ops in ast_ctx.assign_adj.items() if any(x.op is Ops.PRELOAD for x in ops))))
|
||||
for buf_uop in ast_ctx.sinked:
|
||||
for luop in ast_ctx.lazybufs[buf_uop]: luop.become(buf_uop.view(unwrap(luop.st)))
|
||||
for luop in ast_ctx.tensor_uops[buf_uop]: luop.become(buf_uop.view(unwrap(luop.st)))
|
||||
# do BFS
|
||||
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
|
||||
graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)
|
||||
|
||||
Reference in New Issue
Block a user