mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
* Switch to dawn, all tests passing locally * Use dawn-python * Skip failing test * Skip midcast and fix timestamp on metal ci * Autogen webgpu * Try fetch dawn lib again * /usr/lib * Without lib prefix * Test autogen diff * Delete webgpu support, move everything to ops_webgpu * mypy fix * Simplify, refactor * Line savings * No ResultContainer * Type annotation for result * Some more simplifications * Why was this explicit sync used at all? * Refactor: delete functions that are only used once * Create shader module inline * Clear unit tests cache, maybe that solves it * That wasn't it * Try deleting cache to pass failing weight compare * weights_only=False for pytorch 2.6 * Simplify ctype array creation * Remove nanosecond precision timestamps * Simplify error handling * Refactor, add back type annotations * Deleted custom submit function, refactor * read_buffer simplify * Fix use after free, refactor * Simplify supported_features * Runtime docs --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
20 lines
722 B
Python
20 lines
722 B
Python
import unittest
|
|
import numpy as np
|
|
from tinygrad import Tensor, Variable, Device
|
|
|
|
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU can only run kernels with up to 10 buffers")
|
|
class TestSample(unittest.TestCase):
|
|
def test_sample(self):
|
|
X = Tensor.rand(10000, 50).realize()
|
|
BS = 16
|
|
idxs = np.random.randint(0, X.shape[0], size=(BS))
|
|
# this uncovered a bug with arg sort order
|
|
batch = [Variable(f'idx{i}', 0, X.shape[0]-1).bind(s) for i,s in enumerate(idxs.tolist())]
|
|
x = Tensor.cat(*[X.shrink(((batch[i], batch[i]+1), None)) for i in range(BS)])
|
|
print(idxs)
|
|
ret = x.numpy()
|
|
base = X.numpy()[idxs]
|
|
np.testing.assert_equal(ret, base)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |