mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
@@ -5,7 +5,7 @@ import numpy as np
|
|||||||
from hypothesis import given, settings, strategies as strat
|
from hypothesis import given, settings, strategies as strat
|
||||||
from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV, needs_second_gpu
|
from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV, needs_second_gpu
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.engine.jit import TinyJit, GraphRunner, MultiGraphRunner, graph_class
|
from tinygrad.engine.jit import TinyJit, JitError, GraphRunner, MultiGraphRunner, graph_class
|
||||||
from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer
|
from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer
|
||||||
from tinygrad.device import Device
|
from tinygrad.device import Device
|
||||||
from tinygrad.helpers import Context, JIT, GlobalCounters, getenv
|
from tinygrad.helpers import Context, JIT, GlobalCounters, getenv
|
||||||
@@ -76,7 +76,7 @@ class TestJit(unittest.TestCase):
|
|||||||
def test_nothing_jitted(self):
|
def test_nothing_jitted(self):
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def add(a, b): return None
|
def add(a, b): return None
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(JitError):
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
a = Tensor.randn(10, 10)
|
a = Tensor.randn(10, 10)
|
||||||
b = Tensor.randn(10, 10)
|
b = Tensor.randn(10, 10)
|
||||||
@@ -125,13 +125,13 @@ class TestJit(unittest.TestCase):
|
|||||||
b = Tensor.randn(10, 10)
|
b = Tensor.randn(10, 10)
|
||||||
add(a, b)
|
add(a, b)
|
||||||
bad = Tensor.randn(20, 20)
|
bad = Tensor.randn(20, 20)
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(JitError):
|
||||||
add(a, bad)
|
add(a, bad)
|
||||||
|
|
||||||
def test_jit_shape_views_mismatch(self):
|
def test_jit_shape_views_mismatch(self):
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def add(a): return (a+1).realize()
|
def add(a): return (a+1).realize()
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(JitError):
|
||||||
for i in range(1,5):
|
for i in range(1,5):
|
||||||
# a has an offset that the kernel doesn't know about
|
# a has an offset that the kernel doesn't know about
|
||||||
a = Tensor.randn(10, 10).realize()[:, i:i+2]
|
a = Tensor.randn(10, 10).realize()[:, i:i+2]
|
||||||
@@ -142,7 +142,7 @@ class TestJit(unittest.TestCase):
|
|||||||
@TinyJit
|
@TinyJit
|
||||||
def add(a, b): return (a+b).realize()
|
def add(a, b): return (a+b).realize()
|
||||||
a = Tensor.randn(10, 10)
|
a = Tensor.randn(10, 10)
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(JitError):
|
||||||
add(a, a)
|
add(a, a)
|
||||||
|
|
||||||
def test_jit_assign(self, dtype=dtypes.float32):
|
def test_jit_assign(self, dtype=dtypes.float32):
|
||||||
@@ -510,7 +510,7 @@ class TestJit(unittest.TestCase):
|
|||||||
# TODO: this should fail since input has a different size
|
# TODO: this should fail since input has a different size
|
||||||
f(Tensor(2.0)).item()
|
f(Tensor(2.0)).item()
|
||||||
# TODO: this should not fail, and should return 3
|
# TODO: this should not fail, and should return 3
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(JitError):
|
||||||
f(Tensor([2.0])).item()
|
f(Tensor([2.0])).item()
|
||||||
|
|
||||||
@unittest.skip("Pending multioutput implementation #3607")
|
@unittest.skip("Pending multioutput implementation #3607")
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ ERRORS RAISED (lower priority - at least users know):
|
|||||||
import unittest
|
import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad import Tensor, TinyJit
|
from tinygrad import Tensor, TinyJit
|
||||||
|
from tinygrad.engine.jit import JitError
|
||||||
|
|
||||||
class TestJitFootguns(unittest.TestCase):
|
class TestJitFootguns(unittest.TestCase):
|
||||||
|
|
||||||
@@ -70,7 +71,7 @@ class TestJitFootguns(unittest.TestCase):
|
|||||||
def f(a, b): return (a + b).realize()
|
def f(a, b): return (a + b).realize()
|
||||||
|
|
||||||
x = Tensor([1, 2, 3])
|
x = Tensor([1, 2, 3])
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(JitError):
|
||||||
f(x, x)
|
f(x, x)
|
||||||
|
|
||||||
def test_tensors_in_containers_ignored(self):
|
def test_tensors_in_containers_ignored(self):
|
||||||
@@ -116,7 +117,7 @@ class TestJitFootguns(unittest.TestCase):
|
|||||||
def f(a): return (a + 1).realize()
|
def f(a): return (a + 1).realize()
|
||||||
|
|
||||||
base = Tensor.randn(10, 10).realize()
|
base = Tensor.randn(10, 10).realize()
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(JitError):
|
||||||
for i in range(1, 5):
|
for i in range(1, 5):
|
||||||
f(base[:, i:i+2]) # different offset each time
|
f(base[:, i:i+2]) # different offset each time
|
||||||
|
|
||||||
@@ -128,7 +129,7 @@ class TestJitFootguns(unittest.TestCase):
|
|||||||
f(Tensor.randn(10, 10), Tensor.randn(10, 10)) # warmup
|
f(Tensor.randn(10, 10), Tensor.randn(10, 10)) # warmup
|
||||||
f(Tensor.randn(10, 10), Tensor.randn(10, 10)) # capture
|
f(Tensor.randn(10, 10), Tensor.randn(10, 10)) # capture
|
||||||
|
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(JitError):
|
||||||
f(Tensor.randn(20, 20), Tensor.randn(20, 20))
|
f(Tensor.randn(20, 20), Tensor.randn(20, 20))
|
||||||
|
|
||||||
def test_python_constants_frozen(self):
|
def test_python_constants_frozen(self):
|
||||||
@@ -170,7 +171,7 @@ class TestJitFootguns(unittest.TestCase):
|
|||||||
f(Tensor([1]), Tensor([2])) # warmup with positional
|
f(Tensor([1]), Tensor([2])) # warmup with positional
|
||||||
f(Tensor([1]), Tensor([2])) # capture with positional
|
f(Tensor([1]), Tensor([2])) # capture with positional
|
||||||
|
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(JitError):
|
||||||
f(a=Tensor([3]), b=Tensor([4])) # kwargs fail
|
f(a=Tensor([3]), b=Tensor([4])) # kwargs fail
|
||||||
|
|
||||||
def test_class_method_shared_across_instances(self):
|
def test_class_method_shared_across_instances(self):
|
||||||
@@ -213,7 +214,7 @@ class TestJitFootguns(unittest.TestCase):
|
|||||||
@TinyJit
|
@TinyJit
|
||||||
def f(a, b): return None
|
def f(a, b): return None
|
||||||
|
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(JitError):
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
f(Tensor([1]), Tensor([2]))
|
f(Tensor([1]), Tensor([2]))
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import unittest
|
|||||||
|
|
||||||
from test.helpers import assert_jit_cache_len
|
from test.helpers import assert_jit_cache_len
|
||||||
from tinygrad import Variable, Tensor, TinyJit
|
from tinygrad import Variable, Tensor, TinyJit
|
||||||
|
from tinygrad.engine.jit import JitError
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
class TestSymbolicJit(unittest.TestCase):
|
class TestSymbolicJit(unittest.TestCase):
|
||||||
@@ -172,7 +173,7 @@ class TestSymbolicJit(unittest.TestCase):
|
|||||||
vi2 = Variable("i", 1, 10).bind(7)
|
vi2 = Variable("i", 1, 10).bind(7)
|
||||||
a = Tensor.rand(3, 7)[:, :vi2]
|
a = Tensor.rand(3, 7)[:, :vi2]
|
||||||
bad = Tensor.rand(4, 7)[:, :vi2]
|
bad = Tensor.rand(4, 7)[:, :vi2]
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(JitError):
|
||||||
add(a, bad)
|
add(a, bad)
|
||||||
|
|
||||||
def test_shrink(self):
|
def test_shrink(self):
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from dataclasses import dataclass, replace
|
|||||||
from weakref import WeakKeyDictionary
|
from weakref import WeakKeyDictionary
|
||||||
|
|
||||||
class GraphException(Exception): pass
|
class GraphException(Exception): pass
|
||||||
|
class JitError(Exception): pass
|
||||||
|
|
||||||
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
|
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
|
||||||
|
|
||||||
@@ -225,7 +226,7 @@ def _prepare_jit_inputs(args, kwargs):
|
|||||||
lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
|
lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
|
||||||
input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb]
|
input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb]
|
||||||
for lb in lbs if lb.base.realized is not None])
|
for lb in lbs if lb.base.realized is not None])
|
||||||
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
|
if len(set(input_buffers)) != len(input_buffers): raise JitError("duplicate inputs to JIT")
|
||||||
st_varval_dtype_device = [(*(lb.substitute({lb.base:UOp(Ops.NOOP)}, extra_pm=mop_cleanup).unbind_all()), lb.dtype, lb.device) for lb in lbs]
|
st_varval_dtype_device = [(*(lb.substitute({lb.base:UOp(Ops.NOOP)}, extra_pm=mop_cleanup).unbind_all()), lb.dtype, lb.device) for lb in lbs]
|
||||||
_var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
|
_var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
|
||||||
var_vals = {k.expr:v for k,v in _var_vals.items()}
|
var_vals = {k.expr:v for k,v in _var_vals.items()}
|
||||||
@@ -294,7 +295,7 @@ class TinyJit(Generic[ReturnType]):
|
|||||||
finally: capturing.clear()
|
finally: capturing.clear()
|
||||||
jit_cache = self._jit_cache
|
jit_cache = self._jit_cache
|
||||||
del self._buffer_replace, self._jit_cache
|
del self._buffer_replace, self._jit_cache
|
||||||
assert len(jit_cache), "didn't JIT anything!"
|
if not len(jit_cache): raise JitError("didn't JIT anything!")
|
||||||
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs")
|
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs")
|
||||||
|
|
||||||
# track inputs that are views of buffers
|
# track inputs that are views of buffers
|
||||||
@@ -333,9 +334,9 @@ class TinyJit(Generic[ReturnType]):
|
|||||||
elif self.cnt >= 2:
|
elif self.cnt >= 2:
|
||||||
# jit exec
|
# jit exec
|
||||||
assert self.captured is not None
|
assert self.captured is not None
|
||||||
assert self.captured.expected_names == names, f"args mismatch in JIT: {self.captured.expected_names=} != {names}"
|
if self.captured.expected_names != names: raise JitError(f"args mismatch in JIT: {self.captured.expected_names=} != {names}")
|
||||||
assert self.captured.expected_st_vars_dtype_device == st_vars_dtype_device, \
|
if self.captured.expected_st_vars_dtype_device != st_vars_dtype_device:
|
||||||
f"args mismatch in JIT: {self.captured.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}"
|
raise JitError(f"args mismatch in JIT: {self.captured.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}")
|
||||||
ret = self.captured(input_buffers, var_vals)
|
ret = self.captured(input_buffers, var_vals)
|
||||||
|
|
||||||
self.cnt += 1
|
self.cnt += 1
|
||||||
|
|||||||
Reference in New Issue
Block a user