torch.clip(x,y) -> x.clip(y) in test_ops (#10954)

* torch.clip(x,y) -> x.clip(y) in test_ops

* test_binary_crossentropy_logits_pos_weights
This commit is contained in:
chenyu
2025-06-24 10:22:19 -04:00
committed by GitHub
parent 86d458533f
commit 35504c938e

View File

@@ -2894,23 +2894,23 @@ class TestOps(unittest.TestCase):
expected=RuntimeError)
def test_binary_crossentropy(self):
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),torch.clip(y,0,1)),
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),y.clip(0,1)),
lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1)))
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,torch.clip(y,0,1)),
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,y.clip(0,1)),
lambda x,y: x.binary_crossentropy_logits(y.clip(0,1)))
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,torch.clip(y,0,1)),
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,y.clip(0,1)),
lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1)))
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),torch.clip(y,0,1)),
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),y.clip(0,1)),
lambda x,y: x.binary_crossentropy_logits(y.clip(0,1)))
def test_binary_crossentropy_reductions(self):
for r in ("mean", "sum", "none"):
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(), torch.clip(y,0,1), reduction=r),
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(), y.clip(0,1), reduction=r),
lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1), reduction=r))
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x, torch.clip(y,0,1), reduction=r),
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x, y.clip(0,1), reduction=r),
lambda x,y: x.binary_crossentropy_logits(y.clip(0,1), reduction=r))
def test_binary_crossentropy_pos_weights(self):
def test_binary_crossentropy_logits_pos_weights(self):
pos_weight = [0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,torch.clip(y,0,1),
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,y.clip(0,1),
pos_weight=torch.tensor(pos_weight)),
lambda x,y: x.binary_crossentropy_logits(y.clip(0,1),pos_weight=Tensor(pos_weight)))
def test_cross_entropy(self):
@@ -2932,34 +2932,34 @@ class TestOps(unittest.TestCase):
def test_nll_loss(self):
helper_test_op([(32,10), (32)],
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long)),
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), y.clip(0).type(torch.long)),
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32)), forward_only=True)
def test_nll_loss_3d(self):
helper_test_op([(32,10,3,3,3), (32,3,3,3)],
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long)),
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), y.clip(0).type(torch.long)),
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32)), forward_only=True)
def test_nll_loss_reductions(self):
for r in ("mean", "sum", "none"):
helper_test_op([(32,10), (32)],
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long), reduction=r),
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), y.clip(0).type(torch.long), reduction=r),
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), reduction=r), forward_only=True)
self.helper_test_exception([(32,10), (32)],
lambda x,y: torch.nn.functional.nll_loss(x, torch.clip(y,0).type(torch.long), reduction="typo"),
lambda x,y: torch.nn.functional.nll_loss(x, y.clip(0).type(torch.long), reduction="typo"),
lambda x,y: x.nll_loss(y.clip(0).cast(dtypes.int32), reduction="typo"), expected=ValueError)
def test_nll_loss_weight(self):
for r in ("mean", "sum", "none"):
helper_test_op([(32,10), (32), (10)],
lambda x,y,z: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long),
lambda x,y,z: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), y.clip(0).type(torch.long),
weight=z, reduction=r),
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), weight=z, reduction=r), forward_only=True)
def test_nll_loss_3d_weight(self):
for r in ("mean", "sum", "none"):
helper_test_op([(32,10,3,3,3), (32,3,3,3), (10)],
lambda x,y,z: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long),
lambda x,y,z: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), y.clip(0).type(torch.long),
weight=z, reduction=r),
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), weight=z, reduction=r), forward_only=True)
@@ -2969,7 +2969,7 @@ class TestOps(unittest.TestCase):
[0.0, -2.0, 1.0]]
targets = [0, 1, 2]
helper_test_op(None, lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1),
torch.clip(y,0).type(torch.long), ignore_index=1),
y.clip(0).type(torch.long), ignore_index=1),
lambda x,y: x.log_softmax().nll_loss(y.clip(0), ignore_index=1),
forward_only=True, vals=[logits, targets])