Remove UnaryOps, BinaryOps, TernaryOps, MetaOps [pr] (#7725)

* remove unaryops

* remove ternaryops

* remove metaops

* hotfix

* remove binaryops

* hotfix: test_pattern_matcher

---------

Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
ignaciosica
2024-11-16 09:56:56 -03:00
committed by GitHub
parent 22da31b223
commit 597a239e28
33 changed files with 473 additions and 478 deletions

View File

@@ -39,7 +39,7 @@ DEVICE = "CLANG" # NOTE: you can change this!
import struct
from tinygrad.dtype import dtypes
from tinygrad.device import Buffer, Device
from tinygrad.ops import BinaryOps, MetaOps, UOp, Ops
from tinygrad.ops import UOp, Ops
from tinygrad.shape.shapetracker import ShapeTracker
# allocate some buffers + load in values
@@ -81,15 +81,15 @@ from tinygrad.engine.realize import run_schedule
from tinygrad.engine.schedule import create_schedule
# allocate some values + load in values
a = LazyBuffer.metaop(MetaOps.EMPTY, (1,), dtypes.int32, DEVICE)
b = LazyBuffer.metaop(MetaOps.EMPTY, (1,), dtypes.int32, DEVICE)
a = LazyBuffer.metaop(Ops.EMPTY, (1,), dtypes.int32, DEVICE)
b = LazyBuffer.metaop(Ops.EMPTY, (1,), dtypes.int32, DEVICE)
a.buffer.allocate().copyin(memoryview(bytearray(struct.pack("I", 2))))
b.buffer.allocate().copyin(memoryview(bytearray(struct.pack("I", 3))))
del a.srcs
del b.srcs
# describe the computation
out = a.alu(BinaryOps.ADD, b)
out = a.alu(Ops.ADD, b)
# schedule the computation as a list of kernels
sched = create_schedule([out])

View File

@@ -8,7 +8,7 @@ from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GlobalCou
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import get_kernel, run_schedule
from tinygrad.engine.memory import memory_planner
from tinygrad.ops import MetaOps, Ops
from tinygrad.ops import Ops
TIMING = getenv("TIMING")

View File

@@ -2,7 +2,7 @@
from typing import Tuple
from tinygrad import Variable
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.ops import UOp, Ops, KernelInfo, TernaryOps, BinaryOps, UnaryOps, MetaOps
from tinygrad.ops import UOp, Ops, KernelInfo
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View

View File

@@ -9,7 +9,7 @@ from tinygrad.engine.realize import get_runner, CompiledRunner
from test.external.fuzz_linearizer import get_fuzz_rawbufs
from tinygrad.codegen.kernel import Kernel
from tinygrad.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer
from tinygrad.ops import LazyOp, Ops, ReduceOps, BufferOps, MemBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
@@ -26,12 +26,12 @@ class TestNV(unittest.TestCase):
TestNV.addr = struct.pack("QQ", TestNV.b.lazydata.buffer._buf.va_addr, TestNV.a.lazydata.buffer._buf.va_addr)
def test_oor_kernels(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
opts = [Opt(op=OptOps.TC, axis=6, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=2)] # noqa: E501
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["NV"])
def test_error_on_huge_dims(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 683, 1), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 683, 1), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=2)] # noqa: E501
with self.assertRaises(RuntimeError) as cm:
lin = Kernel(ast)
@@ -43,7 +43,7 @@ class TestNV(unittest.TestCase):
def test_buf4_usage(self):
TestNV.along = Tensor([105615], device="NV").realize()
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SIN, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.SIN, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501
temp_runner = get_runner(TestNV.d0.dname, (ast,))
temp_runner([TestNV.b.lazydata.buffer, TestNV.along.lazydata.buffer], var_vals={})
val = TestNV.b.lazydata.buffer.as_buffer().cast("f")[0]

View File

@@ -2,7 +2,7 @@
import unittest
from tinygrad import Device
from tinygrad.ops import UOp, Ops, BinaryOps
from tinygrad.ops import UOp, Ops
from tinygrad.engine.search import Opt, OptOps
from tinygrad.dtype import dtypes
from tinygrad.shape.shapetracker import ShapeTracker
@@ -20,7 +20,7 @@ class TestOpenpilotValidhack(unittest.TestCase):
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.MAX, dtypes.float, arg=None, src=(
x5:=UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8, 9, 10)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 8, 9, 10)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(

View File

@@ -25,7 +25,7 @@ from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.engine.search import get_kernel_actions, bufs_from_lin
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG, Timing
from tinygrad.ops import UnaryOps, UOp, Ops
from tinygrad.ops import UOp, Ops
from tinygrad.device import is_dtype_supported
def on_linearizer_will_run(): pass
@@ -252,7 +252,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2, opts_list=None):
def _is_simple(lin: Kernel) -> bool:
if len(lin.ast.src) > 1: return False
ast:UOp = lin.ast.src[0]
if ast.src[0].op is UnaryOps.CAST and ast.src[0].src[0].op is Ops.LOAD: return True
if ast.src[0].op is Ops.CAST and ast.src[0].src[0].op is Ops.LOAD: return True
return False
if __name__ == "__main__":

View File

@@ -6,7 +6,7 @@ from tinygrad.engine.realize import capturing, lower_schedule_item
from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.engine.schedule import LBScheduleItem, _graph_schedule, ScheduleItem
from tinygrad.ops import MetaOps
from tinygrad.ops import Ops
from tinygrad.tensor import Tensor, _to_np_dtype
ctx_vars = { MULTIOUTPUT: (0, 1) }
@@ -33,7 +33,7 @@ def fuzz_schedule(outs:List[LazyBuffer]):
for lsi in ts:
for out in lsi.outputs:
# freeze assign state before exec
if out.op is MetaOps.ASSIGN:
if out.op is Ops.ASSIGN:
prerealized[out] = out.buffer.as_buffer()
assign_targets[out.srcs[1]] = out
for x in lsi.inputs:
@@ -50,9 +50,9 @@ def fuzz_schedule(outs:List[LazyBuffer]):
rawbufs: Dict[LazyBuffer, Buffer] = {}
for lsi in ts:
for out in lsi.outputs:
base = rawbufs[lsi.inputs[0]].base if out.op is MetaOps.BUFFER_VIEW else None
base = rawbufs[lsi.inputs[0]].base if out.op is Ops.BUFFER_VIEW else None
rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype, base=base)
if out.op is MetaOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
if out.op is Ops.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
for x in lsi.inputs:
if x not in rawbufs:
# override the assign_target after ASSIGN

View File

@@ -42,9 +42,9 @@ def gt(expr, rng=None):
return expr > rng, rng
# NOTE: you have to replace these for this test to pass
from tinygrad.ops import python_alu, BinaryOps
python_alu[BinaryOps.MOD] = lambda x,y: x%y
python_alu[BinaryOps.IDIV] = lambda x,y: x//y
from tinygrad.ops import python_alu, Ops
python_alu[Ops.MOD] = lambda x,y: x%y
python_alu[Ops.IDIV] = lambda x,y: x//y
if __name__ == "__main__":
ops = [add_v, div, mul, add_num, mod]

View File

@@ -154,7 +154,7 @@ class TestReduceOpsConstFolding(unittest.TestCase):
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).sum())
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).sum().numpy(), 4)
# NOTE: cannot just count the non-padded area because some UnaryOps f do not have f(0) = 0.
# NOTE: cannot just count the non-padded area because some Ops f do not have f(0) = 0.
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).exp().sum())
np.testing.assert_allclose(Tensor.ones(4).pad(((1, 1),)).exp().sum().numpy(), 4 * math.e + 2)

View File

@@ -3,7 +3,7 @@ import numpy as np
import unittest
from tinygrad import Tensor, Device, dtypes
from tinygrad.ops import Ops
from tinygrad.engine.lazy import LazyBuffer, MetaOps
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.engine.schedule import create_schedule
class TestLazyBuffer(unittest.TestCase):
@@ -95,24 +95,24 @@ class TestReduceOp(unittest.TestCase):
class TestView(unittest.TestCase):
def test_all_masked_out(self):
# start with non CONST MetaOps
# start with non CONST Ops
a = Tensor.rand(10, 10)
assert a.lazydata.base.op is not MetaOps.CONST
assert a.lazydata.base.op is not Ops.CONST
# all masked out, degrades to const 0
b = a.pad(((0, 10), None))[10:]
assert b.shape == (10, 10)
assert b.lazydata.base.op is MetaOps.CONST and b.lazydata.base.arg == 0
assert b.lazydata.base.op is Ops.CONST and b.lazydata.base.arg == 0
# mask out dim = 1 works too
b = a.pad((None, (0, 10)))[:, 10:]
assert b.shape == (10, 10)
assert b.lazydata.base.op is MetaOps.CONST and b.lazydata.base.arg == 0
assert b.lazydata.base.op is Ops.CONST and b.lazydata.base.arg == 0
# partial masked out does not degrade into CONST
b = a.pad(((0, 5), None))[5:]
assert b.shape == (10, 10)
assert b.lazydata.base.op is not MetaOps.CONST
assert b.lazydata.base.op is not Ops.CONST
if __name__ == "__main__":
unittest.main()

View File

@@ -6,7 +6,7 @@ from dataclasses import replace
from test.helpers import ast_const
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel
from tinygrad.codegen.lowerer import get_grouped_dims
from tinygrad.ops import UOp, Ops, BinaryOps, TernaryOps, UnaryOps, GroupOp
from tinygrad.ops import UOp, Ops, GroupOp
from tinygrad.device import Device, Buffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
@@ -109,10 +109,10 @@ class TestLinearizer(unittest.TestCase):
st_x = x.lazydata.st
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((1, 32)).expand((32, 32)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (1,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (1,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((32, 1)).to_uop()))
diff = second_x + first_reduce*ast_const(dtypes.float, -1, (32, 1))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (0,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (0,)))
store = UOp(Ops.STORE, dtypes.void, (g0, ShapeTracker.from_shape((1, 1)).to_uop(), second_reduce))
sink = UOp(Ops.SINK, src=(store,))
opts = [
@@ -145,10 +145,10 @@ class TestLinearizer(unittest.TestCase):
st_x = x.lazydata.st
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((27, 1, 32, 5)).expand((27, 32, 32, 5)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((27, 32, 1, 5)).to_uop()))
diff = second_x + first_reduce*ast_const(dtypes.float, -1, (27, 32, 1, 5))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
sink = UOp(Ops.SINK, src=(store,))
opts = [
@@ -207,13 +207,13 @@ class TestLinearizer(unittest.TestCase):
x2 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
g0, g1, g2, g3 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(4)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x0.lazydata.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g2, x1.lazydata.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5)).to_uop()))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 32, 32, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (2,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (2,)))
third_x = UOp(Ops.LOAD, dtypes.float, (g3, x2.lazydata.st.reshape((27, 32, 1, 1, 5)).to_uop()))
mul = (third_x*second_reduce)
third_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (mul,), (BinaryOps.ADD, (1,)))
third_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (mul,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 1, 5)).to_uop(), third_reduce))
sink = UOp(Ops.SINK, src=(store,))
wanna_output = (x2.numpy()*(x1.numpy()-x0.numpy().sum(axis=1, keepdims=True)).sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,1,5)
@@ -234,11 +234,11 @@ class TestLinearizer(unittest.TestCase):
st = x.lazydata.st
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, st.reshape((8, 1, 32, 8, 1, 16)).expand((8, 32, 32, 8, 16, 16)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2, 5)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2, 5)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, st.reshape((8, 32, 1, 8, 16, 1)).to_uop()))
neg_first_reduce = first_reduce * ast_const(dtypes.float, -1, (8, 32, 1, 8, 16, 1))
squares = (second_x+neg_first_reduce)
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (1, 4)))
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (1, 4)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((8, 1, 1, 8, 1, 1)).to_uop(), squares_sum,))
sink = UOp(Ops.SINK, src=(store,))
wanna_output = (x.numpy()-x.numpy().sum(axis=(1,3), keepdims=True)).sum(axis=(1,3)).reshape((8,1,1,8,1,1))
@@ -285,10 +285,10 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).softmax(1).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
sink = UOp(Ops.SINK, src=(store,))
opts = [
@@ -317,11 +317,11 @@ class TestLinearizer(unittest.TestCase):
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop()))
first_x_p = UOp(Ops.LOAD, dtypes.float, (g2, x_p.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
first_reduce_p = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x_p.alu(UnaryOps.EXP2),), (BinaryOps.ADD, (2,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
first_reduce_p = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x_p.alu(Ops.EXP2),), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1)).to_uop()))
diff = (second_x+(first_reduce + first_reduce_p)*ast_const(dtypes.float, -1, (4, 32, 1)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((4, 1, 1)).to_uop(), second_reduce))
sink = UOp(Ops.SINK, src=(store,))
opts = [
@@ -352,10 +352,10 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store0 = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
second_out = second_reduce * ast_const(dtypes.float, 1/15, (27, 1, 1, 5))
store1 = UOp(Ops.STORE, src=(g1, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_out))
@@ -375,10 +375,10 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store0 = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
store1 = UOp(Ops.STORE, src=(g1, ShapeTracker(views=(View(shape=(27,15,1,5), strides=(5,0,1,1), offset=0, mask=None, contiguous=False),)).to_uop(), first_reduce)) # noqa: E501
wanna_output0 = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
@@ -399,10 +399,10 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop()))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
sink = UOp(Ops.SINK, src=(store,))
opts = [[Opt(OptOps.UNROLL, 0, 3), Opt(OptOps.UNROLL, 0, 3)]]
@@ -415,10 +415,10 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop()))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
sink = UOp(Ops.SINK, src=(store,))
opts = [[Opt(OptOps.UPCAST, 0, 3)]]
@@ -434,10 +434,10 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(27, 12, 5, dtype=dtypes.float).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 12, 1, 5)).to_uop()))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 12, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
sink = UOp(Ops.SINK, src=(store,))
opts = [[Opt(OptOps.GROUPTOP, 0, 3), Opt(OptOps.GROUPTOP, 1, 3)]]
@@ -450,13 +450,13 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop()))
squares = (second_x+neg_mean)*(second_x+neg_mean)
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,)))
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,)))
variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1))
std = variance.alu(UnaryOps.SQRT)
std = variance.alu(Ops.SQRT)
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), std))
sink = UOp(Ops.SINK, src=(store,))
wanna_output = x.numpy().std(axis=2, ddof=0).reshape((15,25,1,1))
@@ -468,13 +468,13 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
neg_mean = first_reduce * ast_const(dtypes.float, -0.04, (15, 25, 1, 35))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).to_uop()))
squares = (second_x+neg_mean)*(second_x+neg_mean)
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (1,)))
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (1,)))
variance = squares_sum * ast_const(dtypes.float, 0.04, (15, 1, 1, 35))
std = variance.alu(UnaryOps.SQRT)
std = variance.alu(Ops.SQRT)
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 1, 1, 35)).to_uop(), std))
sink = UOp(Ops.SINK, src=(store,))
wanna_output = x.numpy().std(axis=1, ddof=0).reshape((15,1,1,35))
@@ -488,13 +488,13 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1))
second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop()))
squares = (second_x+neg_mean)*(second_x+neg_mean)
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,)))
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,)))
variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1))
std = variance.alu(UnaryOps.SQRT)
std = variance.alu(Ops.SQRT)
store_mean = UOp(Ops.STORE, src=(g1, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), neg_mean))
store_std = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), std))
sink = UOp(Ops.SINK, src=(store_std, store_mean))
@@ -511,13 +511,13 @@ class TestLinearizer(unittest.TestCase):
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
# push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
neg_mean = first_reduce * ast_const(dtypes.float, -0.03125, (3, 27, 32, 1))
# store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 32, 1)).to_uop(), mean))
# verify_lazyop(store)
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 32, 1)).to_uop()))
squares = (second_x+neg_mean)*(second_x+neg_mean)
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,)))
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,)))
variance = squares_sum * ast_const(dtypes.float, 0.03125, (3, 27, 1, 1))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 1, 1)).to_uop(), variance))
sink = UOp(Ops.SINK, src=(store,))
@@ -532,13 +532,13 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.rand(4, 32).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32)).to_uop()))
max_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.MAX, (2,)))
max_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.MAX, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1,)).to_uop()))
centered_x = second_x+max_x*ast_const(dtypes.float, -1, (4, 32, 1))
exp_x = centered_x.alu(UnaryOps.EXP2)
sum_exp_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (exp_x,), (BinaryOps.ADD, (1,)))
# y = exp_x * sum_exp_x.alu(UnaryOps.RECIP) # kernels cannot do a return to full shape
recip_sum_exp_x = sum_exp_x.alu(UnaryOps.RECIP)
exp_x = centered_x.alu(Ops.EXP2)
sum_exp_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (exp_x,), (Ops.ADD, (1,)))
# y = exp_x * sum_exp_x.alu(Ops.RECIP) # kernels cannot do a return to full shape
recip_sum_exp_x = sum_exp_x.alu(Ops.RECIP)
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((4,1,1)).to_uop(), recip_sum_exp_x))
sink = UOp(Ops.SINK, src=(store,))
expected = 1/np.exp2(x.numpy() - x.numpy().max(axis=-1, keepdims=True)).sum(axis=-1, keepdims=True).reshape(4,1,1)
@@ -556,7 +556,7 @@ class TestLinearizer(unittest.TestCase):
View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False)))
arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384))
arange_axis = (3,)
arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis))
arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (Ops.ADD, arange_axis))
output_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape))
out = arange+ast_const(dtypes.int, -1, output_shape)
store = UOp(Ops.STORE, src=(UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0), ShapeTracker.from_shape(output_shape).to_uop(), out))
@@ -573,7 +573,7 @@ class TestLinearizer(unittest.TestCase):
# TODO: do this arange broadcast in the scheduler
arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384))
arange_axis = (3,)
arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis))
arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (Ops.ADD, arange_axis))
arange_out_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape))
arange = arange+ast_const(dtypes.int, -1, arange_out_shape)
# p2: the indexing
@@ -581,10 +581,10 @@ class TestLinearizer(unittest.TestCase):
data1 = (g1, ShapeTracker.from_shape(dataset.shape).reshape((1, 16384, 256, 1)).expand(arange_out_shape).to_uop())
idxs = Tensor([0,3,5,6]).realize()
data2 = (g2, ShapeTracker.from_shape((4,)+(1,)*(len(arange_out_shape)-1)).expand(arange_out_shape).to_uop())
arange_eq = arange.alu(BinaryOps.CMPNE, UOp(Ops.LOAD, dtypes.int, data2)).alu(BinaryOps.CMPNE, ast_const(dtypes.bool, True, arange_out_shape))
arange_eq = arange.alu(Ops.CMPNE, UOp(Ops.LOAD, dtypes.int, data2)).alu(Ops.CMPNE, ast_const(dtypes.bool, True, arange_out_shape))
reduce_input = UOp(Ops.LOAD, dataset.dtype, data1)*UOp(Ops.CAST, dataset.dtype.scalar(), src=(arange_eq,))
out_axis = (1,)
out = UOp(Ops.REDUCE_AXIS, reduce_input.dtype, (reduce_input,), (BinaryOps.ADD, out_axis))
out = UOp(Ops.REDUCE_AXIS, reduce_input.dtype, (reduce_input,), (Ops.ADD, out_axis))
output_shape = tuple(1 if i in out_axis else s for i,s in enumerate(arange_out_shape))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape(output_shape).to_uop(), out))
sink = UOp(Ops.SINK, src=(store,))
@@ -605,7 +605,7 @@ class TestLinearizer(unittest.TestCase):
ast_const(dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), val=10),
UOp(Ops.MUL, dtypes.int, arg=None, src=(
ast_const(dtypes.int, -1, (1, 20, 1)),
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.MAX, (0,)), src=(
UOp(Ops.MUL, dtypes.int, arg=None, src=(
UOp(Ops.CAST, dtypes.int, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
@@ -618,7 +618,7 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa E501
ast_const(dtypes.bool, True, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)), # noqa E501
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (2,)), src=(
ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 20, 10), strides=(1, 0, 20), offset=0, mask=None, contiguous=False)))),)), # noqa E501
ast_const(dtypes.int, 10, (10, 20, 1)))),)),)),)),)),
ast_const(dtypes.int, -1, (1, 20, 1)),)),)),))
@@ -637,7 +637,7 @@ class TestLinearizer(unittest.TestCase):
ast_const(dtypes.int, 200, (1, 1)),
UOp(Ops.MUL, dtypes.int, arg=None, src=(
ast_const(dtypes.int, -1, (1, 1)),
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.MAX, (0,)), src=(
UOp(Ops.MUL, dtypes.int, arg=None, src=(
UOp(Ops.CAST, dtypes.int, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
@@ -650,7 +650,7 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa: E501
ast_const(dtypes.bool, True, (200, 1)),)),)),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(201, 399), strides=(0, 0), offset=0, mask=((0, 201), (199, 399)), contiguous=False), View(shape=(200, 200), strides=(1, 400), offset=0, mask=None, contiguous=False)))),)), # noqa: E501
ast_const(dtypes.int, 200, (200, 1)),)),)),)),)),)),
ast_const(dtypes.int, -1, (1, 1)),)),)),))
@@ -672,16 +672,16 @@ class TestLinearizer(unittest.TestCase):
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop()))
x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop()))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.ADD, (1,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),),(BinaryOps.ADD, (0,)))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.ADD, (1,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),),(Ops.ADD, (0,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,1,N)).to_uop(), r1))
sink = UOp(Ops.SINK, src=(store,))
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=0, keepdims=True)).sum(axis=0).reshape(1,1,N)], opts=opts)
x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)).to_uop()))
x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, N, 1)).to_uop()))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.ADD, (2,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (BinaryOps.ADD, (1,)))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.ADD, (2,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((N,1,1)).to_uop(), r1))
sink = UOp(Ops.SINK, src=(store,))
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(N,1,1)], opts=opts)
@@ -699,16 +699,16 @@ class TestLinearizer(unittest.TestCase):
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop()))
x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop()))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.MAX, (1,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),), (BinaryOps.MAX, (0,)))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.MAX, (1,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),), (Ops.MAX, (0,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,1,N)).to_uop(), r1))
sink = UOp(Ops.SINK, src=(store,))
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().max(axis=0, keepdims=True)).max(axis=0).reshape(1,1,N)], opts=opts)
x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)).to_uop()))
x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, N, 1)).to_uop()))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.MAX, (2,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (BinaryOps.MAX, (1,)))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.MAX, (2,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (Ops.MAX, (1,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((N,1,1)).to_uop(), r1))
sink = UOp(Ops.SINK, src=(store,))
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().max(axis=1, keepdims=True)).max(axis=1).reshape(N,1,1)], opts=opts)
@@ -735,7 +735,7 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
ast_const(dtypes.float, 0.5*N, (N, 1, 1)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
@@ -743,7 +743,7 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
ast_const(dtypes.float, 0.75*N, (N, N, 1)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.LOAD, dtypes.float, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
ld0.to_uop(),)),)),)),
@@ -768,7 +768,7 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
ast_const(dtypes.float, 0.5*N, (1, 1, N)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -776,7 +776,7 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
ast_const(dtypes.float, 0.75*N, (N, 1, N)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.LOAD, dtypes.float, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
ld0.to_uop(),)),)),)),
@@ -804,7 +804,7 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
ast_const(dtypes.float, 0.5*N, (1, 1, 1, 1)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 1)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
@@ -812,7 +812,7 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
ast_const(dtypes.float, 0.75*N, (N, N, 1, 1)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2, 3)), src=(
UOp(Ops.LOAD, dtypes.float, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, N, N), strides=(0, 0, N, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501
@@ -831,7 +831,7 @@ class TestLinearizer(unittest.TestCase):
def test_end_local(self):
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=i) for i in range(2)]
load = UOp(Ops.LOAD, dtypes.int, (g1, ShapeTracker.from_shape((32,)).to_uop()))
reduce = UOp(Ops.REDUCE_AXIS, dtypes.int, (load,), (BinaryOps.ADD, (0,)))
reduce = UOp(Ops.REDUCE_AXIS, dtypes.int, (load,), (Ops.ADD, (0,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,)).to_uop(), reduce))
sink = UOp(Ops.SINK, src=(store,))
load_t = Tensor.full(load.st_arg.shape, 1).contiguous().realize()
@@ -1219,20 +1219,20 @@ class TestLinearizer(unittest.TestCase):
assert len(sched) == 1
lin = Kernel(sched[0].ast)
assert sum(u.op in {UnaryOps.RECIP, BinaryOps.FDIV} for u in lin.linearize().uops) == max_ops, msg
assert sum(u.op in {Ops.RECIP, Ops.FDIV} for u in lin.linearize().uops) == max_ops, msg
a = Tensor.empty((4,4))
b = Tensor.empty((4,4))
d = Tensor.empty((4,4))
c = (a*b)/b
helper(c, "found UnaryOps.RECIP in (a*b)/b operation")
helper(c, "found Ops.RECIP in (a*b)/b operation")
c = a/a
helper(c, "found UnaryOps.RECIP in (a/a) operation")
helper(c, "found Ops.RECIP in (a/a) operation")
c = (a/b)/d
helper(c, "found multiple UnaryOps.RECIP in (a/b)/d operation", 1)
helper(c, "found multiple Ops.RECIP in (a/b)/d operation", 1)
def test_sum_collapse(self):
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
@@ -1260,7 +1260,7 @@ class TestLinearizer(unittest.TestCase):
lin = Kernel(sched_copy[-1].ast)
lin.hand_coded_optimizations()
lin.linearize()
assert not any(u.op == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded"
assert not any(u.op == Ops.WHERE for u in lin.uops), "found where where where should be folded"
def test_phi_simplification(self):
def helper(t, max_ops=0):
@@ -1272,7 +1272,7 @@ class TestLinearizer(unittest.TestCase):
assert len(set([u.op for u in uops if u.op in {Ops.RANGE, Ops.SPECIAL}])) == 1, "has either specials or ranges, not both"
assert len([u for u in uops if u.op is Ops.ASSIGN]) == 0, "ASSIGN should have been simplified"
# TODO: once uops track min/max this will be fixed
#assert len([u for u in uops if u.op is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops"
#assert len([u for u in uops if u.op is Ops.MAX]) <= max_ops, "no unnecessary MAX ops"
helper(Tensor.arange(5.5, (3.5*300), 3.5), max_ops=2)
helper(Tensor.arange(-1, -100, -5), max_ops=2)
@@ -1602,7 +1602,7 @@ class TestFloat4(unittest.TestCase):
UOp(Ops.STORE, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.CAST, dtypes.float, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, src=(
@@ -1632,7 +1632,7 @@ class TestFloat4(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
@@ -1662,7 +1662,7 @@ class TestFloat4(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),))), # noqa: E501
UOp(Ops.CAST, dtypes.half, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (4, 6)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4, 6)), src=(
UOp(Ops.CAST, dtypes.float, src=(
UOp(Ops.LOAD, dtypes.half, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1),
@@ -1949,7 +1949,7 @@ class TestKernelOpts(unittest.TestCase):
UOp(Ops.STORE, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),))),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
@@ -2138,7 +2138,7 @@ class TestKernelOpts(unittest.TestCase):
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
ld0 = UOp(Ops.LOAD, dtypes.float, (g1, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501
ld1 = UOp(Ops.LOAD, dtypes.float, (g2, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501
store = UOp(Ops.STORE, src=(g0, ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)).to_uop(), UOp(Ops.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (BinaryOps.ADD, (0, 2, 4, 6)),))) # noqa: E501
store = UOp(Ops.STORE, src=(g0, ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)).to_uop(), UOp(Ops.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (Ops.ADD, (0, 2, 4, 6)),))) # noqa: E501
sink = UOp(Ops.SINK, src=(store,))
data1 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()
data2 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()

View File

@@ -5,7 +5,7 @@
import unittest
from test.helpers import ast_const
from tinygrad import Device, dtypes
from tinygrad.ops import UOp, Ops, BinaryOps
from tinygrad.ops import UOp, Ops
from tinygrad.helpers import getenv
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.engine.search import Opt, OptOps
@@ -21,7 +21,7 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.MAX, dtypes.half, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
@@ -64,7 +64,7 @@ class TestLinearizerDumb(unittest.TestCase):
ast_const(dtypes.bool, True, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
ast_const(dtypes.int, -1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=()),)),)),
ast_const(dtypes.int, 1000, st_src=(
@@ -75,7 +75,7 @@ class TestLinearizerDumb(unittest.TestCase):
for opt in opts: k.apply_opt(opt)
prg = k.to_program()
print(prg.src)
assert prg.uops is not None and not any(uop.op is BinaryOps.MAX for uop in prg.uops), "leftover MAX"
assert prg.uops is not None and not any(uop.op is Ops.MAX for uop in prg.uops), "leftover MAX"
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
def test_expander_new_srcs(self):
@@ -83,7 +83,7 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
@@ -105,14 +105,14 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (2,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32001, 63999), strides=(0, 0), offset=0, mask=((0, 32001), (31999, 63999)), contiguous=False), View(shape=(4096, 32000, 32000), strides=(0, 1, 64000), offset=0, mask=None, contiguous=False))), src=()),)),)),
ast_const(dtypes.int, -1, st_src=(
@@ -136,7 +136,7 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
@@ -168,7 +168,7 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 1)), src=(
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
@@ -200,7 +200,7 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),

View File

@@ -3,7 +3,7 @@ import unittest, random
import numpy as np
from tinygrad.codegen.kernel import Kernel, KernelOptError
from tinygrad.device import is_dtype_supported
from tinygrad.ops import UOp, Ops, BinaryOps
from tinygrad.ops import UOp, Ops
from tinygrad.engine.search import Opt, OptOps
from tinygrad import Device, dtypes, Tensor
from tinygrad.helpers import CI
@@ -47,7 +47,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 16), strides=(16, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
@@ -64,7 +64,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 2, 37, 9, 1, 1), strides=(666, 333, 9, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (4, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.MAX, (4, 5)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 2, 111, 27), strides=(6160, 3080, 28, 1), offset=0, mask=((0, 32), (0, 2), (0, 110), (0, 27)), contiguous=False), View(shape=(32, 2, 37, 9, 2, 2), strides=(5994, 2997, 81, 3, 27, 1), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
@@ -76,7 +76,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 16, 1), strides=(128, 16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),))
@@ -89,7 +89,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 6)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
@@ -111,7 +111,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
ast_const(dtypes.int, -1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))), src=()),)),)),
ast_const(dtypes.int, 10, st_src=(
@@ -125,7 +125,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 32, 1, 34, 1, 34), strides=(36992, 1156, 0, 34, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 4)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2, 4)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 32, 6, 8, 4, 6, 8, 4), strides=(2048, 64, 6291456, 8, 0, 1048576, 1, 0), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 8), (0, 1), (0, 6), (0, 8), (0, 1)), contiguous=False), View(shape=(512, 32, 6, 35, 6, 35), strides=(1179648, 36864, 6144, 192, 32, 1), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 32), (0, 6), (0, 32)), contiguous=False), View(shape=(512, 32, 238, 238), strides=(1411200, 44100, 210, 1), offset=0, mask=((0, 512), (0, 32), (0, 210), (0, 210)), contiguous=False), View(shape=(512, 32, 7, 34, 7, 34), strides=(1812608, 56644, 8092, 238, 34, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),))
@@ -142,7 +142,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.RECIP, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
x9:=UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
@@ -166,7 +166,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 0, 0, 4500, 0, 0, 0, 0, 900, 60, 12, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -183,7 +183,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
@@ -202,7 +202,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.RECIP, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 3)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 3)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
@@ -277,7 +277,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 6)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
@@ -299,7 +299,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 8)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 8)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
x5:=UOp(Ops.ADD, dtypes.float, arg=None, src=(
x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
@@ -313,7 +313,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
x6,)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 8)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 8)), src=(
x5,)),)),)),)),))
opts = [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.GROUP, axis=0, amt=4)]
helper_test_lin(Kernel(ast), opts, failed_platforms=[])
@@ -325,7 +325,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
@@ -344,7 +344,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 6)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
@@ -370,7 +370,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -405,7 +405,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1024), strides=(0, 1024, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),
@@ -420,7 +420,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 1, 28, 28, 1, 1), strides=(31360, 0, 784, 0, 28, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -442,7 +442,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
@@ -462,7 +462,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 9, 7, 3, 3), strides=(2268, 0, 567, 0, 63, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -511,7 +511,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 3)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 3)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
@@ -624,7 +624,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1025, 2047), strides=(0, 0), offset=0, mask=((0, 1025), (1023, 2047)), contiguous=False), View(shape=(1024, 1024), strides=(1, 2048), offset=0, mask=None, contiguous=False))), src=()),)),)),
ast_const(dtypes.int, -1, st_src=(
@@ -639,7 +639,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(129, 255), strides=(0, 0), offset=0, mask=((0, 129), (127, 255)), contiguous=False), View(shape=(128, 128), strides=(1, 256), offset=0, mask=None, contiguous=False))), src=()),)),)),
ast_const(dtypes.int, -1, st_src=(
@@ -678,7 +678,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 1), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.MAX, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(Ops.MAX, (3,)), src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 169, 13, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),))
@@ -731,7 +731,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
@@ -749,7 +749,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 12, 31, 31, 1, 1, 1), strides=(11532, 0, 961, 31, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
@@ -767,7 +767,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 1), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.EXP2, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
@@ -791,7 +791,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 256, 14, 14, 1, 1, 1), strides=(50176, 0, 196, 14, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
@@ -809,7 +809,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
x5:=UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -851,7 +851,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 1, 1), strides=(180, 0, 30, 3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.MAX, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6, 7)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -875,7 +875,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(5, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.uchar, arg=None, src=(
UOp(Ops.ADD, dtypes.uint, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.uint, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.uint, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.CAST, dtypes.uint, arg=None, src=(
ast_const(dtypes.uchar, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(6, 9), strides=(0, 0), offset=0, mask=((0, 6), (4, 9)), contiguous=False), View(shape=(5, 5), strides=(1, 10), offset=0, mask=None, contiguous=False))), src=()),)),)),)),
@@ -895,7 +895,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.MAX, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6, 7)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
@@ -920,7 +920,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 32, 1, 1, 1, 5, 5, 256), strides=(0, 0, 6400, 0, 0, 0, 1280, 256, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 3, 4)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 3, 4)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
@@ -943,7 +943,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.MAX, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6, 7)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
@@ -969,7 +969,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)),
ast_const(dtypes.int, -1, st_src=(
@@ -987,7 +987,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 1, 1, 1), strides=(100352, 0, 784, 28, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
@@ -1007,7 +1007,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
@@ -1021,7 +1021,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
@@ -1035,7 +1035,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
@@ -1052,7 +1052,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 1, 1, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2, 3)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -1065,7 +1065,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (4,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 3), strides=(0, 0), offset=0, mask=((0, 3), (1, 3)), contiguous=False), View(shape=(2, 3, 2, 3, 3), strides=(0, 0, 1, 0, 4), offset=0, mask=((0, 2), (0, 3), (0, 2), (0, 3), (0, 2)), contiguous=False))), src=()),)),)),
x19:=ast_const(dtypes.int, -1, st_src=(
@@ -1078,7 +1078,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=3, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (4,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5), strides=(0, 0), offset=0, mask=((0, 4), (2, 5)), contiguous=False), View(shape=(2, 3, 2, 3, 3), strides=(0, 0, 0, 1, 6), offset=0, mask=None, contiguous=False))), src=()),)),)),
x19,)),)),
@@ -1093,7 +1093,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.bool, arg=None, src=(
@@ -1127,7 +1127,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)),
ast_const(dtypes.int, -1, st_src=(
@@ -1142,7 +1142,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 1, 1, 256, 1, 1, 256), strides=(0, 0, 65536, 0, 0, 256, 0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3, 4)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3, 4)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
@@ -1160,7 +1160,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 1), strides=(6, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -1178,7 +1178,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 20, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.bool, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.bool, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.MUL, dtypes.bool, arg=None, src=(
UOp(Ops.LOAD, dtypes.bool, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=1, src=()),
@@ -1220,7 +1220,7 @@ class TestLinearizerFailures(unittest.TestCase):
x9,)),
UOp(Ops.ADD, dtypes.half, arg=None, src=(
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
@@ -1249,7 +1249,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
@@ -1267,7 +1267,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.uchar, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.uchar, arg=(Ops.ADD, (1,)), src=(
UOp(Ops.MUL, dtypes.uchar, arg=None, src=(
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()),
@@ -1279,7 +1279,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.WHERE, dtypes.int, arg=None, src=(
UOp(Ops.VALID, dtypes.bool, arg=None, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(50001, 99999), strides=(0, 0), offset=0, mask=((0, 50001), (49999, 99999)), contiguous=False), View(shape=(1024, 50000, 50000), strides=(0, 1, 100000), offset=0, mask=None, contiguous=False))), src=()),)),
@@ -1306,7 +1306,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
@@ -1326,7 +1326,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(W, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(

View File

@@ -8,7 +8,7 @@ from tinygrad.engine.search import Opt, OptOps
from tinygrad.engine.search import time_linearizer, bufs_from_lin
# stuff needed to unpack a kernel
from tinygrad.ops import UOp, Ops, BinaryOps
from tinygrad.ops import UOp, Ops
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
@@ -33,7 +33,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -68,7 +68,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 1, 1, 1), strides=(65536, 0, 1024, 32, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -85,7 +85,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -102,7 +102,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -119,7 +119,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -136,7 +136,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -153,7 +153,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -175,7 +175,7 @@ class TestLinearizerOverflowAlt(unittest.TestCase):
in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop()
ot_st = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)).to_uop()
prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2))
store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5)))))
store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (Ops.ADD, (7, 6, 5)))))
ast = UOp(Ops.SINK, src=(store,))
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=2)]
_test_overflow(ast, opts)
@@ -187,7 +187,7 @@ class TestLinearizerOverflowAlt(unittest.TestCase):
in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop()
ot_st = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)).to_uop()
prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2))
store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5)))))
store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (Ops.ADD, (7, 6, 5)))))
ast = UOp(Ops.SINK, src=(store,))
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=4, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=5, amt=2)]
_test_overflow(ast, opts)

View File

@@ -1,7 +1,7 @@
import unittest, functools, random
from typing import List
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes
from tinygrad.ops import MetaOps, BinaryOps, Ops
from tinygrad.ops import Ops
from tinygrad.helpers import CI, getenv, prod, Context
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.engine.schedule import create_schedule
@@ -481,7 +481,7 @@ class TestMultiTensor(unittest.TestCase):
for p in get_parameters(bn): p.shard_(devices_4).realize()
out = bn(t)
scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.outputs[0].device in devices_4 and sched.ast.op is not MetaOps.COPY]
scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.outputs[0].device in devices_4 and sched.ast.op is not Ops.COPY]
assert set(out.device for sched in scheds for out in sched.outputs) == set(devices_4), "should have ast on each shard device"
asts = [sched.ast for sched in scheds]
assert len(asts)
@@ -640,21 +640,21 @@ class TestMultiTensor(unittest.TestCase):
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is Ops.STORE
assert ast.src[2].op is BinaryOps.ADD
assert ast.src[2].op is Ops.ADD
assert ast.src[2].src[0].op is Ops.LOAD
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 1
t = 2 * t
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is Ops.STORE
assert ast.src[2].op is BinaryOps.MUL
assert ast.src[2].op is Ops.MUL
assert ast.src[2].src[0].src[1].op is Ops.CONST and ast.src[2].src[0].src[1].arg == 2
assert ast.src[2].src[1].op is Ops.LOAD
t = t + t.full_like(3)
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is Ops.STORE
assert ast.src[2].op is BinaryOps.ADD
assert ast.src[2].op is Ops.ADD
assert ast.src[2].src[0].op is Ops.LOAD
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 3

View File

@@ -8,7 +8,7 @@ from tinygrad.dtype import dtypes
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import dedup, flatten, prod
from tinygrad.renderer.cstyle import CStyleLanguage
from tinygrad.ops import BinaryOps, UOp, Ops
from tinygrad.ops import UOp, Ops
from tinygrad.renderer import Program
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.lazy import LazyBuffer
@@ -34,7 +34,7 @@ class TestCStyleFailures(unittest.TestCase):
b = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
idx = UOp.const(dtypes.int, 0)
ld = UOp(Ops.LOAD, dtypes.int, (b.index(idx),))
alu = ld.alu(BinaryOps.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1))
alu = ld.alu(Ops.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1))
store = UOp.store(a.index(idx), alu)
sink = UOp(Ops.SINK, dtypes.void, (store,))
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))

View File

@@ -13,7 +13,7 @@ from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, Ops, graph_rewrite, track_rewrites
from tinygrad.ops import UOp, Ops, graph_rewrite, track_rewrites
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, view_left
@@ -1042,7 +1042,7 @@ class TestSchedule(unittest.TestCase):
b = r.sum(0) * 4
c = r.sum(1) * 2
schedule = check_schedule([b, c], 3)
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
# multireduce spec
def test_multireduce_simple_chase(self):
@@ -1053,7 +1053,7 @@ class TestSchedule(unittest.TestCase):
c = r.sum(1) + 12
np_r = (a.numpy() + (a.numpy().sum(0) + 6)).sum(0) * 2
# schedule = check_schedule([b,c], 3)
# self.assertIs(schedule[0].ast[0].src[0].arg, BinaryOps.MUL)
# self.assertIs(schedule[0].ast[0].src[0].arg, Ops.MUL)
schedule = check_schedule([b,c], 4)
run_schedule(schedule)
np.testing.assert_allclose(b.numpy(), np_r.sum(0) + 8, atol=1e-4, rtol=1e-4)
@@ -1066,7 +1066,7 @@ class TestSchedule(unittest.TestCase):
d = r.T * 4
e = r * d
schedule = check_schedule([d, e], 3)
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
# multireduce spec
def test_multireduce_push_permute_chase(self):
@@ -1077,7 +1077,7 @@ class TestSchedule(unittest.TestCase):
d = r.T * 4
e = r * (d + a).sum(2)
schedule = check_schedule([d, e], 3) # make sure it doesn't fuse
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
run_schedule(schedule)
np.testing.assert_allclose(d.numpy(), (a.numpy().sum(2) + b.numpy()).T * 4, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(e.numpy(), (a.numpy().sum(2) + b.numpy()) * (d.numpy() + a.numpy()).sum(2), atol=1e-4, rtol=1e-4)
@@ -1089,7 +1089,7 @@ class TestSchedule(unittest.TestCase):
r = a.sum(1) + c
d = r[:4] * b
schedule = check_schedule(d, 2)
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
# multireduce spec
def test_multireduce_push_shrink_chase(self):
@@ -1102,7 +1102,7 @@ class TestSchedule(unittest.TestCase):
out = r[:4] * b + d.sum(1)[:4]
# schedule = check_schedule(out, 2)
schedule = check_schedule(out, 3)
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
run_schedule(schedule)
np.testing.assert_allclose(out.numpy(), (a.numpy().sum(1) + c.numpy())[:4] * b.numpy() + d.numpy().sum(1)[:4], atol=1e-4, rtol=1e-4)
@@ -1287,16 +1287,16 @@ class TestSchedule(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported")
def test_bitcast_subbufer(self):
x = cast(LazyBuffer, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata)
a = x.alu(UnaryOps.EXP2).cast(dtypes.int32, True, allow_buffer_view=True)
a = x.alu(Ops.EXP2).cast(dtypes.int32, True, allow_buffer_view=True)
b = x.cast(dtypes.int32, True, allow_buffer_view=True)
b = a.alu(BinaryOps.ADD, b)
b = a.alu(Ops.ADD, b)
check_schedule(b, 2) # this should fuse when it makes sense
def test_bitcast_disable_subbufer(self):
x = cast(LazyBuffer, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata)
a = x.alu(UnaryOps.EXP2).cast(dtypes.int32, True, allow_buffer_view=False)
a = x.alu(Ops.EXP2).cast(dtypes.int32, True, allow_buffer_view=False)
b = x.cast(dtypes.int32, True, allow_buffer_view=False)
b = a.alu(BinaryOps.ADD, b)
b = a.alu(Ops.ADD, b)
check_schedule(b, 1)
def test_reduceop_reshape_dont_push(self):
@@ -1530,7 +1530,7 @@ class TestIndexing(unittest.TestCase):
def test_arange_view_op(self):
a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).contiguous()
assert isinstance(a.lazydata, LazyBuffer)
self.assertIs(a.lazydata.base.op, MetaOps.BUFFER_VIEW)
self.assertIs(a.lazydata.base.op, Ops.BUFFER_VIEW)
self.check_schedule(a, 1)
np.testing.assert_equal(a.numpy(), [[4, 5]])
@@ -1538,7 +1538,7 @@ class TestIndexing(unittest.TestCase):
def test_arange_shrink_copy(self):
a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).to("CLANG")
assert isinstance(a.lazydata, LazyBuffer)
self.assertIs(a.lazydata.base.op, MetaOps.COPY)
self.assertIs(a.lazydata.base.op, Ops.COPY)
self.check_schedule(a, 1)
np.testing.assert_equal(a.numpy(), [[4, 5]])
@@ -1546,8 +1546,8 @@ class TestIndexing(unittest.TestCase):
def test_arange_expand_copy(self):
a = Tensor.arange(4).reshape(2, 2, 1).expand(2, 2, 2).to("CLANG")
assert isinstance(a.lazydata, LazyBuffer)
self.assertIs(a.lazydata.base.op, MetaOps.COPY)
self.assertIs(a.lazydata.base.srcs[0].base.op, BinaryOps.ADD)
self.assertIs(a.lazydata.base.op, Ops.COPY)
self.assertIs(a.lazydata.base.srcs[0].base.op, Ops.ADD)
self.check_schedule(a, 1)
np.testing.assert_equal(a.numpy(), [[[0, 0], [1, 1]], [[2, 2], [3, 3]]])
@@ -1643,7 +1643,7 @@ class TestIndexing(unittest.TestCase):
def test_simple_store_reshape(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1)))
r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
r = r + 2
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
@@ -1655,7 +1655,7 @@ class TestIndexing(unittest.TestCase):
def test_no_reshape_reduceop(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1)))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
rsink = graph_rewrite(sink, view_right)
verify_ast(sink)
@@ -1698,7 +1698,7 @@ class TestSwizzle(unittest.TestCase):
# LazyBuffer to pre-rewrite AST
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,)))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0,)))
swizzle_r = UOp(Ops.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(()))
alu = swizzle_r+1
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),))
@@ -1720,9 +1720,9 @@ class TestSwizzle(unittest.TestCase):
# LazyBuffer to pre-rewrite AST
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld1,), (Ops.ADD, (0,)))
ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop()))
r2 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld2,), (BinaryOps.ADD, (0,)))
r2 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld2,), (Ops.ADD, (0,)))
alu = UOp(Ops.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(Ops.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(()))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+2,),),)) # noqa: E501
# graph rewrite
@@ -1736,7 +1736,7 @@ class TestSwizzle(unittest.TestCase):
def test_swizzle_rewrite_alt(self):
swizzle = UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501

View File

@@ -3,7 +3,7 @@ import unittest
from test.helpers import ast_const
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.codegen.kernel import Kernel
from tinygrad.ops import UOp, Ops, BinaryOps
from tinygrad.ops import UOp, Ops
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search
from tinygrad.device import Device, Buffer
@@ -107,7 +107,7 @@ class TestBEAM(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (1,)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.MAX, (1,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(

View File

@@ -2,8 +2,7 @@ from typing import List
import unittest, time
from tinygrad import dtypes, Device
from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, Ops, UOp, KernelInfo
from tinygrad.ops import UPat, PatternMatcher
from tinygrad.ops import Ops, UOp, KernelInfo, UPat, PatternMatcher
from tinygrad.renderer import Renderer
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, sym
@@ -541,7 +540,7 @@ class TestExpander(unittest.TestCase):
@unittest.skip("no longer supported")
def test_reduce_known_axis(self):
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
sink = UOp(Ops.REDUCE, dtypes.int, (3*e1,e1), BinaryOps.ADD)
sink = UOp(Ops.REDUCE, dtypes.int, (3*e1,e1), Ops.ADD)
sink = expander_rewrite(sink)
assert sink.op is Ops.CONST
self.assertEqual(sink.arg, 3*(0+1+2+3))
@@ -549,7 +548,7 @@ class TestExpander(unittest.TestCase):
@unittest.skip("no longer supported")
def test_reduce_const(self):
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
sink = UOp(Ops.REDUCE, dtypes.int, (UOp.const(dtypes.int, 3), e1), BinaryOps.ADD)
sink = UOp(Ops.REDUCE, dtypes.int, (UOp.const(dtypes.int, 3), e1), Ops.ADD)
sink = expander_rewrite(sink)
assert sink.op is Ops.CONST
self.assertEqual(sink.arg, 3*4)
@@ -590,7 +589,7 @@ class TestExpander(unittest.TestCase):
def test_reduce_different_axis(self):
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e2 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
sink = UOp(Ops.REDUCE, dtypes.int, (e1,e2), BinaryOps.ADD)
sink = UOp(Ops.REDUCE, dtypes.int, (e1,e2), Ops.ADD)
sink = expander_rewrite(sink)
print(sink)

View File

@@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.helpers import CI, DEBUG, getenv, Context
from tinygrad.dtype import dtypes, DType
from tinygrad.device import Buffer, Device
from tinygrad.ops import Ops, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, KernelInfo, exec_alu, spec # noqa F401
from tinygrad.ops import Ops, UOp, UPat, KernelInfo, exec_alu, spec # noqa F401
from tinygrad.renderer import Program
from tinygrad.engine.schedule import create_schedule, to_si
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
@@ -29,7 +29,7 @@ def uop(uops:List[UOp], uop:Ops, dtype:Optional[DType], src:Tuple[UOp, ...], arg
def _test_single_value(vals, op, dts):
uops = []
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
buf_loads = [uop(uops, Ops.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0))]) for i, dtype in enumerate(dts))
@@ -45,7 +45,7 @@ def _test_single_value(vals, op, dts):
def _test_single_value_const(vals, op, dts):
uops = []
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
loads = (uop(uops, Ops.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
alu = uop(uops, op, output_dtype, loads)
@@ -103,49 +103,49 @@ class TestUOps(unittest.TestCase):
class TestFloatUOps(TestUOps):
@unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop')
def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a))
def test_exp2(self): self._test_uop_fxn(Ops.EXP2, lambda a: np.exp2(a))
@unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop')
def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan'))
def test_log2(self): self._test_uop_fxn(Ops.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan'))
@unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop')
def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a))
def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1/a if a != 0 else float('inf'))
def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
def test_sin(self): self._test_uop_fxn(Ops.SIN, lambda a: math.sin(a))
def test_recip(self): self._test_uop_fxn(Ops.RECIP, lambda a: 1/a if a != 0 else float('inf'))
def test_sqrt(self): self._test_uop_fxn(Ops.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b)
def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a<b)
def test_cmpne(self): self._test_bop_fxn(BinaryOps.CMPNE, lambda a,b: a!=b)
def test_add(self): self._test_bop_fxn(Ops.ADD, lambda a,b: a+b)
def test_mul(self): self._test_bop_fxn(Ops.MUL, lambda a,b: a*b)
def test_max(self): self._test_bop_fxn(Ops.MAX, lambda a,b: max(a,b))
def test_cmplt(self): self._test_bop_fxn(Ops.CMPLT, lambda a,b: a<b)
def test_cmpne(self): self._test_bop_fxn(Ops.CMPNE, lambda a,b: a!=b)
# MOD isn't tested on floats
def test_where(self):
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float, dtypes.float))
self._test_top_fxn(Ops.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float, dtypes.float))
@unittest.skipUnless(getenv("PYTHON"), "only python supports MULACC")
def test_mulacc(self):
self._test_top_fxn(TernaryOps.MULACC, lambda a,b,c: a*b+c, (dtypes.float, dtypes.float, dtypes.float))
self._test_top_fxn(Ops.MULACC, lambda a,b,c: a*b+c, (dtypes.float, dtypes.float, dtypes.float))
class TestNonFloatUOps(TestUOps):
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), (dtypes.int32, dtypes.int32))
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), (dtypes.int32, dtypes.int32))
def test_add_int32(self): self._test_bop_fxn(Ops.ADD, lambda a,b: int(a)+int(b), (dtypes.int32, dtypes.int32))
def test_mul_int32(self): self._test_bop_fxn(Ops.MUL, lambda a,b: int(a)*int(b), (dtypes.int32, dtypes.int32))
@unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts")
def test_shr_int32(self): self._test_bop_fxn(BinaryOps.SHR, lambda a,b: int(a)>>int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
def test_shr_int32(self): self._test_bop_fxn(Ops.SHR, lambda a,b: int(a)>>int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
@unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts")
def test_shl_int32(self): self._test_bop_fxn(BinaryOps.SHL, lambda a,b: int(a)<<int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
def test_shl_int32(self): self._test_bop_fxn(Ops.SHL, lambda a,b: int(a)<<int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
def test_div_int32(self):
self._test_bop_fxn(BinaryOps.IDIV, lambda a,b: int(a/b), (dtypes.int32, dtypes.int32), no_b_zero=True)
def test_and_int32(self): self._test_bop_fxn(BinaryOps.AND, lambda a,b: int(a)&int(b), (dtypes.int32, dtypes.int32))
def test_or_int32(self): self._test_bop_fxn(BinaryOps.OR, lambda a,b: int(a)|int(b), (dtypes.int32, dtypes.int32))
self._test_bop_fxn(Ops.IDIV, lambda a,b: int(a/b), (dtypes.int32, dtypes.int32), no_b_zero=True)
def test_and_int32(self): self._test_bop_fxn(Ops.AND, lambda a,b: int(a)&int(b), (dtypes.int32, dtypes.int32))
def test_or_int32(self): self._test_bop_fxn(Ops.OR, lambda a,b: int(a)|int(b), (dtypes.int32, dtypes.int32))
def test_mod_int32(self):
self._test_bop_fxn(BinaryOps.MOD,
self._test_bop_fxn(Ops.MOD,
lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], (dtypes.int32, dtypes.int32), no_b_zero=True)
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: int(a)<int(b), (dtypes.int32, dtypes.int32))
def test_cmpne_int32(self): self._test_bop_fxn(BinaryOps.CMPNE, lambda a,b: int(a)!=int(b), (dtypes.int32, dtypes.int32))
def test_cmplt_int32(self): self._test_bop_fxn(Ops.CMPLT, lambda a,b: int(a)<int(b), (dtypes.int32, dtypes.int32))
def test_cmpne_int32(self): self._test_bop_fxn(Ops.CMPNE, lambda a,b: int(a)!=int(b), (dtypes.int32, dtypes.int32))
@unittest.skipUnless(is_dtype_supported(dtypes.bool), "dtype not supported")
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), (dtypes.bool, dtypes.bool))
def test_mul_bool(self): self._test_bop_fxn(Ops.MUL, lambda a,b: bool(a) and bool(b), (dtypes.bool, dtypes.bool))
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "dtype not supported")
def test_where_float16(self):
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float16, dtypes.float16))
self._test_top_fxn(Ops.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float16, dtypes.float16))
class TestBoolUOps(TestUOps):
def _test_uop_bool_fxn(self, op, fxn):
@@ -166,72 +166,72 @@ class TestBoolUOps(TestUOps):
for c in [False, True]:
self._equal(f([a,b,c], op, (dtypes.bool, )*3), fxn(a,b,c))
def test_add_bool(self): self._test_bop_bool_fxn(BinaryOps.ADD, lambda a,b: a or b)
def test_mul_bool(self): self._test_bop_bool_fxn(BinaryOps.MUL, lambda a,b: a and b)
def test_xor_bool(self): self._test_bop_bool_fxn(BinaryOps.XOR, lambda a,b: a != b)
def test_and_bool(self): self._test_bop_bool_fxn(BinaryOps.AND, lambda a,b: a & b)
def test_or_bool(self): self._test_bop_bool_fxn(BinaryOps.OR, lambda a,b: a | b)
def test_cmpne_bool(self): self._test_bop_bool_fxn(BinaryOps.CMPNE, lambda a,b: a != b)
def test_cmplt_bool(self): self._test_bop_bool_fxn(BinaryOps.CMPLT, lambda a,b: a < b)
def test_where_bool(self): self._test_top_bool_fxn(TernaryOps.WHERE, lambda a,b,c: b if a else c)
def test_add_bool(self): self._test_bop_bool_fxn(Ops.ADD, lambda a,b: a or b)
def test_mul_bool(self): self._test_bop_bool_fxn(Ops.MUL, lambda a,b: a and b)
def test_xor_bool(self): self._test_bop_bool_fxn(Ops.XOR, lambda a,b: a != b)
def test_and_bool(self): self._test_bop_bool_fxn(Ops.AND, lambda a,b: a & b)
def test_or_bool(self): self._test_bop_bool_fxn(Ops.OR, lambda a,b: a | b)
def test_cmpne_bool(self): self._test_bop_bool_fxn(Ops.CMPNE, lambda a,b: a != b)
def test_cmplt_bool(self): self._test_bop_bool_fxn(Ops.CMPLT, lambda a,b: a < b)
def test_where_bool(self): self._test_top_bool_fxn(Ops.WHERE, lambda a,b,c: b if a else c)
class TestExecALU(TestUOps):
def test_sqrt(self):
self.assertEqual(exec_alu(UnaryOps.SQRT, dtypes.float, (0.0,)), 0.0)
self.assertEqual(exec_alu(Ops.SQRT, dtypes.float, (0.0,)), 0.0)
def test_div(self):
self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (8, 2)), 4)
self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (7, 3)), 2)
self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (7, -3)), -2)
self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (-50, 6)), -8)
self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (8, 2)), 4)
self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (7, 3)), 2)
self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (7, -3)), -2)
self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (-50, 6)), -8)
np.testing.assert_allclose(exec_alu(BinaryOps.MUL, dtypes.float32, (7.0, exec_alu(UnaryOps.RECIP, dtypes.float32, (3.0,)))), 2+(1.0/3.0))
np.testing.assert_allclose(exec_alu(BinaryOps.MUL, dtypes.float32, (7.0, exec_alu(UnaryOps.RECIP, dtypes.float32, (-3.0,)))), -2-(1.0/3.0))
np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIP, dtypes.float32, (3.0,)))), 2+(1.0/3.0))
np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIP, dtypes.float32, (-3.0,)))), -2-(1.0/3.0))
def test_recip(self):
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (8,)), 1/8)
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (7,)), 1/7)
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (-3,)), 1/-3)
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (-50,)), 1/-50)
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (8,)), 1/8)
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (7,)), 1/7)
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (-3,)), 1/-3)
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (-50,)), 1/-50)
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, ((32+521+3),)), 1/(32+521+3))
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, ((34**2),)), 1/(34**2))
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (10,)), 1/10)
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, ((32+521+3),)), 1/(32+521+3))
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, ((34**2),)), 1/(34**2))
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (10,)), 1/10)
def test_bool_cmplt(self):
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (False, False)), False)
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (False, True)), True)
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (True, False)), False)
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (True, True)), False)
self.assertEqual(exec_alu(Ops.CMPLT, dtypes.bool, (False, False)), False)
self.assertEqual(exec_alu(Ops.CMPLT, dtypes.bool, (False, True)), True)
self.assertEqual(exec_alu(Ops.CMPLT, dtypes.bool, (True, False)), False)
self.assertEqual(exec_alu(Ops.CMPLT, dtypes.bool, (True, True)), False)
def test_bool_cmpne(self):
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (False, False)), False)
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (False, True)), True)
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (True, False)), True)
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (True, True)), False)
self.assertEqual(exec_alu(Ops.CMPNE, dtypes.bool, (False, False)), False)
self.assertEqual(exec_alu(Ops.CMPNE, dtypes.bool, (False, True)), True)
self.assertEqual(exec_alu(Ops.CMPNE, dtypes.bool, (True, False)), True)
self.assertEqual(exec_alu(Ops.CMPNE, dtypes.bool, (True, True)), False)
def test_bool_where(self):
self.assertEqual(exec_alu(TernaryOps.WHERE, dtypes.bool, (False, False, False)), False)
self.assertEqual(exec_alu(TernaryOps.WHERE, dtypes.int, (False, 2, 4)), 4)
np.testing.assert_allclose(exec_alu(TernaryOps.WHERE, dtypes.float, (False, 2.2, 4.5)), 4.5)
self.assertEqual(exec_alu(Ops.WHERE, dtypes.bool, (False, False, False)), False)
self.assertEqual(exec_alu(Ops.WHERE, dtypes.int, (False, 2, 4)), 4)
np.testing.assert_allclose(exec_alu(Ops.WHERE, dtypes.float, (False, 2.2, 4.5)), 4.5)
def test_overflow(self):
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (250, 250)), 244)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (256, 0)), 0)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (0, -1)), 255)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (0, -1000)), 24)
self.assertEqual(exec_alu(Ops.ADD, dtypes.uint8, (250, 250)), 244)
self.assertEqual(exec_alu(Ops.ADD, dtypes.uint8, (256, 0)), 0)
self.assertEqual(exec_alu(Ops.ADD, dtypes.uint8, (0, -1)), 255)
self.assertEqual(exec_alu(Ops.ADD, dtypes.uint8, (0, -1000)), 24)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (127, 0)), 127)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-128, 0)), -128)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-100, -100)), 56)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-1000, -0)), 24)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-130, -0)), 126)
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (127, 0)), 127)
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (-128, 0)), -128)
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (-100, -100)), 56)
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (-1000, -0)), 24)
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (-130, -0)), 126)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (1, 1)), 2)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-128, 0)), -128)
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (1, 1)), 2)
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (-128, 0)), -128)
# test no truncate
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (250, 250), truncate_output=False), 500)
self.assertEqual(exec_alu(Ops.ADD, dtypes.uint8, (250, 250), truncate_output=False), 500)
class TestConstantFolding(unittest.TestCase):
def test_cast_const(self):
@@ -336,8 +336,8 @@ class TestAssembly(unittest.TestCase):
a2 = UOp(Ops.MUL, dtypes.int, (l1, c2))
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops[-1].op, BinaryOps.SHL)
self.assertEqual(uops[-2].op, BinaryOps.MUL)
self.assertEqual(uops[-1].op, Ops.SHL)
self.assertEqual(uops[-2].op, Ops.MUL)
def test_bitshift_right(self):
g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
@@ -348,8 +348,8 @@ class TestAssembly(unittest.TestCase):
a2 = UOp(Ops.IDIV, dtypes.int, (l1, c2))
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops[-1].op, BinaryOps.SHR)
self.assertEqual(uops[-2].op, BinaryOps.IDIV)
self.assertEqual(uops[-1].op, Ops.SHR)
self.assertEqual(uops[-2].op, Ops.IDIV)
class TestUOpMethod(unittest.TestCase):
@unittest.skip("uops lt no longer ordered")

View File

@@ -1,8 +1,7 @@
from typing import Dict, List, Optional
import unittest
from tinygrad.dtype import dtypes
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, \
graph_rewrite, contexts, track_rewrites
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, contexts, track_rewrites
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json
@track_rewrites()
@@ -35,7 +34,7 @@ class TestViz(unittest.TestCase):
def test_rewrite_twice(self):
pm = PatternMatcher([
(UPat.var("x")+UPat.var("x"), lambda x:x*2),
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))),
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(Ops.SHL, UOp.const(dtypes.int, 1))),
])
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
uops = helper_test_viz(a+a, pm)

View File

@@ -1,7 +1,7 @@
import unittest, math
from tinygrad import dtypes
from tinygrad.helpers import all_same
from tinygrad.ops import GroupOp, UOp, Ops, BinaryOps, exec_alu
from tinygrad.ops import GroupOp, UOp, Ops, exec_alu
from tinygrad.codegen.uopgraph import full_graph_rewrite
# Helper function to apply the graph rewrite
@@ -56,7 +56,7 @@ class TestFoldingAndReduction(unittest.TestCase):
const1 = UOp.const(dtypes.int32, 5)
const2 = UOp.const(dtypes.int32, 10)
const3 = UOp.const(dtypes.int32, 20)
optimized_sink = apply_rewrite((const1 + const2 + const3).reduce(BinaryOps.ADD))
optimized_sink = apply_rewrite((const1 + const2 + const3).reduce(Ops.ADD))
expected_sum = 5 + 10 + 20
self.assertEqual(optimized_sink.arg, expected_sum)
@@ -65,14 +65,14 @@ class TestFoldingAndReduction(unittest.TestCase):
const1 = UOp.const(dtypes.int32, 15)
const2 = UOp.const(dtypes.int32, 25)
rng = UOp.range(dtypes.int32, 0, 10, idx=0)
optimized_sink = apply_rewrite((const1 + const2).reduce(BinaryOps.ADD, rng))
optimized_sink = apply_rewrite((const1 + const2).reduce(Ops.ADD, rng))
expected_sum = 10 * (15 + 25)
self.assertEqual(optimized_sink.arg, expected_sum)
@unittest.skip("currently failing")
def test_full_graph_rewrite_range_reduction(self):
simple_range = UOp.range(dtypes.int32, 0, 5, idx=0)
optimized_sink = apply_rewrite(simple_range.reduce(BinaryOps.ADD, simple_range))
optimized_sink = apply_rewrite(simple_range.reduce(Ops.ADD, simple_range))
expected_sum = sum(range(5))
self.assertEqual(optimized_sink.arg, expected_sum)
@@ -80,7 +80,7 @@ class TestFoldingAndReduction(unittest.TestCase):
def test_full_graph_rewrite_simple_reduction_folding(self):
simple_range = UOp.range(dtypes.int32, 0, 4, idx=0)
add_uop = simple_range + UOp.const(dtypes.int32, 1)
optimized_sink = apply_rewrite(add_uop.reduce(BinaryOps.ADD, simple_range))
optimized_sink = apply_rewrite(add_uop.reduce(Ops.ADD, simple_range))
expected_sum = sum(i + 1 for i in range(4))
self.assertEqual(optimized_sink.arg, expected_sum)
@@ -89,7 +89,7 @@ class TestFoldingAndReduction(unittest.TestCase):
outer_range = UOp.range(dtypes.int32, 0, 8, 0)
inner_range = UOp.range(dtypes.int32, 0, 4, 1)
expr = (outer_range * 10) + inner_range
optimized_reduce_uop = apply_rewrite(expr.reduce(BinaryOps.ADD, outer_range, inner_range))
optimized_reduce_uop = apply_rewrite(expr.reduce(Ops.ADD, outer_range, inner_range))
self.assertEqual(optimized_reduce_uop.op, Ops.CONST)
self.assertEqual(optimized_reduce_uop.arg, sum((i * 10) + j for i in range(8) for j in range(4)))
@@ -104,7 +104,7 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
def test_full_graph_rewrite_division_folding_with_define_var(self):
n_var_uop = UOp.variable('n', 1, 1000)
optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3)
self.assertEqual(optimized_div_uop.op, BinaryOps.MUL)
self.assertEqual(optimized_div_uop.op, Ops.MUL)
self.assertEqual(optimized_div_uop.src[1].arg, 2)
def test_full_graph_rewrite_complex_mod_div_folding(self):

View File

@@ -1,6 +1,6 @@
import unittest, itertools
from tinygrad.dtype import dtypes
from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, UnaryOps, GroupOp # noqa: F401
from tinygrad.ops import Ops, UOp, GroupOp # noqa: F401
from tinygrad.ops import PatternMatcher, UPat
class TestPatternMatcher(unittest.TestCase):
@@ -140,9 +140,9 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c2), None)
# that CONST/ALU -> ALU/CONST rewrite is now instant
"""
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.ALU))), lambda x: x)])
c4 = UOp(UOps.ALU, dtypes.float, (c1,c3), BinaryOps.ADD)
c5 = UOp(UOps.ALU, dtypes.float, (c3,c1), BinaryOps.ADD)
matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=(UPat(Ops.CONST), UPat(GroupOp.ALU))), lambda x: x)])
c4 = UOp(Ops.ADD, dtypes.float, (c1,c3))
c5 = UOp(Ops.ADD, dtypes.float, (c3,c1))
self.assertEqual(matcher.rewrite(c3), None)
self.assertEqual(matcher.rewrite(c4), c4)
self.assertEqual(matcher.rewrite(c5), None)

View File

@@ -5,7 +5,7 @@ from collections import defaultdict
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence
from enum import Enum, auto
from tinygrad.ops import GroupOp, BinaryOps, KernelInfo, UOp, Ops, PatternMatcher, can_pad, print_uops, type_verify, resolve, Variable, sint, \
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, PatternMatcher, can_pad, print_uops, type_verify, resolve, Variable, sint, \
graph_rewrite, track_rewrites
from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, Program
@@ -276,7 +276,7 @@ class Kernel:
if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None
mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
if mul_op.op is not BinaryOps.MUL: return None
if mul_op.op is not Ops.MUL: return None
def buf_index(src:UOp) -> Optional[int]:
# TODO: apply tc even if the sources are not from LOAD
@@ -303,7 +303,7 @@ class Kernel:
return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is BinaryOps.ADD:
if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD:
for tc in self.opts.tensor_cores:
tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
# can only fuse reduces with the same tc options
@@ -338,7 +338,7 @@ class Kernel:
2: apply tensor core shape but don't use UOp.WMMA
extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
0: applies to only kernels with a single reduce axis and direct UOps.LOAD into BinaryOps.MUL
0: applies to only kernels with a single reduce axis and direct UOps.LOAD into Ops.MUL
1: allows kernels with multiple reduce axes and also multiplication of UOps.CAST'd buffers
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
"""
@@ -441,7 +441,7 @@ class Kernel:
check(not self.vars, "does not work with symbolic shape")
check(axis < self.first_upcast, "cannot pad upcasted")
# ok to pad SUM if all parent ALU ops have f(0) = 0
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is BinaryOps.ADD and can_pad(r), f"cannot pad {r}")
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r), f"cannot pad {r}")
padded = False
for i,st in enumerate(self.sts):
if (s:=st.shape[axis]) == 1: continue # reduced
@@ -470,8 +470,8 @@ class Kernel:
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
self.reduceop is not None and self.reduceop.arg[0] is BinaryOps.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \
(mulop:=self.reduceop.src[0]).op is BinaryOps.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \
(mulop:=self.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
st0, st1 = self.sts[self.bufs.index(mulop.src[0])], self.sts[self.bufs.index(mulop.src[1])]
strides0, strides1 = st0.real_strides(), st1.real_strides()
def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides))
@@ -663,7 +663,7 @@ class Kernel:
# for TC=2, we can't do the shapetracker fixup
srcs = [fixup_ast(rsrc.src[0]), fixup_ast(rsrc.src[1])]
# MUL/SUM instead of WMMA
ret = UOp(Ops.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(BinaryOps.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1]))
ret = UOp(Ops.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(Ops.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1]))
else:
# real WMMA, use CONTRACT/EXPAND to get the vectorization right
wmma_upcast_axes = wmma_arg[-2]

View File

@@ -169,7 +169,7 @@ def sin_poly_large(d:UOp, q:UOp) -> UOp:
def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
"""
Implements a 1.0 ULP approximation for UnaryOps.SIN.
Implements a 1.0 ULP approximation for Ops.SIN.
- fast=True assumes x <= switch_over.
- switch_over is the threshold for switching to payne_hanek_reduction.
"""
@@ -192,7 +192,7 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
def xexp2(d:UOp) -> UOp:
"""
Implements a 1.0 ULP approximation for UnaryOps.EXP2
Implements a 1.0 ULP approximation for Ops.EXP2
- Paper: https://arxiv.org/pdf/2001.09258
"""
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
@@ -218,7 +218,7 @@ def xexp2(d:UOp) -> UOp:
def xlog2(d:UOp) -> UOp:
"""
Implements a 1.0 ULP approximation for UnaryOps.LOG2
Implements a 1.0 ULP approximation for Ops.LOG2
Paper: https://arxiv.org/pdf/2001.09258 5.5
"""
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES

View File

@@ -3,7 +3,7 @@ from typing import Optional, Tuple, Dict, List, TYPE_CHECKING, Any, DefaultDict,
import functools, itertools, operator
from collections import defaultdict
from tinygrad.dtype import dtypes, ImageDType, PtrDType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple
from tinygrad.ops import graph_rewrite, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
@@ -23,7 +23,7 @@ def fold_expanded(ex, buf):
for i,s in enumerate(new_srcs):
idx = s.src[0].src[1]
if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue
if idx.op is BinaryOps.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
# add gates for gated
@@ -92,12 +92,12 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
# can drop valid if idx is out of bound when valid is False
drop_stmt = []
for stmt in split_uop(valid, BinaryOps.AND):
for stmt in split_uop(valid, Ops.AND):
X, is_upper_bound, c = parse_valid(stmt)
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, BinaryOps.ADD)):
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, BinaryOps.ADD), idx)
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, Ops.ADD)):
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, Ops.ADD), idx)
testidx = testidx.simplify()
if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0:
drop_stmt.append(stmt)
@@ -114,7 +114,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
break
if not drop_stmt and idx is start_idx: return None
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, BinaryOps.AND) if s not in drop_stmt]) else None
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None
return buf.index(idx, new_valid)
# ***** optional patterns *****
@@ -123,23 +123,23 @@ powers_of_two = {2**i:i for i in range(64)}
@functools.lru_cache(None)
def get_late_rewrite_patterns(ops, force_transcendental=False):
pat: List[Tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if op not in ops or force_transcendental]
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
# rewrite MOD to AND (which should always be supported, but not for generic in tests)
if BinaryOps.AND in ops:
if Ops.AND in ops:
pat += [(UPat(Ops.MOD, src=(UPat.var('base'), UPat.cvar("const"))),
lambda base,const: base & (const.arg-1) if const.arg in powers_of_two else None)]
# rewrite MUL/IDIV to SHL+SHR
if BinaryOps.SHL in ops and BinaryOps.SHR in ops:
if Ops.SHL in ops and Ops.SHR in ops:
pat += [
(UPat(Ops.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const:
mul << powers_of_two[const.arg] if const.arg in powers_of_two else None), # (x * (2**y)) -> shl(x,y)
(UPat(Ops.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const:
div >> powers_of_two[const.arg] if const.arg in powers_of_two else None)] # (x // (2**y)) -> shr(x,y)
if UnaryOps.NEG in ops:
pat += [(UPat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))]
if BinaryOps.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(UnaryOps.NEG), lambda x,y: x.alu(BinaryOps.SUB, y))]
if TernaryOps.MULACC in ops:
pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(TernaryOps.MULACC, b, c))]
if Ops.NEG in ops:
pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
if Ops.MULACC in ops:
pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
return PatternMatcher(pat)
# ***** threefry *****
@@ -225,7 +225,7 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
if len(reduce_unparented) == 0: return None
new_acc = acc.replace(src=acc.src[0:1]+tuple(reduce_parented))
ret = new_acc.assign(new_acc.alu(alu.op, ret))
if alu.op is BinaryOps.ADD:
if alu.op is Ops.ADD:
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Optional, Any, Tuple, List, get_args
from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP, LAZYCACHE
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, exec_alu, python_alu
from tinygrad.ops import exec_alu, python_alu
from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint, GroupOp, Ops
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer
@@ -11,9 +11,9 @@ from weakref import ref, ReferenceType, WeakValueDictionary
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
base:Optional[LazyBuffer]=None, enable_cache=bool(LAZYCACHE)):
if st.size == 0: op, arg, srcs, base = MetaOps.CONST, 0, (), None
if st.size == 0: op, arg, srcs, base = Ops.CONST, 0, (), None
dtype = to_dtype(dtype)
if op is MetaOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, UOp) else arg, True
if op is Ops.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, UOp) else arg, True
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
if enable_cache and (rret := lazycache.get(cache_key, None)) is not None: return rret
@@ -32,13 +32,13 @@ class LazyBuffer(MathTrait):
if base is None:
# properties on base
self.op, self.arg, self.srcs = op, arg, srcs # this is a UOp, except the src is LazyBuffers and not UOps
assert self.op is not MetaOps.ASSIGN or srcs[0].base.realized is not None, "assign target must be realized"
assert self.op is not Ops.ASSIGN or srcs[0].base.realized is not None, "assign target must be realized"
if self.op is MetaOps.BUFFER_VIEW:
if self.op is Ops.BUFFER_VIEW:
# some LazyBuffers can be processed with only a view, no AST required
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, self.dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
else:
self.buffer = srcs[0].base.buffer if self.op is MetaOps.ASSIGN else Buffer(device, self.size, self.dtype)
self.buffer = srcs[0].base.buffer if self.op is Ops.ASSIGN else Buffer(device, self.size, self.dtype)
self.buffer.ref(1)
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
self.forced_realize = False
@@ -74,14 +74,14 @@ class LazyBuffer(MathTrait):
def const_like(self, b): return self.const_with_shape(b, self.shape)
def const_with_shape(self, val:ConstType, shape:Tuple[sint,...]) -> LazyBuffer:
assert isinstance(val, get_args(ConstType)), f"{val=} has {type(val)=}, not a ConstType"
return LazyBuffer.metaop(MetaOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
def is_realized(self) -> bool: return self.base.realized is not None
def assign(self, x:LazyBuffer) -> LazyBuffer:
assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
assert self.is_realized(), f"assign target must be realized {self}"
return LazyBuffer.metaop(MetaOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,),
return LazyBuffer.metaop(Ops.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,),
src=(self.base, x), enable_cache=True)
def can_view(self):
@@ -90,7 +90,7 @@ class LazyBuffer(MathTrait):
def contiguous(self, allow_buffer_view=True):
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
ret = self.alu(MetaOps.BUFFER_VIEW) if allow_buffer_view and self.can_view() else self.alu(MetaOps.CONTIGUOUS)
ret = self.alu(Ops.BUFFER_VIEW) if allow_buffer_view and self.can_view() else self.alu(Ops.CONTIGUOUS)
if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
return ret
self.base.forced_realize = True
@@ -101,7 +101,7 @@ class LazyBuffer(MathTrait):
if self.dtype == dtype: return self
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
if self.is_unrealized_unmasked_const() and not bitcast:
return create_lazybuffer(self.device, self.st, dtype, MetaOps.CONST, dtypes.as_const(self.base.arg, dtype))
return create_lazybuffer(self.device, self.st, dtype, Ops.CONST, dtypes.as_const(self.base.arg, dtype))
new_shape = self.shape
if bitcast and self.dtype.itemsize != dtype.itemsize:
if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
@@ -112,27 +112,27 @@ class LazyBuffer(MathTrait):
elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
# TODO: applying this makes gpt2 slower
return self.base.cast(dtype, bitcast)._view(self.st)
cast_op: Ops = (MetaOps.BUFFER_VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
cast_op: Ops = (Ops.BUFFER_VIEW if self.can_view() and allow_buffer_view else Ops.BITCAST) if bitcast else Ops.CAST
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
def is_unrealized_const(self): return self.base.realized is None and self.base.op is MetaOps.CONST and not isinstance(self.base.arg, UOp)
def is_unrealized_const(self): return self.base.realized is None and self.base.op is Ops.CONST and not isinstance(self.base.arg, UOp)
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
def _copy(self, device:str) -> LazyBuffer:
assert self.st.contiguous and self.size == self.base.size, f"can only copy contig {self} {self.base}"
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, MetaOps.COPY, self.buffer.nbytes, (self,), enable_cache=False)
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, Ops.COPY, self.buffer.nbytes, (self,), enable_cache=False)
def copy_to_device(self, device:str, force:bool=False, clone:bool=False) -> LazyBuffer:
# no COPY
if self.device == device and not clone: return self
# double COPY = one COPY
if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is MetaOps.COPY:
if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is Ops.COPY:
return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
# const doesn't have to be copied (issues with disk tensor)
if self.is_unrealized_const():
return LazyBuffer.metaop(MetaOps.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device)
@@ -149,25 +149,25 @@ class LazyBuffer(MathTrait):
srcs.append(root._view(s.base.contiguous_child[1]))
else:
srcs.append(s)
if not all_same(dts:=[x.dtype.base for x in (srcs[1:] if op is TernaryOps.WHERE else srcs)]):
if not all_same(dts:=[x.dtype.base for x in (srcs[1:] if op is Ops.WHERE else srcs)]):
raise AssertionError(f"all dtypes must match {dts} on {op}")
assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
if op is Ops.WHERE: assert srcs[0].dtype == dtypes.bool, "Ops.WHERE must have the first arg be bool"
out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else srcs[-1].dtype
out_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else srcs[-1].dtype
# const folding
if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
return self.cast(out_dtype).const_like(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
if op in GroupOp.Binary:
x, y = self, in_srcs[0]
if op is BinaryOps.ADD:
if op is Ops.ADD:
if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y
if op is BinaryOps.MUL:
if op is Ops.MUL:
if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0): return y if val == 1 else y.const_like(0)
if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0): return x if val == 1 else x.const_like(0)
if op is BinaryOps.IDIV and y.is_unrealized_unmasked_const() and y.base.arg == 1: return x
if op is Ops.IDIV and y.is_unrealized_unmasked_const() and y.base.arg == 1: return x
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, None, tuple(srcs))

View File

@@ -29,14 +29,14 @@ class SimpleMathTrait:
dtype: Optional[DType] = getattr(self, 'dtype', None)
assert dtype is not None, "MathTraits __neg__ requires a dtype"
return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
def add(self, x, reverse=False): return self._binop(BinaryOps.ADD, x, reverse)
def mul(self, x, reverse=False): return self._binop(BinaryOps.MUL, x, reverse)
def bitwise_and(self, x, reverse=False): return self._binop(BinaryOps.AND, x, reverse)
def bitwise_or(self, x, reverse=False): return self._binop(BinaryOps.OR, x, reverse)
def xor(self, x, reverse=False): return self._binop(BinaryOps.XOR, x, reverse)
def idiv(self, x, reverse=False): return self._binop(BinaryOps.IDIV, x, reverse)
def sub(self, x, reverse=False): return self.ufix(x).alu(BinaryOps.ADD, -self) if reverse else self.alu(BinaryOps.ADD, self.ufix(-x))
def div(self, x, reverse=False): return (self.ufix(x)*self.alu(UnaryOps.RECIP)) if reverse else (self*self.ufix(x).alu(UnaryOps.RECIP))
def add(self, x, reverse=False): return self._binop(Ops.ADD, x, reverse)
def mul(self, x, reverse=False): return self._binop(Ops.MUL, x, reverse)
def bitwise_and(self, x, reverse=False): return self._binop(Ops.AND, x, reverse)
def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse)
def xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse)
def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse)
def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP))
def __neg__(self): return self.neg()
@@ -58,9 +58,9 @@ class SimpleMathTrait:
def __ror__(self, x): return self.bitwise_or(x, True)
def __rxor__(self, x): return self.xor(x, True)
def lt(self, x): return self.alu(BinaryOps.CMPLT, self.ufix(x))
def gt(self, x): return self.ufix(x).alu(BinaryOps.CMPLT, self)
def ne(self, x): return self.alu(BinaryOps.CMPNE, self.ufix(x))
def lt(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
def gt(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x))
def ge(self, x): return self.lt(x).logical_not()
def le(self, x): return self.gt(x).logical_not()
def eq(self, x): return self.ne(x).logical_not()
@@ -74,26 +74,26 @@ class SimpleMathTrait:
class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method
# TODO: move to Tensor when new backward is done
def lshift(self, x, reverse=False): return self._binop(BinaryOps.SHL, x, reverse)
def rshift(self, x, reverse=False): return self._binop(BinaryOps.SHR, x, reverse)
def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse)
def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse)
def __lshift__(self, x): return self.lshift(x)
def __rshift__(self, x): return self.rshift(x)
def __rlshift__(self, x): return self.lshift(x, True)
def __rrshift__(self, x): return self.rshift(x, True)
# not in Tensor
def __mod__(self, x): return self.alu(BinaryOps.MOD, self.ufix(x))
def __rmod__(self, x): return self.ufix(x).alu(BinaryOps.MOD, self)
def __mod__(self, x): return self.alu(Ops.MOD, self.ufix(x))
def __rmod__(self, x): return self.ufix(x).alu(Ops.MOD, self)
def maximum(self, x): return self.alu(BinaryOps.MAX, self.ufix(x))
def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x))
def minimum(self, x): return -(-self).maximum(-x)
def where(self, x, y): return self.alu(TernaryOps.WHERE, x, x.ufix(y))
def threefry(self, seed): return self.alu(BinaryOps.THREEFRY, seed)
def reciprocal(self): return self.alu(UnaryOps.RECIP)
def sqrt(self): return self.alu(UnaryOps.SQRT)
def sin(self): return self.alu(UnaryOps.SIN)
def log2(self): return self.alu(UnaryOps.LOG2)
def exp2(self): return self.alu(UnaryOps.EXP2)
def where(self, x, y): return self.alu(Ops.WHERE, x, x.ufix(y))
def threefry(self, seed): return self.alu(Ops.THREEFRY, seed)
def reciprocal(self): return self.alu(Ops.RECIP)
def sqrt(self): return self.alu(Ops.SQRT)
def sin(self): return self.alu(Ops.SIN)
def log2(self): return self.alu(Ops.LOG2)
def exp2(self): return self.alu(Ops.EXP2)
# the order of these Ops controls the order of the toposort
class Ops(FastEnum):
@@ -182,11 +182,8 @@ class GroupOp:
# do not preserve f(0) = 0
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}
# TODO: remove this?
UnaryOps = BinaryOps = MetaOps = TernaryOps = Ops
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt)
def identity_element(op:Ops, dt:DType): return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
def can_pad(u:UOp) -> bool: return not any(x.op in GroupOp.UnsafePad for x in u.sparents)
@@ -325,7 +322,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
def alu(self, arg, *src:UOp):
out_dtype = (self, *src)[-1].dtype
if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} and out_dtype is not None:
if arg in {Ops.CMPLT, Ops.CMPNE} and out_dtype is not None:
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
return UOp(arg, out_dtype, (self,)+src)
@staticmethod
@@ -392,15 +389,15 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
"""largest known int that divides self"""
if self.op is Ops.CONST: return self.arg
if self.op is Ops.VCONST: return math.gcd(*self.arg)
if self.op is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
if self.op is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
return 1
def divides(self, v) -> Optional[UOp]:
if v==1: return self
if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
if self.op is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
if self.op is BinaryOps.MUL:
if self.op is Ops.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
if self.op is Ops.MUL:
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
return None # generic None if we aren't sure
@@ -412,18 +409,18 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def _min_max(self) -> Tuple[ConstType, ConstType]:
if self.op in GroupOp.Binary and not dtypes.is_float(self.dtype):
(s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max
if self.op is BinaryOps.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
if self.op is BinaryOps.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
if self.op is BinaryOps.MOD and s1_vmin > 0: return 0, s1_vmax-1
if self.op is BinaryOps.IDIV and s1_vmin == s1_vmax: # min/max are equal in a CONST
if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
if self.op is Ops.MOD and s1_vmin > 0: return 0, s1_vmax-1
if self.op is Ops.IDIV and s1_vmin == s1_vmax: # min/max are equal in a CONST
if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
if self.op is BinaryOps.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
if self.op is BinaryOps.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
if self.op is BinaryOps.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))
if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
if self.op is Ops.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))
if self.dtype == dtypes.bool:
if self.op is BinaryOps.OR: return s0_vmin or s1_vmin, s0_vmax or s1_vmax
if self.op is BinaryOps.AND: return s0_vmin and s1_vmin, s0_vmax and s1_vmax
if self.op is Ops.OR: return s0_vmin or s1_vmin, s0_vmax or s1_vmax
if self.op is Ops.AND: return s0_vmin and s1_vmin, s0_vmax and s1_vmax
# float has NAN issue and we use explicit NAN in transcendental
if self.op is Ops.WHERE and dtypes.is_int(self.dtype): return min(self.src[1].vmin, self.src[2].vmin), max(self.src[1].vmax, self.src[2].vmax)
# NOTE: returned UOp is assumed to be CONST
@@ -467,14 +464,14 @@ def hook_overflow(dv, fxn):
return wfxn
python_alu: Dict[Ops, Callable] = {
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x),
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
UnaryOps.NEG: operator.neg, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul,
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt, BinaryOps.XOR: operator.xor,
BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_, BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift,
TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z}
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: hook_overflow(math.inf, lambda x: 2**x),
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul,
Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
Ops.MAX: max, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt, Ops.XOR: operator.xor,
Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift,
Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z}
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
if dtype.count > 1:
@@ -515,7 +512,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
elif u.op is Ops.STORE:
mem += u.src[1].dtype.itemsize * mults
elif u.op in GroupOp.ALU and u not in dont_count:
flops += (mults * (2 if u.op is TernaryOps.MULACC else 1)) * u.dtype.count
flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
elif u.op is Ops.WMMA and u not in dont_count:
flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return flops, mem
@@ -857,7 +854,7 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]:
if 0 < c and 0 <= x.vmin and (quotient:=x.vmin//c) == x.vmax//c: return x-quotient*c
remainder, something_changed = [], False
for u in split_uop(x, BinaryOps.ADD):
for u in split_uop(x, Ops.ADD):
if (factor:=u.const_factor())%c != factor:
divides = u.divides(factor)*(factor%c)
assert divides is not None
@@ -877,7 +874,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
if 0 <= x.vmin and x.vmax < c: return x.const_like(0)
quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1
for u in split_uop(x, BinaryOps.ADD):
for u in split_uop(x, Ops.ADD):
if u.op is Ops.CONST:
# add all const together first
if rem_const != 0: something_changed = True
@@ -911,7 +908,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo
def lt_folding(x:UOp, c:int) -> Optional[UOp]:
p, np = partition(split_uop(x, BinaryOps.ADD), lambda u: u.const_factor() == 1)
p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1)
if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
return cast(UOp, functools.reduce(operator.add, np).divides(d)).lt(c//d)
return None
@@ -919,7 +916,7 @@ def lt_folding(x:UOp, c:int) -> Optional[UOp]:
def fold_unrolled_divs(divs:UOp):
# div pattern in unrolled arange
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
add_chain, denominator, seen_const, ans = list(split_uop(divs, BinaryOps.ADD)), None, [], None
add_chain, denominator, seen_const, ans = list(split_uop(divs, Ops.ADD)), None, [], None
for u in add_chain:
if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None
if denominator is None: denominator = u.src[1].arg
@@ -941,9 +938,9 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
# returns x0 + x1 + ... in such case, or None if not
changed, ret = False, []
for u in split_uop(X, BinaryOps.ADD):
for u in split_uop(X, Ops.ADD):
# assumed the const is the last src of MUL
if u.op is BinaryOps.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
changed = True
u = u.src[0]
if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None
@@ -953,8 +950,8 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
def is_increasing(f:UOp) -> bool:
# is f a monotonically increasing function regards its input
if f.op in GroupOp.Irreducible: return True
if f.op is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
if f.op in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
if f.op is Ops.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
if f.op in (Ops.MUL, Ops.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
return False # False if not sure
def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
@@ -962,10 +959,10 @@ def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
# if it's X >= c, returns X, False, c
# (X < c).ne(True) -> X >= c
if valid.op is BinaryOps.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
(s0:=valid.src[0]).op is BinaryOps.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
(s0:=valid.src[0]).op is Ops.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
# X < c -> X <= c-1
if valid.op is BinaryOps.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1
if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1
raise ValueError(f"not able to parse {valid=}")
def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
@@ -973,7 +970,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
# first, parse valid into {expr: (lower_bound, upper_bound)}
bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None])
for stmt in split_uop(valid, BinaryOps.AND):
for stmt in split_uop(valid, Ops.AND):
try: expr, is_upper, c = parse_valid(stmt)
except ValueError: return uop # give up if we cannot parse the valid
bounds[expr][int(is_upper)] = c
@@ -985,9 +982,9 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
candidates = []
if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, BinaryOps.ADD)):
if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, BinaryOps.ADD)])
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
# try checking the whole clause
if expr in uop.sparents:
candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
@@ -1128,8 +1125,8 @@ symbolic_flat = symbolic+PatternMatcher([
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
# for debug
syms = { BinaryOps.ADD: "+", BinaryOps.SUB: "-", BinaryOps.IDIV: "//", BinaryOps.MOD: "%", BinaryOps.SHL: "<<", BinaryOps.SHR: ">>",
BinaryOps.MUL: "*", BinaryOps.CMPLT: "<", BinaryOps.CMPNE: "!=", BinaryOps.AND: "&", BinaryOps.OR: "|", BinaryOps.XOR: "^"}
syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>",
Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}
renderer = PatternMatcher([
(UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")),

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast
import os, math
from collections import defaultdict, Counter
from tinygrad.ops import GroupOp, UnaryOps, BinaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16
from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
from tinygrad.renderer import Renderer, TensorCore
@@ -42,13 +42,13 @@ base_rewrite = PatternMatcher([
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
# new load/store
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == BinaryOps.ADD else ctx[idx]})"),
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
(UPat(Ops.LOAD, src=(UPat.var('bidx'), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
# alu/gep
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else ctx[v] for v in x.src]), x.dtype)),
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR} else ctx[v] for v in x.src]), x.dtype)),
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
])
@@ -82,13 +82,13 @@ class CStyleLanguage(Renderer):
infinity: str = "INFINITY"
nan: str = "NAN"
code_for_op: Dict = {
UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})", UnaryOps.RECIP: lambda x,dtype: f"(1/{x})", UnaryOps.NEG: lambda x,dtype: f"-{x}",
UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})",
BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})",
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})",
BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})", BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})",
BinaryOps.SHR: lambda a,b,dtype: f"({a}>>{b})", BinaryOps.SHL: lambda a,b,dtype: f"({a}<<{b})", BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})",
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})" }
Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIP: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}",
Ops.EXP2: lambda x,dtype: f"exp2({x})", Ops.LOG2: lambda x,dtype: f"log2({x})", Ops.SIN: lambda x,dtype: f"sin({x})",
Ops.AND: lambda a,b,dtype: f"({a}&{b})", Ops.XOR: lambda a,b,dtype: f"({a}^{b})", Ops.OR: lambda a,b,dtype: f"({a}|{b})",
Ops.ADD: lambda a,b,dtype: f"({a}+{b})", Ops.SUB: lambda a,b,dtype: f"({a}-{b})", Ops.MUL: lambda a,b,dtype: f"({a}*{b})",
Ops.MOD: lambda a,b,dtype: f"({a}%{b})", Ops.IDIV: lambda a,b,dtype: f"({a}/{b})", Ops.CMPNE: lambda a,b,dtype: f"({a}!={b})",
Ops.SHR: lambda a,b,dtype: f"({a}>>{b})", Ops.SHL: lambda a,b,dtype: f"({a}<<{b})", Ops.CMPLT: lambda a,b,dtype: f"({a}<{b})",
Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})" }
string_rewrite = base_rewrite
extra_matcher = extra_pm
@@ -174,8 +174,8 @@ class ClangRenderer(CStyleLanguage):
# language options
buffer_suffix = " restrict"
type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [UnaryOps.EXP2, UnaryOps.SIN, UnaryOps.LOG2]}),
UnaryOps.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"}
code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2]}),
Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"}
if AMX:
tensor_cores = [TensorCore(dims=(sz,sz,1), threads=[], reduce_axes=[], upcast_axes=([(1,sz)],[(0,sz)],[(1,sz),(0,sz)]), dtype_in=dt, dtype_out=dt)
@@ -265,12 +265,12 @@ class MetalRenderer(CStyleLanguage):
type_map = {dtypes.bfloat16: "bfloat"}
# precise::sin
code_for_op = {**CStyleLanguage.code_for_op, UnaryOps.SIN: lambda x,dtype: f"precise::sin({x})"}
code_for_op = {**CStyleLanguage.code_for_op, Ops.SIN: lambda x,dtype: f"precise::sin({x})"}
# upcast to float32 all the ops that don't support bfloat16
extra_matcher = PatternMatcher([
# NOTE: this is copied from PTX
(UPat((UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN), dtype=dtypes.bfloat16, name="x"),
(UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtypes.bfloat16, name="x"),
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))),
]) + extra_pm
@@ -312,11 +312,11 @@ class CUDARenderer(CStyleLanguage):
code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+int(x))}", "l": lambda x: f"threadIdx.{chr(120+int(x))}",
"i": lambda x: f"(blockIdx.{chr(120+int(x))}*blockDim.{chr(120+int(x))}+threadIdx.{chr(120+int(x))})"}
code_for_op = { **CStyleLanguage.code_for_op,
UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",
UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" }
Ops.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
Ops.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
Ops.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",
Ops.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
Ops.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" }
type_map = {dtypes.bfloat16: "nv_bfloat16"}
def render_vector_prefix(self, dt:DType) -> str:
@@ -375,10 +375,10 @@ class AMDRenderer(CStyleLanguage):
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
code_for_op = { **CStyleLanguage.code_for_op,
UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})" }
Ops.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
Ops.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
Ops.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
Ops.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})" }
smem_prefix = "__attribute__((shared))"
barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
'__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'
@@ -433,9 +433,9 @@ class DSPRenderer(ClangRenderer):
buffer_suffix = " restrict __attribute__((align_value(128)))"
kernel_prefix = "__attribute__((noinline)) "
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
code_for_op = {**ClangRenderer.code_for_op, UnaryOps.SIN: lambda x,dtype: f"__builtin_sin({x})",
UnaryOps.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",
UnaryOps.EXP2: lambda x,dtype: f"__builtin_exp2l({x})" if dtype == dtypes.float64 else f"__builtin_exp2f({x})"}
code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})",
Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",
Ops.EXP2: lambda x,dtype: f"__builtin_exp2l({x})" if dtype == dtypes.float64 else f"__builtin_exp2f({x})"}
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
ret = super().render_kernel(function_name, kernel, bufs, uops, prefix)

View File

@@ -1,7 +1,7 @@
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable, Tuple
import struct
from collections import defaultdict
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, GroupOp
from tinygrad.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer
@@ -15,24 +15,24 @@ def render_val(x, dtype):
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
asm_for_op: Dict[Ops, Callable] = {
UnaryOps.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
BinaryOps.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", BinaryOps.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
BinaryOps.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if name == "pred" else f"and.b{name[1:]} {d}, {a}, {b};",
BinaryOps.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if name == "pred" else f"or.b{name[1:]} {d}, {a}, {b};",
BinaryOps.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", BinaryOps.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
TernaryOps.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
TernaryOps.WHERE: lambda d,a,b,c,dt,name:
Ops.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
Ops.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", Ops.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
Ops.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", Ops.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
Ops.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", Ops.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
Ops.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
Ops.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
Ops.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
Ops.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if name == "pred" else f"and.b{name[1:]} {d}, {a}, {b};",
Ops.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if name == "pred" else f"or.b{name[1:]} {d}, {a}, {b};",
Ops.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
Ops.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", Ops.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
Ops.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", Ops.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
Ops.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
Ops.WHERE: lambda d,a,b,c,dt,name:
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
}
supports_half: List[Ops] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
supports_half: List[Ops] = [Ops.EXP2, Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPLT, Ops.WHERE]
doesnt_support_half: Tuple[Ops, ...] = tuple(op for op in asm_for_op.keys() if op not in supports_half)
ptx_matcher = PatternMatcher([
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
@@ -151,8 +151,8 @@ class PTXRenderer(Renderer):
kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{uops.index(u)}", pred_reg, invert=True))
elif uop is Ops.BARRIER and self.barrier: kk(self.barrier)
elif uop is Ops.ENDRANGE:
kk(self.code_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
self.code_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
kk(self.code_for_op[Ops.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
self.code_for_op[Ops.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred))
elif uop is Ops.ENDIF:
kk(f"IF_{r[src[0].src[0]][1:]}_{uops.index(src[0])}:")
@@ -167,7 +167,7 @@ class PTXRenderer(Renderer):
else:
if uop is Ops.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
elif uop in GroupOp.ALU:
src_dtype = src[0].dtype if uop in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype
src_dtype = src[0].dtype if uop in {Ops.CMPLT, Ops.CMPNE} else dtype
kk(self.code_for_op[uop](ssa("alu", u), *[r[x] for x in src], src_dtype, self.types[src_dtype]))
elif uop is Ops.DEFINE_ACC:
if dtype.count > 1:

View File

@@ -7,7 +7,7 @@ import pickle, base64, itertools, time, struct
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
from tinygrad.helpers import all_same, getenv, flatten
from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, Ops, UOp, GroupOp
from tinygrad.ops import exec_alu, Ops, UOp, GroupOp
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer, ClangRenderer
@@ -175,7 +175,7 @@ class PythonProgram:
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
elif uop in GroupOp.ALU:
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {uop}"
assert all_same([dtype] + dtp) or uop in {BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {uop}"
assert all_same([dtype] + dtp) or uop in {Ops.CMPNE, Ops.CMPLT, Ops.WHERE}, f"dtype mismatch on {uop}"
ul[i] = [exec_alu(uop, dtype, p) for p in zip(*inp)]
assert i in ul, (uop, dtype, idp, arg)
i += 1

View File

@@ -5,7 +5,7 @@ from typing import Tuple, List, Optional, Dict, Set
from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.view import View, strides_for_shape
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, Ops, BinaryOps, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid
from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid
@dataclass(frozen=True, order=True)
class ShapeTracker:
@@ -77,7 +77,7 @@ class ShapeTracker:
# TODO: always apply these in to_indexed_uops?
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
for c in split_uop(idx, BinaryOps.ADD):
for c in split_uop(idx, Ops.ADD):
if c.op is Ops.RANGE: ret[c.arg[0]] = 1
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg

View File

@@ -9,7 +9,7 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN
from tinygrad.multi import MultiLazyBuffer
from tinygrad.ops import MetaOps, smax, smin, resolve, UOp, Ops, BinaryOps, sint, Variable, SimpleMathTrait
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait
from tinygrad.device import Device, Buffer, BufferOptions
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.engine.realize import run_schedule
@@ -51,7 +51,7 @@ def _to_np_dtype(dtype:DType) -> Optional[type]:
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
def _fromnp(x: 'np.ndarray') -> LazyBuffer: # type: ignore [name-defined] # noqa: F821
ret = LazyBuffer.metaop(MetaOps.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
ret = LazyBuffer.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
# fake realize
ret.buffer.allocate(x)
del ret.srcs
@@ -64,9 +64,9 @@ def get_shape(x) -> Tuple[int, ...]:
return (len(subs),) + (subs[0] if subs else ())
def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
if isinstance(x, bytes): ret, data = LazyBuffer.metaop(MetaOps.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
if isinstance(x, bytes): ret, data = LazyBuffer.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
else:
ret = LazyBuffer.metaop(MetaOps.EMPTY, get_shape(x), dtype, "PYTHON")
ret = LazyBuffer.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON")
assert dtype.fmt is not None, f"{dtype=} has None fmt"
truncate_function = truncate[dtype]
data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
@@ -134,10 +134,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
# create a LazyBuffer from the different types of inputs
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
elif isinstance(data, get_args(ConstType)): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
elif isinstance(data, get_args(ConstType)): data = _metaop(Ops.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
elif isinstance(data, UOp):
assert data.op is Ops.BIND and data.src[0].op is Ops.DEFINE_VAR and data.src[1].op is Ops.CONST, f"can't create tensor from UOp {data}"
data = _metaop(MetaOps.CONST, tuple(), dtype or data.dtype, device, data)
data = _metaop(Ops.CONST, tuple(), dtype or data.dtype, device, data)
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
elif isinstance(data, (list, tuple)):
if dtype is None:
@@ -145,15 +145,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
else: data = _frompy(data, dtype)
elif data is None: data = _metaop(MetaOps.EMPTY, (0,), dtype or dtypes.default_float, device)
elif data is None: data = _metaop(Ops.EMPTY, (0,), dtype or dtypes.default_float, device)
elif str(type(data)) == "<class 'numpy.ndarray'>":
import numpy as np
assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}"
if data.shape == (): data = _metaop(MetaOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
if data.shape == (): data = _metaop(Ops.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data) # type: ignore [name-defined]
elif isinstance(data, pathlib.Path):
dtype = dtype or dtypes.uint8
data = _metaop(MetaOps.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
# by this point, it has to be a LazyBuffer
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)):
@@ -384,9 +384,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
def from_uop(y:UOp, **kwargs) -> Tensor:
if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor
if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
if y.op is BinaryOps.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
if y.op is BinaryOps.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
if y.op is BinaryOps.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1]))
if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
if y.op is Ops.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1]))
raise RuntimeError(f"unhandled UOp {y}")
# ***** creation entrypoint *****
@@ -412,7 +412,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
print(t.shape)
```
"""
return Tensor._metaop(MetaOps.EMPTY, argfix(*shape), **kwargs)
return Tensor._metaop(Ops.EMPTY, argfix(*shape), **kwargs)
@staticmethod
def from_blob(ptr:int, shape:Tuple[int, ...], **kwargs) -> Tensor:
@@ -424,7 +424,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
Additionally, all other keyword arguments are passed to the constructor of the tensor.
"""
r = Tensor._metaop(MetaOps.EMPTY, shape, **kwargs)
r = Tensor._metaop(Ops.EMPTY, shape, **kwargs)
r.lazydata.buffer.allocate(external_ptr=ptr)
del r.lazydata.srcs # fake realize
return r