use assign_targets in LazyOp creation (#4568)

* start

* correct error

* this is possible

* document it
This commit is contained in:
qazal
2024-05-13 15:24:35 +08:00
committed by GitHub
parent b0fa97e176
commit 77aa8659f5
2 changed files with 25 additions and 10 deletions

View File

@@ -323,14 +323,28 @@ class TestAssign(unittest.TestCase):
b = Tensor.full((32, 32), 1.).contiguous().realize()
c = Tensor.full((32, 32), 2.).contiguous().realize()
# TODO: this is failing in cycle error, it should fail earlier.
with self.assertRaisesRegex(RuntimeError, "cycle"):
with self.assertRaisesRegex(RuntimeError, "contiguous"):
r = a.sum(axis=1)
b_perm = b.permute(1, 0)
b.assign(r + b)
c.assign(r + b_perm)
Tensor.realize(b, c)
def test_permuted_reduceop_multioutput_dual_use_possible(self):
a = Tensor.randn(32, 32, 32, dtype=dtypes.int).realize()
b = Tensor.arange(32 * 32).reshape(32, 32).realize()
c = Tensor.arange(32 * 32).reshape(32, 32).realize()
kc = GlobalCounters.kernel_count
r = a.sum(axis=1)
b_perm = b.permute(1, 0)
b.assign(r + b)
c.assign(r + b_perm.contiguous())
Tensor.realize(b, c)
assert GlobalCounters.kernel_count - kc == 2
np.testing.assert_equal(b.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32))
np.testing.assert_equal(c.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32).transpose(1, 0))
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
@unittest.skip("don't use output buffer, and mismatch dtype no longer supported")