From eebd23155cb0ce240cfd7e93677e3f38e4efa157 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 19 Sep 2024 20:03:28 +0800 Subject: [PATCH] move scheduler rewrites into full_ast_rewrite [run_process_replay] (#6609) --- tinygrad/engine/schedule.py | 106 ++++++++++++++++++------------------ tinygrad/ops.py | 2 +- 2 files changed, 55 insertions(+), 53 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 21867c9c32..7b03c2f9f9 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -43,55 +43,9 @@ class LBScheduleItem: """The unique identifier of a schedule item in the toposort.""" return hash(self.outputs[0]) -# *** DAG transformation: List[LazyBuffer] -> ScheduleItem *** +# *** UOp with SWIZZLE (movementops) rewriting to UOp we can index *** -def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:Dict[LazyBuffer, int], - realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], - cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp: - """recursively create a UOp""" - if buf is not buf.base: st, buf = buf.st+st, buf.base - if (buf, st) in cache: return cache[(buf, st)] - assert buf.op is not None, "base must be a base itself" - dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype - - # buffer ops define ShapeTracker - if buf.realized is not None or (buf in realizes and buf not in outputs): - unbound_st, st_var_vals = st.simplify().unbind() - var_vals.update(st_var_vals) - # if it's a const, we generate it - if buf.op is MetaOps.CONST: - if isinstance(val:=buf.arg, Variable): - val, var_val = val.unbind() - var_vals[val] = var_val - else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}" - return UOp(UOps.CONST, dtype, (unbound_st.to_uop(),), val) - # otherwise, it's a load and we add it to the inputs - if buf in assign_targets and not (unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and \ - ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask))): - # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine - raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" - +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) - ubuf = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), - outputs.index(assign_targets[buf]) if buf in assign_targets else len(outputs)+inputs.setdefault(buf, len(inputs))) - return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop())) - - # reduce ops change ShapeTracker - if buf.op in ReduceOps: - rsrc = _recursive_uop(buf.srcs[0], ShapeTracker.from_shape(buf.srcs[0].shape), outputs, var_vals, inputs, realizes, assign_targets, cache) - return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (REDUCE_ALU[cast(ReduceOps, buf.op)], buf.arg)).swizzle(st)) - - # elementwise ops pass shapetracker - in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, cache) for x in buf.srcs) - if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}: - assert buf in outputs, f"{buf.op} must be writable" - return in_uops[0] - if buf.op is UnaryOps.CAST: return cache.setdefault((buf, st), UOp(UOps.CAST, dtype, in_uops)) - if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops)) - return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op)) - -# ** AST graph rewrite: UOp with SWIZZLE (movementops) -> UOp we can index ** - -# ***** helpers for doing movementops on uops ***** +# ** helpers for doing movementops on uops def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp: if (n:=cache.get(u)): return n @@ -108,7 +62,7 @@ def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTr tmp = input_st.permute(permute_axis) return tmp, tmp.shape[-len(axis):] -# ***** reduceop fusor ***** +# ** reduceop fusor def push_swizzle_up_through_reduce(swizzle:UOp, reduceop:UOp) -> Optional[UOp]: if (swizzle_st:=unwrap(swizzle.st)).contiguous: return None @@ -160,6 +114,56 @@ reduceop_fusor = PatternMatcher([ (UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) +def full_ast_rewrite(sink:UOp) -> UOp: + if not AST_REWRITE: return sink + return graph_rewrite(sink, reduceop_fusor) + +# *** List[LazyBuffer] lowering to ScheduleItem *** + +def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:Dict[LazyBuffer, int], + realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], + cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp: + """recursively create a UOp""" + if buf is not buf.base: st, buf = buf.st+st, buf.base + if (buf, st) in cache: return cache[(buf, st)] + assert buf.op is not None, "base must be a base itself" + dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype + + # buffer ops define ShapeTracker + if buf.realized is not None or (buf in realizes and buf not in outputs): + unbound_st, st_var_vals = st.simplify().unbind() + var_vals.update(st_var_vals) + # if it's a const, we generate it + if buf.op is MetaOps.CONST: + if isinstance(val:=buf.arg, Variable): + val, var_val = val.unbind() + var_vals[val] = var_val + else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}" + return UOp(UOps.CONST, dtype, (unbound_st.to_uop(),), val) + # otherwise, it's a load and we add it to the inputs + if buf in assign_targets and not (unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and \ + ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask))): + # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine + raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" + +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) + ubuf = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), + outputs.index(assign_targets[buf]) if buf in assign_targets else len(outputs)+inputs.setdefault(buf, len(inputs))) + return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop())) + + # reduce ops change ShapeTracker + if buf.op in ReduceOps: + rsrc = _recursive_uop(buf.srcs[0], ShapeTracker.from_shape(buf.srcs[0].shape), outputs, var_vals, inputs, realizes, assign_targets, cache) + return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (REDUCE_ALU[cast(ReduceOps, buf.op)], buf.arg)).swizzle(st)) + + # elementwise ops pass shapetracker + in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, cache) for x in buf.srcs) + if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}: + assert buf in outputs, f"{buf.op} must be writable" + return in_uops[0] + if buf.op is UnaryOps.CAST: return cache.setdefault((buf, st), UOp(UOps.CAST, dtype, in_uops)) + if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops)) + return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op)) + def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> Tuple[LBScheduleItem, Dict[Variable, int]]: """describe the computation for a LazyBuffer with UOp + inputs + var_vals""" if (out:=outs[0]).op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: @@ -179,9 +183,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> var_vals.update(vv) ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i) ast.append(UOp(UOps.STORE, dtypes.void, (ubuf, output_st.to_uop(), src))) - sink = UOp(UOps.SINK, dtypes.void, tuple(ast)) - if AST_REWRITE: - sink = graph_rewrite(sink, reduceop_fusor) + sink = full_ast_rewrite(ast[0].sink(*ast[1:])) return LBScheduleItem(sink, outs, list(inputs), dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])), var_vals # *** DAG creation: decide which LazyBuffers should realize *** diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 08c1e93849..18dc1b55b3 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -175,7 +175,7 @@ class UOp(MathTrait): ret = self.src[self.st_loc] assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}" return ret.arg - def sink(self, *srcs): return UOp(UOps.SINK, dtypes.void, (self,)+srcs) + def sink(self, *srcs:UOp): return UOp(UOps.SINK, dtypes.void, (self,)+srcs) def swizzle(self, st:ShapeTracker): return UOp(UOps.SWIZZLE, self.dtype, (self,), st) def const_like(self, b:ConstType|Variable|Tuple[ConstType]): return type(self).const(self.dtype, b) def broadcast(self, count:int):