jit case with Tensor.empty input, realized means allocated (#14930)

* simple failing jit test case with Tensor.empty

* this used to exist in ops.py...

* Revert "removed if self.buffer.is_allocated() in realized (#14836)"

This reverts commit 72cf603805.
This commit is contained in:
qazal
2026-02-21 15:33:55 +08:00
committed by GitHub
parent 6533250246
commit c5029fa460
5 changed files with 17 additions and 6 deletions

View File

@@ -490,6 +490,15 @@ class TestJit(unittest.TestCase):
#with self.assertRaises(JitError):
# f(Tensor([2.0])).item()
def test_jit_init_empty_alt(self):
@TinyJit
def f(a:Tensor, b:Tensor) -> Tensor: return b.assign(a+1)
a = Tensor([1])
for _ in range(4):
b = Tensor.empty_like(a)
c = f(a, b)
self.assertEqual(c.item(), 2)
@unittest.skip("Pending multioutput implementation #3607")
class TestMultioutputJit(unittest.TestCase):
def _test(self, f):

View File

@@ -169,7 +169,7 @@ class TestSchedule(unittest.TestCase):
def test_empty_is_not_realized(self):
a = Tensor.empty(10)
child = a+2
assert a.uop.is_realized
assert not a.uop.is_realized
child.realize()
assert a.uop.is_realized
@@ -185,7 +185,7 @@ class TestSchedule(unittest.TestCase):
def test_childless_empty_never_allocates(self):
a = Tensor.empty(10)
a.realize()
assert not a.uop.buffer.is_allocated()
assert not a.uop.is_realized
def test_simplify_padded_const(self):
a, _ = Tensor.empty(1022).cummax(axis=0)

View File

@@ -30,14 +30,14 @@ class TestRealizeIsRealized(unittest.TestCase):
def test_empty(self):
t = Tensor.empty(4, 4).realize()
assert t.uop.is_realized
assert not t.uop.is_realized
def test_disk(self):
with tempfile.NamedTemporaryFile() as f:
f.write(b'\x00' * 16)
f.flush()
t = Tensor.empty(4, dtype=dtypes.float32, device=f"disk:{f.name}").realize()
assert t.uop.is_realized
assert not t.uop.is_realized
def test_assign(self):
t = Tensor([1, 2, 3])

View File

@@ -36,7 +36,8 @@ class TestSetitemInto(unittest.TestCase):
self.assertEqual(GlobalCounters.kernel_count, 0)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 4)
# TODO: this can be just 4 if empty goes through is_realized setitem path
self.assertEqual(GlobalCounters.global_mem, 4*(3*2+1)) # 3 elements had +1, 1 is assigned directly
t[1].realize()
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)

View File

@@ -694,7 +694,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.op not in (Ops.BUFFER, Ops.MSTACK): return None
# LUNIQUEs are never realized
if self.op_in_backward_slice_with_self(Ops.LUNIQUE): return None
return self.buffer
# NOTE: this is used by the JIT to determine which inputs we capture
return self.buffer if self.buffer.is_allocated() else None
@property
def is_realized(self) -> bool: return self.base.realized is not None