mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
update test_ops for tensors from torch (#9308)
a few detach().numpy() -> detach().cpu().numpy()
This commit is contained in:
@@ -2504,7 +2504,7 @@ class TestOps(unittest.TestCase):
|
||||
c = torch.randint(low=-5, high=5, size=(1,1,4,1,1,1), dtype=torch.int64, requires_grad=False)
|
||||
d = torch.randint(high=4, size=(2,1,1,5,1,1), dtype=torch.int64, requires_grad=False)
|
||||
e = torch.randint(high=1, size=(1,1,1,1,6,1), dtype=torch.int64, requires_grad=False)
|
||||
i, j, k, o, p = [Tensor(tor.detach().numpy().astype(np.int32), requires_grad=False) for tor in [a,b,c,d,e]]
|
||||
i, j, k, o, p = [Tensor(tor.detach().cpu().numpy().astype(np.int32), requires_grad=False) for tor in [a,b,c,d,e]]
|
||||
return a,b,c,d,e,i,j,k,o,p
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU can only run kernels with up to 10 buffers")
|
||||
@@ -2601,7 +2601,7 @@ class TestOps(unittest.TestCase):
|
||||
# indices cannot have gradient
|
||||
# indices cannot be negative (torch gather)
|
||||
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
|
||||
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
helper_test_op([(4,5,6)], lambda x: x.gather(dim=0, index=b), lambda x: x.gather(dim=0, index=a))
|
||||
helper_test_op([(4,5,6)], lambda x: x.gather(dim=1, index=b), lambda x: x.gather(dim=1, index=a))
|
||||
helper_test_op([(4,5,6)], lambda x: x.gather(dim=2, index=b), lambda x: x.gather(dim=2, index=a))
|
||||
@@ -2619,7 +2619,7 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_scatter(self):
|
||||
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
|
||||
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
for dim in (0,1,2,-1,-2,-3):
|
||||
helper_test_op([(4,5,6), (4,5,6)], lambda x,src: x.scatter(dim=dim, index=b, src=src),
|
||||
lambda x,src: x.scatter(dim=dim, index=a, src=src), forward_only=True)
|
||||
@@ -2644,7 +2644,7 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
# overlapping indices with 0s
|
||||
b = torch.tensor([0,0], requires_grad=False)
|
||||
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
helper_test_op(None,
|
||||
lambda x,src: x.scatter(0, b, src),
|
||||
lambda x,src: x.scatter(0, a, src), forward_only=True,
|
||||
@@ -2652,7 +2652,7 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_scatter_add(self):
|
||||
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
|
||||
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=float("inf"), reduce="add"),
|
||||
lambda x: x.scatter(dim=1, index=a, src=float("inf"), reduce="add"), forward_only=True)
|
||||
|
||||
@@ -2664,7 +2664,7 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_scatter_mul(self):
|
||||
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
|
||||
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=float("inf"), reduce="multiply"),
|
||||
lambda x: x.scatter(dim=1, index=a, src=float("inf"), reduce="multiply"), forward_only=True)
|
||||
|
||||
@@ -2680,7 +2680,7 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_scatter_reduce(self):
|
||||
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
|
||||
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
for reduce in ("sum", "prod", "mean", "amin", "amax"):
|
||||
for dim in (0,1,2,-1,-2,-3):
|
||||
helper_test_op([(4,5,6), (4,5,6)],
|
||||
@@ -2692,7 +2692,7 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_scatter_reduce_prod_zeros(self):
|
||||
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
|
||||
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
x = Tensor.zeros([4,5,6]).float()
|
||||
y = torch.zeros([4,5,6]).float()
|
||||
helper_test_op([(4,5,6)],
|
||||
@@ -2701,7 +2701,7 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_scatter_reduce_errors(self):
|
||||
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
|
||||
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
# invalid reduce arg
|
||||
self.helper_test_exception([(4,5,6), (4,5,6)],
|
||||
lambda x,src: x.scatter_reduce(dim=0, index=b, src=src, reduce="INVALID"),
|
||||
|
||||
Reference in New Issue
Block a user