From 978502be464b7a361e6d2cde229d20182bee93fe Mon Sep 17 00:00:00 2001 From: George Hotz Date: Fri, 17 Oct 2025 14:11:55 +0800 Subject: [PATCH] experiments with multi being range --- tinygrad/codegen/opt/__init__.py | 4 ++-- tinygrad/codegen/opt/postrange.py | 4 ++-- tinygrad/schedule/rangeify.py | 18 +++++++++++++++++- tinygrad/tensor.py | 6 +++--- tinygrad/uop/ops.py | 30 +++++++++++++++++++++++------- 5 files changed, 47 insertions(+), 15 deletions(-) diff --git a/tinygrad/codegen/opt/__init__.py b/tinygrad/codegen/opt/__init__.py index ca11b845ef..2b8eade41c 100644 --- a/tinygrad/codegen/opt/__init__.py +++ b/tinygrad/codegen/opt/__init__.py @@ -17,9 +17,9 @@ class Opt: def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})" axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u", - AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"} + AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.MULTI: "m"} axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE", - AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"} + AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta", AxisType.MULTI: "GREEN"} class KernelOptError(Exception): pass def check(cond:bool, msg:str=""): diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index e4635a2279..636885a08e 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -13,8 +13,8 @@ from tinygrad.renderer import Renderer remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)]) # NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters -axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3, - AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5} +axis_to_pos = {AxisType.MULTI: -2, AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, + AxisType.UPCAST: 3, AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5} class Scheduler: def __init__(self, ast:UOp, opts:Renderer): diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 5f9b4e9631..8a5f84aa77 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -108,6 +108,18 @@ earliest_rewrites = mop_cleanup+PatternMatcher([ (UPat(Ops.ASSIGN, src=(UPat.var("a"), UPat.var("b")), name="assign"), find_permutes), ]) +# ***************** + +pm_where_is_multi = PatternMatcher([ + # move *0 through where + (UPat.var("gate").where(UPat.var("a"), 0) * UPat.var("b"), lambda gate,a,b: gate.where(a*b, 0)), + # move *0 through unary op + (UPat(Ops.CONTIGUOUS, src=(UPat.var("gate").where(UPat.var("a"), 0),), name="u"), lambda gate,a,u: gate.where(u.replace(src=(a,)), 0)), + # move where 0 through reduce if the reduce ranges are not in the gate + (UPat(Ops.REDUCE, src=(UPat.var("gate").where(UPat.var("a"), 0),), name="red", allow_any_len=True), + lambda gate,a,red: gate.where(red.replace(src=(a,)+red.src[1:]), 0) if all(r not in gate.ranges for r in red.src[1:]) else None), +]) + # ***************** # 3.5 cleanups @@ -340,6 +352,7 @@ def handle_assign(ctx:LocalAddBufferContext, assign:UOp): def renumber_range(ctx:LocalAddBufferContext, r:UOp): if r.tag is not None: return None + if r.arg[-1] is AxisType.MULTI: return None ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=()) ctx.range += 1 return ret @@ -414,7 +427,7 @@ class Kernel: return f"" def split_store(ctx:list[UOp], x:UOp): - if len(x.ranges): return None + if len([r for r in x.ranges if r.arg[-1] != AxisType.MULTI]): return None if x.src[0].ptrdtype.addrspace is AddrSpace.LOCAL: return None # local kernel rewrite @@ -496,6 +509,9 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: # NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right tsink = graph_rewrite(tsink, symbolic_simple+pm_reduce_unparented, name="symbolic") # this supports const folding + + tsink = graph_rewrite(tsink, pm_where_is_multi, name="where_is_multi") + tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers") # TODO: can you substitute and remove costly buffers at the same time? tsink = graph_rewrite(tsink, pm_substitute_recurse, bottom_up=True, name="run substitutes") diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 1028519341..4537b9dc35 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -231,9 +231,9 @@ class Tensor(MathTrait): # verify Tensors match the spec if __debug__: type_verify(list(big_sink.toposort()), tensor_uop_spec) - if any(isinstance(x._device, tuple) for x in big_sink.toposort()): - _apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map") - big_sink = UOp.sink(*flatten([x.uop.src if x.uop.op is Ops.MULTI else [x.uop] for x in (self,)+lst])) + #if any(isinstance(x._device, tuple) for x in big_sink.toposort()): + # _apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map") + # big_sink = UOp.sink(*flatten([x.uop.src if x.uop.op is Ops.MULTI else [x.uop] for x in (self,)+lst])) becomes_map = get_rangeify_map(big_sink) _apply_map_to_tensors(becomes_map, name="Apply Kernelize Map") diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 50e474eb9c..225bf8c7dd 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: class AxisType(Enum): def __repr__(self): return str(self) GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702 - THREAD = auto() + THREAD = auto(); MULTI = auto() # noqa: E702 range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3} @@ -114,7 +114,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @functools.cached_property def key(self) -> bytes: return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest() - def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}{x.tagstr()}, src=(%s))") + def __repr__(self): + if self.dtype == dtypes.index: return srender(self) # makes shapes print nicely + return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}{x.tagstr()}, src=(%s))") def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else repr(self.arg) def tagstr(self): return f", tag={self.tag}" if self.tag is not None else "" @@ -220,7 +222,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): match self.op: case Ops.RESHAPE: if not all(x >= 0 for x in self.marg): raise ValueError(f"shape can't contain negative numbers {self.marg}") - if prod(ps) != prod(self.marg): raise ValueError(f"bad reshape: {ps} -> {self.marg}") + #if prod(ps) != prod(self.marg): raise ValueError(f"bad reshape: {ps} -> {self.marg}") return self.marg case Ops.EXPAND: if len(ps) != len(self.marg) or not all(s==ns or (s==1 and ns>=0) for s,ns in zip(ps, self.marg)): @@ -436,16 +438,30 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def _unshard(self, axis:int) -> UOp: bsz, dcount = self.shape[axis], len(self.device) - dnum = UOp.variable("_device_num", 0, dcount-1) + #dnum = UOp.variable("_device_num", 0, dcount-1) + dnum = UOp.range(dcount, -10, AxisType.MULTI) return self.pad(tuple((0,0) if a != axis else (bsz*dnum, bsz*(dcount-1) - bsz*dnum) for a in range(len(self.shape)))) def _shard(self, axis:int) -> UOp: dcount = len(self.device) - dnum = UOp.variable("_device_num", 0, dcount-1) + dnum = UOp.range(dcount, -10, AxisType.MULTI) if self.shape[axis] % dcount != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {dcount=}") sz = self.shape[axis] // dcount - return self.shrink(tuple((0,s) if i != axis else (dnum*sz,dnum*sz+sz) for i,s in enumerate(self.shape))) - def shard(self, devices:tuple[str, ...], axis:int) -> UOp: return self.copy_to_device(devices)._shard(axis).multi(axis) + #ret = self.reshape(tuple(s if i != axis else dnum*sz for i,s in enumerate(self.shape))) + #return ret + + #flatten([[s] if i != axis else [dcount, sz] for i,s in enumerate(self.shape)]))) + + #ret = self.shrink(tuple((0,s) if i != axis else (dnum*sz,dnum*sz+sz) for i,s in enumerate(self.shape))) + #print(ret.shape) + #print(dnum) + #dnum = UOp.variable("_device_num", 0, dcount-1) + # TODO: 0 isn't correct here + ret = self.shrink(tuple((0,s) if i != axis else (dnum*sz,dnum*sz+sz) for i,s in enumerate(self.shape))) + ret = ret.pad(tuple((0,0) if a != axis else (sz*dnum, sz*(dcount-1) - sz*dnum) for a in range(len(self.shape)))) + return ret + + def shard(self, devices:tuple[str, ...], axis:int) -> UOp: return self.copy_to_device(devices)._shard(axis) #.multi(axis) # *** from LazyBuffer ***