mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(representation): add shape equality check to dot product node
This commit is contained in:
@@ -408,10 +408,16 @@ class Dot(IntermediateNode):
|
||||
f"Dot only supports two vectors ({TensorValue.__name__} with ndim == 1)",
|
||||
)
|
||||
|
||||
lhs = cast(TensorValue, self.inputs[0])
|
||||
rhs = cast(TensorValue, self.inputs[1])
|
||||
|
||||
assert_true(
|
||||
lhs.shape[0] == rhs.shape[0],
|
||||
f"Dot between vectors of shapes {lhs.shape} and {rhs.shape} is not supported",
|
||||
)
|
||||
|
||||
output_scalar_value = (
|
||||
EncryptedScalar
|
||||
if (self.inputs[0].is_encrypted or self.inputs[1].is_encrypted)
|
||||
else ClearScalar
|
||||
EncryptedScalar if (lhs.is_encrypted or rhs.is_encrypted) else ClearScalar
|
||||
)
|
||||
|
||||
self.outputs = [output_scalar_value(output_dtype)]
|
||||
@@ -447,13 +453,12 @@ class MatMul(IntermediateNode):
|
||||
f"MatMul only supports two matrices ({TensorValue.__name__} with ndim == 2)",
|
||||
)
|
||||
|
||||
# regular assertions are for mypy to see the inputs are TensorValue
|
||||
lhs = cast(TensorValue, self.inputs[0])
|
||||
rhs = cast(TensorValue, self.inputs[1])
|
||||
|
||||
assert_true(
|
||||
lhs.shape[1] == rhs.shape[0],
|
||||
f"MatMul between matrices of shapes {lhs.shape} and {rhs.shape} " f"is not supported",
|
||||
f"MatMul between matrices of shapes {lhs.shape} and {rhs.shape} is not supported",
|
||||
)
|
||||
|
||||
output_shape = (lhs.shape[0], rhs.shape[1])
|
||||
|
||||
Reference in New Issue
Block a user