From 3f230957cb1ded82f73fbbe1582b121c248d5eeb Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Fri, 24 Mar 2023 10:44:38 +0000 Subject: [PATCH] feat(compiler): add batched operations for all levelled ops. --- .../Dialect/Concrete/IR/ConcreteOps.td | 106 +++++++++++ .../concretelang/Dialect/TFHE/IR/TFHEOps.td | 169 +++++++++++++++++- .../include/concretelang/Runtime/wrappers.h | 48 +++++ .../ConcreteToCAPI/ConcreteToCAPI.cpp | 52 ++++++ .../TFHEToConcrete/TFHEToConcrete.cpp | 24 ++- .../BufferizableOpInterfaceImpl.cpp | 28 +++ .../compiler/lib/Runtime/wrappers.cpp | 96 ++++++++++ 7 files changed, 517 insertions(+), 6 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td index e632ef59e..56e9f9201 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td @@ -26,6 +26,7 @@ def Concrete_CrtLutsTensor : 2DTensorOf<[I64]>; def Concrete_CrtPlaintextTensor : 1DTensorOf<[I64]>; def Concrete_LweCRTTensor : 2DTensorOf<[I64]>; def Concrete_BatchLweTensor : 2DTensorOf<[I64]>; +def Concrete_BatchPlaintextTensor : 1DTensorOf<[I64]>; def Concrete_LweBuffer : MemRefRankOf<[I64], [1]>; def Concrete_LutBuffer : MemRefRankOf<[I64], [1]>; @@ -33,6 +34,7 @@ def Concrete_CrtLutsBuffer : MemRefRankOf<[I64], [2]>; def Concrete_CrtPlaintextBuffer : MemRefRankOf<[I64], [1]>; def Concrete_LweCRTBuffer : MemRefRankOf<[I64], [2]>; def Concrete_BatchLweBuffer : MemRefRankOf<[I64], [2]>; +def Concrete_BatchPlaintextBuffer : MemRefRankOf<[I64], [1]>; class Concrete_Op traits = []> : Op; @@ -58,6 +60,26 @@ def Concrete_AddLweBufferOp : Concrete_Op<"add_lwe_buffer"> { ); } +def Concrete_BatchedAddLweTensorOp : Concrete_Op<"batched_add_lwe_tensor", [Pure]> { + let summary = "Batched version of AddLweTensorOp, which performs the same operation on multiple elements"; + + let arguments = (ins + Concrete_BatchLweTensor:$lhs, + Concrete_BatchLweTensor:$rhs + ); + let results = (outs Concrete_BatchLweTensor:$result); +} + +def Concrete_BatchedAddLweBufferOp : Concrete_Op<"batched_add_lwe_buffer"> { + let summary = "Batched version of AddLweBufferOp, which performs the same operation on multiple elements"; + + let arguments = (ins + Concrete_BatchLweBuffer:$result, + Concrete_BatchLweBuffer:$lhs, + Concrete_BatchLweBuffer:$rhs + ); +} + def Concrete_AddPlaintextLweTensorOp : Concrete_Op<"add_plaintext_lwe_tensor", [Pure]> { let summary = "Returns the sum of a clear integer and an lwe ciphertext"; @@ -75,6 +97,40 @@ def Concrete_AddPlaintextLweBufferOp : Concrete_Op<"add_plaintext_lwe_buffer"> { ); } +def Concrete_BatchedAddPlaintextLweTensorOp : Concrete_Op<"batched_add_plaintext_lwe_tensor", [Pure]> { + let summary = "Batched version of AddPlaintextLweTensorOp, which performs the same operation on multiple elements"; + + let arguments = (ins Concrete_BatchLweTensor:$lhs, Concrete_BatchPlaintextTensor:$rhs); + let results = (outs Concrete_BatchLweTensor:$result); +} + +def Concrete_BatchedAddPlaintextLweBufferOp : Concrete_Op<"batched_add_plaintext_lwe_buffer"> { + let summary = "Batched version of AddPlaintextLweBufferOp, which performs the same operation on multiple elements"; + + let arguments = (ins + Concrete_BatchLweBuffer:$result, + Concrete_BatchLweBuffer:$lhs, + Concrete_BatchPlaintextBuffer:$rhs + ); +} + +def Concrete_BatchedAddPlaintextCstLweTensorOp : Concrete_Op<"batched_add_plaintext_cst_lwe_tensor", [Pure]> { + let summary = "Batched version of AddPlaintextLweTensorOp, which performs the same operation on multiple elements"; + + let arguments = (ins Concrete_BatchLweTensor:$lhs, I64:$rhs); + let results = (outs Concrete_BatchLweTensor:$result); +} + +def Concrete_BatchedAddPlaintextCstLweBufferOp : Concrete_Op<"batched_add_plaintext_cst_lwe_buffer"> { + let summary = "Batched version of AddPlaintextLweBufferOp, which performs the same operation on multiple elements"; + + let arguments = (ins + Concrete_BatchLweBuffer:$result, + Concrete_BatchLweBuffer:$lhs, + I64:$rhs + ); +} + def Concrete_MulCleartextLweTensorOp : Concrete_Op<"mul_cleartext_lwe_tensor", [Pure]> { let summary = "Returns the product of a clear integer and a lwe ciphertext"; @@ -92,6 +148,40 @@ def Concrete_MulCleartextLweBufferOp : Concrete_Op<"mul_cleartext_lwe_buffer"> { ); } +def Concrete_BatchedMulCleartextLweTensorOp : Concrete_Op<"batched_mul_cleartext_lwe_tensor", [Pure]> { + let summary = "Batched version of MulCleartextLweTensorOp, which performs the same operation on multiple elements"; + + let arguments = (ins Concrete_BatchLweTensor:$lhs, Concrete_BatchPlaintextTensor:$rhs); + let results = (outs Concrete_BatchLweTensor:$result); +} + +def Concrete_BatchedMulCleartextLweBufferOp : Concrete_Op<"batched_mul_cleartext_lwe_buffer"> { + let summary = "Batched version of MulCleartextLweBufferOp, which performs the same operation on multiple elements"; + + let arguments = (ins + Concrete_BatchLweBuffer:$result, + Concrete_BatchLweBuffer:$lhs, + Concrete_BatchPlaintextBuffer:$rhs + ); +} + +def Concrete_BatchedMulCleartextCstLweTensorOp : Concrete_Op<"batched_mul_cleartext_cst_lwe_tensor", [Pure]> { + let summary = "Batched version of MulCleartextLweTensorOp, which performs the same operation on multiple elements"; + + let arguments = (ins Concrete_BatchLweTensor:$lhs, I64:$rhs); + let results = (outs Concrete_BatchLweTensor:$result); +} + +def Concrete_BatchedMulCleartextCstLweBufferOp : Concrete_Op<"batched_mul_cleartext_cst_lwe_buffer"> { + let summary = "Batched version of MulCleartextLweBufferOp, which performs the same operation on multiple elements"; + + let arguments = (ins + Concrete_BatchLweBuffer:$result, + Concrete_BatchLweBuffer:$lhs, + I64:$rhs + ); +} + def Concrete_NegateLweTensorOp : Concrete_Op<"negate_lwe_tensor", [Pure]> { let summary = "Negates a lwe ciphertext"; @@ -108,6 +198,22 @@ def Concrete_NegateLweBufferOp : Concrete_Op<"negate_lwe_buffer"> { ); } +def Concrete_BatchedNegateLweTensorOp : Concrete_Op<"batched_negate_lwe_tensor", [Pure]> { + let summary = "Batched version of NegateLweTensorOp, which performs the same operation on multiple elements"; + + let arguments = (ins Concrete_BatchLweTensor:$ciphertext); + let results = (outs Concrete_BatchLweTensor:$result); +} + +def Concrete_BatchedNegateLweBufferOp : Concrete_Op<"batched_negate_lwe_buffer"> { + let summary = "Batched version of NegateLweBufferOp, which performs the same operation on multiple elements"; + + let arguments = (ins + Concrete_BatchLweBuffer:$result, + Concrete_BatchLweBuffer:$ciphertext + ); +} + def Concrete_EncodeExpandLutForBootstrapTensorOp : Concrete_Op<"encode_expand_lut_for_bootstrap_tensor", [Pure]> { let summary = "Encode and expand a lookup table so that it can be used for a bootstrap"; diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td index ff601fbb2..303549771 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td @@ -77,22 +77,103 @@ def TFHE_ZeroTensorGLWEOp : TFHE_Op<"zero_tensor", [Pure]> { let results = (outs Type.predicate, HasStaticShapePred]>>:$tensor); } -def TFHE_AddGLWEIntOp : TFHE_Op<"add_glwe_int", [Pure]> { +def TFHE_ABatchedAddGLWEIntOp : TFHE_Op<"batched_add_glwe_int", [Pure]> { + let summary = "Batched version of AddGLWEIntOp"; + + let arguments = (ins + 1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts, + 1DTensorOf<[AnyInteger]> : $plaintexts + ); + + let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result); +} + +def TFHE_ABatchedAddGLWEIntCstOp : TFHE_Op<"batched_add_glwe_int_cst", [Pure]> { + let summary = "Batched version of AddGLWEIntOp"; + + let arguments = (ins + 1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts, + AnyInteger : $plaintext + ); + + let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result); +} + +def TFHE_AddGLWEIntOp : TFHE_Op<"add_glwe_int", [Pure, BatchableOpInterface]> { let summary = "Returns the sum of a clear integer and a lwe ciphertext"; let arguments = (ins TFHE_GLWECipherTextType : $a, AnyInteger : $b); let results = (outs TFHE_GLWECipherTextType); let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::llvm::MutableArrayRef<::mlir::OpOperand> getBatchableOperands() { + return getOperation()->getOpOperands().take_front(2); + } + + ::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder, + ::mlir::ValueRange batchedOperands, + ::mlir::ValueRange hoistedNonBatchableOperands) { + ::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get( + batchedOperands[0].getType().cast<::mlir::RankedTensorType>().getShape(), + getResult().getType()); + + ::llvm::SmallVector<::mlir::Value> operands(batchedOperands); + if (hoistedNonBatchableOperands.empty()) { + return builder.create( + mlir::TypeRange{resType}, + operands, + getOperation()->getAttrs()); + } else { + operands.append(hoistedNonBatchableOperands.begin(), + hoistedNonBatchableOperands.end()); + return builder.create( + mlir::TypeRange{resType}, + operands, + getOperation()->getAttrs()); + } + } + }]; } -def TFHE_AddGLWEOp : TFHE_Op<"add_glwe", [Pure]> { +def TFHE_ABatchedAddGLWEOp : TFHE_Op<"batched_add_glwe", [Pure]> { + let summary = "Batched version of AddGLWEOp"; + + let arguments = (ins + 1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts_a, + 1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts_b + ); + + let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result); +} + +def TFHE_AddGLWEOp : TFHE_Op<"add_glwe", [Pure, BatchableOpInterface]> { let summary = "Returns the sum of 2 lwe ciphertexts"; let arguments = (ins TFHE_GLWECipherTextType : $a, TFHE_GLWECipherTextType : $b); let results = (outs TFHE_GLWECipherTextType); let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::llvm::MutableArrayRef<::mlir::OpOperand> getBatchableOperands() { + return getOperation()->getOpOperands().take_front(2); + } + + ::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder, + ::mlir::ValueRange batchedOperands, + ::mlir::ValueRange hoistedNonBatchableOperands) { + ::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get( + batchedOperands[0].getType().cast<::mlir::RankedTensorType>().getShape(), + getResult().getType()); + + return builder.create( + mlir::TypeRange{resType}, + mlir::ValueRange{batchedOperands}, + getOperation()->getAttrs()); + } + }]; } def TFHE_SubGLWEIntOp : TFHE_Op<"sub_int_glwe", [Pure]> { @@ -104,22 +185,102 @@ def TFHE_SubGLWEIntOp : TFHE_Op<"sub_int_glwe", [Pure]> { let hasVerifier = 1; } -def TFHE_NegGLWEOp : TFHE_Op<"neg_glwe", [Pure]> { +def TFHE_BatchedNegGLWEOp : TFHE_Op<"batched_neg_glwe", [Pure]> { + let summary = "Batched version of NegGLWEOp"; + + let arguments = (ins + 1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts + ); + + let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result); +} + +def TFHE_NegGLWEOp : TFHE_Op<"neg_glwe", [Pure, BatchableOpInterface]> { let summary = "Negates a glwe ciphertext"; let arguments = (ins TFHE_GLWECipherTextType : $a); let results = (outs TFHE_GLWECipherTextType); let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::llvm::MutableArrayRef<::mlir::OpOperand> getBatchableOperands() { + return getOperation()->getOpOperands().take_front(); + } + + ::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder, + ::mlir::ValueRange batchedOperands, + ::mlir::ValueRange hoistedNonBatchableOperands) { + ::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get( + batchedOperands[0].getType().cast<::mlir::RankedTensorType>().getShape(), + getResult().getType()); + + return builder.create( + mlir::TypeRange{resType}, + mlir::ValueRange{batchedOperands}, + getOperation()->getAttrs()); + } + }]; } -def TFHE_MulGLWEIntOp : TFHE_Op<"mul_glwe_int", [Pure]> { +def TFHE_BatchedMulGLWEIntOp : TFHE_Op<"batched_mul_glwe_int", [Pure]> { + let summary = "Batched version of MulGLWEIntOp"; + + let arguments = (ins + 1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts, + 1DTensorOf<[AnyInteger]> : $cleartexts + ); + + let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result); +} + +def TFHE_BatchedMulGLWEIntCstOp : TFHE_Op<"batched_mul_glwe_int_cst", [Pure]> { + let summary = "Batched version of MulGLWEIntOp"; + + let arguments = (ins + 1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts, + AnyInteger: $cleartext + ); + + let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result); +} + +def TFHE_MulGLWEIntOp : TFHE_Op<"mul_glwe_int", [Pure, BatchableOpInterface]> { let summary = "Returns the product of a clear integer and a lwe ciphertext"; let arguments = (ins TFHE_GLWECipherTextType : $a, AnyInteger : $b); let results = (outs TFHE_GLWECipherTextType); let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::llvm::MutableArrayRef<::mlir::OpOperand> getBatchableOperands() { + return getOperation()->getOpOperands().take_front(2); + } + + ::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder, + ::mlir::ValueRange batchedOperands, + ::mlir::ValueRange hoistedNonBatchableOperands) { + ::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get( + batchedOperands[0].getType().cast<::mlir::RankedTensorType>().getShape(), + getResult().getType()); + + ::llvm::SmallVector<::mlir::Value> operands(batchedOperands); + if (hoistedNonBatchableOperands.empty()) { + return builder.create( + mlir::TypeRange{resType}, + operands, + getOperation()->getAttrs()); + } else { + operands.append(hoistedNonBatchableOperands.begin(), + hoistedNonBatchableOperands.end()); + return builder.create( + mlir::TypeRange{resType}, + operands, + getOperation()->getAttrs()); + } + } + }]; } def TFHE_BatchedKeySwitchGLWEOp : TFHE_Op<"batched_keyswitch_glwe", [Pure]> { diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/wrappers.h b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/wrappers.h index fb9cd070d..cd9bb0d22 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/wrappers.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/wrappers.h @@ -85,6 +85,54 @@ void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned, uint32_t ksk_index, mlir::concretelang::RuntimeContext *context); +void memref_batched_add_lwe_ciphertexts_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0, + uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned, + uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1, + uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *ct1_allocated, + uint64_t *ct1_aligned, uint64_t ct1_offset, uint64_t ct1_size0, + uint64_t ct1_size1, uint64_t ct1_stride0, uint64_t ct1_stride1); + +void memref_batched_add_plaintext_lwe_ciphertext_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0, + uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned, + uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1, + uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *ct1_allocated, + uint64_t *ct1_aligned, uint64_t ct1_offset, uint64_t ct1_size, + uint64_t ct1_stride); + +void memref_batched_add_plaintext_cst_lwe_ciphertext_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0, + uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned, + uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1, + uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t plaintext); + +void memref_batched_mul_cleartext_lwe_ciphertext_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0, + uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned, + uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1, + uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *ct1_allocated, + uint64_t *ct1_aligned, uint64_t ct1_offset, uint64_t ct1_size, + uint64_t ct1_stride); + +void memref_batched_mul_cleartext_cst_lwe_ciphertext_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0, + uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned, + uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1, + uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t cleartext); + +void memref_batched_negate_lwe_ciphertext_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0, + uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned, + uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1, + uint64_t ct0_stride0, uint64_t ct0_stride1); + void memref_batched_keyswitch_lwe_u64( uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0, diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/ConcreteToCAPI/ConcreteToCAPI.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/ConcreteToCAPI/ConcreteToCAPI.cpp index 4320f8ce1..f56ff4f3c 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/ConcreteToCAPI/ConcreteToCAPI.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/ConcreteToCAPI/ConcreteToCAPI.cpp @@ -27,6 +27,18 @@ char memref_mul_cleartext_lwe_ciphertext_u64[] = char memref_negate_lwe_ciphertext_u64[] = "memref_negate_lwe_ciphertext_u64"; char memref_keyswitch_lwe_u64[] = "memref_keyswitch_lwe_u64"; char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64"; +char memref_batched_add_lwe_ciphertexts_u64[] = + "memref_batched_add_lwe_ciphertexts_u64"; +char memref_batched_add_plaintext_lwe_ciphertext_u64[] = + "memref_batched_add_plaintext_lwe_ciphertext_u64"; +char memref_batched_add_plaintext_cst_lwe_ciphertext_u64[] = + "memref_batched_add_plaintext_cst_lwe_ciphertext_u64"; +char memref_batched_mul_cleartext_lwe_ciphertext_u64[] = + "memref_batched_mul_cleartext_lwe_ciphertext_u64"; +char memref_batched_mul_cleartext_cst_lwe_ciphertext_u64[] = + "memref_batched_mul_cleartext_cst_lwe_ciphertext_u64"; +char memref_batched_negate_lwe_ciphertext_u64[] = + "memref_batched_negate_lwe_ciphertext_u64"; char memref_batched_keyswitch_lwe_u64[] = "memref_batched_keyswitch_lwe_u64"; char memref_batched_bootstrap_lwe_u64[] = "memref_batched_bootstrap_lwe_u64"; @@ -129,6 +141,26 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI( memref1DType, i32Type, i32Type, i32Type, i32Type, i32Type, i32Type, contextType}, {futureType}); + } else if (funcName == memref_batched_add_lwe_ciphertexts_u64) { + funcType = mlir::FunctionType::get( + rewriter.getContext(), {memref2DType, memref2DType, memref2DType}, {}); + } else if (funcName == memref_batched_add_plaintext_lwe_ciphertext_u64) { + funcType = mlir::FunctionType::get( + rewriter.getContext(), {memref2DType, memref2DType, memref1DType}, {}); + } else if (funcName == memref_batched_add_plaintext_cst_lwe_ciphertext_u64) { + funcType = mlir::FunctionType::get( + rewriter.getContext(), + {memref2DType, memref2DType, rewriter.getI64Type()}, {}); + } else if (funcName == memref_batched_mul_cleartext_lwe_ciphertext_u64) { + funcType = mlir::FunctionType::get( + rewriter.getContext(), {memref2DType, memref2DType, memref1DType}, {}); + } else if (funcName == memref_batched_mul_cleartext_cst_lwe_ciphertext_u64) { + funcType = mlir::FunctionType::get( + rewriter.getContext(), + {memref2DType, memref2DType, rewriter.getI64Type()}, {}); + } else if (funcName == memref_batched_negate_lwe_ciphertext_u64) { + funcType = mlir::FunctionType::get(rewriter.getContext(), + {memref2DType, memref2DType}, {}); } else if (funcName == memref_batched_keyswitch_lwe_u64 || funcName == memref_batched_keyswitch_lwe_cuda_u64) { funcType = @@ -515,6 +547,26 @@ struct ConcreteToCAPIPass : public ConcreteToCAPIBase { .add>( &getContext(), encodeLutForWopPBSAddOperands); + patterns + .add>( + &getContext()); + patterns.add>(&getContext()); + patterns.add>(&getContext()); + patterns.add>(&getContext()); + patterns.add>(&getContext()); + patterns.add< + ConcreteToCAPICallPattern>( + &getContext()); if (gpu) { patterns.add>( diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp index 2e607ead3..515b29193 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp @@ -784,8 +784,28 @@ void TFHEToConcretePass::runOnOperation() { mlir::concretelang::Concrete::EncodeLutForCrtWopPBSTensorOp, true>, mlir::concretelang::GenericOneToOneOpConversionPattern< mlir::concretelang::TFHE::EncodePlaintextWithCrtOp, - mlir::concretelang::Concrete::EncodePlaintextWithCrtTensorOp, true>>( - &getContext(), converter); + mlir::concretelang::Concrete::EncodePlaintextWithCrtTensorOp, true>, + + mlir::concretelang::GenericOneToOneOpConversionPattern< + mlir::concretelang::TFHE::ABatchedAddGLWEIntOp, + mlir::concretelang::Concrete::BatchedAddPlaintextLweTensorOp>, + mlir::concretelang::GenericOneToOneOpConversionPattern< + mlir::concretelang::TFHE::ABatchedAddGLWEIntCstOp, + mlir::concretelang::Concrete::BatchedAddPlaintextCstLweTensorOp>, + mlir::concretelang::GenericOneToOneOpConversionPattern< + mlir::concretelang::TFHE::ABatchedAddGLWEOp, + mlir::concretelang::Concrete::BatchedAddLweTensorOp>, + mlir::concretelang::GenericOneToOneOpConversionPattern< + mlir::concretelang::TFHE::BatchedMulGLWEIntOp, + mlir::concretelang::Concrete::BatchedMulCleartextLweTensorOp>, + mlir::concretelang::GenericOneToOneOpConversionPattern< + mlir::concretelang::TFHE::BatchedMulGLWEIntCstOp, + mlir::concretelang::Concrete::BatchedMulCleartextCstLweTensorOp>, + mlir::concretelang::GenericOneToOneOpConversionPattern< + mlir::concretelang::TFHE::BatchedNegGLWEOp, + mlir::concretelang::Concrete::BatchedNegateLweTensorOp> + + >(&getContext(), converter); // pattern of remaining TFHE ops patterns.insert, diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.cpp index 138c95a4c..9e54bbdd4 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.cpp @@ -123,6 +123,34 @@ void mlir::concretelang::Concrete:: // bootstrap_lwe_tensor => bootstrap_lwe_buffer Concrete::BootstrapLweTensorOp::attachInterface>(*ctx); + + // batched_add_lwe_tensor => batched_add_lwe_buffer + Concrete::BatchedAddLweTensorOp::attachInterface>( + *ctx); + // batched_add_plaintext_lwe_tensor => batched_add_plaintext_lwe_buffer + Concrete::BatchedAddPlaintextLweTensorOp::attachInterface< + TensorToMemrefOp>(*ctx); + // batched_add_plaintext_cst_lwe_tensor => + // batched_add_plaintext_cst_lwe_buffer + Concrete::BatchedAddPlaintextCstLweTensorOp::attachInterface< + TensorToMemrefOp>(*ctx); + // batched_mul_cleartext_lwe_tensor => batched_mul_cleartext_lwe_buffer + Concrete::BatchedMulCleartextLweTensorOp::attachInterface< + TensorToMemrefOp>(*ctx); + // batched_mul_cleartext_cst_lwe_tensor => + // batched_mul_cleartext_cst_lwe_buffer + Concrete::BatchedMulCleartextCstLweTensorOp::attachInterface< + TensorToMemrefOp>(*ctx); + // batched_negate_lwe_tensor => batched_negate_lwe_buffer + Concrete::BatchedNegateLweTensorOp::attachInterface< + TensorToMemrefOp>(*ctx); + // batched_keyswitch_lwe_tensor => batched_keyswitch_lwe_buffer Concrete::BatchedKeySwitchLweTensorOp::attachInterface< TensorToMemrefOp