mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
fix(frontend-python): stop crashing on scalar squeeze
This commit is contained in:
@@ -521,6 +521,9 @@ class Converter:
|
||||
# if the output shape is (), it means (1, 1, ..., 1, 1) is squeezed
|
||||
# and the result is a scalar, so we need to do indexing, not reshape
|
||||
if node.output.shape == ():
|
||||
if preds[0].shape == ():
|
||||
return preds[0]
|
||||
|
||||
assert all(size == 1 for size in preds[0].shape)
|
||||
index = (0,) * len(preds[0].shape)
|
||||
return ctx.index_static(ctx.typeof(node), preds[0], index)
|
||||
|
||||
@@ -659,6 +659,13 @@ def copy_modify(x):
|
||||
},
|
||||
id="x ** 3",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.squeeze(x),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [-10, 10], "shape": ()},
|
||||
},
|
||||
id="np.squeeze(x)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.squeeze(x),
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user