mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
refactor(compiler): Introduce compilation pipeline with multiple entries / exits
This refactoring commit restructures the compilation pipeline of
`zamacompiler`, such that it is possible to enter and exit the
pipeline at different points, effectively defining the level of
abstraction at the input and the required level of abstraction for the
output.
The entry point is specified using the `--entry-dialect`
argument. Valid choices are:
`--entry-dialect=hlfhe`: Source contains HLFHE operations
`--entry-dialect=midlfhe`: Source contains MidLFHE operations
`--entry-dialect=lowlfhe`: Source contains LowLFHE operations
`--entry-dialect=std`: Source does not contain any FHE Operations
`--entry-dialect=llvm`: Source is in LLVM dialect
The exit point is defined by an action, specified using --action.
`--action=roundtrip`:
Parse the source file to in-memory representation and immediately
dump as text without any processing
`--action=dump-midlfhe`:
Lower source to MidLFHE and dump result as text
`--action=dump-lowlfhe`:
Lower source to LowLFHE and dump result as text
`--action=dump-std`:
Lower source to only standard MLIR dialects (i.e., all FHE
operations have already been lowered)
`--action=dump-llvm-dialect`:
Lower source to MLIR's LLVM dialect (i.e., the LLVM dialect, not
LLVM IR)
`--action=dump-llvm-ir`:
Lower source to plain LLVM IR (i.e., not the LLVM dialect, but
actual LLVM IR)
`--action=dump-optimized-llvm-ir`:
Lower source to plain LLVM IR (i.e., not the LLVM dialect, but
actual LLVM IR), pass the result through the LLVM optimizer and
print the result.
`--action=dump-jit-invoke`:
Execute the full lowering pipeline to optimized LLVM IR, JIT
compile the result, invoke the function specified in
`--jit-funcname` with the parameters from `--jit-args` and print
the functions return value.
This commit is contained in:
committed by
Quentin Bourgerie
parent
ddebedd1d6
commit
30374ebb2c
@@ -19,7 +19,7 @@ struct V0Parameter {
|
||||
size_t ksLevel;
|
||||
size_t ksLogBase;
|
||||
|
||||
V0Parameter() {}
|
||||
V0Parameter() = delete;
|
||||
|
||||
V0Parameter(size_t k, size_t polynomialSize, size_t nSmall, size_t brLevel,
|
||||
size_t brLogBase, size_t ksLevel, size_t ksLogBase)
|
||||
@@ -31,11 +31,14 @@ struct V0Parameter {
|
||||
};
|
||||
|
||||
struct V0FHEContext {
|
||||
V0FHEContext() = delete;
|
||||
V0FHEContext(const V0FHEConstraint &constraint, const V0Parameter ¶meter)
|
||||
: constraint(constraint), parameter(parameter) {}
|
||||
|
||||
V0FHEConstraint constraint;
|
||||
V0Parameter parameter;
|
||||
};
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@@ -1,17 +1,7 @@
|
||||
#ifndef ZAMALANG_SUPPORT_COMPILER_ENGINE_H
|
||||
#define ZAMALANG_SUPPORT_COMPILER_ENGINE_H
|
||||
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
#include "zamalang/Support/CompilerTools.h"
|
||||
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
||||
#include <string>
|
||||
#include "Jit.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
@@ -55,4 +45,4 @@ private:
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
#ifndef ZAMALANG_SUPPORT_COMPILERTOOLS_H_
|
||||
#define ZAMALANG_SUPPORT_COMPILERTOOLS_H_
|
||||
|
||||
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
|
||||
#include <mlir/ExecutionEngine/ExecutionEngine.h>
|
||||
#include <mlir/Pass/PassManager.h>
|
||||
|
||||
#include "zamalang/Support/ClientParameters.h"
|
||||
#include "zamalang/Support/KeySet.h"
|
||||
#include "zamalang/Support/V0Parameters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
class CompilerTools {
|
||||
public:
|
||||
struct LowerOptions {
|
||||
llvm::function_ref<bool(std::string)> enablePass;
|
||||
bool verbose;
|
||||
|
||||
LowerOptions()
|
||||
: verbose(false), enablePass([](std::string pass) { return true; }){};
|
||||
};
|
||||
|
||||
/// lowerHLFHEToMlirLLVMDialect run all passes to lower FHE dialects to mlir
|
||||
/// lowerable to llvm dialect.
|
||||
/// The given module MLIR operation would be modified and the constraints set.
|
||||
static mlir::LogicalResult
|
||||
lowerHLFHEToMlirStdsDialect(mlir::MLIRContext &context,
|
||||
mlir::Operation *module, V0FHEContext &fheContext,
|
||||
LowerOptions options = LowerOptions());
|
||||
|
||||
/// lowerMlirStdsDialectToMlirLLVMDialect run all passes to lower MLIR
|
||||
/// dialects to MLIR LLVM dialect. The given module MLIR operation would be
|
||||
/// modified.
|
||||
static mlir::LogicalResult
|
||||
lowerMlirStdsDialectToMlirLLVMDialect(mlir::MLIRContext &context,
|
||||
mlir::Operation *module,
|
||||
LowerOptions options = LowerOptions());
|
||||
|
||||
static llvm::Expected<std::unique_ptr<llvm::Module>>
|
||||
toLLVMModule(llvm::LLVMContext &llvmContext, mlir::ModuleOp &module,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline);
|
||||
};
|
||||
|
||||
/// JITLambda is a tool to JIT compile an mlir module and to invoke a function
|
||||
/// of the module.
|
||||
class JITLambda {
|
||||
public:
|
||||
class Argument {
|
||||
public:
|
||||
Argument(KeySet &keySet);
|
||||
~Argument();
|
||||
|
||||
// Create lambda Argument that use the given KeySet to perform encryption
|
||||
// and decryption operations.
|
||||
static llvm::Expected<std::unique_ptr<Argument>> create(KeySet &keySet);
|
||||
|
||||
// Set a scalar argument at the given pos as a uint64_t.
|
||||
llvm::Error setArg(size_t pos, uint64_t arg);
|
||||
|
||||
// Set a argument at the given pos as a tensor of int64.
|
||||
llvm::Error setArg(size_t pos, uint64_t *data, size_t size) {
|
||||
return setArg(pos, 64, (void *)data, size);
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a tensor of int32.
|
||||
llvm::Error setArg(size_t pos, uint32_t *data, size_t size) {
|
||||
return setArg(pos, 32, (void *)data, size);
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a tensor of int32.
|
||||
llvm::Error setArg(size_t pos, uint16_t *data, size_t size) {
|
||||
return setArg(pos, 16, (void *)data, size);
|
||||
}
|
||||
|
||||
// Set a tensor argument at the given pos as a uint64_t.
|
||||
llvm::Error setArg(size_t pos, uint8_t *data, size_t size) {
|
||||
return setArg(pos, 8, (void *)data, size);
|
||||
}
|
||||
|
||||
// Get the result at the given pos as an uint64_t.
|
||||
llvm::Error getResult(size_t pos, uint64_t &res);
|
||||
|
||||
// Fill the result.
|
||||
llvm::Error getResult(size_t pos, uint64_t *res, size_t size);
|
||||
|
||||
private:
|
||||
llvm::Error setArg(size_t pos, size_t width, void *data, size_t size);
|
||||
|
||||
friend JITLambda;
|
||||
// Store the pointer on inputs values and outputs values
|
||||
std::vector<void *> rawArg;
|
||||
// Store the values of inputs
|
||||
std::vector<void *> inputs;
|
||||
// Store the values of outputs
|
||||
std::vector<void *> outputs;
|
||||
// Store the input gates description and the offset of the argument.
|
||||
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> inputGates;
|
||||
// Store the outputs gates description and the offset of the argument.
|
||||
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> outputGates;
|
||||
// Store allocated lwe ciphertexts (for free)
|
||||
std::vector<LweCiphertext_u64 *> allocatedCiphertexts;
|
||||
// Store buffers of ciphertexts
|
||||
std::vector<LweCiphertext_u64 **> ciphertextBuffers;
|
||||
|
||||
KeySet &keySet;
|
||||
};
|
||||
JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name)
|
||||
: type(type), name(name){};
|
||||
|
||||
/// create a JITLambda that point to the function name of the given module.
|
||||
static llvm::Expected<std::unique_ptr<JITLambda>>
|
||||
create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline);
|
||||
|
||||
/// invokeRaw execute the jit lambda with a list of Argument, the last one is
|
||||
/// used to store the result of the computation.
|
||||
/// Example:
|
||||
/// uin64_t arg0 = 1;
|
||||
/// uin64_t res;
|
||||
/// llvm::SmallVector<void *> args{&arg1, &res};
|
||||
/// lambda.invokeRaw(args);
|
||||
llvm::Error invokeRaw(llvm::MutableArrayRef<void *> args);
|
||||
|
||||
/// invoke the jit lambda with the Argument.
|
||||
llvm::Error invoke(Argument &args);
|
||||
|
||||
private:
|
||||
mlir::LLVM::LLVMFunctionType type;
|
||||
llvm::StringRef name;
|
||||
std::unique_ptr<mlir::ExecutionEngine> engine;
|
||||
};
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -1,11 +1,12 @@
|
||||
#ifndef COMPILER_JIT_H
|
||||
#define COMPILER_JIT_H
|
||||
|
||||
#include "zamalang/Support/CompilerTools.h"
|
||||
|
||||
#include <mlir/ExecutionEngine/ExecutionEngine.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
|
||||
#include <zamalang/Support/KeySet.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
mlir::LogicalResult
|
||||
@@ -13,6 +14,96 @@ runJit(mlir::ModuleOp module, llvm::StringRef func,
|
||||
llvm::ArrayRef<uint64_t> funcArgs, mlir::zamalang::KeySet &keySet,
|
||||
std::function<llvm::Error(llvm::Module *)> optPipeline,
|
||||
llvm::raw_ostream &os);
|
||||
|
||||
/// JITLambda is a tool to JIT compile an mlir module and to invoke a function
|
||||
/// of the module.
|
||||
class JITLambda {
|
||||
public:
|
||||
class Argument {
|
||||
public:
|
||||
Argument(KeySet &keySet);
|
||||
~Argument();
|
||||
|
||||
// Create lambda Argument that use the given KeySet to perform encryption
|
||||
// and decryption operations.
|
||||
static llvm::Expected<std::unique_ptr<Argument>> create(KeySet &keySet);
|
||||
|
||||
// Set a scalar argument at the given pos as a uint64_t.
|
||||
llvm::Error setArg(size_t pos, uint64_t arg);
|
||||
|
||||
// Set a argument at the given pos as a tensor of int64.
|
||||
llvm::Error setArg(size_t pos, uint64_t *data, size_t size) {
|
||||
return setArg(pos, 64, (void *)data, size);
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a tensor of int32.
|
||||
llvm::Error setArg(size_t pos, uint32_t *data, size_t size) {
|
||||
return setArg(pos, 32, (void *)data, size);
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a tensor of int32.
|
||||
llvm::Error setArg(size_t pos, uint16_t *data, size_t size) {
|
||||
return setArg(pos, 16, (void *)data, size);
|
||||
}
|
||||
|
||||
// Set a tensor argument at the given pos as a uint64_t.
|
||||
llvm::Error setArg(size_t pos, uint8_t *data, size_t size) {
|
||||
return setArg(pos, 8, (void *)data, size);
|
||||
}
|
||||
|
||||
// Get the result at the given pos as an uint64_t.
|
||||
llvm::Error getResult(size_t pos, uint64_t &res);
|
||||
|
||||
// Fill the result.
|
||||
llvm::Error getResult(size_t pos, uint64_t *res, size_t size);
|
||||
|
||||
private:
|
||||
llvm::Error setArg(size_t pos, size_t width, void *data, size_t size);
|
||||
|
||||
friend JITLambda;
|
||||
// Store the pointer on inputs values and outputs values
|
||||
std::vector<void *> rawArg;
|
||||
// Store the values of inputs
|
||||
std::vector<void *> inputs;
|
||||
// Store the values of outputs
|
||||
std::vector<void *> outputs;
|
||||
// Store the input gates description and the offset of the argument.
|
||||
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> inputGates;
|
||||
// Store the outputs gates description and the offset of the argument.
|
||||
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> outputGates;
|
||||
// Store allocated lwe ciphertexts (for free)
|
||||
std::vector<LweCiphertext_u64 *> allocatedCiphertexts;
|
||||
// Store buffers of ciphertexts
|
||||
std::vector<LweCiphertext_u64 **> ciphertextBuffers;
|
||||
|
||||
KeySet &keySet;
|
||||
};
|
||||
JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name)
|
||||
: type(type), name(name){};
|
||||
|
||||
/// create a JITLambda that point to the function name of the given module.
|
||||
static llvm::Expected<std::unique_ptr<JITLambda>>
|
||||
create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline);
|
||||
|
||||
/// invokeRaw execute the jit lambda with a list of Argument, the last one is
|
||||
/// used to store the result of the computation.
|
||||
/// Example:
|
||||
/// uin64_t arg0 = 1;
|
||||
/// uin64_t res;
|
||||
/// llvm::SmallVector<void *> args{&arg1, &res};
|
||||
/// lambda.invokeRaw(args);
|
||||
llvm::Error invokeRaw(llvm::MutableArrayRef<void *> args);
|
||||
|
||||
/// invoke the jit lambda with the Argument.
|
||||
llvm::Error invoke(Argument &args);
|
||||
|
||||
private:
|
||||
mlir::LLVM::LLVMFunctionType type;
|
||||
llvm::StringRef name;
|
||||
std::unique_ptr<mlir::ExecutionEngine> engine;
|
||||
};
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
42
compiler/include/zamalang/Support/Pipeline.h
Normal file
42
compiler/include/zamalang/Support/Pipeline.h
Normal file
@@ -0,0 +1,42 @@
|
||||
#ifndef ZAMALANG_SUPPORT_PIPELINE_H_
|
||||
#define ZAMALANG_SUPPORT_PIPELINE_H_
|
||||
|
||||
#include <llvm/IR/Module.h>
|
||||
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
#include <zamalang/Support/V0Parameters.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
namespace pipeline {
|
||||
|
||||
mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module, bool verbose);
|
||||
|
||||
mlir::LogicalResult lowerMidLFHEToLowLFHE(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
V0FHEContext &fheContext,
|
||||
bool parametrize);
|
||||
|
||||
mlir::LogicalResult lowerLowLFHEToStd(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module);
|
||||
|
||||
mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module, bool verbose);
|
||||
|
||||
mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext,
|
||||
llvm::Module &module);
|
||||
|
||||
mlir::LogicalResult lowerHLFHEToStd(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
V0FHEContext &fheContext, bool verbose);
|
||||
|
||||
std::unique_ptr<llvm::Module>
|
||||
lowerLLVMDialectToLLVMIR(mlir::MLIRContext &context,
|
||||
llvm::LLVMContext &llvmContext,
|
||||
mlir::ModuleOp &module);
|
||||
} // namespace pipeline
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user