From ed7dd36e7040e7f8ea653329006d33f9a1b49042 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Fri, 13 Aug 2021 11:33:49 +0200 Subject: [PATCH] fix(compiler): Remove the type conversion workaround, actually the order of addConversion calls matters (lifo) --- .../HLFHEToMidLFHE/HLFHEToMidLFHE.cpp | 37 +++---------------- .../MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp | 8 +--- 2 files changed, 8 insertions(+), 37 deletions(-) diff --git a/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp b/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp index acb31956f..64dd6b484 100644 --- a/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp +++ b/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp @@ -26,11 +26,12 @@ class HLFHEToMidLFHETypeConverter : public mlir::TypeConverter { public: HLFHEToMidLFHETypeConverter() { - addConversion([&](EncryptedIntegerType type) { + addConversion([](mlir::Type type) { return type; }); + addConversion([](EncryptedIntegerType type) { return mlir::zamalang::convertTypeEncryptedIntegerToGLWE( type.getContext(), type); }); - addConversion([&](mlir::MemRefType type) { + addConversion([](mlir::MemRefType type) { auto eint = type.getElementType().dyn_cast_or_null(); if (eint == nullptr) { @@ -44,25 +45,6 @@ public: return r; }); } - - /// [workaround] as converter.isLegal returns unexpected false for glwe with - /// same parameters. - static bool _isLegal(mlir::Type type) { - if (type.isa()) { - return false; - } - auto memref = type.dyn_cast_or_null(); - if (memref != nullptr) { - return _isLegal(memref.getElementType()); - } - return true; - } - - // [workaround] - template static bool _isLegal(TypeRangeT &&types) { - return llvm::all_of(types, - [&](const mlir::Type ty) { return _isLegal(ty); }); - } }; void HLFHEToMidLFHEPass::runOnOperation() { @@ -86,16 +68,9 @@ void HLFHEToMidLFHEPass::runOnOperation() { }); // Make sure that func has legal signature - target.addDynamicallyLegalOp([](mlir::FuncOp funcOp) { - HLFHEToMidLFHETypeConverter converter; - // [workaround] should be this commented code but for an unknown reasons - // converter.isLegal returns false for glwe with same parameters. - // - // return converter.isSignatureLegal(op.getType()) && - // converter.isLegal(&op.getBody()); - auto funcType = funcOp.getType(); - return HLFHEToMidLFHETypeConverter::_isLegal(funcType.getInputs()) && - HLFHEToMidLFHETypeConverter::_isLegal(funcType.getResults()); + target.addDynamicallyLegalOp([&](mlir::FuncOp funcOp) { + return converter.isSignatureLegal(funcOp.getType()) && + converter.isLegal(&funcOp.getBody()); }); // Add all patterns required to lower all ops from `HLFHE` to // `MidLFHE` diff --git a/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp b/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp index 127d85825..747b9b2cb 100644 --- a/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp +++ b/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp @@ -27,6 +27,7 @@ class MidLFHEToLowLFHETypeConverter : public mlir::TypeConverter { public: MidLFHEToLowLFHETypeConverter() { + addConversion([](mlir::Type type) { return type; }); addConversion([&](GLWECipherTextType type) { return mlir::zamalang::convertTypeGLWEToLWE(type.getContext(), type); }); @@ -41,10 +42,6 @@ public: type.getAffineMaps(), type.getMemorySpace()); return r; }); - // [workaround] need these converters to consider those types legal - addConversion([&](mlir::IntegerType type) { return type; }); - addConversion( - [&](mlir::zamalang::LowLFHE::LweCiphertextType type) { return type; }); } }; @@ -69,8 +66,7 @@ void MidLFHEToLowLFHEPass::runOnOperation() { }); // Make sure that func has legal signature - target.addDynamicallyLegalOp([](mlir::FuncOp funcOp) { - MidLFHEToLowLFHETypeConverter converter; + target.addDynamicallyLegalOp([&](mlir::FuncOp funcOp) { return converter.isSignatureLegal(funcOp.getType()) && converter.isLegal(&funcOp.getBody()); });