From d4e5831a57754719eca1bd8a5ce9f5827a088c59 Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 1 Nov 2021 18:11:26 +0300 Subject: [PATCH] fix(representation): add shape equality check to dot product node --- concrete/common/representation/intermediate.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index 949bef2d1..051fbd8e2 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -408,10 +408,16 @@ class Dot(IntermediateNode): f"Dot only supports two vectors ({TensorValue.__name__} with ndim == 1)", ) + lhs = cast(TensorValue, self.inputs[0]) + rhs = cast(TensorValue, self.inputs[1]) + + assert_true( + lhs.shape[0] == rhs.shape[0], + f"Dot between vectors of shapes {lhs.shape} and {rhs.shape} is not supported", + ) + output_scalar_value = ( - EncryptedScalar - if (self.inputs[0].is_encrypted or self.inputs[1].is_encrypted) - else ClearScalar + EncryptedScalar if (lhs.is_encrypted or rhs.is_encrypted) else ClearScalar ) self.outputs = [output_scalar_value(output_dtype)] @@ -447,13 +453,12 @@ class MatMul(IntermediateNode): f"MatMul only supports two matrices ({TensorValue.__name__} with ndim == 2)", ) - # regular assertions are for mypy to see the inputs are TensorValue lhs = cast(TensorValue, self.inputs[0]) rhs = cast(TensorValue, self.inputs[1]) assert_true( lhs.shape[1] == rhs.shape[0], - f"MatMul between matrices of shapes {lhs.shape} and {rhs.shape} " f"is not supported", + f"MatMul between matrices of shapes {lhs.shape} and {rhs.shape} is not supported", ) output_shape = (lhs.shape[0], rhs.shape[1])