clean up test_scatter_reduce (#12125)

This commit is contained in:
chenyu
2025-09-11 16:36:58 -04:00
committed by GitHub
parent 9ad6a56d17
commit 544eb2c402

View File

@@ -2928,13 +2928,13 @@ class TestOps(unittest.TestCase):
@slow_test
def test_scatter_reduce(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
a = Tensor(b.detach().cpu().numpy().astype(np.int32), requires_grad=False)
for reduce in ("sum", "prod", "mean", "amin", "amax"):
for dim in (-1,1,-3):
helper_test_op([(4,5,6), (4,5,6)],
helper_test_op([(3,4,5), (3,4,5)],
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce=reduce),
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce=reduce), forward_only=True)
helper_test_op([(4,5,6), (4,5,6)],
helper_test_op([(3,4,5), (3,4,5)],
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce=reduce, include_self=False),
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce=reduce, include_self=False), forward_only=True)