Files
tinygrad/test/test_sample.py
George Hotz 9e07824542 move device to device.py (#2466)
* move device to device.py

* pylint test --disable R,C,W,E --enable E0611

* fix tests
2023-11-27 11:34:37 -08:00

22 lines
816 B
Python

import unittest
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad import Device
from tinygrad.shape.symbolic import Variable
@unittest.skipUnless(Device.DEFAULT in ["GPU", "METAL", "CLANG", "CUDA", "LLVM"], f"{Device.DEFAULT} is not supported")
class TestSample(unittest.TestCase):
def test_sample(self):
X = Tensor.rand(10000, 50).realize()
BS = 16
idxs = np.random.randint(0, X.shape[0], size=(BS))
# this uncovered a bug with arg sort order
batch = [Variable(f'idx{i}', 0, X.shape[0]-1).bind(s) for i,s in enumerate(idxs.tolist())]
x = Tensor.cat(*[X.shrink(((batch[i], batch[i]+1), (0,X.shape[1]))) for i in range(BS)])
print(idxs)
ret = x.numpy()
base = X.numpy()[idxs]
np.testing.assert_equal(ret, base)
if __name__ == '__main__':
unittest.main()