From a1507c7fd4dd574b77b599697e3636d514ea9b9a Mon Sep 17 00:00:00 2001 From: Elias Wahl <82230675+Eliulm@users.noreply.github.com> Date: Wed, 6 Mar 2024 00:26:21 +0100 Subject: [PATCH] Fix Tensor.dropout() with multigpu (#3619) * Tensor.rand with multilazybuffer * remove recursive + test * whitespace * another whitespace. Sorry * remove else * Conconicalize multidevice tuple + Remove src --- test/test_multitensor.py | 7 +++++++ tinygrad/tensor.py | 7 +++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 1c5db13750..4bcaaa5194 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -364,6 +364,13 @@ class TestMultiTensor(unittest.TestCase): # don't allow assigns that change axes t_none.assign(t_zero) + def test_dropout_on_shard(self): + Tensor.training = True + X = Tensor.ones(256).to(devices_2) + output = X.dropout(0.5) + output.numpy() + Tensor.training = False + @unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI") class TestShrinkMultiTensorShardedAxis(unittest.TestCase): # shrink a multitensor on sharded axis diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 20ff156129..d6fac4359b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -201,8 +201,11 @@ class Tensor: # ***** creation llop entrypoint ***** @staticmethod - def _loadop(op, shape, device:Optional[str]=None, dtype:Optional[DType]=None, arg=None, **kwargs): - return Tensor(LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs) + def _loadop(op, shape, device:Optional[Union[Tuple[str], str]]=None, dtype:Optional[DType]=None, arg=None, **kwargs): + if isinstance(device, tuple): + return Tensor(MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \ + for d in device], None), device, dtype, **kwargs) + return Tensor(LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), device, dtype, **kwargs) @staticmethod def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs)