mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
enable test_sample for all backend (#2593)
This commit is contained in:
@@ -1,10 +1,8 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad import Device
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in ["GPU", "METAL", "CLANG", "CUDA", "LLVM"], f"{Device.DEFAULT} is not supported")
|
||||
class TestSample(unittest.TestCase):
|
||||
def test_sample(self):
|
||||
X = Tensor.rand(10000, 50).realize()
|
||||
@@ -12,7 +10,7 @@ class TestSample(unittest.TestCase):
|
||||
idxs = np.random.randint(0, X.shape[0], size=(BS))
|
||||
# this uncovered a bug with arg sort order
|
||||
batch = [Variable(f'idx{i}', 0, X.shape[0]-1).bind(s) for i,s in enumerate(idxs.tolist())]
|
||||
x = Tensor.cat(*[X.shrink(((batch[i], batch[i]+1), (0,X.shape[1]))) for i in range(BS)])
|
||||
x = Tensor.cat(*[X.shrink(((batch[i], batch[i]+1), None)) for i in range(BS)])
|
||||
print(idxs)
|
||||
ret = x.numpy()
|
||||
base = X.numpy()[idxs]
|
||||
|
||||
Reference in New Issue
Block a user