mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
docs: use consistent style for comment blocks
prefix comment blocks with ///
This commit is contained in:
@@ -17,13 +17,13 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// C wrapper of the mlir::concretelang::LambdaArgument
|
||||
/// C wrapper of the mlir::concretelang::LambdaArgument
|
||||
struct lambdaArgument {
|
||||
std::shared_ptr<mlir::concretelang::LambdaArgument> ptr;
|
||||
};
|
||||
typedef struct lambdaArgument lambdaArgument;
|
||||
|
||||
// Hold a list of lambdaArgument to represent execution arguments
|
||||
/// Hold a list of lambdaArgument to represent execution arguments
|
||||
struct executionArguments {
|
||||
lambdaArgument *data;
|
||||
size_t size;
|
||||
@@ -136,13 +136,13 @@ evaluationKeysUnserialize(const std::string &buffer);
|
||||
MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize(
|
||||
concretelang::clientlib::EvaluationKeys &evaluationKeys);
|
||||
|
||||
// Parse then print a textual representation of an MLIR module
|
||||
/// Parse then print a textual representation of an MLIR module
|
||||
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
|
||||
|
||||
// Terminate parallelization
|
||||
/// Terminate parallelization
|
||||
MLIR_CAPI_EXPORTED void terminateParallelization();
|
||||
|
||||
// Create a lambdaArgument from a tensor of different data types
|
||||
/// Create a lambdaArgument from a tensor of different data types
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU8(
|
||||
std::vector<uint8_t> data, std::vector<int64_t> dimensions);
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU16(
|
||||
@@ -151,22 +151,22 @@ MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU32(
|
||||
std::vector<uint32_t> data, std::vector<int64_t> dimensions);
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU64(
|
||||
std::vector<uint64_t> data, std::vector<int64_t> dimensions);
|
||||
// Create a lambdaArgument from a scalar
|
||||
/// Create a lambdaArgument from a scalar
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromScalar(uint64_t scalar);
|
||||
// Check if a lambdaArgument holds a tensor
|
||||
/// Check if a lambdaArgument holds a tensor
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg);
|
||||
// Get tensor data from lambdaArgument
|
||||
/// Get tensor data from lambdaArgument
|
||||
MLIR_CAPI_EXPORTED std::vector<uint64_t>
|
||||
lambdaArgumentGetTensorData(lambdaArgument &lambda_arg);
|
||||
// Get tensor dimensions from lambdaArgument
|
||||
/// Get tensor dimensions from lambdaArgument
|
||||
MLIR_CAPI_EXPORTED std::vector<int64_t>
|
||||
lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg);
|
||||
// Check if a lambdaArgument holds a scalar
|
||||
/// Check if a lambdaArgument holds a scalar
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg);
|
||||
// Get scalar value from lambdaArgument
|
||||
/// Get scalar value from lambdaArgument
|
||||
MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(lambdaArgument &lambda_arg);
|
||||
|
||||
// Compile the textual representation of MLIR modules to a library.
|
||||
/// Compile the textual representation of MLIR modules to a library.
|
||||
MLIR_CAPI_EXPORTED std::string library(std::string libraryPath,
|
||||
std::vector<std::string> modules);
|
||||
|
||||
|
||||
@@ -29,8 +29,8 @@ using tensor1_out = std::vector<scalar_out>;
|
||||
using tensor2_out = std::vector<std::vector<scalar_out>>;
|
||||
using tensor3_out = std::vector<std::vector<std::vector<scalar_out>>>;
|
||||
|
||||
/// Low-level class to create the client side view of a FHE function.
|
||||
class ClientLambda {
|
||||
/// Low-level class to create the client side view of a FHE function.
|
||||
public:
|
||||
virtual ~ClientLambda() = default;
|
||||
|
||||
|
||||
@@ -104,11 +104,11 @@ static inline bool operator==(const EncryptionGate &lhs,
|
||||
}
|
||||
|
||||
struct CircuitGateShape {
|
||||
// Width of the scalar value
|
||||
/// Width of the scalar value
|
||||
size_t width;
|
||||
// Dimensions of the tensor, empty if scalar
|
||||
/// Dimensions of the tensor, empty if scalar
|
||||
std::vector<int64_t> dimensions;
|
||||
// Size of the buffer containing the tensor
|
||||
/// Size of the buffer containing the tensor
|
||||
size_t size;
|
||||
};
|
||||
static inline bool operator==(const CircuitGateShape &lhs,
|
||||
|
||||
@@ -23,11 +23,11 @@ using concretelang::error::StringError;
|
||||
|
||||
class PublicArguments;
|
||||
|
||||
/// Temporary object used to hold and encrypt parameters before calling a
|
||||
/// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...).
|
||||
/// Otherwise convert it to a PublicArguments and use
|
||||
/// serializeCall(PublicArguments, KeySet).
|
||||
class EncryptedArguments {
|
||||
/// Temporary object used to hold and encrypt parameters before calling a
|
||||
/// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...).
|
||||
/// Otherwise convert it to a PublicArguments and use
|
||||
/// serializeCall(PublicArguments, KeySet).
|
||||
public:
|
||||
EncryptedArguments() : currentPos(0) {}
|
||||
|
||||
@@ -64,18 +64,18 @@ public:
|
||||
RuntimeContext runtimeContext);
|
||||
|
||||
/// Check that all arguments as been pushed.
|
||||
/// TODO: Remove public method here
|
||||
// TODO: Remove public method here
|
||||
outcome::checked<void, StringError> checkAllArgs(KeySet &keySet);
|
||||
|
||||
public:
|
||||
// Add a uint64_t scalar argument.
|
||||
/// Add a uint64_t scalar argument.
|
||||
outcome::checked<void, StringError> pushArg(uint64_t arg, KeySet &keySet);
|
||||
|
||||
/// Add a vector-tensor argument.
|
||||
outcome::checked<void, StringError> pushArg(std::vector<uint8_t> arg,
|
||||
KeySet &keySet);
|
||||
|
||||
// Add a 1D tensor argument with data and size of the dimension.
|
||||
/// Add a 1D tensor argument with data and size of the dimension.
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError> pushArg(const T *data, int64_t dim1,
|
||||
KeySet &keySet) {
|
||||
@@ -114,14 +114,14 @@ public:
|
||||
|
||||
// Generalize by computing shape by template recursion
|
||||
|
||||
// Set a argument at the given pos as a 1D tensor of T.
|
||||
/// Set a argument at the given pos as a 1D tensor of T.
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError> pushArg(T *data, int64_t dim1,
|
||||
KeySet &keySet) {
|
||||
return pushArg<T>(data, llvm::ArrayRef<int64_t>(&dim1, 1), keySet);
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a tensor of T.
|
||||
/// Set a argument at the given pos as a tensor of T.
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(T *data, llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
|
||||
@@ -133,8 +133,8 @@ public:
|
||||
llvm::ArrayRef<int64_t> shape,
|
||||
KeySet &keySet);
|
||||
|
||||
// Recursive case for scalars: extract first scalar argument from
|
||||
// parameter pack and forward rest
|
||||
/// Recursive case for scalars: extract first scalar argument from
|
||||
/// parameter pack and forward rest
|
||||
template <typename Arg0, typename... OtherArgs>
|
||||
outcome::checked<void, StringError> pushArgs(KeySet &keySet, Arg0 arg0,
|
||||
OtherArgs... others) {
|
||||
@@ -142,8 +142,8 @@ public:
|
||||
return pushArgs(keySet, others...);
|
||||
}
|
||||
|
||||
// Recursive case for tensors: extract pointer and size from
|
||||
// parameter pack and forward rest
|
||||
/// Recursive case for tensors: extract pointer and size from
|
||||
/// parameter pack and forward rest
|
||||
template <typename Arg0, typename... OtherArgs>
|
||||
outcome::checked<void, StringError>
|
||||
pushArgs(KeySet &keySet, Arg0 *arg0, size_t size, OtherArgs... others) {
|
||||
@@ -151,7 +151,7 @@ public:
|
||||
return pushArgs(keySet, others...);
|
||||
}
|
||||
|
||||
// Terminal case of pushArgs
|
||||
/// Terminal case of pushArgs
|
||||
outcome::checked<void, StringError> pushArgs(KeySet &keySet) {
|
||||
return checkAllArgs(keySet);
|
||||
}
|
||||
@@ -160,11 +160,11 @@ private:
|
||||
outcome::checked<void, StringError> checkPushTooManyArgs(KeySet &keySet);
|
||||
|
||||
private:
|
||||
// Position of the next pushed argument
|
||||
/// Position of the next pushed argument
|
||||
size_t currentPos;
|
||||
std::vector<void *> preparedArgs;
|
||||
|
||||
// Store buffers of ciphertexts
|
||||
/// Store buffers of ciphertexts
|
||||
std::vector<TensorData> ciphertextBuffers;
|
||||
};
|
||||
|
||||
|
||||
@@ -32,43 +32,43 @@ public:
|
||||
~KeySet();
|
||||
KeySet(KeySet &other) = delete;
|
||||
|
||||
// allocate a KeySet according the ClientParameters.
|
||||
/// allocate a KeySet according the ClientParameters.
|
||||
static outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
generate(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
// isInputEncrypted return true if the input at the given pos is encrypted.
|
||||
/// isInputEncrypted return true if the input at the given pos is encrypted.
|
||||
bool isInputEncrypted(size_t pos);
|
||||
|
||||
// getInputLweSecretKeyParam returns the parameters of the lwe secret key for
|
||||
// the input at the given `pos`.
|
||||
// The input must be encrupted
|
||||
/// getInputLweSecretKeyParam returns the parameters of the lwe secret key for
|
||||
/// the input at the given `pos`.
|
||||
/// The input must be encrupted
|
||||
LweSecretKeyParam getInputLweSecretKeyParam(size_t pos) {
|
||||
auto gate = inputGate(pos);
|
||||
auto inputSk = this->secretKeys.find(gate.encryption->secretKeyID);
|
||||
return inputSk->second.first;
|
||||
}
|
||||
|
||||
// getOutputLweSecretKeyParam returns the parameters of the lwe secret key for
|
||||
// the given output.
|
||||
/// getOutputLweSecretKeyParam returns the parameters of the lwe secret key
|
||||
/// for the given output.
|
||||
LweSecretKeyParam getOutputLweSecretKeyParam(size_t pos) {
|
||||
auto gate = outputGate(pos);
|
||||
auto outputSk = this->secretKeys.find(gate.encryption->secretKeyID);
|
||||
return outputSk->second.first;
|
||||
}
|
||||
|
||||
// allocate a lwe ciphertext buffer for the argument at argPos, set the size
|
||||
// of the allocated buffer.
|
||||
/// allocate a lwe ciphertext buffer for the argument at argPos, set the size
|
||||
/// of the allocated buffer.
|
||||
outcome::checked<void, StringError>
|
||||
allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size);
|
||||
|
||||
// encrypt the input to the ciphertext for the argument at argPos.
|
||||
/// encrypt the input to the ciphertext for the argument at argPos.
|
||||
outcome::checked<void, StringError>
|
||||
encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input);
|
||||
|
||||
// isOuputEncrypted return true if the output at the given pos is encrypted.
|
||||
/// isOuputEncrypted return true if the output at the given pos is encrypted.
|
||||
bool isOutputEncrypted(size_t pos);
|
||||
|
||||
// decrypt the ciphertext to the output for the argument at argPos.
|
||||
/// decrypt the ciphertext to the output for the argument at argPos.
|
||||
outcome::checked<void, StringError>
|
||||
decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output);
|
||||
|
||||
|
||||
@@ -32,9 +32,10 @@ namespace clientlib {
|
||||
using concretelang::error::StringError;
|
||||
|
||||
class EncryptedArguments;
|
||||
|
||||
/// PublicArguments will be sended to the server. It includes encrypted
|
||||
/// arguments and public keys.
|
||||
class PublicArguments {
|
||||
/// PublicArguments will be sended to the server. It includes encrypted
|
||||
/// arguments and public keys.
|
||||
public:
|
||||
PublicArguments(const ClientParameters &clientParameters,
|
||||
std::vector<void *> &&preparedArgs,
|
||||
@@ -56,13 +57,13 @@ private:
|
||||
|
||||
ClientParameters clientParameters;
|
||||
std::vector<void *> preparedArgs;
|
||||
// Store buffers of ciphertexts
|
||||
/// Store buffers of ciphertexts
|
||||
std::vector<TensorData> ciphertextBuffers;
|
||||
};
|
||||
|
||||
/// PublicResult is a result of a ServerLambda call which contains encrypted
|
||||
/// results.
|
||||
struct PublicResult {
|
||||
/// PublicResult is a result of a ServerLambda call which contains encrypted
|
||||
/// results.
|
||||
|
||||
PublicResult(const ClientParameters &clientParameters,
|
||||
std::vector<TensorData> buffers = {})
|
||||
|
||||
@@ -52,10 +52,10 @@ PlaintextType convertPlaintextTypeFromPType(mlir::MLIRContext *context,
|
||||
return PlaintextType::get(context, type.getP() + 1);
|
||||
}
|
||||
|
||||
// convertPlaintextTypeFromType create a plaintext type according the
|
||||
// precision of the given type argument. The type should be a GLWECipherText
|
||||
// (if operand is not yet lowered) or a LWECipherTextType (if operand is
|
||||
// already lowered).
|
||||
/// convertPlaintextTypeFromType create a plaintext type according the
|
||||
/// precision of the given type argument. The type should be a GLWECipherText
|
||||
/// (if operand is not yet lowered) or a LWECipherTextType (if operand is
|
||||
/// already lowered).
|
||||
PlaintextType convertPlaintextTypeFromType(mlir::MLIRContext *context,
|
||||
mlir::Type &type) {
|
||||
auto glwe = type.dyn_cast_or_null<GLWECipherTextType>();
|
||||
@@ -76,10 +76,10 @@ CleartextType convertCleartextTypeFromPType(mlir::MLIRContext *context,
|
||||
return CleartextType::get(context, type.getP() + 1);
|
||||
}
|
||||
|
||||
// convertCleartextTypeFromType create a cleartext type according the
|
||||
// precision of the given type argument. The type should be a GLWECipherText
|
||||
// (if operand is not yet lowered) or a LWECipherTextType (if operand is
|
||||
// already lowered).
|
||||
/// convertCleartextTypeFromType create a cleartext type according the
|
||||
/// precision of the given type argument. The type should be a GLWECipherText
|
||||
/// (if operand is not yet lowered) or a LWECipherTextType (if operand is
|
||||
/// already lowered).
|
||||
CleartextType convertCleartextTypeFromType(mlir::MLIRContext *context,
|
||||
mlir::Type &type) {
|
||||
auto glwe = type.dyn_cast_or_null<GLWECipherTextType>();
|
||||
|
||||
@@ -24,11 +24,10 @@ bool verifyEncryptedIntegerAndIntegerInputsConsistency(Operation &op,
|
||||
EncryptedIntegerType &a,
|
||||
IntegerType &b);
|
||||
|
||||
/** Shared error message for all ApplyLookupTable variant Op (several Dialect)
|
||||
* E.g. FHE.apply_lookup_table(input, lut)
|
||||
* Message when the lut tensor has an invalid size,
|
||||
* i.e. it cannot accomodate the input elements bitwidth
|
||||
*/
|
||||
/// Shared error message for all ApplyLookupTable variant Op (several Dialect)
|
||||
/// E.g. FHE.apply_lookup_table(input, lut)
|
||||
/// Message when the lut tensor has an invalid size,
|
||||
/// i.e. it cannot accomodate the input elements bitwidth
|
||||
template <class Op>
|
||||
void emitErrorBadLutSize(Op &op, std::string lutName, std::string inputName,
|
||||
int expectedSize, int bitWidth) {
|
||||
|
||||
@@ -19,7 +19,7 @@ extern void *dl_handle;
|
||||
struct WorkFunctionRegistry;
|
||||
extern WorkFunctionRegistry *node_level_work_function_registry;
|
||||
|
||||
// Recover the name of the work function
|
||||
/// Recover the name of the work function
|
||||
static inline const char *_dfr_get_function_name_from_address(void *fn) {
|
||||
Dl_info info;
|
||||
|
||||
@@ -38,8 +38,8 @@ static inline wfnptr _dfr_get_function_pointer_from_name(const char *fn_name) {
|
||||
return (wfnptr)ptr;
|
||||
}
|
||||
|
||||
// Determine where new task should run. For now just round-robin
|
||||
// distribution - TODO: optimise.
|
||||
/// Determine where new task should run. For now just round-robin
|
||||
/// distribution - TODO: optimise.
|
||||
static inline size_t _dfr_find_next_execution_locality() {
|
||||
static size_t num_nodes = hpx::get_num_localities().get();
|
||||
static std::atomic<std::size_t> next_locality{0};
|
||||
|
||||
@@ -26,7 +26,7 @@ typedef struct RuntimeContext {
|
||||
|
||||
RuntimeContext() {}
|
||||
|
||||
// Ensure that the engines map is not copied
|
||||
/// Ensure that the engines map is not copied
|
||||
RuntimeContext(const RuntimeContext &ctx)
|
||||
: evaluationKeys(ctx.evaluationKeys) {}
|
||||
RuntimeContext(const RuntimeContext &&other)
|
||||
|
||||
@@ -3,9 +3,7 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
/**
|
||||
Define the API exposed to the compiler for code generation.
|
||||
*/
|
||||
/// Define the API exposed to the compiler for code generation.
|
||||
|
||||
#ifndef CONCRETELANG_DFR_RUNTIME_API_H
|
||||
#define CONCRETELANG_DFR_RUNTIME_API_H
|
||||
|
||||
@@ -45,7 +45,7 @@ public:
|
||||
protected:
|
||||
ClientParameters clientParameters;
|
||||
void *(*func)(void *...);
|
||||
// Retain module and open shared lib alive
|
||||
/// Retain module and open shared lib alive
|
||||
std::shared_ptr<DynamicModule> module;
|
||||
};
|
||||
|
||||
|
||||
@@ -18,9 +18,9 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
// Compilation context that acts as the root owner of LLVM and MLIR
|
||||
// data structures directly and indirectly referenced by artefacts
|
||||
// produced by the `CompilerEngine`.
|
||||
/// Compilation context that acts as the root owner of LLVM and MLIR
|
||||
/// data structures directly and indirectly referenced by artefacts
|
||||
/// produced by the `CompilerEngine`.
|
||||
class CompilationContext {
|
||||
public:
|
||||
CompilationContext();
|
||||
@@ -68,8 +68,8 @@ struct CompilationOptions {
|
||||
|
||||
class CompilerEngine {
|
||||
public:
|
||||
// Result of an invocation of the `CompilerEngine` with optional
|
||||
// fields for the results produced by different stages.
|
||||
/// Result of an invocation of the `CompilerEngine` with optional
|
||||
/// fields for the results produced by different stages.
|
||||
class CompilationResult {
|
||||
public:
|
||||
CompilationResult(std::shared_ptr<CompilationContext> compilationContext =
|
||||
@@ -89,37 +89,35 @@ public:
|
||||
std::string outputDirPath;
|
||||
std::vector<std::string> objectsPath;
|
||||
std::vector<mlir::concretelang::ClientParameters> clientParametersList;
|
||||
/** Path to the runtime library. Will be linked to the output library if set
|
||||
*/
|
||||
/// Path to the runtime library. Will be linked to the output library if set
|
||||
std::string runtimeLibraryPath;
|
||||
bool cleanUp;
|
||||
|
||||
public:
|
||||
/** Create a library instance on which you can add compilation results.
|
||||
* Then you can emit a library file with the given path.
|
||||
* cleanUp at false keeps intermediate .obj files for later use. */
|
||||
/// Create a library instance on which you can add compilation results.
|
||||
/// Then you can emit a library file with the given path.
|
||||
/// cleanUp at false keeps intermediate .obj files for later use.
|
||||
Library(std::string outputDirPath, std::string runtimeLibraryPath = "",
|
||||
bool cleanUp = true)
|
||||
: outputDirPath(outputDirPath), runtimeLibraryPath(runtimeLibraryPath),
|
||||
cleanUp(cleanUp) {}
|
||||
/** Add a compilation result to the library */
|
||||
/// Add a compilation result to the library
|
||||
llvm::Expected<std::string> addCompilation(CompilationResult &compilation);
|
||||
/** Emit the library artifacts with the previously added compilation result
|
||||
*/
|
||||
/// Emit the library artifacts with the previously added compilation result
|
||||
llvm::Error emitArtifacts(bool sharedLib, bool staticLib,
|
||||
bool clientParameters, bool cppHeader);
|
||||
/** After a shared library has been emitted, its path is here */
|
||||
/// After a shared library has been emitted, its path is here
|
||||
std::string sharedLibraryPath;
|
||||
/** After a static library has been emitted, its path is here */
|
||||
/// After a static library has been emitted, its path is here
|
||||
std::string staticLibraryPath;
|
||||
|
||||
/** Returns the path of the shared library */
|
||||
/// Returns the path of the shared library
|
||||
static std::string getSharedLibraryPath(std::string outputDirPath);
|
||||
|
||||
/** Returns the path of the static library */
|
||||
/// Returns the path of the static library
|
||||
static std::string getStaticLibraryPath(std::string outputDirPath);
|
||||
|
||||
/** Returns the path of the static library */
|
||||
/// Returns the path of the static library
|
||||
static std::string getClientParametersPath(std::string outputDirPath);
|
||||
|
||||
// For advanced use
|
||||
@@ -132,56 +130,56 @@ public:
|
||||
~Library();
|
||||
|
||||
private:
|
||||
/** Emit a shared library with the previously added compilation result */
|
||||
/// Emit a shared library with the previously added compilation result
|
||||
llvm::Expected<std::string> emitStatic();
|
||||
/** Emit a shared library with the previously added compilation result */
|
||||
/// Emit a shared library with the previously added compilation result
|
||||
llvm::Expected<std::string> emitShared();
|
||||
/** Emit a json ClientParameters corresponding to library content */
|
||||
/// Emit a json ClientParameters corresponding to library content
|
||||
llvm::Expected<std::string> emitClientParametersJSON();
|
||||
/// Emit a client header file for this corresponding to library content
|
||||
llvm::Expected<std::string> emitCppHeader();
|
||||
};
|
||||
|
||||
// Specification of the exit stage of the compilation pipeline
|
||||
/// Specification of the exit stage of the compilation pipeline
|
||||
enum class Target {
|
||||
// Only read sources and produce corresponding MLIR module
|
||||
/// Only read sources and produce corresponding MLIR module
|
||||
ROUND_TRIP,
|
||||
|
||||
// Read sources and exit before any lowering
|
||||
/// Read sources and exit before any lowering
|
||||
FHE,
|
||||
|
||||
// Read sources and lower all FHE operations to TFHE
|
||||
// operations
|
||||
/// Read sources and lower all FHE operations to TFHE
|
||||
/// operations
|
||||
TFHE,
|
||||
|
||||
// Read sources and lower all FHE and TFHE operations to Concrete
|
||||
// operations
|
||||
/// Read sources and lower all FHE and TFHE operations to Concrete
|
||||
/// operations
|
||||
CONCRETE,
|
||||
|
||||
// Read sources and lower all FHE, TFHE and Concrete operations to BConcrete
|
||||
// operations
|
||||
/// Read sources and lower all FHE, TFHE and Concrete operations to
|
||||
/// BConcrete operations
|
||||
BCONCRETE,
|
||||
|
||||
// Read sources and lower all FHE, TFHE and Concrete
|
||||
// operations to canonical MLIR dialects. Cryptographic operations
|
||||
// are lowered to invocations of the concrete library.
|
||||
/// Read sources and lower all FHE, TFHE and Concrete
|
||||
/// operations to canonical MLIR dialects. Cryptographic operations
|
||||
/// are lowered to invocations of the concrete library.
|
||||
STD,
|
||||
|
||||
// Read sources and lower all FHE, TFHE and Concrete
|
||||
// operations to operations from the LLVM dialect. Cryptographic
|
||||
// operations are lowered to invocations of the concrete library.
|
||||
/// Read sources and lower all FHE, TFHE and Concrete
|
||||
/// operations to operations from the LLVM dialect. Cryptographic
|
||||
/// operations are lowered to invocations of the concrete library.
|
||||
LLVM,
|
||||
|
||||
// Same as `LLVM`, but lowers to actual LLVM IR instead of the
|
||||
// LLVM dialect
|
||||
/// Same as `LLVM`, but lowers to actual LLVM IR instead of the
|
||||
/// LLVM dialect
|
||||
LLVM_IR,
|
||||
|
||||
// Same as `LLVM_IR`, but invokes the LLVM optimization pipeline
|
||||
// to produce optimized LLVM IR
|
||||
/// Same as `LLVM_IR`, but invokes the LLVM optimization pipeline
|
||||
/// to produce optimized LLVM IR
|
||||
OPTIMIZED_LLVM_IR,
|
||||
|
||||
// Same as `OPTIMIZED_LLVM_IR`, but compiles and add an object file to a
|
||||
// futur library
|
||||
/// Same as `OPTIMIZED_LLVM_IR`, but compiles and add an object file to a
|
||||
/// futur library
|
||||
LIBRARY
|
||||
};
|
||||
|
||||
|
||||
@@ -11,21 +11,21 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
// Internal error class that allows for composing `llvm::Error`s
|
||||
// similar to `llvm::createStringError()`, but using stream-like
|
||||
// composition with `operator<<`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// llvm::Error foo(int i, size_t s, ...) {
|
||||
// ...
|
||||
// if(...) {
|
||||
// return StreamStringError()
|
||||
// << "Some error message with an integer: "
|
||||
// << i << " and a size_t: " << s;
|
||||
// }
|
||||
// ...
|
||||
// }
|
||||
/// Internal error class that allows for composing `llvm::Error`s
|
||||
/// similar to `llvm::createStringError()`, but using stream-like
|
||||
/// composition with `operator<<`.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// llvm::Error foo(int i, size_t s, ...) {
|
||||
/// ...
|
||||
/// if(...) {
|
||||
/// return StreamStringError()
|
||||
/// << "Some error message with an integer: "
|
||||
/// << i << " and a size_t: " << s;
|
||||
/// }
|
||||
/// ...
|
||||
/// }
|
||||
class StreamStringError {
|
||||
public:
|
||||
StreamStringError(const llvm::StringRef &s) : buffer(s.str()), os(buffer){};
|
||||
|
||||
@@ -55,9 +55,9 @@ private:
|
||||
mlir::LLVM::LLVMFunctionType type;
|
||||
std::string name;
|
||||
std::unique_ptr<mlir::ExecutionEngine> engine;
|
||||
// Tell if the DF parallelization was on or during compilation. This will be
|
||||
// useful to abort execution if the runtime doesn't support dataflow
|
||||
// execution, instead of having undefined symbol issues
|
||||
/// Tell if the DF parallelization was on or during compilation. This will be
|
||||
/// useful to abort execution if the runtime doesn't support dataflow
|
||||
/// execution, instead of having undefined symbol issues
|
||||
bool useDataflow = false;
|
||||
};
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
// Abstract base class for lambda arguments
|
||||
/// Abstract base class for lambda arguments
|
||||
class LambdaArgument
|
||||
: public llvm::RTTIExtends<LambdaArgument, llvm::RTTIRoot> {
|
||||
public:
|
||||
@@ -25,13 +25,13 @@ public:
|
||||
|
||||
template <typename T> bool isa() const { return llvm::isa<T>(*this); }
|
||||
|
||||
// Cast functions on constant instances
|
||||
/// Cast functions on constant instances
|
||||
template <typename T> const T &cast() const { return llvm::cast<T>(*this); }
|
||||
template <typename T> const T *dyn_cast() const {
|
||||
return llvm::dyn_cast<T>(this);
|
||||
}
|
||||
|
||||
// Cast functions for mutable instances
|
||||
/// Cast functions for mutable instances
|
||||
template <typename T> T &cast() { return llvm::cast<T>(*this); }
|
||||
template <typename T> T *dyn_cast() { return llvm::dyn_cast<T>(this); }
|
||||
|
||||
@@ -41,10 +41,10 @@ protected:
|
||||
LambdaArgument(){};
|
||||
};
|
||||
|
||||
// Class for integer arguments. `BackingIntType` is used as the data
|
||||
// type to hold the argument's value. The precision is the actual
|
||||
// precision of the value, which might be different from the precision
|
||||
// of the backing integer type.
|
||||
/// Class for integer arguments. `BackingIntType` is used as the data
|
||||
/// type to hold the argument's value. The precision is the actual
|
||||
/// precision of the value, which might be different from the precision
|
||||
/// of the backing integer type.
|
||||
template <typename BackingIntType = uint64_t>
|
||||
class IntLambdaArgument
|
||||
: public llvm::RTTIExtends<IntLambdaArgument<BackingIntType>,
|
||||
@@ -75,10 +75,10 @@ protected:
|
||||
template <typename BackingIntType>
|
||||
char IntLambdaArgument<BackingIntType>::ID = 0;
|
||||
|
||||
// Class for encrypted integer arguments. `BackingIntType` is used as
|
||||
// the data type to hold the argument's plaintext value. The precision
|
||||
// is the actual precision of the value, which might be different from
|
||||
// the precision of the backing integer type.
|
||||
/// Class for encrypted integer arguments. `BackingIntType` is used as
|
||||
/// the data type to hold the argument's plaintext value. The precision
|
||||
/// is the actual precision of the value, which might be different from
|
||||
/// the precision of the backing integer type.
|
||||
template <typename BackingIntType = uint64_t>
|
||||
class EIntLambdaArgument
|
||||
: public llvm::RTTIExtends<EIntLambdaArgument<BackingIntType>,
|
||||
@@ -91,8 +91,8 @@ template <typename BackingIntType>
|
||||
char EIntLambdaArgument<BackingIntType>::ID = 0;
|
||||
|
||||
namespace {
|
||||
// Calculates `accu *= factor` or returns an error if the result
|
||||
// would overflow
|
||||
/// Calculates `accu *= factor` or returns an error if the result
|
||||
/// would overflow
|
||||
template <typename AccuT, typename ValT>
|
||||
llvm::Error safeUnsignedMul(AccuT &accu, ValT factor) {
|
||||
static_assert(std::numeric_limits<AccuT>::is_integer &&
|
||||
@@ -113,10 +113,10 @@ llvm::Error safeUnsignedMul(AccuT &accu, ValT factor) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Class for Tensor arguments. This can either be plaintext tensors
|
||||
// (for `ScalarArgumentT = IntLambaArgument<T>`) or tensors
|
||||
// representing encrypted integers (for `ScalarArgumentT =
|
||||
// EIntLambaArgument<T>`).
|
||||
/// Class for Tensor arguments. This can either be plaintext tensors
|
||||
/// (for `ScalarArgumentT = IntLambaArgument<T>`) or tensors
|
||||
/// representing encrypted integers (for `ScalarArgumentT =
|
||||
/// EIntLambaArgument<T>`).
|
||||
template <typename ScalarArgumentT>
|
||||
class TensorLambdaArgument
|
||||
: public llvm::RTTIExtends<TensorLambdaArgument<ScalarArgumentT>,
|
||||
@@ -124,10 +124,10 @@ class TensorLambdaArgument
|
||||
public:
|
||||
typedef ScalarArgumentT scalar_type;
|
||||
|
||||
// Construct tensor argument from the one-dimensional array `value`,
|
||||
// but interpreting the array's values as a linearized
|
||||
// multi-dimensional tensor with the sizes of the dimensions
|
||||
// specified in `dimensions`.
|
||||
/// Construct tensor argument from the one-dimensional array `value`,
|
||||
/// but interpreting the array's values as a linearized
|
||||
/// multi-dimensional tensor with the sizes of the dimensions
|
||||
/// specified in `dimensions`.
|
||||
TensorLambdaArgument(
|
||||
llvm::ArrayRef<typename ScalarArgumentT::value_type> value,
|
||||
llvm::ArrayRef<int64_t> dimensions)
|
||||
@@ -135,8 +135,8 @@ public:
|
||||
std::copy(value.begin(), value.end(), std::back_inserter(this->value));
|
||||
}
|
||||
|
||||
// Construct a one-dimensional tensor argument from the
|
||||
// array `value`.
|
||||
/// Construct a one-dimensional tensor argument from the
|
||||
/// array `value`.
|
||||
TensorLambdaArgument(
|
||||
llvm::ArrayRef<typename ScalarArgumentT::value_type> value)
|
||||
: TensorLambdaArgument(value, {(int64_t)value.size()}) {}
|
||||
@@ -152,9 +152,9 @@ public:
|
||||
|
||||
const std::vector<int64_t> &getDimensions() const { return this->dimensions; }
|
||||
|
||||
// Returns the total number of elements in the tensor. If the number
|
||||
// of elements cannot be represented as a `size_t`, the method
|
||||
// returns an error.
|
||||
/// Returns the total number of elements in the tensor. If the number
|
||||
/// of elements cannot be represented as a `size_t`, the method
|
||||
/// returns an error.
|
||||
llvm::Expected<size_t> getNumElements() const {
|
||||
size_t accu = 1;
|
||||
|
||||
@@ -165,14 +165,14 @@ public:
|
||||
return accu;
|
||||
}
|
||||
|
||||
// Returns a bare pointer to the linearized values of the tensor
|
||||
// (constant version).
|
||||
/// Returns a bare pointer to the linearized values of the tensor
|
||||
/// (constant version).
|
||||
const typename ScalarArgumentT::value_type *getValue() const {
|
||||
return this->value.data();
|
||||
}
|
||||
|
||||
// Returns a bare pointer to the linearized values of the tensor (mutable
|
||||
// version).
|
||||
/// Returns a bare pointer to the linearized values of the tensor (mutable
|
||||
/// version).
|
||||
typename ScalarArgumentT::value_type *getValue() {
|
||||
return this->value.data();
|
||||
}
|
||||
|
||||
@@ -27,13 +27,13 @@ namespace {
|
||||
// `typedResult` must be declared at namespace scope due to return
|
||||
// type template specialization
|
||||
|
||||
// Helper function for implementing type-dependent preparation of the result.
|
||||
/// Helper function for implementing type-dependent preparation of the result.
|
||||
template <typename ResT>
|
||||
llvm::Expected<ResT> typedResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result);
|
||||
|
||||
// Specialization of `typedResult()` for scalar results, forwarding
|
||||
// scalar value to caller
|
||||
/// Specialization of `typedResult()` for scalar results, forwarding
|
||||
/// scalar value to caller
|
||||
template <>
|
||||
inline llvm::Expected<uint64_t> typedResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
@@ -60,14 +60,13 @@ typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
return std::move(clearResult.value());
|
||||
}
|
||||
|
||||
// Specializations of `typedResult()` for vector results, initializing
|
||||
// an `std::vector` of the right size with the results and forwarding
|
||||
// it to the caller with move semantics.
|
||||
//
|
||||
// Cannot factor out into a template template <typename T> inline
|
||||
// llvm::Expected<std::vector<uint8_t>>
|
||||
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); due
|
||||
// to ambiguity with scalar template
|
||||
/// Specializations of `typedResult()` for vector results, initializing
|
||||
/// an `std::vector` of the right size with the results and forwarding
|
||||
/// it to the caller with move semantics.
|
||||
/// Cannot factor out into a template template <typename T> inline
|
||||
/// llvm::Expected<std::vector<uint8_t>>
|
||||
/// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); due
|
||||
/// to ambiguity with scalar template
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint8_t>>
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
@@ -105,8 +104,8 @@ buildTensorLambdaResult(clientlib::KeySet &keySet,
|
||||
*tensorOrError, tensorDim);
|
||||
}
|
||||
|
||||
// pecialization of `typedResult()` for a single result wrapped into
|
||||
// a `LambdaArgument`.
|
||||
/// pecialization of `typedResult()` for a single result wrapped into
|
||||
/// a `LambdaArgument`.
|
||||
template <>
|
||||
inline llvm::Expected<std::unique_ptr<LambdaArgument>>
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
@@ -138,18 +137,18 @@ typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
|
||||
} // namespace
|
||||
|
||||
// Adaptor class that push arguments specified as instances of
|
||||
// `LambdaArgument` to `clientlib::EncryptedArguments`.
|
||||
/// Adaptor class that push arguments specified as instances of
|
||||
/// `LambdaArgument` to `clientlib::EncryptedArguments`.
|
||||
class LambdaArgumentAdaptor {
|
||||
public:
|
||||
// Checks if the argument `arg` is an plaintext / encrypted integer
|
||||
// argument or a plaintext / encrypted tensor argument with a
|
||||
// backing integer type `IntT` and push the argument to `encryptedArgs`.
|
||||
//
|
||||
// Returns `true` if `arg` has one of the types above and its value
|
||||
// was successfully added to `encryptedArgs`, `false` if none of the types
|
||||
// matches or an error if a type matched, but adding the argument to
|
||||
// `encryptedArgs` failed.
|
||||
/// Checks if the argument `arg` is an plaintext / encrypted integer
|
||||
/// argument or a plaintext / encrypted tensor argument with a
|
||||
/// backing integer type `IntT` and push the argument to `encryptedArgs`.
|
||||
///
|
||||
/// Returns `true` if `arg` has one of the types above and its value
|
||||
/// was successfully added to `encryptedArgs`, `false` if none of the types
|
||||
/// matches or an error if a type matched, but adding the argument to
|
||||
/// `encryptedArgs` failed.
|
||||
template <typename IntT>
|
||||
static inline llvm::Expected<bool>
|
||||
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
|
||||
@@ -174,7 +173,7 @@ public:
|
||||
return false;
|
||||
}
|
||||
|
||||
// Recursive case for `tryAddArg<IntT>(...)`
|
||||
/// Recursive case for `tryAddArg<IntT>(...)`
|
||||
template <typename IntT, typename NextIntT, typename... IntTs>
|
||||
static inline llvm::Expected<bool>
|
||||
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
|
||||
@@ -191,9 +190,9 @@ public:
|
||||
return true;
|
||||
}
|
||||
|
||||
// Attempts to push a single argument `arg` to `encryptedArgs`. Returns an
|
||||
// error if either the argument type is unsupported or if the argument types
|
||||
// is supported, but adding it to `encryptedArgs` failed.
|
||||
/// Attempts to push a single argument `arg` to `encryptedArgs`. Returns an
|
||||
/// error if either the argument type is unsupported or if the argument types
|
||||
/// is supported, but adding it to `encryptedArgs` failed.
|
||||
static inline llvm::Error
|
||||
addArgument(clientlib::EncryptedArguments &encryptedArgs,
|
||||
const LambdaArgument &arg, clientlib::KeySet &keySet) {
|
||||
|
||||
@@ -121,7 +121,7 @@ public:
|
||||
private:
|
||||
std::string outputPath;
|
||||
std::string runtimeLibraryPath;
|
||||
// Flags to select generated artifacts
|
||||
/// Flags to select generated artifacts
|
||||
bool generateSharedLib;
|
||||
bool generateStaticLib;
|
||||
bool generateClientParameters;
|
||||
|
||||
@@ -11,20 +11,20 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
// Returning references to instances of different classes `S` and `T`
|
||||
// is prohibited, even if `T` inherits from `S`. The wrapper class
|
||||
// `StreamWrap` can be initialized with a pointer to an instance of
|
||||
// `S` or any of its subclasses and acts as a proxy transparently
|
||||
// forwarding all calls to `S::operator<<`. The class thus hides the
|
||||
// dereferencing of the pointer and a reference to it can be used as a
|
||||
// replacement for a reference to `S`.
|
||||
/// Returning references to instances of different classes `S` and `T`
|
||||
/// is prohibited, even if `T` inherits from `S`. The wrapper class
|
||||
/// `StreamWrap` can be initialized with a pointer to an instance of
|
||||
/// `S` or any of its subclasses and acts as a proxy transparently
|
||||
/// forwarding all calls to `S::operator<<`. The class thus hides the
|
||||
/// dereferencing of the pointer and a reference to it can be used as a
|
||||
/// replacement for a reference to `S`.
|
||||
template <class S> class StreamWrap {
|
||||
public:
|
||||
StreamWrap() = delete;
|
||||
StreamWrap(S *s) : s(s) {}
|
||||
|
||||
// Forward all invocations of
|
||||
// `StreamWrap<S>::operator<<` to S::operator<<`.
|
||||
/// Forward all invocations of
|
||||
/// `StreamWrap<S>::operator<<` to S::operator<<`.
|
||||
template <class T> StreamWrap<S> &operator<<(const T &v) {
|
||||
*this->s << v;
|
||||
return *this;
|
||||
|
||||
@@ -6,9 +6,9 @@
|
||||
#ifndef CONCRETELANG_SUPPORT_MATH_H_
|
||||
#define CONCRETELANG_SUPPORT_MATH_H_
|
||||
|
||||
// Calculates (T)ceil(log2f(v))
|
||||
// TODO: Replace with some fancy bit twiddling hack
|
||||
/// Calculates (T)ceil(log2f(v))
|
||||
template <typename T> static T ceilLog2(const T v) {
|
||||
// TODO: Replace with some fancy bit twiddling hack
|
||||
T tmp = v;
|
||||
T log2 = 0;
|
||||
|
||||
|
||||
@@ -155,26 +155,26 @@ struct ConcreteIntToCleartextOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of `Concrete.zero_tensor`
|
||||
// operators.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "Concrete.zero_tensor" () :
|
||||
// tensor<...x!Concrete.lwe_ciphertext<lweDim,p>>
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = tensor.generate {
|
||||
// ^bb0(... : index):
|
||||
// %c0 = arith.constant 0 : i64
|
||||
// tensor.yield %z
|
||||
// }: tensor<...xlweDim+1xi64>
|
||||
// i64>
|
||||
// ```
|
||||
/// This rewrite pattern transforms any instance of `Concrete.zero_tensor`
|
||||
/// operators.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = "Concrete.zero_tensor" () :
|
||||
/// tensor<...x!Concrete.lwe_ciphertext<lweDim,p>>
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = tensor.generate {
|
||||
/// ^bb0(... : index):
|
||||
/// %c0 = arith.constant 0 : i64
|
||||
/// tensor.yield %z
|
||||
/// }: tensor<...xlweDim+1xi64>
|
||||
/// i64>
|
||||
/// ```
|
||||
template <typename ZeroOp>
|
||||
struct ZeroOpPattern : public mlir::OpRewritePattern<ZeroOp> {
|
||||
ZeroOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
@@ -204,19 +204,19 @@ struct ZeroOpPattern : public mlir::OpRewritePattern<ZeroOp> {
|
||||
};
|
||||
};
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// `ConcreteOp` to an instance of `BConcreteOp`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// %0 = "ConcreteOp"(%arg0, ...) :
|
||||
// (!Concrete.lwe_ciphertext<lwe_dimension, p>, ...) ->
|
||||
// (!Concrete.lwe_ciphertext<lwe_dimension, p>)
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// %0 = "BConcreteOp"(%arg0, ...) : (tensor<dimension+1, i64>>, ..., ) ->
|
||||
// (tensor<dimension+1, i64>>)
|
||||
/// This template rewrite pattern transforms any instance of
|
||||
/// `ConcreteOp` to an instance of `BConcreteOp`.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// %0 = "ConcreteOp"(%arg0, ...) :
|
||||
/// (!Concrete.lwe_ciphertext<lwe_dimension, p>, ...) ->
|
||||
/// (!Concrete.lwe_ciphertext<lwe_dimension, p>)
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// %0 = "BConcreteOp"(%arg0, ...) : (tensor<dimension+1, i64>>, ..., ) ->
|
||||
/// (tensor<dimension+1, i64>>)
|
||||
template <typename ConcreteOp, typename BConcreteOp>
|
||||
struct LowToBConcrete : public mlir::OpRewritePattern<ConcreteOp> {
|
||||
LowToBConcrete(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
@@ -248,27 +248,27 @@ struct LowToBConcrete : public mlir::OpRewritePattern<ConcreteOp> {
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of
|
||||
// `Concrete.glwe_from_table` operators.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "Concrete.glwe_from_table"(%tlu)
|
||||
// : (tensor<$Dxi64>) ->
|
||||
// !Concrete.glwe_ciphertext<$polynomialSize,$glweDimension,$p>
|
||||
// ```
|
||||
//
|
||||
// with $D = 2^$p
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = linalg.init_tensor [polynomialSize*(glweDimension+1)]
|
||||
// : tensor<polynomialSize*(glweDimension+1), i64>
|
||||
// "BConcrete.fill_glwe_from_table" : (%0, polynomialSize, glweDimension, %tlu)
|
||||
// : tensor<polynomialSize*(glweDimension+1), i64>, i64, i64, tensor<$Dxi64>
|
||||
// ```
|
||||
/// This rewrite pattern transforms any instance of
|
||||
/// `Concrete.glwe_from_table` operators.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = "Concrete.glwe_from_table"(%tlu)
|
||||
/// : (tensor<$Dxi64>) ->
|
||||
/// !Concrete.glwe_ciphertext<$polynomialSize,$glweDimension,$p>
|
||||
/// ```
|
||||
///
|
||||
/// with $D = 2^$p
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = linalg.init_tensor [polynomialSize*(glweDimension+1)]
|
||||
/// : tensor<polynomialSize*(glweDimension+1), i64>
|
||||
/// "BConcrete.fill_glwe_from_table" : (%0, polynomialSize, glweDimension, %tlu)
|
||||
/// : tensor<polynomialSize*(glweDimension+1), i64>, i64, i64, tensor<$Dxi64>
|
||||
/// ```
|
||||
struct GlweFromTablePattern : public mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::GlweFromTable> {
|
||||
GlweFromTablePattern(::mlir::MLIRContext *context,
|
||||
@@ -305,26 +305,26 @@ struct GlweFromTablePattern : public mlir::OpRewritePattern<
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of
|
||||
// `tensor.extract_slice` operators that operates on tensor of lwe ciphertext.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = tensor.extract_slice %arg0
|
||||
// [offsets...] [sizes...] [strides...]
|
||||
// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> to
|
||||
// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = tensor.extract_slice %arg0
|
||||
// [offsets..., 0] [sizes..., lweDimension+1] [strides..., 1]
|
||||
// : tensor<...xlweDimension+1,i64> to
|
||||
// tensor<...xlweDimension+1,i64>
|
||||
// ```
|
||||
/// This rewrite pattern transforms any instance of
|
||||
/// `tensor.extract_slice` operators that operates on tensor of lwe ciphertext.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = tensor.extract_slice %arg0
|
||||
/// [offsets...] [sizes...] [strides...]
|
||||
/// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> to
|
||||
/// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = tensor.extract_slice %arg0
|
||||
/// [offsets..., 0] [sizes..., lweDimension+1] [strides..., 1]
|
||||
/// : tensor<...xlweDimension+1,i64> to
|
||||
/// tensor<...xlweDimension+1,i64>
|
||||
/// ```
|
||||
struct ExtractSliceOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::tensor::ExtractSliceOp> {
|
||||
ExtractSliceOpPattern(::mlir::MLIRContext *context,
|
||||
@@ -380,27 +380,26 @@ struct ExtractSliceOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of
|
||||
// `tensor.extract` operators that operates on tensor of lwe ciphertext.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = tensor.extract %t[offsets...]
|
||||
// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %1 = tensor.extract_slice %arg0
|
||||
// [offsets...] [1..., lweDimension+1] [1...]
|
||||
// : tensor<...xlweDimension+1,i64> to
|
||||
// tensor<1...xlweDimension+1,i64>
|
||||
// %0 = linalg.tensor_collapse_shape %0 [[...]] :
|
||||
// tensor<1x1xlweDimension+1xi64> into tensor<lweDimension+1xi64>
|
||||
// ```
|
||||
//
|
||||
/// This rewrite pattern transforms any instance of
|
||||
/// `tensor.extract` operators that operates on tensor of lwe ciphertext.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = tensor.extract %t[offsets...]
|
||||
/// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %1 = tensor.extract_slice %arg0
|
||||
/// [offsets...] [1..., lweDimension+1] [1...]
|
||||
/// : tensor<...xlweDimension+1,i64> to
|
||||
/// tensor<1...xlweDimension+1,i64>
|
||||
/// %0 = linalg.tensor_collapse_shape %0 [[...]] :
|
||||
/// tensor<1x1xlweDimension+1xi64> into tensor<lweDimension+1xi64>
|
||||
/// ```
|
||||
// TODO: since they are a bug on lowering extract_slice with rank reduction we
|
||||
// add a linalg.tensor_collapse_shape after the extract_slice without rank
|
||||
// reduction. See
|
||||
@@ -487,26 +486,26 @@ struct ExtractOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of
|
||||
// `tensor.insert_slice` operators that operates on tensor of lwe ciphertext.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = tensor.insert_slice %arg1
|
||||
// into %arg0[offsets...] [sizes...] [strides...]
|
||||
// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> into
|
||||
// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = tensor.insert_slice %arg1
|
||||
// into %arg0[offsets..., 0] [sizes..., lweDimension+1] [strides..., 1]
|
||||
// : tensor<...xlweDimension+1xi64> into
|
||||
// tensor<...xlweDimension+1xi64>
|
||||
// ```
|
||||
/// This rewrite pattern transforms any instance of
|
||||
/// `tensor.insert_slice` operators that operates on tensor of lwe ciphertext.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = tensor.insert_slice %arg1
|
||||
/// into %arg0[offsets...] [sizes...] [strides...]
|
||||
/// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> into
|
||||
/// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = tensor.insert_slice %arg1
|
||||
/// into %arg0[offsets..., 0] [sizes..., lweDimension+1] [strides..., 1]
|
||||
/// : tensor<...xlweDimension+1xi64> into
|
||||
/// tensor<...xlweDimension+1xi64>
|
||||
/// ```
|
||||
struct InsertSliceOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::tensor::InsertSliceOp> {
|
||||
InsertSliceOpPattern(::mlir::MLIRContext *context,
|
||||
@@ -559,28 +558,28 @@ struct InsertSliceOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of `tensor.insert`
|
||||
// operators that operates on an lwe ciphertexts to a
|
||||
// `tensor.insert_slice` op operating on the bufferized representation
|
||||
// of the ciphertext.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = tensor.insert %arg1
|
||||
// into %arg0[offsets...]
|
||||
// : !Concrete.lwe_ciphertext<lweDimension,p> into
|
||||
// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = tensor.insert_slice %arg1
|
||||
// into %arg0[offsets..., 0] [sizes..., lweDimension+1] [strides..., 1]
|
||||
// : tensor<lweDimension+1xi64> into
|
||||
// tensor<...xlweDimension+1xi64>
|
||||
// ```
|
||||
/// This rewrite pattern transforms any instance of `tensor.insert`
|
||||
/// operators that operates on an lwe ciphertexts to a
|
||||
/// `tensor.insert_slice` op operating on the bufferized representation
|
||||
/// of the ciphertext.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = tensor.insert %arg1
|
||||
/// into %arg0[offsets...]
|
||||
/// : !Concrete.lwe_ciphertext<lweDimension,p> into
|
||||
/// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = tensor.insert_slice %arg1
|
||||
/// into %arg0[offsets..., 0] [sizes..., lweDimension+1] [strides..., 1]
|
||||
/// : tensor<lweDimension+1xi64> into
|
||||
/// tensor<...xlweDimension+1xi64>
|
||||
/// ```
|
||||
struct InsertOpPattern : public mlir::OpRewritePattern<mlir::tensor::InsertOp> {
|
||||
InsertOpPattern(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
@@ -628,31 +627,31 @@ struct InsertOpPattern : public mlir::OpRewritePattern<mlir::tensor::InsertOp> {
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of
|
||||
// `tensor.from_elements` operators that operates on tensor of lwe ciphertext.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = tensor.from_elements %e0, ..., %e(n-1)
|
||||
// : tensor<Nx!Concrete.lwe_ciphertext<lweDim,p>>
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %m = memref.alloc() : memref<NxlweDim+1xi64>
|
||||
// %s0 = memref.subview %m[0, 0][1, lweDim+1][1, 1] : memref<lweDim+1xi64>
|
||||
// %m0 = memref.buffer_cast %e0 : memref<lweDim+1xi64>
|
||||
// memref.copy %m0, s0 : memref<lweDim+1xi64> to memref<lweDim+1xi64>
|
||||
// ...
|
||||
// %s(n-1) = memref.subview %m[(n-1), 0][1, lweDim+1][1, 1]
|
||||
// : memref<lweDim+1xi64>
|
||||
// %m(n-1) = memref.buffer_cast %e(n-1) : memref<lweDim+1xi64>
|
||||
// memref.copy %e(n-1), s(n-1)
|
||||
// : memref<lweDim+1xi64> to memref<lweDim+1xi64>
|
||||
// %0 = memref.tensor_load %m : memref<NxlweDim+1xi64>
|
||||
// ```
|
||||
/// This rewrite pattern transforms any instance of
|
||||
/// `tensor.from_elements` operators that operates on tensor of lwe ciphertext.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = tensor.from_elements %e0, ..., %e(n-1)
|
||||
/// : tensor<Nx!Concrete.lwe_ciphertext<lweDim,p>>
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %m = memref.alloc() : memref<NxlweDim+1xi64>
|
||||
/// %s0 = memref.subview %m[0, 0][1, lweDim+1][1, 1] : memref<lweDim+1xi64>
|
||||
/// %m0 = memref.buffer_cast %e0 : memref<lweDim+1xi64>
|
||||
/// memref.copy %m0, s0 : memref<lweDim+1xi64> to memref<lweDim+1xi64>
|
||||
/// ...
|
||||
/// %s(n-1) = memref.subview %m[(n-1), 0][1, lweDim+1][1, 1]
|
||||
/// : memref<lweDim+1xi64>
|
||||
/// %m(n-1) = memref.buffer_cast %e(n-1) : memref<lweDim+1xi64>
|
||||
/// memref.copy %e(n-1), s(n-1)
|
||||
/// : memref<lweDim+1xi64> to memref<lweDim+1xi64>
|
||||
/// %0 = memref.tensor_load %m : memref<NxlweDim+1xi64>
|
||||
/// ```
|
||||
struct FromElementsOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::tensor::FromElementsOp> {
|
||||
FromElementsOpPattern(::mlir::MLIRContext *context,
|
||||
@@ -715,26 +714,26 @@ struct FromElementsOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// `ShapeOp` operators that operates on tensor of lwe ciphertext by adding the
|
||||
// lwe size as a size of the tensor result and by adding a trivial reassociation
|
||||
// at the end of the reassociations map.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "ShapeOp" %arg0 [reassocations...]
|
||||
// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> into
|
||||
// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "ShapeOp" %arg0 [reassociations..., [inRank or outRank]]
|
||||
// : tensor<...xlweDimesion+1xi64> into
|
||||
// tensor<...xlweDimesion+1xi64>
|
||||
// ```
|
||||
/// This template rewrite pattern transforms any instance of
|
||||
/// `ShapeOp` operators that operates on tensor of lwe ciphertext by adding the
|
||||
/// lwe size as a size of the tensor result and by adding a trivial
|
||||
/// reassociation at the end of the reassociations map.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = "ShapeOp" %arg0 [reassocations...]
|
||||
/// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> into
|
||||
/// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = "ShapeOp" %arg0 [reassociations..., [inRank or outRank]]
|
||||
/// : tensor<...xlweDimesion+1xi64> into
|
||||
/// tensor<...xlweDimesion+1xi64>
|
||||
/// ```
|
||||
template <typename ShapeOp, typename VecTy, bool inRank>
|
||||
struct TensorShapeOpPattern : public mlir::OpRewritePattern<ShapeOp> {
|
||||
TensorShapeOpPattern(::mlir::MLIRContext *context,
|
||||
@@ -775,8 +774,8 @@ struct TensorShapeOpPattern : public mlir::OpRewritePattern<ShapeOp> {
|
||||
};
|
||||
};
|
||||
|
||||
// Add the instantiated TensorShapeOpPattern rewrite pattern with the `ShapeOp`
|
||||
// to the patterns set and populate the conversion target.
|
||||
/// Add the instantiated TensorShapeOpPattern rewrite pattern with the `ShapeOp`
|
||||
/// to the patterns set and populate the conversion target.
|
||||
template <typename ShapeOp, typename VecTy, bool inRank>
|
||||
void insertTensorShapeOpPattern(mlir::MLIRContext &context,
|
||||
mlir::RewritePatternSet &patterns,
|
||||
@@ -789,26 +788,26 @@ void insertTensorShapeOpPattern(mlir::MLIRContext &context,
|
||||
});
|
||||
}
|
||||
|
||||
// Rewrites `linalg.init_tensor` ops for which the converted type in
|
||||
// BConcrete is different from the original type.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```
|
||||
// linalg.init_tensor [4] : tensor<4x!Concrete.lwe_ciphertext<4096,6>>
|
||||
// ```
|
||||
//
|
||||
// which has become after type conversion:
|
||||
//
|
||||
// ```
|
||||
// linalg.init_tensor [4] : tensor<4x4097xi64>
|
||||
// ```
|
||||
//
|
||||
// is finally fixed:
|
||||
//
|
||||
// ```
|
||||
// linalg.init_tensor [4, 4097] : tensor<4x4097xi64>
|
||||
// ```
|
||||
/// Rewrites `linalg.init_tensor` ops for which the converted type in
|
||||
/// BConcrete is different from the original type.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```
|
||||
/// linalg.init_tensor [4] : tensor<4x!Concrete.lwe_ciphertext<4096,6>>
|
||||
/// ```
|
||||
///
|
||||
/// which has become after type conversion:
|
||||
///
|
||||
/// ```
|
||||
/// linalg.init_tensor [4] : tensor<4x4097xi64>
|
||||
/// ```
|
||||
///
|
||||
/// is finally fixed:
|
||||
///
|
||||
/// ```
|
||||
/// linalg.init_tensor [4, 4097] : tensor<4x4097xi64>
|
||||
/// ```
|
||||
struct InitTensorOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::linalg::InitTensorOp> {
|
||||
InitTensorOpPattern(::mlir::MLIRContext *context,
|
||||
|
||||
@@ -40,41 +40,41 @@ struct DotToLinalgGeneric
|
||||
: ::mlir::OpRewritePattern<::mlir::concretelang::FHELinalg::Dot>(
|
||||
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
// This rewrite pattern transforms any instance of
|
||||
// `FHELinalg.dot_eint_int` to an instance of `linalg.generic` with an
|
||||
// appropriate region using `FHE.mul_eint_int` and
|
||||
// `FHE.add_eint` operations, an appropriate specification for the
|
||||
// iteration dimensions and appropriate operations managing the
|
||||
// accumulator of `linalg.generic`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// %o = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
|
||||
// (tensor<4x!FHE.eint<0>>,
|
||||
// tensor<4xi32>) -> (!FHE.eint<0>)
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// %0 = "FHE.zero_tensor"() : () -> tensor<1x!FHE.eint<0>>
|
||||
// %1 = linalg.generic {
|
||||
// indexing_maps = [#map0, #map0, #map1],
|
||||
// iterator_types = ["reduction"]
|
||||
// }
|
||||
// ins(%arg0, %arg1 : tensor<2x!FHE.eint<0>>, tensor<2xi32>)
|
||||
// outs(%0 : tensor<1x!FHE.eint<0>>) {
|
||||
// ^bb0(%arg2: !FHE.eint<0>, %arg3: i32, %arg4: !FHE.eint<0>):
|
||||
// %4 = "FHE.mul_eint_int"(%arg2, %arg3) :
|
||||
// (!FHE.eint<0>, i32) -> !FHE.eint<0>
|
||||
//
|
||||
// %5 = "FHE.add_eint"(%4, %arg4) :
|
||||
// (!FHE.eint<0>, !FHE.eint<0>) -> !FHE.eint<0>
|
||||
//
|
||||
// linalg.yield %5 : !FHE.eint<0>
|
||||
// } -> tensor<1x!FHE.eint<0>>
|
||||
//
|
||||
// %c0 = constant 0 : index
|
||||
// %o = tensor.extract %1[%c0] : tensor<1x!FHE.eint<0>>
|
||||
//
|
||||
/// This rewrite pattern transforms any instance of
|
||||
/// `FHELinalg.dot_eint_int` to an instance of `linalg.generic` with an
|
||||
/// appropriate region using `FHE.mul_eint_int` and
|
||||
/// `FHE.add_eint` operations, an appropriate specification for the
|
||||
/// iteration dimensions and appropriate operations managing the
|
||||
/// accumulator of `linalg.generic`.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// %o = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
|
||||
/// (tensor<4x!FHE.eint<0>>,
|
||||
/// tensor<4xi32>) -> (!FHE.eint<0>)
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// %0 = "FHE.zero_tensor"() : () -> tensor<1x!FHE.eint<0>>
|
||||
/// %1 = linalg.generic {
|
||||
/// indexing_maps = [#map0, #map0, #map1],
|
||||
/// iterator_types = ["reduction"]
|
||||
/// }
|
||||
/// ins(%arg0, %arg1 : tensor<2x!FHE.eint<0>>, tensor<2xi32>)
|
||||
/// outs(%0 : tensor<1x!FHE.eint<0>>) {
|
||||
/// ^bb0(%arg2: !FHE.eint<0>, %arg3: i32, %arg4: !FHE.eint<0>):
|
||||
/// %4 = "FHE.mul_eint_int"(%arg2, %arg3) :
|
||||
/// (!FHE.eint<0>, i32) -> !FHE.eint<0>
|
||||
///
|
||||
/// %5 = "FHE.add_eint"(%4, %arg4) :
|
||||
/// (!FHE.eint<0>, !FHE.eint<0>) -> !FHE.eint<0>
|
||||
///
|
||||
/// linalg.yield %5 : !FHE.eint<0>
|
||||
/// } -> tensor<1x!FHE.eint<0>>
|
||||
///
|
||||
/// %c0 = constant 0 : index
|
||||
/// %o = tensor.extract %1[%c0] : tensor<1x!FHE.eint<0>>
|
||||
///
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(::mlir::concretelang::FHELinalg::Dot dotOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
@@ -149,16 +149,16 @@ getBroadcastedAffineMap(const mlir::RankedTensorType &resultType,
|
||||
rewriter.getContext());
|
||||
}
|
||||
|
||||
// This create an affine map following the broadcasting rules, but also takes
|
||||
// out one specific element of the LUT from the LUT dimension, which should be
|
||||
// the last.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// resultType: 4x2x5, operandType: 4x2x8, lut_index: 3
|
||||
// return: affine_map<(d0, d1, d2) -> (d0, d1, 3)
|
||||
// last dimension of the operand is the lut size, and we take the map takes out
|
||||
// the element at index 3
|
||||
/// This create an affine map following the broadcasting rules, but also takes
|
||||
/// out one specific element of the LUT from the LUT dimension, which should be
|
||||
/// the last.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// resultType: 4x2x5, operandType: 4x2x8, lut_index: 3
|
||||
/// return: affine_map<(d0, d1, d2) -> (d0, d1, 3)
|
||||
/// last dimension of the operand is the lut size, and we take the map takes out
|
||||
/// the element at index 3
|
||||
mlir::AffineMap
|
||||
getBroadcastedAffineMapMultiLUT(const mlir::RankedTensorType &resultType,
|
||||
const mlir::RankedTensorType &operandType,
|
||||
@@ -183,44 +183,44 @@ getBroadcastedAffineMapMultiLUT(const mlir::RankedTensorType &resultType,
|
||||
rewriter.getContext());
|
||||
}
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// operators `FHELinalgOp` that implements the broadasting rules to an
|
||||
// instance of `linalg.generic` with an appropriate region using `FHEOp`
|
||||
// operation, an appropriate specification for the iteration dimensions and
|
||||
// appropriate operations managing the accumulator of `linalg.generic`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// %res = FHELinalg.op(%lhs, %rhs):
|
||||
// (tensor<D$Ax...xD1x!FHE.eint<p>>, tensor<D$B'x...xD1'xT>)
|
||||
// -> tensor<DR"x...xD1"x!FHE.eint<p>>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// #maps_0 = [
|
||||
// affine_map<(a$R", ..., a$A, ..., a1) ->
|
||||
// (dim(lhs, $A) == 1 ? 0 : a$A,..., dim(lhs, 1) == 1 ? 0 : a1)>,
|
||||
// affine_map<(a$R", ..., a1) ->
|
||||
// (dim(rhs, $B') == 1 ? 0 : a$B', ..., dim(rhs, 1) == 1 ? 0 : a1)>,
|
||||
// affine_map<(a$R", ..., a1) -> (a$R", ..., a1)
|
||||
// ]
|
||||
// #attributes_0 {
|
||||
// indexing_maps = #maps_0,
|
||||
// iterator_types = ["parallel", ..., "parallel"], // $R" parallel
|
||||
// }
|
||||
// %init = linalg.init_tensor [DR",...,D1"]
|
||||
// : tensor<DR"x...xD1"x!FHE.eint<p>>
|
||||
// %res = linalg.generic {
|
||||
// ins(%lhs, %rhs: tensor<DAx...xD1x!FHE.eint<p>>,tensor<DB'x...xD1'xT>)
|
||||
// outs(%init : tensor<DR"x...xD1"x!FHE.eint<p>>)
|
||||
// {
|
||||
// ^bb0(%arg0: !FHE.eint<p>, %arg1: T):
|
||||
// %0 = FHE.op(%arg0, %arg1): !FHE.eint<p>, T ->
|
||||
// !FHE.eint<p>
|
||||
// linalg.yield %0 : !FHE.eint<p>
|
||||
// }
|
||||
// }
|
||||
//
|
||||
/// This template rewrite pattern transforms any instance of
|
||||
/// operators `FHELinalgOp` that implements the broadasting rules to an
|
||||
/// instance of `linalg.generic` with an appropriate region using `FHEOp`
|
||||
/// operation, an appropriate specification for the iteration dimensions and
|
||||
/// appropriate operations managing the accumulator of `linalg.generic`.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// %res = FHELinalg.op(%lhs, %rhs):
|
||||
/// (tensor<D$Ax...xD1x!FHE.eint<p>>, tensor<D$B'x...xD1'xT>)
|
||||
/// -> tensor<DR"x...xD1"x!FHE.eint<p>>
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// #maps_0 = [
|
||||
/// affine_map<(a$R", ..., a$A, ..., a1) ->
|
||||
/// (dim(lhs, $A) == 1 ? 0 : a$A,..., dim(lhs, 1) == 1 ? 0 : a1)>,
|
||||
/// affine_map<(a$R", ..., a1) ->
|
||||
/// (dim(rhs, $B') == 1 ? 0 : a$B', ..., dim(rhs, 1) == 1 ? 0 : a1)>,
|
||||
/// affine_map<(a$R", ..., a1) -> (a$R", ..., a1)
|
||||
/// ]
|
||||
/// #attributes_0 {
|
||||
/// indexing_maps = #maps_0,
|
||||
/// iterator_types = ["parallel", ..., "parallel"], // $R" parallel
|
||||
/// }
|
||||
/// %init = linalg.init_tensor [DR",...,D1"]
|
||||
/// : tensor<DR"x...xD1"x!FHE.eint<p>>
|
||||
/// %res = linalg.generic {
|
||||
/// ins(%lhs, %rhs: tensor<DAx...xD1x!FHE.eint<p>>,tensor<DB'x...xD1'xT>)
|
||||
/// outs(%init : tensor<DR"x...xD1"x!FHE.eint<p>>)
|
||||
/// {
|
||||
/// ^bb0(%arg0: !FHE.eint<p>, %arg1: T):
|
||||
/// %0 = FHE.op(%arg0, %arg1): !FHE.eint<p>, T ->
|
||||
/// !FHE.eint<p>
|
||||
/// linalg.yield %0 : !FHE.eint<p>
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
template <typename FHELinalgOp, typename FHEOp>
|
||||
struct FHELinalgOpToLinalgGeneric : public mlir::OpRewritePattern<FHELinalgOp> {
|
||||
FHELinalgOpToLinalgGeneric(::mlir::MLIRContext *context,
|
||||
@@ -290,51 +290,51 @@ llvm::SmallVector<llvm::StringRef> parallelIteratorType(int n) {
|
||||
return llvm::SmallVector<llvm::StringRef>(n, "parallel");
|
||||
}
|
||||
|
||||
// This class rewrite pattern transforms any instance of
|
||||
// operators `FHELinalg.ApplyMappedLookupTableEintOp` that implements the
|
||||
// broadasting rules to an instance of `linalg.generic` with an appropriate
|
||||
// region using `FHE.ApplyLookupTableEintOp` operation, an appropriate
|
||||
// specification for the iteration dimensions and appropriate operations
|
||||
// managing the accumulator of `linalg.generic`.
|
||||
//
|
||||
// The current implementation does not rely on 'tensor.extract_slice'
|
||||
// because of a bug in lowering this operation.
|
||||
//
|
||||
// Example:
|
||||
// %res = "FHELinalg.apply_mapped_lookup_table"(%t, %luts, %map)
|
||||
// : (tensor<2x3x!FHE.eint<2>>, tensor<5x4xi64>, tensor<2x3xindex>)
|
||||
// -> tensor<2x3x!FHE.eint<2>>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// #map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// %init = linalg.init_tensor [2, 3] : tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// %output = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types
|
||||
// = ["parallel", "parallel"]} ins(%arg0, %arg2 :
|
||||
// tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>, tensor<2x3xindex>) outs(%0 :
|
||||
// tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// ^bb0(%arg3: !TFHE.glwe<{_,_,_}{2}>, %lut_idx: index, %arg5:
|
||||
// !TFHE.glwe<{_,_,_}{2}>): // no predecessors
|
||||
// // SHOULD BE
|
||||
// %lut = tensor.extract_slice %arg1[%[[LUTIDX]], 0] [1,4] [1, 1]
|
||||
// : tensor<5x4xi64> to tensor<4xi64>
|
||||
// // BUT IS
|
||||
// %i0 = arith.constant 0 : index
|
||||
// ...
|
||||
// %i3 = arith.constant 3 : index
|
||||
// %e0 = tensor.extract %arg5[%lut_idx, %i0] : tensor<5x4xi64>
|
||||
// ...
|
||||
// %e3 = tensor.extract %arg5[%lut_idx, %i3] : tensor<5x4xi64>
|
||||
// %lut = tensor.from_elements %e0, ..., %e3 : tensor<4xi64>
|
||||
// %res = "TFHE.apply_lookup_table"(%arg3, %[[LUT]])
|
||||
// {baseLogBS = -1 : i32, baseLogKS = -1 : i32, glweDimension
|
||||
// = -1 : i32,
|
||||
// levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS =
|
||||
// -1 : i32, polynomialSize = -1 : i32}
|
||||
// : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) ->
|
||||
// !TFHE.glwe<{_,_,_}{2}> linalg.yield %res :
|
||||
// !TFHE.glwe<{_,_,_}{2}>
|
||||
// } -> tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>
|
||||
/// This class rewrite pattern transforms any instance of
|
||||
/// operators `FHELinalg.ApplyMappedLookupTableEintOp` that implements the
|
||||
/// broadasting rules to an instance of `linalg.generic` with an appropriate
|
||||
/// region using `FHE.ApplyLookupTableEintOp` operation, an appropriate
|
||||
/// specification for the iteration dimensions and appropriate operations
|
||||
/// managing the accumulator of `linalg.generic`.
|
||||
///
|
||||
/// The current implementation does not rely on 'tensor.extract_slice'
|
||||
/// because of a bug in lowering this operation.
|
||||
///
|
||||
/// Example:
|
||||
/// %res = "FHELinalg.apply_mapped_lookup_table"(%t, %luts, %map)
|
||||
/// : (tensor<2x3x!FHE.eint<2>>, tensor<5x4xi64>, tensor<2x3xindex>)
|
||||
/// -> tensor<2x3x!FHE.eint<2>>
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// #map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
/// %init = linalg.init_tensor [2, 3] : tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>
|
||||
/// %output = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types
|
||||
/// = ["parallel", "parallel"]} ins(%arg0, %arg2 :
|
||||
/// tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>, tensor<2x3xindex>) outs(%0 :
|
||||
/// tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
/// ^bb0(%arg3: !TFHE.glwe<{_,_,_}{2}>, %lut_idx: index, %arg5:
|
||||
/// !TFHE.glwe<{_,_,_}{2}>): // no predecessors
|
||||
/// // SHOULD BE
|
||||
/// %lut = tensor.extract_slice %arg1[%[[LUTIDX]], 0] [1,4] [1, 1]
|
||||
/// : tensor<5x4xi64> to tensor<4xi64>
|
||||
/// // BUT IS
|
||||
/// %i0 = arith.constant 0 : index
|
||||
/// ...
|
||||
/// %i3 = arith.constant 3 : index
|
||||
/// %e0 = tensor.extract %arg5[%lut_idx, %i0] : tensor<5x4xi64>
|
||||
/// ...
|
||||
/// %e3 = tensor.extract %arg5[%lut_idx, %i3] : tensor<5x4xi64>
|
||||
/// %lut = tensor.from_elements %e0, ..., %e3 : tensor<4xi64>
|
||||
/// %res = "TFHE.apply_lookup_table"(%arg3, %[[LUT]])
|
||||
/// {baseLogBS = -1 : i32, baseLogKS = -1 : i32,
|
||||
/// glweDimension = -1 : i32,
|
||||
/// levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS =
|
||||
/// -1 : i32, polynomialSize = -1 : i32}
|
||||
/// : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) ->
|
||||
/// !TFHE.glwe<{_,_,_}{2}> linalg.yield %res :
|
||||
/// !TFHE.glwe<{_,_,_}{2}>
|
||||
/// } -> tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>
|
||||
|
||||
namespace FHELinalg = mlir::concretelang::FHELinalg;
|
||||
|
||||
@@ -450,50 +450,50 @@ struct FHELinalgApplyMappedLookupTableToLinalgGeneric
|
||||
};
|
||||
};
|
||||
|
||||
// This class rewrite pattern transforms any instance of
|
||||
// operators `FHELinalg.ApplyMultiLookupTableEintOp` that implements the
|
||||
// broadasting rules to an instance of `linalg.generic` with an appropriate
|
||||
// region using `FHE.ApplyLookupTableEintOp` operation, an appropriate
|
||||
// specification for the iteration dimensions and appropriate operaztions
|
||||
// managing the accumulator of `linalg.generic`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// %res = "FHELinalg.apply_multi_lookup_table"(%t, %luts):
|
||||
// (tensor<4x3x!FHE.eint<2>>, tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// #maps_0 = [
|
||||
// affine_map<(d0, d1) -> (d0, d1)>
|
||||
// affine_map<(d0, d1) -> (d1, 0)>
|
||||
// affine_map<(d0, d1) -> (d1, 1)>
|
||||
// affine_map<(d0, d1) -> (d1, 2)>
|
||||
// affine_map<(d0, d1) -> (d1, 3)>
|
||||
// ]
|
||||
// #attributes_0 {
|
||||
// indexing_maps = #maps_0,
|
||||
// iterator_types = ["parallel", "parallel"],
|
||||
// }
|
||||
// %init = linalg.init_tensor [4, 3]
|
||||
// : tensor<4x3x!FHE.eint<2>>
|
||||
// %res = linalg.generic {
|
||||
// ins(%t, %luts, %luts, %luts, %luts: tensor<4x3x!FHE.eint<p>>,
|
||||
// tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>)
|
||||
// outs(%init : tensor<4x3x!FHE.eint<2>>)
|
||||
// {
|
||||
// ^bb0(%arg0: !FHE.eint<2>, %arg1: i64, %arg2: i64, %arg3: i64,
|
||||
// %arg4: i64, %arg5: !FHE.eint<2>):
|
||||
// %lut = tensor.from_elements %arg1, %arg2, %arg3, %arg4 :
|
||||
// tensor<4xi64> %0 = "TFHE.apply_lookup_table"(%arg0, %lut)
|
||||
// {baseLogBS = -1 : i32, baseLogKS = -1 : i32, glweDimension = -1 :
|
||||
// i32, levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS = -1 :
|
||||
// i32, polynomialSize = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>,
|
||||
// tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// linalg.yield %0 : !FHE.eint<2>
|
||||
// }
|
||||
// }
|
||||
//
|
||||
/// This class rewrite pattern transforms any instance of
|
||||
/// operators `FHELinalg.ApplyMultiLookupTableEintOp` that implements the
|
||||
/// broadasting rules to an instance of `linalg.generic` with an appropriate
|
||||
/// region using `FHE.ApplyLookupTableEintOp` operation, an appropriate
|
||||
/// specification for the iteration dimensions and appropriate operaztions
|
||||
/// managing the accumulator of `linalg.generic`.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// %res = "FHELinalg.apply_multi_lookup_table"(%t, %luts):
|
||||
/// (tensor<4x3x!FHE.eint<2>>, tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>>
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// #maps_0 = [
|
||||
/// affine_map<(d0, d1) -> (d0, d1)>
|
||||
/// affine_map<(d0, d1) -> (d1, 0)>
|
||||
/// affine_map<(d0, d1) -> (d1, 1)>
|
||||
/// affine_map<(d0, d1) -> (d1, 2)>
|
||||
/// affine_map<(d0, d1) -> (d1, 3)>
|
||||
/// ]
|
||||
/// #attributes_0 {
|
||||
/// indexing_maps = #maps_0,
|
||||
/// iterator_types = ["parallel", "parallel"],
|
||||
/// }
|
||||
/// %init = linalg.init_tensor [4, 3]
|
||||
/// : tensor<4x3x!FHE.eint<2>>
|
||||
/// %res = linalg.generic {
|
||||
/// ins(%t, %luts, %luts, %luts, %luts: tensor<4x3x!FHE.eint<p>>,
|
||||
/// tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>)
|
||||
/// outs(%init : tensor<4x3x!FHE.eint<2>>)
|
||||
/// {
|
||||
/// ^bb0(%arg0: !FHE.eint<2>, %arg1: i64, %arg2: i64, %arg3: i64,
|
||||
/// %arg4: i64, %arg5: !FHE.eint<2>):
|
||||
/// %lut = tensor.from_elements %arg1, %arg2, %arg3, %arg4 :
|
||||
/// tensor<4xi64> %0 = "TFHE.apply_lookup_table"(%arg0, %lut)
|
||||
/// {baseLogBS = -1 : i32, baseLogKS = -1 : i32, glweDimension = -1
|
||||
/// : i32, levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS = -1
|
||||
/// : i32, polynomialSize = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>,
|
||||
/// tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
/// linalg.yield %0 : !FHE.eint<2>
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
struct FHELinalgApplyMultiLookupTableToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::FHELinalg::ApplyMultiLookupTableEintOp> {
|
||||
@@ -578,42 +578,42 @@ struct FHELinalgApplyMultiLookupTableToLinalgGeneric
|
||||
};
|
||||
};
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// operators `FHELinalg.apply_lookup_table` that implements the broadasting
|
||||
// rules to an instance of `linalg.generic` with an appropriate region using
|
||||
// `FHE.apply_lookup_table` operation, an appropriate specification for the
|
||||
// iteration dimensions and appropriate operations managing the accumulator of
|
||||
// `linalg.generic`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// FHELinalg.apply_lookup_table(%t, %lut):
|
||||
// tensor<DNx...xD1x!FHE.eint<p>>, tensor<DAxi64>
|
||||
// -> tensor<DNx...xD1x!FHE.eint<p'>>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// #maps_0 = [
|
||||
// affine_map<(aN, ..., a1) -> (aN, ..., a1)>,
|
||||
// affine_map<(aN, ..., a1) -> (aN, ..., a1)>
|
||||
// ]
|
||||
// #attributes_0 {
|
||||
// indexing_maps = #maps_0,
|
||||
// iterator_types = ["parallel",..],//N parallel
|
||||
// }
|
||||
// %init = linalg.init_tensor [DN,...,D1]
|
||||
// : tensor<DNx...xD1x!FHE.eint<p'>>
|
||||
// %res = linalg.generic {
|
||||
// ins(%t: tensor<DNx...xD1x!FHE.eint<p>>)
|
||||
// outs(%init : tensor<DNx...xD1x!FHE.eint<p'>>)
|
||||
// {
|
||||
// ^bb0(%arg0: !FHE.eint<p>):
|
||||
// %0 = FHE.apply_lookup_table(%arg0, %lut): !FHE.eint<p>,
|
||||
// tensor<4xi64> -> !FHE.eint<p'>
|
||||
// linalg.yield %0 : !FHE.eint<p'>
|
||||
// }
|
||||
// }
|
||||
//
|
||||
/// This template rewrite pattern transforms any instance of
|
||||
/// operators `FHELinalg.apply_lookup_table` that implements the broadasting
|
||||
/// rules to an instance of `linalg.generic` with an appropriate region using
|
||||
/// `FHE.apply_lookup_table` operation, an appropriate specification for the
|
||||
/// iteration dimensions and appropriate operations managing the accumulator of
|
||||
/// `linalg.generic`.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// FHELinalg.apply_lookup_table(%t, %lut):
|
||||
/// tensor<DNx...xD1x!FHE.eint<p>>, tensor<DAxi64>
|
||||
/// -> tensor<DNx...xD1x!FHE.eint<p'>>
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// #maps_0 = [
|
||||
/// affine_map<(aN, ..., a1) -> (aN, ..., a1)>,
|
||||
/// affine_map<(aN, ..., a1) -> (aN, ..., a1)>
|
||||
/// ]
|
||||
/// #attributes_0 {
|
||||
/// indexing_maps = #maps_0,
|
||||
/// iterator_types = ["parallel",..],//N parallel
|
||||
/// }
|
||||
/// %init = linalg.init_tensor [DN,...,D1]
|
||||
/// : tensor<DNx...xD1x!FHE.eint<p'>>
|
||||
/// %res = linalg.generic {
|
||||
/// ins(%t: tensor<DNx...xD1x!FHE.eint<p>>)
|
||||
/// outs(%init : tensor<DNx...xD1x!FHE.eint<p'>>)
|
||||
/// {
|
||||
/// ^bb0(%arg0: !FHE.eint<p>):
|
||||
/// %0 = FHE.apply_lookup_table(%arg0, %lut): !FHE.eint<p>,
|
||||
/// tensor<4xi64> -> !FHE.eint<p'>
|
||||
/// linalg.yield %0 : !FHE.eint<p'>
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
struct FHELinalgApplyLookupTableToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::FHELinalg::ApplyLookupTableEintOp> {
|
||||
@@ -681,39 +681,39 @@ struct FHELinalgApplyLookupTableToLinalgGeneric
|
||||
};
|
||||
};
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// operators `FHELinalg.neg_eint` to an instance of `linalg.generic` with an
|
||||
// appropriate region using `FHE.neg_eint` operation, an appropriate
|
||||
// specification for the iteration dimensions and appropriate operations
|
||||
// managing the accumulator of `linalg.generic`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// FHELinalg.neg_eint(%tensor):
|
||||
// tensor<DNx...xD1x!FHE.eint<p>> -> tensor<DNx...xD1x!FHE.eint<p'>>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// #maps_0 = [
|
||||
// affine_map<(aN, ..., a1) -> (aN, ..., a1)>,
|
||||
// affine_map<(aN, ..., a1) -> (aN, ..., a1)>
|
||||
// ]
|
||||
// #attributes_0 {
|
||||
// indexing_maps = #maps_0,
|
||||
// iterator_types = ["parallel",..],//N parallel
|
||||
// }
|
||||
// %init = linalg.init_tensor [DN,...,D1]
|
||||
// : tensor<DNx...xD1x!FHE.eint<p'>>
|
||||
// %res = linalg.generic {
|
||||
// ins(%tensor: tensor<DNx...xD1x!FHE.eint<p>>)
|
||||
// outs(%init : tensor<DNx...xD1x!FHE.eint<p'>>)
|
||||
// {
|
||||
// ^bb0(%arg0: !FHE.eint<p>):
|
||||
// %0 = FHE.neg_eint(%arg0): !FHE.eint<p> -> !FHE.eint<p'>
|
||||
// linalg.yield %0 : !FHE.eint<p'>
|
||||
// }
|
||||
// }
|
||||
//
|
||||
/// This template rewrite pattern transforms any instance of
|
||||
/// operators `FHELinalg.neg_eint` to an instance of `linalg.generic` with an
|
||||
/// appropriate region using `FHE.neg_eint` operation, an appropriate
|
||||
/// specification for the iteration dimensions and appropriate operations
|
||||
/// managing the accumulator of `linalg.generic`.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// FHELinalg.neg_eint(%tensor):
|
||||
/// tensor<DNx...xD1x!FHE.eint<p>> -> tensor<DNx...xD1x!FHE.eint<p'>>
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// #maps_0 = [
|
||||
/// affine_map<(aN, ..., a1) -> (aN, ..., a1)>,
|
||||
/// affine_map<(aN, ..., a1) -> (aN, ..., a1)>
|
||||
/// ]
|
||||
/// #attributes_0 {
|
||||
/// indexing_maps = #maps_0,
|
||||
/// iterator_types = ["parallel",..],//N parallel
|
||||
/// }
|
||||
/// %init = linalg.init_tensor [DN,...,D1]
|
||||
/// : tensor<DNx...xD1x!FHE.eint<p'>>
|
||||
/// %res = linalg.generic {
|
||||
/// ins(%tensor: tensor<DNx...xD1x!FHE.eint<p>>)
|
||||
/// outs(%init : tensor<DNx...xD1x!FHE.eint<p'>>)
|
||||
/// {
|
||||
/// ^bb0(%arg0: !FHE.eint<p>):
|
||||
/// %0 = FHE.neg_eint(%arg0): !FHE.eint<p> -> !FHE.eint<p'>
|
||||
/// linalg.yield %0 : !FHE.eint<p'>
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
struct FHELinalgNegEintToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::FHELinalg::NegEintOp> {
|
||||
FHELinalgNegEintToLinalgGeneric(
|
||||
@@ -778,44 +778,43 @@ struct FHELinalgNegEintToLinalgGeneric
|
||||
};
|
||||
};
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// operators `FHELinalgMatmulOp` to an instance of `linalg.generic`
|
||||
// with an appropriate region using a builder that create the multiplication
|
||||
// operators and `FHE.add_eint` operation, an appropriate specification for
|
||||
// the iteration dimensions and appropriate operations managing the accumulator
|
||||
// of `linalg.generic`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// "FHELinalg.matmul_eint_int(%a, %b) :
|
||||
// (tensor<MxPx!FHE.eint<p>>, tensor<PxNxip'>) ->
|
||||
// tensor<MxNx!FHE.eint<p>>"
|
||||
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// #maps_0 = [
|
||||
// (m, n, p) -> (m, p),
|
||||
// (m, n, p) -> (p, n),
|
||||
// (m, n, p) -> (m, n)
|
||||
// ]
|
||||
// #attributes_0 = {
|
||||
// indexing_maps = #maps_0,
|
||||
// iterator_types = ["parallel", "parallel", "reduction"]
|
||||
// }
|
||||
// %init = FHE.zero_tensor : tensor<MxNx!FHE.eint<p>>
|
||||
// linalg.generic #attributes_0
|
||||
// ins(%A, %B : tensor<MxPx!FHE.eint<p>>,
|
||||
// tensor<PxNxip'>)
|
||||
// outs(%C : tensor<MxNx!FHE.eint<p>>)
|
||||
// {
|
||||
// ^bb0(%a: !FHE.eint<p>, %b: ip', %c: !FHE.eint<p>) :
|
||||
// %d = createMulOp(%a, %b): !FHE.eint<p>
|
||||
// %e = "FHE.add_eint"(%c, %d):
|
||||
// (!FHE.eint<p>, !FHE.eint<p>) -> !FHE.eint<p>
|
||||
// linalg.yield %e : !FHE.eint<p>
|
||||
// }
|
||||
//
|
||||
/// This template rewrite pattern transforms any instance of
|
||||
/// operators `FHELinalgMatmulOp` to an instance of `linalg.generic`
|
||||
/// with an appropriate region using a builder that create the multiplication
|
||||
/// operators and `FHE.add_eint` operation, an appropriate specification for
|
||||
/// the iteration dimensions and appropriate operations managing the accumulator
|
||||
/// of `linalg.generic`.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// "FHELinalg.matmul_eint_int(%a, %b) :
|
||||
/// (tensor<MxPx!FHE.eint<p>>, tensor<PxNxip'>) ->
|
||||
/// tensor<MxNx!FHE.eint<p>>"
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// #maps_0 = [
|
||||
/// (m, n, p) -> (m, p),
|
||||
/// (m, n, p) -> (p, n),
|
||||
/// (m, n, p) -> (m, n)
|
||||
/// ]
|
||||
/// #attributes_0 = {
|
||||
/// indexing_maps = #maps_0,
|
||||
/// iterator_types = ["parallel", "parallel", "reduction"]
|
||||
/// }
|
||||
/// %init = FHE.zero_tensor : tensor<MxNx!FHE.eint<p>>
|
||||
/// linalg.generic #attributes_0
|
||||
/// ins(%A, %B : tensor<MxPx!FHE.eint<p>>,
|
||||
/// tensor<PxNxip'>)
|
||||
/// outs(%C : tensor<MxNx!FHE.eint<p>>)
|
||||
/// {
|
||||
/// ^bb0(%a: !FHE.eint<p>, %b: ip', %c: !FHE.eint<p>) :
|
||||
/// %d = createMulOp(%a, %b): !FHE.eint<p>
|
||||
/// %e = "FHE.add_eint"(%c, %d):
|
||||
/// (!FHE.eint<p>, !FHE.eint<p>) -> !FHE.eint<p>
|
||||
/// linalg.yield %e : !FHE.eint<p>
|
||||
/// }
|
||||
///
|
||||
template <typename FHELinalgMatmulOp>
|
||||
struct FHELinalgMatmulToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<FHELinalgMatmulOp> {
|
||||
@@ -1089,37 +1088,37 @@ private:
|
||||
createMulOp;
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of operators
|
||||
// `FHELinalg.sum` to an instance of `linalg.generic`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// %result = "FHELinalg.sum"(%input) :
|
||||
// tensor<d0xd1x...xdNx!FHE.eint<p>>() -> !FHE.eint<p>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// #map0 = affine_map<(i0, i1, ..., iN) -> (i0, i1, ..., iN)>
|
||||
// #map1 = affine_map<(i0, i1, ..., iN) -> (0)>
|
||||
//
|
||||
// %accumulator = "FHE.zero_tensor"() : () -> tensor<1x!FHE.eint<7>>
|
||||
// %accumulation = linalg.generic
|
||||
// {
|
||||
// indexing_maps = [#map0, #map1],
|
||||
// iterator_types = ["reduction", "reduction", ..., "reduction"]
|
||||
// }
|
||||
// ins(%input : tensor<d0xd1x...xdNx!FHE.eint<7>>)
|
||||
// outs(%accumulator : tensor<1x!FHE.eint<7>>)
|
||||
// {
|
||||
// ^bb0(%a: !FHE.eint<7>, %b: !FHE.eint<7>):
|
||||
// %c = "FHE.add_eint"(%a, %b) :
|
||||
// (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7>
|
||||
// linalg.yield %c : !FHE.eint<7>
|
||||
// } -> tensor<1x!FHE.eint<7>>
|
||||
//
|
||||
// %index = arith.constant 0 : index
|
||||
// %result = tensor.extract %index : tensor<1x!FHE.eint<7>>
|
||||
//
|
||||
/// This rewrite pattern transforms any instance of operators
|
||||
/// `FHELinalg.sum` to an instance of `linalg.generic`.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// %result = "FHELinalg.sum"(%input) :
|
||||
/// tensor<d0xd1x...xdNx!FHE.eint<p>>() -> !FHE.eint<p>
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// #map0 = affine_map<(i0, i1, ..., iN) -> (i0, i1, ..., iN)>
|
||||
/// #map1 = affine_map<(i0, i1, ..., iN) -> (0)>
|
||||
///
|
||||
/// %accumulator = "FHE.zero_tensor"() : () -> tensor<1x!FHE.eint<7>>
|
||||
/// %accumulation = linalg.generic
|
||||
/// {
|
||||
/// indexing_maps = [#map0, #map1],
|
||||
/// iterator_types = ["reduction", "reduction", ..., "reduction"]
|
||||
/// }
|
||||
/// ins(%input : tensor<d0xd1x...xdNx!FHE.eint<7>>)
|
||||
/// outs(%accumulator : tensor<1x!FHE.eint<7>>)
|
||||
/// {
|
||||
/// ^bb0(%a: !FHE.eint<7>, %b: !FHE.eint<7>):
|
||||
/// %c = "FHE.add_eint"(%a, %b) :
|
||||
/// (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7>
|
||||
/// linalg.yield %c : !FHE.eint<7>
|
||||
/// } -> tensor<1x!FHE.eint<7>>
|
||||
///
|
||||
/// %index = arith.constant 0 : index
|
||||
/// %result = tensor.extract %index : tensor<1x!FHE.eint<7>>
|
||||
///
|
||||
struct SumToLinalgGeneric
|
||||
: public ::mlir::OpRewritePattern<mlir::concretelang::FHELinalg::SumOp> {
|
||||
SumToLinalgGeneric(::mlir::MLIRContext *context)
|
||||
@@ -1245,32 +1244,32 @@ struct SumToLinalgGeneric
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of operators
|
||||
// `FHELinalg.transpose` to an instance of `linalg.generic`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// %result = "FHELinalg.transpose"(%input: tensor<d0xd1x...xdNx!FHE.eint<p>>)
|
||||
// -> tensor<dNx...xd1xd0x!FHE.eint<p>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// #map0 = affine_map<(i0, i1, ..., iN) -> (iN, ..., i1, i0)>
|
||||
// #map1 = affine_map<(i0, i1, ..., iN) -> (i0, i1, ..., iN)>
|
||||
//
|
||||
// %accumulator = "FHE.zero_tensor"() : () ->
|
||||
// tensor<dNx...xd1xd0x!FHE.eint<6>> %result = linalg.generic
|
||||
// {
|
||||
// indexing_maps = [#map0, #map1],
|
||||
// iterator_types = ["parallel", "parallel", ..., "parallel"]
|
||||
// }
|
||||
// ins(%input : tensor<d0xd1x...xdNx!FHE.eint<7>>)
|
||||
// outs(%accumulator : tensor<dNx...xd1xd0x!FHE.eint<7>>)
|
||||
// {
|
||||
// ^bb0(%a: !FHE.eint<7>, %b: !FHE.eint<7>):
|
||||
// linalg.yield %a : !FHE.eint<7>
|
||||
// } -> tensor<dNx...xd1xd0x!FHE.eint<7>>
|
||||
//
|
||||
/// This rewrite pattern transforms any instance of operators
|
||||
/// `FHELinalg.transpose` to an instance of `linalg.generic`.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// %result = "FHELinalg.transpose"(%input: tensor<d0xd1x...xdNx!FHE.eint<p>>)
|
||||
/// -> tensor<dNx...xd1xd0x!FHE.eint<p>
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// #map0 = affine_map<(i0, i1, ..., iN) -> (iN, ..., i1, i0)>
|
||||
/// #map1 = affine_map<(i0, i1, ..., iN) -> (i0, i1, ..., iN)>
|
||||
///
|
||||
/// %accumulator = "FHE.zero_tensor"() : () ->
|
||||
/// tensor<dNx...xd1xd0x!FHE.eint<6>> %result = linalg.generic
|
||||
/// {
|
||||
/// indexing_maps = [#map0, #map1],
|
||||
/// iterator_types = ["parallel", "parallel", ..., "parallel"]
|
||||
/// }
|
||||
/// ins(%input : tensor<d0xd1x...xdNx!FHE.eint<7>>)
|
||||
/// outs(%accumulator : tensor<dNx...xd1xd0x!FHE.eint<7>>)
|
||||
/// {
|
||||
/// ^bb0(%a: !FHE.eint<7>, %b: !FHE.eint<7>):
|
||||
/// linalg.yield %a : !FHE.eint<7>
|
||||
/// } -> tensor<dNx...xd1xd0x!FHE.eint<7>>
|
||||
///
|
||||
struct TransposeToLinalgGeneric
|
||||
: public ::mlir::OpRewritePattern<
|
||||
mlir::concretelang::FHELinalg::TransposeOp> {
|
||||
@@ -1325,25 +1324,25 @@ struct TransposeToLinalgGeneric
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of operators
|
||||
// `FHELinalg.concat` to instances of `tensor.insert_slice`
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// %result = "FHELinalg.concat"(%x, %y) { axis = 1 } :
|
||||
// (tensor<2x3x!FHE.eint<4>>, tensor<2x4x!FHE.eint<4>>)
|
||||
// -> tensor<2x7x!FHE.eint<4>>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// %empty = "FHE.zero_tensor"() : () -> tensor<2x7x!FHE.eint<4>>
|
||||
//
|
||||
// %x_copied = tensor.insert_slice %x into %empty[0, 0] [2, 3] [1, 1]
|
||||
// : tensor<2x3x!FHE.eint<4>> into tensor<2x7x!FHE.eint<4>>
|
||||
//
|
||||
// %y_copied = tensor.insert_slice %y into %x_copied[0, 3] [2, 4] [1, 1]
|
||||
// : tensor<2x4x!FHE.eint<4>> into tensor<2x7x!FHE.eint<4>>
|
||||
//
|
||||
/// This rewrite pattern transforms any instance of operators
|
||||
/// `FHELinalg.concat` to instances of `tensor.insert_slice`
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// %result = "FHELinalg.concat"(%x, %y) { axis = 1 } :
|
||||
/// (tensor<2x3x!FHE.eint<4>>, tensor<2x4x!FHE.eint<4>>)
|
||||
/// -> tensor<2x7x!FHE.eint<4>>
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// %empty = "FHE.zero_tensor"() : () -> tensor<2x7x!FHE.eint<4>>
|
||||
///
|
||||
/// %x_copied = tensor.insert_slice %x into %empty[0, 0] [2, 3] [1, 1]
|
||||
/// : tensor<2x3x!FHE.eint<4>> into tensor<2x7x!FHE.eint<4>>
|
||||
///
|
||||
/// %y_copied = tensor.insert_slice %y into %x_copied[0, 3] [2, 4] [1, 1]
|
||||
/// : tensor<2x4x!FHE.eint<4>> into tensor<2x7x!FHE.eint<4>>
|
||||
///
|
||||
struct ConcatRewritePattern
|
||||
: public mlir::OpRewritePattern<FHELinalg::ConcatOp> {
|
||||
ConcatRewritePattern(mlir::MLIRContext *context)
|
||||
@@ -1449,8 +1448,8 @@ getAsOpFoldResult(mlir::OpBuilder &b, mlir::Location loc,
|
||||
}));
|
||||
}
|
||||
|
||||
// Helper function to get the padding tensor given the padding int values, and
|
||||
// the value to pad with
|
||||
/// Helper function to get the padding tensor given the padding int values, and
|
||||
/// the value to pad with
|
||||
static mlir::Value
|
||||
getPaddedTensor(mlir::Operation *op, mlir::OpBuilder &b, mlir::Value &input,
|
||||
mlir::SmallVectorImpl<int64_t> &lowPaddingInts,
|
||||
@@ -1472,10 +1471,10 @@ getPaddedTensor(mlir::Operation *op, mlir::OpBuilder &b, mlir::Value &input,
|
||||
return paddedInput;
|
||||
}
|
||||
|
||||
// This rewrite pattern transforms any instance of operators
|
||||
// `FHELinalg.conv2d` to an instance of `linalg.fhelinalg_conv_2d_nchw_fchw`.
|
||||
// The transformation consists of padding the input tensor, and initializing the
|
||||
// output tensor with bias values if any.
|
||||
/// This rewrite pattern transforms any instance of operators
|
||||
/// `FHELinalg.conv2d` to an instance of `linalg.fhelinalg_conv_2d_nchw_fchw`.
|
||||
/// The transformation consists of padding the input tensor, and initializing
|
||||
/// the output tensor with bias values if any.
|
||||
struct FHELinalgConv2dToLinalgConv2d
|
||||
: public ::mlir::OpRewritePattern<mlir::concretelang::FHELinalg::Conv2dOp> {
|
||||
FHELinalgConv2dToLinalgConv2d(::mlir::MLIRContext *context)
|
||||
|
||||
@@ -60,30 +60,30 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of `FHE.apply_lookup_table`
|
||||
// operators.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "FHE.apply_lookup_table"(%ct, %lut): (!FHE.eint<2>, tensor<4xi64>)
|
||||
// ->(!FHE.eint<2>)
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %glwe_lut = "TFHE.glwe_from_table"(%lut)
|
||||
// : (tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// %glwe_ks = "TFHE.keyswitch_glwe"(%ct)
|
||||
// {baseLog = -1 : i32, level = -1 : i32}
|
||||
// : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// %0 = "TFHE.bootstrap_glwe"(%glwe_ks, %glwe_lut)
|
||||
// {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32,
|
||||
// polynomialSize = -1 : i32}
|
||||
// : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) ->
|
||||
// !TFHE.glwe<{_,_,_}{2}>
|
||||
// ```
|
||||
/// This rewrite pattern transforms any instance of `FHE.apply_lookup_table`
|
||||
/// operators.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = "FHE.apply_lookup_table"(%ct, %lut): (!FHE.eint<2>, tensor<4xi64>)
|
||||
/// ->(!FHE.eint<2>)
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %glwe_lut = "TFHE.glwe_from_table"(%lut)
|
||||
/// : (tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
/// %glwe_ks = "TFHE.keyswitch_glwe"(%ct)
|
||||
/// {baseLog = -1 : i32, level = -1 : i32}
|
||||
/// : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
/// %0 = "TFHE.bootstrap_glwe"(%glwe_ks, %glwe_lut)
|
||||
/// {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32,
|
||||
/// polynomialSize = -1 : i32}
|
||||
/// : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) ->
|
||||
/// !TFHE.glwe<{_,_,_}{2}>
|
||||
/// ```
|
||||
struct ApplyLookupTableEintOpPattern
|
||||
: public mlir::OpRewritePattern<FHE::ApplyLookupTableEintOp> {
|
||||
ApplyLookupTableEintOpPattern(mlir::MLIRContext *context,
|
||||
@@ -115,8 +115,8 @@ struct ApplyLookupTableEintOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of `FHE.sub_eint_int`
|
||||
// operators to a negation and an addition.
|
||||
/// This rewrite pattern transforms any instance of `FHE.sub_eint_int`
|
||||
/// operators to a negation and an addition.
|
||||
struct SubEintIntOpPattern : public mlir::OpRewritePattern<FHE::SubEintIntOp> {
|
||||
SubEintIntOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
@@ -156,8 +156,8 @@ struct SubEintIntOpPattern : public mlir::OpRewritePattern<FHE::SubEintIntOp> {
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of `FHE.sub_eint`
|
||||
// operators to a negation and an addition.
|
||||
/// This rewrite pattern transforms any instance of `FHE.sub_eint`
|
||||
/// operators to a negation and an addition.
|
||||
struct SubEintOpPattern : public mlir::OpRewritePattern<FHE::SubEintOp> {
|
||||
SubEintOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<FHE::SubEintOp>(context, benefit) {}
|
||||
|
||||
@@ -43,26 +43,26 @@ struct MLIRLowerableDialectsToLLVMPass
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// This rewrite pattern transforms any instance of `memref.copy`
|
||||
// operators on 1D memref.
|
||||
// This is introduced to avoid the MLIR lowering of `memref.copy` of ranked
|
||||
// memref that basically allocate unranked memref structure on the stack before
|
||||
// calling @memrefCopy.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// memref.copy %src, %dst : memref<Xxi64> to memref<Xxi64>
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %_src = memref.cast %src = memref<Xxi64> to memref<?xi64>
|
||||
// %_dst = memref.cast %dst = memref<Xxi64> to memref<?xi64>
|
||||
// call @memref_copy_one_rank(%_src, %_dst) : (tensor<?xi64>, tensor<?xi64>) ->
|
||||
// ()
|
||||
// ```
|
||||
/// This rewrite pattern transforms any instance of `memref.copy`
|
||||
/// operators on 1D memref.
|
||||
/// This is introduced to avoid the MLIR lowering of `memref.copy` of ranked
|
||||
/// memref that basically allocate unranked memref structure on the stack before
|
||||
/// calling @memrefCopy.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// memref.copy %src, %dst : memref<Xxi64> to memref<Xxi64>
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %_src = memref.cast %src = memref<Xxi64> to memref<?xi64>
|
||||
/// %_dst = memref.cast %dst = memref<Xxi64> to memref<?xi64>
|
||||
/// call @memref_copy_one_rank(%_src, %_dst) : (tensor<?xi64>, tensor<?xi64>) ->
|
||||
/// ()
|
||||
/// ```
|
||||
struct Memref1DCopyOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::memref::CopyOp> {
|
||||
Memref1DCopyOpPattern(mlir::MLIRContext *context,
|
||||
|
||||
@@ -147,25 +147,25 @@ private:
|
||||
mlir::concretelang::V0FHEContext &fheContext;
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of `TFHE.glwe_from_table` by
|
||||
// parametrize GLWE return type and pad the table if the precision has been
|
||||
// changed.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %lut = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi64>
|
||||
// %0 = "TFHE.glwe_from_table" (%lut) : (tensor<4xi64>) ->
|
||||
// !TFHE.glwe<{_,_,_}{2}>
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %lut = arith.constant dense<[0, 1, 2, 3, 0, 1, 2, 3]> : tensor<8xi64>
|
||||
// %0 = "TFHE.glwe_from_table" (%lut) : (tensor<8xi64>) ->
|
||||
// !TFHE.glwe<{_,_,_}{3}>
|
||||
// ```
|
||||
/// This rewrite pattern transforms any instance of `TFHE.glwe_from_table` by
|
||||
/// parametrize GLWE return type and pad the table if the precision has been
|
||||
/// changed.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %lut = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi64>
|
||||
/// %0 = "TFHE.glwe_from_table" (%lut) : (tensor<4xi64>) ->
|
||||
/// !TFHE.glwe<{_,_,_}{2}>
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %lut = arith.constant dense<[0, 1, 2, 3, 0, 1, 2, 3]> : tensor<8xi64>
|
||||
/// %0 = "TFHE.glwe_from_table" (%lut) : (tensor<8xi64>) ->
|
||||
/// !TFHE.glwe<{_,_,_}{3}>
|
||||
/// ```
|
||||
struct GLWEFromTablePattern
|
||||
: public mlir::OpRewritePattern<TFHE::GLWEFromTableOp> {
|
||||
GLWEFromTablePattern(mlir::MLIRContext *context,
|
||||
|
||||
@@ -57,8 +57,8 @@ struct AddRuntimeContextToFuncOpPattern
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Legal function are one that are private or has a Concrete.context as last
|
||||
// arguments.
|
||||
/// Legal function are one that are private or has a Concrete.context as last
|
||||
/// arguments.
|
||||
static bool isLegal(mlir::func::FuncOp funcOp) {
|
||||
if (!funcOp.isPublic()) {
|
||||
return true;
|
||||
|
||||
@@ -44,7 +44,7 @@ mlir::Type getDynamic1DMemrefWithUnknownOffset(mlir::RewriterBase &rewriter) {
|
||||
mlir::getAffineSymbolExpr(0, ctx)));
|
||||
}
|
||||
|
||||
// Returns `memref.cast %0 : memref<AxT> to memref<?xT>` if %0 a 1D memref
|
||||
/// Returns `memref.cast %0 : memref<AxT> to memref<?xT>` if %0 a 1D memref
|
||||
mlir::Value getCasted1DMemRef(mlir::RewriterBase &rewriter, mlir::Location loc,
|
||||
mlir::Value value) {
|
||||
mlir::Type valueType = value.getType();
|
||||
@@ -115,7 +115,7 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
return insertForwardDeclaration(op, rewriter, funcName, funcType);
|
||||
}
|
||||
|
||||
// Returns the value of the context argument from the enclosing func
|
||||
/// Returns the value of the context argument from the enclosing func
|
||||
mlir::Value getContextArgument(mlir::Operation *op) {
|
||||
mlir::Block *block = op->getBlock();
|
||||
while (block != nullptr) {
|
||||
@@ -233,25 +233,25 @@ struct BufferizableGlweFromTableOpInterface
|
||||
return BufferRelation::None;
|
||||
}
|
||||
|
||||
// Bufferize GlweFromTable
|
||||
// ```
|
||||
// "BConcrete.fill_glwe_table"(%glwe, %lut) {glweDimension=1,
|
||||
// polynomialSize=2048, outPrecision=3} :
|
||||
// (tensor<4096xi64>, tensor<32xi64>) -> ()
|
||||
// ```
|
||||
//
|
||||
// to
|
||||
//
|
||||
// ```
|
||||
// %glweDim = arith.constant 1 : i32
|
||||
// %polySize = arith.constant 2048 : i32
|
||||
// %outPrecision = arith.constant 3 : i32
|
||||
// %glwe_ = memref.cast %glwe : memref<4096xi64> to memref<?xi64>
|
||||
// %lut_ = memref.cast %lut : memref<32xi64> to memref<?xi64>
|
||||
// call @expand_lut_in_trivial_glwe_ct(%glwe, %polySize, %glweDim,
|
||||
// %outPrecision, %lut_) :
|
||||
// (tensor<?xi64>, i32, i32, tensor<?xi64>) -> ()
|
||||
// ```
|
||||
/// Bufferize GlweFromTable
|
||||
/// ```
|
||||
/// "BConcrete.fill_glwe_table"(%glwe, %lut) {glweDimension=1,
|
||||
/// polynomialSize=2048, outPrecision=3} :
|
||||
/// (tensor<4096xi64>, tensor<32xi64>) -> ()
|
||||
/// ```
|
||||
///
|
||||
/// to
|
||||
///
|
||||
/// ```
|
||||
/// %glweDim = arith.constant 1 : i32
|
||||
/// %polySize = arith.constant 2048 : i32
|
||||
/// %outPrecision = arith.constant 3 : i32
|
||||
/// %glwe_ = memref.cast %glwe : memref<4096xi64> to memref<?xi64>
|
||||
/// %lut_ = memref.cast %lut : memref<32xi64> to memref<?xi64>
|
||||
/// call @expand_lut_in_trivial_glwe_ct(%glwe, %polySize, %glweDim,
|
||||
/// %outPrecision, %lut_) :
|
||||
/// (tensor<?xi64>, i32, i32, tensor<?xi64>) -> ()
|
||||
/// ```
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace concretelang {
|
||||
|
||||
namespace {
|
||||
|
||||
// Get the integer value that the cleartext was created from if it exists.
|
||||
/// Get the integer value that the cleartext was created from if it exists.
|
||||
llvm::Optional<mlir::Value>
|
||||
getIntegerFromCleartextIfExists(mlir::Value cleartext) {
|
||||
assert(
|
||||
@@ -32,7 +32,7 @@ getIntegerFromCleartextIfExists(mlir::Value cleartext) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Get the constant integer that the cleartext was created from if it exists.
|
||||
/// Get the constant integer that the cleartext was created from if it exists.
|
||||
llvm::Optional<IntegerAttr>
|
||||
getConstantIntFromCleartextIfExists(mlir::Value cleartext) {
|
||||
auto cleartextInt = getIntegerFromCleartextIfExists(cleartext);
|
||||
@@ -49,9 +49,9 @@ getConstantIntFromCleartextIfExists(mlir::Value cleartext) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Rewrite a `Concrete.mul_cleartext_lwe_ciphertext` operation as a
|
||||
// `Concrete.zero` operation if it's being multiplied with a constant 0, or as a
|
||||
// `Concrete.negate_lwe_ciphertext` if multiplied with a constant -1.
|
||||
/// Rewrite a `Concrete.mul_cleartext_lwe_ciphertext` operation as a
|
||||
/// `Concrete.zero` operation if it's being multiplied with a constant 0, or as
|
||||
/// a `Concrete.negate_lwe_ciphertext` if multiplied with a constant -1.
|
||||
class MulCleartextLweCiphertextOpPattern
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::MulCleartextLweCiphertextOp> {
|
||||
@@ -85,8 +85,8 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// Optimization pass that should choose more efficient ways of performing crypto
|
||||
// operations.
|
||||
/// Optimization pass that should choose more efficient ways of performing
|
||||
/// crypto operations.
|
||||
class ConcreteOptimizationPass
|
||||
: public ConcreteOptimizationBase<ConcreteOptimizationPass> {
|
||||
public:
|
||||
|
||||
@@ -32,8 +32,8 @@ namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace {
|
||||
|
||||
// Returns `true` if the given value is a scalar or tensor argument of
|
||||
// a function, for which a MANP of 1 can be assumed.
|
||||
/// Returns `true` if the given value is a scalar or tensor argument of
|
||||
/// a function, for which a MANP of 1 can be assumed.
|
||||
static bool isEncryptedFunctionParameter(mlir::Value value) {
|
||||
if (!value.isa<mlir::BlockArgument>())
|
||||
return false;
|
||||
@@ -54,9 +54,9 @@ static bool isEncryptedFunctionParameter(mlir::Value value) {
|
||||
.isa<mlir::concretelang::FHE::EncryptedIntegerType>()));
|
||||
}
|
||||
|
||||
// Returns the bit width of `value` if `value` is an encrypted integer
|
||||
// or the bit width of the elements if `value` is a tensor of
|
||||
// encrypted integers.
|
||||
/// Returns the bit width of `value` if `value` is an encrypted integer
|
||||
/// or the bit width of the elements if `value` is a tensor of
|
||||
/// encrypted integers.
|
||||
static unsigned int getEintPrecision(mlir::Value value) {
|
||||
if (auto ty = value.getType()
|
||||
.dyn_cast_or_null<
|
||||
@@ -77,11 +77,11 @@ static unsigned int getEintPrecision(mlir::Value value) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// The `MANPLatticeValue` represents the squared Minimal Arithmetic
|
||||
// Noise Padding for an operation using the squared 2-norm of an
|
||||
// equivalent dot operation. This can either be an actual value if the
|
||||
// values for its predecessors have been calculated beforehand or an
|
||||
// unknown value otherwise.
|
||||
/// The `MANPLatticeValue` represents the squared Minimal Arithmetic
|
||||
/// Noise Padding for an operation using the squared 2-norm of an
|
||||
/// equivalent dot operation. This can either be an actual value if the
|
||||
/// values for its predecessors have been calculated beforehand or an
|
||||
/// unknown value otherwise.
|
||||
struct MANPLatticeValue {
|
||||
MANPLatticeValue(llvm::Optional<llvm::APInt> manp = {}) : manp(manp) {}
|
||||
|
||||
@@ -109,10 +109,10 @@ struct MANPLatticeValue {
|
||||
return this->manp == rhs.manp;
|
||||
}
|
||||
|
||||
// Required by `mlir::LatticeElement::join()`, but should never be
|
||||
// invoked, as `MANPAnalysis::visitOperation()` takes care of
|
||||
// combining the squared Minimal Arithmetic Noise Padding of
|
||||
// operands into the Minimal Arithmetic Noise Padding of the result.
|
||||
/// Required by `mlir::LatticeElement::join()`, but should never be
|
||||
/// invoked, as `MANPAnalysis::visitOperation()` takes care of
|
||||
/// combining the squared Minimal Arithmetic Noise Padding of
|
||||
/// operands into the Minimal Arithmetic Noise Padding of the result.
|
||||
static MANPLatticeValue join(const MANPLatticeValue &lhs,
|
||||
const MANPLatticeValue &rhs) {
|
||||
assert(false && "Minimal Arithmetic Noise Padding values can only be "
|
||||
@@ -126,9 +126,9 @@ protected:
|
||||
llvm::Optional<llvm::APInt> manp;
|
||||
};
|
||||
|
||||
// Checks if `lhs` is less than `rhs`, where both values are assumed
|
||||
// to be positive. The bit width of the smaller `APInt` is extended
|
||||
// before comparison via `APInt::ult`.
|
||||
/// Checks if `lhs` is less than `rhs`, where both values are assumed
|
||||
/// to be positive. The bit width of the smaller `APInt` is extended
|
||||
/// before comparison via `APInt::ult`.
|
||||
static bool APIntWidthExtendULT(const llvm::APInt &lhs,
|
||||
const llvm::APInt &rhs) {
|
||||
if (lhs.getBitWidth() < rhs.getBitWidth())
|
||||
@@ -139,9 +139,9 @@ static bool APIntWidthExtendULT(const llvm::APInt &lhs,
|
||||
return lhs.ult(rhs);
|
||||
}
|
||||
|
||||
// Adds two `APInt` values, where both values are assumed to be
|
||||
// positive. The bit width of the operands is extended in order to
|
||||
// guarantee that the sum fits into the resulting `APInt`.
|
||||
/// Adds two `APInt` values, where both values are assumed to be
|
||||
/// positive. The bit width of the operands is extended in order to
|
||||
/// guarantee that the sum fits into the resulting `APInt`.
|
||||
static llvm::APInt APIntWidthExtendUAdd(const llvm::APInt &lhs,
|
||||
const llvm::APInt &rhs) {
|
||||
unsigned maxBits = std::max(lhs.getBitWidth(), rhs.getBitWidth());
|
||||
@@ -154,9 +154,9 @@ static llvm::APInt APIntWidthExtendUAdd(const llvm::APInt &lhs,
|
||||
return lhs.zext(targetWidth) + rhs.zext(targetWidth);
|
||||
}
|
||||
|
||||
// Multiplies two `APInt` values, where both values are assumed to be
|
||||
// positive. The bit width of the operands is extended in order to
|
||||
// guarantee that the product fits into the resulting `APInt`.
|
||||
/// Multiplies two `APInt` values, where both values are assumed to be
|
||||
/// positive. The bit width of the operands is extended in order to
|
||||
/// guarantee that the product fits into the resulting `APInt`.
|
||||
static llvm::APInt APIntWidthExtendUMul(const llvm::APInt &lhs,
|
||||
const llvm::APInt &rhs) {
|
||||
// Make sure the required number of bits can be represented by the
|
||||
@@ -170,9 +170,9 @@ static llvm::APInt APIntWidthExtendUMul(const llvm::APInt &lhs,
|
||||
return lhs.zext(targetWidth) * rhs.zext(targetWidth);
|
||||
}
|
||||
|
||||
// Returns the maximum value beetwen `lhs` and `rhs`, where both values are
|
||||
// assumed to be positive. The bit width of the smaller `APInt` is extended
|
||||
// before comparison via `APInt::ult`.
|
||||
/// Returns the maximum value beetwen `lhs` and `rhs`, where both values are
|
||||
/// assumed to be positive. The bit width of the smaller `APInt` is extended
|
||||
/// before comparison via `APInt::ult`.
|
||||
static llvm::APInt APIntUMax(const llvm::APInt &lhs, const llvm::APInt &rhs) {
|
||||
if (APIntWidthExtendULT(lhs, rhs)) {
|
||||
return rhs;
|
||||
@@ -180,9 +180,9 @@ static llvm::APInt APIntUMax(const llvm::APInt &lhs, const llvm::APInt &rhs) {
|
||||
return lhs;
|
||||
}
|
||||
|
||||
// Calculates the square of `i`. The bit width `i` is extended in
|
||||
// order to guarantee that the product fits into the resulting
|
||||
// `APInt`.
|
||||
/// Calculates the square of `i`. The bit width `i` is extended in
|
||||
/// order to guarantee that the product fits into the resulting
|
||||
/// `APInt`.
|
||||
static llvm::APInt APIntWidthExtendUnsignedSq(const llvm::APInt &i) {
|
||||
// Make sure the required number of bits can be represented by the
|
||||
// `unsigned` argument of `zext`.
|
||||
@@ -194,7 +194,7 @@ static llvm::APInt APIntWidthExtendUnsignedSq(const llvm::APInt &i) {
|
||||
return ie * ie;
|
||||
}
|
||||
|
||||
// Calculates the square of the absolute value of `i`.
|
||||
/// Calculates the square of the absolute value of `i`.
|
||||
static llvm::APInt APIntWidthExtendSqForConstant(const llvm::APInt &i) {
|
||||
// Make sure the required number of bits can be represented by the
|
||||
// `unsigned` argument of `zext`.
|
||||
@@ -204,9 +204,9 @@ static llvm::APInt APIntWidthExtendSqForConstant(const llvm::APInt &i) {
|
||||
i.abs().getZExtValue() * i.abs().getZExtValue());
|
||||
}
|
||||
|
||||
// Calculates the square root of `i` and rounds it to the next highest
|
||||
// integer value (i.e., the square of the result is guaranteed to be
|
||||
// greater or equal to `i`).
|
||||
/// Calculates the square root of `i` and rounds it to the next highest
|
||||
/// integer value (i.e., the square of the result is guaranteed to be
|
||||
/// greater or equal to `i`).
|
||||
static llvm::APInt APIntCeilSqrt(const llvm::APInt &i) {
|
||||
llvm::APInt res = i.sqrt();
|
||||
llvm::APInt resSq = APIntWidthExtendUnsignedSq(res);
|
||||
@@ -217,17 +217,17 @@ static llvm::APInt APIntCeilSqrt(const llvm::APInt &i) {
|
||||
return res;
|
||||
}
|
||||
|
||||
// Returns a string representation of `i` assuming that `i` is an
|
||||
// unsigned value.
|
||||
/// Returns a string representation of `i` assuming that `i` is an
|
||||
/// unsigned value.
|
||||
static std::string APIntToStringValUnsigned(const llvm::APInt &i) {
|
||||
llvm::SmallString<32> s;
|
||||
i.toStringUnsigned(s);
|
||||
return std::string(s.c_str());
|
||||
}
|
||||
|
||||
// Calculates the square of the 2-norm of a tensor initialized with a
|
||||
// dense matrix of constant, signless integers. Aborts if the value
|
||||
// type or initialization of of `cstOp` is incorrect.
|
||||
/// Calculates the square of the 2-norm of a tensor initialized with a
|
||||
/// dense matrix of constant, signless integers. Aborts if the value
|
||||
/// type or initialization of of `cstOp` is incorrect.
|
||||
static llvm::APInt denseCstTensorNorm2Sq(mlir::arith::ConstantOp cstOp,
|
||||
llvm::APInt eNorm) {
|
||||
mlir::DenseIntElementsAttr denseVals =
|
||||
@@ -252,10 +252,10 @@ static llvm::APInt denseCstTensorNorm2Sq(mlir::arith::ConstantOp cstOp,
|
||||
return accu;
|
||||
}
|
||||
|
||||
// Calculates the square of the 2-norm of a 1D tensor of signless
|
||||
// integers by conservatively assuming that the dynamic values are the
|
||||
// maximum for the integer width. Aborts if the tensor type `tTy` is
|
||||
// incorrect.
|
||||
/// Calculates the square of the 2-norm of a 1D tensor of signless
|
||||
/// integers by conservatively assuming that the dynamic values are the
|
||||
/// maximum for the integer width. Aborts if the tensor type `tTy` is
|
||||
/// incorrect.
|
||||
static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy,
|
||||
llvm::APInt eNorm) {
|
||||
assert(tTy && tTy.getElementType().isSignlessInteger() &&
|
||||
@@ -283,7 +283,7 @@ static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy,
|
||||
return APIntWidthExtendUMul(maxMulSqNorm, nEltsAP);
|
||||
}
|
||||
|
||||
// Returns the squared 2-norm of the maximum value of the dense values.
|
||||
/// Returns the squared 2-norm of the maximum value of the dense values.
|
||||
static llvm::APInt maxIntNorm2Sq(mlir::DenseIntElementsAttr denseVals) {
|
||||
auto denseValsAP = denseVals.getValues<llvm::APInt>();
|
||||
|
||||
@@ -298,9 +298,9 @@ static llvm::APInt maxIntNorm2Sq(mlir::DenseIntElementsAttr denseVals) {
|
||||
return APIntWidthExtendSqForConstant(maxCst);
|
||||
}
|
||||
|
||||
// Returns the squared 2-norm for a dynamic integer by conservatively
|
||||
// assuming that the integer's value is the maximum for the integer
|
||||
// width.
|
||||
/// Returns the squared 2-norm for a dynamic integer by conservatively
|
||||
/// assuming that the integer's value is the maximum for the integer
|
||||
/// width.
|
||||
static llvm::APInt conservativeIntNorm2Sq(mlir::Type t) {
|
||||
assert(t.isSignlessInteger() && "Type must be a signless integer type");
|
||||
assert(std::numeric_limits<unsigned>::max() - t.getIntOrFloatBitWidth() > 1);
|
||||
@@ -309,8 +309,8 @@ static llvm::APInt conservativeIntNorm2Sq(mlir::Type t) {
|
||||
return APIntWidthExtendUnsignedSq(maxVal);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of an
|
||||
// `FHELinalg.dot_eint_int` operation.
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of an
|
||||
/// `FHELinalg.dot_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::Dot op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
@@ -343,8 +343,8 @@ static llvm::APInt getSqMANP(
|
||||
}
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of an
|
||||
// `FHE.add_eint_int` operation.
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of an
|
||||
/// `FHE.add_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::AddEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
@@ -378,8 +378,8 @@ static llvm::APInt getSqMANP(
|
||||
return APIntWidthExtendUAdd(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `FHE.add_eint` operation.
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.add_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::AddEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
@@ -395,8 +395,8 @@ static llvm::APInt getSqMANP(
|
||||
return APIntWidthExtendUAdd(a, b);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `FHE.sub_int_eint` operation.
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.sub_int_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::SubIntEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
@@ -429,8 +429,8 @@ static llvm::APInt getSqMANP(
|
||||
return APIntWidthExtendUAdd(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `FHE.sub_eint_int` operation.
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.sub_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::SubEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
@@ -463,8 +463,8 @@ static llvm::APInt getSqMANP(
|
||||
return APIntWidthExtendUAdd(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `FHE.sub_eint` operation.
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.sub_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::SubEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
@@ -480,8 +480,8 @@ static llvm::APInt getSqMANP(
|
||||
return APIntWidthExtendUAdd(a, b);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `FHE.neg_eint` operation.
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.neg_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::NegEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
@@ -496,8 +496,8 @@ static llvm::APInt getSqMANP(
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `FHE.mul_eint_int` operation.
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.mul_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::MulEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
@@ -531,8 +531,8 @@ static llvm::APInt getSqMANP(
|
||||
return APIntWidthExtendUMul(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of an
|
||||
// `FHELinalg.add_eint_int` operation.
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of an
|
||||
/// `FHELinalg.add_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::AddEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
@@ -587,8 +587,8 @@ static llvm::APInt getSqMANP(
|
||||
return APIntWidthExtendUAdd(a, b);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `FHELinalg.sub_int_eint` operation.
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHELinalg.sub_int_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::SubIntEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
@@ -678,8 +678,8 @@ static llvm::APInt getSqMANP(
|
||||
return APIntWidthExtendUAdd(a, b);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `FHELinalg.neg_eint` operation.
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHELinalg.neg_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::NegEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
@@ -694,8 +694,8 @@ static llvm::APInt getSqMANP(
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `FHE.mul_eint_int` operation.
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.mul_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::MulEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
@@ -804,8 +804,8 @@ static llvm::APInt calculateSqManpForMatMulWithDenseValues(
|
||||
return maximumNorm;
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `FHE.mul_eint_int` operation.
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.mul_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::MatMulEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
@@ -1508,7 +1508,7 @@ private:
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// For documentation see MANP.td
|
||||
/// For documentation see MANP.td
|
||||
struct MANPPass : public MANPBase<MANPPass> {
|
||||
void runOnOperation() override {
|
||||
mlir::func::FuncOp func = getOperation();
|
||||
@@ -1524,16 +1524,16 @@ protected:
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
// Create an instance of the Minimal Arithmetic Noise Padding analysis
|
||||
// pass. If `debug` is true, for each operation, the pass emits a
|
||||
// remark containing the squared Minimal Arithmetic Noise Padding of
|
||||
// the equivalent dot operation.
|
||||
/// Create an instance of the Minimal Arithmetic Noise Padding analysis
|
||||
/// pass. If `debug` is true, for each operation, the pass emits a
|
||||
/// remark containing the squared Minimal Arithmetic Noise Padding of
|
||||
/// the equivalent dot operation.
|
||||
std::unique_ptr<mlir::Pass> createMANPPass(bool debug) {
|
||||
return std::make_unique<MANPPass>(debug);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// For documentation see MANP.td
|
||||
/// For documentation see MANP.td
|
||||
struct MaxMANPPass : public MaxMANPBase<MaxMANPPass> {
|
||||
void runOnOperation() override {
|
||||
mlir::func::FuncOp func = getOperation();
|
||||
|
||||
@@ -163,7 +163,7 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::Operation &op,
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Avoid addition with constant 0
|
||||
/// Avoid addition with constant 0
|
||||
OpFoldResult AddEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2);
|
||||
auto toAdd = operands[1].dyn_cast_or_null<mlir::IntegerAttr>();
|
||||
@@ -176,7 +176,7 @@ OpFoldResult AddEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Avoid subtraction with constant 0
|
||||
/// Avoid subtraction with constant 0
|
||||
OpFoldResult SubEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2);
|
||||
auto toSub = operands[1].dyn_cast_or_null<mlir::IntegerAttr>();
|
||||
@@ -189,7 +189,7 @@ OpFoldResult SubEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Avoid multiplication with constant 1
|
||||
/// Avoid multiplication with constant 1
|
||||
OpFoldResult MulEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2);
|
||||
auto toMul = operands[1].dyn_cast_or_null<mlir::IntegerAttr>();
|
||||
|
||||
@@ -1717,7 +1717,7 @@ mlir::LogicalResult TransposeOp::verify() {
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Avoid addition with constant tensor of 0s
|
||||
/// Avoid addition with constant tensor of 0s
|
||||
OpFoldResult AddEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2);
|
||||
auto toAdd = operands[1].dyn_cast_or_null<mlir::DenseIntElementsAttr>();
|
||||
@@ -1731,7 +1731,7 @@ OpFoldResult AddEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
return getOperand(0);
|
||||
}
|
||||
|
||||
// Avoid subtraction with constant tensor of 0s
|
||||
/// Avoid subtraction with constant tensor of 0s
|
||||
OpFoldResult SubEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2);
|
||||
auto toSub = operands[1].dyn_cast_or_null<mlir::DenseIntElementsAttr>();
|
||||
@@ -1745,7 +1745,7 @@ OpFoldResult SubEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
return getOperand(0);
|
||||
}
|
||||
|
||||
// Avoid multiplication with constant tensor of 1s
|
||||
/// Avoid multiplication with constant tensor of 1s
|
||||
OpFoldResult MulEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2);
|
||||
auto toMul = operands[1].dyn_cast_or_null<mlir::DenseIntElementsAttr>();
|
||||
|
||||
@@ -22,9 +22,9 @@ namespace concretelang {
|
||||
|
||||
namespace {
|
||||
|
||||
// Creates a `tensor.extract_slice` operation that extracts a
|
||||
// contiguous, 2-dimensional slice with a static size specified by
|
||||
// `sizes` at the dynamic offset `offsets`.
|
||||
/// Creates a `tensor.extract_slice` operation that extracts a
|
||||
/// contiguous, 2-dimensional slice with a static size specified by
|
||||
/// `sizes` at the dynamic offset `offsets`.
|
||||
mlir::tensor::ExtractSliceOp
|
||||
extractContiguous2DSlice(mlir::OpBuilder &builder, mlir::Location loc,
|
||||
mlir::Value T, llvm::ArrayRef<int64_t> sizes,
|
||||
@@ -43,18 +43,18 @@ extractContiguous2DSlice(mlir::OpBuilder &builder, mlir::Location loc,
|
||||
builder.getI64IntegerAttr(1)});
|
||||
}
|
||||
|
||||
// Creates a perfect loop nest of SCF for loops with the lower bounds
|
||||
// `lbs`, the upper bounds `ubs` and the steps `steps` in the order
|
||||
// from the outermost to the innermost loop. The values specified in
|
||||
// `loopCarriedDeps` are loop-carried dependencies carried across all
|
||||
// loops.
|
||||
//
|
||||
// The function `func` is called with a builder for the body of the
|
||||
// innermost loop, the original location `loc`, a vector with all
|
||||
// induction variables from the outermost to the innermost loop and the
|
||||
// loop-carried dependencies.
|
||||
//
|
||||
// Returns the outermost loop.
|
||||
/// Creates a perfect loop nest of SCF for loops with the lower bounds
|
||||
/// `lbs`, the upper bounds `ubs` and the steps `steps` in the order
|
||||
/// from the outermost to the innermost loop. The values specified in
|
||||
/// `loopCarriedDeps` are loop-carried dependencies carried across all
|
||||
/// loops.
|
||||
///
|
||||
/// The function `func` is called with a builder for the body of the
|
||||
/// innermost loop, the original location `loc`, a vector with all
|
||||
/// induction variables from the outermost to the innermost loop and the
|
||||
/// loop-carried dependencies.
|
||||
///
|
||||
/// Returns the outermost loop.
|
||||
mlir::scf::ForOp buildLoopNestWithLoopCarriedDependency(
|
||||
mlir::OpBuilder builder, mlir::Location loc,
|
||||
llvm::ArrayRef<mlir::Value> lbs, llvm::ArrayRef<mlir::Value> ubs,
|
||||
@@ -104,28 +104,28 @@ mlir::scf::ForOp buildLoopNestWithLoopCarriedDependency(
|
||||
return fops[0];
|
||||
}
|
||||
|
||||
// Marker to avoid infinite recursion of the rewriting pattern
|
||||
/// Marker to avoid infinite recursion of the rewriting pattern
|
||||
static const mlir::StringLiteral kTransformMarker =
|
||||
"__internal_fhe_linalg_tiling_marker__";
|
||||
|
||||
// Rewrite an `FHELinalg.matmul_eint_int` operation as an equivalent
|
||||
// sequence of operations consisting of a perfect loop nest of SCF for
|
||||
// loops with a `FHELinalg.matmul_eint_int` operation that performs
|
||||
// a matrix multiplication on a single tile.
|
||||
//
|
||||
// The terminology is as follows:
|
||||
//
|
||||
// - A: The input matrix of encrypted integers of size `NxM`
|
||||
// - B: The input matrix of plaintext integers of size `MxK`
|
||||
// - C: The output matrix of encrypted integers of size `NxK`
|
||||
//
|
||||
// At each iteration of the innermost loop, the generated
|
||||
// `FHELinalg.matmul_eint_int` operation performs a multiplication
|
||||
// of a matrix tile of size `TxU` and a matrix of size `UxV`,
|
||||
// producing a tile of size `UxV`.
|
||||
//
|
||||
// Partial tiles are currently not supported, i.e., `N` must be a
|
||||
// multiple of `T`, `M` a multiple of `U` and `K` a multiple of `V`.
|
||||
/// Rewrite an `FHELinalg.matmul_eint_int` operation as an equivalent
|
||||
/// sequence of operations consisting of a perfect loop nest of SCF for
|
||||
/// loops with a `FHELinalg.matmul_eint_int` operation that performs
|
||||
/// a matrix multiplication on a single tile.
|
||||
///
|
||||
/// The terminology is as follows:
|
||||
///
|
||||
/// - A: The input matrix of encrypted integers of size `NxM`
|
||||
/// - B: The input matrix of plaintext integers of size `MxK`
|
||||
/// - C: The output matrix of encrypted integers of size `NxK`
|
||||
///
|
||||
/// At each iteration of the innermost loop, the generated
|
||||
/// `FHELinalg.matmul_eint_int` operation performs a multiplication
|
||||
/// of a matrix tile of size `TxU` and a matrix of size `UxV`,
|
||||
/// producing a tile of size `UxV`.
|
||||
///
|
||||
/// Partial tiles are currently not supported, i.e., `N` must be a
|
||||
/// multiple of `T`, `M` a multiple of `U` and `K` a multiple of `V`.
|
||||
class MatMulTilingPattern
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::FHELinalg::MatMulEintIntOp> {
|
||||
@@ -312,8 +312,8 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// Perfoms the actual tiling of `FHELinalg.matmul_eint_int`
|
||||
// operations that have been marked with a "tile-sizes" attribute.
|
||||
/// Perfoms the actual tiling of `FHELinalg.matmul_eint_int`
|
||||
/// operations that have been marked with a "tile-sizes" attribute.
|
||||
class FHELinalgTilingPass : public FHELinalgTilingBase<FHELinalgTilingPass> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
@@ -332,8 +332,8 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// Marks all `FHELinalg.matmul_eint_int` operations that with a
|
||||
// "tile-sizes" attribute containing the specified tile sizes.
|
||||
/// Marks all `FHELinalg.matmul_eint_int` operations that with a
|
||||
/// "tile-sizes" attribute containing the specified tile sizes.
|
||||
class FHELinalgTilingMarkerPass
|
||||
: public FHELinalgTilingMarkerBase<FHELinalgTilingMarkerPass> {
|
||||
public:
|
||||
|
||||
@@ -90,7 +90,7 @@ void populateRTBufferizePatterns(
|
||||
}
|
||||
|
||||
namespace {
|
||||
// For documentation see Autopar.td
|
||||
/// For documentation see Autopar.td
|
||||
struct BufferizeDataflowTaskOpsPass
|
||||
: public BufferizeDataflowTaskOpsBase<BufferizeDataflowTaskOpsPass> {
|
||||
|
||||
|
||||
@@ -52,8 +52,8 @@ static bool isCandidateForTask(Operation *op) {
|
||||
FHELinalg::ConcatOp, FHELinalg::FhelinalgConv2DNchwFchwOp>(op);
|
||||
}
|
||||
|
||||
// Identify operations that are beneficial to sink into tasks. These
|
||||
// operations must not have side-effects and not be `isCandidateForTask`
|
||||
/// Identify operations that are beneficial to sink into tasks. These
|
||||
/// operations must not have side-effects and not be `isCandidateForTask`
|
||||
static bool isSinkingBeneficiary(Operation *op) {
|
||||
return isa<FHE::ZeroEintOp, arith::ConstantOp, memref::DimOp, arith::SelectOp,
|
||||
mlir::arith::CmpIOp>(op);
|
||||
@@ -126,7 +126,7 @@ LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) {
|
||||
return success();
|
||||
}
|
||||
|
||||
// For documentation see Autopar.td
|
||||
/// For documentation see Autopar.td
|
||||
struct BuildDataflowTaskGraphPass
|
||||
: public BuildDataflowTaskGraphBase<BuildDataflowTaskGraphPass> {
|
||||
|
||||
@@ -194,7 +194,7 @@ std::unique_ptr<mlir::Pass> createBuildDataflowTaskGraphPass(bool debug) {
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Marker to avoid infinite recursion of the rewriting pattern
|
||||
/// Marker to avoid infinite recursion of the rewriting pattern
|
||||
static const mlir::StringLiteral kTransformMarker =
|
||||
"_internal_RT_FixDataflowTaskOpInputsPattern_marker__";
|
||||
|
||||
@@ -232,7 +232,7 @@ public:
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// For documentation see Autopar.td
|
||||
/// For documentation see Autopar.td
|
||||
struct FixupDataflowTaskOpsPass
|
||||
: public FixupDataflowTaskOpsBase<FixupDataflowTaskOpsPass> {
|
||||
|
||||
|
||||
@@ -271,7 +271,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
DFTOp.erase();
|
||||
}
|
||||
|
||||
// For documentation see Autopar.td
|
||||
/// For documentation see Autopar.td
|
||||
struct LowerDataflowTasksPass
|
||||
: public LowerDataflowTasksBase<LowerDataflowTasksPass> {
|
||||
|
||||
|
||||
@@ -79,8 +79,8 @@ LLVM::LLVMFuncOp getOrInsertFuncOpDecl(mlir::Operation *op,
|
||||
return funcOp;
|
||||
}
|
||||
|
||||
// This function is only needed for debug purposes to inspect values
|
||||
// in the generated code - it is therefore not generally in use.
|
||||
/// This function is only needed for debug purposes to inspect values
|
||||
/// in the generated code - it is therefore not generally in use.
|
||||
LLVM_ATTRIBUTE_UNUSED void
|
||||
insertPrintDebugCall(ConversionPatternRewriter &rewriter, mlir::Operation *op,
|
||||
Value val) {
|
||||
|
||||
@@ -51,9 +51,9 @@ mlir::LogicalResult _verifyGLWEIntegerOperator(mlir::OpState &op,
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// verifyGLWEIntegerOperator verify parameters of operators that has the
|
||||
// following signature (!TFHE.glwe<{dim,poly,bits}{p}>, ip+1) ->
|
||||
// (!TFHE.glwe<{dim,poly,bits}{p}>))
|
||||
/// verifyGLWEIntegerOperator verify parameters of operators that has the
|
||||
/// following signature (!TFHE.glwe<{dim,poly,bits}{p}>, ip+1) ->
|
||||
/// (!TFHE.glwe<{dim,poly,bits}{p}>))
|
||||
template <class Operator>
|
||||
mlir::LogicalResult verifyGLWEIntegerOperator(Operator &op) {
|
||||
auto a = ((mlir::Type)(op.a().getType())).cast<GLWECipherTextType>();
|
||||
@@ -64,9 +64,9 @@ mlir::LogicalResult verifyGLWEIntegerOperator(Operator &op) {
|
||||
return _verifyGLWEIntegerOperator(op, a, b, result);
|
||||
}
|
||||
|
||||
// verifyIntegerGLWEOperator verify parameters of operators that has the
|
||||
// following signature (ip+1, !TFHE.glwe<{dim,poly,bits}{p}>) ->
|
||||
// (!TFHE.glwe<{dim,poly,bits}{p}>))
|
||||
/// verifyIntegerGLWEOperator verify parameters of operators that has the
|
||||
/// following signature (ip+1, !TFHE.glwe<{dim,poly,bits}{p}>) ->
|
||||
/// (!TFHE.glwe<{dim,poly,bits}{p}>))
|
||||
template <class Operator>
|
||||
mlir::LogicalResult verifyIntegerGLWEOperator(Operator &op) {
|
||||
auto a = ((mlir::Type)(op.a().getType())).cast<IntegerType>();
|
||||
@@ -77,10 +77,10 @@ mlir::LogicalResult verifyIntegerGLWEOperator(Operator &op) {
|
||||
return _verifyGLWEIntegerOperator(op, b, a, result);
|
||||
}
|
||||
|
||||
// verifyBinaryGLWEOperator verify parameters of operators that has the
|
||||
// following signature (!TFHE.glwe<{dim,poly,bits}{p}>,
|
||||
// !TFHE.glwe<{dim,poly,bits}{p}>) ->
|
||||
// (!TFHE.glwe<{dim,poly,bits}{p}>))
|
||||
/// verifyBinaryGLWEOperator verify parameters of operators that has the
|
||||
/// following signature (!TFHE.glwe<{dim,poly,bits}{p}>,
|
||||
/// !TFHE.glwe<{dim,poly,bits}{p}>) ->
|
||||
/// (!TFHE.glwe<{dim,poly,bits}{p}>))
|
||||
template <class Operator>
|
||||
mlir::LogicalResult verifyBinaryGLWEOperator(Operator &op) {
|
||||
auto a = ((mlir::Type)(op.a().getType())).cast<GLWECipherTextType>();
|
||||
@@ -111,9 +111,9 @@ mlir::LogicalResult verifyBinaryGLWEOperator(Operator &op) {
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// verifyUnaryGLWEOperator verify parameters of operators that has the following
|
||||
// signature (!TFHE.glwe<{dim,poly,bits}{p}>) ->
|
||||
// (!TFHE.glwe<{dim,poly,bits}{p}>))
|
||||
/// verifyUnaryGLWEOperator verify parameters of operators that has the
|
||||
/// following signature (!TFHE.glwe<{dim,poly,bits}{p}>) ->
|
||||
/// (!TFHE.glwe<{dim,poly,bits}{p}>))
|
||||
template <class Operator>
|
||||
mlir::LogicalResult verifyUnaryGLWEOperator(Operator &op) {
|
||||
auto a = ((mlir::Type)(op.a().getType())).cast<GLWECipherTextType>();
|
||||
|
||||
@@ -3,14 +3,11 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
/**
|
||||
This file implements the dataflow runtime. It encapsulates all of
|
||||
the underlying communication, parallelism, etc. and only exposes a
|
||||
simplified interface for code generation in runtime_api.h
|
||||
|
||||
This hides the details of implementation, including of the HPX
|
||||
framework currently used, from the code generation side.
|
||||
*/
|
||||
/// This file implements the dataflow runtime. It encapsulates all of
|
||||
/// the underlying communication, parallelism, etc. and only exposes a
|
||||
/// simplified interface for code generation in runtime_api.h
|
||||
/// This hides the details of implementation, including of the HPX
|
||||
/// framework currently used, from the code generation side.
|
||||
|
||||
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
|
||||
|
||||
@@ -55,11 +52,11 @@ void _dfr_deallocate_future(void *in) {
|
||||
delete (static_cast<hpx::shared_future<void *> *>(in));
|
||||
}
|
||||
|
||||
// Runtime generic async_task. Each first NUM_PARAMS pairs of
|
||||
// arguments in the variadic list corresponds to a void* pointer on a
|
||||
// hpx::future<void*> and the size of data within the future. After
|
||||
// that come NUM_OUTPUTS pairs of hpx::future<void*>* and size_t for
|
||||
// the returns.
|
||||
/// Runtime generic async_task. Each first NUM_PARAMS pairs of
|
||||
/// arguments in the variadic list corresponds to a void* pointer on a
|
||||
/// hpx::future<void*> and the size of data within the future. After
|
||||
/// that come NUM_OUTPUTS pairs of hpx::future<void*>* and size_t for
|
||||
/// the returns.
|
||||
void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
...) {
|
||||
std::vector<void *> params;
|
||||
@@ -776,7 +773,7 @@ void _dfr_debug_print_task(const char *name, int inputs, int outputs) {
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
// Generic utility function for printing debug info
|
||||
/// Generic utility function for printing debug info
|
||||
void _dfr_print_debug(size_t val) {
|
||||
hpx::cout << "_dfr_print_debug : " << val << "\n" << std::flush;
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context) {
|
||||
return context->evaluationKeys.getBsk();
|
||||
}
|
||||
|
||||
// Instantiate one engine per thread on demand
|
||||
/// Instantiate one engine per thread on demand
|
||||
Engine *get_engine(mlir::concretelang::RuntimeContext *context) {
|
||||
pthread_t threadId = pthread_self();
|
||||
std::lock_guard<std::mutex> guard(context->engines_map_guard);
|
||||
|
||||
@@ -15,13 +15,13 @@
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
|
||||
// Helper class template that yields an unsigned integer type given a
|
||||
// size in bytes
|
||||
/// Helper class template that yields an unsigned integer type given a
|
||||
/// size in bytes
|
||||
template <std::size_t size> struct int_type_of_size {};
|
||||
template <> struct int_type_of_size<4> { typedef uint32_t type; };
|
||||
template <> struct int_type_of_size<8> { typedef uint64_t type; };
|
||||
|
||||
// Converts one function pointer into another
|
||||
/// Converts one function pointer into another
|
||||
// TODO: Not sure this is valid in all implementations / on all
|
||||
// architectures
|
||||
template <typename FnDstT, typename FnSrcT> FnDstT convert_fnptr(FnSrcT src) {
|
||||
|
||||
@@ -19,13 +19,13 @@ print(
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
|
||||
// Helper class template that yields an unsigned integer type given a
|
||||
// size in bytes
|
||||
/// Helper class template that yields an unsigned integer type given a
|
||||
/// size in bytes
|
||||
template <std::size_t size> struct int_type_of_size {};
|
||||
template <> struct int_type_of_size<4> { typedef uint32_t type; };
|
||||
template <> struct int_type_of_size<8> { typedef uint64_t type; };
|
||||
|
||||
// Converts one function pointer into another
|
||||
/// Converts one function pointer into another
|
||||
// TODO: Not sure this is valid in all implementations / on all
|
||||
// architectures
|
||||
template <typename FnDstT, typename FnSrcT> FnDstT convert_fnptr(FnSrcT src) {
|
||||
|
||||
@@ -39,8 +39,8 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
// Creates a new compilation context that can be shared across
|
||||
// compilation engines and results
|
||||
/// Creates a new compilation context that can be shared across
|
||||
/// compilation engines and results
|
||||
std::shared_ptr<CompilationContext> CompilationContext::createShared() {
|
||||
return std::make_shared<CompilationContext>();
|
||||
}
|
||||
@@ -53,8 +53,8 @@ CompilationContext::~CompilationContext() {
|
||||
delete this->llvmContext;
|
||||
}
|
||||
|
||||
// Returns the MLIR context for a compilation context. Creates and
|
||||
// initializes a new MLIR context if necessary.
|
||||
/// Returns the MLIR context for a compilation context. Creates and
|
||||
/// initializes a new MLIR context if necessary.
|
||||
mlir::MLIRContext *CompilationContext::getMLIRContext() {
|
||||
if (this->mlirContext == nullptr) {
|
||||
mlir::DialectRegistry registry;
|
||||
@@ -79,8 +79,8 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() {
|
||||
return this->mlirContext;
|
||||
}
|
||||
|
||||
// Returns the LLVM context for a compilation context. Creates and
|
||||
// initializes a new LLVM context if necessary.
|
||||
/// Returns the LLVM context for a compilation context. Creates and
|
||||
/// initializes a new LLVM context if necessary.
|
||||
llvm::LLVMContext *CompilationContext::getLLVMContext() {
|
||||
if (this->llvmContext == nullptr)
|
||||
this->llvmContext = new llvm::LLVMContext();
|
||||
@@ -88,9 +88,9 @@ llvm::LLVMContext *CompilationContext::getLLVMContext() {
|
||||
return this->llvmContext;
|
||||
}
|
||||
|
||||
// Sets the FHE constraints for the compilation. Overrides any
|
||||
// automatically detected configuration and prevents the autodetection
|
||||
// pass from running.
|
||||
/// Sets the FHE constraints for the compilation. Overrides any
|
||||
/// automatically detected configuration and prevents the autodetection
|
||||
/// pass from running.
|
||||
void CompilerEngine::setFHEConstraints(
|
||||
const mlir::concretelang::V0FHEConstraint &c) {
|
||||
this->overrideMaxEintPrecision = c.p;
|
||||
@@ -112,7 +112,7 @@ void CompilerEngine::setEnablePass(
|
||||
this->enablePass = enablePass;
|
||||
}
|
||||
|
||||
// Returns the overwritten V0FHEConstraint or try to compute them from FHE
|
||||
/// Returns the overwritten V0FHEConstraint or try to compute them from FHE
|
||||
llvm::Expected<llvm::Optional<mlir::concretelang::V0FHEConstraint>>
|
||||
CompilerEngine::getV0FHEConstraint(CompilationResult &res) {
|
||||
mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
|
||||
@@ -136,7 +136,7 @@ CompilerEngine::getV0FHEConstraint(CompilationResult &res) {
|
||||
return fheConstraintsOrErr.get();
|
||||
}
|
||||
|
||||
// set the fheContext field if the v0Constraint can be computed
|
||||
/// set the fheContext field if the v0Constraint can be computed
|
||||
llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
|
||||
auto fheConstraintOrErr = getV0FHEConstraint(res);
|
||||
if (auto err = fheConstraintOrErr.takeError())
|
||||
@@ -165,10 +165,10 @@ llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
|
||||
}
|
||||
|
||||
using OptionalLib = llvm::Optional<std::shared_ptr<CompilerEngine::Library>>;
|
||||
// Compile the sources managed by the source manager `sm` to the
|
||||
// target dialect `target`. If successful, the result can be retrieved
|
||||
// using `getModule()` and `getLLVMModule()`, respectively depending
|
||||
// on the target dialect.
|
||||
/// Compile the sources managed by the source manager `sm` to the
|
||||
/// target dialect `target`. If successful, the result can be retrieved
|
||||
/// using `getModule()` and `getLLVMModule()`, respectively depending
|
||||
/// on the target dialect.
|
||||
llvm::Expected<CompilerEngine::CompilationResult>
|
||||
CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
std::unique_ptr<mlir::SourceMgrDiagnosticVerifierHandler> smHandler;
|
||||
@@ -371,19 +371,19 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
return std::move(res);
|
||||
}
|
||||
|
||||
// Compile the source `s` to the target dialect `target`. If successful, the
|
||||
// result can be retrieved using `getModule()` and `getLLVMModule()`,
|
||||
// respectively depending on the target dialect.
|
||||
/// Compile the source `s` to the target dialect `target`. If successful, the
|
||||
/// result can be retrieved using `getModule()` and `getLLVMModule()`,
|
||||
/// respectively depending on the target dialect.
|
||||
llvm::Expected<CompilerEngine::CompilationResult>
|
||||
CompilerEngine::compile(llvm::StringRef s, Target target, OptionalLib lib) {
|
||||
std::unique_ptr<llvm::MemoryBuffer> mb = llvm::MemoryBuffer::getMemBuffer(s);
|
||||
return this->compile(std::move(mb), target, lib);
|
||||
}
|
||||
|
||||
// Compile the contained in `buffer` to the target dialect
|
||||
// `target`. If successful, the result can be retrieved using
|
||||
// `getModule()` and `getLLVMModule()`, respectively depending on the
|
||||
// target dialect.
|
||||
/// Compile the contained in `buffer` to the target dialect
|
||||
/// `target`. If successful, the result can be retrieved using
|
||||
/// `getModule()` and `getLLVMModule()`, respectively depending on the
|
||||
/// target dialect.
|
||||
llvm::Expected<CompilerEngine::CompilationResult>
|
||||
CompilerEngine::compile(std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
Target target, OptionalLib lib) {
|
||||
@@ -442,7 +442,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, std::string outputDirPath,
|
||||
return *outputLib.get();
|
||||
}
|
||||
|
||||
/** Returns the path of the shared library */
|
||||
/// Returns the path of the shared library
|
||||
std::string
|
||||
CompilerEngine::Library::getSharedLibraryPath(std::string outputDirPath) {
|
||||
llvm::SmallString<0> sharedLibraryPath(outputDirPath);
|
||||
@@ -450,7 +450,7 @@ CompilerEngine::Library::getSharedLibraryPath(std::string outputDirPath) {
|
||||
return sharedLibraryPath.str().str();
|
||||
}
|
||||
|
||||
/** Returns the path of the static library */
|
||||
/// Returns the path of the static library
|
||||
std::string
|
||||
CompilerEngine::Library::getStaticLibraryPath(std::string outputDirPath) {
|
||||
llvm::SmallString<0> staticLibraryPath(outputDirPath);
|
||||
@@ -458,7 +458,7 @@ CompilerEngine::Library::getStaticLibraryPath(std::string outputDirPath) {
|
||||
return staticLibraryPath.str().str();
|
||||
}
|
||||
|
||||
/** Returns the path of the static library */
|
||||
/// Returns the path of the static library
|
||||
std::string
|
||||
CompilerEngine::Library::getClientParametersPath(std::string outputDirPath) {
|
||||
llvm::SmallString<0> clientParametersPath(outputDirPath);
|
||||
|
||||
@@ -30,7 +30,7 @@ const auto securityLevel = SECURITY_LEVEL_128;
|
||||
const auto keyFormat = KEY_FORMAT_BINARY;
|
||||
const auto v0Curve = getV0Curves(securityLevel, keyFormat);
|
||||
|
||||
// For the v0 the secretKeyID and precision are the same for all gates.
|
||||
/// For the v0 the secretKeyID and precision are the same for all gates.
|
||||
llvm::Expected<CircuitGate> gateFromMLIRType(LweSecretKeyID secretKeyID,
|
||||
Precision precision,
|
||||
Variance variance,
|
||||
|
||||
@@ -11,17 +11,17 @@ static bool verbose = false;
|
||||
static StreamWrap<llvm::raw_ostream> errWrap(&llvm::errs());
|
||||
static StreamWrap<llvm::raw_ostream> nullWrap(&llvm::nulls());
|
||||
|
||||
// Returns a stream for logging errors
|
||||
/// Returns a stream for logging errors
|
||||
StreamWrap<llvm::raw_ostream> &log_error(void) { return errWrap; }
|
||||
|
||||
// Returns a stream that either shows or discards messages depending
|
||||
// on the setup through `setupLogging`.
|
||||
/// Returns a stream that either shows or discards messages depending
|
||||
/// on the setup through `setupLogging`.
|
||||
StreamWrap<llvm::raw_ostream> &log_verbose(void) {
|
||||
return (verbose) ? errWrap : nullWrap;
|
||||
}
|
||||
|
||||
// Sets up logging. If `verbose` is false, messages passed to
|
||||
// `log_verbose` will be discarded.
|
||||
/// Sets up logging. If `verbose` is false, messages passed to
|
||||
/// `log_verbose` will be discarded.
|
||||
void setupLogging(bool verbose) { ::mlir::concretelang::verbose = verbose; }
|
||||
bool isVerbose() { return verbose; }
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -271,32 +271,32 @@ cmdlineCompilationOptions() {
|
||||
return options;
|
||||
}
|
||||
|
||||
// Process a single source buffer
|
||||
//
|
||||
// The parameter `action` specifies how the buffer should be processed
|
||||
// and thus defines the output.
|
||||
//
|
||||
// If the specified action involves JIT compilation, `funcName`
|
||||
// designates the function to JIT compile. This function is invoked
|
||||
// using the parameters given in `jitArgs`.
|
||||
//
|
||||
// The parameter `parametrizeTFHE` defines, whether the
|
||||
// parametrization pass for TFHE is executed. If the `action` does
|
||||
// not involve any MidlFHE manipulation, this parameter does not have
|
||||
// any effect.
|
||||
//
|
||||
// The parameters `overrideMaxEintPrecision` and `overrideMaxMANP`, if
|
||||
// set, override the values for the maximum required precision of
|
||||
// encrypted integers and the maximum value for the Minimum Arithmetic
|
||||
// Noise Padding otherwise determined automatically.
|
||||
//
|
||||
// If `verifyDiagnostics` is `true`, the procedure only checks if the
|
||||
// diagnostic messages provided in the source buffer using
|
||||
// `expected-error` are produced. If `verifyDiagnostics` is `false`,
|
||||
// the procedure checks if the parsed module is valid and if all
|
||||
// requested transformations succeeded.
|
||||
//
|
||||
// Compilation output is written to the stream specified by `os`.
|
||||
/// Process a single source buffer
|
||||
///
|
||||
/// The parameter `action` specifies how the buffer should be processed
|
||||
/// and thus defines the output.
|
||||
///
|
||||
/// If the specified action involves JIT compilation, `funcName`
|
||||
/// designates the function to JIT compile. This function is invoked
|
||||
/// using the parameters given in `jitArgs`.
|
||||
///
|
||||
/// The parameter `parametrizeTFHE` defines, whether the
|
||||
/// parametrization pass for TFHE is executed. If the `action` does
|
||||
/// not involve any MidlFHE manipulation, this parameter does not have
|
||||
/// any effect.
|
||||
///
|
||||
/// The parameters `overrideMaxEintPrecision` and `overrideMaxMANP`, if
|
||||
/// set, override the values for the maximum required precision of
|
||||
/// encrypted integers and the maximum value for the Minimum Arithmetic
|
||||
/// Noise Padding otherwise determined automatically.
|
||||
///
|
||||
/// If `verifyDiagnostics` is `true`, the procedure only checks if the
|
||||
/// diagnostic messages provided in the source buffer using
|
||||
/// `expected-error` are produced. If `verifyDiagnostics` is `false`,
|
||||
/// the procedure checks if the parsed module is valid and if all
|
||||
/// requested transformations succeeded.
|
||||
///
|
||||
/// Compilation output is written to the stream specified by `os`.
|
||||
mlir::LogicalResult processInputBuffer(
|
||||
std::unique_ptr<llvm::MemoryBuffer> buffer, std::string sourceFileName,
|
||||
mlir::concretelang::CompilationOptions &options, enum Action action,
|
||||
|
||||
Reference in New Issue
Block a user