fix test_symbolic_arange_sym_step (#13952)

This commit is contained in:
chenyu
2026-01-01 09:41:07 -05:00
committed by GitHub
parent b91b46091c
commit c69470be52
2 changed files with 13 additions and 3 deletions

View File

@@ -73,8 +73,6 @@ class TestTensorVariable(unittest.TestCase):
ret = Tensor.arange(vv.bind(4), 7)
self.assertListEqual(ret[:3].tolist(), [4,5,6])
# TODO: add vmin/vmax pattern for symbolic denominator
@unittest.expectedFailure
def test_symbolic_arange_sym_step(self):
vv = Variable("step", 1, 3)
ret = Tensor.arange(0, 10, vv.bind(2))
@@ -86,6 +84,18 @@ class TestTensorVariable(unittest.TestCase):
ret = Tensor.arange(begin.bind(4), end.bind(7))
self.assertListEqual(ret[:3].tolist(), [4,5,6])
def test_symbolic_arange_three_vars(self):
begin = Variable("b", 0, 5)
end = Variable("e", 10, 20)
step = Variable("s", 1, 3)
ret = Tensor.arange(begin.bind(2), end.bind(14), step.bind(3))
self.assertListEqual(ret[:4].tolist(), [2,5,8,11])
def test_symbolic_full(self):
vv = Variable("x", 1, 10).bind(5)
t = Tensor.full((3,), vv)
self.assertListEqual(t.tolist(), [5,5,5])
def test_variable_empty(self):
v = Variable("i", 1, 10)
# TODO: Tensor creation from unbound variable should assert

View File

@@ -127,7 +127,7 @@ class Tensor(OpMixin):
# create a UOp from the different types of inputs
if isinstance(data, UOp):
assert _dtype is None or _dtype==data.dtype, f"dtype doesn't match ({_dtype} vs {data.dtype}), and casting isn't supported"
assert _dtype is None or _dtype==data.dtype or data.dtype==dtypes.index, f"dtype mismatch: {_dtype} vs {data.dtype}"
# if data is dtype.index that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of
if data.dtype==dtypes.index: data = _index_to_concrete_int(data)
if data.op is Ops.BIND: # type: ignore # mypy type narrowing is bugged here