* JitError

* test_symbolic_jit
This commit is contained in:
chenyu
2026-01-06 12:19:50 -05:00
committed by GitHub
parent 6ddddc68af
commit 4491ec0c9e
4 changed files with 20 additions and 17 deletions

View File

@@ -5,7 +5,7 @@ import numpy as np
from hypothesis import given, settings, strategies as strat from hypothesis import given, settings, strategies as strat
from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV, needs_second_gpu from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV, needs_second_gpu
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.engine.jit import TinyJit, GraphRunner, MultiGraphRunner, graph_class from tinygrad.engine.jit import TinyJit, JitError, GraphRunner, MultiGraphRunner, graph_class
from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer
from tinygrad.device import Device from tinygrad.device import Device
from tinygrad.helpers import Context, JIT, GlobalCounters, getenv from tinygrad.helpers import Context, JIT, GlobalCounters, getenv
@@ -76,7 +76,7 @@ class TestJit(unittest.TestCase):
def test_nothing_jitted(self): def test_nothing_jitted(self):
@TinyJit @TinyJit
def add(a, b): return None def add(a, b): return None
with self.assertRaises(AssertionError): with self.assertRaises(JitError):
for _ in range(5): for _ in range(5):
a = Tensor.randn(10, 10) a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10) b = Tensor.randn(10, 10)
@@ -125,13 +125,13 @@ class TestJit(unittest.TestCase):
b = Tensor.randn(10, 10) b = Tensor.randn(10, 10)
add(a, b) add(a, b)
bad = Tensor.randn(20, 20) bad = Tensor.randn(20, 20)
with self.assertRaises(AssertionError): with self.assertRaises(JitError):
add(a, bad) add(a, bad)
def test_jit_shape_views_mismatch(self): def test_jit_shape_views_mismatch(self):
@TinyJit @TinyJit
def add(a): return (a+1).realize() def add(a): return (a+1).realize()
with self.assertRaises(AssertionError): with self.assertRaises(JitError):
for i in range(1,5): for i in range(1,5):
# a has an offset that the kernel doesn't know about # a has an offset that the kernel doesn't know about
a = Tensor.randn(10, 10).realize()[:, i:i+2] a = Tensor.randn(10, 10).realize()[:, i:i+2]
@@ -142,7 +142,7 @@ class TestJit(unittest.TestCase):
@TinyJit @TinyJit
def add(a, b): return (a+b).realize() def add(a, b): return (a+b).realize()
a = Tensor.randn(10, 10) a = Tensor.randn(10, 10)
with self.assertRaises(AssertionError): with self.assertRaises(JitError):
add(a, a) add(a, a)
def test_jit_assign(self, dtype=dtypes.float32): def test_jit_assign(self, dtype=dtypes.float32):
@@ -510,7 +510,7 @@ class TestJit(unittest.TestCase):
# TODO: this should fail since input has a different size # TODO: this should fail since input has a different size
f(Tensor(2.0)).item() f(Tensor(2.0)).item()
# TODO: this should not fail, and should return 3 # TODO: this should not fail, and should return 3
with self.assertRaises(AssertionError): with self.assertRaises(JitError):
f(Tensor([2.0])).item() f(Tensor([2.0])).item()
@unittest.skip("Pending multioutput implementation #3607") @unittest.skip("Pending multioutput implementation #3607")

View File

@@ -21,6 +21,7 @@ ERRORS RAISED (lower priority - at least users know):
import unittest import unittest
import numpy as np import numpy as np
from tinygrad import Tensor, TinyJit from tinygrad import Tensor, TinyJit
from tinygrad.engine.jit import JitError
class TestJitFootguns(unittest.TestCase): class TestJitFootguns(unittest.TestCase):
@@ -70,7 +71,7 @@ class TestJitFootguns(unittest.TestCase):
def f(a, b): return (a + b).realize() def f(a, b): return (a + b).realize()
x = Tensor([1, 2, 3]) x = Tensor([1, 2, 3])
with self.assertRaises(AssertionError): with self.assertRaises(JitError):
f(x, x) f(x, x)
def test_tensors_in_containers_ignored(self): def test_tensors_in_containers_ignored(self):
@@ -116,7 +117,7 @@ class TestJitFootguns(unittest.TestCase):
def f(a): return (a + 1).realize() def f(a): return (a + 1).realize()
base = Tensor.randn(10, 10).realize() base = Tensor.randn(10, 10).realize()
with self.assertRaises(AssertionError): with self.assertRaises(JitError):
for i in range(1, 5): for i in range(1, 5):
f(base[:, i:i+2]) # different offset each time f(base[:, i:i+2]) # different offset each time
@@ -128,7 +129,7 @@ class TestJitFootguns(unittest.TestCase):
f(Tensor.randn(10, 10), Tensor.randn(10, 10)) # warmup f(Tensor.randn(10, 10), Tensor.randn(10, 10)) # warmup
f(Tensor.randn(10, 10), Tensor.randn(10, 10)) # capture f(Tensor.randn(10, 10), Tensor.randn(10, 10)) # capture
with self.assertRaises(AssertionError): with self.assertRaises(JitError):
f(Tensor.randn(20, 20), Tensor.randn(20, 20)) f(Tensor.randn(20, 20), Tensor.randn(20, 20))
def test_python_constants_frozen(self): def test_python_constants_frozen(self):
@@ -170,7 +171,7 @@ class TestJitFootguns(unittest.TestCase):
f(Tensor([1]), Tensor([2])) # warmup with positional f(Tensor([1]), Tensor([2])) # warmup with positional
f(Tensor([1]), Tensor([2])) # capture with positional f(Tensor([1]), Tensor([2])) # capture with positional
with self.assertRaises(AssertionError): with self.assertRaises(JitError):
f(a=Tensor([3]), b=Tensor([4])) # kwargs fail f(a=Tensor([3]), b=Tensor([4])) # kwargs fail
def test_class_method_shared_across_instances(self): def test_class_method_shared_across_instances(self):
@@ -213,7 +214,7 @@ class TestJitFootguns(unittest.TestCase):
@TinyJit @TinyJit
def f(a, b): return None def f(a, b): return None
with self.assertRaises(AssertionError): with self.assertRaises(JitError):
for _ in range(3): for _ in range(3):
f(Tensor([1]), Tensor([2])) f(Tensor([1]), Tensor([2]))

View File

@@ -2,6 +2,7 @@ import unittest
from test.helpers import assert_jit_cache_len from test.helpers import assert_jit_cache_len
from tinygrad import Variable, Tensor, TinyJit from tinygrad import Variable, Tensor, TinyJit
from tinygrad.engine.jit import JitError
import numpy as np import numpy as np
class TestSymbolicJit(unittest.TestCase): class TestSymbolicJit(unittest.TestCase):
@@ -172,7 +173,7 @@ class TestSymbolicJit(unittest.TestCase):
vi2 = Variable("i", 1, 10).bind(7) vi2 = Variable("i", 1, 10).bind(7)
a = Tensor.rand(3, 7)[:, :vi2] a = Tensor.rand(3, 7)[:, :vi2]
bad = Tensor.rand(4, 7)[:, :vi2] bad = Tensor.rand(4, 7)[:, :vi2]
with self.assertRaises(AssertionError): with self.assertRaises(JitError):
add(a, bad) add(a, bad)
def test_shrink(self): def test_shrink(self):

View File

@@ -13,6 +13,7 @@ from dataclasses import dataclass, replace
from weakref import WeakKeyDictionary from weakref import WeakKeyDictionary
class GraphException(Exception): pass class GraphException(Exception): pass
class JitError(Exception): pass
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
@@ -225,7 +226,7 @@ def _prepare_jit_inputs(args, kwargs):
lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors]) lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb] input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb]
for lb in lbs if lb.base.realized is not None]) for lb in lbs if lb.base.realized is not None])
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT" if len(set(input_buffers)) != len(input_buffers): raise JitError("duplicate inputs to JIT")
st_varval_dtype_device = [(*(lb.substitute({lb.base:UOp(Ops.NOOP)}, extra_pm=mop_cleanup).unbind_all()), lb.dtype, lb.device) for lb in lbs] st_varval_dtype_device = [(*(lb.substitute({lb.base:UOp(Ops.NOOP)}, extra_pm=mop_cleanup).unbind_all()), lb.dtype, lb.device) for lb in lbs]
_var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))]) _var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
var_vals = {k.expr:v for k,v in _var_vals.items()} var_vals = {k.expr:v for k,v in _var_vals.items()}
@@ -294,7 +295,7 @@ class TinyJit(Generic[ReturnType]):
finally: capturing.clear() finally: capturing.clear()
jit_cache = self._jit_cache jit_cache = self._jit_cache
del self._buffer_replace, self._jit_cache del self._buffer_replace, self._jit_cache
assert len(jit_cache), "didn't JIT anything!" if not len(jit_cache): raise JitError("didn't JIT anything!")
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs") if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs")
# track inputs that are views of buffers # track inputs that are views of buffers
@@ -333,9 +334,9 @@ class TinyJit(Generic[ReturnType]):
elif self.cnt >= 2: elif self.cnt >= 2:
# jit exec # jit exec
assert self.captured is not None assert self.captured is not None
assert self.captured.expected_names == names, f"args mismatch in JIT: {self.captured.expected_names=} != {names}" if self.captured.expected_names != names: raise JitError(f"args mismatch in JIT: {self.captured.expected_names=} != {names}")
assert self.captured.expected_st_vars_dtype_device == st_vars_dtype_device, \ if self.captured.expected_st_vars_dtype_device != st_vars_dtype_device:
f"args mismatch in JIT: {self.captured.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}" raise JitError(f"args mismatch in JIT: {self.captured.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}")
ret = self.captured(input_buffers, var_vals) ret = self.captured(input_buffers, var_vals)
self.cnt += 1 self.cnt += 1