fix(representation): add shape equality check to dot product node

This commit is contained in:
Umut
2021-11-01 18:11:26 +03:00
parent 8123a5ef45
commit d4e5831a57

View File

@@ -408,10 +408,16 @@ class Dot(IntermediateNode):
f"Dot only supports two vectors ({TensorValue.__name__} with ndim == 1)",
)
lhs = cast(TensorValue, self.inputs[0])
rhs = cast(TensorValue, self.inputs[1])
assert_true(
lhs.shape[0] == rhs.shape[0],
f"Dot between vectors of shapes {lhs.shape} and {rhs.shape} is not supported",
)
output_scalar_value = (
EncryptedScalar
if (self.inputs[0].is_encrypted or self.inputs[1].is_encrypted)
else ClearScalar
EncryptedScalar if (lhs.is_encrypted or rhs.is_encrypted) else ClearScalar
)
self.outputs = [output_scalar_value(output_dtype)]
@@ -447,13 +453,12 @@ class MatMul(IntermediateNode):
f"MatMul only supports two matrices ({TensorValue.__name__} with ndim == 2)",
)
# regular assertions are for mypy to see the inputs are TensorValue
lhs = cast(TensorValue, self.inputs[0])
rhs = cast(TensorValue, self.inputs[1])
assert_true(
lhs.shape[1] == rhs.shape[0],
f"MatMul between matrices of shapes {lhs.shape} and {rhs.shape} " f"is not supported",
f"MatMul between matrices of shapes {lhs.shape} and {rhs.shape} is not supported",
)
output_shape = (lhs.shape[0], rhs.shape[1])