fix symbolic_ops tests with Tensor.training=True (#1686)

This commit is contained in:
chenyu
2023-08-26 20:19:56 -07:00
committed by GitHub
parent 6c5dc9c153
commit 66fbf4800b
3 changed files with 10 additions and 5 deletions

View File

@@ -670,7 +670,7 @@ class Tensor:
return (ret + bias.reshape(shape=[1, -1, 1, 1])) if bias else ret
def dropout(self, p=0.5) -> Tensor:
if not Tensor.training: return self
if not Tensor.training or p == 0: return self
mask = (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p).cast(dtypes.bool)
return self * mask * (1/(1.0 - p))