experiments with multi being range

This commit is contained in:
George Hotz
2025-10-17 14:11:55 +08:00
parent 9561803cb0
commit 978502be46
5 changed files with 47 additions and 15 deletions

View File

@@ -17,9 +17,9 @@ class Opt:
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.MULTI: "m"}
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta", AxisType.MULTI: "GREEN"}
class KernelOptError(Exception): pass
def check(cond:bool, msg:str=""):

View File

@@ -13,8 +13,8 @@ from tinygrad.renderer import Renderer
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
axis_to_pos = {AxisType.MULTI: -2, AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2,
AxisType.UPCAST: 3, AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
class Scheduler:
def __init__(self, ast:UOp, opts:Renderer):

View File

@@ -108,6 +108,18 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
(UPat(Ops.ASSIGN, src=(UPat.var("a"), UPat.var("b")), name="assign"), find_permutes),
])
# *****************
pm_where_is_multi = PatternMatcher([
# move *0 through where
(UPat.var("gate").where(UPat.var("a"), 0) * UPat.var("b"), lambda gate,a,b: gate.where(a*b, 0)),
# move *0 through unary op
(UPat(Ops.CONTIGUOUS, src=(UPat.var("gate").where(UPat.var("a"), 0),), name="u"), lambda gate,a,u: gate.where(u.replace(src=(a,)), 0)),
# move where 0 through reduce if the reduce ranges are not in the gate
(UPat(Ops.REDUCE, src=(UPat.var("gate").where(UPat.var("a"), 0),), name="red", allow_any_len=True),
lambda gate,a,red: gate.where(red.replace(src=(a,)+red.src[1:]), 0) if all(r not in gate.ranges for r in red.src[1:]) else None),
])
# *****************
# 3.5 cleanups
@@ -340,6 +352,7 @@ def handle_assign(ctx:LocalAddBufferContext, assign:UOp):
def renumber_range(ctx:LocalAddBufferContext, r:UOp):
if r.tag is not None: return None
if r.arg[-1] is AxisType.MULTI: return None
ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=())
ctx.range += 1
return ret
@@ -414,7 +427,7 @@ class Kernel:
return f"<Kernel {len(list(self.ast.toposort()))} {ast_rep} {self.metadata}>"
def split_store(ctx:list[UOp], x:UOp):
if len(x.ranges): return None
if len([r for r in x.ranges if r.arg[-1] != AxisType.MULTI]): return None
if x.src[0].ptrdtype.addrspace is AddrSpace.LOCAL: return None
# local kernel rewrite
@@ -496,6 +509,9 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
# NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right
tsink = graph_rewrite(tsink, symbolic_simple+pm_reduce_unparented, name="symbolic") # this supports const folding
tsink = graph_rewrite(tsink, pm_where_is_multi, name="where_is_multi")
tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers")
# TODO: can you substitute and remove costly buffers at the same time?
tsink = graph_rewrite(tsink, pm_substitute_recurse, bottom_up=True, name="run substitutes")

View File

@@ -231,9 +231,9 @@ class Tensor(MathTrait):
# verify Tensors match the spec
if __debug__: type_verify(list(big_sink.toposort()), tensor_uop_spec)
if any(isinstance(x._device, tuple) for x in big_sink.toposort()):
_apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map")
big_sink = UOp.sink(*flatten([x.uop.src if x.uop.op is Ops.MULTI else [x.uop] for x in (self,)+lst]))
#if any(isinstance(x._device, tuple) for x in big_sink.toposort()):
# _apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map")
# big_sink = UOp.sink(*flatten([x.uop.src if x.uop.op is Ops.MULTI else [x.uop] for x in (self,)+lst]))
becomes_map = get_rangeify_map(big_sink)
_apply_map_to_tensors(becomes_map, name="Apply Kernelize Map")

View File

@@ -15,7 +15,7 @@ if TYPE_CHECKING:
class AxisType(Enum):
def __repr__(self): return str(self)
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
THREAD = auto()
THREAD = auto(); MULTI = auto() # noqa: E702
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3}
@@ -114,7 +114,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def key(self) -> bytes:
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}{x.tagstr()}, src=(%s))")
def __repr__(self):
if self.dtype == dtypes.index: return srender(self) # makes shapes print nicely
return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}{x.tagstr()}, src=(%s))")
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else repr(self.arg)
def tagstr(self): return f", tag={self.tag}" if self.tag is not None else ""
@@ -220,7 +222,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
match self.op:
case Ops.RESHAPE:
if not all(x >= 0 for x in self.marg): raise ValueError(f"shape can't contain negative numbers {self.marg}")
if prod(ps) != prod(self.marg): raise ValueError(f"bad reshape: {ps} -> {self.marg}")
#if prod(ps) != prod(self.marg): raise ValueError(f"bad reshape: {ps} -> {self.marg}")
return self.marg
case Ops.EXPAND:
if len(ps) != len(self.marg) or not all(s==ns or (s==1 and ns>=0) for s,ns in zip(ps, self.marg)):
@@ -436,16 +438,30 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def _unshard(self, axis:int) -> UOp:
bsz, dcount = self.shape[axis], len(self.device)
dnum = UOp.variable("_device_num", 0, dcount-1)
#dnum = UOp.variable("_device_num", 0, dcount-1)
dnum = UOp.range(dcount, -10, AxisType.MULTI)
return self.pad(tuple((0,0) if a != axis else (bsz*dnum, bsz*(dcount-1) - bsz*dnum) for a in range(len(self.shape))))
def _shard(self, axis:int) -> UOp:
dcount = len(self.device)
dnum = UOp.variable("_device_num", 0, dcount-1)
dnum = UOp.range(dcount, -10, AxisType.MULTI)
if self.shape[axis] % dcount != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {dcount=}")
sz = self.shape[axis] // dcount
return self.shrink(tuple((0,s) if i != axis else (dnum*sz,dnum*sz+sz) for i,s in enumerate(self.shape)))
def shard(self, devices:tuple[str, ...], axis:int) -> UOp: return self.copy_to_device(devices)._shard(axis).multi(axis)
#ret = self.reshape(tuple(s if i != axis else dnum*sz for i,s in enumerate(self.shape)))
#return ret
#flatten([[s] if i != axis else [dcount, sz] for i,s in enumerate(self.shape)])))
#ret = self.shrink(tuple((0,s) if i != axis else (dnum*sz,dnum*sz+sz) for i,s in enumerate(self.shape)))
#print(ret.shape)
#print(dnum)
#dnum = UOp.variable("_device_num", 0, dcount-1)
# TODO: 0 isn't correct here
ret = self.shrink(tuple((0,s) if i != axis else (dnum*sz,dnum*sz+sz) for i,s in enumerate(self.shape)))
ret = ret.pad(tuple((0,0) if a != axis else (sz*dnum, sz*(dcount-1) - sz*dnum) for a in range(len(self.shape))))
return ret
def shard(self, devices:tuple[str, ...], axis:int) -> UOp: return self.copy_to_device(devices)._shard(axis) #.multi(axis)
# *** from LazyBuffer ***