This commit is contained in:
George Hotz
2025-10-10 18:10:42 +08:00
parent 5b24999a36
commit 67a409343d
5 changed files with 25 additions and 6 deletions

View File

@@ -11,6 +11,17 @@ class TestOuterworldReduce(unittest.TestCase):
t = Tensor(UOp(Ops.REDUCE, dtype=out.uop.dtype, src=(out.uop, a), arg=Ops.ADD))
print(t.numpy())
class TestOuterworldAssign(unittest.TestCase):
def test_triple_add(self):
t = Tensor.zeros(5).contiguous().realize()
a = UOp.range(3, -1, AxisType.OUTER)
t = t.assign(t+1)
t = Tensor(UOp(Ops.ENDRANGE, dtype=t.uop.dtype, src=(a, t.uop))).contiguous()
self.assertListEqual(t.tolist(), [3,3,3,3,3])
@unittest.skip("gemm is complex")
def test_triple_gemm(self):
Tensor.manual_seed(1337)
x0 = Tensor.rand(1, 16).realize()
@@ -24,7 +35,7 @@ class TestOuterworldReduce(unittest.TestCase):
# does ASSIGN always terminate the range?
a = UOp.range(3, -1, AxisType.REDUCE)
x1 = x1.assign(x1 @ W[a])
out = Tensor(UOp(Ops.ENDRANGE, dtype=x1.uop.dtype, src=(x1.uop, a))).contiguous()
out = Tensor(UOp(Ops.ENDRANGE, dtype=x1.uop.dtype, src=(a, x1.uop))).contiguous()
out.realize()
print(out)

View File

@@ -18,8 +18,8 @@ class Opt:
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "green", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
axis_colors = {AxisType.OUTER: "GREEN", AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN",
AxisType.LOOP: "WHITE", AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
class KernelOptError(Exception): pass
def check(cond:bool, msg:str=""):

View File

@@ -13,8 +13,8 @@ from tinygrad.renderer import Renderer
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
axis_to_pos = {AxisType.OUTER: -2, AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2,
AxisType.UPCAST: 3, AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
class Scheduler:
def __init__(self, ast:UOp, opts:Renderer):

View File

@@ -15,10 +15,11 @@ if TYPE_CHECKING:
class AxisType(Enum):
def __repr__(self): return str(self)
OUTER = auto()
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
THREAD = auto()
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.ENDRANGE: 1}
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3}
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
@@ -240,6 +241,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if s in ret: del ret[s]
else:
for s in self.src: ret.update(s.ranges)
if self.op is Ops.ENDRANGE: del ret[self.src[0]]
return ret
@property
@@ -1098,6 +1100,8 @@ pm_lower_index_dtype = PatternMatcher([
(UPat((Ops.STORE, Ops.LOAD), src=(UPat(), UPat(), UPat().cast(dtypes.index)), allow_any_len=True, name="s"),
lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))),
(UPat((Ops.SINK, Ops.NOOP), src=UPat().cast(dtypes.index), name="n"), lambda n: n.replace(src=tuple(s.src[0] for s in n.src))),
# hack for ENDRANGE
(UPat(Ops.ENDRANGE, src=(UPat(Ops.RANGE, name="r").cast(dtypes.index),), allow_any_len=True, name="x"), lambda x,r: x.replace(src=(r,)+x.src[1:])),
])
def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]

View File

@@ -108,6 +108,10 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
(UPat(Ops.COPY, name="copy", src=(UPat.var("x"), UPat(Ops.DEVICE)), arg=None), lambda copy,x: copy.dtype == x.dtype),
(UPat(Ops.ALLREDUCE, name="red", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda red,x: red.dtype == x.dtype and isinstance(red.arg, Ops)),
(UPat(Ops.MULTI, name="multi"), lambda multi: all(x.dtype == multi.dtype for x in multi.src) and isinstance(multi.arg, int)),
# endrange/reduce for outerworld range work
(UPat(Ops.ENDRANGE, src=(UPat(Ops.RANGE),), allow_any_len=True), lambda: True),
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
])
# ***** uop type spec *****