mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
use assign_targets in LazyOp creation (#4568)
* start * correct error * this is possible * document it
This commit is contained in:
@@ -323,14 +323,28 @@ class TestAssign(unittest.TestCase):
|
||||
b = Tensor.full((32, 32), 1.).contiguous().realize()
|
||||
c = Tensor.full((32, 32), 2.).contiguous().realize()
|
||||
|
||||
# TODO: this is failing in cycle error, it should fail earlier.
|
||||
with self.assertRaisesRegex(RuntimeError, "cycle"):
|
||||
with self.assertRaisesRegex(RuntimeError, "contiguous"):
|
||||
r = a.sum(axis=1)
|
||||
b_perm = b.permute(1, 0)
|
||||
b.assign(r + b)
|
||||
c.assign(r + b_perm)
|
||||
Tensor.realize(b, c)
|
||||
|
||||
def test_permuted_reduceop_multioutput_dual_use_possible(self):
|
||||
a = Tensor.randn(32, 32, 32, dtype=dtypes.int).realize()
|
||||
b = Tensor.arange(32 * 32).reshape(32, 32).realize()
|
||||
c = Tensor.arange(32 * 32).reshape(32, 32).realize()
|
||||
|
||||
kc = GlobalCounters.kernel_count
|
||||
r = a.sum(axis=1)
|
||||
b_perm = b.permute(1, 0)
|
||||
b.assign(r + b)
|
||||
c.assign(r + b_perm.contiguous())
|
||||
Tensor.realize(b, c)
|
||||
assert GlobalCounters.kernel_count - kc == 2
|
||||
np.testing.assert_equal(b.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32))
|
||||
np.testing.assert_equal(c.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32).transpose(1, 0))
|
||||
|
||||
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
|
||||
|
||||
@unittest.skip("don't use output buffer, and mismatch dtype no longer supported")
|
||||
|
||||
Reference in New Issue
Block a user