From 1824cbd72c6546bcedffa024b4d64b51df990bb6 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 13 Dec 2024 13:20:02 +0200 Subject: [PATCH] s/lazybufs/tensor_uops [pr] (#8207) --- tinygrad/engine/schedule.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 336116d124..dca79e8dc4 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -39,7 +39,7 @@ class ScheduleItem: @dataclass(frozen=True) class ScheduleContext: - lazybufs: Dict[UOp, List[UOp]] = field(default_factory=dict) # this maps BUFFER uops of this schedule to the underlying lazybuffer + tensor_uops: Dict[UOp, List[UOp]] = field(default_factory=dict) # this maps BUFFER uops of this schedule to the tensor uop var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value assigns: Set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule realizes: Dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule @@ -76,8 +76,8 @@ def to_uop(buf:UOp, ctx:ScheduleContext, cache:Dict[UOp, UOp]) -> UOp: buf_uop = src[0].base.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype) op = UOp(buf.op, dtype.base, src, buf.arg) ret = UOp(Ops.VIEW, dtype.base, (buf_uop,) if op is None else (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st) - # track the underlying lazydata for this op - if op is not None: ctx.lazybufs[buf_uop] = [buf] + # track the underlying tensor uop for this op + if op is not None: ctx.tensor_uops[buf_uop] = [buf] cache[buf] = ret return ret @@ -137,7 +137,7 @@ view_right = merge_views+PatternMatcher([ @dataclass(frozen=True) class ScheduleItemContext: - lazybufs: Dict[UOp, List[UOp]] + tensor_uops: Dict[UOp, List[UOp]] ops_metadata: Dict[UOp, Metadata] assigns: Set[UOp] var_vals: Dict[Variable, int] @@ -182,7 +182,7 @@ multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: c def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemContext]: # create the ast context - si_ctx = ScheduleItemContext(ctx.lazybufs, ctx.ops_metadata, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src}) + si_ctx = ScheduleItemContext(ctx.tensor_uops, ctx.ops_metadata, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src}) create_ctx = add_metadata if len(si_ctx.assigns) == 0 else add_metadata+add_assign_adjacents sink = graph_rewrite(pre, create_ctx if len(si_ctx.sinked) == 1 else multioutput+create_ctx, si_ctx) # do movement ops @@ -301,7 +301,7 @@ def group_realizes(ctx:ScheduleContext) -> List[List[UOp]]: # maybe fuse arange with its children for rbuf in reduce_of_const: group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf} - if any(luop.forced_realize for tr in group for luop in ctx.lazybufs[tr]): continue + if any(luop.forced_realize for tr in group for luop in ctx.tensor_uops[tr]): continue kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}} if len(kernel_children) == 0: continue for tr in group: del ctx.realizes[tr] @@ -397,14 +397,14 @@ def merge(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp) -> UOp: ctx.realizes[b1] = b1 del ctx.realizes[b2] # ops referring to b2 now ref to b1 - ctx.lazybufs[b1] += ctx.lazybufs[b2] - del ctx.lazybufs[b2] + ctx.tensor_uops[b1] += ctx.tensor_uops[b2] + del ctx.tensor_uops[b2] # merge return v1 def merge_realized(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp): # early become - for luop in ctx.lazybufs.get(b1, [])+ctx.lazybufs.get(b2, []): luop.become(b1.view(unwrap(luop.st))) + for luop in ctx.tensor_uops.get(b1, [])+ctx.tensor_uops.get(b2, []): luop.become(b1.view(unwrap(luop.st))) return v1 merge_bufs = PatternMatcher([ @@ -468,7 +468,7 @@ def append_realize(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp: def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp: # TODO: metadata post merge - if (m:=ctx.lazybufs[b][0].metadata) is not None: ctx.ops_metadata[to_store] = m + if (m:=ctx.tensor_uops[b][0].metadata) is not None: ctx.ops_metadata[to_store] = m return to_store break_sched = PatternMatcher([ @@ -515,7 +515,7 @@ def create_schedule_with_vars(outs:List[UOp]) -> Tuple[List[ScheduleItem], Dict[ prescheduled.append(ScheduleItem(ast, tuple(u.buffer for u in ast_ctx.bufs if u.size != 0), tuple(ast_ctx.metadata), frozenset(ubuf for ubuf,ops in ast_ctx.assign_adj.items() if any(x.op is Ops.PRELOAD for x in ops)))) for buf_uop in ast_ctx.sinked: - for luop in ast_ctx.lazybufs[buf_uop]: luop.become(buf_uop.view(unwrap(luop.st))) + for luop in ast_ctx.tensor_uops[buf_uop]: luop.become(buf_uop.view(unwrap(luop.st))) # do BFS schedule_targets = {out:si for si in prescheduled for out in si.outputs} graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)