Merge remote-tracking branch 'origin/master' into asm_ucode

# Conflicts:
#	test/test_jit.py
#	test/test_jit_footguns.py
#	tinygrad/engine/jit.py
This commit is contained in:
George Hotz
2026-01-08 05:14:42 -08:00
33 changed files with 1401 additions and 280 deletions

View File

@@ -43,16 +43,13 @@ def _test_cast(a:Tensor, target_dtype:DType):
if a.is_floating_point() and dtypes.is_unsigned(target_dtype):
# converting negative float to unsigned integer is undefined
a = a.abs()
if target_dtype == dtypes.half and Device.DEFAULT == "PYTHON":
# TODO: struct.pack cannot pack value > 65504 (max of half) into e format
a = (a > 65504).where(65504, a)
expected = list(a.numpy().astype(_to_np_dtype(target_dtype)))
if target_dtype in dtypes.fp8s: expected = list(map(lambda x: truncate[target_dtype](x), expected))
if target_dtype in dtypes.fp8s: expected = [truncate[target_dtype](x) for x in expected]
_test_op(lambda: a.cast(target_dtype), target_dtype, expected)
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
expected = torch.tensor(a.tolist(), dtype=_to_torch_storage_type(a.dtype)).view(_to_torch_dtype(target_dtype)).tolist()
if target_dtype in dtypes.fp8s: expected = list(map(lambda x: fp8_to_float(x, target_dtype), expected))
if target_dtype in dtypes.fp8s: expected = [fp8_to_float(x, target_dtype) for x in expected]
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or expected)
class TestDType(unittest.TestCase):
@@ -68,37 +65,34 @@ class TestDType(unittest.TestCase):
def test_to_np(self):
_test_to_np(Tensor(self.DATA, dtype=self.DTYPE), _to_np_dtype(self.DTYPE), np.array(self.DATA, dtype=_to_np_dtype(self.DTYPE)))
def test_casts_to(self): list(map(
lambda dtype: _test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE),
get_available_cast_dtypes(self.DTYPE)
))
def test_casts_from(self): list(map(
lambda dtype: _test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype),
get_available_cast_dtypes(self.DTYPE)
))
def test_casts_to(self):
for dtype in get_available_cast_dtypes(self.DTYPE):
_test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE)
def test_casts_from(self):
for dtype in get_available_cast_dtypes(self.DTYPE):
_test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype)
def test_same_size_ops(self):
list(map(
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize == self.DTYPE.itemsize else None,
get_available_cast_dtypes(self.DTYPE)
))
for dtype in get_available_cast_dtypes(self.DTYPE):
if dtype.itemsize == self.DTYPE.itemsize:
_test_ops(a_dtype=self.DTYPE, b_dtype=dtype)
def test_upcast_ops(self):
list(map(
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize > self.DTYPE.itemsize else None,
get_available_cast_dtypes(self.DTYPE)
))
for dtype in get_available_cast_dtypes(self.DTYPE):
if dtype.itemsize > self.DTYPE.itemsize:
_test_ops(a_dtype=self.DTYPE, b_dtype=dtype)
def test_upcast_to_ops(self):
list(map(
lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None,
get_available_cast_dtypes(self.DTYPE)
))
for dtype in get_available_cast_dtypes(self.DTYPE):
if dtype.itemsize < self.DTYPE.itemsize:
_test_ops(a_dtype=dtype, b_dtype=self.DTYPE)
def test_bitcast(self):
if self.DTYPE == dtypes.bool: raise unittest.SkipTest("no bools in bitcast")
list(map(
lambda dtype:
_test_bitcast(Tensor(self.DATA[:8], dtype=self.DTYPE), dtype) if dtype != dtypes.bool else None,
get_available_cast_dtypes(self.DTYPE)
))
for dtype in get_available_cast_dtypes(self.DTYPE):
if dtype != dtypes.bool:
_test_bitcast(Tensor(self.DATA[:8], dtype=self.DTYPE), dtype)
@unittest.skipIf(Device.DEFAULT == "PYTHON", "skip for now")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)), "skip for now")
@@ -307,7 +301,7 @@ class TestBitCast(unittest.TestCase):
data = rand_for_dtype(dt1, 32).reshape(2, 2, 8)
expected = torch.tensor(data.tolist(), dtype=_to_torch_storage_type(dt1)).view(_to_torch_dtype(dt2))
if dt2 in dtypes.fp8s:
expected = torch.tensor(list(map(lambda x: fp8_to_float(x, dt2), expected.view(-1).tolist()))).view_as(expected)
expected = torch.tensor([fp8_to_float(x, dt2) for x in expected.view(-1).tolist()]).view_as(expected)
_test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, expected.tolist())
def test_shape_change_bitcast_exceptions(self):

View File

@@ -184,16 +184,9 @@ class TestJit(unittest.TestCase):
def test_array_jit(self):
@TinyJit
def add_array(a, arr): return (a+arr[0]).realize()
for i in range(5):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
a.realize(), b.realize()
c = add_array(a, [b])
if i >= 2:
# should fail once jitted since jit can't handle arrays
np.testing.assert_allclose(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True, atol=1e-4, rtol=1e-5)
else:
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
for _ in range(5):
a, b = Tensor.randn(10, 10).realize(), Tensor.randn(10, 10).realize()
np.testing.assert_allclose(add_array(a, [b]).numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert_jit_cache_len(add_array, 1)
def test_jit_copyin(self):
@@ -414,12 +407,6 @@ class TestJit(unittest.TestCase):
assert isinstance(jf.jit_cache[0].prg, graph_t)
assert isinstance(jf.jit_cache[1].prg, graph_t)
def test_jit_const_inputs(self):
@TinyJit
def g(x,y,z): return (x+y+z).realize()
for i in range(5):
np.testing.assert_equal(g(Tensor([i]*3), Tensor.ones(3), Tensor.zeros(3)).numpy(), np.array([i+1]*3))
def test_jitted_clone(self):
def f(a): return a.clone().realize()
jf = TinyJit(f)
@@ -496,9 +483,10 @@ class TestJit(unittest.TestCase):
f(Tensor.empty(1))
f(Tensor.empty(1))
# TODO: this should fail since input has a different size
f(Tensor(2.0)).item()
# TODO: this should not fail, and should return 3
# scalar const input is not allowed
with self.assertRaises(JitError):
f(Tensor(2.0)).item()
# list input has different view structure than empty(1)
with self.assertRaises(JitError):
f(Tensor([2.0])).item()

View File

@@ -6,14 +6,13 @@ Each test shows behavior that works without JIT but changes with JIT.
Comments marked "should be X!" indicate the intuitively expected value.
SILENT MISMATCHES (highest priority - wrong results, no error):
tensors_in_containers_ignored EASY only checks t.__class__ is Tensor, could scan lists/dicts
class_method_shared_across_instances EASY could check if first arg is self and warn
output_buffer_reuse MED performance tradeoff, could add option or better docs
python_constants_frozen HARD inherent to tracing JITs
conditional_branches_frozen HARD inherent to tracing JITs
unrealized_const_input_frozen HARD unrealized const has no buffer to replace, values baked in
ERRORS RAISED (lower priority - at least users know):
unrealized_const_input_error EASY raises JitError for unrealized const inputs
non_tensor_outputs_error EASY raises JitError if return contains non-Tensor values
positional_kwargs_cannot_mix EASY normalize positional args to kwargs using function signature
duplicate_inputs_fail MED would need to handle aliasing in input_replace
@@ -65,20 +64,12 @@ class TestJitFootguns(unittest.TestCase):
with self.assertRaises(JitError):
f(x, x)
def test_tensors_in_containers_ignored(self):
"""Tensors inside lists/dicts are not tracked as inputs."""
def test_tensors_in_containers(self):
@TinyJit
def f(a, arr): return (a + arr[0]).realize()
results = []
for i in range(4):
a, b = Tensor([1, 1, 1]).realize(), Tensor([i, i, i]).realize()
results.append(f(a, [b]).numpy().copy())
np.testing.assert_array_equal(results[0], [1, 1, 1]) # warmup
np.testing.assert_array_equal(results[1], [2, 2, 2]) # capture
np.testing.assert_array_equal(results[2], [2, 2, 2]) # should be [3,3,3]!
np.testing.assert_array_equal(results[3], [2, 2, 2]) # should be [4,4,4]!
np.testing.assert_array_equal(f(a, [b]).numpy(), [1+i, 1+i, 1+i])
def test_nested_jit_fails_on_second_call(self):
"""Nested JIT works on first call but fails on second."""
@@ -140,16 +131,20 @@ class TestJitFootguns(unittest.TestCase):
self.assertEqual(results[2], 20) # should be 30!
self.assertEqual(results[3], 20) # should be 40!
def test_unrealized_const_input_frozen(self):
"""Unrealized const tensors have no buffer to replace, so values are baked in at capture time."""
def test_unrealized_const_input_error(self):
"""Const tensors have no buffer to replace, so JIT raises an error. Even explicit .realize() doesn't help."""
@TinyJit
def f(a, b): return (a * b).realize()
for i in range(1, 5):
result = f(Tensor([1, 2, 3]).realize(), Tensor(i)) # Tensor(i) is unrealized const
# value is frozen at capture (i=2), so i=3,4 give wrong results
expected = [2, 4, 6] if i >= 2 else [i, 2*i, 3*i]
np.testing.assert_equal(result.numpy(), expected) # i=3,4 should be [3,6,9], [4,8,12]!
# unrealized const fails
with self.assertRaises(JitError):
f(Tensor([1, 2, 3]).realize(), Tensor(2))
# explicit .realize() on const still fails - const cannot be realized to have a buffer
@TinyJit
def g(a, b): return (a * b).realize()
with self.assertRaises(JitError):
g(Tensor([1, 2, 3]).realize(), Tensor(2).realize())
def test_conditional_branches_frozen(self):
"""Only the branch taken during capture runs thereafter."""

View File

@@ -541,6 +541,37 @@ class TestMultiTensor(unittest.TestCase):
np.testing.assert_allclose(r.numpy(), np.ones(256)+np.ones(256), atol=1e-4, rtol=1e-5)
assert len(jf.jit_cache) > 0
def test_multitensor_jit_in_list(self):
# test MULTI tensor inside a list container - exercises the container unpacking + MULTI unpacking
@TinyJit
def f(a, arr): return (a + arr[0]).realize()
for i in range(5):
a = Tensor.full((4,), i).contiguous().realize().shard(devices_2, 0).realize()
b = Tensor.ones(4).contiguous().realize().shard(devices_2, 0).realize()
out = f(a, [b])
np.testing.assert_allclose(out.numpy(), np.full(4, i) + np.ones(4), atol=1e-4, rtol=1e-5)
def test_multitensor_jit_multiple_inputs(self):
# test multiple MULTI tensors as inputs - each gets unpacked to component UOps
@TinyJit
def f(a, b, c): return (a + b + c).realize()
for i in range(5):
a = Tensor.full((4,), i).contiguous().realize().shard(devices_2, 0).realize()
b = Tensor.full((4,), i*2).contiguous().realize().shard(devices_2, 0).realize()
c = Tensor.ones(4).contiguous().realize().shard(devices_2, 0).realize()
out = f(a, b, c)
np.testing.assert_allclose(out.numpy(), np.full(4, i) + np.full(4, i*2) + np.ones(4), atol=1e-4, rtol=1e-5)
def test_multitensor_jit_different_sharding(self):
# test MULTI tensors with different sharding - one sharded on axis 0, one broadcast (axis=None)
@TinyJit
def f(a, b): return (a + b).realize()
for i in range(5):
a = Tensor.full((4, 4), i).contiguous().realize().shard(devices_2, 0).realize()
b = Tensor.full((4, 4), i*2).contiguous().realize().shard(devices_2, None).realize()
out = f(a, b)
np.testing.assert_allclose(out.numpy(), np.full((4, 4), i) + np.full((4, 4), i*2), atol=1e-4, rtol=1e-5)
@unittest.skip("test broken")
def test_multi_device_jit_graph(self):
if Device[d0].graph is None or Device[d1].graph is None: raise unittest.SkipTest("only test graphs")

View File

@@ -217,7 +217,7 @@ class TestSymbolicJit(unittest.TestCase):
def test_ones_sum(self):
def f(a): return a.sum().realize()
jf = TinyJit(f)
t = Tensor.ones(10)
t = Tensor.ones(10).contiguous()
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
symbolic = jf(t[:vi]).item()

View File

@@ -199,7 +199,7 @@ class TestSymbolicOps(unittest.TestCase):
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_ones_sum(self):
t = Tensor.ones(10)
t = Tensor.ones(10).contiguous()
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
symbolic = t[:vi].sum().item()

View File

@@ -241,18 +241,18 @@ class TestTK(unittest.TestCase):
np.testing.assert_allclose(b.numpy(), ref.numpy())
np.testing.assert_allclose(c.numpy(), ref.numpy())
@unittest.skip("TODO")
def test_load_store_group(self):
N = 256
N = 1024
BLOCK_SIZE = 64
with Kernel("load_store_group", (N // BLOCK_SIZE, N // BLOCK_SIZE, 1), WARP_THREADS * 2) as ker:
NUM_WORKERS = 4
with Kernel("load_store_group", (N // (BLOCK_SIZE * NUM_WORKERS), N // BLOCK_SIZE, 1), WARP_THREADS * NUM_WORKERS) as ker:
warp = ker.warp
group = ker.group(2)
group = ker.group(NUM_WORKERS)
b = ker.gl((1, 1, N, N), dtypes.float32)
a = ker.gl((1, 1, N, N), dtypes.float32)
a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE * NUM_WORKERS), dtypes.float32)
a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
@@ -260,9 +260,9 @@ class TestTK(unittest.TestCase):
col, row = ker.blockIdx_x, ker.blockIdx_y
a_smem = group.load(a_smem, a, (), (0, 0, row, col), axis=2)
a_reg = warp.load(a_reg, a_smem)
a_reg = warp.load(a_reg, a_smem, (), (0, ker.warpid,))
b_reg = warp.copy(b_reg, a_reg)
b = warp.store(b, b_reg, (0, 0, row, col), (), axis=2)
b = warp.store(b, b_reg, (0, 0, row, col * NUM_WORKERS + ker.warpid), (), axis=2)
sink = ker.finish()