compile raise CompileError and skip only RuntimeError in multiprocess… (#4646)

* compile raise CompileError and skip only RuntimeError in multiprocess beam

renderer error with multiprocess should not be skipped by beam

* use `==` for dtype to dtype comparison

* that needs to be is

* typo
This commit is contained in:
chenyu
2024-05-19 00:25:25 -04:00
committed by GitHub
parent 8a0d1ca7bb
commit 286b4dbdf2
14 changed files with 30 additions and 24 deletions

View File

@@ -79,7 +79,7 @@ def make_tensor(shape, dtype:dtypes, noncontiguous) -> Tensor:
+---------------------------+------------+----------+
"""
contiguous = not noncontiguous
if dtype is dtypes.bool: return Tensor.randint(shape=shape, low=0, high=2, contiguous=contiguous).cast(dtypes.bool)
if dtype == dtypes.bool: return Tensor.randint(shape=shape, low=0, high=2, contiguous=contiguous).cast(dtypes.bool)
elif dtype.is_unsigned(): return Tensor.randint(shape=shape, low=0, high=10, contiguous=contiguous).cast(dtype)
elif dtype.is_int(): return Tensor.randint(shape=shape, low=-9, high=10, contiguous=contiguous).cast(dtype) # signed int
elif dtype.is_float(): return Tensor.rand(shape=shape, low=-9, high=9, dtype=dtype, contiguous=contiguous)
@@ -452,7 +452,7 @@ class TestIndexing(unittest.TestCase):
def tensor_indices_to_np(tensor: Tensor, indices):
npt = tensor.numpy()
idxs = tuple(i.numpy().tolist() if isinstance(i, Tensor) and i.dtype is dtypes.int64 else
idxs = tuple(i.numpy().tolist() if isinstance(i, Tensor) and i.dtype == dtypes.int64 else
i for i in indices)
return npt, idxs

View File

@@ -113,7 +113,7 @@ class TestDType(unittest.TestCase):
arr = np.asarray(data, dtype=dt)
tin = Tensor(arr).numpy()
tor = torch.as_tensor(arr).detach().numpy()
assert dt is tin.dtype is tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}"
assert dt == tin.dtype == tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}"
np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3)
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):

View File

@@ -706,12 +706,12 @@ class TestSchedule(unittest.TestCase):
a = shared * 2
b = shared * 3
sched = check_schedule([a, b], 1)
for si in sched[:-2]: assert all(out.dtype is dtypes.half for out in si.outputs)
for si in sched[:-2]: assert all(out.dtype == dtypes.half for out in si.outputs)
# reduce
a = z.sum(axis=0).half().float().sum(axis=0)
sched = check_schedule(a, 2)
for si in sched[:-1]: assert all(out.dtype is dtypes.half for out in si.outputs)
for si in sched[:-1]: assert all(out.dtype == dtypes.half for out in si.outputs)
# expand
# expand will realize just after the .float(), so requires change to realize-before-expand