mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Remove UnaryOps, BinaryOps, TernaryOps, MetaOps [pr] (#7725)
* remove unaryops * remove ternaryops * remove metaops * hotfix * remove binaryops * hotfix: test_pattern_matcher --------- Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
@@ -39,7 +39,7 @@ DEVICE = "CLANG" # NOTE: you can change this!
|
||||
import struct
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, Ops
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
# allocate some buffers + load in values
|
||||
@@ -81,15 +81,15 @@ from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
|
||||
# allocate some values + load in values
|
||||
a = LazyBuffer.metaop(MetaOps.EMPTY, (1,), dtypes.int32, DEVICE)
|
||||
b = LazyBuffer.metaop(MetaOps.EMPTY, (1,), dtypes.int32, DEVICE)
|
||||
a = LazyBuffer.metaop(Ops.EMPTY, (1,), dtypes.int32, DEVICE)
|
||||
b = LazyBuffer.metaop(Ops.EMPTY, (1,), dtypes.int32, DEVICE)
|
||||
a.buffer.allocate().copyin(memoryview(bytearray(struct.pack("I", 2))))
|
||||
b.buffer.allocate().copyin(memoryview(bytearray(struct.pack("I", 3))))
|
||||
del a.srcs
|
||||
del b.srcs
|
||||
|
||||
# describe the computation
|
||||
out = a.alu(BinaryOps.ADD, b)
|
||||
out = a.alu(Ops.ADD, b)
|
||||
|
||||
# schedule the computation as a list of kernels
|
||||
sched = create_schedule([out])
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GlobalCou
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import get_kernel, run_schedule
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.ops import MetaOps, Ops
|
||||
from tinygrad.ops import Ops
|
||||
|
||||
TIMING = getenv("TIMING")
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from typing import Tuple
|
||||
from tinygrad import Variable
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
from tinygrad.ops import UOp, Ops, KernelInfo, TernaryOps, BinaryOps, UnaryOps, MetaOps
|
||||
from tinygrad.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
8
test/external/external_test_nv.py
vendored
8
test/external/external_test_nv.py
vendored
@@ -9,7 +9,7 @@ from tinygrad.engine.realize import get_runner, CompiledRunner
|
||||
from test.external.fuzz_linearizer import get_fuzz_rawbufs
|
||||
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer
|
||||
from tinygrad.ops import LazyOp, Ops, ReduceOps, BufferOps, MemBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
@@ -26,12 +26,12 @@ class TestNV(unittest.TestCase):
|
||||
TestNV.addr = struct.pack("QQ", TestNV.b.lazydata.buffer._buf.va_addr, TestNV.a.lazydata.buffer._buf.va_addr)
|
||||
|
||||
def test_oor_kernels(self):
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
opts = [Opt(op=OptOps.TC, axis=6, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=2)] # noqa: E501
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["NV"])
|
||||
|
||||
def test_error_on_huge_dims(self):
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 683, 1), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 683, 1), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=2)] # noqa: E501
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
lin = Kernel(ast)
|
||||
@@ -43,7 +43,7 @@ class TestNV(unittest.TestCase):
|
||||
|
||||
def test_buf4_usage(self):
|
||||
TestNV.along = Tensor([105615], device="NV").realize()
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SIN, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.SIN, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
temp_runner = get_runner(TestNV.d0.dname, (ast,))
|
||||
temp_runner([TestNV.b.lazydata.buffer, TestNV.along.lazydata.buffer], var_vals={})
|
||||
val = TestNV.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
|
||||
4
test/external/external_test_valid_remove.py
vendored
4
test/external/external_test_valid_remove.py
vendored
@@ -2,7 +2,7 @@
|
||||
import unittest
|
||||
|
||||
from tinygrad import Device
|
||||
from tinygrad.ops import UOp, Ops, BinaryOps
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.engine.search import Opt, OptOps
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -20,7 +20,7 @@ class TestOpenpilotValidhack(unittest.TestCase):
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MAX, dtypes.float, arg=None, src=(
|
||||
x5:=UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8, 9, 10)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 8, 9, 10)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
|
||||
4
test/external/fuzz_linearizer.py
vendored
4
test/external/fuzz_linearizer.py
vendored
@@ -25,7 +25,7 @@ from tinygrad.codegen.kernel import Opt, OptOps
|
||||
from tinygrad.engine.search import get_kernel_actions, bufs_from_lin
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG, Timing
|
||||
from tinygrad.ops import UnaryOps, UOp, Ops
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
def on_linearizer_will_run(): pass
|
||||
@@ -252,7 +252,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2, opts_list=None):
|
||||
def _is_simple(lin: Kernel) -> bool:
|
||||
if len(lin.ast.src) > 1: return False
|
||||
ast:UOp = lin.ast.src[0]
|
||||
if ast.src[0].op is UnaryOps.CAST and ast.src[0].src[0].op is Ops.LOAD: return True
|
||||
if ast.src[0].op is Ops.CAST and ast.src[0].src[0].op is Ops.LOAD: return True
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
8
test/external/fuzz_schedule.py
vendored
8
test/external/fuzz_schedule.py
vendored
@@ -6,7 +6,7 @@ from tinygrad.engine.realize import capturing, lower_schedule_item
|
||||
from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.engine.schedule import LBScheduleItem, _graph_schedule, ScheduleItem
|
||||
from tinygrad.ops import MetaOps
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
|
||||
ctx_vars = { MULTIOUTPUT: (0, 1) }
|
||||
@@ -33,7 +33,7 @@ def fuzz_schedule(outs:List[LazyBuffer]):
|
||||
for lsi in ts:
|
||||
for out in lsi.outputs:
|
||||
# freeze assign state before exec
|
||||
if out.op is MetaOps.ASSIGN:
|
||||
if out.op is Ops.ASSIGN:
|
||||
prerealized[out] = out.buffer.as_buffer()
|
||||
assign_targets[out.srcs[1]] = out
|
||||
for x in lsi.inputs:
|
||||
@@ -50,9 +50,9 @@ def fuzz_schedule(outs:List[LazyBuffer]):
|
||||
rawbufs: Dict[LazyBuffer, Buffer] = {}
|
||||
for lsi in ts:
|
||||
for out in lsi.outputs:
|
||||
base = rawbufs[lsi.inputs[0]].base if out.op is MetaOps.BUFFER_VIEW else None
|
||||
base = rawbufs[lsi.inputs[0]].base if out.op is Ops.BUFFER_VIEW else None
|
||||
rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype, base=base)
|
||||
if out.op is MetaOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
|
||||
if out.op is Ops.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
|
||||
for x in lsi.inputs:
|
||||
if x not in rawbufs:
|
||||
# override the assign_target after ASSIGN
|
||||
|
||||
6
test/external/fuzz_symbolic.py
vendored
6
test/external/fuzz_symbolic.py
vendored
@@ -42,9 +42,9 @@ def gt(expr, rng=None):
|
||||
return expr > rng, rng
|
||||
|
||||
# NOTE: you have to replace these for this test to pass
|
||||
from tinygrad.ops import python_alu, BinaryOps
|
||||
python_alu[BinaryOps.MOD] = lambda x,y: x%y
|
||||
python_alu[BinaryOps.IDIV] = lambda x,y: x//y
|
||||
from tinygrad.ops import python_alu, Ops
|
||||
python_alu[Ops.MOD] = lambda x,y: x%y
|
||||
python_alu[Ops.IDIV] = lambda x,y: x//y
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops = [add_v, div, mul, add_num, mod]
|
||||
|
||||
@@ -154,7 +154,7 @@ class TestReduceOpsConstFolding(unittest.TestCase):
|
||||
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).sum())
|
||||
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).sum().numpy(), 4)
|
||||
|
||||
# NOTE: cannot just count the non-padded area because some UnaryOps f do not have f(0) = 0.
|
||||
# NOTE: cannot just count the non-padded area because some Ops f do not have f(0) = 0.
|
||||
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).exp().sum())
|
||||
np.testing.assert_allclose(Tensor.ones(4).pad(((1, 1),)).exp().sum().numpy(), 4 * math.e + 2)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import numpy as np
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.engine.lazy import LazyBuffer, MetaOps
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
|
||||
class TestLazyBuffer(unittest.TestCase):
|
||||
@@ -95,24 +95,24 @@ class TestReduceOp(unittest.TestCase):
|
||||
|
||||
class TestView(unittest.TestCase):
|
||||
def test_all_masked_out(self):
|
||||
# start with non CONST MetaOps
|
||||
# start with non CONST Ops
|
||||
a = Tensor.rand(10, 10)
|
||||
assert a.lazydata.base.op is not MetaOps.CONST
|
||||
assert a.lazydata.base.op is not Ops.CONST
|
||||
|
||||
# all masked out, degrades to const 0
|
||||
b = a.pad(((0, 10), None))[10:]
|
||||
assert b.shape == (10, 10)
|
||||
assert b.lazydata.base.op is MetaOps.CONST and b.lazydata.base.arg == 0
|
||||
assert b.lazydata.base.op is Ops.CONST and b.lazydata.base.arg == 0
|
||||
|
||||
# mask out dim = 1 works too
|
||||
b = a.pad((None, (0, 10)))[:, 10:]
|
||||
assert b.shape == (10, 10)
|
||||
assert b.lazydata.base.op is MetaOps.CONST and b.lazydata.base.arg == 0
|
||||
assert b.lazydata.base.op is Ops.CONST and b.lazydata.base.arg == 0
|
||||
|
||||
# partial masked out does not degrade into CONST
|
||||
b = a.pad(((0, 5), None))[5:]
|
||||
assert b.shape == (10, 10)
|
||||
assert b.lazydata.base.op is not MetaOps.CONST
|
||||
assert b.lazydata.base.op is not Ops.CONST
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -6,7 +6,7 @@ from dataclasses import replace
|
||||
from test.helpers import ast_const
|
||||
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel
|
||||
from tinygrad.codegen.lowerer import get_grouped_dims
|
||||
from tinygrad.ops import UOp, Ops, BinaryOps, TernaryOps, UnaryOps, GroupOp
|
||||
from tinygrad.ops import UOp, Ops, GroupOp
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
@@ -109,10 +109,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
st_x = x.lazydata.st
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((1, 32)).expand((32, 32)).to_uop()))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (1,)))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (1,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((32, 1)).to_uop()))
|
||||
diff = second_x + first_reduce*ast_const(dtypes.float, -1, (32, 1))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (0,)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (0,)))
|
||||
store = UOp(Ops.STORE, dtypes.void, (g0, ShapeTracker.from_shape((1, 1)).to_uop(), second_reduce))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
opts = [
|
||||
@@ -145,10 +145,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
st_x = x.lazydata.st
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((27, 1, 32, 5)).expand((27, 32, 32, 5)).to_uop()))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((27, 32, 1, 5)).to_uop()))
|
||||
diff = second_x + first_reduce*ast_const(dtypes.float, -1, (27, 32, 1, 5))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
opts = [
|
||||
@@ -207,13 +207,13 @@ class TestLinearizer(unittest.TestCase):
|
||||
x2 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
|
||||
g0, g1, g2, g3 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(4)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x0.lazydata.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5)).to_uop()))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g2, x1.lazydata.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5)).to_uop()))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 32, 32, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (2,)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (2,)))
|
||||
third_x = UOp(Ops.LOAD, dtypes.float, (g3, x2.lazydata.st.reshape((27, 32, 1, 1, 5)).to_uop()))
|
||||
mul = (third_x*second_reduce)
|
||||
third_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (mul,), (BinaryOps.ADD, (1,)))
|
||||
third_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (mul,), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 1, 5)).to_uop(), third_reduce))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
wanna_output = (x2.numpy()*(x1.numpy()-x0.numpy().sum(axis=1, keepdims=True)).sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,1,5)
|
||||
@@ -234,11 +234,11 @@ class TestLinearizer(unittest.TestCase):
|
||||
st = x.lazydata.st
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, st.reshape((8, 1, 32, 8, 1, 16)).expand((8, 32, 32, 8, 16, 16)).to_uop()))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2, 5)))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2, 5)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, st.reshape((8, 32, 1, 8, 16, 1)).to_uop()))
|
||||
neg_first_reduce = first_reduce * ast_const(dtypes.float, -1, (8, 32, 1, 8, 16, 1))
|
||||
squares = (second_x+neg_first_reduce)
|
||||
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (1, 4)))
|
||||
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (1, 4)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((8, 1, 1, 8, 1, 1)).to_uop(), squares_sum,))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
wanna_output = (x.numpy()-x.numpy().sum(axis=(1,3), keepdims=True)).sum(axis=(1,3)).reshape((8,1,1,8,1,1))
|
||||
@@ -285,10 +285,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).softmax(1).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
opts = [
|
||||
@@ -317,11 +317,11 @@ class TestLinearizer(unittest.TestCase):
|
||||
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop()))
|
||||
first_x_p = UOp(Ops.LOAD, dtypes.float, (g2, x_p.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop()))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
||||
first_reduce_p = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x_p.alu(UnaryOps.EXP2),), (BinaryOps.ADD, (2,)))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
first_reduce_p = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x_p.alu(Ops.EXP2),), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1)).to_uop()))
|
||||
diff = (second_x+(first_reduce + first_reduce_p)*ast_const(dtypes.float, -1, (4, 32, 1)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((4, 1, 1)).to_uop(), second_reduce))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
opts = [
|
||||
@@ -352,10 +352,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
|
||||
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store0 = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
|
||||
second_out = second_reduce * ast_const(dtypes.float, 1/15, (27, 1, 1, 5))
|
||||
store1 = UOp(Ops.STORE, src=(g1, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_out))
|
||||
@@ -375,10 +375,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
|
||||
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store0 = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
|
||||
store1 = UOp(Ops.STORE, src=(g1, ShapeTracker(views=(View(shape=(27,15,1,5), strides=(5,0,1,1), offset=0, mask=None, contiguous=False),)).to_uop(), first_reduce)) # noqa: E501
|
||||
wanna_output0 = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
|
||||
@@ -399,10 +399,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop()))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop()))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
opts = [[Opt(OptOps.UNROLL, 0, 3), Opt(OptOps.UNROLL, 0, 3)]]
|
||||
@@ -415,10 +415,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop()))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop()))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
opts = [[Opt(OptOps.UPCAST, 0, 3)]]
|
||||
@@ -434,10 +434,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
x = Tensor.randn(27, 12, 5, dtype=dtypes.float).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)).to_uop()))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 12, 1, 5)).to_uop()))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 12, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
opts = [[Opt(OptOps.GROUPTOP, 0, 3), Opt(OptOps.GROUPTOP, 1, 3)]]
|
||||
@@ -450,13 +450,13 @@ class TestLinearizer(unittest.TestCase):
|
||||
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop()))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
|
||||
neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop()))
|
||||
squares = (second_x+neg_mean)*(second_x+neg_mean)
|
||||
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,)))
|
||||
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,)))
|
||||
variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1))
|
||||
std = variance.alu(UnaryOps.SQRT)
|
||||
std = variance.alu(Ops.SQRT)
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), std))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
wanna_output = x.numpy().std(axis=2, ddof=0).reshape((15,25,1,1))
|
||||
@@ -468,13 +468,13 @@ class TestLinearizer(unittest.TestCase):
|
||||
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35)).to_uop()))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
neg_mean = first_reduce * ast_const(dtypes.float, -0.04, (15, 25, 1, 35))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).to_uop()))
|
||||
squares = (second_x+neg_mean)*(second_x+neg_mean)
|
||||
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (1,)))
|
||||
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (1,)))
|
||||
variance = squares_sum * ast_const(dtypes.float, 0.04, (15, 1, 1, 35))
|
||||
std = variance.alu(UnaryOps.SQRT)
|
||||
std = variance.alu(Ops.SQRT)
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 1, 1, 35)).to_uop(), std))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
wanna_output = x.numpy().std(axis=1, ddof=0).reshape((15,1,1,35))
|
||||
@@ -488,13 +488,13 @@ class TestLinearizer(unittest.TestCase):
|
||||
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
||||
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop()))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
|
||||
neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop()))
|
||||
squares = (second_x+neg_mean)*(second_x+neg_mean)
|
||||
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,)))
|
||||
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,)))
|
||||
variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1))
|
||||
std = variance.alu(UnaryOps.SQRT)
|
||||
std = variance.alu(Ops.SQRT)
|
||||
store_mean = UOp(Ops.STORE, src=(g1, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), neg_mean))
|
||||
store_std = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), std))
|
||||
sink = UOp(Ops.SINK, src=(store_std, store_mean))
|
||||
@@ -511,13 +511,13 @@ class TestLinearizer(unittest.TestCase):
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
# push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32)).to_uop()))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
|
||||
neg_mean = first_reduce * ast_const(dtypes.float, -0.03125, (3, 27, 32, 1))
|
||||
# store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 32, 1)).to_uop(), mean))
|
||||
# verify_lazyop(store)
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 32, 1)).to_uop()))
|
||||
squares = (second_x+neg_mean)*(second_x+neg_mean)
|
||||
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,)))
|
||||
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,)))
|
||||
variance = squares_sum * ast_const(dtypes.float, 0.03125, (3, 27, 1, 1))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 1, 1)).to_uop(), variance))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
@@ -532,13 +532,13 @@ class TestLinearizer(unittest.TestCase):
|
||||
x = Tensor.rand(4, 32).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32)).to_uop()))
|
||||
max_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.MAX, (2,)))
|
||||
max_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.MAX, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1,)).to_uop()))
|
||||
centered_x = second_x+max_x*ast_const(dtypes.float, -1, (4, 32, 1))
|
||||
exp_x = centered_x.alu(UnaryOps.EXP2)
|
||||
sum_exp_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (exp_x,), (BinaryOps.ADD, (1,)))
|
||||
# y = exp_x * sum_exp_x.alu(UnaryOps.RECIP) # kernels cannot do a return to full shape
|
||||
recip_sum_exp_x = sum_exp_x.alu(UnaryOps.RECIP)
|
||||
exp_x = centered_x.alu(Ops.EXP2)
|
||||
sum_exp_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (exp_x,), (Ops.ADD, (1,)))
|
||||
# y = exp_x * sum_exp_x.alu(Ops.RECIP) # kernels cannot do a return to full shape
|
||||
recip_sum_exp_x = sum_exp_x.alu(Ops.RECIP)
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((4,1,1)).to_uop(), recip_sum_exp_x))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
expected = 1/np.exp2(x.numpy() - x.numpy().max(axis=-1, keepdims=True)).sum(axis=-1, keepdims=True).reshape(4,1,1)
|
||||
@@ -556,7 +556,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False)))
|
||||
arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384))
|
||||
arange_axis = (3,)
|
||||
arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis))
|
||||
arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (Ops.ADD, arange_axis))
|
||||
output_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape))
|
||||
out = arange+ast_const(dtypes.int, -1, output_shape)
|
||||
store = UOp(Ops.STORE, src=(UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0), ShapeTracker.from_shape(output_shape).to_uop(), out))
|
||||
@@ -573,7 +573,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
# TODO: do this arange broadcast in the scheduler
|
||||
arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384))
|
||||
arange_axis = (3,)
|
||||
arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis))
|
||||
arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (Ops.ADD, arange_axis))
|
||||
arange_out_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape))
|
||||
arange = arange+ast_const(dtypes.int, -1, arange_out_shape)
|
||||
# p2: the indexing
|
||||
@@ -581,10 +581,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
data1 = (g1, ShapeTracker.from_shape(dataset.shape).reshape((1, 16384, 256, 1)).expand(arange_out_shape).to_uop())
|
||||
idxs = Tensor([0,3,5,6]).realize()
|
||||
data2 = (g2, ShapeTracker.from_shape((4,)+(1,)*(len(arange_out_shape)-1)).expand(arange_out_shape).to_uop())
|
||||
arange_eq = arange.alu(BinaryOps.CMPNE, UOp(Ops.LOAD, dtypes.int, data2)).alu(BinaryOps.CMPNE, ast_const(dtypes.bool, True, arange_out_shape))
|
||||
arange_eq = arange.alu(Ops.CMPNE, UOp(Ops.LOAD, dtypes.int, data2)).alu(Ops.CMPNE, ast_const(dtypes.bool, True, arange_out_shape))
|
||||
reduce_input = UOp(Ops.LOAD, dataset.dtype, data1)*UOp(Ops.CAST, dataset.dtype.scalar(), src=(arange_eq,))
|
||||
out_axis = (1,)
|
||||
out = UOp(Ops.REDUCE_AXIS, reduce_input.dtype, (reduce_input,), (BinaryOps.ADD, out_axis))
|
||||
out = UOp(Ops.REDUCE_AXIS, reduce_input.dtype, (reduce_input,), (Ops.ADD, out_axis))
|
||||
output_shape = tuple(1 if i in out_axis else s for i,s in enumerate(arange_out_shape))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape(output_shape).to_uop(), out))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
@@ -605,7 +605,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
ast_const(dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), val=10),
|
||||
UOp(Ops.MUL, dtypes.int, arg=None, src=(
|
||||
ast_const(dtypes.int, -1, (1, 20, 1)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.MAX, (0,)), src=(
|
||||
UOp(Ops.MUL, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
@@ -618,7 +618,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa E501
|
||||
ast_const(dtypes.bool, True, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)), # noqa E501
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (2,)), src=(
|
||||
ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 20, 10), strides=(1, 0, 20), offset=0, mask=None, contiguous=False)))),)), # noqa E501
|
||||
ast_const(dtypes.int, 10, (10, 20, 1)))),)),)),)),)),
|
||||
ast_const(dtypes.int, -1, (1, 20, 1)),)),)),))
|
||||
@@ -637,7 +637,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
ast_const(dtypes.int, 200, (1, 1)),
|
||||
UOp(Ops.MUL, dtypes.int, arg=None, src=(
|
||||
ast_const(dtypes.int, -1, (1, 1)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.MAX, (0,)), src=(
|
||||
UOp(Ops.MUL, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
@@ -650,7 +650,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa: E501
|
||||
ast_const(dtypes.bool, True, (200, 1)),)),)),
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
|
||||
ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(201, 399), strides=(0, 0), offset=0, mask=((0, 201), (199, 399)), contiguous=False), View(shape=(200, 200), strides=(1, 400), offset=0, mask=None, contiguous=False)))),)), # noqa: E501
|
||||
ast_const(dtypes.int, 200, (200, 1)),)),)),)),)),)),
|
||||
ast_const(dtypes.int, -1, (1, 1)),)),)),))
|
||||
@@ -672,16 +672,16 @@ class TestLinearizer(unittest.TestCase):
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop()))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop()))
|
||||
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.ADD, (1,)))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),),(BinaryOps.ADD, (0,)))
|
||||
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.ADD, (1,)))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),),(Ops.ADD, (0,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,1,N)).to_uop(), r1))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=0, keepdims=True)).sum(axis=0).reshape(1,1,N)], opts=opts)
|
||||
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)).to_uop()))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, N, 1)).to_uop()))
|
||||
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.ADD, (2,)))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (BinaryOps.ADD, (1,)))
|
||||
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.ADD, (2,)))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((N,1,1)).to_uop(), r1))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(N,1,1)], opts=opts)
|
||||
@@ -699,16 +699,16 @@ class TestLinearizer(unittest.TestCase):
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop()))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop()))
|
||||
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.MAX, (1,)))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),), (BinaryOps.MAX, (0,)))
|
||||
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.MAX, (1,)))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),), (Ops.MAX, (0,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,1,N)).to_uop(), r1))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().max(axis=0, keepdims=True)).max(axis=0).reshape(1,1,N)], opts=opts)
|
||||
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)).to_uop()))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, N, 1)).to_uop()))
|
||||
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.MAX, (2,)))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (BinaryOps.MAX, (1,)))
|
||||
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.MAX, (2,)))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (Ops.MAX, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((N,1,1)).to_uop(), r1))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().max(axis=1, keepdims=True)).max(axis=1).reshape(N,1,1)], opts=opts)
|
||||
@@ -735,7 +735,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
|
||||
ast_const(dtypes.float, 0.5*N, (N, 1, 1)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
|
||||
@@ -743,7 +743,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
|
||||
ast_const(dtypes.float, 0.75*N, (N, N, 1)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
|
||||
ld0.to_uop(),)),)),)),
|
||||
@@ -768,7 +768,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
|
||||
ast_const(dtypes.float, 0.5*N, (1, 1, N)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
@@ -776,7 +776,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
|
||||
ast_const(dtypes.float, 0.75*N, (N, 1, N)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
ld0.to_uop(),)),)),)),
|
||||
@@ -804,7 +804,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
|
||||
ast_const(dtypes.float, 0.5*N, (1, 1, 1, 1)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 1)), src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
|
||||
@@ -812,7 +812,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
|
||||
ast_const(dtypes.float, 0.75*N, (N, N, 1, 1)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2, 3)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, N, N), strides=(0, 0, N, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501
|
||||
@@ -831,7 +831,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
def test_end_local(self):
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=i) for i in range(2)]
|
||||
load = UOp(Ops.LOAD, dtypes.int, (g1, ShapeTracker.from_shape((32,)).to_uop()))
|
||||
reduce = UOp(Ops.REDUCE_AXIS, dtypes.int, (load,), (BinaryOps.ADD, (0,)))
|
||||
reduce = UOp(Ops.REDUCE_AXIS, dtypes.int, (load,), (Ops.ADD, (0,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,)).to_uop(), reduce))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
load_t = Tensor.full(load.st_arg.shape, 1).contiguous().realize()
|
||||
@@ -1219,20 +1219,20 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert len(sched) == 1
|
||||
|
||||
lin = Kernel(sched[0].ast)
|
||||
assert sum(u.op in {UnaryOps.RECIP, BinaryOps.FDIV} for u in lin.linearize().uops) == max_ops, msg
|
||||
assert sum(u.op in {Ops.RECIP, Ops.FDIV} for u in lin.linearize().uops) == max_ops, msg
|
||||
|
||||
a = Tensor.empty((4,4))
|
||||
b = Tensor.empty((4,4))
|
||||
d = Tensor.empty((4,4))
|
||||
|
||||
c = (a*b)/b
|
||||
helper(c, "found UnaryOps.RECIP in (a*b)/b operation")
|
||||
helper(c, "found Ops.RECIP in (a*b)/b operation")
|
||||
|
||||
c = a/a
|
||||
helper(c, "found UnaryOps.RECIP in (a/a) operation")
|
||||
helper(c, "found Ops.RECIP in (a/a) operation")
|
||||
|
||||
c = (a/b)/d
|
||||
helper(c, "found multiple UnaryOps.RECIP in (a/b)/d operation", 1)
|
||||
helper(c, "found multiple Ops.RECIP in (a/b)/d operation", 1)
|
||||
|
||||
def test_sum_collapse(self):
|
||||
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
|
||||
@@ -1260,7 +1260,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
lin = Kernel(sched_copy[-1].ast)
|
||||
lin.hand_coded_optimizations()
|
||||
lin.linearize()
|
||||
assert not any(u.op == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded"
|
||||
assert not any(u.op == Ops.WHERE for u in lin.uops), "found where where where should be folded"
|
||||
|
||||
def test_phi_simplification(self):
|
||||
def helper(t, max_ops=0):
|
||||
@@ -1272,7 +1272,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert len(set([u.op for u in uops if u.op in {Ops.RANGE, Ops.SPECIAL}])) == 1, "has either specials or ranges, not both"
|
||||
assert len([u for u in uops if u.op is Ops.ASSIGN]) == 0, "ASSIGN should have been simplified"
|
||||
# TODO: once uops track min/max this will be fixed
|
||||
#assert len([u for u in uops if u.op is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops"
|
||||
#assert len([u for u in uops if u.op is Ops.MAX]) <= max_ops, "no unnecessary MAX ops"
|
||||
|
||||
helper(Tensor.arange(5.5, (3.5*300), 3.5), max_ops=2)
|
||||
helper(Tensor.arange(-1, -100, -5), max_ops=2)
|
||||
@@ -1602,7 +1602,7 @@ class TestFloat4(unittest.TestCase):
|
||||
UOp(Ops.STORE, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, src=(
|
||||
@@ -1632,7 +1632,7 @@ class TestFloat4(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
|
||||
@@ -1662,7 +1662,7 @@ class TestFloat4(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),))), # noqa: E501
|
||||
UOp(Ops.CAST, dtypes.half, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (4, 6)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4, 6)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1),
|
||||
@@ -1949,7 +1949,7 @@ class TestKernelOpts(unittest.TestCase):
|
||||
UOp(Ops.STORE, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),))),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.float, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
@@ -2138,7 +2138,7 @@ class TestKernelOpts(unittest.TestCase):
|
||||
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||
ld0 = UOp(Ops.LOAD, dtypes.float, (g1, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501
|
||||
ld1 = UOp(Ops.LOAD, dtypes.float, (g2, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)).to_uop(), UOp(Ops.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (BinaryOps.ADD, (0, 2, 4, 6)),))) # noqa: E501
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)).to_uop(), UOp(Ops.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (Ops.ADD, (0, 2, 4, 6)),))) # noqa: E501
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
data1 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()
|
||||
data2 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
import unittest
|
||||
from test.helpers import ast_const
|
||||
from tinygrad import Device, dtypes
|
||||
from tinygrad.ops import UOp, Ops, BinaryOps
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.engine.search import Opt, OptOps
|
||||
@@ -21,7 +21,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
UOp(Ops.MAX, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
@@ -64,7 +64,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
ast_const(dtypes.bool, True, st_src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
|
||||
ast_const(dtypes.int, -1, st_src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=()),)),)),
|
||||
ast_const(dtypes.int, 1000, st_src=(
|
||||
@@ -75,7 +75,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
assert prg.uops is not None and not any(uop.op is BinaryOps.MAX for uop in prg.uops), "leftover MAX"
|
||||
assert prg.uops is not None and not any(uop.op is Ops.MAX for uop in prg.uops), "leftover MAX"
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
|
||||
def test_expander_new_srcs(self):
|
||||
@@ -83,7 +83,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
|
||||
@@ -105,14 +105,14 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (2,)), src=(
|
||||
ast_const(dtypes.int, 1, st_src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32001, 63999), strides=(0, 0), offset=0, mask=((0, 32001), (31999, 63999)), contiguous=False), View(shape=(4096, 32000, 32000), strides=(0, 1, 64000), offset=0, mask=None, contiguous=False))), src=()),)),)),
|
||||
ast_const(dtypes.int, -1, st_src=(
|
||||
@@ -136,7 +136,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
@@ -168,7 +168,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 1)), src=(
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
@@ -200,7 +200,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest, random
|
||||
import numpy as np
|
||||
from tinygrad.codegen.kernel import Kernel, KernelOptError
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.ops import UOp, Ops, BinaryOps
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.engine.search import Opt, OptOps
|
||||
from tinygrad import Device, dtypes, Tensor
|
||||
from tinygrad.helpers import CI
|
||||
@@ -47,7 +47,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 16), strides=(16, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
@@ -64,7 +64,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 2, 37, 9, 1, 1), strides=(666, 333, 9, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (4, 5)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.MAX, (4, 5)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 2, 111, 27), strides=(6160, 3080, 28, 1), offset=0, mask=((0, 32), (0, 2), (0, 110), (0, 27)), contiguous=False), View(shape=(32, 2, 37, 9, 2, 2), strides=(5994, 2997, 81, 3, 27, 1), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
|
||||
@@ -76,7 +76,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 16, 1), strides=(128, 16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),))
|
||||
@@ -89,7 +89,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 6)), src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
@@ -111,7 +111,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
|
||||
ast_const(dtypes.int, -1, st_src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))), src=()),)),)),
|
||||
ast_const(dtypes.int, 10, st_src=(
|
||||
@@ -125,7 +125,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 32, 1, 34, 1, 34), strides=(36992, 1156, 0, 34, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 4)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2, 4)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 32, 6, 8, 4, 6, 8, 4), strides=(2048, 64, 6291456, 8, 0, 1048576, 1, 0), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 8), (0, 1), (0, 6), (0, 8), (0, 1)), contiguous=False), View(shape=(512, 32, 6, 35, 6, 35), strides=(1179648, 36864, 6144, 192, 32, 1), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 32), (0, 6), (0, 32)), contiguous=False), View(shape=(512, 32, 238, 238), strides=(1411200, 44100, 210, 1), offset=0, mask=((0, 512), (0, 32), (0, 210), (0, 210)), contiguous=False), View(shape=(512, 32, 7, 34, 7, 34), strides=(1812608, 56644, 8092, 238, 34, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),))
|
||||
@@ -142,7 +142,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.RECIP, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
x9:=UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
@@ -166,7 +166,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 0, 0, 4500, 0, 0, 0, 0, 900, 60, 12, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
@@ -183,7 +183,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.ADD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(Ops.ADD, (3,)), src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
|
||||
@@ -202,7 +202,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.RECIP, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 3)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 3)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
@@ -277,7 +277,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 6)), src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
@@ -299,7 +299,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 8)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 8)), src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
x5:=UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
@@ -313,7 +313,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
x6,)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 8)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 8)), src=(
|
||||
x5,)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.GROUP, axis=0, amt=4)]
|
||||
helper_test_lin(Kernel(ast), opts, failed_platforms=[])
|
||||
@@ -325,7 +325,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.ADD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(Ops.ADD, (3,)), src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
|
||||
@@ -344,7 +344,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 6)), src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
@@ -370,7 +370,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
@@ -405,7 +405,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1024), strides=(0, 1024, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),
|
||||
@@ -420,7 +420,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 1, 28, 28, 1, 1), strides=(31360, 0, 784, 0, 28, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
@@ -442,7 +442,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
@@ -462,7 +462,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 9, 7, 3, 3), strides=(2268, 0, 567, 0, 63, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
@@ -511,7 +511,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 3)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 3)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
@@ -624,7 +624,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
|
||||
ast_const(dtypes.int, 1, st_src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1025, 2047), strides=(0, 0), offset=0, mask=((0, 1025), (1023, 2047)), contiguous=False), View(shape=(1024, 1024), strides=(1, 2048), offset=0, mask=None, contiguous=False))), src=()),)),)),
|
||||
ast_const(dtypes.int, -1, st_src=(
|
||||
@@ -639,7 +639,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
|
||||
ast_const(dtypes.int, 1, st_src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(129, 255), strides=(0, 0), offset=0, mask=((0, 129), (127, 255)), contiguous=False), View(shape=(128, 128), strides=(1, 256), offset=0, mask=None, contiguous=False))), src=()),)),)),
|
||||
ast_const(dtypes.int, -1, st_src=(
|
||||
@@ -678,7 +678,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 1), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.MAX, (3,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(Ops.MAX, (3,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 169, 13, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),))
|
||||
@@ -731,7 +731,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
@@ -749,7 +749,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 12, 31, 31, 1, 1, 1), strides=(11532, 0, 961, 31, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
@@ -767,7 +767,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 1), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
|
||||
UOp(Ops.EXP2, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
@@ -791,7 +791,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 256, 14, 14, 1, 1, 1), strides=(50176, 0, 196, 14, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
@@ -809,7 +809,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
x5:=UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
@@ -851,7 +851,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 1, 1), strides=(180, 0, 30, 3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.MAX, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6, 7)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
@@ -875,7 +875,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(5, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.CAST, dtypes.uchar, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.uint, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.uint, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.uint, arg=(Ops.ADD, (1,)), src=(
|
||||
UOp(Ops.CAST, dtypes.uint, arg=None, src=(
|
||||
ast_const(dtypes.uchar, 1, st_src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(6, 9), strides=(0, 0), offset=0, mask=((0, 6), (4, 9)), contiguous=False), View(shape=(5, 5), strides=(1, 10), offset=0, mask=None, contiguous=False))), src=()),)),)),)),
|
||||
@@ -895,7 +895,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.MAX, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6, 7)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
|
||||
@@ -920,7 +920,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 32, 1, 1, 1, 5, 5, 256), strides=(0, 0, 6400, 0, 0, 0, 1280, 256, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 3, 4)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 3, 4)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
|
||||
@@ -943,7 +943,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.MAX, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6, 7)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
|
||||
@@ -969,7 +969,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
|
||||
ast_const(dtypes.int, 1, st_src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)),
|
||||
ast_const(dtypes.int, -1, st_src=(
|
||||
@@ -987,7 +987,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 1, 1, 1), strides=(100352, 0, 784, 28, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
@@ -1007,7 +1007,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
|
||||
@@ -1021,7 +1021,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
|
||||
@@ -1035,7 +1035,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
|
||||
@@ -1052,7 +1052,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 1, 1, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2, 3)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
@@ -1065,7 +1065,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (4,)), src=(
|
||||
ast_const(dtypes.int, 1, st_src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 3), strides=(0, 0), offset=0, mask=((0, 3), (1, 3)), contiguous=False), View(shape=(2, 3, 2, 3, 3), strides=(0, 0, 1, 0, 4), offset=0, mask=((0, 2), (0, 3), (0, 2), (0, 3), (0, 2)), contiguous=False))), src=()),)),)),
|
||||
x19:=ast_const(dtypes.int, -1, st_src=(
|
||||
@@ -1078,7 +1078,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=3, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (4,)), src=(
|
||||
ast_const(dtypes.int, 1, st_src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5), strides=(0, 0), offset=0, mask=((0, 4), (2, 5)), contiguous=False), View(shape=(2, 3, 2, 3, 3), strides=(0, 0, 0, 1, 6), offset=0, mask=None, contiguous=False))), src=()),)),)),
|
||||
x19,)),)),
|
||||
@@ -1093,7 +1093,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.bool, arg=None, src=(
|
||||
@@ -1127,7 +1127,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
|
||||
ast_const(dtypes.int, 1, st_src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)),
|
||||
ast_const(dtypes.int, -1, st_src=(
|
||||
@@ -1142,7 +1142,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 1, 1, 256, 1, 1, 256), strides=(0, 0, 65536, 0, 0, 256, 0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3, 4)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3, 4)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
@@ -1160,7 +1160,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 1), strides=(6, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
@@ -1178,7 +1178,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 20, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.bool, arg=(BinaryOps.ADD, (3,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.bool, arg=(Ops.ADD, (3,)), src=(
|
||||
UOp(Ops.MUL, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=1, src=()),
|
||||
@@ -1220,7 +1220,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
x9,)),
|
||||
UOp(Ops.ADD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
@@ -1249,7 +1249,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
@@ -1267,7 +1267,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.uchar, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.uchar, arg=(Ops.ADD, (1,)), src=(
|
||||
UOp(Ops.MUL, dtypes.uchar, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()),
|
||||
@@ -1279,7 +1279,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (2,)), src=(
|
||||
UOp(Ops.WHERE, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.VALID, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(50001, 99999), strides=(0, 0), offset=0, mask=((0, 50001), (49999, 99999)), contiguous=False), View(shape=(1024, 50000, 50000), strides=(0, 1, 100000), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
@@ -1306,7 +1306,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
@@ -1326,7 +1326,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(W, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.engine.search import Opt, OptOps
|
||||
from tinygrad.engine.search import time_linearizer, bufs_from_lin
|
||||
|
||||
# stuff needed to unpack a kernel
|
||||
from tinygrad.ops import UOp, Ops, BinaryOps
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
@@ -33,7 +33,7 @@ class TestLinearizerOverflow(unittest.TestCase):
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
@@ -68,7 +68,7 @@ class TestLinearizerOverflow(unittest.TestCase):
|
||||
UOp(Ops.STORE, None, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 1, 1, 1), strides=(65536, 0, 1024, 32, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
@@ -85,7 +85,7 @@ class TestLinearizerOverflow(unittest.TestCase):
|
||||
UOp(Ops.STORE, None, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
@@ -102,7 +102,7 @@ class TestLinearizerOverflow(unittest.TestCase):
|
||||
UOp(Ops.STORE, None, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
@@ -119,7 +119,7 @@ class TestLinearizerOverflow(unittest.TestCase):
|
||||
UOp(Ops.STORE, None, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
@@ -136,7 +136,7 @@ class TestLinearizerOverflow(unittest.TestCase):
|
||||
UOp(Ops.STORE, None, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
@@ -153,7 +153,7 @@ class TestLinearizerOverflow(unittest.TestCase):
|
||||
UOp(Ops.STORE, None, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
@@ -175,7 +175,7 @@ class TestLinearizerOverflowAlt(unittest.TestCase):
|
||||
in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop()
|
||||
ot_st = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)).to_uop()
|
||||
prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2))
|
||||
store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5)))))
|
||||
store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (Ops.ADD, (7, 6, 5)))))
|
||||
ast = UOp(Ops.SINK, src=(store,))
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=2)]
|
||||
_test_overflow(ast, opts)
|
||||
@@ -187,7 +187,7 @@ class TestLinearizerOverflowAlt(unittest.TestCase):
|
||||
in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop()
|
||||
ot_st = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)).to_uop()
|
||||
prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2))
|
||||
store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5)))))
|
||||
store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (Ops.ADD, (7, 6, 5)))))
|
||||
ast = UOp(Ops.SINK, src=(store,))
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=4, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=5, amt=2)]
|
||||
_test_overflow(ast, opts)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest, functools, random
|
||||
from typing import List
|
||||
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes
|
||||
from tinygrad.ops import MetaOps, BinaryOps, Ops
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.helpers import CI, getenv, prod, Context
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
@@ -481,7 +481,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
for p in get_parameters(bn): p.shard_(devices_4).realize()
|
||||
|
||||
out = bn(t)
|
||||
scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.outputs[0].device in devices_4 and sched.ast.op is not MetaOps.COPY]
|
||||
scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.outputs[0].device in devices_4 and sched.ast.op is not Ops.COPY]
|
||||
assert set(out.device for sched in scheds for out in sched.outputs) == set(devices_4), "should have ast on each shard device"
|
||||
asts = [sched.ast for sched in scheds]
|
||||
assert len(asts)
|
||||
@@ -640,21 +640,21 @@ class TestMultiTensor(unittest.TestCase):
|
||||
for si in t.schedule():
|
||||
ast = si.ast.src[0]
|
||||
assert ast.op is Ops.STORE
|
||||
assert ast.src[2].op is BinaryOps.ADD
|
||||
assert ast.src[2].op is Ops.ADD
|
||||
assert ast.src[2].src[0].op is Ops.LOAD
|
||||
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 1
|
||||
t = 2 * t
|
||||
for si in t.schedule():
|
||||
ast = si.ast.src[0]
|
||||
assert ast.op is Ops.STORE
|
||||
assert ast.src[2].op is BinaryOps.MUL
|
||||
assert ast.src[2].op is Ops.MUL
|
||||
assert ast.src[2].src[0].src[1].op is Ops.CONST and ast.src[2].src[0].src[1].arg == 2
|
||||
assert ast.src[2].src[1].op is Ops.LOAD
|
||||
t = t + t.full_like(3)
|
||||
for si in t.schedule():
|
||||
ast = si.ast.src[0]
|
||||
assert ast.op is Ops.STORE
|
||||
assert ast.src[2].op is BinaryOps.ADD
|
||||
assert ast.src[2].op is Ops.ADD
|
||||
assert ast.src[2].src[0].op is Ops.LOAD
|
||||
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 3
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.dtype import dtypes
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import dedup, flatten, prod
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
from tinygrad.ops import BinaryOps, UOp, Ops
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
@@ -34,7 +34,7 @@ class TestCStyleFailures(unittest.TestCase):
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (b.index(idx),))
|
||||
alu = ld.alu(BinaryOps.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1))
|
||||
alu = ld.alu(Ops.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1))
|
||||
store = UOp.store(a.index(idx), alu)
|
||||
sink = UOp(Ops.SINK, dtypes.void, (store,))
|
||||
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
|
||||
|
||||
@@ -13,7 +13,7 @@ from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, Ops, graph_rewrite, track_rewrites
|
||||
from tinygrad.ops import UOp, Ops, graph_rewrite, track_rewrites
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, view_left
|
||||
@@ -1042,7 +1042,7 @@ class TestSchedule(unittest.TestCase):
|
||||
b = r.sum(0) * 4
|
||||
c = r.sum(1) * 2
|
||||
schedule = check_schedule([b, c], 3)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
|
||||
|
||||
# multireduce spec
|
||||
def test_multireduce_simple_chase(self):
|
||||
@@ -1053,7 +1053,7 @@ class TestSchedule(unittest.TestCase):
|
||||
c = r.sum(1) + 12
|
||||
np_r = (a.numpy() + (a.numpy().sum(0) + 6)).sum(0) * 2
|
||||
# schedule = check_schedule([b,c], 3)
|
||||
# self.assertIs(schedule[0].ast[0].src[0].arg, BinaryOps.MUL)
|
||||
# self.assertIs(schedule[0].ast[0].src[0].arg, Ops.MUL)
|
||||
schedule = check_schedule([b,c], 4)
|
||||
run_schedule(schedule)
|
||||
np.testing.assert_allclose(b.numpy(), np_r.sum(0) + 8, atol=1e-4, rtol=1e-4)
|
||||
@@ -1066,7 +1066,7 @@ class TestSchedule(unittest.TestCase):
|
||||
d = r.T * 4
|
||||
e = r * d
|
||||
schedule = check_schedule([d, e], 3)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
|
||||
|
||||
# multireduce spec
|
||||
def test_multireduce_push_permute_chase(self):
|
||||
@@ -1077,7 +1077,7 @@ class TestSchedule(unittest.TestCase):
|
||||
d = r.T * 4
|
||||
e = r * (d + a).sum(2)
|
||||
schedule = check_schedule([d, e], 3) # make sure it doesn't fuse
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
|
||||
run_schedule(schedule)
|
||||
np.testing.assert_allclose(d.numpy(), (a.numpy().sum(2) + b.numpy()).T * 4, atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(e.numpy(), (a.numpy().sum(2) + b.numpy()) * (d.numpy() + a.numpy()).sum(2), atol=1e-4, rtol=1e-4)
|
||||
@@ -1089,7 +1089,7 @@ class TestSchedule(unittest.TestCase):
|
||||
r = a.sum(1) + c
|
||||
d = r[:4] * b
|
||||
schedule = check_schedule(d, 2)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
|
||||
|
||||
# multireduce spec
|
||||
def test_multireduce_push_shrink_chase(self):
|
||||
@@ -1102,7 +1102,7 @@ class TestSchedule(unittest.TestCase):
|
||||
out = r[:4] * b + d.sum(1)[:4]
|
||||
# schedule = check_schedule(out, 2)
|
||||
schedule = check_schedule(out, 3)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
|
||||
run_schedule(schedule)
|
||||
np.testing.assert_allclose(out.numpy(), (a.numpy().sum(1) + c.numpy())[:4] * b.numpy() + d.numpy().sum(1)[:4], atol=1e-4, rtol=1e-4)
|
||||
|
||||
@@ -1287,16 +1287,16 @@ class TestSchedule(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported")
|
||||
def test_bitcast_subbufer(self):
|
||||
x = cast(LazyBuffer, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata)
|
||||
a = x.alu(UnaryOps.EXP2).cast(dtypes.int32, True, allow_buffer_view=True)
|
||||
a = x.alu(Ops.EXP2).cast(dtypes.int32, True, allow_buffer_view=True)
|
||||
b = x.cast(dtypes.int32, True, allow_buffer_view=True)
|
||||
b = a.alu(BinaryOps.ADD, b)
|
||||
b = a.alu(Ops.ADD, b)
|
||||
check_schedule(b, 2) # this should fuse when it makes sense
|
||||
|
||||
def test_bitcast_disable_subbufer(self):
|
||||
x = cast(LazyBuffer, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata)
|
||||
a = x.alu(UnaryOps.EXP2).cast(dtypes.int32, True, allow_buffer_view=False)
|
||||
a = x.alu(Ops.EXP2).cast(dtypes.int32, True, allow_buffer_view=False)
|
||||
b = x.cast(dtypes.int32, True, allow_buffer_view=False)
|
||||
b = a.alu(BinaryOps.ADD, b)
|
||||
b = a.alu(Ops.ADD, b)
|
||||
check_schedule(b, 1)
|
||||
|
||||
def test_reduceop_reshape_dont_push(self):
|
||||
@@ -1530,7 +1530,7 @@ class TestIndexing(unittest.TestCase):
|
||||
def test_arange_view_op(self):
|
||||
a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).contiguous()
|
||||
assert isinstance(a.lazydata, LazyBuffer)
|
||||
self.assertIs(a.lazydata.base.op, MetaOps.BUFFER_VIEW)
|
||||
self.assertIs(a.lazydata.base.op, Ops.BUFFER_VIEW)
|
||||
self.check_schedule(a, 1)
|
||||
np.testing.assert_equal(a.numpy(), [[4, 5]])
|
||||
|
||||
@@ -1538,7 +1538,7 @@ class TestIndexing(unittest.TestCase):
|
||||
def test_arange_shrink_copy(self):
|
||||
a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).to("CLANG")
|
||||
assert isinstance(a.lazydata, LazyBuffer)
|
||||
self.assertIs(a.lazydata.base.op, MetaOps.COPY)
|
||||
self.assertIs(a.lazydata.base.op, Ops.COPY)
|
||||
self.check_schedule(a, 1)
|
||||
np.testing.assert_equal(a.numpy(), [[4, 5]])
|
||||
|
||||
@@ -1546,8 +1546,8 @@ class TestIndexing(unittest.TestCase):
|
||||
def test_arange_expand_copy(self):
|
||||
a = Tensor.arange(4).reshape(2, 2, 1).expand(2, 2, 2).to("CLANG")
|
||||
assert isinstance(a.lazydata, LazyBuffer)
|
||||
self.assertIs(a.lazydata.base.op, MetaOps.COPY)
|
||||
self.assertIs(a.lazydata.base.srcs[0].base.op, BinaryOps.ADD)
|
||||
self.assertIs(a.lazydata.base.op, Ops.COPY)
|
||||
self.assertIs(a.lazydata.base.srcs[0].base.op, Ops.ADD)
|
||||
self.check_schedule(a, 1)
|
||||
np.testing.assert_equal(a.numpy(), [[[0, 0], [1, 1]], [[2, 2], [3, 3]]])
|
||||
|
||||
@@ -1643,7 +1643,7 @@ class TestIndexing(unittest.TestCase):
|
||||
def test_simple_store_reshape(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1)))
|
||||
r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
|
||||
r = r + 2
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
@@ -1655,7 +1655,7 @@ class TestIndexing(unittest.TestCase):
|
||||
def test_no_reshape_reduceop(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1)))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
verify_ast(sink)
|
||||
@@ -1698,7 +1698,7 @@ class TestSwizzle(unittest.TestCase):
|
||||
# LazyBuffer to pre-rewrite AST
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,)))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0,)))
|
||||
swizzle_r = UOp(Ops.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(()))
|
||||
alu = swizzle_r+1
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),))
|
||||
@@ -1720,9 +1720,9 @@ class TestSwizzle(unittest.TestCase):
|
||||
# LazyBuffer to pre-rewrite AST
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,)))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld1,), (Ops.ADD, (0,)))
|
||||
ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop()))
|
||||
r2 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld2,), (BinaryOps.ADD, (0,)))
|
||||
r2 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld2,), (Ops.ADD, (0,)))
|
||||
alu = UOp(Ops.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(Ops.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(()))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+2,),),)) # noqa: E501
|
||||
# graph rewrite
|
||||
@@ -1736,7 +1736,7 @@ class TestSwizzle(unittest.TestCase):
|
||||
|
||||
def test_swizzle_rewrite_alt(self):
|
||||
swizzle = UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
from test.helpers import ast_const
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.ops import UOp, Ops, BinaryOps
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search
|
||||
from tinygrad.device import Device, Buffer
|
||||
@@ -107,7 +107,7 @@ class TestBEAM(unittest.TestCase):
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (1,)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.MAX, (1,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
|
||||
@@ -2,8 +2,7 @@ from typing import List
|
||||
import unittest, time
|
||||
from tinygrad import dtypes, Device
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import BinaryOps, Ops, UOp, KernelInfo
|
||||
from tinygrad.ops import UPat, PatternMatcher
|
||||
from tinygrad.ops import Ops, UOp, KernelInfo, UPat, PatternMatcher
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, sym
|
||||
@@ -541,7 +540,7 @@ class TestExpander(unittest.TestCase):
|
||||
@unittest.skip("no longer supported")
|
||||
def test_reduce_known_axis(self):
|
||||
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
sink = UOp(Ops.REDUCE, dtypes.int, (3*e1,e1), BinaryOps.ADD)
|
||||
sink = UOp(Ops.REDUCE, dtypes.int, (3*e1,e1), Ops.ADD)
|
||||
sink = expander_rewrite(sink)
|
||||
assert sink.op is Ops.CONST
|
||||
self.assertEqual(sink.arg, 3*(0+1+2+3))
|
||||
@@ -549,7 +548,7 @@ class TestExpander(unittest.TestCase):
|
||||
@unittest.skip("no longer supported")
|
||||
def test_reduce_const(self):
|
||||
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
sink = UOp(Ops.REDUCE, dtypes.int, (UOp.const(dtypes.int, 3), e1), BinaryOps.ADD)
|
||||
sink = UOp(Ops.REDUCE, dtypes.int, (UOp.const(dtypes.int, 3), e1), Ops.ADD)
|
||||
sink = expander_rewrite(sink)
|
||||
assert sink.op is Ops.CONST
|
||||
self.assertEqual(sink.arg, 3*4)
|
||||
@@ -590,7 +589,7 @@ class TestExpander(unittest.TestCase):
|
||||
def test_reduce_different_axis(self):
|
||||
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
e2 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
|
||||
sink = UOp(Ops.REDUCE, dtypes.int, (e1,e2), BinaryOps.ADD)
|
||||
sink = UOp(Ops.REDUCE, dtypes.int, (e1,e2), Ops.ADD)
|
||||
sink = expander_rewrite(sink)
|
||||
print(sink)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.helpers import CI, DEBUG, getenv, Context
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import Ops, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, KernelInfo, exec_alu, spec # noqa F401
|
||||
from tinygrad.ops import Ops, UOp, UPat, KernelInfo, exec_alu, spec # noqa F401
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.engine.schedule import create_schedule, to_si
|
||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
||||
@@ -29,7 +29,7 @@ def uop(uops:List[UOp], uop:Ops, dtype:Optional[DType], src:Tuple[UOp, ...], arg
|
||||
|
||||
def _test_single_value(vals, op, dts):
|
||||
uops = []
|
||||
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
|
||||
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
|
||||
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
|
||||
buf_loads = [uop(uops, Ops.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
|
||||
loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0))]) for i, dtype in enumerate(dts))
|
||||
@@ -45,7 +45,7 @@ def _test_single_value(vals, op, dts):
|
||||
|
||||
def _test_single_value_const(vals, op, dts):
|
||||
uops = []
|
||||
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
|
||||
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
|
||||
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
|
||||
loads = (uop(uops, Ops.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
|
||||
alu = uop(uops, op, output_dtype, loads)
|
||||
@@ -103,49 +103,49 @@ class TestUOps(unittest.TestCase):
|
||||
|
||||
class TestFloatUOps(TestUOps):
|
||||
@unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop')
|
||||
def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a))
|
||||
def test_exp2(self): self._test_uop_fxn(Ops.EXP2, lambda a: np.exp2(a))
|
||||
@unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop')
|
||||
def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan'))
|
||||
def test_log2(self): self._test_uop_fxn(Ops.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan'))
|
||||
@unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop')
|
||||
def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a))
|
||||
def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1/a if a != 0 else float('inf'))
|
||||
def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
|
||||
def test_sin(self): self._test_uop_fxn(Ops.SIN, lambda a: math.sin(a))
|
||||
def test_recip(self): self._test_uop_fxn(Ops.RECIP, lambda a: 1/a if a != 0 else float('inf'))
|
||||
def test_sqrt(self): self._test_uop_fxn(Ops.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
|
||||
|
||||
def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b)
|
||||
def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
|
||||
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
|
||||
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a<b)
|
||||
def test_cmpne(self): self._test_bop_fxn(BinaryOps.CMPNE, lambda a,b: a!=b)
|
||||
def test_add(self): self._test_bop_fxn(Ops.ADD, lambda a,b: a+b)
|
||||
def test_mul(self): self._test_bop_fxn(Ops.MUL, lambda a,b: a*b)
|
||||
def test_max(self): self._test_bop_fxn(Ops.MAX, lambda a,b: max(a,b))
|
||||
def test_cmplt(self): self._test_bop_fxn(Ops.CMPLT, lambda a,b: a<b)
|
||||
def test_cmpne(self): self._test_bop_fxn(Ops.CMPNE, lambda a,b: a!=b)
|
||||
# MOD isn't tested on floats
|
||||
|
||||
def test_where(self):
|
||||
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float, dtypes.float))
|
||||
self._test_top_fxn(Ops.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float, dtypes.float))
|
||||
|
||||
@unittest.skipUnless(getenv("PYTHON"), "only python supports MULACC")
|
||||
def test_mulacc(self):
|
||||
self._test_top_fxn(TernaryOps.MULACC, lambda a,b,c: a*b+c, (dtypes.float, dtypes.float, dtypes.float))
|
||||
self._test_top_fxn(Ops.MULACC, lambda a,b,c: a*b+c, (dtypes.float, dtypes.float, dtypes.float))
|
||||
|
||||
class TestNonFloatUOps(TestUOps):
|
||||
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_add_int32(self): self._test_bop_fxn(Ops.ADD, lambda a,b: int(a)+int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_mul_int32(self): self._test_bop_fxn(Ops.MUL, lambda a,b: int(a)*int(b), (dtypes.int32, dtypes.int32))
|
||||
@unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts")
|
||||
def test_shr_int32(self): self._test_bop_fxn(BinaryOps.SHR, lambda a,b: int(a)>>int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
|
||||
def test_shr_int32(self): self._test_bop_fxn(Ops.SHR, lambda a,b: int(a)>>int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
|
||||
@unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts")
|
||||
def test_shl_int32(self): self._test_bop_fxn(BinaryOps.SHL, lambda a,b: int(a)<<int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
|
||||
def test_shl_int32(self): self._test_bop_fxn(Ops.SHL, lambda a,b: int(a)<<int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
|
||||
def test_div_int32(self):
|
||||
self._test_bop_fxn(BinaryOps.IDIV, lambda a,b: int(a/b), (dtypes.int32, dtypes.int32), no_b_zero=True)
|
||||
def test_and_int32(self): self._test_bop_fxn(BinaryOps.AND, lambda a,b: int(a)&int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_or_int32(self): self._test_bop_fxn(BinaryOps.OR, lambda a,b: int(a)|int(b), (dtypes.int32, dtypes.int32))
|
||||
self._test_bop_fxn(Ops.IDIV, lambda a,b: int(a/b), (dtypes.int32, dtypes.int32), no_b_zero=True)
|
||||
def test_and_int32(self): self._test_bop_fxn(Ops.AND, lambda a,b: int(a)&int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_or_int32(self): self._test_bop_fxn(Ops.OR, lambda a,b: int(a)|int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_mod_int32(self):
|
||||
self._test_bop_fxn(BinaryOps.MOD,
|
||||
self._test_bop_fxn(Ops.MOD,
|
||||
lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], (dtypes.int32, dtypes.int32), no_b_zero=True)
|
||||
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: int(a)<int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_cmpne_int32(self): self._test_bop_fxn(BinaryOps.CMPNE, lambda a,b: int(a)!=int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_cmplt_int32(self): self._test_bop_fxn(Ops.CMPLT, lambda a,b: int(a)<int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_cmpne_int32(self): self._test_bop_fxn(Ops.CMPNE, lambda a,b: int(a)!=int(b), (dtypes.int32, dtypes.int32))
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bool), "dtype not supported")
|
||||
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), (dtypes.bool, dtypes.bool))
|
||||
def test_mul_bool(self): self._test_bop_fxn(Ops.MUL, lambda a,b: bool(a) and bool(b), (dtypes.bool, dtypes.bool))
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "dtype not supported")
|
||||
def test_where_float16(self):
|
||||
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float16, dtypes.float16))
|
||||
self._test_top_fxn(Ops.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float16, dtypes.float16))
|
||||
|
||||
class TestBoolUOps(TestUOps):
|
||||
def _test_uop_bool_fxn(self, op, fxn):
|
||||
@@ -166,72 +166,72 @@ class TestBoolUOps(TestUOps):
|
||||
for c in [False, True]:
|
||||
self._equal(f([a,b,c], op, (dtypes.bool, )*3), fxn(a,b,c))
|
||||
|
||||
def test_add_bool(self): self._test_bop_bool_fxn(BinaryOps.ADD, lambda a,b: a or b)
|
||||
def test_mul_bool(self): self._test_bop_bool_fxn(BinaryOps.MUL, lambda a,b: a and b)
|
||||
def test_xor_bool(self): self._test_bop_bool_fxn(BinaryOps.XOR, lambda a,b: a != b)
|
||||
def test_and_bool(self): self._test_bop_bool_fxn(BinaryOps.AND, lambda a,b: a & b)
|
||||
def test_or_bool(self): self._test_bop_bool_fxn(BinaryOps.OR, lambda a,b: a | b)
|
||||
def test_cmpne_bool(self): self._test_bop_bool_fxn(BinaryOps.CMPNE, lambda a,b: a != b)
|
||||
def test_cmplt_bool(self): self._test_bop_bool_fxn(BinaryOps.CMPLT, lambda a,b: a < b)
|
||||
def test_where_bool(self): self._test_top_bool_fxn(TernaryOps.WHERE, lambda a,b,c: b if a else c)
|
||||
def test_add_bool(self): self._test_bop_bool_fxn(Ops.ADD, lambda a,b: a or b)
|
||||
def test_mul_bool(self): self._test_bop_bool_fxn(Ops.MUL, lambda a,b: a and b)
|
||||
def test_xor_bool(self): self._test_bop_bool_fxn(Ops.XOR, lambda a,b: a != b)
|
||||
def test_and_bool(self): self._test_bop_bool_fxn(Ops.AND, lambda a,b: a & b)
|
||||
def test_or_bool(self): self._test_bop_bool_fxn(Ops.OR, lambda a,b: a | b)
|
||||
def test_cmpne_bool(self): self._test_bop_bool_fxn(Ops.CMPNE, lambda a,b: a != b)
|
||||
def test_cmplt_bool(self): self._test_bop_bool_fxn(Ops.CMPLT, lambda a,b: a < b)
|
||||
def test_where_bool(self): self._test_top_bool_fxn(Ops.WHERE, lambda a,b,c: b if a else c)
|
||||
|
||||
class TestExecALU(TestUOps):
|
||||
def test_sqrt(self):
|
||||
self.assertEqual(exec_alu(UnaryOps.SQRT, dtypes.float, (0.0,)), 0.0)
|
||||
self.assertEqual(exec_alu(Ops.SQRT, dtypes.float, (0.0,)), 0.0)
|
||||
|
||||
def test_div(self):
|
||||
self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (8, 2)), 4)
|
||||
self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (7, 3)), 2)
|
||||
self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (7, -3)), -2)
|
||||
self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (-50, 6)), -8)
|
||||
self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (8, 2)), 4)
|
||||
self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (7, 3)), 2)
|
||||
self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (7, -3)), -2)
|
||||
self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (-50, 6)), -8)
|
||||
|
||||
np.testing.assert_allclose(exec_alu(BinaryOps.MUL, dtypes.float32, (7.0, exec_alu(UnaryOps.RECIP, dtypes.float32, (3.0,)))), 2+(1.0/3.0))
|
||||
np.testing.assert_allclose(exec_alu(BinaryOps.MUL, dtypes.float32, (7.0, exec_alu(UnaryOps.RECIP, dtypes.float32, (-3.0,)))), -2-(1.0/3.0))
|
||||
np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIP, dtypes.float32, (3.0,)))), 2+(1.0/3.0))
|
||||
np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIP, dtypes.float32, (-3.0,)))), -2-(1.0/3.0))
|
||||
|
||||
def test_recip(self):
|
||||
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (8,)), 1/8)
|
||||
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (7,)), 1/7)
|
||||
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (-3,)), 1/-3)
|
||||
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (-50,)), 1/-50)
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (8,)), 1/8)
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (7,)), 1/7)
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (-3,)), 1/-3)
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (-50,)), 1/-50)
|
||||
|
||||
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, ((32+521+3),)), 1/(32+521+3))
|
||||
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, ((34**2),)), 1/(34**2))
|
||||
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (10,)), 1/10)
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, ((32+521+3),)), 1/(32+521+3))
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, ((34**2),)), 1/(34**2))
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (10,)), 1/10)
|
||||
|
||||
def test_bool_cmplt(self):
|
||||
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (False, False)), False)
|
||||
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (False, True)), True)
|
||||
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (True, False)), False)
|
||||
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (True, True)), False)
|
||||
self.assertEqual(exec_alu(Ops.CMPLT, dtypes.bool, (False, False)), False)
|
||||
self.assertEqual(exec_alu(Ops.CMPLT, dtypes.bool, (False, True)), True)
|
||||
self.assertEqual(exec_alu(Ops.CMPLT, dtypes.bool, (True, False)), False)
|
||||
self.assertEqual(exec_alu(Ops.CMPLT, dtypes.bool, (True, True)), False)
|
||||
|
||||
def test_bool_cmpne(self):
|
||||
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (False, False)), False)
|
||||
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (False, True)), True)
|
||||
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (True, False)), True)
|
||||
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (True, True)), False)
|
||||
self.assertEqual(exec_alu(Ops.CMPNE, dtypes.bool, (False, False)), False)
|
||||
self.assertEqual(exec_alu(Ops.CMPNE, dtypes.bool, (False, True)), True)
|
||||
self.assertEqual(exec_alu(Ops.CMPNE, dtypes.bool, (True, False)), True)
|
||||
self.assertEqual(exec_alu(Ops.CMPNE, dtypes.bool, (True, True)), False)
|
||||
|
||||
def test_bool_where(self):
|
||||
self.assertEqual(exec_alu(TernaryOps.WHERE, dtypes.bool, (False, False, False)), False)
|
||||
self.assertEqual(exec_alu(TernaryOps.WHERE, dtypes.int, (False, 2, 4)), 4)
|
||||
np.testing.assert_allclose(exec_alu(TernaryOps.WHERE, dtypes.float, (False, 2.2, 4.5)), 4.5)
|
||||
self.assertEqual(exec_alu(Ops.WHERE, dtypes.bool, (False, False, False)), False)
|
||||
self.assertEqual(exec_alu(Ops.WHERE, dtypes.int, (False, 2, 4)), 4)
|
||||
np.testing.assert_allclose(exec_alu(Ops.WHERE, dtypes.float, (False, 2.2, 4.5)), 4.5)
|
||||
|
||||
def test_overflow(self):
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (250, 250)), 244)
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (256, 0)), 0)
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (0, -1)), 255)
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (0, -1000)), 24)
|
||||
self.assertEqual(exec_alu(Ops.ADD, dtypes.uint8, (250, 250)), 244)
|
||||
self.assertEqual(exec_alu(Ops.ADD, dtypes.uint8, (256, 0)), 0)
|
||||
self.assertEqual(exec_alu(Ops.ADD, dtypes.uint8, (0, -1)), 255)
|
||||
self.assertEqual(exec_alu(Ops.ADD, dtypes.uint8, (0, -1000)), 24)
|
||||
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (127, 0)), 127)
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-128, 0)), -128)
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-100, -100)), 56)
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-1000, -0)), 24)
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-130, -0)), 126)
|
||||
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (127, 0)), 127)
|
||||
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (-128, 0)), -128)
|
||||
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (-100, -100)), 56)
|
||||
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (-1000, -0)), 24)
|
||||
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (-130, -0)), 126)
|
||||
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (1, 1)), 2)
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-128, 0)), -128)
|
||||
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (1, 1)), 2)
|
||||
self.assertEqual(exec_alu(Ops.ADD, dtypes.int8, (-128, 0)), -128)
|
||||
|
||||
# test no truncate
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (250, 250), truncate_output=False), 500)
|
||||
self.assertEqual(exec_alu(Ops.ADD, dtypes.uint8, (250, 250), truncate_output=False), 500)
|
||||
|
||||
class TestConstantFolding(unittest.TestCase):
|
||||
def test_cast_const(self):
|
||||
@@ -336,8 +336,8 @@ class TestAssembly(unittest.TestCase):
|
||||
a2 = UOp(Ops.MUL, dtypes.int, (l1, c2))
|
||||
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
|
||||
Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
self.assertEqual(uops[-1].op, BinaryOps.SHL)
|
||||
self.assertEqual(uops[-2].op, BinaryOps.MUL)
|
||||
self.assertEqual(uops[-1].op, Ops.SHL)
|
||||
self.assertEqual(uops[-2].op, Ops.MUL)
|
||||
|
||||
def test_bitshift_right(self):
|
||||
g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
|
||||
@@ -348,8 +348,8 @@ class TestAssembly(unittest.TestCase):
|
||||
a2 = UOp(Ops.IDIV, dtypes.int, (l1, c2))
|
||||
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
|
||||
Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
self.assertEqual(uops[-1].op, BinaryOps.SHR)
|
||||
self.assertEqual(uops[-2].op, BinaryOps.IDIV)
|
||||
self.assertEqual(uops[-1].op, Ops.SHR)
|
||||
self.assertEqual(uops[-2].op, Ops.IDIV)
|
||||
|
||||
class TestUOpMethod(unittest.TestCase):
|
||||
@unittest.skip("uops lt no longer ordered")
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from typing import Dict, List, Optional
|
||||
import unittest
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, \
|
||||
graph_rewrite, contexts, track_rewrites
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, contexts, track_rewrites
|
||||
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json
|
||||
|
||||
@track_rewrites()
|
||||
@@ -35,7 +34,7 @@ class TestViz(unittest.TestCase):
|
||||
def test_rewrite_twice(self):
|
||||
pm = PatternMatcher([
|
||||
(UPat.var("x")+UPat.var("x"), lambda x:x*2),
|
||||
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))),
|
||||
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(Ops.SHL, UOp.const(dtypes.int, 1))),
|
||||
])
|
||||
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
uops = helper_test_viz(a+a, pm)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest, math
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.helpers import all_same
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, BinaryOps, exec_alu
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, exec_alu
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite
|
||||
|
||||
# Helper function to apply the graph rewrite
|
||||
@@ -56,7 +56,7 @@ class TestFoldingAndReduction(unittest.TestCase):
|
||||
const1 = UOp.const(dtypes.int32, 5)
|
||||
const2 = UOp.const(dtypes.int32, 10)
|
||||
const3 = UOp.const(dtypes.int32, 20)
|
||||
optimized_sink = apply_rewrite((const1 + const2 + const3).reduce(BinaryOps.ADD))
|
||||
optimized_sink = apply_rewrite((const1 + const2 + const3).reduce(Ops.ADD))
|
||||
expected_sum = 5 + 10 + 20
|
||||
self.assertEqual(optimized_sink.arg, expected_sum)
|
||||
|
||||
@@ -65,14 +65,14 @@ class TestFoldingAndReduction(unittest.TestCase):
|
||||
const1 = UOp.const(dtypes.int32, 15)
|
||||
const2 = UOp.const(dtypes.int32, 25)
|
||||
rng = UOp.range(dtypes.int32, 0, 10, idx=0)
|
||||
optimized_sink = apply_rewrite((const1 + const2).reduce(BinaryOps.ADD, rng))
|
||||
optimized_sink = apply_rewrite((const1 + const2).reduce(Ops.ADD, rng))
|
||||
expected_sum = 10 * (15 + 25)
|
||||
self.assertEqual(optimized_sink.arg, expected_sum)
|
||||
|
||||
@unittest.skip("currently failing")
|
||||
def test_full_graph_rewrite_range_reduction(self):
|
||||
simple_range = UOp.range(dtypes.int32, 0, 5, idx=0)
|
||||
optimized_sink = apply_rewrite(simple_range.reduce(BinaryOps.ADD, simple_range))
|
||||
optimized_sink = apply_rewrite(simple_range.reduce(Ops.ADD, simple_range))
|
||||
expected_sum = sum(range(5))
|
||||
self.assertEqual(optimized_sink.arg, expected_sum)
|
||||
|
||||
@@ -80,7 +80,7 @@ class TestFoldingAndReduction(unittest.TestCase):
|
||||
def test_full_graph_rewrite_simple_reduction_folding(self):
|
||||
simple_range = UOp.range(dtypes.int32, 0, 4, idx=0)
|
||||
add_uop = simple_range + UOp.const(dtypes.int32, 1)
|
||||
optimized_sink = apply_rewrite(add_uop.reduce(BinaryOps.ADD, simple_range))
|
||||
optimized_sink = apply_rewrite(add_uop.reduce(Ops.ADD, simple_range))
|
||||
expected_sum = sum(i + 1 for i in range(4))
|
||||
self.assertEqual(optimized_sink.arg, expected_sum)
|
||||
|
||||
@@ -89,7 +89,7 @@ class TestFoldingAndReduction(unittest.TestCase):
|
||||
outer_range = UOp.range(dtypes.int32, 0, 8, 0)
|
||||
inner_range = UOp.range(dtypes.int32, 0, 4, 1)
|
||||
expr = (outer_range * 10) + inner_range
|
||||
optimized_reduce_uop = apply_rewrite(expr.reduce(BinaryOps.ADD, outer_range, inner_range))
|
||||
optimized_reduce_uop = apply_rewrite(expr.reduce(Ops.ADD, outer_range, inner_range))
|
||||
self.assertEqual(optimized_reduce_uop.op, Ops.CONST)
|
||||
self.assertEqual(optimized_reduce_uop.arg, sum((i * 10) + j for i in range(8) for j in range(4)))
|
||||
|
||||
@@ -104,7 +104,7 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
|
||||
def test_full_graph_rewrite_division_folding_with_define_var(self):
|
||||
n_var_uop = UOp.variable('n', 1, 1000)
|
||||
optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3)
|
||||
self.assertEqual(optimized_div_uop.op, BinaryOps.MUL)
|
||||
self.assertEqual(optimized_div_uop.op, Ops.MUL)
|
||||
self.assertEqual(optimized_div_uop.src[1].arg, 2)
|
||||
|
||||
def test_full_graph_rewrite_complex_mod_div_folding(self):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import unittest, itertools
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, UnaryOps, GroupOp # noqa: F401
|
||||
from tinygrad.ops import Ops, UOp, GroupOp # noqa: F401
|
||||
from tinygrad.ops import PatternMatcher, UPat
|
||||
|
||||
class TestPatternMatcher(unittest.TestCase):
|
||||
@@ -140,9 +140,9 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
# that CONST/ALU -> ALU/CONST rewrite is now instant
|
||||
"""
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.ALU))), lambda x: x)])
|
||||
c4 = UOp(UOps.ALU, dtypes.float, (c1,c3), BinaryOps.ADD)
|
||||
c5 = UOp(UOps.ALU, dtypes.float, (c3,c1), BinaryOps.ADD)
|
||||
matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=(UPat(Ops.CONST), UPat(GroupOp.ALU))), lambda x: x)])
|
||||
c4 = UOp(Ops.ADD, dtypes.float, (c1,c3))
|
||||
c5 = UOp(Ops.ADD, dtypes.float, (c3,c1))
|
||||
self.assertEqual(matcher.rewrite(c3), None)
|
||||
self.assertEqual(matcher.rewrite(c4), c4)
|
||||
self.assertEqual(matcher.rewrite(c5), None)
|
||||
|
||||
@@ -5,7 +5,7 @@ from collections import defaultdict
|
||||
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence
|
||||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.ops import GroupOp, BinaryOps, KernelInfo, UOp, Ops, PatternMatcher, can_pad, print_uops, type_verify, resolve, Variable, sint, \
|
||||
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, PatternMatcher, can_pad, print_uops, type_verify, resolve, Variable, sint, \
|
||||
graph_rewrite, track_rewrites
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, Program
|
||||
@@ -276,7 +276,7 @@ class Kernel:
|
||||
if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None
|
||||
|
||||
mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
|
||||
if mul_op.op is not BinaryOps.MUL: return None
|
||||
if mul_op.op is not Ops.MUL: return None
|
||||
|
||||
def buf_index(src:UOp) -> Optional[int]:
|
||||
# TODO: apply tc even if the sources are not from LOAD
|
||||
@@ -303,7 +303,7 @@ class Kernel:
|
||||
return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)
|
||||
|
||||
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
|
||||
if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is BinaryOps.ADD:
|
||||
if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD:
|
||||
for tc in self.opts.tensor_cores:
|
||||
tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
|
||||
# can only fuse reduces with the same tc options
|
||||
@@ -338,7 +338,7 @@ class Kernel:
|
||||
2: apply tensor core shape but don't use UOp.WMMA
|
||||
extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
|
||||
tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
|
||||
0: applies to only kernels with a single reduce axis and direct UOps.LOAD into BinaryOps.MUL
|
||||
0: applies to only kernels with a single reduce axis and direct UOps.LOAD into Ops.MUL
|
||||
1: allows kernels with multiple reduce axes and also multiplication of UOps.CAST'd buffers
|
||||
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
|
||||
"""
|
||||
@@ -441,7 +441,7 @@ class Kernel:
|
||||
check(not self.vars, "does not work with symbolic shape")
|
||||
check(axis < self.first_upcast, "cannot pad upcasted")
|
||||
# ok to pad SUM if all parent ALU ops have f(0) = 0
|
||||
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is BinaryOps.ADD and can_pad(r), f"cannot pad {r}")
|
||||
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r), f"cannot pad {r}")
|
||||
padded = False
|
||||
for i,st in enumerate(self.sts):
|
||||
if (s:=st.shape[axis]) == 1: continue # reduced
|
||||
@@ -470,8 +470,8 @@ class Kernel:
|
||||
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
|
||||
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
|
||||
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
|
||||
self.reduceop is not None and self.reduceop.arg[0] is BinaryOps.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \
|
||||
(mulop:=self.reduceop.src[0]).op is BinaryOps.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
|
||||
self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \
|
||||
(mulop:=self.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
|
||||
st0, st1 = self.sts[self.bufs.index(mulop.src[0])], self.sts[self.bufs.index(mulop.src[1])]
|
||||
strides0, strides1 = st0.real_strides(), st1.real_strides()
|
||||
def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides))
|
||||
@@ -663,7 +663,7 @@ class Kernel:
|
||||
# for TC=2, we can't do the shapetracker fixup
|
||||
srcs = [fixup_ast(rsrc.src[0]), fixup_ast(rsrc.src[1])]
|
||||
# MUL/SUM instead of WMMA
|
||||
ret = UOp(Ops.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(BinaryOps.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1]))
|
||||
ret = UOp(Ops.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(Ops.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1]))
|
||||
else:
|
||||
# real WMMA, use CONTRACT/EXPAND to get the vectorization right
|
||||
wmma_upcast_axes = wmma_arg[-2]
|
||||
|
||||
@@ -169,7 +169,7 @@ def sin_poly_large(d:UOp, q:UOp) -> UOp:
|
||||
|
||||
def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
|
||||
"""
|
||||
Implements a 1.0 ULP approximation for UnaryOps.SIN.
|
||||
Implements a 1.0 ULP approximation for Ops.SIN.
|
||||
- fast=True assumes x <= switch_over.
|
||||
- switch_over is the threshold for switching to payne_hanek_reduction.
|
||||
"""
|
||||
@@ -192,7 +192,7 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
|
||||
|
||||
def xexp2(d:UOp) -> UOp:
|
||||
"""
|
||||
Implements a 1.0 ULP approximation for UnaryOps.EXP2
|
||||
Implements a 1.0 ULP approximation for Ops.EXP2
|
||||
- Paper: https://arxiv.org/pdf/2001.09258
|
||||
"""
|
||||
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
@@ -218,7 +218,7 @@ def xexp2(d:UOp) -> UOp:
|
||||
|
||||
def xlog2(d:UOp) -> UOp:
|
||||
"""
|
||||
Implements a 1.0 ULP approximation for UnaryOps.LOG2
|
||||
Implements a 1.0 ULP approximation for Ops.LOG2
|
||||
Paper: https://arxiv.org/pdf/2001.09258 5.5
|
||||
"""
|
||||
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Optional, Tuple, Dict, List, TYPE_CHECKING, Any, DefaultDict,
|
||||
import functools, itertools, operator
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, ImageDType, PtrDType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple
|
||||
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple
|
||||
from tinygrad.ops import graph_rewrite, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
@@ -23,7 +23,7 @@ def fold_expanded(ex, buf):
|
||||
for i,s in enumerate(new_srcs):
|
||||
idx = s.src[0].src[1]
|
||||
if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue
|
||||
if idx.op is BinaryOps.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
||||
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
||||
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
||||
else: root_src, arg = idx, 0
|
||||
# add gates for gated
|
||||
@@ -92,12 +92,12 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
|
||||
|
||||
# can drop valid if idx is out of bound when valid is False
|
||||
drop_stmt = []
|
||||
for stmt in split_uop(valid, BinaryOps.AND):
|
||||
for stmt in split_uop(valid, Ops.AND):
|
||||
X, is_upper_bound, c = parse_valid(stmt)
|
||||
|
||||
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
|
||||
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, BinaryOps.ADD)):
|
||||
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, BinaryOps.ADD), idx)
|
||||
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, Ops.ADD)):
|
||||
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, Ops.ADD), idx)
|
||||
testidx = testidx.simplify()
|
||||
if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0:
|
||||
drop_stmt.append(stmt)
|
||||
@@ -114,7 +114,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
|
||||
break
|
||||
|
||||
if not drop_stmt and idx is start_idx: return None
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, BinaryOps.AND) if s not in drop_stmt]) else None
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None
|
||||
return buf.index(idx, new_valid)
|
||||
|
||||
# ***** optional patterns *****
|
||||
@@ -123,23 +123,23 @@ powers_of_two = {2**i:i for i in range(64)}
|
||||
@functools.lru_cache(None)
|
||||
def get_late_rewrite_patterns(ops, force_transcendental=False):
|
||||
pat: List[Tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
|
||||
((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if op not in ops or force_transcendental]
|
||||
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
|
||||
# rewrite MOD to AND (which should always be supported, but not for generic in tests)
|
||||
if BinaryOps.AND in ops:
|
||||
if Ops.AND in ops:
|
||||
pat += [(UPat(Ops.MOD, src=(UPat.var('base'), UPat.cvar("const"))),
|
||||
lambda base,const: base & (const.arg-1) if const.arg in powers_of_two else None)]
|
||||
# rewrite MUL/IDIV to SHL+SHR
|
||||
if BinaryOps.SHL in ops and BinaryOps.SHR in ops:
|
||||
if Ops.SHL in ops and Ops.SHR in ops:
|
||||
pat += [
|
||||
(UPat(Ops.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const:
|
||||
mul << powers_of_two[const.arg] if const.arg in powers_of_two else None), # (x * (2**y)) -> shl(x,y)
|
||||
(UPat(Ops.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const:
|
||||
div >> powers_of_two[const.arg] if const.arg in powers_of_two else None)] # (x // (2**y)) -> shr(x,y)
|
||||
if UnaryOps.NEG in ops:
|
||||
pat += [(UPat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))]
|
||||
if BinaryOps.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(UnaryOps.NEG), lambda x,y: x.alu(BinaryOps.SUB, y))]
|
||||
if TernaryOps.MULACC in ops:
|
||||
pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(TernaryOps.MULACC, b, c))]
|
||||
if Ops.NEG in ops:
|
||||
pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
|
||||
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
|
||||
if Ops.MULACC in ops:
|
||||
pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
|
||||
return PatternMatcher(pat)
|
||||
|
||||
# ***** threefry *****
|
||||
@@ -225,7 +225,7 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
|
||||
if len(reduce_unparented) == 0: return None
|
||||
new_acc = acc.replace(src=acc.src[0:1]+tuple(reduce_parented))
|
||||
ret = new_acc.assign(new_acc.alu(alu.op, ret))
|
||||
if alu.op is BinaryOps.ADD:
|
||||
if alu.op is Ops.ADD:
|
||||
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
return ret
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
from typing import Optional, Any, Tuple, List, get_args
|
||||
from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP, LAZYCACHE
|
||||
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, exec_alu, python_alu
|
||||
from tinygrad.ops import exec_alu, python_alu
|
||||
from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint, GroupOp, Ops
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer
|
||||
@@ -11,9 +11,9 @@ from weakref import ref, ReferenceType, WeakValueDictionary
|
||||
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
base:Optional[LazyBuffer]=None, enable_cache=bool(LAZYCACHE)):
|
||||
if st.size == 0: op, arg, srcs, base = MetaOps.CONST, 0, (), None
|
||||
if st.size == 0: op, arg, srcs, base = Ops.CONST, 0, (), None
|
||||
dtype = to_dtype(dtype)
|
||||
if op is MetaOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, UOp) else arg, True
|
||||
if op is Ops.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, UOp) else arg, True
|
||||
|
||||
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
|
||||
if enable_cache and (rret := lazycache.get(cache_key, None)) is not None: return rret
|
||||
@@ -32,13 +32,13 @@ class LazyBuffer(MathTrait):
|
||||
if base is None:
|
||||
# properties on base
|
||||
self.op, self.arg, self.srcs = op, arg, srcs # this is a UOp, except the src is LazyBuffers and not UOps
|
||||
assert self.op is not MetaOps.ASSIGN or srcs[0].base.realized is not None, "assign target must be realized"
|
||||
assert self.op is not Ops.ASSIGN or srcs[0].base.realized is not None, "assign target must be realized"
|
||||
|
||||
if self.op is MetaOps.BUFFER_VIEW:
|
||||
if self.op is Ops.BUFFER_VIEW:
|
||||
# some LazyBuffers can be processed with only a view, no AST required
|
||||
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, self.dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
|
||||
else:
|
||||
self.buffer = srcs[0].base.buffer if self.op is MetaOps.ASSIGN else Buffer(device, self.size, self.dtype)
|
||||
self.buffer = srcs[0].base.buffer if self.op is Ops.ASSIGN else Buffer(device, self.size, self.dtype)
|
||||
self.buffer.ref(1)
|
||||
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
|
||||
self.forced_realize = False
|
||||
@@ -74,14 +74,14 @@ class LazyBuffer(MathTrait):
|
||||
def const_like(self, b): return self.const_with_shape(b, self.shape)
|
||||
def const_with_shape(self, val:ConstType, shape:Tuple[sint,...]) -> LazyBuffer:
|
||||
assert isinstance(val, get_args(ConstType)), f"{val=} has {type(val)=}, not a ConstType"
|
||||
return LazyBuffer.metaop(MetaOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
|
||||
return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
|
||||
|
||||
def is_realized(self) -> bool: return self.base.realized is not None
|
||||
|
||||
def assign(self, x:LazyBuffer) -> LazyBuffer:
|
||||
assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
|
||||
assert self.is_realized(), f"assign target must be realized {self}"
|
||||
return LazyBuffer.metaop(MetaOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,),
|
||||
return LazyBuffer.metaop(Ops.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,),
|
||||
src=(self.base, x), enable_cache=True)
|
||||
|
||||
def can_view(self):
|
||||
@@ -90,7 +90,7 @@ class LazyBuffer(MathTrait):
|
||||
|
||||
def contiguous(self, allow_buffer_view=True):
|
||||
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
|
||||
ret = self.alu(MetaOps.BUFFER_VIEW) if allow_buffer_view and self.can_view() else self.alu(MetaOps.CONTIGUOUS)
|
||||
ret = self.alu(Ops.BUFFER_VIEW) if allow_buffer_view and self.can_view() else self.alu(Ops.CONTIGUOUS)
|
||||
if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
|
||||
return ret
|
||||
self.base.forced_realize = True
|
||||
@@ -101,7 +101,7 @@ class LazyBuffer(MathTrait):
|
||||
if self.dtype == dtype: return self
|
||||
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
|
||||
if self.is_unrealized_unmasked_const() and not bitcast:
|
||||
return create_lazybuffer(self.device, self.st, dtype, MetaOps.CONST, dtypes.as_const(self.base.arg, dtype))
|
||||
return create_lazybuffer(self.device, self.st, dtype, Ops.CONST, dtypes.as_const(self.base.arg, dtype))
|
||||
new_shape = self.shape
|
||||
if bitcast and self.dtype.itemsize != dtype.itemsize:
|
||||
if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
|
||||
@@ -112,27 +112,27 @@ class LazyBuffer(MathTrait):
|
||||
elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
|
||||
# TODO: applying this makes gpt2 slower
|
||||
return self.base.cast(dtype, bitcast)._view(self.st)
|
||||
cast_op: Ops = (MetaOps.BUFFER_VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
|
||||
cast_op: Ops = (Ops.BUFFER_VIEW if self.can_view() and allow_buffer_view else Ops.BITCAST) if bitcast else Ops.CAST
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
|
||||
|
||||
def is_unrealized_const(self): return self.base.realized is None and self.base.op is MetaOps.CONST and not isinstance(self.base.arg, UOp)
|
||||
def is_unrealized_const(self): return self.base.realized is None and self.base.op is Ops.CONST and not isinstance(self.base.arg, UOp)
|
||||
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
|
||||
|
||||
def _copy(self, device:str) -> LazyBuffer:
|
||||
assert self.st.contiguous and self.size == self.base.size, f"can only copy contig {self} {self.base}"
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, MetaOps.COPY, self.buffer.nbytes, (self,), enable_cache=False)
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, Ops.COPY, self.buffer.nbytes, (self,), enable_cache=False)
|
||||
|
||||
def copy_to_device(self, device:str, force:bool=False, clone:bool=False) -> LazyBuffer:
|
||||
# no COPY
|
||||
if self.device == device and not clone: return self
|
||||
|
||||
# double COPY = one COPY
|
||||
if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is MetaOps.COPY:
|
||||
if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is Ops.COPY:
|
||||
return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
|
||||
|
||||
# const doesn't have to be copied (issues with disk tensor)
|
||||
if self.is_unrealized_const():
|
||||
return LazyBuffer.metaop(MetaOps.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
|
||||
return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
|
||||
|
||||
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
|
||||
if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device)
|
||||
@@ -149,25 +149,25 @@ class LazyBuffer(MathTrait):
|
||||
srcs.append(root._view(s.base.contiguous_child[1]))
|
||||
else:
|
||||
srcs.append(s)
|
||||
if not all_same(dts:=[x.dtype.base for x in (srcs[1:] if op is TernaryOps.WHERE else srcs)]):
|
||||
if not all_same(dts:=[x.dtype.base for x in (srcs[1:] if op is Ops.WHERE else srcs)]):
|
||||
raise AssertionError(f"all dtypes must match {dts} on {op}")
|
||||
assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
|
||||
if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
|
||||
if op is Ops.WHERE: assert srcs[0].dtype == dtypes.bool, "Ops.WHERE must have the first arg be bool"
|
||||
|
||||
out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else srcs[-1].dtype
|
||||
out_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else srcs[-1].dtype
|
||||
|
||||
# const folding
|
||||
if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
|
||||
return self.cast(out_dtype).const_like(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
|
||||
if op in GroupOp.Binary:
|
||||
x, y = self, in_srcs[0]
|
||||
if op is BinaryOps.ADD:
|
||||
if op is Ops.ADD:
|
||||
if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
|
||||
if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y
|
||||
if op is BinaryOps.MUL:
|
||||
if op is Ops.MUL:
|
||||
if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0): return y if val == 1 else y.const_like(0)
|
||||
if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0): return x if val == 1 else x.const_like(0)
|
||||
if op is BinaryOps.IDIV and y.is_unrealized_unmasked_const() and y.base.arg == 1: return x
|
||||
if op is Ops.IDIV and y.is_unrealized_unmasked_const() and y.base.arg == 1: return x
|
||||
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, None, tuple(srcs))
|
||||
|
||||
|
||||
129
tinygrad/ops.py
129
tinygrad/ops.py
@@ -29,14 +29,14 @@ class SimpleMathTrait:
|
||||
dtype: Optional[DType] = getattr(self, 'dtype', None)
|
||||
assert dtype is not None, "MathTraits __neg__ requires a dtype"
|
||||
return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
|
||||
def add(self, x, reverse=False): return self._binop(BinaryOps.ADD, x, reverse)
|
||||
def mul(self, x, reverse=False): return self._binop(BinaryOps.MUL, x, reverse)
|
||||
def bitwise_and(self, x, reverse=False): return self._binop(BinaryOps.AND, x, reverse)
|
||||
def bitwise_or(self, x, reverse=False): return self._binop(BinaryOps.OR, x, reverse)
|
||||
def xor(self, x, reverse=False): return self._binop(BinaryOps.XOR, x, reverse)
|
||||
def idiv(self, x, reverse=False): return self._binop(BinaryOps.IDIV, x, reverse)
|
||||
def sub(self, x, reverse=False): return self.ufix(x).alu(BinaryOps.ADD, -self) if reverse else self.alu(BinaryOps.ADD, self.ufix(-x))
|
||||
def div(self, x, reverse=False): return (self.ufix(x)*self.alu(UnaryOps.RECIP)) if reverse else (self*self.ufix(x).alu(UnaryOps.RECIP))
|
||||
def add(self, x, reverse=False): return self._binop(Ops.ADD, x, reverse)
|
||||
def mul(self, x, reverse=False): return self._binop(Ops.MUL, x, reverse)
|
||||
def bitwise_and(self, x, reverse=False): return self._binop(Ops.AND, x, reverse)
|
||||
def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse)
|
||||
def xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse)
|
||||
def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse)
|
||||
def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
|
||||
def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP))
|
||||
|
||||
def __neg__(self): return self.neg()
|
||||
|
||||
@@ -58,9 +58,9 @@ class SimpleMathTrait:
|
||||
def __ror__(self, x): return self.bitwise_or(x, True)
|
||||
def __rxor__(self, x): return self.xor(x, True)
|
||||
|
||||
def lt(self, x): return self.alu(BinaryOps.CMPLT, self.ufix(x))
|
||||
def gt(self, x): return self.ufix(x).alu(BinaryOps.CMPLT, self)
|
||||
def ne(self, x): return self.alu(BinaryOps.CMPNE, self.ufix(x))
|
||||
def lt(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
|
||||
def gt(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
|
||||
def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x))
|
||||
def ge(self, x): return self.lt(x).logical_not()
|
||||
def le(self, x): return self.gt(x).logical_not()
|
||||
def eq(self, x): return self.ne(x).logical_not()
|
||||
@@ -74,26 +74,26 @@ class SimpleMathTrait:
|
||||
|
||||
class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
# TODO: move to Tensor when new backward is done
|
||||
def lshift(self, x, reverse=False): return self._binop(BinaryOps.SHL, x, reverse)
|
||||
def rshift(self, x, reverse=False): return self._binop(BinaryOps.SHR, x, reverse)
|
||||
def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse)
|
||||
def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse)
|
||||
def __lshift__(self, x): return self.lshift(x)
|
||||
def __rshift__(self, x): return self.rshift(x)
|
||||
def __rlshift__(self, x): return self.lshift(x, True)
|
||||
def __rrshift__(self, x): return self.rshift(x, True)
|
||||
|
||||
# not in Tensor
|
||||
def __mod__(self, x): return self.alu(BinaryOps.MOD, self.ufix(x))
|
||||
def __rmod__(self, x): return self.ufix(x).alu(BinaryOps.MOD, self)
|
||||
def __mod__(self, x): return self.alu(Ops.MOD, self.ufix(x))
|
||||
def __rmod__(self, x): return self.ufix(x).alu(Ops.MOD, self)
|
||||
|
||||
def maximum(self, x): return self.alu(BinaryOps.MAX, self.ufix(x))
|
||||
def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x))
|
||||
def minimum(self, x): return -(-self).maximum(-x)
|
||||
def where(self, x, y): return self.alu(TernaryOps.WHERE, x, x.ufix(y))
|
||||
def threefry(self, seed): return self.alu(BinaryOps.THREEFRY, seed)
|
||||
def reciprocal(self): return self.alu(UnaryOps.RECIP)
|
||||
def sqrt(self): return self.alu(UnaryOps.SQRT)
|
||||
def sin(self): return self.alu(UnaryOps.SIN)
|
||||
def log2(self): return self.alu(UnaryOps.LOG2)
|
||||
def exp2(self): return self.alu(UnaryOps.EXP2)
|
||||
def where(self, x, y): return self.alu(Ops.WHERE, x, x.ufix(y))
|
||||
def threefry(self, seed): return self.alu(Ops.THREEFRY, seed)
|
||||
def reciprocal(self): return self.alu(Ops.RECIP)
|
||||
def sqrt(self): return self.alu(Ops.SQRT)
|
||||
def sin(self): return self.alu(Ops.SIN)
|
||||
def log2(self): return self.alu(Ops.LOG2)
|
||||
def exp2(self): return self.alu(Ops.EXP2)
|
||||
|
||||
# the order of these Ops controls the order of the toposort
|
||||
class Ops(FastEnum):
|
||||
@@ -182,11 +182,8 @@ class GroupOp:
|
||||
# do not preserve f(0) = 0
|
||||
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}
|
||||
|
||||
# TODO: remove this?
|
||||
UnaryOps = BinaryOps = MetaOps = TernaryOps = Ops
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt)
|
||||
def identity_element(op:Ops, dt:DType): return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
||||
|
||||
def can_pad(u:UOp) -> bool: return not any(x.op in GroupOp.UnsafePad for x in u.sparents)
|
||||
|
||||
@@ -325,7 +322,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
||||
def alu(self, arg, *src:UOp):
|
||||
out_dtype = (self, *src)[-1].dtype
|
||||
if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} and out_dtype is not None:
|
||||
if arg in {Ops.CMPLT, Ops.CMPNE} and out_dtype is not None:
|
||||
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
||||
return UOp(arg, out_dtype, (self,)+src)
|
||||
@staticmethod
|
||||
@@ -392,15 +389,15 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
"""largest known int that divides self"""
|
||||
if self.op is Ops.CONST: return self.arg
|
||||
if self.op is Ops.VCONST: return math.gcd(*self.arg)
|
||||
if self.op is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
|
||||
if self.op is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
|
||||
if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
|
||||
if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
|
||||
return 1
|
||||
def divides(self, v) -> Optional[UOp]:
|
||||
if v==1: return self
|
||||
if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
|
||||
if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
|
||||
if self.op is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
|
||||
if self.op is BinaryOps.MUL:
|
||||
if self.op is Ops.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
|
||||
if self.op is Ops.MUL:
|
||||
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
|
||||
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
|
||||
return None # generic None if we aren't sure
|
||||
@@ -412,18 +409,18 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def _min_max(self) -> Tuple[ConstType, ConstType]:
|
||||
if self.op in GroupOp.Binary and not dtypes.is_float(self.dtype):
|
||||
(s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max
|
||||
if self.op is BinaryOps.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
|
||||
if self.op is BinaryOps.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
|
||||
if self.op is BinaryOps.MOD and s1_vmin > 0: return 0, s1_vmax-1
|
||||
if self.op is BinaryOps.IDIV and s1_vmin == s1_vmax: # min/max are equal in a CONST
|
||||
if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
|
||||
if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
|
||||
if self.op is Ops.MOD and s1_vmin > 0: return 0, s1_vmax-1
|
||||
if self.op is Ops.IDIV and s1_vmin == s1_vmax: # min/max are equal in a CONST
|
||||
if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
|
||||
if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
|
||||
if self.op is BinaryOps.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
|
||||
if self.op is BinaryOps.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
|
||||
if self.op is BinaryOps.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))
|
||||
if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
|
||||
if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
|
||||
if self.op is Ops.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))
|
||||
if self.dtype == dtypes.bool:
|
||||
if self.op is BinaryOps.OR: return s0_vmin or s1_vmin, s0_vmax or s1_vmax
|
||||
if self.op is BinaryOps.AND: return s0_vmin and s1_vmin, s0_vmax and s1_vmax
|
||||
if self.op is Ops.OR: return s0_vmin or s1_vmin, s0_vmax or s1_vmax
|
||||
if self.op is Ops.AND: return s0_vmin and s1_vmin, s0_vmax and s1_vmax
|
||||
# float has NAN issue and we use explicit NAN in transcendental
|
||||
if self.op is Ops.WHERE and dtypes.is_int(self.dtype): return min(self.src[1].vmin, self.src[2].vmin), max(self.src[1].vmax, self.src[2].vmax)
|
||||
# NOTE: returned UOp is assumed to be CONST
|
||||
@@ -467,14 +464,14 @@ def hook_overflow(dv, fxn):
|
||||
return wfxn
|
||||
|
||||
python_alu: Dict[Ops, Callable] = {
|
||||
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x),
|
||||
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
||||
UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
|
||||
UnaryOps.NEG: operator.neg, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul,
|
||||
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
|
||||
BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt, BinaryOps.XOR: operator.xor,
|
||||
BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_, BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift,
|
||||
TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z}
|
||||
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: hook_overflow(math.inf, lambda x: 2**x),
|
||||
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
||||
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
|
||||
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul,
|
||||
Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
|
||||
Ops.MAX: max, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt, Ops.XOR: operator.xor,
|
||||
Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift,
|
||||
Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z}
|
||||
|
||||
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
|
||||
if dtype.count > 1:
|
||||
@@ -515,7 +512,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
||||
elif u.op is Ops.STORE:
|
||||
mem += u.src[1].dtype.itemsize * mults
|
||||
elif u.op in GroupOp.ALU and u not in dont_count:
|
||||
flops += (mults * (2 if u.op is TernaryOps.MULACC else 1)) * u.dtype.count
|
||||
flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
|
||||
elif u.op is Ops.WMMA and u not in dont_count:
|
||||
flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
||||
return flops, mem
|
||||
@@ -857,7 +854,7 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
if 0 < c and 0 <= x.vmin and (quotient:=x.vmin//c) == x.vmax//c: return x-quotient*c
|
||||
|
||||
remainder, something_changed = [], False
|
||||
for u in split_uop(x, BinaryOps.ADD):
|
||||
for u in split_uop(x, Ops.ADD):
|
||||
if (factor:=u.const_factor())%c != factor:
|
||||
divides = u.divides(factor)*(factor%c)
|
||||
assert divides is not None
|
||||
@@ -877,7 +874,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
if 0 <= x.vmin and x.vmax < c: return x.const_like(0)
|
||||
|
||||
quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1
|
||||
for u in split_uop(x, BinaryOps.ADD):
|
||||
for u in split_uop(x, Ops.ADD):
|
||||
if u.op is Ops.CONST:
|
||||
# add all const together first
|
||||
if rem_const != 0: something_changed = True
|
||||
@@ -911,7 +908,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo
|
||||
|
||||
def lt_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
p, np = partition(split_uop(x, BinaryOps.ADD), lambda u: u.const_factor() == 1)
|
||||
p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1)
|
||||
if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
|
||||
return cast(UOp, functools.reduce(operator.add, np).divides(d)).lt(c//d)
|
||||
return None
|
||||
@@ -919,7 +916,7 @@ def lt_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
def fold_unrolled_divs(divs:UOp):
|
||||
# div pattern in unrolled arange
|
||||
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
|
||||
add_chain, denominator, seen_const, ans = list(split_uop(divs, BinaryOps.ADD)), None, [], None
|
||||
add_chain, denominator, seen_const, ans = list(split_uop(divs, Ops.ADD)), None, [], None
|
||||
for u in add_chain:
|
||||
if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None
|
||||
if denominator is None: denominator = u.src[1].arg
|
||||
@@ -941,9 +938,9 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
|
||||
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
|
||||
# returns x0 + x1 + ... in such case, or None if not
|
||||
changed, ret = False, []
|
||||
for u in split_uop(X, BinaryOps.ADD):
|
||||
for u in split_uop(X, Ops.ADD):
|
||||
# assumed the const is the last src of MUL
|
||||
if u.op is BinaryOps.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
|
||||
if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
|
||||
changed = True
|
||||
u = u.src[0]
|
||||
if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None
|
||||
@@ -953,8 +950,8 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
|
||||
def is_increasing(f:UOp) -> bool:
|
||||
# is f a monotonically increasing function regards its input
|
||||
if f.op in GroupOp.Irreducible: return True
|
||||
if f.op is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
|
||||
if f.op in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
|
||||
if f.op is Ops.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
|
||||
if f.op in (Ops.MUL, Ops.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
|
||||
return False # False if not sure
|
||||
|
||||
def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
|
||||
@@ -962,10 +959,10 @@ def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
|
||||
# if it's X >= c, returns X, False, c
|
||||
|
||||
# (X < c).ne(True) -> X >= c
|
||||
if valid.op is BinaryOps.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
|
||||
(s0:=valid.src[0]).op is BinaryOps.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
|
||||
if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
|
||||
(s0:=valid.src[0]).op is Ops.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
|
||||
# X < c -> X <= c-1
|
||||
if valid.op is BinaryOps.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1
|
||||
if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1
|
||||
raise ValueError(f"not able to parse {valid=}")
|
||||
|
||||
def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
||||
@@ -973,7 +970,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
||||
|
||||
# first, parse valid into {expr: (lower_bound, upper_bound)}
|
||||
bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None])
|
||||
for stmt in split_uop(valid, BinaryOps.AND):
|
||||
for stmt in split_uop(valid, Ops.AND):
|
||||
try: expr, is_upper, c = parse_valid(stmt)
|
||||
except ValueError: return uop # give up if we cannot parse the valid
|
||||
bounds[expr][int(is_upper)] = c
|
||||
@@ -985,9 +982,9 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
||||
|
||||
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
||||
candidates = []
|
||||
if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, BinaryOps.ADD)):
|
||||
if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
|
||||
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, BinaryOps.ADD)])
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
|
||||
# try checking the whole clause
|
||||
if expr in uop.sparents:
|
||||
candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
|
||||
@@ -1128,8 +1125,8 @@ symbolic_flat = symbolic+PatternMatcher([
|
||||
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
|
||||
# for debug
|
||||
syms = { BinaryOps.ADD: "+", BinaryOps.SUB: "-", BinaryOps.IDIV: "//", BinaryOps.MOD: "%", BinaryOps.SHL: "<<", BinaryOps.SHR: ">>",
|
||||
BinaryOps.MUL: "*", BinaryOps.CMPLT: "<", BinaryOps.CMPNE: "!=", BinaryOps.AND: "&", BinaryOps.OR: "|", BinaryOps.XOR: "^"}
|
||||
syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>",
|
||||
Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}
|
||||
renderer = PatternMatcher([
|
||||
(UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")),
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast
|
||||
import os, math
|
||||
from collections import defaultdict, Counter
|
||||
from tinygrad.ops import GroupOp, UnaryOps, BinaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16
|
||||
from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16
|
||||
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
||||
from tinygrad.renderer import Renderer, TensorCore
|
||||
@@ -42,13 +42,13 @@ base_rewrite = PatternMatcher([
|
||||
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
|
||||
# new load/store
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
|
||||
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == BinaryOps.ADD else ctx[idx]})"),
|
||||
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
|
||||
# alu/gep
|
||||
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
|
||||
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else ctx[v] for v in x.src]), x.dtype)),
|
||||
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR} else ctx[v] for v in x.src]), x.dtype)),
|
||||
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
|
||||
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
|
||||
])
|
||||
@@ -82,13 +82,13 @@ class CStyleLanguage(Renderer):
|
||||
infinity: str = "INFINITY"
|
||||
nan: str = "NAN"
|
||||
code_for_op: Dict = {
|
||||
UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})", UnaryOps.RECIP: lambda x,dtype: f"(1/{x})", UnaryOps.NEG: lambda x,dtype: f"-{x}",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})",
|
||||
BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})",
|
||||
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})",
|
||||
BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})", BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})",
|
||||
BinaryOps.SHR: lambda a,b,dtype: f"({a}>>{b})", BinaryOps.SHL: lambda a,b,dtype: f"({a}<<{b})", BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})",
|
||||
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})" }
|
||||
Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIP: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}",
|
||||
Ops.EXP2: lambda x,dtype: f"exp2({x})", Ops.LOG2: lambda x,dtype: f"log2({x})", Ops.SIN: lambda x,dtype: f"sin({x})",
|
||||
Ops.AND: lambda a,b,dtype: f"({a}&{b})", Ops.XOR: lambda a,b,dtype: f"({a}^{b})", Ops.OR: lambda a,b,dtype: f"({a}|{b})",
|
||||
Ops.ADD: lambda a,b,dtype: f"({a}+{b})", Ops.SUB: lambda a,b,dtype: f"({a}-{b})", Ops.MUL: lambda a,b,dtype: f"({a}*{b})",
|
||||
Ops.MOD: lambda a,b,dtype: f"({a}%{b})", Ops.IDIV: lambda a,b,dtype: f"({a}/{b})", Ops.CMPNE: lambda a,b,dtype: f"({a}!={b})",
|
||||
Ops.SHR: lambda a,b,dtype: f"({a}>>{b})", Ops.SHL: lambda a,b,dtype: f"({a}<<{b})", Ops.CMPLT: lambda a,b,dtype: f"({a}<{b})",
|
||||
Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})" }
|
||||
|
||||
string_rewrite = base_rewrite
|
||||
extra_matcher = extra_pm
|
||||
@@ -174,8 +174,8 @@ class ClangRenderer(CStyleLanguage):
|
||||
# language options
|
||||
buffer_suffix = " restrict"
|
||||
type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
|
||||
code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [UnaryOps.EXP2, UnaryOps.SIN, UnaryOps.LOG2]}),
|
||||
UnaryOps.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"}
|
||||
code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2]}),
|
||||
Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"}
|
||||
|
||||
if AMX:
|
||||
tensor_cores = [TensorCore(dims=(sz,sz,1), threads=[], reduce_axes=[], upcast_axes=([(1,sz)],[(0,sz)],[(1,sz),(0,sz)]), dtype_in=dt, dtype_out=dt)
|
||||
@@ -265,12 +265,12 @@ class MetalRenderer(CStyleLanguage):
|
||||
type_map = {dtypes.bfloat16: "bfloat"}
|
||||
|
||||
# precise::sin
|
||||
code_for_op = {**CStyleLanguage.code_for_op, UnaryOps.SIN: lambda x,dtype: f"precise::sin({x})"}
|
||||
code_for_op = {**CStyleLanguage.code_for_op, Ops.SIN: lambda x,dtype: f"precise::sin({x})"}
|
||||
|
||||
# upcast to float32 all the ops that don't support bfloat16
|
||||
extra_matcher = PatternMatcher([
|
||||
# NOTE: this is copied from PTX
|
||||
(UPat((UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN), dtype=dtypes.bfloat16, name="x"),
|
||||
(UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtypes.bfloat16, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))),
|
||||
]) + extra_pm
|
||||
|
||||
@@ -312,11 +312,11 @@ class CUDARenderer(CStyleLanguage):
|
||||
code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+int(x))}", "l": lambda x: f"threadIdx.{chr(120+int(x))}",
|
||||
"i": lambda x: f"(blockIdx.{chr(120+int(x))}*blockDim.{chr(120+int(x))}+threadIdx.{chr(120+int(x))})"}
|
||||
code_for_op = { **CStyleLanguage.code_for_op,
|
||||
UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
|
||||
UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" }
|
||||
Ops.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
|
||||
Ops.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
|
||||
Ops.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",
|
||||
Ops.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
|
||||
Ops.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" }
|
||||
type_map = {dtypes.bfloat16: "nv_bfloat16"}
|
||||
|
||||
def render_vector_prefix(self, dt:DType) -> str:
|
||||
@@ -375,10 +375,10 @@ class AMDRenderer(CStyleLanguage):
|
||||
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
|
||||
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
|
||||
code_for_op = { **CStyleLanguage.code_for_op,
|
||||
UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})" }
|
||||
Ops.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
Ops.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
Ops.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
Ops.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})" }
|
||||
smem_prefix = "__attribute__((shared))"
|
||||
barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
|
||||
'__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'
|
||||
@@ -433,9 +433,9 @@ class DSPRenderer(ClangRenderer):
|
||||
buffer_suffix = " restrict __attribute__((align_value(128)))"
|
||||
kernel_prefix = "__attribute__((noinline)) "
|
||||
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
|
||||
code_for_op = {**ClangRenderer.code_for_op, UnaryOps.SIN: lambda x,dtype: f"__builtin_sin({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"__builtin_exp2l({x})" if dtype == dtypes.float64 else f"__builtin_exp2f({x})"}
|
||||
code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})",
|
||||
Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",
|
||||
Ops.EXP2: lambda x,dtype: f"__builtin_exp2l({x})" if dtype == dtypes.float64 else f"__builtin_exp2f({x})"}
|
||||
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
|
||||
ret = super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable, Tuple
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, GroupOp
|
||||
from tinygrad.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
@@ -15,24 +15,24 @@ def render_val(x, dtype):
|
||||
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
|
||||
|
||||
asm_for_op: Dict[Ops, Callable] = {
|
||||
UnaryOps.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
|
||||
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
||||
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
||||
BinaryOps.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", BinaryOps.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
|
||||
BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
|
||||
BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
|
||||
BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
|
||||
BinaryOps.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if name == "pred" else f"and.b{name[1:]} {d}, {a}, {b};",
|
||||
BinaryOps.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if name == "pred" else f"or.b{name[1:]} {d}, {a}, {b};",
|
||||
BinaryOps.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
|
||||
BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
|
||||
BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", BinaryOps.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
|
||||
TernaryOps.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
|
||||
TernaryOps.WHERE: lambda d,a,b,c,dt,name:
|
||||
Ops.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
|
||||
Ops.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", Ops.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
||||
Ops.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", Ops.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
||||
Ops.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", Ops.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
|
||||
Ops.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
|
||||
Ops.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
|
||||
Ops.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
|
||||
Ops.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if name == "pred" else f"and.b{name[1:]} {d}, {a}, {b};",
|
||||
Ops.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if name == "pred" else f"or.b{name[1:]} {d}, {a}, {b};",
|
||||
Ops.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
|
||||
Ops.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", Ops.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
|
||||
Ops.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", Ops.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
|
||||
Ops.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
|
||||
Ops.WHERE: lambda d,a,b,c,dt,name:
|
||||
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
|
||||
}
|
||||
|
||||
supports_half: List[Ops] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
|
||||
supports_half: List[Ops] = [Ops.EXP2, Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPLT, Ops.WHERE]
|
||||
doesnt_support_half: Tuple[Ops, ...] = tuple(op for op in asm_for_op.keys() if op not in supports_half)
|
||||
ptx_matcher = PatternMatcher([
|
||||
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
|
||||
@@ -151,8 +151,8 @@ class PTXRenderer(Renderer):
|
||||
kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{uops.index(u)}", pred_reg, invert=True))
|
||||
elif uop is Ops.BARRIER and self.barrier: kk(self.barrier)
|
||||
elif uop is Ops.ENDRANGE:
|
||||
kk(self.code_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
|
||||
self.code_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
|
||||
kk(self.code_for_op[Ops.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
|
||||
self.code_for_op[Ops.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
|
||||
kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred))
|
||||
elif uop is Ops.ENDIF:
|
||||
kk(f"IF_{r[src[0].src[0]][1:]}_{uops.index(src[0])}:")
|
||||
@@ -167,7 +167,7 @@ class PTXRenderer(Renderer):
|
||||
else:
|
||||
if uop is Ops.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
|
||||
elif uop in GroupOp.ALU:
|
||||
src_dtype = src[0].dtype if uop in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype
|
||||
src_dtype = src[0].dtype if uop in {Ops.CMPLT, Ops.CMPNE} else dtype
|
||||
kk(self.code_for_op[uop](ssa("alu", u), *[r[x] for x in src], src_dtype, self.types[src_dtype]))
|
||||
elif uop is Ops.DEFINE_ACC:
|
||||
if dtype.count > 1:
|
||||
|
||||
@@ -7,7 +7,7 @@ import pickle, base64, itertools, time, struct
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
|
||||
from tinygrad.helpers import all_same, getenv, flatten
|
||||
from tinygrad.device import Compiled, Compiler, Allocator
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, Ops, UOp, GroupOp
|
||||
from tinygrad.ops import exec_alu, Ops, UOp, GroupOp
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer, ClangRenderer
|
||||
|
||||
@@ -175,7 +175,7 @@ class PythonProgram:
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
elif uop in GroupOp.ALU:
|
||||
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {uop}"
|
||||
assert all_same([dtype] + dtp) or uop in {BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {uop}"
|
||||
assert all_same([dtype] + dtp) or uop in {Ops.CMPNE, Ops.CMPLT, Ops.WHERE}, f"dtype mismatch on {uop}"
|
||||
ul[i] = [exec_alu(uop, dtype, p) for p in zip(*inp)]
|
||||
assert i in ul, (uop, dtype, idp, arg)
|
||||
i += 1
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Tuple, List, Optional, Dict, Set
|
||||
from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, Ops, BinaryOps, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid
|
||||
from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class ShapeTracker:
|
||||
@@ -77,7 +77,7 @@ class ShapeTracker:
|
||||
# TODO: always apply these in to_indexed_uops?
|
||||
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
|
||||
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
|
||||
for c in split_uop(idx, BinaryOps.ADD):
|
||||
for c in split_uop(idx, Ops.ADD):
|
||||
if c.op is Ops.RANGE: ret[c.arg[0]] = 1
|
||||
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg
|
||||
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg
|
||||
|
||||
@@ -9,7 +9,7 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
from tinygrad.ops import MetaOps, smax, smin, resolve, UOp, Ops, BinaryOps, sint, Variable, SimpleMathTrait
|
||||
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait
|
||||
from tinygrad.device import Device, Buffer, BufferOptions
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
@@ -51,7 +51,7 @@ def _to_np_dtype(dtype:DType) -> Optional[type]:
|
||||
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
||||
|
||||
def _fromnp(x: 'np.ndarray') -> LazyBuffer: # type: ignore [name-defined] # noqa: F821
|
||||
ret = LazyBuffer.metaop(MetaOps.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
|
||||
ret = LazyBuffer.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
|
||||
# fake realize
|
||||
ret.buffer.allocate(x)
|
||||
del ret.srcs
|
||||
@@ -64,9 +64,9 @@ def get_shape(x) -> Tuple[int, ...]:
|
||||
return (len(subs),) + (subs[0] if subs else ())
|
||||
|
||||
def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
|
||||
if isinstance(x, bytes): ret, data = LazyBuffer.metaop(MetaOps.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
|
||||
if isinstance(x, bytes): ret, data = LazyBuffer.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
|
||||
else:
|
||||
ret = LazyBuffer.metaop(MetaOps.EMPTY, get_shape(x), dtype, "PYTHON")
|
||||
ret = LazyBuffer.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON")
|
||||
assert dtype.fmt is not None, f"{dtype=} has None fmt"
|
||||
truncate_function = truncate[dtype]
|
||||
data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
|
||||
@@ -134,10 +134,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
|
||||
# create a LazyBuffer from the different types of inputs
|
||||
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
elif isinstance(data, get_args(ConstType)): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
|
||||
elif isinstance(data, get_args(ConstType)): data = _metaop(Ops.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
|
||||
elif isinstance(data, UOp):
|
||||
assert data.op is Ops.BIND and data.src[0].op is Ops.DEFINE_VAR and data.src[1].op is Ops.CONST, f"can't create tensor from UOp {data}"
|
||||
data = _metaop(MetaOps.CONST, tuple(), dtype or data.dtype, device, data)
|
||||
data = _metaop(Ops.CONST, tuple(), dtype or data.dtype, device, data)
|
||||
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
if dtype is None:
|
||||
@@ -145,15 +145,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
|
||||
if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
|
||||
else: data = _frompy(data, dtype)
|
||||
elif data is None: data = _metaop(MetaOps.EMPTY, (0,), dtype or dtypes.default_float, device)
|
||||
elif data is None: data = _metaop(Ops.EMPTY, (0,), dtype or dtypes.default_float, device)
|
||||
elif str(type(data)) == "<class 'numpy.ndarray'>":
|
||||
import numpy as np
|
||||
assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}"
|
||||
if data.shape == (): data = _metaop(MetaOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
|
||||
if data.shape == (): data = _metaop(Ops.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
|
||||
else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data) # type: ignore [name-defined]
|
||||
elif isinstance(data, pathlib.Path):
|
||||
dtype = dtype or dtypes.uint8
|
||||
data = _metaop(MetaOps.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
|
||||
data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
|
||||
|
||||
# by this point, it has to be a LazyBuffer
|
||||
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)):
|
||||
@@ -384,9 +384,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
def from_uop(y:UOp, **kwargs) -> Tensor:
|
||||
if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor
|
||||
if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
|
||||
if y.op is BinaryOps.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
|
||||
if y.op is BinaryOps.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
|
||||
if y.op is BinaryOps.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1]))
|
||||
if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
|
||||
if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
|
||||
if y.op is Ops.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1]))
|
||||
raise RuntimeError(f"unhandled UOp {y}")
|
||||
|
||||
# ***** creation entrypoint *****
|
||||
@@ -412,7 +412,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
print(t.shape)
|
||||
```
|
||||
"""
|
||||
return Tensor._metaop(MetaOps.EMPTY, argfix(*shape), **kwargs)
|
||||
return Tensor._metaop(Ops.EMPTY, argfix(*shape), **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def from_blob(ptr:int, shape:Tuple[int, ...], **kwargs) -> Tensor:
|
||||
@@ -424,7 +424,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
"""
|
||||
|
||||
r = Tensor._metaop(MetaOps.EMPTY, shape, **kwargs)
|
||||
r = Tensor._metaop(Ops.EMPTY, shape, **kwargs)
|
||||
r.lazydata.buffer.allocate(external_ptr=ptr)
|
||||
del r.lazydata.srcs # fake realize
|
||||
return r
|
||||
|
||||
Reference in New Issue
Block a user