diff --git a/tinygrad/kernelize/grouper.py b/tinygrad/kernelize/grouper.py index 38c5a691f3..2477da5263 100644 --- a/tinygrad/kernelize/grouper.py +++ b/tinygrad/kernelize/grouper.py @@ -3,7 +3,7 @@ from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND, from tinygrad.shape.shapetracker import ShapeTracker ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, - Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.GBARRIER} + Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK} # **** Grouper decides which of the UOps realize diff --git a/tinygrad/kernelize/kernelize.py b/tinygrad/kernelize/kernelize.py index 77b341cef2..765989b8ec 100644 --- a/tinygrad/kernelize/kernelize.py +++ b/tinygrad/kernelize/kernelize.py @@ -136,9 +136,9 @@ def append_to_kernel(x:UOp): if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(dedup(metadata)))) create_kernels = PatternMatcher([ - # always give assign/gbarrier a kernel + # always give assign/contiguous a kernel (UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel), - (UPat(Ops.GBARRIER, src=(UPat.var("x"),)), create_kernel), + (UPat(Ops.CONTIGUOUS, name="x"), create_kernel), # walk back the local graph until we reach a realized source (UPat(Ops.KERNEL, name="x"), append_to_kernel), # push RESHAPE through MSELECT @@ -240,7 +240,7 @@ view_right = merge_views+PatternMatcher([ # apply view after reduceops (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right), # apply view after elementwise ops - (UPat(GroupOp.All-{Ops.SINK, Ops.GBARRIER, Ops.REDUCE_AXIS}, name="root"), elementwise_view_right), + (UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS}, name="root"), elementwise_view_right), # merge axes for double reduce (invert of SPLIT_REDUCEOP=1) (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"), lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None), @@ -252,7 +252,7 @@ add_buffer_ops = PatternMatcher([ # LOAD (UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).view(x.st),)), # STORE (except for meta ops) - (UPat(Ops.SINK, src=(UPat(GroupOp.Meta, name="x"),)), lambda x:x), + (UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x), (UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda ctx,sink: UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)])), # passthrough ASSIGN @@ -385,8 +385,8 @@ do_fuse = PatternMatcher([ (UPat(Ops.REDUCE_AXIS, name="root"), fuse_arange), ]) -add_gbarrier = PatternMatcher([(UPat(GroupOp.All-{Ops.GBARRIER, Ops.ASSIGN}, name="x"), - lambda ctx,x: x.replace(tag=1).gbarrier() if x in ctx and x.tag is None else None)]) +add_contiguous = PatternMatcher([(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), + lambda ctx,x: x.replace(tag=1).contiguous() if x in ctx and x.tag is None else None)]) # TODO: get this from the device through GrouperOpts DEVICE_MAX_BUFS = {"METAL":32, "WEBGPU":8} @@ -398,22 +398,18 @@ def limit_bufs(root:UOp): # count number of unique buffers flowing into this op bufs: set[UOp] = set() def gate_input(u:UOp): - if (is_load:=(u.op in {Ops.BUFFER, Ops.GBARRIER, Ops.ASSIGN, Ops.MSTACK})): bufs.add(u) + if (is_load:=(u.op in {Ops.BUFFER, Ops.CONTIGUOUS, Ops.ASSIGN, Ops.MSTACK})): bufs.add(u) return not is_load root.toposort(gate=gate_input) # NOTE: this -1 is for the output buffer if len(bufs)>=MAX_BUFS-1: - return root.replace(src=tuple(s if s.base in bufs else s.replace(tag=1).gbarrier() for s in root.src)) + return root.replace(src=tuple(s if s.base in bufs else s.replace(tag=1).contiguous() for s in root.src)) -finalize_gbarrier = PatternMatcher([ +finalize_contiguous = PatternMatcher([ # if an op takes more than one input, check combined LOADs don't exceed device limits (UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs), - # merge gbarrier - (UPat((Ops.GBARRIER, Ops.CONTIGUOUS), src=(UPat(Ops.GBARRIER),), name="x"), lambda x: x.src[0]), - # add contiguous to VIEW before GBARRIER - (UPat(Ops.GBARRIER, src=(UPat(Ops.VIEW,),), name="x"), lambda x: x.src[0].contiguous().gbarrier()), - # remove gbarrier on constants without a contiguous - (UPat(Ops.GBARRIER, src=(UPat(Ops.CONST),), name="x"), lambda x: x.src[0]), + # merge contiguous + (UPat(Ops.CONTIGUOUS, src=(UPat(Ops.CONTIGUOUS),), name="x"), lambda x: x.src[0]), # simplify views (UPat(Ops.VIEW, src=(UPat.var('x')), name="v"), lambda x,v: x.view(new_st) if (new_st:=v.arg.simplify()) != v.arg else None), ]) @@ -438,10 +434,10 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]: # display the cleaned up tensor graph if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph") - # insert gbarriers in places determined by the realize map + # insert contiguous in places determined by the realize map realize_map = group_realizes(tensor_map[sink]) - tensor_map = graph_rewrite_map(tensor_map[sink], add_gbarrier, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="insert_gbarrier") - tensor_map = graph_rewrite_map(tensor_map[sink], finalize_gbarrier+remove_tags, input_map=tensor_map, name="finalize_gbarrier") + tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add_contiguous") + tensor_map = graph_rewrite_map(tensor_map[sink], finalize_contiguous+remove_tags, input_map=tensor_map, name="finalize_contiguous") # TODO: move view_left/view_right here diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 350b5a7b8e..97f3190687 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -272,7 +272,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def contiguous(self): return self.alu(Ops.CONTIGUOUS) def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD) def fuse(self): return self.alu(Ops.FUSE) - def gbarrier(self): return self.alu(Ops.GBARRIER) def allreduce(self, op, device:str|tuple[str, ...]|UOp): assert isinstance(self.device, tuple), f"allreduce must be on tuple {self.device} isn't" return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op) diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 1437042b9d..1aaad9721f 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -14,7 +14,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF", Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500", - Ops.ALLREDUCE: "#ff40a0", Ops.GBARRIER: "#FFC14D", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0"} + Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0"} # VIZ API