mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix(representation): set correct output type for dot product
This commit is contained in:
@@ -18,14 +18,7 @@ from ..debugging.custom_assert import assert_true
|
||||
from ..helpers import indexing_helpers
|
||||
from ..helpers.formatting_helpers import format_constant
|
||||
from ..helpers.python_helpers import catch, update_and_return_dict
|
||||
from ..values import (
|
||||
BaseValue,
|
||||
ClearScalar,
|
||||
ClearTensor,
|
||||
EncryptedScalar,
|
||||
EncryptedTensor,
|
||||
TensorValue,
|
||||
)
|
||||
from ..values import BaseValue, ClearTensor, EncryptedTensor, TensorValue
|
||||
|
||||
IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func"
|
||||
|
||||
@@ -490,11 +483,20 @@ class Dot(IntermediateNode):
|
||||
f"Dot between vectors of shapes {lhs.shape} and {rhs.shape} is not supported",
|
||||
)
|
||||
|
||||
output_scalar_value = (
|
||||
EncryptedScalar if (lhs.is_encrypted or rhs.is_encrypted) else ClearScalar
|
||||
)
|
||||
output_shape: Tuple[int, ...]
|
||||
if (lhs.ndim == 1 and rhs.ndim == 1) or (lhs.ndim == 0 and rhs.ndim == 0):
|
||||
# numpy.dot(x, y) where x and y are both vectors or both scalars
|
||||
output_shape = ()
|
||||
elif lhs.ndim == 1:
|
||||
# numpy.dot(x, y) where x is a vector and y is a scalar
|
||||
output_shape = lhs.shape
|
||||
else:
|
||||
# numpy.dot(x, y) where x is a scalar and y is a vector
|
||||
output_shape = rhs.shape
|
||||
|
||||
self.outputs = [output_scalar_value(output_dtype)]
|
||||
output_value = EncryptedTensor if (lhs.is_encrypted or rhs.is_encrypted) else ClearTensor
|
||||
|
||||
self.outputs = [output_value(output_dtype, output_shape)]
|
||||
self.evaluation_function = delegate_evaluation_function
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
|
||||
@@ -927,11 +927,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[0, 2],
|
||||
],
|
||||
),
|
||||
# TODO: find a way to support this case
|
||||
# https://github.com/zama-ai/concretefhe-internal/issues/837
|
||||
#
|
||||
# the problem is that compiler doesn't support combining scalars and tensors
|
||||
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
|
||||
pytest.param(
|
||||
lambda x: numpy.dot(x, 2),
|
||||
{
|
||||
@@ -940,13 +935,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
|
||||
([2, 7, 1],),
|
||||
[4, 14, 2],
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
# TODO: find a way to support this case
|
||||
# https://github.com/zama-ai/concretefhe-internal/issues/837
|
||||
#
|
||||
# the problem is that compiler doesn't support combining scalars and tensors
|
||||
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
|
||||
pytest.param(
|
||||
lambda x: numpy.dot(2, x),
|
||||
{
|
||||
@@ -955,7 +944,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
|
||||
([2, 7, 1],),
|
||||
[4, 14, 2],
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user