mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Merge remote-tracking branch 'origin/master' into asm_ucode
# Conflicts: # test/test_jit.py # test/test_jit_footguns.py # tinygrad/engine/jit.py
This commit is contained in:
@@ -43,16 +43,13 @@ def _test_cast(a:Tensor, target_dtype:DType):
|
||||
if a.is_floating_point() and dtypes.is_unsigned(target_dtype):
|
||||
# converting negative float to unsigned integer is undefined
|
||||
a = a.abs()
|
||||
if target_dtype == dtypes.half and Device.DEFAULT == "PYTHON":
|
||||
# TODO: struct.pack cannot pack value > 65504 (max of half) into e format
|
||||
a = (a > 65504).where(65504, a)
|
||||
|
||||
expected = list(a.numpy().astype(_to_np_dtype(target_dtype)))
|
||||
if target_dtype in dtypes.fp8s: expected = list(map(lambda x: truncate[target_dtype](x), expected))
|
||||
if target_dtype in dtypes.fp8s: expected = [truncate[target_dtype](x) for x in expected]
|
||||
_test_op(lambda: a.cast(target_dtype), target_dtype, expected)
|
||||
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
|
||||
expected = torch.tensor(a.tolist(), dtype=_to_torch_storage_type(a.dtype)).view(_to_torch_dtype(target_dtype)).tolist()
|
||||
if target_dtype in dtypes.fp8s: expected = list(map(lambda x: fp8_to_float(x, target_dtype), expected))
|
||||
if target_dtype in dtypes.fp8s: expected = [fp8_to_float(x, target_dtype) for x in expected]
|
||||
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or expected)
|
||||
|
||||
class TestDType(unittest.TestCase):
|
||||
@@ -68,37 +65,34 @@ class TestDType(unittest.TestCase):
|
||||
def test_to_np(self):
|
||||
_test_to_np(Tensor(self.DATA, dtype=self.DTYPE), _to_np_dtype(self.DTYPE), np.array(self.DATA, dtype=_to_np_dtype(self.DTYPE)))
|
||||
|
||||
def test_casts_to(self): list(map(
|
||||
lambda dtype: _test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE),
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
def test_casts_from(self): list(map(
|
||||
lambda dtype: _test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype),
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
def test_casts_to(self):
|
||||
for dtype in get_available_cast_dtypes(self.DTYPE):
|
||||
_test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE)
|
||||
|
||||
def test_casts_from(self):
|
||||
for dtype in get_available_cast_dtypes(self.DTYPE):
|
||||
_test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype)
|
||||
|
||||
def test_same_size_ops(self):
|
||||
list(map(
|
||||
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize == self.DTYPE.itemsize else None,
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
for dtype in get_available_cast_dtypes(self.DTYPE):
|
||||
if dtype.itemsize == self.DTYPE.itemsize:
|
||||
_test_ops(a_dtype=self.DTYPE, b_dtype=dtype)
|
||||
|
||||
def test_upcast_ops(self):
|
||||
list(map(
|
||||
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize > self.DTYPE.itemsize else None,
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
for dtype in get_available_cast_dtypes(self.DTYPE):
|
||||
if dtype.itemsize > self.DTYPE.itemsize:
|
||||
_test_ops(a_dtype=self.DTYPE, b_dtype=dtype)
|
||||
|
||||
def test_upcast_to_ops(self):
|
||||
list(map(
|
||||
lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None,
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
for dtype in get_available_cast_dtypes(self.DTYPE):
|
||||
if dtype.itemsize < self.DTYPE.itemsize:
|
||||
_test_ops(a_dtype=dtype, b_dtype=self.DTYPE)
|
||||
|
||||
def test_bitcast(self):
|
||||
if self.DTYPE == dtypes.bool: raise unittest.SkipTest("no bools in bitcast")
|
||||
list(map(
|
||||
lambda dtype:
|
||||
_test_bitcast(Tensor(self.DATA[:8], dtype=self.DTYPE), dtype) if dtype != dtypes.bool else None,
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
for dtype in get_available_cast_dtypes(self.DTYPE):
|
||||
if dtype != dtypes.bool:
|
||||
_test_bitcast(Tensor(self.DATA[:8], dtype=self.DTYPE), dtype)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "PYTHON", "skip for now")
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)), "skip for now")
|
||||
@@ -307,7 +301,7 @@ class TestBitCast(unittest.TestCase):
|
||||
data = rand_for_dtype(dt1, 32).reshape(2, 2, 8)
|
||||
expected = torch.tensor(data.tolist(), dtype=_to_torch_storage_type(dt1)).view(_to_torch_dtype(dt2))
|
||||
if dt2 in dtypes.fp8s:
|
||||
expected = torch.tensor(list(map(lambda x: fp8_to_float(x, dt2), expected.view(-1).tolist()))).view_as(expected)
|
||||
expected = torch.tensor([fp8_to_float(x, dt2) for x in expected.view(-1).tolist()]).view_as(expected)
|
||||
_test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, expected.tolist())
|
||||
|
||||
def test_shape_change_bitcast_exceptions(self):
|
||||
|
||||
@@ -184,16 +184,9 @@ class TestJit(unittest.TestCase):
|
||||
def test_array_jit(self):
|
||||
@TinyJit
|
||||
def add_array(a, arr): return (a+arr[0]).realize()
|
||||
for i in range(5):
|
||||
a = Tensor.randn(10, 10)
|
||||
b = Tensor.randn(10, 10)
|
||||
a.realize(), b.realize()
|
||||
c = add_array(a, [b])
|
||||
if i >= 2:
|
||||
# should fail once jitted since jit can't handle arrays
|
||||
np.testing.assert_allclose(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True, atol=1e-4, rtol=1e-5)
|
||||
else:
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
for _ in range(5):
|
||||
a, b = Tensor.randn(10, 10).realize(), Tensor.randn(10, 10).realize()
|
||||
np.testing.assert_allclose(add_array(a, [b]).numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert_jit_cache_len(add_array, 1)
|
||||
|
||||
def test_jit_copyin(self):
|
||||
@@ -414,12 +407,6 @@ class TestJit(unittest.TestCase):
|
||||
assert isinstance(jf.jit_cache[0].prg, graph_t)
|
||||
assert isinstance(jf.jit_cache[1].prg, graph_t)
|
||||
|
||||
def test_jit_const_inputs(self):
|
||||
@TinyJit
|
||||
def g(x,y,z): return (x+y+z).realize()
|
||||
for i in range(5):
|
||||
np.testing.assert_equal(g(Tensor([i]*3), Tensor.ones(3), Tensor.zeros(3)).numpy(), np.array([i+1]*3))
|
||||
|
||||
def test_jitted_clone(self):
|
||||
def f(a): return a.clone().realize()
|
||||
jf = TinyJit(f)
|
||||
@@ -496,9 +483,10 @@ class TestJit(unittest.TestCase):
|
||||
|
||||
f(Tensor.empty(1))
|
||||
f(Tensor.empty(1))
|
||||
# TODO: this should fail since input has a different size
|
||||
f(Tensor(2.0)).item()
|
||||
# TODO: this should not fail, and should return 3
|
||||
# scalar const input is not allowed
|
||||
with self.assertRaises(JitError):
|
||||
f(Tensor(2.0)).item()
|
||||
# list input has different view structure than empty(1)
|
||||
with self.assertRaises(JitError):
|
||||
f(Tensor([2.0])).item()
|
||||
|
||||
|
||||
@@ -6,14 +6,13 @@ Each test shows behavior that works without JIT but changes with JIT.
|
||||
Comments marked "should be X!" indicate the intuitively expected value.
|
||||
|
||||
SILENT MISMATCHES (highest priority - wrong results, no error):
|
||||
tensors_in_containers_ignored EASY only checks t.__class__ is Tensor, could scan lists/dicts
|
||||
class_method_shared_across_instances EASY could check if first arg is self and warn
|
||||
output_buffer_reuse MED performance tradeoff, could add option or better docs
|
||||
python_constants_frozen HARD inherent to tracing JITs
|
||||
conditional_branches_frozen HARD inherent to tracing JITs
|
||||
unrealized_const_input_frozen HARD unrealized const has no buffer to replace, values baked in
|
||||
|
||||
ERRORS RAISED (lower priority - at least users know):
|
||||
unrealized_const_input_error EASY raises JitError for unrealized const inputs
|
||||
non_tensor_outputs_error EASY raises JitError if return contains non-Tensor values
|
||||
positional_kwargs_cannot_mix EASY normalize positional args to kwargs using function signature
|
||||
duplicate_inputs_fail MED would need to handle aliasing in input_replace
|
||||
@@ -65,20 +64,12 @@ class TestJitFootguns(unittest.TestCase):
|
||||
with self.assertRaises(JitError):
|
||||
f(x, x)
|
||||
|
||||
def test_tensors_in_containers_ignored(self):
|
||||
"""Tensors inside lists/dicts are not tracked as inputs."""
|
||||
def test_tensors_in_containers(self):
|
||||
@TinyJit
|
||||
def f(a, arr): return (a + arr[0]).realize()
|
||||
|
||||
results = []
|
||||
for i in range(4):
|
||||
a, b = Tensor([1, 1, 1]).realize(), Tensor([i, i, i]).realize()
|
||||
results.append(f(a, [b]).numpy().copy())
|
||||
|
||||
np.testing.assert_array_equal(results[0], [1, 1, 1]) # warmup
|
||||
np.testing.assert_array_equal(results[1], [2, 2, 2]) # capture
|
||||
np.testing.assert_array_equal(results[2], [2, 2, 2]) # should be [3,3,3]!
|
||||
np.testing.assert_array_equal(results[3], [2, 2, 2]) # should be [4,4,4]!
|
||||
np.testing.assert_array_equal(f(a, [b]).numpy(), [1+i, 1+i, 1+i])
|
||||
|
||||
def test_nested_jit_fails_on_second_call(self):
|
||||
"""Nested JIT works on first call but fails on second."""
|
||||
@@ -140,16 +131,20 @@ class TestJitFootguns(unittest.TestCase):
|
||||
self.assertEqual(results[2], 20) # should be 30!
|
||||
self.assertEqual(results[3], 20) # should be 40!
|
||||
|
||||
def test_unrealized_const_input_frozen(self):
|
||||
"""Unrealized const tensors have no buffer to replace, so values are baked in at capture time."""
|
||||
def test_unrealized_const_input_error(self):
|
||||
"""Const tensors have no buffer to replace, so JIT raises an error. Even explicit .realize() doesn't help."""
|
||||
@TinyJit
|
||||
def f(a, b): return (a * b).realize()
|
||||
|
||||
for i in range(1, 5):
|
||||
result = f(Tensor([1, 2, 3]).realize(), Tensor(i)) # Tensor(i) is unrealized const
|
||||
# value is frozen at capture (i=2), so i=3,4 give wrong results
|
||||
expected = [2, 4, 6] if i >= 2 else [i, 2*i, 3*i]
|
||||
np.testing.assert_equal(result.numpy(), expected) # i=3,4 should be [3,6,9], [4,8,12]!
|
||||
# unrealized const fails
|
||||
with self.assertRaises(JitError):
|
||||
f(Tensor([1, 2, 3]).realize(), Tensor(2))
|
||||
|
||||
# explicit .realize() on const still fails - const cannot be realized to have a buffer
|
||||
@TinyJit
|
||||
def g(a, b): return (a * b).realize()
|
||||
with self.assertRaises(JitError):
|
||||
g(Tensor([1, 2, 3]).realize(), Tensor(2).realize())
|
||||
|
||||
def test_conditional_branches_frozen(self):
|
||||
"""Only the branch taken during capture runs thereafter."""
|
||||
|
||||
@@ -541,6 +541,37 @@ class TestMultiTensor(unittest.TestCase):
|
||||
np.testing.assert_allclose(r.numpy(), np.ones(256)+np.ones(256), atol=1e-4, rtol=1e-5)
|
||||
assert len(jf.jit_cache) > 0
|
||||
|
||||
def test_multitensor_jit_in_list(self):
|
||||
# test MULTI tensor inside a list container - exercises the container unpacking + MULTI unpacking
|
||||
@TinyJit
|
||||
def f(a, arr): return (a + arr[0]).realize()
|
||||
for i in range(5):
|
||||
a = Tensor.full((4,), i).contiguous().realize().shard(devices_2, 0).realize()
|
||||
b = Tensor.ones(4).contiguous().realize().shard(devices_2, 0).realize()
|
||||
out = f(a, [b])
|
||||
np.testing.assert_allclose(out.numpy(), np.full(4, i) + np.ones(4), atol=1e-4, rtol=1e-5)
|
||||
|
||||
def test_multitensor_jit_multiple_inputs(self):
|
||||
# test multiple MULTI tensors as inputs - each gets unpacked to component UOps
|
||||
@TinyJit
|
||||
def f(a, b, c): return (a + b + c).realize()
|
||||
for i in range(5):
|
||||
a = Tensor.full((4,), i).contiguous().realize().shard(devices_2, 0).realize()
|
||||
b = Tensor.full((4,), i*2).contiguous().realize().shard(devices_2, 0).realize()
|
||||
c = Tensor.ones(4).contiguous().realize().shard(devices_2, 0).realize()
|
||||
out = f(a, b, c)
|
||||
np.testing.assert_allclose(out.numpy(), np.full(4, i) + np.full(4, i*2) + np.ones(4), atol=1e-4, rtol=1e-5)
|
||||
|
||||
def test_multitensor_jit_different_sharding(self):
|
||||
# test MULTI tensors with different sharding - one sharded on axis 0, one broadcast (axis=None)
|
||||
@TinyJit
|
||||
def f(a, b): return (a + b).realize()
|
||||
for i in range(5):
|
||||
a = Tensor.full((4, 4), i).contiguous().realize().shard(devices_2, 0).realize()
|
||||
b = Tensor.full((4, 4), i*2).contiguous().realize().shard(devices_2, None).realize()
|
||||
out = f(a, b)
|
||||
np.testing.assert_allclose(out.numpy(), np.full((4, 4), i) + np.full((4, 4), i*2), atol=1e-4, rtol=1e-5)
|
||||
|
||||
@unittest.skip("test broken")
|
||||
def test_multi_device_jit_graph(self):
|
||||
if Device[d0].graph is None or Device[d1].graph is None: raise unittest.SkipTest("only test graphs")
|
||||
|
||||
@@ -217,7 +217,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_ones_sum(self):
|
||||
def f(a): return a.sum().realize()
|
||||
jf = TinyJit(f)
|
||||
t = Tensor.ones(10)
|
||||
t = Tensor.ones(10).contiguous()
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = jf(t[:vi]).item()
|
||||
|
||||
@@ -199,7 +199,7 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_ones_sum(self):
|
||||
t = Tensor.ones(10)
|
||||
t = Tensor.ones(10).contiguous()
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = t[:vi].sum().item()
|
||||
|
||||
@@ -241,18 +241,18 @@ class TestTK(unittest.TestCase):
|
||||
np.testing.assert_allclose(b.numpy(), ref.numpy())
|
||||
np.testing.assert_allclose(c.numpy(), ref.numpy())
|
||||
|
||||
@unittest.skip("TODO")
|
||||
def test_load_store_group(self):
|
||||
N = 256
|
||||
N = 1024
|
||||
BLOCK_SIZE = 64
|
||||
with Kernel("load_store_group", (N // BLOCK_SIZE, N // BLOCK_SIZE, 1), WARP_THREADS * 2) as ker:
|
||||
NUM_WORKERS = 4
|
||||
with Kernel("load_store_group", (N // (BLOCK_SIZE * NUM_WORKERS), N // BLOCK_SIZE, 1), WARP_THREADS * NUM_WORKERS) as ker:
|
||||
warp = ker.warp
|
||||
group = ker.group(2)
|
||||
group = ker.group(NUM_WORKERS)
|
||||
|
||||
b = ker.gl((1, 1, N, N), dtypes.float32)
|
||||
a = ker.gl((1, 1, N, N), dtypes.float32)
|
||||
|
||||
a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE * NUM_WORKERS), dtypes.float32)
|
||||
|
||||
a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
@@ -260,9 +260,9 @@ class TestTK(unittest.TestCase):
|
||||
col, row = ker.blockIdx_x, ker.blockIdx_y
|
||||
|
||||
a_smem = group.load(a_smem, a, (), (0, 0, row, col), axis=2)
|
||||
a_reg = warp.load(a_reg, a_smem)
|
||||
a_reg = warp.load(a_reg, a_smem, (), (0, ker.warpid,))
|
||||
b_reg = warp.copy(b_reg, a_reg)
|
||||
b = warp.store(b, b_reg, (0, 0, row, col), (), axis=2)
|
||||
b = warp.store(b, b_reg, (0, 0, row, col * NUM_WORKERS + ker.warpid), (), axis=2)
|
||||
|
||||
sink = ker.finish()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user