Files
tinygrad/test/unit/test_call.py
George Hotz 3ff03be413 call always has tuple (#15297)
* call always has tuple

* fix pre-commit and simplify

* update

* fix

* move that assert

* tuple

* fix multi

* cleanups

* fix merge
2026-03-17 10:58:46 +08:00

353 lines
14 KiB
Python

import unittest
import numpy as np
from tinygrad import Tensor, function
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp, Ops
class TestCall(unittest.TestCase):
def test_call_plus(self):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
Tensor.realize(a,b)
# we define a plus function
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
c = Tensor.call(a, b, fxn=plus_fxn)
np.testing.assert_equal(c.numpy(), (a+b).numpy())
def test_call_plus_backward(self):
a = Tensor.ones(10, 10, requires_grad=True)
b = Tensor.ones(10, 10, requires_grad=True)
(a+b).mean().backward()
gt_a_grad = a.grad.numpy()
gt_b_grad = b.grad.numpy()
a.grad, b.grad = None, None
# this is the gradient for +
def grad_fxn(grad:UOp, call:UOp): return (grad, grad)
# we define a plus function
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
c = Tensor.call(a, b, fxn=plus_fxn, grad_fxn=grad_fxn)
c.mean().backward()
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5)
def test_call_plus_backward_auto(self):
a = Tensor.ones(10, 10, requires_grad=True)
b = Tensor.ones(10, 10, requires_grad=True)
(a+b).mean().backward()
gt_a_grad = a.grad.numpy()
gt_b_grad = b.grad.numpy()
a.grad, b.grad = None, None
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
c = Tensor.call(a, b, fxn=plus_fxn)
c.mean().backward()
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5)
def test_call_gemm(self):
M, K, N = 4, 8, 4
a = Tensor.randn(M, K)
b = Tensor.randn(K, N)
Tensor.realize(a, b)
c = Tensor.call(a, b, fxn=a.as_param(0) @ b.as_param(1))
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
@unittest.skip("needs GEMM on mixins")
def test_call_gemm_uop(self):
M, K, N = 4, 8, 4
a = Tensor.randn(M, K)
b = Tensor.randn(K, N)
Tensor.realize(a, b)
# we define a gemm function
x = UOp.param(0, dtypes.float, shape=(M, K))
y = UOp.param(1, dtypes.float, shape=(K, N))
c = Tensor.call(a, b, fxn=x@y)
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
def test_call_complex_backward_auto(self):
# complex chain: (a*b + a).exp2() * b.reciprocal() - tests mul, add, exp2, reciprocal, param reuse
a = Tensor.randn(10, 10, requires_grad=True)
b = Tensor.randn(10, 10, requires_grad=True) + 2 # avoid div by zero
Tensor.realize(a, b)
((a*b + a).exp2() * b.reciprocal()).mean().backward()
gt_a_grad, gt_b_grad = a.grad.numpy(), b.grad.numpy()
a.grad, b.grad = None, None
p0, p1 = UOp.param(0, dtypes.float, (10,10)), UOp.param(1, dtypes.float, (10,10))
complex_fxn = (p0*p1 + p0).exp2() * p1.reciprocal()
c = Tensor.call(a, b, fxn=complex_fxn)
c.mean().backward()
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5)
def test_call_plus_sharded(self):
devs = ("CPU:0", "CPU:1")
a = Tensor.ones(10, 10).shard(devs, axis=0)
b = Tensor.ones(10, 10).shard(devs, axis=0)
Tensor.realize(a, b)
c = Tensor.call(a, b, fxn=a.as_param(0) + b.as_param(1))
np.testing.assert_equal(c.numpy(), 2 * np.ones((10, 10)))
class TestCallShape(unittest.TestCase):
def test_call_shape_int(self):
# fixed-shape function: shape passes through unchanged
@function
def f(x:Tensor) -> Tensor: return x * 2
self.assertEqual(f(Tensor.empty(4, 8)).shape, (4, 8))
def test_call_shape_param_substitution(self):
# symbolic shape dimension is substituted: inner PARAM replaced with the BIND arg
@function
def f(x:Tensor) -> Tensor: return x * 2
sz = UOp.variable("sz", 1, 8)
shape = f(Tensor.empty(8)[:sz.bind(5)]).shape
# the PARAM should be gone, replaced with the BIND from the call arg
self.assertIsInstance(shape[0], UOp)
self.assertNotEqual(shape[0].op, Ops.PARAM)
self.assertEqual(shape[0], sz.bind(5))
def test_call_shape_expr_substitution(self):
# expression containing PARAMs in shape gets fully substituted
@function
def f(x:Tensor) -> Tensor: return x + 1
sz = UOp.variable("sz", 1, 10)
shape = f(Tensor.empty(10, 4)[:sz.bind(3)]).shape
self.assertIsInstance(shape[0], UOp)
self.assertNotEqual(shape[0].op, Ops.PARAM)
self.assertEqual(shape[1], 4)
def test_call_shape_no_param_passthrough(self):
# a non-PARAM UOp shape element passes through unchanged
@function
def f(x:Tensor) -> Tensor: return x * 3
sz = UOp.variable("sz", 1, 8)
shape = f(Tensor.empty(8)[:sz.bind(5)]).shape
self.assertEqual(shape[0], sz.bind(5))
class TestCallSchedule(unittest.TestCase):
def test_reshape_precompile(self):
a = Tensor.empty(4, 8).realize()
a = a.reshape(4,4,2).assign(Tensor.empty(4,4,2)).reshape(8,4)
@function(precompile=True)
def s(x): return x.sum(axis=0)
(s(a)*3).realize()
def test_call_precompiled(self):
a = Tensor.empty(4, 8)
@function(precompile=True)
def s(x): return x*2
(s(a)*3).realize()
def test_double_call(self):
a = Tensor.empty(4, 8)
@function(precompile=True)
def s(x): return x*2
s(s(a)).realize()
def test_double_call_contiguous(self):
a = Tensor.empty(4, 8)
@function(precompile=True)
def s(x): return x*2
s(s(a).contiguous()).realize()
def test_call_double_gemm(self):
a = Tensor.randn(4, 8, requires_grad=True)
b = Tensor.randn(8, 12, requires_grad=True)
c = Tensor.randn(12, 16, requires_grad=True)
ref = Tensor.randn(4, 16)
Tensor.realize(a,b,c,ref)
@function(precompile=True)
def gemm(a:Tensor, b:Tensor, c:Tensor) -> Tensor: return (a@b)@c
out = gemm(a,b,c)
(out-ref).square().mean().backward()
out.realize(a.grad, b.grad, c.grad)
def test_precompile_symbolic_shape(self):
"""precompile with a symbolic-shaped input produces correct values and shape"""
@function(precompile=True)
def f(x:Tensor) -> Tensor: return x * 2
sz = UOp.variable("sz", 1, 8)
a = Tensor([1., 2., 3., 4., 5., 6., 7., 8.])[:sz.bind(5)]
out = f(a)
self.assertIsInstance(out.shape[0], UOp)
np.testing.assert_allclose(out[:5].numpy(), [2., 4., 6., 8., 10.])
def test_precompile_symbolic_shape_contiguous(self):
"""precompile with a .contiguous() inside the function body on a symbolic-shaped input"""
@function(precompile=True)
def f(x:Tensor) -> Tensor: return (x * 2).contiguous() + 1
sz = UOp.variable("sz", 1, 8)
a = Tensor([1., 2., 3., 4., 5., 6., 7., 8.])[:sz.bind(3)]
out = f(a)
self.assertIsInstance(out.shape[0], UOp)
np.testing.assert_allclose(out[:3].numpy(), [3., 5., 7.])
def test_precompile_symbolic_shape_chain(self):
"""precompiled symbolic result used in downstream ops (tests AFTER has correct symbolic shape)"""
@function(precompile=True)
def f(x:Tensor) -> Tensor: return x * 2
sz = UOp.variable("sz", 1, 8)
a = Tensor([1., 2., 3., 4., 5., 6., 7., 8.])[:sz.bind(4)]
out = f(a) + 10 # downstream op on the precompiled result
self.assertIsInstance(out.shape[0], UOp)
np.testing.assert_allclose(out[:4].numpy(), [12., 14., 16., 18.])
def test_precompile_bind_arg(self):
"""precompile with a BIND (scalar variable) as a function argument"""
@function(precompile=True)
def f(x:Tensor, scale:UOp) -> Tensor: return x * scale
v = UOp.variable("scale", 1, 100)
a = Tensor([1., 2., 3.])
out = f(a, v.bind(5))
np.testing.assert_allclose(out.numpy(), [5., 10., 15.])
def test_precompile_schedule_cache_hit(self):
"""two instances of the same @function should produce identical function body keys (schedule cache hit)"""
@function(precompile=True)
def f(x:Tensor) -> Tensor: return x + Tensor.full(x.shape, -1.0)
a = Tensor.empty(4, 8)
b = Tensor.empty(4, 8)
r0, r1 = f(a), f(b)
# find the CALL nodes
c0 = next(u for u in r0.uop.toposort() if u.op is Ops.CALL)
c1 = next(u for u in r1.uop.toposort() if u.op is Ops.CALL)
# the function bodies (src[0]) should have identical keys — unique consts must not leak through
self.assertEqual(c0.src[0].key, c1.src[0].key)
def test_precompile_symbolic_2d(self):
"""precompile with symbolic shapes in 2D (tests debuf reshape with symbolic PARAM)"""
@function(precompile=True)
def f(x:Tensor) -> Tensor: return x * 2 + 1
sz = UOp.variable("sz", 1, 16)
a = Tensor.arange(16*4).reshape(16, 4).float()[:sz.bind(5)]
out = f(a)
# result shape should have the symbolic dim, not the max
self.assertIsInstance(out.shape[0], UOp)
np.testing.assert_allclose(out[:5].numpy(), (np.arange(16*4).reshape(16, 4)[:5] * 2 + 1).astype(np.float32))
def test_precompile_multi_sharded(self):
@function(precompile=True)
def f(x:Tensor) -> Tensor: return x + 1
devs = ("CPU:0", "CPU:1")
a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0)
out = f(a) + 2
np.testing.assert_allclose(out.numpy(), np.arange(8, dtype=np.float32).reshape(4, 2) + 3)
class TestCallMultiSharded(unittest.TestCase):
# TODO: multi-output + sharded needs per-device CALL execution, which requires reworking how MULTI propagates through TUPLE bodies
def test_tuple_sharded(self):
"""multi-output function with sharded input"""
devs = ("CPU:0", "CPU:1")
@function
def f(x:Tensor): return (x + 1, x * 2)
a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0)
t1, t2 = f(a)
ref = np.arange(8, dtype=np.float32).reshape(4, 2)
np.testing.assert_allclose(t1.numpy(), ref + 1)
np.testing.assert_allclose(t2.numpy(), ref * 2)
def test_tuple_sharded_precompile(self):
"""multi-output precompiled function with sharded input"""
devs = ("CPU:0", "CPU:1")
@function(precompile=True)
def f(x:Tensor): return (x + 1, x * 2)
a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0)
t1, t2 = f(a)
ref = np.arange(8, dtype=np.float32).reshape(4, 2)
np.testing.assert_allclose(t1.numpy(), ref + 1)
np.testing.assert_allclose(t2.numpy(), ref * 2)
def test_tuple_sharded_different_axis(self):
"""multi-output function where outputs have different sharding: one reduces on sharded axis, one doesn't"""
devs = ("CPU:0", "CPU:1")
@function
def f(x:Tensor): return (x.sum(axis=0), x.sum(axis=1))
a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0)
t1, t2 = f(a)
ref = np.arange(8, dtype=np.float32).reshape(4, 2)
np.testing.assert_allclose(t1.numpy(), ref.sum(axis=0))
np.testing.assert_allclose(t2.numpy(), ref.sum(axis=1))
def test_tuple_sharded_different_ops(self):
"""multi-output function with different operations per output"""
devs = ("CPU:0", "CPU:1")
@function
def f(x:Tensor, y:Tensor): return (x + y, x * y)
a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0)
b = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0) + 1
t1, t2 = f(a, b)
ref_a = np.arange(8, dtype=np.float32).reshape(4, 2)
ref_b = ref_a + 1
np.testing.assert_allclose(t1.numpy(), ref_a + ref_b)
np.testing.assert_allclose(t2.numpy(), ref_a * ref_b)
def test_tuple_sharded_mixed_use(self):
"""multi-output sharded results used in further computation"""
devs = ("CPU:0", "CPU:1")
@function
def f(x:Tensor): return (x + 1, x * 2)
a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0)
t1, t2 = f(a)
out = (t1 + t2).sum()
ref = np.arange(8, dtype=np.float32).reshape(4, 2)
np.testing.assert_allclose(out.numpy(), ((ref + 1) + (ref * 2)).sum())
def test_tuple_sharded_outputs_different_axis(self):
"""multi-output function where the two outputs are sharded on different axes"""
devs = ("CPU:0", "CPU:1")
@function
def f(x:Tensor, y:Tensor): return (x + 1, y + 2)
a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0)
b = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=1)
t1, t2 = f(a, b)
ref_a = np.arange(8, dtype=np.float32).reshape(4, 2)
ref_b = np.arange(8, dtype=np.float32).reshape(4, 2)
np.testing.assert_allclose(t1.numpy(), ref_a + 1)
np.testing.assert_allclose(t2.numpy(), ref_b + 2)
def test_call_reduce_sharded(self):
devs = ("CPU:0", "CPU:1")
a = Tensor.ones(10, 10).shard(devs, axis=0)
Tensor.realize(a)
c = Tensor.call(a, fxn=a.as_param(0).sum(axis=0))
np.testing.assert_equal(c.numpy(), 10 * np.ones(10))
def test_call_reduce_sharded_mixed_args(self):
devs = ("CPU:0", "CPU:1")
a = Tensor.ones(10, 10).shard(devs, axis=0)
b = Tensor.ones(10).shard(devs, axis=None)
Tensor.realize(a, b)
c = Tensor.call(a, b, fxn=a.as_param(0).sum(axis=0) + b.as_param(1))
np.testing.assert_equal(c.numpy(), 11 * np.ones(10))
def test_call_reduce_sharded_backward(self):
devs = ("CPU:0", "CPU:1")
a = Tensor.randn(10, 10, requires_grad=True).shard(devs, axis=0)
b = Tensor.randn(10, 10, requires_grad=True).shard(devs, axis=0)
Tensor.realize(a, b)
def grad_fxn(grad, call):
a_arg, b_arg = call.src[1], call.src[2]
return (grad.expand(a_arg.shape) * b_arg, grad.expand(b_arg.shape) * a_arg)
body = (a.as_param(0) * b.as_param(1)).sum(axis=0)
c = Tensor.call(a, b, fxn=body, grad_fxn=grad_fxn)
c.sum().backward()
np.testing.assert_allclose(a.grad.numpy(), b.numpy(), rtol=1e-5)
np.testing.assert_allclose(b.grad.numpy(), a.numpy(), rtol=1e-5)
if __name__ == '__main__':
unittest.main()