mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
[pr] Unify CONTIGUOUS and GBARRIER (#11121)
* Unify CONTIGUOUS and GBARRIER * Simplify rules
This commit is contained in:
@@ -3,7 +3,7 @@ from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND,
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.GBARRIER}
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK}
|
||||
|
||||
# **** Grouper decides which of the UOps realize
|
||||
|
||||
|
||||
@@ -136,9 +136,9 @@ def append_to_kernel(x:UOp):
|
||||
if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(dedup(metadata))))
|
||||
|
||||
create_kernels = PatternMatcher([
|
||||
# always give assign/gbarrier a kernel
|
||||
# always give assign/contiguous a kernel
|
||||
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel),
|
||||
(UPat(Ops.GBARRIER, src=(UPat.var("x"),)), create_kernel),
|
||||
(UPat(Ops.CONTIGUOUS, name="x"), create_kernel),
|
||||
# walk back the local graph until we reach a realized source
|
||||
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
|
||||
# push RESHAPE through MSELECT
|
||||
@@ -240,7 +240,7 @@ view_right = merge_views+PatternMatcher([
|
||||
# apply view after reduceops
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right),
|
||||
# apply view after elementwise ops
|
||||
(UPat(GroupOp.All-{Ops.SINK, Ops.GBARRIER, Ops.REDUCE_AXIS}, name="root"), elementwise_view_right),
|
||||
(UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS}, name="root"), elementwise_view_right),
|
||||
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
|
||||
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
|
||||
@@ -252,7 +252,7 @@ add_buffer_ops = PatternMatcher([
|
||||
# LOAD
|
||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).view(x.st),)),
|
||||
# STORE (except for meta ops)
|
||||
(UPat(Ops.SINK, src=(UPat(GroupOp.Meta, name="x"),)), lambda x:x),
|
||||
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
|
||||
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda ctx,sink:
|
||||
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
|
||||
# passthrough ASSIGN
|
||||
@@ -385,8 +385,8 @@ do_fuse = PatternMatcher([
|
||||
(UPat(Ops.REDUCE_AXIS, name="root"), fuse_arange),
|
||||
])
|
||||
|
||||
add_gbarrier = PatternMatcher([(UPat(GroupOp.All-{Ops.GBARRIER, Ops.ASSIGN}, name="x"),
|
||||
lambda ctx,x: x.replace(tag=1).gbarrier() if x in ctx and x.tag is None else None)])
|
||||
add_contiguous = PatternMatcher([(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"),
|
||||
lambda ctx,x: x.replace(tag=1).contiguous() if x in ctx and x.tag is None else None)])
|
||||
|
||||
# TODO: get this from the device through GrouperOpts
|
||||
DEVICE_MAX_BUFS = {"METAL":32, "WEBGPU":8}
|
||||
@@ -398,22 +398,18 @@ def limit_bufs(root:UOp):
|
||||
# count number of unique buffers flowing into this op
|
||||
bufs: set[UOp] = set()
|
||||
def gate_input(u:UOp):
|
||||
if (is_load:=(u.op in {Ops.BUFFER, Ops.GBARRIER, Ops.ASSIGN, Ops.MSTACK})): bufs.add(u)
|
||||
if (is_load:=(u.op in {Ops.BUFFER, Ops.CONTIGUOUS, Ops.ASSIGN, Ops.MSTACK})): bufs.add(u)
|
||||
return not is_load
|
||||
root.toposort(gate=gate_input)
|
||||
# NOTE: this -1 is for the output buffer
|
||||
if len(bufs)>=MAX_BUFS-1:
|
||||
return root.replace(src=tuple(s if s.base in bufs else s.replace(tag=1).gbarrier() for s in root.src))
|
||||
return root.replace(src=tuple(s if s.base in bufs else s.replace(tag=1).contiguous() for s in root.src))
|
||||
|
||||
finalize_gbarrier = PatternMatcher([
|
||||
finalize_contiguous = PatternMatcher([
|
||||
# if an op takes more than one input, check combined LOADs don't exceed device limits
|
||||
(UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs),
|
||||
# merge gbarrier
|
||||
(UPat((Ops.GBARRIER, Ops.CONTIGUOUS), src=(UPat(Ops.GBARRIER),), name="x"), lambda x: x.src[0]),
|
||||
# add contiguous to VIEW before GBARRIER
|
||||
(UPat(Ops.GBARRIER, src=(UPat(Ops.VIEW,),), name="x"), lambda x: x.src[0].contiguous().gbarrier()),
|
||||
# remove gbarrier on constants without a contiguous
|
||||
(UPat(Ops.GBARRIER, src=(UPat(Ops.CONST),), name="x"), lambda x: x.src[0]),
|
||||
# merge contiguous
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.CONTIGUOUS),), name="x"), lambda x: x.src[0]),
|
||||
# simplify views
|
||||
(UPat(Ops.VIEW, src=(UPat.var('x')), name="v"), lambda x,v: x.view(new_st) if (new_st:=v.arg.simplify()) != v.arg else None),
|
||||
])
|
||||
@@ -438,10 +434,10 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
# display the cleaned up tensor graph
|
||||
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph")
|
||||
|
||||
# insert gbarriers in places determined by the realize map
|
||||
# insert contiguous in places determined by the realize map
|
||||
realize_map = group_realizes(tensor_map[sink])
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], add_gbarrier, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="insert_gbarrier")
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], finalize_gbarrier+remove_tags, input_map=tensor_map, name="finalize_gbarrier")
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add_contiguous")
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], finalize_contiguous+remove_tags, input_map=tensor_map, name="finalize_contiguous")
|
||||
|
||||
# TODO: move view_left/view_right here
|
||||
|
||||
|
||||
@@ -272,7 +272,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
|
||||
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
|
||||
def fuse(self): return self.alu(Ops.FUSE)
|
||||
def gbarrier(self): return self.alu(Ops.GBARRIER)
|
||||
def allreduce(self, op, device:str|tuple[str, ...]|UOp):
|
||||
assert isinstance(self.device, tuple), f"allreduce must be on tuple {self.device} isn't"
|
||||
return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op)
|
||||
|
||||
@@ -14,7 +14,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
||||
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
|
||||
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.GBARRIER: "#FFC14D", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0"}
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0"}
|
||||
|
||||
# VIZ API
|
||||
|
||||
|
||||
Reference in New Issue
Block a user