mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add a test case that triggers CALL passthrough_multi (#14887)
This commit is contained in:
@@ -92,5 +92,13 @@ class TestCall(unittest.TestCase):
|
||||
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
|
||||
np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5)
|
||||
|
||||
def test_call_plus_sharded(self):
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
a = Tensor.ones(10, 10).shard(devs, axis=0)
|
||||
b = Tensor.ones(10, 10).shard(devs, axis=0)
|
||||
Tensor.realize(a, b)
|
||||
c = Tensor.call(a, b, fxn=a.as_param(0) + b.as_param(1))
|
||||
np.testing.assert_equal(c.numpy(), 2 * np.ones((10, 10)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user