import numpy as np import unittest from tinygrad.function import function from tinygrad import Tensor, GlobalCounters from tinygrad.uop.ops import UOp class TestFunction(unittest.TestCase): def test_simple(self): @function def f(a:Tensor, b:Tensor) -> Tensor: return a+b a = Tensor([1,2,3]) b = Tensor([4,5,6]) np.testing.assert_equal(f(a,b).numpy(), [5,7,9]) def test_simple_same(self): @function def f(a:Tensor, b:Tensor) -> Tensor: return a+b a = Tensor([1,2,3]) np.testing.assert_equal(f(a,a).numpy(), [2,4,6]) def test_implicit(self): inp = Tensor([7,8,9]) @function(allow_implicit=True) def f(a:Tensor, b:Tensor) -> Tensor: return a+b+inp a = Tensor([1,2,3]) b = Tensor([4,5,6]) np.testing.assert_equal(f(a,b).numpy(), [12,15,18]) def test_implicit_same_as_input(self): inp = Tensor([7,8,9]) @function(allow_implicit=True) def f(a:Tensor, b:Tensor) -> Tensor: return a+b+inp a = Tensor([1,2,3]) np.testing.assert_equal(f(a, inp).numpy(), [15,18,21]) def test_implicit_2(self): inp = Tensor([7,8,9]) @function(allow_implicit=True) def f(a:Tensor, b:Tensor) -> Tensor: return a+b+inp inp2 = Tensor([7,8,10]) @function(allow_implicit=True) def g(a:Tensor, b:Tensor) -> Tensor: return a+b+inp2 a = Tensor([1,2,3]) b = Tensor([4,5,6]) c = f(a,b) d = g(a,b) c.realize(d) np.testing.assert_equal(c.numpy(), [12,15,18]) np.testing.assert_equal(d.numpy(), [12,15,19]) def test_implicit_unrealized(self): inp = Tensor([1,2,3]) + Tensor([4,5,6]) @function(allow_implicit=True) def f(a:Tensor) -> Tensor: return a + inp np.testing.assert_equal(f(Tensor([10,20,30])).numpy(), [15,27,39]) def test_detach(self): @function def f(a:Tensor, b:Tensor) -> Tensor: return a.detach() + b a = Tensor([1,2,3]) b = Tensor([4,5,6]) np.testing.assert_equal(f(a, b).numpy(), [5,7,9]) def test_contiguous_backward(self): @function def f(a:Tensor, b:Tensor) -> Tensor: return (a + b).contiguous_backward() a = Tensor([1,2,3]) b = Tensor([4,5,6]) np.testing.assert_equal(f(a, b).numpy(), [5,7,9]) def test_method(self): class Foo: def __init__(self): self.w = Tensor([10,20,30]) @function def __call__(self, x:Tensor) -> Tensor: return x + self.w foo = Foo() np.testing.assert_equal(foo(Tensor([1,2,3])).numpy(), [11,22,33]) def test_grad_gemm(self): @function def f(a:Tensor, b:Tensor) -> Tensor: return a @ b a = Tensor([[1.,2.],[3.,4.]], requires_grad=True) b = Tensor([[5.,6.],[7.,8.]], requires_grad=True) (f(a, b).contiguous() * b).sum().backward() Tensor.realize(a, b, a.grad, b.grad) # L = sum((a@b) * b), dL/d(a@b) = b, dL/da = b @ b^T, dL/db = a^T @ b + (a@b) na, nb = a.numpy(), b.numpy() np.testing.assert_allclose(a.grad.numpy(), nb @ nb.T) np.testing.assert_allclose(b.grad.numpy(), na.T @ nb + na @ nb) def test_grad_implicit(self): w = Tensor([1., 2., 3.], requires_grad=True) w.realize() # TODO: this is required @function(allow_implicit=True) def f(x:Tensor) -> Tensor: return x * w x = Tensor([4., 5., 6.]) f(x).sum().backward() np.testing.assert_allclose(w.grad.numpy(), [4., 5., 6.]) def test_symbolic_index(self): table = Tensor([10,20,30,40]).contiguous().realize() @function(allow_implicit=True) def f(x:Tensor, start_pos:int|UOp) -> Tensor: return x + table[start_pos] v = UOp.variable("start_pos", 0, 3) np.testing.assert_equal(f(Tensor([1,2,3]), v.bind(0)).numpy(), [11,12,13]) def test_symbolic_shape_input(self): table = Tensor([10,20,30,40]).contiguous().realize() @function def f(x:Tensor) -> Tensor: return x * 2 sz = UOp.variable("sz", 1, 3) slic = table[:sz.bind(2)] np.testing.assert_equal(f(slic)[:2].numpy(), [20,40]) def test_nested_calls(self): w = Tensor([10., 20., 30.]) @function(allow_implicit=True) def f(a:Tensor) -> Tensor: return a + w @function(allow_implicit=True) def g(a:Tensor) -> Tensor: return a * w a = Tensor([1., 2., 3.]) np.testing.assert_allclose(g(f(a)).numpy(), [110., 440., 990.]) def test_nested_calls_backward(self): w = Tensor([[1., 2.], [3., 4.]]).contiguous().realize() @function(allow_implicit=True) def inner(x:Tensor) -> Tensor: return x + w @function(allow_implicit=True) def outer(a:Tensor, b:Tensor) -> Tensor: return inner(a.reshape(1,2) + b.reshape(1,2)) a = Tensor([1., 2.], requires_grad=True) b = Tensor([3., 4.], requires_grad=True) outer(a, b).sum().backward() np.testing.assert_allclose(a.grad.numpy(), [2., 2.]) np.testing.assert_allclose(b.grad.numpy(), [2., 2.]) def test_unused_param_backward(self): @function def f(a:Tensor, b:Tensor, c:Tensor) -> Tensor: return a + c # b is unused a = Tensor([1., 2., 3.], requires_grad=True) b = Tensor([4., 5., 6.], requires_grad=True) c = Tensor([7., 8., 9.], requires_grad=True) f(a, b, c).sum().backward() np.testing.assert_allclose(a.grad.numpy(), [1., 1., 1.]) np.testing.assert_allclose(b.grad.numpy(), [0., 0., 0.]) np.testing.assert_allclose(c.grad.numpy(), [1., 1., 1.]) def test_name(self): @function def f(a:Tensor) -> Tensor: return a + 1 assert f(Tensor([1])).uop.src[0].arg.name.endswith("f") def test_method_name(self): class Foo: @function def __call__(self, x:Tensor) -> Tensor: return x + 1 assert Foo()(Tensor([1])).uop.src[0].arg.name.endswith("Foo.__call__") def test_callable_instance(self): class Foo: def __init__(self): self.w = Tensor([10,20,30]) def __call__(self, x:Tensor) -> Tensor: return x + self.w foo = Foo() f = function(foo, allow_implicit=True) np.testing.assert_equal(f(Tensor([1,2,3])).numpy(), [11,22,33]) assert f(Tensor([1,2,3])).uop.src[0].arg.name.endswith("Foo") def test_iadd(self): @function def f(x:Tensor) -> Tensor: x += 1 return x a = Tensor([1,2,3]).realize() np.testing.assert_equal(f(a).numpy(), [2,3,4]) np.testing.assert_equal(a.numpy(), [3,4,5]) # TODO: should be [1,2,3] def test_implicit_assign(self): a = Tensor([1,2,3]) a += 1 c = Tensor([2,2,2]).contiguous() @function def f(b:Tensor) -> Tensor: return a+b+c b = Tensor([10,20,30]).realize() np.testing.assert_equal(f(b).numpy(), [14,25,36]) def test_assign_input(self): @function def f(a:Tensor, b:Tensor) -> Tensor: a.assign(b+1) return a a = Tensor([1,2,3]).realize() b = Tensor([10,20,30]).realize() np.testing.assert_equal(f(a,b).numpy(), [11,21,31]) np.testing.assert_equal(a.numpy(), [11,21,31]) # TODO: should be [1,2,3] np.testing.assert_equal(b.numpy(), [10,20,30]) def test_view_assign_explicit_buffer(self): """view assign on an explicit param's buffer should not create implicit inputs.""" class State: def __init__(self): self.buf = Tensor.zeros(2, 4).contiguous().realize() @function(allow_implicit=False) def __call__(self, x:Tensor) -> Tensor: self.buf[:, 0:2].assign(x) return self.buf[:, 0:2] s = State() np.testing.assert_equal(s(Tensor([[5., 6.], [7., 8.]])).numpy(), [[5., 6.], [7., 8.]]) def test_single_after_store(self): """AFTER(buf, STORE(view, data)) should write data through the view into buf, same as the double-after pattern.""" @function def f(buf:Tensor, x:Tensor, start_pos:int|UOp) -> Tensor: slice_uop = buf[:, start_pos:start_pos+1].uop assigned = Tensor(buf.uop.after(slice_uop.store(x.uop))) return assigned buf = Tensor.zeros(2, 8).contiguous().realize() x = Tensor([[1.], [2.]]).realize() v = UOp.variable("sp", 0, 7) r0 = f(buf, x, v.bind(0)).numpy() np.testing.assert_equal(r0, [[1.,0.,0.,0.,0.,0.,0.,0.], [2.,0.,0.,0.,0.,0.,0.,0.]]) @unittest.expectedFailure def test_assign_slice(self): @function def f(a:Tensor, b:Tensor) -> Tensor: a[1:] = b[1:]+1 return a a = Tensor([1,2,3]).realize() b = Tensor([10,20,30]).realize() np.testing.assert_equal(f(a,b).numpy(), [1,21,31]) np.testing.assert_equal(a.numpy(), [1,2,3]) np.testing.assert_equal(b.numpy(), [10,20,30]) class TestFunctionMulti(unittest.TestCase): devices_2 = ("CPU:0", "CPU:1") def test_simple_multi(self): @function def f(a:Tensor, b:Tensor) -> Tensor: return a+b a = Tensor([1,2,3,4]).shard(self.devices_2, axis=None) b = Tensor([10,20,30,40]).shard(self.devices_2, axis=None) np.testing.assert_equal(f(a,b).numpy(), [11,22,33,44]) def test_simple_multi_sharded(self): @function def f(a:Tensor, b:Tensor) -> Tensor: return a+b a = Tensor([1,2,3,4]).shard(self.devices_2, axis=0) b = Tensor([10,20,30,40]).shard(self.devices_2, axis=0) np.testing.assert_equal(f(a,b).numpy(), [11,22,33,44]) def test_data_parallel_multi(self): @function def f(x:Tensor, w:Tensor) -> Tensor: return x @ w x = Tensor([[1.,2.],[3.,4.],[5.,6.],[7.,8.]]).shard(self.devices_2, axis=0) w = Tensor([[1.,0.],[0.,1.]]).shard(self.devices_2, axis=None) np.testing.assert_allclose(f(x, w).numpy(), [[1.,2.],[3.,4.],[5.,6.],[7.,8.]]) def test_grad_implicit_multi(self): w = Tensor([1., 2., 3., 4.], requires_grad=True).shard(self.devices_2, axis=None) w.realize() @function(allow_implicit=True) def f(x:Tensor) -> Tensor: return x * w x = Tensor([4., 5., 6., 7.]).shard(self.devices_2, axis=None) f(x).sum().backward() np.testing.assert_allclose(w.grad.numpy(), [4., 5., 6., 7.]) def test_call_axis(self): @function def f(x:Tensor, w:Tensor) -> Tensor: return x @ w x = Tensor([[1.,0.],[0.,1.],[1.,1.],[0.,0.]]).shard(self.devices_2, axis=0) w = Tensor([[1.,2.],[3.,4.]]).shard(self.devices_2, axis=None) result = f(x, w) # CALL output should inherit axis=0 from the sharded input self.assertEqual(result.uop.axis, 0) # reduce on the sharded axis should remove it self.assertIsNone(result.sum().uop.axis) def test_call_axis_shard_inside(self): @function def f(x:Tensor, w:Tensor) -> Tensor: return x.shard(self.devices_2, axis=0) @ w.shard(self.devices_2, axis=None) x = Tensor([[1.,0.],[0.,1.],[1.,1.],[0.,0.]]) w = Tensor([[1.,2.],[3.,4.]]) result = f(x, w) self.assertEqual(result.uop.axis, 0) np.testing.assert_allclose(result.numpy(), x.numpy() @ w.numpy()) def test_data_parallel_backward(self): @function def f(x:Tensor, w:Tensor) -> Tensor: return x @ w x = Tensor([[1.,0.],[0.,1.],[1.,1.],[0.,0.]], requires_grad=True).shard(self.devices_2, axis=0) w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(self.devices_2, axis=None) w.realize() f(x, w).sum().backward() # d/dx = ones @ w^T = [[1,3],[1,3],[1,3],[1,3]], but sum so ones(4,2) @ w^T? no: # L = sum(x @ w), dL/dx = ones(4,2) @ w^T... actually dL/d(xw) = ones(4,2), dL/dx = ones(4,2) @ w^T np.testing.assert_allclose(x.grad.numpy(), np.ones((4,2)) @ np.array([[1,3],[2,4]])) def test_data_parallel_backward_4(self): devices_4 = tuple(f"CPU:{i}" for i in range(4)) @function def f(x:Tensor, w:Tensor) -> Tensor: return x @ w x = Tensor(np.arange(16).reshape(8,2).astype(np.float32), requires_grad=True).shard(devices_4, axis=0) w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(devices_4, axis=None) w.realize() f(x, w).sum().backward() np.testing.assert_allclose(x.grad.numpy(), np.ones((8,2)) @ np.array([[1,3],[2,4]])) def test_data_parallel_backward_implicit(self): devices_4 = tuple(f"CPU:{i}" for i in range(4)) w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(devices_4, axis=None) w.realize() @function(allow_implicit=True) def f(x:Tensor) -> Tensor: return x @ w x = Tensor(np.arange(16).reshape(8,2).astype(np.float32), requires_grad=True).shard(devices_4, axis=0) f(x).sum().backward() np.testing.assert_allclose(x.grad.numpy(), np.ones((8,2)) @ np.array([[1,3],[2,4]])) def test_data_parallel_backward_twice(self): devices_4 = tuple(f"CPU:{i}" for i in range(4)) w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(devices_4, axis=None) w.realize() # pre-init grads like the training loop does w.grad = w.zeros_like().contiguous().realize() @function(allow_implicit=True) def f(x:Tensor) -> Tensor: return x @ w expected = np.ones((8,2)) @ np.array([[1,3],[2,4]]) for _ in range(2): x = Tensor(np.arange(16).reshape(8,2).astype(np.float32), requires_grad=True).shard(devices_4, axis=0) f(x).sum().backward() np.testing.assert_allclose(x.grad.numpy(), expected) class TestFunctionTuple(unittest.TestCase): def test_tuple(self, precompile=False): x = Tensor.ones(3).contiguous() @function(precompile=precompile) def f(t:Tensor): return (t+1, t+2) t1, t2 = f(x) t1.realize(t2) print(t1.tolist(), t2.tolist()) assert t1.tolist() == [2,2,2] assert t2.tolist() == [3,3,3] def test_tuple_precompile(self): self.test_tuple(True) def test_grad_tuple(self, precompile=False): x = Tensor.ones(3, requires_grad=True).contiguous() y = Tensor.ones(3, requires_grad=True).contiguous() @function(precompile=precompile) def f(u1:Tensor, u2:Tensor): return (u1+1, u2+2) t1, t2 = f(x,y) (t1+t2).sum().backward() x.grad.realize(y.grad) def test_grad_tuple_precompile(self): self.test_grad_tuple(True) def test_grad_fxn_tuple(self): # grad_fxn for tuple: receives one gradient per output as positional args def grad_fxn(d_out0:UOp, d_out1:UOp, call:UOp): # f(u1, u2) = (u1+1, u2+2) # df/du1 = d_out0, df/du2 = d_out1 return (d_out0, d_out1) x = Tensor.ones(3, requires_grad=True).contiguous() y = Tensor.ones(3, requires_grad=True).contiguous() @function(grad_fxn=grad_fxn) def f(u1:Tensor, u2:Tensor): return (u1+1, u2+2) t1, t2 = f(x, y) (t1+t2).sum().backward() np.testing.assert_allclose(x.grad.numpy(), [1., 1., 1.]) np.testing.assert_allclose(y.grad.numpy(), [1., 1., 1.]) class TestFunctionGrad(unittest.TestCase): def test_function_grad_ops(self, precompile=False, precompile_backward=False): N = 64 x = Tensor.ones(N,N).contiguous() w1 = Tensor.ones(N,N, requires_grad=True).contiguous() w2 = Tensor.ones(N,N, requires_grad=True).contiguous() w3 = Tensor.ones(N,N, requires_grad=True).contiguous() ref = Tensor.ones(N,N).contiguous() Tensor.realize(x, w1, w2, w3, ref) @function(precompile=precompile, precompile_backward=precompile_backward) def f(x, w1, w2, w3) -> tuple[Tensor, ...]: p1 = x@w1 p2 = p1@w2 p3 = p2@w3 return p1, p2, p3, p3.contiguous() ret = f(x, w1, w2, w3)[-1] loss = (ret-ref).square().mean().backward() print("RESET") GlobalCounters.reset() loss.realize(w1.grad, w2.grad, w3.grad) print(GlobalCounters.global_ops, GlobalCounters.global_mem) self.assertLessEqual(GlobalCounters.global_ops, 4739073) def test_function_grad_ops_precompile(self): self.test_function_grad_ops(precompile=True) def test_function_grad_ops_precompile_backward(self): self.test_function_grad_ops(precompile=True, precompile_backward=True) if __name__ == '__main__': unittest.main()