mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
fix test_symbolic_arange_sym_step (#13952)
This commit is contained in:
@@ -73,8 +73,6 @@ class TestTensorVariable(unittest.TestCase):
|
|||||||
ret = Tensor.arange(vv.bind(4), 7)
|
ret = Tensor.arange(vv.bind(4), 7)
|
||||||
self.assertListEqual(ret[:3].tolist(), [4,5,6])
|
self.assertListEqual(ret[:3].tolist(), [4,5,6])
|
||||||
|
|
||||||
# TODO: add vmin/vmax pattern for symbolic denominator
|
|
||||||
@unittest.expectedFailure
|
|
||||||
def test_symbolic_arange_sym_step(self):
|
def test_symbolic_arange_sym_step(self):
|
||||||
vv = Variable("step", 1, 3)
|
vv = Variable("step", 1, 3)
|
||||||
ret = Tensor.arange(0, 10, vv.bind(2))
|
ret = Tensor.arange(0, 10, vv.bind(2))
|
||||||
@@ -86,6 +84,18 @@ class TestTensorVariable(unittest.TestCase):
|
|||||||
ret = Tensor.arange(begin.bind(4), end.bind(7))
|
ret = Tensor.arange(begin.bind(4), end.bind(7))
|
||||||
self.assertListEqual(ret[:3].tolist(), [4,5,6])
|
self.assertListEqual(ret[:3].tolist(), [4,5,6])
|
||||||
|
|
||||||
|
def test_symbolic_arange_three_vars(self):
|
||||||
|
begin = Variable("b", 0, 5)
|
||||||
|
end = Variable("e", 10, 20)
|
||||||
|
step = Variable("s", 1, 3)
|
||||||
|
ret = Tensor.arange(begin.bind(2), end.bind(14), step.bind(3))
|
||||||
|
self.assertListEqual(ret[:4].tolist(), [2,5,8,11])
|
||||||
|
|
||||||
|
def test_symbolic_full(self):
|
||||||
|
vv = Variable("x", 1, 10).bind(5)
|
||||||
|
t = Tensor.full((3,), vv)
|
||||||
|
self.assertListEqual(t.tolist(), [5,5,5])
|
||||||
|
|
||||||
def test_variable_empty(self):
|
def test_variable_empty(self):
|
||||||
v = Variable("i", 1, 10)
|
v = Variable("i", 1, 10)
|
||||||
# TODO: Tensor creation from unbound variable should assert
|
# TODO: Tensor creation from unbound variable should assert
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ class Tensor(OpMixin):
|
|||||||
|
|
||||||
# create a UOp from the different types of inputs
|
# create a UOp from the different types of inputs
|
||||||
if isinstance(data, UOp):
|
if isinstance(data, UOp):
|
||||||
assert _dtype is None or _dtype==data.dtype, f"dtype doesn't match ({_dtype} vs {data.dtype}), and casting isn't supported"
|
assert _dtype is None or _dtype==data.dtype or data.dtype==dtypes.index, f"dtype mismatch: {_dtype} vs {data.dtype}"
|
||||||
# if data is dtype.index that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of
|
# if data is dtype.index that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of
|
||||||
if data.dtype==dtypes.index: data = _index_to_concrete_int(data)
|
if data.dtype==dtypes.index: data = _index_to_concrete_int(data)
|
||||||
if data.op is Ops.BIND: # type: ignore # mypy type narrowing is bugged here
|
if data.op is Ops.BIND: # type: ignore # mypy type narrowing is bugged here
|
||||||
|
|||||||
Reference in New Issue
Block a user