mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
buffer access on UOp [pr] (#7665)
* add .buffer access on uop * rename to buf_uop * start smaller * ptr != buffer!!
This commit is contained in:
@@ -204,17 +204,17 @@ multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: c
|
||||
|
||||
def full_ast_rewrite(pre:UOp, var_vals:Dict[Variable, int], assigned:Set[UOp]) -> Tuple[UOp, ScheduleItemContext]:
|
||||
# fuse and fold store -> loads
|
||||
sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, {x.src[0]:x.src[2] for x in pre.src})
|
||||
sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, {x.buf_uop:x.src[2] for x in pre.src})
|
||||
# assert cyclic dependency
|
||||
for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {Ops.PRELOAD,Ops.LOAD} and x.src[0] in assigned), key=lambda x:x.src[0]):
|
||||
for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {Ops.PRELOAD,Ops.LOAD} and x.buf_uop in assigned), key=lambda x:x.buf_uop):
|
||||
if not all_same([x.op for x in ops]):
|
||||
raise RuntimeError(f"cycle detected in kernel.\nhelp: use .contiguous() to break the part loading pre-assign {b} into a different kernel.")
|
||||
# do movementops
|
||||
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is Ops.ASSIGN]) != 0:
|
||||
if len(assign_targets:=[x.buf_uop for x in sink.sparents if x.op is Ops.ASSIGN]) != 0:
|
||||
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
|
||||
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is Ops.PRELOAD and x.src[0] in assign_targets):
|
||||
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is Ops.PRELOAD and x.buf_uop in assign_targets):
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
# convert to AST
|
||||
@@ -251,7 +251,7 @@ do_realize = PatternMatcher([
|
||||
(UPatLoadStore(UPat((Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta))), realize),
|
||||
# don't realize image to image casts
|
||||
(UPatLoadStore(UPat(Ops.CAST, src=(UPat(Ops.LOAD, name="x"),), dtype=dtypes.float)).view(name="view"), lambda ctx,x,view,**kwargs: r.view(view.st)
|
||||
if (r:=ctx.get(b:=x.src[0])) is not None and r.op is Ops.STORE and isinstance(b.dtype, ImageDType) else None),
|
||||
if (r:=ctx.get(b:=x.buf_uop)) is not None and r.op is Ops.STORE and isinstance(b.dtype, ImageDType) else None),
|
||||
# realize before expand or unsafe pad ops
|
||||
(UPatLoadStore(UPat.var("base")).view(name="view"), realize_view),
|
||||
# realize before COPY or BUFFER_VIEW
|
||||
@@ -283,7 +283,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
bufs = list(ctx.buf_uops)
|
||||
prescheduled: List[ScheduleItem] = []
|
||||
for sink in sinks:
|
||||
metadata = tuple({mx for x in sink.sparents if x.op in GroupOp.Buffer and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.src[0]))})
|
||||
metadata = tuple({mx for x in sink.sparents if x.op in GroupOp.Buffer and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.buf_uop))})
|
||||
ast, ast_ctx = full_ast_rewrite(sink, ctx.var_vals, ctx.assigns)
|
||||
prescheduled.append(ScheduleItem(ast, tuple(b for u in ast_ctx.bufs if (b:=bufs[u.arg[0]]).size != 0), metadata, tuple(ast_ctx.assign_preloads)))
|
||||
# do BFS
|
||||
|
||||
@@ -353,6 +353,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def view(self, st:ShapeTracker): return self if self.st is None or self.st == st else UOp(Ops.VIEW, self.dtype, (self,), st)
|
||||
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
|
||||
|
||||
# *** uop Buffer stuff ***
|
||||
|
||||
@property
|
||||
def buf_uop(self) -> UOp:
|
||||
assert self.op in {*GroupOp.Buffer, Ops.ASSIGN} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}"
|
||||
return self.src[0]
|
||||
|
||||
# *** uop Variable stuff ***
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user