diff --git a/test/test_tensor_variable.py b/test/test_tensor_variable.py index 643b1c5dec..b05529c71c 100644 --- a/test/test_tensor_variable.py +++ b/test/test_tensor_variable.py @@ -73,8 +73,6 @@ class TestTensorVariable(unittest.TestCase): ret = Tensor.arange(vv.bind(4), 7) self.assertListEqual(ret[:3].tolist(), [4,5,6]) - # TODO: add vmin/vmax pattern for symbolic denominator - @unittest.expectedFailure def test_symbolic_arange_sym_step(self): vv = Variable("step", 1, 3) ret = Tensor.arange(0, 10, vv.bind(2)) @@ -86,6 +84,18 @@ class TestTensorVariable(unittest.TestCase): ret = Tensor.arange(begin.bind(4), end.bind(7)) self.assertListEqual(ret[:3].tolist(), [4,5,6]) + def test_symbolic_arange_three_vars(self): + begin = Variable("b", 0, 5) + end = Variable("e", 10, 20) + step = Variable("s", 1, 3) + ret = Tensor.arange(begin.bind(2), end.bind(14), step.bind(3)) + self.assertListEqual(ret[:4].tolist(), [2,5,8,11]) + + def test_symbolic_full(self): + vv = Variable("x", 1, 10).bind(5) + t = Tensor.full((3,), vv) + self.assertListEqual(t.tolist(), [5,5,5]) + def test_variable_empty(self): v = Variable("i", 1, 10) # TODO: Tensor creation from unbound variable should assert diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c50dcfcb6a..652fbe8511 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -127,7 +127,7 @@ class Tensor(OpMixin): # create a UOp from the different types of inputs if isinstance(data, UOp): - assert _dtype is None or _dtype==data.dtype, f"dtype doesn't match ({_dtype} vs {data.dtype}), and casting isn't supported" + assert _dtype is None or _dtype==data.dtype or data.dtype==dtypes.index, f"dtype mismatch: {_dtype} vs {data.dtype}" # if data is dtype.index that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of if data.dtype==dtypes.index: data = _index_to_concrete_int(data) if data.op is Ops.BIND: # type: ignore # mypy type narrowing is bugged here