mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
update setitem tests to test for currently supported cases (#4334)
* tests, tests, tests * one more test * tests tests tests tests * t e s t * a few more
This commit is contained in:
@@ -35,7 +35,7 @@ def data_ptr(tensor:Tensor): return tensor.lazydata
|
||||
# https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html
|
||||
# TODO this is setitem
|
||||
def index_put_(tensor:Tensor, indices, values, accumulate) -> Tensor:
|
||||
pass
|
||||
tensor[indices] = values
|
||||
|
||||
# https://pytorch.org/docs/stable/generated/torch.argsort.html
|
||||
def argsort(tensor:Tensor) -> Tensor:
|
||||
@@ -215,7 +215,7 @@ class TestIndexing(unittest.TestCase):
|
||||
validate_indexing(reference)
|
||||
|
||||
# setting values
|
||||
# TODO: setitem
|
||||
# TODO: advanced setitem
|
||||
'''
|
||||
validate_setting(reference)
|
||||
'''
|
||||
@@ -276,7 +276,7 @@ class TestIndexing(unittest.TestCase):
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 2],
|
||||
[4, 5]]))
|
||||
|
||||
# TODO: setitem
|
||||
# TODO: advanced setitem
|
||||
'''
|
||||
# setting values
|
||||
reference[ri([0]), ri([1])] = -1
|
||||
@@ -323,7 +323,7 @@ class TestIndexing(unittest.TestCase):
|
||||
[1, 2]])
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 4], [5, 11]]))
|
||||
|
||||
# TODO: setitem
|
||||
# TODO: advanced setitem
|
||||
'''
|
||||
# setting values
|
||||
reference[ri([0]), ri([1])] = -1
|
||||
@@ -389,7 +389,7 @@ class TestIndexing(unittest.TestCase):
|
||||
|
||||
numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])],
|
||||
np.array([11]))
|
||||
# TODO setitem
|
||||
# TODO advanced setitem
|
||||
'''
|
||||
strided[ri([0]), ri([1])] = -1
|
||||
numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])],
|
||||
@@ -401,7 +401,7 @@ class TestIndexing(unittest.TestCase):
|
||||
|
||||
numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])],
|
||||
np.array([11, 17]))
|
||||
# TODO setitem
|
||||
# TODO advanced setitem
|
||||
'''
|
||||
strided[ri([0, 1]), ri([1, 0])] = Tensor([-1, 2])
|
||||
numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])],
|
||||
@@ -417,7 +417,7 @@ class TestIndexing(unittest.TestCase):
|
||||
[0, 1]])
|
||||
numpy_testing_assert_equal_helper(strided[rows, columns],
|
||||
np.array([[10, 11], [17, 18]]))
|
||||
# TODO setitem
|
||||
# TODO advanced setitem
|
||||
'''
|
||||
strided[rows, columns] = Tensor([[4, 6], [2, 3]])
|
||||
numpy_testing_assert_equal_helper(strided[rows, columns],
|
||||
@@ -525,7 +525,7 @@ class TestIndexing(unittest.TestCase):
|
||||
assert_get_eq(reference, indexer)
|
||||
assert_backward_eq(reference, indexer)
|
||||
|
||||
# TODO setitem
|
||||
# TODO advanced setitem
|
||||
'''
|
||||
for indexer in indices_to_test:
|
||||
assert_set_eq(reference, indexer, 44)
|
||||
@@ -580,7 +580,7 @@ class TestIndexing(unittest.TestCase):
|
||||
|
||||
for indexer in indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
# TODO setitem
|
||||
# TODO advanced setitem
|
||||
'''
|
||||
assert_set_eq(reference, indexer, 212)
|
||||
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
|
||||
@@ -654,7 +654,7 @@ class TestIndexing(unittest.TestCase):
|
||||
|
||||
for indexer in indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
# TODO setitem
|
||||
# TODO advanced setitem
|
||||
'''
|
||||
assert_set_eq(reference, indexer, 1333)
|
||||
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
|
||||
@@ -665,13 +665,13 @@ class TestIndexing(unittest.TestCase):
|
||||
]
|
||||
for indexer in indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
# TODO setitem
|
||||
# TODO advanced setitem
|
||||
'''
|
||||
assert_set_eq(reference, indexer, 1333)
|
||||
'''
|
||||
assert_backward_eq(reference, indexer)
|
||||
|
||||
# TODO setitem
|
||||
# TODO setitem backward
|
||||
'''
|
||||
def test_set_item_to_scalar_tensor(self):
|
||||
m = random.randint(1, 10)
|
||||
@@ -708,14 +708,11 @@ class TestIndexing(unittest.TestCase):
|
||||
numpy_testing_assert_equal_helper(v[::11], [0])
|
||||
numpy_testing_assert_equal_helper(v[1:6:2], [1, 3, 5])
|
||||
|
||||
# TODO setitem with stride
|
||||
'''
|
||||
def test_step_assignment(self):
|
||||
v = Tensor.zeros(4, 4)
|
||||
v = Tensor.zeros(4, 4).contiguous()
|
||||
v[0, 1::2] = Tensor([3., 4.])
|
||||
numpy_testing_assert_equal_helper(v[0].numpy().tolist(), [0, 3, 0, 4])
|
||||
numpy_testing_assert_equal_helper(v[1:].sum(), 0)
|
||||
'''
|
||||
|
||||
@unittest.skip("bool indexing not supported")
|
||||
def test_bool_indices(self):
|
||||
@@ -771,12 +768,13 @@ class TestIndexing(unittest.TestCase):
|
||||
numpy_testing_assert_equal_helper(len(w), 2)
|
||||
|
||||
# TODO setitem
|
||||
# NOTE: tinygrad doesn't support idx.max that big
|
||||
'''
|
||||
def test_index_put_accumulate_large_tensor(self):
|
||||
# This test is for tensors with number of elements >= INT_MAX (2^31 - 1).
|
||||
N = (1 << 31) + 5
|
||||
dt = dtypes.int8
|
||||
a = Tensor.ones(N, dtype=dt)
|
||||
a = Tensor.ones(N, dtype=dt).contiguous()
|
||||
indices = Tensor([-2, 0, -2, -1, 0, -1, 1], dtype=dtypes.int64)
|
||||
values = Tensor([6, 5, 6, 6, 5, 7, 11], dtype=dt)
|
||||
|
||||
@@ -789,7 +787,7 @@ class TestIndexing(unittest.TestCase):
|
||||
numpy_testing_assert_equal_helper(a[-2], 13)
|
||||
numpy_testing_assert_equal_helper(a[-1], 14)
|
||||
|
||||
a = Tensor.ones((2, N), dtype=dt)
|
||||
a = Tensor.ones((2, N), dtype=dt).contiguous()
|
||||
indices0 = np.array([0, -1, 0, 1], dtype=dtypes.int64)
|
||||
indices1 = np.array([-2, -1, 0, 1], dtype=dtypes.int64)
|
||||
values = np.array([12, 13, 10, 11], dtype=dt)
|
||||
@@ -808,7 +806,7 @@ class TestIndexing(unittest.TestCase):
|
||||
numpy_testing_assert_equal_helper(a[0, -1], 1)
|
||||
'''
|
||||
|
||||
# TODO setitem
|
||||
# TODO fancy setitem
|
||||
'''
|
||||
def test_index_put_accumulate_duplicate_indices(self):
|
||||
for i in range(1, 512):
|
||||
@@ -850,7 +848,7 @@ class TestIndexing(unittest.TestCase):
|
||||
res = x[:, ind_int]
|
||||
numpy_testing_assert_equal_helper(ref, res)
|
||||
# no repeating indices for index_put
|
||||
# TODO setitem
|
||||
# TODO fancy setitem
|
||||
'''
|
||||
src = Tensor.randn(4)
|
||||
ind_long = Tensor.arange(4, dtype=dtypes.int64)
|
||||
@@ -863,7 +861,7 @@ class TestIndexing(unittest.TestCase):
|
||||
numpy_testing_assert_equal_helper(inp_ref, inp_res)
|
||||
'''
|
||||
|
||||
# TODO setitem
|
||||
# TODO empty setitem
|
||||
'''
|
||||
def test_index_put_accumulate_empty(self):
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/94667
|
||||
@@ -916,7 +914,7 @@ class TestIndexing(unittest.TestCase):
|
||||
numpy_testing_assert_equal_helper(v[:, [0, 4, 2]].shape, (5, 3, 3))
|
||||
numpy_testing_assert_equal_helper(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
|
||||
|
||||
# TODO setitem
|
||||
# TODO fancy setitem
|
||||
'''
|
||||
def test_index_put_src_datatype(self, dtype):
|
||||
src = Tensor.ones(3, 2, 4, dtype=dtype)
|
||||
@@ -932,7 +930,7 @@ class TestIndexing(unittest.TestCase):
|
||||
res = src[[0, 2, 1], :, :]
|
||||
numpy_testing_assert_equal_helper(res.shape, src.shape)
|
||||
# test index_put, no accum
|
||||
# TODO setitem
|
||||
# TODO fancy setitem
|
||||
'''
|
||||
src[[0, 2, 1], :, :] = res
|
||||
numpy_testing_assert_equal_helper(res.shape, src.shape)
|
||||
@@ -960,7 +958,7 @@ class TestIndexing(unittest.TestCase):
|
||||
idx = Tensor([], dtype=dtypes.int64)
|
||||
numpy_testing_assert_equal_helper(x[idx].numel(), 0)
|
||||
|
||||
# TODO setitem
|
||||
# TODO empty setitem
|
||||
'''
|
||||
# empty assignment should have no effect but not throw an exception
|
||||
y = clone(x)
|
||||
@@ -1096,32 +1094,33 @@ class TestIndexing(unittest.TestCase):
|
||||
r[zero]
|
||||
numpy_testing_assert_equal_helper(r, r[...])
|
||||
|
||||
# TODO setitem
|
||||
'''
|
||||
def test_setitem_scalars(self):
|
||||
zero = Tensor(0, dtype=dtypes.int64)
|
||||
|
||||
# non-scalar indexed with scalars
|
||||
a = Tensor.randn(2, 3)
|
||||
a_set_with_number = clone(a)
|
||||
a_set_with_scalar = clone(a)
|
||||
a = Tensor.randn(2, 3).contiguous()
|
||||
a_set_with_number = clone(a).contiguous()
|
||||
a_set_with_scalar = clone(a).contiguous()
|
||||
b = Tensor.randn(3)
|
||||
|
||||
a_set_with_number[0] = b
|
||||
a_set_with_scalar[zero] = b
|
||||
numpy_testing_assert_equal_helper(a_set_with_number, a_set_with_scalar)
|
||||
a[1, zero] = 7.7
|
||||
numpy_testing_assert_equal_helper(7.7, a[1, 0])
|
||||
# TODO: weird inaccuracy Max relative difference: 2.47707621e-08
|
||||
# numpy_testing_assert_equal_helper(7.7, a[1, 0])
|
||||
np.testing.assert_allclose(7.7, a[1, 0].numpy(), rtol=1e-7)
|
||||
|
||||
# scalar indexed with scalars
|
||||
r = Tensor.randn()
|
||||
r = Tensor.randn().contiguous()
|
||||
with self.assertRaises(IndexError):
|
||||
r[:] = 8.8
|
||||
with self.assertRaises(IndexError):
|
||||
r[zero] = 8.8
|
||||
r[...] = 9.9
|
||||
numpy_testing_assert_equal_helper(9.9, r)
|
||||
'''
|
||||
# TODO: weird inaccuracy Max relative difference: 3.85322971e-08
|
||||
# numpy_testing_assert_equal_helper(9.9, r)
|
||||
np.testing.assert_allclose(9.9, r, rtol=1e-7)
|
||||
|
||||
def test_basic_advanced_combined(self):
|
||||
# From the NumPy indexing example
|
||||
@@ -1135,7 +1134,7 @@ class TestIndexing(unittest.TestCase):
|
||||
numpy_testing_assert_equal_helper(x, unmodified)
|
||||
|
||||
# But assignment should modify the original
|
||||
# TODO setitem
|
||||
# TODO fancy setitem
|
||||
'''
|
||||
unmodified = clone(x)
|
||||
x[1:2, [1, 2]] = 0
|
||||
@@ -1151,7 +1150,7 @@ class TestIndexing(unittest.TestCase):
|
||||
x[1] = Tensor.arange(5, 7)
|
||||
numpy_testing_assert_equal_helper(x.numpy().tolist(), [[0, 1], [5, 6]])
|
||||
|
||||
# TODO setitem
|
||||
# TODO fancy setitem
|
||||
'''
|
||||
def test_byte_tensor_assignment(self):
|
||||
x = Tensor.arange(0., 16).reshape(4, 4)
|
||||
@@ -1234,14 +1233,14 @@ class TestIndexing(unittest.TestCase):
|
||||
|
||||
self.assertRaises(IndexError, runner)
|
||||
|
||||
# TODO setitem
|
||||
# TODO fancy setitem
|
||||
'''
|
||||
def test_cpu_indices(self):
|
||||
idx = Tensor([0, 1])
|
||||
b = Tensor.zeros(2)
|
||||
x = Tensor.ones(10)
|
||||
x = Tensor.ones(10).contiguous()
|
||||
x[idx] = b # index_put_
|
||||
ref = Tensor.ones(10)
|
||||
ref = Tensor.ones(10).contiguous()
|
||||
ref[:2] = 0
|
||||
numpy_testing_assert_equal_helper(x, ref)
|
||||
out = x[idx] # index
|
||||
@@ -1535,7 +1534,7 @@ class TestNumpy(unittest.TestCase):
|
||||
self.assertRaises(IndexError, a.__setitem__, ind, 0)
|
||||
'''
|
||||
|
||||
# TODO setitem
|
||||
# TODO fancy setitem
|
||||
'''
|
||||
def test_index_is_larger(self):
|
||||
# Simple case of fancy index broadcasting of the index.
|
||||
@@ -1545,7 +1544,7 @@ class TestNumpy(unittest.TestCase):
|
||||
self.assertTrue((a[:3, :3] == all_(Tensor([2., 3., 4.]))))
|
||||
'''
|
||||
|
||||
# TODO setitem
|
||||
# TODO fancy setitem
|
||||
'''
|
||||
def test_broadcast_subspace(self):
|
||||
a = Tensor.zeros((100, 100))
|
||||
@@ -1556,7 +1555,7 @@ class TestNumpy(unittest.TestCase):
|
||||
numpy_testing_assert_equal_helper(a, expected)
|
||||
'''
|
||||
|
||||
# TODO setitem
|
||||
# TODO fancy setitem
|
||||
'''
|
||||
def test_truncate_leading_1s(self):
|
||||
col_max = Tensor.randn(1, 4)
|
||||
|
||||
@@ -4,28 +4,53 @@ import numpy as np
|
||||
|
||||
class TestSetitem(unittest.TestCase):
|
||||
def test_simple_setitem(self):
|
||||
t = Tensor.zeros(6, 6).contiguous().realize()
|
||||
t[2:4, 3:5] = Tensor.ones(2, 2)
|
||||
n = np.zeros((6, 6))
|
||||
n[2:4, 3:5] = np.ones((2, 2))
|
||||
np.testing.assert_allclose(t.numpy(), n)
|
||||
|
||||
t = Tensor.zeros(6, 6).contiguous().realize()
|
||||
t[2:4, 3:5] = 1.0
|
||||
n = np.zeros((6, 6))
|
||||
n[2:4, 3:5] = 1.0
|
||||
np.testing.assert_allclose(t.numpy(), n)
|
||||
cases = (
|
||||
((6,6), (slice(2,4), slice(3,5)), Tensor.ones(2,2)),
|
||||
((6,6), (slice(2,4), slice(3,5)), Tensor([1.,2.])),
|
||||
((6,6), (slice(2,4), slice(3,5)), 1.0),
|
||||
((6,6), (3, 4), 1.0),
|
||||
((6,6), (3, None, 4, None), 1.0),
|
||||
((4,4,4,4), (Ellipsis, slice(1,3), slice(None)), Tensor(4)),
|
||||
((4,4,4,4), (Ellipsis, slice(1,3)), 4),
|
||||
((4,4,4,4), (2, slice(1,3), None, 1), 4),
|
||||
((4,4,4,4), (slice(1,3), slice(None), slice(0,4,2)), 4),
|
||||
((4,4,4,4), (slice(1,3), slice(None), slice(None), slice(0,3)), 4),
|
||||
((6,6), (slice(1,5,2), slice(0,5,3)), 1.0),
|
||||
((6,6), (slice(5,1,-2), slice(5,0,-3)), 1.0),
|
||||
)
|
||||
for shp, slc, val in cases:
|
||||
t = Tensor.zeros(shp).contiguous()
|
||||
t[slc] = val
|
||||
n = np.zeros(shp)
|
||||
n[slc] = val.numpy() if isinstance(val, Tensor) else val
|
||||
np.testing.assert_allclose(t.numpy(), n)
|
||||
|
||||
def test_setitem_into_unrealized(self):
|
||||
t = Tensor.arange(4).reshape(2, 2)
|
||||
t[1] = 5
|
||||
np.testing.assert_allclose(t.numpy(), [[0, 1], [5, 5]])
|
||||
|
||||
def test_setitem_dtype(self):
|
||||
for dt in (dtypes.int, dtypes.float, dtypes.bool):
|
||||
for v in (5., 5, True):
|
||||
t = Tensor.ones(6,6, dtype=dt).contiguous()
|
||||
t[1] = v
|
||||
assert t.dtype == dt
|
||||
|
||||
def test_setitem_into_noncontiguous(self):
|
||||
t = Tensor.ones(4)
|
||||
assert not t.lazydata.st.contiguous
|
||||
with self.assertRaises(AssertionError): t[1] = 5
|
||||
|
||||
# TODO: implement fancy setitem
|
||||
@unittest.expectedFailure
|
||||
def test_fancy_setitem(self):
|
||||
t = Tensor.zeros(6,6).contiguous()
|
||||
t[[1,2], [3,2]] = 3
|
||||
n = np.zeros((6,6))
|
||||
n[[1,2], [3,2]] = 3
|
||||
np.testing.assert_allclose(t.numpy(), n)
|
||||
|
||||
def test_simple_jit_setitem(self):
|
||||
@TinyJit
|
||||
def f(t:Tensor, a:Tensor):
|
||||
|
||||
@@ -824,7 +824,7 @@ class Tensor:
|
||||
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
|
||||
assign_to = self.realize().__getitem__(indices)
|
||||
# NOTE: contiguous to prevent const folding.
|
||||
v = v._broadcast_to(broadcast_shape(assign_to.shape, v.shape)).contiguous()
|
||||
v = v.cast(assign_to.dtype)._broadcast_to(broadcast_shape(assign_to.shape, v.shape)).contiguous()
|
||||
assign_to.assign(v).realize()
|
||||
|
||||
# NOTE: using slice is discouraged and things should migrate to pad and shrink
|
||||
|
||||
Reference in New Issue
Block a user