From 217c0061039b0b0871666c098a49d7d984c2fe11 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:04:19 +0200 Subject: [PATCH] buffer access on UOp [pr] (#7665) * add .buffer access on uop * rename to buf_uop * start smaller * ptr != buffer!! --- tinygrad/engine/schedule.py | 12 ++++++------ tinygrad/ops.py | 7 +++++++ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 0e064f4fe2..67a2243600 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -204,17 +204,17 @@ multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: c def full_ast_rewrite(pre:UOp, var_vals:Dict[Variable, int], assigned:Set[UOp]) -> Tuple[UOp, ScheduleItemContext]: # fuse and fold store -> loads - sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, {x.src[0]:x.src[2] for x in pre.src}) + sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, {x.buf_uop:x.src[2] for x in pre.src}) # assert cyclic dependency - for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {Ops.PRELOAD,Ops.LOAD} and x.src[0] in assigned), key=lambda x:x.src[0]): + for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {Ops.PRELOAD,Ops.LOAD} and x.buf_uop in assigned), key=lambda x:x.buf_uop): if not all_same([x.op for x in ops]): raise RuntimeError(f"cycle detected in kernel.\nhelp: use .contiguous() to break the part loading pre-assign {b} into a different kernel.") # do movementops sink = graph_rewrite(graph_rewrite(sink, view_left), view_right) # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine - if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is Ops.ASSIGN]) != 0: + if len(assign_targets:=[x.buf_uop for x in sink.sparents if x.op is Ops.ASSIGN]) != 0: if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \ - and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is Ops.PRELOAD and x.src[0] in assign_targets): + and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is Ops.PRELOAD and x.buf_uop in assign_targets): raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) # convert to AST @@ -251,7 +251,7 @@ do_realize = PatternMatcher([ (UPatLoadStore(UPat((Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta))), realize), # don't realize image to image casts (UPatLoadStore(UPat(Ops.CAST, src=(UPat(Ops.LOAD, name="x"),), dtype=dtypes.float)).view(name="view"), lambda ctx,x,view,**kwargs: r.view(view.st) - if (r:=ctx.get(b:=x.src[0])) is not None and r.op is Ops.STORE and isinstance(b.dtype, ImageDType) else None), + if (r:=ctx.get(b:=x.buf_uop)) is not None and r.op is Ops.STORE and isinstance(b.dtype, ImageDType) else None), # realize before expand or unsafe pad ops (UPatLoadStore(UPat.var("base")).view(name="view"), realize_view), # realize before COPY or BUFFER_VIEW @@ -283,7 +283,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] bufs = list(ctx.buf_uops) prescheduled: List[ScheduleItem] = [] for sink in sinks: - metadata = tuple({mx for x in sink.sparents if x.op in GroupOp.Buffer and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.src[0]))}) + metadata = tuple({mx for x in sink.sparents if x.op in GroupOp.Buffer and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.buf_uop))}) ast, ast_ctx = full_ast_rewrite(sink, ctx.var_vals, ctx.assigns) prescheduled.append(ScheduleItem(ast, tuple(b for u in ast_ctx.bufs if (b:=bufs[u.arg[0]]).size != 0), metadata, tuple(ast_ctx.assign_preloads))) # do BFS diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 5715d064d3..b599423b85 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -353,6 +353,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def view(self, st:ShapeTracker): return self if self.st is None or self.st == st else UOp(Ops.VIEW, self.dtype, (self,), st) def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg)) + # *** uop Buffer stuff *** + + @property + def buf_uop(self) -> UOp: + assert self.op in {*GroupOp.Buffer, Ops.ASSIGN} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}" + return self.src[0] + # *** uop Variable stuff *** @staticmethod