mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
compile raise CompileError and skip only RuntimeError in multiprocess… (#4646)
* compile raise CompileError and skip only RuntimeError in multiprocess beam renderer error with multiprocess should not be skipped by beam * use `==` for dtype to dtype comparison * that needs to be is * typo
This commit is contained in:
@@ -79,7 +79,7 @@ def make_tensor(shape, dtype:dtypes, noncontiguous) -> Tensor:
|
||||
+---------------------------+------------+----------+
|
||||
"""
|
||||
contiguous = not noncontiguous
|
||||
if dtype is dtypes.bool: return Tensor.randint(shape=shape, low=0, high=2, contiguous=contiguous).cast(dtypes.bool)
|
||||
if dtype == dtypes.bool: return Tensor.randint(shape=shape, low=0, high=2, contiguous=contiguous).cast(dtypes.bool)
|
||||
elif dtype.is_unsigned(): return Tensor.randint(shape=shape, low=0, high=10, contiguous=contiguous).cast(dtype)
|
||||
elif dtype.is_int(): return Tensor.randint(shape=shape, low=-9, high=10, contiguous=contiguous).cast(dtype) # signed int
|
||||
elif dtype.is_float(): return Tensor.rand(shape=shape, low=-9, high=9, dtype=dtype, contiguous=contiguous)
|
||||
@@ -452,7 +452,7 @@ class TestIndexing(unittest.TestCase):
|
||||
|
||||
def tensor_indices_to_np(tensor: Tensor, indices):
|
||||
npt = tensor.numpy()
|
||||
idxs = tuple(i.numpy().tolist() if isinstance(i, Tensor) and i.dtype is dtypes.int64 else
|
||||
idxs = tuple(i.numpy().tolist() if isinstance(i, Tensor) and i.dtype == dtypes.int64 else
|
||||
i for i in indices)
|
||||
return npt, idxs
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ class TestDType(unittest.TestCase):
|
||||
arr = np.asarray(data, dtype=dt)
|
||||
tin = Tensor(arr).numpy()
|
||||
tor = torch.as_tensor(arr).detach().numpy()
|
||||
assert dt is tin.dtype is tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}"
|
||||
assert dt == tin.dtype == tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}"
|
||||
np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3)
|
||||
|
||||
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
|
||||
|
||||
@@ -706,12 +706,12 @@ class TestSchedule(unittest.TestCase):
|
||||
a = shared * 2
|
||||
b = shared * 3
|
||||
sched = check_schedule([a, b], 1)
|
||||
for si in sched[:-2]: assert all(out.dtype is dtypes.half for out in si.outputs)
|
||||
for si in sched[:-2]: assert all(out.dtype == dtypes.half for out in si.outputs)
|
||||
|
||||
# reduce
|
||||
a = z.sum(axis=0).half().float().sum(axis=0)
|
||||
sched = check_schedule(a, 2)
|
||||
for si in sched[:-1]: assert all(out.dtype is dtypes.half for out in si.outputs)
|
||||
for si in sched[:-1]: assert all(out.dtype == dtypes.half for out in si.outputs)
|
||||
|
||||
# expand
|
||||
# expand will realize just after the .float(), so requires change to realize-before-expand
|
||||
|
||||
Reference in New Issue
Block a user