improve test_nll_loss (#10986)

build target and weight tensors outside so it tests backward too.
This commit is contained in:
chenyu
2025-06-26 02:46:55 -04:00
committed by GitHub
parent 0612acfc70
commit 49bba2f0a0
2 changed files with 30 additions and 25 deletions

View File

@@ -369,6 +369,7 @@ decomps = [
aten.threshold,
aten.nll_loss_forward,
aten.nll_loss_backward,
aten.nll_loss2d_backward,
# AttributeError: 'int' object has no attribute '_broadcasted'
aten.sigmoid_backward,
aten.tanh_backward,

View File

@@ -2946,47 +2946,51 @@ class TestOps(unittest.TestCase):
lambda x: x.cross_entropy(Tensor(classes), label_smoothing=ls))
def test_nll_loss(self):
helper_test_op([(32,10), (32)],
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), y.clip(0).type(torch.long)),
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32)), forward_only=True)
target = np.random.randint(0, 10, (32,), dtype=np.int32).tolist()
helper_test_op([(32,10)],
lambda x: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.tensor(target)),
lambda x: x.log_softmax(axis=1).nll_loss(Tensor(target)))
def test_nll_loss_3d(self):
helper_test_op([(32,10,3,3,3), (32,3,3,3)],
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), y.clip(0).type(torch.long)),
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32)), forward_only=True)
target = np.random.randint(0, 10, (32,3,3,3), dtype=np.int32).tolist()
helper_test_op([(32,10,3,3,3)],
lambda x: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.tensor(target)),
lambda x: x.log_softmax(axis=1).nll_loss(Tensor(target)))
def test_nll_loss_reductions(self):
target = np.random.randint(0, 10, (32,), dtype=np.int32).tolist()
for r in ("mean", "sum", "none"):
helper_test_op([(32,10), (32)],
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), y.clip(0).type(torch.long), reduction=r),
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), reduction=r), forward_only=True)
self.helper_test_exception([(32,10), (32)],
lambda x,y: torch.nn.functional.nll_loss(x, y.clip(0).type(torch.long), reduction="typo"),
lambda x,y: x.nll_loss(y.clip(0).cast(dtypes.int32), reduction="typo"), expected=ValueError)
helper_test_op([(32,10)],
lambda x: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.tensor(target), reduction=r),
lambda x: x.log_softmax(axis=1).nll_loss(Tensor(target), reduction=r))
self.helper_test_exception([(32,10)],
lambda x: torch.nn.functional.nll_loss(x, torch.tensor(target), reduction="typo"),
lambda x: x.nll_loss(Tensor(target), reduction="typo"), expected=ValueError)
def test_nll_loss_weight(self):
target = np.random.randint(0, 10, (32,), dtype=np.int32).tolist()
weight = np.random.normal(0, 1, (10,)).astype(np.float32).tolist()
for r in ("mean", "sum", "none"):
helper_test_op([(32,10), (32), (10)],
lambda x,y,z: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), y.clip(0).type(torch.long),
weight=z, reduction=r),
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), weight=z, reduction=r), forward_only=True)
helper_test_op([(32,10)],
lambda x: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.tensor(target), torch.tensor(weight), reduction=r),
lambda x: x.log_softmax(axis=1).nll_loss(Tensor(target), Tensor(weight), reduction=r))
def test_nll_loss_3d_weight(self):
target = np.random.randint(0, 10, (32,3,3,3), dtype=np.int32).tolist()
weight = np.random.normal(0, 1, (10,)).astype(np.float32).tolist()
for r in ("mean", "sum", "none"):
helper_test_op([(32,10,3,3,3), (32,3,3,3), (10)],
lambda x,y,z: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), y.clip(0).type(torch.long),
weight=z, reduction=r),
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), weight=z, reduction=r), forward_only=True)
helper_test_op([(32,10,3,3,3)],
lambda x: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.tensor(target), torch.tensor(weight), reduction=r),
lambda x: x.log_softmax(axis=1).nll_loss(Tensor(target), Tensor(weight), reduction=r))
def test_nll_loss_ignore_index(self):
logits = [[2.0, 0.5, -1.0],
[1.5, 2.5, -0.5],
[0.0, -2.0, 1.0]]
targets = [0, 1, 2]
helper_test_op(None, lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1),
y.clip(0).type(torch.long), ignore_index=1),
lambda x,y: x.log_softmax().nll_loss(y.clip(0), ignore_index=1),
forward_only=True, vals=[logits, targets])
target = [0, 1, 2]
helper_test_op(None, lambda x: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.tensor(target), ignore_index=1),
lambda x: x.log_softmax().nll_loss(Tensor(target), ignore_index=1),
vals=[logits])
def test_one_hot(self):
data = [1, 2, 4]