mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-15 01:48:23 -05:00
work
This commit is contained in:
@@ -11,6 +11,17 @@ class TestOuterworldReduce(unittest.TestCase):
|
||||
t = Tensor(UOp(Ops.REDUCE, dtype=out.uop.dtype, src=(out.uop, a), arg=Ops.ADD))
|
||||
print(t.numpy())
|
||||
|
||||
class TestOuterworldAssign(unittest.TestCase):
|
||||
def test_triple_add(self):
|
||||
t = Tensor.zeros(5).contiguous().realize()
|
||||
|
||||
a = UOp.range(3, -1, AxisType.OUTER)
|
||||
t = t.assign(t+1)
|
||||
t = Tensor(UOp(Ops.ENDRANGE, dtype=t.uop.dtype, src=(a, t.uop))).contiguous()
|
||||
|
||||
self.assertListEqual(t.tolist(), [3,3,3,3,3])
|
||||
|
||||
@unittest.skip("gemm is complex")
|
||||
def test_triple_gemm(self):
|
||||
Tensor.manual_seed(1337)
|
||||
x0 = Tensor.rand(1, 16).realize()
|
||||
@@ -24,7 +35,7 @@ class TestOuterworldReduce(unittest.TestCase):
|
||||
# does ASSIGN always terminate the range?
|
||||
a = UOp.range(3, -1, AxisType.REDUCE)
|
||||
x1 = x1.assign(x1 @ W[a])
|
||||
out = Tensor(UOp(Ops.ENDRANGE, dtype=x1.uop.dtype, src=(x1.uop, a))).contiguous()
|
||||
out = Tensor(UOp(Ops.ENDRANGE, dtype=x1.uop.dtype, src=(a, x1.uop))).contiguous()
|
||||
out.realize()
|
||||
print(out)
|
||||
|
||||
|
||||
@@ -18,8 +18,8 @@ class Opt:
|
||||
|
||||
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
|
||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
|
||||
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
|
||||
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "green", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
|
||||
axis_colors = {AxisType.OUTER: "GREEN", AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN",
|
||||
AxisType.LOOP: "WHITE", AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
|
||||
|
||||
class KernelOptError(Exception): pass
|
||||
def check(cond:bool, msg:str=""):
|
||||
|
||||
@@ -13,8 +13,8 @@ from tinygrad.renderer import Renderer
|
||||
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
||||
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
|
||||
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
|
||||
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
|
||||
axis_to_pos = {AxisType.OUTER: -2, AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2,
|
||||
AxisType.UPCAST: 3, AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
|
||||
|
||||
class Scheduler:
|
||||
def __init__(self, ast:UOp, opts:Renderer):
|
||||
|
||||
@@ -15,10 +15,11 @@ if TYPE_CHECKING:
|
||||
|
||||
class AxisType(Enum):
|
||||
def __repr__(self): return str(self)
|
||||
OUTER = auto()
|
||||
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
|
||||
THREAD = auto()
|
||||
|
||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.ENDRANGE: 1}
|
||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3}
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
||||
@@ -240,6 +241,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if s in ret: del ret[s]
|
||||
else:
|
||||
for s in self.src: ret.update(s.ranges)
|
||||
if self.op is Ops.ENDRANGE: del ret[self.src[0]]
|
||||
return ret
|
||||
|
||||
@property
|
||||
@@ -1098,6 +1100,8 @@ pm_lower_index_dtype = PatternMatcher([
|
||||
(UPat((Ops.STORE, Ops.LOAD), src=(UPat(), UPat(), UPat().cast(dtypes.index)), allow_any_len=True, name="s"),
|
||||
lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))),
|
||||
(UPat((Ops.SINK, Ops.NOOP), src=UPat().cast(dtypes.index), name="n"), lambda n: n.replace(src=tuple(s.src[0] for s in n.src))),
|
||||
# hack for ENDRANGE
|
||||
(UPat(Ops.ENDRANGE, src=(UPat(Ops.RANGE, name="r").cast(dtypes.index),), allow_any_len=True, name="x"), lambda x,r: x.replace(src=(r,)+x.src[1:])),
|
||||
])
|
||||
def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
|
||||
|
||||
|
||||
@@ -108,6 +108,10 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
|
||||
(UPat(Ops.COPY, name="copy", src=(UPat.var("x"), UPat(Ops.DEVICE)), arg=None), lambda copy,x: copy.dtype == x.dtype),
|
||||
(UPat(Ops.ALLREDUCE, name="red", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda red,x: red.dtype == x.dtype and isinstance(red.arg, Ops)),
|
||||
(UPat(Ops.MULTI, name="multi"), lambda multi: all(x.dtype == multi.dtype for x in multi.src) and isinstance(multi.arg, int)),
|
||||
|
||||
# endrange/reduce for outerworld range work
|
||||
(UPat(Ops.ENDRANGE, src=(UPat(Ops.RANGE),), allow_any_len=True), lambda: True),
|
||||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
|
||||
])
|
||||
|
||||
# ***** uop type spec *****
|
||||
|
||||
Reference in New Issue
Block a user