fix(frontend-python): support axis as a positional argument for np.expand_dims

This commit is contained in:
Umut
2023-05-10 10:05:15 +02:00
parent 5b07f44e74
commit 0f726882c9
2 changed files with 11 additions and 0 deletions

View File

@@ -466,6 +466,10 @@ class Tracer:
sanitized_args = [self.sanitize(args[0])]
if len(args) > 1:
kwargs["axes"] = args[1]
elif func is np.expand_dims:
sanitized_args = [self.sanitize(args[0])]
if len(args) > 1:
kwargs["axis"] = args[1]
else:
sanitized_args = [self.sanitize(arg) for arg in args]

View File

@@ -586,6 +586,13 @@ def deterministic_unary_function(x):
},
id="fhe.LookupTable(list(range(32)))[x + y]",
),
pytest.param(
lambda x: np.expand_dims(x, 0),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
},
id="np.expand_dims(x, 0)",
),
pytest.param(
lambda x: np.expand_dims(x, axis=0),
{