fix(frontend-python): stop crashing on scalar squeeze

This commit is contained in:
Umut
2024-02-02 12:09:27 +03:00
parent 565e6f2796
commit efc9314d25
2 changed files with 10 additions and 0 deletions

View File

@@ -521,6 +521,9 @@ class Converter:
# if the output shape is (), it means (1, 1, ..., 1, 1) is squeezed
# and the result is a scalar, so we need to do indexing, not reshape
if node.output.shape == ():
if preds[0].shape == ():
return preds[0]
assert all(size == 1 for size in preds[0].shape)
index = (0,) * len(preds[0].shape)
return ctx.index_static(ctx.typeof(node), preds[0], index)

View File

@@ -659,6 +659,13 @@ def copy_modify(x):
},
id="x ** 3",
),
pytest.param(
lambda x: np.squeeze(x),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": ()},
},
id="np.squeeze(x)",
),
pytest.param(
lambda x: np.squeeze(x),
{