mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-27 07:48:07 -05:00
const folding sum should return in the same dtype the same as regular sum, which can be different from input dtype
239 lines
9.6 KiB
Python
239 lines
9.6 KiB
Python
import unittest, math
|
|
from tinygrad import Tensor, Device, dtypes
|
|
from tinygrad.engine.schedule import create_schedule
|
|
from tinygrad.helpers import CI
|
|
from tinygrad.ops import BufferOps
|
|
import numpy as np
|
|
from test.helpers import is_dtype_supported
|
|
|
|
def _check_ast_count(desired_count:int, t:Tensor):
|
|
# NOTE: this has side effect because everything can be scheduled only once
|
|
schedule = create_schedule(t.lazydata.lbs)
|
|
asts = [s for s in schedule if s.ast[0].op is BufferOps.STORE]
|
|
assert len(asts) == desired_count
|
|
|
|
class TestUnaryOpsConstFolding(unittest.TestCase):
|
|
def test_all_consts_ops(self):
|
|
_check_ast_count(0, Tensor.ones(4).exp())
|
|
_check_ast_count(0, Tensor.ones(4).sqrt())
|
|
_check_ast_count(0, Tensor.ones(4) + Tensor.ones(4))
|
|
_check_ast_count(0, Tensor.ones(4) / Tensor.ones(4))
|
|
|
|
def test_cast(self):
|
|
_check_ast_count(0, Tensor.ones(4).cast(dtypes.int16))
|
|
_check_ast_count(0, Tensor.full(4, fill_value=-1).cast(dtypes.uint16))
|
|
|
|
class TestBinaryOpsConstFolding(unittest.TestCase):
|
|
def test_add_literal_zero(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + 0)
|
|
def test_add_tensor_zero(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(4))
|
|
def test_literal_zero_add(self):
|
|
_check_ast_count(0, 0 + Tensor([1.0, 2, 3, 4]))
|
|
def test_tensor_zero_add(self):
|
|
_check_ast_count(0, Tensor.zeros(4) + Tensor([1.0, 2, 3, 4]))
|
|
|
|
def test_sub_literal_zero(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) - 0)
|
|
def test_sub_tensor_zero(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) - Tensor.zeros(4))
|
|
|
|
def test_mul_literal_zero(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * 0)
|
|
def test_mul_tensor_zero(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.zeros(4))
|
|
def test_literal_zero_mul(self):
|
|
_check_ast_count(0, 0 * Tensor([1.0, 2, 3, 4]) * 0)
|
|
def test_tensor_zero_mul(self):
|
|
_check_ast_count(0, Tensor.zeros(4) * Tensor([1.0, 2, 3, 4]))
|
|
|
|
def test_mul_literal_one(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * 1)
|
|
def test_mul_tensor_one(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.ones(4))
|
|
def test_literal_one_mul(self):
|
|
_check_ast_count(0, 1 * Tensor([1.0, 2, 3, 4]))
|
|
def test_tensor_one_mul(self):
|
|
_check_ast_count(0, Tensor.ones(4) * Tensor([1.0, 2, 3, 4]))
|
|
|
|
def test_div_literal_one(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) / 1)
|
|
def test_div_tensor_one(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) / Tensor.ones(4))
|
|
|
|
def test_pow_literal_zero(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 0)
|
|
def test_pow_tensor_zero(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.zeros(4))
|
|
|
|
def test_pow_literal_one(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 1)
|
|
def test_pow_tensor_one(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.ones(4))
|
|
def test_literal_one_pow(self):
|
|
_check_ast_count(0, 1 ** Tensor([1.0, 2, 3, 4]))
|
|
def test_tensor_one_pow(self):
|
|
_check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
|
|
|
|
# folds advance indexing into basic indexing
|
|
class TestIndexingConstFolding(unittest.TestCase):
|
|
def test_scalar_index(self):
|
|
t = Tensor.arange(16).float().reshape(1,1,4,4).realize()
|
|
_check_ast_count(0, t[:,:,Tensor(1),:])
|
|
_check_ast_count(0, t[:,:,Tensor(1)+2,:])
|
|
_check_ast_count(0, t[:,:,Tensor(1),Tensor(0)])
|
|
|
|
@unittest.expectedFailure
|
|
def test_const_tensor_index(self):
|
|
# TODO: implement const tensor folded indexing
|
|
t = Tensor.arange(16).float().reshape(1,1,4,4).realize()
|
|
_check_ast_count(0, t[:,:,Tensor.ones(2,1),:])
|
|
_check_ast_count(0, t[:,:,Tensor.ones(1,2)+2,:])
|
|
_check_ast_count(0, t[:,:,Tensor.ones(1,1),Tensor.zeros(2,1,2)])
|
|
|
|
class TestMovedConstFolding(unittest.TestCase):
|
|
def test_add_shrunk_zero(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(6).shrink(((1, 5),)))
|
|
|
|
def test_add_padded_zero(self):
|
|
# TODO: it's 1 now, this might be possible to fold
|
|
_check_ast_count(1, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(2).pad(((1, 1),)))
|
|
|
|
def test_mul_shrunk_one(self):
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.ones(6).shrink(((1, 5),)))
|
|
|
|
def test_add_padded_one(self):
|
|
_check_ast_count(1, Tensor([1.0, 2, 3, 4]) * Tensor.ones(2).pad(((1, 1),)))
|
|
|
|
def test_cast_padded(self):
|
|
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
|
|
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
|
|
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
|
|
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
|
|
_check_ast_count(1, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
|
|
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
|
|
|
|
class TestReduceOpsConstFolding(unittest.TestCase):
|
|
def test_const_sum(self):
|
|
_check_ast_count(0, Tensor.ones(4, 5, 6).sum())
|
|
np.testing.assert_equal(Tensor.ones(4, 5, 6).sum().numpy(), 4 * 5 * 6)
|
|
_check_ast_count(0, Tensor.ones(4, 5, 6).sum(axis=0))
|
|
np.testing.assert_equal(Tensor.ones(4, 5, 6).sum(axis=0).numpy(), np.full((5, 6), 4))
|
|
_check_ast_count(0, Tensor(4).sum())
|
|
np.testing.assert_equal(Tensor(4).sum().numpy(), 4)
|
|
|
|
def test_padded_const_sum(self):
|
|
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).sum())
|
|
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).sum().numpy(), 4)
|
|
|
|
# NOTE: cannot just count the non-padded area because some UnaryOps f do not have f(0) = 0.
|
|
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).exp().sum())
|
|
np.testing.assert_allclose(Tensor.ones(4).pad(((1, 1),)).exp().sum().numpy(), 4 * math.e + 2)
|
|
|
|
def test_const_max(self):
|
|
_check_ast_count(0, Tensor.ones(4, 5, 6).max())
|
|
np.testing.assert_equal(Tensor.ones(4, 5, 6).max().numpy(), 1)
|
|
_check_ast_count(0, Tensor(4).max())
|
|
np.testing.assert_equal(Tensor(4).max().numpy(), 4)
|
|
|
|
def test_sum_output_dtype(self):
|
|
# sum output dtype can be different from input
|
|
for dt in dtypes.fields().values():
|
|
if is_dtype_supported(dt):
|
|
t = Tensor.ones(16, dtype=dt).reshape(4, 4)
|
|
assert t.sum().dtype == t.contiguous().sum().dtype
|
|
|
|
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
|
class TestMultiConstFolding(unittest.TestCase):
|
|
def test_multi_const_folding_literal(self):
|
|
ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))
|
|
t = Tensor.arange(16).float().realize().to(ds)
|
|
|
|
# non const folding case creates one ast on each shard
|
|
_check_ast_count(4, t + 1)
|
|
_check_ast_count(4, 1 + t)
|
|
_check_ast_count(4, t * 2)
|
|
_check_ast_count(4, 2 * t)
|
|
|
|
# const folded
|
|
_check_ast_count(0, t + 0)
|
|
_check_ast_count(0, 0 + t)
|
|
_check_ast_count(0, t * 0)
|
|
_check_ast_count(0, 0 * t)
|
|
_check_ast_count(0, t * 1)
|
|
_check_ast_count(0, 1 * t)
|
|
np.testing.assert_equal((t + 0).numpy(), np.arange(16))
|
|
np.testing.assert_equal((t * 0).numpy(), [0] * 16)
|
|
np.testing.assert_equal((t * 1).numpy(), np.arange(16))
|
|
|
|
_check_ast_count(0, t ** 0)
|
|
_check_ast_count(0, t ** 1)
|
|
_check_ast_count(0, 1 ** t)
|
|
|
|
def test_multi_const_folding_tensor(self):
|
|
ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))
|
|
t = Tensor.arange(16).float().realize().to(ds)
|
|
zero = Tensor.zeros(16).realize().to(ds)
|
|
one = Tensor.ones(16).realize().to(ds)
|
|
|
|
# const folded
|
|
_check_ast_count(0, t + zero)
|
|
_check_ast_count(0, zero + t)
|
|
_check_ast_count(0, t * zero)
|
|
_check_ast_count(0, zero * t)
|
|
_check_ast_count(0, t * one)
|
|
_check_ast_count(0, one * t)
|
|
np.testing.assert_equal((t + zero).numpy(), np.arange(16))
|
|
np.testing.assert_equal((t * zero).numpy(), [0] * 16)
|
|
np.testing.assert_equal((t * one).numpy(), np.arange(16))
|
|
|
|
@unittest.expectedFailure
|
|
def test_multi_todo_pow(self):
|
|
ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))
|
|
t = Tensor.arange(16).float().realize().to(ds)
|
|
zero = Tensor.zeros(16).realize().to(ds)
|
|
one = Tensor.ones(16).realize().to(ds)
|
|
|
|
# TODO: fix pow folding
|
|
_check_ast_count(0, t ** zero)
|
|
_check_ast_count(0, t ** one)
|
|
_check_ast_count(0, one ** t)
|
|
|
|
class TestTautologicalCompare(unittest.TestCase):
|
|
# without const folding, these would have triggered -Wtautological-compare in clang
|
|
def test_lt_false(self):
|
|
# bool < False is always false
|
|
np.testing.assert_equal((Tensor([True, False]) < False).numpy(), [False, False])
|
|
|
|
def test_true_lt(self):
|
|
# True < bool is always false
|
|
np.testing.assert_equal((True < Tensor([True, False])).numpy(), [False, False])
|
|
|
|
def test_truth_table(self):
|
|
np.testing.assert_equal((Tensor(False) < Tensor(False)).numpy(), False)
|
|
np.testing.assert_equal((Tensor(False) < Tensor(True)).numpy(), True)
|
|
np.testing.assert_equal((Tensor(True) < Tensor(False)).numpy(), False)
|
|
np.testing.assert_equal((Tensor(True) < Tensor(True)).numpy(), False)
|
|
|
|
@unittest.skip("not implemented yet")
|
|
def test_a_eq_a(self):
|
|
# self eq is always true for int or bool
|
|
a = Tensor([1, 2, 3])
|
|
np.testing.assert_equal((a == a).numpy(), [True, True, True])
|
|
|
|
# not true for nan
|
|
a = Tensor([math.nan, 1.0, 2.0])
|
|
np.testing.assert_equal((a == a).numpy(), [False, True, True])
|
|
|
|
@unittest.skip("not implemented yet")
|
|
def test_a_ne_a(self):
|
|
# self not eq is always false for int or bool
|
|
a = Tensor([1, 2, 3])
|
|
np.testing.assert_equal((a != a).numpy(), [False, False, False])
|
|
|
|
# not true for nan
|
|
a = Tensor([math.nan, 1.0, 2.0])
|
|
np.testing.assert_equal((a != a).numpy(), [True, False, False])
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |