mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix multi minimal (#15044)
This commit is contained in:
@@ -193,5 +193,42 @@ class TestFunction(unittest.TestCase):
|
||||
np.testing.assert_equal(a.numpy(), [1,2,3])
|
||||
np.testing.assert_equal(b.numpy(), [10,20,30])
|
||||
|
||||
class TestFunctionMulti(unittest.TestCase):
|
||||
devices_2 = ("CPU:0", "CPU:1")
|
||||
|
||||
def test_simple_multi(self):
|
||||
@function
|
||||
def f(a:Tensor, b:Tensor) -> Tensor: return a+b
|
||||
|
||||
a = Tensor([1,2,3,4]).shard(self.devices_2, axis=None)
|
||||
b = Tensor([10,20,30,40]).shard(self.devices_2, axis=None)
|
||||
np.testing.assert_equal(f(a,b).numpy(), [11,22,33,44])
|
||||
|
||||
def test_simple_multi_sharded(self):
|
||||
@function
|
||||
def f(a:Tensor, b:Tensor) -> Tensor: return a+b
|
||||
|
||||
a = Tensor([1,2,3,4]).shard(self.devices_2, axis=0)
|
||||
b = Tensor([10,20,30,40]).shard(self.devices_2, axis=0)
|
||||
np.testing.assert_equal(f(a,b).numpy(), [11,22,33,44])
|
||||
|
||||
def test_data_parallel_multi(self):
|
||||
@function
|
||||
def f(x:Tensor, w:Tensor) -> Tensor: return x @ w
|
||||
|
||||
x = Tensor([[1.,2.],[3.,4.],[5.,6.],[7.,8.]]).shard(self.devices_2, axis=0)
|
||||
w = Tensor([[1.,0.],[0.,1.]]).shard(self.devices_2, axis=None)
|
||||
np.testing.assert_allclose(f(x, w).numpy(), [[1.,2.],[3.,4.],[5.,6.],[7.,8.]])
|
||||
|
||||
def test_grad_implicit_multi(self):
|
||||
w = Tensor([1., 2., 3., 4.], requires_grad=True).shard(self.devices_2, axis=None)
|
||||
w.realize()
|
||||
@function
|
||||
def f(x:Tensor) -> Tensor: return x * w
|
||||
|
||||
x = Tensor([4., 5., 6., 7.]).shard(self.devices_2, axis=None)
|
||||
f(x).sum().backward()
|
||||
np.testing.assert_allclose(w.grad.numpy(), [4., 5., 6., 7.])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -867,6 +867,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
def param_like(self, slot:int):
|
||||
if self.op is Ops.BIND:
|
||||
return UOp.param(slot, self.dtype, self._shape, self._device, self._min_max, self.src[0].arg[0])
|
||||
if self.axis is not None:
|
||||
return UOp.param(slot, self.dtype, self.shard_shape, self._device).multi(self.axis)
|
||||
return UOp.param(slot, self.dtype, self._shape, self._device)
|
||||
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), name:str|None=None) -> UOp:
|
||||
|
||||
Reference in New Issue
Block a user