Files
tinygrad/test/test_linearizer.py
gswangg 0e6f057eae migrate test_linearizer.py to UOP AST (pt. 1) (#6150)
* migrate test_multioutput to UOP AST

* inline buf declarations

* migrate test_multireduce to UOp AST

* update test_mid_dim_multireduce to UOp AST

* update test_triple_multireduce with UOp AST

* make global definitions more concise

* update test_double_reduce_multireduce with UOp AST

* update test_multireduce_with_parallel with UOp AST

* update test_multiout_multireduce to UOp AST

* make gidx style consistent across updated tests

---------

Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
2024-08-20 10:02:20 +03:00

1916 lines
115 KiB
Python

from typing import List, Tuple, Union, cast
import numpy as np
import unittest
from dataclasses import replace
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel
from tinygrad.codegen.lowerer import get_grouped_dims
from tinygrad.ops import UOp, UOps
from tinygrad.device import Device, Buffer
from extra.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, MetaOps, TernaryOps, ReduceOps, UnaryOps, to_uop
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
# from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner
from tinygrad.helpers import prod, Context, getenv, CI, flatten, dedup
from tinygrad.dtype import DType, PtrDType, dtypes
def helper_realized_ast(r:Union[Tensor, List[Tensor]]) -> Tuple[UOp, List[Buffer]]:
if isinstance(r, Tensor): r = [r]
s = create_schedule([x.lazydata for x in r])
run_schedule(s[:-1]) # run all kernels except the last one
# now all input LazyBuffers buffers in s[-1] should be realized
# allocate an output buffer
output_buffers = [Buffer((out).device, out.size, out.dtype).allocate() for out in s[-1].outputs]
return s[-1].ast, output_buffers+list(s[-1].inputs)
def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_opt:int=0):
a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in)
np_a, np_b = a.numpy(), b.numpy()
r = a.matmul(b, acc_dtype=dtype_out)
sched = create_schedule([r.lazydata])
realized_ast = sched[-1].ast
run_schedule(sched)
out = r.numpy()
k = Kernel(realized_ast)
k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt)
k.linearize()
assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered"
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
np_c = np_a @ np_b
if dtype_out == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3
elif dtype_in == dtypes.bfloat16: tc_atol, tc_rtol = 1e-2, 3e-3
else: tc_atol, tc_rtol = 5e-3, 1e-4
np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol)
def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_opt:int=0, ensure_triggered:bool=True):
a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in)
r = a.matmul(b, acc_dtype=dtype_out)
sched = create_schedule([r.lazydata])
realized_ast = sched[-1].ast
k = Kernel(realized_ast)
k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt)
k.linearize()
wmmas = len([uop for uop in k.uops if uop.op is UOps.WMMA])
tcs = len([x for x in k.applied_opts if x.op is OptOps.TC])
if ensure_triggered:
assert wmmas > 0, "tensor core not triggered"
assert tcs == 1, "tensor core opt not included"
else:
assert wmmas == 0, "tensor core is incorrectly triggered"
assert tcs == 0, "tensor core opt is incorrectly included"
class TestLinearizer(unittest.TestCase):
def test_arg_dedup(self):
a, b = Tensor.randn(4), Tensor.randn(4)
np_a, np_b = a.numpy(), b.numpy()
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),))))
lowered = list(lower_schedule(create_schedule([c.lazydata])))
for ei in lowered: ei.run()
rawbufs = lowered[-1].bufs
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.base.realized, b.lazydata.base.realized}
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)
def test_load_removed(self):
a = Tensor.rand(1).realize()
b = Tensor.rand(1).realize()
ta = Tensor.where(Tensor(True), a, b).numpy()
tb = Tensor.where(Tensor(False), a, b).numpy()
np.testing.assert_equal(a.numpy(), ta)
np.testing.assert_equal(b.numpy(), tb)
def test_multioutput(self):
dtype, st = dtypes.int, ShapeTracker.from_shape((8,))
g0, g1, g2, g3 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtype), arg=i) for i in range(4)]
a = UOp(UOps.LOAD, dtype, (g2, st.to_uop()))
b = UOp(UOps.LOAD, dtype, (g3, st.to_uop()))
out0 = UOp(UOps.STORE, None, (g0, st.to_uop(), a + b))
out1 = UOp(UOps.STORE, None, (g1, st.to_uop(), a * b))
sink = UOp(UOps.SINK, src=(out0, out1))
a_t = Tensor.full(st.shape, 2).contiguous().realize()
b_t = Tensor.full(st.shape, 3).contiguous().realize()
lin = helper_linearizer_ast(sink, [a_t, b_t], wanna_output=[a_t.numpy()+b_t.numpy(), a_t.numpy()*b_t.numpy()])[0]
stores = [u for u in lin.uops if u.op is UOps.STORE]
mutable_bufs = dedup(flatten([[x for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL] for u in stores]))
assert len(mutable_bufs) == len(stores) == 2
assert [u.arg for u in mutable_bufs] == [0, 1]
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
def test_multireduce(self):
Tensor.manual_seed(0)
x = Tensor.randn(32, dtype=dtypes.float).realize()
st_x = x.lazydata.st
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
first_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((1, 32)).expand((32, 32)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (ReduceOps.SUM, (1,)))
second_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((32, 1)).to_uop()))
diff = second_x - first_reduce
second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (ReduceOps.SUM, (0,)))
store = UOp(UOps.STORE, None, (g0, ShapeTracker.from_shape((1, 1)).to_uop(), second_reduce))
sink = UOp(UOps.SINK, src=(store,))
opts = [
[Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], # grouping
[Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 8)],
[Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 16)],
[Opt(OptOps.GROUPTOP, 0, 32), Opt(OptOps.GROUPTOP, 0, 32)],
[Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2)], # unroll reduce
[Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)],
[Opt(OptOps.UNROLL, 0, 8), Opt(OptOps.UNROLL, 1, 8)] if Device.DEFAULT not in {"NV", "METAL"} else [], # can't do float8,
[Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)], # grouping + unrolling
[Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)],
[Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UNROLL, 2, 8), Opt(OptOps.UNROLL, 2, 8)],
[Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 0, 8)],
]
wanna_output = (x.numpy()-x.numpy().sum(-1, keepdims=True)).sum(-1).reshape(1,1)
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
for l in lins:
ranges = [u.op for u in l.uops if (u.op is UOps.RANGE and u.arg[1]) or (u.op is UOps.ENDRANGE and u.src[0].arg[1])]
for i,u in enumerate(ranges):
if i == 0: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
def test_mid_dim_multireduce(self):
Tensor.manual_seed(0)
x = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
st_x = x.lazydata.st
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
first_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((27, 1, 32, 5)).expand((27, 32, 32, 5)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (ReduceOps.SUM, (2,)))
second_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((27, 32, 1, 5)).to_uop()))
diff = second_x - first_reduce
second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (ReduceOps.SUM, (1,)))
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
sink = UOp(UOps.SINK, src=(store,))
opts = [
# locals
[Opt(OptOps.LOCAL, 0, 3)],
[Opt(OptOps.LOCAL, 0, 9)],
[Opt(OptOps.LOCAL, 0, 27)],
# grouping
[Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)],
[Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 8)],
[Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 16)],
[Opt(OptOps.GROUPTOP, 0, 32), Opt(OptOps.GROUPTOP, 0, 32)],
# # unroll
[Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2)],
[Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)],
[Opt(OptOps.UNROLL, 0, 8), Opt(OptOps.UNROLL, 1, 8)] if Device.DEFAULT not in {"NV", "METAL"} else [],
# # upcasting
[Opt(OptOps.UPCAST, 0, 3)],
[Opt(OptOps.UPCAST, 0, 9)],
# locals with grouping
[Opt(OptOps.LOCAL, 0, 3), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)],
# locals with unroll
[Opt(OptOps.LOCAL, 0, 3), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2)],
# locals with upcasting
[Opt(OptOps.LOCAL, 0, 3), Opt(OptOps.UPCAST, 0, 9)],
# grouping with unrolling
[Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)],
[Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UNROLL, 2, 8), Opt(OptOps.UNROLL, 2, 8)],
# grouping with upcasting
[Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UPCAST, 0, 3)],
# locals with grouping with unroll
[Opt(OptOps.LOCAL, 0, 3), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)],
[Opt(OptOps.LOCAL, 0, 3), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UNROLL, 2, 8), Opt(OptOps.UNROLL, 2, 8)],
# locals with grouping with upcasting
[Opt(OptOps.LOCAL, 0, 3), Opt(OptOps.UPCAST, 0, 3), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)],
[Opt(OptOps.LOCAL, 0, 9), Opt(OptOps.UPCAST, 0, 3), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)],
# grouping with unrolling and upcasting
[Opt(OptOps.UPCAST, 0, 3), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)],
[Opt(OptOps.UPCAST, 0, 3), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UNROLL, 2, 8), Opt(OptOps.UNROLL, 2, 8)],
# locals + grouping + unrolling + upcasting
[Opt(OptOps.LOCAL, 0, 3), Opt(OptOps.UPCAST, 0, 3), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2),
Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)],
]
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
for l in lins:
ranges = [u.op for u in l.uops if (u.op is UOps.RANGE and u.arg[1]) or (u.op is UOps.ENDRANGE and u.src[0].arg[1])]
for i,u in enumerate(ranges):
if i == 0: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
def test_triple_multireduce(self):
Tensor.manual_seed(0)
x0 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
x1 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
x2 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
g0, g1, g2, g3 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(4)]
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x0.lazydata.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (ReduceOps.SUM, (3,)))
second_x = UOp(UOps.LOAD, dtypes.float, (g2, x1.lazydata.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5)).to_uop()))
diff = (second_x-first_reduce)
second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (ReduceOps.SUM, (2,)))
third_x = UOp(UOps.LOAD, dtypes.float, (g3, x2.lazydata.st.reshape((27, 32, 1, 1, 5)).to_uop()))
mul = (third_x*second_reduce)
third_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (mul,), (ReduceOps.SUM, (1,)))
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 1, 5)).to_uop(), third_reduce))
sink = UOp(UOps.SINK, src=(store,))
wanna_output = (x2.numpy()*(x1.numpy()-x0.numpy().sum(axis=1, keepdims=True)).sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,1,5)
lins = helper_linearizer_ast(sink, [x0,x1,x2], wanna_output=[wanna_output])
for l in lins:
ranges = [u.op for u in l.uops if (u.op is UOps.RANGE and u.arg[1]) or (u.op is UOps.ENDRANGE and u.src[0].arg[1])]
for i,u in enumerate(ranges):
if i == 0: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
def test_double_reduce_multireduce(self):
Tensor.manual_seed(0)
x = Tensor.randn(8, 32, 8, 16, dtype=dtypes.float).realize()
st = x.lazydata.st
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
first_x = UOp(UOps.LOAD, dtypes.float, (g1, st.reshape((8, 1, 32, 8, 1, 16)).expand((8, 32, 32, 8, 16, 16)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (ReduceOps.SUM, (2, 5)))
second_x = UOp(UOps.LOAD, dtypes.float, (g1, st.reshape((8, 32, 1, 8, 16, 1)).to_uop()))
squares = (second_x-first_reduce)
squares_sum = UOp(UOps.REDUCE_AXIS, dtypes.float, (squares,), (ReduceOps.SUM, (1, 4)))
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((8, 1, 1, 8, 1, 1)).to_uop(), squares_sum,))
sink = UOp(UOps.SINK, src=(store,))
wanna_output = (x.numpy()-x.numpy().sum(axis=(1,3), keepdims=True)).sum(axis=(1,3)).reshape((8,1,1,8,1,1))
opts = [
# openCL / GPU=1 is 256 max threads
# grouping
[Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], # first dim of both reduces
[Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 3, 2)], # both dims of the second reduce
[Opt(OptOps.GROUPTOP, 2, 2), Opt(OptOps.GROUPTOP, 3, 2)], # second dim of both reduces
[Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.GROUPTOP, 3, 2)], # both dims of the first reduce
# group all reduce dims
[Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.GROUPTOP, 2, 2), Opt(OptOps.GROUPTOP, 3, 2)],
# checking how it works with 2 grouped reduces + unrolling
[Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.GROUPTOP, 2, 4), Opt(OptOps.GROUPTOP, 3, 4),
Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)],
# Checking how it works with 2 grouped reduces + locals.
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 0, 4),
Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.GROUPTOP, 2, 2), Opt(OptOps.GROUPTOP, 3, 2)],
# Checking how it works with 2 grouped reduces + locals + unroll.
[Opt(OptOps.LOCAL, 0, 2),
Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.GROUPTOP, 2, 4), Opt(OptOps.GROUPTOP, 3, 4),
Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)],
# Checking how it works with 2 grouped reduces + locals + upcast.
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.UPCAST, 0, 2),
Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.GROUPTOP, 2, 2), Opt(OptOps.GROUPTOP, 3, 2)],
# Checking how it works with 2 grouped reduces + locals + upcast + unroll.
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.UPCAST, 0, 2),
Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.GROUPTOP, 2, 4), Opt(OptOps.GROUPTOP, 3, 4),
Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)],
]
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
for l in lins:
ranges = [u.op for u in l.uops if (u.op is UOps.RANGE and u.arg[1]) or (u.op is UOps.ENDRANGE and u.src[0].arg[1])]
for i,u in enumerate(ranges):
if i < 2: continue
assert ranges[i-2] != u or ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-2], ranges[i-1], {u}}"
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
def test_partial_opt_multireduce(self):
# check how it works with one reduce optimized and one unoptimized
Tensor.manual_seed(0)
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).softmax(1).realize()
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))))
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,))
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 15, 1, 5))))
diff = (second_x-first_reduce)
second_reduce = LazyOp(ReduceOps.SUM, (diff,), (1,))
store = LazyOp(BufferOps.STORE, (second_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((27, 1, 1, 5))))
opts = [
[Opt(OptOps.GROUPTOP, 0, 3)], # grouping
[Opt(OptOps.GROUPTOP, 1, 3)],
[Opt(OptOps.GROUPTOP, 0, 15)],
[Opt(OptOps.GROUPTOP, 1, 15)],
[Opt(OptOps.UNROLL, 0, 3)],
[Opt(OptOps.UNROLL, 1, 3)],
]
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
lins = helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output], opts=opts)
for l in lins:
ranges = [u.op for u in l.uops if (u.op is UOps.RANGE and u.arg[1]) or (u.op is UOps.ENDRANGE and u.src[0].arg[1])]
for i,u in enumerate(ranges):
if i == 0: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
def test_multireduce_with_parallel(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32, dtype=dtypes.float).realize()
x_p = Tensor.randn(4, 32, dtype=dtypes.float).realize()
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)]
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop()))
first_x_p = UOp(UOps.LOAD, dtypes.float, (g2, x_p.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (ReduceOps.SUM, (2,)))
first_reduce_p = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x_p.alu(UnaryOps.EXP2),), (ReduceOps.SUM, (2,)))
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1)).to_uop()))
diff = (second_x-(first_reduce + first_reduce_p))
second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (ReduceOps.SUM, (1,)))
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((4, 1, 1)).to_uop(), second_reduce))
sink = UOp(UOps.SINK, src=(store,))
opts = [
# [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], # grouping
# [Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 8)],
# [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 16)],
# [Opt(OptOps.GROUPTOP, 0, 32), Opt(OptOps.GROUPTOP, 0, 32)],
[Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2)], # unroll reduce
[Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)],
[Opt(OptOps.UNROLL, 0, 8), Opt(OptOps.UNROLL, 1, 8)] if Device.DEFAULT not in {"NV", "METAL"} else [], # can't do float8,
# [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)], # grouping + unrolling
# [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)],
# [Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UNROLL, 2, 8), Opt(OptOps.UNROLL, 2, 8)],
# [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 0, 8)],
]
wanna_output = (x.numpy()-(x.numpy().sum(-1, keepdims=True)+np.exp2(x_p.numpy()).sum(-1, keepdims=True))).sum(-1).reshape(4, 1,1)
lins = helper_linearizer_ast(sink, [x,x_p], wanna_output=[wanna_output], opts=opts)
for l in lins:
ranges = [u.op for u in l.uops if (u.op is UOps.RANGE and u.arg[1]) or (u.op is UOps.ENDRANGE and u.src[0].arg[1])]
for i,u in enumerate(ranges):
if i == 0: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
def test_multiout_multireduce(self):
# check how multireduce works with multioutput
Tensor.manual_seed(0)
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)]
first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (ReduceOps.SUM, (2,)))
second_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
diff = (second_x-first_reduce)
second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (ReduceOps.SUM, (1,)))
store0 = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
second_out = second_reduce * UOp(UOps.CONST, dtypes.float, (ShapeTracker.from_shape((27,1,1,5)).to_uop(),), 1/15)
store1 = UOp(UOps.STORE, src=(g1, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_out))
sink = UOp(UOps.SINK, src=(store0, store1))
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
helper_linearizer_ast(sink, [x], wanna_output=[wanna_output, wanna_output/15])
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.expectedFailure
def test_multiout_intermediate_multireduce(self):
# check how it outputing at different stages of the multireduce works
# TODO: Fails because the stores shapes do not match: store1.shape = (27,15,1,5) != store0.shape = (27,1,1,5)
# so the output shapes are different (FAIL!),
# if we change the shape of store1 to be contiguous, it will match store0 but not the value it's storing (FAIL!)
Tensor.manual_seed(0)
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))))
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,))
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((27, 15, 1, 5))))
diff = (second_x-first_reduce)
second_reduce = LazyOp(ReduceOps.SUM, (diff,), (1,))
store0 = LazyOp(BufferOps.STORE, (second_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((27, 1, 1, 5))))
store1 = LazyOp(BufferOps.STORE, (first_reduce,), MemBuffer(1, dtypes.float, ShapeTracker(views=(View(shape=(27,15,1,5), strides=(5,0,1,1), offset=0, mask=None, contiguous=False),)))) # noqa: E501
wanna_output0 = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
wanna_output1 = x.numpy().sum(axis=1).reshape(27,1,1,5)
ast = LazyOp(MetaOps.KERNEL, (store0,store1))
k = Kernel(ast)
prg = CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT))
inbufs = [x.lazydata.base.buffer]
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.arg.st.size, out.arg.dtype).allocate() for out in ast.src]
prg.exec(outbufs+inbufs)
np.testing.assert_allclose(np.frombuffer(outbufs[0].as_buffer(), _to_np_dtype(outbufs[0].dtype)).reshape(27,1,1,5), wanna_output0)
np.testing.assert_allclose(np.frombuffer(outbufs[1].as_buffer(), _to_np_dtype(outbufs[1].dtype))[:135].reshape(27,1,1,5), wanna_output1)
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
def test_complete_unroll_multireduce(self):
Tensor.manual_seed(0)
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5))))
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,))
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 3, 1, 5))))
diff = (second_x-first_reduce)
second_reduce = LazyOp(ReduceOps.SUM, (diff,), (1,))
store = LazyOp(BufferOps.STORE, (second_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((27, 1, 1, 5))))
opts = [[Opt(OptOps.UNROLL, 0, 3), Opt(OptOps.UNROLL, 0, 3)]]
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output], opts=opts)
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
def test_upcast_multireduce(self):
Tensor.manual_seed(0)
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5))))
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,))
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 3, 1, 5))))
diff = (second_x-first_reduce)
second_reduce = LazyOp(ReduceOps.SUM, (diff,), (1,))
store = LazyOp(BufferOps.STORE, (second_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((27, 1, 1, 5))))
opts = [[Opt(OptOps.UPCAST, 0, 3)]]
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output], opts=opts)
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skip("can't group with multiple reduces yet")
def test_early_endif(self):
# make sure the if block of a grouped reduce can be closed early and the result loaded back in
Tensor.manual_seed(0)
x = Tensor.randn(27, 12, 5, dtype=dtypes.float).realize()
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5))))
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,))
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 12, 1, 5))))
diff = (second_x-first_reduce)
second_reduce = LazyOp(ReduceOps.SUM, (diff,), (1,))
store = LazyOp(BufferOps.STORE, (second_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((27, 1, 1, 5))))
opts = [[Opt(OptOps.GROUPTOP, 0, 3), Opt(OptOps.GROUPTOP, 1, 3)]]
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output], opts=opts)
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
def test_mean_std_multireduce(self):
Tensor.manual_seed(0)
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35))))
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,))
mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 35, 1)))) # noqa: E501
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 25, 35, 1))))
squares = (second_x-mean)*(second_x-mean)
squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,))
variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 1)))) # noqa: E501
std = LazyOp(UnaryOps.SQRT, (variance,), None)
store = LazyOp(BufferOps.STORE, (std,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((15, 25, 1, 1))))
wanna_output = x.numpy().std(axis=2, ddof=0).reshape((15,25,1,1))
helper_linearizer_ast((store,), [x], wanna_output=[wanna_output])
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
def test_mean_std_multireduce_mid_dim(self):
Tensor.manual_seed(0)
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35))))
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,))
mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(0.04, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 35)))) # noqa: E501
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 25, 1, 35))))
squares = (second_x-mean)*(second_x-mean)
squares_sum = LazyOp(ReduceOps.SUM, (squares,), (1,))
variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(0.04, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 1, 1, 35)))) # noqa: E501
std = LazyOp(UnaryOps.SQRT, (variance,), None)
store = LazyOp(BufferOps.STORE, (std,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((15, 1, 1, 35))))
wanna_output = x.numpy().std(axis=1, ddof=0).reshape((15,1,1,35))
helper_linearizer_ast((store,), [x], wanna_output=[wanna_output])
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.expectedFailure
def test_mean_std_multireduce_multiout(self):
# TODO: Same error as in test_multiout_intermediate_multireduce
Tensor.manual_seed(0)
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35))))
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,))
mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 35, 1)))) # noqa: E501
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((15, 25, 35, 1))))
squares = (second_x-mean)*(second_x-mean)
squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,))
variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 1)))) # noqa: E501
std = LazyOp(UnaryOps.SQRT, (variance,), None)
store_mean = LazyOp(BufferOps.STORE, (mean,), MemBuffer(1, dtypes.float, ShapeTracker.from_shape((15,25,1,1))))
store_std = LazyOp(BufferOps.STORE, (std,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((15, 25, 1, 1))))
wanna_output = [x.numpy().std(axis=2, ddof=0).reshape(15,25,1,1), x.numpy().mean(axis=2).reshape(15,25,1,1)]
lins = helper_linearizer_ast((store_std,store_mean), [x], wanna_output=wanna_output)
for k in lins:
assert len([u for u in k.uops if u.op is UOps.DEFINE_ACC]) == 2, "got more than two accs (implies the kernel didn't reuse the mean reduce)"
@unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "ocelot/remu doesn't have multiple wave syncs yet")
def test_var_multireduce(self):
Tensor.manual_seed(0)
x = Tensor.randn(3, 27, 32, dtype=dtypes.float).realize()
# push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32))))
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,))
mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(0.03125, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 32, 1)))) # noqa: E501
# store = LazyOp(BufferOps.STORE, (mean,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 32, 1))))
# verify_lazyop(store)
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 32, 1))))
squares = (second_x-mean)*(second_x-mean)
squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,))
variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(0.03125, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 1, 1)))) # noqa: E501
store = LazyOp(BufferOps.STORE, (variance,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 1, 1))))
wanna_output = x.numpy().var(axis=2, ddof=0).reshape((3,27,1,1))
helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output])
# tinygrad ref
y_tiny = x.var(axis=2, correction=0).reshape(3,27,1,1)
np.testing.assert_allclose(y_tiny.numpy(), wanna_output, atol=1e-4, rtol=1e-4)
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
def test_softmax_multireduce(self):
x = Tensor.rand(4, 32).realize()
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32))))
max_x = LazyOp(op=ReduceOps.MAX, src=(first_x,), arg=(2,))
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((4, 32, 1,))))
centered_x = LazyOp(op=BinaryOps.ADD, src=(second_x, LazyOp(op=UnaryOps.NEG, src=(max_x,), arg=None)))
exp_x = LazyOp(op=UnaryOps.EXP2, src=(centered_x,))
sum_exp_x = LazyOp(op=ReduceOps.SUM, src=(exp_x,), arg=(1,))
# y = LazyOp(op=BinaryOps.MUL, src=(exp_x, LazyOp(op=UnaryOps.RECIP, src=(sum_exp_x,)))) # kernels cannot do a return to full shape
recip_sum_exp_x = LazyOp(op=UnaryOps.RECIP, src=(sum_exp_x,))
store = LazyOp(op=BufferOps.STORE, src=(recip_sum_exp_x,), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((4,1,1))))
expected = 1/np.exp2(x.numpy() - x.numpy().max(axis=-1, keepdims=True)).sum(axis=-1, keepdims=True).reshape(4,1,1)
helper_linearizer_ast((store,), [x], wanna_output=[expected])
# *** buildup to fused indexing
@unittest.skipIf(CI, "very slow because of recomputing")
def test_arange_expanded(self):
# Tensor.arange(16384) expanded such that output shape is (4, 16384, 256, 1)
# basically it's pushing the expand through this reduce:
tiny = Tensor.arange(16384).reshape(16384, 1).expand(4, 16384, 256).reshape(4, 16384, 256, 1)
real_arange = np.broadcast_to(np.arange(16384).reshape(16384, 1), (4, 16384, 256)).reshape(4, 16384, 256, 1)
# NOTE: this is stupidly recomputing because it's not fused, but it proves a point.
arange_input_st = ShapeTracker(views=(View(shape=(16385, 32767), strides=(0, 0), offset=0, mask=((0, 16385), (16383, 32767)), contiguous=False), \
View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False)))
arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384))
arange_axis = (3,)
arange = LazyOp(ReduceOps.SUM, (LazyOp(BufferOps.CONST, (), ConstBuffer(1, dtypes.int, arange_input_st)), ), arange_axis)
output_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape))
out = arange-LazyOp.const(1, dtypes.int, output_shape)
store = LazyOp(BufferOps.STORE, (out, ), MemBuffer(0, dtypes.int, st=ShapeTracker.from_shape(output_shape)))
helper_linearizer_ast((store, ), [], wanna_output=[real_arange])
with Context(DEBUG=0, NOOPT=0): np.testing.assert_equal(tiny.numpy(), real_arange)
@unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "very slow")
def test_indexing_multireduce(self):
arange_input_st = ShapeTracker(views=(View(shape=(16385, 32767), strides=(0, 0), offset=0, mask=((0, 16385), (16383, 32767)), contiguous=False), \
View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False)))
# TODO: do this arange broadcast in the scheduler
arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384))
arange_axis = (3,)
arange = LazyOp(ReduceOps.SUM, (LazyOp(BufferOps.CONST, (), ConstBuffer(1, dtypes.int, arange_input_st)), ), arange_axis)
arange_out_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape))
arange = arange-LazyOp.const(1, dtypes.int, arange_out_shape)
# p2: the indexing
dataset = Tensor.rand(16384, 256).realize()
data1 = MemBuffer(1, dataset.dtype, ShapeTracker.from_shape(dataset.shape).reshape((1, 16384, 256, 1)).expand(arange_out_shape))
idxs = Tensor([0,3,5,6]).realize()
data2 = MemBuffer(2, dtypes.int, ShapeTracker.from_shape((4,)+(1,)*(len(arange_out_shape)-1)).expand(arange_out_shape))
reduce_input = LazyOp(BufferOps.LOAD, (), data1)*LazyOp(UnaryOps.CAST, (arange.eq(LazyOp(BufferOps.LOAD, (), data2)),), dataset.dtype)
out = LazyOp(ReduceOps.SUM, (reduce_input, ), (1,))
output_shape = tuple(1 if i in out.arg else s for i,s in enumerate(arange_out_shape))
store = LazyOp(BufferOps.STORE, (out, ), MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape(output_shape)))
real_index = dataset.numpy()[idxs.numpy()].reshape(4, 1, 256, 1)
helper_linearizer_ast((store, ), [dataset, idxs], wanna_output=[real_index])
# AssertionError: repeated stores in uops
def test_argmax_multireduce_axis0(self):
t = Tensor.randn(10, 20).realize()
t_max = t.max((0,)).realize()
real_argmax = np.argmax(t.numpy(), axis=0, keepdims=False).reshape(1, 20, 1)
ast = LazyOp(MetaOps.KERNEL, arg=None, src=(
LazyOp(BufferOps.STORE, arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),))), src=( # noqa E501
LazyOp(BinaryOps.ADD, arg=None, src=(
LazyOp(BinaryOps.ADD, arg=None, src=(
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=10, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()), # noqa E501
LazyOp(UnaryOps.NEG, arg=None, src=(
LazyOp(ReduceOps.MAX, arg=(0,), src=(
LazyOp(BinaryOps.MUL, arg=None, src=(
LazyOp(UnaryOps.CAST, arg=dtypes.int, src=(
LazyOp(BinaryOps.CMPNE, arg=None, src=(
LazyOp(BinaryOps.CMPNE, arg=None, src=(
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(20, 1, 0), offset=0, mask=None, contiguous=True),))), src=()), # noqa E501
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),))), src=()),)), # noqa E501
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)), # noqa E501
LazyOp(BinaryOps.ADD, arg=None, src=(
LazyOp(ReduceOps.SUM, arg=(2,), src=(
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 20, 10), strides=(1, 0, 20), offset=0, mask=None, contiguous=False)))), src=()),)), # noqa E501
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=10, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),)),)),)), # noqa E501
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa E501
helper_linearizer_ast(ast, [t, t_max], wanna_output=[real_argmax])
def test_argmax_multireduce_flat(self):
t = Tensor.randn(10, 20).realize()
t_max = t.max().realize()
real_argmax = np.argmax(t.numpy())
ast = LazyOp(MetaOps.KERNEL, arg=None, src=(
LazyOp(BufferOps.STORE, arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),))), src=( # noqa E501
LazyOp(BinaryOps.ADD, arg=None, src=(
LazyOp(BinaryOps.ADD, arg=None, src=(
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=200, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),))), src=()), # noqa E501
LazyOp(UnaryOps.NEG, arg=None, src=(
LazyOp(ReduceOps.MAX, arg=(0,), src=(
LazyOp(BinaryOps.MUL, arg=None, src=(
LazyOp(UnaryOps.CAST, arg=dtypes.int, src=(
LazyOp(BinaryOps.CMPNE, arg=None, src=(
LazyOp(BinaryOps.CMPNE, arg=None, src=(
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(200, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),))), src=()), # noqa E501
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))), src=()),)), # noqa E501
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)), # noqa E501
LazyOp(BinaryOps.ADD, arg=None, src=(
LazyOp(ReduceOps.SUM, arg=(1,), src=(
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(201, 399), strides=(0, 0), offset=0, mask=((0, 201), (199, 399)), contiguous=False), View(shape=(200, 200), strides=(1, 400), offset=0, mask=None, contiguous=False)))), src=()),)), # noqa E501
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=200, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),)),)),)), # noqa E501
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),))), src=()),)),)),)) # noqa E501
helper_linearizer_ast(ast, [t, t_max], wanna_output=[real_argmax])
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
def test_padto_sum_multireduce(self):
Tensor.manual_seed(0)
N = 17
x = Tensor.rand(N, N).realize()
opts = [
[Opt(OptOps.PADTO, 0, 32)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
# TODO: multireduce pads
# causes an issue because the acc won't be masked in the second reduce
# [Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)]
]
x_ld0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((1, N, N)).expand((N,N,N))))
x_ld1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((N, 1, N))))
r0 = LazyOp(ReduceOps.SUM, (x_ld0,), (1,))
r1 = LazyOp(ReduceOps.SUM, (LazyOp(BinaryOps.ADD, (x_ld1, LazyOp(op=UnaryOps.NEG, src=(r0,), arg=None)),),), (0,))
store = LazyOp(BufferOps.STORE, (r1, ), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((1,1,N))))
helper_linearizer_ast((store,), [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=0, keepdims=True)).sum(axis=0).reshape(1,1,N)], opts=opts)
x_ld0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N))))
x_ld1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((N, N, 1))))
r0 = LazyOp(ReduceOps.SUM, (x_ld0,), (2,))
r1 = LazyOp(ReduceOps.SUM, (LazyOp(BinaryOps.ADD, (x_ld1, LazyOp(op=UnaryOps.NEG, src=(r0,), arg=None)),),), (1,))
store = LazyOp(BufferOps.STORE, (r1, ), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((N,1,1))))
helper_linearizer_ast((store,), [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(N,1,1)], opts=opts)
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
def test_padto_max_multireduce(self):
Tensor.manual_seed(0)
N = 17
x = Tensor.rand(N, N).realize()
opts = [
[Opt(OptOps.PADTO, 0, 32)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),]
]
x_ld0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((1, N, N)).expand((N,N,N))))
x_ld1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((N, 1, N))))
r0 = LazyOp(ReduceOps.MAX, (x_ld0,), (1,))
r1 = LazyOp(ReduceOps.MAX, (LazyOp(BinaryOps.ADD, (x_ld1, LazyOp(op=UnaryOps.NEG, src=(r0,), arg=None)),),), (0,))
store = LazyOp(BufferOps.STORE, (r1, ), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((1,1,N))))
helper_linearizer_ast((store,), [x], wanna_output=[(x.numpy()-x.numpy().max(axis=0, keepdims=True)).max(axis=0).reshape(1,1,N)], opts=opts)
x_ld0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N))))
x_ld1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((N, N, 1))))
r0 = LazyOp(ReduceOps.MAX, (x_ld0,), (2,))
r1 = LazyOp(ReduceOps.MAX, (LazyOp(BinaryOps.ADD, (x_ld1, LazyOp(op=UnaryOps.NEG, src=(r0,), arg=None)),),), (1,))
store = LazyOp(BufferOps.STORE, (r1, ), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((N,1,1))))
helper_linearizer_ast((store,), [x], wanna_output=[(x.numpy()-x.numpy().max(axis=1, keepdims=True)).max(axis=1).reshape(N,1,1)], opts=opts)
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
def test_padto_where_multireduce(self):
# ternary operators try to use both ridxs
# we need to make sure the ternary operators nest properly
N = 17
x = Tensor.rand(N, N).realize()
a = Tensor.rand(1, 1).realize()
b = Tensor.rand(1, 1).realize()
opts = [[Opt(OptOps.PADTO, 0, 32)],[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],]
# TODO: these large ASTs are suboptimal but we need this until the scheduler can fuse these
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=1,keepdims=True), a.numpy(), b.numpy())).sum(axis=1),0.0,1.0).reshape((N,1,1)) # noqa: E501
ld0 = x.lazydata.st.reshape((N, 1, N)).expand((N,N,N))
ld1 = x.lazydata.st.reshape((N, N, 1))
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.5*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,1,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ld1)),LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.75*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,N,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ld0)),), arg=(2,)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,N,1)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,N,1)))),)),)),), arg=(1,)),)),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,1,1)))),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,1,1)))),)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((N,1,1)))) # noqa: E501
helper_linearizer_ast((ast,), [x,a,b], opts=opts, wanna_output=[wanna_output])
ld0 = x.lazydata.st.reshape((1, N, N)).expand((N,N,N))
ld1 = x.lazydata.st.reshape((N, 1, N))
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=0,keepdims=True), a.numpy(), b.numpy())).sum(axis=0),0.0,1.0).reshape(1,1,N) # noqa: E501
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.5*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((1,1,N)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ld1)),LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.75*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,1,N)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ld0)),), arg=(1,)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,1,N)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,1,N)))),)),)),), arg=(0,)),)),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((1,1,N)))),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((1,1,N)))),)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,N)))) # noqa: E501
helper_linearizer_ast((ast,), [x,a,b], opts=opts, wanna_output=[wanna_output])
# # pad reduce axis
# helper_linearizer_ast((ast,), [x,a,b], opts=[[Opt(OptOps.PADTO, 1, 32)],], wanna_output=[wanna_output])
# ld0 = x.lazydata.st.reshape((1,1,N,N)).expand((N,N,N,N))
# ld1 = x.lazydata.st.reshape((N,N,1,1))
# wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(keepdims=True), a.numpy(), b.numpy())).sum(keepdims=True),0.0,1.0).reshape((1,1,1,1))# noqa: E501
# ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.5*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ld1)),LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.75*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)).expand((N,N,1,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ld0)),), arg=(2,3,)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)).expand((N,N,1,1)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)).expand((N,N,1,1)))),)),)),), arg=(0,1,)),)),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)))),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)))),)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)))) # noqa: E501
# helper_linearizer_ast((ast,), [x,a,b], opts=[[Opt(OptOps.PADTO, 0, 32)],], wanna_output=[wanna_output])
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
def test_end_local(self):
load = MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker.from_shape((32,)))
store = MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker.from_shape((1,)))
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, arg=load),), arg=(0,)),), arg=store),
load_t = Tensor.full(load.st.shape, 1).contiguous().realize()
k = helper_linearizer_ast(ast, [load_t], wanna_output=[load_t.numpy().sum()])[1]
self.assertEqual(k.uops[-1].op, UOps.ENDIF)
self.assertLess(k.uops.index([x for x in k.uops if x.op is UOps.STORE][-1]), k.uops.index(k.uops[-1]))
def test_two_nested_range(self):
a = Tensor.randn(2, ).realize()
out = a.reshape(2, 1).expand(2, 3).sum()
lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
assert len(ranges) == 1 # NOTE: it collapses now
# RANGE -> LOAD -> RANGE -> PHI
#assert any(x.op is UOps.LOAD for x in lin.uops[ranges[0]:ranges[1]])
def test_three_nested_range(self):
a = Tensor.randn(2, ).realize()
out = a.reshape(2, 1).expand(2, 3).expand(2, 2, 3).sum()
lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
assert len(ranges) == 1 # NOTE: it collapses now
# RANGE -> RANGE -> LOAD -> RANGE -> PHI
# NOTE: nothing should toposort between the first two ranges
#assert ranges[0]+1 == ranges[1]
#assert any(x.op is UOps.LOAD for x in lin.uops[ranges[1]:ranges[2]])
def test_two_nested_range_alt_indexing(self):
a = Tensor([2, 2]).realize()
out = a.reshape(2, 1).pad(((1, 1), (1, 1)), 2).sum()
lin = helper_linearizer_opt(out, wanna_output=[24])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
# RANGE -> ALU -> RANGE -> ALU + LOAD -> PHI
assert any(x.op is UOps.ALU for x in lin.uops[ranges[0]:ranges[1]])
assert not any(x.op is UOps.LOAD for x in lin.uops[ranges[0]:ranges[1]])
assert any(x.op in {UOps.ALU, UOps.LOAD} for x in lin.uops[ranges[1]:])
def test_range_outer_op_before_phi(self):
a = Tensor.randn(4, 1).realize()
b = Tensor.randn(1, 1).realize()
out = (a + b[0]).sum() + b[0]
lin = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
# LOAD -> RANGE -> LOAD -> PHI
assert lin.uops[ranges[0]-2].op is UOps.LOAD
def test_range_outer_op_before_phi_nested_range(self):
a = Tensor.randn(2, ).realize()
b = Tensor.randn(1, 1).realize()
out = (a.reshape(2, 1).expand(2, 3) + b[0]).sum() + b[0]
lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
assert len(ranges) == 1 # NOTE: it collapses now
#if getenv("PTX"):
# LOAD -> RANGE -> CAST -> ALU -> ALU -> LOAD -> ALU -> RANGE -> ALU -> PHI
# assert lin.uops[ranges[0]-2].op is UOps.LOAD
# assert ranges[1] == ranges[0]+6
# assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [UOps.LOAD, UOps.ALU]
# LOAD -> RANGE -> LOAD -> ALU -> RANGE -> PHI
#else:
# assert lin.uops[ranges[0]-2].op is UOps.LOAD
# assert ranges[1] == ranges[0]+3
# assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [UOps.LOAD, UOps.ALU]
def test_range_outer_op_after_phi(self):
a = Tensor.randn(4, 1).realize()
out = a.sum() * a.sum()
lin = helper_linearizer_opt(out, wanna_output=[a.numpy().sum()*a.numpy().sum()])[0]
# RANGE -> LOAD -> PHI -> ALU
end = max(i for i,u in enumerate(lin.uops) if u.op is UOps.ENDRANGE)
assert lin.uops[end+1].op is UOps.ALU
def test_range_outer_op_after_phi_nested_range(self):
a = Tensor.randn(2, ).realize()
out = a.reshape(2, 1).expand(2, 3).sum() + a.reshape(2, 1).expand(2, 3).sum()
lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3))).sum()*2])[0]
# RANGE -> LOAD -> PHI -> ALU
end = max(i for i,u in enumerate(lin.uops) if u.op is UOps.ENDRANGE)
assert lin.uops[end+1].op is UOps.ALU
def test_load_dedup(self):
# for different leaves in the AST, the same loads may occur.
a = Tensor.randn(4).realize()
# these are of size 3 to avoid float4 coalesce
r = a[:-1] + a[1:]
k = Kernel(create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_loads = len([uop for uop in k.uops if uop.op is UOps.LOAD])
assert num_loads <= 4, "more load uops than needed"
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
def test_load_cache_const_bufs(self):
# make sure const buffers are differentiated from local and mem buffers
ST, DT = ShapeTracker(views=(View(shape=((1,)), strides=(0, 0), offset=0, mask=None, contiguous=False),)), dtypes.int
VAL = LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=2, dtype=DT, st=ST))
# data1[0] + VAL
a = LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=DT, st=ST)), VAL))
# (literal const 1) + VAL
b = LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=DT, st=ST)), VAL))
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(a,b)),), arg=MemBuffer(idx=0, dtype=DT, st=ST))
lin = Kernel(ast)
lin.linearize()
assert len(lin.uops) <= 7, "too many uops"
a_bufs = [u.op for u in lin.uops[-1].src[2].src]
assert a_bufs == [UOps.LOAD, UOps.CONST]
def test_upcast_cse(self):
# when upcasting, within a subtree, there may be common expressions.
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = a.expand([2]) + b.expand([2])
k = Kernel(create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU])
assert num_ops <= 1, "more alu uops than needed"
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_reduce_upcast(self):
x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize()
r = Tensor.conv2d(x,w,padding=1).relu()
k = Kernel(create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.upcast()
k.linearize()
accs = [u for u in k.uops if u.op is UOps.DEFINE_ACC]
stores = [u for u in k.uops if u.op is UOps.STORE]
assert len(accs) == 0 # it's removed now
assert len(stores) == 1
assert stores[0].src[-1].dtype == dtypes.float.vec(4)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
def test_upcast_with_locals(self):
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
k = Kernel(create_schedule([r.lazydata])[-1].ast)
k.hand_coded_optimizations()
k.linearize()
accs = [u for u in k.uops if u.op is UOps.DEFINE_ACC]
stores = [u for u in k.uops if u.op is UOps.STORE]
# the first store is to lds and can be upcasted
assert accs[0].dtype == stores[0].src[-1].dtype == dtypes.float.vec(4)
assert stores[0].src[0].op is UOps.DEFINE_LOCAL
# the second store is to gds with no upcasts
assert stores[1].src[2].dtype == dtypes.float
assert stores[1].src[0].op is UOps.DEFINE_GLOBAL
def test_zero_fold(self):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = Tensor.stack(a, b)
k = Kernel(create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU])
assert num_ops == 0, "more alu uops than needed"
def test_sum_acc_dtype(self):
for tensor_dtype, acc_dtype in (
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
k = Kernel(create_schedule([a.lazydata])[-1].ast)
k.linearize()
local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC]
assert local[0].dtype == acc_dtype
def test_arg_acc_dtype(self):
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
k = Kernel(create_schedule([c.lazydata])[-1].ast)
k.linearize()
local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC]
assert local[0].dtype == expected_dtype
tests = (
(dtypes.float16, None, dtypes.float),
(dtypes.bfloat16, None, dtypes.float),
(dtypes.float, None, dtypes.float),
(dtypes.float16, dtypes.float16, dtypes.float16),
(dtypes.bfloat16, dtypes.bfloat16, dtypes.bfloat16),
(dtypes.float, dtypes.float16, dtypes.float16),
)
for tensor_dtype, acc_dtype, expected_dtype in tests:
a, b = Tensor.rand(8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, dtype=tensor_dtype)
helper_arg_acc_dtype(a.sum(acc_dtype=acc_dtype), expected_dtype)
helper_arg_acc_dtype(a.matmul(b, acc_dtype=acc_dtype), expected_dtype)
helper_arg_acc_dtype(Tensor.einsum("ki,ij->kj", a, b, acc_dtype=acc_dtype), expected_dtype)
d, w = Tensor.rand(4, 8, 8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, 2, 2, dtype=tensor_dtype)
helper_arg_acc_dtype(d.conv2d(w, acc_dtype=acc_dtype), expected_dtype)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores(self):
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
if (getenv("EMULATE_CUDA") or getenv("EMULATE_INTEL")) and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
helper_tc_allclose(tc.dims[0], tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores_padded(self):
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
pad = 1
# check that TC is triggered for TC_OPT=2
helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad,
tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=True)
# check that TC is not triggered for TC_OPT<2
helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad,
tc.dtype_in, tc.dtype_out, tc_opt=1, ensure_triggered=False)
helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad,
tc.dtype_in, tc.dtype_out, tc_opt=0, ensure_triggered=False)
# check excessive padding doesn't trigger padded TC in TC_OPT=2
helper_tc_ensure_uops_and_opts_count(tc.dims[0]//4, tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False)
helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1]//4, tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False)
helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1], tc.dims[2]//4, tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False)
# check correctness
helper_tc_allclose(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2)
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores_multi_reduce(self):
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
if tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16: continue
# this will be a M=G16, N=G32, M=G16, M=G16, K=R16, K=R16, K=R16 with 9 choices of TC MNK axes
golden_result = None
for axis in range(9):
a = Tensor.rand(16, 16, 29, 29, dtype=tc.dtype_in).realize()
b = Tensor.rand(32, 16, 16, 16, dtype=tc.dtype_in).realize()
c = a.conv2d(b, padding=1, acc_dtype=tc.dtype_out)
realized_ast, real_bufs = helper_realized_ast(c)
k = Kernel(realized_ast)
k.apply_tensor_cores(1, axis=axis, tc_opt=2)
k.linearize()
assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered"
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
prg = CompiledRunner(k.to_program())
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=_to_np_dtype(real_bufs[0].dtype)).data) # Zero to check that all values are filled
prg.exec(real_bufs)
result = np.frombuffer(real_bufs[0].as_buffer(), _to_np_dtype(real_bufs[0].dtype))
# ensure the results for each choice of axis matches
if golden_result is None: golden_result = np.frombuffer(real_bufs[0].as_buffer(), _to_np_dtype(real_bufs[0].dtype))
np.testing.assert_allclose(result, golden_result, atol=0.1, rtol=0.15)
# check that get_kernel_actions produces all 9 options
from tinygrad.engine.search import get_kernel_actions
tc_actions = [k for i, k in get_kernel_actions(Kernel(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC]
assert len(tc_actions) == 9, f"get_kernel_actions should contain 9 possible TC actions, only got {len(tc_actions)}"
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores_unroll_phi(self):
tc = Device[Device.DEFAULT].renderer.tensor_cores[0]
x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in)
r = x.matmul(y, acc_dtype=tc.dtype_out)
k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1]
for u in k.uops:
if u.op is UOps.WMMA:
assert u.src[-1].src[0].op != UOps.PHI
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores_unroll_casted_phi(self):
tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0]
x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in)
r = x.matmul(y, acc_dtype=tc.dtype_out)
k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1]
for u in k.uops:
if u.op is UOps.WMMA:
#assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2]))
assert u.src[-1].src[0].op != UOps.PHI
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores_unroll_casted_phi_with_children(self):
# all PHI children are outside the loop
tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0]
x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in)
r = x.matmul(y, acc_dtype=tc.dtype_out).relu()
k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1]
for u in k.uops:
if u.op is UOps.WMMA:
#assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2]))
assert u.src[-1].src[0].op != UOps.PHI
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_simple_unroll_no_between_phi_dependencies(self):
x, y = Tensor.rand(128, 128), Tensor.rand(128, 128)
r = (x@y).relu()
k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)]])[-1]
# the uops graph is RANGE -> DEFINE_ACC -> 4x ALU -> 4x PHI -> ENDRANGE
for u in k.uops:
if u.op is UOps.PHI:
assert u.src[1].op is UOps.ALU
# children of PHI are placed after ENDRANGE
if any(x.op is UOps.PHI for x in u.src):
end_range = [i for i, x in enumerate(k.uops) if x.op is UOps.ENDRANGE][0]
assert end_range < k.uops.index(u)
def test_grouped_dims(self):
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes):
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims)
loop_idxs = dedup(flatten([[y for y in sorted(list(x.sparents)) if y.op is UOps.SPECIAL] for x in idxs]))
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg[0])
sizes = [x.arg[1] for x in loop_idxs]
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
# TODO: add these back after uop symbolic
# for i in range(len(dims)):
# assert idxs[i].max+1 == dims[i], f"idxs[{i}] should have max {dims[i]-1}"
# for i in range(len(loop_idxs)):
# assert loop_idxs[i].expr.startswith(prefix), f"loop_idxs[{i}] must start with {prefix}"
# assert loop_idxs[i].max+1 == sizes[i], f"loop_idxs[{i}] should have max {sizes[i]-1}"
# no-op
_assert_grouped_dims("gidx", (2,), (16,16,16), False, [2])
_assert_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3])
# check reverse dims
_assert_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2])
_assert_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4])
# test splitting globals
# _assert_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
# _assert_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,4,12])
# _assert_grouped_dims("gidx", (64,3,4), (16,16,16), True, [12,16,4])
# _assert_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,4,24])
# collapse on onto the left most axis
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5])
_assert_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2])
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (32,16,16), True, [20,3,Variable("start_pos",1,2)])
# collapse on left-most available axis (the left most is too small)
_assert_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2])
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (16,16,16), False, [Variable("start_pos",1,2)*3,4,5])
# # dim too large and not factorable
# with self.assertRaises(AssertionError):
# get_grouped_dims("gidx", (23,), (16,16,16), False,)
# with self.assertRaises(AssertionError):
# get_grouped_dims("gidx", (128,3,4), (16,4,23), False,)
# too large for sizes
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
# # variable too large
# with self.assertRaises(AssertionError):
# get_grouped_dims("gidx", (Variable("start_pos",0,16),3,4), (16,16,16), False,)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
def test_default_global_reversed(self):
# shrink so that the dims do not collapse
t = Tensor.ones(5, 6, 7).contiguous().realize().shrink(((0, 4), (0, 5), (0, 6)))
k = helper_linearizer_opt(t+1)[0]
idxs = dedup([uop for uop in k.uops if uop.op is UOps.SPECIAL])
idxs = sorted(idxs, key=lambda uop: uop.arg[0])
assert idxs[0].arg == ('gidx0', 6), idxs[0].arg
assert idxs[1].arg == ('gidx1', 5), idxs[1].arg
assert idxs[2].arg == ('gidx2', 4), idxs[2].arg
def test_div_collapse(self):
def helper(t, msg, max_ops=0):
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is UOps.SINK]
assert len(sched) == 1
lin = Kernel(sched[0].ast)
assert sum(u.arg is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg
a = Tensor.empty((4,4))
b = Tensor.empty((4,4))
d = Tensor.empty((4,4))
c = (a*b)/b
helper(c, "found UnaryOps.RECIP in (a*b)/b operation")
c = a/a
helper(c, "found UnaryOps.RECIP in (a/a) operation")
c = (a/b)/d
helper(c, "found multiple UnaryOps.RECIP in (a/b)/d operation", 1)
def test_sum_collapse(self):
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is UOps.SINK]
assert len(sched) == 1
lin = Kernel(sched[0].ast)
assert not any(u.op is UOps.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
def test_assign_fold(self):
a = Tensor.ones(4, 4).contiguous().realize()
m = Tensor.ones(4, 4).shrink(((1, 2), None)).pad(((1, 2), None))
a.assign(a+m)
a.realize()
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
def test_where_fold(self):
a = Tensor.ones(4, 4).contiguous().realize()
b = a.shrink(((1, 2), None)).pad(((1, 2), None))
a.assign(b.where(2, a))
sched = create_schedule([a.lazydata])
assert len(sched) == 1
sched_copy = sched[:]
run_schedule(sched)
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
lin = Kernel(sched_copy[-1].ast)
lin.hand_coded_optimizations()
lin.linearize()
assert not any(u.arg == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded"
def test_phi_simplification(self):
def helper(t, max_ops=0):
k = helper_linearizer_opt(t)[-1]
uops = list(k.linearize().uops)
# ignore kernel optimized IF statements for now
if if_op:=next((u for u in uops if u.op is UOps.IF), None):
uops = uops[:uops.index(if_op)]
assert len(set([u.op for u in uops if u.op in {UOps.RANGE, UOps.SPECIAL}])) == 1, "has either specials or ranges, not both"
assert len([u for u in uops if u.op is UOps.PHI]) == 0, "PHI should have been simplified"
# TODO: once uops track min/max this will be fixed
#assert len([u for u in uops if u.arg is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops"
helper(Tensor.arange(5.5, (3.5*300), 3.5), max_ops=2)
helper(Tensor.arange(-1, -100, -5), max_ops=2)
# NOTE: both of these split the reduce (this just wasn't tracked before)
#helper(Tensor.arange(-3.2, 6.7, 0.64), max_ops=2)
#helper(Tensor.arange(256), max_ops=2)
helper(Tensor.arange(255), max_ops=2)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_grouped_store_phis(self):
"""
float4 acc0 = float4(0.0,0.0,0.0,0.0);
{
acc0 = // ...
}
*((device float4*)(data0+alu2)) = float4(acc0.x,acc0.y,acc0.z,acc0.w);
simplifies to:
*((device float4*)(data0+alu2)) = acc0;
"""
x, y = Tensor.randn(64,64), Tensor.randn(64,64)
out = x.matmul(y)
k = helper_linearizer_opt(out)[-1]
# check that the float4 cast collapses
store_vals = [u.src[-1] for u in k.uops if u.op is UOps.STORE]
for val in store_vals:
assert val.dtype == dtypes.float.vec(4) and val.op is not UOps.VECTORIZE
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_arange_opts(self):
a = Tensor.arange(128)
helper_linearizer_opt(a, [
[Opt(OptOps.GROUP, 0, 32)],
[Opt(OptOps.GROUPTOP, 0, 32)],
[Opt(op=OptOps.LOCAL, axis=0, amt=8)],
[Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=0, amt=0)],
[Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=8)],
[Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.UNROLL, axis=1, amt=4)], # noqa: E501
])
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_grouped_store_values(self):
x = Tensor.randn((4,3,6,6)).realize()
out = x.flip((0,1)).contiguous()
k = helper_linearizer_opt(out)[-1]
store_val = [u.src[-1] for u in k.uops if u.op is UOps.STORE][0]
assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not UOps.VECTORIZE
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_grouped_store_locals_and_globals(self):
x, y = Tensor.rand(128, 128), Tensor.rand(128, 128)
out = x@y
opt = [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8),
Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces
k = helper_linearizer_opt(out, opts=[opt])[-1]
def get_recursive(uop): return set.union(set(uop.src), [uop], *[get_recursive(v) for v in uop.src])
local_stores = [u for u in k.uops if u.op is UOps.STORE and any(x.op is UOps.DEFINE_LOCAL for x in get_recursive(u.src[0]))]
global_stores = [u for u in k.uops if u.op is UOps.STORE and any(x.op is UOps.DEFINE_GLOBAL for x in get_recursive(u.src[0]))]
barrier = [u for u in k.uops if u.op is UOps.BARRIER][0]
# check that the float4 cast collapses for all stores
for store in local_stores+global_stores:
assert store.src[2].dtype.count > 1 and store.src[2].op is not UOps.VECTORIZE
# # check the children's vins
# TODO: src ALU are not the same, should it?
# assert barrier.src == tuple(local_stores)
assert len([u for u in k.uops if u.op is UOps.IF and u.src[-1] == barrier]) == 1
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_grouped_store_local_only(self):
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
k = helper_linearizer_opt(r)[-1]
stores = [u for u in k.uops if u.op is UOps.STORE]
# the float4 value stores directly in lds and we skip upcast
assert stores[0].src[-1].dtype == dtypes.float.vec(4)
assert stores[0].src[-1].op is not UOps.VECTORIZE
# the global store doesn't change
assert stores[1].src[2].dtype == dtypes.float
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_skip_unmatching_upcasts(self):
Tensor.manual_seed(0)
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
opt = [
Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16),
Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)
]
k = helper_linearizer_ast(ast, [Tensor.randn(240*40).realize()], opts=[opt])[-1]
out = [u for u in k.uops if u.op is UOps.STORE][0]
assert out.src[-1].op is UOps.VECTORIZE and out.src[-1].dtype == dtypes.float.vec(4)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_skip_unmatching_upcasts_with_gep(self):
Tensor.manual_seed(0)
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
opt = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8),
Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8),
Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=2)]
k = helper_linearizer_ast(ast, [Tensor.randn(8*32).realize()], opts=[opt])[-1]
out = [u for u in k.uops if u.op is UOps.STORE][0]
assert out.src[-1].op is UOps.VECTORIZE and out.src[-1].dtype.count != 1
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4")
class TestFloat4(unittest.TestCase):
@staticmethod
def count_float4(k):
return (len([uop for uop in k.uops if uop.op is UOps.LOAD and uop.dtype == dtypes.float.vec(4)]),
len([uop for uop in k.uops if uop.op is UOps.STORE and len(uop.src) == 3 and uop.src[2].dtype == dtypes.float.vec(4)]))
@staticmethod
def count_half4(k):
return (len([uop for uop in k.uops if uop.op is UOps.LOAD and uop.dtype == dtypes.half.vec(4)]),
len([uop for uop in k.uops if uop.op is UOps.STORE and len(uop.src) == 3 and uop.src[2].dtype == dtypes.half.vec(4)]))
# TODO: express opts below as auto opts
def test_float4_basic(self):
a = Tensor.rand(2, 8).realize()
b = Tensor.rand(2, 8).realize()
c = a + b
s = create_schedule([c.lazydata])[0]
k = Kernel(s.ast)
k.hand_coded_optimizations()
k.linearize()
assert TestFloat4.count_float4(k) == (2, 1)
def test_float4_multidim(self):
a = Tensor.rand(2, 8).realize()
b = Tensor.rand(2, 8).realize()
c = a + b
s = create_schedule([c.lazydata])[0]
k = Kernel(s.ast)
k.shift_to(0, 4) # float4 dimension
k.shift_to(0, 2, insert_before=k.shape_len-1)
k.upcast()
k.upcast()
k.local_dims += 1
k.linearize()
assert TestFloat4.count_float4(k) == (4, 2)
def test_float4_unaligned_load(self):
a = Tensor.rand(9).realize().shrink(((1, 9),))
b = Tensor.rand(9).realize().shrink(((1, 9),))
c = a + b
s = create_schedule([c.lazydata])[0]
k = Kernel(s.ast)
k.hand_coded_optimizations() # implicit trigger float4 dim
k.linearize()
assert TestFloat4.count_float4(k) == (0, 1)
def test_float4_multidim_unaligned_load(self):
a = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),))
b = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),))
c = a + b
s = create_schedule([c.lazydata])[0]
k = Kernel(s.ast)
k.shift_to(len(k.full_unupcasted_shape)-1, 4) # manual trigger float4 dim
k.upcast()
k.shift_to(len(k.full_unupcasted_shape)-1, 2, insert_before=k.shape_len-1)
k.upcast()
k.local_dims += 1
k.linearize()
assert TestFloat4.count_float4(k) == (0, 2)
def test_float4_sometimes_unaligned(self):
a = Tensor.rand(1, 1, 8).realize()
b = Tensor.rand(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5)))
c = a.conv2d(b)
# only the first and last conv dot products are aligned in a, and b is never aligned, so no
# float4 should be emitted (the reduce axis of size 4 is the float4 axis here)
s = create_schedule([c.lazydata])[0]
k = Kernel(s.ast)
k.upcast()
k.linearize()
assert TestFloat4.count_float4(k) == (0, 0)
def test_float4_multidim_sometimes_unaligned(self):
a = Tensor.rand(1, 1, 7).realize()
b = Tensor.rand(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5)))
c = a.conv2d(b)
# the first conv dot product is aligned in a. If we upcast the output and reduce
# dimension, then we could do float4 for only that one set of loads, but we currently
# don't.
# UPDATE: now we do this fusion
s = create_schedule([c.lazydata])[0]
k = Kernel(s.ast)
k.upcast()
k.upcast()
k.linearize()
assert TestFloat4.count_float4(k) in {(0,1), (1,1)}
def test_float4_noncontiguous(self):
a = Tensor.rand(4, 2).realize()
b = Tensor.rand(4, 2).realize()
c = a + b
# we will upcast the top axis of sz 4. they should not be coalesced into float4,
# since the top axis is not contiguous.
s = create_schedule([c.lazydata])[0]
k = Kernel(s.ast)
k.shift_to(0, 4, top=True) # top axes are float4 axes
k.upcast()
k.linearize()
assert TestFloat4.count_float4(k) == (0, 0)
def test_float4_expand(self):
a = Tensor.rand(9).realize().shrink(((1, 9),))
b = Tensor.rand(2).realize().reshape((2, 1)).expand((2,4)).reshape((8,))
c = a + b
# we will upcast the top axis of sz 4. they should not be coalesced into float4,
# since the top axis is not contiguous.
s = create_schedule([c.lazydata])[0]
k = Kernel(s.ast)
k.shift_to(0, 4) # float4 axis
k.upcast()
k.linearize()
assert TestFloat4.count_float4(k) == (0, 1)
def test_float4_heterogeneous(self):
a = Tensor.rand(8).realize()
b = Tensor.rand(9).realize().shrink(((1, 9),))
c = a + b
# should float4 b but not a
s = create_schedule([c.lazydata])[0]
k = Kernel(s.ast)
k.shift_to(0, 4) # float4 axis
k.upcast()
k.linearize()
assert TestFloat4.count_float4(k) == (1, 1)
def test_half4_load_unrolled(self):
# from llama 7B shard 4 gpus
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
# TODO: fix this, expected might change but should be positive
for expected, opts in [
((7, 0), [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=3), Opt(op=OptOps.UNROLL, axis=0, amt=4)]),
((5, 0), [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4)]),
((2, 0), [Opt(op=OptOps.UNROLL, axis=0, amt=4)]),
]:
k = Kernel(ast)
for opt in opts: k.apply_opt(opt)
k.linearize()
count = TestFloat4.count_half4(k)
assert count == expected, f"{count=}, {expected=}"
def test_float4_acc(self):
# from float32 stable diffusion red tinybox
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(5, 6, 7)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
for expected, opts in [
(1, [Opt(op=OptOps.UPCAST, axis=2, amt=4)]),
(4, [Opt(op=OptOps.UPCAST, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4)]),
]:
k = Kernel(ast)
for opt in opts: k.apply_opt(opt)
k.linearize()
count = len([uop for uop in k.uops if uop.op is UOps.DEFINE_ACC and uop.dtype == dtypes.float.vec(4)])
assert count == expected, f"{count=}, {expected=}"
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
def test_float2_acc(self):
# from resnet
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True))))),), arg=dtypes.float),), arg=(4, 6)),), arg=dtypes.half),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
for expected, opts in [
(16, [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=3, amt=4)]), # noqa: E501
(4, [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=2)]),
]:
k = Kernel(ast)
for opt in opts: k.apply_opt(opt)
k.linearize()
count = len([uop for uop in k.uops if uop.op is UOps.DEFINE_ACC and uop.dtype == dtypes.float.vec(2)])
assert count == expected, f"{count=}, {expected=}"
class TestHandCodedOpts(unittest.TestCase):
def test_masked_upcast(self):
layer_1 = Tensor.cat(*[Tensor.rand(5) for _ in range(4)])
layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 20))
s = create_schedule([layer_2.lazydata])[-1]
k = Kernel(s.ast)
k.hand_coded_optimizations()
assert len(k.bufs) == 6 # make sure all ops are done in one kernel
# masked upcast should upcast masked axis of size 7
# masked upcast should not upcast large (20) last axis
# float4/other hcopt shouldn't upcast last axis, since we already have 7 upcast, and the last axis is not very contiguous
assert k.upcasted == 1 and k.full_shape[-1] == 7
def test_masked_upcast_wino(self):
monster = Tensor.stack(*[Tensor.stack(*[Tensor.rand(16) for _ in range(6)]) for _ in range(6)])
s = create_schedule([monster.lazydata])[-1]
k = Kernel(s.ast)
k.hand_coded_optimizations()
assert len(k.bufs) == 37 # make sure all ops are done in one kernel
# should upcast the two Tensor.stacks
assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2
def test_masked_upcast_wino_full(self):
with Context(WINO=1):
x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()
out = Tensor.conv2d(x,w, padding=1)
upcasts = []
wino_schedule = create_schedule([out.lazydata])
# collect upcasts of tile transform kernels
for i, si in enumerate(wino_schedule):
k = Kernel(si.ast)
k.hand_coded_optimizations()
if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel)
if len(k.bufs) < 36: continue # not a tile transform kernel (there's a permute kernel at the end)
upcasts.append(tuple(k.full_shape[k.shape_len - k.upcasted:k.shape_len]))
assert len(upcasts) == 3 # 3 transformation matrices
assert len(wino_schedule) <= 4 # 4 kernels
# this test case's inputs are too small, so one of the 4-stacks became a local, which is fine i guess
assert upcasts.count((6, 6)) == 2 #and upcasts.count((4, 4)) == 1
out.mean().backward()
backward_schedule = create_schedule([x.grad.lazydata, w.grad.lazydata])
for si in backward_schedule:
k = Kernel(si.ast)
k.hand_coded_optimizations()
k.linearize()
if len(k.bufs) < 20: continue # not a tile transform kernel
# heuristic number to make sure that at least some upcasts but not too many upcasts are being done
assert 6 <= prod(k.full_shape[k.shape_len - k.upcasted:k.shape_len]) <= 216
assert len(backward_schedule) <= 13 # just the current number, but it could be better
def test_masked_upcast_many(self):
layer_1 = Tensor.cat(Tensor.rand(3, 4), Tensor.rand(4, 4))
layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 7, 4))
layer_3 = Tensor.cat(layer_2.unsqueeze(0), Tensor.rand(6, 7, 7, 4))
k = helper_linearizer_opt(layer_3)[-1]
assert len(k.bufs) == 5 # make sure all ops are done in one kernel
# check that we don't do too many upcasts
assert prod(k.full_shape[k.shape_len-k.upcasted:k.shape_len]) <= 49
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
def test_matvec(self):
N = 128
a = Tensor.rand(1, N).realize()
b = Tensor.rand(N, N).realize()
c = a @ b
k = helper_linearizer_opt(c)[-1]
assert k.group_for_reduces == 1
assert k.local_dims == 1
assert k.upcasted == 1
def helper_linearizer_ast(ast:Union[Tuple[LazyOp, ...], LazyOp, UOp], inputs:List[Tensor], *args, **kwargs):
if not isinstance(ast, LazyOp) and not isinstance(ast, UOp): ast = LazyOp(MetaOps.KERNEL, ast)
inbufs = [x.lazydata.base.buffer for x in inputs]
ast = to_uop(ast) if isinstance(ast, LazyOp) else ast
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.st_arg.size, cast(DType,out.src[2].dtype)).allocate() \
for out in ast.src]
return _helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs)
def helper_linearizer_opt(r:Union[Tensor, List[Tensor]], *args, **kwargs):
realized_ast, real_bufs = helper_realized_ast(r)
return _helper_linearizer_opt_ast(realized_ast, real_bufs, *args, **kwargs)
def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:List[Buffer], opts=[],
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]) -> List[Kernel]:
lins: List[Kernel] = []
outbufs = [(real_bufs[i], lop.st_arg.shape) for i,lop in enumerate(realized_ast.src)]
def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT))
def check_opt(opts, create_k, expected_color_size):
k = create_k()
lins.append(k)
if apply_tc:
assert k.apply_tensor_cores(1, extra_opts=opts), "no tensor core triggered"
else:
for opt in opts:
k.apply_opt(opt)
if expected_color_size is not None:
cs = list(zip(k.colors(), k.full_shape))
assert cs == expected_color_size, f"expected={expected_color_size} got={cs}"
prg = get_prg(k)
for buf,_ in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled
prg.exec(real_bufs)
for i, (buf,shape) in enumerate(outbufs):
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)).reshape(shape), wanna_output[i], atol=atol, rtol=rtol)
# Get baseline if it is not provided, which is not optimized at all.
k = Kernel(realized_ast)
lins.append(k)
prg = get_prg(k)
prg.exec(real_bufs)
if len(wanna_output) == 0: wanna_output = [np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)).reshape(shape).copy() for buf,shape in outbufs]
else:
for i, (buf,shape) in enumerate(outbufs):
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)).reshape(shape), wanna_output[i], atol=atol, rtol=rtol)
# Check correctness of handcoded optimiztions.
k = Kernel(realized_ast)
lins.append(k)
k.hand_coded_optimizations()
prg = get_prg(k)
for buf,_ in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled
prg.exec(real_bufs)
for i, (buf,shape) in enumerate(outbufs):
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)).reshape(shape), wanna_output[i], atol=atol, rtol=rtol)
for i, x in enumerate(opts): # Check custom transformations if any.
check_opt(x, lambda: Kernel(realized_ast), color_sizes[i] if i < len(color_sizes) else None)
return lins
class TestKernelOpts(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
def test_local_and_grouped_reduce(self):
N = 128
Tensor.manual_seed(1882)
a = Tensor.rand(4, 4, N, N)
b = Tensor.rand(4, 4, N)
r = (b.sqrt() + ((a+1).sum(axis=3).exp()))
helper_linearizer_opt(r, [
[Opt(OptOps.LOCAL, 0, 2)],
[Opt(OptOps.LOCAL, 0, 8)],
[Opt(OptOps.LOCAL, 0, 16)], # Checking how it works with locals
[Opt(OptOps.GROUPTOP, 0, 2)],
[Opt(OptOps.GROUPTOP, 0, 32)],
[Opt(OptOps.GROUPTOP, 0, 64)], # Checking how it works with grouped reduce
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2)],
[Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.GROUPTOP, 0, 16)],
[Opt(OptOps.LOCAL, 0, 32), Opt(OptOps.GROUPTOP, 0, 2)],
# Checking how it works with locals + grouped reduce
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 64)],
# Checking how it works with locals + grouped reduce + upcasts
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.UNROLL, 1, 4)],
# many local + many group
[Opt(OptOps.GROUP, 0, 2)] * 4,
[Opt(OptOps.LOCAL, 0, 2)] * 4,
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUP, 0, 2)] * 4,
])
def test_upcasts(self):
N = 16
Tensor.manual_seed(1772)
a = Tensor.rand(N, N)
b = Tensor.rand(N, N)
r = (a+b).sqrt() * ((a+1).exp())
helper_linearizer_opt(r, [
[Opt(OptOps.UPCAST, 0, 2)],
[Opt(OptOps.UPCAST, 0, 4)],
[Opt(OptOps.UPCAST, 0, 8)], # Checking how it works with upcasts
])
def test_full_upcast(self):
Tensor.manual_seed(1772)
a = Tensor.rand(4)
b = Tensor.rand(4)
r = (a+b).sqrt() * ((a+1).exp())
helper_linearizer_opt(r, [
[Opt(OptOps.UPCAST, 0, 4)], # Checking how it works with upcasts
])
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
def test_matmul(self):
N = 128
Tensor.manual_seed(1552)
a = Tensor.rand(N, N)
b = Tensor.rand(N, N)
r = a@b
helper_linearizer_opt(r, [
[Opt(OptOps.UPCAST, 0, 2)],
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # Checking how it works with upcasts
[Opt(OptOps.LOCAL, 0, 2)],
[Opt(OptOps.LOCAL, 1, 32)],
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4)],
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 32)],
[Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.LOCAL, 1, 8)], # Checking how it works with locals
[Opt(OptOps.GROUPTOP, 0, 2)],
[Opt(OptOps.GROUPTOP, 0, 32)],
[Opt(OptOps.GROUPTOP, 0, 32), Opt(OptOps.UNROLL, 0, 4)], # Checking how it works with grouped_reduce
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 32)],
[Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 32)],
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 4)], # Checking how it works with local+grouped_reduce
# Checking all together
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4),
Opt(OptOps.UPCAST, 1, 2)],
# Full global upcast + local
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 8)],
])
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
def test_double_reduce(self):
N = 128
Tensor.manual_seed(1552)
a = Tensor.rand(8, N, 8, N)
r = a.sum(axis=(1,3))
helper_linearizer_opt(r, [
# openCL / GPU=1 is 256 max threads
[Opt(OptOps.GROUPTOP, 0, 2)], [Opt(OptOps.GROUPTOP, 0, 32)],
[Opt(OptOps.GROUPTOP, 1, 2)], [Opt(OptOps.GROUPTOP, 1, 32)], # Checking how it works with 1 grouped_reduce.
[Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)],
[Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2)],
[Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 64)], # Checking how it works with 2 grouped_reduces.
[Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 0, 4)],
[Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 2, 4)], # Checking how it works with 2 grouped_reduces + upcasts.
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4)],
# Checking how it works with 2 grouped_reduces + upcasts + locals.
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 1, 4)],
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2)],
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2),
Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)], # Checking how it works with 2 grouped_reduces + upcasts + locals.
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2),
Opt(OptOps.UPCAST, 0, 2)], # No globals
])
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_invalid_tensor_core_extra_opts(self):
N = 128
Tensor.manual_seed(1552)
a = Tensor.rand(N, N)
b = Tensor.rand(N, N)
realized_ast, _ = helper_realized_ast(a@b)
invalid_opts = [
[Opt(OptOps.LOCAL, 2, 2)],
[Opt(OptOps.UPCAST, 2, 2)],
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 2, 2)],
]
for x in invalid_opts:
k = Kernel(realized_ast)
with self.assertRaises(AssertionError):
assert k.apply_tensor_cores(use_tensor_cores=1, extra_opts=x), "no valid tensor core" # for METAL in runners
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_buf_index_not_found_tensor_core(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPNE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
with self.assertRaises(KernelOptError):
k.apply_opt(Opt(OptOps.TC, 0, 1))
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_core_opts(self):
N = 128
Tensor.manual_seed(1552)
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
# bf16 buffer returns float32 numpy outputs so test would fail. testing opt with half suffices.
if tc.dtype_in == dtypes.bfloat16: continue
a, b = Tensor.rand(N, N, dtype=tc.dtype_in), Tensor.rand(N, N, dtype=tc.dtype_in)
r = a.matmul(b, acc_dtype=tc.dtype_out)
(atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4)
helper_linearizer_opt(r, [
[],
[Opt(OptOps.UPCAST, 0, 4)],
[Opt(OptOps.UPCAST, 1, 4)],
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts
[Opt(OptOps.UNROLL, 0, 2)], # check unroll
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of unroll and local
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)],
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)],
[Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4)], # check permutations
[Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)],
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4)],
[Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)],
# [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC)
], apply_tc=True, atol=atol, rtol=rtol)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
def test_tensor_core_opts_locals(self):
N = 128
Tensor.manual_seed(1552)
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
# bf16 buffer returns float32 numpy outputs so test would fail. testing opt with half suffices.
if tc.dtype_in == dtypes.bfloat16: continue
a, b = Tensor.rand(N, N, dtype=tc.dtype_in), Tensor.rand(N, N, dtype=tc.dtype_in)
r = a.matmul(b, acc_dtype=tc.dtype_out)
(atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4)
helper_linearizer_opt(r, [
[Opt(OptOps.UNROLL, 0, 0)], # check full unroll of reduce with locals
[Opt(OptOps.LOCAL, 0, 4)], # check local
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)],
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)],
], apply_tc=True, atol=atol, rtol=rtol)
def test_padto_matmul(self):
if CI and Device.DEFAULT in ["AMD", "NV", "CUDA"]: self.skipTest("super slow on CUDA and AMD because of the big grid dims")
N = 17 * 17
Tensor.manual_seed(289)
a = Tensor.rand(N, N)
b = Tensor.rand(N, N)
helper_linearizer_opt(a@b, [
[Opt(OptOps.PADTO, 0, 32)],
[Opt(OptOps.PADTO, 1, 32)],
[Opt(OptOps.PADTO, 2, 32)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)],
# can optimize further post PADTO
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 2),],
])
def test_padto_upcasted_not_ok(self):
N = 4
a = Tensor.rand(N, N)
b = Tensor.rand(N, N)
helper_linearizer_opt(a@b, [
[Opt(OptOps.UPCAST, 0, 0)],
[Opt(OptOps.UPCAST, 1, 0)],
[Opt(OptOps.UNROLL, 0, 0)],
[Opt(OptOps.PADTO, 0, 8)],
[Opt(OptOps.PADTO, 1, 8)],
[Opt(OptOps.PADTO, 2, 8)],
])
with self.assertRaises(KernelOptError):
helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 0, 0), Opt(OptOps.PADTO, 2, 8)]])
with self.assertRaises(KernelOptError):
helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 1, 0), Opt(OptOps.PADTO, 2, 8)]])
with self.assertRaises(KernelOptError):
helper_linearizer_opt(a@b, [[Opt(OptOps.UNROLL, 0, 0), Opt(OptOps.PADTO, 2, 8)]])
def test_padto_sum_ok(self):
N = 18 * 18
# NOTE: this setup prevents 17 * 17 contiguous merged into one dimension
a = Tensor.rand(N, N).shrink(((0, 17), (0, 17))) * 100
b = (Tensor.rand(N, N) < 0.5).realize().shrink(((0, 17), (0, 17)))
helper_linearizer_opt(a.sum(0), [
[Opt(OptOps.PADTO, 0, 32)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
])
helper_linearizer_opt(a.sum(1), [
[Opt(OptOps.PADTO, 0, 32)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
])
# can pad sum reduce axis if there's no unsafe ops prior to sum
for axis in (0, 1):
helper_linearizer_opt(a.sum(), [[Opt(OptOps.PADTO, axis, 32)],])
helper_linearizer_opt(a.sum(0), [[Opt(OptOps.PADTO, axis, 32)],])
helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, axis, 32)],])
helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, axis, 32)],])
helper_linearizer_opt(b.sum(acc_dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],])
helper_linearizer_opt(b.sum(0, acc_dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],])
helper_linearizer_opt(b.sum(1, acc_dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],])
# having unsafe ops after sum is fine
helper_linearizer_opt(a.sum().exp(), [[Opt(OptOps.PADTO, 0, 32)],])
helper_linearizer_opt(a.sum(0).exp(), [[Opt(OptOps.PADTO, 1, 32)],])
def test_padto_sum_not_ok(self):
N = 18 * 18
# NOTE: this setup prevents 17 * 17 contiguous merged into one dimension
a = Tensor.rand(N, N).shrink(((0, 17), (0, 17))).exp()
# exp is not safe to pad
with self.assertRaises(KernelOptError):
helper_linearizer_opt(a.exp().sum(), [[Opt(OptOps.PADTO, 0, 32)],])
with self.assertRaises(KernelOptError):
helper_linearizer_opt(a.exp().sum(0), [[Opt(OptOps.PADTO, 1, 32)],])
b = a < -1
# lt is not safe to pad
with self.assertRaises(KernelOptError):
helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, 0, 32)],])
with self.assertRaises(KernelOptError):
helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, 1, 32)],])
def test_padto_max(self):
N = 18 * 18
# NOTE: this setup prevents 17 * 17 contiguous merged into one axis
a = -Tensor.rand(N, N).shrink(((0, 17), (0, 17))) * 100
helper_linearizer_opt(a.max(0), [
[Opt(OptOps.PADTO, 0, 32)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
])
helper_linearizer_opt(a.max(1), [
[Opt(OptOps.PADTO, 0, 32)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
])
# cannot pad max kernel on reduce
with self.assertRaises(KernelOptError):
helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],])
with self.assertRaises(KernelOptError):
helper_linearizer_opt(a.max(0), [[Opt(OptOps.PADTO, 1, 32)],])
def test_padto_where(self):
Tensor.manual_seed(0)
N = 17 * 17
a = (Tensor.randn(N, N).realize().max(axis=0, keepdim=True) > 1).where(1, 0)
helper_linearizer_opt(a.max(0), [
[Opt(OptOps.PADTO, 0, 32)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
])
def test_padto_where_multioutput(self):
Tensor.manual_seed(0)
N = 17 * 17
r = Tensor.randn(N, N).realize().max(axis=0, keepdim=True) > 1
a0 = r.where(1, 0)
a1 = r.where(2, 0)
helper_linearizer_opt([a0.max(0), a1.max(0)], [
[Opt(OptOps.PADTO, 0, 32)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
])
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
def test_padto_group(self):
Tensor.manual_seed(0)
ld0 = LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)))) # noqa: E501
ld1 = LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))) # noqa: E501
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(ld0, ld1)),), arg=(0, 2, 4, 6)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
data1 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()
data2 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()
helper_linearizer_ast((ast, ), [data1, data2], opts=[
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.GROUP, 0, 4)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.GROUP, 0, 4)]
])
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
def test_color_shapes_with_local(self):
N = 32
Tensor.manual_seed(1552)
a = Tensor.rand(N, N)
b = Tensor.rand(N, N)
r = a@b
opts_shapes = [
([Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("red",32)]),
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",2),("red",16)]),
# check to ensure local_dims are stable for full UNROLL of first_reduce
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
([Opt(OptOps.UNROLL, 0, 0),Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
# check behavior for full UNROLL on an existing GROUP
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",16),("magenta",2)]),
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
([Opt(OptOps.GROUP, 0, 0),Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
([Opt(OptOps.GROUP, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",32),("blue",32),("red",16),("magenta",2)]),
]
helper_linearizer_opt(r, [x[0] for x in opts_shapes], color_sizes=[x[1] for x in opts_shapes])
if __name__ == '__main__':
unittest.main()