mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
default tensors to int32 in test_ops (#8097)
torch defaults to int64 but we care more about int32 anyway. remove skipped tests due to int64 not supported
This commit is contained in:
@@ -71,6 +71,9 @@ def prepare_test_op(low, high, shps, vals, forward_only=False):
|
||||
np.random.seed(0)
|
||||
np_data = [np.random.uniform(low=low, high=high, size=size).astype(_to_np_dtype(dtypes.default_float)) for size in shps]
|
||||
ts = [torch.tensor(data, requires_grad=(not forward_only)) for data in np_data]
|
||||
for i in range(len(ts)):
|
||||
# NOTE: torch default int64 for python ints input
|
||||
if ts[i].dtype == torch.int64: ts[i] = ts[i].type(torch.int32)
|
||||
tst = [Tensor(x.detach().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts]
|
||||
return ts, tst
|
||||
|
||||
@@ -312,8 +315,7 @@ class TestOps(unittest.TestCase):
|
||||
def _test_cmp(self, fxn, reverse=True):
|
||||
# test different dtypes
|
||||
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0.,1,2], [2.,1,0]])
|
||||
if is_dtype_supported(dtypes.long):
|
||||
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0,1,2], [2,1,0]])
|
||||
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0,1,2], [2,1,0]])
|
||||
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[True, True, False], [False,True,False]])
|
||||
# test broadcasting
|
||||
for shps in [[(3, 4, 5), (3, 4, 5)], [(3, 4, 5), (5,)], [(5,), (3, 4, 5)]]:
|
||||
@@ -563,10 +565,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True)
|
||||
|
||||
def test_pow_int(self):
|
||||
# TODO: better infra for these, helper_test_op creates buffer in long first, so WEBGPU fails
|
||||
def _test(base, exponent):
|
||||
np.testing.assert_equal((Tensor(base) ** Tensor(exponent)).numpy(),
|
||||
(torch.tensor(base, dtype=torch.int) ** torch.tensor(exponent, dtype=torch.int)).numpy())
|
||||
def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, exponent], forward_only=True)
|
||||
|
||||
for base in ([1, 2, 3], [-1, -2, -3]):
|
||||
for exponent in ([2, 3, 4], [-2, -3, -4]):
|
||||
@@ -1098,9 +1097,8 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,3)], lambda x: x.min().mul(0.5))
|
||||
helper_test_op([()], lambda x: x.min())
|
||||
|
||||
if is_dtype_supported(dtypes.long):
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[0, -2**31]])
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[-2**31, 0]])
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[0, -2**31]])
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[-2**31, 0]])
|
||||
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[False, True]])
|
||||
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[True, False]])
|
||||
|
||||
@@ -1111,9 +1109,8 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1))
|
||||
helper_test_op([()], lambda x: x.max())
|
||||
|
||||
if is_dtype_supported(dtypes.long):
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[0, -2**31]])
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[-2**31, 0]])
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[0, -2**31]])
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[-2**31, 0]])
|
||||
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[False, True]])
|
||||
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[True, False]])
|
||||
|
||||
@@ -1216,12 +1213,10 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(3,4,5,6)], lambda x: torch.stack(torch.std_mean(x, axis=(1,2))),
|
||||
lambda x: Tensor.stack(*x.std_mean(axis=(1,2))))
|
||||
def test_softmax(self):
|
||||
# exceed per kernel buffer limit with backward
|
||||
forward_only = (Device.DEFAULT == "WEBGPU")
|
||||
helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
|
||||
helper_test_op([(45)], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
|
||||
helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
|
||||
helper_test_op([()], torch.nn.Softmax(dim=-1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
|
||||
helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45)], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], torch.nn.Softmax(dim=-1), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
|
||||
def test_softmax_other_axis(self):
|
||||
helper_test_op([(10,10,10)], lambda x: x.softmax(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(10,10,10)], lambda x: x.softmax(1), atol=1e-7, grad_atol=1e-7)
|
||||
@@ -2246,7 +2241,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(1,128), (128,128)], lambda x,y: (x@y).relu())
|
||||
|
||||
@unittest.skip("this test is broken #862")
|
||||
def test_max_inf(self):
|
||||
def test_max_nan(self):
|
||||
n = Tensor([1, float("nan")]).max().numpy()
|
||||
assert math.isnan(n.item()), f"{n.item()} is not nan"
|
||||
|
||||
@@ -2484,45 +2479,39 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, label_smoothing=ls),
|
||||
lambda x,y: x.cross_entropy(y, label_smoothing=ls))
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
|
||||
def test_nll_loss(self):
|
||||
helper_test_op([(32,10), (32)],
|
||||
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long)),
|
||||
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long)), forward_only=True)
|
||||
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32)), forward_only=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
|
||||
def test_nll_loss_3d(self):
|
||||
helper_test_op([(32,10,3,3,3), (32,3,3,3)],
|
||||
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long)),
|
||||
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long)), forward_only=True)
|
||||
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32)), forward_only=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
|
||||
def test_nll_loss_reductions(self):
|
||||
for r in ("mean", "sum", "none"):
|
||||
helper_test_op([(32,10), (32)],
|
||||
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long), reduction=r),
|
||||
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), reduction=r), forward_only=True)
|
||||
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), reduction=r), forward_only=True)
|
||||
self.helper_test_exception([(32,10), (32)],
|
||||
lambda x,y: torch.nn.functional.nll_loss(x, torch.clip(y,0).type(torch.long), reduction="typo"),
|
||||
lambda x,y: x.nll_loss(y.clip(0).cast(dtypes.long), reduction="typo"), expected=ValueError)
|
||||
lambda x,y: x.nll_loss(y.clip(0).cast(dtypes.int32), reduction="typo"), expected=ValueError)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
|
||||
def test_nll_loss_weight(self):
|
||||
for r in ("mean", "sum", "none"):
|
||||
helper_test_op([(32,10), (32), (10)],
|
||||
lambda x,y,z: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long),
|
||||
weight=z, reduction=r),
|
||||
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), weight=z, reduction=r), forward_only=True)
|
||||
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), weight=z, reduction=r), forward_only=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
|
||||
def test_nll_loss_3d_weight(self):
|
||||
for r in ("mean", "sum", "none"):
|
||||
helper_test_op([(32,10,3,3,3), (32,3,3,3), (10)],
|
||||
lambda x,y,z: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long),
|
||||
weight=z, reduction=r),
|
||||
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), weight=z, reduction=r), forward_only=True)
|
||||
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), weight=z, reduction=r), forward_only=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
|
||||
def test_nll_loss_ignore_index(self):
|
||||
logits = [[2.0, 0.5, -1.0],
|
||||
[1.5, 2.5, -0.5],
|
||||
@@ -2530,7 +2519,7 @@ class TestOps(unittest.TestCase):
|
||||
targets = [0, 1, 2]
|
||||
helper_test_op(None, lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1),
|
||||
torch.clip(y,0).type(torch.long), ignore_index=1),
|
||||
lambda x,y: x.log_softmax().nll_loss(y.clip(0).cast(dtypes.long), ignore_index=1),
|
||||
lambda x,y: x.log_softmax().nll_loss(y.clip(0), ignore_index=1),
|
||||
forward_only=True, vals=[logits, targets])
|
||||
|
||||
def test_one_hot(self):
|
||||
@@ -2552,8 +2541,7 @@ class TestOps(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)")
|
||||
def test_cast(self):
|
||||
helper_test_op([(3, 3)], lambda x: x.float())
|
||||
if is_dtype_supported(dtypes.long):
|
||||
helper_test_op(None, lambda x: x.float(), vals=[[0, 1, 2, 3]], forward_only=True)
|
||||
helper_test_op(None, lambda x: x.float(), vals=[[0, 1, 2, 3]], forward_only=True)
|
||||
helper_test_op(None, lambda x: x.float(), vals=[[True, False]], forward_only=True)
|
||||
helper_test_op([(3, 3)], lambda x: x.int(), forward_only=True)
|
||||
helper_test_op([(3, 3)], lambda x: x.bool(), forward_only=True)
|
||||
@@ -2587,7 +2575,6 @@ class TestOpsUint8(unittest.TestCase):
|
||||
lambda x: torch.nn.functional.interpolate((10*x).relu().type(torch.uint8), size=out_sz, mode="nearest-exact"),
|
||||
lambda x: Tensor.interpolate((10*x).relu().cast('uint8'), size=out_sz, mode="nearest-exact"), forward_only=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
|
||||
def test_min(self):
|
||||
helper_test_op(None,
|
||||
lambda x: x.type(torch.uint8).min(),
|
||||
|
||||
Reference in New Issue
Block a user