mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
tensor reduction touchup (#6402)
- fixing spacing - use get_args to get valid Literal values and raise ValueError to match, and a test for that - use `Y` to be consistent
This commit is contained in:
@@ -2109,6 +2109,9 @@ class TestOps(unittest.TestCase):
|
||||
for r in ("mean", "sum", "none"):
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction=r),
|
||||
lambda x,y: x.cross_entropy(y, reduction=r))
|
||||
self.helper_test_exception([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction="typo"),
|
||||
lambda x,y: x.cross_entropy(y, reduction="typo"), expected=ValueError)
|
||||
|
||||
def test_cross_entropy_smoothing(self):
|
||||
for ls in (0., 0.3, 0.7, 1.):
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, label_smoothing=ls),
|
||||
|
||||
Reference in New Issue
Block a user