diff --git a/test/test_jit.py b/test/test_jit.py index 7b0b5dc5f4..816117f69e 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -5,7 +5,7 @@ import numpy as np from hypothesis import given, settings, strategies as strat from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV, needs_second_gpu from tinygrad.tensor import Tensor -from tinygrad.engine.jit import TinyJit, GraphRunner, MultiGraphRunner, graph_class +from tinygrad.engine.jit import TinyJit, JitError, GraphRunner, MultiGraphRunner, graph_class from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer from tinygrad.device import Device from tinygrad.helpers import Context, JIT, GlobalCounters, getenv @@ -76,7 +76,7 @@ class TestJit(unittest.TestCase): def test_nothing_jitted(self): @TinyJit def add(a, b): return None - with self.assertRaises(AssertionError): + with self.assertRaises(JitError): for _ in range(5): a = Tensor.randn(10, 10) b = Tensor.randn(10, 10) @@ -125,13 +125,13 @@ class TestJit(unittest.TestCase): b = Tensor.randn(10, 10) add(a, b) bad = Tensor.randn(20, 20) - with self.assertRaises(AssertionError): + with self.assertRaises(JitError): add(a, bad) def test_jit_shape_views_mismatch(self): @TinyJit def add(a): return (a+1).realize() - with self.assertRaises(AssertionError): + with self.assertRaises(JitError): for i in range(1,5): # a has an offset that the kernel doesn't know about a = Tensor.randn(10, 10).realize()[:, i:i+2] @@ -142,7 +142,7 @@ class TestJit(unittest.TestCase): @TinyJit def add(a, b): return (a+b).realize() a = Tensor.randn(10, 10) - with self.assertRaises(AssertionError): + with self.assertRaises(JitError): add(a, a) def test_jit_assign(self, dtype=dtypes.float32): @@ -510,7 +510,7 @@ class TestJit(unittest.TestCase): # TODO: this should fail since input has a different size f(Tensor(2.0)).item() # TODO: this should not fail, and should return 3 - with self.assertRaises(AssertionError): + with self.assertRaises(JitError): f(Tensor([2.0])).item() @unittest.skip("Pending multioutput implementation #3607") diff --git a/test/test_jit_footguns.py b/test/test_jit_footguns.py index bf0c596de0..8a8506b827 100644 --- a/test/test_jit_footguns.py +++ b/test/test_jit_footguns.py @@ -21,6 +21,7 @@ ERRORS RAISED (lower priority - at least users know): import unittest import numpy as np from tinygrad import Tensor, TinyJit +from tinygrad.engine.jit import JitError class TestJitFootguns(unittest.TestCase): @@ -70,7 +71,7 @@ class TestJitFootguns(unittest.TestCase): def f(a, b): return (a + b).realize() x = Tensor([1, 2, 3]) - with self.assertRaises(AssertionError): + with self.assertRaises(JitError): f(x, x) def test_tensors_in_containers_ignored(self): @@ -116,7 +117,7 @@ class TestJitFootguns(unittest.TestCase): def f(a): return (a + 1).realize() base = Tensor.randn(10, 10).realize() - with self.assertRaises(AssertionError): + with self.assertRaises(JitError): for i in range(1, 5): f(base[:, i:i+2]) # different offset each time @@ -128,7 +129,7 @@ class TestJitFootguns(unittest.TestCase): f(Tensor.randn(10, 10), Tensor.randn(10, 10)) # warmup f(Tensor.randn(10, 10), Tensor.randn(10, 10)) # capture - with self.assertRaises(AssertionError): + with self.assertRaises(JitError): f(Tensor.randn(20, 20), Tensor.randn(20, 20)) def test_python_constants_frozen(self): @@ -170,7 +171,7 @@ class TestJitFootguns(unittest.TestCase): f(Tensor([1]), Tensor([2])) # warmup with positional f(Tensor([1]), Tensor([2])) # capture with positional - with self.assertRaises(AssertionError): + with self.assertRaises(JitError): f(a=Tensor([3]), b=Tensor([4])) # kwargs fail def test_class_method_shared_across_instances(self): @@ -213,7 +214,7 @@ class TestJitFootguns(unittest.TestCase): @TinyJit def f(a, b): return None - with self.assertRaises(AssertionError): + with self.assertRaises(JitError): for _ in range(3): f(Tensor([1]), Tensor([2])) diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index 9174a47187..ddc464661d 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -2,6 +2,7 @@ import unittest from test.helpers import assert_jit_cache_len from tinygrad import Variable, Tensor, TinyJit +from tinygrad.engine.jit import JitError import numpy as np class TestSymbolicJit(unittest.TestCase): @@ -172,7 +173,7 @@ class TestSymbolicJit(unittest.TestCase): vi2 = Variable("i", 1, 10).bind(7) a = Tensor.rand(3, 7)[:, :vi2] bad = Tensor.rand(4, 7)[:, :vi2] - with self.assertRaises(AssertionError): + with self.assertRaises(JitError): add(a, bad) def test_shrink(self): diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 146472bcd0..46a5517eec 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -13,6 +13,7 @@ from dataclasses import dataclass, replace from weakref import WeakKeyDictionary class GraphException(Exception): pass +class JitError(Exception): pass def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph @@ -225,7 +226,7 @@ def _prepare_jit_inputs(args, kwargs): lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors]) input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb] for lb in lbs if lb.base.realized is not None]) - assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT" + if len(set(input_buffers)) != len(input_buffers): raise JitError("duplicate inputs to JIT") st_varval_dtype_device = [(*(lb.substitute({lb.base:UOp(Ops.NOOP)}, extra_pm=mop_cleanup).unbind_all()), lb.dtype, lb.device) for lb in lbs] _var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))]) var_vals = {k.expr:v for k,v in _var_vals.items()} @@ -294,7 +295,7 @@ class TinyJit(Generic[ReturnType]): finally: capturing.clear() jit_cache = self._jit_cache del self._buffer_replace, self._jit_cache - assert len(jit_cache), "didn't JIT anything!" + if not len(jit_cache): raise JitError("didn't JIT anything!") if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs") # track inputs that are views of buffers @@ -333,9 +334,9 @@ class TinyJit(Generic[ReturnType]): elif self.cnt >= 2: # jit exec assert self.captured is not None - assert self.captured.expected_names == names, f"args mismatch in JIT: {self.captured.expected_names=} != {names}" - assert self.captured.expected_st_vars_dtype_device == st_vars_dtype_device, \ - f"args mismatch in JIT: {self.captured.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}" + if self.captured.expected_names != names: raise JitError(f"args mismatch in JIT: {self.captured.expected_names=} != {names}") + if self.captured.expected_st_vars_dtype_device != st_vars_dtype_device: + raise JitError(f"args mismatch in JIT: {self.captured.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}") ret = self.captured(input_buffers, var_vals) self.cnt += 1