From 5e1221845ff6eaedc9f5b53b01df1538996f91f3 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 29 Sep 2024 11:34:39 +0800 Subject: [PATCH] refactor schedule edges to tuple[LazyBuffer, ...] [run_process_replay] (#6797) --- tinygrad/engine/schedule.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 6cdb789779..314829ef41 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -128,7 +128,7 @@ def full_ast_rewrite(sink:UOp, ctx:ScheduleItemContext) -> UOp: # *** List[LazyBuffer] lowering to ScheduleItem *** def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:List[LazyBuffer], - realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], + bufs:Tuple[LazyBuffer, ...], assign_targets:Dict[LazyBuffer, LazyBuffer], cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp: """recursively create a UOp""" if buf is not buf.base: st, buf = buf.st+st, buf.base @@ -137,7 +137,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype # buffer ops define ShapeTracker - if buf in realizes and buf not in outputs: + if buf in bufs and buf not in outputs: unbound_st, st_var_vals = st.simplify().unbind() var_vals.update(st_var_vals) # if it's a const, we generate it @@ -158,11 +158,11 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. # reduce ops change ShapeTracker if buf.op in ReduceOps: - rsrc = _recursive_uop(buf.srcs[0], ShapeTracker.from_shape(buf.srcs[0].shape), outputs, var_vals, inputs, realizes, assign_targets, cache) + rsrc = _recursive_uop(buf.srcs[0], ShapeTracker.from_shape(buf.srcs[0].shape), outputs, var_vals, inputs, bufs, assign_targets, cache) return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (REDUCE_ALU[cast(ReduceOps, buf.op)], buf.arg)).swizzle(st)) # elementwise ops pass shapetracker - in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, cache) for x in buf.srcs) + in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, bufs, assign_targets, cache) for x in buf.srcs) if buf.op is MetaOps.CONTIGUOUS: assert buf in outputs, f"{buf.op} must be writable" return in_uops[0] @@ -171,7 +171,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops)) return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op)) -def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> Tuple[LBScheduleItem, Dict[Variable, int]]: +def _lower_lazybuffer(outs:List[LazyBuffer], bufs:Tuple[LazyBuffer, ...]) -> Tuple[LBScheduleItem, Dict[Variable, int]]: """describe the computation for a LazyBuffer with UOp + inputs + var_vals""" if (out:=outs[0]).op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), (out,)+tuple(x.base for x in out.srcs)), {} @@ -182,7 +182,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> ast: List[UOp] = [] inputs: List[LazyBuffer] = [] for out in outs: - src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, realizes, assign_targets, cache=cache) + src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, bufs, assign_targets, cache=cache) if out.op is MetaOps.ASSIGN and out.arg: assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}" output_st = out.arg[0] @@ -273,7 +273,7 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff def _get_output_groups(outs:List[LazyBuffer]) -> \ Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], # these are the output groups - Dict[LazyBuffer, None], # these are all the realizes in the graph + Tuple[LazyBuffer, ...], # these are all the realizes in the graph Dict[LazyBuffer, LazyBuffer]]: # these are the buffers we ASSIGN to in this schedule """find all the realizes in the graph, group the output LazyBuffers into kernels.""" # start by just realizing the buffers passed in @@ -363,7 +363,7 @@ def _get_output_groups(outs:List[LazyBuffer]) -> \ assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer" buf.buffer.dtype = dtypes.float32 buf.buffer.options = None - return output_groups, realizes, assign_targets + return output_groups, tuple(realizes), assign_targets SCHEDULES: List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]] = [] def _graph_schedule(outs:List[LazyBuffer]) -> \ @@ -371,12 +371,12 @@ def _graph_schedule(outs:List[LazyBuffer]) -> \ DefaultDict[LBScheduleItem, int], # this is the in-degree of the graph Dict[Variable, int]]: # this has all the var values of the schedule """create a graph for realizing the outputs""" - output_groups, realizes, assign_targets = _get_output_groups(outs) + output_groups, bufs, assign_targets = _get_output_groups(outs) # preschedule all buffers in realizes prescheduled: List[LBScheduleItem] = [] var_vals: Dict[Variable, int] = {} for group in output_groups.values(): - prescheduled.append((ret:=_lower_lazybuffer(group, realizes))[0]) + prescheduled.append((ret:=_lower_lazybuffer(group, bufs))[0]) var_vals = merge_dicts([var_vals, ret[1]]) schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}