From bb2e430ac36cd1d644c0c00c24ea8b44664228e5 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Fri, 16 May 2025 14:06:48 -0700 Subject: [PATCH] mselect --- tinygrad/engine/grouper.py | 27 ++++++++++++++++----------- tinygrad/engine/multi.py | 7 +++---- tinygrad/ops.py | 20 ++++++++++++-------- tinygrad/shape/shapetracker.py | 1 + tinygrad/shape/view.py | 15 +++++++++------ tinygrad/spec.py | 3 ++- tinygrad/viz/serve.py | 2 +- 7 files changed, 44 insertions(+), 31 deletions(-) diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index 3fe199a49c..690897f294 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -7,7 +7,7 @@ from tinygrad.codegen.lowerer import get_contraction_with_reduce, get_contractio from tinygrad.codegen.symbolic import symbolic_simple from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, ContextVar, Context, diskcache_put from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP, CAPTURE_PROCESS_REPLAY, RING -from tinygrad.dtype import ImageDType +from tinygrad.dtype import ImageDType, dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View, strides_for_shape from tinygrad.spec import type_verify, sched_spec @@ -86,8 +86,7 @@ sym = symbolic_simple+PatternMatcher([ (UPat(Ops.COPY, name="root", src=(UPat.cvar("x"), UPat(Ops.DEVICE))), lambda root,x: root.const_like(x.arg)), # store a shrink before COPY, otherwise view after the COPY (UPat(Ops.COPY, src=(UPat(Ops.VIEW, name="v"), UPat(Ops.DEVICE)), name="copy"), lambda copy,v: - v.contiguous().copy_to_device(copy.device, arg=copy.arg) if prod(v.shape) < prod(v.base.shape) else \ - v.base.copy_to_device(copy.device, arg=copy.arg).view(v.st)), + v.contiguous().copy_to_device(copy.device) if prod(v.shape) < prod(v.base.shape) else v.base.copy_to_device(copy.device).view(v.st)), # remove cast to image when it's already a contiguous image (UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)), lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None), @@ -244,7 +243,7 @@ def create_kernel(x:UOp, b:UOp|None=None): buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset)) return buffer.assign(kernel).reshape(x.shape) -DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER} +DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT} def append_to_kernel(ctx:dict[UOp, None], x:UOp): new_srcs: list[UOp] = [] metadata = dict.fromkeys(x.arg.metadata) @@ -393,6 +392,9 @@ def fix_kernel_ast(k:UOp) -> UOp|None: if s.op is Ops.ASSIGN: for out in s.src[1].arg.ast.src: parents_rep[out] = s.buf_uop.view(unwrap(out.st)) parents_rep[s] = s.buf_uop + if s.op is Ops.MSELECT: + for out in s.src[0].src[1].arg.ast.src: parents_rep[out] = s.src[0].buf_uop.view(unwrap(out.st)) + parents_rep[s] = s.buf_uop ast = k.arg.ast.substitute(parents_rep, name="replace realized") # push views to edges ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right") @@ -496,8 +498,9 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None: buf = buf.contiguous() # copy to all devices. if you shrink later, that'll be handled - if not use_ring: return functools.reduce(lambda x,y: x.alu(red.arg, y), - [UOp(Ops.COPY, buf.dtype, (buf, red.src[1]), arg=i) for i in range(len(buf.device))]) + use_ring = False + if not use_ring: + return functools.reduce(lambda x,y: x.alu(red.arg, y), [buf.mselect(i).copy_to_device(buf.device) for i in range(len(buf.device))]) # new ring reduce factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1) @@ -521,10 +524,12 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None: return functools.reduce(operator.add, [c.copy_to_device(buf.device).pad(pad) for pad,c in zip(pads, reduced_chunks)]).reshape(shape) replace_allreduce = PatternMatcher([ - (UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce), - # copy on specific copy - (UPat(Ops.COPY, src=(UPat(Ops.COPY, name="c1"), UPat(Ops.DEVICE, name="out_device")), name="c2"), - lambda c1,c2,out_device: c1.src[0].copy_to_device(out_device, arg=c1.arg) if c2.arg is None else None), + (UPat(Ops.ALLREDUCE, src=(UPat.var("buf"),), name="red"), handle_allreduce), +]) + +view_mselect = PatternMatcher([ + (UPat(Ops.MSELECT, src=(UPat(Ops.VIEW, name="v")), name="m"), lambda m,v: + v.src[0].mselect(m.arg).view(v.arg.substitute({v.device_num():UOp.const(dtypes.int, m.arg)}))), ]) @track_rewrites(name_fxn=get_name) @@ -533,7 +538,7 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]: tensor_map = graph_rewrite_map(big_sink, replace_allreduce, name="replace_allreduce") # merge_views + simplify - tensor_map = graph_rewrite_map(tensor_map[big_sink], insert_fuse+do_fuse+merge_views+sym+replace_contiguous, ctx={}, + tensor_map = graph_rewrite_map(tensor_map[big_sink], insert_fuse+do_fuse+merge_views+sym+replace_contiguous+view_mselect, ctx={}, input_map=tensor_map, name="merge_views") # display the cleaned up tensor graph diff --git a/tinygrad/engine/multi.py b/tinygrad/engine/multi.py index aef58c3696..98dc82bad9 100644 --- a/tinygrad/engine/multi.py +++ b/tinygrad/engine/multi.py @@ -118,7 +118,7 @@ def shrink_multi(root:UOp, multi:UOp): "cannot shrink sharded and non-sharded axis at the same time" # NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real # we just copy it to all the devices, no real. this will be optimized out later - return multi.src[0].copy_to_device(multi.device, arg=multi.bounds.index(root.arg[multi.axis])) + return multi.src[0].mselect(multi.bounds.index(root.arg[multi.axis])).copy_to_device(multi.device) return UOp.multi(*[x.shrink(tuple((0, x.shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))) for x in multi.src], axis=multi.axis) @@ -139,7 +139,7 @@ def copy_multi(multi:UOp, device:UOp): bsz, dcount = multi.shape[multi.axis]//len(multi.device), len(multi.device) dnum = UOp.variable("_device_num", 0, len(multi.device)-1) padded = multi.src[0].pad(tuple((0,0) if a != multi.axis else (bsz*dnum, bsz*(dcount-1) - bsz*dnum) for a in range(len(multi.shape)))) - ret = padded.allreduce(Ops.ADD, device) + ret = padded.allreduce(Ops.ADD) return ret if isinstance(device.arg, str) else ret.multi(axis=None) def assign_multi(dest:UOp, src:UOp): @@ -160,8 +160,7 @@ multi_pm = PatternMatcher([ (UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi), (UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi), (UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi), - (UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"), - lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)), + (UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"),), name="red"), lambda multi,red: multi.src[0].allreduce(red.arg).multi(axis=multi.axis)), (UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi), ]) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index f219d72a1a..6d68f23b7d 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -116,6 +116,9 @@ class Ops(FastEnum): # reduce REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702 + # multi + MCAT = auto(); MSELECT = auto() + # helper ops GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702 @@ -430,9 +433,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def contiguous(self): return self.alu(Ops.CONTIGUOUS) def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD) def fuse(self): return self.alu(Ops.FUSE) - def allreduce(self, op, device:str|tuple[str, ...]|UOp): - assert isinstance(self.device, tuple), f"allreduce must be on tuple {self.device} isn't" - return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op) + def allreduce(self, op:Ops): return UOp(Ops.ALLREDUCE, self.dtype, (self,), op) # *** from MultiLazyBuffer *** @@ -488,8 +489,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass): assert op is Ops.BIND, f"unknown op {op}" var, val = arg.unbind() return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val) - def copy_to_device(self, device:str|tuple[str, ...]|UOp, arg=None): - return UOp(Ops.COPY, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), arg) + def copy_to_device(self, device:str|tuple[str, ...]|UOp): + return UOp(Ops.COPY, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device)) + def mselect(self, idx:int) -> UOp: return UOp(Ops.MSELECT, self.dtype, (self,), idx) + def device_num(self): return UOp.variable("_device_num", 0, len(self.device)-1) def clone(self) -> UOp: return self.copy_to_device(self.device) @property def metadata(self) -> Metadata|None: return all_metadata.get(self, None) @@ -530,11 +533,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @functools.cached_property def _device(self) -> Optional[str|tuple[str, ...]]: if self.op is Ops.DEVICE: return self.arg - if self.op in {Ops.COPY, Ops.BUFFER, Ops.ALLREDUCE}: return self.src[1].device + if self.op is Ops.MSELECT: return self.src[0].device[self.arg] + if self.op in {Ops.COPY, Ops.BUFFER}: return self.src[1].device return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None @property def buf_uop(self) -> UOp: if self.op is Ops.BUFFER: return self + if self.op is Ops.MSELECT: return self.src[0].buf_uop.mselect(self.arg) assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}" return self.src[0].base @property @@ -570,8 +575,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR" assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]" return UOp(Ops.BIND, self.dtype, (self, self.const_like(val))) - def unbind(self, optional=False) -> tuple[Variable, int]: - if optional and self.op is Ops.DEFINE_VAR: return (self, -1) + def unbind(self) -> tuple[Variable, int]: assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}" return self.src[0], self.src[1].arg @property diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 97538ddf15..2609ec17ed 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -106,6 +106,7 @@ class ShapeTracker: unbound_views, var_vals = zip(*[v.unbind() for v in self.views]) if all(len(x) == 0 for x in var_vals): return self, {} return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals) + def substitute(self, dvars:dict[Variable, int]): return ShapeTracker(tuple(x.substitute(dvars) for x in self.views)) def real_strides(self, ignore_valid=False) -> tuple[Optional[sint], ...]: with Context(TRACK_MATCH_STATS=0): return views_to_real_strides(self.views, ignore_valid) diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index c93977410b..a7bd6710ec 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -141,12 +141,15 @@ class View: def unbind(self) -> tuple[View, dict[Variable, int]]: var_unboundvar_val = [(v, v.unbind()) for v in self.vars() if v.op is Ops.BIND] unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val} - def substitute(x:sint): return x if isinstance(x, int) else x.substitute(unbound_vars) - new_shape = tuple(map(substitute, self.shape)) - new_strides = tuple(map(substitute, self.strides)) - new_offset = substitute(self.offset) - new_mask = tuple((substitute(x[0]), substitute(x[1])) for x in self.mask) if self.mask is not None else None - return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val) + return self.substitute(unbound_vars), dict(x[1] for x in var_unboundvar_val) + + def substitute(self, dvars:dict[UOp, UOp]): + def _substitute(x:sint): return x if isinstance(x, int) else x.substitute(dvars) + new_shape = tuple(map(_substitute, self.shape)) + new_strides = tuple(map(_substitute, self.strides)) + new_offset = _substitute(self.offset) + new_mask = tuple((_substitute(x[0]), _substitute(x[1])) for x in self.mask) if self.mask is not None else None + return View.create(new_shape, new_strides, new_offset, new_mask) @functools.cache # pylint: disable=method-cache-max-size-none def __add__(self, vm1:View) -> Optional[View]: diff --git a/tinygrad/spec.py b/tinygrad/spec.py index 0369651690..ebf2862b48 100644 --- a/tinygrad/spec.py +++ b/tinygrad/spec.py @@ -78,7 +78,8 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([ # COPY/ALLREDUCE (UPat(Ops.COPY, name="copy", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda copy,x: copy.dtype == x.dtype), - (UPat(Ops.ALLREDUCE, name="red", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda red,x: red.dtype == x.dtype and isinstance(red.arg, Ops)), + (UPat(Ops.ALLREDUCE, name="red", src=(UPat.var("x"),)), lambda red,x: red.dtype == x.dtype and isinstance(red.arg, Ops)), + (UPat(Ops.MSELECT, name="m"), lambda m: m.arg >= 0 and m.arg < len(m.src[0].device)) ]) # ***** uop type spec ***** diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 9305fca8bd..8ac99ddcb2 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -15,7 +15,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.IGNORE: "#00C000", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF", Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500", - Ops.ALLREDUCE: "#ff40a0"} + Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#ffc0e0", Ops.MCAT: "#ffc0e0"} # VIZ API