[pr] Unify CONTIGUOUS and GBARRIER (#11121)

* Unify CONTIGUOUS and GBARRIER

* Simplify rules
This commit is contained in:
quortus
2025-07-08 19:27:23 +02:00
committed by GitHub
parent b516fe71b4
commit 790b05ab12
4 changed files with 16 additions and 21 deletions

View File

@@ -3,7 +3,7 @@ from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND,
from tinygrad.shape.shapetracker import ShapeTracker
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.GBARRIER}
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK}
# **** Grouper decides which of the UOps realize

View File

@@ -136,9 +136,9 @@ def append_to_kernel(x:UOp):
if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(dedup(metadata))))
create_kernels = PatternMatcher([
# always give assign/gbarrier a kernel
# always give assign/contiguous a kernel
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel),
(UPat(Ops.GBARRIER, src=(UPat.var("x"),)), create_kernel),
(UPat(Ops.CONTIGUOUS, name="x"), create_kernel),
# walk back the local graph until we reach a realized source
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
# push RESHAPE through MSELECT
@@ -240,7 +240,7 @@ view_right = merge_views+PatternMatcher([
# apply view after reduceops
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right),
# apply view after elementwise ops
(UPat(GroupOp.All-{Ops.SINK, Ops.GBARRIER, Ops.REDUCE_AXIS}, name="root"), elementwise_view_right),
(UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS}, name="root"), elementwise_view_right),
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
@@ -252,7 +252,7 @@ add_buffer_ops = PatternMatcher([
# LOAD
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).view(x.st),)),
# STORE (except for meta ops)
(UPat(Ops.SINK, src=(UPat(GroupOp.Meta, name="x"),)), lambda x:x),
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda ctx,sink:
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
# passthrough ASSIGN
@@ -385,8 +385,8 @@ do_fuse = PatternMatcher([
(UPat(Ops.REDUCE_AXIS, name="root"), fuse_arange),
])
add_gbarrier = PatternMatcher([(UPat(GroupOp.All-{Ops.GBARRIER, Ops.ASSIGN}, name="x"),
lambda ctx,x: x.replace(tag=1).gbarrier() if x in ctx and x.tag is None else None)])
add_contiguous = PatternMatcher([(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"),
lambda ctx,x: x.replace(tag=1).contiguous() if x in ctx and x.tag is None else None)])
# TODO: get this from the device through GrouperOpts
DEVICE_MAX_BUFS = {"METAL":32, "WEBGPU":8}
@@ -398,22 +398,18 @@ def limit_bufs(root:UOp):
# count number of unique buffers flowing into this op
bufs: set[UOp] = set()
def gate_input(u:UOp):
if (is_load:=(u.op in {Ops.BUFFER, Ops.GBARRIER, Ops.ASSIGN, Ops.MSTACK})): bufs.add(u)
if (is_load:=(u.op in {Ops.BUFFER, Ops.CONTIGUOUS, Ops.ASSIGN, Ops.MSTACK})): bufs.add(u)
return not is_load
root.toposort(gate=gate_input)
# NOTE: this -1 is for the output buffer
if len(bufs)>=MAX_BUFS-1:
return root.replace(src=tuple(s if s.base in bufs else s.replace(tag=1).gbarrier() for s in root.src))
return root.replace(src=tuple(s if s.base in bufs else s.replace(tag=1).contiguous() for s in root.src))
finalize_gbarrier = PatternMatcher([
finalize_contiguous = PatternMatcher([
# if an op takes more than one input, check combined LOADs don't exceed device limits
(UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs),
# merge gbarrier
(UPat((Ops.GBARRIER, Ops.CONTIGUOUS), src=(UPat(Ops.GBARRIER),), name="x"), lambda x: x.src[0]),
# add contiguous to VIEW before GBARRIER
(UPat(Ops.GBARRIER, src=(UPat(Ops.VIEW,),), name="x"), lambda x: x.src[0].contiguous().gbarrier()),
# remove gbarrier on constants without a contiguous
(UPat(Ops.GBARRIER, src=(UPat(Ops.CONST),), name="x"), lambda x: x.src[0]),
# merge contiguous
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.CONTIGUOUS),), name="x"), lambda x: x.src[0]),
# simplify views
(UPat(Ops.VIEW, src=(UPat.var('x')), name="v"), lambda x,v: x.view(new_st) if (new_st:=v.arg.simplify()) != v.arg else None),
])
@@ -438,10 +434,10 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
# display the cleaned up tensor graph
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph")
# insert gbarriers in places determined by the realize map
# insert contiguous in places determined by the realize map
realize_map = group_realizes(tensor_map[sink])
tensor_map = graph_rewrite_map(tensor_map[sink], add_gbarrier, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="insert_gbarrier")
tensor_map = graph_rewrite_map(tensor_map[sink], finalize_gbarrier+remove_tags, input_map=tensor_map, name="finalize_gbarrier")
tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add_contiguous")
tensor_map = graph_rewrite_map(tensor_map[sink], finalize_contiguous+remove_tags, input_map=tensor_map, name="finalize_contiguous")
# TODO: move view_left/view_right here

View File

@@ -272,7 +272,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
def fuse(self): return self.alu(Ops.FUSE)
def gbarrier(self): return self.alu(Ops.GBARRIER)
def allreduce(self, op, device:str|tuple[str, ...]|UOp):
assert isinstance(self.device, tuple), f"allreduce must be on tuple {self.device} isn't"
return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op)

View File

@@ -14,7 +14,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
Ops.ALLREDUCE: "#ff40a0", Ops.GBARRIER: "#FFC14D", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0"}
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0"}
# VIZ API