mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
Zero dim Tensor support (#777)
* add and reorganize test_slice_* tests * refactor Tensor.__getitem__() * preliminary tests for 1) 0D tensors and 2) varargs for Tensor.zeros and Tensor.ones * always compare shapes of the numpy arrays obtained from tinygrad and torch tensors * add more tests for 0D support * remove test_tensor.test_slicing(). All slicing tests at test/test_ops.py * add zero-dim support * make test_end2end.py consistent with 0dim support * add test for tensor with zero in shape * don't simplify ones if shape is () * skip tests that need zero-size tensor support. - zero-size tensor support not related to 0dim tensors. * add tests for __getitem__() supporting strides >= 1 * refactor __getitem__: support for strides >= 1 * minor refactors and add comments to __getitem__ * add tests for slices with negative steps * add support for slices with negative strides
This commit is contained in:
168
test/test_ops.py
168
test/test_ops.py
@@ -15,7 +15,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
|
||||
if shps is None:
|
||||
ts = [torch.tensor(x, requires_grad=True) for x in vals]
|
||||
else:
|
||||
ts = [torch.tensor((np.random.random(size=x).astype(np.float32)+a)*b, requires_grad=True) for x in shps]
|
||||
ts = [torch.tensor((np.random.random(size=x)+a)*b, requires_grad=True, dtype=torch.float32) for x in shps]
|
||||
|
||||
tst = [Tensor(x.detach().numpy(), requires_grad=not FORWARD_ONLY) for x in ts]
|
||||
|
||||
@@ -29,7 +29,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
|
||||
|
||||
def compare(s, x,y,atol,rtol):
|
||||
if PRINT_TENSORS: print(s, x, y)
|
||||
if y.shape != tuple(): assert x.shape == y.shape, f"shape mismatch (tinygrad){x.shape} != (torch){y.shape}"
|
||||
assert x.shape == y.shape, f"shape mismatch: tinygrad={x.shape} | torch={y.shape}"
|
||||
try:
|
||||
np.testing.assert_allclose(x,y, atol=atol, rtol=rtol)
|
||||
except Exception:
|
||||
@@ -62,6 +62,8 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([], lambda: torch.full((45,65), 4), lambda: Tensor.full((45,65), 4), forward_only=True)
|
||||
def test_zeros(self):
|
||||
helper_test_op([], lambda: torch.zeros(45,65), lambda: Tensor.zeros(45,65), forward_only=True)
|
||||
helper_test_op([], lambda: torch.zeros([45,65]), lambda: Tensor.zeros([45,65]), forward_only=True)
|
||||
helper_test_op([], lambda: torch.zeros([]), lambda: Tensor.zeros([]), forward_only=True)
|
||||
def test_zeros_like(self):
|
||||
a = Tensor([[1,2,3],[4,5,6]])
|
||||
b = torch.tensor([[1,2,3],[4,5,6]])
|
||||
@@ -70,12 +72,15 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([], lambda: torch.empty(45,65)*0/0, lambda: Tensor.empty(45,65)*0/0, forward_only=True)
|
||||
def test_ones(self):
|
||||
helper_test_op([], lambda: torch.ones(45,65), lambda: Tensor.ones(45,65), forward_only=True)
|
||||
helper_test_op([], lambda: torch.ones([45,65]), lambda: Tensor.ones([45,65]), forward_only=True)
|
||||
helper_test_op([], lambda: torch.ones([]), lambda: Tensor.ones([]), forward_only=True)
|
||||
def test_ones_like(self):
|
||||
a = Tensor([[1,2,3],[4,5,6]])
|
||||
b = torch.tensor([[1,2,3],[4,5,6]])
|
||||
helper_test_op([], lambda: torch.ones_like(b), lambda: Tensor.ones_like(a), forward_only=True)
|
||||
def test_eye(self):
|
||||
helper_test_op([], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True)
|
||||
|
||||
def test_arange(self):
|
||||
helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True)
|
||||
def test_where(self):
|
||||
@@ -121,43 +126,58 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_maximum(self):
|
||||
helper_test_op([(45,65), (45,65)], torch.maximum, Tensor.maximum)
|
||||
helper_test_op([(), ()], torch.maximum, Tensor.maximum)
|
||||
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., 4.], [1., 2., 3., 0.]])
|
||||
def test_minimum(self):
|
||||
helper_test_op([(45,65), (45,65)], torch.minimum, Tensor.minimum)
|
||||
helper_test_op([(), ()], torch.minimum, Tensor.minimum)
|
||||
def test_add(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add)
|
||||
helper_test_op([(), ()], lambda x,y: x+y, Tensor.add)
|
||||
def test_add_simple(self):
|
||||
helper_test_op([(256), (256)], lambda x,y: x+y, Tensor.add, forward_only=True)
|
||||
def test_broadcasted_add(self):
|
||||
helper_test_op([(45,65), (45,1)], lambda x,y: x+y, lambda x,y: x+y)
|
||||
helper_test_op([(45,65), ()], lambda x,y: x+y, lambda x,y: x+y)
|
||||
def test_broadcasted_add_2(self):
|
||||
helper_test_op([(45,65), (65,)], lambda x,y: x+y, lambda x,y: x+y)
|
||||
def test_sub(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub)
|
||||
helper_test_op([(), ()], lambda x,y: x-y, Tensor.sub)
|
||||
def test_neg(self):
|
||||
helper_test_op([(45,65)], lambda x: -x)
|
||||
helper_test_op([()], lambda x: -x)
|
||||
def test_mul(self):
|
||||
helper_test_op([(64,64), (64,64)], lambda x,y: x*y, Tensor.mul)
|
||||
helper_test_op([(), ()], lambda x,y: x*y, Tensor.mul)
|
||||
def test_div(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div)
|
||||
helper_test_op([(), ()], lambda x,y: x/y, Tensor.div)
|
||||
def test_div_const(self):
|
||||
helper_test_op([(45,65)], lambda x: x/255, lambda x: x/255)
|
||||
helper_test_op([(45,65)], lambda x: x/1, lambda x: x/1)
|
||||
helper_test_op([(45,65)], lambda x: 1/x, lambda x: 1/x)
|
||||
helper_test_op([(45,65)], lambda x: x/2, lambda x: x/2)
|
||||
helper_test_op([(45,65)], lambda x: 2/x, lambda x: 2/x)
|
||||
helper_test_op([()], lambda x: x/2, lambda x: x/2)
|
||||
helper_test_op([()], lambda x: 2/x, lambda x: 2/x)
|
||||
def test_pow(self):
|
||||
helper_test_op([(45,65)], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0)
|
||||
helper_test_op([(45,65)], lambda x: x**3, lambda x: Tensor.pow(x,3), a=0)
|
||||
helper_test_op([(45,65)], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0)
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, a=0)
|
||||
helper_test_op([()], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0)
|
||||
helper_test_op([()], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0)
|
||||
def test_pow_const(self):
|
||||
helper_test_op([(45,65)], lambda x: x**1.0, lambda x: x**1.0)
|
||||
helper_test_op([(45,65)], lambda x: 1.0**x, lambda x: 1.0**x)
|
||||
helper_test_op([(45,65)], lambda x: x**2.0, lambda x: x**2.0)
|
||||
helper_test_op([(45,65)], lambda x: 2.0**x, lambda x: 2.0**x)
|
||||
helper_test_op([()], lambda x: x**2.0, lambda x: x**2.0)
|
||||
helper_test_op([()], lambda x: 2.0**x, lambda x: 2.0**x)
|
||||
def test_sqrt(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, a=0)
|
||||
helper_test_op([()], lambda x: x.sqrt(), Tensor.sqrt, a=0)
|
||||
|
||||
def test_sin(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sin(), Tensor.sin, a=0)
|
||||
@@ -168,47 +188,65 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_relu(self):
|
||||
helper_test_op([(64,64)], lambda x: x.relu(), Tensor.relu)
|
||||
helper_test_op([()], lambda x: x.relu(), Tensor.relu)
|
||||
def test_relu_exact(self):
|
||||
helper_test_op(None, lambda x: x.relu(), Tensor.relu, vals=[[-1.,0,1]])
|
||||
def test_relu_maximum_exact(self):
|
||||
helper_test_op(None, lambda x: torch.maximum(x, torch.zeros_like(x, requires_grad=False)), lambda x: Tensor.maximum(x, 0), vals=[[-1.,0,1]])
|
||||
def test_leakyrelu(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu)
|
||||
def test_celu(self):
|
||||
for val in range(1, 5):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))
|
||||
helper_test_op([()], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))
|
||||
def test_abs(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs)
|
||||
helper_test_op([()], lambda x: torch.abs(x), Tensor.abs)
|
||||
def test_log(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log)
|
||||
helper_test_op([()], lambda x: torch.log(x), Tensor.log)
|
||||
def test_exp(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp)
|
||||
helper_test_op([()], lambda x: torch.exp(x), Tensor.exp)
|
||||
def test_sign(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.sign(x), Tensor.sign)
|
||||
helper_test_op([()], lambda x: torch.sign(x), Tensor.sign)
|
||||
def test_softsign(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.softsign(x), Tensor.softsign)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.softsign(x), Tensor.softsign)
|
||||
def test_sigmoid(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid)
|
||||
helper_test_op([()], lambda x: x.sigmoid(), Tensor.sigmoid, forward_only=True)
|
||||
def test_softplus(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6)
|
||||
@unittest.skip("not supported in older pytorch")
|
||||
def test_gelu(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu)
|
||||
def test_quick_gelu(self):
|
||||
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu)
|
||||
helper_test_op([()], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu)
|
||||
def test_elu(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.elu(x), Tensor.elu)
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.elu(x, alpha=0.1), lambda x: Tensor.elu(x, alpha=0.1))
|
||||
helper_test_op([()], lambda x: torch.nn.functional.elu(x), Tensor.elu)
|
||||
def test_relu6(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.relu6(x), Tensor.relu6)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.relu6(x), Tensor.relu6)
|
||||
def test_hardswish(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6)
|
||||
def test_mish(self):
|
||||
def _mish_pytorch(x):
|
||||
return x*torch.tanh(torch.nn.functional.softplus(x))
|
||||
helper_test_op([(45,65)], _mish_pytorch, Tensor.mish, atol=1e-4)
|
||||
helper_test_op([()], _mish_pytorch, Tensor.mish, atol=1e-4)
|
||||
def test_dot(self):
|
||||
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
with self.assertRaises(RuntimeError):
|
||||
a = Tensor(3.14)
|
||||
a.matmul(a)
|
||||
def test_matmul_simple(self):
|
||||
helper_test_op([(4), (4,4)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
def test_matmul(self):
|
||||
@@ -225,6 +263,10 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(64,64), (64,64)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-3)
|
||||
def test_broadcastdot(self):
|
||||
helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
with self.assertRaises(RuntimeError):
|
||||
a = Tensor(3.14)
|
||||
b = Tensor.ones(3,3)
|
||||
a @ b
|
||||
def test_multidot(self):
|
||||
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
@@ -241,10 +283,12 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)), lambda x: Tensor.sum(x, axis=(0,2)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1))
|
||||
helper_test_op([()], lambda x: x.sum(), Tensor.sum)
|
||||
def test_min(self):
|
||||
helper_test_op([(3,3)], lambda x: x.min(), Tensor.min)
|
||||
helper_test_op([(45,3)], lambda x: x.min(), Tensor.min)
|
||||
helper_test_op([(45,3)], lambda x: x.min().mul(0.5), lambda x: Tensor.min(x).mul(0.5))
|
||||
helper_test_op([()], lambda x: x.min(), Tensor.min)
|
||||
def test_max(self):
|
||||
helper_test_op([(45,3)], lambda x: x.max(), Tensor.max)
|
||||
helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5))
|
||||
@@ -253,8 +297,10 @@ class TestOps(unittest.TestCase):
|
||||
[[1.0,1.0,0.0,1.0]],
|
||||
])
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1))
|
||||
helper_test_op([()], lambda x: x.max(), Tensor.max)
|
||||
def test_mean(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.mean())
|
||||
helper_test_op([()], lambda x: x.mean())
|
||||
def test_mean_axis(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)))
|
||||
def test_std(self):
|
||||
@@ -275,28 +321,34 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, keepdim=True, correction=0, dim=0), lambda x: Tensor.std(x, keepdim=True, correction=0, axis=0))
|
||||
def test_log_softmax(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], lambda x: torch.nn.LogSoftmax(dim=0)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
|
||||
def test_log_softmax_other_axis(self):
|
||||
helper_test_op([(10,10,10)], lambda x: x.log_softmax(0), lambda x: x.log_softmax(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(10,10,10)], lambda x: x.log_softmax(2), lambda x: x.log_softmax(2), atol=1e-7, grad_atol=1e-7)
|
||||
def test_tanh(self):
|
||||
helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6)
|
||||
def test_hardtanh(self):
|
||||
for val in range(10, 30, 5):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.hardtanh(x,-val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.hardtanh(x,-val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6)
|
||||
def test_topo_sort(self):
|
||||
helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6)
|
||||
|
||||
def test_scalar_mul(self):
|
||||
helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2)
|
||||
helper_test_op([()], lambda x: x*2, lambda x: x*2)
|
||||
def test_scalar_rmul(self):
|
||||
helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x)
|
||||
|
||||
helper_test_op([()], lambda x: 2*x, lambda x: 2*x)
|
||||
def test_scalar_sub(self):
|
||||
helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2)
|
||||
helper_test_op([()], lambda x: x-2, lambda x: x-2)
|
||||
def test_scalar_rsub(self):
|
||||
helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x)
|
||||
|
||||
helper_test_op([()], lambda x: 2-x, lambda x: 2-x)
|
||||
def test_flip_eye_crash(self):
|
||||
helper_test_op([], lambda: (torch.eye(10)@torch.eye(10).flip(0)),
|
||||
lambda: (Tensor.eye(10)@Tensor.eye(10).flip(0)), forward_only=True)
|
||||
@@ -310,6 +362,7 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_broadcast_simple(self):
|
||||
helper_test_op([(45,65), (45,1)], lambda x,y: x/y, lambda x,y: x/y)
|
||||
helper_test_op([(45,65), ()], lambda x,y: x/y, lambda x,y: x/y)
|
||||
|
||||
def test_broadcast_partial(self):
|
||||
for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul),
|
||||
@@ -320,19 +373,83 @@ class TestOps(unittest.TestCase):
|
||||
# NOTE: ANE backwards?
|
||||
helper_test_op(shapes, torch_op, tinygrad_op, a=-0.5 if tinygrad_op != Tensor.pow else 0.0)
|
||||
|
||||
def test_slice_simple(self):
|
||||
helper_test_op([(3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2])
|
||||
def test_slice_in_bounds_1dim(self):
|
||||
helper_test_op([(3)], lambda x: x[1:3], lambda x: x[1:3])
|
||||
helper_test_op([(3)], lambda x: x[0:2], lambda x: x[0:2])
|
||||
helper_test_op([(3)], lambda x: x[-2:2], lambda x: x[-2:2])
|
||||
|
||||
def test_slice(self):
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[1:2], lambda x: x[1:2])
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2])
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[1:2, 1:2, 0:-1], lambda x: x[1:2, 1:2, 0:-1])
|
||||
def test_slice_on_0dim_tensor(self):
|
||||
helper_test_op([()], lambda x: x[None], lambda x: x[None])
|
||||
|
||||
def test_slice_one(self):
|
||||
with self.assertRaises(IndexError):
|
||||
a = Tensor(3.14)
|
||||
a[0]
|
||||
|
||||
def test_slice_int_indexing(self):
|
||||
helper_test_op([(3)], lambda x: x[1], lambda x: x[1])
|
||||
|
||||
def test_slice_one_multi(self):
|
||||
helper_test_op([(3)], lambda x: x[-2], lambda x: x[-2])
|
||||
helper_test_op([(10,10)], lambda x: x[1], lambda x: x[1])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1,1,1], lambda x: x[1,1,1])
|
||||
|
||||
def test_slice_in_bounds_multidim(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2], lambda x: x[1:2])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, 2], lambda x: x[1:2, 2])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, 0:-1], lambda x: x[1:2, 1:2, 0:-1])
|
||||
|
||||
def test_slice_with_none(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x[None], lambda x: x[None])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, None], lambda x: x[1:2, None])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, None, 1:2], lambda x: x[1:2, None, 1:2])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, None, -1], lambda x: x[1:2, 1:2, None, -1])
|
||||
|
||||
def test_slice_one_endpoint_out_of_bounds(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x[0:4], lambda x: x[0:4])
|
||||
helper_test_op([(3,3,3)], lambda x: x[-6:4], lambda x: x[-6:4])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:50], lambda x: x[1:50])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:50, 1:2, -1], lambda x: x[1:50, 1:2, -1])
|
||||
|
||||
def test_slice_stride_gt_one(self):
|
||||
helper_test_op([(7,5,10)], lambda x: x[::2, ::3, ::4], lambda x: x[::2, ::3, ::4])
|
||||
helper_test_op([(7,5,10)], lambda x: x[1:5:2, ::3, ::4], lambda x: x[1:5:2, ::3, ::4])
|
||||
helper_test_op([(7,5,10)], lambda x: x[1:5:2, 3, ::4], lambda x: x[1:5:2, 3, ::4])
|
||||
helper_test_op([(7,5,10)], lambda x: x[1:5:2, None, None, 3, None, ::4], lambda x: x[1:5:2, None, None, 3, None, ::4])
|
||||
|
||||
def test_slice_negative_strides(self):
|
||||
# Torch doesn't support slicing with negative steps
|
||||
a = np.random.randn(10, 10, 10).astype(np.float32)
|
||||
t = Tensor(a)
|
||||
np.testing.assert_allclose(a[::-1], t[::-1].numpy())
|
||||
np.testing.assert_allclose(a[::-2], t[::-2].numpy())
|
||||
np.testing.assert_allclose(a[:, 2:0:-1], t[:, 2:0:-1].numpy())
|
||||
np.testing.assert_allclose(a[:, 2:0:-1, 3:1:-2], t[:, 2:0:-1, 3:1:-2].numpy())
|
||||
np.testing.assert_allclose(a[4:0:-3, 2:0:-1, -1:-5:-2], t[4:0:-3, 2:0:-1, -1:-5:-2].numpy())
|
||||
|
||||
@unittest.skip("No suppport for tensors with 0s in shape")
|
||||
def test_slice_both_endpoints_out_of_bounds(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x[5:10], lambda x: x[5:10], forward_only=True)
|
||||
helper_test_op([(3,3,3)], lambda x: x[-15:-7], lambda x: x[-15:-7], forward_only=True)
|
||||
|
||||
@unittest.skip("No suppport for tensors with 0s in shape")
|
||||
def test_slice_start_gt_end(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x[-2:2], lambda x: x[-2:2], forward_only=True)
|
||||
helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5], forward_only=True)
|
||||
|
||||
@unittest.skip("No suppport for tensors with 0s in shape")
|
||||
def test_slice_empty(self):
|
||||
helper_test_op([(10,10)], lambda x: x[1:1], lambda x: x[1:1], forward_only=True)
|
||||
|
||||
@unittest.skip("No suppport for tensors with 0s in shape")
|
||||
def test_slice_zero_in_shape(self):
|
||||
helper_test_op([(10,10)], lambda x: x[1:1], lambda x: x[1:1]) # x.shape = (0, 10)
|
||||
helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5]) # x.shape = (0, 3, 3)
|
||||
|
||||
def test_slice_errors(self):
|
||||
a = Tensor.ones(4, 3)
|
||||
with self.assertRaises(IndexError):
|
||||
a[1, 77, 77, 77] # IndexError: (finds too many indices before the out of bounds)
|
||||
a[1, 77] # IndexError: (out of bounds).
|
||||
a[0, -77]
|
||||
|
||||
def test_pad2d(self):
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)))
|
||||
@@ -342,10 +459,18 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(3,3,3)], lambda x: x.transpose(0,2), lambda x: x.transpose(0,2))
|
||||
helper_test_op([(1,2,3,4)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.permute(order=(3,0,2,1)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.permute(order=(3,2,1,0)))
|
||||
helper_test_op([()], lambda x: x.permute(()), lambda x: x.permute(()))
|
||||
|
||||
def test_reshape(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6)))
|
||||
helper_test_op([()], lambda x: torch.reshape(x, []), lambda x: x.reshape([]))
|
||||
helper_test_op([(1,)], lambda x: torch.reshape(x, []), lambda x: x.reshape([]))
|
||||
helper_test_op([()], lambda x: torch.reshape(x, [1]), lambda x: x.reshape([1]))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
x = Tensor.ones((4,3,6,6))
|
||||
x.reshape([])
|
||||
|
||||
def test_flip(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,)), lambda x: x.flip(axis=(0,)))
|
||||
@@ -354,23 +479,31 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(3,)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1,3)).flip((0,)), lambda x: x.flip(axis=(0,1,3)).flip(0))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(-1,)))
|
||||
helper_test_op([()], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=()))
|
||||
helper_test_op([(1,)], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=()))
|
||||
helper_test_op([(4, 3, 6, 6)], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=()))
|
||||
|
||||
def test_unsqueeze(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 4), lambda x: x.unsqueeze(dim=4))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -1), lambda x: x.unsqueeze(dim=-1))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -3), lambda x: x.unsqueeze(dim=-3))
|
||||
helper_test_op([()], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0))
|
||||
|
||||
def test_flatten(self):
|
||||
for axis in range(3):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flatten(x, start_dim=axis), lambda x: x.flatten(axis))
|
||||
helper_test_op([()], lambda x: x.flatten(), lambda x: x.flatten())
|
||||
helper_test_op([(1,)], lambda x: x.flatten(), lambda x: x.flatten())
|
||||
|
||||
def test_detach(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), forward_only=True)
|
||||
helper_test_op([()], lambda x: x.detach(), lambda x: x.detach(), forward_only=True)
|
||||
|
||||
def test_expand(self):
|
||||
arg = (4,3,2,6)
|
||||
helper_test_op([(4,3,1,6)], lambda x: x.expand(arg), lambda x: x.expand(shape=arg))
|
||||
helper_test_op([()], lambda x: x.expand([]), lambda x: x.expand(shape=[]))
|
||||
|
||||
@unittest.skip("very slow")
|
||||
def test_sd_big_conv(self):
|
||||
@@ -695,6 +828,10 @@ class TestOps(unittest.TestCase):
|
||||
for dim in range(-1, 2):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: torch.cat((x,y), dim), lambda x,y: x.cat(y, dim=dim))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
a = Tensor(3.14)
|
||||
a.cat(a)
|
||||
|
||||
def test_multicat(self):
|
||||
for dim in range(-1, 2):
|
||||
helper_test_op([(45,65), (45,65), (45,65)], lambda x,y,z: torch.cat((x,y,z), dim), lambda x,y,z: x.cat(y, z, dim=dim))
|
||||
@@ -707,6 +844,9 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
Tensor.stack([x], dim=77)
|
||||
|
||||
a = Tensor(3.14)
|
||||
np.testing.assert_allclose(Tensor.stack([a, a]).numpy(), Tensor([3.14, 3.14]).numpy())
|
||||
|
||||
def test_repeat(self):
|
||||
x = Tensor.randn(45, 65, 3)
|
||||
@@ -715,6 +855,7 @@ class TestOps(unittest.TestCase):
|
||||
for reps in [[], [4], [2, 1], [3, 2, 2]]:
|
||||
repeats = base_repeats + reps
|
||||
helper_test_op([(45, 65, 3)], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats))
|
||||
helper_test_op([()], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
x.repeat((2, 4))
|
||||
@@ -722,7 +863,6 @@ class TestOps(unittest.TestCase):
|
||||
with self.assertRaises(AssertionError):
|
||||
x.repeat((2, 0, 4))
|
||||
|
||||
|
||||
def test_clip(self):
|
||||
helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2), lambda x: x.clip(-2.3, 1.2))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user