fix(compiler): Since the patterns could be call in any order take care of already lowered types in mid to low conversion

This commit is contained in:
Quentin Bourgerie
2021-08-25 14:07:42 +02:00
parent 697d4033e1
commit 15fd194075

View File

@@ -15,22 +15,66 @@ using LowLFHE::LweCiphertextType;
using LowLFHE::PlaintextType;
using MidLFHE::GLWECipherTextType;
LweCiphertextType convertTypeGLWEToLWE(mlir::MLIRContext *context,
GLWECipherTextType &glwe) {
return LweCiphertextType::get(
context, glwe.getDimension() * glwe.getPolynomialSize(), glwe.getP());
LweCiphertextType convertTypeToLWE(mlir::MLIRContext *context,
mlir::Type type) {
auto glwe = type.dyn_cast_or_null<GLWECipherTextType>();
if (glwe != nullptr) {
return LweCiphertextType::get(
context, glwe.getDimension() * glwe.getPolynomialSize(), glwe.getP());
}
auto lwe = type.dyn_cast_or_null<LweCiphertextType>();
if (lwe != nullptr) {
return lwe;
}
assert(false && "expect glwe or lwe");
}
PlaintextType convertPlaintextTypeFromGlwe(mlir::MLIRContext *context,
GLWECipherTextType &type) {
template <typename PType>
PlaintextType convertPlaintextTypeFromPType(mlir::MLIRContext *context,
PType &type) {
return PlaintextType::get(context, type.getP() + 1);
}
CleartextType convertCleartextTypeFromGlwe(mlir::MLIRContext *context,
GLWECipherTextType &type) {
// convertPlaintextTypeFromType create a plaintext type according the
// precision of the given type argument. The type should be a GLWECipherText
// (if operand is not yet lowered) or a LWECipherTextType (if operand is
// already lowered).
PlaintextType convertPlaintextTypeFromType(mlir::MLIRContext *context,
mlir::Type &type) {
auto glwe = type.dyn_cast_or_null<GLWECipherTextType>();
if (glwe != nullptr) {
return convertPlaintextTypeFromPType<GLWECipherTextType>(context, glwe);
}
auto lwe = type.dyn_cast_or_null<LweCiphertextType>();
if (lwe != nullptr) {
return convertPlaintextTypeFromPType<LweCiphertextType>(context, lwe);
}
assert(false && "expect glwe or lwe");
}
template <typename PType>
CleartextType convertCleartextTypeFromPType(mlir::MLIRContext *context,
PType &type) {
return CleartextType::get(context, type.getP() + 1);
}
// convertCleartextTypeFromType create a cleartext type according the
// precision of the given type argument. The type should be a GLWECipherText
// (if operand is not yet lowered) or a LWECipherTextType (if operand is
// already lowered).
CleartextType convertCleartextTypeFromType(mlir::MLIRContext *context,
mlir::Type &type) {
auto glwe = type.dyn_cast_or_null<GLWECipherTextType>();
if (glwe != nullptr) {
return convertCleartextTypeFromPType<GLWECipherTextType>(context, glwe);
}
auto lwe = type.dyn_cast_or_null<LweCiphertextType>();
if (lwe != nullptr) {
return convertCleartextTypeFromPType<LweCiphertextType>(context, lwe);
}
assert(false && "expect glwe or lwe");
}
template <class Operator>
mlir::Value createLowLFHEOpFromMidLFHE(mlir::PatternRewriter rewriter,
mlir::Location loc, mlir::Value arg0,
@@ -40,16 +84,16 @@ mlir::Value createLowLFHEOpFromMidLFHE(mlir::PatternRewriter rewriter,
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
auto glwe = result.getType().cast<GLWECipherTextType>();
mlir::SmallVector<mlir::Type, 1> resTypes{
convertTypeGLWEToLWE(rewriter.getContext(), glwe)};
convertTypeToLWE(rewriter.getContext(), glwe)};
Operator op = rewriter.create<Operator>(loc, resTypes, args, attrs);
return op.getODSResults(0).front();
}
mlir::Value createAddPlainLweCiphertextWithGlwe(
mlir::PatternRewriter rewriter, mlir::Location loc, mlir::Value arg0,
mlir::Value arg1, mlir::OpResult result, GLWECipherTextType glwe) {
mlir::Value arg1, mlir::OpResult result, mlir::Type encryptedType) {
PlaintextType encoded_type =
convertPlaintextTypeFromGlwe(rewriter.getContext(), glwe);
convertPlaintextTypeFromType(rewriter.getContext(), encryptedType);
// encode int into plaintext
mlir::Value encoded =
rewriter
@@ -58,7 +102,7 @@ mlir::Value createAddPlainLweCiphertextWithGlwe(
// convert result type
GLWECipherTextType glwe_type = result.getType().cast<GLWECipherTextType>();
LweCiphertextType lwe_type =
convertTypeGLWEToLWE(rewriter.getContext(), glwe_type);
convertTypeToLWE(rewriter.getContext(), result.getType());
// replace op using the encoded plaintext instead of int
auto op =
rewriter.create<mlir::zamalang::LowLFHE::AddPlaintextLweCiphertextOp>(
@@ -78,11 +122,11 @@ mlir::Value createAddPlainLweCiphertext(mlir::PatternRewriter rewriter,
mlir::Value createSubIntLweCiphertext(mlir::PatternRewriter rewriter,
mlir::Location loc, mlir::Value arg0,
mlir::Value arg1, mlir::OpResult result) {
auto arg1_type = arg1.getType().cast<GLWECipherTextType>();
auto arg1_type = arg1.getType();
auto negated_arg1 =
rewriter
.create<mlir::zamalang::LowLFHE::NegateLweCiphertextOp>(
loc, convertTypeGLWEToLWE(rewriter.getContext(), arg1_type), arg1)
loc, convertTypeToLWE(rewriter.getContext(), arg1_type), arg1)
.result();
return createAddPlainLweCiphertextWithGlwe(rewriter, loc, negated_arg1, arg0,
result, arg1_type);
@@ -92,18 +136,17 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter rewriter,
mlir::Location loc, mlir::Value arg0,
mlir::Value arg1,
mlir::OpResult result) {
auto glwe = arg0.getType().cast<GLWECipherTextType>();
auto inType = arg0.getType();
CleartextType encoded_type =
convertCleartextTypeFromGlwe(rewriter.getContext(), glwe);
convertCleartextTypeFromType(rewriter.getContext(), inType);
// encode int into plaintext
mlir::Value encoded = rewriter
.create<mlir::zamalang::LowLFHE::IntToCleartextOp>(
loc, encoded_type, arg1)
.cleartext();
// convert result type
GLWECipherTextType glwe_type = result.getType().cast<GLWECipherTextType>();
LweCiphertextType lwe_type =
convertTypeGLWEToLWE(rewriter.getContext(), glwe_type);
auto resType = result.getType();
LweCiphertextType lwe_type = convertTypeToLWE(rewriter.getContext(), resType);
// replace op using the encoded plaintext instead of int
auto op =
rewriter.create<mlir::zamalang::LowLFHE::MulCleartextLweCiphertextOp>(