mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
* add precompile to call * put get back * something * after structure * alt * keep it call * resolve call * resolve linear call * precompile works with llm * revert rangeify * color for debugging * getenv PRECOMPILE * clean up deco pattern * fully recursive sink scheduling * revert llama * fix SPEC=2
143 lines
4.7 KiB
Python
143 lines
4.7 KiB
Python
import unittest
|
|
import numpy as np
|
|
from tinygrad import Tensor, function
|
|
from tinygrad.dtype import dtypes
|
|
from tinygrad.uop.ops import UOp
|
|
|
|
class TestCall(unittest.TestCase):
|
|
def test_call_plus(self):
|
|
a = Tensor.randn(10, 10)
|
|
b = Tensor.randn(10, 10)
|
|
Tensor.realize(a,b)
|
|
|
|
# we define a plus function
|
|
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
|
|
|
|
c = Tensor.call(a, b, fxn=plus_fxn)
|
|
np.testing.assert_equal(c.numpy(), (a+b).numpy())
|
|
|
|
def test_call_plus_backward(self):
|
|
a = Tensor.ones(10, 10, requires_grad=True)
|
|
b = Tensor.ones(10, 10, requires_grad=True)
|
|
|
|
(a+b).mean().backward()
|
|
gt_a_grad = a.grad.numpy()
|
|
gt_b_grad = b.grad.numpy()
|
|
a.grad, b.grad = None, None
|
|
|
|
# this is the gradient for +
|
|
def grad_fxn(grad:UOp, call:UOp): return (grad, grad)
|
|
|
|
# we define a plus function
|
|
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
|
|
c = Tensor.call(a, b, fxn=plus_fxn, grad_fxn=grad_fxn)
|
|
c.mean().backward()
|
|
|
|
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
|
|
np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5)
|
|
|
|
def test_call_plus_backward_auto(self):
|
|
a = Tensor.ones(10, 10, requires_grad=True)
|
|
b = Tensor.ones(10, 10, requires_grad=True)
|
|
|
|
(a+b).mean().backward()
|
|
gt_a_grad = a.grad.numpy()
|
|
gt_b_grad = b.grad.numpy()
|
|
a.grad, b.grad = None, None
|
|
|
|
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
|
|
c = Tensor.call(a, b, fxn=plus_fxn)
|
|
c.mean().backward()
|
|
|
|
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
|
|
np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5)
|
|
|
|
def test_call_gemm(self):
|
|
M, K, N = 4, 8, 4
|
|
a = Tensor.randn(M, K)
|
|
b = Tensor.randn(K, N)
|
|
Tensor.realize(a, b)
|
|
c = Tensor.call(a, b, fxn=a.as_param(0) @ b.as_param(1))
|
|
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
|
|
|
|
@unittest.skip("needs GEMM on mixins")
|
|
def test_call_gemm_uop(self):
|
|
M, K, N = 4, 8, 4
|
|
a = Tensor.randn(M, K)
|
|
b = Tensor.randn(K, N)
|
|
Tensor.realize(a, b)
|
|
|
|
# we define a gemm function
|
|
x = UOp.param(0, dtypes.float, shape=(M, K))
|
|
y = UOp.param(1, dtypes.float, shape=(K, N))
|
|
c = Tensor.call(a, b, fxn=x@y)
|
|
|
|
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
|
|
|
|
def test_call_complex_backward_auto(self):
|
|
# complex chain: (a*b + a).exp2() * b.reciprocal() - tests mul, add, exp2, reciprocal, param reuse
|
|
a = Tensor.randn(10, 10, requires_grad=True)
|
|
b = Tensor.randn(10, 10, requires_grad=True) + 2 # avoid div by zero
|
|
Tensor.realize(a, b)
|
|
|
|
((a*b + a).exp2() * b.reciprocal()).mean().backward()
|
|
gt_a_grad, gt_b_grad = a.grad.numpy(), b.grad.numpy()
|
|
a.grad, b.grad = None, None
|
|
|
|
p0, p1 = UOp.param(0, dtypes.float, (10,10)), UOp.param(1, dtypes.float, (10,10))
|
|
complex_fxn = (p0*p1 + p0).exp2() * p1.reciprocal()
|
|
c = Tensor.call(a, b, fxn=complex_fxn)
|
|
c.mean().backward()
|
|
|
|
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
|
|
np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5)
|
|
|
|
def test_call_plus_sharded(self):
|
|
devs = ("CPU:0", "CPU:1")
|
|
a = Tensor.ones(10, 10).shard(devs, axis=0)
|
|
b = Tensor.ones(10, 10).shard(devs, axis=0)
|
|
Tensor.realize(a, b)
|
|
c = Tensor.call(a, b, fxn=a.as_param(0) + b.as_param(1))
|
|
np.testing.assert_equal(c.numpy(), 2 * np.ones((10, 10)))
|
|
|
|
class TestCallSchedule(unittest.TestCase):
|
|
def test_reshape_precompile(self):
|
|
a = Tensor.empty(4, 8).realize()
|
|
a = a.reshape(4,4,2).assign(Tensor.empty(4,4,2)).reshape(8,4)
|
|
@function(precompile=True)
|
|
def s(x): return x.sum(axis=0)
|
|
(s(a)*3).realize()
|
|
|
|
def test_call_precompiled(self):
|
|
a = Tensor.empty(4, 8)
|
|
@function(precompile=True)
|
|
def s(x): return x*2
|
|
(s(a)*3).realize()
|
|
|
|
def test_double_call(self):
|
|
a = Tensor.empty(4, 8)
|
|
@function(precompile=True)
|
|
def s(x): return x*2
|
|
s(s(a)).realize()
|
|
|
|
def test_double_call_contiguous(self):
|
|
a = Tensor.empty(4, 8)
|
|
@function(precompile=True)
|
|
def s(x): return x*2
|
|
s(s(a).contiguous()).realize()
|
|
|
|
def test_call_double_gemm(self):
|
|
a = Tensor.randn(4, 8, requires_grad=True)
|
|
b = Tensor.randn(8, 12, requires_grad=True)
|
|
c = Tensor.randn(12, 16, requires_grad=True)
|
|
ref = Tensor.randn(4, 16)
|
|
Tensor.realize(a,b,c,ref)
|
|
@function(precompile=True)
|
|
def gemm(a:Tensor, b:Tensor, c:Tensor) -> Tensor: return (a@b)@c
|
|
out = gemm(a,b,c)
|
|
(out-ref).square().mean().backward()
|
|
out.realize(a.grad, b.grad, c.grad)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|