fix TestDropoutProbabilityEdgeCases (#11322)

This commit is contained in:
chenyu
2025-07-22 11:13:56 -04:00
committed by GitHub
parent fb42c84365
commit c6aa8e58ca
2 changed files with 5 additions and 7 deletions

View File

@@ -97,21 +97,17 @@ class TestEmptyTensorEdgeCases(unittest.TestCase):
class TestDropoutProbabilityEdgeCases(unittest.TestCase):
# we don't need more of these
@unittest.expectedFailure
def test_dropout_rate_one(self):
# out is full of NaNs it should be 0s
with Tensor.train():
out = Tensor.ones(100).dropout(1.0)
np.testing.assert_allclose(out.numpy(), np.zeros(100))
@unittest.expectedFailure
def test_dropout_invalid_prob(self):
# negative dropout probability should raise an error
with self.assertRaises(ValueError):
torch.nn.functional.dropout(torch.ones(10), -0.1, True)
with Tensor.train():
out = Tensor.ones(10).dropout(-0.1)
np.testing.assert_allclose(out.numpy(), np.ones(10))
with self.assertRaises(ValueError):
with Tensor.train():
Tensor.ones(10).dropout(-0.1)
class TestInputValidation(unittest.TestCase):
# we don't need more of these, input validation bugs are not very interesting, many are WONTFIX

View File

@@ -3853,7 +3853,9 @@ class Tensor(MathTrait):
print(t.dropout().numpy())
```
"""
if not 0 <= p <= 1: raise ValueError(f"{p=} is out of range [0, 1]")
if not Tensor.training or p == 0: return self
if p == 1: return self.zeros_like()
return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
# helper function commonly used for indexing