From 9dd776533a458158da70ccede0b5466019562f65 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Tue, 14 Jun 2022 11:21:16 +0200 Subject: [PATCH] cleanup(compiler): Remove useless concrete types, simplify print and parse, and remove BConcreteToBConcreteCAPI pass --- .../BConcreteToBConcreteCAPI/Pass.h | 22 - .../include/concretelang/Conversion/Passes.h | 1 - .../Dialect/Concrete/IR/ConcreteDialect.td | 5 - .../Dialect/Concrete/IR/ConcreteTypes.td | 152 ++-- .../BConcreteToBConcreteCAPI.cpp | 718 ------------------ .../BConcreteToBConcreteCAPI/CMakeLists.txt | 17 - compiler/lib/Conversion/CMakeLists.txt | 1 - .../MLIRLowerableDialectsToLLVM.cpp | 4 - .../Dialect/Concrete/IR/ConcreteDialect.cpp | 59 -- .../RT/Analysis/LowerDataflowTasksToRT.cpp | 6 +- .../Dialect/Concrete/Concrete/types.mlir | 24 - 11 files changed, 49 insertions(+), 960 deletions(-) delete mode 100644 compiler/include/concretelang/Conversion/BConcreteToBConcreteCAPI/Pass.h delete mode 100644 compiler/lib/Conversion/BConcreteToBConcreteCAPI/BConcreteToBConcreteCAPI.cpp delete mode 100644 compiler/lib/Conversion/BConcreteToBConcreteCAPI/CMakeLists.txt diff --git a/compiler/include/concretelang/Conversion/BConcreteToBConcreteCAPI/Pass.h b/compiler/include/concretelang/Conversion/BConcreteToBConcreteCAPI/Pass.h deleted file mode 100644 index 906e1c417..000000000 --- a/compiler/include/concretelang/Conversion/BConcreteToBConcreteCAPI/Pass.h +++ /dev/null @@ -1,22 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt -// for license information. - -#ifndef CONCRETELANG_CONVERSION_BCONCRETETOBCONCRETECAPI_PASS_H_ -#define CONCRETELANG_CONVERSION_BCONCRETETOBCONCRETECAPI_PASS_H_ - -#include "mlir/Pass/Pass.h" - -#include "concretelang/Conversion/Utils/GlobalFHEContext.h" - -namespace mlir { -namespace concretelang { -/// Create a pass to convert `Concrete` operators to function call to the -/// `ConcreteCAPI` -std::unique_ptr> -createConvertBConcreteToBConcreteCAPIPass(); -} // namespace concretelang -} // namespace mlir - -#endif diff --git a/compiler/include/concretelang/Conversion/Passes.h b/compiler/include/concretelang/Conversion/Passes.h index 59e99a251..767b51e14 100644 --- a/compiler/include/concretelang/Conversion/Passes.h +++ b/compiler/include/concretelang/Conversion/Passes.h @@ -11,7 +11,6 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/SCF.h" -#include "concretelang/Conversion/BConcreteToBConcreteCAPI/Pass.h" #include "concretelang/Conversion/ConcreteToBConcrete/Pass.h" #include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h" #include "concretelang/Conversion/FHEToTFHE/Pass.h" diff --git a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteDialect.td b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteDialect.td index 7f74d11f9..1da3649d6 100644 --- a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteDialect.td +++ b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteDialect.td @@ -10,11 +10,6 @@ def Concrete_Dialect : Dialect { A dialect for representation of low level operation on fully homomorphic ciphertext. }]; let cppNamespace = "::mlir::concretelang::Concrete"; - let useDefaultTypePrinterParser = 0; - let extraClassDeclaration = [{ - ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override; - void printType(::mlir::Type type, ::mlir::DialectAsmPrinter &printer) const override; - }]; } #endif diff --git a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td index 606eae2f3..bdece275a 100644 --- a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td +++ b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td @@ -5,141 +5,85 @@ include "mlir/IR/BuiltinTypes.td" include "concretelang/Dialect/Concrete/IR/ConcreteDialect.td" -class Concrete_Type traits = []> : TypeDef { } +class Concrete_Type traits = []> + : TypeDef {} def GlweCiphertextType : Concrete_Type<"GlweCiphertext"> { - let mnemonic = "glwe_ciphertext"; + let mnemonic = "glwe_ciphertext"; - let summary = "A GLWE ciphertext (encryption of a polynomial of fixed-precision integers)"; + let summary = "A GLWE ciphertext (encryption of a polynomial of " + "fixed-precision integers)"; - let description = [{ - GLWE ciphertext. - }]; + let description = [{GLWE ciphertext.}]; - let hasCustomAssemblyFormat = 1; + let hasCustomAssemblyFormat = 1; - let parameters = (ins - "signed":$polynomialSize, - "signed":$glweDimension, - // Precision of the lwe ciphertext - "signed":$p - ); + let parameters = (ins "signed" + : $polynomialSize, "signed" + : $glweDimension, + // Precision of the lwe ciphertext + "signed" + : $p); } -def LweCiphertextType : Concrete_Type<"LweCiphertext", [MemRefElementTypeInterface]> { - let mnemonic = "lwe_ciphertext"; +def LweCiphertextType + : Concrete_Type<"LweCiphertext", [MemRefElementTypeInterface]> { + let mnemonic = "lwe_ciphertext"; - let summary = "A LWE ciphertext (encryption of a fixed-precision integer)"; + let summary = "A LWE ciphertext (encryption of a fixed-precision integer)"; - let description = [{ - Learning With Error ciphertext. - }]; + let description = [{Learning With Error ciphertext.}]; + let parameters = (ins + // The dimension of the lwe ciphertext + "signed" + : $dimension, + // Precision of the lwe ciphertext + "signed" + : $p); - let parameters = (ins - // The dimension of the lwe ciphertext - "signed":$dimension, - // Precision of the lwe ciphertext - "signed":$p - ); - - let hasCustomAssemblyFormat = 1; + let hasCustomAssemblyFormat = 1; } def CleartextType : Concrete_Type<"Cleartext"> { - let mnemonic = "cleartext"; + let mnemonic = "cleartext"; - let summary = "A cleartext (a fixed-precision integer) ready to be multiplied to a LWE ciphertext"; + let summary = "A cleartext (a fixed-precision integer) ready to be " + "multiplied to a LWE ciphertext"; - let description = [{ - Cleartext. - }]; + let description = [{Cleartext.}]; - let parameters = (ins - // Number of bits of the cleartext representation - "signed":$p - ); + let parameters = (ins + // Number of bits of the cleartext representation + "signed" + : $p); - let hasCustomAssemblyFormat = 1; + let hasCustomAssemblyFormat = 1; } def PlaintextType : Concrete_Type<"Plaintext"> { - let mnemonic = "plaintext"; + let mnemonic = "plaintext"; - let summary = "A Plaintext (a fixed-precision integer) ready to be added to a LWE ciphertext"; + let summary = "A Plaintext (a fixed-precision integer) ready to be added to " + "a LWE ciphertext"; - let description = [{ - Plaintext. - }]; + let description = [{Plaintext.}]; - let parameters = (ins - // Number of bits of the cleartext representation - "signed":$p - ); + let parameters = (ins + // Number of bits of the cleartext representation + "signed" + : $p); - let hasCustomAssemblyFormat = 1; -} - -def PlaintextListType : Concrete_Type<"PlaintextList"> { - let mnemonic = "plaintext_list"; - - let summary = "List of plaintexts"; - - let description = [{ - Plaintext list. - }]; - - let hasCustomAssemblyFormat = 1; -} - -def ForeignPlaintextListType : Concrete_Type<"ForeignPlaintextList"> { - let mnemonic = "foreign_plaintext_list"; - - let summary = "A foreign (reference to a independently allocated memory space) plaintext list"; - - let description = [{ - Foreign plaintext list. - }]; - - let hasCustomAssemblyFormat = 1; -} - -def LweKeySwitchKeyType : Concrete_Type<"LweKeySwitchKey"> { - let mnemonic = "lwe_key_switch_key"; - - let summary = "A LWE keyswitching key"; - - let description = [{ - Learning With Error keyswitching key. - }]; - - let hasCustomAssemblyFormat = 1; -} - -def LweBootstrapKeyType : Concrete_Type<"LweBootstrapKey"> { - let mnemonic = "lwe_bootstrap_key"; - - let summary = "A LWE bootstrapping key"; - - let description = [{ - Learning With Error bootstrapping key. - }]; - - let hasCustomAssemblyFormat = 1; + let hasCustomAssemblyFormat = 1; } def Context : Concrete_Type<"Context"> { - let mnemonic = "context"; + let mnemonic = "context"; - let summary = "A runtime context"; + let summary = "A runtime context"; - let description = [{ - An abstract runtime context to pass contextual value, like public keys, ... - }]; - - let hasCustomAssemblyFormat = 1; + let description = [{An abstract runtime context to pass contextual value, + like public keys, ...}]; } - - #endif diff --git a/compiler/lib/Conversion/BConcreteToBConcreteCAPI/BConcreteToBConcreteCAPI.cpp b/compiler/lib/Conversion/BConcreteToBConcreteCAPI/BConcreteToBConcreteCAPI.cpp deleted file mode 100644 index e9d0a5670..000000000 --- a/compiler/lib/Conversion/BConcreteToBConcreteCAPI/BConcreteToBConcreteCAPI.cpp +++ /dev/null @@ -1,718 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt -// for license information. - -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR//BuiltinTypes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "concretelang/Conversion/Passes.h" -#include "concretelang/Conversion/Tools.h" -#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h" -#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" -#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" -#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" -#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" -#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" -#include -#include - -static mlir::Type convertTypeIfConcreteType(mlir::MLIRContext *context, - mlir::Type t) { - if (t.isa() || - t.isa()) { - return mlir::IntegerType::get(context, 64); - } else { - return t; - } -} - -namespace { -class BConcreteToBConcreteCAPITypeConverter : public mlir::TypeConverter { - -public: - BConcreteToBConcreteCAPITypeConverter() { - addConversion([](mlir::Type type) { return type; }); - addConversion([&](mlir::concretelang::Concrete::PlaintextType type) { - return convertTypeIfConcreteType(type.getContext(), type); - }); - addConversion([&](mlir::concretelang::Concrete::CleartextType type) { - return convertTypeIfConcreteType(type.getContext(), type); - }); - } -}; - -// Set of functions to generate generic types. -// Generic types are used to add forward declarations without a specific type. -// For example, we may need to add LWE ciphertext of different dimensions, or -// allocate them. All the calls to the C API should be done using this generic -// types, and casting should then be performed back to the appropriate type. - -inline mlir::Type getGenericLweBufferType(mlir::MLIRContext *context) { - return mlir::RankedTensorType::get({-1}, mlir::IntegerType::get(context, 64)); -} - -inline mlir::Type getGenericLweMemrefType(mlir::MLIRContext *context) { - return mlir::MemRefType::get({-1}, mlir::IntegerType::get(context, 64)); -} - -inline mlir::Type getGenericGlweBufferType(mlir::MLIRContext *context) { - return mlir::RankedTensorType::get({-1}, mlir::IntegerType::get(context, 64)); -} - -inline mlir::Type getGenericPlaintextType(mlir::MLIRContext *context) { - return mlir::IntegerType::get(context, 64); -} - -inline mlir::Type getGenericCleartextType(mlir::MLIRContext *context) { - return mlir::IntegerType::get(context, 64); -} - -inline mlir::concretelang::Concrete::LweKeySwitchKeyType -getGenericLweKeySwitchKeyType(mlir::MLIRContext *context) { - return mlir::concretelang::Concrete::LweKeySwitchKeyType::get(context); -} - -inline mlir::concretelang::Concrete::LweBootstrapKeyType -getGenericLweBootstrapKeyType(mlir::MLIRContext *context) { - return mlir::concretelang::Concrete::LweBootstrapKeyType::get(context); -} - -// Insert all forward declarations needed for the pass. -// Should generalize input and output types for all decalarations, and the -// pattern using them would be resposible for casting them to the appropriate -// type. -mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, - mlir::IRRewriter &rewriter) { - auto lweBufferType = getGenericLweMemrefType(rewriter.getContext()); - auto plaintextType = getGenericPlaintextType(rewriter.getContext()); - auto cleartextType = getGenericCleartextType(rewriter.getContext()); - auto keySwitchKeyType = getGenericLweKeySwitchKeyType(rewriter.getContext()); - auto bootstrapKeyType = getGenericLweBootstrapKeyType(rewriter.getContext()); - auto contextType = - mlir::concretelang::Concrete::ContextType::get(rewriter.getContext()); - - // Insert forward declaration of the add_lwe_ciphertexts function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), {lweBufferType, lweBufferType, lweBufferType}, - {}); - if (insertForwardDeclaration(op, rewriter, "memref_add_lwe_ciphertexts_u64", - funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the add_plaintext_lwe_ciphertext function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), {lweBufferType, lweBufferType, plaintextType}, - {}); - if (insertForwardDeclaration( - op, rewriter, "memref_add_plaintext_lwe_ciphertext_u64", funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the mul_cleartext_lwe_ciphertext function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), {lweBufferType, lweBufferType, cleartextType}, - {}); - if (insertForwardDeclaration( - op, rewriter, "memref_mul_cleartext_lwe_ciphertext_u64", funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the negate_lwe_ciphertext function - { - auto funcType = mlir::FunctionType::get(rewriter.getContext(), - {lweBufferType, lweBufferType}, {}); - if (insertForwardDeclaration(op, rewriter, - "memref_negate_lwe_ciphertext_u64", funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the memref_keyswitch_lwe_u64 function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), {lweBufferType, lweBufferType, contextType}, {}); - if (insertForwardDeclaration(op, rewriter, "memref_keyswitch_lwe_u64", - funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the memref_bootstrap_lwe_u64 function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), - {lweBufferType, lweBufferType, lweBufferType, contextType}, {}); - if (insertForwardDeclaration(op, rewriter, "memref_bootstrap_lwe_u64", - funcType) - .failed()) { - return mlir::failure(); - } - } - - // Insert forward declaration of the expand_lut_in_trivial_glwe_ct function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), - { - getGenericGlweBufferType(rewriter.getContext()), - rewriter.getI32Type(), - rewriter.getI32Type(), - rewriter.getI32Type(), - mlir::RankedTensorType::get( - {-1}, mlir::IntegerType::get(rewriter.getContext(), 64)), - }, - {}); - if (insertForwardDeclaration( - op, rewriter, "memref_expand_lut_in_trivial_glwe_ct_u64", funcType) - .failed()) { - return mlir::failure(); - } - } - - // Insert forward declaration of the getGlobalKeyswitchKey function - { - auto funcType = mlir::FunctionType::get(rewriter.getContext(), - {contextType}, {keySwitchKeyType}); - if (insertForwardDeclaration(op, rewriter, "get_keyswitch_key_u64", - funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the getGlobalBootstrapKey function - { - auto funcType = mlir::FunctionType::get(rewriter.getContext(), - {contextType}, {bootstrapKeyType}); - if (insertForwardDeclaration(op, rewriter, "get_bootstrap_key_u64", - funcType) - .failed()) { - return mlir::failure(); - } - } - return mlir::success(); -} - -// Replaces an operand `tensor` with -// ``` -// %casted_tensor = tensor.cast %op : tensor to tensor -// %casted_memref = bufferization.to_memref %casted_tensor : memref -// ``` -mlir::Value getCastedTensorOperand(mlir::PatternRewriter &rewriter, - mlir::Location loc, mlir::Value operand) { - mlir::Type operandType = operand.getType(); - if (operandType.isa()) { - mlir::Value castedTensor = rewriter.create( - loc, getGenericLweBufferType(rewriter.getContext()), operand); - - mlir::Value castedMemRef = rewriter.create( - loc, getGenericLweMemrefType(rewriter.getContext()), castedTensor); - return castedMemRef; - } else { - return operand; - } -} - -mlir::SmallVector -getCastedTensorOperands(mlir::PatternRewriter &rewriter, mlir::Operation *op) { - return llvm::to_vector<3>( - llvm::map_range(op->getOperands(), [&](mlir::Value operand) { - return getCastedTensorOperand(rewriter, op->getLoc(), operand); - })); -} - -// template -// mlir::SmallVector -// getCastedTensorOperands(Op op, mlir::PatternRewriter &rewriter) { -// mlir::SmallVector newOperands{}; -// for (mlir::Value operand : op->getOperands()) { -// mlir::Type operandType = operand.getType(); -// if (operandType.isa()) { -// mlir::Value castedTensor = rewriter.create( -// op.getLoc(), getGenericLweBufferType(rewriter.getContext()), -// operand); - -// mlir::Value castedMemRef = -// rewriter.create( -// op.getLoc(), getGenericLweBufferType(rewriter.getContext()), -// operand); -// newOperands.push_back(castedMemRef); -// } else { -// newOperands.push_back(operand); -// } -// } -// return std::move(newOperands); -// } - -/// BConcreteOpToConcreteCAPICallPattern matches the `BConcreteOp` -/// Operation and replaces it with a call to `funcName`, the funcName should be -/// an external function that is linked later. It inserts the forward -/// declaration of the private `funcName` if it not already in the symbol table. -/// The C signature of the function should be `void (out, args..., -/// lweDimension)`, the pattern rewrites: -/// ``` -/// "%out = BConcreteOp"(args ...) : -/// (tensor...) -> tensor -/// ``` -/// to -/// ``` -/// %args_tensor = tensor.cast ... -/// %args_memref = bufferize.to_memref ... -/// %out_tensor_ranked = linalg.tensor_init ... -// %out_tensor = tensor.cast ... -/// %out_memref = bufferize.to_memref ... -/// call @funcName(%out_memref, %args_memref...) : -/// (memref, memref...) -> () -// %out = bufferize.to_tensor ... -/// ``` -template -struct ConcreteOpToConcreteCAPICallPattern - : public mlir::OpRewritePattern { - ConcreteOpToConcreteCAPICallPattern(mlir::MLIRContext *context, - mlir::StringRef funcName, - mlir::PatternBenefit benefit = 1) - : mlir::OpRewritePattern(context, benefit), - funcName(funcName) {} - - mlir::LogicalResult - matchAndRewrite(BConcreteOp op, - mlir::PatternRewriter &rewriter) const override { - BConcreteToBConcreteCAPITypeConverter typeConverter; - - mlir::RankedTensorType tensorResultTy = - op.getResult().getType().template cast(); - - mlir::Value outTensor = rewriter.create( - op.getLoc(), tensorResultTy.getShape(), - tensorResultTy.getElementType()); - - mlir::Value outMemref = - getCastedTensorOperand(rewriter, op.getLoc(), outTensor); - - mlir::SmallVector castedOperands{outMemref}; - castedOperands.append(getCastedTensorOperands(rewriter, op)); - - mlir::func::CallOp callOp = rewriter.create( - op.getLoc(), funcName, mlir::TypeRange{}, castedOperands); - - // Convert remaining, non-tensor types (e.g., plaintext values) - mlir::concretelang::convertOperandAndResultTypes( - rewriter, callOp, [&](mlir::MLIRContext *context, mlir::Type t) { - return typeConverter.convertType(t); - }); - - mlir::Value updatedOutTensor = - rewriter.create(op.getLoc(), - outMemref); - - rewriter.replaceOpWithNewOp(op, tensorResultTy, - updatedOutTensor); - - return mlir::success(); - }; - -private: - std::string funcName; -}; - -struct ConcreteEncodeIntOpPattern - : public mlir::OpRewritePattern { - ConcreteEncodeIntOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : mlir::OpRewritePattern( - context, benefit) {} - - mlir::LogicalResult - matchAndRewrite(mlir::concretelang::Concrete::EncodeIntOp op, - mlir::PatternRewriter &rewriter) const override { - { - mlir::Value castedInt = rewriter.create( - op.getLoc(), rewriter.getIntegerType(64), op->getOperands().front()); - mlir::Value constantShiftOp = rewriter.create( - op.getLoc(), rewriter.getI64IntegerAttr(64 - op.getType().getP())); - - mlir::Type resultType = rewriter.getIntegerType(64); - rewriter.replaceOpWithNewOp( - op, resultType, castedInt, constantShiftOp); - } - return mlir::success(); - }; -}; - -struct ConcreteIntToCleartextOpPattern - : public mlir::OpRewritePattern< - mlir::concretelang::Concrete::IntToCleartextOp> { - ConcreteIntToCleartextOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : mlir::OpRewritePattern( - context, benefit) {} - - mlir::LogicalResult - matchAndRewrite(mlir::concretelang::Concrete::IntToCleartextOp op, - mlir::PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, rewriter.getIntegerType(64), op->getOperands().front()); - return mlir::success(); - }; -}; - -mlir::Value getContextArgument(mlir::Operation *op) { - mlir::Block *block = op->getBlock(); - while (block != nullptr) { - if (llvm::isa(block->getParentOp())) { - - mlir::Value context = block->getArguments().back(); - - assert( - context.getType().isa() && - "the Concrete.context should be the last argument of the enclosing " - "function of the op"); - - return context; - } - block = block->getParentOp()->getBlock(); - } - assert("can't find a function that enclose the op"); - return nullptr; -} - -// Rewrite pattern that rewrite every -// ``` -// %out = "BConcrete.keyswitch_lwe_buffer"(%out, %in) {...}: -// (tensor<2049xi64>) -> (tensor<2049xi64>) -// ``` -// -// to -// -// ``` -// %out = linalg.tensor_init [B] : tensor -// %out_casted = tensor.cast %out : tensor to tensor -// %out_memref = bufferize.to_memref %out_casted ... -// %in_casted = tensor.cast %in : tensor to tensor -// %in_memref = bufferize.to_memref ... -// call @memref_keyswitch_lwe_u64(%out_memref, %in_memref) : -// (tensor, !Concrete.context) -> (tensor) -// ``` -struct BConcreteKeySwitchLweOpPattern - : public mlir::OpRewritePattern< - mlir::concretelang::BConcrete::KeySwitchLweBufferOp> { - BConcreteKeySwitchLweOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : mlir::OpRewritePattern< - mlir::concretelang::BConcrete::KeySwitchLweBufferOp>(context, - benefit) {} - - mlir::LogicalResult - matchAndRewrite(mlir::concretelang::BConcrete::KeySwitchLweBufferOp op, - mlir::PatternRewriter &rewriter) const override { - // Create the output operand - mlir::RankedTensorType tensorResultTy = - op.getResult().getType().template cast(); - mlir::Value outTensor = - rewriter.replaceOpWithNewOp( - op, tensorResultTy.getShape(), tensorResultTy.getElementType()); - mlir::Value outMemref = - getCastedTensorOperand(rewriter, op.getLoc(), outTensor); - - mlir::SmallVector operands{outMemref}; - operands.append(getCastedTensorOperands(rewriter, op)); - operands.push_back(getContextArgument(op)); - - rewriter.create(op.getLoc(), "memref_keyswitch_lwe_u64", - mlir::TypeRange({}), operands); - return mlir::success(); - }; -}; - -// Rewrite pattern that rewrite every -// ``` -// %out = "BConcrete.bootstrap_lwe_buffer"(%in, %acc) {...} : -// (tensor, !Concrete.glwe_ciphertext) -> (tensor) -// ``` -// -// to -// -// ``` -// %out = linalg.tensor_init [B] : tensor -// %out_casted = tensor.cast %out : tensor to tensor -// %out_memref = bufferize.to_memref %out_casted ... -// %in_casted = tensor.cast %in : tensor to tensor -// %in_memref = bufferize.to_memref ... -// call @memref_bootstrap_lwe_u64(%out_memref, %in_memref, %acc_, %ctx) : -// (memref, memref, -// !Concrete.glwe_ciphertext, !Concrete.context) -> () -// ``` -struct BConcreteBootstrapLweOpPattern - : public mlir::OpRewritePattern< - mlir::concretelang::BConcrete::BootstrapLweBufferOp> { - BConcreteBootstrapLweOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : mlir::OpRewritePattern< - mlir::concretelang::BConcrete::BootstrapLweBufferOp>(context, - benefit) {} - - mlir::LogicalResult - matchAndRewrite(mlir::concretelang::BConcrete::BootstrapLweBufferOp op, - mlir::PatternRewriter &rewriter) const override { - - // Create the output operand - mlir::RankedTensorType tensorResultTy = - op.getResult().getType().template cast(); - mlir::Value outTensor = - rewriter.replaceOpWithNewOp( - op, tensorResultTy.getShape(), tensorResultTy.getElementType()); - mlir::Value outMemref = - getCastedTensorOperand(rewriter, op.getLoc(), outTensor); - - mlir::SmallVector operands{outMemref}; - operands.append(getCastedTensorOperands(rewriter, op)); - operands.push_back(getContextArgument(op)); - - rewriter.create(op.getLoc(), "memref_bootstrap_lwe_u64", - mlir::TypeRange({}), operands); - return mlir::success(); - }; -}; - -// Rewrite pattern that rewrite every -// ``` -// "BConcrete.fill_glwe_table"(%glwe, %lut) {glweDimension=1, -// polynomialSize=2048, outPrecision=3} : -// (tensor<4096xi64>, tensor<32xi64>) -> () -// ``` -// -// to -// -// ``` -// %glweDim = arith.constant 1 : i32 -// %polySize = arith.constant 2048 : i32 -// %outPrecision = arith.constant 3 : i32 -// %glwe_ = tensor.cast %glwe : tensor<4096xi64> to tensor -// %lut_ = tensor.cast %lut : tensor<32xi64> to tensor -// call @expand_lut_in_trivial_glwe_ct(%glwe, %polySize, %glweDim, -// %outPrecision, %lut_) : -// (tensor, i32, i32, tensor) -> () -// ``` -struct BConcreteGlweFromTableOpPattern - : public mlir::OpRewritePattern< - mlir::concretelang::BConcrete::FillGlweFromTable> { - BConcreteGlweFromTableOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : mlir::OpRewritePattern< - mlir::concretelang::BConcrete::FillGlweFromTable>(context, - benefit) {} - - mlir::LogicalResult - matchAndRewrite(mlir::concretelang::BConcrete::FillGlweFromTable op, - mlir::PatternRewriter &rewriter) const override { - BConcreteToBConcreteCAPITypeConverter typeConverter; - // %glweDim = arith.constant 1 : i32 - // %polySize = arith.constant 2048 : i32 - // %outPrecision = arith.constant 3 : i32 - - auto castedOp = getCastedTensorOperands(rewriter, op); - - auto polySizeOp = rewriter.create( - op.getLoc(), rewriter.getI32IntegerAttr(op.polynomialSize())); - auto glweDimensionOp = rewriter.create( - op.getLoc(), rewriter.getI32IntegerAttr(op.glweDimension())); - auto outPrecisionOp = rewriter.create( - op.getLoc(), rewriter.getI32IntegerAttr(op.outPrecision())); - - mlir::SmallVector newOperands{ - castedOp[0], polySizeOp, glweDimensionOp, outPrecisionOp, castedOp[1]}; - - // getCastedTensor(op.getLoc(), newOperands, rewriter); - // perform operands conversion - // %glwe_ = tensor.cast %glwe : tensor<4096xi64> to tensor - // %lut_ = tensor.cast %lut : tensor<32xi64> to tensor - - // call @expand_lut_in_trivial_glwe_ct(%glwe, %polySize, %glweDim, - // %lut_) : - // (tensor, i32, i32, tensor) -> () - - rewriter.replaceOpWithNewOp( - op, "memref_expand_lut_in_trivial_glwe_ct_u64", - mlir::SmallVector{}, newOperands); - return mlir::success(); - }; -}; - -/// Populate the RewritePatternSet with all patterns that rewrite Concrete -/// operators to the corresponding function call to the `Concrete C API`. -void populateBConcreteToBConcreteCAPICall(mlir::RewritePatternSet &patterns) { - patterns.add>( - patterns.getContext(), "memref_add_lwe_ciphertexts_u64"); - patterns.add>( - patterns.getContext(), "memref_add_plaintext_lwe_ciphertext_u64"); - patterns.add>( - patterns.getContext(), "memref_mul_cleartext_lwe_ciphertext_u64"); - patterns.add>( - patterns.getContext(), "memref_negate_lwe_ciphertext_u64"); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); -} - -struct AddRuntimeContextToFuncOpPattern - : public mlir::OpRewritePattern { - AddRuntimeContextToFuncOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : mlir::OpRewritePattern(context, benefit) {} - - mlir::LogicalResult - matchAndRewrite(mlir::func::FuncOp oldFuncOp, - mlir::PatternRewriter &rewriter) const override { - mlir::OpBuilder::InsertionGuard guard(rewriter); - mlir::FunctionType oldFuncType = oldFuncOp.getFunctionType(); - - // Add a Concrete.context to the function signature - mlir::SmallVector newInputs(oldFuncType.getInputs().begin(), - oldFuncType.getInputs().end()); - newInputs.push_back( - rewriter.getType()); - mlir::FunctionType newFuncTy = rewriter.getType( - newInputs, oldFuncType.getResults()); - // Create the new func - mlir::func::FuncOp newFuncOp = rewriter.create( - oldFuncOp.getLoc(), oldFuncOp.getName(), newFuncTy); - - // Create the arguments of the new func - mlir::Region &newFuncBody = newFuncOp.getBody(); - mlir::Block *newFuncEntryBlock = new mlir::Block(); - llvm::SmallVector locations(newFuncTy.getInputs().size(), - oldFuncOp.getLoc()); - - newFuncEntryBlock->addArguments(newFuncTy.getInputs(), locations); - newFuncBody.push_back(newFuncEntryBlock); - - // Clone the old body to the new one - mlir::BlockAndValueMapping map; - for (auto arg : llvm::enumerate(oldFuncOp.getArguments())) { - map.map(arg.value(), newFuncEntryBlock->getArgument(arg.index())); - } - for (auto &op : oldFuncOp.getBody().front()) { - newFuncEntryBlock->push_back(op.clone(map)); - } - rewriter.eraseOp(oldFuncOp); - return mlir::success(); - } - - // Legal function are one that are private or has a Concrete.context as last - // arguments. - static bool isLegal(mlir::func::FuncOp funcOp) { - if (!funcOp.isPublic()) { - return true; - } - // TODO : Don't need to add a runtime context for function that doesn't - // manipulates Concrete types. - // - // if (!llvm::any_of(funcOp.getType().getInputs(), [](mlir::Type t) { - // if (auto tensorTy = t.dyn_cast_or_null()) { - // t = tensorTy.getElementType(); - // } - // return llvm::isa( - // t.getDialect()); - // })) { - // return true; - // } - return funcOp.getFunctionType().getNumInputs() >= 1 && - funcOp.getFunctionType() - .getInputs() - .back() - .isa(); - } -}; - -namespace { -struct BConcreteToBConcreteCAPIPass - : public BConcreteToBConcreteCAPIBase { - void runOnOperation() final; -}; -} // namespace - -void BConcreteToBConcreteCAPIPass::runOnOperation() { - mlir::ModuleOp op = getOperation(); - - // First of all add the Concrete.context to the block arguments of function - // that manipulates ciphertexts. - { - mlir::ConversionTarget target(getContext()); - mlir::RewritePatternSet patterns(&getContext()); - - target.addDynamicallyLegalOp( - [&](mlir::func::FuncOp funcOp) { - return AddRuntimeContextToFuncOpPattern::isLegal(funcOp); - }); - - patterns.add(patterns.getContext()); - - // Apply the conversion - if (mlir::applyPartialConversion(op, target, std::move(patterns)) - .failed()) { - this->signalPassFailure(); - return; - } - } - - // Insert forward declaration - mlir::IRRewriter rewriter(&getContext()); - if (insertForwardDeclarations(op, rewriter).failed()) { - this->signalPassFailure(); - } - // Rewrite Concrete ops to CallOp to the Concrete C API - { - mlir::ConversionTarget target(getContext()); - mlir::RewritePatternSet patterns(&getContext()); - - target.addIllegalDialect(); - - target.addLegalDialect(); - - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - - populateBConcreteToBConcreteCAPICall(patterns); - - if (mlir::applyPartialConversion(op, target, std::move(patterns)) - .failed()) { - this->signalPassFailure(); - } - } -} -} // namespace - -namespace mlir { -namespace concretelang { -std::unique_ptr> -createConvertBConcreteToBConcreteCAPIPass() { - return std::make_unique(); -} -} // namespace concretelang -} // namespace mlir diff --git a/compiler/lib/Conversion/BConcreteToBConcreteCAPI/CMakeLists.txt b/compiler/lib/Conversion/BConcreteToBConcreteCAPI/CMakeLists.txt deleted file mode 100644 index ab2b16111..000000000 --- a/compiler/lib/Conversion/BConcreteToBConcreteCAPI/CMakeLists.txt +++ /dev/null @@ -1,17 +0,0 @@ -add_mlir_dialect_library(BConcreteToBConcreteCAPI - BConcreteToBConcreteCAPI.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE - - DEPENDS - BConcreteDialect - mlir-headers - - LINK_LIBS PUBLIC - ConcretelangConversion - MLIRIR - MLIRTransforms -) - -target_link_libraries(BConcreteToBConcreteCAPI PUBLIC MLIRIR) diff --git a/compiler/lib/Conversion/CMakeLists.txt b/compiler/lib/Conversion/CMakeLists.txt index d8ea78739..428de75c1 100644 --- a/compiler/lib/Conversion/CMakeLists.txt +++ b/compiler/lib/Conversion/CMakeLists.txt @@ -3,7 +3,6 @@ add_subdirectory(TFHEGlobalParametrization) add_subdirectory(TFHEToConcrete) add_subdirectory(FHETensorOpsToLinalg) add_subdirectory(ConcreteToBConcrete) -add_subdirectory(BConcreteToBConcreteCAPI) add_subdirectory(MLIRLowerableDialectsToLLVM) add_subdirectory(LinalgExtras) diff --git a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp index 74eba9d8d..3de07402d 100644 --- a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp +++ b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp @@ -154,11 +154,7 @@ llvm::Optional MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) { if (type.isa() || type.isa() || - type.isa() || - type.isa() || type.isa() || - type.isa() || - type.isa() || type.isa()) { return mlir::LLVM::LLVMPointerType::get( mlir::IntegerType::get(type.getContext(), 64)); diff --git a/compiler/lib/Dialect/Concrete/IR/ConcreteDialect.cpp b/compiler/lib/Dialect/Concrete/IR/ConcreteDialect.cpp index 007baf8c3..a9bf3c96f 100644 --- a/compiler/lib/Dialect/Concrete/IR/ConcreteDialect.cpp +++ b/compiler/lib/Dialect/Concrete/IR/ConcreteDialect.cpp @@ -152,62 +152,3 @@ mlir::Type PlaintextType::parse(mlir::AsmParser &parser) { return getChecked(loc, loc.getContext(), p); } - -mlir::Type PlaintextListType::parse(mlir::AsmParser &parser) { - return get(parser.getContext()); -} - -void PlaintextListType::print(mlir::AsmPrinter &p) const {} - -mlir::Type ForeignPlaintextListType::parse(mlir::AsmParser &parser) { - return get(parser.getContext()); -} - -void ForeignPlaintextListType::print(mlir::AsmPrinter &p) const {} - -mlir::Type LweKeySwitchKeyType::parse(mlir::AsmParser &parser) { - return get(parser.getContext()); -} - -void LweKeySwitchKeyType::print(mlir::AsmPrinter &p) const {} - -mlir::Type LweBootstrapKeyType::parse(mlir::AsmParser &parser) { - return get(parser.getContext()); -} - -void LweBootstrapKeyType::print(mlir::AsmPrinter &p) const {} - -void ContextType::print(mlir::AsmPrinter &p) const {} - -mlir::Type ContextType::parse(mlir::AsmParser &parser) { - return get(parser.getContext()); -} - -::mlir::Type -ConcreteDialect::parseType(::mlir::DialectAsmParser &parser) const { - mlir::Type type; - - std::string types_str[] = { - "plaintext", "plaintext_list", "foreign_plaintext_list", - "lwe_ciphertext", "lwe_key_switch_key", "lwe_bootstrap_key", - "glwe_ciphertext", "cleartext", "context", - }; - - for (const std::string &type_str : types_str) { - if (parser.parseOptionalKeyword(type_str).succeeded()) { - generatedTypeParser(parser, type_str, type); - return type; - } - } - - parser.emitError(parser.getCurrentLocation(), "Unknown Concrete type"); - - return type; -} - -void ConcreteDialect::printType(::mlir::Type type, - ::mlir::DialectAsmPrinter &printer) const { - if (generatedTypePrinter(type, printer).failed()) - // Calling default printer if failed to print Concrete type - printer.printType(type); -} diff --git a/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp b/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp index 1cbd5b59a..d5e44b07f 100644 --- a/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp +++ b/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp @@ -166,11 +166,7 @@ static mlir::Value getSizeInBytes(Value val, Location loc, OpBuilder builder) { // bytes until we can get the actual size of the actual types. if (type.isa() || type.isa() || - type.isa() || - type.isa() || - type.isa() || - type.isa() || - type.isa()) + type.isa()) return builder.create(loc, builder.getI64IntegerAttr(8)); // For all other types, get type size. diff --git a/compiler/tests/Dialect/Concrete/Concrete/types.mlir b/compiler/tests/Dialect/Concrete/Concrete/types.mlir index c17f352c9..f1f217f38 100644 --- a/compiler/tests/Dialect/Concrete/Concrete/types.mlir +++ b/compiler/tests/Dialect/Concrete/Concrete/types.mlir @@ -7,36 +7,12 @@ func @type_plaintext(%arg0: !Concrete.plaintext<7>) -> !Concrete.plaintext<7> { return %arg0: !Concrete.plaintext<7> } -// CHECK-LABEL: func @type_plaintext_list(%arg0: !Concrete.plaintext_list) -> !Concrete.plaintext_list -func @type_plaintext_list(%arg0: !Concrete.plaintext_list) -> !Concrete.plaintext_list { - // CHECK-NEXT: return %arg0 : !Concrete.plaintext_list - return %arg0: !Concrete.plaintext_list -} - -// CHECK-LABEL: func @type_foreign_plaintext_list(%arg0: !Concrete.foreign_plaintext_list) -> !Concrete.foreign_plaintext_list -func @type_foreign_plaintext_list(%arg0: !Concrete.foreign_plaintext_list) -> !Concrete.foreign_plaintext_list { - // CHECK-NEXT: return %arg0 : !Concrete.foreign_plaintext_list - return %arg0: !Concrete.foreign_plaintext_list -} - // CHECK-LABEL: func @type_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> func @type_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { // CHECK-NEXT: return %arg0 : !Concrete.lwe_ciphertext<2048,7> return %arg0: !Concrete.lwe_ciphertext<2048,7> } -// CHECK-LABEL: func @type_lwe_key_switch_key(%arg0: !Concrete.lwe_key_switch_key) -> !Concrete.lwe_key_switch_key -func @type_lwe_key_switch_key(%arg0: !Concrete.lwe_key_switch_key) -> !Concrete.lwe_key_switch_key { - // CHECK-NEXT: return %arg0 : !Concrete.lwe_key_switch_key - return %arg0: !Concrete.lwe_key_switch_key -} - -// CHECK-LABEL: func @type_lwe_bootstrap_key(%arg0: !Concrete.lwe_bootstrap_key) -> !Concrete.lwe_bootstrap_key -func @type_lwe_bootstrap_key(%arg0: !Concrete.lwe_bootstrap_key) -> !Concrete.lwe_bootstrap_key { - // CHECK-NEXT: return %arg0 : !Concrete.lwe_bootstrap_key - return %arg0: !Concrete.lwe_bootstrap_key -} - // CHECK-LABEL: func @type_cleartext(%arg0: !Concrete.cleartext<5>) -> !Concrete.cleartext<5> func @type_cleartext(%arg0: !Concrete.cleartext<5>) -> !Concrete.cleartext<5> { // CHECK-NEXT: return %arg0 : !Concrete.cleartext<5>