tensor reduction touchup (#6402)

- fixing spacing
- use get_args to get valid Literal values and raise ValueError to match, and a test for that
- use `Y` to be consistent
This commit is contained in:
chenyu
2024-09-08 03:55:51 -04:00
committed by GitHub
parent 65da03e186
commit 7df4373fd9
2 changed files with 19 additions and 16 deletions

View File

@@ -2109,6 +2109,9 @@ class TestOps(unittest.TestCase):
for r in ("mean", "sum", "none"):
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction=r),
lambda x,y: x.cross_entropy(y, reduction=r))
self.helper_test_exception([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction="typo"),
lambda x,y: x.cross_entropy(y, reduction="typo"), expected=ValueError)
def test_cross_entropy_smoothing(self):
for ls in (0., 0.3, 0.7, 1.):
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, label_smoothing=ls),