diff --git a/test/test_ops.py b/test/test_ops.py index 3174280290..2c4424357d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -102,6 +102,7 @@ class TestOps(unittest.TestCase): def test_eye(self): helper_test_op([], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True) helper_test_op([], lambda: torch.eye(1), lambda: Tensor.eye(1), forward_only=True) + helper_test_op([], lambda: torch.eye(0), lambda: Tensor.eye(0), forward_only=True) def test_split(self): test_cases = [ @@ -149,8 +150,6 @@ class TestOps(unittest.TestCase): helper_test_op([], lambda: torch.arange(5, 10, 3), lambda: Tensor.arange(5, 10, 3), forward_only=True) helper_test_op([], lambda: torch.arange(10, 5, -3), lambda: Tensor.arange(10, 5, -3), forward_only=True) helper_test_op([], lambda: torch.arange(11, 5, -3), lambda: Tensor.arange(11, 5, -3), forward_only=True) - def test_arange_simple(self): - helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True) def test_arange_big(self): helper_test_op([], lambda: torch.arange(256), lambda: Tensor.arange(256), forward_only=True) @@ -234,33 +233,30 @@ class TestOps(unittest.TestCase): (t*(t < 0)).sum().backward() np.testing.assert_allclose(t.grad.numpy(), tt.grad.numpy(), atol=5e-4, rtol=1e-5) - #@unittest.skip("this is broken with contiguous") def test_trunc(self): - helper_test_op([(45,65)], lambda x: torch.trunc(x), lambda x: x.trunc(), forward_only=True) + helper_test_op([(45,65)], lambda x: x.trunc(), forward_only=True) a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5]) helper_test_op([], lambda: torch.trunc(b), lambda: Tensor.trunc(a), forward_only=True) - #@unittest.skip("this is broken with contiguous") def test_floor(self): - helper_test_op([(45,65)], lambda x: torch.floor(x), lambda x: x.floor(), forward_only=True) + helper_test_op([(45,65)], lambda x: x.floor(), forward_only=True) a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5]) helper_test_op([], lambda: torch.floor(b), lambda: Tensor.floor(a), forward_only=True) - #@unittest.skip("this is broken with contiguous") def test_ceil(self): - helper_test_op([(45,65)], lambda x: torch.ceil(x), lambda x: x.ceil(), forward_only=True) + helper_test_op([(45,65)], lambda x: x.ceil(), forward_only=True) a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5]) helper_test_op([], lambda: torch.ceil(b), lambda: Tensor.ceil(a), forward_only=True) def test_tril(self): - helper_test_op([(3,3)], lambda x: x.tril(), lambda x: x.tril()) - helper_test_op([(3,3)], lambda x: x.tril(1), lambda x: x.tril(1)) - helper_test_op([(3,3)], lambda x: x.tril(-1), lambda x: x.tril(-1)) - helper_test_op([(5,3,3)], lambda x: x.tril(), lambda x: x.tril()) - helper_test_op([(5,3,3)], lambda x: x.tril(1), lambda x: x.tril(1)) + helper_test_op([(3,3)], lambda x: x.tril()) + helper_test_op([(3,3)], lambda x: x.tril(1)) + helper_test_op([(3,3)], lambda x: x.tril(-1)) + helper_test_op([(5,3,3)], lambda x: x.tril()) + helper_test_op([(5,3,3)], lambda x: x.tril(1)) def test_triu(self): - helper_test_op([(3,3)], lambda x: x.triu(), lambda x: x.triu()) - helper_test_op([(3,3)], lambda x: x.triu(1), lambda x: x.triu(1)) - helper_test_op([(3,3)], lambda x: x.triu(-1), lambda x: x.triu(-1)) - helper_test_op([(5,3,3)], lambda x: x.triu(), lambda x: x.triu()) - helper_test_op([(5,3,3)], lambda x: x.triu(1), lambda x: x.triu(1)) + helper_test_op([(3,3)], lambda x: x.triu()) + helper_test_op([(3,3)], lambda x: x.triu(1)) + helper_test_op([(3,3)], lambda x: x.triu(-1)) + helper_test_op([(5,3,3)], lambda x: x.triu()) + helper_test_op([(5,3,3)], lambda x: x.triu(1)) def test_maximum(self): helper_test_op([(45,65), (45,65)], torch.maximum, Tensor.maximum) helper_test_op([(), ()], torch.maximum, Tensor.maximum) @@ -269,81 +265,97 @@ class TestOps(unittest.TestCase): def test_minimum(self): helper_test_op([(45,65), (45,65)], torch.minimum, Tensor.minimum) helper_test_op([(), ()], torch.minimum, Tensor.minimum) + def test_add(self): + helper_test_op([(45,68), (45,68)], lambda x,y: x+y) helper_test_op([(45,68), (45,68)], lambda x,y: x+y, Tensor.add) - def test_add_number(self): - helper_test_op([(), ()], lambda x,y: x+y, Tensor.add) + helper_test_op([(), ()], lambda x,y: x+y) def test_add3(self): helper_test_op([(45,65), (45,65), (45,65)], lambda x,y,z: x+y+z) - def test_add_simple(self): - helper_test_op([(256), (256)], lambda x,y: x+y, Tensor.add, forward_only=True) def test_broadcasted_add(self): - helper_test_op([(45,65), (45,1)], lambda x,y: x+y, lambda x,y: x+y) - helper_test_op([(45,65), ()], lambda x,y: x+y, lambda x,y: x+y) + helper_test_op([(45,65), (45,1)], lambda x,y: x+y) + helper_test_op([(45,65), ()], lambda x,y: x+y) def test_broadcasted_add_2(self): - helper_test_op([(45,65), (65,)], lambda x,y: x+y, lambda x,y: x+y) + helper_test_op([(45,65), (65,)], lambda x,y: x+y) + def test_sub(self): helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub) - helper_test_op([(), ()], lambda x,y: x-y, Tensor.sub) + helper_test_op([(45,65), (45,65)], lambda x,y: x-y) + helper_test_op([(), ()], lambda x,y: x-y) + def test_scalar_sub(self): + helper_test_op([(45,65)], lambda x: x-2) + helper_test_op([()], lambda x: x-2) + def test_scalar_rsub(self): + helper_test_op([(45,65)], lambda x: 2-x) + helper_test_op([()], lambda x: 2-x) + def test_neg(self): helper_test_op([(45,65)], lambda x: -x) helper_test_op([()], lambda x: -x) + def test_mul(self): helper_test_op([(64,64), (64,64)], lambda x,y: x*y, Tensor.mul) - def test_mul_number(self): - helper_test_op([(), ()], lambda x,y: x*y, Tensor.mul) - def test_mul_const(self): - helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2) - helper_test_op([(45,65)], lambda x: x*-1, lambda x: x*-1) - helper_test_op([(45,65)], lambda x: 255*x, lambda x: 255*x) + helper_test_op([(64,64), (64,64)], lambda x,y: x*y) + helper_test_op([(), ()], lambda x,y: x*y) + def test_scalar_mul(self): + helper_test_op([(45,65)], lambda x: x*2) + helper_test_op([(45,65)], lambda x: x*-1) + helper_test_op([(45,65)], lambda x: 255*x) + helper_test_op([()], lambda x: x*2) + def test_scalar_rmul(self): + helper_test_op([(45,65)], lambda x: 2*x) + helper_test_op([()], lambda x: 2*x) + def test_div(self): helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div) - helper_test_op([(), ()], lambda x,y: x/y, Tensor.div) + helper_test_op([(45,65), (45,65)], lambda x,y: x/y) + helper_test_op([(), ()], lambda x,y: x/y) helper_test_op(None, lambda x,y: x/y, Tensor.div, forward_only=True, vals=np.array([[5],[1]], dtype=np.int32)) def test_div_int(self): helper_test_op(None, lambda x: (x/2).to(torch.int), lambda x: x/2, forward_only=True, vals=np.array([[3]], dtype=np.int32)) def test_div_const(self): - helper_test_op([(45,65)], lambda x: x/255, lambda x: x/255) - helper_test_op([(45,65)], lambda x: x/1, lambda x: x/1) - helper_test_op([(45,65)], lambda x: 1/x, lambda x: 1/x) - helper_test_op([(45,65)], lambda x: x/2, lambda x: x/2) - helper_test_op([(45,65)], lambda x: 2/x, lambda x: 2/x) - helper_test_op([()], lambda x: x/2, lambda x: x/2) - helper_test_op([()], lambda x: 2/x, lambda x: 2/x) + helper_test_op([(45,65)], lambda x: x/255) + helper_test_op([(45,65)], lambda x: x/1) + helper_test_op([(45,65)], lambda x: 1/x) + helper_test_op([(45,65)], lambda x: x/2) + helper_test_op([(45,65)], lambda x: 2/x) + helper_test_op([()], lambda x: x/2) + helper_test_op([()], lambda x: 2/x) @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "METAL has issues with -inf") def test_mul_const_naninf(self): - helper_test_op([(45,65)], lambda x: x*float("inf"), lambda x: x*float("inf")) - helper_test_op([(45,65)], lambda x: x*-float("inf"), lambda x: x*-float("inf")) - helper_test_op([(45,65)], lambda x: x*float("nan"), lambda x: x*float("nan")) + helper_test_op([(45,65)], lambda x: x*math.inf) + helper_test_op([(45,65)], lambda x: x*-math.inf) + helper_test_op([(45,65)], lambda x: x*math.nan) @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "METAL has issues with -inf") def test_div_const_naninf(self): - helper_test_op([(45,65)], lambda x: x/float("inf"), lambda x: x/float("inf")) - helper_test_op([(45,65)], lambda x: x/-float("inf"), lambda x: x/-float("inf")) - helper_test_op([(45,65)], lambda x: x/float("nan"), lambda x: x/float("nan")) - helper_test_op([(45,65)], lambda x: float("inf")/x, lambda x: float("inf")/x) - helper_test_op([(45,65)], lambda x: (-float("inf"))/x, lambda x: (-float("inf"))/x) - helper_test_op([(45,65)], lambda x: float("nan")/x, lambda x: float("nan")/x) + helper_test_op([(45,65)], lambda x: x/math.inf) + helper_test_op([(45,65)], lambda x: x/-math.inf) + helper_test_op([(45,65)], lambda x: x/math.nan) + helper_test_op([(45,65)], lambda x: math.inf/x) + helper_test_op([(45,65)], lambda x: (-math.inf)/x) + helper_test_op([(45,65)], lambda x: math.nan/x) + def test_pow_full(self): - helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, a=0) + helper_test_op([(45,65), (45,65)], lambda x,y: x**y, a=0) def test_pow(self): # TODO: why is a=0 for these tests? - helper_test_op([(45,65)], lambda x: x**0, lambda x: Tensor.pow(x,0), a=0) - helper_test_op([(45,65)], lambda x: x**1, lambda x: Tensor.pow(x,1), a=0) - helper_test_op([(45,65)], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0) - helper_test_op([(45,65)], lambda x: x**3, lambda x: Tensor.pow(x,3), a=0) - helper_test_op([(45,65)], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0) - helper_test_op([()], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0) - helper_test_op([()], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0) + helper_test_op([(45,65)], lambda x: x**0, a=0) + helper_test_op([(45,65)], lambda x: x**1, a=0) + helper_test_op([(45,65)], lambda x: x**2, a=0) + helper_test_op([(45,65)], lambda x: x**3, a=0) + helper_test_op([(45,65)], lambda x: x**-2, a=0) + helper_test_op([()], lambda x: x**2, a=0) + helper_test_op([()], lambda x: x**-2, a=0) # Regression tests for https://github.com/tinygrad/tinygrad/issues/1151 - helper_test_op([(45,65)], lambda x: x**3, lambda x: Tensor.pow(x,3), a=-10) - helper_test_op([()], lambda x: x**3, lambda x: Tensor.pow(x,3), a=-10) + helper_test_op([(45,65)], lambda x: x**3, a=-10) + helper_test_op([()], lambda x: x**3, a=-10) # Regression tests for https://github.com/tinygrad/tinygrad/issues/1251 - helper_test_op([(45,65)], lambda x: x**0.2, lambda x: Tensor.pow(x,0.2), a=-10) - helper_test_op([(45,65)], lambda x: x**1.2, lambda x: Tensor.pow(x,1.2), a=-10) - helper_test_op([()], lambda x: x**0.2, lambda x: Tensor.pow(x,0.2), a=-10) - helper_test_op([()], lambda x: x**1.2, lambda x: Tensor.pow(x,1.2), a=-10) + helper_test_op([(45,65)], lambda x: x**0.2, a=-10) + helper_test_op([(45,65)], lambda x: x**1.2, a=-10) + helper_test_op([()], lambda x: x**0.2, a=-10) + helper_test_op([()], lambda x: x**1.2, a=-10) a, b = Tensor([0.0], requires_grad=True), torch.tensor([0.0], requires_grad=True) - helper_test_op([], lambda: b**1.1, lambda: a**1.1, ) + helper_test_op([], lambda: b**1.1, lambda: a**1.1) def test_pow_const(self): helper_test_op([(45,65)], lambda x: x**1.0, lambda x: x**1.0) helper_test_op([(45,65)], lambda x: x**-1.0, lambda x: x**-1.0) @@ -352,12 +364,14 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], lambda x: 2.0**x, lambda x: 2.0**x) helper_test_op([()], lambda x: x**2.0, lambda x: x**2.0) helper_test_op([()], lambda x: 2.0**x, lambda x: 2.0**x) + def test_sqrt(self): - helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, a=0) - helper_test_op([()], lambda x: x.sqrt(), Tensor.sqrt, a=0) + helper_test_op([(45,65)], lambda x: x.sqrt(), a=0) + helper_test_op([()], lambda x: x.sqrt(), a=0) def test_rsqrt(self): - helper_test_op([(45,65)], lambda x: torch.rsqrt(x), Tensor.rsqrt, a=0) - helper_test_op([()], lambda x: torch.rsqrt(x), Tensor.rsqrt, a=0) + helper_test_op([(45,65)], lambda x: x.rsqrt(), a=0) + helper_test_op([()], lambda x: x.rsqrt(), a=0) + def test_xor(self): tor = torch.tensor([[1,-8,1],[32,1,6]], dtype=torch.int) ten = Tensor([[1,-8,1],[32,1,6]], dtype=dtypes.int32) @@ -366,17 +380,20 @@ class TestOps(unittest.TestCase): helper_test_op([], lambda: 0x1337^tor, lambda: 0x1337^ten, forward_only=True) def test_sin(self): - helper_test_op([(45,65)], lambda x: x.sin(), Tensor.sin, a=0) + helper_test_op([(45,65)], lambda x: x.sin(), a=0) + helper_test_op([()], lambda x: x.sin(), a=0) def test_cos(self): - helper_test_op([(45,65)], lambda x: x.cos(), Tensor.cos, a=0) + helper_test_op([(45,65)], lambda x: x.cos(), a=0) + helper_test_op([()], lambda x: x.cos(), a=0) def test_tan(self): - helper_test_op([(45,65)], lambda x: x.tan(), Tensor.tan, a=0) + helper_test_op([(45,65)], lambda x: x.tan(), a=0) + helper_test_op([()], lambda x: x.tan(), a=0) def test_relu(self): - helper_test_op([(64,64)], lambda x: x.relu(), Tensor.relu) - helper_test_op([()], lambda x: x.relu(), Tensor.relu) + helper_test_op([(64,64)], lambda x: x.relu()) + helper_test_op([()], lambda x: x.relu()) def test_relu_exact(self): - helper_test_op(None, lambda x: x.relu(), Tensor.relu, vals=[[-1.,0,1]]) + helper_test_op(None, lambda x: x.relu(), vals=[[-1.,0,1]]) def test_relu_maximum_exact(self): helper_test_op(None, lambda x: torch.maximum(x, torch.zeros_like(x, requires_grad=False)), lambda x: Tensor.maximum(x, 0), vals=[[-1.,0,1]]) def test_leakyrelu(self): @@ -386,38 +403,40 @@ class TestOps(unittest.TestCase): for val in range(1, 5): helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val)) helper_test_op([()], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val)) + def test_abs(self): - helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs) - helper_test_op([()], lambda x: torch.abs(x), Tensor.abs) + helper_test_op([(45,65)], torch.abs, Tensor.abs) + helper_test_op([()], torch.abs, Tensor.abs) def test_log(self): - helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log) - helper_test_op([()], lambda x: torch.log(x), Tensor.log) + helper_test_op([(45,65)], torch.log, Tensor.log) + helper_test_op([()], torch.log, Tensor.log) def test_log2(self): - helper_test_op([(45,65)], lambda x: torch.log2(x), Tensor.log2) - helper_test_op([()], lambda x: torch.log2(x), Tensor.log2) + helper_test_op([(45,65)], torch.log2, Tensor.log2) + helper_test_op([()], torch.log2, Tensor.log2) def test_exp(self): - helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp) - helper_test_op([()], lambda x: torch.exp(x), Tensor.exp) + helper_test_op([(45,65)], torch.exp, Tensor.exp) + helper_test_op([()], torch.exp, Tensor.exp) def test_exp2(self): - helper_test_op([(45,65)], lambda x: torch.exp2(x), Tensor.exp2) - helper_test_op([()], lambda x: torch.exp2(x), Tensor.exp2) + helper_test_op([(45,65)], torch.exp2, Tensor.exp2) + helper_test_op([()], torch.exp2, Tensor.exp2) def test_sign(self): - helper_test_op([(45,65)], lambda x: torch.sign(x), Tensor.sign) - helper_test_op([()], lambda x: torch.sign(x), Tensor.sign) + helper_test_op([(45,65)], torch.sign, Tensor.sign) + helper_test_op([()], torch.sign, Tensor.sign) def test_softsign(self): - helper_test_op([(45,65)], lambda x: torch.nn.functional.softsign(x), Tensor.softsign) - helper_test_op([()], lambda x: torch.nn.functional.softsign(x), Tensor.softsign) + helper_test_op([(45,65)], torch.nn.functional.softsign, Tensor.softsign) + helper_test_op([()], torch.nn.functional.softsign, Tensor.softsign) def test_sigmoid(self): - helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid) - helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, a=100) - helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, a=-100) - helper_test_op([()], lambda x: x.sigmoid(), Tensor.sigmoid, forward_only=True) + helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid) + helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, a=100) + helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, a=-100) + helper_test_op([()], torch.sigmoid, Tensor.sigmoid) def test_softplus(self): - helper_test_op([(45,65)], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6) - helper_test_op([()], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6) + helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, atol=1e-6, grad_atol=1e-6) + helper_test_op([()], torch.nn.functional.softplus, Tensor.softplus, atol=1e-6, grad_atol=1e-6) def test_gelu(self): helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu) - #helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, a=100) + if not (CI and Device.DEFAULT == "METAL"): + helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, a=100) helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, a=-100) def test_quick_gelu(self): helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu) @@ -425,38 +444,18 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, a=-100) helper_test_op([()], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu) def test_elu(self): - helper_test_op([(45,65)], lambda x: torch.nn.functional.elu(x), Tensor.elu) + helper_test_op([(45,65)], torch.nn.functional.elu, Tensor.elu) helper_test_op([(45,65)], lambda x: torch.nn.functional.elu(x, alpha=0.1), lambda x: Tensor.elu(x, alpha=0.1)) - helper_test_op([()], lambda x: torch.nn.functional.elu(x), Tensor.elu) + helper_test_op([()], torch.nn.functional.elu, Tensor.elu) def test_relu6(self): - helper_test_op([(45,65)], lambda x: torch.nn.functional.relu6(x), Tensor.relu6) - helper_test_op([()], lambda x: torch.nn.functional.relu6(x), Tensor.relu6) + helper_test_op([(45,65)], torch.nn.functional.relu6, Tensor.relu6) + helper_test_op([()], torch.nn.functional.relu6, Tensor.relu6) def test_hardswish(self): - helper_test_op([(45,65)], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6) - helper_test_op([()], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6) + helper_test_op([(45,65)], torch.nn.functional.hardswish, Tensor.hardswish, atol=1e-6, grad_atol=1e-6) + helper_test_op([()], torch.nn.functional.hardswish, Tensor.hardswish, atol=1e-6, grad_atol=1e-6) def test_mish(self): - def _mish_pytorch(x): - return x*torch.tanh(torch.nn.functional.softplus(x)) - helper_test_op([(45,65)], _mish_pytorch, Tensor.mish, atol=1e-4) - helper_test_op([()], _mish_pytorch, Tensor.mish, atol=1e-4) - @unittest.skipIf(IMAGE>0, "no 1d dot for images") - def test_dot_1d(self): - helper_test_op([(65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) - helper_test_op([(65), (65,45)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) - helper_test_op([(45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) - helper_test_op([(8,45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) - helper_test_op([(65), (8,65,45)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) - self.helper_test_exception([(4), (1,2)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) - self.helper_test_exception([(2,1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) - self.helper_test_exception([(1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) - def test_dot(self): - helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) - helper_test_op([(8,45,65), (8,65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) - self.helper_test_exception([(2, 4), (1, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) - self.helper_test_exception([(2, 1), (4, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) - with self.assertRaises(AssertionError): - a = Tensor(3.14) - a.matmul(a) + helper_test_op([(45,65)], torch.nn.functional.mish, Tensor.mish, atol=1e-4) + helper_test_op([()], torch.nn.functional.mish, Tensor.mish, atol=1e-4) def test_multinomial(self): # NOTE: this is random, so it has a very large atol @@ -477,17 +476,17 @@ class TestOps(unittest.TestCase): def test_argmax(self): self.assertEqual(torch.Tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy()) # check if returns first index for same max - helper_test_op([(10,20)], lambda x: x.argmax(), lambda x: x.argmax(), forward_only=True) - helper_test_op([(10,20)], lambda x: x.argmax(0, False), lambda x: x.argmax(0, False), forward_only=True) - helper_test_op([(10,20)], lambda x: x.argmax(1, False), lambda x: x.argmax(1, False), forward_only=True) - helper_test_op([(10,20)], lambda x: x.argmax(1, True), lambda x: x.argmax(1, True), forward_only=True) + helper_test_op([(10,20)], lambda x: x.argmax(), forward_only=True) + helper_test_op([(10,20)], lambda x: x.argmax(0, False), forward_only=True) + helper_test_op([(10,20)], lambda x: x.argmax(1, False), forward_only=True) + helper_test_op([(10,20)], lambda x: x.argmax(1, True), forward_only=True) def test_argmin(self): self.assertEqual(torch.Tensor([2, 2]).argmin().numpy(), Tensor([2, 2]).argmin().numpy()) - helper_test_op([(10,20)], lambda x: x.argmin(), lambda x: x.argmin(), forward_only=True) - helper_test_op([(10,20)], lambda x: x.argmin(0, False), lambda x: x.argmin(0, False), forward_only=True) - helper_test_op([(10,20)], lambda x: x.argmin(1, False), lambda x: x.argmin(1, False), forward_only=True) - helper_test_op([(10,20)], lambda x: x.argmin(1, True), lambda x: x.argmin(1, True), forward_only=True) + helper_test_op([(10,20)], lambda x: x.argmin(), forward_only=True) + helper_test_op([(10,20)], lambda x: x.argmin(0, False), forward_only=True) + helper_test_op([(10,20)], lambda x: x.argmin(1, False), forward_only=True) + helper_test_op([(10,20)], lambda x: x.argmin(1, True), forward_only=True) def test_einsum(self): # matrix transpose @@ -545,6 +544,25 @@ class TestOps(unittest.TestCase): with self.assertRaises(AssertionError): Tensor.einsum('ij,jk->ij', a) + @unittest.skipIf(IMAGE>0, "no 1d dot for images") + def test_dot_1d(self): + helper_test_op([(65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) + helper_test_op([(65), (65,45)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) + helper_test_op([(45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) + helper_test_op([(8,45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) + helper_test_op([(65), (8,65,45)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) + self.helper_test_exception([(4), (1,2)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) + self.helper_test_exception([(2,1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) + self.helper_test_exception([(1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) + def test_dot(self): + helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) + helper_test_op([(8,45,65), (8,65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) + self.helper_test_exception([(2, 4), (1, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) + self.helper_test_exception([(2, 1), (4, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) + with self.assertRaises(AssertionError): + a = Tensor(3.14) + a.matmul(a) + def test_matmul_simple(self): helper_test_op([(4), (4,4)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) def test_matmul(self): @@ -583,126 +601,107 @@ class TestOps(unittest.TestCase): def test_multidot(self): helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) + def test_sum_simple(self): - helper_test_op(None, lambda x: x.sum(), Tensor.sum, vals=[[1.,1.]]) + helper_test_op(None, lambda x: x.sum(), vals=[[1.,1.]]) def test_sum_full(self): - helper_test_op([(16384)], lambda x: x.sum(), lambda x: x.sum()) - def test_sum_small_full(self): - helper_test_op([(45,5)], lambda x: x.sum(), Tensor.sum) + helper_test_op([(16384)], lambda x: x.sum()) def test_sum_relu(self): - helper_test_op([(3,4,5)], lambda x: x.relu().sum().relu(), lambda x: x.relu().sum().relu()) + helper_test_op([(3,4,5)], lambda x: x.relu().sum().relu()) def test_sum(self): - helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum) - helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=3), lambda x: Tensor.sum(x, axis=3)) - helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,3)), lambda x: Tensor.sum(x, axis=(1,3))) - helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)), lambda x: Tensor.sum(x, axis=(0,2))) - helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2))) - helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1)) + helper_test_op([(45,3)], lambda x: x.sum()) + helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=3)) + helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,3))) + helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2))) + helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2))) + helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1)) helper_test_op([()], lambda x: x.sum(), Tensor.sum) def test_sum_with_zeros_shape(self): - helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,)), lambda x: Tensor.sum(x, axis=(0,))) - helper_test_op([(4, 0)], lambda x: x.sum(axis=(1,)), lambda x: Tensor.sum(x, axis=(1,))) - helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,1)), lambda x: Tensor.sum(x, axis=(0,1))) + helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,))) + helper_test_op([(4, 0)], lambda x: x.sum(axis=(1,))) + helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,1))) def test_min(self): - helper_test_op([(3,3)], lambda x: x.min(), Tensor.min) - helper_test_op([(45,3)], lambda x: x.min(), Tensor.min) - helper_test_op([(45,3)], lambda x: x.min().mul(0.5), lambda x: Tensor.min(x).mul(0.5)) - helper_test_op([()], lambda x: x.min(), Tensor.min) + helper_test_op([(3,3)], lambda x: x.min()) + helper_test_op([(45,3)], lambda x: x.min()) + helper_test_op([(45,3)], lambda x: x.min().mul(0.5)) + helper_test_op([()], lambda x: x.min()) def test_max(self): - helper_test_op([(45,3)], lambda x: x.max(), Tensor.max) - helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5)) - helper_test_op(None, lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), vals=[[[1.0,1.0,0.0,1.0]],]) - helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1)) - helper_test_op([()], lambda x: x.max(), Tensor.max) + helper_test_op([(45,3)], lambda x: x.max()) + helper_test_op([(45,3)], lambda x: x.max().mul(0.5)) + helper_test_op(None, lambda x: x.max().mul(0.5), vals=[[[1.0,1.0,0.0,1.0]],]) + helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1)) + helper_test_op([()], lambda x: x.max()) def test_mean(self): helper_test_op([(3,4,5,6)], lambda x: x.mean()) helper_test_op([()], lambda x: x.mean()) def test_mean_axis(self): - helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2))) + helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2))) def test_var(self): - helper_test_op([(15, 25, 35)], lambda x: torch.var(x), lambda x: Tensor.var(x)) - helper_test_op([(15, 25, 35)], lambda x: torch.var(x, dim=None, correction=0), lambda x: Tensor.var(x, correction=0)) - helper_test_op([(15, 25, 35)], lambda x: torch.var(x, dim=None, correction=5), lambda x: Tensor.var(x, correction=5)) + helper_test_op([(15, 25, 35)], lambda x: x.var()) + helper_test_op([(15, 25, 35)], lambda x: x.var(correction=0)) + helper_test_op([(15, 25, 35)], lambda x: x.var(correction=5)) def test_var_axis(self): - helper_test_op([(15, 25, 35)], lambda x: torch.var(x, dim=0), lambda x: Tensor.var(x, axis=0)) - helper_test_op([(15, 25, 35)], lambda x: torch.var(x, dim=2), lambda x: Tensor.var(x, axis=2)) - helper_test_op([(15, 25, 35)], lambda x: torch.var(x, dim=[1, 2]), lambda x: Tensor.var(x, axis=[1, 2])) - helper_test_op([(15, 25, 35)], lambda x: torch.var(x, dim=None), lambda x: Tensor.var(x, axis=None)) - helper_test_op([(15, 25, 35)], lambda x: torch.var(x, correction=0, dim=0), lambda x: Tensor.var(x, axis=0, correction=0)) - helper_test_op([(15, 25, 35)], lambda x: torch.var(x, correction=0, dim=2), lambda x: Tensor.var(x, axis=2, correction=0)) - helper_test_op([(15, 25, 35)], lambda x: torch.var(x, correction=0, dim=[1, 2]), lambda x: Tensor.var(x, axis=[1, 2], correction=0)) - helper_test_op([(15, 25, 35)], lambda x: torch.var(x, correction=0, dim=None), lambda x: Tensor.var(x, axis=None, correction=0)) + helper_test_op([(15, 25, 35)], lambda x: x.var(0)) + helper_test_op([(15, 25, 35)], lambda x: x.var(2)) + helper_test_op([(15, 25, 35)], lambda x: x.var([1, 2])) + helper_test_op([(15, 25, 35)], lambda x: x.var(0, correction=0)) + helper_test_op([(15, 25, 35)], lambda x: x.var(2, correction=0)) + helper_test_op([(15, 25, 35)], lambda x: x.var([1, 2], correction=0)) def test_var_keepdim(self): - helper_test_op([(15, 25, 35)], lambda x: torch.var(x, dim=None, keepdim=True), lambda x: Tensor.var(x, keepdim=True)) - helper_test_op([(15, 25, 35)], lambda x: torch.var(x, dim=0, keepdim=True, correction=0), - lambda x: Tensor.var(x, keepdim=True, correction=0, axis=0)) + helper_test_op([(15, 25, 35)], lambda x: x.var(keepdim=True)) + helper_test_op([(15, 25, 35)], lambda x: x.var(0, keepdim=True, correction=0)) def test_std(self): - helper_test_op([(15, 25, 35)], lambda x: torch.std(x), lambda x: Tensor.std(x)) - helper_test_op([(15, 25, 35)], lambda x: torch.std(x, dim=None, correction=0), lambda x: Tensor.std(x, correction=0)) - helper_test_op([(15, 25, 35)], lambda x: torch.std(x, dim=None, correction=5), lambda x: Tensor.std(x, correction=5)) + helper_test_op([(15, 25, 35)], lambda x: x.std()) + helper_test_op([(15, 25, 35)], lambda x: x.std(correction=0)) + helper_test_op([(15, 25, 35)], lambda x: x.std(correction=5)) def test_std_axis(self): - helper_test_op([(15, 25, 35)], lambda x: torch.std(x, dim=0), lambda x: Tensor.std(x, axis=0)) - helper_test_op([(15, 25, 35)], lambda x: torch.std(x, dim=2), lambda x: Tensor.std(x, axis=2)) - helper_test_op([(15, 25, 35)], lambda x: torch.std(x, dim=[1, 2]), lambda x: Tensor.std(x, axis=[1, 2])) - helper_test_op([(15, 25, 35)], lambda x: torch.std(x, dim=None), lambda x: Tensor.std(x, axis=None)) - helper_test_op([(15, 25, 35)], lambda x: torch.std(x, correction=0, dim=0), lambda x: Tensor.std(x, axis=0, correction=0)) - helper_test_op([(15, 25, 35)], lambda x: torch.std(x, correction=0, dim=2), lambda x: Tensor.std(x, axis=2, correction=0)) - helper_test_op([(15, 25, 35)], lambda x: torch.std(x, correction=0, dim=[1, 2]), lambda x: Tensor.std(x, axis=[1, 2], correction=0)) - helper_test_op([(15, 25, 35)], lambda x: torch.std(x, correction=0, dim=None), lambda x: Tensor.std(x, axis=None, correction=0)) + helper_test_op([(15, 25, 35)], lambda x: x.std(0)) + helper_test_op([(15, 25, 35)], lambda x: x.std(2)) + helper_test_op([(15, 25, 35)], lambda x: x.std([1, 2])) + helper_test_op([(15, 25, 35)], lambda x: x.std(0, correction=0)) + helper_test_op([(15, 25, 35)], lambda x: x.std(2, correction=0)) + helper_test_op([(15, 25, 35)], lambda x: x.std([1, 2], correction=0)) def test_std_keepdim(self): - helper_test_op([(15, 25, 35)], lambda x: torch.std(x, dim=None, keepdim=True), lambda x: Tensor.std(x, keepdim=True)) - helper_test_op([(15, 25, 35)], lambda x: torch.std(x, dim=0, keepdim=True, correction=0), - lambda x: Tensor.std(x, keepdim=True, correction=0, axis=0)) + helper_test_op([(15, 25, 35)], lambda x: x.std(keepdim=True)) + helper_test_op([(15, 25, 35)], lambda x: x.std(0, keepdim=True, correction=0)) def test_softmax(self): # exceed per kernel buffer limit with backward forward_only = (Device.DEFAULT == "WEBGPU") - helper_test_op([(45,65)], lambda x: torch.nn.Softmax(dim=1)(x), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only) - helper_test_op([()], lambda x: torch.nn.Softmax(dim=0)(x), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only) + helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only) + helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only) def test_log_softmax(self): - helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) - helper_test_op([()], lambda x: torch.nn.LogSoftmax(dim=0)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) + helper_test_op([(45,65)], torch.nn.LogSoftmax(dim=1), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) + helper_test_op([()], torch.nn.LogSoftmax(dim=0), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) def test_log_softmax_other_axis(self): - helper_test_op([(10,10,10)], lambda x: x.log_softmax(0), lambda x: x.log_softmax(0), atol=1e-7, grad_atol=1e-7) - helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7) - helper_test_op([(10,10,10)], lambda x: x.log_softmax(2), lambda x: x.log_softmax(2), atol=1e-7, grad_atol=1e-7) + helper_test_op([(10,10,10)], lambda x: x.log_softmax(0), atol=1e-7, grad_atol=1e-7) + helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7) + helper_test_op([(10,10,10)], lambda x: x.log_softmax(2), atol=1e-7, grad_atol=1e-7) def test_tanh(self): - helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6) - helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6, a=-100) - helper_test_op([()], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6) + helper_test_op([(45,65)], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6) + helper_test_op([(45,65)], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6, a=-100) + helper_test_op([()], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6) def test_hardtanh(self): for val in range(10, 30, 5): - helper_test_op([(45,65)], lambda x: torch.nn.functional.hardtanh(x,-val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6) - helper_test_op([()], lambda x: torch.nn.functional.hardtanh(x,-val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6) + helper_test_op([(45,65)], lambda x: torch.nn.functional.hardtanh(x, -val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6) + helper_test_op([()], lambda x: torch.nn.functional.hardtanh(x, -val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6) def test_topo_sort(self): - helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6) - helper_test_op([()], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6) + helper_test_op([(45,65)], lambda x: (x+x)*x, atol=1e-6, grad_atol=1e-6) + helper_test_op([()], lambda x: (x+x)*x, atol=1e-6, grad_atol=1e-6) - def test_scalar_mul(self): - helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2) - helper_test_op([()], lambda x: x*2, lambda x: x*2) - def test_scalar_rmul(self): - helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x) - helper_test_op([()], lambda x: 2*x, lambda x: 2*x) - def test_scalar_sub(self): - helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2) - helper_test_op([()], lambda x: x-2, lambda x: x-2) - def test_scalar_rsub(self): - helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x) - helper_test_op([()], lambda x: 2-x, lambda x: 2-x) def test_flip_eye_crash(self): helper_test_op([], lambda: (torch.eye(10)@torch.eye(10).flip(0)), lambda: (Tensor.eye(10)@Tensor.eye(10).flip(0)), forward_only=True) def test_broadcast_full(self): for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul), - (torch.div, Tensor.div)]: #, (torch.pow, Tensor.pow)]: + (torch.div, Tensor.div), (torch.pow, Tensor.pow)]: for shapes in [((5,13,24,16), (5,1,24,1)), ((1,3,1,7,1), (2,1,5,1,8))]: with self.subTest(op=torch_op.__name__, shapes=shapes): helper_test_op(shapes, torch_op, tinygrad_op, a=-0.5 if tinygrad_op != Tensor.pow else 0.0) def test_broadcast_simple(self): - helper_test_op([(45,65), (45,1)], lambda x,y: x/y, lambda x,y: x/y) - helper_test_op([(45,65), ()], lambda x,y: x/y, lambda x,y: x/y) + helper_test_op([(45,65), (45,1)], lambda x,y: x/y) + helper_test_op([(45,65), ()], lambda x,y: x/y) def test_broadcast_partial(self): for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul), @@ -714,49 +713,49 @@ class TestOps(unittest.TestCase): helper_test_op(shapes, torch_op, tinygrad_op, a=-0.5 if tinygrad_op != Tensor.pow else 0.0) def test_slice_in_bounds_1dim(self): - helper_test_op([(3)], lambda x: x[1:3], lambda x: x[1:3]) - helper_test_op([(3)], lambda x: x[0:2], lambda x: x[0:2]) - helper_test_op([(3)], lambda x: x[-2:2], lambda x: x[-2:2]) + helper_test_op([(3)], lambda x: x[1:3]) + helper_test_op([(3)], lambda x: x[0:2]) + helper_test_op([(3)], lambda x: x[-2:2]) def test_slice_on_0dim_tensor(self): - helper_test_op([()], lambda x: x[None], lambda x: x[None]) + helper_test_op([()], lambda x: x[None]) with self.assertRaises(IndexError): a = Tensor(3.14) a[0] def test_slice_int_indexing(self): - helper_test_op([(3)], lambda x: x[0], lambda x: x[0]) - helper_test_op([(3)], lambda x: x[2], lambda x: x[2]) - helper_test_op([(3)], lambda x: x[-1], lambda x: x[-1]) - helper_test_op([(3)], lambda x: x[-3], lambda x: x[-3]) - helper_test_op([(10,10)], lambda x: x[1], lambda x: x[1]) - helper_test_op([(3,3,3)], lambda x: x[1,1,1], lambda x: x[1,1,1]) + helper_test_op([(3)], lambda x: x[0]) + helper_test_op([(3)], lambda x: x[2]) + helper_test_op([(3)], lambda x: x[-1]) + helper_test_op([(3)], lambda x: x[-3]) + helper_test_op([(10,10)], lambda x: x[1]) + helper_test_op([(3,3,3)], lambda x: x[1,1,1]) def test_slice_in_bounds_multidim(self): - helper_test_op([(3,3,3)], lambda x: x[1:2], lambda x: x[1:2]) - helper_test_op([(3,3,3)], lambda x: x[1:2, 2], lambda x: x[1:2, 2]) - helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2]) - helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, 0:-1], lambda x: x[1:2, 1:2, 0:-1]) + helper_test_op([(3,3,3)], lambda x: x[1:2]) + helper_test_op([(3,3,3)], lambda x: x[1:2, 2]) + helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2]) + helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, 0:-1]) def test_slice_with_none(self): - helper_test_op([(3,3,3)], lambda x: x[None], lambda x: x[None]) - helper_test_op([(3,3,3)], lambda x: x[1:2, None], lambda x: x[1:2, None]) - helper_test_op([(3,3,3)], lambda x: x[1:2, None, 1:2], lambda x: x[1:2, None, 1:2]) - helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, None, -1], lambda x: x[1:2, 1:2, None, -1]) - helper_test_op([(3,3,3)], lambda x: x[None, None, 1, None, 2, 0:2], lambda x: x[None, None, 1, None, 2, 0:2]) + helper_test_op([(3,3,3)], lambda x: x[None]) + helper_test_op([(3,3,3)], lambda x: x[1:2, None]) + helper_test_op([(3,3,3)], lambda x: x[1:2, None, 1:2]) + helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, None, -1]) + helper_test_op([(3,3,3)], lambda x: x[None, None, 1, None, 2, 0:2]) def test_slice_one_endpoint_out_of_bounds(self): - helper_test_op([(3,3,3)], lambda x: x[0:4], lambda x: x[0:4]) - helper_test_op([(3,3,3)], lambda x: x[-6:4], lambda x: x[-6:4]) - helper_test_op([(3,3,3)], lambda x: x[1:50], lambda x: x[1:50]) - helper_test_op([(3,3,3)], lambda x: x[1:50, 1:2, -1], lambda x: x[1:50, 1:2, -1]) + helper_test_op([(3,3,3)], lambda x: x[0:4]) + helper_test_op([(3,3,3)], lambda x: x[-6:4]) + helper_test_op([(3,3,3)], lambda x: x[1:50]) + helper_test_op([(3,3,3)], lambda x: x[1:50, 1:2, -1]) def test_slice_stride_gt_one(self): - helper_test_op([(7,5,10)], lambda x: x[::2, ::3, ::4], lambda x: x[::2, ::3, ::4]) - helper_test_op([(7,5,10)], lambda x: x[1:5:2, ::3, ::4], lambda x: x[1:5:2, ::3, ::4]) - helper_test_op([(7,5,10)], lambda x: x[1:5:2, 3, ::4], lambda x: x[1:5:2, 3, ::4]) - helper_test_op([(7,5,10)], lambda x: x[1:5:2, None, None, 3, None, ::4], lambda x: x[1:5:2, None, None, 3, None, ::4]) + helper_test_op([(7,5,10)], lambda x: x[::2, ::3, ::4]) + helper_test_op([(7,5,10)], lambda x: x[1:5:2, ::3, ::4]) + helper_test_op([(7,5,10)], lambda x: x[1:5:2, 3, ::4]) + helper_test_op([(7,5,10)], lambda x: x[1:5:2, None, None, 3, None, ::4]) def test_slice_negative_strides(self): # Torch doesn't support slicing with negative steps @@ -767,26 +766,24 @@ class TestOps(unittest.TestCase): np.testing.assert_allclose(a[:, 2:0:-1], t[:, 2:0:-1].numpy()) np.testing.assert_allclose(a[:, 2:0:-1, 3:1:-2], t[:, 2:0:-1, 3:1:-2].numpy()) np.testing.assert_allclose(a[4:0:-3, 2:0:-1, -1:-5:-2], t[4:0:-3, 2:0:-1, -1:-5:-2].numpy()) - if Device.DEFAULT not in ["CPU"]: - # broken - np.testing.assert_allclose(a[2:5:-1, :, :], t[2:5:-1, :, :].numpy()) # shape = (0, 10, 10) - np.testing.assert_allclose(a[:, 2:5:-1, :], t[:, 2:5:-1, :].numpy()) # shape = (0, 10, 10) - np.testing.assert_allclose(a[:, :, 2:5:-1], t[:, :, 2:5:-1].numpy()) # shape = (0, 10, 10) + np.testing.assert_allclose(a[2:5:-1, :, :], t[2:5:-1, :, :].numpy()) # shape = (0, 10, 10) + np.testing.assert_allclose(a[:, 2:5:-1, :], t[:, 2:5:-1, :].numpy()) # shape = (0, 10, 10) + np.testing.assert_allclose(a[:, :, 2:5:-1], t[:, :, 2:5:-1].numpy()) # shape = (0, 10, 10) def test_slice_both_endpoints_out_of_bounds(self): - helper_test_op([(3,3,3)], lambda x: x[5:10], lambda x: x[5:10], forward_only=True) - helper_test_op([(3,3,3)], lambda x: x[-15:-7], lambda x: x[-15:-7], forward_only=True) + helper_test_op([(3,3,3)], lambda x: x[5:10]) + helper_test_op([(3,3,3)], lambda x: x[-15:-7]) def test_slice_start_gt_end(self): - helper_test_op([(3,3,3)], lambda x: x[-2:2], lambda x: x[-2:2], forward_only=True) - helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5], forward_only=True) + helper_test_op([(3,3,3)], lambda x: x[-2:2]) + helper_test_op([(3,3,3)], lambda x: x[-2:-5]) def test_slice_empty(self): - helper_test_op([(10,10)], lambda x: x[1:1], lambda x: x[1:1], forward_only=True) + helper_test_op([(10,10)], lambda x: x[1:1]) def test_slice_zero_in_shape(self): - helper_test_op([(10,10)], lambda x: x[1:1], lambda x: x[1:1], forward_only=True) # x.shape = (0, 10) - helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5], forward_only=True) # x.shape = (0, 3, 3) + helper_test_op([(10,10)], lambda x: x[1:1]) # x.shape = (0, 10) + helper_test_op([(3,3,3)], lambda x: x[-2:-5]) # x.shape = (0, 3, 3) def test_slice_errors(self): a = Tensor.ones(4, 3) @@ -799,11 +796,11 @@ class TestOps(unittest.TestCase): with self.assertRaises(IndexError): b[:] # slice cannot be applied to a 0-dim tensor def test_slice_ellipsis(self): - helper_test_op([(3,3,3,3)], lambda x: x[..., 0], lambda x: x[..., 0]) - helper_test_op([(3,3,3,3)], lambda x: x[0, ...], lambda x: x[0, ...]) - helper_test_op([(3,3,3,3)], lambda x: x[0, ..., 0], lambda x: x[0, ..., 0]) - helper_test_op([(3,3,3,3)], lambda x: x[0:3, ..., 2:3], lambda x: x[0:3, ..., 2:3]) - helper_test_op([(3,3,3,3)], lambda x: x[None, 0:3, ..., 0, None], lambda x: x[None, 0:3, ..., 0, None]) + helper_test_op([(3,3,3,3)], lambda x: x[..., 0]) + helper_test_op([(3,3,3,3)], lambda x: x[0, ...]) + helper_test_op([(3,3,3,3)], lambda x: x[0, ..., 0]) + helper_test_op([(3,3,3,3)], lambda x: x[0:3, ..., 2:3]) + helper_test_op([(3,3,3,3)], lambda x: x[None, 0:3, ..., 0, None]) def test_pad2d(self): helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4))) @@ -850,80 +847,82 @@ class TestOps(unittest.TestCase): helper_test_op([(4,4)], lambda x: torch.stack([x for i in range(4)])[3], lambda x: Tensor.stack([x for i in range(4)])[3]) def test_transpose(self): - helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(1,2)) - helper_test_op([(3,3,3)], lambda x: x.transpose(0,2), lambda x: x.transpose(0,2)) - helper_test_op([(1,2,3,4)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.permute(order=(3,0,2,1))) - helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.permute(order=(3,2,1,0))) - helper_test_op([()], lambda x: x.permute(()), lambda x: x.permute(())) + helper_test_op([(3,3)], lambda x: x.T) + helper_test_op([(3,3,3)], lambda x: x.transpose(1,2)) + helper_test_op([(3,3,3)], lambda x: x.transpose(0,2)) + helper_test_op([(1,2,3,4)], lambda x: x.permute((3,0,2,1))) + helper_test_op([(3,4,5,6)], lambda x: x.permute((3,2,1,0))) + helper_test_op([()], lambda x: x.permute(())) def test_reshape(self): - helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6))) - helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6))) - helper_test_op([()], lambda x: torch.reshape(x, []), lambda x: x.reshape([])) - helper_test_op([(1,)], lambda x: torch.reshape(x, []), lambda x: x.reshape([])) - helper_test_op([()], lambda x: torch.reshape(x, [1]), lambda x: x.reshape([1])) + helper_test_op([(4,3,6,6)], lambda x: x.reshape((-1,3,6,6))) + helper_test_op([(4,3,6,6)], lambda x: x.reshape((-1,1,6,6))) + helper_test_op([()], lambda x: x.reshape([])) + helper_test_op([(1,)], lambda x: x.reshape([])) + helper_test_op([()], lambda x: x.reshape([1])) + helper_test_op([()], lambda x: x.reshape([1, 1, 1])) with self.assertRaises(ValueError): x = Tensor.ones((4,3,6,6)) x.reshape([]) def test_flip(self): - helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,)), lambda x: x.flip(axis=(0,))) - helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1)), lambda x: x.flip(axis=(0,1))) - helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1,3)), lambda x: x.flip(axis=(0,1,3))) - helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(3,))) - helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1,3)).flip((0,)), lambda x: x.flip(axis=(0,1,3)).flip(0)) - helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(-1,))) - helper_test_op([()], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=())) - helper_test_op([(1,)], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=())) - helper_test_op([(4, 3, 6, 6)], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=())) + helper_test_op([(4,3,6,6)], lambda x: x.flip((0,))) + helper_test_op([(4,3,6,6)], lambda x: x.flip((0,1))) + helper_test_op([(4,3,6,6)], lambda x: x.flip((0,1,3))) + helper_test_op([(4,3,6,6)], lambda x: x.flip((3,))) + helper_test_op([(4,3,6,6)], lambda x: x.flip((0,1,3)).flip(0)) + helper_test_op([(4,3,6,6)], lambda x: x.flip((-1,))) + helper_test_op([()], lambda x: x.flip(())) + helper_test_op([(1,)], lambda x: x.flip(())) + helper_test_op([(4, 3, 6, 6)], lambda x: x.flip(())) def test_squeeze(self): - helper_test_op([(1,3,6,6)], lambda x: torch.squeeze(x, 0), lambda x: x.squeeze(dim=0)) - helper_test_op([(4,3,1,6)], lambda x: torch.squeeze(x, 1), lambda x: x.squeeze(dim=1)) - helper_test_op([(4,3,6,6)], lambda x: torch.squeeze(x, 3), lambda x: x.squeeze(dim=3)) + helper_test_op([(1,3,6,6)], lambda x: x.squeeze(0)) + helper_test_op([(4,3,1,6)], lambda x: x.squeeze(1)) + helper_test_op([(4,3,6,6)], lambda x: x.squeeze(3)) self.helper_test_exception([(4,3,6,6)], lambda x: torch.squeeze(x, 50), lambda x: x.squeeze(dim=50), expected=IndexError) self.helper_test_exception([(4,3,6,6)], lambda x: torch.squeeze(x, -50), lambda x: x.squeeze(dim=-50), expected=IndexError) - helper_test_op([(4,3,6,1)], lambda x: torch.squeeze(x, -1), lambda x: x.squeeze(dim=-1)) - helper_test_op([(4,3,6,6)], lambda x: torch.squeeze(x), lambda x: x.squeeze()) - helper_test_op([(1,3,6,6)], lambda x: torch.squeeze(x), lambda x: x.squeeze()) - helper_test_op([(2,3,1)], lambda x: torch.squeeze(x), lambda x: x.squeeze()) - helper_test_op([()], lambda x: torch.squeeze(x, -1), lambda x: x.squeeze(dim=-1)) - helper_test_op([()], lambda x: torch.squeeze(x, 0), lambda x: x.squeeze(dim=0)) - helper_test_op([()], lambda x: torch.squeeze(x), lambda x: x.squeeze()) + helper_test_op([(4,3,6,1)], lambda x: x.squeeze(-1)) + helper_test_op([(4,3,6,6)], lambda x: x.squeeze()) + helper_test_op([(1,3,6,6)], lambda x: x.squeeze()) + helper_test_op([(2,3,1)], lambda x: x.squeeze()) + helper_test_op([()], lambda x: x.squeeze(-1)) + helper_test_op([()], lambda x: x.squeeze(0)) + helper_test_op([()], lambda x: x.squeeze()) self.helper_test_exception([()], lambda x: torch.squeeze(x, 10), lambda x: x.squeeze(dim=10), expected=IndexError) self.helper_test_exception([()], lambda x: torch.squeeze(x, 1), lambda x: x.squeeze(dim=1), expected=IndexError) self.helper_test_exception([()], lambda x: torch.squeeze(x, -2), lambda x: x.squeeze(dim=-2), expected=IndexError) def test_unsqueeze(self): - helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0)) - helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 4), lambda x: x.unsqueeze(dim=4)) - helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -1), lambda x: x.unsqueeze(dim=-1)) - helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -3), lambda x: x.unsqueeze(dim=-3)) - helper_test_op([()], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0)) + helper_test_op([(4,3,6,6)], lambda x: x.unsqueeze(0)) + helper_test_op([(4,3,6,6)], lambda x: x.unsqueeze(4)) + helper_test_op([(4,3,6,6)], lambda x: x.unsqueeze(-1)) + helper_test_op([(4,3,6,6)], lambda x: x.unsqueeze(-3)) + helper_test_op([()], lambda x: x.unsqueeze(0)) def test_flatten(self): for axis in range(3): - helper_test_op([(4,3,6,6)], lambda x: torch.flatten(x, start_dim=axis), lambda x: x.flatten(start_dim=axis)) + helper_test_op([(4,3,6,6)], lambda x: x.flatten(start_dim=axis)) for axis in range(3): - helper_test_op([(4,3,6,6)], lambda x: torch.flatten(x, end_dim=axis), lambda x: x.flatten(end_dim=axis)) - helper_test_op([(4,3,6,6)], lambda x: torch.flatten(x, start_dim=1, end_dim=3), lambda x: x.flatten(start_dim=1, end_dim=3)) - helper_test_op([()], lambda x: x.flatten(), lambda x: x.flatten()) - helper_test_op([(1,)], lambda x: x.flatten(), lambda x: x.flatten()) + helper_test_op([(4,3,6,6)], lambda x: x.flatten(end_dim=axis)) + helper_test_op([(4,3,6,6)], lambda x: x.flatten(start_dim=1, end_dim=3)) + helper_test_op([()], lambda x: x.flatten()) + helper_test_op([(1,)], lambda x: x.flatten()) def test_unflatten(self): - helper_test_op([(4,3,6,6)], lambda x: torch.unflatten(x, 0, (2, 2)), lambda x: x.unflatten(0, (2, 2))) - helper_test_op([(4,3,6,6)], lambda x: torch.unflatten(x, 3, (3, 2)), lambda x: x.unflatten(3, (3, 2))) - helper_test_op([(4,3,6,6)], lambda x: torch.unflatten(x, -1, (3, 2, 1)), lambda x: x.unflatten(-1, (3, 2, 1))) + helper_test_op([(4,3,6,6)], lambda x: x.unflatten(0, (2, 2))) + helper_test_op([(4,3,6,6)], lambda x: x.unflatten(3, (3, 2))) + helper_test_op([(4,3,6,6)], lambda x: x.unflatten(-1, (3, 2, 1))) def test_detach(self): - helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), forward_only=True) - helper_test_op([()], lambda x: x.detach(), lambda x: x.detach(), forward_only=True) + helper_test_op([(4,3,6,6)], lambda x: x.detach(), forward_only=True) + helper_test_op([()], lambda x: x.detach(), forward_only=True) def test_expand(self): - arg = (4,3,2,6) - helper_test_op([(4,3,1,6)], lambda x: x.expand(arg), lambda x: x.expand(shape=arg)) - helper_test_op([()], lambda x: x.expand([]), lambda x: x.expand(shape=[])) + helper_test_op([(4,3,1,6)], lambda x: x.expand((4,3,2,6))) + helper_test_op([(1,1,1,1)], lambda x: x.expand((4,3,2,6))) + helper_test_op([()], lambda x: x.expand([])) @unittest.skip("very slow") def test_sd_big_conv(self): @@ -1348,12 +1347,12 @@ class TestOps(unittest.TestCase): helper_test_op([(3, 3)], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats)) def test_clip(self): - helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2), lambda x: x.clip(-2.3, 1.2)) - helper_test_op([(45,65)], lambda x: x.clip(0, 0), lambda x: x.clip(0, 0)) - helper_test_op([(45,65)], lambda x: x.clip(10, 100), lambda x: x.clip(10, 100)) - helper_test_op([(45,65)], lambda x: x.clip(0, 0.1), lambda x: x.clip(0, 0.1)) - helper_test_op([(45,65)], lambda x: x.clip(-0.3, -0.2), lambda x: x.clip(-0.3, -0.2)) - helper_test_op([(45,65)], lambda x: x.clip(3, 0), lambda x: x.clip(3, 0)) + helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2)) + helper_test_op([(45,65)], lambda x: x.clip(0, 0)) + helper_test_op([(45,65)], lambda x: x.clip(10, 100)) + helper_test_op([(45,65)], lambda x: x.clip(0, 0.1)) + helper_test_op([(45,65)], lambda x: x.clip(-0.3, -0.2)) + helper_test_op([(45,65)], lambda x: x.clip(3, 0)) def test_matvecmat(self): helper_test_op([(1,128), (128,128), (128,128)], lambda x,y,z: (x@y).relu()@z, atol=1e-4) @@ -1484,8 +1483,7 @@ class TestOps(unittest.TestCase): lambda x: x.gather(idx=a, dim=0), expected=(RuntimeError, AssertionError)) def test_scaled_product_attention(self): - helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z), - lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z)) + helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], torch.nn.functional.scaled_dot_product_attention, Tensor.scaled_dot_product_attention) helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64), (32,8,16,16)], lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,attn_mask=m), lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,attn_mask=m))