mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -3,7 +3,7 @@
|
||||
import unittest, random, warnings
|
||||
import numpy as np
|
||||
|
||||
from tinygrad import Tensor, dtypes, Device, TinyJit
|
||||
from tinygrad import Tensor, dtypes, Device, TinyJit, Variable
|
||||
from tinygrad.helpers import all_same, prod
|
||||
from test.helpers import slow
|
||||
|
||||
@@ -647,6 +647,11 @@ class TestIndexing(unittest.TestCase):
|
||||
i, j = indices
|
||||
numpy_testing_assert_equal_helper(x[i:j], x[0:1])
|
||||
|
||||
def test_variable_with_tensor_index(self):
|
||||
t = Tensor.arange(12).reshape(3, 4)
|
||||
v = Variable("v", 0, 2).bind(1)
|
||||
numpy_testing_assert_equal_helper(t[v, Tensor([0, 1, 2])], t[1, Tensor([0, 1, 2])])
|
||||
|
||||
def test_ellipsis_tensor(self):
|
||||
x = Tensor.arange(0, 9).reshape(3, 3)
|
||||
idx = Tensor([0, 2])
|
||||
|
||||
@@ -1215,7 +1215,7 @@ class Tensor(OpMixin):
|
||||
x = x.reshape(tuple(index['size'] for index in indices_parsed if not isinstance(index['index'], sint)))
|
||||
|
||||
# tensor indexing
|
||||
if tops := [(d,i) for d,i in enumerate(i_ for i_ in indices_parsed if not isinstance(i_['index'], int)) if isinstance(i['index'], Tensor)]:
|
||||
if tops := [(d,i) for d,i in enumerate(i_ for i_ in indices_parsed if not isinstance(i_['index'], sint)) if isinstance(i['index'], Tensor)]:
|
||||
# unload the tensor object into actual tensors
|
||||
dims, tensors, masks = [d for d,_ in tops], cast(list[Tensor], [i['index'] for _,i in tops]), []
|
||||
big_shape = _broadcast_shape(*(t.shape for t in tensors))
|
||||
|
||||
Reference in New Issue
Block a user