fix multi minimal (#15044)

This commit is contained in:
George Hotz
2026-02-27 14:31:58 +08:00
committed by GitHub
parent 3e1e12528c
commit 010d2790ce
2 changed files with 39 additions and 0 deletions

View File

@@ -193,5 +193,42 @@ class TestFunction(unittest.TestCase):
np.testing.assert_equal(a.numpy(), [1,2,3])
np.testing.assert_equal(b.numpy(), [10,20,30])
class TestFunctionMulti(unittest.TestCase):
devices_2 = ("CPU:0", "CPU:1")
def test_simple_multi(self):
@function
def f(a:Tensor, b:Tensor) -> Tensor: return a+b
a = Tensor([1,2,3,4]).shard(self.devices_2, axis=None)
b = Tensor([10,20,30,40]).shard(self.devices_2, axis=None)
np.testing.assert_equal(f(a,b).numpy(), [11,22,33,44])
def test_simple_multi_sharded(self):
@function
def f(a:Tensor, b:Tensor) -> Tensor: return a+b
a = Tensor([1,2,3,4]).shard(self.devices_2, axis=0)
b = Tensor([10,20,30,40]).shard(self.devices_2, axis=0)
np.testing.assert_equal(f(a,b).numpy(), [11,22,33,44])
def test_data_parallel_multi(self):
@function
def f(x:Tensor, w:Tensor) -> Tensor: return x @ w
x = Tensor([[1.,2.],[3.,4.],[5.,6.],[7.,8.]]).shard(self.devices_2, axis=0)
w = Tensor([[1.,0.],[0.,1.]]).shard(self.devices_2, axis=None)
np.testing.assert_allclose(f(x, w).numpy(), [[1.,2.],[3.,4.],[5.,6.],[7.,8.]])
def test_grad_implicit_multi(self):
w = Tensor([1., 2., 3., 4.], requires_grad=True).shard(self.devices_2, axis=None)
w.realize()
@function
def f(x:Tensor) -> Tensor: return x * w
x = Tensor([4., 5., 6., 7.]).shard(self.devices_2, axis=None)
f(x).sum().backward()
np.testing.assert_allclose(w.grad.numpy(), [4., 5., 6., 7.])
if __name__ == '__main__':
unittest.main()

View File

@@ -867,6 +867,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def param_like(self, slot:int):
if self.op is Ops.BIND:
return UOp.param(slot, self.dtype, self._shape, self._device, self._min_max, self.src[0].arg[0])
if self.axis is not None:
return UOp.param(slot, self.dtype, self.shard_shape, self._device).multi(self.axis)
return UOp.param(slot, self.dtype, self._shape, self._device)
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), name:str|None=None) -> UOp: