assign jit test case with other tensor as input (#4098)

hmm it works
This commit is contained in:
chenyu
2024-04-06 14:41:14 -04:00
committed by GitHub
parent e4a1858471
commit bdbcac67f1

View File

@@ -79,10 +79,34 @@ class TestAssign(unittest.TestCase):
x.realize()
x = Tensor([0])
for _ in range(5): f(x)
assert x.item() == 5
y = Tensor([0])
for _ in range(4): f(y)
assert y.item() == 4
def test_assign_other_jit(self):
@TinyJit
def f(x, a):
x.assign(a)
x.realize()
x = Tensor([0])
for i in range(1, 6):
f(x, x.full_like(i).contiguous()) # const would be implicitly folded without contiguous
assert x.item() == i
def test_assign_add_other_jit(self):
@TinyJit
def f(x, a):
x += a
x.realize()
x = Tensor([0])
a = 0
for i in range(1, 6):
a += i
f(x, x.full_like(i).contiguous())
assert x.item() == a
def test_assign_changes(self):
a = Tensor.ones(4).contiguous().realize()
old_a = a