diff --git a/test/models/test_end2end.py b/test/models/test_end2end.py index 49c37147cd..da326118e8 100644 --- a/test/models/test_end2end.py +++ b/test/models/test_end2end.py @@ -22,14 +22,14 @@ def compare_tiny_torch(model, model_torch, X, Y): out = model(X) loss = (out * Y).mean() - print(loss.realize().numpy()[0]) + print(loss.realize().numpy()) out_torch = model_torch(torch.Tensor(X.numpy())) loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean() print(loss_torch.detach().numpy()) # assert losses match - np.testing.assert_allclose(loss.realize().numpy()[0], loss_torch.detach().numpy(), atol=1e-4) + np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4) # zero and backward optimizer.zero_grad() diff --git a/test/test_ops.py b/test/test_ops.py index 38abd02a87..497e5aee82 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -15,7 +15,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra if shps is None: ts = [torch.tensor(x, requires_grad=True) for x in vals] else: - ts = [torch.tensor((np.random.random(size=x).astype(np.float32)+a)*b, requires_grad=True) for x in shps] + ts = [torch.tensor((np.random.random(size=x)+a)*b, requires_grad=True, dtype=torch.float32) for x in shps] tst = [Tensor(x.detach().numpy(), requires_grad=not FORWARD_ONLY) for x in ts] @@ -29,7 +29,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra def compare(s, x,y,atol,rtol): if PRINT_TENSORS: print(s, x, y) - if y.shape != tuple(): assert x.shape == y.shape, f"shape mismatch (tinygrad){x.shape} != (torch){y.shape}" + assert x.shape == y.shape, f"shape mismatch: tinygrad={x.shape} | torch={y.shape}" try: np.testing.assert_allclose(x,y, atol=atol, rtol=rtol) except Exception: @@ -62,6 +62,8 @@ class TestOps(unittest.TestCase): helper_test_op([], lambda: torch.full((45,65), 4), lambda: Tensor.full((45,65), 4), forward_only=True) def test_zeros(self): helper_test_op([], lambda: torch.zeros(45,65), lambda: Tensor.zeros(45,65), forward_only=True) + helper_test_op([], lambda: torch.zeros([45,65]), lambda: Tensor.zeros([45,65]), forward_only=True) + helper_test_op([], lambda: torch.zeros([]), lambda: Tensor.zeros([]), forward_only=True) def test_zeros_like(self): a = Tensor([[1,2,3],[4,5,6]]) b = torch.tensor([[1,2,3],[4,5,6]]) @@ -70,12 +72,15 @@ class TestOps(unittest.TestCase): helper_test_op([], lambda: torch.empty(45,65)*0/0, lambda: Tensor.empty(45,65)*0/0, forward_only=True) def test_ones(self): helper_test_op([], lambda: torch.ones(45,65), lambda: Tensor.ones(45,65), forward_only=True) + helper_test_op([], lambda: torch.ones([45,65]), lambda: Tensor.ones([45,65]), forward_only=True) + helper_test_op([], lambda: torch.ones([]), lambda: Tensor.ones([]), forward_only=True) def test_ones_like(self): a = Tensor([[1,2,3],[4,5,6]]) b = torch.tensor([[1,2,3],[4,5,6]]) helper_test_op([], lambda: torch.ones_like(b), lambda: Tensor.ones_like(a), forward_only=True) def test_eye(self): helper_test_op([], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True) + def test_arange(self): helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True) def test_where(self): @@ -121,43 +126,58 @@ class TestOps(unittest.TestCase): def test_maximum(self): helper_test_op([(45,65), (45,65)], torch.maximum, Tensor.maximum) + helper_test_op([(), ()], torch.maximum, Tensor.maximum) helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., 4.], [1., 2., 3., 0.]]) def test_minimum(self): helper_test_op([(45,65), (45,65)], torch.minimum, Tensor.minimum) + helper_test_op([(), ()], torch.minimum, Tensor.minimum) def test_add(self): helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add) + helper_test_op([(), ()], lambda x,y: x+y, Tensor.add) def test_add_simple(self): helper_test_op([(256), (256)], lambda x,y: x+y, Tensor.add, forward_only=True) def test_broadcasted_add(self): helper_test_op([(45,65), (45,1)], lambda x,y: x+y, lambda x,y: x+y) + helper_test_op([(45,65), ()], lambda x,y: x+y, lambda x,y: x+y) def test_broadcasted_add_2(self): helper_test_op([(45,65), (65,)], lambda x,y: x+y, lambda x,y: x+y) def test_sub(self): helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub) + helper_test_op([(), ()], lambda x,y: x-y, Tensor.sub) def test_neg(self): helper_test_op([(45,65)], lambda x: -x) + helper_test_op([()], lambda x: -x) def test_mul(self): helper_test_op([(64,64), (64,64)], lambda x,y: x*y, Tensor.mul) + helper_test_op([(), ()], lambda x,y: x*y, Tensor.mul) def test_div(self): helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div) + helper_test_op([(), ()], lambda x,y: x/y, Tensor.div) def test_div_const(self): helper_test_op([(45,65)], lambda x: x/255, lambda x: x/255) helper_test_op([(45,65)], lambda x: x/1, lambda x: x/1) helper_test_op([(45,65)], lambda x: 1/x, lambda x: 1/x) helper_test_op([(45,65)], lambda x: x/2, lambda x: x/2) helper_test_op([(45,65)], lambda x: 2/x, lambda x: 2/x) + helper_test_op([()], lambda x: x/2, lambda x: x/2) + helper_test_op([()], lambda x: 2/x, lambda x: 2/x) def test_pow(self): helper_test_op([(45,65)], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0) helper_test_op([(45,65)], lambda x: x**3, lambda x: Tensor.pow(x,3), a=0) helper_test_op([(45,65)], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0) helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, a=0) + helper_test_op([()], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0) + helper_test_op([()], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0) def test_pow_const(self): helper_test_op([(45,65)], lambda x: x**1.0, lambda x: x**1.0) helper_test_op([(45,65)], lambda x: 1.0**x, lambda x: 1.0**x) helper_test_op([(45,65)], lambda x: x**2.0, lambda x: x**2.0) helper_test_op([(45,65)], lambda x: 2.0**x, lambda x: 2.0**x) + helper_test_op([()], lambda x: x**2.0, lambda x: x**2.0) + helper_test_op([()], lambda x: 2.0**x, lambda x: 2.0**x) def test_sqrt(self): helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, a=0) + helper_test_op([()], lambda x: x.sqrt(), Tensor.sqrt, a=0) def test_sin(self): helper_test_op([(45,65)], lambda x: x.sin(), Tensor.sin, a=0) @@ -168,47 +188,65 @@ class TestOps(unittest.TestCase): def test_relu(self): helper_test_op([(64,64)], lambda x: x.relu(), Tensor.relu) + helper_test_op([()], lambda x: x.relu(), Tensor.relu) def test_relu_exact(self): helper_test_op(None, lambda x: x.relu(), Tensor.relu, vals=[[-1.,0,1]]) def test_relu_maximum_exact(self): helper_test_op(None, lambda x: torch.maximum(x, torch.zeros_like(x, requires_grad=False)), lambda x: Tensor.maximum(x, 0), vals=[[-1.,0,1]]) def test_leakyrelu(self): helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu) + helper_test_op([()], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu) def test_celu(self): for val in range(1, 5): helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val)) + helper_test_op([()], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val)) def test_abs(self): helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs) + helper_test_op([()], lambda x: torch.abs(x), Tensor.abs) def test_log(self): helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log) + helper_test_op([()], lambda x: torch.log(x), Tensor.log) def test_exp(self): helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp) + helper_test_op([()], lambda x: torch.exp(x), Tensor.exp) def test_sign(self): helper_test_op([(45,65)], lambda x: torch.sign(x), Tensor.sign) + helper_test_op([()], lambda x: torch.sign(x), Tensor.sign) def test_softsign(self): helper_test_op([(45,65)], lambda x: torch.nn.functional.softsign(x), Tensor.softsign) + helper_test_op([()], lambda x: torch.nn.functional.softsign(x), Tensor.softsign) def test_sigmoid(self): helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid) + helper_test_op([()], lambda x: x.sigmoid(), Tensor.sigmoid, forward_only=True) def test_softplus(self): helper_test_op([(45,65)], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6) + helper_test_op([()], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6) @unittest.skip("not supported in older pytorch") def test_gelu(self): helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu) def test_quick_gelu(self): helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu) + helper_test_op([()], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu) def test_elu(self): helper_test_op([(45,65)], lambda x: torch.nn.functional.elu(x), Tensor.elu) helper_test_op([(45,65)], lambda x: torch.nn.functional.elu(x, alpha=0.1), lambda x: Tensor.elu(x, alpha=0.1)) + helper_test_op([()], lambda x: torch.nn.functional.elu(x), Tensor.elu) def test_relu6(self): helper_test_op([(45,65)], lambda x: torch.nn.functional.relu6(x), Tensor.relu6) + helper_test_op([()], lambda x: torch.nn.functional.relu6(x), Tensor.relu6) def test_hardswish(self): helper_test_op([(45,65)], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6) + helper_test_op([()], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6) def test_mish(self): def _mish_pytorch(x): return x*torch.tanh(torch.nn.functional.softplus(x)) helper_test_op([(45,65)], _mish_pytorch, Tensor.mish, atol=1e-4) + helper_test_op([()], _mish_pytorch, Tensor.mish, atol=1e-4) def test_dot(self): helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) + with self.assertRaises(RuntimeError): + a = Tensor(3.14) + a.matmul(a) def test_matmul_simple(self): helper_test_op([(4), (4,4)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) def test_matmul(self): @@ -225,6 +263,10 @@ class TestOps(unittest.TestCase): helper_test_op([(64,64), (64,64)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-3) def test_broadcastdot(self): helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) + with self.assertRaises(RuntimeError): + a = Tensor(3.14) + b = Tensor.ones(3,3) + a @ b def test_multidot(self): helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) @@ -241,10 +283,12 @@ class TestOps(unittest.TestCase): helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)), lambda x: Tensor.sum(x, axis=(0,2))) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2))) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1)) + helper_test_op([()], lambda x: x.sum(), Tensor.sum) def test_min(self): helper_test_op([(3,3)], lambda x: x.min(), Tensor.min) helper_test_op([(45,3)], lambda x: x.min(), Tensor.min) helper_test_op([(45,3)], lambda x: x.min().mul(0.5), lambda x: Tensor.min(x).mul(0.5)) + helper_test_op([()], lambda x: x.min(), Tensor.min) def test_max(self): helper_test_op([(45,3)], lambda x: x.max(), Tensor.max) helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5)) @@ -253,8 +297,10 @@ class TestOps(unittest.TestCase): [[1.0,1.0,0.0,1.0]], ]) helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1)) + helper_test_op([()], lambda x: x.max(), Tensor.max) def test_mean(self): helper_test_op([(3,4,5,6)], lambda x: x.mean()) + helper_test_op([()], lambda x: x.mean()) def test_mean_axis(self): helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2))) def test_std(self): @@ -275,28 +321,34 @@ class TestOps(unittest.TestCase): helper_test_op([(45, 65, 85)], lambda x: torch.std(x, keepdim=True, correction=0, dim=0), lambda x: Tensor.std(x, keepdim=True, correction=0, axis=0)) def test_log_softmax(self): helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) + helper_test_op([()], lambda x: torch.nn.LogSoftmax(dim=0)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) def test_log_softmax_other_axis(self): helper_test_op([(10,10,10)], lambda x: x.log_softmax(0), lambda x: x.log_softmax(0), atol=1e-7, grad_atol=1e-7) helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7) helper_test_op([(10,10,10)], lambda x: x.log_softmax(2), lambda x: x.log_softmax(2), atol=1e-7, grad_atol=1e-7) def test_tanh(self): helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6) + helper_test_op([()], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6) def test_hardtanh(self): for val in range(10, 30, 5): helper_test_op([(45,65)], lambda x: torch.nn.functional.hardtanh(x,-val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6) + helper_test_op([()], lambda x: torch.nn.functional.hardtanh(x,-val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6) def test_topo_sort(self): helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6) + helper_test_op([()], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6) def test_scalar_mul(self): helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2) + helper_test_op([()], lambda x: x*2, lambda x: x*2) def test_scalar_rmul(self): helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x) - + helper_test_op([()], lambda x: 2*x, lambda x: 2*x) def test_scalar_sub(self): helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2) + helper_test_op([()], lambda x: x-2, lambda x: x-2) def test_scalar_rsub(self): helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x) - + helper_test_op([()], lambda x: 2-x, lambda x: 2-x) def test_flip_eye_crash(self): helper_test_op([], lambda: (torch.eye(10)@torch.eye(10).flip(0)), lambda: (Tensor.eye(10)@Tensor.eye(10).flip(0)), forward_only=True) @@ -310,6 +362,7 @@ class TestOps(unittest.TestCase): def test_broadcast_simple(self): helper_test_op([(45,65), (45,1)], lambda x,y: x/y, lambda x,y: x/y) + helper_test_op([(45,65), ()], lambda x,y: x/y, lambda x,y: x/y) def test_broadcast_partial(self): for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul), @@ -320,19 +373,83 @@ class TestOps(unittest.TestCase): # NOTE: ANE backwards? helper_test_op(shapes, torch_op, tinygrad_op, a=-0.5 if tinygrad_op != Tensor.pow else 0.0) - def test_slice_simple(self): - helper_test_op([(3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2]) + def test_slice_in_bounds_1dim(self): + helper_test_op([(3)], lambda x: x[1:3], lambda x: x[1:3]) + helper_test_op([(3)], lambda x: x[0:2], lambda x: x[0:2]) + helper_test_op([(3)], lambda x: x[-2:2], lambda x: x[-2:2]) - def test_slice(self): - helper_test_op([(3,3,3,3)], lambda x: x[1:2], lambda x: x[1:2]) - helper_test_op([(3,3,3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2]) - helper_test_op([(3,3,3,3)], lambda x: x[1:2, 1:2, 0:-1], lambda x: x[1:2, 1:2, 0:-1]) + def test_slice_on_0dim_tensor(self): + helper_test_op([()], lambda x: x[None], lambda x: x[None]) - def test_slice_one(self): + with self.assertRaises(IndexError): + a = Tensor(3.14) + a[0] + + def test_slice_int_indexing(self): helper_test_op([(3)], lambda x: x[1], lambda x: x[1]) - - def test_slice_one_multi(self): + helper_test_op([(3)], lambda x: x[-2], lambda x: x[-2]) helper_test_op([(10,10)], lambda x: x[1], lambda x: x[1]) + helper_test_op([(3,3,3)], lambda x: x[1,1,1], lambda x: x[1,1,1]) + + def test_slice_in_bounds_multidim(self): + helper_test_op([(3,3,3)], lambda x: x[1:2], lambda x: x[1:2]) + helper_test_op([(3,3,3)], lambda x: x[1:2, 2], lambda x: x[1:2, 2]) + helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2]) + helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, 0:-1], lambda x: x[1:2, 1:2, 0:-1]) + + def test_slice_with_none(self): + helper_test_op([(3,3,3)], lambda x: x[None], lambda x: x[None]) + helper_test_op([(3,3,3)], lambda x: x[1:2, None], lambda x: x[1:2, None]) + helper_test_op([(3,3,3)], lambda x: x[1:2, None, 1:2], lambda x: x[1:2, None, 1:2]) + helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, None, -1], lambda x: x[1:2, 1:2, None, -1]) + + def test_slice_one_endpoint_out_of_bounds(self): + helper_test_op([(3,3,3)], lambda x: x[0:4], lambda x: x[0:4]) + helper_test_op([(3,3,3)], lambda x: x[-6:4], lambda x: x[-6:4]) + helper_test_op([(3,3,3)], lambda x: x[1:50], lambda x: x[1:50]) + helper_test_op([(3,3,3)], lambda x: x[1:50, 1:2, -1], lambda x: x[1:50, 1:2, -1]) + + def test_slice_stride_gt_one(self): + helper_test_op([(7,5,10)], lambda x: x[::2, ::3, ::4], lambda x: x[::2, ::3, ::4]) + helper_test_op([(7,5,10)], lambda x: x[1:5:2, ::3, ::4], lambda x: x[1:5:2, ::3, ::4]) + helper_test_op([(7,5,10)], lambda x: x[1:5:2, 3, ::4], lambda x: x[1:5:2, 3, ::4]) + helper_test_op([(7,5,10)], lambda x: x[1:5:2, None, None, 3, None, ::4], lambda x: x[1:5:2, None, None, 3, None, ::4]) + + def test_slice_negative_strides(self): + # Torch doesn't support slicing with negative steps + a = np.random.randn(10, 10, 10).astype(np.float32) + t = Tensor(a) + np.testing.assert_allclose(a[::-1], t[::-1].numpy()) + np.testing.assert_allclose(a[::-2], t[::-2].numpy()) + np.testing.assert_allclose(a[:, 2:0:-1], t[:, 2:0:-1].numpy()) + np.testing.assert_allclose(a[:, 2:0:-1, 3:1:-2], t[:, 2:0:-1, 3:1:-2].numpy()) + np.testing.assert_allclose(a[4:0:-3, 2:0:-1, -1:-5:-2], t[4:0:-3, 2:0:-1, -1:-5:-2].numpy()) + + @unittest.skip("No suppport for tensors with 0s in shape") + def test_slice_both_endpoints_out_of_bounds(self): + helper_test_op([(3,3,3)], lambda x: x[5:10], lambda x: x[5:10], forward_only=True) + helper_test_op([(3,3,3)], lambda x: x[-15:-7], lambda x: x[-15:-7], forward_only=True) + + @unittest.skip("No suppport for tensors with 0s in shape") + def test_slice_start_gt_end(self): + helper_test_op([(3,3,3)], lambda x: x[-2:2], lambda x: x[-2:2], forward_only=True) + helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5], forward_only=True) + + @unittest.skip("No suppport for tensors with 0s in shape") + def test_slice_empty(self): + helper_test_op([(10,10)], lambda x: x[1:1], lambda x: x[1:1], forward_only=True) + + @unittest.skip("No suppport for tensors with 0s in shape") + def test_slice_zero_in_shape(self): + helper_test_op([(10,10)], lambda x: x[1:1], lambda x: x[1:1]) # x.shape = (0, 10) + helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5]) # x.shape = (0, 3, 3) + + def test_slice_errors(self): + a = Tensor.ones(4, 3) + with self.assertRaises(IndexError): + a[1, 77, 77, 77] # IndexError: (finds too many indices before the out of bounds) + a[1, 77] # IndexError: (out of bounds). + a[0, -77] def test_pad2d(self): helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4))) @@ -342,10 +459,18 @@ class TestOps(unittest.TestCase): helper_test_op([(3,3,3)], lambda x: x.transpose(0,2), lambda x: x.transpose(0,2)) helper_test_op([(1,2,3,4)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.permute(order=(3,0,2,1))) helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.permute(order=(3,2,1,0))) + helper_test_op([()], lambda x: x.permute(()), lambda x: x.permute(())) def test_reshape(self): helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6))) helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6))) + helper_test_op([()], lambda x: torch.reshape(x, []), lambda x: x.reshape([])) + helper_test_op([(1,)], lambda x: torch.reshape(x, []), lambda x: x.reshape([])) + helper_test_op([()], lambda x: torch.reshape(x, [1]), lambda x: x.reshape([1])) + + with self.assertRaises(AssertionError): + x = Tensor.ones((4,3,6,6)) + x.reshape([]) def test_flip(self): helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,)), lambda x: x.flip(axis=(0,))) @@ -354,23 +479,31 @@ class TestOps(unittest.TestCase): helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(3,))) helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1,3)).flip((0,)), lambda x: x.flip(axis=(0,1,3)).flip(0)) helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(-1,))) + helper_test_op([()], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=())) + helper_test_op([(1,)], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=())) + helper_test_op([(4, 3, 6, 6)], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=())) def test_unsqueeze(self): helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0)) helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 4), lambda x: x.unsqueeze(dim=4)) helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -1), lambda x: x.unsqueeze(dim=-1)) helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -3), lambda x: x.unsqueeze(dim=-3)) + helper_test_op([()], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0)) def test_flatten(self): for axis in range(3): helper_test_op([(4,3,6,6)], lambda x: torch.flatten(x, start_dim=axis), lambda x: x.flatten(axis)) + helper_test_op([()], lambda x: x.flatten(), lambda x: x.flatten()) + helper_test_op([(1,)], lambda x: x.flatten(), lambda x: x.flatten()) def test_detach(self): helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), forward_only=True) + helper_test_op([()], lambda x: x.detach(), lambda x: x.detach(), forward_only=True) def test_expand(self): arg = (4,3,2,6) helper_test_op([(4,3,1,6)], lambda x: x.expand(arg), lambda x: x.expand(shape=arg)) + helper_test_op([()], lambda x: x.expand([]), lambda x: x.expand(shape=[])) @unittest.skip("very slow") def test_sd_big_conv(self): @@ -695,6 +828,10 @@ class TestOps(unittest.TestCase): for dim in range(-1, 2): helper_test_op([(45,65), (45,65)], lambda x,y: torch.cat((x,y), dim), lambda x,y: x.cat(y, dim=dim)) + with self.assertRaises(AssertionError): + a = Tensor(3.14) + a.cat(a) + def test_multicat(self): for dim in range(-1, 2): helper_test_op([(45,65), (45,65), (45,65)], lambda x,y,z: torch.cat((x,y,z), dim), lambda x,y,z: x.cat(y, z, dim=dim)) @@ -707,6 +844,9 @@ class TestOps(unittest.TestCase): with self.assertRaises(IndexError): Tensor.stack([x], dim=77) + + a = Tensor(3.14) + np.testing.assert_allclose(Tensor.stack([a, a]).numpy(), Tensor([3.14, 3.14]).numpy()) def test_repeat(self): x = Tensor.randn(45, 65, 3) @@ -715,6 +855,7 @@ class TestOps(unittest.TestCase): for reps in [[], [4], [2, 1], [3, 2, 2]]: repeats = base_repeats + reps helper_test_op([(45, 65, 3)], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats)) + helper_test_op([()], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats)) with self.assertRaises(AssertionError): x.repeat((2, 4)) @@ -722,7 +863,6 @@ class TestOps(unittest.TestCase): with self.assertRaises(AssertionError): x.repeat((2, 0, 4)) - def test_clip(self): helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2), lambda x: x.clip(-2.3, 1.2)) diff --git a/test/test_tensor.py b/test/test_tensor.py index 23ed975858..7e9e997348 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -14,6 +14,13 @@ W_init = np.random.randn(3,3).astype(np.float32) m_init = np.random.randn(1,3).astype(np.float32) class TestTinygrad(unittest.TestCase): + def test_zerodim_initialization(self): + a = Tensor(55) + b = Tensor(3.14) + + self.assertEqual(a.shape, ()) + self.assertEqual(b.shape, ()) + def test_plus_equals(self): a = Tensor.randn(10,10) b = Tensor.randn(10,10) @@ -23,20 +30,6 @@ class TestTinygrad(unittest.TestCase): val2 = a.numpy() np.testing.assert_allclose(val1, val2) - def test_slicing(self): - x = Tensor.randn(10,10) - slices = [0,1,9,-1,-10,None] + [slice(s,e) for s,e in itertools.combinations([0,1,-1,None], r=2)] + [slice(9,11), slice(-11,-9)] - fmt = lambda s: f'{s.start}:{s.stop}' if isinstance(s, slice) else str(s) - for s in list(itertools.product(slices, slices)) + [(None,0,None,0,None), (slice(0,2),None,None,slice(2,4),None,None)]: - np.testing.assert_equal(x.numpy()[s], x[s].numpy(), f'Test failed for slice x[{",".join(fmt(x) for x in s)}]') - for s in [-11,10]: - with self.assertRaises(IndexError): - x[s] - with self.assertRaises(AssertionError): - x[::2] - with self.assertRaises(AssertionError): - x[0,0,0] - def test_backward_pass(self): def test_tinygrad(): x = Tensor(x_init, requires_grad=True) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 589949a504..2ff62736b1 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -414,6 +414,7 @@ class Linearizer: def simplify_ones(self): # remove places where the shape is all ones # TODO: this should be factored in to multi shape stride + if self.shape_len == 0: return all_ones = [all(st.shape[i]==1 for st in self.sts) for i in range(self.shape_len)] # keep at least 1 one if all(all_ones): all_ones[-1] = False diff --git a/tinygrad/nn/image.py b/tinygrad/nn/image.py index 698e48bc2e..86255a60f5 100644 --- a/tinygrad/nn/image.py +++ b/tinygrad/nn/image.py @@ -7,6 +7,7 @@ base_image_type = (100, 2, "imageh", np.float16) if FLOAT16 else (100, 4, "image def image_dot(self, w): # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) + if (n1:=len(self.shape))*(n2:=len(w.shape)) == 0: raise RuntimeError(f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D") bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2]) cin, cout = w.shape[-2], w.shape[-1] out_shape_t = self.shape[0:-2] + (cout,-1) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index e40ed7a9b7..19a53a3158 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -74,13 +74,13 @@ class View: @functools.lru_cache(maxsize=None) def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]: - strides = [1] + strides = [1] if shape else [] for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides return tuple(st if s != 1 else 0 for st, s in zip(strides, shape)) @functools.lru_cache(maxsize=None) def view_from_shape(shape:Tuple[int, ...]) -> View: - assert all(isinstance(x, int) for x in shape) and len(shape) != 0 + assert all(isinstance(x, int) for x in shape) return View(tuple(shape), strides_for_shape(shape)) @functools.lru_cache(maxsize=None) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 63b994e9d8..c3058875d1 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1,6 +1,6 @@ # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py from __future__ import annotations -import math, functools, itertools +import math, functools, itertools, operator import numpy as np from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence from tinygrad.helpers import prod, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, LazyNumpyArray @@ -33,9 +33,9 @@ class Tensor: no_grad: ClassVar[bool] = False default_type: ClassVar[DType] = dtypes.float32 - def __init__(self, data:Union[list, LazyBuffer, LazyNumpyArray, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): + def __init__(self, data:Union[int, float, list, LazyBuffer, LazyNumpyArray, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): device = (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") # canonicalize device - if isinstance(data, list): + if isinstance(data, (int, float, list)): data = np.array(data, dtype=(dtype if dtype is not None else Tensor.default_type).np) elif isinstance(data, LazyBuffer) and data.device != device: # TODO: this has to realize, it shouldn't have to @@ -47,7 +47,6 @@ class Tensor: # by here, it's either LazyNumpyArray or LazyBuffer # TODO: it should all be LazyBuffer I think if isinstance(data, LazyNumpyArray): - data = data if data.shape else data.reshape((1,)) lazydata = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None else data, device) elif isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" @@ -122,15 +121,13 @@ class Tensor: # ***** creation helper functions ***** @staticmethod - def full(shape:Tuple[int, ...], fill_value, **kwargs): - new_shape = argfix(shape) - return Tensor([fill_value], **kwargs).reshape([1]*len(new_shape)).expand(new_shape).contiguous() + def full(shape:Tuple[int, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape).contiguous() @staticmethod - def zeros(*shape, **kwargs): return Tensor.full(shape, 0, **kwargs) + def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs) @staticmethod - def ones(*shape, **kwargs): return Tensor.full(shape, 1, **kwargs) + def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs) @staticmethod def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs): @@ -203,11 +200,11 @@ class Tensor: return _deepwalk(self, set(), []) def backward(self): - assert self.shape == (1,) + assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})" # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous # this is "implicit gradient creation" - self.grad = Tensor([1], device=self.device, requires_grad=False) + self.grad = Tensor(1, device=self.device, requires_grad=False) for t0 in reversed(self.deepwalk()): if not any(x.requires_grad for x in t0._ctx.parents): @@ -227,7 +224,7 @@ class Tensor: def reshape(self, shape, *args) -> Tensor: new_shape = argfix(shape, *args) - assert len(new_shape) > 0 and all(x != 0 for x in new_shape), f"zeros not allowed in shape {new_shape}" + assert all(x != 0 for x in new_shape), f"zeros not allowed in shape {new_shape}" return mlops.Reshape.apply(self, shape=tuple(-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape)) def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args)))) def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args)) @@ -243,37 +240,71 @@ class Tensor: padding = tuple((max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)) return self.pad(padding).shrink(tuple((p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_))) - # Tensors mostly follow the normal python indexing / slicing behavior for sequences # - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element # - A slice i:j returns the elements with indices in [i, j) - # - If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence - # - Negative values for i and j are taken relative to the end of the sequence - # - Both i and j will be clamped to the range (-N, N], where N in the length of the sequence + # - If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence + # - Negative values for i and j are taken relative to the end of the sequence + # - Both i and j will be clamped to the range (-N, N], where N in the length of the sequence # - Indexing with np.newaxis or None on a given axis will add a new dimension of size one before that axis - # - Empty slices are not allowed - # - Strides other than 1 are not allowed + # - Empty slices are not allowed (tensors with 0s in shape have to be supported first, for all backends). + # - For a slice [i:j:k] finding the correct indices is delegated to slice.indices(len). + # - Strides > 1 and < 0 are now allowed!: + # - This works by applying Shrink -> [[Flip -> ] Pad -> Reshape -> Shrink] -> Reshape (ops in brackets are optional) + # - Idea of stride < 0 support: + # - Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below. + # - Idea of stride `s` > 1 support (Pad -> Reshape -> Shrink): + # - Instead of doing [::s] on axis [dim_sz], do [:, 0] on axes [dim_sz_padded // s, s]. + # - So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s] + # is possible. + # - Apply Shrink to do the slice [:, 0] on axes of shapes [dim_sz_padded // s, s]. def __getitem__(self, val): - def slcfix(i, sz, default): return default if i is None else max(0, min(sz, sz+i if i < 0 else i)) # Fix negative idxs, clamp to [0,N] - new_slice, new_shape = [], [] - val = [val] if not isinstance(val, (list, tuple)) else val - assert sum(s is not None for s in val) <= len(self.shape) - assert all(s.step is None or s.step == 1 for s in val if isinstance(s, slice)) - for i,(sz,s) in enumerate(zip(self.shape, [v for v in val if v is not None])): # Slicing only depends on ints + slices - if isinstance(s, int) and not (-sz <= s < sz): - raise IndexError(f"index {s} is out of bounds for dimension {i} with size {sz}") - new_slice.append((s%sz, s%sz+1) if isinstance(s, int) else (slcfix(s.start, sz, 0), slcfix(s.stop, sz, sz))) - for s,sz in zip(val, [self.shape[i-1] for i in itertools.accumulate([int(s is not None) for s in val])]): # Shape depends on slices + positions of Nones - if not isinstance(s, int): - new_shape.append(1 if s is None else slcfix(s.stop, sz, sz) - slcfix(s.start, sz, 0)) - new_shape += [self.shape[i] for i in range(len(new_slice), len(self.shape))] - new_slice += [(0,self.shape[i]) for i in range(len(new_slice), len(self.shape))] - return self.slice(new_slice).reshape(new_shape if len(new_shape) else (1,)) + def normalize_int(e, i, dim_sz): + if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1 + raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}") + val = list(val) if isinstance(val, tuple) else [val] + if (num_slices := sum(isinstance(v, (slice, int)) for v in val)) > len(self.shape): + raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}") + orig_slices = list(val) + [slice(None)] * (len(self.shape) - num_slices) + valid_slices = list(itertools.filterfalse(lambda x: x is None, orig_slices)) + valid_slices = [v if isinstance(v, slice) else slice(y := normalize_int(v, i, dim_sz), y+1) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))] + start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ()) + new_slice = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in zip(start, stop, strides)) + new_shape = tuple(e - s for s, e in new_slice) + # Shrink + sliced_tensor = self.shrink(new_slice) + # Flip + if (flip_axes := tuple(i for i, s in enumerate(strides) if s < 0)): + sliced_tensor = sliced_tensor.flip(axis=flip_axes) + if any(s > 1 or s < 0 for s in strides): + # normalize if negative strides + strides = tuple(abs(s) for s in strides) + def num_zeros(step, dim_sz): return 0 if step == 1 or (y := dim_sz % step) == 0 else (step - y) + # Pad: add pad at the end: [dim_sz] -> [dim_sz_padded] + paddings = tuple((0, num_zeros(s, dim_sz)) for s, dim_sz in zip(strides, sliced_tensor.shape)) + padded_tensor = sliced_tensor.pad(paddings) + # Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s] + new_shape = functools.reduce(operator.add, [[sh // s, s] for sh, s in zip(padded_tensor.shape, strides)], []) # type: ignore + reshaped_tensor = padded_tensor.reshape(new_shape) + # Shrink: do [:, 0] + new_shape = new_shape[::2] + final_slice = functools.reduce(operator.add, (((0, sh), (0, 1)) for sh in new_shape), ()) + sliced_tensor = reshaped_tensor.shrink(final_slice) + final_shape = [] + it_shape = iter(new_shape) + for i in orig_slices: + if isinstance(i, (int, slice)): + dim_shape = next(it_shape) + if isinstance(i, slice): final_shape.append(dim_shape) + else: # i is None + final_shape.append(1) + return sliced_tensor.reshape(tuple(final_shape)) # Reshape def cat(self, *args, dim=0): dim = (dim + len(self.shape)) if dim < 0 else dim for y in args: assert len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) catargs = [self] + list(args) + assert all(len(t.shape) != 0 for t in catargs), "zero-dimensional tensor cannot be concatenated" shape_cumsum = [0, *itertools.accumulate([y.shape[dim] for y in catargs])] slc = [[(0, s) for s in self.shape] for _ in catargs] for s,k in zip(slc, shape_cumsum): @@ -327,7 +358,7 @@ class Tensor: axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_] shape = [self.shape[i] for i in range(len(self.shape)) if i not in axis_] ret = fxn.apply(self, new_shape=tuple(1 if i in axis_ else self.shape[i] for i in range(len(self.shape)))) - return ret if keepdim else ret.reshape(shape=[1] if shape == [] else shape) + return ret if keepdim else ret.reshape(shape=shape) def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim) def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim) @@ -425,6 +456,7 @@ class Tensor: return ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))])) def dot(self, w:Tensor) -> Tensor: + if (n1:=len(self.shape))*(n2:=len(w.shape)) == 0: raise RuntimeError(f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D") x = self.reshape(*self.shape[0:-1], 1, self.shape[-1]) w = w.reshape(*w.shape[0:-2], 1, w.shape[-2], w.shape[-1]).transpose(-1, -2) r = (x*w).sum(-1) @@ -471,7 +503,7 @@ class Tensor: # ***** broadcasted binary mlops ***** def _broadcasted(self, fxn:Type[Function], other:Union[Tensor, float], reverse:bool=False) -> Tensor: - x,y = [Tensor([t], device=self.device, requires_grad=False) if not isinstance(t, Tensor) else t for t in ([other,self] if reverse else [self,other])] + x,y = [Tensor(t, device=self.device, requires_grad=False) if not isinstance(t, Tensor) else t for t in ([other,self] if reverse else [self,other])] x,y = [t.reshape([1]*(max(len(x.shape), len(y.shape))-len(t.shape)) + list(t.shape)) for t in [x,y]] shape_ret = tuple(max(sx, sy) for sx,sy in zip(x.shape, y.shape)) return fxn.apply(x.expand(shape_ret), y.expand(shape_ret))