mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
hotfix: test_assign_contiguous
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad import dtypes, TinyJit
|
||||
from tinygrad import dtypes, TinyJit, GlobalCounters
|
||||
|
||||
N = 200 # has to be bigger than the cache to fail
|
||||
|
||||
@@ -69,6 +69,20 @@ class TestAssign(unittest.TestCase):
|
||||
for _ in range(4): f(y)
|
||||
assert y.item() == 4
|
||||
|
||||
def test_assign_contiguous(self):
|
||||
b = Tensor.rand(4,4).realize()
|
||||
a = (Tensor.rand(4,4).realize() + 1)
|
||||
kc = GlobalCounters.kernel_count
|
||||
b.assign(a.contiguous()).realize()
|
||||
assert GlobalCounters.kernel_count - kc == 1
|
||||
|
||||
def test_assign_contiguous_permute(self):
|
||||
b = Tensor.rand(4,4).realize()
|
||||
a = (Tensor.rand(4,4).realize() + 1).permute((1,0))
|
||||
kc = GlobalCounters.kernel_count
|
||||
b.assign(a.contiguous()).realize()
|
||||
assert GlobalCounters.kernel_count - kc == 1
|
||||
|
||||
def test_permuted_assignment(self):
|
||||
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
||||
b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
||||
|
||||
Reference in New Issue
Block a user