mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
should mark the ones that are expected to work with expectedFailure, and delete and ones that are not expected to work
274 lines
11 KiB
Python
274 lines
11 KiB
Python
import unittest, math
|
|
from tinygrad import Tensor, Device, dtypes
|
|
from tinygrad.ops import Ops
|
|
from tinygrad.engine.schedule import create_schedule
|
|
from tinygrad.helpers import CI
|
|
import numpy as np
|
|
from tinygrad.device import is_dtype_supported
|
|
|
|
def _check_ast_count(desired_count:int, t:Tensor):
|
|
# NOTE: this has side effect because everything can be scheduled only once
|
|
schedule = create_schedule(t.lazydata.lbs)
|
|
asts = [s for s in schedule if s.ast.op is Ops.SINK]
|
|
assert len(asts) == desired_count
|
|
|
|
class TestUnaryOpsConstFolding(unittest.TestCase):
|
|
def test_all_consts_ops(self):
|
|
_check_ast_count(0, Tensor.ones(4).exp())
|
|
_check_ast_count(0, Tensor.ones(4).sqrt())
|
|
_check_ast_count(0, Tensor.ones(4) + Tensor.ones(4))
|
|
_check_ast_count(0, Tensor.ones(4) / Tensor.ones(4))
|
|
|
|
def test_cast(self):
|
|
_check_ast_count(0, Tensor.ones(4).cast(dtypes.int16))
|
|
_check_ast_count(0, Tensor.full(4, fill_value=-1).cast(dtypes.uint16))
|
|
|
|
@unittest.expectedFailure # no two level fold at lazybuffer
|
|
def test_neg_folding(self):
|
|
_check_ast_count(0, Tensor([1, 2, 3]).mul(-1).neg())
|
|
_check_ast_count(0, Tensor([1, 2, 3]).neg().mul(-1))
|
|
_check_ast_count(0, Tensor([1, 2, 3]).neg().neg())
|
|
|
|
def test_neg_realized_no_fold(self):
|
|
x = Tensor.randn(32, 32)
|
|
x = x.clip(0, 1).realize()
|
|
_check_ast_count(1, x.neg())
|
|
|
|
class TestBinaryOpsConstFolding(unittest.TestCase):
|
|
def test_add_literal_zero(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + 0)
|
|
def test_add_tensor_zero(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(4))
|
|
def test_literal_zero_add(self):
|
|
_check_ast_count(0, 0 + Tensor([1.0, 2, 3, 4]))
|
|
def test_tensor_zero_add(self):
|
|
_check_ast_count(0, Tensor.zeros(4) + Tensor([1.0, 2, 3, 4]))
|
|
|
|
def test_sub_literal_zero(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) - 0)
|
|
def test_sub_tensor_zero(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) - Tensor.zeros(4))
|
|
|
|
def test_mul_literal_zero(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * 0)
|
|
def test_mul_tensor_zero(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.zeros(4))
|
|
def test_literal_zero_mul(self):
|
|
_check_ast_count(0, 0 * Tensor([1.0, 2, 3, 4]) * 0)
|
|
def test_tensor_zero_mul(self):
|
|
_check_ast_count(0, Tensor.zeros(4) * Tensor([1.0, 2, 3, 4]))
|
|
|
|
def test_mul_literal_one(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * 1)
|
|
def test_mul_tensor_one(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.ones(4))
|
|
def test_literal_one_mul(self):
|
|
_check_ast_count(0, 1 * Tensor([1.0, 2, 3, 4]))
|
|
def test_tensor_one_mul(self):
|
|
_check_ast_count(0, Tensor.ones(4) * Tensor([1.0, 2, 3, 4]))
|
|
|
|
def test_bool_tensor_mul_bool(self):
|
|
_check_ast_count(0, Tensor([True, False]) * True)
|
|
_check_ast_count(0, Tensor([True, False]) * False)
|
|
def test_bool_mul_bool_tensor(self):
|
|
_check_ast_count(0, True * Tensor([True, False]))
|
|
_check_ast_count(0, False * Tensor([True, False]))
|
|
|
|
def test_div_literal_one(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) / 1)
|
|
def test_div_tensor_one(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) / Tensor.ones(4))
|
|
|
|
def test_idiv_literal_one(self):
|
|
_check_ast_count(0, Tensor([1, 2, 3, 4]) // 1)
|
|
def test_idiv_tensor_one(self):
|
|
_check_ast_count(0, Tensor([1, 2, 3, 4]) // Tensor.ones(4, dtype=dtypes.int32))
|
|
|
|
def test_pow_literal_zero(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 0)
|
|
def test_pow_tensor_zero(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.zeros(4))
|
|
|
|
def test_pow_literal_one(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 1)
|
|
def test_pow_tensor_one(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.ones(4))
|
|
def test_literal_one_pow(self):
|
|
_check_ast_count(0, 1 ** Tensor([1.0, 2, 3, 4]))
|
|
def test_tensor_one_pow(self):
|
|
_check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
|
|
|
|
# folds advance indexing into basic indexing
|
|
class TestIndexingConstFolding(unittest.TestCase):
|
|
def test_scalar_index(self):
|
|
t = Tensor.arange(16).float().reshape(1,1,4,4).realize()
|
|
_check_ast_count(0, t[:,:,Tensor(1),:])
|
|
_check_ast_count(0, t[:,:,Tensor(1)+2,:])
|
|
_check_ast_count(0, t[:,:,Tensor(1),Tensor(0)])
|
|
|
|
@unittest.expectedFailure
|
|
def test_const_tensor_index(self):
|
|
# TODO: implement const tensor folded indexing
|
|
t = Tensor.arange(16).float().reshape(1,1,4,4).realize()
|
|
_check_ast_count(0, t[:,:,Tensor.ones(2,1),:])
|
|
_check_ast_count(0, t[:,:,Tensor.ones(1,2)+2,:])
|
|
_check_ast_count(0, t[:,:,Tensor.ones(1,1),Tensor.zeros(2,1,2)])
|
|
|
|
class TestMovedConstFolding(unittest.TestCase):
|
|
def test_add_shrunk_zero(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(6).shrink(((1, 5),)))
|
|
|
|
def test_add_padded_zero(self):
|
|
# TODO: it's 1 now, this might be possible to fold
|
|
_check_ast_count(1, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(2).pad(((1, 1),)))
|
|
|
|
def test_mul_shrunk_one(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.ones(6).shrink(((1, 5),)))
|
|
|
|
def test_add_padded_one(self):
|
|
_check_ast_count(1, Tensor([1.0, 2, 3, 4]) * Tensor.ones(2).pad(((1, 1),)))
|
|
|
|
def test_cast_padded(self):
|
|
# NOTE: this is folded due to CAST_BEFORE_VIEW
|
|
if is_dtype_supported(dtypes.int16):
|
|
_check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
|
|
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
|
|
if is_dtype_supported(dtypes.uint16):
|
|
_check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
|
|
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
|
|
# not folded
|
|
if is_dtype_supported(dtypes.int64):
|
|
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
|
|
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
|
|
|
|
class TestReduceOpsConstFolding(unittest.TestCase):
|
|
def test_const_sum(self):
|
|
_check_ast_count(0, Tensor.ones(4, 5, 6).sum())
|
|
np.testing.assert_equal(Tensor.ones(4, 5, 6).sum().numpy(), 4 * 5 * 6)
|
|
_check_ast_count(0, Tensor.ones(4, 5, 6).sum(axis=0))
|
|
np.testing.assert_equal(Tensor.ones(4, 5, 6).sum(axis=0).numpy(), np.full((5, 6), 4))
|
|
_check_ast_count(0, Tensor(4).sum())
|
|
np.testing.assert_equal(Tensor(4).sum().numpy(), 4)
|
|
|
|
def test_padded_const_sum(self):
|
|
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).sum())
|
|
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).sum().numpy(), 4)
|
|
|
|
# NOTE: cannot just count the non-padded area because some UnaryOps f do not have f(0) = 0.
|
|
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).exp().sum())
|
|
np.testing.assert_allclose(Tensor.ones(4).pad(((1, 1),)).exp().sum().numpy(), 4 * math.e + 2)
|
|
|
|
def test_const_prod(self):
|
|
_check_ast_count(0, Tensor.full((2, 3), fill_value=2).prod())
|
|
np.testing.assert_equal(Tensor.full((2, 3), fill_value=2).prod().numpy(), 2**(2*3))
|
|
_check_ast_count(0, Tensor.full((4, 5, 6), fill_value=2).prod(axis=0))
|
|
np.testing.assert_equal(Tensor.full((4, 5, 6), fill_value=2).prod(axis=0).numpy(), np.full((5, 6), 2**4))
|
|
_check_ast_count(0, Tensor(4).prod())
|
|
np.testing.assert_equal(Tensor(4).prod().numpy(), 4)
|
|
|
|
def test_const_max(self):
|
|
_check_ast_count(0, Tensor.ones(4, 5, 6).max())
|
|
np.testing.assert_equal(Tensor.ones(4, 5, 6).max().numpy(), 1)
|
|
_check_ast_count(0, Tensor(4).max())
|
|
np.testing.assert_equal(Tensor(4).max().numpy(), 4)
|
|
|
|
def test_sum_output_dtype(self):
|
|
# sum output dtype can be different from input
|
|
for dt in dtypes.fields().values():
|
|
if is_dtype_supported(dt):
|
|
t = Tensor.ones(16, dtype=dt).reshape(4, 4)
|
|
assert t.sum().dtype == t.contiguous().sum().dtype
|
|
|
|
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
|
class TestMultiConstFolding(unittest.TestCase):
|
|
def test_multi_const_folding_literal(self):
|
|
ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))
|
|
t = Tensor.arange(16).float().realize().to(ds)
|
|
|
|
# non const folding case creates one ast on each shard
|
|
_check_ast_count(4, t + 1)
|
|
_check_ast_count(4, 1 + t)
|
|
_check_ast_count(4, t * 2)
|
|
_check_ast_count(4, 2 * t)
|
|
|
|
# const folded
|
|
_check_ast_count(0, t + 0)
|
|
_check_ast_count(0, 0 + t)
|
|
_check_ast_count(0, t * 0)
|
|
_check_ast_count(0, 0 * t)
|
|
_check_ast_count(0, t * 1)
|
|
_check_ast_count(0, 1 * t)
|
|
np.testing.assert_equal((t + 0).numpy(), np.arange(16))
|
|
np.testing.assert_equal((t * 0).numpy(), [0] * 16)
|
|
np.testing.assert_equal((t * 1).numpy(), np.arange(16))
|
|
|
|
_check_ast_count(0, t ** 0)
|
|
_check_ast_count(0, t ** 1)
|
|
_check_ast_count(0, 1 ** t)
|
|
|
|
def test_multi_const_folding_tensor(self):
|
|
ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))
|
|
t = Tensor.arange(16).float().realize().to(ds)
|
|
zero = Tensor.zeros(16).realize().to(ds)
|
|
one = Tensor.ones(16).realize().to(ds)
|
|
|
|
# const folded
|
|
_check_ast_count(0, t + zero)
|
|
_check_ast_count(0, zero + t)
|
|
_check_ast_count(0, t * zero)
|
|
_check_ast_count(0, zero * t)
|
|
_check_ast_count(0, t * one)
|
|
_check_ast_count(0, one * t)
|
|
np.testing.assert_equal((t + zero).numpy(), np.arange(16))
|
|
np.testing.assert_equal((t * zero).numpy(), [0] * 16)
|
|
np.testing.assert_equal((t * one).numpy(), np.arange(16))
|
|
|
|
@unittest.expectedFailure
|
|
def test_multi_todo_pow(self):
|
|
ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))
|
|
t = Tensor.arange(16).float().realize().to(ds)
|
|
zero = Tensor.zeros(16).realize().to(ds)
|
|
one = Tensor.ones(16).realize().to(ds)
|
|
|
|
# TODO: fix pow folding
|
|
_check_ast_count(0, t ** zero)
|
|
_check_ast_count(0, t ** one)
|
|
_check_ast_count(0, one ** t)
|
|
|
|
class TestTautologicalCompare(unittest.TestCase):
|
|
# without const folding, these would have triggered -Wtautological-compare in clang
|
|
def test_lt_false(self):
|
|
# bool < False is always false
|
|
np.testing.assert_equal((Tensor([True, False]) < False).numpy(), [False, False])
|
|
|
|
def test_true_lt(self):
|
|
# True < bool is always false
|
|
np.testing.assert_equal((True < Tensor([True, False])).numpy(), [False, False])
|
|
|
|
def test_truth_table(self):
|
|
np.testing.assert_equal((Tensor(False) < Tensor(False)).numpy(), False)
|
|
np.testing.assert_equal((Tensor(False) < Tensor(True)).numpy(), True)
|
|
np.testing.assert_equal((Tensor(True) < Tensor(False)).numpy(), False)
|
|
np.testing.assert_equal((Tensor(True) < Tensor(True)).numpy(), False)
|
|
|
|
def test_a_eq_a(self):
|
|
# self eq is always true for int or bool
|
|
a = Tensor([1, 2, 3])
|
|
np.testing.assert_equal((a == a).numpy(), [True, True, True])
|
|
|
|
# not true for nan
|
|
a = Tensor([math.nan, 1.0, 2.0])
|
|
np.testing.assert_equal((a == a).numpy(), [False, True, True])
|
|
|
|
def test_a_ne_a(self):
|
|
# self not eq is always false for int or bool
|
|
a = Tensor([1, 2, 3])
|
|
np.testing.assert_equal((a != a).numpy(), [False, False, False])
|
|
|
|
# not true for nan
|
|
a = Tensor([math.nan, 1.0, 2.0])
|
|
np.testing.assert_equal((a != a).numpy(), [True, False, False])
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|