update test_ops for tensors from torch (#9308)

a few detach().numpy() -> detach().cpu().numpy()
This commit is contained in:
chenyu
2025-02-28 15:57:25 -05:00
committed by GitHub
parent 38d7aae3b7
commit fe0f860209

View File

@@ -2504,7 +2504,7 @@ class TestOps(unittest.TestCase):
c = torch.randint(low=-5, high=5, size=(1,1,4,1,1,1), dtype=torch.int64, requires_grad=False)
d = torch.randint(high=4, size=(2,1,1,5,1,1), dtype=torch.int64, requires_grad=False)
e = torch.randint(high=1, size=(1,1,1,1,6,1), dtype=torch.int64, requires_grad=False)
i, j, k, o, p = [Tensor(tor.detach().numpy().astype(np.int32), requires_grad=False) for tor in [a,b,c,d,e]]
i, j, k, o, p = [Tensor(tor.detach().cpu().numpy().astype(np.int32), requires_grad=False) for tor in [a,b,c,d,e]]
return a,b,c,d,e,i,j,k,o,p
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU can only run kernels with up to 10 buffers")
@@ -2601,7 +2601,7 @@ class TestOps(unittest.TestCase):
# indices cannot have gradient
# indices cannot be negative (torch gather)
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
helper_test_op([(4,5,6)], lambda x: x.gather(dim=0, index=b), lambda x: x.gather(dim=0, index=a))
helper_test_op([(4,5,6)], lambda x: x.gather(dim=1, index=b), lambda x: x.gather(dim=1, index=a))
helper_test_op([(4,5,6)], lambda x: x.gather(dim=2, index=b), lambda x: x.gather(dim=2, index=a))
@@ -2619,7 +2619,7 @@ class TestOps(unittest.TestCase):
def test_scatter(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
for dim in (0,1,2,-1,-2,-3):
helper_test_op([(4,5,6), (4,5,6)], lambda x,src: x.scatter(dim=dim, index=b, src=src),
lambda x,src: x.scatter(dim=dim, index=a, src=src), forward_only=True)
@@ -2644,7 +2644,7 @@ class TestOps(unittest.TestCase):
# overlapping indices with 0s
b = torch.tensor([0,0], requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
helper_test_op(None,
lambda x,src: x.scatter(0, b, src),
lambda x,src: x.scatter(0, a, src), forward_only=True,
@@ -2652,7 +2652,7 @@ class TestOps(unittest.TestCase):
def test_scatter_add(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=float("inf"), reduce="add"),
lambda x: x.scatter(dim=1, index=a, src=float("inf"), reduce="add"), forward_only=True)
@@ -2664,7 +2664,7 @@ class TestOps(unittest.TestCase):
def test_scatter_mul(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=float("inf"), reduce="multiply"),
lambda x: x.scatter(dim=1, index=a, src=float("inf"), reduce="multiply"), forward_only=True)
@@ -2680,7 +2680,7 @@ class TestOps(unittest.TestCase):
def test_scatter_reduce(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
for reduce in ("sum", "prod", "mean", "amin", "amax"):
for dim in (0,1,2,-1,-2,-3):
helper_test_op([(4,5,6), (4,5,6)],
@@ -2692,7 +2692,7 @@ class TestOps(unittest.TestCase):
def test_scatter_reduce_prod_zeros(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
x = Tensor.zeros([4,5,6]).float()
y = torch.zeros([4,5,6]).float()
helper_test_op([(4,5,6)],
@@ -2701,7 +2701,7 @@ class TestOps(unittest.TestCase):
def test_scatter_reduce_errors(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
# invalid reduce arg
self.helper_test_exception([(4,5,6), (4,5,6)],
lambda x,src: x.scatter_reduce(dim=0, index=b, src=src, reduce="INVALID"),