From 197f8fd986b83506ff2c7eb5c144a5d6c606c7c8 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:34:21 +0800 Subject: [PATCH] early uop globals with Buffer (#6753) --- tinygrad/engine/schedule.py | 22 +++++++++++++++------- viz/serve.py | 9 ++++++--- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index f4f78d0977..13c3e977aa 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -41,6 +41,10 @@ class LBScheduleItem: @property def inputs(self) -> Tuple[LazyBuffer, ...]: return self.bufs[len(self.ast.src):] if self.ast.op is UOps.SINK else self.bufs[1:] +@dataclass(frozen=True) +class ScheduleItemContext: + bufs: Tuple[Buffer, ...] + # *** UOp with SWIZZLE (movementops) rewriting to UOp we can index *** # ** helpers for doing movementops on uops @@ -112,9 +116,14 @@ reduceop_fusor = PatternMatcher([ (UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) -def full_ast_rewrite(sink:UOp) -> UOp: +enumerate_bufs = PatternMatcher([ + (UPat(UOps.DEFINE_GLOBAL, name="x"), lambda ctx,x: x.replace(arg=ctx.bufs.index(x.arg)) if isinstance(x.arg, Buffer) else None), +]) + +def full_ast_rewrite(sink:UOp, ctx:ScheduleItemContext) -> UOp: if not AST_REWRITE: return sink - return graph_rewrite(sink, reduceop_fusor) + sink = graph_rewrite(sink, reduceop_fusor) + return graph_rewrite(sink, enumerate_bufs, ctx) # *** List[LazyBuffer] lowering to ScheduleItem *** @@ -145,8 +154,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) if buf not in assign_targets and buf not in inputs: inputs.append(buf) - ubuf = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), - outputs.index(assign_targets[buf]) if buf in assign_targets else len(outputs)+inputs.index(buf)) + ubuf = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), buf.buffer) return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop())) # reduce ops change ShapeTracker @@ -174,16 +182,16 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> cache: Dict[Tuple[LazyBuffer, ShapeTracker], UOp] = {} ast: List[UOp] = [] inputs: List[LazyBuffer] = [] - for i, out in enumerate(outs): + for out in outs: src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, realizes, assign_targets, cache=cache) if out.op is MetaOps.ASSIGN and out.arg: assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}" output_st = out.arg[0] output_st, vv = output_st.simplify().unbind() var_vals.update(vv) - ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i) + ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), out.buffer) ast.append(UOp(UOps.STORE, dtypes.void, (ubuf, output_st.to_uop(), src))) - sink = full_ast_rewrite(ast[0].sink(*ast[1:])) + sink = full_ast_rewrite(ast[0].sink(*ast[1:]), ScheduleItemContext(bufs=tuple(x.buffer for x in outs+inputs))) return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))), var_vals # *** DAG creation: decide which LazyBuffers should realize *** diff --git a/viz/serve.py b/viz/serve.py index 848e3f061f..84d40136d1 100755 --- a/viz/serve.py +++ b/viz/serve.py @@ -10,12 +10,14 @@ from tinygrad.helpers import Context, getenv, to_function_name from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines from tinygrad.engine.graph import uops_colors, word_wrap from tinygrad.engine.realize import get_runner -from tinygrad.engine.schedule import full_ast_rewrite +from tinygrad.engine.schedule import ScheduleItemContext, full_ast_rewrite # **** /graph - detailed UOp + rewrites # NOTE: UPats in ops.py are spec -def graph_rewrites(ctx:TrackedRewriteContext): return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py"] +# TODO: fix key for uop with buffer +def graph_rewrites(ctx:TrackedRewriteContext): + return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py" and not ("schedule" in ctx.loc[0] and "DEFINE_GLOBAL" in str(x[2]))] @dataclass(frozen=True) class RewriteLocation: @@ -95,7 +97,8 @@ def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]: code = "" for ctx in contexts: if ctx.loc[0].split("/")[-1] == "schedule.py": - with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink)).p).name, prg.src + si_ctx = ScheduleItemContext(bufs=tuple(x.arg for x in ctx.sink.sparents if x.op is UOps.DEFINE_GLOBAL)) + with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink, si_ctx)).p).name, prg.src elif ctx.kernel_name is not None: kernel_name, code = ctx.kernel_name, "" if ret.get(k:=to_function_name(kernel_name)) is None: ret[k] = KernelRet(k, code, {}) ret[k].ctxs[(ctx.loc, ctx.sink.key)] = ctx