fix(frontend-python): mimic summing like numpy exactly

This commit is contained in:
Umut
2023-07-06 16:30:46 +02:00
parent 1215a57752
commit 9ef1229c51
2 changed files with 18 additions and 0 deletions

View File

@@ -465,6 +465,12 @@ class Tracer:
sanitized_args = [self.sanitize(args[0])]
if len(args) > 1:
kwargs["newshape"] = args[1]
elif func is np.sum:
sanitized_args = [self.sanitize(args[0])]
for i, keyword in enumerate(["axis", "dtype", "out", "keepdims", "initial", "where"]):
position = i + 1
if len(args) > position:
kwargs[keyword] = args[position]
elif func is np.transpose:
sanitized_args = [self.sanitize(args[0])]
if len(args) > 1:

View File

@@ -17,6 +17,18 @@ from concrete import fhe
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.sum(x, 0),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.sum(x, 1),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.sum(x, axis=None), # type: ignore
{