From 7e65af390662f37f4e05771dd494267e5733a566 Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 22 Nov 2021 14:02:12 +0300 Subject: [PATCH] fix(representation): set correct output type for dot product --- .../common/representation/intermediate.py | 26 ++++++++++--------- tests/numpy/test_compile.py | 12 --------- 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index ab5446f0a..0c6d851f7 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -18,14 +18,7 @@ from ..debugging.custom_assert import assert_true from ..helpers import indexing_helpers from ..helpers.formatting_helpers import format_constant from ..helpers.python_helpers import catch, update_and_return_dict -from ..values import ( - BaseValue, - ClearScalar, - ClearTensor, - EncryptedScalar, - EncryptedTensor, - TensorValue, -) +from ..values import BaseValue, ClearTensor, EncryptedTensor, TensorValue IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func" @@ -490,11 +483,20 @@ class Dot(IntermediateNode): f"Dot between vectors of shapes {lhs.shape} and {rhs.shape} is not supported", ) - output_scalar_value = ( - EncryptedScalar if (lhs.is_encrypted or rhs.is_encrypted) else ClearScalar - ) + output_shape: Tuple[int, ...] + if (lhs.ndim == 1 and rhs.ndim == 1) or (lhs.ndim == 0 and rhs.ndim == 0): + # numpy.dot(x, y) where x and y are both vectors or both scalars + output_shape = () + elif lhs.ndim == 1: + # numpy.dot(x, y) where x is a vector and y is a scalar + output_shape = lhs.shape + else: + # numpy.dot(x, y) where x is a scalar and y is a vector + output_shape = rhs.shape - self.outputs = [output_scalar_value(output_dtype)] + output_value = EncryptedTensor if (lhs.is_encrypted or rhs.is_encrypted) else ClearTensor + + self.outputs = [output_value(output_dtype, output_shape)] self.evaluation_function = delegate_evaluation_function def text_for_drawing(self) -> str: diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index b62ce99de..22a9642a1 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -927,11 +927,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [0, 2], ], ), - # TODO: find a way to support this case - # https://github.com/zama-ai/concretefhe-internal/issues/837 - # - # the problem is that compiler doesn't support combining scalars and tensors - # but they do support broadcasting, so scalars should be converted to (1,) shaped tensors pytest.param( lambda x: numpy.dot(x, 2), { @@ -940,13 +935,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)], ([2, 7, 1],), [4, 14, 2], - marks=pytest.mark.xfail(strict=True), ), - # TODO: find a way to support this case - # https://github.com/zama-ai/concretefhe-internal/issues/837 - # - # the problem is that compiler doesn't support combining scalars and tensors - # but they do support broadcasting, so scalars should be converted to (1,) shaped tensors pytest.param( lambda x: numpy.dot(2, x), { @@ -955,7 +944,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)], ([2, 7, 1],), [4, 14, 2], - marks=pytest.mark.xfail(strict=True), ), ], )