diff --git a/compiler/include/zamalang/Conversion/CMakeLists.txt b/compiler/include/zamalang/Conversion/CMakeLists.txt index 35343356a..578527e23 100644 --- a/compiler/include/zamalang/Conversion/CMakeLists.txt +++ b/compiler/include/zamalang/Conversion/CMakeLists.txt @@ -3,4 +3,5 @@ mlir_tablegen(Passes.h.inc -gen-pass-decls -name Conversion) add_public_tablegen_target(MLIRConversionPassIncGen) -add_subdirectory(HLFHEToMidLFHE) \ No newline at end of file +add_subdirectory(HLFHEToMidLFHE) +add_subdirectory(MidLFHEToLowLFHE) \ No newline at end of file diff --git a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/CMakeLists.txt b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/CMakeLists.txt new file mode 100644 index 000000000..cef0de880 --- /dev/null +++ b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Patterns.td) +mlir_tablegen(Patterns.h.inc -gen-rewriters -name MidLFHE) +add_public_tablegen_target(MidLFHEToLowLFHEPatternsIncGen) diff --git a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Pass.h b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Pass.h new file mode 100644 index 000000000..b73d8c572 --- /dev/null +++ b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Pass.h @@ -0,0 +1,14 @@ + +#ifndef ZAMALANG_CONVERSION_MIDLFHETOLOWLFHE_PASS_H_ +#define ZAMALANG_CONVERSION_MIDLFHETOLOWLFHE_PASS_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace zamalang { +/// Create a pass to convert `MidLFHE` dialect to `LowLFHE` dialect. +std::unique_ptr> createConvertMidLFHEToLowLFHEPass(); +} // namespace zamalang +} // namespace mlir + +#endif \ No newline at end of file diff --git a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h new file mode 100644 index 000000000..1efbbead8 --- /dev/null +++ b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h @@ -0,0 +1,116 @@ +#ifndef ZAMALANG_CONVERSION_MIDLFHETOLOWLFHE_PATTERNS_H_ +#define ZAMALANG_CONVERSION_MIDLFHETOLOWLFHE_PATTERNS_H_ + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h" +#include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h" + +namespace mlir { +namespace zamalang { + +using LowLFHE::CleartextType; +using LowLFHE::LweCiphertextType; +using LowLFHE::PlaintextType; +using MidLFHE::GLWECipherTextType; + +LweCiphertextType convertTypeGLWEToLWE(mlir::MLIRContext *context, + GLWECipherTextType &glwe) { + return LweCiphertextType::get(context); +} + +PlaintextType convertIntToPlaintextType(mlir::MLIRContext *context, + IntegerType &type) { + return PlaintextType::get(context, type.getWidth()); +} + +CleartextType convertIntToCleartextType(mlir::MLIRContext *context, + IntegerType &type) { + return CleartextType::get(context, type.getWidth()); +} + +template +mlir::Value createLowLFHEOpFromMidLFHE(mlir::PatternRewriter rewriter, + mlir::Location loc, mlir::Value arg0, + mlir::Value arg1, + mlir::OpResult result) { + mlir::SmallVector args{arg0, arg1}; + mlir::SmallVector attrs; + auto glwe = result.getType().cast(); + mlir::SmallVector resTypes{ + convertTypeGLWEToLWE(rewriter.getContext(), glwe)}; + Operator op = rewriter.create(loc, resTypes, args, attrs); + return op.getODSResults(0).front(); +} + +mlir::Value createAddPlainLweCiphertext(mlir::PatternRewriter rewriter, + mlir::Location loc, mlir::Value arg0, + mlir::Value arg1, + mlir::OpResult result) { + auto integer_type = arg1.getType().cast(); + PlaintextType encoded_type = + convertIntToPlaintextType(rewriter.getContext(), integer_type); + // encode int into plaintext + mlir::Value encoded = + rewriter + .create(loc, encoded_type, arg1) + .plaintext(); + // convert result type + GLWECipherTextType glwe_type = result.getType().cast(); + LweCiphertextType lwe_type = + convertTypeGLWEToLWE(rewriter.getContext(), glwe_type); + // replace op using the encoded plaintext instead of int + auto op = + rewriter.create( + loc, lwe_type, arg0, encoded); + return op.getODSResults(0).front(); +} + +mlir::Value createSubIntLweCiphertext(mlir::PatternRewriter rewriter, + mlir::Location loc, mlir::Value arg0, + mlir::Value arg1, mlir::OpResult result) { + auto arg1_type = arg1.getType().cast(); + auto negated_arg1 = + rewriter + .create( + loc, convertTypeGLWEToLWE(rewriter.getContext(), arg1_type), arg1) + .result(); + return createAddPlainLweCiphertext(rewriter, loc, negated_arg1, arg0, result); +} + +mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter rewriter, + mlir::Location loc, mlir::Value arg0, + mlir::Value arg1, + mlir::OpResult result) { + auto integer_type = arg1.getType().cast(); + CleartextType encoded_type = + convertIntToCleartextType(rewriter.getContext(), integer_type); + // encode int into plaintext + mlir::Value encoded = rewriter + .create( + loc, encoded_type, arg1) + .cleartext(); + // convert result type + GLWECipherTextType glwe_type = result.getType().cast(); + LweCiphertextType lwe_type = + convertTypeGLWEToLWE(rewriter.getContext(), glwe_type); + // replace op using the encoded plaintext instead of int + auto op = + rewriter.create( + loc, lwe_type, arg0, encoded); + return op.getODSResults(0).front(); +} + +} // namespace zamalang +} // namespace mlir + +namespace { +#include "zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h.inc" +} + +void populateWithGeneratedMidLFHEToLowLFHE(mlir::RewritePatternSet &patterns) { + populateWithGenerated(patterns); +} + +#endif diff --git a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.td b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.td new file mode 100644 index 000000000..6157dbce1 --- /dev/null +++ b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.td @@ -0,0 +1,32 @@ +#ifndef ZAMALANG_CONVERSION_MIDLFHETOLOWLFHE_PATTERNS +#define ZAMALANG_CONVERSION_MIDLFHETOLOWLFHE_PATTERNS + +include "mlir/Dialect/StandardOps/IR/Ops.td" +include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.td" +include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td" + +def createAddLWEOp : NativeCodeCall<"mlir::zamalang::createLowLFHEOpFromMidLFHE($_builder, $_loc, $0, $1, $2)">; + +def AddGLWEPattern : Pat< + (AddGLWEOp:$result $arg0, $arg1), + (createAddLWEOp $arg0, $arg1, $result)>; + +def createAddPlainLweOp : NativeCodeCall<"mlir::zamalang::createAddPlainLweCiphertext($_builder, $_loc, $0, $1, $2)">; + +def AddGLWEIntPattern : Pat< + (AddGLWEIntOp:$result $arg0, $arg1), + (createAddPlainLweOp $arg0, $arg1, $result)>; + +def createMulClearLweOp : NativeCodeCall<"mlir::zamalang::createMulClearLweCiphertext($_builder, $_loc, $0, $1, $2)">; + +def MulGLWEIntPattern : Pat< + (MulGLWEIntOp:$result $arg0, $arg1), + (createMulClearLweOp $arg0, $arg1, $result)>; + +def createSubIntLweOp : NativeCodeCall<"mlir::zamalang::createSubIntLweCiphertext($_builder, $_loc, $0, $1, $2)">; + +def SubIntGLWEPattern : Pat< + (SubIntGLWEOp:$result $arg0, $arg1), + (createSubIntLweOp $arg0, $arg1, $result)>; + +#endif diff --git a/compiler/include/zamalang/Conversion/Passes.h b/compiler/include/zamalang/Conversion/Passes.h index 991e4e8fa..e90ed5758 100644 --- a/compiler/include/zamalang/Conversion/Passes.h +++ b/compiler/include/zamalang/Conversion/Passes.h @@ -9,6 +9,7 @@ #include "zamalang/Conversion/HLFHETensorOpsToLinalg/Pass.h" #include "zamalang/Conversion/HLFHEToMidLFHE/Pass.h" #include "zamalang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h" +#include "zamalang/Conversion/MidLFHEToLowLFHE/Pass.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" #define GEN_PASS_CLASSES diff --git a/compiler/include/zamalang/Conversion/Passes.td b/compiler/include/zamalang/Conversion/Passes.td index d0e68e15f..f1f39f203 100644 --- a/compiler/include/zamalang/Conversion/Passes.td +++ b/compiler/include/zamalang/Conversion/Passes.td @@ -17,6 +17,14 @@ def HLFHEToMidLFHE : Pass<"hlfhe-to-midlfhe", "mlir::ModuleOp"> { let dependentDialects = ["mlir::linalg::LinalgDialect"]; } +def MidLFHEToLowLFHE : Pass<"midlfhe-to-lowlfhe", "mlir::ModuleOp"> { + let summary = "Lowers operations from the MidLFHE dialect to LowLFHE"; + let description = [{ Lowers operations from the MidLFHE dialect to LowLFHE }]; + let constructor = "mlir::zamalang::createConvertMidLFHEToLowLFHEPass()"; + let options = []; + let dependentDialects = ["mlir::linalg::LinalgDialect"]; +} + def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> { let summary = "Lowers operations from MLIR lowerable dialects to LLVM"; let constructor = "mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass()"; diff --git a/compiler/lib/Conversion/CMakeLists.txt b/compiler/lib/Conversion/CMakeLists.txt index 75629b843..9de41dd79 100644 --- a/compiler/lib/Conversion/CMakeLists.txt +++ b/compiler/lib/Conversion/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(HLFHEToMidLFHE) +add_subdirectory(MidLFHEToLowLFHE) add_subdirectory(HLFHETensorOpsToLinalg) add_subdirectory(MLIRLowerableDialectsToLLVM) diff --git a/compiler/lib/Conversion/MidLFHEToLowLFHE/CMakeLists.txt b/compiler/lib/Conversion/MidLFHEToLowLFHE/CMakeLists.txt new file mode 100644 index 000000000..bed556503 --- /dev/null +++ b/compiler/lib/Conversion/MidLFHEToLowLFHE/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MidLFHEToLowLFHE + MidLFHEToLowLFHE.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/MidLFHE + + DEPENDS + MidLFHEDialect + LowLFHEDialect + MidLFHEToLowLFHEPatternsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRTransforms + MLIRMath) + +target_link_libraries(MidLFHEToLowLFHE PUBLIC MLIRIR) diff --git a/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp b/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp new file mode 100644 index 000000000..127d85825 --- /dev/null +++ b/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp @@ -0,0 +1,99 @@ +#include + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h" +#include "zamalang/Conversion/Passes.h" +#include "zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h" +#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" +#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" +#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" +#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" + +namespace { +struct MidLFHEToLowLFHEPass + : public MidLFHEToLowLFHEBase { + void runOnOperation() final; +}; +} // namespace + +using mlir::zamalang::LowLFHE::LweCiphertextType; +using mlir::zamalang::MidLFHE::GLWECipherTextType; + +/// MidLFHEToLowLFHETypeConverter is a TypeConverter that transform +/// `MidLFHE.gwle<{_,_,_}{p}>` to LowLFHE.lwe_ciphertext +class MidLFHEToLowLFHETypeConverter : public mlir::TypeConverter { + +public: + MidLFHEToLowLFHETypeConverter() { + addConversion([&](GLWECipherTextType type) { + return mlir::zamalang::convertTypeGLWEToLWE(type.getContext(), type); + }); + addConversion([&](mlir::MemRefType type) { + auto glwe = type.getElementType().dyn_cast_or_null(); + if (glwe == nullptr) { + return (mlir::Type)(type); + } + mlir::Type r = mlir::MemRefType::get( + type.getShape(), + mlir::zamalang::convertTypeGLWEToLWE(glwe.getContext(), glwe), + type.getAffineMaps(), type.getMemorySpace()); + return r; + }); + // [workaround] need these converters to consider those types legal + addConversion([&](mlir::IntegerType type) { return type; }); + addConversion( + [&](mlir::zamalang::LowLFHE::LweCiphertextType type) { return type; }); + } +}; + +void MidLFHEToLowLFHEPass::runOnOperation() { + auto op = this->getOperation(); + + mlir::ConversionTarget target(getContext()); + MidLFHEToLowLFHETypeConverter converter; + + // Mark ops from the target dialect as legal operations + target.addLegalDialect(); + + // Make sure that no ops from `MidLFHE` remain after the lowering + target.addIllegalDialect(); + + // Make sure that no ops `linalg.generic` that have illegal types + target.addDynamicallyLegalOp( + [&](mlir::linalg::GenericOp op) { + return (converter.isLegal(op.getOperandTypes()) && + converter.isLegal(op.getResultTypes()) && + converter.isLegal(op->getRegion(0).front().getArgumentTypes())); + }); + + // Make sure that func has legal signature + target.addDynamicallyLegalOp([](mlir::FuncOp funcOp) { + MidLFHEToLowLFHETypeConverter converter; + return converter.isSignatureLegal(funcOp.getType()) && + converter.isLegal(&funcOp.getBody()); + }); + // Add all patterns required to lower all ops from `MidLFHE` to + // `LowLFHE` + mlir::OwningRewritePatternList patterns(&getContext()); + + populateWithGeneratedMidLFHEToLowLFHE(patterns); + patterns + .add>( + &getContext(), converter); + mlir::populateFuncOpTypeConversionPattern(patterns, converter); + + // Apply conversion + if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { + this->signalPassFailure(); + } +} + +namespace mlir { +namespace zamalang { +std::unique_ptr> createConvertMidLFHEToLowLFHEPass() { + return std::make_unique(); +} +} // namespace zamalang +} // namespace mlir diff --git a/compiler/lib/Support/CompilerTools.cpp b/compiler/lib/Support/CompilerTools.cpp index 43bdc7e30..598d9ee65 100644 --- a/compiler/lib/Support/CompilerTools.cpp +++ b/compiler/lib/Support/CompilerTools.cpp @@ -37,6 +37,8 @@ mlir::LogicalResult CompilerTools::lowerHLFHEToMlirLLVMDialect( pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass); addFilteredPassToPassManager( pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass); + addFilteredPassToPassManager( + pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), enablePass); addFilteredPassToPassManager( pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(), enablePass); diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir new file mode 100644 index 000000000..45f1c8d3e --- /dev/null +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir @@ -0,0 +1,10 @@ +// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @add_glwe(%arg0: !LowLFHE.lwe_ciphertext, %arg1: !LowLFHE.lwe_ciphertext) -> !LowLFHE.lwe_ciphertext +func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { + // CHECK-NEXT: %[[V1:.*]] = "LowLFHE.add_lwe_ciphertexts"(%arg0, %arg1) : (!LowLFHE.lwe_ciphertext, !LowLFHE.lwe_ciphertext) -> !LowLFHE.lwe_ciphertext + // CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext + + %0 = "MidLFHE.add_glwe"(%arg0, %arg1): (!MidLFHE.glwe<{1024,12,64}{7}>, !MidLFHE.glwe<{1024,12,64}{7}>) -> (!MidLFHE.glwe<{1024,12,64}{7}>) + return %0: !MidLFHE.glwe<{1024,12,64}{7}> +} diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir new file mode 100644 index 000000000..bbb921749 --- /dev/null +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir @@ -0,0 +1,22 @@ +// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @add_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext) -> !LowLFHE.lwe_ciphertext +func @add_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { + // CHECK-NEXT: %[[V1:.*]] = constant 1 : i8 + // CHECK-NEXT: %[[V2:.*]] = "LowLFHE.encode_int"(%[[V1]]) : (i8) -> !LowLFHE.plaintext<8> + // CHECK-NEXT: %[[V3:.*]] = "LowLFHE.add_plaintext_lwe_ciphertext"(%arg0, %[[V2]]) : (!LowLFHE.lwe_ciphertext, !LowLFHE.plaintext<8>) -> !LowLFHE.lwe_ciphertext + // CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext + %0 = constant 1 : i8 + %1 = "MidLFHE.add_glwe_int"(%arg0, %0): (!MidLFHE.glwe<{1024,1,64}{7}>, i8) -> (!MidLFHE.glwe<{1024,1,64}{7}>) + return %1: !MidLFHE.glwe<{1024,1,64}{7}> +} + + +// CHECK-LABEL: func @add_glwe_int(%arg0: !LowLFHE.lwe_ciphertext, %arg1: i5) -> !LowLFHE.lwe_ciphertext +func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !MidLFHE.glwe<{1024,1,64}{4}> { + // CHECK-NEXT: %[[V1:.*]] = "LowLFHE.encode_int"(%arg1) : (i5) -> !LowLFHE.plaintext<5> + // CHECK-NEXT: %[[V2:.*]] = "LowLFHE.add_plaintext_lwe_ciphertext"(%arg0, %[[V1]]) : (!LowLFHE.lwe_ciphertext, !LowLFHE.plaintext<5>) -> !LowLFHE.lwe_ciphertext + // CHECK-NEXT: return %[[V2]] : !LowLFHE.lwe_ciphertext + %1 = "MidLFHE.add_glwe_int"(%arg0, %arg1): (!MidLFHE.glwe<{1024,1,64}{4}>, i5) -> (!MidLFHE.glwe<{1024,1,64}{4}>) + return %1: !MidLFHE.glwe<{1024,1,64}{4}> +} \ No newline at end of file diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir new file mode 100644 index 000000000..090316780 --- /dev/null +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir @@ -0,0 +1,22 @@ +// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @mul_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext) -> !LowLFHE.lwe_ciphertext +func @mul_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { + // CHECK-NEXT: %[[V1:.*]] = constant 1 : i8 + // CHECK-NEXT: %[[V2:.*]] = "LowLFHE.int_to_cleartext"(%[[V1]]) : (i8) -> !LowLFHE.cleartext<8> + // CHECK-NEXT: %[[V3:.*]] = "LowLFHE.mul_cleartext_lwe_ciphertext"(%arg0, %[[V2]]) : (!LowLFHE.lwe_ciphertext, !LowLFHE.cleartext<8>) -> !LowLFHE.lwe_ciphertext + // CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext + %0 = constant 1 : i8 + %1 = "MidLFHE.mul_glwe_int"(%arg0, %0): (!MidLFHE.glwe<{1024,1,64}{7}>, i8) -> (!MidLFHE.glwe<{1024,1,64}{7}>) + return %1: !MidLFHE.glwe<{1024,1,64}{7}> +} + + +// CHECK-LABEL: func @mul_glwe_int(%arg0: !LowLFHE.lwe_ciphertext, %arg1: i5) -> !LowLFHE.lwe_ciphertext +func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !MidLFHE.glwe<{1024,1,64}{4}> { + // CHECK-NEXT: %[[V1:.*]] = "LowLFHE.int_to_cleartext"(%arg1) : (i5) -> !LowLFHE.cleartext<5> + // CHECK-NEXT: %[[V2:.*]] = "LowLFHE.mul_cleartext_lwe_ciphertext"(%arg0, %[[V1]]) : (!LowLFHE.lwe_ciphertext, !LowLFHE.cleartext<5>) -> !LowLFHE.lwe_ciphertext + // CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext + %1 = "MidLFHE.mul_glwe_int"(%arg0, %arg1): (!MidLFHE.glwe<{1024,1,64}{4}>, i5) -> (!MidLFHE.glwe<{1024,1,64}{4}>) + return %1: !MidLFHE.glwe<{1024,1,64}{4}> +} \ No newline at end of file diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir new file mode 100644 index 000000000..c5936cfe9 --- /dev/null +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir @@ -0,0 +1,23 @@ +// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @sub_const_int_glwe(%arg0: !LowLFHE.lwe_ciphertext) -> !LowLFHE.lwe_ciphertext +func @sub_const_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { + // CHECK-NEXT: %[[V1:.*]] = constant 1 : i8 + // CHECK-NEXT: %[[NEG:.*]] = "LowLFHE.negate_lwe_ciphertext"(%arg0) : (!LowLFHE.lwe_ciphertext) -> !LowLFHE.lwe_ciphertext + // CHECK-NEXT: %[[V2:.*]] = "LowLFHE.encode_int"(%[[V1]]) : (i8) -> !LowLFHE.plaintext<8> + // CHECK-NEXT: %[[V3:.*]] = "LowLFHE.add_plaintext_lwe_ciphertext"(%[[NEG]], %[[V2]]) : (!LowLFHE.lwe_ciphertext, !LowLFHE.plaintext<8>) -> !LowLFHE.lwe_ciphertext + // CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext + %0 = constant 1 : i8 + %1 = "MidLFHE.sub_int_glwe"(%0, %arg0): (i8, !MidLFHE.glwe<{1024,1,64}{7}>) -> (!MidLFHE.glwe<{1024,1,64}{7}>) + return %1: !MidLFHE.glwe<{1024,1,64}{7}> +} + +// CHECK-LABEL: func @sub_int_glwe(%arg0: !LowLFHE.lwe_ciphertext, %arg1: i5) -> !LowLFHE.lwe_ciphertext +func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !MidLFHE.glwe<{1024,1,64}{4}> { + // CHECK-NEXT: %[[NEG:.*]] = "LowLFHE.negate_lwe_ciphertext"(%arg0) : (!LowLFHE.lwe_ciphertext) -> !LowLFHE.lwe_ciphertext + // CHECK-NEXT: %[[V1:.*]] = "LowLFHE.encode_int"(%arg1) : (i5) -> !LowLFHE.plaintext<5> + // CHECK-NEXT: %[[V2:.*]] = "LowLFHE.add_plaintext_lwe_ciphertext"(%[[NEG]], %[[V1]]) : (!LowLFHE.lwe_ciphertext, !LowLFHE.plaintext<5>) -> !LowLFHE.lwe_ciphertext + // CHECK-NEXT: return %[[V2]] : !LowLFHE.lwe_ciphertext + %1 = "MidLFHE.sub_int_glwe"(%arg1, %arg0): (i5, !MidLFHE.glwe<{1024,1,64}{4}>) -> (!MidLFHE.glwe<{1024,1,64}{4}>) + return %1: !MidLFHE.glwe<{1024,1,64}{4}> +} \ No newline at end of file