mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
clean up test_scatter_reduce (#12125)
This commit is contained in:
@@ -2928,13 +2928,13 @@ class TestOps(unittest.TestCase):
|
||||
@slow_test
|
||||
def test_scatter_reduce(self):
|
||||
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
|
||||
a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
a = Tensor(b.detach().cpu().numpy().astype(np.int32), requires_grad=False)
|
||||
for reduce in ("sum", "prod", "mean", "amin", "amax"):
|
||||
for dim in (-1,1,-3):
|
||||
helper_test_op([(4,5,6), (4,5,6)],
|
||||
helper_test_op([(3,4,5), (3,4,5)],
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce=reduce),
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce=reduce), forward_only=True)
|
||||
helper_test_op([(4,5,6), (4,5,6)],
|
||||
helper_test_op([(3,4,5), (3,4,5)],
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce=reduce, include_self=False),
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce=reduce, include_self=False), forward_only=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user