Merge branch 'master' into retinanet_mlperf

This commit is contained in:
Francis Lata
2024-11-18 04:42:57 -08:00
57 changed files with 1197 additions and 1047 deletions

View File

@@ -9,7 +9,7 @@ from tinygrad.engine.realize import get_runner, CompiledRunner
from test.external.fuzz_linearizer import get_fuzz_rawbufs
from tinygrad.codegen.kernel import Kernel
from tinygrad.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer
from tinygrad.ops import LazyOp, Ops, ReduceOps, BufferOps, MemBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
@@ -26,12 +26,12 @@ class TestNV(unittest.TestCase):
TestNV.addr = struct.pack("QQ", TestNV.b.lazydata.buffer._buf.va_addr, TestNV.a.lazydata.buffer._buf.va_addr)
def test_oor_kernels(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
opts = [Opt(op=OptOps.TC, axis=6, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=2)] # noqa: E501
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["NV"])
def test_error_on_huge_dims(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 683, 1), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 683, 1), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=2)] # noqa: E501
with self.assertRaises(RuntimeError) as cm:
lin = Kernel(ast)
@@ -43,7 +43,7 @@ class TestNV(unittest.TestCase):
def test_buf4_usage(self):
TestNV.along = Tensor([105615], device="NV").realize()
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SIN, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.SIN, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501
temp_runner = get_runner(TestNV.d0.dname, (ast,))
temp_runner([TestNV.b.lazydata.buffer, TestNV.along.lazydata.buffer], var_vals={})
val = TestNV.b.lazydata.buffer.as_buffer().cast("f")[0]

View File

@@ -2,7 +2,7 @@
import unittest
from tinygrad import Device
from tinygrad.ops import UOp, Ops, BinaryOps
from tinygrad.ops import UOp, Ops
from tinygrad.engine.search import Opt, OptOps
from tinygrad.dtype import dtypes
from tinygrad.shape.shapetracker import ShapeTracker
@@ -20,7 +20,7 @@ class TestOpenpilotValidhack(unittest.TestCase):
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.MAX, dtypes.float, arg=None, src=(
x5:=UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8, 9, 10)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 8, 9, 10)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(

View File

@@ -25,7 +25,7 @@ from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.engine.search import get_kernel_actions, bufs_from_lin
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG, Timing
from tinygrad.ops import UnaryOps, UOp, Ops
from tinygrad.ops import UOp, Ops
from tinygrad.device import is_dtype_supported
def on_linearizer_will_run(): pass
@@ -252,7 +252,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2, opts_list=None):
def _is_simple(lin: Kernel) -> bool:
if len(lin.ast.src) > 1: return False
ast:UOp = lin.ast.src[0]
if ast.src[0].op is UnaryOps.CAST and ast.src[0].src[0].op is Ops.LOAD: return True
if ast.src[0].op is Ops.CAST and ast.src[0].src[0].op is Ops.LOAD: return True
return False
if __name__ == "__main__":

View File

@@ -6,7 +6,7 @@ from tinygrad.engine.realize import capturing, lower_schedule_item
from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.engine.schedule import LBScheduleItem, _graph_schedule, ScheduleItem
from tinygrad.ops import MetaOps
from tinygrad.ops import Ops
from tinygrad.tensor import Tensor, _to_np_dtype
ctx_vars = { MULTIOUTPUT: (0, 1) }
@@ -33,7 +33,7 @@ def fuzz_schedule(outs:List[LazyBuffer]):
for lsi in ts:
for out in lsi.outputs:
# freeze assign state before exec
if out.op is MetaOps.ASSIGN:
if out.op is Ops.ASSIGN:
prerealized[out] = out.buffer.as_buffer()
assign_targets[out.srcs[1]] = out
for x in lsi.inputs:
@@ -50,9 +50,9 @@ def fuzz_schedule(outs:List[LazyBuffer]):
rawbufs: Dict[LazyBuffer, Buffer] = {}
for lsi in ts:
for out in lsi.outputs:
base = rawbufs[lsi.inputs[0]].base if out.op is MetaOps.BUFFER_VIEW else None
base = rawbufs[lsi.inputs[0]].base if out.op is Ops.BUFFER_VIEW else None
rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype, base=base)
if out.op is MetaOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
if out.op is Ops.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
for x in lsi.inputs:
if x not in rawbufs:
# override the assign_target after ASSIGN

View File

@@ -42,9 +42,9 @@ def gt(expr, rng=None):
return expr > rng, rng
# NOTE: you have to replace these for this test to pass
from tinygrad.ops import python_alu, BinaryOps
python_alu[BinaryOps.MOD] = lambda x,y: x%y
python_alu[BinaryOps.IDIV] = lambda x,y: x//y
from tinygrad.ops import python_alu, Ops
python_alu[Ops.MOD] = lambda x,y: x%y
python_alu[Ops.IDIV] = lambda x,y: x//y
if __name__ == "__main__":
ops = [add_v, div, mul, add_num, mod]

View File

@@ -11,7 +11,6 @@ class TestCompileFailures(unittest.TestCase):
def test_interpolate_atari(self):
self.compile(Tensor.empty(210, 160, dtype='uint8').interpolate((64, 64)))
@unittest.skip("FIXME: broken on METAL")
def test_add_max_uchar(self):
self.compile((Tensor.empty(1024, dtype='uint8') + Tensor.empty(1024, dtype='uint8')).max())

View File

@@ -154,7 +154,7 @@ class TestReduceOpsConstFolding(unittest.TestCase):
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).sum())
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).sum().numpy(), 4)
# NOTE: cannot just count the non-padded area because some UnaryOps f do not have f(0) = 0.
# NOTE: cannot just count the non-padded area because some Ops f do not have f(0) = 0.
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).exp().sum())
np.testing.assert_allclose(Tensor.ones(4).pad(((1, 1),)).exp().sum().numpy(), 4 * math.e + 2)
@@ -251,7 +251,6 @@ class TestTautologicalCompare(unittest.TestCase):
np.testing.assert_equal((Tensor(True) < Tensor(False)).numpy(), False)
np.testing.assert_equal((Tensor(True) < Tensor(True)).numpy(), False)
@unittest.skip("not implemented yet")
def test_a_eq_a(self):
# self eq is always true for int or bool
a = Tensor([1, 2, 3])
@@ -261,7 +260,6 @@ class TestTautologicalCompare(unittest.TestCase):
a = Tensor([math.nan, 1.0, 2.0])
np.testing.assert_equal((a == a).numpy(), [False, True, True])
@unittest.skip("not implemented yet")
def test_a_ne_a(self):
# self not eq is always false for int or bool
a = Tensor([1, 2, 3])

View File

@@ -46,5 +46,24 @@ class TestFusionOp(unittest.TestCase):
with self.assertRaises(AssertionError): self.assertEqual(sched1[-1].ast, sched3[-1].ast)
self.assertLess(time.perf_counter()-st, 2.0)
def test_recursive_pad(self):
st = time.perf_counter()
val = 1.0
a = Tensor(val).realize()
for _ in range(24): a = Tensor.stack(a, a)[0]
r = a.item()
self.assertEqual(r, val)
self.assertLess(time.perf_counter()-st, 2.0)
def test_recursive_reshape(self):
st = time.perf_counter()
a = Tensor.empty(32, 32).realize()
b = Tensor.empty(16, 2).realize()
r = a.sum(1)
for _ in range(24): r = r.reshape(16, 2) + b
sched = r.schedule()
self.assertEqual(len(sched), 1)
self.assertLess(time.perf_counter()-st, 2.0)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -2,8 +2,9 @@
import numpy as np
import unittest
from tinygrad import Tensor, Device, dtypes
from tinygrad.engine.realize import run_schedule
from tinygrad.ops import Ops
from tinygrad.engine.lazy import LazyBuffer, MetaOps
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.engine.schedule import create_schedule
class TestLazyBuffer(unittest.TestCase):
@@ -69,6 +70,25 @@ class TestLazyBuffer(unittest.TestCase):
assert lb.const_like(1).base.arg == 1.0
assert type(lb.const_like(1).base.arg) is float
def test_forced_realized_alu(self):
a = Tensor.randn(2, 2).realize()
b = Tensor.randn(2, 2).realize()
add = a + b
add.lazydata.forced_realize = True
out = add+2
sched = create_schedule([out.lazydata])
self.assertEqual(len(sched), 2)
run_schedule(sched)
np.testing.assert_allclose(out.numpy(), a.numpy()+b.numpy()+2)
def test_forced_realized_metaop(self):
empty = Tensor.empty(1)
empty.lazydata.forced_realize = True
sched = create_schedule([empty.lazydata])
self.assertEqual(len(sched), 1)
self.assertIs(sched[0].ast.op, Ops.EMPTY)
run_schedule(sched)
class TestReduceOp(unittest.TestCase):
def test_no_split_reduce_kernel(self):
a = Tensor.rand(4, 4).realize()
@@ -95,24 +115,24 @@ class TestReduceOp(unittest.TestCase):
class TestView(unittest.TestCase):
def test_all_masked_out(self):
# start with non CONST MetaOps
# start with non CONST Ops
a = Tensor.rand(10, 10)
assert a.lazydata.base.op is not MetaOps.CONST
assert a.lazydata.base.op is not Ops.CONST
# all masked out, degrades to const 0
b = a.pad(((0, 10), None))[10:]
assert b.shape == (10, 10)
assert b.lazydata.base.op is MetaOps.CONST and b.lazydata.base.arg == 0
assert b.lazydata.base.op is Ops.CONST and b.lazydata.base.arg == 0
# mask out dim = 1 works too
b = a.pad((None, (0, 10)))[:, 10:]
assert b.shape == (10, 10)
assert b.lazydata.base.op is MetaOps.CONST and b.lazydata.base.arg == 0
assert b.lazydata.base.op is Ops.CONST and b.lazydata.base.arg == 0
# partial masked out does not degrade into CONST
b = a.pad(((0, 5), None))[5:]
assert b.shape == (10, 10)
assert b.lazydata.base.op is not MetaOps.CONST
assert b.lazydata.base.op is not Ops.CONST
if __name__ == "__main__":
unittest.main()

View File

@@ -6,7 +6,7 @@ from dataclasses import replace
from test.helpers import ast_const
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel
from tinygrad.codegen.lowerer import get_grouped_dims
from tinygrad.ops import UOp, Ops, BinaryOps, TernaryOps, UnaryOps, GroupOp
from tinygrad.ops import UOp, Ops, GroupOp
from tinygrad.device import Device, Buffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
@@ -109,10 +109,10 @@ class TestLinearizer(unittest.TestCase):
st_x = x.lazydata.st
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((1, 32)).expand((32, 32)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (1,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (1,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((32, 1)).to_uop()))
diff = second_x + first_reduce*ast_const(dtypes.float, -1, (32, 1))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (0,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (0,)))
store = UOp(Ops.STORE, dtypes.void, (g0, ShapeTracker.from_shape((1, 1)).to_uop(), second_reduce))
sink = UOp(Ops.SINK, src=(store,))
opts = [
@@ -145,10 +145,10 @@ class TestLinearizer(unittest.TestCase):
st_x = x.lazydata.st
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((27, 1, 32, 5)).expand((27, 32, 32, 5)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((27, 32, 1, 5)).to_uop()))
diff = second_x + first_reduce*ast_const(dtypes.float, -1, (27, 32, 1, 5))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
sink = UOp(Ops.SINK, src=(store,))
opts = [
@@ -207,13 +207,13 @@ class TestLinearizer(unittest.TestCase):
x2 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
g0, g1, g2, g3 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(4)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x0.lazydata.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g2, x1.lazydata.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5)).to_uop()))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 32, 32, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (2,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (2,)))
third_x = UOp(Ops.LOAD, dtypes.float, (g3, x2.lazydata.st.reshape((27, 32, 1, 1, 5)).to_uop()))
mul = (third_x*second_reduce)
third_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (mul,), (BinaryOps.ADD, (1,)))
third_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (mul,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 1, 5)).to_uop(), third_reduce))
sink = UOp(Ops.SINK, src=(store,))
wanna_output = (x2.numpy()*(x1.numpy()-x0.numpy().sum(axis=1, keepdims=True)).sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,1,5)
@@ -234,11 +234,11 @@ class TestLinearizer(unittest.TestCase):
st = x.lazydata.st
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, st.reshape((8, 1, 32, 8, 1, 16)).expand((8, 32, 32, 8, 16, 16)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2, 5)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2, 5)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, st.reshape((8, 32, 1, 8, 16, 1)).to_uop()))
neg_first_reduce = first_reduce * ast_const(dtypes.float, -1, (8, 32, 1, 8, 16, 1))
squares = (second_x+neg_first_reduce)
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (1, 4)))
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (1, 4)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((8, 1, 1, 8, 1, 1)).to_uop(), squares_sum,))
sink = UOp(Ops.SINK, src=(store,))
wanna_output = (x.numpy()-x.numpy().sum(axis=(1,3), keepdims=True)).sum(axis=(1,3)).reshape((8,1,1,8,1,1))
@@ -285,10 +285,10 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).softmax(1).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
sink = UOp(Ops.SINK, src=(store,))
opts = [
@@ -317,11 +317,11 @@ class TestLinearizer(unittest.TestCase):
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop()))
first_x_p = UOp(Ops.LOAD, dtypes.float, (g2, x_p.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
first_reduce_p = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x_p.alu(UnaryOps.EXP2),), (BinaryOps.ADD, (2,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
first_reduce_p = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x_p.alu(Ops.EXP2),), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1)).to_uop()))
diff = (second_x+(first_reduce + first_reduce_p)*ast_const(dtypes.float, -1, (4, 32, 1)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((4, 1, 1)).to_uop(), second_reduce))
sink = UOp(Ops.SINK, src=(store,))
opts = [
@@ -352,10 +352,10 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store0 = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
second_out = second_reduce * ast_const(dtypes.float, 1/15, (27, 1, 1, 5))
store1 = UOp(Ops.STORE, src=(g1, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_out))
@@ -375,10 +375,10 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store0 = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
store1 = UOp(Ops.STORE, src=(g1, ShapeTracker(views=(View(shape=(27,15,1,5), strides=(5,0,1,1), offset=0, mask=None, contiguous=False),)).to_uop(), first_reduce)) # noqa: E501
wanna_output0 = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
@@ -399,10 +399,10 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop()))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
sink = UOp(Ops.SINK, src=(store,))
opts = [[Opt(OptOps.UNROLL, 0, 3), Opt(OptOps.UNROLL, 0, 3)]]
@@ -415,10 +415,10 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop()))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
sink = UOp(Ops.SINK, src=(store,))
opts = [[Opt(OptOps.UPCAST, 0, 3)]]
@@ -434,10 +434,10 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(27, 12, 5, dtype=dtypes.float).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 12, 1, 5)).to_uop()))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 12, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
sink = UOp(Ops.SINK, src=(store,))
opts = [[Opt(OptOps.GROUPTOP, 0, 3), Opt(OptOps.GROUPTOP, 1, 3)]]
@@ -450,13 +450,13 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop()))
squares = (second_x+neg_mean)*(second_x+neg_mean)
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,)))
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,)))
variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1))
std = variance.alu(UnaryOps.SQRT)
std = variance.alu(Ops.SQRT)
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), std))
sink = UOp(Ops.SINK, src=(store,))
wanna_output = x.numpy().std(axis=2, ddof=0).reshape((15,25,1,1))
@@ -468,13 +468,13 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
neg_mean = first_reduce * ast_const(dtypes.float, -0.04, (15, 25, 1, 35))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).to_uop()))
squares = (second_x+neg_mean)*(second_x+neg_mean)
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (1,)))
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (1,)))
variance = squares_sum * ast_const(dtypes.float, 0.04, (15, 1, 1, 35))
std = variance.alu(UnaryOps.SQRT)
std = variance.alu(Ops.SQRT)
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 1, 1, 35)).to_uop(), std))
sink = UOp(Ops.SINK, src=(store,))
wanna_output = x.numpy().std(axis=1, ddof=0).reshape((15,1,1,35))
@@ -488,13 +488,13 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1))
second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop()))
squares = (second_x+neg_mean)*(second_x+neg_mean)
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,)))
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,)))
variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1))
std = variance.alu(UnaryOps.SQRT)
std = variance.alu(Ops.SQRT)
store_mean = UOp(Ops.STORE, src=(g1, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), neg_mean))
store_std = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), std))
sink = UOp(Ops.SINK, src=(store_std, store_mean))
@@ -511,13 +511,13 @@ class TestLinearizer(unittest.TestCase):
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
# push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
neg_mean = first_reduce * ast_const(dtypes.float, -0.03125, (3, 27, 32, 1))
# store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 32, 1)).to_uop(), mean))
# verify_lazyop(store)
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 32, 1)).to_uop()))
squares = (second_x+neg_mean)*(second_x+neg_mean)
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,)))
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,)))
variance = squares_sum * ast_const(dtypes.float, 0.03125, (3, 27, 1, 1))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 1, 1)).to_uop(), variance))
sink = UOp(Ops.SINK, src=(store,))
@@ -532,13 +532,13 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.rand(4, 32).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32)).to_uop()))
max_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.MAX, (2,)))
max_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.MAX, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1,)).to_uop()))
centered_x = second_x+max_x*ast_const(dtypes.float, -1, (4, 32, 1))
exp_x = centered_x.alu(UnaryOps.EXP2)
sum_exp_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (exp_x,), (BinaryOps.ADD, (1,)))
# y = exp_x * sum_exp_x.alu(UnaryOps.RECIP) # kernels cannot do a return to full shape
recip_sum_exp_x = sum_exp_x.alu(UnaryOps.RECIP)
exp_x = centered_x.alu(Ops.EXP2)
sum_exp_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (exp_x,), (Ops.ADD, (1,)))
# y = exp_x * sum_exp_x.alu(Ops.RECIP) # kernels cannot do a return to full shape
recip_sum_exp_x = sum_exp_x.alu(Ops.RECIP)
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((4,1,1)).to_uop(), recip_sum_exp_x))
sink = UOp(Ops.SINK, src=(store,))
expected = 1/np.exp2(x.numpy() - x.numpy().max(axis=-1, keepdims=True)).sum(axis=-1, keepdims=True).reshape(4,1,1)
@@ -556,7 +556,7 @@ class TestLinearizer(unittest.TestCase):
View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False)))
arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384))
arange_axis = (3,)
arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis))
arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (Ops.ADD, arange_axis))
output_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape))
out = arange+ast_const(dtypes.int, -1, output_shape)
store = UOp(Ops.STORE, src=(UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0), ShapeTracker.from_shape(output_shape).to_uop(), out))
@@ -573,7 +573,7 @@ class TestLinearizer(unittest.TestCase):
# TODO: do this arange broadcast in the scheduler
arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384))
arange_axis = (3,)
arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis))
arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (Ops.ADD, arange_axis))
arange_out_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape))
arange = arange+ast_const(dtypes.int, -1, arange_out_shape)
# p2: the indexing
@@ -581,10 +581,10 @@ class TestLinearizer(unittest.TestCase):
data1 = (g1, ShapeTracker.from_shape(dataset.shape).reshape((1, 16384, 256, 1)).expand(arange_out_shape).to_uop())
idxs = Tensor([0,3,5,6]).realize()
data2 = (g2, ShapeTracker.from_shape((4,)+(1,)*(len(arange_out_shape)-1)).expand(arange_out_shape).to_uop())
arange_eq = arange.alu(BinaryOps.CMPNE, UOp(Ops.LOAD, dtypes.int, data2)).alu(BinaryOps.CMPNE, ast_const(dtypes.bool, True, arange_out_shape))
arange_eq = arange.alu(Ops.CMPNE, UOp(Ops.LOAD, dtypes.int, data2)).alu(Ops.CMPNE, ast_const(dtypes.bool, True, arange_out_shape))
reduce_input = UOp(Ops.LOAD, dataset.dtype, data1)*UOp(Ops.CAST, dataset.dtype.scalar(), src=(arange_eq,))
out_axis = (1,)
out = UOp(Ops.REDUCE_AXIS, reduce_input.dtype, (reduce_input,), (BinaryOps.ADD, out_axis))
out = UOp(Ops.REDUCE_AXIS, reduce_input.dtype, (reduce_input,), (Ops.ADD, out_axis))
output_shape = tuple(1 if i in out_axis else s for i,s in enumerate(arange_out_shape))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape(output_shape).to_uop(), out))
sink = UOp(Ops.SINK, src=(store,))
@@ -605,7 +605,7 @@ class TestLinearizer(unittest.TestCase):
ast_const(dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), val=10),
UOp(Ops.MUL, dtypes.int, arg=None, src=(
ast_const(dtypes.int, -1, (1, 20, 1)),
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.MAX, (0,)), src=(
UOp(Ops.MUL, dtypes.int, arg=None, src=(
UOp(Ops.CAST, dtypes.int, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
@@ -618,7 +618,7 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa E501
ast_const(dtypes.bool, True, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)), # noqa E501
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (2,)), src=(
ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 20, 10), strides=(1, 0, 20), offset=0, mask=None, contiguous=False)))),)), # noqa E501
ast_const(dtypes.int, 10, (10, 20, 1)))),)),)),)),)),
ast_const(dtypes.int, -1, (1, 20, 1)),)),)),))
@@ -637,7 +637,7 @@ class TestLinearizer(unittest.TestCase):
ast_const(dtypes.int, 200, (1, 1)),
UOp(Ops.MUL, dtypes.int, arg=None, src=(
ast_const(dtypes.int, -1, (1, 1)),
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.MAX, (0,)), src=(
UOp(Ops.MUL, dtypes.int, arg=None, src=(
UOp(Ops.CAST, dtypes.int, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
@@ -650,7 +650,7 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa: E501
ast_const(dtypes.bool, True, (200, 1)),)),)),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(201, 399), strides=(0, 0), offset=0, mask=((0, 201), (199, 399)), contiguous=False), View(shape=(200, 200), strides=(1, 400), offset=0, mask=None, contiguous=False)))),)), # noqa: E501
ast_const(dtypes.int, 200, (200, 1)),)),)),)),)),)),
ast_const(dtypes.int, -1, (1, 1)),)),)),))
@@ -672,16 +672,16 @@ class TestLinearizer(unittest.TestCase):
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop()))
x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop()))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.ADD, (1,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),),(BinaryOps.ADD, (0,)))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.ADD, (1,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),),(Ops.ADD, (0,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,1,N)).to_uop(), r1))
sink = UOp(Ops.SINK, src=(store,))
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=0, keepdims=True)).sum(axis=0).reshape(1,1,N)], opts=opts)
x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)).to_uop()))
x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, N, 1)).to_uop()))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.ADD, (2,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (BinaryOps.ADD, (1,)))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.ADD, (2,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((N,1,1)).to_uop(), r1))
sink = UOp(Ops.SINK, src=(store,))
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(N,1,1)], opts=opts)
@@ -699,16 +699,16 @@ class TestLinearizer(unittest.TestCase):
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop()))
x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop()))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.MAX, (1,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),), (BinaryOps.MAX, (0,)))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.MAX, (1,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),), (Ops.MAX, (0,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,1,N)).to_uop(), r1))
sink = UOp(Ops.SINK, src=(store,))
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().max(axis=0, keepdims=True)).max(axis=0).reshape(1,1,N)], opts=opts)
x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)).to_uop()))
x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, N, 1)).to_uop()))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.MAX, (2,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (BinaryOps.MAX, (1,)))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.MAX, (2,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (Ops.MAX, (1,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((N,1,1)).to_uop(), r1))
sink = UOp(Ops.SINK, src=(store,))
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().max(axis=1, keepdims=True)).max(axis=1).reshape(N,1,1)], opts=opts)
@@ -735,7 +735,7 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
ast_const(dtypes.float, 0.5*N, (N, 1, 1)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
@@ -743,7 +743,7 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
ast_const(dtypes.float, 0.75*N, (N, N, 1)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.LOAD, dtypes.float, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
ld0.to_uop(),)),)),)),
@@ -768,7 +768,7 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
ast_const(dtypes.float, 0.5*N, (1, 1, N)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -776,7 +776,7 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
ast_const(dtypes.float, 0.75*N, (N, 1, N)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.LOAD, dtypes.float, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
ld0.to_uop(),)),)),)),
@@ -804,7 +804,7 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
ast_const(dtypes.float, 0.5*N, (1, 1, 1, 1)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 1)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
@@ -812,7 +812,7 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
ast_const(dtypes.float, 0.75*N, (N, N, 1, 1)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2, 3)), src=(
UOp(Ops.LOAD, dtypes.float, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, N, N), strides=(0, 0, N, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501
@@ -831,7 +831,7 @@ class TestLinearizer(unittest.TestCase):
def test_end_local(self):
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=i) for i in range(2)]
load = UOp(Ops.LOAD, dtypes.int, (g1, ShapeTracker.from_shape((32,)).to_uop()))
reduce = UOp(Ops.REDUCE_AXIS, dtypes.int, (load,), (BinaryOps.ADD, (0,)))
reduce = UOp(Ops.REDUCE_AXIS, dtypes.int, (load,), (Ops.ADD, (0,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,)).to_uop(), reduce))
sink = UOp(Ops.SINK, src=(store,))
load_t = Tensor.full(load.st_arg.shape, 1).contiguous().realize()
@@ -1219,20 +1219,20 @@ class TestLinearizer(unittest.TestCase):
assert len(sched) == 1
lin = Kernel(sched[0].ast)
assert sum(u.op in {UnaryOps.RECIP, BinaryOps.FDIV} for u in lin.linearize().uops) == max_ops, msg
assert sum(u.op in {Ops.RECIP, Ops.FDIV} for u in lin.linearize().uops) == max_ops, msg
a = Tensor.empty((4,4))
b = Tensor.empty((4,4))
d = Tensor.empty((4,4))
c = (a*b)/b
helper(c, "found UnaryOps.RECIP in (a*b)/b operation")
helper(c, "found Ops.RECIP in (a*b)/b operation")
c = a/a
helper(c, "found UnaryOps.RECIP in (a/a) operation")
helper(c, "found Ops.RECIP in (a/a) operation")
c = (a/b)/d
helper(c, "found multiple UnaryOps.RECIP in (a/b)/d operation", 1)
helper(c, "found multiple Ops.RECIP in (a/b)/d operation", 1)
def test_sum_collapse(self):
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
@@ -1260,7 +1260,7 @@ class TestLinearizer(unittest.TestCase):
lin = Kernel(sched_copy[-1].ast)
lin.hand_coded_optimizations()
lin.linearize()
assert not any(u.op == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded"
assert not any(u.op == Ops.WHERE for u in lin.uops), "found where where where should be folded"
def test_phi_simplification(self):
def helper(t, max_ops=0):
@@ -1272,7 +1272,7 @@ class TestLinearizer(unittest.TestCase):
assert len(set([u.op for u in uops if u.op in {Ops.RANGE, Ops.SPECIAL}])) == 1, "has either specials or ranges, not both"
assert len([u for u in uops if u.op is Ops.ASSIGN]) == 0, "ASSIGN should have been simplified"
# TODO: once uops track min/max this will be fixed
#assert len([u for u in uops if u.op is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops"
#assert len([u for u in uops if u.op is Ops.MAX]) <= max_ops, "no unnecessary MAX ops"
helper(Tensor.arange(5.5, (3.5*300), 3.5), max_ops=2)
helper(Tensor.arange(-1, -100, -5), max_ops=2)
@@ -1602,7 +1602,7 @@ class TestFloat4(unittest.TestCase):
UOp(Ops.STORE, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.CAST, dtypes.float, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, src=(
@@ -1632,7 +1632,7 @@ class TestFloat4(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
@@ -1662,7 +1662,7 @@ class TestFloat4(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),))), # noqa: E501
UOp(Ops.CAST, dtypes.half, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (4, 6)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4, 6)), src=(
UOp(Ops.CAST, dtypes.float, src=(
UOp(Ops.LOAD, dtypes.half, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1),
@@ -1949,7 +1949,7 @@ class TestKernelOpts(unittest.TestCase):
UOp(Ops.STORE, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),))),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
@@ -2138,7 +2138,7 @@ class TestKernelOpts(unittest.TestCase):
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
ld0 = UOp(Ops.LOAD, dtypes.float, (g1, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501
ld1 = UOp(Ops.LOAD, dtypes.float, (g2, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501
store = UOp(Ops.STORE, src=(g0, ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)).to_uop(), UOp(Ops.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (BinaryOps.ADD, (0, 2, 4, 6)),))) # noqa: E501
store = UOp(Ops.STORE, src=(g0, ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)).to_uop(), UOp(Ops.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (Ops.ADD, (0, 2, 4, 6)),))) # noqa: E501
sink = UOp(Ops.SINK, src=(store,))
data1 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()
data2 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()

View File

@@ -5,7 +5,7 @@
import unittest
from test.helpers import ast_const
from tinygrad import Device, dtypes
from tinygrad.ops import UOp, Ops, BinaryOps
from tinygrad.ops import UOp, Ops
from tinygrad.helpers import getenv
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.engine.search import Opt, OptOps
@@ -21,7 +21,7 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.MAX, dtypes.half, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
@@ -64,7 +64,7 @@ class TestLinearizerDumb(unittest.TestCase):
ast_const(dtypes.bool, True, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
ast_const(dtypes.int, -1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=()),)),)),
ast_const(dtypes.int, 1000, st_src=(
@@ -75,7 +75,7 @@ class TestLinearizerDumb(unittest.TestCase):
for opt in opts: k.apply_opt(opt)
prg = k.to_program()
print(prg.src)
assert prg.uops is not None and not any(uop.op is BinaryOps.MAX for uop in prg.uops), "leftover MAX"
assert prg.uops is not None and not any(uop.op is Ops.MAX for uop in prg.uops), "leftover MAX"
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
def test_expander_new_srcs(self):
@@ -83,7 +83,7 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
@@ -105,14 +105,14 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (2,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32001, 63999), strides=(0, 0), offset=0, mask=((0, 32001), (31999, 63999)), contiguous=False), View(shape=(4096, 32000, 32000), strides=(0, 1, 64000), offset=0, mask=None, contiguous=False))), src=()),)),)),
ast_const(dtypes.int, -1, st_src=(
@@ -136,7 +136,7 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
@@ -168,7 +168,7 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 1)), src=(
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
@@ -200,7 +200,7 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),

View File

@@ -3,7 +3,7 @@ import unittest, random
import numpy as np
from tinygrad.codegen.kernel import Kernel, KernelOptError
from tinygrad.device import is_dtype_supported
from tinygrad.ops import UOp, Ops, BinaryOps
from tinygrad.ops import UOp, Ops
from tinygrad.engine.search import Opt, OptOps
from tinygrad import Device, dtypes, Tensor
from tinygrad.helpers import CI
@@ -47,7 +47,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 16), strides=(16, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
@@ -64,7 +64,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 2, 37, 9, 1, 1), strides=(666, 333, 9, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (4, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.MAX, (4, 5)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 2, 111, 27), strides=(6160, 3080, 28, 1), offset=0, mask=((0, 32), (0, 2), (0, 110), (0, 27)), contiguous=False), View(shape=(32, 2, 37, 9, 2, 2), strides=(5994, 2997, 81, 3, 27, 1), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
@@ -76,7 +76,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 16, 1), strides=(128, 16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),))
@@ -89,7 +89,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 6)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
@@ -111,7 +111,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
ast_const(dtypes.int, -1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))), src=()),)),)),
ast_const(dtypes.int, 10, st_src=(
@@ -125,7 +125,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 32, 1, 34, 1, 34), strides=(36992, 1156, 0, 34, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 4)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2, 4)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 32, 6, 8, 4, 6, 8, 4), strides=(2048, 64, 6291456, 8, 0, 1048576, 1, 0), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 8), (0, 1), (0, 6), (0, 8), (0, 1)), contiguous=False), View(shape=(512, 32, 6, 35, 6, 35), strides=(1179648, 36864, 6144, 192, 32, 1), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 32), (0, 6), (0, 32)), contiguous=False), View(shape=(512, 32, 238, 238), strides=(1411200, 44100, 210, 1), offset=0, mask=((0, 512), (0, 32), (0, 210), (0, 210)), contiguous=False), View(shape=(512, 32, 7, 34, 7, 34), strides=(1812608, 56644, 8092, 238, 34, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),))
@@ -142,7 +142,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.RECIP, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
x9:=UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
@@ -166,7 +166,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 0, 0, 4500, 0, 0, 0, 0, 900, 60, 12, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -183,7 +183,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
@@ -202,7 +202,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.RECIP, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 3)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 3)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
@@ -277,7 +277,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 6)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
@@ -299,7 +299,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 8)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 8)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
x5:=UOp(Ops.ADD, dtypes.float, arg=None, src=(
x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
@@ -313,7 +313,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
x6,)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 8)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 8)), src=(
x5,)),)),)),)),))
opts = [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.GROUP, axis=0, amt=4)]
helper_test_lin(Kernel(ast), opts, failed_platforms=[])
@@ -325,7 +325,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
@@ -344,7 +344,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 6)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
@@ -370,7 +370,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -405,7 +405,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1024), strides=(0, 1024, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),
@@ -420,7 +420,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 1, 28, 28, 1, 1), strides=(31360, 0, 784, 0, 28, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -442,7 +442,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
@@ -462,7 +462,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 9, 7, 3, 3), strides=(2268, 0, 567, 0, 63, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -511,7 +511,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 3)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 3)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
@@ -624,7 +624,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1025, 2047), strides=(0, 0), offset=0, mask=((0, 1025), (1023, 2047)), contiguous=False), View(shape=(1024, 1024), strides=(1, 2048), offset=0, mask=None, contiguous=False))), src=()),)),)),
ast_const(dtypes.int, -1, st_src=(
@@ -639,7 +639,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(129, 255), strides=(0, 0), offset=0, mask=((0, 129), (127, 255)), contiguous=False), View(shape=(128, 128), strides=(1, 256), offset=0, mask=None, contiguous=False))), src=()),)),)),
ast_const(dtypes.int, -1, st_src=(
@@ -678,7 +678,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 1), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.MAX, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(Ops.MAX, (3,)), src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 169, 13, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),))
@@ -731,7 +731,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
@@ -749,7 +749,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 12, 31, 31, 1, 1, 1), strides=(11532, 0, 961, 31, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
@@ -767,7 +767,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 1), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.EXP2, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
@@ -791,7 +791,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 256, 14, 14, 1, 1, 1), strides=(50176, 0, 196, 14, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
@@ -809,7 +809,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
x5:=UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -851,7 +851,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 1, 1), strides=(180, 0, 30, 3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.MAX, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6, 7)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -875,7 +875,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(5, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.uchar, arg=None, src=(
UOp(Ops.ADD, dtypes.uint, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.uint, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.uint, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.CAST, dtypes.uint, arg=None, src=(
ast_const(dtypes.uchar, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(6, 9), strides=(0, 0), offset=0, mask=((0, 6), (4, 9)), contiguous=False), View(shape=(5, 5), strides=(1, 10), offset=0, mask=None, contiguous=False))), src=()),)),)),)),
@@ -895,7 +895,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.MAX, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6, 7)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
@@ -920,7 +920,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 32, 1, 1, 1, 5, 5, 256), strides=(0, 0, 6400, 0, 0, 0, 1280, 256, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 3, 4)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 3, 4)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
@@ -943,7 +943,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.MAX, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6, 7)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
@@ -969,7 +969,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)),
ast_const(dtypes.int, -1, st_src=(
@@ -987,7 +987,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 1, 1, 1), strides=(100352, 0, 784, 28, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
@@ -1007,7 +1007,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
@@ -1021,7 +1021,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
@@ -1035,7 +1035,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
@@ -1052,7 +1052,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 1, 1, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2, 3)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -1065,7 +1065,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (4,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 3), strides=(0, 0), offset=0, mask=((0, 3), (1, 3)), contiguous=False), View(shape=(2, 3, 2, 3, 3), strides=(0, 0, 1, 0, 4), offset=0, mask=((0, 2), (0, 3), (0, 2), (0, 3), (0, 2)), contiguous=False))), src=()),)),)),
x19:=ast_const(dtypes.int, -1, st_src=(
@@ -1078,7 +1078,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=3, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (4,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5), strides=(0, 0), offset=0, mask=((0, 4), (2, 5)), contiguous=False), View(shape=(2, 3, 2, 3, 3), strides=(0, 0, 0, 1, 6), offset=0, mask=None, contiguous=False))), src=()),)),)),
x19,)),)),
@@ -1093,7 +1093,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.bool, arg=None, src=(
@@ -1127,7 +1127,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)),
ast_const(dtypes.int, -1, st_src=(
@@ -1142,7 +1142,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 1, 1, 256, 1, 1, 256), strides=(0, 0, 65536, 0, 0, 256, 0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3, 4)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3, 4)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
@@ -1160,7 +1160,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 1), strides=(6, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -1178,7 +1178,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 20, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.bool, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.bool, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.MUL, dtypes.bool, arg=None, src=(
UOp(Ops.LOAD, dtypes.bool, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=1, src=()),
@@ -1220,7 +1220,7 @@ class TestLinearizerFailures(unittest.TestCase):
x9,)),
UOp(Ops.ADD, dtypes.half, arg=None, src=(
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
@@ -1249,7 +1249,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
@@ -1267,7 +1267,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.uchar, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.uchar, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.MUL, dtypes.uchar, arg=None, src=(
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()),
@@ -1279,7 +1279,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.WHERE, dtypes.int, arg=None, src=(
UOp(Ops.VALID, dtypes.bool, arg=None, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(50001, 99999), strides=(0, 0), offset=0, mask=((0, 50001), (49999, 99999)), contiguous=False), View(shape=(1024, 50000, 50000), strides=(0, 1, 100000), offset=0, mask=None, contiguous=False))), src=()),)),
@@ -1306,7 +1306,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
@@ -1326,7 +1326,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(W, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(

View File

@@ -8,7 +8,7 @@ from tinygrad.engine.search import Opt, OptOps
from tinygrad.engine.search import time_linearizer, bufs_from_lin
# stuff needed to unpack a kernel
from tinygrad.ops import UOp, Ops, BinaryOps
from tinygrad.ops import UOp, Ops
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
@@ -33,7 +33,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -68,7 +68,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 1, 1, 1), strides=(65536, 0, 1024, 32, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -85,7 +85,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -102,7 +102,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -119,7 +119,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -136,7 +136,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -153,7 +153,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -175,7 +175,7 @@ class TestLinearizerOverflowAlt(unittest.TestCase):
in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop()
ot_st = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)).to_uop()
prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2))
store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5)))))
store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (Ops.ADD, (7, 6, 5)))))
ast = UOp(Ops.SINK, src=(store,))
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=2)]
_test_overflow(ast, opts)
@@ -187,7 +187,7 @@ class TestLinearizerOverflowAlt(unittest.TestCase):
in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop()
ot_st = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)).to_uop()
prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2))
store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5)))))
store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (Ops.ADD, (7, 6, 5)))))
ast = UOp(Ops.SINK, src=(store,))
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=4, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=5, amt=2)]
_test_overflow(ast, opts)

View File

@@ -1,7 +1,7 @@
import unittest, functools, random
from typing import List
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes
from tinygrad.ops import MetaOps, BinaryOps, Ops
from tinygrad.ops import Ops
from tinygrad.helpers import CI, getenv, prod, Context
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.engine.schedule import create_schedule
@@ -51,6 +51,15 @@ class TestMultiTensor(unittest.TestCase):
assert lb.shape == (128,)
(X + X).realize()
def test_tensor_from_multi(self):
X = Tensor([1, 2], dtype=dtypes.int).shard_(devices_2, 0)
Y = Tensor(X.lazydata)
self.assertEqual(Y.device, Device.DEFAULT)
np.testing.assert_equal(X.numpy(), Y.numpy())
with self.assertRaises(AssertionError):
_ = Tensor(X.lazydata, dtype=dtypes.float)
def test_sharded_arange(self):
sharded_arange = Tensor.arange(1000).shard(devices_2, 0)
sharded_arange.realize()
@@ -481,7 +490,7 @@ class TestMultiTensor(unittest.TestCase):
for p in get_parameters(bn): p.shard_(devices_4).realize()
out = bn(t)
scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.outputs[0].device in devices_4 and sched.ast.op is not MetaOps.COPY]
scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.outputs[0].device in devices_4 and sched.ast.op is not Ops.COPY]
assert set(out.device for sched in scheds for out in sched.outputs) == set(devices_4), "should have ast on each shard device"
asts = [sched.ast for sched in scheds]
assert len(asts)
@@ -640,21 +649,21 @@ class TestMultiTensor(unittest.TestCase):
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is Ops.STORE
assert ast.src[2].op is BinaryOps.ADD
assert ast.src[2].op is Ops.ADD
assert ast.src[2].src[0].op is Ops.LOAD
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 1
t = 2 * t
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is Ops.STORE
assert ast.src[2].op is BinaryOps.MUL
assert ast.src[2].op is Ops.MUL
assert ast.src[2].src[0].src[1].op is Ops.CONST and ast.src[2].src[0].src[1].arg == 2
assert ast.src[2].src[1].op is Ops.LOAD
t = t + t.full_like(3)
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is Ops.STORE
assert ast.src[2].op is BinaryOps.ADD
assert ast.src[2].op is Ops.ADD
assert ast.src[2].src[0].op is Ops.LOAD
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 3

View File

@@ -216,6 +216,29 @@ class TestOps(unittest.TestCase):
for i in range(len(tor)):
helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True)
def test_meshgrid(self):
x, xt = torch.tensor([0.,1.,2.], requires_grad=True), Tensor([0.,1.,2.], requires_grad=True)
y, yt = torch.tensor([3.,4.,5.,6.], requires_grad=True), Tensor([3.,4.,5.,6.], requires_grad=True)
z, zt = torch.tensor([7.,8.,9.], requires_grad=True), Tensor([7.,8.,9.], requires_grad=True)
for indexing in ("ij", "xy"):
tor = torch.meshgrid(x, indexing=indexing)
ten = xt.meshgrid(indexing=indexing)
self.assertEqual(len(tor), len(ten))
for tor_i, ten_i in zip(tor, ten):
helper_test_op([], lambda: tor_i, lambda: ten_i)
tor = torch.meshgrid(x, y, indexing=indexing)
ten = xt.meshgrid(yt, indexing=indexing)
self.assertEqual(len(tor), len(ten))
for tor_i, ten_i in zip(tor, ten):
helper_test_op([], lambda: tor_i, lambda: ten_i)
tor = torch.meshgrid(x, torch.tensor(10., requires_grad=True), y, z, indexing=indexing)
ten = xt.meshgrid(Tensor(10., requires_grad=True), yt, zt, indexing=indexing)
self.assertEqual(len(tor), len(ten))
for tor_i, ten_i in zip(tor, ten):
helper_test_op([], lambda: tor_i, lambda: ten_i)
self.helper_test_exception([], lambda: torch.meshgrid(x, indexing="bad"), lambda: xt.meshgrid(indexing="bad"), expected=RuntimeError)
def test_arange(self):
helper_test_op([], lambda: torch.arange(10, dtype=torch.int32), lambda: Tensor.arange(10), forward_only=True)
helper_test_op([], lambda: torch.arange(36, dtype=torch.int32), lambda: Tensor.arange(36), forward_only=True)
@@ -879,15 +902,15 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot)
helper_test_op([(8,45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot)
helper_test_op([(65), (8,65,45)], lambda x,y: x.matmul(y), Tensor.dot)
self.helper_test_exception([(4), (1,2)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
self.helper_test_exception([(2,1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
self.helper_test_exception([(1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
self.helper_test_exception([(4), (1,2)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError)
self.helper_test_exception([(2,1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError)
self.helper_test_exception([(1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError)
def test_dot(self):
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5)
helper_test_op([(8,45,65), (8,65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5)
self.helper_test_exception([(2, 4), (1, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
self.helper_test_exception([(2, 1), (4, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
with self.assertRaises(AssertionError):
self.helper_test_exception([(2, 4), (1, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError)
self.helper_test_exception([(2, 1), (4, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError)
with self.assertRaises(RuntimeError):
a = Tensor(3.14)
a.matmul(a)
def test_mulacc_with_zero_strides(self):
@@ -954,7 +977,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(0), (0)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7)
def test_broadcastdot(self):
helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
with self.assertRaises(AssertionError):
with self.assertRaises(RuntimeError):
a = Tensor(3.14)
b = Tensor.ones(3,3)
a @ b
@@ -988,6 +1011,12 @@ class TestOps(unittest.TestCase):
self.helper_test_exception([()], lambda x: x.sum(1), lambda x: x.sum(1), expected=IndexError)
self.helper_test_exception([()], lambda x: x.sum((1,)), lambda x: x.sum((1,)), expected=IndexError)
def test_sum_acc_dtype(self):
helper_test_op([(45,3)], lambda x: x.sum(), lambda x: x.sum(acc_dtype=dtypes.float32))
if is_dtype_supported(dtypes.float64): helper_test_op([(45,3)], lambda x: x.sum(dtype=torch.float64), lambda x: x.sum(acc_dtype=dtypes.float64))
with self.assertRaises(AttributeError): Tensor([1.0, 2.0]).sum(acc_dtype="")
def test_sum_with_zeros_shape(self):
helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,)))
helper_test_op([(4, 0)], lambda x: x.sum(axis=(1,)))
@@ -1003,6 +1032,9 @@ class TestOps(unittest.TestCase):
helper_test_op([()], lambda x: x.prod(0))
helper_test_op([()], lambda x: x.prod(-1))
def test_prod_acc_dtype(self):
with self.assertRaises(AttributeError): Tensor([1.0, 2.0]).prod(acc_dtype="")
def test_min(self):
helper_test_op([(3,3)], lambda x: x.min())
helper_test_op([(45,3)], lambda x: x.min())
@@ -1339,27 +1371,36 @@ class TestOps(unittest.TestCase):
helper_test_op([(4,4)], lambda x: x[:, 1:2][0:1])
helper_test_op([(4,4)], lambda x: x[:, 1:2][:, 0:1])
def test_pad2d(self):
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)))
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4)), lambda x: x.pad2d(padding=(-1,2,-3,4)))
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad2d(padding=(1,2,3,4),value=5))
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4), value=5), lambda x: x.pad2d(padding=(-1,2,-3,4),value=5))
def test_pad(self):
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad(padding=(1,2,3,4)))
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4)), lambda x: x.pad(padding=(-1,2,-3,4)))
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad(padding=(1,2,3,4),value=5))
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4), value=5), lambda x: x.pad(padding=(-1,2,-3,4),value=5))
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=math.inf), lambda x: x.pad(padding=(1,2,3,4),value=math.inf))
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4), value=-math.inf),
lambda x: x.pad(padding=(-1,2,-3,4),value=-math.inf))
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)),lambda x: x.pad(((3,4),(1,2))))
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4)), lambda x: x.pad(((-3,4), (-1,2))))
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad(((3,4), (1,2)), value=5))
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=math.inf), lambda x: x.pad(((3,4), (1,2)), value=math.inf))
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=-math.inf), lambda x: x.pad(((3,4), (1,2)), value=-math.inf))
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,3,4), value=1), lambda x: x.pad(((3,4), None), value=1))
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,0,0), value=1), lambda x: x.pad((None, None), value=1))
# raise error for uneven pads
self.helper_test_exception([(3,3)], lambda x: torch.nn.functional.pad(x, (2,0,2)), lambda x: x.pad((2,0,2)),
expected=(RuntimeError, ValueError))
# raise error for too many or too little pads
self.helper_test_exception([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,0,0,1,0,3,0)), lambda x: x.pad((0,0,0,0,1,0,3,0)),
expected=(RuntimeError, ValueError))
x = Tensor.ones(3,3)
with self.assertRaises(ValueError): x.pad((None,(0,1),(3,0)))
with self.assertRaises(ValueError): x.pad(((0,1),))
def test_pad_reshape(self):
helper_test_op([(1, 2)],
lambda x: torch.nn.functional.pad(x, (0, 1, 1, 0)).reshape((3, 2)),
lambda x: x.pad2d((0, 1, 1, 0)).reshape((3, 2)), forward_only=True)
lambda x: x.pad((0, 1, 1, 0)).reshape((3, 2)), forward_only=True)
helper_test_op([(1, 2)],
lambda x: torch.nn.functional.pad(x, (0, 2, 1, 1)).reshape((4, 3)),
lambda x: x.pad2d((0, 2, 1, 1)).reshape((4, 3)), forward_only=True)
lambda x: x.pad((0, 2, 1, 1)).reshape((4, 3)), forward_only=True)
helper_test_op([(1, 1, 1, 2)],
lambda x: torch.nn.functional.pad(x, (0, 4, 2, 2, 1, 2, 0, 2)).reshape((4, 3, 6, 5)),
lambda x: x.pad(((0, 2), (1, 2), (2, 2), (0, 4))).reshape((4, 3, 6, 5)), forward_only=True)
@@ -1832,7 +1873,7 @@ class TestOps(unittest.TestCase):
def test_padding_add(self):
helper_test_op([(64,64), (60,60)],
lambda x,w: x+torch.nn.functional.pad(w, (2,2,2,2)),
lambda x,w: x+w.pad2d((2,2,2,2)))
lambda x,w: x+w.pad((2,2,2,2)))
def test_dilated_conv2d(self):
bs = 4
@@ -1844,34 +1885,40 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x,w,dilation=dilation).relu(),
lambda x,w: Tensor.conv2d(x,w,dilation=dilation).relu())
def test_maxpool2d_simple(self):
def test_max_pool2d_simple(self):
ksz = (2,2)
helper_test_op([(1,1,2,3)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz),
lambda x: Tensor.max_pool2d(x, kernel_size=ksz))
def test_maxpool2d(self):
def test_max_pool2d(self):
for ksz in [(2,2), (3,3), 2, 3, (3,2), (5,5), (5,1)]:
with self.subTest(kernel_size=ksz):
helper_test_op([(32,2,110,28)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz),
lambda x: Tensor.max_pool2d(x, kernel_size=ksz))
def test_maxpool2d_padding(self):
def test_max_pool2d_padding(self):
for ksz in [(2,2), (3,3), 2, 3, (3,2)]:
with self.subTest(kernel_size=ksz):
helper_test_op([(32,2,110,28)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz, padding=1),
lambda x: Tensor.max_pool2d(x, kernel_size=ksz, padding=1))
def test_maxpool2d_bigger_stride(self):
def test_max_pool2d_padding_int(self):
ksz = (2,2)
helper_test_op([(32,2,110,28)],
lambda x: torch.nn.functional.max_pool2d(x.int(), kernel_size=ksz, padding=1),
lambda x: Tensor.max_pool2d(x.int(), kernel_size=ksz, padding=1), forward_only=True)
def test_max_pool2d_bigger_stride(self):
for stride in [(2,3), (3,2), 2, 3]:
with self.subTest(stride=stride):
helper_test_op([(32,2,110,28)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=stride),
lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride))
def test_maxpool2d_bigger_stride_dilation(self):
def test_max_pool2d_bigger_stride_dilation(self):
for stride, dilation in zip([(2,3), (3,2), 2, 3, 4], [(3,2), (2,3), 2, 3, 6]):
with self.subTest(stride=stride):
helper_test_op([(32,2,110,28)],
@@ -1879,25 +1926,25 @@ class TestOps(unittest.TestCase):
lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride, dilation=dilation))
@unittest.skipIf( Device.DEFAULT in {"CUDA", "NV"}, "CUDA fails on this")
def test_maxpool2d_unit_stride(self):
def test_max_pool2d_unit_stride(self):
helper_test_op([(8, 2, 17, 14)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=1),
lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), stride=1))
def test_maxpool2d_smaller_stride(self):
def test_max_pool2d_smaller_stride(self):
for stride in [(2,3), (3,2), 2, 3]:
with self.subTest(stride=stride):
helper_test_op([(8, 2, 17, 14)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=stride),
lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), stride=stride))
def test_maxpool2d_dilation(self):
def test_max_pool2d_dilation(self):
for dilation in [(2, 3), (3, 2), 2, 3]:
helper_test_op([(8, 2, 17, 14)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), dilation=dilation),
lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), dilation=dilation))
def test_avgpool2d(self):
def test_avg_pool2d(self):
shape = (32,2,111,28)
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
with self.subTest(kernel_size=ksz):
@@ -1907,12 +1954,12 @@ class TestOps(unittest.TestCase):
# TODO fix edge case
@unittest.expectedFailure
def test_avgpool2d_failure(self):
def test_avg_pool2d_failure(self):
helper_test_op([(1,1,8,8)],
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(1,2), padding=(0,1), stride=(5,1)),
lambda x: Tensor.avg_pool2d(x, kernel_size=(1,2), padding=(0,1), stride=(5,1)), rtol=1e-5)
def test_avgpool2d_padding(self):
def test_avg_pool2d_padding(self):
shape = (32,2,111,28)
for ksz in [(2,2), (3,3), 2, 3, (3,2)]:
with self.subTest(kernel_size=ksz):
@@ -1920,7 +1967,7 @@ class TestOps(unittest.TestCase):
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1),
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1), rtol=1e-5)
def test_avgpool2d_padding_not_counted(self):
def test_avg_pool2d_padding_not_counted(self):
shape = (32,2,111,28)
for ksz in [(2,2), (3,3), 2, 3, (3,2)]:
with self.subTest(kernel_size=ksz):
@@ -1928,7 +1975,7 @@ class TestOps(unittest.TestCase):
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, count_include_pad=False),
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, count_include_pad=False), rtol=1e-5)
def test_global_avgpool2d(self):
def test_global_avg_pool2d(self):
helper_test_op([(32,2,111,28)],
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(111,28)),
lambda x: Tensor.avg_pool2d(x, kernel_size=(111,28)), rtol=1e-5)
@@ -2032,6 +2079,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(3, 3)], lambda x: x.repeat_interleave(6))
helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, 1))
helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, 0))
helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, -1))
helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, -2))
def test_simple_repeat(self):
repeats = [3, 3, 4]

View File

@@ -8,7 +8,7 @@ from tinygrad.dtype import dtypes
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import dedup, flatten, prod
from tinygrad.renderer.cstyle import CStyleLanguage
from tinygrad.ops import BinaryOps, UOp, Ops
from tinygrad.ops import UOp, Ops
from tinygrad.renderer import Program
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.lazy import LazyBuffer
@@ -34,7 +34,7 @@ class TestCStyleFailures(unittest.TestCase):
b = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
idx = UOp.const(dtypes.int, 0)
ld = UOp(Ops.LOAD, dtypes.int, (b.index(idx),))
alu = ld.alu(BinaryOps.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1))
alu = ld.alu(Ops.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1))
store = UOp.store(a.index(idx), alu)
sink = UOp(Ops.SINK, dtypes.void, (store,))
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))

View File

@@ -13,13 +13,12 @@ from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, Ops, graph_rewrite, track_rewrites
from tinygrad.ops import UOp, Ops, graph_rewrite, track_rewrites
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, view_left
from tinygrad.engine.realize import CompiledRunner, run_schedule
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
from tinygrad.engine.lazy import LazyBuffer, view_supported_devices
from test.helpers import ast_const, timeit
from extra.models.llama import precompute_freqs_cis
class KernelCountException(Exception): pass
@@ -41,9 +40,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], LazyBuffer], allowed:int, to_pr
# test the (sink) ops linearize
for s in sched:
if s.ast.op is not Ops.SINK: continue
l = Kernel(s.ast)
l.hand_coded_optimizations()
l.to_program()
get_runner(s.bufs[0].device, s.ast)
return sched
def _realize_weights(m):
@@ -311,7 +308,6 @@ class TestSchedule(unittest.TestCase):
img = Tensor.empty(64,64)
x = (img.sum(0) + img.sum(1))
out = x.relu()
del x # is 3 without this
check_schedule(out, 2)
#@unittest.skip("failing in old lazy")
@@ -335,6 +331,7 @@ class TestSchedule(unittest.TestCase):
d = (a+b).reshape(16,1)
check_schedule(d, 0, [c])
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_multi_permute_should_collapse(self):
a = Tensor.empty(4,4,4,4)
b = Tensor.empty(16)
@@ -1045,7 +1042,7 @@ class TestSchedule(unittest.TestCase):
b = r.sum(0) * 4
c = r.sum(1) * 2
schedule = check_schedule([b, c], 3)
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
# multireduce spec
def test_multireduce_simple_chase(self):
@@ -1056,7 +1053,7 @@ class TestSchedule(unittest.TestCase):
c = r.sum(1) + 12
np_r = (a.numpy() + (a.numpy().sum(0) + 6)).sum(0) * 2
# schedule = check_schedule([b,c], 3)
# self.assertIs(schedule[0].ast[0].src[0].arg, BinaryOps.MUL)
# self.assertIs(schedule[0].ast[0].src[0].arg, Ops.MUL)
schedule = check_schedule([b,c], 4)
run_schedule(schedule)
np.testing.assert_allclose(b.numpy(), np_r.sum(0) + 8, atol=1e-4, rtol=1e-4)
@@ -1069,7 +1066,7 @@ class TestSchedule(unittest.TestCase):
d = r.T * 4
e = r * d
schedule = check_schedule([d, e], 3)
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
# multireduce spec
def test_multireduce_push_permute_chase(self):
@@ -1080,7 +1077,7 @@ class TestSchedule(unittest.TestCase):
d = r.T * 4
e = r * (d + a).sum(2)
schedule = check_schedule([d, e], 3) # make sure it doesn't fuse
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
run_schedule(schedule)
np.testing.assert_allclose(d.numpy(), (a.numpy().sum(2) + b.numpy()).T * 4, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(e.numpy(), (a.numpy().sum(2) + b.numpy()) * (d.numpy() + a.numpy()).sum(2), atol=1e-4, rtol=1e-4)
@@ -1092,7 +1089,7 @@ class TestSchedule(unittest.TestCase):
r = a.sum(1) + c
d = r[:4] * b
schedule = check_schedule(d, 2)
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
# multireduce spec
def test_multireduce_push_shrink_chase(self):
@@ -1105,7 +1102,7 @@ class TestSchedule(unittest.TestCase):
out = r[:4] * b + d.sum(1)[:4]
# schedule = check_schedule(out, 2)
schedule = check_schedule(out, 3)
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
run_schedule(schedule)
np.testing.assert_allclose(out.numpy(), (a.numpy().sum(1) + c.numpy())[:4] * b.numpy() + d.numpy().sum(1)[:4], atol=1e-4, rtol=1e-4)
@@ -1290,16 +1287,16 @@ class TestSchedule(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported")
def test_bitcast_subbufer(self):
x = cast(LazyBuffer, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata)
a = x.alu(UnaryOps.EXP2).cast(dtypes.int32, True, allow_buffer_view=True)
a = x.alu(Ops.EXP2).cast(dtypes.int32, True, allow_buffer_view=True)
b = x.cast(dtypes.int32, True, allow_buffer_view=True)
b = a.alu(BinaryOps.ADD, b)
b = a.alu(Ops.ADD, b)
check_schedule(b, 2) # this should fuse when it makes sense
def test_bitcast_disable_subbufer(self):
x = cast(LazyBuffer, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata)
a = x.alu(UnaryOps.EXP2).cast(dtypes.int32, True, allow_buffer_view=False)
a = x.alu(Ops.EXP2).cast(dtypes.int32, True, allow_buffer_view=False)
b = x.cast(dtypes.int32, True, allow_buffer_view=False)
b = a.alu(BinaryOps.ADD, b)
b = a.alu(Ops.ADD, b)
check_schedule(b, 1)
def test_reduceop_reshape_dont_push(self):
@@ -1533,7 +1530,7 @@ class TestIndexing(unittest.TestCase):
def test_arange_view_op(self):
a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).contiguous()
assert isinstance(a.lazydata, LazyBuffer)
self.assertIs(a.lazydata.base.op, MetaOps.BUFFER_VIEW)
self.assertIs(a.lazydata.base.op, Ops.BUFFER_VIEW)
self.check_schedule(a, 1)
np.testing.assert_equal(a.numpy(), [[4, 5]])
@@ -1541,7 +1538,7 @@ class TestIndexing(unittest.TestCase):
def test_arange_shrink_copy(self):
a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).to("CLANG")
assert isinstance(a.lazydata, LazyBuffer)
self.assertIs(a.lazydata.base.op, MetaOps.COPY)
self.assertIs(a.lazydata.base.op, Ops.COPY)
self.check_schedule(a, 1)
np.testing.assert_equal(a.numpy(), [[4, 5]])
@@ -1549,8 +1546,8 @@ class TestIndexing(unittest.TestCase):
def test_arange_expand_copy(self):
a = Tensor.arange(4).reshape(2, 2, 1).expand(2, 2, 2).to("CLANG")
assert isinstance(a.lazydata, LazyBuffer)
self.assertIs(a.lazydata.base.op, MetaOps.COPY)
self.assertIs(a.lazydata.base.srcs[0].base.op, BinaryOps.ADD)
self.assertIs(a.lazydata.base.op, Ops.COPY)
self.assertIs(a.lazydata.base.srcs[0].base.op, Ops.ADD)
self.check_schedule(a, 1)
np.testing.assert_equal(a.numpy(), [[[0, 0], [1, 1]], [[2, 2], [3, 3]]])
@@ -1635,16 +1632,6 @@ class TestIndexing(unittest.TestCase):
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
self.assertEqual(swizzle_cnt(new_uop), 0)
def test_strongly_connected_DAG(self):
val = 1.0
a = Tensor(val).realize()
def f(a):
for _ in range(24): a = Tensor.stack(a, a)[0]
return a.item()
r, et = timeit(f, a)
self.assertEqual(r, val)
self.assertLess(et, 1600)
def test_no_rewrite_elementwise(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
@@ -1656,9 +1643,9 @@ class TestIndexing(unittest.TestCase):
def test_simple_store_reshape(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1)))
r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
r = r + ast_const(dtypes.int, 2, ())
r = r + 2
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink = graph_rewrite(sink, view_right)
# this AST first needs to swizzle, but it doesn't have implicit movementops
@@ -1668,50 +1655,12 @@ class TestIndexing(unittest.TestCase):
def test_no_reshape_reduceop(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1)))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
rsink = graph_rewrite(sink, view_right)
verify_ast(sink)
self.assertEqual(sink.key, rsink.key)
def test_reshape_many(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
for _ in range(24): r = r + ast_const(dtypes.int, 2, ())
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink, et = timeit(graph_rewrite, sink, view_right)
# this AST first needs to swizzle, but it doesn't have implicit movementops
with self.assertRaisesRegex(AssertionError, "swizzle"): verify_ast(sink)
verify_ast(rsink)
self.assertLessEqual(et, 1e3)
@unittest.skip("test is flaky")
def test_complexity(self):
SZ = 30 if getenv("BIG") else 10
sizes = [10*(i+1) for i in range(SZ)]
tms: List[float] = []
for sz in sizes:
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
for _ in range(sz): r = r + ast_const(dtypes.int, 2, ())
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink, et = timeit(graph_rewrite, sink, view_right)
with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
verify_ast(rsink)
tms.append(et)
if getenv("GRAPH_TIMING"):
import plotly.express as px
fig = px.line(x=sizes, y=tms, title="graph_rewrite time as ast grows")
fig.update_layout(paper_bgcolor="black", plot_bgcolor="black", font={"color":"white"},
yaxis={"gridcolor":"rgba(255, 255, 255, 0.3)"}, xaxis={"gridcolor":"rgba(255, 255, 255, 0.3)"})
fig.show()
change = tms[-1] / tms[0]
assert change <= SZ, f"bad complexity, time increased by {change:4.2f}x while input only grew {SZ}x"
@track_rewrites(named=True)
def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_left), view_right)
@@ -1749,10 +1698,9 @@ class TestSwizzle(unittest.TestCase):
# LazyBuffer to pre-rewrite AST
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,)))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0,)))
swizzle_r = UOp(Ops.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(()))
const = ast_const(dtypes.int, 1, ())
alu = swizzle_r+const
alu = swizzle_r+1
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),))
# graph rewrite
sink = swizzle_rewrite(sink)
@@ -1772,11 +1720,11 @@ class TestSwizzle(unittest.TestCase):
# LazyBuffer to pre-rewrite AST
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld1,), (Ops.ADD, (0,)))
ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop()))
r2 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld2,), (BinaryOps.ADD, (0,)))
r2 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld2,), (Ops.ADD, (0,)))
alu = UOp(Ops.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(Ops.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(()))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+ast_const(dtypes.int, 2, ()),),),)) # noqa: E501
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+2,),),)) # noqa: E501
# graph rewrite
sink = swizzle_rewrite(sink)
# verify output
@@ -1788,7 +1736,7 @@ class TestSwizzle(unittest.TestCase):
def test_swizzle_rewrite_alt(self):
swizzle = UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501

View File

@@ -3,7 +3,7 @@ import unittest
from test.helpers import ast_const
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.codegen.kernel import Kernel
from tinygrad.ops import UOp, Ops, BinaryOps
from tinygrad.ops import UOp, Ops
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search
from tinygrad.device import Device, Buffer
@@ -107,7 +107,7 @@ class TestBEAM(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.MAX, (1,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(

View File

@@ -55,10 +55,10 @@ class TestTensorVariable(unittest.TestCase):
ret = t.var().item()
assert ret == 0
def test_symbolic_pad2d(self):
def test_symbolic_pad(self):
vv = Variable("a", 1, 10).bind(2)
t = Tensor.ones(2, 2).contiguous()
t = t.pad2d([vv, vv, vv, vv]).mean()
t = t.pad([vv, vv, vv, vv]).mean()
ones = 4
zeros = 6+6+4+4+6+6
self.assertAlmostEqual(t.item(), ones/(ones+zeros))

View File

@@ -2,8 +2,7 @@ from typing import List
import unittest, time
from tinygrad import dtypes, Device
from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, Ops, UOp, KernelInfo
from tinygrad.ops import UPat, PatternMatcher
from tinygrad.ops import Ops, UOp, KernelInfo, UPat, PatternMatcher
from tinygrad.renderer import Renderer
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, sym
@@ -541,7 +540,7 @@ class TestExpander(unittest.TestCase):
@unittest.skip("no longer supported")
def test_reduce_known_axis(self):
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
sink = UOp(Ops.REDUCE, dtypes.int, (3*e1,e1), BinaryOps.ADD)
sink = UOp(Ops.REDUCE, dtypes.int, (3*e1,e1), Ops.ADD)
sink = expander_rewrite(sink)
assert sink.op is Ops.CONST
self.assertEqual(sink.arg, 3*(0+1+2+3))
@@ -549,7 +548,7 @@ class TestExpander(unittest.TestCase):
@unittest.skip("no longer supported")
def test_reduce_const(self):
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
sink = UOp(Ops.REDUCE, dtypes.int, (UOp.const(dtypes.int, 3), e1), BinaryOps.ADD)
sink = UOp(Ops.REDUCE, dtypes.int, (UOp.const(dtypes.int, 3), e1), Ops.ADD)
sink = expander_rewrite(sink)
assert sink.op is Ops.CONST
self.assertEqual(sink.arg, 3*4)
@@ -590,7 +589,7 @@ class TestExpander(unittest.TestCase):
def test_reduce_different_axis(self):
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e2 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
sink = UOp(Ops.REDUCE, dtypes.int, (e1,e2), BinaryOps.ADD)
sink = UOp(Ops.REDUCE, dtypes.int, (e1,e2), Ops.ADD)
sink = expander_rewrite(sink)
print(sink)

View File

@@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.helpers import CI, DEBUG, getenv, Context
from tinygrad.dtype import dtypes, DType
from tinygrad.device import Buffer, Device
from tinygrad.ops import Ops, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, KernelInfo, exec_alu, spec # noqa F401
from tinygrad.ops import Ops, UOp, UPat, KernelInfo, exec_alu, spec # noqa F401
from tinygrad.renderer import Program
from tinygrad.engine.schedule import create_schedule, to_si
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
@@ -29,7 +29,7 @@ def uop(uops:List[UOp], uop:Ops, dtype:Optional[DType], src:Tuple[UOp, ...], arg
def _test_single_value(vals, op, dts):
uops = []
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
buf_loads = [uop(uops, Ops.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0))]) for i, dtype in enumerate(dts))
@@ -45,7 +45,7 @@ def _test_single_value(vals, op, dts):
def _test_single_value_const(vals, op, dts):
uops = []
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
loads = (uop(uops, Ops.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
alu = uop(uops, op, output_dtype, loads)
@@ -103,49 +103,49 @@ class TestUOps(unittest.TestCase):
class TestFloatUOps(TestUOps):
@unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop')
def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a))
def test_exp2(self): self._test_uop_fxn(Ops.EXP2, lambda a: np.exp2(a))
@unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop')
def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan'))
def test_log2(self): self._test_uop_fxn(Ops.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan'))
@unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop')
def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a))
def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1/a if a != 0 else float('inf'))
def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
def test_sin(self): self._test_uop_fxn(Ops.SIN, lambda a: math.sin(a))
def test_recip(self): self._test_uop_fxn(Ops.RECIP, lambda a: 1/a if a != 0 else float('inf'))
def test_sqrt(self): self._test_uop_fxn(Ops.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b)
def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a<b)
def test_cmpne(self): self._test_bop_fxn(BinaryOps.CMPNE, lambda a,b: a!=b)
def test_add(self): self._test_bop_fxn(Ops.ADD, lambda a,b: a+b)
def test_mul(self): self._test_bop_fxn(Ops.MUL, lambda a,b: a*b)
def test_max(self): self._test_bop_fxn(Ops.MAX, lambda a,b: max(a,b))
def test_cmplt(self): self._test_bop_fxn(Ops.CMPLT, lambda a,b: a<b)
def test_cmpne(self): self._test_bop_fxn(Ops.CMPNE, lambda a,b: a!=b)
# MOD isn't tested on floats
def test_where(self):
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float, dtypes.float))
self._test_top_fxn(Ops.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float, dtypes.float))
@unittest.skipUnless(getenv("PYTHON"), "only python supports MULACC")
def test_mulacc(self):
self._test_top_fxn(TernaryOps.MULACC, lambda a,b,c: a*b+c, (dtypes.float, dtypes.float, dtypes.float))
self._test_top_fxn(Ops.MULACC, lambda a,b,c: a*b+c, (dtypes.float, dtypes.float, dtypes.float))
class TestNonFloatUOps(TestUOps):
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), (dtypes.int32, dtypes.int32))
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), (dtypes.int32, dtypes.int32))
def test_add_int32(self): self._test_bop_fxn(Ops.ADD, lambda a,b: int(a)+int(b), (dtypes.int32, dtypes.int32))
def test_mul_int32(self): self._test_bop_fxn(Ops.MUL, lambda a,b: int(a)*int(b), (dtypes.int32, dtypes.int32))
@unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts")
def test_shr_int32(self): self._test_bop_fxn(BinaryOps.SHR, lambda a,b: int(a)>>int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
def test_shr_int32(self): self._test_bop_fxn(Ops.SHR, lambda a,b: int(a)>>int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
@unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts")
def test_shl_int32(self): self._test_bop_fxn(BinaryOps.SHL, lambda a,b: int(a)<<int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
def test_shl_int32(self): self._test_bop_fxn(Ops.SHL, lambda a,b: int(a)<<int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
def test_div_int32(self):
self._test_bop_fxn(BinaryOps.IDIV, lambda a,b: int(a/b), (dtypes.int32, dtypes.int32), no_b_zero=True)
def test_and_int32(self): self._test_bop_fxn(BinaryOps.AND, lambda a,b: int(a)&int(b), (dtypes.int32, dtypes.int32))
def test_or_int32(self): self._test_bop_fxn(BinaryOps.OR, lambda a,b: int(a)|int(b), (dtypes.int32, dtypes.int32))
self._test_bop_fxn(Ops.IDIV, lambda a,b: int(a/b), (dtypes.int32, dtypes.int32), no_b_zero=True)
def test_and_int32(self): self._test_bop_fxn(Ops.AND, lambda a,b: int(a)&int(b), (dtypes.int32, dtypes.int32))
def test_or_int32(self): self._test_bop_fxn(Ops.OR, lambda a,b: int(a)|int(b), (dtypes.int32, dtypes.int32))
def test_mod_int32(self):
self._test_bop_fxn(BinaryOps.MOD,
self._test_bop_fxn(Ops.MOD,
lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], (dtypes.int32, dtypes.int32), no_b_zero=True)
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: int(a)<int(b), (dtypes.int32, dtypes.int32))
def test_cmpne_int32(self): self._test_bop_fxn(BinaryOps.CMPNE, lambda a,b: int(a)!=int(b), (dtypes.int32, dtypes.int32))
def test_cmplt_int32(self): self._test_bop_fxn(Ops.CMPLT, lambda a,b: int(a)<int(b), (dtypes.int32, dtypes.int32))
def test_cmpne_int32(self): self._test_bop_fxn(Ops.CMPNE, lambda a,b: int(a)!=int(b), (dtypes.int32, dtypes.int32))
@unittest.skipUnless(is_dtype_supported(dtypes.bool), "dtype not supported")
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), (dtypes.bool, dtypes.bool))
def test_mul_bool(self): self._test_bop_fxn(Ops.MUL, lambda a,b: bool(a) and bool(b), (dtypes.bool, dtypes.bool))
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "dtype not supported")
def test_where_float16(self):
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float16, dtypes.float16))
self._test_top_fxn(Ops.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float16, dtypes.float16))
class TestBoolUOps(TestUOps):
def _test_uop_bool_fxn(self, op, fxn):
@@ -166,72 +166,72 @@ class TestBoolUOps(TestUOps):
for c in [False, True]:
self._equal(f([a,b,c], op, (dtypes.bool, )*3), fxn(a,b,c))
def test_add_bool(self): self._test_bop_bool_fxn(BinaryOps.ADD, lambda a,b: a or b)
def test_mul_bool(self): self._test_bop_bool_fxn(BinaryOps.MUL, lambda a,b: a and b)
def test_xor_bool(self): self._test_bop_bool_fxn(BinaryOps.XOR, lambda a,b: a != b)
def test_and_bool(self): self._test_bop_bool_fxn(BinaryOps.AND, lambda a,b: a & b)
def test_or_bool(self): self._test_bop_bool_fxn(BinaryOps.OR, lambda a,b: a | b)
def test_cmpne_bool(self): self._test_bop_bool_fxn(BinaryOps.CMPNE, lambda a,b: a != b)
def test_cmplt_bool(self): self._test_bop_bool_fxn(BinaryOps.CMPLT, lambda a,b: a < b)
def test_where_bool(self): self._test_top_bool_fxn(TernaryOps.WHERE, lambda a,b,c: b if a else c)
def test_add_bool(self): self._test_bop_bool_fxn(Ops.ADD, lambda a,b: a or b)
def test_mul_bool(self): self._test_bop_bool_fxn(Ops.MUL, lambda a,b: a and b)
def test_xor_bool(self): self._test_bop_bool_fxn(Ops.XOR, lambda a,b: a != b)
def test_and_bool(self): self._test_bop_bool_fxn(Ops.AND, lambda a,b: a & b)
def test_or_bool(self): self._test_bop_bool_fxn(Ops.OR, lambda a,b: a | b)
def test_cmpne_bool(self): self._test_bop_bool_fxn(Ops.CMPNE, lambda a,b: a != b)
def test_cmplt_bool(self): self._test_bop_bool_fxn(Ops.CMPLT, lambda a,b: a < b)
def test_where_bool(self): self._test_top_bool_fxn(Ops.WHERE, lambda a,b,c: b if a else c)
class TestExecALU(TestUOps):
def test_sqrt(self):
self.assertEqual(exec_alu(UnaryOps.SQRT, dtypes.float, (0.0,)), 0.0)
self.assertEqual(exec_alu(Ops.SQRT, dtypes.float, (0.0,)), 0.0)
def test_div(self):
self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (8, 2)), 4)
self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (7, 3)), 2)
self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (7, -3)), -2)
self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (-50, 6)), -8)
self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (8, 2)), 4)
self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (7, 3)), 2)
self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (7, -3)), -2)
self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (-50, 6)), -8)
np.testing.assert_allclose(exec_alu(BinaryOps.MUL, dtypes.float32, (7.0, exec_alu(UnaryOps.RECIP, dtypes.float32, (3.0,)))), 2+(1.0/3.0))
np.testing.assert_allclose(exec_alu(BinaryOps.MUL, dtypes.float32, (7.0, exec_alu(UnaryOps.RECIP, dtypes.float32, (-3.0,)))), -2-(1.0/3.0))
np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIP, dtypes.float32, (3.0,)))), 2+(1.0/3.0))
np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIP, dtypes.float32, (-3.0,)))), -2-(1.0/3.0))
def test_recip(self):
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (8,)), 1/8)
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (7,)), 1/7)
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (-3,)), 1/-3)
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (-50,)), 1/-50)
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (8,)), 1/8)
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (7,)), 1/7)
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (-3,)), 1/-3)
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (-50,)), 1/-50)
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, ((32+521+3),)), 1/(32+521+3))
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, ((34**2),)), 1/(34**2))
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (10,)), 1/10)
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, ((32+521+3),)), 1/(32+521+3))
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, ((34**2),)), 1/(34**2))
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (10,)), 1/10)
def test_bool_cmplt(self):
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (False, False)), False)
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (False, True)), True)
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (True, False)), False)
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (True, True)), False)
self.assertEqual(exec_alu(Ops.CMPLT, dtypes.bool, (False, False)), False)
self.assertEqual(exec_alu(Ops.CMPLT, dtypes.bool, (False, True)), True)
self.assertEqual(exec_alu(Ops.CMPLT, dtypes.bool, (True, False)), False)
self.assertEqual(exec_alu(Ops.CMPLT, dtypes.bool, (True, True)), False)
def test_bool_cmpne(self):
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (False, False)), False)
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (False, True)), True)
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (True, False)), True)
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (True, True)), False)
self.assertEqual(exec_alu(Ops.CMPNE, dtypes.bool, (False, False)), False)
self.assertEqual(exec_alu(Ops.CMPNE, dtypes.bool, (False, True)), True)
self.assertEqual(exec_alu(Ops.CMPNE, dtypes.bool, (True, False)), True)
self.assertEqual(exec_alu(Ops.CMPNE, dtypes.bool, (True, True)), False)
def test_bool_where(self):
self.assertEqual(exec_alu(TernaryOps.WHERE, dtypes.bool, (False, False, False)), False)
self.assertEqual(exec_alu(TernaryOps.WHERE, dtypes.int, (False, 2, 4)), 4)
np.testing.assert_allclose(exec_alu(TernaryOps.WHERE, dtypes.float, (False, 2.2, 4.5)), 4.5)
self.assertEqual(exec_alu(Ops.WHERE, dtypes.bool, (False, False, False)), False)
self.assertEqual(exec_alu(Ops.WHERE, dtypes.int, (False, 2, 4)), 4)
np.testing.assert_allclose(exec_alu(Ops.WHERE, dtypes.float, (False, 2.2, 4.5)), 4.5)
def test_overflow(self):
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (250, 250)), 244)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (256, 0)), 0)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (0, -1)), 255)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (0, -1000)), 24)
self.assertEqual(exec_alu(Ops.ADD, dtypes.uint8, (250, 250)), 244)
self.assertEqual(exec_alu(Ops.ADD, dtypes.uint8, (256, 0)), 0)
self.assertEqual(exec_alu(Ops.ADD, dtypes.uint8, (0, -1)), 255)
self.assertEqual(exec_alu(Ops.ADD, dtypes.uint8, (0, -1000)), 24)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (127, 0)), 127)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-128, 0)), -128)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-100, -100)), 56)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-1000, -0)), 24)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-130, -0)), 126)
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (127, 0)), 127)
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (-128, 0)), -128)
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (-100, -100)), 56)
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (-1000, -0)), 24)
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (-130, -0)), 126)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (1, 1)), 2)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-128, 0)), -128)
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (1, 1)), 2)
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (-128, 0)), -128)
# test no truncate
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (250, 250), truncate_output=False), 500)
self.assertEqual(exec_alu(Ops.ADD, dtypes.uint8, (250, 250), truncate_output=False), 500)
class TestConstantFolding(unittest.TestCase):
def test_cast_const(self):
@@ -336,8 +336,8 @@ class TestAssembly(unittest.TestCase):
a2 = UOp(Ops.MUL, dtypes.int, (l1, c2))
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops[-1].op, BinaryOps.SHL)
self.assertEqual(uops[-2].op, BinaryOps.MUL)
self.assertEqual(uops[-1].op, Ops.SHL)
self.assertEqual(uops[-2].op, Ops.MUL)
def test_bitshift_right(self):
g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
@@ -348,8 +348,8 @@ class TestAssembly(unittest.TestCase):
a2 = UOp(Ops.IDIV, dtypes.int, (l1, c2))
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops[-1].op, BinaryOps.SHR)
self.assertEqual(uops[-2].op, BinaryOps.IDIV)
self.assertEqual(uops[-1].op, Ops.SHR)
self.assertEqual(uops[-2].op, Ops.IDIV)
class TestUOpMethod(unittest.TestCase):
@unittest.skip("uops lt no longer ordered")

View File

@@ -1,8 +1,7 @@
from typing import Dict, List, Optional
import unittest
from tinygrad.dtype import dtypes
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, \
graph_rewrite, contexts, track_rewrites
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, contexts, track_rewrites
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json
@track_rewrites()
@@ -35,7 +34,7 @@ class TestViz(unittest.TestCase):
def test_rewrite_twice(self):
pm = PatternMatcher([
(UPat.var("x")+UPat.var("x"), lambda x:x*2),
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))),
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(Ops.SHL, UOp.const(dtypes.int, 1))),
])
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
uops = helper_test_viz(a+a, pm)

View File

@@ -1,7 +1,7 @@
import unittest, math
from tinygrad import dtypes
from tinygrad.helpers import all_same
from tinygrad.ops import GroupOp, UOp, Ops, BinaryOps, exec_alu
from tinygrad.ops import GroupOp, UOp, Ops, exec_alu
from tinygrad.codegen.uopgraph import full_graph_rewrite
# Helper function to apply the graph rewrite
@@ -56,7 +56,7 @@ class TestFoldingAndReduction(unittest.TestCase):
const1 = UOp.const(dtypes.int32, 5)
const2 = UOp.const(dtypes.int32, 10)
const3 = UOp.const(dtypes.int32, 20)
optimized_sink = apply_rewrite((const1 + const2 + const3).reduce(BinaryOps.ADD))
optimized_sink = apply_rewrite((const1 + const2 + const3).reduce(Ops.ADD))
expected_sum = 5 + 10 + 20
self.assertEqual(optimized_sink.arg, expected_sum)
@@ -65,14 +65,14 @@ class TestFoldingAndReduction(unittest.TestCase):
const1 = UOp.const(dtypes.int32, 15)
const2 = UOp.const(dtypes.int32, 25)
rng = UOp.range(dtypes.int32, 0, 10, idx=0)
optimized_sink = apply_rewrite((const1 + const2).reduce(BinaryOps.ADD, rng))
optimized_sink = apply_rewrite((const1 + const2).reduce(Ops.ADD, rng))
expected_sum = 10 * (15 + 25)
self.assertEqual(optimized_sink.arg, expected_sum)
@unittest.skip("currently failing")
def test_full_graph_rewrite_range_reduction(self):
simple_range = UOp.range(dtypes.int32, 0, 5, idx=0)
optimized_sink = apply_rewrite(simple_range.reduce(BinaryOps.ADD, simple_range))
optimized_sink = apply_rewrite(simple_range.reduce(Ops.ADD, simple_range))
expected_sum = sum(range(5))
self.assertEqual(optimized_sink.arg, expected_sum)
@@ -80,7 +80,7 @@ class TestFoldingAndReduction(unittest.TestCase):
def test_full_graph_rewrite_simple_reduction_folding(self):
simple_range = UOp.range(dtypes.int32, 0, 4, idx=0)
add_uop = simple_range + UOp.const(dtypes.int32, 1)
optimized_sink = apply_rewrite(add_uop.reduce(BinaryOps.ADD, simple_range))
optimized_sink = apply_rewrite(add_uop.reduce(Ops.ADD, simple_range))
expected_sum = sum(i + 1 for i in range(4))
self.assertEqual(optimized_sink.arg, expected_sum)
@@ -89,7 +89,7 @@ class TestFoldingAndReduction(unittest.TestCase):
outer_range = UOp.range(dtypes.int32, 0, 8, 0)
inner_range = UOp.range(dtypes.int32, 0, 4, 1)
expr = (outer_range * 10) + inner_range
optimized_reduce_uop = apply_rewrite(expr.reduce(BinaryOps.ADD, outer_range, inner_range))
optimized_reduce_uop = apply_rewrite(expr.reduce(Ops.ADD, outer_range, inner_range))
self.assertEqual(optimized_reduce_uop.op, Ops.CONST)
self.assertEqual(optimized_reduce_uop.arg, sum((i * 10) + j for i in range(8) for j in range(4)))
@@ -104,7 +104,7 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
def test_full_graph_rewrite_division_folding_with_define_var(self):
n_var_uop = UOp.variable('n', 1, 1000)
optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3)
self.assertEqual(optimized_div_uop.op, BinaryOps.MUL)
self.assertEqual(optimized_div_uop.op, Ops.MUL)
self.assertEqual(optimized_div_uop.src[1].arg, 2)
def test_full_graph_rewrite_complex_mod_div_folding(self):

View File

@@ -1,6 +1,6 @@
import unittest, itertools
from tinygrad.dtype import dtypes
from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, UnaryOps, GroupOp # noqa: F401
from tinygrad.ops import Ops, UOp, GroupOp # noqa: F401
from tinygrad.ops import PatternMatcher, UPat
class TestPatternMatcher(unittest.TestCase):
@@ -140,9 +140,9 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c2), None)
# that CONST/ALU -> ALU/CONST rewrite is now instant
"""
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.ALU))), lambda x: x)])
c4 = UOp(UOps.ALU, dtypes.float, (c1,c3), BinaryOps.ADD)
c5 = UOp(UOps.ALU, dtypes.float, (c3,c1), BinaryOps.ADD)
matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=(UPat(Ops.CONST), UPat(GroupOp.ALU))), lambda x: x)])
c4 = UOp(Ops.ADD, dtypes.float, (c1,c3))
c5 = UOp(Ops.ADD, dtypes.float, (c3,c1))
self.assertEqual(matcher.rewrite(c3), None)
self.assertEqual(matcher.rewrite(c4), c4)
self.assertEqual(matcher.rewrite(c5), None)

View File

@@ -109,7 +109,6 @@ class TestShapeTrackerAddVariable(unittest.TestCase):
vm2 = View(shape=(var_i, var_j, 3), strides=(var_j*3, 3, 1), offset=0, mask=None, contiguous=True)
ShapeTracker((vm1,)) + ShapeTracker((vm2,))
@unittest.skip("two vars not supported")
def test_merge_symbolic_views_2(self):
var_i = Variable('i', 1, 10)
var_j = Variable('j', 1, 10)

View File

@@ -455,6 +455,13 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
def test_where_removal(self):
cond = Variable("a", 0, 3).lt(2)
u1, u0 = cond.ufix(1), cond.ufix(0)
self.helper_test_variable(cond, 0, 1, "(a<2)")
self.helper_test_variable(cond.where(u1, u0), 0, 1, "(a<2)")
self.helper_test_variable(cond.where(u1, u0).where(u1, u0), 0, 1, "(a<2)")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):
MIN, MAX = 0, 10