mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
this fixes .split where self.shape[dim] is not perfectly divisible by sizes - .chunk is always the wrong choice here: - tensor((5,)).split(4) should result in (tensor((4,)), tensor((1,))) was (tensor((3,)), tensor((2,))) this also fixes issues in .split and .chunk where tensors with shape[dim]==0 lead to empty tuples/lists when the tensor itself should have been returned instead because tinygrad is expected to fail in all cases where torch fails tinygrad will now be strict regarding sizes having to sum up to passed dimension in .split, num having to be non-null for .chunk and only allowing valid dims in .unsqueeze
90 lines
2.9 KiB
Python
90 lines
2.9 KiB
Python
import unittest
|
|
from math import prod
|
|
|
|
from hypothesis import assume, given, settings, strategies as st
|
|
from hypothesis.extra import numpy as stn
|
|
|
|
import numpy as np
|
|
import torch
|
|
import tinygrad
|
|
from tinygrad.helpers import CI
|
|
|
|
|
|
settings.register_profile(__file__, settings.default,
|
|
max_examples=100 if CI else 250, deadline=None)
|
|
|
|
|
|
# torch wraparound for large numbers
|
|
st_int32 = st.integers(-2147483648, 2147483647)
|
|
|
|
@st.composite
|
|
def st_shape(draw) -> tuple[int, ...]:
|
|
s = draw(stn.array_shapes(min_dims=0, max_dims=6,
|
|
min_side=0, max_side=512))
|
|
assume(prod(s) <= 1024 ** 2)
|
|
assume(prod([d for d in s if d]) <= 1024 ** 4)
|
|
return s
|
|
|
|
|
|
def tensors_for_shape(s:tuple[int, ...]) -> tuple[torch.tensor, tinygrad.Tensor]:
|
|
x = np.arange(prod(s)).reshape(s)
|
|
return torch.from_numpy(x), tinygrad.Tensor(x)
|
|
|
|
def apply(tor, ten, tor_fn, ten_fn=None):
|
|
ok = True
|
|
try: tor = tor_fn(tor)
|
|
except: tor, ok = None, not ok # noqa: E722
|
|
try: ten = ten_fn(ten) if ten_fn is not None else tor_fn(ten)
|
|
except: ten, ok = None, not ok # noqa: E722
|
|
return tor, ten, ok
|
|
|
|
|
|
class TestShapeOps(unittest.TestCase):
|
|
@settings.get_profile(__file__)
|
|
@given(st_shape(), st_int32, st.one_of(st_int32, st.lists(st_int32)))
|
|
def test_split(self, s:tuple[int, ...], dim:int, sizes:int|list[int]):
|
|
tor, ten = tensors_for_shape(s)
|
|
tor, ten, ok = apply(tor, ten, lambda t: t.split(sizes, dim))
|
|
assert ok
|
|
if tor is None and ten is None: return
|
|
|
|
assert len(tor) == len(ten)
|
|
assert all([np.array_equal(tor.numpy(), ten.numpy()) for (tor, ten) in zip(tor, ten)])
|
|
|
|
|
|
@settings.get_profile(__file__)
|
|
@given(st_shape(), st_int32, st_int32)
|
|
def test_chunk(self, s:tuple[int, ...], dim:int, num:int):
|
|
# chunking on a 0 dim is cloning and leads to OOM if done unbounded.
|
|
assume((0 <= (actual_dim := len(s)-dim if dim < 0 else dim) < len(s) and s[actual_dim] > 0) or
|
|
(num < 32))
|
|
|
|
tor, ten = tensors_for_shape(s)
|
|
tor, ten, ok = apply(tor, ten, lambda t: t.chunk(num, dim))
|
|
assert ok
|
|
if tor is None and ten is None: return
|
|
|
|
assert len(tor) == len(ten)
|
|
assert all([np.array_equal(tor.numpy(), ten.numpy()) for (tor, ten) in zip(tor, ten)])
|
|
|
|
@settings.get_profile(__file__)
|
|
@given(st_shape(), st_int32)
|
|
def test_squeeze(self, s:tuple[int, ...], dim:int):
|
|
tor, ten = tensors_for_shape(s)
|
|
tor, ten, ok = apply(tor, ten, lambda t: t.squeeze(dim))
|
|
assert ok
|
|
if tor is None and ten is None: return
|
|
assert np.array_equal(tor.numpy(), ten.numpy())
|
|
|
|
@settings.get_profile(__file__)
|
|
@given(st_shape(), st_int32)
|
|
def test_unsqueeze(self, s:tuple[int, ...], dim:int):
|
|
tor, ten = tensors_for_shape(s)
|
|
tor, ten, ok = apply(tor, ten, lambda t: t.unsqueeze(dim))
|
|
assert ok
|
|
if tor is None and ten is None: return
|
|
assert np.array_equal(tor.numpy(), ten.numpy())
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|