buffer access on UOp [pr] (#7665)

* add .buffer access on uop

* rename to buf_uop

* start smaller

* ptr != buffer!!
This commit is contained in:
qazal
2024-11-13 11:04:19 +02:00
committed by GitHub
parent 5da149d23c
commit 217c006103
2 changed files with 13 additions and 6 deletions

View File

@@ -204,17 +204,17 @@ multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: c
def full_ast_rewrite(pre:UOp, var_vals:Dict[Variable, int], assigned:Set[UOp]) -> Tuple[UOp, ScheduleItemContext]:
# fuse and fold store -> loads
sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, {x.src[0]:x.src[2] for x in pre.src})
sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, {x.buf_uop:x.src[2] for x in pre.src})
# assert cyclic dependency
for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {Ops.PRELOAD,Ops.LOAD} and x.src[0] in assigned), key=lambda x:x.src[0]):
for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {Ops.PRELOAD,Ops.LOAD} and x.buf_uop in assigned), key=lambda x:x.buf_uop):
if not all_same([x.op for x in ops]):
raise RuntimeError(f"cycle detected in kernel.\nhelp: use .contiguous() to break the part loading pre-assign {b} into a different kernel.")
# do movementops
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is Ops.ASSIGN]) != 0:
if len(assign_targets:=[x.buf_uop for x in sink.sparents if x.op is Ops.ASSIGN]) != 0:
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is Ops.PRELOAD and x.src[0] in assign_targets):
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is Ops.PRELOAD and x.buf_uop in assign_targets):
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
# convert to AST
@@ -251,7 +251,7 @@ do_realize = PatternMatcher([
(UPatLoadStore(UPat((Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta))), realize),
# don't realize image to image casts
(UPatLoadStore(UPat(Ops.CAST, src=(UPat(Ops.LOAD, name="x"),), dtype=dtypes.float)).view(name="view"), lambda ctx,x,view,**kwargs: r.view(view.st)
if (r:=ctx.get(b:=x.src[0])) is not None and r.op is Ops.STORE and isinstance(b.dtype, ImageDType) else None),
if (r:=ctx.get(b:=x.buf_uop)) is not None and r.op is Ops.STORE and isinstance(b.dtype, ImageDType) else None),
# realize before expand or unsafe pad ops
(UPatLoadStore(UPat.var("base")).view(name="view"), realize_view),
# realize before COPY or BUFFER_VIEW
@@ -283,7 +283,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
bufs = list(ctx.buf_uops)
prescheduled: List[ScheduleItem] = []
for sink in sinks:
metadata = tuple({mx for x in sink.sparents if x.op in GroupOp.Buffer and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.src[0]))})
metadata = tuple({mx for x in sink.sparents if x.op in GroupOp.Buffer and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.buf_uop))})
ast, ast_ctx = full_ast_rewrite(sink, ctx.var_vals, ctx.assigns)
prescheduled.append(ScheduleItem(ast, tuple(b for u in ast_ctx.bufs if (b:=bufs[u.arg[0]]).size != 0), metadata, tuple(ast_ctx.assign_preloads)))
# do BFS

View File

@@ -353,6 +353,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def view(self, st:ShapeTracker): return self if self.st is None or self.st == st else UOp(Ops.VIEW, self.dtype, (self,), st)
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
# *** uop Buffer stuff ***
@property
def buf_uop(self) -> UOp:
assert self.op in {*GroupOp.Buffer, Ops.ASSIGN} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}"
return self.src[0]
# *** uop Variable stuff ***
@staticmethod