mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(compiler): lower neg_eint from HLFHE to LowLFHE
This commit is contained in:
@@ -47,6 +47,20 @@ mlir::Value createGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter,
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
|
||||
template <class Operator>
|
||||
mlir::Value createGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::OpResult result) {
|
||||
mlir::SmallVector<mlir::Value, 1> args{arg0};
|
||||
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
|
||||
auto eint =
|
||||
result.getType().cast<mlir::zamalang::HLFHE::EncryptedIntegerType>();
|
||||
mlir::SmallVector<mlir::Type, 1> resTypes{
|
||||
convertTypeEncryptedIntegerToGLWE(rewriter.getContext(), eint)};
|
||||
Operator op = rewriter.create<Operator>(loc, resTypes, args, attrs);
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
|
||||
mlir::Value
|
||||
createApplyLookupTableGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
|
||||
@@ -28,6 +28,12 @@ def SubIntEintPattern : Pat<
|
||||
(SubIntEintOp:$result $arg0, $arg1),
|
||||
(createSubIntGLWEOp $arg0, $arg1, $result)>;
|
||||
|
||||
def createNegGLWEOp : NativeCodeCall<"mlir::zamalang::createGLWEOpFromHLFHE<mlir::zamalang::MidLFHE::NegGLWEOp>($_builder, $_loc, $0, $1)">;
|
||||
|
||||
def NegEintPattern : Pat<
|
||||
(NegEintOp:$result $arg0),
|
||||
(createNegGLWEOp $arg0, $result)>;
|
||||
|
||||
def createMulGLWEIntOp : NativeCodeCall<"mlir::zamalang::createGLWEOpFromHLFHE<mlir::zamalang::MidLFHE::MulGLWEIntOp>($_builder, $_loc, $0, $1, $2)">;
|
||||
|
||||
def MulEintIntPattern : Pat<
|
||||
|
||||
@@ -146,6 +146,16 @@ mlir::Value createSubIntLweCiphertext(mlir::PatternRewriter &rewriter,
|
||||
result, arg1_type);
|
||||
}
|
||||
|
||||
mlir::Value createNegLweCiphertext(mlir::PatternRewriter &rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::OpResult result) {
|
||||
auto arg0_type = arg0.getType();
|
||||
auto negated =
|
||||
rewriter.create<mlir::zamalang::LowLFHE::NegateLweCiphertextOp>(
|
||||
loc, convertTypeToLWE(rewriter.getContext(), arg0_type), arg0);
|
||||
return negated.getODSResults(0).front();
|
||||
}
|
||||
|
||||
mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::Value arg1,
|
||||
|
||||
@@ -35,6 +35,12 @@ def SubIntGLWEPattern : Pat<
|
||||
(SubIntGLWEOp:$result $arg0, $arg1),
|
||||
(createSubIntLweOp $arg0, $arg1, $result)>;
|
||||
|
||||
def createNegLweOp : NativeCodeCall<"mlir::zamalang::createNegLweCiphertext($_builder, $_loc, $0, $1)">;
|
||||
|
||||
def NegGLWEPattern : Pat<
|
||||
(NegGLWEOp:$result $arg0),
|
||||
(createNegLweOp $arg0, $result)>;
|
||||
|
||||
def createPBS : NativeCodeCall<"mlir::zamalang::createPBS($_builder, $_loc, $0, $1, $2, $3, $4, $5, $6, $7, $8, $9)">;
|
||||
|
||||
def ApplyLookupTableGLWEPattern : Pat<
|
||||
|
||||
@@ -256,6 +256,8 @@ void populateWithMidLFHEOpTypeConversionPatterns(
|
||||
mlir::zamalang::MidLFHE::AddGLWEOp>(patterns, target, typeConverter);
|
||||
populateWithMidLFHEOpTypeConversionPattern<
|
||||
mlir::zamalang::MidLFHE::SubIntGLWEOp>(patterns, target, typeConverter);
|
||||
populateWithMidLFHEOpTypeConversionPattern<
|
||||
mlir::zamalang::MidLFHE::NegGLWEOp>(patterns, target, typeConverter);
|
||||
populateWithMidLFHEOpTypeConversionPattern<
|
||||
mlir::zamalang::MidLFHE::MulGLWEIntOp>(patterns, target, typeConverter);
|
||||
populateWithMidLFHEApplyLookupTableParametrizationPattern(
|
||||
|
||||
10
compiler/tests/Conversion/HLFHEToMidLFHE/neg_eint.mlir
Normal file
10
compiler/tests/Conversion/HLFHEToMidLFHE/neg_eint.mlir
Normal file
@@ -0,0 +1,10 @@
|
||||
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @neg_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
|
||||
func @neg_eint(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "MidLFHE.neg_glwe"(%arg0) : (!MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: return %[[V1]] : !MidLFHE.glwe<{_,_,_}{7}>
|
||||
|
||||
%1 = "HLFHE.neg_eint"(%arg0): (!HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
return %1: !HLFHE.eint<7>
|
||||
}
|
||||
9
compiler/tests/Conversion/MidLFHEToLowLFHE/neg_glwe.mlir
Normal file
9
compiler/tests/Conversion/MidLFHEToLowLFHE/neg_glwe.mlir
Normal file
@@ -0,0 +1,9 @@
|
||||
// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @neg_glwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
func @neg_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>) -> !MidLFHE.glwe<{1024,1,64}{4}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "LowLFHE.negate_lwe_ciphertext"(%arg0) : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext<1024,4>
|
||||
%1 = "MidLFHE.neg_glwe"(%arg0): (!MidLFHE.glwe<{1024,1,64}{4}>) -> (!MidLFHE.glwe<{1024,1,64}{4}>)
|
||||
return %1: !MidLFHE.glwe<{1024,1,64}{4}>
|
||||
}
|
||||
Reference in New Issue
Block a user