From 597a239e282b275626a4e4f26b1f40348c6a9b7a Mon Sep 17 00:00:00 2001 From: ignaciosica Date: Sat, 16 Nov 2024 09:56:56 -0300 Subject: [PATCH] Remove UnaryOps, BinaryOps, TernaryOps, MetaOps [pr] (#7725) * remove unaryops * remove ternaryops * remove metaops * hotfix * remove binaryops * hotfix: test_pattern_matcher --------- Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com> --- docs/abstractions2.py | 8 +- examples/llm.c/export.py | 2 +- extra/optimization/helpers.py | 2 +- test/external/external_test_nv.py | 8 +- test/external/external_test_valid_remove.py | 4 +- test/external/fuzz_linearizer.py | 4 +- test/external/fuzz_schedule.py | 8 +- test/external/fuzz_symbolic.py | 6 +- test/test_const_folding.py | 2 +- test/test_lazybuffer.py | 12 +- test/test_linearizer.py | 150 +++++++++---------- test/test_linearizer_dumb.py | 20 +-- test/test_linearizer_failures.py | 108 +++++++------- test/test_linearizer_overflows.py | 20 +-- test/test_multitensor.py | 10 +- test/test_renderer_failures.py | 4 +- test/test_schedule.py | 42 +++--- test/test_search.py | 4 +- test/test_uop_graph.py | 9 +- test/test_uops.py | 152 ++++++++++---------- test/test_viz.py | 5 +- test/unit/test_graph_rewrite.py | 14 +- test/unit/test_pattern_matcher.py | 8 +- tinygrad/codegen/kernel.py | 16 +-- tinygrad/codegen/transcendental.py | 6 +- tinygrad/codegen/uopgraph.py | 30 ++-- tinygrad/engine/lazy.py | 42 +++--- tinygrad/ops.py | 129 ++++++++--------- tinygrad/renderer/cstyle.py | 52 +++---- tinygrad/renderer/ptx.py | 38 ++--- tinygrad/runtime/ops_python.py | 4 +- tinygrad/shape/shapetracker.py | 4 +- tinygrad/tensor.py | 28 ++-- 33 files changed, 473 insertions(+), 478 deletions(-) diff --git a/docs/abstractions2.py b/docs/abstractions2.py index e2690baab8..09bdcdc466 100644 --- a/docs/abstractions2.py +++ b/docs/abstractions2.py @@ -39,7 +39,7 @@ DEVICE = "CLANG" # NOTE: you can change this! import struct from tinygrad.dtype import dtypes from tinygrad.device import Buffer, Device -from tinygrad.ops import BinaryOps, MetaOps, UOp, Ops +from tinygrad.ops import UOp, Ops from tinygrad.shape.shapetracker import ShapeTracker # allocate some buffers + load in values @@ -81,15 +81,15 @@ from tinygrad.engine.realize import run_schedule from tinygrad.engine.schedule import create_schedule # allocate some values + load in values -a = LazyBuffer.metaop(MetaOps.EMPTY, (1,), dtypes.int32, DEVICE) -b = LazyBuffer.metaop(MetaOps.EMPTY, (1,), dtypes.int32, DEVICE) +a = LazyBuffer.metaop(Ops.EMPTY, (1,), dtypes.int32, DEVICE) +b = LazyBuffer.metaop(Ops.EMPTY, (1,), dtypes.int32, DEVICE) a.buffer.allocate().copyin(memoryview(bytearray(struct.pack("I", 2)))) b.buffer.allocate().copyin(memoryview(bytearray(struct.pack("I", 3)))) del a.srcs del b.srcs # describe the computation -out = a.alu(BinaryOps.ADD, b) +out = a.alu(Ops.ADD, b) # schedule the computation as a list of kernels sched = create_schedule([out]) diff --git a/examples/llm.c/export.py b/examples/llm.c/export.py index 46e36dddfc..c0d52f32cd 100755 --- a/examples/llm.c/export.py +++ b/examples/llm.c/export.py @@ -8,7 +8,7 @@ from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GlobalCou from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import get_kernel, run_schedule from tinygrad.engine.memory import memory_planner -from tinygrad.ops import MetaOps, Ops +from tinygrad.ops import Ops TIMING = getenv("TIMING") diff --git a/extra/optimization/helpers.py b/extra/optimization/helpers.py index 342c4c27b6..7de5a24d33 100644 --- a/extra/optimization/helpers.py +++ b/extra/optimization/helpers.py @@ -2,7 +2,7 @@ from typing import Tuple from tinygrad import Variable from tinygrad.codegen.kernel import Opt, OptOps -from tinygrad.ops import UOp, Ops, KernelInfo, TernaryOps, BinaryOps, UnaryOps, MetaOps +from tinygrad.ops import UOp, Ops, KernelInfo from tinygrad.dtype import dtypes, PtrDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View diff --git a/test/external/external_test_nv.py b/test/external/external_test_nv.py index d4d924287e..58e556b44e 100644 --- a/test/external/external_test_nv.py +++ b/test/external/external_test_nv.py @@ -9,7 +9,7 @@ from tinygrad.engine.realize import get_runner, CompiledRunner from test.external.fuzz_linearizer import get_fuzz_rawbufs from tinygrad.codegen.kernel import Kernel -from tinygrad.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer +from tinygrad.ops import LazyOp, Ops, ReduceOps, BufferOps, MemBuffer from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View @@ -26,12 +26,12 @@ class TestNV(unittest.TestCase): TestNV.addr = struct.pack("QQ", TestNV.b.lazydata.buffer._buf.va_addr, TestNV.a.lazydata.buffer._buf.va_addr) def test_oor_kernels(self): - ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501 + ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501 opts = [Opt(op=OptOps.TC, axis=6, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=2)] # noqa: E501 helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["NV"]) def test_error_on_huge_dims(self): - ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 683, 1), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501 + ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 683, 1), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501 opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=2)] # noqa: E501 with self.assertRaises(RuntimeError) as cm: lin = Kernel(ast) @@ -43,7 +43,7 @@ class TestNV(unittest.TestCase): def test_buf4_usage(self): TestNV.along = Tensor([105615], device="NV").realize() - ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SIN, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501 + ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.SIN, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501 temp_runner = get_runner(TestNV.d0.dname, (ast,)) temp_runner([TestNV.b.lazydata.buffer, TestNV.along.lazydata.buffer], var_vals={}) val = TestNV.b.lazydata.buffer.as_buffer().cast("f")[0] diff --git a/test/external/external_test_valid_remove.py b/test/external/external_test_valid_remove.py index 24f94c640a..1467139daf 100644 --- a/test/external/external_test_valid_remove.py +++ b/test/external/external_test_valid_remove.py @@ -2,7 +2,7 @@ import unittest from tinygrad import Device -from tinygrad.ops import UOp, Ops, BinaryOps +from tinygrad.ops import UOp, Ops from tinygrad.engine.search import Opt, OptOps from tinygrad.dtype import dtypes from tinygrad.shape.shapetracker import ShapeTracker @@ -20,7 +20,7 @@ class TestOpenpilotValidhack(unittest.TestCase): UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.MAX, dtypes.float, arg=None, src=( x5:=UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8, 9, 10)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 8, 9, 10)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 296ec2c60e..49199c11c5 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -25,7 +25,7 @@ from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.engine.search import get_kernel_actions, bufs_from_lin from tinygrad.engine.realize import CompiledRunner from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG, Timing -from tinygrad.ops import UnaryOps, UOp, Ops +from tinygrad.ops import UOp, Ops from tinygrad.device import is_dtype_supported def on_linearizer_will_run(): pass @@ -252,7 +252,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2, opts_list=None): def _is_simple(lin: Kernel) -> bool: if len(lin.ast.src) > 1: return False ast:UOp = lin.ast.src[0] - if ast.src[0].op is UnaryOps.CAST and ast.src[0].src[0].op is Ops.LOAD: return True + if ast.src[0].op is Ops.CAST and ast.src[0].src[0].op is Ops.LOAD: return True return False if __name__ == "__main__": diff --git a/test/external/fuzz_schedule.py b/test/external/fuzz_schedule.py index b9ab2fed86..059edb37b8 100644 --- a/test/external/fuzz_schedule.py +++ b/test/external/fuzz_schedule.py @@ -6,7 +6,7 @@ from tinygrad.engine.realize import capturing, lower_schedule_item from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv from tinygrad.engine.lazy import LazyBuffer from tinygrad.engine.schedule import LBScheduleItem, _graph_schedule, ScheduleItem -from tinygrad.ops import MetaOps +from tinygrad.ops import Ops from tinygrad.tensor import Tensor, _to_np_dtype ctx_vars = { MULTIOUTPUT: (0, 1) } @@ -33,7 +33,7 @@ def fuzz_schedule(outs:List[LazyBuffer]): for lsi in ts: for out in lsi.outputs: # freeze assign state before exec - if out.op is MetaOps.ASSIGN: + if out.op is Ops.ASSIGN: prerealized[out] = out.buffer.as_buffer() assign_targets[out.srcs[1]] = out for x in lsi.inputs: @@ -50,9 +50,9 @@ def fuzz_schedule(outs:List[LazyBuffer]): rawbufs: Dict[LazyBuffer, Buffer] = {} for lsi in ts: for out in lsi.outputs: - base = rawbufs[lsi.inputs[0]].base if out.op is MetaOps.BUFFER_VIEW else None + base = rawbufs[lsi.inputs[0]].base if out.op is Ops.BUFFER_VIEW else None rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype, base=base) - if out.op is MetaOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out]) + if out.op is Ops.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out]) for x in lsi.inputs: if x not in rawbufs: # override the assign_target after ASSIGN diff --git a/test/external/fuzz_symbolic.py b/test/external/fuzz_symbolic.py index bee91cfa57..42798bd181 100644 --- a/test/external/fuzz_symbolic.py +++ b/test/external/fuzz_symbolic.py @@ -42,9 +42,9 @@ def gt(expr, rng=None): return expr > rng, rng # NOTE: you have to replace these for this test to pass -from tinygrad.ops import python_alu, BinaryOps -python_alu[BinaryOps.MOD] = lambda x,y: x%y -python_alu[BinaryOps.IDIV] = lambda x,y: x//y +from tinygrad.ops import python_alu, Ops +python_alu[Ops.MOD] = lambda x,y: x%y +python_alu[Ops.IDIV] = lambda x,y: x//y if __name__ == "__main__": ops = [add_v, div, mul, add_num, mod] diff --git a/test/test_const_folding.py b/test/test_const_folding.py index 327f351248..3763e864f6 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -154,7 +154,7 @@ class TestReduceOpsConstFolding(unittest.TestCase): _check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).sum()) np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).sum().numpy(), 4) - # NOTE: cannot just count the non-padded area because some UnaryOps f do not have f(0) = 0. + # NOTE: cannot just count the non-padded area because some Ops f do not have f(0) = 0. _check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).exp().sum()) np.testing.assert_allclose(Tensor.ones(4).pad(((1, 1),)).exp().sum().numpy(), 4 * math.e + 2) diff --git a/test/test_lazybuffer.py b/test/test_lazybuffer.py index 168720e2d2..30e4e190c9 100644 --- a/test/test_lazybuffer.py +++ b/test/test_lazybuffer.py @@ -3,7 +3,7 @@ import numpy as np import unittest from tinygrad import Tensor, Device, dtypes from tinygrad.ops import Ops -from tinygrad.engine.lazy import LazyBuffer, MetaOps +from tinygrad.engine.lazy import LazyBuffer from tinygrad.engine.schedule import create_schedule class TestLazyBuffer(unittest.TestCase): @@ -95,24 +95,24 @@ class TestReduceOp(unittest.TestCase): class TestView(unittest.TestCase): def test_all_masked_out(self): - # start with non CONST MetaOps + # start with non CONST Ops a = Tensor.rand(10, 10) - assert a.lazydata.base.op is not MetaOps.CONST + assert a.lazydata.base.op is not Ops.CONST # all masked out, degrades to const 0 b = a.pad(((0, 10), None))[10:] assert b.shape == (10, 10) - assert b.lazydata.base.op is MetaOps.CONST and b.lazydata.base.arg == 0 + assert b.lazydata.base.op is Ops.CONST and b.lazydata.base.arg == 0 # mask out dim = 1 works too b = a.pad((None, (0, 10)))[:, 10:] assert b.shape == (10, 10) - assert b.lazydata.base.op is MetaOps.CONST and b.lazydata.base.arg == 0 + assert b.lazydata.base.op is Ops.CONST and b.lazydata.base.arg == 0 # partial masked out does not degrade into CONST b = a.pad(((0, 5), None))[5:] assert b.shape == (10, 10) - assert b.lazydata.base.op is not MetaOps.CONST + assert b.lazydata.base.op is not Ops.CONST if __name__ == "__main__": unittest.main() diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 957f529c1b..d4bf0b3a55 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -6,7 +6,7 @@ from dataclasses import replace from test.helpers import ast_const from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel from tinygrad.codegen.lowerer import get_grouped_dims -from tinygrad.ops import UOp, Ops, BinaryOps, TernaryOps, UnaryOps, GroupOp +from tinygrad.ops import UOp, Ops, GroupOp from tinygrad.device import Device, Buffer from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View @@ -109,10 +109,10 @@ class TestLinearizer(unittest.TestCase): st_x = x.lazydata.st g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((1, 32)).expand((32, 32)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (1,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (1,))) second_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((32, 1)).to_uop())) diff = second_x + first_reduce*ast_const(dtypes.float, -1, (32, 1)) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (0,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (0,))) store = UOp(Ops.STORE, dtypes.void, (g0, ShapeTracker.from_shape((1, 1)).to_uop(), second_reduce)) sink = UOp(Ops.SINK, src=(store,)) opts = [ @@ -145,10 +145,10 @@ class TestLinearizer(unittest.TestCase): st_x = x.lazydata.st g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((27, 1, 32, 5)).expand((27, 32, 32, 5)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) second_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((27, 32, 1, 5)).to_uop())) diff = second_x + first_reduce*ast_const(dtypes.float, -1, (27, 32, 1, 5)) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) sink = UOp(Ops.SINK, src=(store,)) opts = [ @@ -207,13 +207,13 @@ class TestLinearizer(unittest.TestCase): x2 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() g0, g1, g2, g3 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(4)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, x0.lazydata.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,))) second_x = UOp(Ops.LOAD, dtypes.float, (g2, x1.lazydata.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 32, 32, 1, 5))) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (2,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (2,))) third_x = UOp(Ops.LOAD, dtypes.float, (g3, x2.lazydata.st.reshape((27, 32, 1, 1, 5)).to_uop())) mul = (third_x*second_reduce) - third_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (mul,), (BinaryOps.ADD, (1,))) + third_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (mul,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 1, 5)).to_uop(), third_reduce)) sink = UOp(Ops.SINK, src=(store,)) wanna_output = (x2.numpy()*(x1.numpy()-x0.numpy().sum(axis=1, keepdims=True)).sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,1,5) @@ -234,11 +234,11 @@ class TestLinearizer(unittest.TestCase): st = x.lazydata.st g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, st.reshape((8, 1, 32, 8, 1, 16)).expand((8, 32, 32, 8, 16, 16)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2, 5))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2, 5))) second_x = UOp(Ops.LOAD, dtypes.float, (g1, st.reshape((8, 32, 1, 8, 16, 1)).to_uop())) neg_first_reduce = first_reduce * ast_const(dtypes.float, -1, (8, 32, 1, 8, 16, 1)) squares = (second_x+neg_first_reduce) - squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (1, 4))) + squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (1, 4))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((8, 1, 1, 8, 1, 1)).to_uop(), squares_sum,)) sink = UOp(Ops.SINK, src=(store,)) wanna_output = (x.numpy()-x.numpy().sum(axis=(1,3), keepdims=True)).sum(axis=(1,3)).reshape((8,1,1,8,1,1)) @@ -285,10 +285,10 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(27, 15, 5, dtype=dtypes.float).softmax(1).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5))) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) sink = UOp(Ops.SINK, src=(store,)) opts = [ @@ -317,11 +317,11 @@ class TestLinearizer(unittest.TestCase): g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop())) first_x_p = UOp(Ops.LOAD, dtypes.float, (g2, x_p.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) - first_reduce_p = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x_p.alu(UnaryOps.EXP2),), (BinaryOps.ADD, (2,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) + first_reduce_p = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x_p.alu(Ops.EXP2),), (Ops.ADD, (2,))) second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1)).to_uop())) diff = (second_x+(first_reduce + first_reduce_p)*ast_const(dtypes.float, -1, (4, 32, 1))) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((4, 1, 1)).to_uop(), second_reduce)) sink = UOp(Ops.SINK, src=(store,)) opts = [ @@ -352,10 +352,10 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize() g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5))) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store0 = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) second_out = second_reduce * ast_const(dtypes.float, 1/15, (27, 1, 1, 5)) store1 = UOp(Ops.STORE, src=(g1, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_out)) @@ -375,10 +375,10 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize() g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5))) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store0 = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) store1 = UOp(Ops.STORE, src=(g1, ShapeTracker(views=(View(shape=(27,15,1,5), strides=(5,0,1,1), offset=0, mask=None, contiguous=False),)).to_uop(), first_reduce)) # noqa: E501 wanna_output0 = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5) @@ -399,10 +399,10 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5))) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) sink = UOp(Ops.SINK, src=(store,)) opts = [[Opt(OptOps.UNROLL, 0, 3), Opt(OptOps.UNROLL, 0, 3)]] @@ -415,10 +415,10 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5))) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) sink = UOp(Ops.SINK, src=(store,)) opts = [[Opt(OptOps.UPCAST, 0, 3)]] @@ -434,10 +434,10 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(27, 12, 5, dtype=dtypes.float).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 12, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 12, 1, 5))) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) sink = UOp(Ops.SINK, src=(store,)) opts = [[Opt(OptOps.GROUPTOP, 0, 3), Opt(OptOps.GROUPTOP, 1, 3)]] @@ -450,13 +450,13 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,))) neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1)) second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop())) squares = (second_x+neg_mean)*(second_x+neg_mean) - squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,))) + squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,))) variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1)) - std = variance.alu(UnaryOps.SQRT) + std = variance.alu(Ops.SQRT) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), std)) sink = UOp(Ops.SINK, src=(store,)) wanna_output = x.numpy().std(axis=2, ddof=0).reshape((15,25,1,1)) @@ -468,13 +468,13 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) neg_mean = first_reduce * ast_const(dtypes.float, -0.04, (15, 25, 1, 35)) second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).to_uop())) squares = (second_x+neg_mean)*(second_x+neg_mean) - squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (1,))) + squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (1,))) variance = squares_sum * ast_const(dtypes.float, 0.04, (15, 1, 1, 35)) - std = variance.alu(UnaryOps.SQRT) + std = variance.alu(Ops.SQRT) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 1, 1, 35)).to_uop(), std)) sink = UOp(Ops.SINK, src=(store,)) wanna_output = x.numpy().std(axis=1, ddof=0).reshape((15,1,1,35)) @@ -488,13 +488,13 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,))) neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1)) second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop())) squares = (second_x+neg_mean)*(second_x+neg_mean) - squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,))) + squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,))) variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1)) - std = variance.alu(UnaryOps.SQRT) + std = variance.alu(Ops.SQRT) store_mean = UOp(Ops.STORE, src=(g1, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), neg_mean)) store_std = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), std)) sink = UOp(Ops.SINK, src=(store_std, store_mean)) @@ -511,13 +511,13 @@ class TestLinearizer(unittest.TestCase): g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] # push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,))) neg_mean = first_reduce * ast_const(dtypes.float, -0.03125, (3, 27, 32, 1)) # store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 32, 1)).to_uop(), mean)) # verify_lazyop(store) second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 32, 1)).to_uop())) squares = (second_x+neg_mean)*(second_x+neg_mean) - squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,))) + squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,))) variance = squares_sum * ast_const(dtypes.float, 0.03125, (3, 27, 1, 1)) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 1, 1)).to_uop(), variance)) sink = UOp(Ops.SINK, src=(store,)) @@ -532,13 +532,13 @@ class TestLinearizer(unittest.TestCase): x = Tensor.rand(4, 32).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32)).to_uop())) - max_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.MAX, (2,))) + max_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.MAX, (2,))) second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1,)).to_uop())) centered_x = second_x+max_x*ast_const(dtypes.float, -1, (4, 32, 1)) - exp_x = centered_x.alu(UnaryOps.EXP2) - sum_exp_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (exp_x,), (BinaryOps.ADD, (1,))) - # y = exp_x * sum_exp_x.alu(UnaryOps.RECIP) # kernels cannot do a return to full shape - recip_sum_exp_x = sum_exp_x.alu(UnaryOps.RECIP) + exp_x = centered_x.alu(Ops.EXP2) + sum_exp_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (exp_x,), (Ops.ADD, (1,))) + # y = exp_x * sum_exp_x.alu(Ops.RECIP) # kernels cannot do a return to full shape + recip_sum_exp_x = sum_exp_x.alu(Ops.RECIP) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((4,1,1)).to_uop(), recip_sum_exp_x)) sink = UOp(Ops.SINK, src=(store,)) expected = 1/np.exp2(x.numpy() - x.numpy().max(axis=-1, keepdims=True)).sum(axis=-1, keepdims=True).reshape(4,1,1) @@ -556,7 +556,7 @@ class TestLinearizer(unittest.TestCase): View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False))) arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384)) arange_axis = (3,) - arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis)) + arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (Ops.ADD, arange_axis)) output_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape)) out = arange+ast_const(dtypes.int, -1, output_shape) store = UOp(Ops.STORE, src=(UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0), ShapeTracker.from_shape(output_shape).to_uop(), out)) @@ -573,7 +573,7 @@ class TestLinearizer(unittest.TestCase): # TODO: do this arange broadcast in the scheduler arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384)) arange_axis = (3,) - arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis)) + arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (Ops.ADD, arange_axis)) arange_out_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape)) arange = arange+ast_const(dtypes.int, -1, arange_out_shape) # p2: the indexing @@ -581,10 +581,10 @@ class TestLinearizer(unittest.TestCase): data1 = (g1, ShapeTracker.from_shape(dataset.shape).reshape((1, 16384, 256, 1)).expand(arange_out_shape).to_uop()) idxs = Tensor([0,3,5,6]).realize() data2 = (g2, ShapeTracker.from_shape((4,)+(1,)*(len(arange_out_shape)-1)).expand(arange_out_shape).to_uop()) - arange_eq = arange.alu(BinaryOps.CMPNE, UOp(Ops.LOAD, dtypes.int, data2)).alu(BinaryOps.CMPNE, ast_const(dtypes.bool, True, arange_out_shape)) + arange_eq = arange.alu(Ops.CMPNE, UOp(Ops.LOAD, dtypes.int, data2)).alu(Ops.CMPNE, ast_const(dtypes.bool, True, arange_out_shape)) reduce_input = UOp(Ops.LOAD, dataset.dtype, data1)*UOp(Ops.CAST, dataset.dtype.scalar(), src=(arange_eq,)) out_axis = (1,) - out = UOp(Ops.REDUCE_AXIS, reduce_input.dtype, (reduce_input,), (BinaryOps.ADD, out_axis)) + out = UOp(Ops.REDUCE_AXIS, reduce_input.dtype, (reduce_input,), (Ops.ADD, out_axis)) output_shape = tuple(1 if i in out_axis else s for i,s in enumerate(arange_out_shape)) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape(output_shape).to_uop(), out)) sink = UOp(Ops.SINK, src=(store,)) @@ -605,7 +605,7 @@ class TestLinearizer(unittest.TestCase): ast_const(dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), val=10), UOp(Ops.MUL, dtypes.int, arg=None, src=( ast_const(dtypes.int, -1, (1, 20, 1)), - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.MAX, (0,)), src=( UOp(Ops.MUL, dtypes.int, arg=None, src=( UOp(Ops.CAST, dtypes.int, arg=None, src=( UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( @@ -618,7 +618,7 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa E501 ast_const(dtypes.bool, True, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)), # noqa E501 UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (2,)), src=( ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 20, 10), strides=(1, 0, 20), offset=0, mask=None, contiguous=False)))),)), # noqa E501 ast_const(dtypes.int, 10, (10, 20, 1)))),)),)),)),)), ast_const(dtypes.int, -1, (1, 20, 1)),)),)),)) @@ -637,7 +637,7 @@ class TestLinearizer(unittest.TestCase): ast_const(dtypes.int, 200, (1, 1)), UOp(Ops.MUL, dtypes.int, arg=None, src=( ast_const(dtypes.int, -1, (1, 1)), - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.MAX, (0,)), src=( UOp(Ops.MUL, dtypes.int, arg=None, src=( UOp(Ops.CAST, dtypes.int, arg=None, src=( UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( @@ -650,7 +650,7 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa: E501 ast_const(dtypes.bool, True, (200, 1)),)),)), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=( ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(201, 399), strides=(0, 0), offset=0, mask=((0, 201), (199, 399)), contiguous=False), View(shape=(200, 200), strides=(1, 400), offset=0, mask=None, contiguous=False)))),)), # noqa: E501 ast_const(dtypes.int, 200, (200, 1)),)),)),)),)),)), ast_const(dtypes.int, -1, (1, 1)),)),)),)) @@ -672,16 +672,16 @@ class TestLinearizer(unittest.TestCase): g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop())) x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop())) - r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.ADD, (1,))) - r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),),(BinaryOps.ADD, (0,))) + r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.ADD, (1,))) + r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),),(Ops.ADD, (0,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,1,N)).to_uop(), r1)) sink = UOp(Ops.SINK, src=(store,)) helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=0, keepdims=True)).sum(axis=0).reshape(1,1,N)], opts=opts) x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)).to_uop())) x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, N, 1)).to_uop())) - r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.ADD, (2,))) - r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (BinaryOps.ADD, (1,))) + r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.ADD, (2,))) + r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((N,1,1)).to_uop(), r1)) sink = UOp(Ops.SINK, src=(store,)) helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(N,1,1)], opts=opts) @@ -699,16 +699,16 @@ class TestLinearizer(unittest.TestCase): g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop())) x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop())) - r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.MAX, (1,))) - r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),), (BinaryOps.MAX, (0,))) + r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.MAX, (1,))) + r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),), (Ops.MAX, (0,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,1,N)).to_uop(), r1)) sink = UOp(Ops.SINK, src=(store,)) helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().max(axis=0, keepdims=True)).max(axis=0).reshape(1,1,N)], opts=opts) x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)).to_uop())) x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, N, 1)).to_uop())) - r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.MAX, (2,))) - r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (BinaryOps.MAX, (1,))) + r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.MAX, (2,))) + r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (Ops.MAX, (1,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((N,1,1)).to_uop(), r1)) sink = UOp(Ops.SINK, src=(store,)) helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().max(axis=1, keepdims=True)).max(axis=1).reshape(N,1,1)], opts=opts) @@ -735,7 +735,7 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.WHERE, dtypes.float, arg=None, src=( UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( ast_const(dtypes.float, 0.5*N, (N, 1, 1)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), @@ -743,7 +743,7 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.WHERE, dtypes.float, arg=None, src=( UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( ast_const(dtypes.float, 0.75*N, (N, N, 1)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), ld0.to_uop(),)),)),)), @@ -768,7 +768,7 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.WHERE, dtypes.float, arg=None, src=( UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( ast_const(dtypes.float, 0.5*N, (1, 1, N)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -776,7 +776,7 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.WHERE, dtypes.float, arg=None, src=( UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( ast_const(dtypes.float, 0.75*N, (N, 1, N)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), ld0.to_uop(),)),)),)), @@ -804,7 +804,7 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.WHERE, dtypes.float, arg=None, src=( UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( ast_const(dtypes.float, 0.5*N, (1, 1, 1, 1)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 1)), src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), @@ -812,7 +812,7 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.WHERE, dtypes.float, arg=None, src=( UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( ast_const(dtypes.float, 0.75*N, (N, N, 1, 1)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2, 3)), src=( UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, N, N), strides=(0, 0, N, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501 @@ -831,7 +831,7 @@ class TestLinearizer(unittest.TestCase): def test_end_local(self): g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=i) for i in range(2)] load = UOp(Ops.LOAD, dtypes.int, (g1, ShapeTracker.from_shape((32,)).to_uop())) - reduce = UOp(Ops.REDUCE_AXIS, dtypes.int, (load,), (BinaryOps.ADD, (0,))) + reduce = UOp(Ops.REDUCE_AXIS, dtypes.int, (load,), (Ops.ADD, (0,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,)).to_uop(), reduce)) sink = UOp(Ops.SINK, src=(store,)) load_t = Tensor.full(load.st_arg.shape, 1).contiguous().realize() @@ -1219,20 +1219,20 @@ class TestLinearizer(unittest.TestCase): assert len(sched) == 1 lin = Kernel(sched[0].ast) - assert sum(u.op in {UnaryOps.RECIP, BinaryOps.FDIV} for u in lin.linearize().uops) == max_ops, msg + assert sum(u.op in {Ops.RECIP, Ops.FDIV} for u in lin.linearize().uops) == max_ops, msg a = Tensor.empty((4,4)) b = Tensor.empty((4,4)) d = Tensor.empty((4,4)) c = (a*b)/b - helper(c, "found UnaryOps.RECIP in (a*b)/b operation") + helper(c, "found Ops.RECIP in (a*b)/b operation") c = a/a - helper(c, "found UnaryOps.RECIP in (a/a) operation") + helper(c, "found Ops.RECIP in (a/a) operation") c = (a/b)/d - helper(c, "found multiple UnaryOps.RECIP in (a/b)/d operation", 1) + helper(c, "found multiple Ops.RECIP in (a/b)/d operation", 1) def test_sum_collapse(self): t = Tensor([2]).reshape(1, 1).expand(256, 256).sum() @@ -1260,7 +1260,7 @@ class TestLinearizer(unittest.TestCase): lin = Kernel(sched_copy[-1].ast) lin.hand_coded_optimizations() lin.linearize() - assert not any(u.op == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded" + assert not any(u.op == Ops.WHERE for u in lin.uops), "found where where where should be folded" def test_phi_simplification(self): def helper(t, max_ops=0): @@ -1272,7 +1272,7 @@ class TestLinearizer(unittest.TestCase): assert len(set([u.op for u in uops if u.op in {Ops.RANGE, Ops.SPECIAL}])) == 1, "has either specials or ranges, not both" assert len([u for u in uops if u.op is Ops.ASSIGN]) == 0, "ASSIGN should have been simplified" # TODO: once uops track min/max this will be fixed - #assert len([u for u in uops if u.op is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops" + #assert len([u for u in uops if u.op is Ops.MAX]) <= max_ops, "no unnecessary MAX ops" helper(Tensor.arange(5.5, (3.5*300), 3.5), max_ops=2) helper(Tensor.arange(-1, -100, -5), max_ops=2) @@ -1602,7 +1602,7 @@ class TestFloat4(unittest.TestCase): UOp(Ops.STORE, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501 - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( UOp(Ops.CAST, dtypes.float, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, src=( @@ -1632,7 +1632,7 @@ class TestFloat4(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501 UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), @@ -1662,7 +1662,7 @@ class TestFloat4(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),))), # noqa: E501 UOp(Ops.CAST, dtypes.half, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (4, 6)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4, 6)), src=( UOp(Ops.CAST, dtypes.float, src=( UOp(Ops.LOAD, dtypes.half, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1), @@ -1949,7 +1949,7 @@ class TestKernelOpts(unittest.TestCase): UOp(Ops.STORE, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),))), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, src=( UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( @@ -2138,7 +2138,7 @@ class TestKernelOpts(unittest.TestCase): g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] ld0 = UOp(Ops.LOAD, dtypes.float, (g1, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501 ld1 = UOp(Ops.LOAD, dtypes.float, (g2, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501 - store = UOp(Ops.STORE, src=(g0, ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)).to_uop(), UOp(Ops.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (BinaryOps.ADD, (0, 2, 4, 6)),))) # noqa: E501 + store = UOp(Ops.STORE, src=(g0, ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)).to_uop(), UOp(Ops.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (Ops.ADD, (0, 2, 4, 6)),))) # noqa: E501 sink = UOp(Ops.SINK, src=(store,)) data1 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize() data2 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize() diff --git a/test/test_linearizer_dumb.py b/test/test_linearizer_dumb.py index b0830f7004..d9e7a0e058 100644 --- a/test/test_linearizer_dumb.py +++ b/test/test_linearizer_dumb.py @@ -5,7 +5,7 @@ import unittest from test.helpers import ast_const from tinygrad import Device, dtypes -from tinygrad.ops import UOp, Ops, BinaryOps +from tinygrad.ops import UOp, Ops from tinygrad.helpers import getenv from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.engine.search import Opt, OptOps @@ -21,7 +21,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.MAX, dtypes.half, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( @@ -64,7 +64,7 @@ class TestLinearizerDumb(unittest.TestCase): ast_const(dtypes.bool, True, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=( ast_const(dtypes.int, -1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, 1000, st_src=( @@ -75,7 +75,7 @@ class TestLinearizerDumb(unittest.TestCase): for opt in opts: k.apply_opt(opt) prg = k.to_program() print(prg.src) - assert prg.uops is not None and not any(uop.op is BinaryOps.MAX for uop in prg.uops), "leftover MAX" + assert prg.uops is not None and not any(uop.op is Ops.MAX for uop in prg.uops), "leftover MAX" @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local") def test_expander_new_srcs(self): @@ -83,7 +83,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) @@ -105,14 +105,14 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.CAST, dtypes.half, arg=None, src=( UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (2,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32001, 63999), strides=(0, 0), offset=0, mask=((0, 32001), (31999, 63999)), contiguous=False), View(shape=(4096, 32000, 32000), strides=(0, 1, 64000), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, -1, st_src=( @@ -136,7 +136,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( @@ -168,7 +168,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 1)), src=( UOp(Ops.WHERE, dtypes.float, arg=None, src=( UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( @@ -200,7 +200,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 26dd7b23af..6dc6e4a97b 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -3,7 +3,7 @@ import unittest, random import numpy as np from tinygrad.codegen.kernel import Kernel, KernelOptError from tinygrad.device import is_dtype_supported -from tinygrad.ops import UOp, Ops, BinaryOps +from tinygrad.ops import UOp, Ops from tinygrad.engine.search import Opt, OptOps from tinygrad import Device, dtypes, Tensor from tinygrad.helpers import CI @@ -47,7 +47,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 16), strides=(16, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), @@ -64,7 +64,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 2, 37, 9, 1, 1), strides=(666, 333, 9, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (4, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.MAX, (4, 5)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 2, 111, 27), strides=(6160, 3080, 28, 1), offset=0, mask=((0, 32), (0, 2), (0, 110), (0, 27)), contiguous=False), View(shape=(32, 2, 37, 9, 2, 2), strides=(5994, 2997, 81, 3, 27, 1), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) @@ -76,7 +76,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 16, 1), strides=(128, 16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) @@ -89,7 +89,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 6)), src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( @@ -111,7 +111,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=( ast_const(dtypes.int, -1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, 10, st_src=( @@ -125,7 +125,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 32, 1, 34, 1, 34), strides=(36992, 1156, 0, 34, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 4)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2, 4)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 32, 6, 8, 4, 6, 8, 4), strides=(2048, 64, 6291456, 8, 0, 1048576, 1, 0), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 8), (0, 1), (0, 6), (0, 8), (0, 1)), contiguous=False), View(shape=(512, 32, 6, 35, 6, 35), strides=(1179648, 36864, 6144, 192, 32, 1), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 32), (0, 6), (0, 32)), contiguous=False), View(shape=(512, 32, 238, 238), strides=(1411200, 44100, 210, 1), offset=0, mask=((0, 512), (0, 32), (0, 210), (0, 210)), contiguous=False), View(shape=(512, 32, 7, 34, 7, 34), strides=(1812608, 56644, 8092, 238, 34, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)) @@ -142,7 +142,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.RECIP, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( x9:=UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( @@ -166,7 +166,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 0, 0, 4500, 0, 0, 0, 0, 900, 60, 12, 4, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -183,7 +183,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(Ops.ADD, (3,)), src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), @@ -202,7 +202,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.RECIP, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 3)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 3)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( @@ -277,7 +277,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 6)), src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( @@ -299,7 +299,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 8)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 8)), src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( x5:=UOp(Ops.ADD, dtypes.float, arg=None, src=( x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=( @@ -313,7 +313,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x6,)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 8)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 8)), src=( x5,)),)),)),)),)) opts = [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.GROUP, axis=0, amt=4)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -325,7 +325,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(Ops.ADD, (3,)), src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), @@ -344,7 +344,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 6)), src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( @@ -370,7 +370,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -405,7 +405,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1024), strides=(0, 1024, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)), @@ -420,7 +420,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 1, 28, 28, 1, 1), strides=(31360, 0, 784, 0, 28, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -442,7 +442,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), @@ -462,7 +462,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 9, 7, 3, 3), strides=(2268, 0, 567, 0, 63, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -511,7 +511,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 3)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 3)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( @@ -624,7 +624,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1025, 2047), strides=(0, 0), offset=0, mask=((0, 1025), (1023, 2047)), contiguous=False), View(shape=(1024, 1024), strides=(1, 2048), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, -1, st_src=( @@ -639,7 +639,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(129, 255), strides=(0, 0), offset=0, mask=((0, 129), (127, 255)), contiguous=False), View(shape=(128, 128), strides=(1, 256), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, -1, st_src=( @@ -678,7 +678,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 1), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.MAX, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(Ops.MAX, (3,)), src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 169, 13, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) @@ -731,7 +731,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( @@ -749,7 +749,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 12, 31, 31, 1, 1, 1), strides=(11532, 0, 961, 31, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( @@ -767,7 +767,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 1), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( UOp(Ops.EXP2, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( @@ -791,7 +791,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 256, 14, 14, 1, 1, 1), strides=(50176, 0, 196, 14, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( @@ -809,7 +809,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( x5:=UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -851,7 +851,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 1, 1), strides=(180, 0, 30, 3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.MAX, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6, 7)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -875,7 +875,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(5, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.uchar, arg=None, src=( UOp(Ops.ADD, dtypes.uint, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.uint, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.uint, arg=(Ops.ADD, (1,)), src=( UOp(Ops.CAST, dtypes.uint, arg=None, src=( ast_const(dtypes.uchar, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(6, 9), strides=(0, 0), offset=0, mask=((0, 6), (4, 9)), contiguous=False), View(shape=(5, 5), strides=(1, 10), offset=0, mask=None, contiguous=False))), src=()),)),)),)), @@ -895,7 +895,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.MAX, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6, 7)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.uchar, arg=None, src=( @@ -920,7 +920,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 32, 1, 1, 1, 5, 5, 256), strides=(0, 0, 6400, 0, 0, 0, 1280, 256, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 3, 4)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 3, 4)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.uchar, arg=None, src=( @@ -943,7 +943,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.MAX, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6, 7)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.uchar, arg=None, src=( @@ -969,7 +969,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, -1, st_src=( @@ -987,7 +987,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 1, 1, 1), strides=(100352, 0, 784, 28, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( @@ -1007,7 +1007,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) @@ -1021,7 +1021,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) @@ -1035,7 +1035,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) @@ -1052,7 +1052,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 1, 1, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2, 3)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -1065,7 +1065,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (4,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 3), strides=(0, 0), offset=0, mask=((0, 3), (1, 3)), contiguous=False), View(shape=(2, 3, 2, 3, 3), strides=(0, 0, 1, 0, 4), offset=0, mask=((0, 2), (0, 3), (0, 2), (0, 3), (0, 2)), contiguous=False))), src=()),)),)), x19:=ast_const(dtypes.int, -1, st_src=( @@ -1078,7 +1078,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=3, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (4,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5), strides=(0, 0), offset=0, mask=((0, 4), (2, 5)), contiguous=False), View(shape=(2, 3, 2, 3, 3), strides=(0, 0, 0, 1, 6), offset=0, mask=None, contiguous=False))), src=()),)),)), x19,)),)), @@ -1093,7 +1093,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.bool, arg=None, src=( @@ -1127,7 +1127,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, -1, st_src=( @@ -1142,7 +1142,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 1, 1, 256, 1, 1, 256), strides=(0, 0, 65536, 0, 0, 256, 0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3, 4)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3, 4)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( @@ -1160,7 +1160,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 1), strides=(6, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -1178,7 +1178,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 20, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.bool, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.bool, arg=(Ops.ADD, (3,)), src=( UOp(Ops.MUL, dtypes.bool, arg=None, src=( UOp(Ops.LOAD, dtypes.bool, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=1, src=()), @@ -1220,7 +1220,7 @@ class TestLinearizerFailures(unittest.TestCase): x9,)), UOp(Ops.ADD, dtypes.half, arg=None, src=( UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( @@ -1249,7 +1249,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( @@ -1267,7 +1267,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.uchar, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.uchar, arg=(Ops.ADD, (1,)), src=( UOp(Ops.MUL, dtypes.uchar, arg=None, src=( UOp(Ops.LOAD, dtypes.uchar, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()), @@ -1279,7 +1279,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (2,)), src=( UOp(Ops.WHERE, dtypes.int, arg=None, src=( UOp(Ops.VALID, dtypes.bool, arg=None, src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(50001, 99999), strides=(0, 0), offset=0, mask=((0, 50001), (49999, 99999)), contiguous=False), View(shape=(1024, 50000, 50000), strides=(0, 1, 100000), offset=0, mask=None, contiguous=False))), src=()),)), @@ -1306,7 +1306,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( @@ -1326,7 +1326,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(W, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( diff --git a/test/test_linearizer_overflows.py b/test/test_linearizer_overflows.py index 5ff6e0ca74..2e7265f652 100644 --- a/test/test_linearizer_overflows.py +++ b/test/test_linearizer_overflows.py @@ -8,7 +8,7 @@ from tinygrad.engine.search import Opt, OptOps from tinygrad.engine.search import time_linearizer, bufs_from_lin # stuff needed to unpack a kernel -from tinygrad.ops import UOp, Ops, BinaryOps +from tinygrad.ops import UOp, Ops from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View @@ -33,7 +33,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -68,7 +68,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.STORE, None, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 1, 1, 1), strides=(65536, 0, 1024, 32, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -85,7 +85,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.STORE, None, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -102,7 +102,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.STORE, None, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -119,7 +119,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.STORE, None, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -136,7 +136,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.STORE, None, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -153,7 +153,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.STORE, None, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -175,7 +175,7 @@ class TestLinearizerOverflowAlt(unittest.TestCase): in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop() ot_st = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)).to_uop() prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2)) - store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5))))) + store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (Ops.ADD, (7, 6, 5))))) ast = UOp(Ops.SINK, src=(store,)) opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=2)] _test_overflow(ast, opts) @@ -187,7 +187,7 @@ class TestLinearizerOverflowAlt(unittest.TestCase): in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop() ot_st = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)).to_uop() prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2)) - store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5))))) + store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (Ops.ADD, (7, 6, 5))))) ast = UOp(Ops.SINK, src=(store,)) opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=4, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=5, amt=2)] _test_overflow(ast, opts) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 5b1704c364..58ee41ff37 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -1,7 +1,7 @@ import unittest, functools, random from typing import List from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes -from tinygrad.ops import MetaOps, BinaryOps, Ops +from tinygrad.ops import Ops from tinygrad.helpers import CI, getenv, prod, Context from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.engine.schedule import create_schedule @@ -481,7 +481,7 @@ class TestMultiTensor(unittest.TestCase): for p in get_parameters(bn): p.shard_(devices_4).realize() out = bn(t) - scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.outputs[0].device in devices_4 and sched.ast.op is not MetaOps.COPY] + scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.outputs[0].device in devices_4 and sched.ast.op is not Ops.COPY] assert set(out.device for sched in scheds for out in sched.outputs) == set(devices_4), "should have ast on each shard device" asts = [sched.ast for sched in scheds] assert len(asts) @@ -640,21 +640,21 @@ class TestMultiTensor(unittest.TestCase): for si in t.schedule(): ast = si.ast.src[0] assert ast.op is Ops.STORE - assert ast.src[2].op is BinaryOps.ADD + assert ast.src[2].op is Ops.ADD assert ast.src[2].src[0].op is Ops.LOAD assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 1 t = 2 * t for si in t.schedule(): ast = si.ast.src[0] assert ast.op is Ops.STORE - assert ast.src[2].op is BinaryOps.MUL + assert ast.src[2].op is Ops.MUL assert ast.src[2].src[0].src[1].op is Ops.CONST and ast.src[2].src[0].src[1].arg == 2 assert ast.src[2].src[1].op is Ops.LOAD t = t + t.full_like(3) for si in t.schedule(): ast = si.ast.src[0] assert ast.op is Ops.STORE - assert ast.src[2].op is BinaryOps.ADD + assert ast.src[2].op is Ops.ADD assert ast.src[2].src[0].op is Ops.LOAD assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 3 diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index d997cb3f37..5fa5a4b73b 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -8,7 +8,7 @@ from tinygrad.dtype import dtypes from tinygrad.engine.realize import CompiledRunner from tinygrad.helpers import dedup, flatten, prod from tinygrad.renderer.cstyle import CStyleLanguage -from tinygrad.ops import BinaryOps, UOp, Ops +from tinygrad.ops import UOp, Ops from tinygrad.renderer import Program from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.engine.lazy import LazyBuffer @@ -34,7 +34,7 @@ class TestCStyleFailures(unittest.TestCase): b = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1) idx = UOp.const(dtypes.int, 0) ld = UOp(Ops.LOAD, dtypes.int, (b.index(idx),)) - alu = ld.alu(BinaryOps.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1)) + alu = ld.alu(Ops.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1)) store = UOp.store(a.index(idx), alu) sink = UOp(Ops.SINK, dtypes.void, (store,)) uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer)) diff --git a/test/test_schedule.py b/test/test_schedule.py index 6347bbd968..2925e25d1a 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -13,7 +13,7 @@ from tinygrad.device import is_dtype_supported from tinygrad.dtype import DType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View -from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, Ops, graph_rewrite, track_rewrites +from tinygrad.ops import UOp, Ops, graph_rewrite, track_rewrites from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context from tinygrad.codegen.kernel import Kernel, verify_ast from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, view_left @@ -1042,7 +1042,7 @@ class TestSchedule(unittest.TestCase): b = r.sum(0) * 4 c = r.sum(1) * 2 schedule = check_schedule([b, c], 3) - self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD) + self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD) # multireduce spec def test_multireduce_simple_chase(self): @@ -1053,7 +1053,7 @@ class TestSchedule(unittest.TestCase): c = r.sum(1) + 12 np_r = (a.numpy() + (a.numpy().sum(0) + 6)).sum(0) * 2 # schedule = check_schedule([b,c], 3) - # self.assertIs(schedule[0].ast[0].src[0].arg, BinaryOps.MUL) + # self.assertIs(schedule[0].ast[0].src[0].arg, Ops.MUL) schedule = check_schedule([b,c], 4) run_schedule(schedule) np.testing.assert_allclose(b.numpy(), np_r.sum(0) + 8, atol=1e-4, rtol=1e-4) @@ -1066,7 +1066,7 @@ class TestSchedule(unittest.TestCase): d = r.T * 4 e = r * d schedule = check_schedule([d, e], 3) - self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD) + self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD) # multireduce spec def test_multireduce_push_permute_chase(self): @@ -1077,7 +1077,7 @@ class TestSchedule(unittest.TestCase): d = r.T * 4 e = r * (d + a).sum(2) schedule = check_schedule([d, e], 3) # make sure it doesn't fuse - self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD) + self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD) run_schedule(schedule) np.testing.assert_allclose(d.numpy(), (a.numpy().sum(2) + b.numpy()).T * 4, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(e.numpy(), (a.numpy().sum(2) + b.numpy()) * (d.numpy() + a.numpy()).sum(2), atol=1e-4, rtol=1e-4) @@ -1089,7 +1089,7 @@ class TestSchedule(unittest.TestCase): r = a.sum(1) + c d = r[:4] * b schedule = check_schedule(d, 2) - self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD) + self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD) # multireduce spec def test_multireduce_push_shrink_chase(self): @@ -1102,7 +1102,7 @@ class TestSchedule(unittest.TestCase): out = r[:4] * b + d.sum(1)[:4] # schedule = check_schedule(out, 2) schedule = check_schedule(out, 3) - self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD) + self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD) run_schedule(schedule) np.testing.assert_allclose(out.numpy(), (a.numpy().sum(1) + c.numpy())[:4] * b.numpy() + d.numpy().sum(1)[:4], atol=1e-4, rtol=1e-4) @@ -1287,16 +1287,16 @@ class TestSchedule(unittest.TestCase): @unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported") def test_bitcast_subbufer(self): x = cast(LazyBuffer, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata) - a = x.alu(UnaryOps.EXP2).cast(dtypes.int32, True, allow_buffer_view=True) + a = x.alu(Ops.EXP2).cast(dtypes.int32, True, allow_buffer_view=True) b = x.cast(dtypes.int32, True, allow_buffer_view=True) - b = a.alu(BinaryOps.ADD, b) + b = a.alu(Ops.ADD, b) check_schedule(b, 2) # this should fuse when it makes sense def test_bitcast_disable_subbufer(self): x = cast(LazyBuffer, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata) - a = x.alu(UnaryOps.EXP2).cast(dtypes.int32, True, allow_buffer_view=False) + a = x.alu(Ops.EXP2).cast(dtypes.int32, True, allow_buffer_view=False) b = x.cast(dtypes.int32, True, allow_buffer_view=False) - b = a.alu(BinaryOps.ADD, b) + b = a.alu(Ops.ADD, b) check_schedule(b, 1) def test_reduceop_reshape_dont_push(self): @@ -1530,7 +1530,7 @@ class TestIndexing(unittest.TestCase): def test_arange_view_op(self): a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).contiguous() assert isinstance(a.lazydata, LazyBuffer) - self.assertIs(a.lazydata.base.op, MetaOps.BUFFER_VIEW) + self.assertIs(a.lazydata.base.op, Ops.BUFFER_VIEW) self.check_schedule(a, 1) np.testing.assert_equal(a.numpy(), [[4, 5]]) @@ -1538,7 +1538,7 @@ class TestIndexing(unittest.TestCase): def test_arange_shrink_copy(self): a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).to("CLANG") assert isinstance(a.lazydata, LazyBuffer) - self.assertIs(a.lazydata.base.op, MetaOps.COPY) + self.assertIs(a.lazydata.base.op, Ops.COPY) self.check_schedule(a, 1) np.testing.assert_equal(a.numpy(), [[4, 5]]) @@ -1546,8 +1546,8 @@ class TestIndexing(unittest.TestCase): def test_arange_expand_copy(self): a = Tensor.arange(4).reshape(2, 2, 1).expand(2, 2, 2).to("CLANG") assert isinstance(a.lazydata, LazyBuffer) - self.assertIs(a.lazydata.base.op, MetaOps.COPY) - self.assertIs(a.lazydata.base.srcs[0].base.op, BinaryOps.ADD) + self.assertIs(a.lazydata.base.op, Ops.COPY) + self.assertIs(a.lazydata.base.srcs[0].base.op, Ops.ADD) self.check_schedule(a, 1) np.testing.assert_equal(a.numpy(), [[[0, 0], [1, 1]], [[2, 2], [3, 3]]]) @@ -1643,7 +1643,7 @@ class TestIndexing(unittest.TestCase): def test_simple_store_reshape(self): bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) - r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) + r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1))) r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(())) r = r + 2 sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),)) @@ -1655,7 +1655,7 @@ class TestIndexing(unittest.TestCase): def test_no_reshape_reduceop(self): bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) - r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) + r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1))) sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),)) rsink = graph_rewrite(sink, view_right) verify_ast(sink) @@ -1698,7 +1698,7 @@ class TestSwizzle(unittest.TestCase): # LazyBuffer to pre-rewrite AST bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop())) - r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,))) + r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0,))) swizzle_r = UOp(Ops.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(())) alu = swizzle_r+1 sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),)) @@ -1720,9 +1720,9 @@ class TestSwizzle(unittest.TestCase): # LazyBuffer to pre-rewrite AST bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)] ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop())) - r1 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,))) + r1 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld1,), (Ops.ADD, (0,))) ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop())) - r2 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld2,), (BinaryOps.ADD, (0,))) + r2 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld2,), (Ops.ADD, (0,))) alu = UOp(Ops.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(Ops.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(())) sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+2,),),)) # noqa: E501 # graph rewrite @@ -1736,7 +1736,7 @@ class TestSwizzle(unittest.TestCase): def test_swizzle_rewrite_alt(self): swizzle = UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501 - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501 diff --git a/test/test_search.py b/test/test_search.py index b91d003e54..a9e6d4e6f0 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -3,7 +3,7 @@ import unittest from test.helpers import ast_const from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.codegen.kernel import Kernel -from tinygrad.ops import UOp, Ops, BinaryOps +from tinygrad.ops import UOp, Ops from tinygrad.engine.schedule import create_schedule from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search from tinygrad.device import Device, Buffer @@ -107,7 +107,7 @@ class TestBEAM(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501 - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.MAX, (1,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 502c0df635..85396d7b1a 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -2,8 +2,7 @@ from typing import List import unittest, time from tinygrad import dtypes, Device from tinygrad.helpers import DEBUG -from tinygrad.ops import BinaryOps, Ops, UOp, KernelInfo -from tinygrad.ops import UPat, PatternMatcher +from tinygrad.ops import Ops, UOp, KernelInfo, UPat, PatternMatcher from tinygrad.renderer import Renderer from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, sym @@ -541,7 +540,7 @@ class TestExpander(unittest.TestCase): @unittest.skip("no longer supported") def test_reduce_known_axis(self): e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) - sink = UOp(Ops.REDUCE, dtypes.int, (3*e1,e1), BinaryOps.ADD) + sink = UOp(Ops.REDUCE, dtypes.int, (3*e1,e1), Ops.ADD) sink = expander_rewrite(sink) assert sink.op is Ops.CONST self.assertEqual(sink.arg, 3*(0+1+2+3)) @@ -549,7 +548,7 @@ class TestExpander(unittest.TestCase): @unittest.skip("no longer supported") def test_reduce_const(self): e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) - sink = UOp(Ops.REDUCE, dtypes.int, (UOp.const(dtypes.int, 3), e1), BinaryOps.ADD) + sink = UOp(Ops.REDUCE, dtypes.int, (UOp.const(dtypes.int, 3), e1), Ops.ADD) sink = expander_rewrite(sink) assert sink.op is Ops.CONST self.assertEqual(sink.arg, 3*4) @@ -590,7 +589,7 @@ class TestExpander(unittest.TestCase): def test_reduce_different_axis(self): e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) e2 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),)) - sink = UOp(Ops.REDUCE, dtypes.int, (e1,e2), BinaryOps.ADD) + sink = UOp(Ops.REDUCE, dtypes.int, (e1,e2), Ops.ADD) sink = expander_rewrite(sink) print(sink) diff --git a/test/test_uops.py b/test/test_uops.py index 13dbdb7d2b..abb9eae35e 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.helpers import CI, DEBUG, getenv, Context from tinygrad.dtype import dtypes, DType from tinygrad.device import Buffer, Device -from tinygrad.ops import Ops, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, KernelInfo, exec_alu, spec # noqa F401 +from tinygrad.ops import Ops, UOp, UPat, KernelInfo, exec_alu, spec # noqa F401 from tinygrad.renderer import Program from tinygrad.engine.schedule import create_schedule, to_si from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel @@ -29,7 +29,7 @@ def uop(uops:List[UOp], uop:Ops, dtype:Optional[DType], src:Tuple[UOp, ...], arg def _test_single_value(vals, op, dts): uops = [] - output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1] + output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1] buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) buf_loads = [uop(uops, Ops.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)] loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0))]) for i, dtype in enumerate(dts)) @@ -45,7 +45,7 @@ def _test_single_value(vals, op, dts): def _test_single_value_const(vals, op, dts): uops = [] - output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1] + output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1] buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) loads = (uop(uops, Ops.CONST, dtype, [], a) for a,dtype in zip(vals, dts)) alu = uop(uops, op, output_dtype, loads) @@ -103,49 +103,49 @@ class TestUOps(unittest.TestCase): class TestFloatUOps(TestUOps): @unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop') - def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a)) + def test_exp2(self): self._test_uop_fxn(Ops.EXP2, lambda a: np.exp2(a)) @unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop') - def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan')) + def test_log2(self): self._test_uop_fxn(Ops.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan')) @unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop') - def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a)) - def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1/a if a != 0 else float('inf')) - def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan')) + def test_sin(self): self._test_uop_fxn(Ops.SIN, lambda a: math.sin(a)) + def test_recip(self): self._test_uop_fxn(Ops.RECIP, lambda a: 1/a if a != 0 else float('inf')) + def test_sqrt(self): self._test_uop_fxn(Ops.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan')) - def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b) - def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b) - def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b)) - def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a>int(b), (dtypes.int32, dtypes.int32), no_b_neg=True) + def test_shr_int32(self): self._test_bop_fxn(Ops.SHR, lambda a,b: int(a)>>int(b), (dtypes.int32, dtypes.int32), no_b_neg=True) @unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts") - def test_shl_int32(self): self._test_bop_fxn(BinaryOps.SHL, lambda a,b: int(a)< ALU/CONST rewrite is now instant """ - matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.ALU))), lambda x: x)]) - c4 = UOp(UOps.ALU, dtypes.float, (c1,c3), BinaryOps.ADD) - c5 = UOp(UOps.ALU, dtypes.float, (c3,c1), BinaryOps.ADD) + matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=(UPat(Ops.CONST), UPat(GroupOp.ALU))), lambda x: x)]) + c4 = UOp(Ops.ADD, dtypes.float, (c1,c3)) + c5 = UOp(Ops.ADD, dtypes.float, (c3,c1)) self.assertEqual(matcher.rewrite(c3), None) self.assertEqual(matcher.rewrite(c4), c4) self.assertEqual(matcher.rewrite(c5), None) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 547d352979..6565dc310a 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -5,7 +5,7 @@ from collections import defaultdict from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence from enum import Enum, auto -from tinygrad.ops import GroupOp, BinaryOps, KernelInfo, UOp, Ops, PatternMatcher, can_pad, print_uops, type_verify, resolve, Variable, sint, \ +from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, PatternMatcher, can_pad, print_uops, type_verify, resolve, Variable, sint, \ graph_rewrite, track_rewrites from tinygrad.device import Device from tinygrad.renderer import Renderer, TensorCore, Program @@ -276,7 +276,7 @@ class Kernel: if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0] - if mul_op.op is not BinaryOps.MUL: return None + if mul_op.op is not Ops.MUL: return None def buf_index(src:UOp) -> Optional[int]: # TODO: apply tc even if the sources are not from LOAD @@ -303,7 +303,7 @@ class Kernel: return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads) def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool: - if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is BinaryOps.ADD: + if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD: for tc in self.opts.tensor_cores: tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops] # can only fuse reduces with the same tc options @@ -338,7 +338,7 @@ class Kernel: 2: apply tensor core shape but don't use UOp.WMMA extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None) tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise) - 0: applies to only kernels with a single reduce axis and direct UOps.LOAD into BinaryOps.MUL + 0: applies to only kernels with a single reduce axis and direct UOps.LOAD into Ops.MUL 1: allows kernels with multiple reduce axes and also multiplication of UOps.CAST'd buffers 2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed """ @@ -441,7 +441,7 @@ class Kernel: check(not self.vars, "does not work with symbolic shape") check(axis < self.first_upcast, "cannot pad upcasted") # ok to pad SUM if all parent ALU ops have f(0) = 0 - if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is BinaryOps.ADD and can_pad(r), f"cannot pad {r}") + if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r), f"cannot pad {r}") padded = False for i,st in enumerate(self.sts): if (s:=st.shape[axis]) == 1: continue # reduced @@ -470,8 +470,8 @@ class Kernel: # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4) if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \ - self.reduceop is not None and self.reduceop.arg[0] is BinaryOps.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \ - (mulop:=self.reduceop.src[0]).op is BinaryOps.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD: + self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \ + (mulop:=self.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD: st0, st1 = self.sts[self.bufs.index(mulop.src[0])], self.sts[self.bufs.index(mulop.src[1])] strides0, strides1 = st0.real_strides(), st1.real_strides() def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides)) @@ -663,7 +663,7 @@ class Kernel: # for TC=2, we can't do the shapetracker fixup srcs = [fixup_ast(rsrc.src[0]), fixup_ast(rsrc.src[1])] # MUL/SUM instead of WMMA - ret = UOp(Ops.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(BinaryOps.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1])) + ret = UOp(Ops.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(Ops.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1])) else: # real WMMA, use CONTRACT/EXPAND to get the vectorization right wmma_upcast_axes = wmma_arg[-2] diff --git a/tinygrad/codegen/transcendental.py b/tinygrad/codegen/transcendental.py index 42cb2abcf5..4f9aeff51a 100644 --- a/tinygrad/codegen/transcendental.py +++ b/tinygrad/codegen/transcendental.py @@ -169,7 +169,7 @@ def sin_poly_large(d:UOp, q:UOp) -> UOp: def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp: """ - Implements a 1.0 ULP approximation for UnaryOps.SIN. + Implements a 1.0 ULP approximation for Ops.SIN. - fast=True assumes x <= switch_over. - switch_over is the threshold for switching to payne_hanek_reduction. """ @@ -192,7 +192,7 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp: def xexp2(d:UOp) -> UOp: """ - Implements a 1.0 ULP approximation for UnaryOps.EXP2 + Implements a 1.0 ULP approximation for Ops.EXP2 - Paper: https://arxiv.org/pdf/2001.09258 """ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES @@ -218,7 +218,7 @@ def xexp2(d:UOp) -> UOp: def xlog2(d:UOp) -> UOp: """ - Implements a 1.0 ULP approximation for UnaryOps.LOG2 + Implements a 1.0 ULP approximation for Ops.LOG2 Paper: https://arxiv.org/pdf/2001.09258 5.5 """ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 4113a0169c..610c36e5ad 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -3,7 +3,7 @@ from typing import Optional, Tuple, Dict, List, TYPE_CHECKING, Any, DefaultDict, import functools, itertools, operator from collections import defaultdict from tinygrad.dtype import dtypes, ImageDType, PtrDType -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple +from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple from tinygrad.ops import graph_rewrite, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES @@ -23,7 +23,7 @@ def fold_expanded(ex, buf): for i,s in enumerate(new_srcs): idx = s.src[0].src[1] if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue - if idx.op is BinaryOps.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg + if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg else: root_src, arg = idx, 0 # add gates for gated @@ -92,12 +92,12 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]: # can drop valid if idx is out of bound when valid is False drop_stmt = [] - for stmt in split_uop(valid, BinaryOps.AND): + for stmt in split_uop(valid, Ops.AND): X, is_upper_bound, c = parse_valid(stmt) # for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i - if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, BinaryOps.ADD)): - testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, BinaryOps.ADD), idx) + if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, Ops.ADD)): + testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, Ops.ADD), idx) testidx = testidx.simplify() if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0: drop_stmt.append(stmt) @@ -114,7 +114,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]: break if not drop_stmt and idx is start_idx: return None - new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, BinaryOps.AND) if s not in drop_stmt]) else None + new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None return buf.index(idx, new_valid) # ***** optional patterns ***** @@ -123,23 +123,23 @@ powers_of_two = {2**i:i for i in range(64)} @functools.lru_cache(None) def get_late_rewrite_patterns(ops, force_transcendental=False): pat: List[Tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \ - ((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if op not in ops or force_transcendental] + ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental] # rewrite MOD to AND (which should always be supported, but not for generic in tests) - if BinaryOps.AND in ops: + if Ops.AND in ops: pat += [(UPat(Ops.MOD, src=(UPat.var('base'), UPat.cvar("const"))), lambda base,const: base & (const.arg-1) if const.arg in powers_of_two else None)] # rewrite MUL/IDIV to SHL+SHR - if BinaryOps.SHL in ops and BinaryOps.SHR in ops: + if Ops.SHL in ops and Ops.SHR in ops: pat += [ (UPat(Ops.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const: mul << powers_of_two[const.arg] if const.arg in powers_of_two else None), # (x * (2**y)) -> shl(x,y) (UPat(Ops.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const: div >> powers_of_two[const.arg] if const.arg in powers_of_two else None)] # (x // (2**y)) -> shr(x,y) - if UnaryOps.NEG in ops: - pat += [(UPat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))] - if BinaryOps.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(UnaryOps.NEG), lambda x,y: x.alu(BinaryOps.SUB, y))] - if TernaryOps.MULACC in ops: - pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(TernaryOps.MULACC, b, c))] + if Ops.NEG in ops: + pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))] + if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))] + if Ops.MULACC in ops: + pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))] return PatternMatcher(pat) # ***** threefry ***** @@ -225,7 +225,7 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp): if len(reduce_unparented) == 0: return None new_acc = acc.replace(src=acc.src[0:1]+tuple(reduce_parented)) ret = new_acc.assign(new_acc.alu(alu.op, ret)) - if alu.op is BinaryOps.ADD: + if alu.op is Ops.ADD: for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count) return ret diff --git a/tinygrad/engine/lazy.py b/tinygrad/engine/lazy.py index 6fc0675b5b..19430156ab 100644 --- a/tinygrad/engine/lazy.py +++ b/tinygrad/engine/lazy.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Optional, Any, Tuple, List, get_args from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP, LAZYCACHE -from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, exec_alu, python_alu +from tinygrad.ops import exec_alu, python_alu from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint, GroupOp, Ops from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.device import Buffer @@ -11,9 +11,9 @@ from weakref import ref, ReferenceType, WeakValueDictionary lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary() def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(), base:Optional[LazyBuffer]=None, enable_cache=bool(LAZYCACHE)): - if st.size == 0: op, arg, srcs, base = MetaOps.CONST, 0, (), None + if st.size == 0: op, arg, srcs, base = Ops.CONST, 0, (), None dtype = to_dtype(dtype) - if op is MetaOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, UOp) else arg, True + if op is Ops.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, UOp) else arg, True cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base)) if enable_cache and (rret := lazycache.get(cache_key, None)) is not None: return rret @@ -32,13 +32,13 @@ class LazyBuffer(MathTrait): if base is None: # properties on base self.op, self.arg, self.srcs = op, arg, srcs # this is a UOp, except the src is LazyBuffers and not UOps - assert self.op is not MetaOps.ASSIGN or srcs[0].base.realized is not None, "assign target must be realized" + assert self.op is not Ops.ASSIGN or srcs[0].base.realized is not None, "assign target must be realized" - if self.op is MetaOps.BUFFER_VIEW: + if self.op is Ops.BUFFER_VIEW: # some LazyBuffers can be processed with only a view, no AST required self.buffer: Buffer = srcs[0].base.buffer.view(st.size, self.dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize) else: - self.buffer = srcs[0].base.buffer if self.op is MetaOps.ASSIGN else Buffer(device, self.size, self.dtype) + self.buffer = srcs[0].base.buffer if self.op is Ops.ASSIGN else Buffer(device, self.size, self.dtype) self.buffer.ref(1) self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None self.forced_realize = False @@ -74,14 +74,14 @@ class LazyBuffer(MathTrait): def const_like(self, b): return self.const_with_shape(b, self.shape) def const_with_shape(self, val:ConstType, shape:Tuple[sint,...]) -> LazyBuffer: assert isinstance(val, get_args(ConstType)), f"{val=} has {type(val)=}, not a ConstType" - return LazyBuffer.metaop(MetaOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape) + return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape) def is_realized(self) -> bool: return self.base.realized is not None def assign(self, x:LazyBuffer) -> LazyBuffer: assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}" assert self.is_realized(), f"assign target must be realized {self}" - return LazyBuffer.metaop(MetaOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), + return LazyBuffer.metaop(Ops.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), src=(self.base, x), enable_cache=True) def can_view(self): @@ -90,7 +90,7 @@ class LazyBuffer(MathTrait): def contiguous(self, allow_buffer_view=True): if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const(): - ret = self.alu(MetaOps.BUFFER_VIEW) if allow_buffer_view and self.can_view() else self.alu(MetaOps.CONTIGUOUS) + ret = self.alu(Ops.BUFFER_VIEW) if allow_buffer_view and self.can_view() else self.alu(Ops.CONTIGUOUS) if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti return ret self.base.forced_realize = True @@ -101,7 +101,7 @@ class LazyBuffer(MathTrait): if self.dtype == dtype: return self if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)") if self.is_unrealized_unmasked_const() and not bitcast: - return create_lazybuffer(self.device, self.st, dtype, MetaOps.CONST, dtypes.as_const(self.base.arg, dtype)) + return create_lazybuffer(self.device, self.st, dtype, Ops.CONST, dtypes.as_const(self.base.arg, dtype)) new_shape = self.shape if bitcast and self.dtype.itemsize != dtype.itemsize: if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now") @@ -112,27 +112,27 @@ class LazyBuffer(MathTrait): elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base: # TODO: applying this makes gpt2 slower return self.base.cast(dtype, bitcast)._view(self.st) - cast_op: Ops = (MetaOps.BUFFER_VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST + cast_op: Ops = (Ops.BUFFER_VIEW if self.can_view() and allow_buffer_view else Ops.BITCAST) if bitcast else Ops.CAST return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,)) - def is_unrealized_const(self): return self.base.realized is None and self.base.op is MetaOps.CONST and not isinstance(self.base.arg, UOp) + def is_unrealized_const(self): return self.base.realized is None and self.base.op is Ops.CONST and not isinstance(self.base.arg, UOp) def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views) def _copy(self, device:str) -> LazyBuffer: assert self.st.contiguous and self.size == self.base.size, f"can only copy contig {self} {self.base}" - return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, MetaOps.COPY, self.buffer.nbytes, (self,), enable_cache=False) + return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, Ops.COPY, self.buffer.nbytes, (self,), enable_cache=False) def copy_to_device(self, device:str, force:bool=False, clone:bool=False) -> LazyBuffer: # no COPY if self.device == device and not clone: return self # double COPY = one COPY - if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is MetaOps.COPY: + if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is Ops.COPY: return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape) # const doesn't have to be copied (issues with disk tensor) if self.is_unrealized_const(): - return LazyBuffer.metaop(MetaOps.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st) + return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st) # if it's a shrink, do the shrink before the copy with CONTIGUOUS if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device) @@ -149,25 +149,25 @@ class LazyBuffer(MathTrait): srcs.append(root._view(s.base.contiguous_child[1])) else: srcs.append(s) - if not all_same(dts:=[x.dtype.base for x in (srcs[1:] if op is TernaryOps.WHERE else srcs)]): + if not all_same(dts:=[x.dtype.base for x in (srcs[1:] if op is Ops.WHERE else srcs)]): raise AssertionError(f"all dtypes must match {dts} on {op}") assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}" - if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool" + if op is Ops.WHERE: assert srcs[0].dtype == dtypes.bool, "Ops.WHERE must have the first arg be bool" - out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else srcs[-1].dtype + out_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else srcs[-1].dtype # const folding if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs): return self.cast(out_dtype).const_like(exec_alu(op, out_dtype, [s.base.arg for s in srcs])) if op in GroupOp.Binary: x, y = self, in_srcs[0] - if op is BinaryOps.ADD: + if op is Ops.ADD: if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y - if op is BinaryOps.MUL: + if op is Ops.MUL: if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0): return y if val == 1 else y.const_like(0) if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0): return x if val == 1 else x.const_like(0) - if op is BinaryOps.IDIV and y.is_unrealized_unmasked_const() and y.base.arg == 1: return x + if op is Ops.IDIV and y.is_unrealized_unmasked_const() and y.base.arg == 1: return x return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, None, tuple(srcs)) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index c90571448d..252f66b6f9 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -29,14 +29,14 @@ class SimpleMathTrait: dtype: Optional[DType] = getattr(self, 'dtype', None) assert dtype is not None, "MathTraits __neg__ requires a dtype" return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1) - def add(self, x, reverse=False): return self._binop(BinaryOps.ADD, x, reverse) - def mul(self, x, reverse=False): return self._binop(BinaryOps.MUL, x, reverse) - def bitwise_and(self, x, reverse=False): return self._binop(BinaryOps.AND, x, reverse) - def bitwise_or(self, x, reverse=False): return self._binop(BinaryOps.OR, x, reverse) - def xor(self, x, reverse=False): return self._binop(BinaryOps.XOR, x, reverse) - def idiv(self, x, reverse=False): return self._binop(BinaryOps.IDIV, x, reverse) - def sub(self, x, reverse=False): return self.ufix(x).alu(BinaryOps.ADD, -self) if reverse else self.alu(BinaryOps.ADD, self.ufix(-x)) - def div(self, x, reverse=False): return (self.ufix(x)*self.alu(UnaryOps.RECIP)) if reverse else (self*self.ufix(x).alu(UnaryOps.RECIP)) + def add(self, x, reverse=False): return self._binop(Ops.ADD, x, reverse) + def mul(self, x, reverse=False): return self._binop(Ops.MUL, x, reverse) + def bitwise_and(self, x, reverse=False): return self._binop(Ops.AND, x, reverse) + def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse) + def xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse) + def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse) + def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x)) + def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP)) def __neg__(self): return self.neg() @@ -58,9 +58,9 @@ class SimpleMathTrait: def __ror__(self, x): return self.bitwise_or(x, True) def __rxor__(self, x): return self.xor(x, True) - def lt(self, x): return self.alu(BinaryOps.CMPLT, self.ufix(x)) - def gt(self, x): return self.ufix(x).alu(BinaryOps.CMPLT, self) - def ne(self, x): return self.alu(BinaryOps.CMPNE, self.ufix(x)) + def lt(self, x): return self.alu(Ops.CMPLT, self.ufix(x)) + def gt(self, x): return self.ufix(x).alu(Ops.CMPLT, self) + def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x)) def ge(self, x): return self.lt(x).logical_not() def le(self, x): return self.gt(x).logical_not() def eq(self, x): return self.ne(x).logical_not() @@ -74,26 +74,26 @@ class SimpleMathTrait: class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method # TODO: move to Tensor when new backward is done - def lshift(self, x, reverse=False): return self._binop(BinaryOps.SHL, x, reverse) - def rshift(self, x, reverse=False): return self._binop(BinaryOps.SHR, x, reverse) + def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse) + def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse) def __lshift__(self, x): return self.lshift(x) def __rshift__(self, x): return self.rshift(x) def __rlshift__(self, x): return self.lshift(x, True) def __rrshift__(self, x): return self.rshift(x, True) # not in Tensor - def __mod__(self, x): return self.alu(BinaryOps.MOD, self.ufix(x)) - def __rmod__(self, x): return self.ufix(x).alu(BinaryOps.MOD, self) + def __mod__(self, x): return self.alu(Ops.MOD, self.ufix(x)) + def __rmod__(self, x): return self.ufix(x).alu(Ops.MOD, self) - def maximum(self, x): return self.alu(BinaryOps.MAX, self.ufix(x)) + def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x)) def minimum(self, x): return -(-self).maximum(-x) - def where(self, x, y): return self.alu(TernaryOps.WHERE, x, x.ufix(y)) - def threefry(self, seed): return self.alu(BinaryOps.THREEFRY, seed) - def reciprocal(self): return self.alu(UnaryOps.RECIP) - def sqrt(self): return self.alu(UnaryOps.SQRT) - def sin(self): return self.alu(UnaryOps.SIN) - def log2(self): return self.alu(UnaryOps.LOG2) - def exp2(self): return self.alu(UnaryOps.EXP2) + def where(self, x, y): return self.alu(Ops.WHERE, x, x.ufix(y)) + def threefry(self, seed): return self.alu(Ops.THREEFRY, seed) + def reciprocal(self): return self.alu(Ops.RECIP) + def sqrt(self): return self.alu(Ops.SQRT) + def sin(self): return self.alu(Ops.SIN) + def log2(self): return self.alu(Ops.LOG2) + def exp2(self): return self.alu(Ops.EXP2) # the order of these Ops controls the order of the toposort class Ops(FastEnum): @@ -182,11 +182,8 @@ class GroupOp: # do not preserve f(0) = 0 UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV} -# TODO: remove this? -UnaryOps = BinaryOps = MetaOps = TernaryOps = Ops - # https://en.wikipedia.org/wiki/Identity_element -def identity_element(op:Ops, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt) +def identity_element(op:Ops, dt:DType): return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt) def can_pad(u:UOp) -> bool: return not any(x.op in GroupOp.UnsafePad for x in u.sparents) @@ -325,7 +322,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs) def alu(self, arg, *src:UOp): out_dtype = (self, *src)[-1].dtype - if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} and out_dtype is not None: + if arg in {Ops.CMPLT, Ops.CMPNE} and out_dtype is not None: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool return UOp(arg, out_dtype, (self,)+src) @staticmethod @@ -392,15 +389,15 @@ class UOp(MathTrait, metaclass=UOpMetaClass): """largest known int that divides self""" if self.op is Ops.CONST: return self.arg if self.op is Ops.VCONST: return math.gcd(*self.arg) - if self.op is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor()) - if self.op is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1 + if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor()) + if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1 return 1 def divides(self, v) -> Optional[UOp]: if v==1: return self if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None - if self.op is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None - if self.op is BinaryOps.MUL: + if self.op is Ops.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None + if self.op is Ops.MUL: if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1] if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1 return None # generic None if we aren't sure @@ -412,18 +409,18 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def _min_max(self) -> Tuple[ConstType, ConstType]: if self.op in GroupOp.Binary and not dtypes.is_float(self.dtype): (s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max - if self.op is BinaryOps.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax - if self.op is BinaryOps.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals) - if self.op is BinaryOps.MOD and s1_vmin > 0: return 0, s1_vmax-1 - if self.op is BinaryOps.IDIV and s1_vmin == s1_vmax: # min/max are equal in a CONST + if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax + if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals) + if self.op is Ops.MOD and s1_vmin > 0: return 0, s1_vmax-1 + if self.op is Ops.IDIV and s1_vmin == s1_vmax: # min/max are equal in a CONST if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin) - if self.op is BinaryOps.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax) - if self.op is BinaryOps.CMPLT: return (s0_vmax 0 else -math.inf if x == 0 else math.nan, UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x), - UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x), - UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, - UnaryOps.NEG: operator.neg, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, - BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf, - BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt, BinaryOps.XOR: operator.xor, - BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_, BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift, - TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z} + Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: hook_overflow(math.inf, lambda x: 2**x), + Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x), + Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, + Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, + Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf, + Ops.MAX: max, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt, Ops.XOR: operator.xor, + Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, + Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z} def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True): if dtype.count > 1: @@ -515,7 +512,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]: elif u.op is Ops.STORE: mem += u.src[1].dtype.itemsize * mults elif u.op in GroupOp.ALU and u not in dont_count: - flops += (mults * (2 if u.op is TernaryOps.MULACC else 1)) * u.dtype.count + flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults return flops, mem @@ -857,7 +854,7 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]: if 0 < c and 0 <= x.vmin and (quotient:=x.vmin//c) == x.vmax//c: return x-quotient*c remainder, something_changed = [], False - for u in split_uop(x, BinaryOps.ADD): + for u in split_uop(x, Ops.ADD): if (factor:=u.const_factor())%c != factor: divides = u.divides(factor)*(factor%c) assert divides is not None @@ -877,7 +874,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]: if 0 <= x.vmin and x.vmax < c: return x.const_like(0) quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1 - for u in split_uop(x, BinaryOps.ADD): + for u in split_uop(x, Ops.ADD): if u.op is Ops.CONST: # add all const together first if rem_const != 0: something_changed = True @@ -911,7 +908,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]: return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo def lt_folding(x:UOp, c:int) -> Optional[UOp]: - p, np = partition(split_uop(x, BinaryOps.ADD), lambda u: u.const_factor() == 1) + p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1) if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d: return cast(UOp, functools.reduce(operator.add, np).divides(d)).lt(c//d) return None @@ -919,7 +916,7 @@ def lt_folding(x:UOp, c:int) -> Optional[UOp]: def fold_unrolled_divs(divs:UOp): # div pattern in unrolled arange # example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x - add_chain, denominator, seen_const, ans = list(split_uop(divs, BinaryOps.ADD)), None, [], None + add_chain, denominator, seen_const, ans = list(split_uop(divs, Ops.ADD)), None, [], None for u in add_chain: if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None if denominator is None: denominator = u.src[1].arg @@ -941,9 +938,9 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]: # (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints. # returns x0 + x1 + ... in such case, or None if not changed, ret = False, [] - for u in split_uop(X, BinaryOps.ADD): + for u in split_uop(X, Ops.ADD): # assumed the const is the last src of MUL - if u.op is BinaryOps.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0: + if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0: changed = True u = u.src[0] if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None @@ -953,8 +950,8 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]: def is_increasing(f:UOp) -> bool: # is f a monotonically increasing function regards its input if f.op in GroupOp.Irreducible: return True - if f.op is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1]) - if f.op in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0]) + if f.op is Ops.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1]) + if f.op in (Ops.MUL, Ops.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0]) return False # False if not sure def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]: @@ -962,10 +959,10 @@ def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]: # if it's X >= c, returns X, False, c # (X < c).ne(True) -> X >= c - if valid.op is BinaryOps.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \ - (s0:=valid.src[0]).op is BinaryOps.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg + if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \ + (s0:=valid.src[0]).op is Ops.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg # X < c -> X <= c-1 - if valid.op is BinaryOps.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1 + if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1 raise ValueError(f"not able to parse {valid=}") def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]: @@ -973,7 +970,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]: # first, parse valid into {expr: (lower_bound, upper_bound)} bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None]) - for stmt in split_uop(valid, BinaryOps.AND): + for stmt in split_uop(valid, Ops.AND): try: expr, is_upper, c = parse_valid(stmt) except ValueError: return uop # give up if we cannot parse the valid bounds[expr][int(is_upper)] = c @@ -985,9 +982,9 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]: # every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop candidates = [] - if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, BinaryOps.ADD)): + if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)): # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output - candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, BinaryOps.ADD)]) + candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)]) # try checking the whole clause if expr in uop.sparents: candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))]) @@ -1128,8 +1125,8 @@ symbolic_flat = symbolic+PatternMatcher([ _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) # for debug -syms = { BinaryOps.ADD: "+", BinaryOps.SUB: "-", BinaryOps.IDIV: "//", BinaryOps.MOD: "%", BinaryOps.SHL: "<<", BinaryOps.SHR: ">>", - BinaryOps.MUL: "*", BinaryOps.CMPLT: "<", BinaryOps.CMPNE: "!=", BinaryOps.AND: "&", BinaryOps.OR: "|", BinaryOps.XOR: "^"} +syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>", + Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"} renderer = PatternMatcher([ (UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])), (UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")), diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index f634bf4df4..790c59a861 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast import os, math from collections import defaultdict, Counter -from tinygrad.ops import GroupOp, UnaryOps, BinaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16 +from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16 from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType from tinygrad.renderer import Renderer, TensorCore @@ -42,13 +42,13 @@ base_rewrite = PatternMatcher([ (UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)), # new load/store (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))), - lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == BinaryOps.ADD else ctx[idx]})"), + lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"), (UPat(Ops.LOAD, src=(UPat.var('bidx'), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"), (UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"), (UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"), # alu/gep (UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op]( - *([strip_parens(ctx[v]) if v.op == x.op and x.op in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else ctx[v] for v in x.src]), x.dtype)), + *([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR} else ctx[v] for v in x.src]), x.dtype)), (UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \ (f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")), ]) @@ -82,13 +82,13 @@ class CStyleLanguage(Renderer): infinity: str = "INFINITY" nan: str = "NAN" code_for_op: Dict = { - UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})", UnaryOps.RECIP: lambda x,dtype: f"(1/{x})", UnaryOps.NEG: lambda x,dtype: f"-{x}", - UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})", - BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})", - BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", - BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})", BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})", - BinaryOps.SHR: lambda a,b,dtype: f"({a}>>{b})", BinaryOps.SHL: lambda a,b,dtype: f"({a}<<{b})", BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", - TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})" } + Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIP: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}", + Ops.EXP2: lambda x,dtype: f"exp2({x})", Ops.LOG2: lambda x,dtype: f"log2({x})", Ops.SIN: lambda x,dtype: f"sin({x})", + Ops.AND: lambda a,b,dtype: f"({a}&{b})", Ops.XOR: lambda a,b,dtype: f"({a}^{b})", Ops.OR: lambda a,b,dtype: f"({a}|{b})", + Ops.ADD: lambda a,b,dtype: f"({a}+{b})", Ops.SUB: lambda a,b,dtype: f"({a}-{b})", Ops.MUL: lambda a,b,dtype: f"({a}*{b})", + Ops.MOD: lambda a,b,dtype: f"({a}%{b})", Ops.IDIV: lambda a,b,dtype: f"({a}/{b})", Ops.CMPNE: lambda a,b,dtype: f"({a}!={b})", + Ops.SHR: lambda a,b,dtype: f"({a}>>{b})", Ops.SHL: lambda a,b,dtype: f"({a}<<{b})", Ops.CMPLT: lambda a,b,dtype: f"({a}<{b})", + Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})" } string_rewrite = base_rewrite extra_matcher = extra_pm @@ -174,8 +174,8 @@ class ClangRenderer(CStyleLanguage): # language options buffer_suffix = " restrict" type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"} - code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [UnaryOps.EXP2, UnaryOps.SIN, UnaryOps.LOG2]}), - UnaryOps.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"} + code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2]}), + Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"} if AMX: tensor_cores = [TensorCore(dims=(sz,sz,1), threads=[], reduce_axes=[], upcast_axes=([(1,sz)],[(0,sz)],[(1,sz),(0,sz)]), dtype_in=dt, dtype_out=dt) @@ -265,12 +265,12 @@ class MetalRenderer(CStyleLanguage): type_map = {dtypes.bfloat16: "bfloat"} # precise::sin - code_for_op = {**CStyleLanguage.code_for_op, UnaryOps.SIN: lambda x,dtype: f"precise::sin({x})"} + code_for_op = {**CStyleLanguage.code_for_op, Ops.SIN: lambda x,dtype: f"precise::sin({x})"} # upcast to float32 all the ops that don't support bfloat16 extra_matcher = PatternMatcher([ # NOTE: this is copied from PTX - (UPat((UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN), dtype=dtypes.bfloat16, name="x"), + (UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtypes.bfloat16, name="x"), lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))), ]) + extra_pm @@ -312,11 +312,11 @@ class CUDARenderer(CStyleLanguage): code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+int(x))}", "l": lambda x: f"threadIdx.{chr(120+int(x))}", "i": lambda x: f"(blockIdx.{chr(120+int(x))}*blockDim.{chr(120+int(x))}+threadIdx.{chr(120+int(x))})"} code_for_op = { **CStyleLanguage.code_for_op, - UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})", - UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})", - UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})", - UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})", - UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" } + Ops.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})", + Ops.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})", + Ops.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})", + Ops.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})", + Ops.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" } type_map = {dtypes.bfloat16: "nv_bfloat16"} def render_vector_prefix(self, dt:DType) -> str: @@ -375,10 +375,10 @@ class AMDRenderer(CStyleLanguage): code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})", "i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"} code_for_op = { **CStyleLanguage.code_for_op, - UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", - UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", - UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", - UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})" } + Ops.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", + Ops.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", + Ops.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", + Ops.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})" } smem_prefix = "__attribute__((shared))" barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \ '__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");' @@ -433,9 +433,9 @@ class DSPRenderer(ClangRenderer): buffer_suffix = " restrict __attribute__((align_value(128)))" kernel_prefix = "__attribute__((noinline)) " type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" } - code_for_op = {**ClangRenderer.code_for_op, UnaryOps.SIN: lambda x,dtype: f"__builtin_sin({x})", - UnaryOps.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})", - UnaryOps.EXP2: lambda x,dtype: f"__builtin_exp2l({x})" if dtype == dtypes.float64 else f"__builtin_exp2f({x})"} + code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})", + Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})", + Ops.EXP2: lambda x,dtype: f"__builtin_exp2l({x})" if dtype == dtypes.float64 else f"__builtin_exp2f({x})"} def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str: ret = super().render_kernel(function_name, kernel, bufs, uops, prefix) diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 4912c024de..0c78a226f3 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -1,7 +1,7 @@ from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable, Tuple import struct from collections import defaultdict -from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, GroupOp +from tinygrad.ops import Ops, UOp, PatternMatcher, UPat, GroupOp from tinygrad.dtype import dtypes, DType, PtrDType, ConstType from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer @@ -15,24 +15,24 @@ def render_val(x, dtype): return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "") asm_for_op: Dict[Ops, Callable] = { - UnaryOps.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};", - UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};", - UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};", - BinaryOps.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", BinaryOps.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};", - BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};", - BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};", - BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};", - BinaryOps.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if name == "pred" else f"and.b{name[1:]} {d}, {a}, {b};", - BinaryOps.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if name == "pred" else f"or.b{name[1:]} {d}, {a}, {b};", - BinaryOps.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};", - BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};", - BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", BinaryOps.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};", - TernaryOps.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};", - TernaryOps.WHERE: lambda d,a,b,c,dt,name: + Ops.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};", + Ops.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", Ops.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};", + Ops.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", Ops.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};", + Ops.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", Ops.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};", + Ops.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};", + Ops.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};", + Ops.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};", + Ops.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if name == "pred" else f"and.b{name[1:]} {d}, {a}, {b};", + Ops.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if name == "pred" else f"or.b{name[1:]} {d}, {a}, {b};", + Ops.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};", + Ops.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", Ops.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};", + Ops.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", Ops.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};", + Ops.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};", + Ops.WHERE: lambda d,a,b,c,dt,name: f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};" } -supports_half: List[Ops] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE] +supports_half: List[Ops] = [Ops.EXP2, Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPLT, Ops.WHERE] doesnt_support_half: Tuple[Ops, ...] = tuple(op for op in asm_for_op.keys() if op not in supports_half) ptx_matcher = PatternMatcher([ # bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only) @@ -151,8 +151,8 @@ class PTXRenderer(Renderer): kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{uops.index(u)}", pred_reg, invert=True)) elif uop is Ops.BARRIER and self.barrier: kk(self.barrier) elif uop is Ops.ENDRANGE: - kk(self.code_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]), - self.code_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int])) + kk(self.code_for_op[Ops.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]), + self.code_for_op[Ops.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int])) kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred)) elif uop is Ops.ENDIF: kk(f"IF_{r[src[0].src[0]][1:]}_{uops.index(src[0])}:") @@ -167,7 +167,7 @@ class PTXRenderer(Renderer): else: if uop is Ops.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:])) elif uop in GroupOp.ALU: - src_dtype = src[0].dtype if uop in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype + src_dtype = src[0].dtype if uop in {Ops.CMPLT, Ops.CMPNE} else dtype kk(self.code_for_op[uop](ssa("alu", u), *[r[x] for x in src], src_dtype, self.types[src_dtype])) elif uop is Ops.DEFINE_ACC: if dtype.count > 1: diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 452618f813..32a38672cd 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -7,7 +7,7 @@ import pickle, base64, itertools, time, struct from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate from tinygrad.helpers import all_same, getenv, flatten from tinygrad.device import Compiled, Compiler, Allocator -from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, Ops, UOp, GroupOp +from tinygrad.ops import exec_alu, Ops, UOp, GroupOp from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer, ClangRenderer @@ -175,7 +175,7 @@ class PythonProgram: else: raise NotImplementedError(f"unimplemented tensor core {arg}") elif uop in GroupOp.ALU: assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {uop}" - assert all_same([dtype] + dtp) or uop in {BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {uop}" + assert all_same([dtype] + dtp) or uop in {Ops.CMPNE, Ops.CMPLT, Ops.WHERE}, f"dtype mismatch on {uop}" ul[i] = [exec_alu(uop, dtype, p) for p in zip(*inp)] assert i in ul, (uop, dtype, idp, arg) i += 1 diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index a0d93a7bc9..5204b27f1f 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -5,7 +5,7 @@ from typing import Tuple, List, Optional, Dict, Set from tinygrad.helpers import merge_dicts, getenv from tinygrad.shape.view import View, strides_for_shape from tinygrad.dtype import dtypes -from tinygrad.ops import UOp, Ops, BinaryOps, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid +from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid @dataclass(frozen=True, order=True) class ShapeTracker: @@ -77,7 +77,7 @@ class ShapeTracker: # TODO: always apply these in to_indexed_uops? if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat) - for c in split_uop(idx, BinaryOps.ADD): + for c in split_uop(idx, Ops.ADD): if c.op is Ops.RANGE: ret[c.arg[0]] = 1 if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 0efb61ad9b..ccbe0cae93 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -9,7 +9,7 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN from tinygrad.multi import MultiLazyBuffer -from tinygrad.ops import MetaOps, smax, smin, resolve, UOp, Ops, BinaryOps, sint, Variable, SimpleMathTrait +from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait from tinygrad.device import Device, Buffer, BufferOptions from tinygrad.engine.lazy import LazyBuffer from tinygrad.engine.realize import run_schedule @@ -51,7 +51,7 @@ def _to_np_dtype(dtype:DType) -> Optional[type]: return np.dtype(dtype.fmt).type if dtype.fmt is not None else None def _fromnp(x: 'np.ndarray') -> LazyBuffer: # type: ignore [name-defined] # noqa: F821 - ret = LazyBuffer.metaop(MetaOps.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY") + ret = LazyBuffer.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY") # fake realize ret.buffer.allocate(x) del ret.srcs @@ -64,9 +64,9 @@ def get_shape(x) -> Tuple[int, ...]: return (len(subs),) + (subs[0] if subs else ()) def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer: - if isinstance(x, bytes): ret, data = LazyBuffer.metaop(MetaOps.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x + if isinstance(x, bytes): ret, data = LazyBuffer.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x else: - ret = LazyBuffer.metaop(MetaOps.EMPTY, get_shape(x), dtype, "PYTHON") + ret = LazyBuffer.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON") assert dtype.fmt is not None, f"{dtype=} has None fmt" truncate_function = truncate[dtype] data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)]) @@ -134,10 +134,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method # create a LazyBuffer from the different types of inputs if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" - elif isinstance(data, get_args(ConstType)): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data) + elif isinstance(data, get_args(ConstType)): data = _metaop(Ops.CONST, tuple(), dtype or dtypes.from_py(data), device, data) elif isinstance(data, UOp): assert data.op is Ops.BIND and data.src[0].op is Ops.DEFINE_VAR and data.src[1].op is Ops.CONST, f"can't create tensor from UOp {data}" - data = _metaop(MetaOps.CONST, tuple(), dtype or data.dtype, device, data) + data = _metaop(Ops.CONST, tuple(), dtype or data.dtype, device, data) elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype) elif isinstance(data, (list, tuple)): if dtype is None: @@ -145,15 +145,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata else: data = _frompy(data, dtype) - elif data is None: data = _metaop(MetaOps.EMPTY, (0,), dtype or dtypes.default_float, device) + elif data is None: data = _metaop(Ops.EMPTY, (0,), dtype or dtypes.default_float, device) elif str(type(data)) == "": import numpy as np assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}" - if data.shape == (): data = _metaop(MetaOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item()) + if data.shape == (): data = _metaop(Ops.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item()) else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data) # type: ignore [name-defined] elif isinstance(data, pathlib.Path): dtype = dtype or dtypes.uint8 - data = _metaop(MetaOps.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}") + data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}") # by this point, it has to be a LazyBuffer if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): @@ -384,9 +384,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method def from_uop(y:UOp, **kwargs) -> Tensor: if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False) - if y.op is BinaryOps.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1]) - if y.op is BinaryOps.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1]) - if y.op is BinaryOps.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1])) + if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1]) + if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1]) + if y.op is Ops.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1])) raise RuntimeError(f"unhandled UOp {y}") # ***** creation entrypoint ***** @@ -412,7 +412,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method print(t.shape) ``` """ - return Tensor._metaop(MetaOps.EMPTY, argfix(*shape), **kwargs) + return Tensor._metaop(Ops.EMPTY, argfix(*shape), **kwargs) @staticmethod def from_blob(ptr:int, shape:Tuple[int, ...], **kwargs) -> Tensor: @@ -424,7 +424,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method Additionally, all other keyword arguments are passed to the constructor of the tensor. """ - r = Tensor._metaop(MetaOps.EMPTY, shape, **kwargs) + r = Tensor._metaop(Ops.EMPTY, shape, **kwargs) r.lazydata.buffer.allocate(external_ptr=ptr) del r.lazydata.srcs # fake realize return r