diff --git a/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.h b/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.h index 2a4230e20..2c9cbf319 100644 --- a/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.h +++ b/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.h @@ -47,6 +47,20 @@ mlir::Value createGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter, return op.getODSResults(0).front(); } +template +mlir::Value createGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter, + mlir::Location loc, mlir::Value arg0, + mlir::OpResult result) { + mlir::SmallVector args{arg0}; + mlir::SmallVector attrs; + auto eint = + result.getType().cast(); + mlir::SmallVector resTypes{ + convertTypeEncryptedIntegerToGLWE(rewriter.getContext(), eint)}; + Operator op = rewriter.create(loc, resTypes, args, attrs); + return op.getODSResults(0).front(); +} + mlir::Value createApplyLookupTableGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Value arg0, diff --git a/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.td b/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.td index ab5ce47bc..5e9143380 100644 --- a/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.td +++ b/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.td @@ -28,6 +28,12 @@ def SubIntEintPattern : Pat< (SubIntEintOp:$result $arg0, $arg1), (createSubIntGLWEOp $arg0, $arg1, $result)>; +def createNegGLWEOp : NativeCodeCall<"mlir::zamalang::createGLWEOpFromHLFHE($_builder, $_loc, $0, $1)">; + +def NegEintPattern : Pat< + (NegEintOp:$result $arg0), + (createNegGLWEOp $arg0, $result)>; + def createMulGLWEIntOp : NativeCodeCall<"mlir::zamalang::createGLWEOpFromHLFHE($_builder, $_loc, $0, $1, $2)">; def MulEintIntPattern : Pat< diff --git a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h index b5daad82d..749ad29d1 100644 --- a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h +++ b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h @@ -146,6 +146,16 @@ mlir::Value createSubIntLweCiphertext(mlir::PatternRewriter &rewriter, result, arg1_type); } +mlir::Value createNegLweCiphertext(mlir::PatternRewriter &rewriter, + mlir::Location loc, mlir::Value arg0, + mlir::OpResult result) { + auto arg0_type = arg0.getType(); + auto negated = + rewriter.create( + loc, convertTypeToLWE(rewriter.getContext(), arg0_type), arg0); + return negated.getODSResults(0).front(); +} + mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Value arg0, mlir::Value arg1, diff --git a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.td b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.td index e57b209a3..8b3e06e7f 100644 --- a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.td +++ b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.td @@ -35,6 +35,12 @@ def SubIntGLWEPattern : Pat< (SubIntGLWEOp:$result $arg0, $arg1), (createSubIntLweOp $arg0, $arg1, $result)>; +def createNegLweOp : NativeCodeCall<"mlir::zamalang::createNegLweCiphertext($_builder, $_loc, $0, $1)">; + +def NegGLWEPattern : Pat< + (NegGLWEOp:$result $arg0), + (createNegLweOp $arg0, $result)>; + def createPBS : NativeCodeCall<"mlir::zamalang::createPBS($_builder, $_loc, $0, $1, $2, $3, $4, $5, $6, $7, $8, $9)">; def ApplyLookupTableGLWEPattern : Pat< diff --git a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp index 6e80c4196..0cb432010 100644 --- a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp +++ b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp @@ -256,6 +256,8 @@ void populateWithMidLFHEOpTypeConversionPatterns( mlir::zamalang::MidLFHE::AddGLWEOp>(patterns, target, typeConverter); populateWithMidLFHEOpTypeConversionPattern< mlir::zamalang::MidLFHE::SubIntGLWEOp>(patterns, target, typeConverter); + populateWithMidLFHEOpTypeConversionPattern< + mlir::zamalang::MidLFHE::NegGLWEOp>(patterns, target, typeConverter); populateWithMidLFHEOpTypeConversionPattern< mlir::zamalang::MidLFHE::MulGLWEIntOp>(patterns, target, typeConverter); populateWithMidLFHEApplyLookupTableParametrizationPattern( diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/neg_eint.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/neg_eint.mlir new file mode 100644 index 000000000..134d3fd8f --- /dev/null +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/neg_eint.mlir @@ -0,0 +1,10 @@ +// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s + +// CHECK-LABEL: func @neg_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> +func @neg_eint(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { + // CHECK-NEXT: %[[V1:.*]] = "MidLFHE.neg_glwe"(%arg0) : (!MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> + // CHECK-NEXT: return %[[V1]] : !MidLFHE.glwe<{_,_,_}{7}> + + %1 = "HLFHE.neg_eint"(%arg0): (!HLFHE.eint<7>) -> (!HLFHE.eint<7>) + return %1: !HLFHE.eint<7> +} diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/neg_glwe.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/neg_glwe.mlir new file mode 100644 index 000000000..76e3dfb3b --- /dev/null +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/neg_glwe.mlir @@ -0,0 +1,9 @@ +// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @neg_glwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> +func @neg_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>) -> !MidLFHE.glwe<{1024,1,64}{4}> { + // CHECK-NEXT: %[[V1:.*]] = "LowLFHE.negate_lwe_ciphertext"(%arg0) : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> + // CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext<1024,4> + %1 = "MidLFHE.neg_glwe"(%arg0): (!MidLFHE.glwe<{1024,1,64}{4}>) -> (!MidLFHE.glwe<{1024,1,64}{4}>) + return %1: !MidLFHE.glwe<{1024,1,64}{4}> +}