mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix TestDropoutProbabilityEdgeCases (#11322)
This commit is contained in:
@@ -97,21 +97,17 @@ class TestEmptyTensorEdgeCases(unittest.TestCase):
|
||||
class TestDropoutProbabilityEdgeCases(unittest.TestCase):
|
||||
# we don't need more of these
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_dropout_rate_one(self):
|
||||
# out is full of NaNs it should be 0s
|
||||
with Tensor.train():
|
||||
out = Tensor.ones(100).dropout(1.0)
|
||||
np.testing.assert_allclose(out.numpy(), np.zeros(100))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_dropout_invalid_prob(self):
|
||||
# negative dropout probability should raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
torch.nn.functional.dropout(torch.ones(10), -0.1, True)
|
||||
with Tensor.train():
|
||||
out = Tensor.ones(10).dropout(-0.1)
|
||||
np.testing.assert_allclose(out.numpy(), np.ones(10))
|
||||
with self.assertRaises(ValueError):
|
||||
with Tensor.train():
|
||||
Tensor.ones(10).dropout(-0.1)
|
||||
|
||||
class TestInputValidation(unittest.TestCase):
|
||||
# we don't need more of these, input validation bugs are not very interesting, many are WONTFIX
|
||||
|
||||
@@ -3853,7 +3853,9 @@ class Tensor(MathTrait):
|
||||
print(t.dropout().numpy())
|
||||
```
|
||||
"""
|
||||
if not 0 <= p <= 1: raise ValueError(f"{p=} is out of range [0, 1]")
|
||||
if not Tensor.training or p == 0: return self
|
||||
if p == 1: return self.zeros_like()
|
||||
return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
|
||||
|
||||
# helper function commonly used for indexing
|
||||
|
||||
Reference in New Issue
Block a user