mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
grouper tests from fuse_arange_default [pr] (#10394)
This commit is contained in:
@@ -86,6 +86,21 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(b, 1))
|
||||
np.testing.assert_allclose(b.numpy(), np.broadcast_to(a.numpy().astype(np.float16), (2, 4, 4))+2)
|
||||
|
||||
def test_indexing_scalars_simple(self):
|
||||
X = Tensor.randn(2, 2).realize()
|
||||
xt = X[Tensor(1)][Tensor(0)]
|
||||
with Context(FUSE_ARANGE=1):
|
||||
run_schedule(check_schedule(xt, 2))
|
||||
np.testing.assert_equal(xt.numpy(), X.numpy()[1][0])
|
||||
|
||||
@unittest.expectedFailure # TODO: failing because of can_chase
|
||||
def test_indexing_scalars_multiple_dims(self):
|
||||
X = Tensor.randn(2, 3).realize()
|
||||
xt = X[Tensor(0)][Tensor(1)]
|
||||
with Context(FUSE_ARANGE=1):
|
||||
run_schedule(check_schedule(xt, 2))
|
||||
np.testing.assert_equal(xt.numpy(), X.numpy()[0][1])
|
||||
|
||||
def test_push_pads_elementwise(self):
|
||||
x = Tensor.full((4,4), 2.).contiguous().realize()
|
||||
y = Tensor.full((4,4), 4.).contiguous().realize()
|
||||
|
||||
Reference in New Issue
Block a user