diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index 77df44681..9f24c98c2 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -124,7 +124,6 @@ enum CompilationTarget { FHE, TFHE, CONCRETE, - BCONCRETE, STD, LLVM, LLVM_IR, @@ -138,8 +137,7 @@ typedef enum CompilationTarget CompilationTarget; MLIR_CAPI_EXPORTED CompilationOptions compilationOptionsCreate( MlirStringRef funcName, bool autoParallelize, bool batchConcreteOps, bool dataflowParallelize, bool emitGPUOps, bool loopParallelize, - bool optimizeConcrete, OptimizerConfig optimizerConfig, - bool verifyDiagnostics); + bool optimizeTFHE, OptimizerConfig optimizerConfig, bool verifyDiagnostics); MLIR_CAPI_EXPORTED CompilationOptions compilationOptionsCreateDefault(); diff --git a/compiler/include/concretelang/Conversion/CMakeLists.txt b/compiler/include/concretelang/Conversion/CMakeLists.txt index 6643f8d57..a9c1a87f5 100644 --- a/compiler/include/concretelang/Conversion/CMakeLists.txt +++ b/compiler/include/concretelang/Conversion/CMakeLists.txt @@ -2,5 +2,3 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name Conversion) add_public_tablegen_target(ConcretelangConversionPassIncGen) add_dependencies(mlir-headers ConcretelangConversionPassIncGen) - -add_subdirectory(TFHEToConcrete) diff --git a/compiler/include/concretelang/Conversion/ConcreteToBConcrete/Pass.h b/compiler/include/concretelang/Conversion/ConcreteToBConcrete/Pass.h deleted file mode 100644 index 35626ab62..000000000 --- a/compiler/include/concretelang/Conversion/ConcreteToBConcrete/Pass.h +++ /dev/null @@ -1,18 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt -// for license information. - -#ifndef ZAMALANG_CONVERSION_CONCRETETOBCONCRETE_PASS_H_ -#define ZAMALANG_CONVERSION_CONCRETETOBCONCRETE_PASS_H_ - -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace concretelang { -/// Create a pass to convert `Concrete` dialect to `BConcrete` dialect. -std::unique_ptr> createConvertConcreteToBConcretePass(); -} // namespace concretelang -} // namespace mlir - -#endif diff --git a/compiler/include/concretelang/Conversion/BConcreteToCAPI/Pass.h b/compiler/include/concretelang/Conversion/ConcreteToCAPI/Pass.h similarity index 63% rename from compiler/include/concretelang/Conversion/BConcreteToCAPI/Pass.h rename to compiler/include/concretelang/Conversion/ConcreteToCAPI/Pass.h index ebd9c3d91..3b8245e40 100644 --- a/compiler/include/concretelang/Conversion/BConcreteToCAPI/Pass.h +++ b/compiler/include/concretelang/Conversion/ConcreteToCAPI/Pass.h @@ -3,16 +3,16 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. -#ifndef ZAMALANG_CONVERSION_BCONCRETETOCAPI_PASS_H_ -#define ZAMALANG_CONVERSION_BCONCRETETOCAPI_PASS_H_ +#ifndef ZAMALANG_CONVERSION_CONCRETETOCAPI_PASS_H_ +#define ZAMALANG_CONVERSION_CONCRETETOCAPI_PASS_H_ #include "mlir/Pass/Pass.h" namespace mlir { namespace concretelang { -/// Create a pass to convert `BConcrete` dialect to CAPI calls. +/// Create a pass to convert `Concrete` dialect to CAPI calls. std::unique_ptr> -createConvertBConcreteToCAPIPass(bool gpu); +createConvertConcreteToCAPIPass(bool gpu); } // namespace concretelang } // namespace mlir diff --git a/compiler/include/concretelang/Conversion/Passes.h b/compiler/include/concretelang/Conversion/Passes.h index 6cdfb800b..8e7a919d7 100644 --- a/compiler/include/concretelang/Conversion/Passes.h +++ b/compiler/include/concretelang/Conversion/Passes.h @@ -13,8 +13,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "concretelang/Conversion/BConcreteToCAPI/Pass.h" -#include "concretelang/Conversion/ConcreteToBConcrete/Pass.h" +#include "concretelang/Conversion/ConcreteToCAPI/Pass.h" #include "concretelang/Conversion/ExtractSDFGOps/Pass.h" #include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h" #include "concretelang/Conversion/FHEToTFHECrt/Pass.h" @@ -25,7 +24,6 @@ #include "concretelang/Conversion/TFHEGlobalParametrization/Pass.h" #include "concretelang/Conversion/TFHEToConcrete/Pass.h" #include "concretelang/Conversion/TracingToCAPI/Pass.h" -#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" #include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" #include "concretelang/Dialect/FHE/IR/FHEDialect.h" #include "concretelang/Dialect/SDFG/IR/SDFGDialect.h" diff --git a/compiler/include/concretelang/Conversion/Passes.td b/compiler/include/concretelang/Conversion/Passes.td index 2fb409a0b..9f99117eb 100644 --- a/compiler/include/concretelang/Conversion/Passes.td +++ b/compiler/include/concretelang/Conversion/Passes.td @@ -48,13 +48,6 @@ def LinalgGenericOpWithTensorsToLoops : Pass<"linalg-generic-op-with-tensors-to- let dependentDialects = ["mlir::linalg::LinalgDialect", "mlir::scf::SCFDialect"]; } -def ConcreteToBConcrete : Pass<"concrete-to-bconcrete", "mlir::ModuleOp"> { - let summary = "Lowers operations from the Concrete dialect to Bufferized Concrete"; - let description = [{ Lowers operations from the Concrete dialect to Bufferized Concrete }]; - let constructor = "mlir::concretelang::createConvertConcreteToBConcretePass()"; - let dependentDialects = ["mlir::linalg::LinalgDialect", "mlir::concretelang::Concrete::ConcreteDialect", "mlir::concretelang::BConcrete::BConcreteDialect"]; -} - def ExtractSDFGOps : Pass<"extract-sdfg-ops", "::mlir::func::FuncOp"> { let summary = "Extracts SDFG ops and creates a static data flow graph"; let description = [{ Extracts SDFG ops and creates a static data flow graph }]; @@ -62,11 +55,11 @@ def ExtractSDFGOps : Pass<"extract-sdfg-ops", "::mlir::func::FuncOp"> { let dependentDialects = ["mlir::concretelang::SDFG::SDFGDialect"]; } -def BConcreteToCAPI : Pass<"bconcrete-to-capi", "mlir::ModuleOp"> { - let summary = "Lowers operations from the BConcrete dialect to CAPI calls"; - let description = [{ Lowers operations from the BConcrete dialect to CAPI calls }]; - let constructor = "mlir::concretelang::createConvertBConcreteToCAPIPass()"; - let dependentDialects = ["mlir::concretelang::BConcrete::BConcreteDialect"]; +def ConcreteToCAPI : Pass<"concrete-to-capi", "mlir::ModuleOp"> { + let summary = "Lowers operations from the Concrete dialect to CAPI calls"; + let description = [{ Lowers operations from the Concrete dialect to CAPI calls }]; + let constructor = "mlir::concretelang::createConvertConcreteToCAPIPass()"; + let dependentDialects = ["mlir::concretelang::Concrete::ConcreteDialect"]; } def TracingToCAPI : Pass<"tracing-to-capi", "mlir::ModuleOp"> { diff --git a/compiler/include/concretelang/Conversion/TFHEToConcrete/CMakeLists.txt b/compiler/include/concretelang/Conversion/TFHEToConcrete/CMakeLists.txt deleted file mode 100644 index 89494e516..000000000 --- a/compiler/include/concretelang/Conversion/TFHEToConcrete/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Patterns.td) -mlir_tablegen(Patterns.h.inc -gen-rewriters -name TFHE) -add_public_tablegen_target(TFHEToConcretePatternsIncGen) -add_dependencies(mlir-headers TFHEToConcretePatternsIncGen) - -add_concretelang_doc(Patterns TFHEToConcretePatterns concretelang/ -gen-pass-doc) diff --git a/compiler/include/concretelang/Conversion/TFHEToConcrete/Patterns.h b/compiler/include/concretelang/Conversion/TFHEToConcrete/Patterns.h deleted file mode 100644 index 9949b5f88..000000000 --- a/compiler/include/concretelang/Conversion/TFHEToConcrete/Patterns.h +++ /dev/null @@ -1,187 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt -// for license information. - -#ifndef CONCRETELANG_CONVERSION_TFHETOCONCRETE_PATTERNS_H_ -#define CONCRETELANG_CONVERSION_TFHETOCONCRETE_PATTERNS_H_ - -#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h" -#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" -#include "concretelang/Dialect/TFHE/IR/TFHEOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/PatternMatch.h" - -namespace mlir { -namespace concretelang { - -using Concrete::CleartextType; -using Concrete::LweCiphertextType; -using Concrete::PlaintextType; -using TFHE::GLWECipherTextType; - -LweCiphertextType convertTypeToLWE(mlir::MLIRContext *context, - mlir::Type type) { - auto glwe = type.dyn_cast_or_null(); - if (glwe != nullptr) { - assert(glwe.getPolynomialSize() == 1); - return LweCiphertextType::get(context, glwe.getDimension(), glwe.getP()); - } - auto lwe = type.dyn_cast_or_null(); - if (lwe != nullptr) { - return lwe; - } - assert(false && "expect glwe or lwe"); - return nullptr; -} - -/// Converts the type `t` to an LWE type if `t` is a -/// `TFHE::GLWECipherTextType`, otherwise just returns `t`. -mlir::Type convertTypeToLWEIfTFHEType(mlir::MLIRContext *context, - mlir::Type t) { - if (auto eint = t.dyn_cast()) - return convertTypeToLWE(context, eint); - - return t; -} - -template -PlaintextType convertPlaintextTypeFromPType(mlir::MLIRContext *context, - PType &type) { - return PlaintextType::get(context, type.getP() + 1); -} - -/// convertPlaintextTypeFromType create a plaintext type according the -/// precision of the given type argument. The type should be a GLWECipherText -/// (if operand is not yet lowered) or a LWECipherTextType (if operand is -/// already lowered). -PlaintextType convertPlaintextTypeFromType(mlir::MLIRContext *context, - mlir::Type &type) { - auto glwe = type.dyn_cast_or_null(); - if (glwe != nullptr) { - return convertPlaintextTypeFromPType(context, glwe); - } - auto lwe = type.dyn_cast_or_null(); - if (lwe != nullptr) { - return convertPlaintextTypeFromPType(context, lwe); - } - assert(false && "expect glwe or lwe"); - return nullptr; -} - -template -CleartextType convertCleartextTypeFromPType(mlir::MLIRContext *context, - PType &type) { - return CleartextType::get(context, type.getP() + 1); -} - -/// convertCleartextTypeFromType create a cleartext type according the -/// precision of the given type argument. The type should be a GLWECipherText -/// (if operand is not yet lowered) or a LWECipherTextType (if operand is -/// already lowered). -CleartextType convertCleartextTypeFromType(mlir::MLIRContext *context, - mlir::Type &type) { - auto glwe = type.dyn_cast_or_null(); - if (glwe != nullptr) { - return convertCleartextTypeFromPType(context, glwe); - } - auto lwe = type.dyn_cast_or_null(); - if (lwe != nullptr) { - return convertCleartextTypeFromPType(context, lwe); - } - assert(false && "expect glwe or lwe"); - return nullptr; -} - -mlir::Value createZeroLWEOpFromTFHE(mlir::PatternRewriter &rewriter, - mlir::Location loc, mlir::OpResult result) { - mlir::SmallVector args{}; - mlir::SmallVector attrs; - auto glwe = result.getType().cast(); - mlir::SmallVector resTypes{ - convertTypeToLWE(rewriter.getContext(), glwe)}; - Concrete::ZeroLWEOp op = - rewriter.create(loc, resTypes, args, attrs); - return op.getODSResults(0).front(); -} - -template -mlir::Value createConcreteOpFromTFHE(mlir::PatternRewriter &rewriter, - mlir::Location loc, mlir::Value arg0, - mlir::Value arg1, mlir::OpResult result) { - mlir::SmallVector args{arg0, arg1}; - mlir::SmallVector attrs; - mlir::SmallVector resTypes{result.getType()}; - Operator op = rewriter.create(loc, resTypes, args, attrs); - convertOperandAndResultTypes(rewriter, op, convertTypeToLWE); - - return op.getODSResults(0).front(); -} - -mlir::Value createAddPlainLweCiphertextWithGlwe( - mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Value arg0, - mlir::Value arg1, mlir::OpResult result, mlir::Type encryptedType) { - auto op = - rewriter - .create( - loc, result.getType(), arg0, arg1); - - convertOperandAndResultTypes(rewriter, op, convertTypeToLWEIfTFHEType); - - return op.getODSResults(0).front(); -} - -mlir::Value createAddPlainLweCiphertext(mlir::PatternRewriter &rewriter, - mlir::Location loc, mlir::Value arg0, - mlir::Value arg1, - mlir::OpResult result) { - return createAddPlainLweCiphertextWithGlwe(rewriter, loc, arg0, arg1, result, - arg0.getType()); -} - -mlir::Value createNegLweCiphertext(mlir::PatternRewriter &rewriter, - mlir::Location loc, mlir::Value arg0, - mlir::OpResult result) { - auto negated = - rewriter.create( - loc, arg0.getType(), arg0); - convertOperandAndResultTypes(rewriter, negated, convertTypeToLWEIfTFHEType); - return negated.getODSResults(0).front(); -} - -mlir::Value createSubIntLweCiphertext(mlir::PatternRewriter &rewriter, - mlir::Location loc, mlir::Value arg0, - mlir::Value arg1, mlir::OpResult result) { - auto negated_arg1 = createNegLweCiphertext(rewriter, loc, arg1, result); - return createAddPlainLweCiphertextWithGlwe(rewriter, loc, negated_arg1, arg0, - result, arg1.getType()); -} - -mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter, - mlir::Location loc, mlir::Value arg0, - mlir::Value arg1, - mlir::OpResult result) { - // replace op using the encoded plaintext instead of int - auto op = - rewriter - .create( - loc, result.getType(), arg0, arg1); - - convertOperandAndResultTypes(rewriter, op, convertTypeToLWEIfTFHEType); - - return op.getODSResults(0).front(); -} - -} // namespace concretelang -} // namespace mlir - -namespace { -#include "concretelang/Conversion/TFHEToConcrete/Patterns.h.inc" -} - -void populateWithGeneratedTFHEToConcrete(mlir::RewritePatternSet &patterns) { - populateWithGenerated(patterns); -} - -#endif diff --git a/compiler/include/concretelang/Conversion/TFHEToConcrete/Patterns.td b/compiler/include/concretelang/Conversion/TFHEToConcrete/Patterns.td deleted file mode 100644 index e68c0feee..000000000 --- a/compiler/include/concretelang/Conversion/TFHEToConcrete/Patterns.td +++ /dev/null @@ -1,45 +0,0 @@ -#ifndef CONCRETELANG_CONVERSION_TFHETOCONCRETE_PATTERNS -#define CONCRETELANG_CONVERSION_TFHETOCONCRETE_PATTERNS - -include "mlir/Pass/PassBase.td" -include "mlir/IR/PatternBase.td" -include "concretelang/Dialect/Concrete/IR/ConcreteOps.td" -include "concretelang/Dialect/TFHE/IR/TFHEOps.td" - -def createZeroLWEOp : NativeCodeCall<"mlir::concretelang::createZeroLWEOpFromTFHE($_builder, $_loc, $0)">; - -def ZeroGLWEPattern : Pat< - (TFHE_ZeroGLWEOp:$result), - (createZeroLWEOp $result)>; - -def createAddLWEOp : NativeCodeCall<"mlir::concretelang::createConcreteOpFromTFHE($_builder, $_loc, $0, $1, $2)">; - -def AddGLWEPattern : Pat< - (TFHE_AddGLWEOp:$result $arg0, $arg1), - (createAddLWEOp $arg0, $arg1, $result)>; - -def createAddPlainLweOp : NativeCodeCall<"mlir::concretelang::createAddPlainLweCiphertext($_builder, $_loc, $0, $1, $2)">; - -def AddGLWEIntPattern : Pat< - (TFHE_AddGLWEIntOp:$result $arg0, $arg1), - (createAddPlainLweOp $arg0, $arg1, $result)>; - -def createMulClearLweOp : NativeCodeCall<"mlir::concretelang::createMulClearLweCiphertext($_builder, $_loc, $0, $1, $2)">; - -def MulGLWEIntPattern : Pat< - (TFHE_MulGLWEIntOp:$result $arg0, $arg1), - (createMulClearLweOp $arg0, $arg1, $result)>; - -def createSubIntLweOp : NativeCodeCall<"mlir::concretelang::createSubIntLweCiphertext($_builder, $_loc, $0, $1, $2)">; - -def SubGLWEIntPattern : Pat< - (TFHE_SubGLWEIntOp:$result $arg0, $arg1), - (createSubIntLweOp $arg0, $arg1, $result)>; - -def createNegLweOp : NativeCodeCall<"mlir::concretelang::createNegLweCiphertext($_builder, $_loc, $0, $1)">; - -def NegGLWEPattern : Pat< - (TFHE_NegGLWEOp:$result $arg0), - (createNegLweOp $arg0, $result)>; - -#endif diff --git a/compiler/include/concretelang/Dialect/BConcrete/CMakeLists.txt b/compiler/include/concretelang/Dialect/BConcrete/CMakeLists.txt deleted file mode 100644 index 9f57627c3..000000000 --- a/compiler/include/concretelang/Dialect/BConcrete/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_subdirectory(IR) -add_subdirectory(Transforms) diff --git a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteDialect.h b/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteDialect.h deleted file mode 100644 index 0b1a2e25c..000000000 --- a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteDialect.h +++ /dev/null @@ -1,18 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt -// for license information. - -#ifndef ZAMALANG_DIALECT_BConcrete_IR_BConcreteDIALECT_H -#define ZAMALANG_DIALECT_BConcrete_IR_BConcreteDIALECT_H - -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dialect.h" - -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dialect.h" - -#include "concretelang/Dialect/BConcrete/IR/BConcreteOpsDialect.h.inc" - -#endif diff --git a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteDialect.td b/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteDialect.td deleted file mode 100644 index df4524c16..000000000 --- a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteDialect.td +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef ZAMALANG_DIALECT_BConcrete_IR_BConcrete_DIALECT -#define ZAMALANG_DIALECT_BConcrete_IR_BConcrete_DIALECT - -include "mlir/IR/OpBase.td" - -def BConcrete_Dialect : Dialect { - let name = "BConcrete"; - let summary = "Bufferized concrete dialect"; - let description = [{ - A dialect for representation of bufferized concrete operations on fully homomorphic ciphertext. - }]; - let cppNamespace = "::mlir::concretelang::BConcrete"; -} - -#endif diff --git a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.h b/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.h deleted file mode 100644 index 568843db0..000000000 --- a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.h +++ /dev/null @@ -1,22 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt -// for license information. - -#ifndef ZAMALANG_DIALECT_BConcrete_BConcrete_OPS_H -#define ZAMALANG_DIALECT_BConcrete_BConcrete_OPS_H - -#include -#include -#include -#include -#include -#include - -#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" -#include "concretelang/Dialect/RT/IR/RTTypes.h" - -#define GET_OP_CLASSES -#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h.inc" - -#endif diff --git a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td b/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td deleted file mode 100644 index 5f93581a1..000000000 --- a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td +++ /dev/null @@ -1,314 +0,0 @@ -#ifndef ZAMALANG_DIALECT_BConcrete_IR_BConcrete_OPS -#define ZAMALANG_DIALECT_BConcrete_IR_BConcrete_OPS - -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/IR/BuiltinTypes.td" -include "mlir/Dialect/MemRef/IR/MemRefBase.td" -include "mlir/Dialect/LLVMIR/LLVMOpBase.td" - -include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.td" -include "concretelang/Dialect/Concrete/IR/ConcreteTypes.td" -include "concretelang/Dialect/RT/IR/RTDialect.td" -include "concretelang/Dialect/RT/IR/RTTypes.td" - -class BConcrete_Op traits = []> : - Op; - -// BConcrete tensor operators ///////////////////////////////////////////////// - -def BConcrete_AddLweTensorOp : BConcrete_Op<"add_lwe_tensor", [NoSideEffect]> { - let arguments = (ins - 1DTensorOf<[I64]>:$lhs, - 1DTensorOf<[I64]>:$rhs - ); - let results = (outs 1DTensorOf<[I64]>:$result); -} - -def BConcrete_AddPlaintextLweTensorOp : BConcrete_Op<"add_plaintext_lwe_tensor", [NoSideEffect]> { - let arguments = (ins 1DTensorOf<[I64]>:$lhs, I64:$rhs); - let results = (outs 1DTensorOf<[I64]>:$result); -} - -def BConcrete_MulCleartextLweTensorOp : BConcrete_Op<"mul_cleartext_lwe_tensor", [NoSideEffect]> { - let arguments = (ins 1DTensorOf<[I64]>:$lhs, I64:$rhs); - let results = (outs 1DTensorOf<[I64]>:$result); -} - -def BConcrete_NegateLweTensorOp : BConcrete_Op<"negate_lwe_tensor", [NoSideEffect]> { - let arguments = (ins 1DTensorOf<[I64]>:$ciphertext); - let results = (outs 1DTensorOf<[I64]>:$result); -} - -def BConcrete_KeySwitchLweTensorOp : BConcrete_Op<"keyswitch_lwe_tensor", [NoSideEffect]> { - let arguments = (ins - // LweKeySwitchKeyType:$keyswitch_key, - 1DTensorOf<[I64]>:$ciphertext, - I32Attr:$level, - I32Attr:$baseLog, - I32Attr:$lwe_dim_in, - I32Attr:$lwe_dim_out - ); - let results = (outs 1DTensorOf<[I64]>:$result); -} - -def BConcrete_BatchedKeySwitchLweTensorOp : BConcrete_Op<"batched_keyswitch_lwe_tensor", [NoSideEffect]> { - let arguments = (ins - // LweKeySwitchKeyType:$keyswitch_key, - 2DTensorOf<[I64]>:$ciphertext, - I32Attr:$level, - I32Attr:$baseLog, - I32Attr:$lwe_dim_in, - I32Attr:$lwe_dim_out - ); - let results = (outs 2DTensorOf<[I64]>:$result); -} - -def BConcrete_EncodeExpandLutForBootstrapTensorOp : BConcrete_Op<"encode_expand_lut_for_bootstrap_tensor", [NoSideEffect]> { - let summary = - "Encode and expand a lookup table so that it can be used for a bootstrap."; - - let arguments = (ins - 1DTensorOf<[I64]> : $input_lookup_table, - I32Attr: $polySize, - I32Attr: $outputBits, - BoolAttr: $isSigned - ); - - let results = (outs 1DTensorOf<[I64]> : $result); -} - -def BConcrete_EncodeExpandLutForWopPBSTensorOp : BConcrete_Op<"encode_expand_lut_for_woppbs_tensor", [NoSideEffect]> { - let summary = - "Encode and expand a lookup table so that it can be used for a wop pbs."; - - let arguments = (ins - 1DTensorOf<[I64]> : $input_lookup_table, - I64ArrayAttr: $crtDecomposition, - I64ArrayAttr: $crtBits, - I32Attr : $polySize, - I32Attr : $modulusProduct, - BoolAttr: $isSigned - ); - - let results = (outs 1DTensorOf<[I64]> : $result); -} - - -def BConcrete_EncodePlaintextWithCrtTensorOp : BConcrete_Op<"encode_plaintext_with_crt_tensor", [NoSideEffect]> { - let summary = - "Encodes a plaintext by decomposing it on a crt basis."; - - let arguments = (ins - I64 : $input, - I64ArrayAttr: $mods, - I64Attr: $modsProd - ); - - let results = (outs 1DTensorOf<[I64]> : $result); -} - - -def BConcrete_BootstrapLweTensorOp : BConcrete_Op<"bootstrap_lwe_tensor", [NoSideEffect]> { - let arguments = (ins - 1DTensorOf<[I64]>:$input_ciphertext, - 1DTensorOf<[I64]>:$lookup_table, - I32Attr:$inputLweDim, - I32Attr:$polySize, - I32Attr:$level, - I32Attr:$baseLog, - I32Attr:$glweDimension, - I32Attr:$outPrecision - ); - let results = (outs 1DTensorOf<[I64]>:$result); -} - -def BConcrete_BatchedBootstrapLweTensorOp : BConcrete_Op<"batched_bootstrap_lwe_tensor", [NoSideEffect]> { - let arguments = (ins - 2DTensorOf<[I64]>:$input_ciphertext, - 1DTensorOf<[I64]>:$lookup_table, - I32Attr:$inputLweDim, - I32Attr:$polySize, - I32Attr:$level, - I32Attr:$baseLog, - I32Attr:$glweDimension, - I32Attr:$outPrecision - ); - let results = (outs 2DTensorOf<[I64]>:$result); -} - -def BConcrete_WopPBSCRTLweTensorOp : BConcrete_Op<"wop_pbs_crt_lwe_tensor", [NoSideEffect]> { - let arguments = (ins - 2DTensorOf<[I64]>:$ciphertext, - 1DTensorOf<[I64]>:$lookupTable, - // Bootstrap parameters - I32Attr : $bootstrapLevel, - I32Attr : $bootstrapBaseLog, - // Keyswitch parameters - I32Attr : $keyswitchLevel, - I32Attr : $keyswitchBaseLog, - // Packing keyswitch key parameters - I32Attr : $packingKeySwitchInputLweDimension, - I32Attr : $packingKeySwitchoutputPolynomialSize, - I32Attr : $packingKeySwitchLevel, - I32Attr : $packingKeySwitchBaseLog, - // Circuit bootstrap parameters - I32Attr : $circuitBootstrapLevel, - I32Attr : $circuitBootstrapBaseLog - ); - let results = (outs 2DTensorOf<[I64]>:$result); -} - -// BConcrete memref operators ///////////////////////////////////////////////// - -def BConcrete_LweBuffer : MemRefRankOf<[I64], [1]>; -def BConcrete_LutBuffer : MemRefRankOf<[I64], [1]>; -def BConcrete_CrtPlaintextBuffer : MemRefRankOf<[I64], [1]>; -def BConcrete_LweCRTBuffer : MemRefRankOf<[I64], [2]>; -def BConcrete_BatchLweBuffer : MemRefRankOf<[I64], [2]>; - -def BConcrete_AddLweBufferOp : BConcrete_Op<"add_lwe_buffer"> { - let arguments = (ins - BConcrete_LweBuffer:$result, - BConcrete_LweBuffer:$lhs, - BConcrete_LweBuffer:$rhs - ); -} - -def BConcrete_AddPlaintextLweBufferOp : BConcrete_Op<"add_plaintext_lwe_buffer"> { - let arguments = (ins - BConcrete_LweBuffer:$result, - BConcrete_LweBuffer:$lhs, - I64:$rhs - ); -} - -def BConcrete_MulCleartextLweBufferOp : BConcrete_Op<"mul_cleartext_lwe_buffer"> { - let arguments = (ins - BConcrete_LweBuffer:$result, - BConcrete_LweBuffer:$lhs, - I64:$rhs - ); -} - -def BConcrete_NegateLweBufferOp : BConcrete_Op<"negate_lwe_buffer"> { - let arguments = (ins - BConcrete_LweBuffer:$result, - BConcrete_LweBuffer:$ciphertext - ); -} - -def BConcrete_KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> { - let arguments = (ins - BConcrete_LweBuffer:$result, - BConcrete_LweBuffer:$ciphertext, - I32Attr:$level, - I32Attr:$baseLog, - I32Attr:$lwe_dim_in, - I32Attr:$lwe_dim_out - ); -} - -def BConcrete_BatchedKeySwitchLweBufferOp : BConcrete_Op<"batched_keyswitch_lwe_buffer"> { - let arguments = (ins - BConcrete_BatchLweBuffer:$result, - BConcrete_BatchLweBuffer:$ciphertext, - I32Attr:$level, - I32Attr:$baseLog, - I32Attr:$lwe_dim_in, - I32Attr:$lwe_dim_out - ); -} - -def BConcrete_EncodeExpandLutForBootstrapBufferOp : BConcrete_Op<"encode_expand_lut_for_bootstrap_buffer"> { - let summary = - "Encode and expand a lookup table so that it can be used for a bootstrap."; - - let arguments = (ins - BConcrete_LutBuffer: $result, - BConcrete_LutBuffer: $input_lookup_table, - I32Attr: $polySize, - I32Attr: $outputBits, - BoolAttr : $isSigned - ); -} - -def BConcrete_EncodeExpandLutForWopPBSBufferOp : BConcrete_Op<"encode_expand_lut_for_woppbs_buffer"> { - let summary = - "Encode and expand a lookup table so that it can be used for a wop pbs."; - - let arguments = (ins - BConcrete_LutBuffer : $result, - BConcrete_LutBuffer : $input_lookup_table, - I64ArrayAttr: $crtDecomposition, - I64ArrayAttr: $crtBits, - I32Attr : $polySize, - I32Attr : $modulusProduct, - BoolAttr: $isSigned - ); -} - -def BConcrete_EncodePlaintextWithCrtBufferOp : BConcrete_Op<"encode_plaintext_with_crt_buffer"> { - let summary = - "Encodes a plaintext by decomposing it on a crt basis."; - - let arguments = (ins - BConcrete_CrtPlaintextBuffer: $result, - I64 : $input, - I64ArrayAttr: $mods, - I64Attr: $modsProd - ); -} - -def BConcrete_BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> { - let arguments = (ins - BConcrete_LweBuffer:$result, - BConcrete_LweBuffer:$input_ciphertext, - BConcrete_LutBuffer:$lookup_table, - I32Attr:$inputLweDim, - I32Attr:$polySize, - I32Attr:$level, - I32Attr:$baseLog, - I32Attr:$glweDimension, - I32Attr:$outPrecision - ); -} - -def BConcrete_BatchedBootstrapLweBufferOp : BConcrete_Op<"batched_bootstrap_lwe_buffer"> { - let arguments = (ins - BConcrete_BatchLweBuffer:$result, - BConcrete_BatchLweBuffer:$input_ciphertext, - BConcrete_LutBuffer:$lookup_table, - I32Attr:$inputLweDim, - I32Attr:$polySize, - I32Attr:$level, - I32Attr:$baseLog, - I32Attr:$glweDimension, - I32Attr:$outPrecision - ); -} - -def BConcrete_WopPBSCRTLweBufferOp : BConcrete_Op<"wop_pbs_crt_lwe_buffer"> { - let arguments = (ins - BConcrete_LweCRTBuffer:$result, - BConcrete_LweCRTBuffer:$ciphertext, - BConcrete_LutBuffer:$lookup_table, - // Bootstrap parameters - I32Attr : $bootstrapLevel, - I32Attr : $bootstrapBaseLog, - // Keyswitch parameters - I32Attr : $keyswitchLevel, - I32Attr : $keyswitchBaseLog, - // Packing keyswitch key parameters - I32Attr : $packingKeySwitchInputLweDimension, - I32Attr : $packingKeySwitchoutputPolynomialSize, - I32Attr : $packingKeySwitchLevel, - I32Attr : $packingKeySwitchBaseLog, - // Circuit bootstrap parameters - I32Attr : $circuitBootstrapLevel, - I32Attr : $circuitBootstrapBaseLog, - I64ArrayAttr:$crtDecomposition - ); -} - -#endif diff --git a/compiler/include/concretelang/Dialect/BConcrete/IR/CMakeLists.txt b/compiler/include/concretelang/Dialect/BConcrete/IR/CMakeLists.txt deleted file mode 100644 index 87dacb599..000000000 --- a/compiler/include/concretelang/Dialect/BConcrete/IR/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS BConcreteOps.td) -mlir_tablegen(BConcreteOps.h.inc -gen-op-decls) -mlir_tablegen(BConcreteOps.cpp.inc -gen-op-defs) -mlir_tablegen(BConcreteOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=BConcrete) -mlir_tablegen(BConcreteOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=BConcrete) -mlir_tablegen(BConcreteOpsDialect.h.inc -gen-dialect-decls -dialect=BConcrete) -mlir_tablegen(BConcreteOpsDialect.cpp.inc -gen-dialect-defs -dialect=BConcrete) -add_public_tablegen_target(MLIRBConcreteOpsIncGen) -add_dependencies(mlir-headers MLIRBConcreteOpsIncGen) diff --git a/compiler/include/concretelang/Dialect/BConcrete/Transforms/CMakeLists.txt b/compiler/include/concretelang/Dialect/BConcrete/Transforms/CMakeLists.txt deleted file mode 100644 index e74efae2f..000000000 --- a/compiler/include/concretelang/Dialect/BConcrete/Transforms/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls -name BConcrete) -add_public_tablegen_target(BConcreteTransformsIncGen) diff --git a/compiler/include/concretelang/Dialect/CMakeLists.txt b/compiler/include/concretelang/Dialect/CMakeLists.txt index 3889967e6..a76fe6f1e 100644 --- a/compiler/include/concretelang/Dialect/CMakeLists.txt +++ b/compiler/include/concretelang/Dialect/CMakeLists.txt @@ -2,7 +2,6 @@ add_subdirectory(FHE) add_subdirectory(FHELinalg) add_subdirectory(TFHE) add_subdirectory(Concrete) -add_subdirectory(BConcrete) add_subdirectory(RT) add_subdirectory(SDFG) add_subdirectory(Tracing) diff --git a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td index 130762fc2..101811186 100644 --- a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td +++ b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td @@ -3,76 +3,135 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/IR/BuiltinTypes.td" +include "mlir/Dialect/MemRef/IR/MemRefBase.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "concretelang/Dialect/Concrete/IR/ConcreteDialect.td" include "concretelang/Dialect/Concrete/IR/ConcreteTypes.td" include "concretelang/Interfaces/BatchableInterface.td" +include "concretelang/Dialect/RT/IR/RTDialect.td" +include "concretelang/Dialect/RT/IR/RTTypes.td" + +def Concrete_LweTensor : 1DTensorOf<[I64]>; +def Concrete_LutTensor : 1DTensorOf<[I64]>; +def Concrete_CrtPlaintextTensor : 1DTensorOf<[I64]>; +def Concrete_LweCRTTensor : 2DTensorOf<[I64]>; +def Concrete_BatchLweTensor : 2DTensorOf<[I64]>; + +def Concrete_LweBuffer : MemRefRankOf<[I64], [1]>; +def Concrete_LutBuffer : MemRefRankOf<[I64], [1]>; +def Concrete_CrtPlaintextBuffer : MemRefRankOf<[I64], [1]>; +def Concrete_LweCRTBuffer : MemRefRankOf<[I64], [2]>; +def Concrete_BatchLweBuffer : MemRefRankOf<[I64], [2]>; class Concrete_Op traits = []> : Op; -def Concrete_ZeroLWEOp : Concrete_Op<"zero"> { - let summary = "Returns a trivial encyption of 0"; - let arguments = (ins); - let results = (outs Concrete_LweCiphertextType:$out); -} - -def Concrete_ZeroTensorLWEOp : Concrete_Op<"zero_tensor"> { - let summary = "Returns a trivial encyption of 0"; - - let arguments = (ins); - let results = (outs Type.predicate, HasStaticShapePred]>>:$tensor); -} - -def Concrete_AddLweCiphertextsOp : Concrete_Op<"add_lwe_ciphertexts"> { +def Concrete_AddLweTensorOp : Concrete_Op<"add_lwe_tensor", [NoSideEffect]> { let summary = "Returns the sum of 2 lwe ciphertexts"; - let arguments = (ins Concrete_LweCiphertextType:$lhs, Concrete_LweCiphertextType:$rhs); - let results = (outs Concrete_LweCiphertextType:$result); + let arguments = (ins + Concrete_LweTensor:$lhs, + Concrete_LweTensor:$rhs + ); + let results = (outs Concrete_LweTensor:$result); } -def Concrete_AddPlaintextLweCiphertextOp : Concrete_Op<"add_plaintext_lwe_ciphertext"> { - let summary = "Returns the sum of a clear integer and a lwe ciphertext"; +def Concrete_AddLweBufferOp : Concrete_Op<"add_lwe_buffer"> { + let summary = "Returns the sum of 2 lwe ciphertexts"; - let arguments = (ins Concrete_LweCiphertextType:$lhs, AnyInteger:$rhs); - let results = (outs Concrete_LweCiphertextType:$result); + let arguments = (ins + Concrete_LweBuffer:$result, + Concrete_LweBuffer:$lhs, + Concrete_LweBuffer:$rhs + ); } -def Concrete_MulCleartextLweCiphertextOp : Concrete_Op<"mul_cleartext_lwe_ciphertext"> { +def Concrete_AddPlaintextLweTensorOp : Concrete_Op<"add_plaintext_lwe_tensor", [NoSideEffect]> { + let summary = "Returns the sum of a clear integer and an lwe ciphertext"; + + let arguments = (ins Concrete_LweTensor:$lhs, I64:$rhs); + let results = (outs Concrete_LweTensor:$result); +} + +def Concrete_AddPlaintextLweBufferOp : Concrete_Op<"add_plaintext_lwe_buffer"> { + let summary = "Returns the sum of a clear integer and an lwe ciphertext"; + + let arguments = (ins + Concrete_LweBuffer:$result, + Concrete_LweBuffer:$lhs, + I64:$rhs + ); +} + +def Concrete_MulCleartextLweTensorOp : Concrete_Op<"mul_cleartext_lwe_tensor", [NoSideEffect]> { let summary = "Returns the product of a clear integer and a lwe ciphertext"; - let arguments = (ins Concrete_LweCiphertextType:$lhs, AnyInteger:$rhs); - let results = (outs Concrete_LweCiphertextType:$result); + let arguments = (ins Concrete_LweTensor:$lhs, I64:$rhs); + let results = (outs Concrete_LweTensor:$result); } -def Concrete_NegateLweCiphertextOp : Concrete_Op<"negate_lwe_ciphertext"> { +def Concrete_MulCleartextLweBufferOp : Concrete_Op<"mul_cleartext_lwe_buffer"> { + let summary = "Returns the product of a clear integer and a lwe ciphertext"; + + let arguments = (ins + Concrete_LweBuffer:$result, + Concrete_LweBuffer:$lhs, + I64:$rhs + ); +} + +def Concrete_NegateLweTensorOp : Concrete_Op<"negate_lwe_tensor", [NoSideEffect]> { let summary = "Negates a lwe ciphertext"; - let arguments = (ins Concrete_LweCiphertextType:$ciphertext); - let results = (outs Concrete_LweCiphertextType:$result); + let arguments = (ins Concrete_LweTensor:$ciphertext); + let results = (outs Concrete_LweTensor:$result); } -def Concrete_EncodeExpandLutForBootstrapOp : Concrete_Op<"encode_expand_lut_for_bootstrap"> { - let summary = - "Encode and expand a lookup table so that it can be used for a bootstrap."; +def Concrete_NegateLweBufferOp : Concrete_Op<"negate_lwe_buffer"> { + let summary = "Negates a lwe ciphertext"; let arguments = (ins - 1DTensorOf<[I64]> : $input_lookup_table, - I32Attr: $polySize, - I32Attr: $outputBits, - BoolAttr: $isSigned + Concrete_LweBuffer:$result, + Concrete_LweBuffer:$ciphertext + ); +} + +def Concrete_EncodeExpandLutForBootstrapTensorOp : Concrete_Op<"encode_expand_lut_for_bootstrap_tensor", [NoSideEffect]> { + let summary = + "Encode and expand a lookup table so that it can be used for a bootstrap"; + + let arguments = (ins + Concrete_LutTensor : $input_lookup_table, + I32Attr: $polySize, + I32Attr: $outputBits, + BoolAttr: $isSigned ); - let results = (outs 1DTensorOf<[I64]> : $result); + let results = (outs Concrete_LutTensor : $result); } -def Concrete_EncodeExpandLutForWopPBSOp : Concrete_Op<"encode_expand_lut_for_woppbs"> { -let summary = - "Encode and expand a lookup table so that it can be used for a wop pbs."; +def Concrete_EncodeExpandLutForBootstrapBufferOp : Concrete_Op<"encode_expand_lut_for_bootstrap_buffer"> { + let summary = + "Encode and expand a lookup table so that it can be used for a bootstrap"; let arguments = (ins - 1DTensorOf<[I64]> : $input_lookup_table, + Concrete_LutBuffer: $result, + Concrete_LutBuffer: $input_lookup_table, + I32Attr: $polySize, + I32Attr: $outputBits, + BoolAttr : $isSigned + ); +} + +def Concrete_EncodeExpandLutForWopPBSTensorOp : Concrete_Op<"encode_expand_lut_for_woppbs_tensor", [NoSideEffect]> { + let summary = + "Encode and expand a lookup table so that it can be used for a wop pbs"; + + let arguments = (ins + Concrete_LutTensor : $input_lookup_table, I64ArrayAttr: $crtDecomposition, I64ArrayAttr: $crtBits, I32Attr : $polySize, @@ -80,12 +139,27 @@ let summary = BoolAttr: $isSigned ); - let results = (outs 1DTensorOf<[I64]> : $result); + let results = (outs Concrete_LutTensor : $result); } -def Concrete_EncodePlaintextWithCrtOp : Concrete_Op<"encode_plaintext_with_crt"> { +def Concrete_EncodeExpandLutForWopPBSBufferOp : Concrete_Op<"encode_expand_lut_for_woppbs_buffer"> { let summary = - "Encodes a plaintext by decomposing it on a crt basis."; + "Encode and expand a lookup table so that it can be used for a wop pbs"; + + let arguments = (ins + Concrete_LutBuffer : $result, + Concrete_LutBuffer : $input_lookup_table, + I64ArrayAttr: $crtDecomposition, + I64ArrayAttr: $crtBits, + I32Attr : $polySize, + I32Attr : $modulusProduct, + BoolAttr: $isSigned + ); +} + +def Concrete_EncodePlaintextWithCrtTensorOp : Concrete_Op<"encode_plaintext_with_crt_tensor", [NoSideEffect]> { + let summary = + "Encodes a plaintext by decomposing it on a crt basis"; let arguments = (ins I64 : $input, @@ -93,21 +167,35 @@ def Concrete_EncodePlaintextWithCrtOp : Concrete_Op<"encode_plaintext_with_crt"> I64Attr: $modsProd ); - let results = (outs 1DTensorOf<[I64]> : $result); + let results = (outs Concrete_CrtPlaintextTensor : $result); } -def Concrete_BootstrapLweOp : Concrete_Op<"bootstrap_lwe", [BatchableOpInterface]> { - let summary = "Bootstraps a LWE ciphertext with a GLWE trivial encryption of the lookup table"; +def Concrete_EncodePlaintextWithCrtBufferOp : Concrete_Op<"encode_plaintext_with_crt_buffer"> { + let summary = + "Encodes a plaintext by decomposing it on a crt basis"; - let arguments = (ins - Concrete_LweCiphertextType:$input_ciphertext, - 1DTensorOf<[I64]>:$lookup_table, + let arguments = (ins + Concrete_CrtPlaintextBuffer: $result, + I64 : $input, + I64ArrayAttr: $mods, + I64Attr: $modsProd + ); +} + +def Concrete_BootstrapLweTensorOp : Concrete_Op<"bootstrap_lwe_tensor", [NoSideEffect]> { + let summary = "Bootstraps an LWE ciphertext with a GLWE trivial encryption of the lookup table"; + + let arguments = (ins + Concrete_LweTensor:$input_ciphertext, + Concrete_LweTensor:$lookup_table, + I32Attr:$inputLweDim, + I32Attr:$polySize, I32Attr:$level, I32Attr:$baseLog, - I32Attr:$polySize, - I32Attr:$glweDimension + I32Attr:$glweDimension, + I32Attr:$outPrecision ); - let results = (outs Concrete_LweCiphertextType:$result); + let results = (outs Concrete_LweTensor:$result); let extraClassDeclaration = [{ ::mlir::OpOperand& getBatchableOperand() { @@ -124,38 +212,74 @@ def Concrete_BootstrapLweOp : Concrete_Op<"bootstrap_lwe", [BatchableOpInterface batchedOperands.getType().cast<::mlir::RankedTensorType>().getShape(), getResult().getType()); - return builder.create( + return builder.create( mlir::TypeRange{resType}, mlir::ValueRange{batchedOperands, lookup_table()}, getOperation()->getAttrs()); } }]; - } -def Concrete_BatchedBootstrapLweOp : Concrete_Op<"batched_bootstrap_lwe"> { - let summary = "Batched version of BootstrapLweOp, which performs the same operation on a tensor of elements"; +def Concrete_BootstrapLweBufferOp : Concrete_Op<"bootstrap_lwe_buffer"> { + let summary = "Bootstraps a LWE ciphertext with a GLWE trivial encryption of the lookup table"; let arguments = (ins - 1DTensorOf<[Concrete_LweCiphertextType]>:$input_ciphertexts, - 1DTensorOf<[I64]>:$lookup_table, + Concrete_LweBuffer:$result, + Concrete_LweBuffer:$input_ciphertext, + Concrete_LutBuffer:$lookup_table, + I32Attr:$inputLweDim, + I32Attr:$polySize, I32Attr:$level, I32Attr:$baseLog, - I32Attr:$polySize, - I32Attr:$glweDimension + I32Attr:$glweDimension, + I32Attr:$outPrecision ); - let results = (outs 1DTensorOf<[Concrete_LweCiphertextType]>:$result); } -def Concrete_KeySwitchLweOp : Concrete_Op<"keyswitch_lwe", [BatchableOpInterface]> { - let summary = "Keyswitches a LWE ciphertext"; +def Concrete_BatchedBootstrapLweTensorOp : Concrete_Op<"batched_bootstrap_lwe_tensor", [NoSideEffect]> { + let summary = "Batched version of BootstrapLweOp, which performs the same operation on multiple elements"; let arguments = (ins - Concrete_LweCiphertextType:$ciphertext, + Concrete_BatchLweTensor:$input_ciphertext, + Concrete_LutTensor:$lookup_table, + I32Attr:$inputLweDim, + I32Attr:$polySize, I32Attr:$level, - I32Attr:$baseLog + I32Attr:$baseLog, + I32Attr:$glweDimension, + I32Attr:$outPrecision ); - let results = (outs Concrete_LweCiphertextType:$result); + let results = (outs Concrete_BatchLweTensor:$result); +} + +def Concrete_BatchedBootstrapLweBufferOp : Concrete_Op<"batched_bootstrap_lwe_buffer"> { + let summary = "Batched version of BootstrapLweOp, which performs the same operation on multiple elements"; + + let arguments = (ins + Concrete_BatchLweBuffer:$result, + Concrete_BatchLweBuffer:$input_ciphertext, + Concrete_LutBuffer:$lookup_table, + I32Attr:$inputLweDim, + I32Attr:$polySize, + I32Attr:$level, + I32Attr:$baseLog, + I32Attr:$glweDimension, + I32Attr:$outPrecision + ); +} + +def Concrete_KeySwitchLweTensorOp : Concrete_Op<"keyswitch_lwe_tensor", [NoSideEffect]> { + let summary = "Keyswitches an LWE ciphertext"; + + let arguments = (ins + // LweKeySwitchKeyType:$keyswitch_key, + Concrete_LweTensor:$ciphertext, + I32Attr:$level, + I32Attr:$baseLog, + I32Attr:$lwe_dim_in, + I32Attr:$lwe_dim_out + ); + let results = (outs Concrete_LweTensor:$result); let extraClassDeclaration = [{ ::mlir::OpOperand& getBatchableOperand() { @@ -172,7 +296,7 @@ def Concrete_KeySwitchLweOp : Concrete_Op<"keyswitch_lwe", [BatchableOpInterface batchedOperands.getType().cast<::mlir::RankedTensorType>().getShape(), getResult().getType()); - return builder.create( + return builder.create( mlir::TypeRange{resType}, mlir::ValueRange{batchedOperands}, getOperation()->getAttrs()); @@ -180,24 +304,73 @@ def Concrete_KeySwitchLweOp : Concrete_Op<"keyswitch_lwe", [BatchableOpInterface }]; } -def Concrete_BatchedKeySwitchLweOp : Concrete_Op<"batched_keyswitch_lwe"> { - let summary = "Batched version of KeySwitchLweOp, which performs the same operation on a tensor of elements"; +def Concrete_KeySwitchLweBufferOp : Concrete_Op<"keyswitch_lwe_buffer"> { + let summary = "Keyswitches an LWE ciphertext"; let arguments = (ins - 1DTensorOf<[Concrete_LweCiphertextType]>:$ciphertexts, + Concrete_LweBuffer:$result, + Concrete_LweBuffer:$ciphertext, I32Attr:$level, - I32Attr:$baseLog + I32Attr:$baseLog, + I32Attr:$lwe_dim_in, + I32Attr:$lwe_dim_out ); - let results = (outs 1DTensorOf<[Concrete_LweCiphertextType]>:$result); } -// TODO(16bits): hack -def Concrete_WopPBSLweOp : Concrete_Op<"wop_pbs_lwe"> { - let summary = ""; +def Concrete_BatchedKeySwitchLweTensorOp : Concrete_Op<"batched_keyswitch_lwe_tensor", [NoSideEffect]> { + let summary = "Batched version of KeySwitchLweOp, which performs the same operation on multiple elements"; let arguments = (ins - Type.predicate, HasStaticShapePred]>>:$ciphertexts, - 1DTensorOf<[I64]>:$accumulator, + // LweKeySwitchKeyType:$keyswitch_key, + Concrete_BatchLweTensor:$ciphertext, + I32Attr:$level, + I32Attr:$baseLog, + I32Attr:$lwe_dim_in, + I32Attr:$lwe_dim_out + ); + let results = (outs Concrete_BatchLweTensor:$result); +} + +def Concrete_BatchedKeySwitchLweBufferOp : Concrete_Op<"batched_keyswitch_lwe_buffer"> { + let summary = "Batched version of KeySwitchLweOp, which performs the same operation on multiple elements"; + + let arguments = (ins + Concrete_BatchLweBuffer:$result, + Concrete_BatchLweBuffer:$ciphertext, + I32Attr:$level, + I32Attr:$baseLog, + I32Attr:$lwe_dim_in, + I32Attr:$lwe_dim_out + ); +} + +def Concrete_WopPBSCRTLweTensorOp : Concrete_Op<"wop_pbs_crt_lwe_tensor", [NoSideEffect]> { + let arguments = (ins + Concrete_LweCRTTensor:$ciphertext, + Concrete_LutTensor:$lookupTable, + // Bootstrap parameters + I32Attr : $bootstrapLevel, + I32Attr : $bootstrapBaseLog, + // Keyswitch parameters + I32Attr : $keyswitchLevel, + I32Attr : $keyswitchBaseLog, + // Packing keyswitch key parameters + I32Attr : $packingKeySwitchInputLweDimension, + I32Attr : $packingKeySwitchoutputPolynomialSize, + I32Attr : $packingKeySwitchLevel, + I32Attr : $packingKeySwitchBaseLog, + // Circuit bootstrap parameters + I32Attr : $circuitBootstrapLevel, + I32Attr : $circuitBootstrapBaseLog + ); + let results = (outs Concrete_LweCRTTensor:$result); +} + +def Concrete_WopPBSCRTLweBufferOp : Concrete_Op<"wop_pbs_crt_lwe_buffer"> { + let arguments = (ins + Concrete_LweCRTBuffer:$result, + Concrete_LweCRTBuffer:$ciphertext, + Concrete_LutBuffer:$lookup_table, // Bootstrap parameters I32Attr : $bootstrapLevel, I32Attr : $bootstrapBaseLog, @@ -212,10 +385,8 @@ def Concrete_WopPBSLweOp : Concrete_Op<"wop_pbs_lwe"> { // Circuit bootstrap parameters I32Attr : $circuitBootstrapLevel, I32Attr : $circuitBootstrapBaseLog, - // Crt decomposition - I64ArrayAttr: $crtDecomposition + I64ArrayAttr:$crtDecomposition ); - let results = (outs Type.predicate, HasStaticShapePred]>>:$result); } #endif diff --git a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td index 126180e8b..888f1ab0b 100644 --- a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td +++ b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td @@ -7,80 +7,6 @@ include "concretelang/Dialect/Concrete/IR/ConcreteDialect.td" class Concrete_Type traits = []> : TypeDef { } -def Concrete_GlweCiphertextType : Concrete_Type<"GlweCiphertext"> { - let mnemonic = "glwe_ciphertext"; - - let summary = "A GLWE ciphertext (encryption of a polynomial of fixed-precision integers)"; - - let description = [{ - GLWE ciphertext. - }]; - - let hasCustomAssemblyFormat = 1; - - let parameters = (ins - "signed":$glweDimension, - "signed":$polynomialSize, - // Precision of the lwe ciphertext - "signed":$p - ); -} - -def Concrete_LweCiphertextType : Concrete_Type<"LweCiphertext", [MemRefElementTypeInterface]> { - let mnemonic = "lwe_ciphertext"; - - let summary = "A LWE ciphertext (encryption of a fixed-precision integer)"; - - let description = [{ - Learning With Error ciphertext. - }]; - - - let parameters = (ins - // The dimension of the lwe ciphertext - "signed":$dimension, - // Precision of the lwe ciphertext - "signed":$p - - ); - - let hasCustomAssemblyFormat = 1; -} - -def Concrete_CleartextType : Concrete_Type<"Cleartext"> { - let mnemonic = "cleartext"; - - let summary = "A cleartext (a fixed-precision integer) ready to be multiplied to a LWE ciphertext"; - - let description = [{ - Cleartext. - }]; - - let parameters = (ins - // Number of bits of the cleartext representation - "signed":$p - ); - - let hasCustomAssemblyFormat = 1; -} - -def Concrete_PlaintextType : Concrete_Type<"Plaintext"> { - let mnemonic = "plaintext"; - - let summary = "A Plaintext (a fixed-precision integer) ready to be added to a LWE ciphertext"; - - let description = [{ - Plaintext. - }]; - - let parameters = (ins - // Number of bits of the cleartext representation - "signed":$p - ); - - let hasCustomAssemblyFormat = 1; -} - def Concrete_Context : Concrete_Type<"Context"> { let mnemonic = "context"; diff --git a/compiler/include/concretelang/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.h b/compiler/include/concretelang/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.h similarity index 68% rename from compiler/include/concretelang/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.h rename to compiler/include/concretelang/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.h index 947551431..2fce04a96 100644 --- a/compiler/include/concretelang/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.h +++ b/compiler/include/concretelang/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.h @@ -3,16 +3,16 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. -#ifndef CONCRETELANG_DIALECT_BCONCRETE_BUFFERIZABLEOPINTERFACEIMPL_H -#define CONCRETELANG_DIALECT_BCONCRETE_BUFFERIZABLEOPINTERFACEIMPL_H +#ifndef CONCRETELANG_DIALECT_CONCRETE_BUFFERIZABLEOPINTERFACEIMPL_H +#define CONCRETELANG_DIALECT_CONCRETE_BUFFERIZABLEOPINTERFACEIMPL_H namespace mlir { class DialectRegistry; namespace concretelang { -namespace BConcrete { +namespace Concrete { void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); -} // namespace BConcrete +} // namespace Concrete } // namespace concretelang } // namespace mlir diff --git a/compiler/include/concretelang/Dialect/Concrete/Transforms/CMakeLists.txt b/compiler/include/concretelang/Dialect/Concrete/Transforms/CMakeLists.txt index 425267561..bbcdd599a 100644 --- a/compiler/include/concretelang/Dialect/Concrete/Transforms/CMakeLists.txt +++ b/compiler/include/concretelang/Dialect/Concrete/Transforms/CMakeLists.txt @@ -1,4 +1,3 @@ -set(LLVM_TARGET_DEFINITIONS Optimization.td) -mlir_tablegen(Optimization.h.inc -gen-pass-decls -name Transforms) -add_public_tablegen_target(ConcretelangConcreteOptimizationPassIncGen) -add_dependencies(mlir-headers ConcretelangConcreteOptimizationPassIncGen) +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Concrete) +add_public_tablegen_target(ConcreteTransformsIncGen) diff --git a/compiler/include/concretelang/Dialect/Concrete/Transforms/Optimization.td b/compiler/include/concretelang/Dialect/Concrete/Transforms/Optimization.td deleted file mode 100644 index 42bc363e7..000000000 --- a/compiler/include/concretelang/Dialect/Concrete/Transforms/Optimization.td +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef CONCRETELANG_CONCRETE_OPTIMIZATION_PASS -#define CONCRETELANG_CONCRETE_OPTIMIZATION_PASS - -include "mlir/Pass/PassBase.td" - -def ConcreteOptimization : Pass<"concrete-optimization"> { - let summary = "Optimize Concrete operations"; - let constructor = "mlir::concretelang::createConcreteOptimizationPass()"; - let options = []; - let dependentDialects = [ "mlir::concretelang::Concrete::ConcreteDialect" ]; -} - -#endif diff --git a/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.h b/compiler/include/concretelang/Dialect/Concrete/Transforms/Passes.h similarity index 56% rename from compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.h rename to compiler/include/concretelang/Dialect/Concrete/Transforms/Passes.h index cd337afc5..ad8ad961c 100644 --- a/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.h +++ b/compiler/include/concretelang/Dialect/Concrete/Transforms/Passes.h @@ -3,19 +3,18 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. -#ifndef CONCRETELANG_DIALECT_BCONCRETE_TRANSFORMS_PASSES_H_ -#define CONCRETELANG_DIALECT_BCONCRETE_TRANSFORMS_PASSES_H_ +#ifndef CONCRETELANG_DIALECT_CONCRETE_TRANSFORMS_PASSES_H_ +#define CONCRETELANG_DIALECT_CONCRETE_TRANSFORMS_PASSES_H_ #include "mlir/Pass/Pass.h" #define GEN_PASS_CLASSES -#include "concretelang/Dialect/BConcrete/Transforms/Passes.h.inc" +#include "concretelang/Dialect/Concrete/Transforms/Passes.h.inc" namespace mlir { namespace concretelang { std::unique_ptr> createAddRuntimeContext(); -std::unique_ptr> createEliminateCRTOps(); } // namespace concretelang } // namespace mlir -#endif // CONCRETELANG_DIALECT_BCONCRETE_TRANSFORMS_PASSES_H_ +#endif // CONCRETELANG_DIALECT_CONCRETE_TRANSFORMS_PASSES_H_ diff --git a/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.td b/compiler/include/concretelang/Dialect/Concrete/Transforms/Passes.td similarity index 77% rename from compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.td rename to compiler/include/concretelang/Dialect/Concrete/Transforms/Passes.td index e9519d4a9..e77d7770a 100644 --- a/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.td +++ b/compiler/include/concretelang/Dialect/Concrete/Transforms/Passes.td @@ -16,10 +16,4 @@ def AddRuntimeContext : Pass<"add-runtime-context", "mlir::ModuleOp"> { let constructor = "mlir::concretelang::createAddRuntimeContext()"; } -def EliminateCRTOps - : Pass<"eliminate-bconcrete-crt-ops", "mlir::func::FuncOp"> { - let summary = "Eliminate the crt bconcrete operators."; - let constructor = "mlir::concretelang::createEliminateCRTOpsPass()"; -} - #endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES diff --git a/compiler/include/concretelang/Dialect/TFHE/CMakeLists.txt b/compiler/include/concretelang/Dialect/TFHE/CMakeLists.txt index f33061b2d..9f57627c3 100644 --- a/compiler/include/concretelang/Dialect/TFHE/CMakeLists.txt +++ b/compiler/include/concretelang/Dialect/TFHE/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/compiler/include/concretelang/Dialect/TFHE/Transforms/CMakeLists.txt b/compiler/include/concretelang/Dialect/TFHE/Transforms/CMakeLists.txt new file mode 100644 index 000000000..f19200ffa --- /dev/null +++ b/compiler/include/concretelang/Dialect/TFHE/Transforms/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Optimization.td) +mlir_tablegen(Optimization.h.inc -gen-pass-decls -name Transforms) +add_public_tablegen_target(ConcretelangTFHEOptimizationPassIncGen) +add_dependencies(mlir-headers ConcretelangTFHEOptimizationPassIncGen) diff --git a/compiler/include/concretelang/Dialect/Concrete/Transforms/Optimization.h b/compiler/include/concretelang/Dialect/TFHE/Transforms/Optimization.h similarity index 53% rename from compiler/include/concretelang/Dialect/Concrete/Transforms/Optimization.h rename to compiler/include/concretelang/Dialect/TFHE/Transforms/Optimization.h index 7cb234522..031acf528 100644 --- a/compiler/include/concretelang/Dialect/Concrete/Transforms/Optimization.h +++ b/compiler/include/concretelang/Dialect/TFHE/Transforms/Optimization.h @@ -3,18 +3,18 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. -#ifndef CONCRETELANG_CONCRETE_OPTIMIZATION_PASS_H -#define CONCRETELANG_CONCRETE_OPTIMIZATION_PASS_H +#ifndef CONCRETELANG_TFHE_OPTIMIZATION_PASS_H +#define CONCRETELANG_TFHE_OPTIMIZATION_PASS_H -#include +#include #include #define GEN_PASS_CLASSES -#include +#include namespace mlir { namespace concretelang { -std::unique_ptr> createConcreteOptimizationPass(); +std::unique_ptr> createTFHEOptimizationPass(); } // namespace concretelang } // namespace mlir diff --git a/compiler/include/concretelang/Dialect/TFHE/Transforms/Optimization.td b/compiler/include/concretelang/Dialect/TFHE/Transforms/Optimization.td new file mode 100644 index 000000000..9e48d2081 --- /dev/null +++ b/compiler/include/concretelang/Dialect/TFHE/Transforms/Optimization.td @@ -0,0 +1,13 @@ +#ifndef CONCRETELANG_TFHE_OPTIMIZATION_PASS +#define CONCRETELANG_TFHE_OPTIMIZATION_PASS + +include "mlir/Pass/PassBase.td" + +def TFHEOptimization : Pass<"tfhe-optimization"> { + let summary = "Optimize TFHE operations"; + let constructor = "mlir::concretelang::createTFHEOptimizationPass()"; + let options = []; + let dependentDialects = [ "mlir::concretelang::TFHE::TFHEDialect" ]; +} + +#endif diff --git a/compiler/include/concretelang/Dialect/Tracing/IR/TracingOps.td b/compiler/include/concretelang/Dialect/Tracing/IR/TracingOps.td index 186c83414..0010bfc36 100644 --- a/compiler/include/concretelang/Dialect/Tracing/IR/TracingOps.td +++ b/compiler/include/concretelang/Dialect/Tracing/IR/TracingOps.td @@ -29,7 +29,6 @@ def Tracing_TraceCiphertextOp : Tracing_Op<"trace_ciphertext"> { FHE_EncryptedIntegerType.predicate, FHE_EncryptedSignedIntegerType.predicate, TFHE_GLWECipherTextType.predicate, - Concrete_LweCiphertextType.predicate, 1DTensorOf<[I64]>.predicate, MemRefRankOf<[I64], [1]>.predicate ]>>: $ciphertext, diff --git a/compiler/include/concretelang/Support/CompilerEngine.h b/compiler/include/concretelang/Support/CompilerEngine.h index 3af76dd52..71298ae29 100644 --- a/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compiler/include/concretelang/Support/CompilerEngine.h @@ -62,7 +62,7 @@ struct CompilationOptions { bool emitSDFGOps; bool unrollLoopsWithSDFGConvertibleOps; bool dataflowParallelize; - bool optimizeConcrete; + bool optimizeTFHE; /// use GPU during execution by generating GPU operations if possible bool emitGPUOps; llvm::Optional> fhelinalgTileSizes; @@ -82,7 +82,7 @@ struct CompilationOptions { : v0FHEConstraints(llvm::None), verifyDiagnostics(false), autoParallelize(false), loopParallelize(false), batchConcreteOps(false), emitSDFGOps(false), unrollLoopsWithSDFGConvertibleOps(false), - dataflowParallelize(false), optimizeConcrete(true), emitGPUOps(false), + dataflowParallelize(false), optimizeTFHE(true), emitGPUOps(false), clientParametersFuncName(llvm::None), optimizerConfig(optimizer::DEFAULT_CONFIG), chunkIntegers(false), chunkSize(4), chunkWidth(2){}; @@ -212,12 +212,8 @@ public: /// operations CONCRETE, - /// Read sources and lower all FHE, TFHE and Concrete operations to - /// BConcrete operations - BCONCRETE, - - /// Read sources and lower all FHE, TFHE and Concrete operations to - /// BConcrete, then extract SDFG operations + /// Read sources and lower all FHE and TFHE operations to Concrete + /// then extract SDFG operations SDFG, /// Read sources and lower all FHE, TFHE and Concrete diff --git a/compiler/include/concretelang/Support/Pipeline.h b/compiler/include/concretelang/Support/Pipeline.h index 9152f75f0..141a8b9c8 100644 --- a/compiler/include/concretelang/Support/Pipeline.h +++ b/compiler/include/concretelang/Support/Pipeline.h @@ -68,14 +68,9 @@ lowerConcreteLinalgToLoops(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass, bool parallelizeLoops, bool batchOperations); -mlir::LogicalResult -lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, - std::function enablePass, - bool parallelizeLoops); - -mlir::LogicalResult -optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, - std::function enablePass); +mlir::LogicalResult optimizeTFHE(mlir::MLIRContext &context, + mlir::ModuleOp &module, + std::function enablePass); mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context, mlir::ModuleOp &module, @@ -83,8 +78,8 @@ mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context, bool unrollLoops); mlir::LogicalResult -lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, - std::function enablePass); +lowerConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass); mlir::LogicalResult lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index e6e4e2b10..3f3ba8fcf 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -53,10 +53,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](CompilationOptions &options, bool b) { options.dataflowParallelize = b; }) - .def("set_optimize_concrete", - [](CompilationOptions &options, bool b) { - options.optimizeConcrete = b; - }) + .def("set_optimize_concrete", [](CompilationOptions &options, + bool b) { options.optimizeTFHE = b; }) .def("set_p_error", [](CompilationOptions &options, double p_error) { options.optimizerConfig.p_error = p_error; diff --git a/compiler/lib/Bindings/Rust/build.rs b/compiler/lib/Bindings/Rust/build.rs index 1ad9114be..1c6cc848c 100644 --- a/compiler/lib/Bindings/Rust/build.rs +++ b/compiler/lib/Bindings/Rust/build.rs @@ -241,20 +241,19 @@ const LLVM_STATIC_LIBS: [&str; 51] = [ "LLVMX86Info", ]; -const CONCRETE_COMPILER_LIBS: [&str; 35] = [ +const CONCRETE_COMPILER_LIBS: [&str; 33] = [ "RTDialect", "RTDialectTransforms", "ConcretelangSupport", - "BConcreteToCAPI", + "ConcreteToCAPI", "ConcretelangConversion", "ConcretelangTransforms", "FHETensorOpsToLinalg", "ConcretelangServerLib", - "ConcreteToBConcrete", "CONCRETELANGCAPIFHE", "TFHEGlobalParametrization", "ConcretelangClientLib", - "ConcretelangBConcreteTransforms", + "ConcretelangConcreteTransforms", "ConcretelangSDFGInterfaces", "ConcretelangSDFGTransforms", "CONCRETELANGCAPISupport", @@ -267,8 +266,7 @@ const CONCRETE_COMPILER_LIBS: [&str; 35] = [ "TFHEToConcrete", "FHEToTFHECrt", "FHEToTFHEScalar", - "ConcreteDialectTransforms", - "BConcreteDialect", + "TFHEDialectTransforms", "concrete_optimizer", "LinalgExtras", "FHEDialectAnalysis", diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index cd59eee8e..e853ce5a2 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -65,7 +65,7 @@ CompilationOptions compilationOptionsCreate(MlirStringRef funcName, bool autoParallelize, bool batchConcreteOps, bool dataflowParallelize, bool emitGPUOps, bool loopParallelize, - bool optimizeConcrete, OptimizerConfig optimizerConfig, + bool optimizeTFHE, OptimizerConfig optimizerConfig, bool verifyDiagnostics) { std::string funcNameStr(funcName.data, funcName.length); auto options = new mlir::concretelang::CompilationOptions(funcNameStr); @@ -74,7 +74,7 @@ compilationOptionsCreate(MlirStringRef funcName, bool autoParallelize, options->dataflowParallelize = dataflowParallelize; options->emitGPUOps = emitGPUOps; options->loopParallelize = loopParallelize; - options->optimizeConcrete = optimizeConcrete; + options->optimizeTFHE = optimizeTFHE; options->optimizerConfig = *unwrap(optimizerConfig); options->verifyDiagnostics = verifyDiagnostics; return wrap(options); @@ -133,8 +133,6 @@ llvm::Expected -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "concretelang/Conversion/Passes.h" -#include "concretelang/Conversion/Utils/FuncConstOpConversion.h" -#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h" -#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" -#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" -#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" -#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" -#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" -#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" -#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" -#include "concretelang/Dialect/RT/IR/RTOps.h" -#include "concretelang/Dialect/Tracing/IR/TracingOps.h" - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/IR/Function.h" - -namespace Concrete = ::mlir::concretelang::Concrete; -namespace BConcrete = ::mlir::concretelang::BConcrete; -namespace Tracing = ::mlir::concretelang::Tracing; - -namespace { -struct ConcreteToBConcretePass - : public ConcreteToBConcreteBase { - void runOnOperation() final; -}; -} // namespace - -/// ConcreteToBConcreteTypeConverter is a TypeConverter that transform -/// `Concrete.lwe_ciphertext` to `tensor>` -/// `tensor<...xConcrete.lwe_ciphertext>` to -/// `tensor<...xdimension+1, i64>>` -class ConcreteToBConcreteTypeConverter : public mlir::TypeConverter { - -public: - ConcreteToBConcreteTypeConverter() { - addConversion([](mlir::Type type) { return type; }); - addConversion([&](mlir::concretelang::Concrete::PlaintextType type) { - return mlir::IntegerType::get(type.getContext(), 64); - }); - addConversion([&](mlir::concretelang::Concrete::CleartextType type) { - return mlir::IntegerType::get(type.getContext(), 64); - }); - addConversion([&](mlir::concretelang::Concrete::LweCiphertextType type) { - assert(type.getDimension() != -1); - llvm::SmallVector shape; - shape.push_back(type.getDimension() + 1); - return mlir::RankedTensorType::get( - shape, mlir::IntegerType::get(type.getContext(), 64)); - }); - addConversion([&](mlir::concretelang::Concrete::GlweCiphertextType type) { - assert(type.getGlweDimension() != -1); - assert(type.getPolynomialSize() != -1); - - return mlir::RankedTensorType::get( - {type.getPolynomialSize() * (type.getGlweDimension() + 1)}, - mlir::IntegerType::get(type.getContext(), 64)); - }); - addConversion([&](mlir::RankedTensorType type) { - auto lwe = type.getElementType() - .dyn_cast_or_null< - mlir::concretelang::Concrete::LweCiphertextType>(); - if (lwe == nullptr) { - return (mlir::Type)(type); - } - assert(lwe.getDimension() != -1); - mlir::SmallVector newShape; - newShape.reserve(type.getShape().size() + 1); - newShape.append(type.getShape().begin(), type.getShape().end()); - newShape.push_back(lwe.getDimension() + 1); - mlir::Type r = mlir::RankedTensorType::get( - newShape, mlir::IntegerType::get(type.getContext(), 64)); - return r; - }); - addConversion([&](mlir::concretelang::RT::FutureType type) { - return mlir::concretelang::RT::FutureType::get( - this->convertType(type.dyn_cast() - .getElementType())); - }); - addConversion([&](mlir::concretelang::RT::PointerType type) { - return mlir::concretelang::RT::PointerType::get( - this->convertType(type.dyn_cast() - .getElementType())); - }); - } -}; - -template -struct ZeroOpPattern : public mlir::OpRewritePattern { - ZeroOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern(context, benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(ZeroOp zeroOp, - ::mlir::PatternRewriter &rewriter) const override { - ConcreteToBConcreteTypeConverter converter; - auto resultTy = zeroOp.getType(); - auto newResultTy = converter.convertType(resultTy); - - auto generateBody = [&](mlir::OpBuilder &nestedBuilder, - mlir::Location nestedLoc, - mlir::ValueRange blockArgs) { - // %c0 = 0 : i64 - auto cstOp = nestedBuilder.create( - nestedLoc, nestedBuilder.getI64IntegerAttr(0)); - // tensor.yield %z : !FHE.eint

- nestedBuilder.create(nestedLoc, cstOp.getResult()); - }; - // tensor.generate - rewriter.replaceOpWithNewOp( - zeroOp, newResultTy, mlir::ValueRange{}, generateBody); - - return ::mlir::success(); - }; -}; - -template -struct LowToBConcrete : public mlir::OpRewritePattern { - LowToBConcrete(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern(context, benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(ConcreteOp concreteOp, - ::mlir::PatternRewriter &rewriter) const override { - ConcreteToBConcreteTypeConverter converter; - mlir::TypeRange resultTyRange = concreteOp->getResultTypes(); - - llvm::ArrayRef<::mlir::NamedAttribute> attributes = - concreteOp.getOperation()->getAttrs(); - - mlir::Operation *bConcreteOp; - bConcreteOp = rewriter.replaceOpWithNewOp( - concreteOp, resultTyRange, concreteOp.getOperation()->getOperands(), - attributes); - - mlir::concretelang::convertOperandAndResultTypes( - rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) { - return converter.convertType(t); - }); - - return ::mlir::success(); - }; -}; - -struct LowerKeySwitch : public mlir::OpRewritePattern< - mlir::concretelang::Concrete::KeySwitchLweOp> { - LowerKeySwitch(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern( - context, benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(mlir::concretelang::Concrete::KeySwitchLweOp ksOp, - ::mlir::PatternRewriter &rewriter) const override { - ConcreteToBConcreteTypeConverter converter; - - // construct attributes for in/out dimensions - mlir::concretelang::Concrete::LweCiphertextType outType = ksOp.getType(); - auto outDimAttr = rewriter.getI32IntegerAttr(outType.getDimension()); - auto inputType = converter.convertType(ksOp.ciphertext().getType()) - .cast(); - auto inputDimension = inputType.getShape().back() - 1; - mlir::IntegerAttr inputDimAttr = rewriter.getI32IntegerAttr(inputDimension); - - mlir::Operation *bKeySwitchOp = rewriter.replaceOpWithNewOp< - mlir::concretelang::BConcrete::KeySwitchLweTensorOp>( - ksOp, outType, ksOp.ciphertext(), ksOp.levelAttr(), ksOp.baseLogAttr(), - inputDimAttr, outDimAttr); - - mlir::concretelang::convertOperandAndResultTypes( - rewriter, bKeySwitchOp, [&](mlir::MLIRContext *, mlir::Type t) { - return converter.convertType(t); - }); - - return ::mlir::success(); - }; -}; - -struct LowerBatchedKeySwitch - : public mlir::OpRewritePattern< - mlir::concretelang::Concrete::BatchedKeySwitchLweOp> { - LowerBatchedKeySwitch(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern< - mlir::concretelang::Concrete::BatchedKeySwitchLweOp>(context, - benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(mlir::concretelang::Concrete::BatchedKeySwitchLweOp bksOp, - ::mlir::PatternRewriter &rewriter) const override { - ConcreteToBConcreteTypeConverter converter; - - mlir::concretelang::Concrete::LweCiphertextType outType = - bksOp.getType() - .cast() - .getElementType() - .cast(); - - auto outDimAttr = rewriter.getI32IntegerAttr(outType.getDimension()); - auto inputType = - bksOp.ciphertexts() - .getType() - .cast() - .getElementType() - .cast(); - - mlir::IntegerAttr inputDimAttr = - rewriter.getI32IntegerAttr(inputType.getDimension()); - - mlir::Operation *bBatchedKeySwitchOp = rewriter.replaceOpWithNewOp< - mlir::concretelang::BConcrete::BatchedKeySwitchLweTensorOp>( - bksOp, bksOp.getType(), bksOp.ciphertexts(), bksOp.levelAttr(), - bksOp.baseLogAttr(), inputDimAttr, outDimAttr); - - mlir::concretelang::convertOperandAndResultTypes( - rewriter, bBatchedKeySwitchOp, [&](mlir::MLIRContext *, mlir::Type t) { - return converter.convertType(t); - }); - - return ::mlir::success(); - }; -}; - -struct LowerBootstrap : public mlir::OpRewritePattern< - mlir::concretelang::Concrete::BootstrapLweOp> { - LowerBootstrap(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern( - context, benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(mlir::concretelang::Concrete::BootstrapLweOp bsOp, - ::mlir::PatternRewriter &rewriter) const override { - ConcreteToBConcreteTypeConverter converter; - - mlir::concretelang::Concrete::LweCiphertextType outType = bsOp.getType(); - auto inputType = converter.convertType(bsOp.input_ciphertext().getType()) - .cast(); - auto inputDimension = inputType.getShape().back() - 1; - mlir::IntegerAttr inputDimAttr = rewriter.getI32IntegerAttr(inputDimension); - - auto outputPrecisionAttr = rewriter.getI32IntegerAttr(outType.getP()); - mlir::Operation *bBootstrapOp = rewriter.replaceOpWithNewOp< - mlir::concretelang::BConcrete::BootstrapLweTensorOp>( - bsOp, outType, bsOp.input_ciphertext(), bsOp.lookup_table(), - inputDimAttr, bsOp.polySizeAttr(), bsOp.levelAttr(), bsOp.baseLogAttr(), - bsOp.glweDimensionAttr(), outputPrecisionAttr); - - mlir::concretelang::convertOperandAndResultTypes( - rewriter, bBootstrapOp, [&](mlir::MLIRContext *, mlir::Type t) { - return converter.convertType(t); - }); - - return ::mlir::success(); - }; -}; - -struct LowerBatchedBootstrap - : public mlir::OpRewritePattern< - mlir::concretelang::Concrete::BatchedBootstrapLweOp> { - LowerBatchedBootstrap(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern< - mlir::concretelang::Concrete::BatchedBootstrapLweOp>(context, - benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(mlir::concretelang::Concrete::BatchedBootstrapLweOp bbsOp, - ::mlir::PatternRewriter &rewriter) const override { - ConcreteToBConcreteTypeConverter converter; - - mlir::concretelang::Concrete::LweCiphertextType outType = - bbsOp.getType() - .cast() - .getElementType() - .cast(); - - auto inputType = - bbsOp.input_ciphertexts() - .getType() - .cast() - .getElementType() - .cast(); - - auto inputDimAttr = rewriter.getI32IntegerAttr(inputType.getDimension()); - auto outputPrecisionAttr = rewriter.getI32IntegerAttr(outType.getP()); - - mlir::Operation *bBatchedBootstrapOp = rewriter.replaceOpWithNewOp< - mlir::concretelang::BConcrete::BatchedBootstrapLweTensorOp>( - bbsOp, bbsOp.getType(), bbsOp.input_ciphertexts(), bbsOp.lookup_table(), - inputDimAttr, bbsOp.polySizeAttr(), bbsOp.levelAttr(), - bbsOp.baseLogAttr(), bbsOp.glweDimensionAttr(), outputPrecisionAttr); - - mlir::concretelang::convertOperandAndResultTypes( - rewriter, bBatchedBootstrapOp, [&](mlir::MLIRContext *, mlir::Type t) { - return converter.convertType(t); - }); - - return ::mlir::success(); - }; -}; - -struct AddPlaintextLweCiphertextOpPattern - : public mlir::OpRewritePattern { - AddPlaintextLweCiphertextOpPattern(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern( - context, benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(Concrete::AddPlaintextLweCiphertextOp concreteOp, - ::mlir::PatternRewriter &rewriter) const override { - ConcreteToBConcreteTypeConverter converter; - mlir::concretelang::Concrete::LweCiphertextType resultTy = - ((mlir::Type)concreteOp->getResult(0).getType()) - .cast(); - auto newResultTy = - converter.convertType(resultTy).cast(); - - llvm::ArrayRef<::mlir::NamedAttribute> attributes = - concreteOp.getOperation()->getAttrs(); - - mlir::Operation *bConcreteOp; - bConcreteOp = - rewriter.replaceOpWithNewOp( - concreteOp, newResultTy, - mlir::ValueRange{concreteOp.lhs(), concreteOp.rhs()}, attributes); - - mlir::concretelang::convertOperandAndResultTypes( - rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) { - return converter.convertType(t); - }); - - return ::mlir::success(); - }; -}; - -struct MulCleartextLweCiphertextOpPattern - : public mlir::OpRewritePattern { - MulCleartextLweCiphertextOpPattern(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern( - context, benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(Concrete::MulCleartextLweCiphertextOp concreteOp, - ::mlir::PatternRewriter &rewriter) const override { - ConcreteToBConcreteTypeConverter converter; - mlir::concretelang::Concrete::LweCiphertextType resultTy = - ((mlir::Type)concreteOp->getResult(0).getType()) - .cast(); - auto newResultTy = - converter.convertType(resultTy).cast(); - - llvm::ArrayRef<::mlir::NamedAttribute> attributes = - concreteOp.getOperation()->getAttrs(); - - mlir::Operation *bConcreteOp; - bConcreteOp = - rewriter.replaceOpWithNewOp( - concreteOp, newResultTy, - mlir::ValueRange{concreteOp.lhs(), concreteOp.rhs()}, attributes); - - mlir::concretelang::convertOperandAndResultTypes( - rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) { - return converter.convertType(t); - }); - - return ::mlir::success(); - }; -}; - -struct ExtractSliceOpPattern - : public mlir::OpRewritePattern { - ExtractSliceOpPattern(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern(context, - benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, - ::mlir::PatternRewriter &rewriter) const override { - ConcreteToBConcreteTypeConverter converter; - auto resultTy = extractSliceOp.result().getType(); - auto newResultTy = - converter.convertType(resultTy).cast(); - - // add 0 to the static_offsets - mlir::SmallVector staticOffsets; - staticOffsets.append(extractSliceOp.static_offsets().begin(), - extractSliceOp.static_offsets().end()); - staticOffsets.push_back(rewriter.getI64IntegerAttr(0)); - - // add the lweSize to the sizes - mlir::SmallVector staticSizes; - staticSizes.append(extractSliceOp.static_sizes().begin(), - extractSliceOp.static_sizes().end()); - staticSizes.push_back(rewriter.getI64IntegerAttr( - newResultTy.getDimSize(newResultTy.getRank() - 1))); - - // add 1 to the strides - mlir::SmallVector staticStrides; - staticStrides.append(extractSliceOp.static_strides().begin(), - extractSliceOp.static_strides().end()); - staticStrides.push_back(rewriter.getI64IntegerAttr(1)); - - // replace tensor.extract_slice to the new one - mlir::tensor::ExtractSliceOp extractOp = - rewriter.replaceOpWithNewOp( - extractSliceOp, newResultTy, extractSliceOp.source(), - extractSliceOp.offsets(), extractSliceOp.sizes(), - extractSliceOp.strides(), rewriter.getArrayAttr(staticOffsets), - rewriter.getArrayAttr(staticSizes), - rewriter.getArrayAttr(staticStrides)); - - mlir::concretelang::convertOperandAndResultTypes( - rewriter, extractOp, [&](mlir::MLIRContext *, mlir::Type t) { - return converter.convertType(t); - }); - - return ::mlir::success(); - }; -}; - -// TODO: since they are a bug on lowering extract_slice with rank reduction we -// add a linalg.tensor_collapse_shape after the extract_slice without rank -// reduction. See -// https://github.com/zama-ai/concrete-compiler-internal/issues/396. -struct ExtractOpPattern - : public mlir::OpRewritePattern { - ExtractOpPattern(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern(context, benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(mlir::tensor::ExtractOp extractOp, - ::mlir::PatternRewriter &rewriter) const override { - ConcreteToBConcreteTypeConverter converter; - auto lweResultTy = - extractOp.result() - .getType() - .dyn_cast_or_null< - mlir::concretelang::Concrete::LweCiphertextType>(); - if (lweResultTy == nullptr) { - return mlir::failure(); - } - auto newResultTy = - converter.convertType(lweResultTy).cast(); - auto rankOfResult = extractOp.indices().size() + 1; - - // [min..., 0] for static_offsets () - mlir::SmallVector staticOffsets( - rankOfResult, - rewriter.getI64IntegerAttr(std::numeric_limits::min())); - staticOffsets[staticOffsets.size() - 1] = rewriter.getI64IntegerAttr(0); - - // [1..., lweDimension+1] for static_sizes or - // [1..., nbBlock, lweDimension+1] - mlir::SmallVector staticSizes( - rankOfResult, rewriter.getI64IntegerAttr(1)); - staticSizes[staticSizes.size() - 1] = rewriter.getI64IntegerAttr( - newResultTy.getDimSize(newResultTy.getRank() - 1)); - - // [1...] for static_strides - mlir::SmallVector staticStrides( - rankOfResult, rewriter.getI64IntegerAttr(1)); - - // replace tensor.extract_slice to the new one - mlir::SmallVector extractedSliceShape(rankOfResult, 1); - extractedSliceShape[extractedSliceShape.size() - 1] = - newResultTy.getDimSize(0); - - auto extractedSliceType = - mlir::RankedTensorType::get(extractedSliceShape, rewriter.getI64Type()); - - auto extractedSlice = rewriter.create( - extractOp.getLoc(), extractedSliceType, extractOp.tensor(), - extractOp.indices(), mlir::SmallVector{}, - mlir::SmallVector{}, rewriter.getArrayAttr(staticOffsets), - rewriter.getArrayAttr(staticSizes), - rewriter.getArrayAttr(staticStrides)); - mlir::concretelang::convertOperandAndResultTypes( - rewriter, extractedSlice, [&](mlir::MLIRContext *, mlir::Type t) { - return converter.convertType(t); - }); - - mlir::ReassociationIndices reassociation; - for (int64_t i = 0; i < extractedSliceType.getRank(); i++) { - reassociation.push_back(i); - } - - mlir::SmallVector reassocs{reassociation}; - - mlir::tensor::CollapseShapeOp collapseOp = - rewriter.replaceOpWithNewOp( - extractOp, newResultTy, extractedSlice, reassocs); - - mlir::concretelang::convertOperandAndResultTypes( - rewriter, collapseOp, [&](mlir::MLIRContext *, mlir::Type t) { - return converter.convertType(t); - }); - - return ::mlir::success(); - }; -}; - -struct InsertSliceOpPattern - : public mlir::OpRewritePattern { - InsertSliceOpPattern(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern(context, - benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(mlir::tensor::InsertSliceOp insertSliceOp, - ::mlir::PatternRewriter &rewriter) const override { - ConcreteToBConcreteTypeConverter converter; - auto resultTy = insertSliceOp.result().getType(); - auto lweResultTy = - resultTy.cast() - .getElementType() - .cast(); - if (lweResultTy == nullptr) { - return mlir::failure(); - } - auto newResultTy = - converter.convertType(resultTy).cast(); - - // add 0 to static_offsets - mlir::SmallVector staticOffsets; - staticOffsets.append(insertSliceOp.static_offsets().begin(), - insertSliceOp.static_offsets().end()); - staticOffsets.push_back(rewriter.getI64IntegerAttr(0)); - - // add lweDimension+1 to static_sizes - mlir::SmallVector staticSizes; - staticSizes.append(insertSliceOp.static_sizes().begin(), - insertSliceOp.static_sizes().end()); - staticSizes.push_back(rewriter.getI64IntegerAttr( - newResultTy.getDimSize(newResultTy.getRank() - 1))); - - // add 1 to the strides - mlir::SmallVector staticStrides; - staticStrides.append(insertSliceOp.static_strides().begin(), - insertSliceOp.static_strides().end()); - staticStrides.push_back(rewriter.getI64IntegerAttr(1)); - - // replace tensor.insert_slice with the new one - auto newOp = rewriter.replaceOpWithNewOp( - insertSliceOp, newResultTy, insertSliceOp.source(), - insertSliceOp.dest(), insertSliceOp.offsets(), insertSliceOp.sizes(), - insertSliceOp.strides(), rewriter.getArrayAttr(staticOffsets), - rewriter.getArrayAttr(staticSizes), - rewriter.getArrayAttr(staticStrides)); - - mlir::concretelang::convertOperandAndResultTypes( - rewriter, newOp, [&](mlir::MLIRContext *, mlir::Type t) { - return converter.convertType(t); - }); - - return ::mlir::success(); - }; -}; - -struct InsertOpPattern : public mlir::OpRewritePattern { - InsertOpPattern(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern(context, benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(mlir::tensor::InsertOp insertOp, - ::mlir::PatternRewriter &rewriter) const override { - ConcreteToBConcreteTypeConverter converter; - auto resultTy = - insertOp.result().getType().dyn_cast_or_null(); - auto lweResultTy = resultTy.getElementType() - .dyn_cast_or_null(); - if (lweResultTy == nullptr) { - return mlir::failure(); - }; - mlir::RankedTensorType newResultTy = - converter.convertType(resultTy).cast(); - - // add zeros to static_offsets - mlir::SmallVector offsets; - offsets.append(insertOp.indices().begin(), insertOp.indices().end()); - offsets.push_back(rewriter.getIndexAttr(0)); - - // Inserting a smaller tensor into a (potentially) bigger one. Set - // dimensions for all leading dimensions of the target tensor not - // present in the source to 1. - mlir::SmallVector sizes(insertOp.indices().size(), - rewriter.getI64IntegerAttr(1)); - - // Add size for the bufferized source element - sizes.push_back(rewriter.getI64IntegerAttr( - newResultTy.getDimSize(newResultTy.getRank() - 1))); - - // Set stride of all dimensions to 1 - mlir::SmallVector strides( - newResultTy.getRank(), rewriter.getI64IntegerAttr(1)); - - // replace tensor.insert_slice with the new one - mlir::tensor::InsertSliceOp insertSliceOp = - rewriter.replaceOpWithNewOp( - insertOp, insertOp.getOperand(0), insertOp.dest(), offsets, sizes, - strides); - - mlir::concretelang::convertOperandAndResultTypes( - rewriter, insertSliceOp, [&](mlir::MLIRContext *, mlir::Type t) { - return converter.convertType(t); - }); - - return ::mlir::success(); - }; -}; - -/// FromElementsOpPatterns transform each tensor.from_elements that operates on -/// Concrete.lwe_ciphertext -/// -/// refs: check_tests/Conversion/ConcreteToBConcrete/tensor_from_elements.mlir -struct FromElementsOpPattern - : public mlir::OpRewritePattern { - FromElementsOpPattern(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern(context, - benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(mlir::tensor::FromElementsOp fromElementsOp, - ::mlir::PatternRewriter &rewriter) const override { - ConcreteToBConcreteTypeConverter converter; - - auto resultTy = fromElementsOp.result().getType(); - if (converter.isLegal(resultTy)) { - return mlir::failure(); - } - auto oldTensorResultTy = resultTy.cast(); - auto oldRank = oldTensorResultTy.getRank(); - - auto newTensorResultTy = - converter.convertType(resultTy).cast(); - auto newRank = newTensorResultTy.getRank(); - auto newShape = newTensorResultTy.getShape(); - - mlir::Value tensor = rewriter.create( - fromElementsOp.getLoc(), newTensorResultTy, mlir::ValueRange{}); - - // sizes are [1, ..., 1, diffShape...] - llvm::SmallVector sizes(oldRank, - rewriter.getI64IntegerAttr(1)); - for (auto i = newRank - oldRank; i > 0; i--) { - sizes.push_back(rewriter.getI64IntegerAttr(*(newShape.end() - i))); - } - - // strides are [1, ..., 1] - llvm::SmallVector oneStrides( - newShape.size(), rewriter.getI64IntegerAttr(1)); - - // start with offets [0, ..., 0] - llvm::SmallVector currentOffsets(newRank, 0); - - // for each elements insert_slice with right offet - for (auto elt : llvm::enumerate(fromElementsOp.elements())) { - // Just create offsets as attributes - llvm::SmallVector offsets; - offsets.reserve(currentOffsets.size()); - std::transform(currentOffsets.begin(), currentOffsets.end(), - std::back_inserter(offsets), - [&](auto v) { return rewriter.getI64IntegerAttr(v); }); - mlir::tensor::InsertSliceOp insOp = - rewriter.create( - fromElementsOp.getLoc(), - /* src: */ elt.value(), - /* dst: */ tensor, - /* offs: */ offsets, - /* sizes: */ sizes, - /* strides: */ oneStrides); - - mlir::concretelang::convertOperandAndResultTypes( - rewriter, insOp, [&](mlir::MLIRContext *, mlir::Type t) { - return converter.convertType(t); - }); - - tensor = insOp.getResult(); - - // Increment the offsets - for (auto i = newRank - 2; i >= 0; i--) { - if (currentOffsets[i] == newShape[i] - 1) { - currentOffsets[i] = 0; - continue; - } - currentOffsets[i]++; - break; - } - } - - rewriter.replaceOp(fromElementsOp, tensor); - return ::mlir::success(); - }; -}; - -// This template rewrite pattern transforms any instance of -// `ShapeOp` operators that operates on tensor of lwe ciphertext by adding the -// lwe size as a size of the tensor result and by adding a trivial -// reassociation at the end of the reassociations map. -// -// Example: -// -// ```mlir -// %0 = "ShapeOp" %arg0 [reassocations...] -// : tensor<...x!Concrete.lwe_ciphertext> into -// tensor<...x!Concrete.lwe_ciphertext> -// ``` -// -// becomes: -// -// ```mlir -// %0 = "ShapeOp" %arg0 [reassociations..., [inRank or outRank]] -// : tensor<...xlweDimesion+1xi64> into -// tensor<...xlweDimesion+1xi64> -// ``` -template -struct TensorShapeOpPattern : public mlir::OpRewritePattern { - TensorShapeOpPattern(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern(context, benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(ShapeOp shapeOp, - ::mlir::PatternRewriter &rewriter) const override { - ConcreteToBConcreteTypeConverter converter; - auto resultTy = ((mlir::Type)shapeOp.result().getType()).cast(); - - auto newResultTy = - ((mlir::Type)converter.convertType(resultTy)).cast(); - - auto reassocTy = - ((mlir::Type)converter.convertType( - (inRank ? shapeOp.src() : shapeOp.result()).getType())) - .cast(); - - auto oldReassocs = shapeOp.getReassociationIndices(); - mlir::SmallVector newReassocs; - newReassocs.append(oldReassocs.begin(), oldReassocs.end()); - - // add [rank] to reassociations - { - mlir::ReassociationIndices lweAssoc; - lweAssoc.push_back(reassocTy.getRank() - 1); - newReassocs.push_back(lweAssoc); - } - - ShapeOp op = rewriter.replaceOpWithNewOp( - shapeOp, newResultTy, shapeOp.src(), newReassocs); - - // fix operand types - mlir::concretelang::convertOperandAndResultTypes( - rewriter, op, [&](mlir::MLIRContext *, mlir::Type t) { - return converter.convertType(t); - }); - - return ::mlir::success(); - }; -}; - -/// Add the instantiated TensorShapeOpPattern rewrite pattern with the `ShapeOp` -/// to the patterns set and populate the conversion target. -template -void insertTensorShapeOpPattern(mlir::MLIRContext &context, - mlir::RewritePatternSet &patterns, - mlir::ConversionTarget &target) { - patterns.insert>(&context); - target.addDynamicallyLegalOp([&](mlir::Operation *op) { - ConcreteToBConcreteTypeConverter converter; - return converter.isLegal(op->getResultTypes()) && - converter.isLegal(op->getOperandTypes()); - }); -} - -struct AllocTensorOpPattern - : public mlir::OpRewritePattern { - AllocTensorOpPattern(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern(context, - benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(mlir::bufferization::AllocTensorOp allocTensorOp, - ::mlir::PatternRewriter &rewriter) const override { - ConcreteToBConcreteTypeConverter converter; - mlir::RankedTensorType resultTy = - allocTensorOp.getType().dyn_cast(); - - if (!resultTy || !resultTy.hasStaticShape()) - return mlir::failure(); - - mlir::RankedTensorType newResultTy = - converter.convertType(resultTy).dyn_cast(); - - if (resultTy.getShape().size() != newResultTy.getShape().size()) { - rewriter.replaceOpWithNewOp( - allocTensorOp, newResultTy, mlir::ValueRange{}); - } - - return ::mlir::success(); - }; -}; - -struct ForOpPattern : public mlir::OpRewritePattern { - ForOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern(context, benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(mlir::scf::ForOp forOp, - ::mlir::PatternRewriter &rewriter) const override { - ConcreteToBConcreteTypeConverter converter; - - // TODO: Check if there is a cleaner way to modify the types in - // place through appropriate interfaces or by reconstructing the - // ForOp with the right types. - rewriter.updateRootInPlace(forOp, [&] { - for (mlir::Value initArg : forOp.getInitArgs()) { - mlir::Type convertedType = converter.convertType(initArg.getType()); - initArg.setType(convertedType); - } - - for (mlir::Value &blockArg : forOp.getBody()->getArguments()) { - mlir::Type convertedType = converter.convertType(blockArg.getType()); - blockArg.setType(convertedType); - } - - for (mlir::OpResult result : forOp.getResults()) { - mlir::Type convertedType = converter.convertType(result.getType()); - result.setType(convertedType); - } - }); - - return ::mlir::success(); - }; -}; - -void ConcreteToBConcretePass::runOnOperation() { - auto op = this->getOperation(); - - // Then convert ciphertext to tensor or add a dimension to tensor of - // ciphertext and memref of ciphertext - { - mlir::ConversionTarget target(getContext()); - ConcreteToBConcreteTypeConverter converter; - mlir::RewritePatternSet patterns(&getContext()); - - // All BConcrete ops are legal after the conversion - target.addLegalDialect(); - - // Add Concrete ops are illegal after the conversion - target.addIllegalDialect(); - - target.addLegalDialect(); - - // Add patterns to convert the zero ops to tensor.generate - patterns - .insert, - ZeroOpPattern>( - &getContext()); - target.addLegalOp(); - - // Add patterns to trivialy convert Concrete op to the equivalent - // BConcrete op - patterns.insert< - LowerBootstrap, LowerBatchedBootstrap, LowerKeySwitch, - LowerBatchedKeySwitch, - LowToBConcrete, - AddPlaintextLweCiphertextOpPattern, MulCleartextLweCiphertextOpPattern, - LowToBConcrete< - mlir::concretelang::Concrete::EncodeExpandLutForBootstrapOp, - mlir::concretelang::BConcrete::EncodeExpandLutForBootstrapTensorOp>, - LowToBConcrete< - mlir::concretelang::Concrete::EncodeExpandLutForWopPBSOp, - mlir::concretelang::BConcrete::EncodeExpandLutForWopPBSTensorOp>, - LowToBConcrete< - mlir::concretelang::Concrete::EncodePlaintextWithCrtOp, - mlir::concretelang::BConcrete::EncodePlaintextWithCrtTensorOp>, - LowToBConcrete, - LowToBConcrete>( - &getContext()); - - // Add patterns to rewrite tensor operators that works on encrypted - // tensors - patterns - .insert(&getContext()); - - target.addDynamicallyLegalOp( - [&](mlir::Operation *op) { - return converter.isLegal(op->getResultTypes()) && - converter.isLegal(op->getOperandTypes()); - }); - - patterns.insert(&getContext()); - - target.addDynamicallyLegalOp( - [&](mlir::Operation *op) { - return converter.isLegal(op->getResult(0).getType()); - }); - target.addLegalOp(); - - patterns.insert(&getContext()); - - // Add patterns to rewrite some of memref ops that was introduced by the - // linalg bufferization of encrypted tensor (first conversion of this - // pass) - insertTensorShapeOpPattern(getContext(), patterns, target); - insertTensorShapeOpPattern(getContext(), patterns, target); - insertTensorShapeOpPattern(getContext(), patterns, target); - insertTensorShapeOpPattern(getContext(), patterns, target); - - target.addDynamicallyLegalOp< - mlir::arith::ConstantOp, mlir::scf::ForOp, mlir::scf::ParallelOp, - mlir::scf::YieldOp, mlir::AffineApplyOp, mlir::memref::SubViewOp, - mlir::memref::LoadOp, mlir::memref::TensorStoreOp>( - [&](mlir::Operation *op) { - return converter.isLegal(op->getResultTypes()) && - converter.isLegal(op->getOperandTypes()); - }); - - // Add patterns to do the conversion of func - mlir::populateFunctionOpInterfaceTypeConversionPattern( - patterns, converter); - - target.addDynamicallyLegalOp( - [&](mlir::func::FuncOp funcOp) { - return converter.isSignatureLegal(funcOp.getFunctionType()) && - converter.isLegal(&funcOp.getBody()); - }); - target.addDynamicallyLegalOp( - [&](mlir::func::ConstantOp op) { - return FunctionConstantOpConversion< - ConcreteToBConcreteTypeConverter>::isLegal(op, converter); - }); - patterns - .insert>( - &getContext(), converter); - - target.addDynamicallyLegalOp([&](mlir::scf::ForOp forOp) { - return converter.isLegal(forOp.getInitArgs().getTypes()) && - converter.isLegal(forOp.getResults().getTypes()); - }); - - // Add pattern for return op - target.addDynamicallyLegalOp( - [&](mlir::Operation *op) { - return converter.isLegal(op->getResultTypes()) && - converter.isLegal(op->getOperandTypes()); - }); - - // Conversion of Tracing dialect - patterns.add, - mlir::concretelang::GenericTypeConverterPattern< - Tracing::TracePlaintextOp>>(&getContext(), converter); - mlir::concretelang::addDynamicallyLegalTypeOp( - target, converter); - mlir::concretelang::addDynamicallyLegalTypeOp( - target, converter); - - // Conversion of RT Dialect Ops - patterns.add< - mlir::concretelang::GenericTypeConverterPattern, - mlir::concretelang::GenericTypeConverterPattern, - mlir::concretelang::GenericTypeConverterPattern< - mlir::concretelang::RT::MakeReadyFutureOp>, - mlir::concretelang::GenericTypeConverterPattern< - mlir::concretelang::RT::AwaitFutureOp>, - mlir::concretelang::GenericTypeConverterPattern< - mlir::concretelang::RT::CreateAsyncTaskOp>, - mlir::concretelang::GenericTypeConverterPattern< - mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>, - mlir::concretelang::GenericTypeConverterPattern< - mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>, - mlir::concretelang::GenericTypeConverterPattern< - mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>, - mlir::concretelang::GenericTypeConverterPattern< - mlir::concretelang::RT::WorkFunctionReturnOp>, - mlir::concretelang::GenericTypeConverterPattern< - mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(), - converter); - mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::concretelang::RT::MakeReadyFutureOp>(target, converter); - mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::concretelang::RT::AwaitFutureOp>(target, converter); - mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::concretelang::RT::CreateAsyncTaskOp>(target, converter); - mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter); - mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>( - target, converter); - mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter); - mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::concretelang::RT::WorkFunctionReturnOp>(target, converter); - mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter); - - // Apply conversion - if (mlir::applyPartialConversion(op, target, std::move(patterns)) - .failed()) { - this->signalPassFailure(); - } - } -} - -namespace mlir { -namespace concretelang { -std::unique_ptr> -createConvertConcreteToBConcretePass() { - return std::make_unique(); -} -} // namespace concretelang -} // namespace mlir diff --git a/compiler/lib/Conversion/ConcreteToCAPI/CMakeLists.txt b/compiler/lib/Conversion/ConcreteToCAPI/CMakeLists.txt new file mode 100644 index 000000000..f85c4a1a6 --- /dev/null +++ b/compiler/lib/Conversion/ConcreteToCAPI/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library( + ConcreteToCAPI + ConcreteToCAPI.cpp + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Concrete + DEPENDS + ConcreteDialect + mlir-headers + LINK_LIBS + PUBLIC + MLIRIR + MLIRTransforms) + +target_link_libraries(ConcreteToCAPI PUBLIC ConcreteDialect MLIRIR) diff --git a/compiler/lib/Conversion/BConcreteToCAPI/BConcreteToCAPI.cpp b/compiler/lib/Conversion/ConcreteToCAPI/ConcreteToCAPI.cpp similarity index 81% rename from compiler/lib/Conversion/BConcreteToCAPI/BConcreteToCAPI.cpp rename to compiler/lib/Conversion/ConcreteToCAPI/ConcreteToCAPI.cpp index de1729668..be9fc9aaf 100644 --- a/compiler/lib/Conversion/BConcreteToCAPI/BConcreteToCAPI.cpp +++ b/compiler/lib/Conversion/ConcreteToCAPI/ConcreteToCAPI.cpp @@ -8,12 +8,13 @@ #include "concretelang/Conversion/Passes.h" #include "concretelang/Conversion/Tools.h" -#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" +#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" +#include "concretelang/Dialect/RT/IR/RTOps.h" #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" namespace { -namespace BConcrete = mlir::concretelang::BConcrete; +namespace Concrete = mlir::concretelang::Concrete; namespace arith = mlir::arith; namespace func = mlir::func; namespace memref = mlir::memref; @@ -200,23 +201,23 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI( return insertForwardDeclaration(op, rewriter, funcName, funcType); } -template -void addNoOperands(BConcreteOp op, mlir::SmallVector &operands, +template +void addNoOperands(ConcreteOp op, mlir::SmallVector &operands, mlir::RewriterBase &rewriter) {} -template -struct BConcreteToCAPICallPattern : public mlir::OpRewritePattern { - BConcreteToCAPICallPattern( +template +struct ConcreteToCAPICallPattern : public mlir::OpRewritePattern { + ConcreteToCAPICallPattern( ::mlir::MLIRContext *context, - std::function &, + std::function &, mlir::RewriterBase &)> - addOperands = addNoOperands, + addOperands = addNoOperands, mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern(context, benefit), + : ::mlir::OpRewritePattern(context, benefit), addOperands(addOperands) {} ::mlir::LogicalResult - matchAndRewrite(BConcreteOp bOp, + matchAndRewrite(ConcreteOp bOp, ::mlir::PatternRewriter &rewriter) const override { // Create the operands @@ -246,7 +247,7 @@ struct BConcreteToCAPICallPattern : public mlir::OpRewritePattern { }; private: - std::function &, + std::function &, mlir::RewriterBase &)> addOperands; }; @@ -297,7 +298,7 @@ void bootstrapAddOperands(BootstrapOp op, operands.push_back(getContextArgument(op)); } -void wopPBSAddOperands(BConcrete::WopPBSCRTLweBufferOp op, +void wopPBSAddOperands(Concrete::WopPBSCRTLweBufferOp op, mlir::SmallVector &operands, mlir::RewriterBase &rewriter) { mlir::Type crtType = mlir::RankedTensorType::get( @@ -333,7 +334,7 @@ void wopPBSAddOperands(BConcrete::WopPBSCRTLweBufferOp op, } void encodePlaintextWithCrtAddOperands( - BConcrete::EncodePlaintextWithCrtBufferOp op, + Concrete::EncodePlaintextWithCrtBufferOp op, mlir::SmallVector &operands, mlir::RewriterBase &rewriter) { // mods mlir::Type modsType = mlir::RankedTensorType::get({(int)op.modsAttr().size()}, @@ -358,7 +359,7 @@ void encodePlaintextWithCrtAddOperands( } void encodeExpandLutForBootstrapAddOperands( - BConcrete::EncodeExpandLutForBootstrapBufferOp op, + Concrete::EncodeExpandLutForBootstrapBufferOp op, mlir::SmallVector &operands, mlir::RewriterBase &rewriter) { // poly_size operands.push_back( @@ -372,7 +373,7 @@ void encodeExpandLutForBootstrapAddOperands( } void encodeExpandLutForWopPBSAddOperands( - BConcrete::EncodeExpandLutForWopPBSBufferOp op, + Concrete::EncodeExpandLutForWopPBSBufferOp op, mlir::SmallVector &operands, mlir::RewriterBase &rewriter) { // crt_decomposition @@ -424,9 +425,9 @@ void encodeExpandLutForWopPBSAddOperands( rewriter.create(op.getLoc(), op.isSignedAttr())); } -struct BConcreteToCAPIPass : public BConcreteToCAPIBase { +struct ConcreteToCAPIPass : public ConcreteToCAPIBase { - BConcreteToCAPIPass(bool gpu) : gpu(gpu) {} + ConcreteToCAPIPass(bool gpu) : gpu(gpu) {} void runOnOperation() override { auto op = this->getOperation(); @@ -441,73 +442,73 @@ struct BConcreteToCAPIPass : public BConcreteToCAPIBase { target.addLegalDialect(); // Make sure that no ops from `FHE` remain after the lowering - target.addIllegalDialect(); + target.addIllegalDialect(); - // Add patterns to transform BConcrete operators to CAPI call - patterns.add>( + // Add patterns to transform Concrete operators to CAPI call + patterns.add>( &getContext()); patterns.add< - BConcreteToCAPICallPattern>( + ConcreteToCAPICallPattern>( &getContext()); patterns.add< - BConcreteToCAPICallPattern>( + ConcreteToCAPICallPattern>( &getContext()); - patterns.add>( + patterns.add>( &getContext()); + patterns + .add>( + &getContext(), encodePlaintextWithCrtAddOperands); patterns.add< - BConcreteToCAPICallPattern>( - &getContext(), encodePlaintextWithCrtAddOperands); - patterns.add>( + ConcreteToCAPICallPattern>( &getContext(), encodeExpandLutForBootstrapAddOperands); patterns.add< - BConcreteToCAPICallPattern>( + ConcreteToCAPICallPattern>( &getContext(), encodeExpandLutForWopPBSAddOperands); if (gpu) { - patterns.add>( - &getContext(), keyswitchAddOperands); - patterns.add>( - &getContext(), bootstrapAddOperands); + patterns.add>( + &getContext(), keyswitchAddOperands); + patterns.add>( + &getContext(), bootstrapAddOperands); patterns.add< - BConcreteToCAPICallPattern>( + ConcreteToCAPICallPattern>( &getContext(), - keyswitchAddOperands); + keyswitchAddOperands); patterns.add< - BConcreteToCAPICallPattern>( + ConcreteToCAPICallPattern>( &getContext(), - bootstrapAddOperands); + bootstrapAddOperands); } else { - patterns.add>( - &getContext(), keyswitchAddOperands); - patterns.add>( - &getContext(), bootstrapAddOperands); - patterns.add< - BConcreteToCAPICallPattern>( - &getContext(), - keyswitchAddOperands); - patterns.add< - BConcreteToCAPICallPattern>( - &getContext(), - bootstrapAddOperands); + patterns.add>( + &getContext(), keyswitchAddOperands); + patterns.add>( + &getContext(), bootstrapAddOperands); + patterns + .add>( + &getContext(), + keyswitchAddOperands); + patterns + .add>( + &getContext(), + bootstrapAddOperands); } - patterns.add>( + patterns.add>( &getContext(), wopPBSAddOperands); // Apply conversion @@ -526,8 +527,8 @@ private: namespace mlir { namespace concretelang { std::unique_ptr> -createConvertBConcreteToCAPIPass(bool gpu) { - return std::make_unique(gpu); +createConvertConcreteToCAPIPass(bool gpu) { + return std::make_unique(gpu); } } // namespace concretelang } // namespace mlir diff --git a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp index d57ae774a..1a17950a4 100644 --- a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp +++ b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp @@ -111,14 +111,6 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() { mlir::LowerToLLVMOptions options(&getContext()); mlir::LLVMTypeConverter typeConverter(&getContext(), options); typeConverter.addConversion(convertTypes); - typeConverter.addConversion( - [&](mlir::concretelang::Concrete::PlaintextType type) { - return mlir::IntegerType::get(type.getContext(), 64); - }); - typeConverter.addConversion( - [&](mlir::concretelang::Concrete::CleartextType type) { - return mlir::IntegerType::get(type.getContext(), 64); - }); // Setup the set of the patterns rewriter. At this point we want to // convert the `scf` operations to `std` and `std` operations to `llvm`. @@ -153,9 +145,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() { llvm::Optional MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) { - if (type.isa() || - type.isa() || - type.isa() || + if (type.isa() || type.isa() || type.isa() || type.isa()) { @@ -166,14 +156,6 @@ MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) { mlir::LowerToLLVMOptions options(type.getContext()); mlir::LLVMTypeConverter typeConverter(type.getContext(), options); typeConverter.addConversion(convertTypes); - typeConverter.addConversion( - [&](mlir::concretelang::Concrete::PlaintextType type) { - return mlir::IntegerType::get(type.getContext(), 64); - }); - typeConverter.addConversion( - [&](mlir::concretelang::Concrete::CleartextType type) { - return mlir::IntegerType::get(type.getContext(), 64); - }); mlir::Type subtype = type.dyn_cast().getElementType(); mlir::Type convertedSubtype = typeConverter.convertType(subtype); diff --git a/compiler/lib/Conversion/SDFGToStreamEmulator/SDFGToStreamEmulator.cpp b/compiler/lib/Conversion/SDFGToStreamEmulator/SDFGToStreamEmulator.cpp index c0fefd51e..c23587382 100644 --- a/compiler/lib/Conversion/SDFGToStreamEmulator/SDFGToStreamEmulator.cpp +++ b/compiler/lib/Conversion/SDFGToStreamEmulator/SDFGToStreamEmulator.cpp @@ -368,8 +368,8 @@ void SDFGToStreamEmulatorPass::runOnOperation() { target.addIllegalOp(); - // All BConcrete ops are legal after the conversion - target.addLegalDialect(); + // All Concrete ops are legal after the conversion + target.addLegalDialect(); target.addLegalDialect(); target.addLegalOp(); diff --git a/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp b/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp index d80424cc3..b95b725e0 100644 --- a/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp +++ b/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp @@ -10,16 +10,19 @@ #include "mlir/Transforms/DialectConversion.h" #include "concretelang/Conversion/Passes.h" -#include "concretelang/Conversion/TFHEToConcrete/Patterns.h" #include "concretelang/Conversion/Utils/FuncConstOpConversion.h" #include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" +#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h" #include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" #include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" +#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" #include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" #include "concretelang/Dialect/RT/IR/RTOps.h" #include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" +#include "concretelang/Dialect/TFHE/IR/TFHEOps.h" #include "concretelang/Dialect/TFHE/IR/TFHETypes.h" #include "concretelang/Dialect/Tracing/IR/TracingOps.h" +#include "concretelang/Support/Constants.h" namespace TFHE = mlir::concretelang::TFHE; namespace Concrete = mlir::concretelang::Concrete; @@ -31,27 +34,38 @@ struct TFHEToConcretePass : public TFHEToConcreteBase { }; } // namespace -using mlir::concretelang::Concrete::LweCiphertextType; using mlir::concretelang::TFHE::GLWECipherTextType; /// TFHEToConcreteTypeConverter is a TypeConverter that transform -/// `TFHE.glwe<{_,_,_}{p}>` to Concrete.lwe_ciphertext +/// `TFHE.glwe<{dimension,1,bits}{p}>` to `tensor>` +/// `tensor<...xTFHE.glwe<{dimension,1,bits}{p}>>` to +/// `tensor<...xdimension+1, i64>>` class TFHEToConcreteTypeConverter : public mlir::TypeConverter { public: TFHEToConcreteTypeConverter() { addConversion([](mlir::Type type) { return type; }); addConversion([&](GLWECipherTextType type) { - return mlir::concretelang::convertTypeToLWE(type.getContext(), type); + assert(type.getPolynomialSize() <= 1 && + "converter doesn't support polynomialSize > 1"); + assert(type.getDimension() != -1); + llvm::SmallVector shape; + shape.push_back(type.getDimension() + 1); + return mlir::RankedTensorType::get( + shape, mlir::IntegerType::get(type.getContext(), 64)); }); addConversion([&](mlir::RankedTensorType type) { auto glwe = type.getElementType().dyn_cast_or_null(); if (glwe == nullptr) { return (mlir::Type)(type); } + mlir::SmallVector newShape; + newShape.reserve(type.getShape().size() + 1); + newShape.append(type.getShape().begin(), type.getShape().end()); + assert(glwe.getDimension() != -1); + newShape.push_back(glwe.getDimension() + 1); mlir::Type r = mlir::RankedTensorType::get( - type.getShape(), - mlir::concretelang::convertTypeToLWE(glwe.getContext(), glwe)); + newShape, mlir::IntegerType::get(type.getContext(), 64)); return r; }); addConversion([&](mlir::concretelang::RT::FutureType type) { @@ -69,73 +83,84 @@ public: namespace { -struct BootstrapGLWEOpPattern - : public mlir::OpRewritePattern { - BootstrapGLWEOpPattern(mlir::MLIRContext *context, - mlir::TypeConverter &converter, - mlir::PatternBenefit benefit = 100) - : mlir::OpRewritePattern(context, benefit), - converter(converter) {} +struct SubIntGLWEOpPattern + : public mlir::OpConversionPattern { - mlir::LogicalResult - matchAndRewrite(TFHE::BootstrapGLWEOp bsOp, - mlir::PatternRewriter &rewriter) const override { - mlir::Type resultType = converter.convertType(bsOp.getType()); + SubIntGLWEOpPattern(mlir::MLIRContext *context, + mlir::TypeConverter &typeConverter) + : mlir::OpConversionPattern( + typeConverter, context, + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} - auto newOp = rewriter.replaceOpWithNewOp( - bsOp, resultType, bsOp.ciphertext(), bsOp.lookup_table(), bsOp.level(), - bsOp.baseLog(), bsOp.polySize(), bsOp.glweDimension()); + ::mlir::LogicalResult + matchAndRewrite(TFHE::SubGLWEIntOp subOp, TFHE::SubGLWEIntOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Value negated = rewriter.create( + subOp.getLoc(), adaptor.b().getType(), adaptor.b()); - rewriter.startRootUpdate(newOp); - newOp.input_ciphertext().setType( - converter.convertType(bsOp.ciphertext().getType())); - rewriter.finalizeRootUpdate(newOp); + rewriter.replaceOpWithNewOp( + subOp, this->getTypeConverter()->convertType(subOp.getType()), negated, + subOp.a()); - return ::mlir::success(); + return mlir::success(); } - -private: - mlir::TypeConverter &converter; }; -struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern { - WopPBSGLWEOpPattern(mlir::MLIRContext *context, - mlir::TypeConverter &converter, - mlir::PatternBenefit benefit = 100) - : mlir::OpRewritePattern(context, benefit), - converter(converter) {} +struct BootstrapGLWEOpPattern + : public mlir::OpConversionPattern { - mlir::LogicalResult - matchAndRewrite(TFHE::WopPBSGLWEOp wopOp, - mlir::PatternRewriter &rewriter) const override { - mlir::Type resultType = converter.convertType(wopOp.getType()); + BootstrapGLWEOpPattern(mlir::MLIRContext *context, + mlir::TypeConverter &typeConverter) + : mlir::OpConversionPattern( + typeConverter, context, + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} - auto newOp = rewriter.replaceOpWithNewOp( - wopOp, resultType, wopOp.ciphertexts(), wopOp.lookupTable(), - // Bootstrap parameters - wopOp.bootstrapLevel(), wopOp.bootstrapBaseLog(), - // Keyswitch parameters - wopOp.keyswitchLevel(), wopOp.keyswitchBaseLog(), - // Packing keyswitch key parameters - wopOp.packingKeySwitchInputLweDimension(), - wopOp.packingKeySwitchoutputPolynomialSize(), - wopOp.packingKeySwitchLevel(), wopOp.packingKeySwitchBaseLog(), - // Circuit bootstrap parameters - wopOp.circuitBootstrapLevel(), wopOp.circuitBootstrapBaseLog(), - // Crt Decomposition - wopOp.crtDecomposition()); + ::mlir::LogicalResult + matchAndRewrite(TFHE::BootstrapGLWEOp bsOp, + TFHE::BootstrapGLWEOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { - rewriter.startRootUpdate(newOp); + TFHE::GLWECipherTextType resultType = + bsOp.getType().cast(); + TFHE::GLWECipherTextType inputType = + bsOp.ciphertext().getType().cast(); - newOp.ciphertexts().setType( - converter.convertType(wopOp.ciphertexts().getType())); + rewriter.replaceOpWithNewOp( + bsOp, this->getTypeConverter()->convertType(resultType), + adaptor.ciphertext(), adaptor.lookup_table(), inputType.getDimension(), + adaptor.polySize(), adaptor.level(), adaptor.baseLog(), + adaptor.glweDimension(), resultType.getP()); - rewriter.finalizeRootUpdate(newOp); - return ::mlir::success(); + return mlir::success(); } +}; -private: - mlir::TypeConverter &converter; +struct KeySwitchGLWEOpPattern + : public mlir::OpConversionPattern { + + KeySwitchGLWEOpPattern(mlir::MLIRContext *context, + mlir::TypeConverter &typeConverter) + : mlir::OpConversionPattern( + typeConverter, context, + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(TFHE::KeySwitchGLWEOp ksOp, + TFHE::KeySwitchGLWEOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + + TFHE::GLWECipherTextType resultType = + ksOp.getType().cast(); + TFHE::GLWECipherTextType inputType = + ksOp.ciphertext().getType().cast(); + + rewriter.replaceOpWithNewOp( + ksOp, this->getTypeConverter()->convertType(resultType), + adaptor.ciphertext(), adaptor.level(), adaptor.baseLog(), + inputType.getDimension(), resultType.getDimension()); + + return mlir::success(); + } }; struct TracePlaintextOpPattern @@ -163,6 +188,419 @@ struct TracePlaintextOpPattern } }; +template +struct ZeroOpPattern : public mlir::OpRewritePattern { + ZeroOpPattern(mlir::MLIRContext *context) + : mlir::OpRewritePattern( + context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(ZeroOp zeroOp, + mlir::PatternRewriter &rewriter) const override { + TFHEToConcreteTypeConverter converter; + auto newResultTy = converter.convertType(zeroOp.getType()); + + auto generateBody = [&](mlir::OpBuilder &nestedBuilder, + mlir::Location nestedLoc, + mlir::ValueRange blockArgs) { + // %c0 = 0 : i64 + auto cstOp = nestedBuilder.create( + nestedLoc, nestedBuilder.getI64IntegerAttr(0)); + // tensor.yield %z : !FHE.eint

+ nestedBuilder.create(nestedLoc, cstOp.getResult()); + }; + // tensor.generate + rewriter.replaceOpWithNewOp( + zeroOp, newResultTy, mlir::ValueRange{}, generateBody); + + return ::mlir::success(); + }; +}; + +/// Pattern that rewrites the ExtractSlice operation, taking into account the +/// additional LWE dimension introduced during type conversion +struct ExtractSliceOpPattern + : public mlir::OpConversionPattern { + ExtractSliceOpPattern(mlir::MLIRContext *context, + mlir::TypeConverter &typeConverter) + : ::mlir::OpConversionPattern( + typeConverter, context, + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, + mlir::tensor::ExtractSliceOp::Adaptor adaptor, + ::mlir::ConversionPatternRewriter &rewriter) const override { + // is not a tensor of GLWEs that need to be extended with the LWE dimension + if (this->getTypeConverter()->isLegal(extractSliceOp.getType())) { + return mlir::failure(); + } + auto resultTy = extractSliceOp.result().getType(); + auto newResultTy = this->getTypeConverter() + ->convertType(resultTy) + .cast(); + + // add 0 to the static_offsets + mlir::SmallVector staticOffsets; + staticOffsets.append(adaptor.static_offsets().begin(), + adaptor.static_offsets().end()); + staticOffsets.push_back(rewriter.getI64IntegerAttr(0)); + + // add the lweSize to the sizes + mlir::SmallVector staticSizes; + staticSizes.append(adaptor.static_sizes().begin(), + adaptor.static_sizes().end()); + staticSizes.push_back(rewriter.getI64IntegerAttr( + newResultTy.getDimSize(newResultTy.getRank() - 1))); + + // add 1 to the strides + mlir::SmallVector staticStrides; + staticStrides.append(adaptor.static_strides().begin(), + adaptor.static_strides().end()); + staticStrides.push_back(rewriter.getI64IntegerAttr(1)); + + // replace tensor.extract_slice to the new one + rewriter.replaceOpWithNewOp( + extractSliceOp, newResultTy, adaptor.source(), adaptor.offsets(), + adaptor.sizes(), adaptor.strides(), + rewriter.getArrayAttr(staticOffsets), + rewriter.getArrayAttr(staticSizes), + rewriter.getArrayAttr(staticStrides)); + + return ::mlir::success(); + }; +}; + +/// Pattern that rewrites the Extract operation, taking into account the +/// additional LWE dimension introduced during type conversion +struct ExtractOpPattern + : public mlir::OpConversionPattern { + ExtractOpPattern(::mlir::MLIRContext *context, + mlir::TypeConverter &typeConverter) + : ::mlir::OpConversionPattern( + typeConverter, context, + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(mlir::tensor::ExtractOp extractOp, + mlir::tensor::ExtractOp::Adaptor adaptor, + ::mlir::ConversionPatternRewriter &rewriter) const override { + // is not a tensor of GLWEs that need to be extended with the LWE dimension + if (this->getTypeConverter()->isLegal(extractOp.getType())) { + return mlir::failure(); + } + + auto newResultType = this->getTypeConverter() + ->convertType(extractOp.getType()) + .cast(); + auto tensorRank = + adaptor.tensor().getType().cast().getRank(); + + // [min..., 0] for static_offsets () + mlir::SmallVector staticOffsets( + tensorRank, + rewriter.getI64IntegerAttr(std::numeric_limits::min())); + staticOffsets[staticOffsets.size() - 1] = rewriter.getI64IntegerAttr(0); + + // [1..., lweDimension+1] for static_sizes or + // [1..., nbBlock, lweDimension+1] + mlir::SmallVector staticSizes( + tensorRank, rewriter.getI64IntegerAttr(1)); + staticSizes[staticSizes.size() - 1] = rewriter.getI64IntegerAttr( + newResultType.getDimSize(newResultType.getRank() - 1)); + + // [1...] for static_strides + mlir::SmallVector staticStrides( + tensorRank, rewriter.getI64IntegerAttr(1)); + + rewriter.replaceOpWithNewOp( + extractOp, newResultType, adaptor.tensor(), adaptor.indices(), + mlir::SmallVector{}, mlir::SmallVector{}, + rewriter.getArrayAttr(staticOffsets), + rewriter.getArrayAttr(staticSizes), + rewriter.getArrayAttr(staticStrides)); + + return ::mlir::success(); + }; +}; + +/// Pattern that rewrites the InsertSlice operation, taking into account the +/// additional LWE dimension introduced during type conversion +struct InsertSliceOpPattern + : public mlir::OpConversionPattern { + InsertSliceOpPattern(mlir::MLIRContext *context, + mlir::TypeConverter &typeConverter) + : ::mlir::OpConversionPattern( + typeConverter, context, + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(mlir::tensor::InsertSliceOp insertSliceOp, + mlir::tensor::InsertSliceOp::Adaptor adaptor, + ::mlir::ConversionPatternRewriter &rewriter) const override { + // is not a tensor of GLWEs that need to be extended with the LWE dimension + if (this->getTypeConverter()->isLegal(insertSliceOp.getType())) { + return mlir::failure(); + } + + auto newResultTy = this->getTypeConverter() + ->convertType(insertSliceOp.result().getType()) + .cast(); + + // add 0 to static_offsets + mlir::SmallVector staticOffsets; + staticOffsets.append(adaptor.static_offsets().begin(), + adaptor.static_offsets().end()); + staticOffsets.push_back(rewriter.getI64IntegerAttr(0)); + + // add lweDimension+1 to static_sizes + mlir::SmallVector staticSizes; + staticSizes.append(adaptor.static_sizes().begin(), + adaptor.static_sizes().end()); + staticSizes.push_back(rewriter.getI64IntegerAttr( + newResultTy.getDimSize(newResultTy.getRank() - 1))); + + // add 1 to the strides + mlir::SmallVector staticStrides; + staticStrides.append(adaptor.static_strides().begin(), + adaptor.static_strides().end()); + staticStrides.push_back(rewriter.getI64IntegerAttr(1)); + + // replace tensor.insert_slice with the new one + rewriter.replaceOpWithNewOp( + insertSliceOp, newResultTy, adaptor.source(), adaptor.dest(), + adaptor.offsets(), adaptor.sizes(), adaptor.strides(), + rewriter.getArrayAttr(staticOffsets), + rewriter.getArrayAttr(staticSizes), + rewriter.getArrayAttr(staticStrides)); + + return ::mlir::success(); + }; +}; + +/// Pattern that rewrites the Insert operation, taking into account the +/// additional LWE dimension introduced during type conversion +struct InsertOpPattern + : public mlir::OpConversionPattern { + InsertOpPattern(mlir::MLIRContext *context, + mlir::TypeConverter &typeConverter) + : ::mlir::OpConversionPattern( + typeConverter, context, + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(mlir::tensor::InsertOp insertOp, + mlir::tensor::InsertOp::Adaptor adaptor, + ::mlir::ConversionPatternRewriter &rewriter) const override { + // is not a tensor of GLWEs that need to be extended with the LWE dimension + if (this->getTypeConverter()->isLegal(insertOp.getType())) { + return mlir::failure(); + } + + mlir::RankedTensorType newResultTy = + this->getTypeConverter() + ->convertType(insertOp.result().getType()) + .cast(); + + // add zeros to static_offsets + mlir::SmallVector offsets; + offsets.append(adaptor.indices().begin(), adaptor.indices().end()); + offsets.push_back(rewriter.getIndexAttr(0)); + + // Inserting a smaller tensor into a (potentially) bigger one. Set + // dimensions for all leading dimensions of the target tensor not + // present in the source to 1. + mlir::SmallVector sizes(adaptor.indices().size(), + rewriter.getI64IntegerAttr(1)); + + // Add size for the bufferized source element + sizes.push_back(rewriter.getI64IntegerAttr( + newResultTy.getDimSize(newResultTy.getRank() - 1))); + + // Set stride of all dimensions to 1 + mlir::SmallVector strides( + newResultTy.getRank(), rewriter.getI64IntegerAttr(1)); + + // replace tensor.insert_slice with the new one + rewriter.replaceOpWithNewOp( + insertOp, adaptor.scalar(), adaptor.dest(), offsets, sizes, strides); + + return ::mlir::success(); + }; +}; + +/// FromElementsOpPatterns transform each tensor.from_elements that operates on +/// TFHE.glwe +/// +/// refs: check_tests/Conversion/TFHEToConcrete/tensor_from_elements.mlir +struct FromElementsOpPattern + : public mlir::OpConversionPattern { + FromElementsOpPattern(mlir::MLIRContext *context, + mlir::TypeConverter &typeConverter) + : ::mlir::OpConversionPattern( + typeConverter, context, + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(mlir::tensor::FromElementsOp fromElementsOp, + mlir::tensor::FromElementsOp::Adaptor adaptor, + ::mlir::ConversionPatternRewriter &rewriter) const override { + + // is not a tensor of GLWEs that need to be extended with the LWE dimension + if (this->getTypeConverter()->isLegal(fromElementsOp.getType())) { + return mlir::failure(); + } + + auto converter = this->getTypeConverter(); + + auto resultTy = fromElementsOp.result().getType(); + if (converter->isLegal(resultTy)) { + return mlir::failure(); + } + auto oldTensorResultTy = resultTy.cast(); + auto oldRank = oldTensorResultTy.getRank(); + + auto newTensorResultTy = + converter->convertType(resultTy).cast(); + auto newRank = newTensorResultTy.getRank(); + auto newShape = newTensorResultTy.getShape(); + + mlir::Value tensor = rewriter.create( + fromElementsOp.getLoc(), newTensorResultTy, mlir::ValueRange{}); + + // sizes are [1, ..., 1, diffShape...] + llvm::SmallVector sizes(oldRank, + rewriter.getI64IntegerAttr(1)); + for (auto i = newRank - oldRank; i > 0; i--) { + sizes.push_back(rewriter.getI64IntegerAttr(*(newShape.end() - i))); + } + + // strides are [1, ..., 1] + llvm::SmallVector oneStrides( + newShape.size(), rewriter.getI64IntegerAttr(1)); + + // start with offets [0, ..., 0] + llvm::SmallVector currentOffsets(newRank, 0); + + // for each elements insert_slice with right offet + for (auto elt : llvm::enumerate(adaptor.elements())) { + // Just create offsets as attributes + llvm::SmallVector offsets; + offsets.reserve(currentOffsets.size()); + std::transform(currentOffsets.begin(), currentOffsets.end(), + std::back_inserter(offsets), + [&](auto v) { return rewriter.getI64IntegerAttr(v); }); + mlir::tensor::InsertSliceOp insOp = + rewriter.create( + fromElementsOp.getLoc(), + /* src: */ elt.value(), + /* dst: */ tensor, + /* offs: */ offsets, + /* sizes: */ sizes, + /* strides: */ oneStrides); + + tensor = insOp.getResult(); + + // Increment the offsets + for (auto i = newRank - 2; i >= 0; i--) { + if (currentOffsets[i] == newShape[i] - 1) { + currentOffsets[i] = 0; + continue; + } + currentOffsets[i]++; + break; + } + } + + rewriter.replaceOp(fromElementsOp, tensor); + return ::mlir::success(); + }; +}; + +// This template rewrite pattern transforms any instance of +// `ShapeOp` operators that operates on tensor of lwe ciphertext by adding +// the lwe size as a size of the tensor result and by adding a trivial +// reassociation at the end of the reassociations map. +// +// Example: +// +// ```mlir +// %0 = "ShapeOp" %arg0 [reassocations...] +// : tensor<...x!TFHE.glwe<{dimension,1,bits}{p}>> into +// tensor<...x!TFHE.glwe<{dimension,1,bits}{p}>> +// ``` +// +// becomes: +// +// ```mlir +// %0 = "ShapeOp" %arg0 [reassociations..., [inRank or outRank]] +// : tensor<...xdimension+1xi64> into +// tensor<...xdimension+1xi64> +// ``` +template +struct TensorShapeOpPattern : public mlir::OpConversionPattern { + TensorShapeOpPattern(mlir::MLIRContext *context, + mlir::TypeConverter &typeConverter) + : ::mlir::OpConversionPattern( + typeConverter, context, + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(ShapeOp shapeOp, ShapeOpAdaptor adaptor, + ::mlir::ConversionPatternRewriter &rewriter) const override { + // is not a tensor of GLWEs that need to be extended with the LWE dimension + if (this->getTypeConverter()->isLegal(shapeOp.getType())) { + return mlir::failure(); + } + + auto newResultTy = + ((mlir::Type)this->getTypeConverter()->convertType(shapeOp.getType())) + .cast(); + + auto reassocTy = + ((mlir::Type)this->getTypeConverter()->convertType( + (inRank ? shapeOp.src() : shapeOp.result()).getType())) + .cast(); + + auto oldReassocs = shapeOp.getReassociationIndices(); + mlir::SmallVector newReassocs; + newReassocs.append(oldReassocs.begin(), oldReassocs.end()); + + // add [rank] to reassociations + { + mlir::ReassociationIndices lweAssoc; + lweAssoc.push_back(reassocTy.getRank() - 1); + newReassocs.push_back(lweAssoc); + } + + rewriter.replaceOpWithNewOp(shapeOp, newResultTy, adaptor.src(), + newReassocs); + + return ::mlir::success(); + }; +}; + +/// Add the instantiated TensorShapeOpPattern rewrite pattern with the +/// `ShapeOp` to the patterns set and populate the conversion target. +template +void insertTensorShapeOpPattern(mlir::MLIRContext &context, + mlir::TypeConverter &converter, + mlir::RewritePatternSet &patterns, + mlir::ConversionTarget &target) { + patterns.insert>( + &context, converter); + target.addDynamicallyLegalOp([&](mlir::Operation *op) { + return converter.isLegal(op->getResultTypes()) && + converter.isLegal(op->getOperandTypes()); + }); +} + +// The pass is supposed to endup with no TFHE.glwe type. Tensors should be +// extended with an additional dimension at the end, and some patterns in this +// pass are fully dedicated to rewrite tensor ops with this additional dimension +// in mind void TFHEToConcretePass::runOnOperation() { auto op = this->getOperation(); @@ -205,60 +643,86 @@ void TFHEToConcretePass::runOnOperation() { patterns.add>( &getContext(), converter); - populateWithGeneratedTFHEToConcrete(patterns); + // populateWithGeneratedTFHEToConcrete(patterns); - patterns.add>(&getContext(), converter); - patterns.add>( + // Generic patterns + patterns.insert< + mlir::concretelang::GenericOneToOneOpConversionPattern< + mlir::concretelang::TFHE::AddGLWEOp, + mlir::concretelang::Concrete::AddLweTensorOp>, + mlir::concretelang::GenericOneToOneOpConversionPattern< + mlir::concretelang::TFHE::AddGLWEIntOp, + mlir::concretelang::Concrete::AddPlaintextLweTensorOp>, + mlir::concretelang::GenericOneToOneOpConversionPattern< + mlir::concretelang::TFHE::MulGLWEIntOp, + mlir::concretelang::Concrete::MulCleartextLweTensorOp>, + mlir::concretelang::GenericOneToOneOpConversionPattern< + mlir::concretelang::TFHE::NegGLWEOp, + mlir::concretelang::Concrete::NegateLweTensorOp>, + mlir::concretelang::GenericOneToOneOpConversionPattern< + mlir::concretelang::TFHE::EncodeExpandLutForBootstrapOp, + mlir::concretelang::Concrete::EncodeExpandLutForBootstrapTensorOp, + true>, + mlir::concretelang::GenericOneToOneOpConversionPattern< + mlir::concretelang::TFHE::EncodeExpandLutForWopPBSOp, + mlir::concretelang::Concrete::EncodeExpandLutForWopPBSTensorOp, true>, + mlir::concretelang::GenericOneToOneOpConversionPattern< + mlir::concretelang::TFHE::EncodePlaintextWithCrtOp, + mlir::concretelang::Concrete::EncodePlaintextWithCrtTensorOp, true>, + mlir::concretelang::GenericOneToOneOpConversionPattern< + mlir::concretelang::TFHE::WopPBSGLWEOp, + mlir::concretelang::Concrete::WopPBSCRTLweTensorOp, true>>( &getContext(), converter); - patterns.add>(&getContext(), - converter); - patterns.add>(&getContext(), - converter); - patterns.add(&getContext(), converter); - patterns.add(&getContext(), converter); - target.addDynamicallyLegalOp( - [&](Concrete::BootstrapLweOp op) { - return (converter.isLegal(op->getOperandTypes()) && - converter.isLegal(op->getResultTypes())); + // pattern of remaining TFHE ops + patterns.insert, + ZeroOpPattern>( + &getContext()); + patterns.insert(&getContext(), converter); + + // Add patterns to rewrite tensor operators that works on tensors of TFHE GLWE + // types + patterns.insert(&getContext(), + converter); + // Add patterns to rewrite some of tensor ops that were introduced by the + // linalg bufferization of encrypted tensor + insertTensorShapeOpPattern(getContext(), converter, + patterns, target); + insertTensorShapeOpPattern(getContext(), converter, + patterns, target); + // legalize ops only if operand and result types are legal + target.addDynamicallyLegalOp< + mlir::tensor::YieldOp, mlir::scf::YieldOp, mlir::tensor::GenerateOp, + mlir::tensor::ExtractSliceOp, mlir::tensor::ExtractOp, + mlir::tensor::InsertSliceOp, mlir::tensor::ExpandShapeOp, + mlir::tensor::CollapseShapeOp, mlir::bufferization::AllocTensorOp>( + [&](mlir::Operation *op) { + return converter.isLegal(op->getResultTypes()) && + converter.isLegal(op->getOperandTypes()); }); - patterns.add>(&getContext(), - converter); - patterns.add>( - &getContext(), converter); - - patterns.add< - mlir::concretelang::GenericTypeConverterPattern>( - patterns.getContext(), converter); - - patterns.add< - mlir::concretelang::GenericTypeConverterPattern>( - patterns.getContext(), converter); - - patterns.add>( - &getContext(), converter); + // rewrite scf for loops if working on illegal types patterns.add>( &getContext(), converter); - mlir::concretelang::populateWithTensorTypeConverterPatterns(patterns, target, - converter); + target.addDynamicallyLegalOp([&](mlir::scf::ForOp forOp) { + return converter.isLegal(forOp.getInitArgs().getTypes()) && + converter.isLegal(forOp.getResults().getTypes()); + }); + + mlir::concretelang::addDynamicallyLegalTypeOp( + target, converter); mlir::populateFunctionOpInterfaceTypeConversionPattern( patterns, converter); // Conversion of Tracing dialect - patterns.add>(&getContext(), converter); + patterns.add>(&getContext(), converter); mlir::concretelang::addDynamicallyLegalTypeOp( target, converter); patterns.add(&getContext(), converter); @@ -271,25 +735,27 @@ void TFHEToConcretePass::runOnOperation() { // Conversion of RT Dialect Ops patterns.add< - mlir::concretelang::GenericTypeConverterPattern, - mlir::concretelang::GenericTypeConverterPattern, - mlir::concretelang::GenericTypeConverterPattern< - mlir::bufferization::AllocTensorOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::func::ReturnOp>, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::scf::YieldOp>, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::bufferization::AllocTensorOp, true>, + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::MakeReadyFutureOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::AwaitFutureOp>, - mlir::concretelang::GenericTypeConverterPattern< - mlir::concretelang::RT::CreateAsyncTaskOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::concretelang::RT::CreateAsyncTaskOp, true>, + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::WorkFunctionReturnOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(), converter); mlir::concretelang::addDynamicallyLegalTypeOp< @@ -310,13 +776,6 @@ void TFHEToConcretePass::runOnOperation() { mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter); - mlir::concretelang::addDynamicallyLegalTypeOp( - target, converter); - mlir::concretelang::addDynamicallyLegalTypeOp( - target, converter); - mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::bufferization::AllocTensorOp>(target, converter); - // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { this->signalPassFailure(); diff --git a/compiler/lib/Dialect/BConcrete/CMakeLists.txt b/compiler/lib/Dialect/BConcrete/CMakeLists.txt deleted file mode 100644 index 9f57627c3..000000000 --- a/compiler/lib/Dialect/BConcrete/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_subdirectory(IR) -add_subdirectory(Transforms) diff --git a/compiler/lib/Dialect/BConcrete/IR/BConcreteDialect.cpp b/compiler/lib/Dialect/BConcrete/IR/BConcreteDialect.cpp deleted file mode 100644 index 88a988b9f..000000000 --- a/compiler/lib/Dialect/BConcrete/IR/BConcreteDialect.cpp +++ /dev/null @@ -1,26 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt -// for license information. - -#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" -#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" - -#define GET_TYPEDEF_CLASSES -#include "concretelang/Dialect/BConcrete/IR/BConcreteOpsTypes.cpp.inc" - -#include "concretelang/Dialect/BConcrete/IR/BConcreteOpsDialect.cpp.inc" - -using namespace mlir::concretelang::BConcrete; - -void BConcreteDialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.cpp.inc" - >(); - - addTypes< -#define GET_TYPEDEF_LIST -#include "concretelang/Dialect/BConcrete/IR/BConcreteOpsTypes.cpp.inc" - >(); -} diff --git a/compiler/lib/Dialect/BConcrete/IR/BConcreteOps.cpp b/compiler/lib/Dialect/BConcrete/IR/BConcreteOps.cpp deleted file mode 100644 index fb732f68b..000000000 --- a/compiler/lib/Dialect/BConcrete/IR/BConcreteOps.cpp +++ /dev/null @@ -1,9 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt -// for license information. - -#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" - -#define GET_OP_CLASSES -#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.cpp.inc" diff --git a/compiler/lib/Dialect/BConcrete/IR/CMakeLists.txt b/compiler/lib/Dialect/BConcrete/IR/CMakeLists.txt deleted file mode 100644 index 69c6f551f..000000000 --- a/compiler/lib/Dialect/BConcrete/IR/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -add_mlir_dialect_library( - BConcreteDialect - BConcreteDialect.cpp - BConcreteOps.cpp - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/BConcrete - DEPENDS - mlir-headers - LINK_LIBS - PUBLIC - MLIRIR) - -target_link_libraries(BConcreteDialect PUBLIC MLIRIR) diff --git a/compiler/lib/Dialect/BConcrete/Transforms/CMakeLists.txt b/compiler/lib/Dialect/BConcrete/Transforms/CMakeLists.txt deleted file mode 100644 index 14d98990e..000000000 --- a/compiler/lib/Dialect/BConcrete/Transforms/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -add_mlir_dialect_library( - ConcretelangBConcreteTransforms - BufferizableOpInterfaceImpl.cpp - AddRuntimeContext.cpp - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/BConcrete - DEPENDS - BConcreteTransformsIncGen - mlir-headers - LINK_LIBS - PUBLIC - ConcretelangConversion - MLIRArithmeticDialect - MLIRBufferizationDialect - MLIRBufferizationTransforms - MLIRIR - MLIRMemRefDialect - MLIRPass - MLIRTransforms) diff --git a/compiler/lib/Dialect/CMakeLists.txt b/compiler/lib/Dialect/CMakeLists.txt index 76ea8788b..9aaed6502 100644 --- a/compiler/lib/Dialect/CMakeLists.txt +++ b/compiler/lib/Dialect/CMakeLists.txt @@ -2,7 +2,6 @@ add_subdirectory(FHELinalg) add_subdirectory(FHE) add_subdirectory(TFHE) add_subdirectory(Concrete) -add_subdirectory(BConcrete) add_subdirectory(RT) add_subdirectory(SDFG) add_subdirectory(Tracing) diff --git a/compiler/lib/Dialect/Concrete/IR/ConcreteDialect.cpp b/compiler/lib/Dialect/Concrete/IR/ConcreteDialect.cpp index 9517e7174..cdce892d9 100644 --- a/compiler/lib/Dialect/Concrete/IR/ConcreteDialect.cpp +++ b/compiler/lib/Dialect/Concrete/IR/ConcreteDialect.cpp @@ -24,122 +24,3 @@ void ConcreteDialect::initialize() { #include "concretelang/Dialect/Concrete/IR/ConcreteOpsTypes.cpp.inc" >(); } - -void printSigned(mlir::AsmPrinter &p, signed i) { - if (i == -1) - p << "_"; - else - p << i; -} - -mlir::Type GlweCiphertextType::parse(mlir::AsmParser &parser) { - if (parser.parseLess()) - return Type(); - int glweDimension = -1; - if (parser.parseOptionalKeyword("_") && parser.parseInteger(glweDimension)) - return Type(); - if (parser.parseComma()) - return Type(); - int polynomialSize = -1; - if (parser.parseOptionalKeyword("_") && parser.parseInteger(polynomialSize)) - return Type(); - if (parser.parseComma()) - return Type(); - - int p = -1; - if (parser.parseOptionalKeyword("_") && parser.parseInteger(p)) - return Type(); - if (parser.parseGreater()) - return Type(); - Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); - return getChecked(loc, loc.getContext(), glweDimension, polynomialSize, p); -} - -void GlweCiphertextType::print(mlir::AsmPrinter &p) const { - p << "<"; - printSigned(p, getGlweDimension()); - p << ","; - printSigned(p, getPolynomialSize()); - p << ","; - printSigned(p, getP()); - p << ">"; -} - -void LweCiphertextType::print(mlir::AsmPrinter &p) const { - p << "<"; - printSigned(p, getDimension()); - p << ","; - printSigned(p, getP()); - p << ">"; -} - -mlir::Type LweCiphertextType::parse(mlir::AsmParser &parser) { - if (parser.parseLess()) - return mlir::Type(); - - int dimension = -1; - if (parser.parseOptionalKeyword("_") && parser.parseInteger(dimension)) - return mlir::Type(); - if (parser.parseComma()) - return mlir::Type(); - int p = -1; - if (parser.parseOptionalKeyword("_") && parser.parseInteger(p)) - return mlir::Type(); - if (parser.parseGreater()) - return mlir::Type(); - - mlir::Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); - - return getChecked(loc, loc.getContext(), dimension, p); -} - -void CleartextType::print(mlir::AsmPrinter &p) const { - p << "<"; - if (getP() == -1) - p << "_"; - else - p << getP(); - p << ">"; -} - -mlir::Type CleartextType::parse(mlir::AsmParser &parser) { - if (parser.parseLess()) - return mlir::Type(); - - int p = -1; - - if (parser.parseOptionalKeyword("_") && parser.parseInteger(p)) - return mlir::Type(); - if (parser.parseGreater()) - return mlir::Type(); - - Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); - - return getChecked(loc, loc.getContext(), p); -} - -void PlaintextType::print(mlir::AsmPrinter &p) const { - p << "<"; - if (getP() == -1) - p << "_"; - else - p << getP(); - p << ">"; -} - -mlir::Type PlaintextType::parse(mlir::AsmParser &parser) { - - if (parser.parseLess()) - return mlir::Type(); - - int p = -1; - - if (parser.parseOptionalKeyword("_") && parser.parseInteger(p)) - return mlir::Type(); - if (parser.parseGreater()) - return mlir::Type(); - - Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); - - return getChecked(loc, loc.getContext(), p); -} diff --git a/compiler/lib/Dialect/BConcrete/Transforms/AddRuntimeContext.cpp b/compiler/lib/Dialect/Concrete/Transforms/AddRuntimeContext.cpp similarity index 96% rename from compiler/lib/Dialect/BConcrete/Transforms/AddRuntimeContext.cpp rename to compiler/lib/Dialect/Concrete/Transforms/AddRuntimeContext.cpp index 2b7bbb440..5245fd0b6 100644 --- a/compiler/lib/Dialect/BConcrete/Transforms/AddRuntimeContext.cpp +++ b/compiler/lib/Dialect/Concrete/Transforms/AddRuntimeContext.cpp @@ -8,9 +8,9 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" -#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" -#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" -#include "concretelang/Dialect/BConcrete/Transforms/Passes.h" +#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" +#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" +#include "concretelang/Dialect/Concrete/Transforms/Passes.h" namespace { struct AddRuntimeContextToFuncOpPattern diff --git a/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp b/compiler/lib/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.cpp similarity index 60% rename from compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp rename to compiler/lib/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.cpp index b5c9b31ba..cc8e84beb 100644 --- a/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/compiler/lib/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.cpp @@ -15,9 +15,9 @@ #include "mlir/IR/Operation.h" #include "concretelang/Conversion/Tools.h" -#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" -#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" -#include "concretelang/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.h" +#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" +#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" +#include "concretelang/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.h" #include "concretelang/Dialect/Tracing/IR/TracingOps.h" #include "concretelang/Support/CompilerEngine.h" #include @@ -30,8 +30,8 @@ using namespace mlir::tensor; namespace { -namespace BConcrete = mlir::concretelang::BConcrete; namespace Tracing = mlir::concretelang::Tracing; +namespace Concrete = mlir::concretelang::Concrete; template struct TensorToMemrefOp : public BufferizableOpInterface::ExternalModel< @@ -95,61 +95,58 @@ struct TensorToMemrefOp : public BufferizableOpInterface::ExternalModel< } // namespace -void mlir::concretelang::BConcrete:: +void mlir::concretelang::Concrete:: registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, - BConcrete::BConcreteDialect *dialect) { + Concrete::ConcreteDialect *dialect) { // add_lwe_tensor => add_lwe_buffer - BConcrete::AddLweTensorOp::attachInterface< - TensorToMemrefOp>( + Concrete::AddLweTensorOp::attachInterface< + TensorToMemrefOp>( *ctx); // add_plaintext_lwe_tensor => add_plaintext_lwe_buffer - BConcrete::AddPlaintextLweTensorOp::attachInterface< - TensorToMemrefOp>(*ctx); + Concrete::AddPlaintextLweTensorOp::attachInterface>( + *ctx); // mul_cleartext_lwe_tensor => mul_cleartext_lwe_buffer - BConcrete::MulCleartextLweTensorOp::attachInterface< - TensorToMemrefOp>(*ctx); + Concrete::MulCleartextLweTensorOp::attachInterface>( + *ctx); // negate_cleartext_lwe_tensor => negate_cleartext_lwe_buffer - BConcrete::NegateLweTensorOp::attachInterface>(*ctx); + Concrete::NegateLweTensorOp::attachInterface>(*ctx); // negate_cleartext_lwe_tensor => negate_cleartext_lwe_buffer - BConcrete::NegateLweTensorOp::attachInterface>(*ctx); + Concrete::NegateLweTensorOp::attachInterface>(*ctx); // keyswitch_lwe_tensor => keyswitch_lwe_buffer - BConcrete::KeySwitchLweTensorOp::attachInterface>( - *ctx); + Concrete::KeySwitchLweTensorOp::attachInterface>(*ctx); // bootstrap_lwe_tensor => bootstrap_lwe_buffer - BConcrete::BootstrapLweTensorOp::attachInterface>( - *ctx); + Concrete::BootstrapLweTensorOp::attachInterface>(*ctx); // batched_keyswitch_lwe_tensor => batched_keyswitch_lwe_buffer - BConcrete::BatchedKeySwitchLweTensorOp::attachInterface< - TensorToMemrefOp>(*ctx); + Concrete::BatchedKeySwitchLweTensorOp::attachInterface< + TensorToMemrefOp>(*ctx); // batched_bootstrap_lwe_tensor => batched_bootstrap_lwe_buffer - BConcrete::BatchedBootstrapLweTensorOp::attachInterface< - TensorToMemrefOp>(*ctx); + Concrete::BatchedBootstrapLweTensorOp::attachInterface< + TensorToMemrefOp>(*ctx); // wop_pbs_crt_lwe_tensor => wop_pbs_crt_lwe_buffer - BConcrete::WopPBSCRTLweTensorOp::attachInterface>( - *ctx); + Concrete::WopPBSCRTLweTensorOp::attachInterface>(*ctx); // encode_plaintext_with_crt_tensor => encode_plaintext_with_crt_buffer - BConcrete::EncodePlaintextWithCrtTensorOp::attachInterface< - TensorToMemrefOp>(*ctx); + Concrete::EncodePlaintextWithCrtTensorOp::attachInterface< + TensorToMemrefOp>(*ctx); // encode_expand_lut_for_bootstrap_tensor => // encode_expand_lut_for_bootstrap_buffer - BConcrete::EncodeExpandLutForBootstrapTensorOp::attachInterface< - TensorToMemrefOp>(*ctx); + Concrete::EncodeExpandLutForBootstrapTensorOp::attachInterface< + TensorToMemrefOp>(*ctx); // encode_expand_lut_for_woppbs_tensor => // encode_expand_lut_for_woppbs_buffer - BConcrete::EncodeExpandLutForWopPBSTensorOp::attachInterface< - TensorToMemrefOp>(*ctx); + Concrete::EncodeExpandLutForWopPBSTensorOp::attachInterface< + TensorToMemrefOp>(*ctx); }); } diff --git a/compiler/lib/Dialect/Concrete/Transforms/CMakeLists.txt b/compiler/lib/Dialect/Concrete/Transforms/CMakeLists.txt index 0694da6e5..0c867f0de 100644 --- a/compiler/lib/Dialect/Concrete/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/Concrete/Transforms/CMakeLists.txt @@ -1,12 +1,19 @@ -add_mlir_library( - ConcreteDialectTransforms - Optimization.cpp +add_mlir_dialect_library( + ConcretelangConcreteTransforms + BufferizableOpInterfaceImpl.cpp + AddRuntimeContext.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Concrete DEPENDS - ConcreteDialect + ConcreteTransformsIncGen mlir-headers LINK_LIBS PUBLIC + ConcretelangConversion + MLIRArithmeticDialect + MLIRBufferizationDialect + MLIRBufferizationTransforms MLIRIR - ConcreteDialect) + MLIRMemRefDialect + MLIRPass + MLIRTransforms) diff --git a/compiler/lib/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.cpp b/compiler/lib/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.cpp index e3c01e291..547ab4161 100644 --- a/compiler/lib/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/compiler/lib/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.cpp @@ -15,8 +15,8 @@ #include "mlir/IR/Operation.h" #include "concretelang/Conversion/Tools.h" -#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" -#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" +#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" +#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" #include "concretelang/Dialect/SDFG/IR/SDFGDialect.h" #include "concretelang/Dialect/SDFG/IR/SDFGOps.h" #include "concretelang/Dialect/SDFG/IR/SDFGTypes.h" diff --git a/compiler/lib/Dialect/SDFG/Transforms/CMakeLists.txt b/compiler/lib/Dialect/SDFG/Transforms/CMakeLists.txt index 723edc0fc..7e2473e91 100644 --- a/compiler/lib/Dialect/SDFG/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/SDFG/Transforms/CMakeLists.txt @@ -3,7 +3,7 @@ add_mlir_dialect_library( BufferizableOpInterfaceImpl.cpp SDFGConvertibleOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/BConcrete + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Concrete ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/SDFG DEPENDS mlir-headers diff --git a/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp b/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp index 0c98ec8fd..848f2ed5d 100644 --- a/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp +++ b/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp @@ -3,8 +3,8 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. -#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" -#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" +#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" +#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" #include "concretelang/Dialect/SDFG/IR/SDFGDialect.h" #include "concretelang/Dialect/SDFG/IR/SDFGOps.h" #include "concretelang/Dialect/SDFG/Interfaces/SDFGConvertibleInterface.h" @@ -54,33 +54,33 @@ struct ReplaceWithProcessSDFGConversionInterface void registerSDFGConvertibleOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, - BConcrete::BConcreteDialect *dialect) { - mlir::concretelang::BConcrete::AddLweTensorOp::attachInterface< + Concrete::ConcreteDialect *dialect) { + mlir::concretelang::Concrete::AddLweTensorOp::attachInterface< ReplaceWithProcessSDFGConversionInterface< - mlir::concretelang::BConcrete::AddLweTensorOp, add_eint>>(*ctx); + mlir::concretelang::Concrete::AddLweTensorOp, add_eint>>(*ctx); - mlir::concretelang::BConcrete::AddPlaintextLweTensorOp::attachInterface< + mlir::concretelang::Concrete::AddPlaintextLweTensorOp::attachInterface< ReplaceWithProcessSDFGConversionInterface< - mlir::concretelang::BConcrete::AddPlaintextLweTensorOp, + mlir::concretelang::Concrete::AddPlaintextLweTensorOp, add_eint_int>>(*ctx); - mlir::concretelang::BConcrete::MulCleartextLweTensorOp::attachInterface< + mlir::concretelang::Concrete::MulCleartextLweTensorOp::attachInterface< ReplaceWithProcessSDFGConversionInterface< - mlir::concretelang::BConcrete::MulCleartextLweTensorOp, + mlir::concretelang::Concrete::MulCleartextLweTensorOp, mul_eint_int>>(*ctx); - mlir::concretelang::BConcrete::NegateLweTensorOp::attachInterface< + mlir::concretelang::Concrete::NegateLweTensorOp::attachInterface< ReplaceWithProcessSDFGConversionInterface< - mlir::concretelang::BConcrete::NegateLweTensorOp, neg_eint>>(*ctx); + mlir::concretelang::Concrete::NegateLweTensorOp, neg_eint>>(*ctx); - mlir::concretelang::BConcrete::KeySwitchLweTensorOp::attachInterface< + mlir::concretelang::Concrete::KeySwitchLweTensorOp::attachInterface< ReplaceWithProcessSDFGConversionInterface< - mlir::concretelang::BConcrete::KeySwitchLweTensorOp, keyswitch, + mlir::concretelang::Concrete::KeySwitchLweTensorOp, keyswitch, true>>(*ctx); - mlir::concretelang::BConcrete::BootstrapLweTensorOp::attachInterface< + mlir::concretelang::Concrete::BootstrapLweTensorOp::attachInterface< ReplaceWithProcessSDFGConversionInterface< - mlir::concretelang::BConcrete::BootstrapLweTensorOp, bootstrap, + mlir::concretelang::Concrete::BootstrapLweTensorOp, bootstrap, true>>(*ctx); }); } diff --git a/compiler/lib/Dialect/TFHE/CMakeLists.txt b/compiler/lib/Dialect/TFHE/CMakeLists.txt index f33061b2d..9f57627c3 100644 --- a/compiler/lib/Dialect/TFHE/CMakeLists.txt +++ b/compiler/lib/Dialect/TFHE/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/compiler/lib/Dialect/TFHE/Transforms/CMakeLists.txt b/compiler/lib/Dialect/TFHE/Transforms/CMakeLists.txt new file mode 100644 index 000000000..eada6b25a --- /dev/null +++ b/compiler/lib/Dialect/TFHE/Transforms/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_library( + TFHEDialectTransforms + Optimization.cpp + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/TFHE + DEPENDS + TFHEDialect + mlir-headers + LINK_LIBS + PUBLIC + MLIRIR + TFHEDialect) diff --git a/compiler/lib/Dialect/Concrete/Transforms/Optimization.cpp b/compiler/lib/Dialect/TFHE/Transforms/Optimization.cpp similarity index 67% rename from compiler/lib/Dialect/Concrete/Transforms/Optimization.cpp rename to compiler/lib/Dialect/TFHE/Transforms/Optimization.cpp index ba59ba2b3..c5bcbc933 100644 --- a/compiler/lib/Dialect/Concrete/Transforms/Optimization.cpp +++ b/compiler/lib/Dialect/TFHE/Transforms/Optimization.cpp @@ -7,8 +7,8 @@ #include #include -#include -#include +#include +#include #include namespace mlir { @@ -30,20 +30,18 @@ getConstantIntFromCleartextIfExists(mlir::Value cleartext) { return {}; } -/// Rewrite a `Concrete.mul_cleartext_lwe_ciphertext` operation as a -/// `Concrete.zero` operation if it's being multiplied with a constant 0, or as -/// a `Concrete.negate_lwe_ciphertext` if multiplied with a constant -1. +/// Rewrite a TFHE multiplication with an integer operation as a +/// Zero operation if it's being multiplied with a constant 0, or as +/// a Negate operation if multiplied with a constant -1. class MulCleartextLweCiphertextOpPattern - : public mlir::OpRewritePattern< - mlir::concretelang::Concrete::MulCleartextLweCiphertextOp> { + : public mlir::OpRewritePattern { public: MulCleartextLweCiphertextOpPattern(mlir::MLIRContext *context) - : mlir::OpRewritePattern< - mlir::concretelang::Concrete::MulCleartextLweCiphertextOp>( + : mlir::OpRewritePattern( context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} mlir::LogicalResult - matchAndRewrite(mlir::concretelang::Concrete::MulCleartextLweCiphertextOp op, + matchAndRewrite(mlir::concretelang::TFHE::MulGLWEIntOp op, mlir::PatternRewriter &rewriter) const override { auto cleartext = op.getOperand(1); auto constIntToMul = getConstantIntFromCleartextIfExists(cleartext); @@ -51,13 +49,12 @@ public: if (constIntToMul.hasValue()) { auto toMul = constIntToMul.getValue().getInt(); if (toMul == 0) { - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getResult().getType()); return mlir::success(); } if (toMul == -1) { - rewriter.replaceOpWithNewOp< - mlir::concretelang::Concrete::NegateLweCiphertextOp>( + rewriter.replaceOpWithNewOp( op, op.getResult().getType(), op.getOperand(0)); return mlir::success(); } @@ -68,8 +65,7 @@ public: /// Optimization pass that should choose more efficient ways of performing /// crypto operations. -class ConcreteOptimizationPass - : public ConcreteOptimizationBase { +class TFHEOptimizationPass : public TFHEOptimizationBase { public: void runOnOperation() override { mlir::Operation *op = getOperation(); @@ -85,8 +81,8 @@ public: } // end anonymous namespace -std::unique_ptr> createConcreteOptimizationPass() { - return std::make_unique(); +std::unique_ptr> createTFHEOptimizationPass() { + return std::make_unique(); } } // namespace concretelang diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index b3ea2fc7b..7a4cd81ba 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -27,11 +27,11 @@ add_mlir_library( FHEDialectTransforms RTDialectAnalysis ConcretelangTransforms - ConcretelangBConcreteTransforms + ConcretelangConcreteTransforms ConcretelangSDFGTransforms ConcretelangSDFGInterfaces LinalgExtras - ConcreteDialectTransforms + TFHEDialectTransforms concrete_optimizer MLIRExecutionEngine ${LLVM_PTHREAD_LIB} diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index b253aca70..cbdcbb150 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -27,9 +27,8 @@ #include "concretelang/Conversion/Utils/GlobalFHEContext.h" #include -#include -#include #include +#include #include #include #include @@ -80,13 +79,12 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() { mlir::concretelang::TFHE::TFHEDialect, mlir::concretelang::FHELinalg::FHELinalgDialect, mlir::concretelang::Concrete::ConcreteDialect, - mlir::concretelang::BConcrete::BConcreteDialect, mlir::concretelang::SDFG::SDFGDialect, mlir::func::FuncDialect, mlir::memref::MemRefDialect, mlir::linalg::LinalgDialect, mlir::LLVM::LLVMDialect, mlir::scf::SCFDialect, mlir::omp::OpenMPDialect, mlir::bufferization::BufferizationDialect>(); - BConcrete::registerBufferizableOpInterfaceExternalModels(registry); Tracing::registerBufferizableOpInterfaceExternalModels(registry); + Concrete::registerBufferizableOpInterfaceExternalModels(registry); SDFG::registerSDFGConvertibleOpInterfaceExternalModels(registry); SDFG::registerBufferizableOpInterfaceExternalModels(registry); arith::registerBufferizableOpInterfaceExternalModels(registry); @@ -392,6 +390,15 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { .failed()) { return errorDiag("Lowering from FHE to TFHE failed"); } + + // Optimizing TFHE + if (this->compilerOptions.optimizeTFHE && + mlir::concretelang::pipeline::optimizeTFHE(mlirContext, module, + this->enablePass) + .failed()) { + return errorDiag("Optimizing TFHE failed"); + } + if (target == Target::TFHE) return std::move(res); @@ -402,37 +409,17 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { return errorDiag("Lowering from TFHE to Concrete failed"); } - // Optimizing Concrete - if (this->compilerOptions.optimizeConcrete && - mlir::concretelang::pipeline::optimizeConcrete(mlirContext, module, - this->enablePass) - .failed()) { - return errorDiag("Optimizing Concrete failed"); - } - if (target == Target::CONCRETE) return std::move(res); - // Concrete -> BConcrete - if (mlir::concretelang::pipeline::lowerConcreteToBConcrete( - mlirContext, module, this->enablePass, loopParallelize) - .failed()) { - return StreamStringError( - "Lowering from Concrete to Bufferized Concrete failed"); - } - - if (target == Target::BCONCRETE) { - return std::move(res); - } - - // Extract SDFG data flow graph from BConcrete representation + // Extract SDFG data flow graph from Concrete representation if (options.emitSDFGOps) { if (mlir::concretelang::pipeline::extractSDFGOps( mlirContext, module, enablePass, options.unrollLoopsWithSDFGConvertibleOps) .failed()) { - return errorDiag("Extraction of SDFG operations from BConcrete " + return errorDiag("Extraction of SDFG operations from Concrete " "representation failed"); } } @@ -441,9 +428,9 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { return std::move(res); } - // BConcrete -> Canonical dialects - if (mlir::concretelang::pipeline::lowerBConcreteToStd(mlirContext, module, - enablePass) + // Concrete -> Canonical dialects + if (mlir::concretelang::pipeline::lowerConcreteToStd(mlirContext, module, + enablePass) .failed()) { return errorDiag("Lowering from Bufferized Concrete to canonical MLIR " "dialects failed"); diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 41b83ee8d..99afc590a 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -31,8 +31,7 @@ #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/Error.h" #include -#include -#include +#include #include #include #include @@ -41,6 +40,7 @@ #include #include #include +#include #include #include #include @@ -290,27 +290,13 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, return pm.run(module.getOperation()); } -mlir::LogicalResult -optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, - std::function enablePass) { +mlir::LogicalResult optimizeTFHE(mlir::MLIRContext &context, + mlir::ModuleOp &module, + std::function enablePass) { mlir::PassManager pm(&context); - pipelinePrinting("ConcreteOptimization", pm, context); - addPotentiallyNestedPass( - pm, mlir::concretelang::createConcreteOptimizationPass(), enablePass); - - return pm.run(module.getOperation()); -} - -mlir::LogicalResult -lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, - std::function enablePass, - bool parallelizeLoops) { - mlir::PassManager pm(&context); - pipelinePrinting("ConcreteToBConcrete", pm, context); - - addPotentiallyNestedPass( - pm, mlir::concretelang::createConvertConcreteToBConcretePass(), - enablePass); + pipelinePrinting("TFHEOptimization", pm, context); + addPotentiallyNestedPass(pm, mlir::concretelang::createTFHEOptimizationPass(), + enablePass); return pm.run(module.getOperation()); } @@ -320,7 +306,7 @@ mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context, std::function enablePass, bool unroll) { mlir::PassManager pm(&context); - pipelinePrinting("extract SDFG ops from BConcrete", pm, context); + pipelinePrinting("extract SDFG ops from Concrete", pm, context); addPotentiallyNestedPass( pm, mlir::concretelang::createExtractSDFGOpsPass(unroll), enablePass); LogicalResult res = pm.run(module.getOperation()); @@ -329,10 +315,10 @@ mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context, } mlir::LogicalResult -lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, - std::function enablePass) { +lowerConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass) { mlir::PassManager pm(&context); - pipelinePrinting("BConcreteToStd", pm, context); + pipelinePrinting("ConcreteToStd", pm, context); addPotentiallyNestedPass(pm, mlir::concretelang::createAddRuntimeContext(), enablePass); return pm.run(module.getOperation()); @@ -399,8 +385,7 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, pm, mlir::concretelang::createFixupBufferDeallocationPass(), enablePass); addPotentiallyNestedPass( - pm, mlir::concretelang::createConvertBConcreteToCAPIPass(gpu), - enablePass); + pm, mlir::concretelang::createConvertConcreteToCAPIPass(gpu), enablePass); addPotentiallyNestedPass( pm, mlir::concretelang::createConvertTracingToCAPIPass(), enablePass); diff --git a/compiler/src/CMakeLists.txt b/compiler/src/CMakeLists.txt index 804bb435a..5aba3e8cc 100644 --- a/compiler/src/CMakeLists.txt +++ b/compiler/src/CMakeLists.txt @@ -12,7 +12,6 @@ target_link_libraries( PRIVATE ${dialect_libs} ${conversion_libs} MLIRTransforms - BConcreteDialect ConcreteDialect TFHEDialect FHEDialect diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 5e1381108..c035b3b57 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -49,7 +49,6 @@ enum Action { DUMP_FHE_NO_LINALG, DUMP_TFHE, DUMP_CONCRETE, - DUMP_BCONCRETE, DUMP_SDFG, DUMP_STD, DUMP_LLVM_DIALECT, @@ -95,10 +94,10 @@ llvm::cl::opt verbose("verbose", llvm::cl::desc("verbose logs"), llvm::cl::init(false)); llvm::cl::opt - optimizeConcrete("optimize-concrete", - llvm::cl::desc("enable/disable optimizations of concrete " - "dialects. (Enabled by default)"), - llvm::cl::init(true)); + optimizeTFHE("optimize-tfhe", + llvm::cl::desc("enable/disable optimizations of TFHE " + "dialects. (Enabled by default)"), + llvm::cl::init(true)); llvm::cl::opt emitGPUOps( "emit-gpu-ops", @@ -126,9 +125,6 @@ static llvm::cl::opt action( "Lower to TFHE and dump result")), llvm::cl::values(clEnumValN(Action::DUMP_CONCRETE, "dump-concrete", "Lower to Concrete and dump result")), - llvm::cl::values( - clEnumValN(Action::DUMP_BCONCRETE, "dump-bconcrete", - "Lower to Bufferized Concrete and dump result")), llvm::cl::values(clEnumValN(Action::DUMP_SDFG, "dump-sdfg", "Lower to SDFG operations annd dump result")), llvm::cl::values(clEnumValN(Action::DUMP_STD, "dump-std", @@ -354,7 +350,7 @@ cmdlineCompilationOptions() { options.emitSDFGOps = cmdline::emitSDFGOps; options.unrollLoopsWithSDFGConvertibleOps = cmdline::unrollLoopsWithSDFGConvertibleOps; - options.optimizeConcrete = cmdline::optimizeConcrete; + options.optimizeTFHE = cmdline::optimizeTFHE; options.emitGPUOps = cmdline::emitGPUOps; options.chunkIntegers = cmdline::chunkIntegers; options.chunkSize = cmdline::chunkSize; @@ -531,9 +527,6 @@ mlir::LogicalResult processInputBuffer( case Action::DUMP_CONCRETE: target = mlir::concretelang::CompilerEngine::Target::CONCRETE; break; - case Action::DUMP_BCONCRETE: - target = mlir::concretelang::CompilerEngine::Target::BCONCRETE; - break; case Action::DUMP_SDFG: target = mlir::concretelang::CompilerEngine::Target::SDFG; break; diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe.mlir deleted file mode 100644 index cc158f1b0..000000000 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe.mlir +++ /dev/null @@ -1,10 +0,0 @@ -// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s - -//CHECK: func @add_lwe_ciphertexts(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>) -> tensor<2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.add_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, tensor<2049xi64>) -> tensor<2049xi64> -//CHECK: return %[[V0]] : tensor<2049xi64> -//CHECK: } -func.func @add_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { - %0 = "Concrete.add_lwe_ciphertexts"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> - return %0 : !Concrete.lwe_ciphertext<2048,7> -} diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe_int.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe_int.mlir deleted file mode 100644 index 1ca1118f5..000000000 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe_int.mlir +++ /dev/null @@ -1,22 +0,0 @@ -// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s - - -//CHECK: func.func @add_glwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> { -//CHECK: %c1_i64 = arith.constant 1 : i64 -//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %c1_i64) : (tensor<1025xi64>, i64) -> tensor<1025xi64> -//CHECK: return %[[V2]] : tensor<1025xi64> -//CHECK: } -func.func @add_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> { - %0 = arith.constant 1 : i64 - %2 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,7>, i64) -> !Concrete.lwe_ciphertext<1024,7> - return %2 : !Concrete.lwe_ciphertext<1024,7> -} - -//CHECK: func.func @add_glwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i64) -> tensor<1025xi64> { -//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> -//CHECK: return %[[V2]] : tensor<1025xi64> -//CHECK: } -func.func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i64) -> !Concrete.lwe_ciphertext<1024,4> { - %1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<1024,4>, i64) -> !Concrete.lwe_ciphertext<1024,4> - return %1 : !Concrete.lwe_ciphertext<1024,4> -} diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table.mlir deleted file mode 100644 index 617f861bb..000000000 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table.mlir +++ /dev/null @@ -1,12 +0,0 @@ -// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s - -//CHECK: func.func @apply_lookup_table(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: tensor<16xi64>) -> tensor<1025xi64> { -//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 1024 : i32, lwe_dim_out = 600 : i32} : (tensor<1025xi64>) -> tensor<601xi64> -//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_tensor"(%[[V1]], %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 1024 : i32} : (tensor<601xi64>, tensor<16xi64>) -> tensor<1025xi64> -//CHECK: return %[[V2]] : tensor<1025xi64> -//CHECK: } -func.func @apply_lookup_table(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !Concrete.lwe_ciphertext<1024,4> { - %1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, level = 3 : i32} : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<600,4> - %2 = "Concrete.bootstrap_lwe"(%1, %arg1) {baseLog = 2 : i32, polySize = 1024 : i32, level = 3 : i32, glweDimension = 4 : i32} : (!Concrete.lwe_ciphertext<600,4>, tensor<16xi64> ) -> !Concrete.lwe_ciphertext<1024,4> - return %2 : !Concrete.lwe_ciphertext<1024,4> -} diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table_cst.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table_cst.mlir deleted file mode 100644 index ac1b09fde..000000000 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table_cst.mlir +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s - -//CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> { -//CHECK: %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64> -//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 600 : i32} : (tensor<2049xi64>) -> tensor<601xi64> -//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_tensor"(%[[V1]], %cst) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<601xi64>, tensor<16xi64>) -> tensor<2049xi64> -//CHECK: return %[[V2]] : tensor<2049xi64> -//CHECK: } -func.func @apply_lookup_table_cst(%arg0: !Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<2048,4> { - %tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64> - %1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, level = 3 : i32} : (!Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<600,4> - %2 = "Concrete.bootstrap_lwe"(%1, %tlu) {baseLog = 2 : i32, polySize = 2048 : i32, level = 3 : i32, glweDimension = 4 : i32} : (!Concrete.lwe_ciphertext<600,4>, tensor<16xi64>) -> !Concrete.lwe_ciphertext<2048,4> - return %2 : !Concrete.lwe_ciphertext<2048,4> -} diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/encode_expand_lut_for_bootstrap.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/encode_expand_lut_for_bootstrap.mlir deleted file mode 100644 index 57f99eada..000000000 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/encode_expand_lut_for_bootstrap.mlir +++ /dev/null @@ -1,10 +0,0 @@ -// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s - -// CHECK: func.func @apply_lookup_table(%arg0: tensor<4xi64>) -> tensor<1024xi64> { -// CHECK-NEXT: %0 = "BConcrete.encode_expand_lut_for_bootstrap_tensor"(%arg0) {isSigned = true, outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64> -// CHECK-NEXT: return %0 : tensor<1024xi64> -// CHECK-NEXT: } -func.func @apply_lookup_table(%arg0: tensor<4xi64>) -> tensor<1024xi64> { - %0 = "Concrete.encode_expand_lut_for_bootstrap"(%arg0) {outputBits = 3 : i32, polySize = 1024 : i32, isSigned = true} : (tensor<4xi64>) -> tensor<1024xi64> - return %0 : tensor<1024xi64> -} diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/encode_expand_lut_for_woppbs.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/encode_expand_lut_for_woppbs.mlir deleted file mode 100644 index 2d56b4e6d..000000000 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/encode_expand_lut_for_woppbs.mlir +++ /dev/null @@ -1,10 +0,0 @@ -// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s - -// CHECK: func.func @main(%arg0: tensor<4xi64>) -> tensor<40960xi64> { -// CHECK-NEXT: %0 = "BConcrete.encode_expand_lut_for_woppbs_tensor"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64> -// CHECK-NEXT: return %0 : tensor<40960xi64> -// CHECK-NEXT: } -func.func @main(%arg0: tensor<4xi64>) -> tensor<40960xi64> { - %0 = "Concrete.encode_expand_lut_for_woppbs"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64> - return %0 : tensor<40960xi64> -} diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/encode_plaintext_with_crt.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/encode_plaintext_with_crt.mlir deleted file mode 100644 index f8a72f321..000000000 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/encode_plaintext_with_crt.mlir +++ /dev/null @@ -1,10 +0,0 @@ -// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s - -// CHECK: func.func @main(%arg0: i64) -> tensor<5xi64> { -// CHECK-NEXT: %0 = "BConcrete.encode_plaintext_with_crt_tensor"(%arg0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64> -// CHECK-NEXT: return %0 : tensor<5xi64> -// CHECK-NEXT: } -func.func @main(%arg0: i64) -> tensor<5xi64> { - %0 = "Concrete.encode_plaintext_with_crt"(%arg0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64> - return %0 : tensor<5xi64> -} diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/identity.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/identity.mlir deleted file mode 100644 index 3b0224725..000000000 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/identity.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s - -// CHECK: func.func @identity(%arg0: tensor<1025xi64>) -> tensor<1025xi64> { -// CHECK-NEXT: return %arg0 : tensor<1025xi64> -// CHECK-NEXT: } -func.func @identity(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> { - return %arg0 : !Concrete.lwe_ciphertext<1024,7> -} diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/mul_lwe_int.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/mul_lwe_int.mlir deleted file mode 100644 index b4c8990a3..000000000 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/mul_lwe_int.mlir +++ /dev/null @@ -1,21 +0,0 @@ -// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s - -//CHECK: func.func @mul_lwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> { -//CHECK: %c1_i64 = arith.constant 1 : i64 -//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %c1_i64) : (tensor<1025xi64>, i64) -> tensor<1025xi64> -//CHECK: return %[[V1]] : tensor<1025xi64> -//CHECK: } -func.func @mul_lwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> { - %0 = arith.constant 1 : i64 - %2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,7>, i64) -> !Concrete.lwe_ciphertext<1024,7> - return %2 : !Concrete.lwe_ciphertext<1024,7> -} - -//CHECK: func.func @mul_lwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i64) -> tensor<1025xi64> { -//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> -//CHECK: return %[[V1]] : tensor<1025xi64> -//CHECK: } -func.func @mul_lwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i64) -> !Concrete.lwe_ciphertext<1024,4> { - %1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<1024,4>, i64) -> !Concrete.lwe_ciphertext<1024,4> - return %1 : !Concrete.lwe_ciphertext<1024,4> -} diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/neg_lwe.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/neg_lwe.mlir deleted file mode 100644 index 14a3a1712..000000000 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/neg_lwe.mlir +++ /dev/null @@ -1,10 +0,0 @@ -// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s - -//CHECK: func.func @neg_lwe(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.negate_lwe_tensor"(%[[A0]]) : (tensor<1025xi64>) -> tensor<1025xi64> -//CHECK: return %[[V0]] : tensor<1025xi64> -//CHECK: } -func.func @neg_lwe(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4> { - %0 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4> - return %0 : !Concrete.lwe_ciphertext<1024,4> -} diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_identity.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_identity.mlir deleted file mode 100644 index f718a0963..000000000 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_identity.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s - -// CHECK: func.func @tensor_identity(%arg0: tensor<2x3x4x1025xi64>) -> tensor<2x3x4x1025xi64> { -// CHECK-NEXT: return %arg0 : tensor<2x3x4x1025xi64> -// CHECK-NEXT: } -func.func @tensor_identity(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>> { - return %arg0 : tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>> -} diff --git a/compiler/tests/check_tests/Conversion/ConcreteToLLVM/gpu_ops.mlir b/compiler/tests/check_tests/Conversion/ConcreteToLLVM/gpu_ops.mlir index f474a8195..8e035e1e6 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToLLVM/gpu_ops.mlir +++ b/compiler/tests/check_tests/Conversion/ConcreteToLLVM/gpu_ops.mlir @@ -2,9 +2,9 @@ //CHECK: llvm.call @memref_keyswitch_lwe_cuda_u64 //CHECK: llvm.call @memref_bootstrap_lwe_cuda_u64 -func.func @main(%arg0: !Concrete.lwe_ciphertext<1024,2>) -> !Concrete.lwe_ciphertext<1024,2> { +func.func @main(%arg0: tensor<1025xi64>) -> tensor<1025xi64> { %cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi64> - %0 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, level = 5 : i32} : (!Concrete.lwe_ciphertext<1024,2>) -> !Concrete.lwe_ciphertext<575,2> - %1 = "Concrete.bootstrap_lwe"(%0, %cst) {baseLog = 2 : i32, level = 5 : i32, polySize = 1024: i32, glweDimension = 1 : i32} : (!Concrete.lwe_ciphertext<575,2>, tensor<4xi64>) -> !Concrete.lwe_ciphertext<1024,2> - return %1 : !Concrete.lwe_ciphertext<1024,2> + %0 = "Concrete.keyswitch_lwe_tensor"(%arg0) {baseLog = 2 : i32, level = 5 : i32, lwe_dim_in = 1025 : i32, lwe_dim_out = 576 : i32} : (tensor<1025xi64>) -> tensor<576xi64> + %1 = "Concrete.bootstrap_lwe_tensor"(%0, %cst) {baseLog = 2 : i32, level = 5 : i32, polySize = 1024: i32, glweDimension = 1 : i32, inputLweDim = 576 : i32, outPrecision = 2 : i32} : (tensor<576xi64>, tensor<4xi64>) -> tensor<1025xi64> + return %1 : tensor<1025xi64> } diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint.mlir index 99d991adc..1837ed67e 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s +// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s //CHECK-LABEL: func.func @add_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>, %arg1: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> { //CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>> diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint_int.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint_int.mlir index 89f20b32d..181e9a9ef 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint_int.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint_int.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s +// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s // CHECK-LABEL: func.func @add_eint_int(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> { // CHECK-NEXT: %c1_i8 = arith.constant 1 : i8 diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate.mlir index a958311d8..9e9c17899 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s +// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s // CHECK: func.func @apply_lookup_table(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{3}>> // CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64> diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate_cst.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate_cst.mlir index ce3b8359d..1ad7e30f6 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate_cst.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate_cst.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s +// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s // CHECK: func.func @apply_lookup_table_cst(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> { // CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64> diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/conv2d.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/conv2d.mlir index ec45fd288..d6de6fe22 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/conv2d.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/conv2d.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s +// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s //CHECK-LABEL: func.func @conv2d(%arg0: tensor<100x3x28x28x5x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4x3x14x14xi3>, %arg2: tensor<4xi3>) -> tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> { diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/mul_eint_int.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/mul_eint_int.mlir index bbc98bb1b..e0f7ac85a 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/mul_eint_int.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/mul_eint_int.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s +// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s // CHECK-LABEL: func.func @mul_eint_int(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> { // CHECK-NEXT: %c2_i8 = arith.constant 2 : i8 diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/neg_eint.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/neg_eint.mlir index 3d18ce33a..16852a8fa 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/neg_eint.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/neg_eint.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s +// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s // CHECK-LABEL: func.func @neg_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> { // CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>> diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/sub_int_eint.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/sub_int_eint.mlir index 42d0c5dd1..cb63ac45d 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/sub_int_eint.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/sub_int_eint.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s +// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s // CHECK-LABEL: func.func @sub_int_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> { // CHECK-NEXT: %c1_i8 = arith.constant 1 : i8 diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/add_eint.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/add_eint.mlir index c56585f24..8ec0d0f71 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/add_eint.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/add_eint.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s +// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s // CHECK-LABEL: func.func @add_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>, %arg1: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> func.func @add_eint(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> { diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/add_eint_int.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/add_eint_int.mlir index 481577b39..655bfc435 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/add_eint_int.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/add_eint_int.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s +// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s // CHECK-LABEL: func.func @add_eint_int(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> func.func @add_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/apply_univariate.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/apply_univariate.mlir index a17912704..516207a52 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/apply_univariate.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/apply_univariate.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s +// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s // CHECK: func.func @apply_lookup_table(%arg0: !TFHE.glwe<{_,_,_}{2}>, %arg1: tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}> { // CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_bootstrap"(%arg1) {isSigned = false, outputBits = 3 : i32, polySize = 256 : i32} : (tensor<4xi64>) -> tensor<256xi64> diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/apply_univariate_cst.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/apply_univariate_cst.mlir index 4a3df25b0..988c00484 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/apply_univariate_cst.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/apply_univariate_cst.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s +// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s //CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> { diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/conv2d.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/conv2d.mlir index 2f05cac72..81a42947d 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/conv2d.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/conv2d.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s +// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s //CHECK: func.func @conv2d(%arg0: tensor<100x3x28x28x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4x3x14x14xi3>, %arg2: tensor<4xi3>) -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>> // CHECK-NEXT: %c4 = arith.constant 4 : index diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/mul_eint_int.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/mul_eint_int.mlir index e3f822d03..038037951 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/mul_eint_int.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/mul_eint_int.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s +// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s // CHECK-LABEL: func.func @mul_eint_int(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> func.func @mul_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/neg_eint.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/neg_eint.mlir index c162c2df4..8c9651ad7 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/neg_eint.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/neg_eint.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s +// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s // CHECK-LABEL: func.func @neg_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/sub_int_eint.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/sub_int_eint.mlir index 2ec79d939..752e8f7f7 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/sub_int_eint.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/sub_int_eint.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s +// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s // CHECK-LABEL: func.func @sub_int_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> func.func @sub_int_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe.mlir index 2ecf8db9f..9c29ea86e 100644 --- a/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe.mlir @@ -1,9 +1,9 @@ // RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s -// CHECK-LABEL: func.func @add_glwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> +// CHECK-LABEL: func.func @add_glwe(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>) -> tensor<2049xi64> func.func @add_glwe(%arg0: !TFHE.glwe<{2048,1,64}{7}>, %arg1: !TFHE.glwe<{2048,1,64}{7}>) -> !TFHE.glwe<{2048,1,64}{7}> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.add_lwe_ciphertexts"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> - // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> + // CHECK-NEXT: %[[V1:.*]] = "Concrete.add_lwe_tensor"(%arg0, %arg1) : (tensor<2049xi64>, tensor<2049xi64>) -> tensor<2049xi64> + // CHECK-NEXT: return %[[V1]] : tensor<2049xi64> %0 = "TFHE.add_glwe"(%arg0, %arg1): (!TFHE.glwe<{2048,1,64}{7}>, !TFHE.glwe<{2048,1,64}{7}>) -> (!TFHE.glwe<{2048,1,64}{7}>) return %0: !TFHE.glwe<{2048,1,64}{7}> diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe_int.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe_int.mlir index 000d4ff98..66c5b2a6c 100644 --- a/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe_int.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe_int.mlir @@ -1,9 +1,9 @@ // RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s -//CHECK: func.func @add_glwe_const_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> { +//CHECK: func.func @add_glwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> { //CHECK: %c1_i64 = arith.constant 1 : i64 -//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[A0]], %c1_i64) : (!Concrete.lwe_ciphertext<1024,7>, i64) -> !Concrete.lwe_ciphertext<1024,7> -//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,7> +//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_tensor"(%[[A0]], %c1_i64) : (tensor<1025xi64>, i64) -> tensor<1025xi64> +//CHECK: return %[[V0]] : tensor<1025xi64> //CHECK: } func.func @add_glwe_const_int(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{7}> { %0 = arith.constant 1 : i64 @@ -12,9 +12,9 @@ func.func @add_glwe_const_int(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{ } -//CHECK: func.func @add_glwe_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i64) -> !Concrete.lwe_ciphertext<1024,4> { -//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[A0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i64) -> !Concrete.lwe_ciphertext<1024,4> -//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,4> +//CHECK: func.func @add_glwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i64) -> tensor<1025xi64> { +//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> +//CHECK: return %[[V0]] : tensor<1025xi64> //CHECK: } func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i64) -> !TFHE.glwe<{1024,1,64}{4}> { %1 = "TFHE.add_glwe_int"(%arg0, %arg1): (!TFHE.glwe<{1024,1,64}{4}>, i64) -> (!TFHE.glwe<{1024,1,64}{4}>) diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/bootstrap.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/bootstrap.mlir index fadf286a6..6325fb98c 100644 --- a/compiler/tests/check_tests/Conversion/TFHEToConcrete/bootstrap.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/bootstrap.mlir @@ -1,9 +1,9 @@ // RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s -//CHECK: func.func @bootstrap_lwe(%[[A0:.*]]: !Concrete.lwe_ciphertext<600,7>) -> !Concrete.lwe_ciphertext<1024,4> { +//CHECK: func.func @bootstrap_lwe(%[[A0:.*]]: tensor<601xi64>) -> tensor<1025xi64> { //CHECK: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64> -//CHECK: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %cst) {baseLog = 1 : i32, glweDimension = 1 : i32, level = 3 : i32, polySize = 1024 : i32} : (!Concrete.lwe_ciphertext<600,7>, tensor<128xi64>) -> !Concrete.lwe_ciphertext<1024,4> -//CHECK: return %[[V1]] : !Concrete.lwe_ciphertext<1024,4> +//CHECK: %[[V1:.*]] = "Concrete.bootstrap_lwe_tensor"(%arg0, %cst) {baseLog = 1 : i32, glweDimension = 1 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 1024 : i32} : (tensor<601xi64>, tensor<128xi64>) -> tensor<1025xi64> +//CHECK: return %[[V1]] : tensor<1025xi64> //CHECK: } func.func @bootstrap_lwe(%ciphertext: !TFHE.glwe<{600,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{4}> { %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64> diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_bootstrap.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_bootstrap.mlir index b93fe73a6..fd369d907 100644 --- a/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_bootstrap.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_bootstrap.mlir @@ -1,7 +1,7 @@ // RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s // CHECK: func.func @apply_lookup_table(%arg0: tensor<4xi64>) -> tensor<1024xi64> { -// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_bootstrap"(%arg0) {isSigned = true, outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64> +// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_bootstrap_tensor"(%arg0) {isSigned = true, outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64> // CHECK-NEXT: return %0 : tensor<1024xi64> // CHECK-NEXT: } func.func @apply_lookup_table(%arg1: tensor<4xi64>) -> tensor<1024xi64> { diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_woppbs.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_woppbs.mlir index 271398f66..b37612ac3 100644 --- a/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_woppbs.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_woppbs.mlir @@ -1,7 +1,7 @@ // RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s // CHECK: func.func @main(%arg0: tensor<4xi64>) -> tensor<40960xi64> { -// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_woppbs"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64> +// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_woppbs_tensor"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64> // CHECK-NEXT: return %0 : tensor<40960xi64> // CHECK-NEXT: } func.func @main(%arg1: tensor<4xi64>) -> tensor<40960xi64> { diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_plaintext_with_crt.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_plaintext_with_crt.mlir index 7959ba646..694634319 100644 --- a/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_plaintext_with_crt.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_plaintext_with_crt.mlir @@ -1,7 +1,7 @@ // RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s // CHECK: func.func @main(%arg0: i64) -> tensor<5xi64> { -// CHECK-NEXT: %0 = "Concrete.encode_plaintext_with_crt"(%arg0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64> +// CHECK-NEXT: %0 = "Concrete.encode_plaintext_with_crt_tensor"(%arg0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64> // CHECK-NEXT: return %0 : tensor<5xi64> // CHECK-NEXT: } func.func @main(%arg1: i64) -> tensor<5xi64> { diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/keyswitch.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/keyswitch.mlir index 4517f08c0..f4b451c6c 100644 --- a/compiler/tests/check_tests/Conversion/TFHEToConcrete/keyswitch.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/keyswitch.mlir @@ -1,8 +1,8 @@ // RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s -// CHECK: func.func @keyswitch_glwe(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,2>) -> !Concrete.lwe_ciphertext<567,2> { -// CHECK-NEXT: %[[V0:.*]] = "Concrete.keyswitch_lwe"(%[[A0]]) {baseLog = 3 : i32, level = 2 : i32} : (!Concrete.lwe_ciphertext<1024,2>) -> !Concrete.lwe_ciphertext<567,2> -// CHECK-NEXT: return %[[V0]] : !Concrete.lwe_ciphertext<567,2> +// CHECK: func.func @keyswitch_glwe(%[[A0:.*]]: tensor<1025xi64>) -> tensor<568xi64> { +// CHECK-NEXT: %[[V0:.*]] = "Concrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 3 : i32, level = 2 : i32, lwe_dim_in = 1024 : i32, lwe_dim_out = 567 : i32} : (tensor<1025xi64>) -> tensor<568xi64> +// CHECK-NEXT: return %[[V0]] : tensor<568xi64> // CHECK-NEXT: } func.func @keyswitch_glwe(%arg0: !TFHE.glwe<{1024,1,64}{2}>) -> !TFHE.glwe<{567,1,64}{2}> { %0 = "TFHE.keyswitch_glwe"(%arg0) {baseLog = 3 : i32, level = 2 : i32} : (!TFHE.glwe<{1024,1,64}{2}>) -> !TFHE.glwe<{567,1,64}{2}> diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/mul_glwe_int.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/mul_glwe_int.mlir index 81239695f..eaae122bf 100644 --- a/compiler/tests/check_tests/Conversion/TFHEToConcrete/mul_glwe_int.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/mul_glwe_int.mlir @@ -1,9 +1,9 @@ // RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s -//CHECK: func.func @mul_glwe_const_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> { +//CHECK: func.func @mul_glwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> { //CHECK: %c1_i64 = arith.constant 1 : i64 -//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %c1_i64) : (!Concrete.lwe_ciphertext<1024,7>, i64) -> !Concrete.lwe_ciphertext<1024,7> -//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,7> +//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_tensor"(%[[A0]], %c1_i64) : (tensor<1025xi64>, i64) -> tensor<1025xi64> +//CHECK: return %[[V0]] : tensor<1025xi64> //CHECK: } func.func @mul_glwe_const_int(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{7}> { %0 = arith.constant 1 : i64 @@ -12,9 +12,9 @@ func.func @mul_glwe_const_int(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{ } -//CHECK: func.func @mul_glwe_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i64) -> !Concrete.lwe_ciphertext<1024,4> { -//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i64) -> !Concrete.lwe_ciphertext<1024,4> -//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,4> +//CHECK: func.func @mul_glwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i64) -> tensor<1025xi64> { +//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> +//CHECK: return %[[V0]] : tensor<1025xi64> //CHECK: } func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i64) -> !TFHE.glwe<{1024,1,64}{4}> { %1 = "TFHE.mul_glwe_int"(%arg0, %arg1): (!TFHE.glwe<{1024,1,64}{4}>, i64) -> (!TFHE.glwe<{1024,1,64}{4}>) diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/neg_glwe.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/neg_glwe.mlir index 7ae9e711e..a4e77ccb8 100644 --- a/compiler/tests/check_tests/Conversion/TFHEToConcrete/neg_glwe.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/neg_glwe.mlir @@ -1,9 +1,9 @@ // RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s -// CHECK-LABEL: func.func @neg_glwe(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4> +// CHECK-LABEL: func.func @neg_glwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64> func.func @neg_glwe(%arg0: !TFHE.glwe<{1024,1,64}{4}>) -> !TFHE.glwe<{1024,1,64}{4}> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4> - // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<1024,4> + // CHECK-NEXT: %[[V1:.*]] = "Concrete.negate_lwe_tensor"(%arg0) : (tensor<1025xi64>) -> tensor<1025xi64> + // CHECK-NEXT: return %[[V1]] : tensor<1025xi64> %1 = "TFHE.neg_glwe"(%arg0): (!TFHE.glwe<{1024,1,64}{4}>) -> (!TFHE.glwe<{1024,1,64}{4}>) return %1: !TFHE.glwe<{1024,1,64}{4}> } diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/sub_int_glwe.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/sub_int_glwe.mlir index 192cbf1c6..c382e9f81 100644 --- a/compiler/tests/check_tests/Conversion/TFHEToConcrete/sub_int_glwe.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/sub_int_glwe.mlir @@ -1,10 +1,10 @@ // RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s -//CHECK: func.func @sub_const_int_glwe(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> { +//CHECK: func.func @sub_const_int_glwe(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> { //CHECK: %c1_i64 = arith.constant 1 : i64 -//CHECK: %[[V0:.*]] = "Concrete.negate_lwe_ciphertext"(%[[A0]]) : (!Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> -//CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[V0]], %c1_i64) : (!Concrete.lwe_ciphertext<1024,7>, i64) -> !Concrete.lwe_ciphertext<1024,7> -//CHECK: return %[[V1]] : !Concrete.lwe_ciphertext<1024,7> +//CHECK: %[[V0:.*]] = "Concrete.negate_lwe_tensor"(%[[A0]]) : (tensor<1025xi64>) -> tensor<1025xi64> +//CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_tensor"(%[[V0]], %c1_i64) : (tensor<1025xi64>, i64) -> tensor<1025xi64> +//CHECK: return %[[V1]] : tensor<1025xi64> //CHECK: } func.func @sub_const_int_glwe(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{7}> { %0 = arith.constant 1 : i64 @@ -12,10 +12,10 @@ func.func @sub_const_int_glwe(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{ return %1: !TFHE.glwe<{1024,1,64}{7}> } -//CHECK: func.func @sub_int_glwe(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i64) -> !Concrete.lwe_ciphertext<1024,4> { -//CHECK: %[[V0:.*]] = "Concrete.negate_lwe_ciphertext"(%[[A0]]) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4> -//CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[V0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i64) -> !Concrete.lwe_ciphertext<1024,4> -//CHECK: return %[[V1]] : !Concrete.lwe_ciphertext<1024,4> +//CHECK: func.func @sub_int_glwe(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i64) -> tensor<1025xi64> { +//CHECK: %[[V0:.*]] = "Concrete.negate_lwe_tensor"(%[[A0]]) : (tensor<1025xi64>) -> tensor<1025xi64> +//CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_tensor"(%[[V0]], %[[A1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> +//CHECK: return %[[V1]] : tensor<1025xi64> //CHECK: } func.func @sub_int_glwe(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i64) -> !TFHE.glwe<{1024,1,64}{4}> { %1 = "TFHE.sub_int_glwe"(%arg1, %arg0): (i64, !TFHE.glwe<{1024,1,64}{4}>) -> (!TFHE.glwe<{1024,1,64}{4}>) diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_exapand_collapse_shape.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_exapand_collapse_shape.mlir similarity index 54% rename from compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_exapand_collapse_shape.mlir rename to compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_exapand_collapse_shape.mlir index ec6b0bf28..51962d18d 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_exapand_collapse_shape.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_exapand_collapse_shape.mlir @@ -1,12 +1,12 @@ -// RUN: concretecompiler --split-input-file --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s +// RUN: concretecompiler --split-input-file --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s //CHECK: func.func @tensor_collapse_shape(%[[A0:.*]]: tensor<2x3x4x5x6x1025xi64>) -> tensor<720x1025xi64> { //CHECK: %[[V0:.*]] = tensor.collapse_shape %[[A0]] [[_:\[\[0, 1, 2, 3, 4\], \[5\]\]]] : tensor<2x3x4x5x6x1025xi64> into tensor<720x1025xi64> //CHECK: return %[[V0]] : tensor<720x1025xi64> //CHECK: } -func.func @tensor_collapse_shape(%arg0: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<720x!Concrete.lwe_ciphertext<1024,4>> { - %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4]] {MANP = 1 : ui1}: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<1024,4>> into tensor<720x!Concrete.lwe_ciphertext<1024,4>> - return %0 : tensor<720x!Concrete.lwe_ciphertext<1024,4>> +func.func @tensor_collapse_shape(%arg0: tensor<2x3x4x5x6x!TFHE.glwe<{1024, 1, 64}{4}>>) -> tensor<720x!TFHE.glwe<{1024, 1, 64}{4}>> { + %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4]] {MANP = 1 : ui1}: tensor<2x3x4x5x6x!TFHE.glwe<{1024, 1, 64}{4}>> into tensor<720x!TFHE.glwe<{1024, 1, 64}{4}>> + return %0 : tensor<720x!TFHE.glwe<{1024, 1, 64}{4}>> } // ----- @@ -15,10 +15,10 @@ func.func @tensor_collapse_shape(%arg0: tensor<2x3x4x5x6x!Concrete.lwe_ciphertex //CHECK: %[[V1:.*]] = tensor.expand_shape %[[V0]] [[_:\[\[0, 1\], \[2\]\]]] : tensor<30x1025xi64> into tensor<5x6x1025xi64> //CHECK: return %[[V1]] : tensor<5x6x1025xi64> //CHECK: } -func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x5x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<5x6x!Concrete.lwe_ciphertext<1024,4>> { - %0 = tensor.collapse_shape %arg0 [[0, 1, 2]] {MANP = 1 : ui1}: tensor<2x3x5x!Concrete.lwe_ciphertext<1024,4>> into tensor<30x!Concrete.lwe_ciphertext<1024,4>> - %1 = tensor.expand_shape %0 [[0, 1]] {MANP = 1 : ui1}: tensor<30x!Concrete.lwe_ciphertext<1024,4>> into tensor<5x6x!Concrete.lwe_ciphertext<1024,4>> - return %1 : tensor<5x6x!Concrete.lwe_ciphertext<1024,4>> +func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x5x!TFHE.glwe<{1024, 1, 64}{4}>>) -> tensor<5x6x!TFHE.glwe<{1024, 1, 64}{4}>> { + %0 = tensor.collapse_shape %arg0 [[0, 1, 2]] {MANP = 1 : ui1}: tensor<2x3x5x!TFHE.glwe<{1024, 1, 64}{4}>> into tensor<30x!TFHE.glwe<{1024, 1, 64}{4}>> + %1 = tensor.expand_shape %0 [[0, 1]] {MANP = 1 : ui1}: tensor<30x!TFHE.glwe<{1024, 1, 64}{4}>> into tensor<5x6x!TFHE.glwe<{1024, 1, 64}{4}>> + return %1 : tensor<5x6x!TFHE.glwe<{1024, 1, 64}{4}>> } // ----- @@ -26,9 +26,9 @@ func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x5x!Concrete.lwe_ciphertex //CHECK: %[[V0:.*]] = tensor.collapse_shape %[[A0]] [[_:\[\[0, 1\], \[2\], \[3, 4\], \[5\]\]]] : tensor<2x3x2x3x4x1025xi64> into tensor<6x2x12x1025xi64> //CHECK: return %[[V0]] : tensor<6x2x12x1025xi64> //CHECK: } -func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x2x3x4x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>> { - %0 = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] {MANP = 1 : ui1}: tensor<2x3x2x3x4x!Concrete.lwe_ciphertext<1024,4>> into tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>> - return %0 : tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>> +func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x2x3x4x!TFHE.glwe<{1024, 1, 64}{4}>>) -> tensor<6x2x12x!TFHE.glwe<{1024, 1, 64}{4}>> { + %0 = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] {MANP = 1 : ui1}: tensor<2x3x2x3x4x!TFHE.glwe<{1024, 1, 64}{4}>> into tensor<6x2x12x!TFHE.glwe<{1024, 1, 64}{4}>> + return %0 : tensor<6x2x12x!TFHE.glwe<{1024, 1, 64}{4}>> } // ----- @@ -36,7 +36,7 @@ func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x2x3x4x!Concrete.lwe_ciphe //CHECK: %[[V0:.*]] = tensor.expand_shape %[[A0]] [[_:\[\[0, 1\], \[2\]\]]] : tensor<30x1025xi64> into tensor<5x6x1025xi64> //CHECK: return %[[V0]] : tensor<5x6x1025xi64> //CHECK: } -func.func @tensor_expand_shape_crt(%arg0: tensor<30x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<5x6x!Concrete.lwe_ciphertext<1024,4>> { - %0 = tensor.expand_shape %arg0 [[0, 1]] {MANP = 1 : ui1}: tensor<30x!Concrete.lwe_ciphertext<1024,4>> into tensor<5x6x!Concrete.lwe_ciphertext<1024,4>> - return %0 : tensor<5x6x!Concrete.lwe_ciphertext<1024,4>> +func.func @tensor_expand_shape_crt(%arg0: tensor<30x!TFHE.glwe<{1024, 1, 64}{4}>>) -> tensor<5x6x!TFHE.glwe<{1024, 1, 64}{4}>> { + %0 = tensor.expand_shape %arg0 [[0, 1]] {MANP = 1 : ui1}: tensor<30x!TFHE.glwe<{1024, 1, 64}{4}>> into tensor<5x6x!TFHE.glwe<{1024, 1, 64}{4}>> + return %0 : tensor<5x6x!TFHE.glwe<{1024, 1, 64}{4}>> } diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_from_elements.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_from_elements.mlir similarity index 72% rename from compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_from_elements.mlir rename to compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_from_elements.mlir index eb5801e5c..9d0b18b3d 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_from_elements.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_from_elements.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete --split-input-file %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --split-input-file %s 2>&1| FileCheck %s // CHECK: func.func @main(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>, %[[A2:.*]]: tensor<2049xi64>, %[[A3:.*]]: tensor<2049xi64>, %[[A4:.*]]: tensor<2049xi64>, %[[A5:.*]]: tensor<2049xi64>) -> tensor<6x2049xi64> { // CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<6x2049xi64> @@ -10,9 +10,9 @@ // CHECK: %[[V6:.*]] = tensor.insert_slice %[[A5]] into %[[V5]][5, 0] [1, 2049] [1, 1] : tensor<2049xi64> into tensor<6x2049xi64> // CHECK: return %[[V6]] : tensor<6x2049xi64> // CHECK: } -func.func @main(%arg0 : !Concrete.lwe_ciphertext<2048,4>, %arg1 : !Concrete.lwe_ciphertext<2048,4>, %arg2 : !Concrete.lwe_ciphertext<2048,4>, %arg3 : !Concrete.lwe_ciphertext<2048,4>, %arg4 : !Concrete.lwe_ciphertext<2048,4>, %arg5 : !Concrete.lwe_ciphertext<2048,4>) -> tensor<6x!Concrete.lwe_ciphertext<2048,4>> { - %0 = tensor.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : tensor<6x!Concrete.lwe_ciphertext<2048,4>> - return %0 : tensor<6x!Concrete.lwe_ciphertext<2048,4>> +func.func @main(%arg0 : !TFHE.glwe<{2048, 1, 64}{4}>, %arg1 : !TFHE.glwe<{2048, 1, 64}{4}>, %arg2 : !TFHE.glwe<{2048, 1, 64}{4}>, %arg3 : !TFHE.glwe<{2048, 1, 64}{4}>, %arg4 : !TFHE.glwe<{2048, 1, 64}{4}>, %arg5 : !TFHE.glwe<{2048, 1, 64}{4}>) -> tensor<6x!TFHE.glwe<{2048, 1, 64}{4}>> { + %0 = tensor.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : tensor<6x!TFHE.glwe<{2048, 1, 64}{4}>> + return %0 : tensor<6x!TFHE.glwe<{2048, 1, 64}{4}>> } // ----- @@ -27,7 +27,7 @@ func.func @main(%arg0 : !Concrete.lwe_ciphertext<2048,4>, %arg1 : !Concrete.lwe_ // CHECK: %[[V6:.*]] = tensor.insert_slice %[[A5]] into %[[V5]][1, 2, 0] [1, 1, 2049] [1, 1, 1] : tensor<2049xi64> into tensor<2x3x2049xi64> // CHECK: return %[[V6]] : tensor<2x3x2049xi64> // CHECK: } -func.func @main(%arg0 : !Concrete.lwe_ciphertext<2048,4>, %arg1 : !Concrete.lwe_ciphertext<2048,4>, %arg2 : !Concrete.lwe_ciphertext<2048,4>, %arg3 : !Concrete.lwe_ciphertext<2048,4>, %arg4 : !Concrete.lwe_ciphertext<2048,4>, %arg5 : !Concrete.lwe_ciphertext<2048,4>) -> tensor<2x3x!Concrete.lwe_ciphertext<2048,4>> { - %0 = tensor.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : tensor<2x3x!Concrete.lwe_ciphertext<2048,4>> - return %0 : tensor<2x3x!Concrete.lwe_ciphertext<2048,4>> +func.func @main(%arg0 : !TFHE.glwe<{2048, 1, 64}{4}>, %arg1 : !TFHE.glwe<{2048, 1, 64}{4}>, %arg2 : !TFHE.glwe<{2048, 1, 64}{4}>, %arg3 : !TFHE.glwe<{2048, 1, 64}{4}>, %arg4 : !TFHE.glwe<{2048, 1, 64}{4}>, %arg5 : !TFHE.glwe<{2048, 1, 64}{4}>) -> tensor<2x3x!TFHE.glwe<{2048, 1, 64}{4}>> { + %0 = tensor.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : tensor<2x3x!TFHE.glwe<{2048, 1, 64}{4}>> + return %0 : tensor<2x3x!TFHE.glwe<{2048, 1, 64}{4}>> } diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_identity.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_identity.mlir new file mode 100644 index 000000000..49401debe --- /dev/null +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_identity.mlir @@ -0,0 +1,8 @@ +// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s + +// CHECK: func.func @tensor_identity(%arg0: tensor<2x3x4x1025xi64>) -> tensor<2x3x4x1025xi64> { +// CHECK-NEXT: return %arg0 : tensor<2x3x4x1025xi64> +// CHECK-NEXT: } +func.func @tensor_identity(%arg0: tensor<2x3x4x!TFHE.glwe<{1024, 1, 64}{7}>>) -> tensor<2x3x4x!TFHE.glwe<{1024, 1, 64}{7}>> { + return %arg0 : tensor<2x3x4x!TFHE.glwe<{1024, 1, 64}{7}>> +} diff --git a/compiler/tests/check_tests/Dialect/BConcrete/ops_memref.mlir b/compiler/tests/check_tests/Dialect/BConcrete/ops_memref.mlir deleted file mode 100644 index 91d822d91..000000000 --- a/compiler/tests/check_tests/Dialect/BConcrete/ops_memref.mlir +++ /dev/null @@ -1,37 +0,0 @@ -// RUN: concretecompiler --action=roundtrip %s 2>&1| FileCheck %s - -func.func @add_lwe_ciphertexts(%arg0: memref<2049xi64>, %arg1: memref<2049xi64>, %result : memref<2049xi64>) { - //CHECK: "BConcrete.add_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A1:.*]]) : (memref<2049xi64>, memref<2049xi64>, memref<2049xi64>) -> () - "BConcrete.add_lwe_buffer"(%result, %arg0, %arg1) : (memref<2049xi64>, memref<2049xi64>, memref<2049xi64>) -> () - return -} - -func.func @add_plaintext_lwe_ciphertext(%arg0: memref<2049xi64>, %arg1: i64, %result: memref<2049xi64>) { - //CHECK: "BConcrete.add_plaintext_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A1:.*]]) : (memref<2049xi64>, memref<2049xi64>, i64) -> () - "BConcrete.add_plaintext_lwe_buffer"(%result, %arg0, %arg1) : (memref<2049xi64>, memref<2049xi64>, i64) -> () - return -} - -func.func @mul_cleartext_lwe_ciphertext(%arg0: memref<2049xi64>, %arg1: i64, %result: memref<2049xi64>) { - //CHECK: "BConcrete.mul_cleartext_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A0:.*]]) : (memref<2049xi64>, memref<2049xi64>, i64) -> () - "BConcrete.mul_cleartext_lwe_buffer"(%result, %arg0, %arg1) : (memref<2049xi64>, memref<2049xi64>, i64) -> () - return -} - -func.func @negate_lwe_ciphertext(%arg0: memref<2049xi64>, %result: memref<2049xi64>) { - //CHECK: "BConcrete.negate_lwe_buffer"(%[[R:.*]], %[[A0:.*]]) : (memref<2049xi64>, memref<2049xi64>) -> () - "BConcrete.negate_lwe_buffer"(%result, %arg0) : (memref<2049xi64>, memref<2049xi64>) -> () - return -} - -func.func @bootstrap_lwe(%arg0: memref<2049xi64>, %arg1: memref<16xi64>, %result: memref<2049xi64>) { - //CHECK: "BConcrete.bootstrap_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A1:.*]]) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>, memref<16xi64>) -> () - "BConcrete.bootstrap_lwe_buffer"(%result, %arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>, memref<16xi64>) -> () - return -} - -func.func @keyswitch_lwe(%arg0: memref<2049xi64>, %result: memref<2049xi64>) { - //CHECK: "BConcrete.keyswitch_lwe_buffer"(%[[R:.*]], %[[A0:.*]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>) -> () - "BConcrete.keyswitch_lwe_buffer"(%result, %arg0) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>) -> () - return -} diff --git a/compiler/tests/check_tests/Dialect/BConcrete/ops_tensor.mlir b/compiler/tests/check_tests/Dialect/BConcrete/ops_tensor.mlir deleted file mode 100644 index d66f67576..000000000 --- a/compiler/tests/check_tests/Dialect/BConcrete/ops_tensor.mlir +++ /dev/null @@ -1,55 +0,0 @@ -// RUN: concretecompiler --action=roundtrip %s 2>&1| FileCheck %s - -//CHECK: func.func @add_lwe_ciphertexts(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>) -> tensor<2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.add_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, tensor<2049xi64>) -> tensor<2049xi64> -//CHECK: return %[[V0]] : tensor<2049xi64> -//CHECK: } -func.func @add_lwe_ciphertexts(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>) -> tensor<2049xi64> { - %0 = "BConcrete.add_lwe_tensor"(%arg0, %arg1) : (tensor<2049xi64>, tensor<2049xi64>) -> ( tensor<2049xi64>) - return %0 : tensor<2049xi64> -} - -//CHECK: func.func @add_plaintext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64> -//CHECK: return %[[V0]] : tensor<2049xi64> -//CHECK: } -func.func @add_plaintext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) -> tensor<2049xi64> { - %0 = "BConcrete.add_plaintext_lwe_tensor"(%arg0, %arg1) : (tensor<2049xi64>, i64) -> ( tensor<2049xi64>) - return %0 : tensor<2049xi64> -} - -//CHECK: func @mul_cleartext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64> -//CHECK: return %[[V0]] : tensor<2049xi64> -//CHECK: } -func.func @mul_cleartext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) -> tensor<2049xi64> { - %0 = "BConcrete.mul_cleartext_lwe_tensor"(%arg0, %arg1) : (tensor<2049xi64>, i64) -> (tensor<2049xi64>) - return %0 : tensor<2049xi64> -} - -//CHECK: func.func @negate_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.negate_lwe_tensor"(%[[A0]]) : (tensor<2049xi64>) -> tensor<2049xi64> -//CHECK: return %[[V0]] : tensor<2049xi64> -//CHECK: } -func.func @negate_lwe_ciphertext(%arg0: tensor<2049xi64>) -> tensor<2049xi64> { - %0 = "BConcrete.negate_lwe_tensor"(%arg0) : (tensor<2049xi64>) -> (tensor<2049xi64>) - return %0 : tensor<2049xi64> -} - -//CHECK: func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>) -> tensor<2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> tensor<2049xi64> -//CHECK: return %[[V0]] : tensor<2049xi64> -//CHECK: } -func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>) -> tensor<2049xi64> { - %0 = "BConcrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> (tensor<2049xi64>) - return %0 : tensor<2049xi64> -} - -//CHECK: func.func @keyswitch_lwe(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> tensor<2049xi64> -//CHECK: return %[[V0]] : tensor<2049xi64> -//CHECK: } -func.func @keyswitch_lwe(%arg0: tensor<2049xi64>) -> tensor<2049xi64> { - %0 = "BConcrete.keyswitch_lwe_tensor"(%arg0) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> (tensor<2049xi64>) - return %0 : tensor<2049xi64> -} diff --git a/compiler/tests/check_tests/Dialect/BConcrete/bufferization-nonzero-offets.mlir b/compiler/tests/check_tests/Dialect/Concrete/bufferization-nonzero-offets.mlir similarity index 93% rename from compiler/tests/check_tests/Dialect/BConcrete/bufferization-nonzero-offets.mlir rename to compiler/tests/check_tests/Dialect/Concrete/bufferization-nonzero-offets.mlir index 4193b66a0..1c5b0598c 100644 --- a/compiler/tests/check_tests/Dialect/BConcrete/bufferization-nonzero-offets.mlir +++ b/compiler/tests/check_tests/Dialect/Concrete/bufferization-nonzero-offets.mlir @@ -4,7 +4,7 @@ // Extracted from the source referenced in Issue 663. This should // trigger the folding of memrefs of itermediate results to memrefs // with non-zero offsets. Prior to the use of symbolic offsets in the -// memref used in the memref.cast operation produced by the BConcrete +// memref used in the memref.cast operation produced by the Concrete // bufferizer, bufferization of the function below would fail. func.func @main(%arg0: tensor<32x!FHE.eint<8>>, %arg1: tensor<256xi64>) -> !FHE.eint<8> { diff --git a/compiler/tests/check_tests/Dialect/Concrete/no_optimization.mlir b/compiler/tests/check_tests/Dialect/Concrete/no_optimization.mlir deleted file mode 100644 index ec69bb7a7..000000000 --- a/compiler/tests/check_tests/Dialect/Concrete/no_optimization.mlir +++ /dev/null @@ -1,12 +0,0 @@ -// RUN: concretecompiler --optimize-concrete=false --action=dump-concrete %s 2>&1| FileCheck %s - -//CHECK: func.func @mul_cleartext_lwe_ciphertext_0(%[[A0:.*]]: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { -//CHECK: %c0_i7 = arith.constant 0 : i7 -//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %c0_i7) : (!Concrete.lwe_ciphertext<2048,7>, i7) -> !Concrete.lwe_ciphertext<2048,7> -//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<2048,7> -//CHECK: } -func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { - %0 = arith.constant 0 : i7 - %2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>) - return %2: !Concrete.lwe_ciphertext<2048,7> -} diff --git a/compiler/tests/check_tests/Dialect/Concrete/ops.mlir b/compiler/tests/check_tests/Dialect/Concrete/ops.mlir index 970bb73f0..cbdb01335 100644 --- a/compiler/tests/check_tests/Dialect/Concrete/ops.mlir +++ b/compiler/tests/check_tests/Dialect/Concrete/ops.mlir @@ -1,53 +1,95 @@ // RUN: concretecompiler --action=roundtrip %s 2>&1| FileCheck %s -// CHECK-LABEL: func.func @add_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> -func.func @add_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.add_lwe_ciphertexts"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> - // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> +// Tensor ops - %1 = "Concrete.add_lwe_ciphertexts"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.lwe_ciphertext<2048,7>) -> (!Concrete.lwe_ciphertext<2048,7>) - return %1: !Concrete.lwe_ciphertext<2048,7> +//CHECK: func.func @add_lwe_ciphertexts(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>) -> tensor<2049xi64> { +//CHECK: %[[V0:.*]] = "Concrete.add_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, tensor<2049xi64>) -> tensor<2049xi64> +//CHECK: return %[[V0]] : tensor<2049xi64> +//CHECK: } +func.func @add_lwe_ciphertexts(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>) -> tensor<2049xi64> { + %0 = "Concrete.add_lwe_tensor"(%arg0, %arg1) : (tensor<2049xi64>, tensor<2049xi64>) -> ( tensor<2049xi64>) + return %0 : tensor<2049xi64> } -// CHECK-LABEL: func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i5) -> !Concrete.lwe_ciphertext<2048,7> -func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i5) -> !Concrete.lwe_ciphertext<2048,7> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, i5) -> !Concrete.lwe_ciphertext<2048,7> - // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> - - %1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, i5) -> (!Concrete.lwe_ciphertext<2048,7>) - return %1: !Concrete.lwe_ciphertext<2048,7> +//CHECK: func.func @add_plaintext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> { +//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64> +//CHECK: return %[[V0]] : tensor<2049xi64> +//CHECK: } +func.func @add_plaintext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) -> tensor<2049xi64> { + %0 = "Concrete.add_plaintext_lwe_tensor"(%arg0, %arg1) : (tensor<2049xi64>, i64) -> ( tensor<2049xi64>) + return %0 : tensor<2049xi64> } -// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i7) -> !Concrete.lwe_ciphertext<2048,7> -func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i7) -> !Concrete.lwe_ciphertext<2048,7> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, i7) -> !Concrete.lwe_ciphertext<2048,7> - // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> - - %1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>) - return %1: !Concrete.lwe_ciphertext<2048,7> +//CHECK: func @mul_cleartext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> { +//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64> +//CHECK: return %[[V0]] : tensor<2049xi64> +//CHECK: } +func.func @mul_cleartext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) -> tensor<2049xi64> { + %0 = "Concrete.mul_cleartext_lwe_tensor"(%arg0, %arg1) : (tensor<2049xi64>, i64) -> (tensor<2049xi64>) + return %0 : tensor<2049xi64> } -// CHECK-LABEL: func.func @negate_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> -func.func @negate_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> - // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> - - %1 = "Concrete.negate_lwe_ciphertext"(%arg0): (!Concrete.lwe_ciphertext<2048,7>) -> (!Concrete.lwe_ciphertext<2048,7>) - return %1: !Concrete.lwe_ciphertext<2048,7> +//CHECK: func.func @negate_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> { +//CHECK: %[[V0:.*]] = "Concrete.negate_lwe_tensor"(%[[A0]]) : (tensor<2049xi64>) -> tensor<2049xi64> +//CHECK: return %[[V0]] : tensor<2049xi64> +//CHECK: } +func.func @negate_lwe_ciphertext(%arg0: tensor<2049xi64>) -> tensor<2049xi64> { + %0 = "Concrete.negate_lwe_tensor"(%arg0) : (tensor<2049xi64>) -> (tensor<2049xi64>) + return %0 : tensor<2049xi64> } -// CHECK-LABEL: func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: tensor<128xi64>) -> !Concrete.lwe_ciphertext<2048,7> -func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: tensor<128xi64>) -> !Concrete.lwe_ciphertext<2048,7> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, level = 3 : i32, polySize = 2048 : i32} : (!Concrete.lwe_ciphertext<2048,7>, tensor<128xi64>) -> !Concrete.lwe_ciphertext<2048,7> - // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> - %1 = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = 2 : i32, polySize = 2048 : i32, level = 3 : i32, glweDimension = 4 : i32} : (!Concrete.lwe_ciphertext<2048,7>, tensor<128xi64>) -> !Concrete.lwe_ciphertext<2048,7> - return %1: !Concrete.lwe_ciphertext<2048,7> +//CHECK: func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>) -> tensor<2049xi64> { +//CHECK: %[[V0:.*]] = "Concrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> tensor<2049xi64> +//CHECK: return %[[V0]] : tensor<2049xi64> +//CHECK: } +func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>) -> tensor<2049xi64> { + %0 = "Concrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> (tensor<2049xi64>) + return %0 : tensor<2049xi64> } -// CHECK-LABEL: func.func @keyswitch_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> -func.func @keyswitch_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, level = 3 : i32} : (!Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> - // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> - %1 = "Concrete.keyswitch_lwe"(%arg0){baseLog = 2 : i32, level = 3 : i32}: (!Concrete.lwe_ciphertext<2048,7>) -> (!Concrete.lwe_ciphertext<2048,7>) - return %1: !Concrete.lwe_ciphertext<2048,7> +//CHECK: func.func @keyswitch_lwe(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> { +//CHECK: %[[V0:.*]] = "Concrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> tensor<2049xi64> +//CHECK: return %[[V0]] : tensor<2049xi64> +//CHECK: } +func.func @keyswitch_lwe(%arg0: tensor<2049xi64>) -> tensor<2049xi64> { + %0 = "Concrete.keyswitch_lwe_tensor"(%arg0) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> (tensor<2049xi64>) + return %0 : tensor<2049xi64> +} + +// MemRef Ops + +func.func @add_lwe_ciphertexts_buffer(%arg0: memref<2049xi64>, %arg1: memref<2049xi64>, %result : memref<2049xi64>) { + //CHECK: "Concrete.add_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A1:.*]]) : (memref<2049xi64>, memref<2049xi64>, memref<2049xi64>) -> () + "Concrete.add_lwe_buffer"(%result, %arg0, %arg1) : (memref<2049xi64>, memref<2049xi64>, memref<2049xi64>) -> () + return +} + +func.func @add_plaintext_lwe_ciphertext_buffer(%arg0: memref<2049xi64>, %arg1: i64, %result: memref<2049xi64>) { + //CHECK: "Concrete.add_plaintext_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A1:.*]]) : (memref<2049xi64>, memref<2049xi64>, i64) -> () + "Concrete.add_plaintext_lwe_buffer"(%result, %arg0, %arg1) : (memref<2049xi64>, memref<2049xi64>, i64) -> () + return +} + +func.func @mul_cleartext_lwe_ciphertext_buffer(%arg0: memref<2049xi64>, %arg1: i64, %result: memref<2049xi64>) { + //CHECK: "Concrete.mul_cleartext_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A0:.*]]) : (memref<2049xi64>, memref<2049xi64>, i64) -> () + "Concrete.mul_cleartext_lwe_buffer"(%result, %arg0, %arg1) : (memref<2049xi64>, memref<2049xi64>, i64) -> () + return +} + +func.func @negate_lwe_ciphertext_buffer(%arg0: memref<2049xi64>, %result: memref<2049xi64>) { + //CHECK: "Concrete.negate_lwe_buffer"(%[[R:.*]], %[[A0:.*]]) : (memref<2049xi64>, memref<2049xi64>) -> () + "Concrete.negate_lwe_buffer"(%result, %arg0) : (memref<2049xi64>, memref<2049xi64>) -> () + return +} + +func.func @bootstrap_lwe_buffer(%arg0: memref<2049xi64>, %arg1: memref<16xi64>, %result: memref<2049xi64>) { + //CHECK: "Concrete.bootstrap_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A1:.*]]) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>, memref<16xi64>) -> () + "Concrete.bootstrap_lwe_buffer"(%result, %arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>, memref<16xi64>) -> () + return +} + +func.func @keyswitch_lwe_buffer(%arg0: memref<2049xi64>, %result: memref<2049xi64>) { + //CHECK: "Concrete.keyswitch_lwe_buffer"(%[[R:.*]], %[[A0:.*]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>) -> () + "Concrete.keyswitch_lwe_buffer"(%result, %arg0) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>) -> () + return } diff --git a/compiler/tests/check_tests/Dialect/Concrete/optimization.mlir b/compiler/tests/check_tests/Dialect/Concrete/optimization.mlir deleted file mode 100644 index 5f10ec7a0..000000000 --- a/compiler/tests/check_tests/Dialect/Concrete/optimization.mlir +++ /dev/null @@ -1,31 +0,0 @@ -// RUN: concretecompiler --action=dump-concrete %s 2>&1| FileCheck %s - - -// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i7) -> !Concrete.lwe_ciphertext<2048,7> -func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i7) -> !Concrete.lwe_ciphertext<2048,7> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, i7) -> !Concrete.lwe_ciphertext<2048,7> - // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> - - %1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>) - return %1: !Concrete.lwe_ciphertext<2048,7> -} - -// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> -func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.zero"() : () -> !Concrete.lwe_ciphertext<2048,7> - // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> - - %0 = arith.constant 0 : i7 - %2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>) - return %2: !Concrete.lwe_ciphertext<2048,7> -} - -// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext_minus_1(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> -func.func @mul_cleartext_lwe_ciphertext_minus_1(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> - // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> - - %0 = arith.constant -1 : i7 - %2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>) - return %2: !Concrete.lwe_ciphertext<2048,7> -} diff --git a/compiler/tests/check_tests/Dialect/Concrete/types.mlir b/compiler/tests/check_tests/Dialect/Concrete/types.mlir deleted file mode 100644 index 311fb8376..000000000 --- a/compiler/tests/check_tests/Dialect/Concrete/types.mlir +++ /dev/null @@ -1,20 +0,0 @@ -// RUN: concretecompiler --action=roundtrip %s 2>&1| FileCheck %s - - -// CHECK-LABEL: func.func @type_plaintext(%arg0: !Concrete.plaintext<7>) -> !Concrete.plaintext<7> -func.func @type_plaintext(%arg0: !Concrete.plaintext<7>) -> !Concrete.plaintext<7> { - // CHECK-NEXT: return %arg0 : !Concrete.plaintext<7> - return %arg0: !Concrete.plaintext<7> -} - -// CHECK-LABEL: func.func @type_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> -func.func @type_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { - // CHECK-NEXT: return %arg0 : !Concrete.lwe_ciphertext<2048,7> - return %arg0: !Concrete.lwe_ciphertext<2048,7> -} - -// CHECK-LABEL: func @type_cleartext(%arg0: !Concrete.cleartext<5>) -> !Concrete.cleartext<5> -func.func @type_cleartext(%arg0: !Concrete.cleartext<5>) -> !Concrete.cleartext<5> { - // CHECK-NEXT: return %arg0 : !Concrete.cleartext<5> - return %arg0: !Concrete.cleartext<5> -} diff --git a/compiler/tests/check_tests/Dialect/TFHE/no_optimization.mlir b/compiler/tests/check_tests/Dialect/TFHE/no_optimization.mlir new file mode 100644 index 000000000..3c8c891c9 --- /dev/null +++ b/compiler/tests/check_tests/Dialect/TFHE/no_optimization.mlir @@ -0,0 +1,12 @@ +// RUN: concretecompiler --passes tfhe-optimization --optimize-tfhe=false --action=dump-tfhe %s 2>&1| FileCheck %s + +//CHECK: func.func @mul_cleartext_glwe_ciphertext_0(%[[A0:.*]]: !TFHE.glwe<{1,527,64}{7}>) -> !TFHE.glwe<{1,527,64}{7}> { +//CHECK: %c0_i64 = arith.constant 0 : i64 +//CHECK: %[[V0:.*]] = "TFHE.mul_glwe_int"(%[[A0]], %c0_i64) : (!TFHE.glwe<{1,527,64}{7}>, i64) -> !TFHE.glwe<{1,527,64}{7}> +//CHECK: return %[[V0]] : !TFHE.glwe<{1,527,64}{7}> +//CHECK: } +func.func @mul_cleartext_glwe_ciphertext_0(%arg0: !TFHE.glwe<{1,527,64}{7}>) -> !TFHE.glwe<{1,527,64}{7}> { + %0 = arith.constant 0 : i64 + %2 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<{1,527,64}{7}>, i64) -> (!TFHE.glwe<{1,527,64}{7}>) + return %2: !TFHE.glwe<{1,527,64}{7}> +} diff --git a/compiler/tests/check_tests/Dialect/TFHE/optimization.mlir b/compiler/tests/check_tests/Dialect/TFHE/optimization.mlir new file mode 100644 index 000000000..57a74701d --- /dev/null +++ b/compiler/tests/check_tests/Dialect/TFHE/optimization.mlir @@ -0,0 +1,31 @@ +// RUN: concretecompiler --passes tfhe-optimization --action=dump-tfhe %s 2>&1| FileCheck %s + + +// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !TFHE.glwe<{1,527,64}{7}>, %arg1: i64) -> !TFHE.glwe<{1,527,64}{7}> +func.func @mul_cleartext_lwe_ciphertext(%arg0: !TFHE.glwe<{1,527,64}{7}>, %arg1: i64) -> !TFHE.glwe<{1,527,64}{7}> { + // CHECK-NEXT: %[[V1:.*]] = "TFHE.mul_glwe_int"(%arg0, %arg1) : (!TFHE.glwe<{1,527,64}{7}>, i64) -> !TFHE.glwe<{1,527,64}{7}> + // CHECK-NEXT: return %[[V1]] : !TFHE.glwe<{1,527,64}{7}> + + %1 = "TFHE.mul_glwe_int"(%arg0, %arg1): (!TFHE.glwe<{1,527,64}{7}>, i64) -> (!TFHE.glwe<{1,527,64}{7}>) + return %1: !TFHE.glwe<{1,527,64}{7}> +} + +// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !TFHE.glwe<{1,527,64}{7}>) -> !TFHE.glwe<{1,527,64}{7}> +func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !TFHE.glwe<{1,527,64}{7}>) -> !TFHE.glwe<{1,527,64}{7}> { + // CHECK-NEXT: %[[V1:.*]] = "TFHE.zero"() : () -> !TFHE.glwe<{1,527,64}{7}> + // CHECK-NEXT: return %[[V1]] : !TFHE.glwe<{1,527,64}{7}> + + %0 = arith.constant 0 : i64 + %2 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<{1,527,64}{7}>, i64) -> (!TFHE.glwe<{1,527,64}{7}>) + return %2: !TFHE.glwe<{1,527,64}{7}> +} + +// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext_minus_1(%arg0: !TFHE.glwe<{1,527,64}{7}>) -> !TFHE.glwe<{1,527,64}{7}> +func.func @mul_cleartext_lwe_ciphertext_minus_1(%arg0: !TFHE.glwe<{1,527,64}{7}>) -> !TFHE.glwe<{1,527,64}{7}> { + // CHECK-NEXT: %[[V1:.*]] = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe<{1,527,64}{7}>) -> !TFHE.glwe<{1,527,64}{7}> + // CHECK-NEXT: return %[[V1]] : !TFHE.glwe<{1,527,64}{7}> + + %0 = arith.constant -1 : i64 + %2 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<{1,527,64}{7}>, i64) -> (!TFHE.glwe<{1,527,64}{7}>) + return %2: !TFHE.glwe<{1,527,64}{7}> +} diff --git a/compiler/tests/check_tests/Transforms/batching.mlir b/compiler/tests/check_tests/Transforms/batching.mlir.disabled similarity index 100% rename from compiler/tests/check_tests/Transforms/batching.mlir rename to compiler/tests/check_tests/Transforms/batching.mlir.disabled