fix(representation): set correct output type for dot product

This commit is contained in:
Umut
2021-11-22 14:02:12 +03:00
parent 1d77816aa3
commit 7e65af3906
2 changed files with 14 additions and 24 deletions

View File

@@ -18,14 +18,7 @@ from ..debugging.custom_assert import assert_true
from ..helpers import indexing_helpers
from ..helpers.formatting_helpers import format_constant
from ..helpers.python_helpers import catch, update_and_return_dict
from ..values import (
BaseValue,
ClearScalar,
ClearTensor,
EncryptedScalar,
EncryptedTensor,
TensorValue,
)
from ..values import BaseValue, ClearTensor, EncryptedTensor, TensorValue
IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func"
@@ -490,11 +483,20 @@ class Dot(IntermediateNode):
f"Dot between vectors of shapes {lhs.shape} and {rhs.shape} is not supported",
)
output_scalar_value = (
EncryptedScalar if (lhs.is_encrypted or rhs.is_encrypted) else ClearScalar
)
output_shape: Tuple[int, ...]
if (lhs.ndim == 1 and rhs.ndim == 1) or (lhs.ndim == 0 and rhs.ndim == 0):
# numpy.dot(x, y) where x and y are both vectors or both scalars
output_shape = ()
elif lhs.ndim == 1:
# numpy.dot(x, y) where x is a vector and y is a scalar
output_shape = lhs.shape
else:
# numpy.dot(x, y) where x is a scalar and y is a vector
output_shape = rhs.shape
self.outputs = [output_scalar_value(output_dtype)]
output_value = EncryptedTensor if (lhs.is_encrypted or rhs.is_encrypted) else ClearTensor
self.outputs = [output_value(output_dtype, output_shape)]
self.evaluation_function = delegate_evaluation_function
def text_for_drawing(self) -> str:

View File

@@ -927,11 +927,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[0, 2],
],
),
# TODO: find a way to support this case
# https://github.com/zama-ai/concretefhe-internal/issues/837
#
# the problem is that compiler doesn't support combining scalars and tensors
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
pytest.param(
lambda x: numpy.dot(x, 2),
{
@@ -940,13 +935,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
([2, 7, 1],),
[4, 14, 2],
marks=pytest.mark.xfail(strict=True),
),
# TODO: find a way to support this case
# https://github.com/zama-ai/concretefhe-internal/issues/837
#
# the problem is that compiler doesn't support combining scalars and tensors
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
pytest.param(
lambda x: numpy.dot(2, x),
{
@@ -955,7 +944,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
([2, 7, 1],),
[4, 14, 2],
marks=pytest.mark.xfail(strict=True),
),
],
)