diff --git a/test/unit/test_function.py b/test/unit/test_function.py index 67e0ca4d65..2a3fa6a1e2 100644 --- a/test/unit/test_function.py +++ b/test/unit/test_function.py @@ -193,5 +193,42 @@ class TestFunction(unittest.TestCase): np.testing.assert_equal(a.numpy(), [1,2,3]) np.testing.assert_equal(b.numpy(), [10,20,30]) +class TestFunctionMulti(unittest.TestCase): + devices_2 = ("CPU:0", "CPU:1") + + def test_simple_multi(self): + @function + def f(a:Tensor, b:Tensor) -> Tensor: return a+b + + a = Tensor([1,2,3,4]).shard(self.devices_2, axis=None) + b = Tensor([10,20,30,40]).shard(self.devices_2, axis=None) + np.testing.assert_equal(f(a,b).numpy(), [11,22,33,44]) + + def test_simple_multi_sharded(self): + @function + def f(a:Tensor, b:Tensor) -> Tensor: return a+b + + a = Tensor([1,2,3,4]).shard(self.devices_2, axis=0) + b = Tensor([10,20,30,40]).shard(self.devices_2, axis=0) + np.testing.assert_equal(f(a,b).numpy(), [11,22,33,44]) + + def test_data_parallel_multi(self): + @function + def f(x:Tensor, w:Tensor) -> Tensor: return x @ w + + x = Tensor([[1.,2.],[3.,4.],[5.,6.],[7.,8.]]).shard(self.devices_2, axis=0) + w = Tensor([[1.,0.],[0.,1.]]).shard(self.devices_2, axis=None) + np.testing.assert_allclose(f(x, w).numpy(), [[1.,2.],[3.,4.],[5.,6.],[7.,8.]]) + + def test_grad_implicit_multi(self): + w = Tensor([1., 2., 3., 4.], requires_grad=True).shard(self.devices_2, axis=None) + w.realize() + @function + def f(x:Tensor) -> Tensor: return x * w + + x = Tensor([4., 5., 6., 7.]).shard(self.devices_2, axis=None) + f(x).sum().backward() + np.testing.assert_allclose(w.grad.numpy(), [4., 5., 6., 7.]) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 4826ddab2c..12adde4866 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -867,6 +867,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass): def param_like(self, slot:int): if self.op is Ops.BIND: return UOp.param(slot, self.dtype, self._shape, self._device, self._min_max, self.src[0].arg[0]) + if self.axis is not None: + return UOp.param(slot, self.dtype, self.shard_shape, self._device).multi(self.axis) return UOp.param(slot, self.dtype, self._shape, self._device) def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), name:str|None=None) -> UOp: