mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
Rebase onto llvm-project 465ee9bfb26d with local changes
This commit rebases the compiler onto commit 465ee9bfb26d from
llvm-project with locally maintained patches on top, i.e.:
* 5d8669d669ee: Fix the element alignment (size) for memrefCopy
* 4239163ea337: fix: Do not fold the memref.subview if the offset are
!= 0 and strides != 1
* 72c5decfcc21: remove github stuff from llvm
* 8d0ce8f9eca1: Support arbitrary element types in named operations
via attributes
* 94f64805c38c: Copy attributes of scf.for on bufferization and make
it an allocation hoisting barrier
Main upstream changes from llvm-project that required modification of
concretecompiler:
* Switch to C++17
* Various changes in the interfaces for linalg named operations
* Transition from `llvm::Optional` to `std::optional`
* Use of enums instead of string values for iterator types in linalg
* Changed default naming convention of getter methods in
ODS-generated operation classes from `some_value()` to
`getSomeValue()`
* Renaming of Arithmetic dialect to Arith
* Refactoring of side effect interfaces (i.e., renaming from
`NoSideEffect` to `Pure`)
* Re-design of the data flow analysis framework
* Refactoring of build targets for Python bindings
* Refactoring of array attributes with integer values
* Renaming of `linalg.init_tensor` to `tensor.empty`
* Emission of `linalg.map` operations in bufferization of the Tensor
dialect requiring another linalg conversion pass and registration
of the bufferization op interfaces for linalg operations
* Refactoring of the one-shot bufferizer
* Necessity to run the expand-strided-metadata, affine-to-std and
finalize-memref-to-llvm passes before converson to the LLVM
dialect
* Renaming of `BlockAndValueMapping` to `IRMapping`
* Changes in the build function of `LLVM::CallOp`
* Refactoring of the construction of `llvm::ArrayRef` and
`llvm::MutableArrayRef` (direct invocation of constructor instead
of builder functions for some cases)
* New naming conventions for generated SSA values requiring rewrite
of some check tests
* Refactoring of `mlir::LLVM::lookupOrCreateMallocFn()`
* Interface changes in generated type parsers
* New dependencies for to mlir_float16_utils and
MLIRSparseTensorRuntime for the runtime
* Overhaul of MLIR-c deleting `mlir-c/Registration.h`
* Deletion of library MLIRLinalgToSPIRV
* Deletion of library MLIRLinalgAnalysis
* Deletion of library MLIRMemRefUtils
* Deletion of library MLIRQuantTransforms
* Deletion of library MLIRVectorToROCDL
This commit is contained in:
committed by
Quentin Bourgerie
parent
8ebfccd9a7
commit
c8c969773e
@@ -4,7 +4,7 @@ project(concretecompiler LANGUAGES C CXX)
|
||||
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 14)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
# Needed on linux with clang 15 and on MacOS because cxx emits dollars in the optimizer C++ API
|
||||
|
||||
@@ -416,7 +416,7 @@ rust-format:
|
||||
|
||||
# libraries we want to have in the installation that aren't already a deps of other targets
|
||||
install-deps:
|
||||
cmake --build $(BUILD_DIR) --target MLIRCAPIRegistration
|
||||
cmake --build $(BUILD_DIR) --target MLIRCAPIRegisterEverything
|
||||
|
||||
ifeq ($(OS), darwin)
|
||||
# rsync should normally come pre-installed on macOS
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#define CONCRETELANG_C_DIALECT_FHE_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#define CONCRETELANG_C_DIALECT_FHELINALG_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
||||
@@ -104,7 +104,7 @@ library_get_client_parameters_path(LibrarySupport_Py support);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::KeySet>
|
||||
key_set(concretelang::clientlib::ClientParameters clientParameters,
|
||||
llvm::Optional<concretelang::clientlib::KeySetCache> cache);
|
||||
std::optional<concretelang::clientlib::KeySetCache> cache);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
|
||||
encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
|
||||
|
||||
@@ -184,11 +184,11 @@ static inline bool operator==(const ChunkInfo &lhs, const ChunkInfo &rhs) {
|
||||
}
|
||||
|
||||
struct CircuitGate {
|
||||
llvm::Optional<EncryptionGate> encryption;
|
||||
std::optional<EncryptionGate> encryption;
|
||||
CircuitGateShape shape;
|
||||
llvm::Optional<ChunkInfo> chunkInfo;
|
||||
std::optional<ChunkInfo> chunkInfo;
|
||||
|
||||
bool isEncrypted() { return encryption.hasValue(); }
|
||||
bool isEncrypted() { return encryption.has_value(); }
|
||||
|
||||
/// byteSize returns the size in bytes for this gate.
|
||||
size_t byteSize(std::vector<LweSecretKeyParam> secretKeys) {
|
||||
@@ -240,7 +240,7 @@ struct ClientParameters {
|
||||
|
||||
outcome::checked<LweSecretKeyParam, StringError>
|
||||
lweSecretKeyParam(CircuitGate gate) {
|
||||
if (!gate.encryption.hasValue()) {
|
||||
if (!gate.encryption.has_value()) {
|
||||
return StringError("gate is not encrypted");
|
||||
}
|
||||
assert(gate.encryption->secretKeyID < secretKeys.size());
|
||||
@@ -250,7 +250,7 @@ struct ClientParameters {
|
||||
|
||||
/// bufferSize returns the size of the whole buffer of a gate.
|
||||
int64_t bufferSize(CircuitGate gate) {
|
||||
if (!gate.encryption.hasValue()) {
|
||||
if (!gate.encryption.has_value()) {
|
||||
// Value is not encrypted just returns the tensor size
|
||||
return gate.shape.size;
|
||||
}
|
||||
@@ -261,7 +261,7 @@ struct ClientParameters {
|
||||
|
||||
/// lweBufferSize returns the size of one ciphertext of a gate.
|
||||
int64_t lweBufferSize(CircuitGate gate) {
|
||||
assert(gate.encryption.hasValue());
|
||||
assert(gate.encryption.has_value());
|
||||
auto nbBlocks = gate.encryption->encoding.crt.size();
|
||||
nbBlocks = nbBlocks == 0 ? 1 : nbBlocks;
|
||||
|
||||
@@ -273,7 +273,7 @@ struct ClientParameters {
|
||||
/// bufferShape returns the shape of the tensor for the given gate. It returns
|
||||
/// the shape used at low-level, i.e. contains the dimensions for ciphertexts.
|
||||
std::vector<int64_t> bufferShape(CircuitGate gate) {
|
||||
if (!gate.encryption.hasValue()) {
|
||||
if (!gate.encryption.has_value()) {
|
||||
// Value is not encrypted just returns the tensor shape
|
||||
return gate.shape.dimensions;
|
||||
}
|
||||
|
||||
@@ -159,7 +159,7 @@ public:
|
||||
// Set sizes
|
||||
std::vector<int64_t> sizes = keySet.clientParameters().bufferShape(input);
|
||||
|
||||
if (input.encryption.hasValue()) {
|
||||
if (input.encryption.has_value()) {
|
||||
TensorData td(sizes, EncryptedScalarElementType,
|
||||
EncryptedScalarElementWidth);
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ private:
|
||||
|
||||
///////////////////////////////////////////////
|
||||
// Convenient positional mapping between positional gate en secret key
|
||||
typedef std::vector<std::pair<CircuitGate, llvm::Optional<LweSecretKey>>>
|
||||
typedef std::vector<std::pair<CircuitGate, std::optional<LweSecretKey>>>
|
||||
SecretKeyGateMapping;
|
||||
outcome::checked<SecretKeyGateMapping, StringError>
|
||||
mapCircuitGateLweSecretKey(std::vector<CircuitGate> gates);
|
||||
|
||||
@@ -113,7 +113,7 @@ struct PublicResult {
|
||||
// Chunked integers are represented as tensors at a lower level, so we need
|
||||
// to deal with them as tensors, then build the resulting scalar out of the
|
||||
// tensor values
|
||||
if (gate.chunkInfo.hasValue()) {
|
||||
if (gate.chunkInfo.has_value()) {
|
||||
OUTCOME_TRY(std::vector<uint64_t> decryptedChunks,
|
||||
this->asClearTextVector<uint64_t>(keySet, pos));
|
||||
uint64_t decrypted = fromChunks(decryptedChunks, gate.chunkInfo->width);
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#ifndef CONCRETELANG_TRANSFORMS_PASSES_H
|
||||
#define CONCRETELANG_TRANSFORMS_PASSES_H
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
|
||||
@@ -79,7 +79,7 @@ def SDFGToStreamEmulator : Pass<"sdfg-to-stream-emulator", "mlir::ModuleOp"> {
|
||||
def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from MLIR lowerable dialects to LLVM";
|
||||
let constructor = "mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass()";
|
||||
let dependentDialects = ["mlir::func::FuncDialect", "mlir::arith::ArithmeticDialect", "mlir::scf::SCFDialect", "mlir::LLVM::LLVMDialect"];
|
||||
let dependentDialects = ["mlir::func::FuncDialect", "mlir::arith::ArithDialect", "mlir::scf::SCFDialect", "mlir::LLVM::LLVMDialect"];
|
||||
let options = [];
|
||||
}
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ struct V0Parameter {
|
||||
size_t ksLevel;
|
||||
size_t ksLogBase;
|
||||
|
||||
llvm::Optional<LargeIntegerParameter> largeInteger;
|
||||
std::optional<LargeIntegerParameter> largeInteger;
|
||||
|
||||
// TODO remove the shift when we have true polynomial size
|
||||
size_t getPolynomialSize() { return 1 << logPolynomialSize; }
|
||||
|
||||
@@ -31,7 +31,7 @@ class Concrete_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<Concrete_Dialect, mnemonic, traits>;
|
||||
|
||||
|
||||
def Concrete_AddLweTensorOp : Concrete_Op<"add_lwe_tensor", [NoSideEffect]> {
|
||||
def Concrete_AddLweTensorOp : Concrete_Op<"add_lwe_tensor", [Pure]> {
|
||||
let summary = "Returns the sum of 2 lwe ciphertexts";
|
||||
|
||||
let arguments = (ins
|
||||
@@ -51,7 +51,7 @@ def Concrete_AddLweBufferOp : Concrete_Op<"add_lwe_buffer"> {
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_AddPlaintextLweTensorOp : Concrete_Op<"add_plaintext_lwe_tensor", [NoSideEffect]> {
|
||||
def Concrete_AddPlaintextLweTensorOp : Concrete_Op<"add_plaintext_lwe_tensor", [Pure]> {
|
||||
let summary = "Returns the sum of a clear integer and an lwe ciphertext";
|
||||
|
||||
let arguments = (ins Concrete_LweTensor:$lhs, I64:$rhs);
|
||||
@@ -68,7 +68,7 @@ def Concrete_AddPlaintextLweBufferOp : Concrete_Op<"add_plaintext_lwe_buffer"> {
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_MulCleartextLweTensorOp : Concrete_Op<"mul_cleartext_lwe_tensor", [NoSideEffect]> {
|
||||
def Concrete_MulCleartextLweTensorOp : Concrete_Op<"mul_cleartext_lwe_tensor", [Pure]> {
|
||||
let summary = "Returns the product of a clear integer and a lwe ciphertext";
|
||||
|
||||
let arguments = (ins Concrete_LweTensor:$lhs, I64:$rhs);
|
||||
@@ -85,7 +85,7 @@ def Concrete_MulCleartextLweBufferOp : Concrete_Op<"mul_cleartext_lwe_buffer"> {
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_NegateLweTensorOp : Concrete_Op<"negate_lwe_tensor", [NoSideEffect]> {
|
||||
def Concrete_NegateLweTensorOp : Concrete_Op<"negate_lwe_tensor", [Pure]> {
|
||||
let summary = "Negates a lwe ciphertext";
|
||||
|
||||
let arguments = (ins Concrete_LweTensor:$ciphertext);
|
||||
@@ -101,7 +101,7 @@ def Concrete_NegateLweBufferOp : Concrete_Op<"negate_lwe_buffer"> {
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_EncodeExpandLutForBootstrapTensorOp : Concrete_Op<"encode_expand_lut_for_bootstrap_tensor", [NoSideEffect]> {
|
||||
def Concrete_EncodeExpandLutForBootstrapTensorOp : Concrete_Op<"encode_expand_lut_for_bootstrap_tensor", [Pure]> {
|
||||
let summary =
|
||||
"Encode and expand a lookup table so that it can be used for a bootstrap";
|
||||
|
||||
@@ -128,7 +128,7 @@ def Concrete_EncodeExpandLutForBootstrapBufferOp : Concrete_Op<"encode_expand_lu
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_EncodeLutForCrtWopPBSTensorOp : Concrete_Op<"encode_lut_for_crt_woppbs_tensor", [NoSideEffect]> {
|
||||
def Concrete_EncodeLutForCrtWopPBSTensorOp : Concrete_Op<"encode_lut_for_crt_woppbs_tensor", [Pure]> {
|
||||
let summary =
|
||||
"Encode and expand a lookup table so that it can be used for a wop pbs";
|
||||
|
||||
@@ -157,7 +157,7 @@ def Concrete_EncodeLutForCrtWopPBSBufferOp : Concrete_Op<"encode_lut_for_crt_wop
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_EncodePlaintextWithCrtTensorOp : Concrete_Op<"encode_plaintext_with_crt_tensor", [NoSideEffect]> {
|
||||
def Concrete_EncodePlaintextWithCrtTensorOp : Concrete_Op<"encode_plaintext_with_crt_tensor", [Pure]> {
|
||||
let summary =
|
||||
"Encodes a plaintext by decomposing it on a crt basis";
|
||||
|
||||
@@ -182,7 +182,7 @@ def Concrete_EncodePlaintextWithCrtBufferOp : Concrete_Op<"encode_plaintext_with
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_BootstrapLweTensorOp : Concrete_Op<"bootstrap_lwe_tensor", [NoSideEffect]> {
|
||||
def Concrete_BootstrapLweTensorOp : Concrete_Op<"bootstrap_lwe_tensor", [Pure]> {
|
||||
let summary = "Bootstraps an LWE ciphertext with a GLWE trivial encryption of the lookup table";
|
||||
|
||||
let arguments = (ins
|
||||
@@ -214,7 +214,7 @@ def Concrete_BootstrapLweTensorOp : Concrete_Op<"bootstrap_lwe_tensor", [NoSideE
|
||||
|
||||
return builder.create<BatchedBootstrapLweTensorOp>(
|
||||
mlir::TypeRange{resType},
|
||||
mlir::ValueRange{batchedOperands, lookup_table()},
|
||||
mlir::ValueRange{batchedOperands, getLookupTable()},
|
||||
getOperation()->getAttrs());
|
||||
}
|
||||
}];
|
||||
@@ -236,7 +236,7 @@ def Concrete_BootstrapLweBufferOp : Concrete_Op<"bootstrap_lwe_buffer"> {
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_BatchedBootstrapLweTensorOp : Concrete_Op<"batched_bootstrap_lwe_tensor", [NoSideEffect]> {
|
||||
def Concrete_BatchedBootstrapLweTensorOp : Concrete_Op<"batched_bootstrap_lwe_tensor", [Pure]> {
|
||||
let summary = "Batched version of BootstrapLweOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins
|
||||
@@ -268,7 +268,7 @@ def Concrete_BatchedBootstrapLweBufferOp : Concrete_Op<"batched_bootstrap_lwe_bu
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_KeySwitchLweTensorOp : Concrete_Op<"keyswitch_lwe_tensor", [NoSideEffect]> {
|
||||
def Concrete_KeySwitchLweTensorOp : Concrete_Op<"keyswitch_lwe_tensor", [Pure]> {
|
||||
let summary = "Keyswitches an LWE ciphertext";
|
||||
|
||||
let arguments = (ins
|
||||
@@ -317,7 +317,7 @@ def Concrete_KeySwitchLweBufferOp : Concrete_Op<"keyswitch_lwe_buffer"> {
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_BatchedKeySwitchLweTensorOp : Concrete_Op<"batched_keyswitch_lwe_tensor", [NoSideEffect]> {
|
||||
def Concrete_BatchedKeySwitchLweTensorOp : Concrete_Op<"batched_keyswitch_lwe_tensor", [Pure]> {
|
||||
let summary = "Batched version of KeySwitchLweOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins
|
||||
@@ -344,7 +344,7 @@ def Concrete_BatchedKeySwitchLweBufferOp : Concrete_Op<"batched_keyswitch_lwe_bu
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_WopPBSCRTLweTensorOp : Concrete_Op<"wop_pbs_crt_lwe_tensor", [NoSideEffect]> {
|
||||
def Concrete_WopPBSCRTLweTensorOp : Concrete_Op<"wop_pbs_crt_lwe_tensor", [Pure]> {
|
||||
let arguments = (ins
|
||||
Concrete_LweCRTTensor:$ciphertext,
|
||||
Concrete_CrtLutsTensor:$lookupTable,
|
||||
|
||||
@@ -17,7 +17,7 @@ namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
namespace optimizer {
|
||||
using FunctionsDag = std::map<std::string, llvm::Optional<Dag>>;
|
||||
using FunctionsDag = std::map<std::string, std::optional<Dag>>;
|
||||
|
||||
std::unique_ptr<mlir::Pass> createDagPass(optimizer::Config config,
|
||||
optimizer::FunctionsDag &dags);
|
||||
|
||||
@@ -19,6 +19,7 @@ def FHE_Dialect : Dialect {
|
||||
}];
|
||||
let cppNamespace = "::mlir::concretelang::FHE";
|
||||
let useDefaultTypePrinterParser = 1;
|
||||
let useFoldAPI = kEmitRawAttributesFolder;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -18,7 +18,7 @@ include "concretelang/Dialect/FHE/IR/FHETypes.td"
|
||||
class FHE_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<FHE_Dialect, mnemonic, traits>;
|
||||
|
||||
def FHE_ZeroEintOp : FHE_Op<"zero", [NoSideEffect]> {
|
||||
def FHE_ZeroEintOp : FHE_Op<"zero", [Pure]> {
|
||||
let summary = "Returns a trivial encrypted integer of 0";
|
||||
|
||||
let description = [{
|
||||
@@ -35,7 +35,7 @@ def FHE_ZeroEintOp : FHE_Op<"zero", [NoSideEffect]> {
|
||||
let results = (outs FHE_AnyEncryptedInteger:$out);
|
||||
}
|
||||
|
||||
def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [NoSideEffect]> {
|
||||
def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [Pure]> {
|
||||
let summary = "Creates a new tensor with all elements initialized to an encrypted zero.";
|
||||
|
||||
let description = [{
|
||||
@@ -53,7 +53,7 @@ def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [NoSideEffect]> {
|
||||
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$tensor);
|
||||
}
|
||||
|
||||
def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [NoSideEffect]> {
|
||||
def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [Pure]> {
|
||||
let summary = "Adds an encrypted integer and a clear integer";
|
||||
|
||||
let description = [{
|
||||
@@ -87,7 +87,7 @@ def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [NoSideEffect]> {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def FHE_AddEintOp : FHE_Op<"add_eint", [NoSideEffect]> {
|
||||
def FHE_AddEintOp : FHE_Op<"add_eint", [Pure]> {
|
||||
let summary = "Adds two encrypted integers";
|
||||
|
||||
let description = [{
|
||||
@@ -120,7 +120,7 @@ def FHE_AddEintOp : FHE_Op<"add_eint", [NoSideEffect]> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [NoSideEffect]> {
|
||||
def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [Pure]> {
|
||||
let summary = "Subtract an encrypted integer from a clear integer";
|
||||
|
||||
let description = [{
|
||||
@@ -153,7 +153,7 @@ def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [NoSideEffect]> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [NoSideEffect]> {
|
||||
def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [Pure]> {
|
||||
let summary = "Subtract a clear integer from an encrypted integer";
|
||||
|
||||
let description = [{
|
||||
@@ -187,7 +187,7 @@ def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [NoSideEffect]> {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def FHE_SubEintOp : FHE_Op<"sub_eint", [NoSideEffect]> {
|
||||
def FHE_SubEintOp : FHE_Op<"sub_eint", [Pure]> {
|
||||
let summary = "Subtract an encrypted integer from an encrypted integer";
|
||||
|
||||
let description = [{
|
||||
@@ -220,7 +220,7 @@ def FHE_SubEintOp : FHE_Op<"sub_eint", [NoSideEffect]> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_NegEintOp : FHE_Op<"neg_eint", [NoSideEffect]> {
|
||||
def FHE_NegEintOp : FHE_Op<"neg_eint", [Pure]> {
|
||||
|
||||
let summary = "Negates an encrypted integer";
|
||||
|
||||
@@ -251,7 +251,7 @@ def FHE_NegEintOp : FHE_Op<"neg_eint", [NoSideEffect]> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [NoSideEffect]> {
|
||||
def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [Pure]> {
|
||||
let summary = "Multiply an encrypted integer with a clear integer";
|
||||
|
||||
let description = [{
|
||||
@@ -285,7 +285,7 @@ def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [NoSideEffect]> {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def FHE_MulEintOp : FHE_Op<"mul_eint", [NoSideEffect]> {
|
||||
def FHE_MulEintOp : FHE_Op<"mul_eint", [Pure]> {
|
||||
let summary = "Multiplies two encrypted integers";
|
||||
|
||||
let description = [{
|
||||
@@ -323,7 +323,7 @@ def FHE_MulEintOp : FHE_Op<"mul_eint", [NoSideEffect]> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_MaxEintOp : FHE_Op<"max_eint", [NoSideEffect]> {
|
||||
def FHE_MaxEintOp : FHE_Op<"max_eint", [Pure]> {
|
||||
let summary = "Get maximum of two encrypted integers.";
|
||||
|
||||
let description = [{
|
||||
@@ -358,7 +358,7 @@ def FHE_MaxEintOp : FHE_Op<"max_eint", [NoSideEffect]> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_ToSignedOp : FHE_Op<"to_signed", [NoSideEffect]> {
|
||||
def FHE_ToSignedOp : FHE_Op<"to_signed", [Pure]> {
|
||||
let summary = "Cast an unsigned integer to a signed one";
|
||||
|
||||
let description = [{
|
||||
@@ -383,7 +383,7 @@ def FHE_ToSignedOp : FHE_Op<"to_signed", [NoSideEffect]> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [NoSideEffect]> {
|
||||
def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [Pure]> {
|
||||
let summary = "Cast a signed integer to an unsigned one";
|
||||
|
||||
let description = [{
|
||||
@@ -408,7 +408,7 @@ def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [NoSideEffect]> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [NoSideEffect]> {
|
||||
def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [Pure]> {
|
||||
|
||||
let summary = "Applies a clear lookup table to an encrypted integer";
|
||||
|
||||
@@ -434,7 +434,7 @@ def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [NoSideEffect]> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_RoundEintOp: FHE_Op<"round", [NoSideEffect]> {
|
||||
def FHE_RoundEintOp: FHE_Op<"round", [Pure]> {
|
||||
|
||||
let summary = "Rounds a ciphertext to a smaller precision.";
|
||||
|
||||
@@ -465,7 +465,7 @@ def FHE_RoundEintOp: FHE_Op<"round", [NoSideEffect]> {
|
||||
|
||||
// FHE Boolean Operations
|
||||
|
||||
def FHE_GenGateOp : FHE_Op<"gen_gate", [NoSideEffect]> {
|
||||
def FHE_GenGateOp : FHE_Op<"gen_gate", [Pure]> {
|
||||
|
||||
let summary = "Applies a truth table based on two boolean inputs";
|
||||
|
||||
@@ -492,7 +492,7 @@ def FHE_GenGateOp : FHE_Op<"gen_gate", [NoSideEffect]> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_MuxOp : FHE_Op<"mux", [NoSideEffect]> {
|
||||
def FHE_MuxOp : FHE_Op<"mux", [Pure]> {
|
||||
|
||||
let summary = "Multiplexer for two encrypted boolean inputs, based on an encrypted condition";
|
||||
|
||||
@@ -509,7 +509,7 @@ def FHE_MuxOp : FHE_Op<"mux", [NoSideEffect]> {
|
||||
let results = (outs FHE_EncryptedBooleanType);
|
||||
}
|
||||
|
||||
def FHE_BoolAndOp : FHE_Op<"and", [NoSideEffect]> {
|
||||
def FHE_BoolAndOp : FHE_Op<"and", [Pure]> {
|
||||
|
||||
let summary = "Applies an AND gate to two encrypted boolean values";
|
||||
|
||||
@@ -526,7 +526,7 @@ def FHE_BoolAndOp : FHE_Op<"and", [NoSideEffect]> {
|
||||
let results = (outs FHE_EncryptedBooleanType);
|
||||
}
|
||||
|
||||
def FHE_BoolOrOp : FHE_Op<"or", [NoSideEffect]> {
|
||||
def FHE_BoolOrOp : FHE_Op<"or", [Pure]> {
|
||||
|
||||
let summary = "Applies an OR gate to two encrypted boolean values";
|
||||
|
||||
@@ -543,7 +543,7 @@ def FHE_BoolOrOp : FHE_Op<"or", [NoSideEffect]> {
|
||||
let results = (outs FHE_EncryptedBooleanType);
|
||||
}
|
||||
|
||||
def FHE_BoolNandOp : FHE_Op<"nand", [NoSideEffect]> {
|
||||
def FHE_BoolNandOp : FHE_Op<"nand", [Pure]> {
|
||||
|
||||
let summary = "Applies a NAND gate to two encrypted boolean values";
|
||||
|
||||
@@ -560,7 +560,7 @@ def FHE_BoolNandOp : FHE_Op<"nand", [NoSideEffect]> {
|
||||
let results = (outs FHE_EncryptedBooleanType);
|
||||
}
|
||||
|
||||
def FHE_BoolXorOp : FHE_Op<"xor", [NoSideEffect]> {
|
||||
def FHE_BoolXorOp : FHE_Op<"xor", [Pure]> {
|
||||
|
||||
let summary = "Applies a XOR gate to two encrypted boolean values";
|
||||
|
||||
@@ -577,7 +577,7 @@ def FHE_BoolXorOp : FHE_Op<"xor", [NoSideEffect]> {
|
||||
let results = (outs FHE_EncryptedBooleanType);
|
||||
}
|
||||
|
||||
def FHE_BoolNotOp : FHE_Op<"not", [NoSideEffect]> {
|
||||
def FHE_BoolNotOp : FHE_Op<"not", [Pure]> {
|
||||
|
||||
let summary = "Applies a NOT gate to an encrypted boolean value";
|
||||
|
||||
@@ -594,7 +594,7 @@ def FHE_BoolNotOp : FHE_Op<"not", [NoSideEffect]> {
|
||||
let results = (outs FHE_EncryptedBooleanType);
|
||||
}
|
||||
|
||||
def FHE_ToBoolOp : FHE_Op<"to_bool", [NoSideEffect]> {
|
||||
def FHE_ToBoolOp : FHE_Op<"to_bool", [Pure]> {
|
||||
let summary = "Cast an unsigned integer to a boolean";
|
||||
|
||||
let description = [{
|
||||
@@ -620,7 +620,7 @@ def FHE_ToBoolOp : FHE_Op<"to_bool", [NoSideEffect]> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_FromBoolOp : FHE_Op<"from_bool", [NoSideEffect]> {
|
||||
def FHE_FromBoolOp : FHE_Op<"from_bool", [Pure]> {
|
||||
let summary = "Cast a boolean to an unsigned integer";
|
||||
|
||||
let description = [{
|
||||
|
||||
@@ -10,6 +10,7 @@ def FHELinalg_Dialect : Dialect {
|
||||
A dialect for representation of high level linalg operations on fully homomorphic ciphertexts.
|
||||
}];
|
||||
let cppNamespace = "::mlir::concretelang::FHELinalg";
|
||||
let useFoldAPI = kEmitRawAttributesFolder;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -72,13 +72,13 @@ def SDFG_MakeStream : SDFG_Op<"make_stream"> {
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
bool createsInputStream() {
|
||||
return type() == StreamKind::host_to_device ||
|
||||
type() == StreamKind::on_device;
|
||||
return getType() == StreamKind::host_to_device ||
|
||||
getType() == StreamKind::on_device;
|
||||
}
|
||||
|
||||
bool createsOutputStream() {
|
||||
return type() == StreamKind::device_to_host ||
|
||||
type() == StreamKind::on_device;
|
||||
return getType() == StreamKind::device_to_host ||
|
||||
getType() == StreamKind::on_device;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -45,13 +45,13 @@ enum Backend {
|
||||
|
||||
/// Compilation options allows to configure the compilation pipeline.
|
||||
struct CompilationOptions {
|
||||
llvm::Optional<mlir::concretelang::V0FHEConstraint> v0FHEConstraints;
|
||||
std::optional<mlir::concretelang::V0FHEConstraint> v0FHEConstraints;
|
||||
|
||||
llvm::Optional<mlir::concretelang::V0Parameter> v0Parameter;
|
||||
std::optional<mlir::concretelang::V0Parameter> v0Parameter;
|
||||
|
||||
/// largeIntegerParameter force the compiler engine to lower FHE.eint using
|
||||
/// the large integers strategy with the given parameters.
|
||||
llvm::Optional<mlir::concretelang::LargeIntegerParameter>
|
||||
std::optional<mlir::concretelang::LargeIntegerParameter>
|
||||
largeIntegerParameter;
|
||||
|
||||
bool verifyDiagnostics;
|
||||
@@ -65,9 +65,9 @@ struct CompilationOptions {
|
||||
bool optimizeTFHE;
|
||||
/// use GPU during execution by generating GPU operations if possible
|
||||
bool emitGPUOps;
|
||||
llvm::Optional<std::vector<int64_t>> fhelinalgTileSizes;
|
||||
std::optional<std::vector<int64_t>> fhelinalgTileSizes;
|
||||
|
||||
llvm::Optional<std::string> clientParametersFuncName;
|
||||
std::optional<std::string> clientParametersFuncName;
|
||||
|
||||
optimizer::Config optimizerConfig;
|
||||
|
||||
@@ -79,11 +79,11 @@ struct CompilationOptions {
|
||||
unsigned int chunkWidth;
|
||||
|
||||
CompilationOptions()
|
||||
: v0FHEConstraints(llvm::None), verifyDiagnostics(false),
|
||||
: v0FHEConstraints(std::nullopt), verifyDiagnostics(false),
|
||||
autoParallelize(false), loopParallelize(false), batchConcreteOps(false),
|
||||
emitSDFGOps(false), unrollLoopsWithSDFGConvertibleOps(false),
|
||||
dataflowParallelize(false), optimizeTFHE(true), emitGPUOps(false),
|
||||
clientParametersFuncName(llvm::None),
|
||||
clientParametersFuncName(std::nullopt),
|
||||
optimizerConfig(optimizer::DEFAULT_CONFIG), chunkIntegers(false),
|
||||
chunkSize(4), chunkWidth(2){};
|
||||
|
||||
@@ -119,11 +119,11 @@ public:
|
||||
CompilationContext::createShared())
|
||||
: compilationContext(compilationContext) {}
|
||||
|
||||
llvm::Optional<mlir::OwningOpRef<mlir::ModuleOp>> mlirModuleRef;
|
||||
llvm::Optional<mlir::concretelang::ClientParameters> clientParameters;
|
||||
llvm::Optional<CompilationFeedback> feedback;
|
||||
std::optional<mlir::OwningOpRef<mlir::ModuleOp>> mlirModuleRef;
|
||||
std::optional<mlir::concretelang::ClientParameters> clientParameters;
|
||||
std::optional<CompilationFeedback> feedback;
|
||||
std::unique_ptr<llvm::Module> llvmModule;
|
||||
llvm::Optional<mlir::concretelang::V0FHEContext> fheContext;
|
||||
std::optional<mlir::concretelang::V0FHEContext> fheContext;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<CompilationContext> compilationContext;
|
||||
@@ -176,7 +176,7 @@ public:
|
||||
void addExtraObjectFilePath(std::string objectFilePath);
|
||||
llvm::Expected<std::string>
|
||||
emit(std::string path, std::string dotExt, std::string linker,
|
||||
llvm::Optional<std::vector<std::string>> extraArgs = {});
|
||||
std::optional<std::vector<std::string>> extraArgs = {});
|
||||
~Library();
|
||||
|
||||
private:
|
||||
@@ -242,21 +242,21 @@ public:
|
||||
CompilerEngine(std::shared_ptr<CompilationContext> compilationContext)
|
||||
: overrideMaxEintPrecision(), overrideMaxMANP(), compilerOptions(),
|
||||
generateClientParameters(
|
||||
compilerOptions.clientParametersFuncName.hasValue()),
|
||||
compilerOptions.clientParametersFuncName.has_value()),
|
||||
enablePass([](mlir::Pass *pass) { return true; }),
|
||||
compilationContext(compilationContext) {}
|
||||
|
||||
llvm::Expected<CompilationResult>
|
||||
compile(llvm::StringRef s, Target target,
|
||||
llvm::Optional<std::shared_ptr<Library>> lib = {});
|
||||
std::optional<std::shared_ptr<Library>> lib = {});
|
||||
|
||||
llvm::Expected<CompilationResult>
|
||||
compile(std::unique_ptr<llvm::MemoryBuffer> buffer, Target target,
|
||||
llvm::Optional<std::shared_ptr<Library>> lib = {});
|
||||
std::optional<std::shared_ptr<Library>> lib = {});
|
||||
|
||||
llvm::Expected<CompilationResult>
|
||||
compile(llvm::SourceMgr &sm, Target target,
|
||||
llvm::Optional<std::shared_ptr<Library>> lib = {});
|
||||
std::optional<std::shared_ptr<Library>> lib = {});
|
||||
|
||||
llvm::Expected<CompilerEngine::Library>
|
||||
compile(std::vector<std::string> inputs, std::string outputDirPath,
|
||||
@@ -276,11 +276,11 @@ public:
|
||||
|
||||
void setCompilationOptions(CompilationOptions &options) {
|
||||
compilerOptions = options;
|
||||
if (options.v0FHEConstraints.hasValue()) {
|
||||
if (options.v0FHEConstraints.has_value()) {
|
||||
setFHEConstraints(*options.v0FHEConstraints);
|
||||
}
|
||||
|
||||
if (options.clientParametersFuncName.hasValue()) {
|
||||
if (options.clientParametersFuncName.has_value()) {
|
||||
setGenerateClientParameters(true);
|
||||
}
|
||||
}
|
||||
@@ -292,8 +292,8 @@ public:
|
||||
void setEnablePass(std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
protected:
|
||||
llvm::Optional<size_t> overrideMaxEintPrecision;
|
||||
llvm::Optional<size_t> overrideMaxMANP;
|
||||
std::optional<size_t> overrideMaxEintPrecision;
|
||||
std::optional<size_t> overrideMaxMANP;
|
||||
CompilationOptions compilerOptions;
|
||||
bool generateClientParameters;
|
||||
std::function<bool(mlir::Pass *)> enablePass;
|
||||
@@ -301,7 +301,7 @@ protected:
|
||||
std::shared_ptr<CompilationContext> compilationContext;
|
||||
|
||||
private:
|
||||
llvm::Expected<llvm::Optional<optimizer::Description>>
|
||||
llvm::Expected<std::optional<optimizer::Description>>
|
||||
getConcreteOptimizerDescription(CompilationResult &res);
|
||||
llvm::Error determineFHEParameters(CompilationResult &res);
|
||||
};
|
||||
|
||||
@@ -34,7 +34,7 @@ class JITSupport
|
||||
JitCompilationResult> {
|
||||
|
||||
public:
|
||||
JITSupport(llvm::Optional<std::string> runtimeLibPath = llvm::None);
|
||||
JITSupport(std::optional<std::string> runtimeLibPath = std::nullopt);
|
||||
|
||||
llvm::Expected<std::unique_ptr<JitCompilationResult>>
|
||||
compile(llvm::SourceMgr &program, CompilationOptions options) override;
|
||||
@@ -63,7 +63,7 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
llvm::Optional<std::string> runtimeLibPath;
|
||||
std::optional<std::string> runtimeLibPath;
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> llvmOptPipeline;
|
||||
};
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ public:
|
||||
static llvm::Expected<std::unique_ptr<JITLambda>>
|
||||
create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline,
|
||||
llvm::Optional<std::string> runtimeLibPath = {});
|
||||
std::optional<std::string> runtimeLibPath = {});
|
||||
|
||||
/// Call the JIT lambda with the public arguments.
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
|
||||
@@ -13,10 +13,9 @@ llvm::Error emitObject(llvm::Module &module, std::string objectPath);
|
||||
|
||||
llvm::Error callCmd(std::string cmd);
|
||||
|
||||
llvm::Error
|
||||
emitLibrary(std::vector<std::string> objectsPath, std::string libraryPath,
|
||||
std::string linker,
|
||||
llvm::Optional<std::vector<std::string>> extraArgs = {});
|
||||
llvm::Error emitLibrary(std::vector<std::string> objectsPath,
|
||||
std::string libraryPath, std::string linker,
|
||||
std::optional<std::vector<std::string>> extraArgs = {});
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -206,7 +206,7 @@ typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
return (sign) ? buildScalarLambdaResult<int8_t>(keySet, result)
|
||||
: buildScalarLambdaResult<uint8_t>(keySet, result);
|
||||
}
|
||||
} else if (gate.chunkInfo.hasValue()) {
|
||||
} else if (gate.chunkInfo.has_value()) {
|
||||
// chunked scalar case
|
||||
assert(gate.shape.dimensions.size() == 1);
|
||||
width = gate.shape.size * gate.chunkInfo->width;
|
||||
@@ -392,10 +392,10 @@ public:
|
||||
/// Build the client KeySet from the client parameters.
|
||||
static llvm::Expected<std::unique_ptr<clientlib::KeySet>>
|
||||
keySet(clientlib::ClientParameters clientParameters,
|
||||
llvm::Optional<clientlib::KeySetCache> cache) {
|
||||
std::optional<clientlib::KeySetCache> cache) {
|
||||
std::shared_ptr<clientlib::KeySetCache> cachePtr;
|
||||
if (cache.hasValue()) {
|
||||
cachePtr = std::make_shared<clientlib::KeySetCache>(cache.getValue());
|
||||
if (cache.has_value()) {
|
||||
cachePtr = std::make_shared<clientlib::KeySetCache>(cache.value());
|
||||
}
|
||||
auto keySet =
|
||||
clientlib::KeySetCache::generate(cachePtr, clientParameters, 0, 0);
|
||||
@@ -434,7 +434,7 @@ public:
|
||||
static llvm::Expected<ClientServer>
|
||||
create(llvm::StringRef program,
|
||||
CompilationOptions options = CompilationOptions("main"),
|
||||
llvm::Optional<clientlib::KeySetCache> cache = {},
|
||||
std::optional<clientlib::KeySetCache> cache = {},
|
||||
LambdaSupport support = LambdaSupport()) {
|
||||
auto compilationResult = support.compile(program, options);
|
||||
if (auto err = compilationResult.takeError()) {
|
||||
|
||||
@@ -62,7 +62,7 @@ public:
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
if (!options.clientParametersFuncName.hasValue()) {
|
||||
if (!options.clientParametersFuncName.has_value()) {
|
||||
return StreamStringError("Need to have a funcname to compile library");
|
||||
}
|
||||
|
||||
|
||||
@@ -39,11 +39,11 @@ static SmallVector<Value> makeCanonicalAffineApplies(OpBuilder &b, Location loc,
|
||||
}
|
||||
|
||||
template <typename LoadOpTy, typename StoreOpTy, typename OpType>
|
||||
static std::vector<Value> inlineRegionAndEmitStore(
|
||||
static llvm::SmallVector<Value> inlineRegionAndEmitStore(
|
||||
OpBuilder &b, Location loc, OpType op, ArrayRef<Value> indexedValues,
|
||||
ArrayRef<SmallVector<Value>> indexing, ArrayRef<Value> outputBuffers) {
|
||||
auto &block = op->getRegion(0).front();
|
||||
BlockAndValueMapping map;
|
||||
IRMapping map;
|
||||
map.map(block.getArguments(), indexedValues);
|
||||
for (auto &op : block.without_terminator()) {
|
||||
auto *newOp = b.clone(op, map);
|
||||
@@ -51,7 +51,7 @@ static std::vector<Value> inlineRegionAndEmitStore(
|
||||
}
|
||||
|
||||
Operation *terminator = block.getTerminator();
|
||||
std::vector<Value> retVals;
|
||||
llvm::SmallVector<Value> retVals;
|
||||
|
||||
for (OpOperand &operand : terminator->getOpOperands()) {
|
||||
Value toStore = map.lookupOrDefault(operand.get());
|
||||
@@ -91,40 +91,42 @@ static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
|
||||
LoopLikeOpInterface loopOp = loopOps.back();
|
||||
for (IndexOp indexOp :
|
||||
llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>()))
|
||||
rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]);
|
||||
rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LoadOpTy, typename StoreOpTy>
|
||||
static std::vector<Value>
|
||||
static llvm::SmallVector<Value>
|
||||
emitScalarImplementation(OpBuilder &b, Location loc, ArrayRef<Value> allIvs,
|
||||
LinalgOp linalgOp, ValueRange operandValuesToUse) {
|
||||
assert(linalgOp.hasTensorSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
SmallVector<Value> indexedValues;
|
||||
indexedValues.reserve(linalgOp.getNumInputsAndOutputs());
|
||||
indexedValues.reserve(linalgOp->getNumOperands());
|
||||
|
||||
auto allIvsPlusDims = SmallVector<Value>(allIvs.begin(), allIvs.end());
|
||||
|
||||
// TODO: Avoid the loads if the corresponding argument of the
|
||||
// region has no uses.
|
||||
// 1.a. Emit load from input operand or for scalars access the operand itself.
|
||||
for (OpOperand *inputOperand : linalgOp.getInputOperands()) {
|
||||
for (OpOperand *inputOperand : linalgOp.getDpsInputOperands()) {
|
||||
Value v = operandValuesToUse[inputOperand->getOperandNumber()];
|
||||
|
||||
if (linalgOp.isScalar(inputOperand)) {
|
||||
indexedValues.push_back(inputOperand->get());
|
||||
indexedValues.push_back(v);
|
||||
continue;
|
||||
}
|
||||
auto indexing = makeCanonicalAffineApplies(
|
||||
b, loc, linalgOp.getTiedIndexingMap(inputOperand), allIvsPlusDims);
|
||||
indexedValues.push_back(
|
||||
b.create<LoadOpTy>(loc, inputOperand->get(), indexing));
|
||||
b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims);
|
||||
indexedValues.push_back(b.create<LoadOpTy>(loc, v, indexing));
|
||||
}
|
||||
// 1.b. Emit load from output views.
|
||||
for (OpOperand *outputOperand : linalgOp.getOutputOperands()) {
|
||||
for (OpOperand *outputOperand : linalgOp.getDpsInitOperands()) {
|
||||
Value v = operandValuesToUse[outputOperand->getOperandNumber()];
|
||||
|
||||
SmallVector<Value> indexing = makeCanonicalAffineApplies(
|
||||
b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims);
|
||||
indexedValues.push_back(
|
||||
b.create<LoadOpTy>(loc, outputOperand->get(), indexing));
|
||||
b, loc, linalgOp.getMatchingIndexingMap(outputOperand), allIvsPlusDims);
|
||||
indexedValues.push_back(b.create<LoadOpTy>(loc, v, indexing));
|
||||
}
|
||||
|
||||
// TODO: When a region inliner exists, use it.
|
||||
@@ -132,10 +134,13 @@ emitScalarImplementation(OpBuilder &b, Location loc, ArrayRef<Value> allIvs,
|
||||
// 3. Emit store.
|
||||
SmallVector<SmallVector<Value>, 8> indexing;
|
||||
SmallVector<Value> outputBuffers;
|
||||
for (OpOperand *outputOperand : linalgOp.getOutputTensorOperands()) {
|
||||
indexing.push_back(makeCanonicalAffineApplies(
|
||||
b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims));
|
||||
outputBuffers.push_back(operandValuesToUse.back());
|
||||
for (OpOperand *outputOperand : linalgOp.getDpsInitOperands()) {
|
||||
if (outputOperand->get().getType().isa<mlir::TensorType>()) {
|
||||
indexing.push_back(makeCanonicalAffineApplies(
|
||||
b, loc, linalgOp.getMatchingIndexingMap(outputOperand),
|
||||
allIvsPlusDims));
|
||||
outputBuffers.push_back(operandValuesToUse.back());
|
||||
}
|
||||
}
|
||||
return inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(
|
||||
b, loc, linalgOp, indexedValues, indexing, outputBuffers);
|
||||
@@ -151,7 +156,7 @@ linalgTensorOpToLoopsImpl(PatternRewriter &rewriter, LinalgOp linalgOp,
|
||||
"expected linalg op with value semantics");
|
||||
|
||||
auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
|
||||
auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
|
||||
auto iteratorTypes = llvm::to_vector<4>(linalgOp.getIteratorTypesArray());
|
||||
|
||||
SmallVector<Value> allIvs;
|
||||
GenerateLoopNest<LoopTy>::doit(
|
||||
|
||||
@@ -20,7 +20,7 @@ namespace pipeline {
|
||||
mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
llvm::Expected<std::map<std::string, llvm::Optional<optimizer::Description>>>
|
||||
llvm::Expected<std::map<std::string, std::optional<optimizer::Description>>>
|
||||
getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
optimizer::Config config,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
@@ -40,7 +40,7 @@ transformHighLevelFHEOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
std::optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelize, bool batch);
|
||||
|
||||
@@ -55,12 +55,12 @@ transformFHEBigInt(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
std::optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
std::optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
|
||||
@@ -21,7 +21,7 @@ using ::concretelang::clientlib::ClientParameters;
|
||||
llvm::Expected<ClientParameters>
|
||||
createClientParametersForV0(V0FHEContext context, llvm::StringRef functionName,
|
||||
mlir::ModuleOp module, int bitsOfSecurity,
|
||||
llvm::Optional<ChunkInfo> chunkInfo = llvm::None);
|
||||
llvm::Optional<ChunkInfo> chunkInfo = std::nullopt);
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -60,7 +60,7 @@ using DagSolution = concrete_optimizer::dag::DagSolution;
|
||||
/* Contains any circuit description usable by the concrete-optimizer */
|
||||
struct Description {
|
||||
V0FHEConstraint constraint;
|
||||
llvm::Optional<optimizer::Dag> dag;
|
||||
std::optional<optimizer::Dag> dag;
|
||||
};
|
||||
|
||||
} // namespace optimizer
|
||||
|
||||
@@ -115,7 +115,6 @@ add_mlir_python_common_capi_library(
|
||||
DECLARED_SOURCES
|
||||
# TODO: This can be chopped down significantly for size.
|
||||
MLIRPythonSources
|
||||
MLIRPythonExtension.AllPassesRegistration
|
||||
ConcretelangBindingsPythonSources
|
||||
ConcretelangBindingsPythonExtension)
|
||||
|
||||
@@ -130,7 +129,6 @@ add_mlir_python_modules(
|
||||
"python_packages/concretelang_core/mlir"
|
||||
DECLARED_SOURCES
|
||||
MLIRPythonSources
|
||||
MLIRPythonExtension.AllPassesRegistration
|
||||
# We need the circt extensions co-located with the MLIR extensions. When the namespace is unified, this moves to the
|
||||
# below.
|
||||
ConcretelangBindingsPythonExtension
|
||||
|
||||
@@ -203,10 +203,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
"key_set",
|
||||
[](clientlib::ClientParameters clientParameters,
|
||||
clientlib::KeySetCache *cache) {
|
||||
auto optCache =
|
||||
cache == nullptr
|
||||
? llvm::None
|
||||
: llvm::Optional<clientlib::KeySetCache>(*cache);
|
||||
auto optCache = cache == nullptr
|
||||
? std::nullopt
|
||||
: std::optional<clientlib::KeySetCache>(*cache);
|
||||
return key_set(clientParameters, optCache);
|
||||
},
|
||||
pybind11::arg().none(false), pybind11::arg().none(true))
|
||||
@@ -241,9 +240,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
[](mlir::concretelang::ClientParameters &clientParameters) {
|
||||
std::vector<bool> result;
|
||||
for (auto output : clientParameters.outputs) {
|
||||
if (output.encryption.hasValue()) {
|
||||
result.push_back(
|
||||
output.encryption.getValue().encoding.isSigned);
|
||||
if (output.encryption.has_value()) {
|
||||
result.push_back(output.encryption.value().encoding.isSigned);
|
||||
} else {
|
||||
result.push_back(true);
|
||||
}
|
||||
|
||||
@@ -23,8 +23,8 @@
|
||||
|
||||
MLIR_CAPI_EXPORTED JITSupport_Py jit_support(std::string runtimeLibPath) {
|
||||
auto opt = runtimeLibPath.empty()
|
||||
? llvm::None
|
||||
: llvm::Optional<std::string>(runtimeLibPath);
|
||||
? std::nullopt
|
||||
: std::optional<std::string>(runtimeLibPath);
|
||||
return JITSupport_Py{mlir::concretelang::JITSupport(opt)};
|
||||
}
|
||||
|
||||
@@ -139,7 +139,7 @@ library_get_client_parameters_path(LibrarySupport_Py support) {
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::KeySet>
|
||||
key_set(concretelang::clientlib::ClientParameters clientParameters,
|
||||
llvm::Optional<concretelang::clientlib::KeySetCache> cache) {
|
||||
std::optional<concretelang::clientlib::KeySetCache> cache) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(
|
||||
ks, (mlir::concretelang::LambdaSupport<int, int>::keySet(clientParameters,
|
||||
cache)));
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "concretelang/Bindings/Python/DialectModules.h"
|
||||
#include "concretelang/Support/Constants.h"
|
||||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
|
||||
#include "llvm-c/ErrorHandling.h"
|
||||
|
||||
@@ -30,6 +30,5 @@
|
||||
#include <mlir-c/IntegerSet.h>
|
||||
#include <mlir-c/Interfaces.h>
|
||||
#include <mlir-c/Pass.h>
|
||||
#include <mlir-c/Registration.h>
|
||||
#include <mlir-c/Support.h>
|
||||
#include <mlir-c/Transforms.h>
|
||||
|
||||
@@ -5,7 +5,7 @@ use std::error::Error;
|
||||
use std::path::Path;
|
||||
use std::process::exit;
|
||||
|
||||
const MLIR_STATIC_LIBS: [&str; 179] = [
|
||||
const MLIR_STATIC_LIBS: [&str; 174] = [
|
||||
"MLIRMemRefDialect",
|
||||
"MLIRVectorToSPIRV",
|
||||
"MLIRControlFlowInterfaces",
|
||||
@@ -37,7 +37,7 @@ const MLIR_STATIC_LIBS: [&str; 179] = [
|
||||
"MLIRPresburger",
|
||||
"MLIRFuncDialect",
|
||||
"MLIRPDLToPDLInterp",
|
||||
"MLIRArithmeticTransforms",
|
||||
"MLIRArithTransforms",
|
||||
"MLIRViewLikeInterface",
|
||||
"MLIRTargetCpp",
|
||||
"MLIROpenMPToLLVM",
|
||||
@@ -53,8 +53,8 @@ const MLIR_STATIC_LIBS: [&str; 179] = [
|
||||
"MLIRTensorUtils",
|
||||
"MLIRSPIRVSerialization",
|
||||
"MLIRShapeToStandard",
|
||||
"MLIRArithmeticToSPIRV",
|
||||
"MLIRArithmeticDialect",
|
||||
"MLIRArithToSPIRV",
|
||||
"MLIRArithDialect",
|
||||
"MLIRFuncToSPIRV",
|
||||
"MLIRQuantUtils",
|
||||
"MLIRTensorTilingInterfaceImpl",
|
||||
@@ -95,7 +95,7 @@ const MLIR_STATIC_LIBS: [&str; 179] = [
|
||||
"MLIRRewrite",
|
||||
"MLIRAMXToLLVMIRTranslation",
|
||||
"MLIRInferIntRangeInterface",
|
||||
"MLIRCAPIRegistration",
|
||||
"MLIRCAPIRegisterEverything",
|
||||
"MLIRNVVMToLLVMIRTranslation",
|
||||
"MLIRAsyncTransforms",
|
||||
"MLIRPDLInterpDialect",
|
||||
@@ -129,7 +129,6 @@ const MLIR_STATIC_LIBS: [&str; 179] = [
|
||||
"MLIRSPIRVUtils",
|
||||
"MLIRCastInterfaces",
|
||||
"MLIRTosaToTensor",
|
||||
"MLIRMemRefUtils",
|
||||
"MLIRGPUToSPIRV",
|
||||
"MLIRBufferizationDialect",
|
||||
"MLIRSCFToControlFlow",
|
||||
@@ -139,7 +138,6 @@ const MLIR_STATIC_LIBS: [&str; 179] = [
|
||||
"MLIRSparseTensorDialect",
|
||||
"MLIRTensorToSPIRV",
|
||||
"MLIRVectorToSCF",
|
||||
"MLIRQuantTransforms",
|
||||
"MLIRLLVMToLLVMIRTranslation",
|
||||
"MLIRNVGPUDialect",
|
||||
"MLIRAsyncToLLVM",
|
||||
@@ -158,11 +156,9 @@ const MLIR_STATIC_LIBS: [&str; 179] = [
|
||||
"MLIRVectorToLLVM",
|
||||
"MLIRSPIRVDialect",
|
||||
"MLIRSideEffectInterfaces",
|
||||
"MLIRVectorToROCDL",
|
||||
"MLIRQuantDialect",
|
||||
"MLIRSCFTransforms",
|
||||
"MLIRMLProgramDialect",
|
||||
"MLIRLinalgToSPIRV",
|
||||
"MLIRDLTIDialect",
|
||||
"MLIRLinalgFrontend",
|
||||
"MLIRROCDLToLLVMIRTranslation",
|
||||
@@ -177,14 +173,13 @@ const MLIR_STATIC_LIBS: [&str; 179] = [
|
||||
"MLIRSPIRVTransforms",
|
||||
"MLIRMemRefToLLVM",
|
||||
"MLIRSPIRVBinaryUtils",
|
||||
"MLIRLinalgAnalysis",
|
||||
"MLIRArithmeticUtils",
|
||||
"MLIRArithUtils",
|
||||
"MLIRVectorInterfaces",
|
||||
"MLIRGPUOps",
|
||||
"MLIRComplexToLLVM",
|
||||
"MLIRShapeOpsTransforms",
|
||||
"MLIRX86VectorTransforms",
|
||||
"MLIRArithmeticToLLVM",
|
||||
"MLIRArithToLLVM",
|
||||
];
|
||||
|
||||
const LLVM_STATIC_LIBS: [&str; 51] = [
|
||||
|
||||
@@ -29,7 +29,7 @@ ClientLambda::load(std::string functionName, std::string jsonPath) {
|
||||
<< std::to_string(param->outputs.size()) << ") != 1 is not supprted";
|
||||
}
|
||||
|
||||
if (!param->outputs[0].encryption.hasValue()) {
|
||||
if (!param->outputs[0].encryption.has_value()) {
|
||||
return StringError("ClientLambda: clear output is not yet supported");
|
||||
}
|
||||
ClientLambda lambda;
|
||||
|
||||
@@ -37,9 +37,9 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
|
||||
OUTCOME_TRY(CircuitGate input, keySet.clientParameters().input(currentPos));
|
||||
// a chunked input is represented as a tensor in lower levels, and need to to
|
||||
// splitted into chunks and encrypted as such
|
||||
if (input.chunkInfo.hasValue()) {
|
||||
if (input.chunkInfo.has_value()) {
|
||||
std::vector<uint64_t> chunks =
|
||||
chunkInput(arg, input.shape.size, input.chunkInfo.getPointer()->width);
|
||||
chunkInput(arg, input.shape.size, input.chunkInfo.value().width);
|
||||
return this->pushArg(chunks.data(), input.shape.size, keySet);
|
||||
}
|
||||
// we only increment if we don't forward the call to another pushArg method
|
||||
@@ -47,7 +47,7 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
|
||||
if (input.shape.size != 0) {
|
||||
return StringError("argument #") << pos << " is not a scalar";
|
||||
}
|
||||
if (!input.encryption.hasValue()) {
|
||||
if (!input.encryption.has_value()) {
|
||||
// clear scalar: just push the argument
|
||||
preparedArgs.push_back((void *)arg);
|
||||
return outcome::success();
|
||||
|
||||
@@ -45,15 +45,15 @@ outcome::checked<KeySet::SecretKeyGateMapping, StringError>
|
||||
KeySet::mapCircuitGateLweSecretKey(std::vector<CircuitGate> gates) {
|
||||
SecretKeyGateMapping mapping;
|
||||
for (auto gate : gates) {
|
||||
if (gate.encryption.hasValue()) {
|
||||
if (gate.encryption.has_value()) {
|
||||
assert(gate.encryption->secretKeyID < this->secretKeys.size());
|
||||
auto skIt = this->secretKeys[gate.encryption->secretKeyID];
|
||||
|
||||
std::pair<CircuitGate, llvm::Optional<LweSecretKey>> input = {gate, skIt};
|
||||
std::pair<CircuitGate, std::optional<LweSecretKey>> input = {gate, skIt};
|
||||
mapping.push_back(input);
|
||||
} else {
|
||||
std::pair<CircuitGate, llvm::Optional<LweSecretKey>> input = {gate,
|
||||
llvm::None};
|
||||
std::pair<CircuitGate, std::optional<LweSecretKey>> input = {
|
||||
gate, std::nullopt};
|
||||
mapping.push_back(input);
|
||||
}
|
||||
}
|
||||
@@ -152,7 +152,7 @@ KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size) {
|
||||
}
|
||||
auto inputSk = inputs[argPos];
|
||||
auto encryption = std::get<0>(inputSk).encryption;
|
||||
if (!encryption.hasValue()) {
|
||||
if (!encryption.has_value()) {
|
||||
return StringError("allocate_lwe argument #")
|
||||
<< argPos << "is not encypeted";
|
||||
}
|
||||
@@ -167,12 +167,12 @@ KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size) {
|
||||
|
||||
bool KeySet::isInputEncrypted(size_t argPos) {
|
||||
return argPos < inputs.size() &&
|
||||
std::get<0>(inputs[argPos]).encryption.hasValue();
|
||||
std::get<0>(inputs[argPos]).encryption.has_value();
|
||||
}
|
||||
|
||||
bool KeySet::isOutputEncrypted(size_t argPos) {
|
||||
return argPos < outputs.size() &&
|
||||
std::get<0>(outputs[argPos]).encryption.hasValue();
|
||||
std::get<0>(outputs[argPos]).encryption.has_value();
|
||||
}
|
||||
|
||||
/// Return the number of bits to represents the given value
|
||||
@@ -183,9 +183,9 @@ KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) {
|
||||
if (argPos >= inputs.size()) {
|
||||
return StringError("encrypt_lwe position of argument is too high");
|
||||
}
|
||||
auto inputSk = inputs[argPos];
|
||||
const auto &inputSk = inputs[argPos];
|
||||
auto encryption = std::get<0>(inputSk).encryption;
|
||||
if (!encryption.hasValue()) {
|
||||
if (!encryption.has_value()) {
|
||||
return StringError("encrypt_lwe the positional argument is not encrypted");
|
||||
}
|
||||
auto encoding = encryption->encoding;
|
||||
@@ -221,7 +221,7 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
|
||||
auto lweSecretKey = *outputSk.second;
|
||||
auto lweSecretKeyParam = lweSecretKey.parameters();
|
||||
auto encryption = std::get<0>(outputSk).encryption;
|
||||
if (!encryption.hasValue()) {
|
||||
if (!encryption.has_value()) {
|
||||
return StringError("decrypt_lwe: the positional argument is not encrypted");
|
||||
}
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ PublicArguments::serialize(std::ostream &ostream) {
|
||||
for (auto gate : clientParameters.inputs) {
|
||||
iGate++;
|
||||
size_t rank = gate.shape.dimensions.size();
|
||||
if (!gate.encryption.hasValue()) {
|
||||
if (!gate.encryption.has_value()) {
|
||||
return StringError("PublicArguments::serialize: Clear arguments "
|
||||
"are not yet supported. Argument ")
|
||||
<< iGate;
|
||||
@@ -78,7 +78,7 @@ PublicArguments::unserializeArgs(std::istream &istream) {
|
||||
int iGate = -1;
|
||||
for (auto gate : clientParameters.inputs) {
|
||||
iGate++;
|
||||
if (!gate.encryption.hasValue()) {
|
||||
if (!gate.encryption.has_value()) {
|
||||
return StringError("Clear values are not handled");
|
||||
}
|
||||
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
|
||||
@@ -135,7 +135,7 @@ PublicArguments::unserialize(ClientParameters &clientParameters,
|
||||
outcome::checked<void, StringError>
|
||||
PublicResult::unserialize(std::istream &istream) {
|
||||
for (auto gate : clientParameters.outputs) {
|
||||
if (!gate.encryption.hasValue()) {
|
||||
if (!gate.encryption.has_value()) {
|
||||
return StringError("Clear values are not handled");
|
||||
}
|
||||
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
|
||||
|
||||
@@ -52,7 +52,7 @@ char memref_trace[] = "memref_trace";
|
||||
|
||||
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
|
||||
size_t rank) {
|
||||
std::vector<int64_t> shape(rank, -1);
|
||||
std::vector<int64_t> shape(rank, mlir::ShapedType::kDynamic);
|
||||
mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0);
|
||||
for (size_t i = 0; i < rank; i++) {
|
||||
expr = expr +
|
||||
@@ -263,16 +263,16 @@ void keyswitchAddOperands(KeySwitchOp op,
|
||||
mlir::RewriterBase &rewriter) {
|
||||
// level
|
||||
operands.push_back(
|
||||
rewriter.create<arith::ConstantOp>(op.getLoc(), op.levelAttr()));
|
||||
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getLevelAttr()));
|
||||
// base_log
|
||||
operands.push_back(
|
||||
rewriter.create<arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
|
||||
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getBaseLogAttr()));
|
||||
// lwe_dim_in
|
||||
operands.push_back(
|
||||
rewriter.create<arith::ConstantOp>(op.getLoc(), op.lwe_dim_inAttr()));
|
||||
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getLweDimInAttr()));
|
||||
// lwe_dim_out
|
||||
operands.push_back(
|
||||
rewriter.create<arith::ConstantOp>(op.getLoc(), op.lwe_dim_outAttr()));
|
||||
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getLweDimOutAttr()));
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
@@ -283,22 +283,22 @@ void bootstrapAddOperands(BootstrapOp op,
|
||||
mlir::RewriterBase &rewriter) {
|
||||
// input_lwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.inputLweDimAttr()));
|
||||
op.getLoc(), op.getInputLweDimAttr()));
|
||||
// poly_size
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.getPolySizeAttr()));
|
||||
// level
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.getLevelAttr()));
|
||||
// base_log
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.getBaseLogAttr()));
|
||||
// glwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.glweDimensionAttr()));
|
||||
op.getLoc(), op.getGlweDimensionAttr()));
|
||||
// out_precision
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.outPrecisionAttr()));
|
||||
op.getLoc(), op.getOutPrecisionAttr()));
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
@@ -307,9 +307,9 @@ void wopPBSAddOperands(Concrete::WopPBSCRTLweBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
mlir::RewriterBase &rewriter) {
|
||||
mlir::Type crtType = mlir::RankedTensorType::get(
|
||||
{(int)op.crtDecompositionAttr().size()}, rewriter.getI64Type());
|
||||
{(int)op.getCrtDecompositionAttr().size()}, rewriter.getI64Type());
|
||||
std::vector<int64_t> values;
|
||||
for (auto a : op.crtDecomposition()) {
|
||||
for (auto a : op.getCrtDecomposition()) {
|
||||
values.push_back(a.cast<mlir::IntegerAttr>().getValue().getZExtValue());
|
||||
}
|
||||
auto attr = rewriter.getI64TensorAttr(values);
|
||||
@@ -319,43 +319,43 @@ void wopPBSAddOperands(Concrete::WopPBSCRTLweBufferOp op,
|
||||
assert(!failed(globalMemref));
|
||||
|
||||
auto globalRef = rewriter.create<memref::GetGlobalOp>(
|
||||
op.getLoc(), (*globalMemref).type(), (*globalMemref).getName());
|
||||
op.getLoc(), (*globalMemref).getType(), (*globalMemref).getName());
|
||||
operands.push_back(getCastedMemRef(rewriter, globalRef));
|
||||
|
||||
// lwe_small_size
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.packingKeySwitchInputLweDimensionAttr()));
|
||||
op.getLoc(), op.getPackingKeySwitchInputLweDimensionAttr()));
|
||||
// cbs_level_count
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.circuitBootstrapLevelAttr()));
|
||||
op.getLoc(), op.getCircuitBootstrapLevelAttr()));
|
||||
// cbs_base_log
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.circuitBootstrapBaseLogAttr()));
|
||||
op.getLoc(), op.getCircuitBootstrapBaseLogAttr()));
|
||||
|
||||
// ksk_level_count
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.keyswitchLevelAttr()));
|
||||
op.getLoc(), op.getKeyswitchLevelAttr()));
|
||||
// ksk_base_log
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.keyswitchBaseLogAttr()));
|
||||
op.getLoc(), op.getKeyswitchBaseLogAttr()));
|
||||
|
||||
// bsk_level_count
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.bootstrapLevelAttr()));
|
||||
op.getLoc(), op.getBootstrapLevelAttr()));
|
||||
// bsk_base_log
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.bootstrapBaseLogAttr()));
|
||||
op.getLoc(), op.getBootstrapBaseLogAttr()));
|
||||
|
||||
// fpksk_level_count
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.packingKeySwitchLevelAttr()));
|
||||
op.getLoc(), op.getPackingKeySwitchLevelAttr()));
|
||||
// fpksk_base_log
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.packingKeySwitchBaseLogAttr()));
|
||||
op.getLoc(), op.getPackingKeySwitchBaseLogAttr()));
|
||||
|
||||
// polynomial_size
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.packingKeySwitchoutputPolynomialSizeAttr()));
|
||||
op.getLoc(), op.getPackingKeySwitchoutputPolynomialSizeAttr()));
|
||||
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
@@ -365,10 +365,10 @@ void encodePlaintextWithCrtAddOperands(
|
||||
Concrete::EncodePlaintextWithCrtBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands, mlir::RewriterBase &rewriter) {
|
||||
// mods
|
||||
mlir::Type modsType = mlir::RankedTensorType::get({(int)op.modsAttr().size()},
|
||||
rewriter.getI64Type());
|
||||
mlir::Type modsType = mlir::RankedTensorType::get(
|
||||
{(int)op.getModsAttr().size()}, rewriter.getI64Type());
|
||||
std::vector<int64_t> modsValues;
|
||||
for (auto a : op.mods()) {
|
||||
for (auto a : op.getMods()) {
|
||||
modsValues.push_back(a.cast<mlir::IntegerAttr>().getValue().getZExtValue());
|
||||
}
|
||||
auto modsAttr = rewriter.getI64TensorAttr(modsValues);
|
||||
@@ -378,26 +378,27 @@ void encodePlaintextWithCrtAddOperands(
|
||||
rewriter.eraseOp(modsOp);
|
||||
assert(!failed(modsGlobalMemref));
|
||||
auto modsGlobalRef = rewriter.create<memref::GetGlobalOp>(
|
||||
op.getLoc(), (*modsGlobalMemref).type(), (*modsGlobalMemref).getName());
|
||||
op.getLoc(), (*modsGlobalMemref).getType(),
|
||||
(*modsGlobalMemref).getName());
|
||||
operands.push_back(getCastedMemRef(rewriter, modsGlobalRef));
|
||||
|
||||
// mods_prod
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.modsProdAttr()));
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.getModsProdAttr()));
|
||||
}
|
||||
|
||||
void encodeExpandLutForBootstrapAddOperands(
|
||||
Concrete::EncodeExpandLutForBootstrapBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands, mlir::RewriterBase &rewriter) {
|
||||
// poly_size
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.getPolySizeAttr()));
|
||||
// output bits
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.outputBitsAttr()));
|
||||
op.getLoc(), op.getOutputBitsAttr()));
|
||||
// is_signed
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.isSignedAttr()));
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.getIsSignedAttr()));
|
||||
}
|
||||
|
||||
void encodeLutForWopPBSAddOperands(Concrete::EncodeLutForCrtWopPBSBufferOp op,
|
||||
@@ -406,9 +407,9 @@ void encodeLutForWopPBSAddOperands(Concrete::EncodeLutForCrtWopPBSBufferOp op,
|
||||
|
||||
// crt_decomposition
|
||||
mlir::Type crtDecompositionType = mlir::RankedTensorType::get(
|
||||
{(int)op.crtDecompositionAttr().size()}, rewriter.getI64Type());
|
||||
{(int)op.getCrtDecompositionAttr().size()}, rewriter.getI64Type());
|
||||
std::vector<int64_t> crtDecompositionValues;
|
||||
for (auto a : op.crtDecomposition()) {
|
||||
for (auto a : op.getCrtDecomposition()) {
|
||||
crtDecompositionValues.push_back(
|
||||
a.cast<mlir::IntegerAttr>().getValue().getZExtValue());
|
||||
}
|
||||
@@ -420,15 +421,15 @@ void encodeLutForWopPBSAddOperands(Concrete::EncodeLutForCrtWopPBSBufferOp op,
|
||||
rewriter.eraseOp(crtDecompositionOp);
|
||||
assert(!failed(crtDecompositionGlobalMemref));
|
||||
auto crtDecompositionGlobalRef = rewriter.create<memref::GetGlobalOp>(
|
||||
op.getLoc(), (*crtDecompositionGlobalMemref).type(),
|
||||
op.getLoc(), (*crtDecompositionGlobalMemref).getType(),
|
||||
(*crtDecompositionGlobalMemref).getName());
|
||||
operands.push_back(getCastedMemRef(rewriter, crtDecompositionGlobalRef));
|
||||
|
||||
// crt_bits
|
||||
mlir::Type crtBitsType = mlir::RankedTensorType::get(
|
||||
{(int)op.crtBitsAttr().size()}, rewriter.getI64Type());
|
||||
{(int)op.getCrtBitsAttr().size()}, rewriter.getI64Type());
|
||||
std::vector<int64_t> crtBitsValues;
|
||||
for (auto a : op.crtBits()) {
|
||||
for (auto a : op.getCrtBits()) {
|
||||
crtBitsValues.push_back(
|
||||
a.cast<mlir::IntegerAttr>().getValue().getZExtValue());
|
||||
}
|
||||
@@ -439,15 +440,15 @@ void encodeLutForWopPBSAddOperands(Concrete::EncodeLutForCrtWopPBSBufferOp op,
|
||||
rewriter.eraseOp(crtBitsOp);
|
||||
assert(!failed(crtBitsGlobalMemref));
|
||||
auto crtBitsGlobalRef = rewriter.create<memref::GetGlobalOp>(
|
||||
op.getLoc(), (*crtBitsGlobalMemref).type(),
|
||||
op.getLoc(), (*crtBitsGlobalMemref).getType(),
|
||||
(*crtBitsGlobalMemref).getName());
|
||||
operands.push_back(getCastedMemRef(rewriter, crtBitsGlobalRef));
|
||||
// modulus_product
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.modulusProductAttr()));
|
||||
op.getLoc(), op.getModulusProductAttr()));
|
||||
// is_signed
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.isSignedAttr()));
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.getIsSignedAttr()));
|
||||
}
|
||||
|
||||
struct ConcreteToCAPIPass : public ConcreteToCAPIBase<ConcreteToCAPIPass> {
|
||||
@@ -463,7 +464,7 @@ struct ConcreteToCAPIPass : public ConcreteToCAPIBase<ConcreteToCAPIPass> {
|
||||
// Mark ops from the target dialect as legal operations
|
||||
target.addLegalDialect<func::FuncDialect>();
|
||||
target.addLegalDialect<memref::MemRefDialect>();
|
||||
target.addLegalDialect<arith::ArithmeticDialect>();
|
||||
target.addLegalDialect<arith::ArithDialect>();
|
||||
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
|
||||
|
||||
// Make sure that no ops from `FHE` remain after the lowering
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/Dialect/SCF/IR/SCF.h>
|
||||
|
||||
namespace SDFG = mlir::concretelang::SDFG;
|
||||
|
||||
@@ -86,7 +86,7 @@ struct DotToLinalgGeneric
|
||||
|
||||
// Create `linalg.generic` op
|
||||
llvm::SmallVector<mlir::Type, 1> resTypes{zeroTensorOp.getType()};
|
||||
llvm::SmallVector<mlir::Value, 2> ins{dotOp.lhs(), dotOp.rhs()};
|
||||
llvm::SmallVector<mlir::Value, 2> ins{dotOp.getLhs(), dotOp.getRhs()};
|
||||
llvm::SmallVector<mlir::Value, 1> outs{zeroTensorOp};
|
||||
llvm::SmallVector<mlir::AffineMap, 3> maps{
|
||||
mlir::AffineMap::getMultiDimIdentityMap(1, this->getContext()),
|
||||
@@ -94,7 +94,8 @@ struct DotToLinalgGeneric
|
||||
mlir::AffineMap::get(1, 0, {rewriter.getAffineConstantExpr(0)},
|
||||
this->getContext())};
|
||||
|
||||
llvm::SmallVector<llvm::StringRef, 1> itTypes{"reduction"};
|
||||
llvm::SmallVector<mlir::utils::IteratorType, 1> itTypes{
|
||||
mlir::utils::IteratorType::reduction};
|
||||
llvm::StringRef doc{""};
|
||||
llvm::StringRef call{""};
|
||||
|
||||
@@ -236,10 +237,10 @@ struct FHELinalgOpToLinalgGeneric : public mlir::OpRewritePattern<FHELinalgOp> {
|
||||
mlir::RankedTensorType resultTy =
|
||||
((mlir::Type)linalgOp->getResult(0).getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
mlir::RankedTensorType lhsTy =
|
||||
((mlir::Type)linalgOp.lhs().getType()).cast<mlir::RankedTensorType>();
|
||||
mlir::RankedTensorType rhsTy =
|
||||
((mlir::Type)linalgOp.rhs().getType()).cast<mlir::RankedTensorType>();
|
||||
mlir::RankedTensorType lhsTy = ((mlir::Type)linalgOp.getLhs().getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
mlir::RankedTensorType rhsTy = ((mlir::Type)linalgOp.getRhs().getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
// linalg.init_tensor for initial value
|
||||
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
|
||||
linalgOp.getLoc(), resultTy, mlir::ValueRange{});
|
||||
@@ -252,8 +253,8 @@ struct FHELinalgOpToLinalgGeneric : public mlir::OpRewritePattern<FHELinalgOp> {
|
||||
};
|
||||
|
||||
// Create the iterator_types
|
||||
llvm::SmallVector<llvm::StringRef> iteratorTypes(resultTy.getShape().size(),
|
||||
"parallel");
|
||||
llvm::SmallVector<mlir::utils::IteratorType> iteratorTypes(
|
||||
resultTy.getShape().size(), mlir::utils::IteratorType::parallel);
|
||||
|
||||
// Create the body of the `linalg.generic` op
|
||||
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
|
||||
@@ -268,7 +269,7 @@ struct FHELinalgOpToLinalgGeneric : public mlir::OpRewritePattern<FHELinalgOp> {
|
||||
|
||||
// Create the `linalg.generic` op
|
||||
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
|
||||
llvm::SmallVector<mlir::Value, 2> ins{linalgOp.lhs(), linalgOp.rhs()};
|
||||
llvm::SmallVector<mlir::Value, 2> ins{linalgOp.getLhs(), linalgOp.getRhs()};
|
||||
llvm::SmallVector<mlir::Value, 1> outs{init};
|
||||
llvm::StringRef doc{""};
|
||||
llvm::StringRef call{""};
|
||||
@@ -288,8 +289,9 @@ template <class T> inline mlir::RankedTensorType getRankedTensorType(T v) {
|
||||
return ((mlir::Type)v.getType()).cast<mlir::RankedTensorType>();
|
||||
}
|
||||
|
||||
llvm::SmallVector<llvm::StringRef> parallelIteratorType(int n) {
|
||||
return llvm::SmallVector<llvm::StringRef>(n, "parallel");
|
||||
llvm::SmallVector<mlir::utils::IteratorType> parallelIteratorType(int n) {
|
||||
return llvm::SmallVector<mlir::utils::IteratorType>(
|
||||
n, mlir::utils::IteratorType::parallel);
|
||||
}
|
||||
|
||||
/// This class rewrite pattern transforms any instance of
|
||||
@@ -347,9 +349,9 @@ struct FHELinalgApplyMappedLookupTableToLinalgGeneric
|
||||
using AffineMaps = llvm::SmallVector<mlir::AffineMap>;
|
||||
using sliceArg = llvm::SmallVector<mlir::OpFoldResult>;
|
||||
|
||||
auto input = mappedLookup.t();
|
||||
auto luts = mappedLookup.luts();
|
||||
auto map = mappedLookup.map();
|
||||
auto input = mappedLookup.getT();
|
||||
auto luts = mappedLookup.getLuts();
|
||||
auto map = mappedLookup.getMap();
|
||||
|
||||
auto loc = mappedLookup.getLoc();
|
||||
auto tensorTy = getRankedTensorType(input);
|
||||
@@ -471,9 +473,10 @@ struct FHELinalgApplyMultiLookupTableToLinalgGeneric
|
||||
mlir::RankedTensorType resultTy =
|
||||
((mlir::Type)fheLinalgLutOp->getResult(0).getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
mlir::RankedTensorType tensorTy = ((mlir::Type)fheLinalgLutOp.t().getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
auto luts = fheLinalgLutOp.luts();
|
||||
mlir::RankedTensorType tensorTy =
|
||||
((mlir::Type)fheLinalgLutOp.getT().getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
auto luts = fheLinalgLutOp.getLuts();
|
||||
mlir::RankedTensorType lutsTy = getRankedTensorType(luts);
|
||||
auto lutElmtTy = lutsTy.getElementType();
|
||||
// linalg.init_tensor for initial value
|
||||
@@ -549,7 +552,7 @@ struct FHELinalgApplyMultiLookupTableToLinalgGeneric
|
||||
|
||||
// Create the `linalg.generic` op
|
||||
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
|
||||
llvm::SmallVector<mlir::Value> ins{fheLinalgLutOp.t()};
|
||||
llvm::SmallVector<mlir::Value> ins{fheLinalgLutOp.getT()};
|
||||
llvm::SmallVector<mlir::Value, 1> outs{init};
|
||||
llvm::StringRef doc{""};
|
||||
llvm::StringRef call{""};
|
||||
@@ -619,7 +622,7 @@ struct FHELinalgApplyLookupTableToLinalgGeneric
|
||||
((mlir::Type)lutOp->getResult(0).getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
mlir::RankedTensorType tTy =
|
||||
((mlir::Type)lutOp.t().getType()).cast<mlir::RankedTensorType>();
|
||||
((mlir::Type)lutOp.getT().getType()).cast<mlir::RankedTensorType>();
|
||||
|
||||
// linalg.init_tensor for initial value
|
||||
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
|
||||
@@ -634,8 +637,8 @@ struct FHELinalgApplyLookupTableToLinalgGeneric
|
||||
};
|
||||
|
||||
// Create the iterator_types
|
||||
llvm::SmallVector<llvm::StringRef> iteratorTypes(resultTy.getShape().size(),
|
||||
"parallel");
|
||||
llvm::SmallVector<mlir::utils::IteratorType> iteratorTypes(
|
||||
resultTy.getShape().size(), mlir::utils::IteratorType::parallel);
|
||||
|
||||
// Create the body of the `linalg.generic` op
|
||||
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
|
||||
@@ -644,7 +647,7 @@ struct FHELinalgApplyLookupTableToLinalgGeneric
|
||||
mlir::concretelang::FHE::ApplyLookupTableEintOp fheOp =
|
||||
nestedBuilder.create<mlir::concretelang::FHE::ApplyLookupTableEintOp>(
|
||||
lutOp.getLoc(), resultTy.getElementType(), blockArgs[0],
|
||||
lutOp.lut());
|
||||
lutOp.getLut());
|
||||
|
||||
nestedBuilder.create<mlir::linalg::YieldOp>(lutOp.getLoc(),
|
||||
fheOp.getResult());
|
||||
@@ -652,7 +655,7 @@ struct FHELinalgApplyLookupTableToLinalgGeneric
|
||||
|
||||
// Create the `linalg.generic` op
|
||||
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
|
||||
llvm::SmallVector<mlir::Value, 1> ins{lutOp.t()};
|
||||
llvm::SmallVector<mlir::Value, 1> ins{lutOp.getT()};
|
||||
llvm::SmallVector<mlir::Value, 1> outs{init};
|
||||
llvm::StringRef doc{""};
|
||||
llvm::StringRef call{""};
|
||||
@@ -716,8 +719,9 @@ struct FHELinalgNegEintToLinalgGeneric
|
||||
mlir::RankedTensorType resultTy =
|
||||
((mlir::Type)negEintOp->getResult(0).getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
mlir::RankedTensorType tensorTy = ((mlir::Type)negEintOp.tensor().getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
mlir::RankedTensorType tensorTy =
|
||||
((mlir::Type)negEintOp.getTensor().getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
|
||||
// linalg.init_tensor for initial value
|
||||
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
|
||||
@@ -732,8 +736,8 @@ struct FHELinalgNegEintToLinalgGeneric
|
||||
};
|
||||
|
||||
// Create the iterator_types
|
||||
llvm::SmallVector<llvm::StringRef> iteratorTypes(resultTy.getShape().size(),
|
||||
"parallel");
|
||||
llvm::SmallVector<mlir::utils::IteratorType> iteratorTypes(
|
||||
resultTy.getShape().size(), mlir::utils::IteratorType::parallel);
|
||||
|
||||
// Create the body of the `linalg.generic` op
|
||||
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
|
||||
@@ -749,7 +753,7 @@ struct FHELinalgNegEintToLinalgGeneric
|
||||
|
||||
// Create the `linalg.generic` op
|
||||
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
|
||||
llvm::SmallVector<mlir::Value, 1> ins{negEintOp.tensor()};
|
||||
llvm::SmallVector<mlir::Value, 1> ins{negEintOp.getTensor()};
|
||||
llvm::SmallVector<mlir::Value, 1> outs{init};
|
||||
llvm::StringRef doc{""};
|
||||
llvm::StringRef call{""};
|
||||
@@ -821,8 +825,8 @@ struct FHELinalgMatmulToLinalgGeneric
|
||||
|
||||
mlir::Location location = matmulOp.getLoc();
|
||||
|
||||
mlir::Value lhs = matmulOp.lhs();
|
||||
mlir::Value rhs = matmulOp.rhs();
|
||||
mlir::Value lhs = matmulOp.getLhs();
|
||||
mlir::Value rhs = matmulOp.getRhs();
|
||||
mlir::Value out = matmulOp.getResult();
|
||||
|
||||
auto lhsType = ((mlir::Type)lhs.getType()).cast<mlir::RankedTensorType>();
|
||||
@@ -843,7 +847,7 @@ struct FHELinalgMatmulToLinalgGeneric
|
||||
auto ins = llvm::SmallVector<mlir::Value, 2>{lhs, rhs};
|
||||
auto outs = llvm::SmallVector<mlir::Value, 1>{zeros};
|
||||
|
||||
auto iteratorTypes = llvm::SmallVector<llvm::StringRef, 3>{};
|
||||
auto iteratorTypes = llvm::SmallVector<mlir::utils::IteratorType, 3>{};
|
||||
|
||||
auto lhsAffineExpressions = llvm::SmallVector<mlir::AffineExpr, 2>{};
|
||||
auto rhsAffineExpressions = llvm::SmallVector<mlir::AffineExpr, 2>{};
|
||||
@@ -878,9 +882,9 @@ struct FHELinalgMatmulToLinalgGeneric
|
||||
// - Last iterator is for the reduced dimension (N in the examples)
|
||||
|
||||
for (int64_t i = 0; i < outDims; i++) {
|
||||
iteratorTypes.push_back(mlir::getParallelIteratorTypeName());
|
||||
iteratorTypes.push_back(mlir::utils::IteratorType::parallel);
|
||||
}
|
||||
iteratorTypes.push_back(mlir::getReductionIteratorTypeName());
|
||||
iteratorTypes.push_back(mlir::utils::IteratorType::reduction);
|
||||
|
||||
// we need to put appropriate affine dimension expressions
|
||||
// that match lhs.shape on iterator types array
|
||||
@@ -988,9 +992,9 @@ struct FHELinalgMatmulToLinalgGeneric
|
||||
int64_t commonDim = rhsDims - 2;
|
||||
for (int64_t i = 0; i < rhsDims; i++) {
|
||||
if (i == commonDim) {
|
||||
iteratorTypes.push_back(mlir::getReductionIteratorTypeName());
|
||||
iteratorTypes.push_back(mlir::utils::IteratorType::reduction);
|
||||
} else {
|
||||
iteratorTypes.push_back(mlir::getParallelIteratorTypeName());
|
||||
iteratorTypes.push_back(mlir::utils::IteratorType::parallel);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1016,9 +1020,9 @@ struct FHELinalgMatmulToLinalgGeneric
|
||||
// KxLxMxN @ N -> KxLxM
|
||||
|
||||
for (int64_t i = 0; i < lhsDims - 1; i++) {
|
||||
iteratorTypes.push_back(mlir::getParallelIteratorTypeName());
|
||||
iteratorTypes.push_back(mlir::utils::IteratorType::parallel);
|
||||
}
|
||||
iteratorTypes.push_back(mlir::getReductionIteratorTypeName());
|
||||
iteratorTypes.push_back(mlir::utils::IteratorType::reduction);
|
||||
|
||||
for (int64_t i = 0; i < lhsDims; i++) {
|
||||
lhsAffineExpressions.push_back(rewriter.getAffineDimExpr(i));
|
||||
@@ -1145,7 +1149,7 @@ struct SumToLinalgGeneric
|
||||
}
|
||||
|
||||
auto axesToDestroy = std::unordered_set<int64_t>{};
|
||||
for (mlir::Attribute axisAttribute : sumOp.axes()) {
|
||||
for (mlir::Attribute axisAttribute : sumOp.getAxes()) {
|
||||
int64_t axis = axisAttribute.cast<mlir::IntegerAttr>().getInt();
|
||||
axesToDestroy.insert(axis);
|
||||
}
|
||||
@@ -1178,7 +1182,7 @@ struct SumToLinalgGeneric
|
||||
bool ithAxisIsDestroyed = axesToDestroy.find(i) != axesToDestroy.end();
|
||||
if (!ithAxisIsDestroyed) {
|
||||
outputAffineExpressions.push_back(rewriter.getAffineDimExpr(i));
|
||||
} else if (sumOp.keep_dims()) {
|
||||
} else if (sumOp.getKeepDims()) {
|
||||
outputAffineExpressions.push_back(rewriter.getAffineConstantExpr(0));
|
||||
}
|
||||
}
|
||||
@@ -1191,11 +1195,11 @@ struct SumToLinalgGeneric
|
||||
|
||||
auto maps = llvm::SmallVector<mlir::AffineMap, 2>{inputMap, outputMap};
|
||||
|
||||
auto iteratorTypes = llvm::SmallVector<llvm::StringRef, 3>(
|
||||
inputDimensions, mlir::getParallelIteratorTypeName());
|
||||
auto iteratorTypes = llvm::SmallVector<mlir::utils::IteratorType, 3>(
|
||||
inputDimensions, mlir::utils::IteratorType::parallel);
|
||||
|
||||
for (int64_t axis : axesToDestroy) {
|
||||
iteratorTypes[axis] = mlir::getReductionIteratorTypeName();
|
||||
iteratorTypes[axis] = mlir::utils::IteratorType::reduction;
|
||||
}
|
||||
|
||||
auto regionBuilder = [&](mlir::OpBuilder &nestedBuilder,
|
||||
@@ -1210,11 +1214,10 @@ struct SumToLinalgGeneric
|
||||
};
|
||||
|
||||
auto resultTypes = llvm::SmallVector<mlir::Type, 1>{accumulatorType};
|
||||
mlir::Value accumulation =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(location, resultTypes, ins, outs, maps,
|
||||
iteratorTypes, regionBuilder)
|
||||
.getResult(0);
|
||||
linalg::GenericOp genericOp = rewriter.create<linalg::GenericOp>(
|
||||
location, resultTypes, ins, outs, maps, iteratorTypes, regionBuilder);
|
||||
|
||||
mlir::Value accumulation = genericOp.getResult(0);
|
||||
|
||||
mlir::Value result = accumulation;
|
||||
if (!outputIsTensor) {
|
||||
@@ -1282,7 +1285,7 @@ struct TransposeToLinalgGeneric
|
||||
|
||||
std::vector<unsigned int> perms = {};
|
||||
|
||||
mlir::ArrayAttr axes = transposeOp.axes();
|
||||
mlir::ArrayAttr axes = transposeOp.getAxes();
|
||||
if (axes.empty()) {
|
||||
for (int i = n_dim - 1; i >= 0; i--) {
|
||||
perms.push_back(i);
|
||||
@@ -1386,7 +1389,7 @@ struct ConcatRewritePattern
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::Location location = op.getLoc();
|
||||
size_t axis = op.axis();
|
||||
size_t axis = op.getAxis();
|
||||
|
||||
mlir::Value output = op.getResult();
|
||||
auto outputType = output.getType().dyn_cast<mlir::TensorType>();
|
||||
@@ -1452,9 +1455,11 @@ struct ConcatRewritePattern
|
||||
sizes[axis] = axisSize;
|
||||
|
||||
// these arrays are copied, so it's fine to modify and use them again
|
||||
mlir::ArrayAttr offsetsAttr = rewriter.getI64ArrayAttr(offsets);
|
||||
mlir::ArrayAttr sizesAttr = rewriter.getI64ArrayAttr(sizes);
|
||||
mlir::ArrayAttr stridesAttr = rewriter.getI64ArrayAttr(strides);
|
||||
mlir::DenseI64ArrayAttr offsetsAttr =
|
||||
rewriter.getDenseI64ArrayAttr(offsets);
|
||||
mlir::DenseI64ArrayAttr sizesAttr = rewriter.getDenseI64ArrayAttr(sizes);
|
||||
mlir::DenseI64ArrayAttr stridesAttr =
|
||||
rewriter.getDenseI64ArrayAttr(strides);
|
||||
|
||||
offsets[axis] += axisSize;
|
||||
|
||||
@@ -1497,9 +1502,10 @@ getPaddedTensor(mlir::Operation *op, mlir::OpBuilder &b, mlir::Value &input,
|
||||
getAsOpFoldResult(b, loc, lowPaddingInts);
|
||||
mlir::SmallVector<mlir::OpFoldResult> highPaddings =
|
||||
getAsOpFoldResult(b, loc, highPaddingInts);
|
||||
mlir::Value paddedInput = mlir::tensor::createPadScalarOp(
|
||||
rankedTensorType, input, pad, /*low=*/lowPaddings, /*high=*/highPaddings,
|
||||
/*packing=*/false, loc, b);
|
||||
|
||||
mlir::Value paddedInput = b.create<mlir::tensor::PadOp>(
|
||||
loc, rankedTensorType, input, lowPaddings, highPaddings, pad);
|
||||
|
||||
return paddedInput;
|
||||
}
|
||||
|
||||
@@ -1665,9 +1671,9 @@ struct FHELinalgConv2dToLinalgConv2d
|
||||
|
||||
mlir::Location loc = conv2dOp->getLoc();
|
||||
mlir::Value input =
|
||||
conv2dOp.input(); /* shape: Batch*Channels*Height*Width */
|
||||
conv2dOp.getInput(); /* shape: Batch*Channels*Height*Width */
|
||||
mlir::Value weight =
|
||||
conv2dOp.weight(); /* shape: Filters*Channels*Height*Width */
|
||||
conv2dOp.getWeight(); /* shape: Filters*Channels*Height*Width */
|
||||
|
||||
mlir::Type inputElementType =
|
||||
input.getType().cast<mlir::RankedTensorType>().getElementType();
|
||||
@@ -1714,7 +1720,7 @@ struct FHELinalgConv2dToLinalgConv2d
|
||||
// Since linalg doesn't support a bias in the conv operation, we initialize
|
||||
// the output tensor to the bias values, so that conv results get
|
||||
// accumulated to it
|
||||
mlir::Value bias = conv2dOp.bias(); /* optional of shape: Filters */
|
||||
mlir::Value bias = conv2dOp.getBias(); /* optional of shape: Filters */
|
||||
mlir::Value biasInitTensor;
|
||||
if (!bias) { // no bias was used
|
||||
biasInitTensor = initTensor;
|
||||
@@ -1726,7 +1732,8 @@ struct FHELinalgConv2dToLinalgConv2d
|
||||
mlir::AffineMap::get(resultRank, 0, rewriter.getAffineDimExpr(1),
|
||||
rewriter.getContext()),
|
||||
rewriter.getMultiDimIdentityMap(resultRank)};
|
||||
mlir::SmallVector<llvm::StringRef> iteratorTypes(resultRank, "parallel");
|
||||
mlir::SmallVector<mlir::utils::IteratorType> iteratorTypes(
|
||||
resultRank, mlir::utils::IteratorType::parallel);
|
||||
biasInitTensor =
|
||||
rewriter
|
||||
.create<mlir::linalg::GenericOp>(
|
||||
@@ -1817,14 +1824,15 @@ struct FHELinalgMaxpool2dToLinalgMaxpool2d
|
||||
output = rewriter.create<FHELinalg::SubEintIntOp>(loc, output, offset);
|
||||
}
|
||||
|
||||
const mlir::DenseElementsAttr kernelShapeAttr = maxpool2dOp.kernel_shape();
|
||||
const mlir::DenseElementsAttr kernelShapeAttr =
|
||||
maxpool2dOp.getKernelShape();
|
||||
const auto kernelShape =
|
||||
llvm::SmallVector<int64_t, 2>(kernelShapeAttr.value_begin<int64_t>(),
|
||||
kernelShapeAttr.value_end<int64_t>());
|
||||
|
||||
const mlir::Value kernel =
|
||||
rewriter
|
||||
.create<mlir::linalg::InitTensorOp>(
|
||||
.create<mlir::tensor::EmptyOp>(
|
||||
loc, kernelShape,
|
||||
mlir::IntegerType::get(this->getContext(), 64))
|
||||
.getResult();
|
||||
@@ -1833,12 +1841,12 @@ struct FHELinalgMaxpool2dToLinalgMaxpool2d
|
||||
rewriter.getI64VectorAttr({1, 1});
|
||||
|
||||
const mlir::DenseIntElementsAttr stridesAttr =
|
||||
maxpool2dOp.dilations().getValueOr(defaultAttr);
|
||||
maxpool2dOp.getDilations().value_or(defaultAttr);
|
||||
const mlir::DenseIntElementsAttr dilationsAttr =
|
||||
maxpool2dOp.dilations().getValueOr(defaultAttr);
|
||||
maxpool2dOp.getDilations().value_or(defaultAttr);
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::linalg::PoolingNchwMaxOp>(
|
||||
maxpool2dOp, outputTy, mlir::ValueRange{maxpool2dOp.input(), kernel},
|
||||
maxpool2dOp, outputTy, mlir::ValueRange{maxpool2dOp.getInput(), kernel},
|
||||
output, stridesAttr, dilationsAttr,
|
||||
llvm::ArrayRef<mlir::NamedAttribute>({maxOpAttr}));
|
||||
|
||||
@@ -1892,7 +1900,7 @@ struct FHELinalgToSignedToLinalgGeneric
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::RankedTensorType inputTy =
|
||||
op.input().getType().cast<mlir::RankedTensorType>();
|
||||
op.getInput().getType().cast<mlir::RankedTensorType>();
|
||||
mlir::RankedTensorType resultTy =
|
||||
op->getResult(0).getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
@@ -1906,8 +1914,8 @@ struct FHELinalgToSignedToLinalgGeneric
|
||||
this->getContext()),
|
||||
};
|
||||
|
||||
llvm::SmallVector<llvm::StringRef> iteratorTypes(resultTy.getShape().size(),
|
||||
"parallel");
|
||||
llvm::SmallVector<mlir::utils::IteratorType> iteratorTypes(
|
||||
resultTy.getShape().size(), mlir::utils::IteratorType::parallel);
|
||||
|
||||
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
@@ -1920,7 +1928,7 @@ struct FHELinalgToSignedToLinalgGeneric
|
||||
};
|
||||
|
||||
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
|
||||
llvm::SmallVector<mlir::Value, 1> ins{op.input()};
|
||||
llvm::SmallVector<mlir::Value, 1> ins{op.getInput()};
|
||||
llvm::SmallVector<mlir::Value, 1> outs{init};
|
||||
|
||||
llvm::StringRef doc{""};
|
||||
@@ -1981,7 +1989,7 @@ struct FHELinalgToUnsignedToLinalgGeneric
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::RankedTensorType inputTy =
|
||||
op.input().getType().cast<mlir::RankedTensorType>();
|
||||
op.getInput().getType().cast<mlir::RankedTensorType>();
|
||||
mlir::RankedTensorType resultTy =
|
||||
op->getResult(0).getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
@@ -1995,8 +2003,8 @@ struct FHELinalgToUnsignedToLinalgGeneric
|
||||
this->getContext()),
|
||||
};
|
||||
|
||||
llvm::SmallVector<llvm::StringRef> iteratorTypes(resultTy.getShape().size(),
|
||||
"parallel");
|
||||
llvm::SmallVector<mlir::utils::IteratorType> iteratorTypes(
|
||||
resultTy.getShape().size(), mlir::utils::IteratorType::parallel);
|
||||
|
||||
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
@@ -2009,7 +2017,7 @@ struct FHELinalgToUnsignedToLinalgGeneric
|
||||
};
|
||||
|
||||
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
|
||||
llvm::SmallVector<mlir::Value, 1> ins{op.input()};
|
||||
llvm::SmallVector<mlir::Value, 1> ins{op.getInput()};
|
||||
llvm::SmallVector<mlir::Value, 1> outs{init};
|
||||
|
||||
llvm::StringRef doc{""};
|
||||
@@ -2040,7 +2048,7 @@ void FHETensorOpsToLinalg::runOnOperation() {
|
||||
target.addLegalDialect<mlir::memref::MemRefDialect>();
|
||||
target.addLegalDialect<mlir::concretelang::FHE::FHEDialect>();
|
||||
target.addLegalDialect<mlir::tensor::TensorDialect>();
|
||||
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
|
||||
target.addLegalDialect<mlir::arith::ArithDialect>();
|
||||
target.addIllegalOp<mlir::concretelang::FHELinalg::Dot>();
|
||||
target.addIllegalDialect<mlir::concretelang::FHELinalg::FHELinalgDialect>();
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
// for license information.
|
||||
|
||||
#include <iostream>
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/IR/Operation.h>
|
||||
@@ -195,8 +195,8 @@ struct AddEintIntOpPattern : public CrtOpPattern<FHE::AddEintIntOp> {
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value eintOperand = adaptor.a();
|
||||
mlir::Value intOperand = adaptor.b();
|
||||
mlir::Value eintOperand = adaptor.getA();
|
||||
mlir::Value intOperand = adaptor.getB();
|
||||
|
||||
// Write plaintext encoding
|
||||
mlir::Value encodedPlaintextTensor =
|
||||
@@ -242,8 +242,8 @@ struct SubIntEintOpPattern : public CrtOpPattern<FHE::SubIntEintOp> {
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value intOperand = adaptor.a();
|
||||
mlir::Value eintOperand = adaptor.b();
|
||||
mlir::Value intOperand = adaptor.getA();
|
||||
mlir::Value eintOperand = adaptor.getB();
|
||||
|
||||
// Write plaintext encoding
|
||||
mlir::Value encodedPlaintextTensor =
|
||||
@@ -289,8 +289,8 @@ struct SubEintIntOpPattern : public CrtOpPattern<FHE::SubEintIntOp> {
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value eintOperand = adaptor.a();
|
||||
mlir::Value intOperand = adaptor.b();
|
||||
mlir::Value eintOperand = adaptor.getA();
|
||||
mlir::Value intOperand = adaptor.getB();
|
||||
|
||||
// Write plaintext negation
|
||||
mlir::Type intType = intOperand.getType();
|
||||
@@ -346,8 +346,8 @@ struct AddEintOpPattern : CrtOpPattern<FHE::AddEintOp> {
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value lhsOperand = adaptor.a();
|
||||
mlir::Value rhsOperand = adaptor.b();
|
||||
mlir::Value lhsOperand = adaptor.getA();
|
||||
mlir::Value rhsOperand = adaptor.getB();
|
||||
|
||||
// Write add loop.
|
||||
mlir::Type ciphertextScalarType =
|
||||
@@ -389,8 +389,8 @@ struct SubEintOpPattern : CrtOpPattern<FHE::SubEintOp> {
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value lhsOperand = adaptor.a();
|
||||
mlir::Value rhsOperand = adaptor.b();
|
||||
mlir::Value lhsOperand = adaptor.getA();
|
||||
mlir::Value rhsOperand = adaptor.getB();
|
||||
|
||||
// Write sub loop.
|
||||
mlir::Type ciphertextScalarType =
|
||||
@@ -434,7 +434,7 @@ struct NegEintOpPattern : CrtOpPattern<FHE::NegEintOp> {
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value operand = adaptor.a();
|
||||
mlir::Value operand = adaptor.getA();
|
||||
|
||||
// Write the loop nest.
|
||||
mlir::Type ciphertextScalarType = converter->convertType(operand.getType())
|
||||
@@ -472,7 +472,7 @@ struct ToSignedOpPattern : public CrtOpPattern<FHE::ToSignedOp> {
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
typing::TypeConverter converter{loweringParameters};
|
||||
rewriter.replaceOp(op, {adaptor.input()});
|
||||
rewriter.replaceOp(op, {adaptor.getInput()});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -490,7 +490,7 @@ struct ToUnsignedOpPattern : public CrtOpPattern<FHE::ToUnsignedOp> {
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
typing::TypeConverter converter{loweringParameters};
|
||||
rewriter.replaceOp(op, {adaptor.input()});
|
||||
rewriter.replaceOp(op, {adaptor.getInput()});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -509,8 +509,8 @@ struct MulEintIntOpPattern : CrtOpPattern<FHE::MulEintIntOp> {
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value eintOperand = adaptor.a();
|
||||
mlir::Value intOperand = adaptor.b();
|
||||
mlir::Value eintOperand = adaptor.getA();
|
||||
mlir::Value intOperand = adaptor.getB();
|
||||
|
||||
// Write cleartext "encoding"
|
||||
mlir::Value encodedCleartext = rewriter.create<mlir::arith::ExtSIOp>(
|
||||
@@ -556,7 +556,8 @@ struct ApplyLookupTableEintOpPattern
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
|
||||
auto originalInputType = op.a().getType().cast<FHE::FheIntegerInterface>();
|
||||
auto originalInputType =
|
||||
op.getA().getType().cast<FHE::FheIntegerInterface>();
|
||||
|
||||
mlir::Value newLut =
|
||||
rewriter
|
||||
@@ -567,7 +568,7 @@ struct ApplyLookupTableEintOpPattern
|
||||
(int64_t)loweringParameters.nMods,
|
||||
(int64_t)loweringParameters.singleLutSize},
|
||||
rewriter.getI64Type()),
|
||||
adaptor.lut(),
|
||||
adaptor.getLut(),
|
||||
rewriter.getI64ArrayAttr(
|
||||
mlir::ArrayRef<int64_t>(loweringParameters.mods)),
|
||||
rewriter.getI64ArrayAttr(
|
||||
@@ -578,8 +579,9 @@ struct ApplyLookupTableEintOpPattern
|
||||
|
||||
// Replace the lut with an encoded / expanded one.
|
||||
auto wopPBS = rewriter.create<TFHE::WopPBSGLWEOp>(
|
||||
op.getLoc(), converter->convertType(op.getType()), adaptor.a(), newLut,
|
||||
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, rewriter.getI64ArrayAttr({}));
|
||||
op.getLoc(), converter->convertType(op.getType()), adaptor.getA(),
|
||||
newLut, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
|
||||
rewriter.getI64ArrayAttr({}));
|
||||
|
||||
rewriter.replaceOp(op, {wopPBS.getResult()});
|
||||
return ::mlir::success();
|
||||
@@ -601,25 +603,25 @@ struct TraceCiphertextOpPattern : CrtOpPattern<Tracing::TraceCiphertextOp> {
|
||||
|
||||
typing::TypeConverter converter{loweringParameters};
|
||||
mlir::Type ciphertextScalarType =
|
||||
converter.convertType(op.ciphertext().getType())
|
||||
converter.convertType(op.getCiphertext().getType())
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType();
|
||||
|
||||
for (size_t i = 0; i < (loweringParameters.nMods - 1); ++i) {
|
||||
auto extractedCiphertext = rewriter.create<mlir::tensor::ExtractOp>(
|
||||
op.getLoc(), ciphertextScalarType, adaptor.ciphertext(),
|
||||
op.getLoc(), ciphertextScalarType, adaptor.getCiphertext(),
|
||||
mlir::ValueRange{rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(i))});
|
||||
rewriter.create<Tracing::TraceCiphertextOp>(
|
||||
op.getLoc(), extractedCiphertext, op.msgAttr(), op.nmsbAttr());
|
||||
op.getLoc(), extractedCiphertext, op.getMsgAttr(), op.getNmsbAttr());
|
||||
}
|
||||
|
||||
auto extractedCiphertext = rewriter.create<mlir::tensor::ExtractOp>(
|
||||
op.getLoc(), ciphertextScalarType, adaptor.ciphertext(),
|
||||
op.getLoc(), ciphertextScalarType, adaptor.getCiphertext(),
|
||||
mlir::ValueRange{rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(loweringParameters.nMods - 1))});
|
||||
rewriter.replaceOpWithNewOp<Tracing::TraceCiphertextOp>(
|
||||
op, extractedCiphertext, op.msgAttr(), op.nmsbAttr());
|
||||
op, extractedCiphertext, op.getMsgAttr(), op.getNmsbAttr());
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -813,7 +815,7 @@ struct TensorReassociationOpPattern : public CrtOpPattern<Op> {
|
||||
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
|
||||
auto reassocVal = (inRank ? adaptor.src() : op.result());
|
||||
auto reassocVal = (inRank ? adaptor.getSrc() : op.getResult());
|
||||
auto reassocTy = reassocVal.getType();
|
||||
auto newReassocType = converter->convertType(reassocTy);
|
||||
|
||||
@@ -827,7 +829,7 @@ struct TensorReassociationOpPattern : public CrtOpPattern<Op> {
|
||||
|
||||
auto newOp = rewriter.create<Op>(
|
||||
op.getLoc(), converter->convertType(op.getResult().getType()),
|
||||
adaptor.src(), newReassocs);
|
||||
adaptor.getSrc(), newReassocs);
|
||||
rewriter.replaceOp(op, {newOp});
|
||||
|
||||
return mlir::success();
|
||||
@@ -848,25 +850,22 @@ struct ExtractSliceOpPattern
|
||||
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
|
||||
mlir::SmallVector<mlir::Attribute> newStaticOffsets{
|
||||
op.static_offsets().template getAsRange<mlir::IntegerAttr>()};
|
||||
mlir::SmallVector<mlir::Attribute> newStaticSizes{
|
||||
op.static_sizes().template getAsRange<mlir::IntegerAttr>()};
|
||||
mlir::SmallVector<mlir::Attribute> newStaticStrides{
|
||||
op.static_strides().template getAsRange<mlir::IntegerAttr>()};
|
||||
newStaticOffsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
newStaticSizes.push_back(
|
||||
rewriter.getI64IntegerAttr(this->loweringParameters.nMods));
|
||||
newStaticStrides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
mlir::SmallVector<int64_t> newStaticOffsets{op.static_offsets()};
|
||||
mlir::SmallVector<int64_t> newStaticSizes{op.static_sizes()};
|
||||
mlir::SmallVector<int64_t> newStaticStrides{op.static_strides()};
|
||||
newStaticOffsets.push_back(0);
|
||||
newStaticSizes.push_back(this->loweringParameters.nMods);
|
||||
newStaticStrides.push_back(1);
|
||||
|
||||
mlir::RankedTensorType newType =
|
||||
converter->convertType(op.getResult().getType())
|
||||
.template cast<mlir::RankedTensorType>();
|
||||
rewriter.replaceOpWithNewOp<mlir::tensor::ExtractSliceOp>(
|
||||
op, newType, adaptor.source(), adaptor.getOffsets(), adaptor.getSizes(),
|
||||
adaptor.getStrides(), rewriter.getArrayAttr(newStaticOffsets),
|
||||
rewriter.getArrayAttr(newStaticSizes),
|
||||
rewriter.getArrayAttr(newStaticStrides));
|
||||
op, newType, adaptor.getSource(), adaptor.getOffsets(),
|
||||
adaptor.getSizes(), adaptor.getStrides(),
|
||||
rewriter.getDenseI64ArrayAttr(newStaticOffsets),
|
||||
rewriter.getDenseI64ArrayAttr(newStaticSizes),
|
||||
rewriter.getDenseI64ArrayAttr(newStaticStrides));
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
@@ -885,26 +884,22 @@ struct InsertSliceOpPattern : public CrtOpPattern<mlir::tensor::InsertSliceOp> {
|
||||
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
|
||||
mlir::SmallVector<mlir::Attribute> newStaticOffsets{
|
||||
op.static_offsets().template getAsRange<mlir::IntegerAttr>()};
|
||||
mlir::SmallVector<mlir::Attribute> newStaticSizes{
|
||||
op.static_sizes().template getAsRange<mlir::IntegerAttr>()};
|
||||
mlir::SmallVector<mlir::Attribute> newStaticStrides{
|
||||
op.static_strides().template getAsRange<mlir::IntegerAttr>()};
|
||||
newStaticOffsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
newStaticSizes.push_back(
|
||||
rewriter.getI64IntegerAttr(this->loweringParameters.nMods));
|
||||
newStaticStrides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
mlir::SmallVector<int64_t> newStaticOffsets{op.static_offsets()};
|
||||
mlir::SmallVector<int64_t> newStaticSizes{op.static_sizes()};
|
||||
mlir::SmallVector<int64_t> newStaticStrides{op.static_strides()};
|
||||
newStaticOffsets.push_back(0);
|
||||
newStaticSizes.push_back(this->loweringParameters.nMods);
|
||||
newStaticStrides.push_back(1);
|
||||
|
||||
mlir::RankedTensorType newType =
|
||||
converter->convertType(op.getResult().getType())
|
||||
.template cast<mlir::RankedTensorType>();
|
||||
rewriter.replaceOpWithNewOp<mlir::tensor::InsertSliceOp>(
|
||||
op, newType, adaptor.source(), adaptor.dest(), adaptor.getOffsets(),
|
||||
adaptor.getSizes(), adaptor.getStrides(),
|
||||
rewriter.getArrayAttr(newStaticOffsets),
|
||||
rewriter.getArrayAttr(newStaticSizes),
|
||||
rewriter.getArrayAttr(newStaticStrides));
|
||||
op, newType, adaptor.getSource(), adaptor.getDest(),
|
||||
adaptor.getOffsets(), adaptor.getSizes(), adaptor.getStrides(),
|
||||
rewriter.getDenseI64ArrayAttr(newStaticOffsets),
|
||||
rewriter.getDenseI64ArrayAttr(newStaticSizes),
|
||||
rewriter.getDenseI64ArrayAttr(newStaticStrides));
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
@@ -926,7 +921,7 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
//------------------------------------------- Marking legal/illegal dialects
|
||||
target.addIllegalDialect<FHE::FHEDialect>();
|
||||
target.addLegalDialect<TFHE::TFHEDialect>();
|
||||
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
|
||||
target.addLegalDialect<mlir::arith::ArithDialect>();
|
||||
target.addDynamicallyLegalOp<mlir::tensor::GenerateOp, mlir::scf::ForOp>(
|
||||
[&](mlir::Operation *op) {
|
||||
return (
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
// for license information.
|
||||
|
||||
#include <iostream>
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/Linalg/IR/Linalg.h>
|
||||
@@ -146,12 +146,12 @@ struct AddEintIntOpPattern : public ScalarOpPattern<FHE::AddEintIntOp> {
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
// Write the plaintext encoding
|
||||
mlir::Value encodedInt = writePlaintextShiftEncoding(
|
||||
op.getLoc(), adaptor.b(),
|
||||
op.getLoc(), adaptor.getB(),
|
||||
op.getType().cast<FHE::FheIntegerInterface>().getWidth(), rewriter);
|
||||
|
||||
// Write the new op
|
||||
rewriter.replaceOpWithNewOp<TFHE::AddGLWEIntOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.a(),
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
|
||||
encodedInt);
|
||||
|
||||
return mlir::success();
|
||||
@@ -169,8 +169,8 @@ struct SubEintIntOpPattern : public ScalarOpPattern<FHE::SubEintIntOp> {
|
||||
matchAndRewrite(FHE::SubEintIntOp op, FHE::SubEintIntOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value eintOperand = op.a();
|
||||
mlir::Value intOperand = op.b();
|
||||
mlir::Value eintOperand = op.getA();
|
||||
mlir::Value intOperand = op.getB();
|
||||
|
||||
// Write the integer negation
|
||||
mlir::Type intType = intOperand.getType();
|
||||
@@ -190,7 +190,7 @@ struct SubEintIntOpPattern : public ScalarOpPattern<FHE::SubEintIntOp> {
|
||||
|
||||
// Write the new op
|
||||
rewriter.replaceOpWithNewOp<TFHE::AddGLWEIntOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.a(),
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
|
||||
encodedInt);
|
||||
|
||||
return mlir::success();
|
||||
@@ -209,13 +209,14 @@ struct SubIntEintOpPattern : public ScalarOpPattern<FHE::SubIntEintOp> {
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
// Write the plaintext encoding
|
||||
mlir::Value encodedInt = writePlaintextShiftEncoding(
|
||||
op.getLoc(), adaptor.a(),
|
||||
op.b().getType().cast<FHE::FheIntegerInterface>().getWidth(), rewriter);
|
||||
op.getLoc(), adaptor.getA(),
|
||||
op.getB().getType().cast<FHE::FheIntegerInterface>().getWidth(),
|
||||
rewriter);
|
||||
|
||||
// Write the new op
|
||||
rewriter.replaceOpWithNewOp<TFHE::SubGLWEIntOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), encodedInt,
|
||||
adaptor.b());
|
||||
adaptor.getB());
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
@@ -231,8 +232,8 @@ struct SubEintOpPattern : public ScalarOpPattern<FHE::SubEintOp> {
|
||||
matchAndRewrite(FHE::SubEintOp op, FHE::SubEintOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value lhsOperand = adaptor.a();
|
||||
mlir::Value rhsOperand = adaptor.b();
|
||||
mlir::Value lhsOperand = adaptor.getA();
|
||||
mlir::Value rhsOperand = adaptor.getB();
|
||||
|
||||
// Write rhs negation
|
||||
auto negative = rewriter.create<TFHE::NegGLWEOp>(
|
||||
@@ -259,8 +260,8 @@ struct MulEintIntOpPattern : public ScalarOpPattern<FHE::MulEintIntOp> {
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value eintOperand = adaptor.a();
|
||||
mlir::Value intOperand = adaptor.b();
|
||||
mlir::Value eintOperand = adaptor.getA();
|
||||
mlir::Value intOperand = adaptor.getB();
|
||||
|
||||
// Write the cleartext "encoding"
|
||||
mlir::Value castedCleartext = rewriter.create<mlir::arith::ExtSIOp>(
|
||||
@@ -286,7 +287,7 @@ struct ToSignedOpPattern : public ScalarOpPattern<FHE::ToSignedOp> {
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
typing::TypeConverter converter;
|
||||
rewriter.replaceOp(op, {adaptor.input()});
|
||||
rewriter.replaceOp(op, {adaptor.getInput()});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -304,7 +305,7 @@ struct ToUnsignedOpPattern : public ScalarOpPattern<FHE::ToUnsignedOp> {
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
typing::TypeConverter converter;
|
||||
rewriter.replaceOp(op, {adaptor.input()});
|
||||
rewriter.replaceOp(op, {adaptor.getInput()});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -326,7 +327,7 @@ struct ApplyLookupTableEintOpPattern
|
||||
FHE::ApplyLookupTableEintOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
auto inputType = op.a().getType().cast<FHE::FheIntegerInterface>();
|
||||
auto inputType = op.getA().getType().cast<FHE::FheIntegerInterface>();
|
||||
size_t outputBits =
|
||||
op.getResult().getType().cast<FHE::FheIntegerInterface>().getWidth();
|
||||
mlir::Value newLut =
|
||||
@@ -336,14 +337,14 @@ struct ApplyLookupTableEintOpPattern
|
||||
mlir::RankedTensorType::get(
|
||||
mlir::ArrayRef<int64_t>(loweringParameters.polynomialSize),
|
||||
rewriter.getI64Type()),
|
||||
op.lut(),
|
||||
op.getLut(),
|
||||
rewriter.getI32IntegerAttr(loweringParameters.polynomialSize),
|
||||
rewriter.getI32IntegerAttr(outputBits),
|
||||
rewriter.getBoolAttr(inputType.isSigned()))
|
||||
.getResult();
|
||||
|
||||
typing::TypeConverter converter;
|
||||
mlir::Value input = adaptor.a();
|
||||
mlir::Value input = adaptor.getA();
|
||||
|
||||
if (inputType.isSigned()) {
|
||||
// If the input is a signed integer, it comes to the bootstrap with a
|
||||
@@ -367,7 +368,7 @@ struct ApplyLookupTableEintOpPattern
|
||||
|
||||
// Insert keyswitch
|
||||
auto ksOp = rewriter.create<TFHE::KeySwitchGLWEOp>(
|
||||
op.getLoc(), getTypeConverter()->convertType(adaptor.a().getType()),
|
||||
op.getLoc(), getTypeConverter()->convertType(adaptor.getA().getType()),
|
||||
input, -1, -1);
|
||||
|
||||
// Insert bootstrap
|
||||
@@ -409,8 +410,8 @@ struct RoundEintOpPattern : public ScalarOpPattern<FHE::RoundEintOp> {
|
||||
// -> Subtract this one from the input by performing a
|
||||
// homomorphic subtraction.
|
||||
|
||||
mlir::Value input = adaptor.input();
|
||||
auto inputType = op.input().getType().cast<FHE::FheIntegerInterface>();
|
||||
mlir::Value input = adaptor.getInput();
|
||||
auto inputType = op.getInput().getType().cast<FHE::FheIntegerInterface>();
|
||||
mlir::Value output = op.getResult();
|
||||
uint64_t inputBitwidth = inputType.getWidth();
|
||||
uint64_t outputBitwidth =
|
||||
@@ -594,12 +595,12 @@ struct ToBoolOpPattern : public mlir::OpRewritePattern<FHE::ToBoolOp> {
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::ToBoolOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto width = op.input()
|
||||
auto width = op.getInput()
|
||||
.getType()
|
||||
.dyn_cast<mlir::concretelang::FHE::EncryptedIntegerType>()
|
||||
.getWidth();
|
||||
if (width == mlir::concretelang::FHE::EncryptedBooleanType::getWidth()) {
|
||||
rewriter.replaceOp(op, op.input());
|
||||
rewriter.replaceOp(op, op.getInput());
|
||||
return mlir::success();
|
||||
}
|
||||
// TODO
|
||||
@@ -622,7 +623,7 @@ struct FromBoolOpPattern : public mlir::OpRewritePattern<FHE::FromBoolOp> {
|
||||
.dyn_cast<mlir::concretelang::FHE::EncryptedIntegerType>()
|
||||
.getWidth();
|
||||
if (width == mlir::concretelang::FHE::EncryptedBooleanType::getWidth()) {
|
||||
rewriter.replaceOp(op, op.input());
|
||||
rewriter.replaceOp(op, op.getInput());
|
||||
return mlir::success();
|
||||
}
|
||||
// TODO
|
||||
@@ -647,7 +648,7 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
|
||||
//------------------------------------------- Marking legal/illegal dialects
|
||||
target.addIllegalDialect<FHE::FHEDialect>();
|
||||
target.addLegalDialect<TFHE::TFHEDialect>();
|
||||
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
|
||||
target.addLegalDialect<mlir::arith::ArithDialect>();
|
||||
target.addDynamicallyLegalOp<mlir::linalg::GenericOp,
|
||||
mlir::tensor::GenerateOp>(
|
||||
[&](mlir::Operation *op) {
|
||||
|
||||
@@ -43,7 +43,7 @@ public:
|
||||
if (((mlir::LogicalResult)loops).failed() || loops->size() == 0)
|
||||
return mlir::failure();
|
||||
|
||||
rewriter.replaceOp(linalgOp, loops.getValue()[0]->getResult(0));
|
||||
rewriter.replaceOp(linalgOp, loops.value()[0]->getResult(0));
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
|
||||
@@ -5,11 +5,13 @@
|
||||
|
||||
#include <iostream>
|
||||
#include <mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h>
|
||||
#include <mlir/IR/BuiltinTypeInterfaces.h>
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
||||
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
|
||||
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
|
||||
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
|
||||
@@ -21,6 +23,7 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
@@ -40,7 +43,7 @@ struct MLIRLowerableDialectsToLLVMPass
|
||||
void runOnOperation() final;
|
||||
|
||||
/// Convert types to the LLVM dialect-compatible type
|
||||
static llvm::Optional<mlir::Type> convertTypes(mlir::Type type);
|
||||
static std::optional<mlir::Type> convertTypes(mlir::Type type);
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@@ -73,11 +76,12 @@ struct Memref1DCopyOpPattern
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::memref::CopyOp copyOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (copyOp.source().getType().cast<mlir::MemRefType>().getRank() != 1 ||
|
||||
copyOp.source().getType().cast<mlir::MemRefType>().getRank() != 1) {
|
||||
if (copyOp.getSource().getType().cast<mlir::MemRefType>().getRank() != 1 ||
|
||||
copyOp.getSource().getType().cast<mlir::MemRefType>().getRank() != 1) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto opType = mlir::MemRefType::get({-1}, rewriter.getI64Type());
|
||||
auto opType = mlir::MemRefType::get({mlir::ShapedType::kDynamic},
|
||||
rewriter.getI64Type());
|
||||
// Insert forward declaration of the add_lwe_ciphertexts function
|
||||
{
|
||||
if (insertForwardDeclaration(
|
||||
@@ -89,9 +93,9 @@ struct Memref1DCopyOpPattern
|
||||
}
|
||||
}
|
||||
auto sourceOp = rewriter.create<mlir::memref::CastOp>(
|
||||
copyOp.getLoc(), opType, copyOp.source());
|
||||
copyOp.getLoc(), opType, copyOp.getSource());
|
||||
auto targetOp = rewriter.create<mlir::memref::CastOp>(
|
||||
copyOp.getLoc(), opType, copyOp.target());
|
||||
copyOp.getLoc(), opType, copyOp.getTarget());
|
||||
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
|
||||
copyOp, "memref_copy_one_rank", mlir::TypeRange{},
|
||||
mlir::ValueRange{sourceOp, targetOp});
|
||||
@@ -119,9 +123,10 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
|
||||
mlir::concretelang::populateRTToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
mlir::populateFuncToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
mlir::populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::memref::populateExpandStridedMetadataPatterns(patterns);
|
||||
mlir::populateAffineToStdConversionPatterns(patterns);
|
||||
mlir::populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::populateSCFToControlFlowConversionPatterns(patterns);
|
||||
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
@@ -143,7 +148,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Optional<mlir::Type>
|
||||
std::optional<mlir::Type>
|
||||
MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) {
|
||||
if (type.isa<mlir::concretelang::Concrete::ContextType>() ||
|
||||
type.isa<mlir::concretelang::RT::FutureType>() ||
|
||||
@@ -161,7 +166,7 @@ MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) {
|
||||
mlir::Type convertedSubtype = typeConverter.convertType(subtype);
|
||||
return mlir::LLVM::LLVMPointerType::get(convertedSubtype);
|
||||
}
|
||||
return llvm::None;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@@ -47,7 +47,7 @@ char stream_emulator_put_uint64[] = "stream_emulator_put_uint64";
|
||||
char stream_emulator_get_uint64[] = "stream_emulator_get_uint64";
|
||||
|
||||
mlir::Type getDynamicTensor(mlir::OpBuilder &rewriter, size_t rank) {
|
||||
std::vector<int64_t> shape(rank, -1);
|
||||
std::vector<int64_t> shape(rank, mlir::ShapedType::kDynamic);
|
||||
return mlir::RankedTensorType::get(shape, rewriter.getI64Type());
|
||||
}
|
||||
|
||||
@@ -166,7 +166,7 @@ struct LowerSDFGMakeProcess
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
const char *funcName;
|
||||
mlir::SmallVector<mlir::Value> operands(mpOp->getOperands());
|
||||
switch (mpOp.type()) {
|
||||
switch (mpOp.getType()) {
|
||||
case SDFG::ProcessKind::add_eint:
|
||||
funcName = stream_emulator_make_memref_add_lwe_ciphertexts_u64_process;
|
||||
break;
|
||||
@@ -249,7 +249,7 @@ struct LowerSDFGMakeStream
|
||||
const char *funcName;
|
||||
|
||||
stream_type t;
|
||||
switch (msOp.type()) {
|
||||
switch (msOp.getType()) {
|
||||
case SDFG::StreamKind::host_to_device:
|
||||
t = TS_STREAM_TYPE_X86_TO_TOPO_LSAP;
|
||||
break;
|
||||
@@ -370,7 +370,7 @@ void SDFGToStreamEmulatorPass::runOnOperation() {
|
||||
SDFG::MakeProcess, SDFG::MakeStream, SDFG::Put>();
|
||||
// All Concrete ops are legal after the conversion
|
||||
target.addLegalDialect<mlir::concretelang::Concrete::ConcreteDialect>();
|
||||
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
|
||||
target.addLegalDialect<mlir::arith::ArithDialect>();
|
||||
target.addLegalOp<mlir::func::ReturnOp, mlir::func::FuncOp,
|
||||
mlir::func::CallOp, SDFG::Get, mlir::tensor::CastOp>();
|
||||
|
||||
|
||||
@@ -114,15 +114,17 @@ struct KeySwitchGLWEOpPattern
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(TFHE::KeySwitchGLWEOp ksOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto inputTy = ksOp.ciphertext().getType().cast<TFHE::GLWECipherTextType>();
|
||||
auto newInputTy = converter.convertType(inputTy);
|
||||
auto outputTy = ksOp.result().getType().cast<TFHE::GLWECipherTextType>();
|
||||
auto inputTy =
|
||||
ksOp.getCiphertext().getType().cast<TFHE::GLWECipherTextType>();
|
||||
auto newInputTy = converter.convertType(inputTy)
|
||||
.cast<mlir::concretelang::TFHE::GLWECipherTextType>();
|
||||
auto outputTy = ksOp.getResult().getType().cast<TFHE::GLWECipherTextType>();
|
||||
auto newOutputTy = converter.glweIntraPBSType(outputTy);
|
||||
auto newOp = rewriter.replaceOpWithNewOp<TFHE::KeySwitchGLWEOp>(
|
||||
ksOp, newOutputTy, ksOp.ciphertext(), cryptoParameters.ksLevel,
|
||||
ksOp, newOutputTy, ksOp.getCiphertext(), cryptoParameters.ksLevel,
|
||||
cryptoParameters.ksLogBase);
|
||||
rewriter.startRootUpdate(newOp);
|
||||
newOp.ciphertext().setType(newInputTy);
|
||||
newOp.getCiphertext().setType(newInputTy);
|
||||
rewriter.finalizeRootUpdate(newOp);
|
||||
return mlir::success();
|
||||
};
|
||||
@@ -145,16 +147,17 @@ struct BootstrapGLWEOpPattern
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(TFHE::BootstrapGLWEOp bsOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto inputTy = bsOp.ciphertext().getType().cast<TFHE::GLWECipherTextType>();
|
||||
auto inputTy =
|
||||
bsOp.getCiphertext().getType().cast<TFHE::GLWECipherTextType>();
|
||||
auto newInputTy = converter.glweIntraPBSType(inputTy);
|
||||
auto outputTy = bsOp.result().getType().cast<TFHE::GLWECipherTextType>();
|
||||
auto outputTy = bsOp.getResult().getType().cast<TFHE::GLWECipherTextType>();
|
||||
auto newOutputTy = converter.convertType(outputTy);
|
||||
auto newOp = rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
|
||||
bsOp, newOutputTy, bsOp.ciphertext(), bsOp.lookup_table(),
|
||||
bsOp, newOutputTy, bsOp.getCiphertext(), bsOp.getLookupTable(),
|
||||
cryptoParameters.brLevel, cryptoParameters.brLogBase,
|
||||
cryptoParameters.getPolynomialSize(), cryptoParameters.glweDimension);
|
||||
rewriter.startRootUpdate(newOp);
|
||||
newOp.ciphertext().setType(newInputTy);
|
||||
newOp.getCiphertext().setType(newInputTy);
|
||||
rewriter.finalizeRootUpdate(newOp);
|
||||
return mlir::success();
|
||||
};
|
||||
@@ -177,8 +180,8 @@ struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
|
||||
matchAndRewrite(TFHE::WopPBSGLWEOp wopPBSOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto newOp = rewriter.replaceOpWithNewOp<TFHE::WopPBSGLWEOp>(
|
||||
wopPBSOp, converter.convertType(wopPBSOp.result().getType()),
|
||||
wopPBSOp.ciphertexts(), wopPBSOp.lookupTable(),
|
||||
wopPBSOp, converter.convertType(wopPBSOp.getResult().getType()),
|
||||
wopPBSOp.getCiphertexts(), wopPBSOp.getLookupTable(),
|
||||
// Bootstrap parameters
|
||||
cryptoParameters.brLevel, cryptoParameters.brLogBase,
|
||||
// Keyswitch parameters
|
||||
@@ -198,12 +201,12 @@ struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
|
||||
cryptoParameters.largeInteger->crtDecomposition));
|
||||
rewriter.startRootUpdate(newOp);
|
||||
auto ctType =
|
||||
wopPBSOp.ciphertexts().getType().cast<mlir::RankedTensorType>();
|
||||
wopPBSOp.getCiphertexts().getType().cast<mlir::RankedTensorType>();
|
||||
auto ciphertextType =
|
||||
ctType.getElementType().cast<TFHE::GLWECipherTextType>();
|
||||
auto newType = mlir::RankedTensorType::get(
|
||||
ctType.getShape(), converter.glweInterPBSType(ciphertextType));
|
||||
newOp.ciphertexts().setType(newType);
|
||||
newOp.getCiphertexts().setType(newType);
|
||||
rewriter.finalizeRootUpdate(newOp);
|
||||
return mlir::success();
|
||||
};
|
||||
@@ -279,7 +282,8 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
|
||||
cryptoParameters);
|
||||
target.addDynamicallyLegalOp<TFHE::KeySwitchGLWEOp>(
|
||||
[&](TFHE::KeySwitchGLWEOp op) {
|
||||
return op.level() != (uint32_t)-1 && op.baseLog() != (uint32_t)-1;
|
||||
return op.getLevel() != (uint32_t)-1 &&
|
||||
op.getBaseLog() != (uint32_t)-1;
|
||||
});
|
||||
patterns.add<BootstrapGLWEOpPattern>(&getContext(), converter,
|
||||
cryptoParameters);
|
||||
|
||||
@@ -96,11 +96,11 @@ struct SubIntGLWEOpPattern
|
||||
matchAndRewrite(TFHE::SubGLWEIntOp subOp, TFHE::SubGLWEIntOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::Value negated = rewriter.create<Concrete::NegateLweTensorOp>(
|
||||
subOp.getLoc(), adaptor.b().getType(), adaptor.b());
|
||||
subOp.getLoc(), adaptor.getB().getType(), adaptor.getB());
|
||||
|
||||
rewriter.replaceOpWithNewOp<Concrete::AddPlaintextLweTensorOp>(
|
||||
subOp, this->getTypeConverter()->convertType(subOp.getType()), negated,
|
||||
subOp.a());
|
||||
subOp.getA());
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -119,17 +119,16 @@ struct BootstrapGLWEOpPattern
|
||||
matchAndRewrite(TFHE::BootstrapGLWEOp bsOp,
|
||||
TFHE::BootstrapGLWEOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
TFHE::GLWECipherTextType resultType =
|
||||
bsOp.getType().cast<TFHE::GLWECipherTextType>();
|
||||
TFHE::GLWECipherTextType inputType =
|
||||
bsOp.ciphertext().getType().cast<TFHE::GLWECipherTextType>();
|
||||
bsOp.getCiphertext().getType().cast<TFHE::GLWECipherTextType>();
|
||||
|
||||
rewriter.replaceOpWithNewOp<Concrete::BootstrapLweTensorOp>(
|
||||
bsOp, this->getTypeConverter()->convertType(resultType),
|
||||
adaptor.ciphertext(), adaptor.lookup_table(), inputType.getDimension(),
|
||||
adaptor.polySize(), adaptor.level(), adaptor.baseLog(),
|
||||
adaptor.glweDimension(), resultType.getP());
|
||||
adaptor.getCiphertext(), adaptor.getLookupTable(),
|
||||
inputType.getDimension(), adaptor.getPolySize(), adaptor.getLevel(),
|
||||
adaptor.getBaseLog(), adaptor.getGlweDimension(), resultType.getP());
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -152,11 +151,11 @@ struct KeySwitchGLWEOpPattern
|
||||
TFHE::GLWECipherTextType resultType =
|
||||
ksOp.getType().cast<TFHE::GLWECipherTextType>();
|
||||
TFHE::GLWECipherTextType inputType =
|
||||
ksOp.ciphertext().getType().cast<TFHE::GLWECipherTextType>();
|
||||
ksOp.getCiphertext().getType().cast<TFHE::GLWECipherTextType>();
|
||||
|
||||
rewriter.replaceOpWithNewOp<Concrete::KeySwitchLweTensorOp>(
|
||||
ksOp, this->getTypeConverter()->convertType(resultType),
|
||||
adaptor.ciphertext(), adaptor.level(), adaptor.baseLog(),
|
||||
adaptor.getCiphertext(), adaptor.getLevel(), adaptor.getBaseLog(),
|
||||
inputType.getDimension(), resultType.getDimension());
|
||||
|
||||
return mlir::success();
|
||||
@@ -174,15 +173,15 @@ struct TracePlaintextOpPattern
|
||||
matchAndRewrite(Tracing::TracePlaintextOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto inputWidth =
|
||||
op.plaintext().getType().cast<mlir::IntegerType>().getWidth();
|
||||
op.getPlaintext().getType().cast<mlir::IntegerType>().getWidth();
|
||||
if (inputWidth == 64) {
|
||||
op->setAttr("input_width", rewriter.getI64IntegerAttr(inputWidth));
|
||||
return mlir::success();
|
||||
}
|
||||
auto extendedInput = rewriter.create<mlir::arith::ExtUIOp>(
|
||||
op.getLoc(), rewriter.getI64Type(), op.plaintext());
|
||||
op.getLoc(), rewriter.getI64Type(), op.getPlaintext());
|
||||
auto newOp = rewriter.replaceOpWithNewOp<Tracing::TracePlaintextOp>(
|
||||
op, extendedInput, op.msgAttr(), op.nmsbAttr());
|
||||
op, extendedInput, op.getMsgAttr(), op.getNmsbAttr());
|
||||
newOp->setAttr("input_width", rewriter.getI64IntegerAttr(inputWidth));
|
||||
return ::mlir::success();
|
||||
}
|
||||
@@ -235,37 +234,36 @@ struct ExtractSliceOpPattern
|
||||
if (this->getTypeConverter()->isLegal(extractSliceOp.getType())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto resultTy = extractSliceOp.result().getType();
|
||||
auto resultTy = extractSliceOp.getResult().getType();
|
||||
auto newResultTy = this->getTypeConverter()
|
||||
->convertType(resultTy)
|
||||
.cast<mlir::RankedTensorType>();
|
||||
|
||||
// add 0 to the static_offsets
|
||||
mlir::SmallVector<mlir::Attribute> staticOffsets;
|
||||
staticOffsets.append(adaptor.static_offsets().begin(),
|
||||
adaptor.static_offsets().end());
|
||||
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
mlir::SmallVector<int64_t> staticOffsets;
|
||||
staticOffsets.append(adaptor.getStaticOffsets().begin(),
|
||||
adaptor.getStaticOffsets().end());
|
||||
staticOffsets.push_back(0);
|
||||
|
||||
// add the lweSize to the sizes
|
||||
mlir::SmallVector<mlir::Attribute> staticSizes;
|
||||
staticSizes.append(adaptor.static_sizes().begin(),
|
||||
adaptor.static_sizes().end());
|
||||
staticSizes.push_back(rewriter.getI64IntegerAttr(
|
||||
newResultTy.getDimSize(newResultTy.getRank() - 1)));
|
||||
mlir::SmallVector<int64_t> staticSizes;
|
||||
staticSizes.append(adaptor.getStaticSizes().begin(),
|
||||
adaptor.getStaticSizes().end());
|
||||
staticSizes.push_back(newResultTy.getDimSize(newResultTy.getRank() - 1));
|
||||
|
||||
// add 1 to the strides
|
||||
mlir::SmallVector<mlir::Attribute> staticStrides;
|
||||
staticStrides.append(adaptor.static_strides().begin(),
|
||||
adaptor.static_strides().end());
|
||||
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
mlir::SmallVector<int64_t> staticStrides;
|
||||
staticStrides.append(adaptor.getStaticStrides().begin(),
|
||||
adaptor.getStaticStrides().end());
|
||||
staticStrides.push_back(1);
|
||||
|
||||
// replace tensor.extract_slice to the new one
|
||||
rewriter.replaceOpWithNewOp<mlir::tensor::ExtractSliceOp>(
|
||||
extractSliceOp, newResultTy, adaptor.source(), adaptor.offsets(),
|
||||
adaptor.sizes(), adaptor.strides(),
|
||||
rewriter.getArrayAttr(staticOffsets),
|
||||
rewriter.getArrayAttr(staticSizes),
|
||||
rewriter.getArrayAttr(staticStrides));
|
||||
extractSliceOp, newResultTy, adaptor.getSource(), adaptor.getOffsets(),
|
||||
adaptor.getSizes(), adaptor.getStrides(),
|
||||
rewriter.getDenseI64ArrayAttr(staticOffsets),
|
||||
rewriter.getDenseI64ArrayAttr(staticSizes),
|
||||
rewriter.getDenseI64ArrayAttr(staticStrides));
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
@@ -294,31 +292,28 @@ struct ExtractOpPattern
|
||||
->convertType(extractOp.getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
auto tensorRank =
|
||||
adaptor.tensor().getType().cast<mlir::RankedTensorType>().getRank();
|
||||
adaptor.getTensor().getType().cast<mlir::RankedTensorType>().getRank();
|
||||
|
||||
// [min..., 0] for static_offsets ()
|
||||
mlir::SmallVector<mlir::Attribute> staticOffsets(
|
||||
tensorRank,
|
||||
rewriter.getI64IntegerAttr(std::numeric_limits<int64_t>::min()));
|
||||
staticOffsets[staticOffsets.size() - 1] = rewriter.getI64IntegerAttr(0);
|
||||
mlir::SmallVector<int64_t> staticOffsets(
|
||||
tensorRank, std::numeric_limits<int64_t>::min());
|
||||
staticOffsets[staticOffsets.size() - 1] = 0;
|
||||
|
||||
// [1..., lweDimension+1] for static_sizes or
|
||||
// [1..., nbBlock, lweDimension+1]
|
||||
mlir::SmallVector<mlir::Attribute> staticSizes(
|
||||
tensorRank, rewriter.getI64IntegerAttr(1));
|
||||
staticSizes[staticSizes.size() - 1] = rewriter.getI64IntegerAttr(
|
||||
newResultType.getDimSize(newResultType.getRank() - 1));
|
||||
mlir::SmallVector<int64_t> staticSizes(tensorRank, 1);
|
||||
staticSizes[staticSizes.size() - 1] =
|
||||
newResultType.getDimSize(newResultType.getRank() - 1);
|
||||
|
||||
// [1...] for static_strides
|
||||
mlir::SmallVector<mlir::Attribute> staticStrides(
|
||||
tensorRank, rewriter.getI64IntegerAttr(1));
|
||||
mlir::SmallVector<int64_t> staticStrides(tensorRank, 1);
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::tensor::ExtractSliceOp>(
|
||||
extractOp, newResultType, adaptor.tensor(), adaptor.indices(),
|
||||
extractOp, newResultType, adaptor.getTensor(), adaptor.getIndices(),
|
||||
mlir::SmallVector<mlir::Value>{}, mlir::SmallVector<mlir::Value>{},
|
||||
rewriter.getArrayAttr(staticOffsets),
|
||||
rewriter.getArrayAttr(staticSizes),
|
||||
rewriter.getArrayAttr(staticStrides));
|
||||
rewriter.getDenseI64ArrayAttr(staticOffsets),
|
||||
rewriter.getDenseI64ArrayAttr(staticSizes),
|
||||
rewriter.getDenseI64ArrayAttr(staticStrides));
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
@@ -344,35 +339,34 @@ struct InsertSliceOpPattern
|
||||
}
|
||||
|
||||
auto newResultTy = this->getTypeConverter()
|
||||
->convertType(insertSliceOp.result().getType())
|
||||
->convertType(insertSliceOp.getResult().getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
|
||||
// add 0 to static_offsets
|
||||
mlir::SmallVector<mlir::Attribute> staticOffsets;
|
||||
staticOffsets.append(adaptor.static_offsets().begin(),
|
||||
adaptor.static_offsets().end());
|
||||
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
// add 0 to static offsets
|
||||
mlir::SmallVector<int64_t> staticOffsets;
|
||||
staticOffsets.append(adaptor.getStaticOffsets().begin(),
|
||||
adaptor.getStaticOffsets().end());
|
||||
staticOffsets.push_back(0);
|
||||
|
||||
// add lweDimension+1 to static_sizes
|
||||
mlir::SmallVector<mlir::Attribute> staticSizes;
|
||||
staticSizes.append(adaptor.static_sizes().begin(),
|
||||
adaptor.static_sizes().end());
|
||||
staticSizes.push_back(rewriter.getI64IntegerAttr(
|
||||
newResultTy.getDimSize(newResultTy.getRank() - 1)));
|
||||
mlir::SmallVector<int64_t> staticSizes;
|
||||
staticSizes.append(adaptor.getStaticSizes().begin(),
|
||||
adaptor.getStaticSizes().end());
|
||||
staticSizes.push_back(newResultTy.getDimSize(newResultTy.getRank() - 1));
|
||||
|
||||
// add 1 to the strides
|
||||
mlir::SmallVector<mlir::Attribute> staticStrides;
|
||||
staticStrides.append(adaptor.static_strides().begin(),
|
||||
adaptor.static_strides().end());
|
||||
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
mlir::SmallVector<int64_t> staticStrides;
|
||||
staticStrides.append(adaptor.getStaticStrides().begin(),
|
||||
adaptor.getStaticStrides().end());
|
||||
staticStrides.push_back(1);
|
||||
|
||||
// replace tensor.insert_slice with the new one
|
||||
rewriter.replaceOpWithNewOp<mlir::tensor::InsertSliceOp>(
|
||||
insertSliceOp, newResultTy, adaptor.source(), adaptor.dest(),
|
||||
adaptor.offsets(), adaptor.sizes(), adaptor.strides(),
|
||||
rewriter.getArrayAttr(staticOffsets),
|
||||
rewriter.getArrayAttr(staticSizes),
|
||||
rewriter.getArrayAttr(staticStrides));
|
||||
insertSliceOp, newResultTy, adaptor.getSource(), adaptor.getDest(),
|
||||
adaptor.getOffsets(), adaptor.getSizes(), adaptor.getStrides(),
|
||||
rewriter.getDenseI64ArrayAttr(staticOffsets),
|
||||
rewriter.getDenseI64ArrayAttr(staticSizes),
|
||||
rewriter.getDenseI64ArrayAttr(staticStrides));
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
@@ -399,18 +393,18 @@ struct InsertOpPattern
|
||||
|
||||
mlir::RankedTensorType newResultTy =
|
||||
this->getTypeConverter()
|
||||
->convertType(insertOp.result().getType())
|
||||
->convertType(insertOp.getResult().getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
|
||||
// add zeros to static_offsets
|
||||
// add zeros to static offsets
|
||||
mlir::SmallVector<mlir::OpFoldResult> offsets;
|
||||
offsets.append(adaptor.indices().begin(), adaptor.indices().end());
|
||||
offsets.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
|
||||
// Inserting a smaller tensor into a (potentially) bigger one. Set
|
||||
// dimensions for all leading dimensions of the target tensor not
|
||||
// present in the source to 1.
|
||||
mlir::SmallVector<mlir::OpFoldResult> sizes(adaptor.indices().size(),
|
||||
mlir::SmallVector<mlir::OpFoldResult> sizes(adaptor.getIndices().size(),
|
||||
rewriter.getI64IntegerAttr(1));
|
||||
|
||||
// Add size for the bufferized source element
|
||||
@@ -423,7 +417,8 @@ struct InsertOpPattern
|
||||
|
||||
// replace tensor.insert_slice with the new one
|
||||
rewriter.replaceOpWithNewOp<mlir::tensor::InsertSliceOp>(
|
||||
insertOp, adaptor.scalar(), adaptor.dest(), offsets, sizes, strides);
|
||||
insertOp, adaptor.getScalar(), adaptor.getDest(), offsets, sizes,
|
||||
strides);
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
@@ -453,7 +448,7 @@ struct FromElementsOpPattern
|
||||
|
||||
auto converter = this->getTypeConverter();
|
||||
|
||||
auto resultTy = fromElementsOp.result().getType();
|
||||
auto resultTy = fromElementsOp.getResult().getType();
|
||||
if (converter->isLegal(resultTy)) {
|
||||
return mlir::failure();
|
||||
}
|
||||
@@ -483,7 +478,7 @@ struct FromElementsOpPattern
|
||||
llvm::SmallVector<int64_t> currentOffsets(newRank, 0);
|
||||
|
||||
// for each elements insert_slice with right offet
|
||||
for (auto elt : llvm::enumerate(adaptor.elements())) {
|
||||
for (auto elt : llvm::enumerate(adaptor.getElements())) {
|
||||
// Just create offsets as attributes
|
||||
llvm::SmallVector<mlir::OpFoldResult, 4> offsets;
|
||||
offsets.reserve(currentOffsets.size());
|
||||
@@ -560,7 +555,7 @@ struct TensorShapeOpPattern : public mlir::OpConversionPattern<ShapeOp> {
|
||||
|
||||
auto reassocTy =
|
||||
((mlir::Type)this->getTypeConverter()->convertType(
|
||||
(inRank ? shapeOp.src() : shapeOp.result()).getType()))
|
||||
(inRank ? shapeOp.getSrc() : shapeOp.getResult()).getType()))
|
||||
.cast<VecTy>();
|
||||
|
||||
auto oldReassocs = shapeOp.getReassociationIndices();
|
||||
@@ -574,7 +569,7 @@ struct TensorShapeOpPattern : public mlir::OpConversionPattern<ShapeOp> {
|
||||
newReassocs.push_back(lweAssoc);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<ShapeOp>(shapeOp, newResultTy, adaptor.src(),
|
||||
rewriter.replaceOpWithNewOp<ShapeOp>(shapeOp, newResultTy, adaptor.getSrc(),
|
||||
newReassocs);
|
||||
|
||||
return ::mlir::success();
|
||||
@@ -729,8 +724,9 @@ void TFHEToConcretePass::runOnOperation() {
|
||||
target.addLegalOp<mlir::arith::ExtUIOp>();
|
||||
target.addDynamicallyLegalOp<Tracing::TracePlaintextOp>(
|
||||
[&](Tracing::TracePlaintextOp op) {
|
||||
return (op.plaintext().getType().cast<mlir::IntegerType>().getWidth() ==
|
||||
64);
|
||||
return (
|
||||
op.getPlaintext().getType().cast<mlir::IntegerType>().getWidth() ==
|
||||
64);
|
||||
});
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <mlir/IR/BuiltinTypeInterfaces.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
#include <mlir/Transforms/DialectConversion.h>
|
||||
|
||||
@@ -24,7 +25,7 @@ char memref_trace_message[] = "memref_trace_message";
|
||||
|
||||
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
|
||||
size_t rank) {
|
||||
std::vector<int64_t> shape(rank, -1);
|
||||
std::vector<int64_t> shape(rank, mlir::ShapedType::kDynamic);
|
||||
mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0);
|
||||
for (size_t i = 0; i < rank; i++) {
|
||||
expr = expr +
|
||||
@@ -136,15 +137,15 @@ private:
|
||||
void traceCiphertextAddOperands(Tracing::TraceCiphertextOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
mlir::RewriterBase &rewriter) {
|
||||
auto msg = op.msg().getValueOr("");
|
||||
auto nmsb = op.nmsb().getValueOr(0);
|
||||
auto msg = op.getMsg().value_or("");
|
||||
auto nmsb = op.getNmsb().value_or(0);
|
||||
std::string msgName;
|
||||
std::stringstream stream;
|
||||
stream << rand();
|
||||
stream >> msgName;
|
||||
auto messageVal =
|
||||
mlir::LLVM::createGlobalString(op.getLoc(), rewriter, msgName, msg,
|
||||
mlir::LLVM::linkage::Linkage::Linkonce);
|
||||
auto messageVal = mlir::LLVM::createGlobalString(
|
||||
op.getLoc(), rewriter, msgName, msg,
|
||||
mlir::LLVM::linkage::Linkage::Linkonce, false);
|
||||
operands.push_back(messageVal);
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(msg.size())));
|
||||
@@ -155,15 +156,15 @@ void traceCiphertextAddOperands(Tracing::TraceCiphertextOp op,
|
||||
void tracePlaintextAddOperands(Tracing::TracePlaintextOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
mlir::RewriterBase &rewriter) {
|
||||
auto msg = op.msg().getValueOr("");
|
||||
auto nmsb = op.nmsb().getValueOr(0);
|
||||
auto msg = op.getMsg().value_or("");
|
||||
auto nmsb = op.getNmsb().value_or(0);
|
||||
std::string msgName;
|
||||
std::stringstream stream;
|
||||
stream << rand();
|
||||
stream >> msgName;
|
||||
auto messageVal =
|
||||
mlir::LLVM::createGlobalString(op.getLoc(), rewriter, msgName, msg,
|
||||
mlir::LLVM::linkage::Linkage::Linkonce);
|
||||
auto messageVal = mlir::LLVM::createGlobalString(
|
||||
op.getLoc(), rewriter, msgName, msg,
|
||||
mlir::LLVM::linkage::Linkage::Linkonce, false);
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op->getAttr("input_width")));
|
||||
operands.push_back(messageVal);
|
||||
@@ -176,14 +177,14 @@ void tracePlaintextAddOperands(Tracing::TracePlaintextOp op,
|
||||
void traceMessageAddOperands(Tracing::TraceMessageOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
mlir::RewriterBase &rewriter) {
|
||||
auto msg = op.msg().getValueOr("");
|
||||
auto msg = op.getMsg().value_or("");
|
||||
std::string msgName;
|
||||
std::stringstream stream;
|
||||
stream << rand();
|
||||
stream >> msgName;
|
||||
auto messageVal =
|
||||
mlir::LLVM::createGlobalString(op.getLoc(), rewriter, msgName, msg,
|
||||
mlir::LLVM::linkage::Linkage::Linkonce);
|
||||
auto messageVal = mlir::LLVM::createGlobalString(
|
||||
op.getLoc(), rewriter, msgName, msg,
|
||||
mlir::LLVM::linkage::Linkage::Linkonce, false);
|
||||
operands.push_back(messageVal);
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(msg.size())));
|
||||
@@ -202,7 +203,7 @@ struct TracingToCAPIPass : public TracingToCAPIBase<TracingToCAPIPass> {
|
||||
// Mark ops from the target dialect as legal operations
|
||||
target.addLegalDialect<func::FuncDialect>();
|
||||
target.addLegalDialect<memref::MemRefDialect>();
|
||||
target.addLegalDialect<arith::ArithmeticDialect>();
|
||||
target.addLegalDialect<arith::ArithDialect>();
|
||||
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
|
||||
|
||||
// Make sure that no ops from `Tracing` remain after the lowering
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
|
||||
#include "concretelang/Conversion/Utils/Dialects/SCF.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
#include <mlir/IR/BlockAndValueMapping.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
// for license information.
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
@@ -46,14 +46,14 @@ struct TensorToMemrefOp : public BufferizableOpInterface::ExternalModel<
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const AnalysisState &state) const {
|
||||
return BufferRelation::None;
|
||||
return BufferRelation::Unknown;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
@@ -63,7 +63,7 @@ struct TensorToMemrefOp : public BufferizableOpInterface::ExternalModel<
|
||||
auto castOp = cast<TensorOp>(op);
|
||||
|
||||
auto resTensorType =
|
||||
castOp.result().getType().template cast<mlir::TensorType>();
|
||||
castOp.getResult().getType().template cast<mlir::TensorType>();
|
||||
|
||||
auto outMemrefType = MemRefType::get(resTensorType.getShape(),
|
||||
resTensorType.getElementType());
|
||||
@@ -81,7 +81,7 @@ struct TensorToMemrefOp : public BufferizableOpInterface::ExternalModel<
|
||||
operands.push_back(operand.get());
|
||||
} else {
|
||||
operands.push_back(
|
||||
bufferization::getBuffer(rewriter, operand.get(), options));
|
||||
*bufferization::getBuffer(rewriter, operand.get(), options));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ add_mlir_dialect_library(
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
ConcretelangConversion
|
||||
MLIRArithmeticDialect
|
||||
MLIRArithDialect
|
||||
MLIRBufferizationDialect
|
||||
MLIRBufferizationTransforms
|
||||
MLIRIR
|
||||
|
||||
@@ -6,11 +6,12 @@
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <initializer_list>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
@@ -65,7 +66,7 @@ struct FunctionToDag {
|
||||
mlir::concretelang::log_verbose() << MSG << "\n"; \
|
||||
}
|
||||
|
||||
outcome::checked<llvm::Optional<optimizer::Dag>,
|
||||
outcome::checked<std::optional<optimizer::Dag>,
|
||||
::concretelang::error::StringError>
|
||||
build() {
|
||||
auto dag = concrete_optimizer::dag::empty();
|
||||
@@ -88,7 +89,7 @@ struct FunctionToDag {
|
||||
// Dag is empty <=> classical function without encryption
|
||||
DEBUG("!!! concrete-optimizer: nothing to do in " << func.getName()
|
||||
<< "\n");
|
||||
return llvm::None;
|
||||
return std::nullopt;
|
||||
};
|
||||
DEBUG(std::string(dag->dump()));
|
||||
return std::move(dag);
|
||||
@@ -143,7 +144,7 @@ struct FunctionToDag {
|
||||
if (auto dot = asDot(op)) {
|
||||
auto weightsOpt = dotWeights(dot);
|
||||
if (weightsOpt) {
|
||||
addDot(dag, val, encrypted_inputs, weightsOpt.getValue());
|
||||
addDot(dag, val, encrypted_inputs, weightsOpt.value());
|
||||
return;
|
||||
}
|
||||
// If can't find weights return default leveled op
|
||||
@@ -231,8 +232,8 @@ struct FunctionToDag {
|
||||
mlir::Value result = mulOp.getResult();
|
||||
const std::vector<uint64_t> resultShape = getShape(result);
|
||||
|
||||
Operation *xOp = mulOp.a().getDefiningOp();
|
||||
Operation *yOp = mulOp.b().getDefiningOp();
|
||||
Operation *xOp = mulOp.getA().getDefiningOp();
|
||||
Operation *yOp = mulOp.getB().getDefiningOp();
|
||||
|
||||
const double fixedCost = NEGLIGIBLE_COMPLEXITY;
|
||||
const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY;
|
||||
@@ -292,8 +293,8 @@ struct FunctionToDag {
|
||||
mlir::Value result = maxOp.getResult();
|
||||
const std::vector<uint64_t> resultShape = getShape(result);
|
||||
|
||||
Operation *xOp = maxOp.x().getDefiningOp();
|
||||
Operation *yOp = maxOp.y().getDefiningOp();
|
||||
Operation *xOp = maxOp.getX().getDefiningOp();
|
||||
Operation *yOp = maxOp.getY().getDefiningOp();
|
||||
|
||||
const double fixedCost = NEGLIGIBLE_COMPLEXITY;
|
||||
const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY;
|
||||
@@ -344,12 +345,13 @@ struct FunctionToDag {
|
||||
std::vector<uint64_t> fakeShape = resultShape;
|
||||
|
||||
uint64_t numberOfComparisons = 1;
|
||||
for (auto dimensionSize : maxpool2dOp.kernel_shape().getValues<int64_t>()) {
|
||||
for (auto dimensionSize :
|
||||
maxpool2dOp.getKernelShape().getValues<int64_t>()) {
|
||||
numberOfComparisons *= dimensionSize;
|
||||
}
|
||||
fakeShape.push_back(numberOfComparisons);
|
||||
|
||||
Operation *inputOp = maxpool2dOp.input().getDefiningOp();
|
||||
Operation *inputOp = maxpool2dOp.getInput().getDefiningOp();
|
||||
|
||||
const double fixedCost = NEGLIGIBLE_COMPLEXITY;
|
||||
const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY;
|
||||
@@ -438,7 +440,7 @@ struct FunctionToDag {
|
||||
return value.isa<mlir::BlockArgument>();
|
||||
}
|
||||
|
||||
llvm::Optional<std::vector<std::int64_t>>
|
||||
std::optional<std::vector<std::int64_t>>
|
||||
resolveConstantVectorWeights(mlir::arith::ConstantOp &cstOp) {
|
||||
std::vector<std::int64_t> values;
|
||||
mlir::DenseIntElementsAttr denseVals =
|
||||
@@ -446,14 +448,14 @@ struct FunctionToDag {
|
||||
|
||||
for (llvm::APInt val : denseVals.getValues<llvm::APInt>()) {
|
||||
if (val.getActiveBits() > 64) {
|
||||
return llvm::None;
|
||||
return std::nullopt;
|
||||
}
|
||||
values.push_back(val.getSExtValue());
|
||||
}
|
||||
return values;
|
||||
}
|
||||
|
||||
llvm::Optional<std::vector<std::int64_t>>
|
||||
std::optional<std::vector<std::int64_t>>
|
||||
resolveConstantWeights(mlir::Value &value) {
|
||||
if (auto cstOp = llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
|
||||
value.getDefiningOp())) {
|
||||
@@ -463,18 +465,18 @@ struct FunctionToDag {
|
||||
return resolveConstantVectorWeights(cstOp);
|
||||
default:
|
||||
DEBUG("High-Rank tensor: rely on MANP and levelledOp");
|
||||
return llvm::None;
|
||||
return std::nullopt;
|
||||
}
|
||||
} else {
|
||||
DEBUG("Dynamic Weights: rely on MANP and levelledOp");
|
||||
return llvm::None;
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Optional<std::vector<std::int64_t>>
|
||||
std::optional<std::vector<std::int64_t>>
|
||||
dotWeights(mlir::concretelang::FHELinalg::Dot &dot) {
|
||||
if (dot.getOperands().size() != 2) {
|
||||
return llvm::None;
|
||||
return std::nullopt;
|
||||
}
|
||||
auto weights = dot.getOperands()[1];
|
||||
return resolveConstantWeights(weights);
|
||||
|
||||
@@ -16,8 +16,9 @@
|
||||
#include <llvm/ADT/APInt.h>
|
||||
#include <llvm/ADT/Optional.h>
|
||||
#include <llvm/ADT/SmallString.h>
|
||||
#include <mlir/Analysis/DataFlowAnalysis.h>
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
|
||||
#include <mlir/Analysis/DataFlow/SparseAnalysis.h>
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/Linalg/IR/Linalg.h>
|
||||
#include <mlir/Dialect/Tensor/IR/Tensor.h>
|
||||
@@ -55,7 +56,7 @@ static bool isEncryptedFunctionParameter(mlir::Value value) {
|
||||
/// values for its predecessors have been calculated beforehand or an
|
||||
/// unknown value otherwise.
|
||||
struct MANPLatticeValue {
|
||||
MANPLatticeValue(llvm::Optional<llvm::APInt> manp = {}) : manp(manp) {}
|
||||
MANPLatticeValue(std::optional<llvm::APInt> manp = {}) : manp(manp) {}
|
||||
|
||||
static MANPLatticeValue getPessimisticValueState(mlir::MLIRContext *context) {
|
||||
return MANPLatticeValue();
|
||||
@@ -81,21 +82,38 @@ struct MANPLatticeValue {
|
||||
return this->manp == rhs.manp;
|
||||
}
|
||||
|
||||
/// Required by `mlir::LatticeElement::join()`, but should never be
|
||||
/// invoked, as `MANPAnalysis::visitOperation()` takes care of
|
||||
/// combining the squared Minimal Arithmetic Noise Padding of
|
||||
/// operands into the Minimal Arithmetic Noise Padding of the result.
|
||||
static MANPLatticeValue join(const MANPLatticeValue &lhs,
|
||||
const MANPLatticeValue &rhs) {
|
||||
assert(false && "Minimal Arithmetic Noise Padding values can only be "
|
||||
"combined sensibly when the combining operation is known");
|
||||
if (!lhs.getMANP().has_value())
|
||||
return rhs;
|
||||
if (!rhs.getMANP().has_value())
|
||||
return lhs;
|
||||
|
||||
if (lhs.getMANP().value() == rhs.getMANP().value())
|
||||
return lhs;
|
||||
|
||||
assert(false && "Attempting to join two distinct initialized values");
|
||||
return MANPLatticeValue{};
|
||||
}
|
||||
|
||||
llvm::Optional<llvm::APInt> getMANP() { return manp; }
|
||||
void print(raw_ostream &os) const {
|
||||
if (manp.has_value())
|
||||
os << manp.value();
|
||||
else
|
||||
os << "(undefined)";
|
||||
}
|
||||
|
||||
std::optional<llvm::APInt> getMANP() const { return manp; }
|
||||
|
||||
protected:
|
||||
llvm::Optional<llvm::APInt> manp;
|
||||
std::optional<llvm::APInt> manp;
|
||||
};
|
||||
|
||||
class MANPLattice : public mlir::dataflow::Lattice<MANPLatticeValue> {
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MANPLattice)
|
||||
|
||||
using Lattice::Lattice;
|
||||
};
|
||||
|
||||
/// Checks if `lhs` is less than `rhs`, where both values are assumed
|
||||
@@ -278,15 +296,14 @@ static llvm::APInt conservativeIntNorm2Sq(mlir::Type t) {
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of an
|
||||
/// `FHELinalg.dot_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::Dot op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::Dot op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
assert(operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
|
||||
"operands");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
|
||||
|
||||
mlir::arith::ConstantOp cstOp =
|
||||
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
|
||||
@@ -312,150 +329,140 @@ static llvm::APInt getSqMANP(
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of an
|
||||
/// `FHE.add_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::AddEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHE::AddEintIntOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
|
||||
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.add_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::AddEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHE::AddEintOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
assert(operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
operandMANPs[1]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
|
||||
"operands");
|
||||
|
||||
llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue();
|
||||
llvm::APInt a = operandMANPs[0]->getValue().getMANP().value();
|
||||
llvm::APInt b = operandMANPs[1]->getValue().getMANP().value();
|
||||
|
||||
return APIntWidthExtendUAdd(a, b);
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.sub_int_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::SubIntEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHE::SubIntEintOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[1]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[1]->getValue().getMANP().getValue();
|
||||
llvm::APInt eNorm = operandMANPs[1]->getValue().getMANP().value();
|
||||
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.sub_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::SubEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHE::SubEintIntOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
|
||||
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.sub_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::SubEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHE::SubEintOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
assert(operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
operandMANPs[1]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
|
||||
"operands");
|
||||
|
||||
llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue();
|
||||
llvm::APInt a = operandMANPs[0]->getValue().getMANP().value();
|
||||
llvm::APInt b = operandMANPs[1]->getValue().getMANP().value();
|
||||
|
||||
return APIntWidthExtendUAdd(a, b);
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.neg_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::NegEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHE::NegEintOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
|
||||
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.not` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::BoolNotOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHE::BoolNotOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
|
||||
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::ToSignedOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHE::ToSignedOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
|
||||
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::ToUnsignedOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHE::ToUnsignedOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
|
||||
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.mul_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::MulEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHE::MulEintIntOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
mlir::Type iTy = op->getOpOperand(1).get().getType();
|
||||
|
||||
assert(iTy.isSignlessInteger() &&
|
||||
@@ -463,14 +470,14 @@ static llvm::APInt getSqMANP(
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
mlir::arith::ConstantOp cstOp =
|
||||
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
|
||||
op->getOpOperand(1).get().getDefiningOp());
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
|
||||
llvm::APInt sqNorm;
|
||||
|
||||
if (cstOp) {
|
||||
@@ -488,19 +495,18 @@ static llvm::APInt getSqMANP(
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of
|
||||
/// `FHE.mul_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::MulEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHE::MulEintOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
assert(operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
operandMANPs[1]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
|
||||
"operands");
|
||||
|
||||
// x * y = ((x + y)^2 / 4) - ((x - y)^2 / 4) == tlu(x + y) - tlu(x - y)
|
||||
|
||||
const llvm::APInt x = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
const llvm::APInt y = operandMANPs[1]->getValue().getMANP().getValue();
|
||||
const llvm::APInt x = operandMANPs[0]->getValue().getMANP().value();
|
||||
const llvm::APInt y = operandMANPs[1]->getValue().getMANP().value();
|
||||
|
||||
const llvm::APInt beforeTLUs = APIntWidthExtendUAdd(x, y);
|
||||
const llvm::APInt tlu = {1, 1, false};
|
||||
@@ -512,13 +518,12 @@ static llvm::APInt getSqMANP(
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.round` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::RoundEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHE::RoundEintOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
uint64_t inputWidth =
|
||||
@@ -527,7 +532,7 @@ static llvm::APInt getSqMANP(
|
||||
op.getResult().getType().cast<FHE::FheIntegerInterface>().getWidth();
|
||||
uint64_t clearedBits = inputWidth - outputWidth;
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
|
||||
eNorm += clearedBits;
|
||||
|
||||
return eNorm;
|
||||
@@ -535,19 +540,18 @@ static llvm::APInt getSqMANP(
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of
|
||||
/// `FHE.max_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::MaxEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHE::MaxEintOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
assert(operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
operandMANPs[1]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
|
||||
"operands");
|
||||
|
||||
// max(x, y) = max(x - y, 0) + y
|
||||
|
||||
const llvm::APInt x = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
const llvm::APInt y = operandMANPs[1]->getValue().getMANP().getValue();
|
||||
const llvm::APInt x = operandMANPs[0]->getValue().getMANP().value();
|
||||
const llvm::APInt y = operandMANPs[1]->getValue().getMANP().value();
|
||||
|
||||
const llvm::APInt sub = APIntWidthExtendUAdd(x, y);
|
||||
const llvm::APInt tlu = {1, 1, false};
|
||||
@@ -559,101 +563,94 @@ static llvm::APInt getSqMANP(
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of an
|
||||
/// `FHELinalg.add_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::AddEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::AddEintIntOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
|
||||
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::AddEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::AddEintOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
assert(operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
operandMANPs[1]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
|
||||
"operands");
|
||||
|
||||
llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue();
|
||||
llvm::APInt a = operandMANPs[0]->getValue().getMANP().value();
|
||||
llvm::APInt b = operandMANPs[1]->getValue().getMANP().value();
|
||||
|
||||
return APIntWidthExtendUAdd(a, b);
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHELinalg.sub_int_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::SubIntEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::SubIntEintOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[1]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[1]->getValue().getMANP().getValue();
|
||||
llvm::APInt eNorm = operandMANPs[1]->getValue().getMANP().value();
|
||||
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::SubEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::SubEintIntOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
|
||||
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::SubEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::SubEintOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
assert(operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
operandMANPs[1]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
|
||||
"operands");
|
||||
|
||||
llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue();
|
||||
llvm::APInt a = operandMANPs[0]->getValue().getMANP().value();
|
||||
llvm::APInt b = operandMANPs[1]->getValue().getMANP().value();
|
||||
|
||||
return APIntWidthExtendUAdd(a, b);
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHELinalg.neg_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::NegEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::NegEintOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
|
||||
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.mul_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::MulEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MulEintIntOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
mlir::RankedTensorType op0Ty =
|
||||
op->getOpOperand(1).get().getType().cast<mlir::RankedTensorType>();
|
||||
@@ -665,10 +662,10 @@ static llvm::APInt getSqMANP(
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
|
||||
llvm::APInt sqNorm;
|
||||
|
||||
mlir::arith::ConstantOp cstOp =
|
||||
@@ -761,14 +758,13 @@ static llvm::APInt calculateSqManpForMatMulWithDenseValues(
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.mul_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::MatMulEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MatMulEintIntOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
auto lhsType =
|
||||
((mlir::Type)op.lhs().getType()).cast<mlir::RankedTensorType>();
|
||||
((mlir::Type)op.getLhs().getType()).cast<mlir::RankedTensorType>();
|
||||
auto rhsType =
|
||||
((mlir::Type)op.rhs().getType()).cast<mlir::RankedTensorType>();
|
||||
((mlir::Type)op.getRhs().getType()).cast<mlir::RankedTensorType>();
|
||||
|
||||
llvm::ArrayRef<int64_t> lhsShape = lhsType.getShape();
|
||||
llvm::ArrayRef<int64_t> rhsShape = rhsType.getShape();
|
||||
@@ -782,10 +778,10 @@ static llvm::APInt getSqMANP(
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt lhsNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt lhsNorm = operandMANPs[0]->getValue().getMANP().value();
|
||||
llvm::APInt accNorm = llvm::APInt{1, 1, false};
|
||||
|
||||
mlir::arith::ConstantOp cstOp =
|
||||
@@ -860,14 +856,13 @@ static llvm::APInt getSqMANP(
|
||||
return accNorm;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::MatMulIntEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MatMulIntEintOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
auto lhsType =
|
||||
((mlir::Type)op.lhs().getType()).cast<mlir::RankedTensorType>();
|
||||
((mlir::Type)op.getLhs().getType()).cast<mlir::RankedTensorType>();
|
||||
auto rhsType =
|
||||
((mlir::Type)op.rhs().getType()).cast<mlir::RankedTensorType>();
|
||||
((mlir::Type)op.getRhs().getType()).cast<mlir::RankedTensorType>();
|
||||
|
||||
llvm::ArrayRef<int64_t> lhsShape = lhsType.getShape();
|
||||
llvm::ArrayRef<int64_t> rhsShape = rhsType.getShape();
|
||||
@@ -881,10 +876,10 @@ static llvm::APInt getSqMANP(
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[1]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt rhsNorm = operandMANPs[1]->getValue().getMANP().getValue();
|
||||
llvm::APInt rhsNorm = operandMANPs[1]->getValue().getMANP().value();
|
||||
llvm::APInt accNorm = llvm::APInt{1, 1, false};
|
||||
|
||||
mlir::arith::ConstantOp cstOp =
|
||||
@@ -960,123 +955,112 @@ static llvm::APInt getSqMANP(
|
||||
return accNorm;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::TransposeOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::TransposeOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
return operandMANPs[0]->getValue().getMANP().getValue();
|
||||
return operandMANPs[0]->getValue().getMANP().value();
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::tensor::ExtractOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::tensor::ExtractOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
|
||||
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
FHELinalg::FromElementOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(FHELinalg::FromElementOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
auto manp = operandMANPs[0]->getValue().getMANP();
|
||||
if (manp.hasValue()) {
|
||||
return manp.getValue();
|
||||
if (manp.has_value()) {
|
||||
return manp.value();
|
||||
}
|
||||
|
||||
return llvm::APInt{1, 1, false};
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::tensor::FromElementsOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::tensor::FromElementsOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
auto max = std::max_element(
|
||||
operandMANPs.begin(), operandMANPs.end(),
|
||||
[](mlir::LatticeElement<MANPLatticeValue> *const a,
|
||||
mlir::LatticeElement<MANPLatticeValue> *const b) {
|
||||
return APIntWidthExtendULT(a->getValue().getMANP().getValue(),
|
||||
b->getValue().getMANP().getValue());
|
||||
});
|
||||
return (*max)->getValue().getMANP().getValue();
|
||||
auto max = std::max_element(operandMANPs.begin(), operandMANPs.end(),
|
||||
[](const MANPLattice *a, const MANPLattice *b) {
|
||||
return APIntWidthExtendULT(
|
||||
a->getValue().getMANP().value(),
|
||||
b->getValue().getMANP().value());
|
||||
});
|
||||
return (*max)->getValue().getMANP().value();
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::tensor::ExtractSliceOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::tensor::ExtractSliceOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
return operandMANPs[0]->getValue().getMANP().getValue();
|
||||
return operandMANPs[0]->getValue().getMANP().value();
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::tensor::InsertSliceOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::tensor::InsertSliceOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() >= 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
operandMANPs[1]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
return APIntUMax(operandMANPs[0]->getValue().getMANP().getValue(),
|
||||
operandMANPs[1]->getValue().getMANP().getValue());
|
||||
return APIntUMax(operandMANPs[0]->getValue().getMANP().value(),
|
||||
operandMANPs[1]->getValue().getMANP().value());
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::tensor::InsertOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::tensor::InsertOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() >= 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
operandMANPs[1]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
return APIntUMax(operandMANPs[0]->getValue().getMANP().getValue(),
|
||||
operandMANPs[1]->getValue().getMANP().getValue());
|
||||
return APIntUMax(operandMANPs[0]->getValue().getMANP().value(),
|
||||
operandMANPs[1]->getValue().getMANP().value());
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::tensor::CollapseShapeOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::tensor::CollapseShapeOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() >= 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
return operandMANPs[0]->getValue().getMANP().getValue();
|
||||
return operandMANPs[0]->getValue().getMANP().value();
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::tensor::ExpandShapeOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::tensor::ExpandShapeOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() >= 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
return operandMANPs[0]->getValue().getMANP().getValue();
|
||||
return operandMANPs[0]->getValue().getMANP().value();
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::SumOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::SumOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
auto inputType = op.getOperand().getType().dyn_cast<mlir::TensorType>();
|
||||
|
||||
@@ -1087,12 +1071,12 @@ static llvm::APInt getSqMANP(
|
||||
|
||||
uint64_t numberOfElementsAddedTogetherInEachOutputCell = 1;
|
||||
|
||||
mlir::ArrayAttr axes = op.axes();
|
||||
mlir::ArrayAttr axes = op.getAxes();
|
||||
if (axes.empty()) {
|
||||
numberOfElementsAddedTogetherInEachOutputCell *= numberOfElementsInTheInput;
|
||||
} else {
|
||||
llvm::ArrayRef<int64_t> shape = inputType.getShape();
|
||||
for (mlir::Attribute axisAttribute : op.axes()) {
|
||||
for (mlir::Attribute axisAttribute : op.getAxes()) {
|
||||
int64_t axis = axisAttribute.cast<IntegerAttr>().getInt();
|
||||
numberOfElementsAddedTogetherInEachOutputCell *= shape[axis];
|
||||
}
|
||||
@@ -1108,22 +1092,21 @@ static llvm::APInt getSqMANP(
|
||||
};
|
||||
|
||||
assert(operandMANPs.size() == 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
|
||||
"operands");
|
||||
|
||||
llvm::APInt operandMANP = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt operandMANP = operandMANPs[0]->getValue().getMANP().value();
|
||||
|
||||
return APIntWidthExtendUMul(noiseMultiplier, operandMANP);
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::ConcatOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::ConcatOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
llvm::APInt result = llvm::APInt{1, 0, false};
|
||||
for (mlir::LatticeElement<MANPLatticeValue> *operandMANP : operandMANPs) {
|
||||
llvm::APInt candidate = operandMANP->getValue().getMANP().getValue();
|
||||
for (const MANPLattice *operandMANP : operandMANPs) {
|
||||
llvm::APInt candidate = operandMANP->getValue().getMANP().value();
|
||||
if (candidate.getLimitedValue() >= result.getLimitedValue()) {
|
||||
result = candidate;
|
||||
}
|
||||
@@ -1131,22 +1114,21 @@ static llvm::APInt getSqMANP(
|
||||
return result;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::Conv2dOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::Conv2dOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
mlir::RankedTensorType weightTy =
|
||||
op.weight().getType().cast<mlir::RankedTensorType>();
|
||||
op.getWeight().getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
mlir::Type weightIntType = weightTy.getElementType();
|
||||
|
||||
// Bias is optional, so we can have both 2 or 3 operands
|
||||
assert((operandMANPs.size() == 2 || operandMANPs.size() == 3) &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
|
||||
"operand");
|
||||
|
||||
llvm::APInt inputNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt inputNorm = operandMANPs[0]->getValue().getMANP().value();
|
||||
|
||||
mlir::arith::ConstantOp weightCstOp =
|
||||
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
|
||||
@@ -1200,9 +1182,8 @@ static llvm::APInt getSqMANP(
|
||||
return accNorm;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::Maxpool2dOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::Maxpool2dOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
// maximum between two value is calculated using
|
||||
// - max(x - y, 0) + y
|
||||
@@ -1214,7 +1195,7 @@ static llvm::APInt getSqMANP(
|
||||
// so the resulting MANP is `{1, 1, false} + MANP input`
|
||||
|
||||
const llvm::APInt tlu = {1, 1, false};
|
||||
const llvm::APInt input = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
const llvm::APInt input = operandMANPs[0]->getValue().getMANP().value();
|
||||
|
||||
const llvm::APInt forResult = APIntWidthExtendUAdd(tlu, input);
|
||||
const llvm::APInt forIntermediate = APIntWidthExtendUAdd(forResult, input);
|
||||
@@ -1222,18 +1203,28 @@ static llvm::APInt getSqMANP(
|
||||
return APIntUMax(forIntermediate, forResult);
|
||||
}
|
||||
|
||||
struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
using ForwardDataFlowAnalysis<MANPLatticeValue>::ForwardDataFlowAnalysis;
|
||||
MANPAnalysis(mlir::MLIRContext *ctx, bool debug)
|
||||
: mlir::ForwardDataFlowAnalysis<MANPLatticeValue>(ctx), debug(debug) {}
|
||||
class MANPAnalysis
|
||||
: public mlir::dataflow::SparseDataFlowAnalysis<MANPLattice> {
|
||||
public:
|
||||
explicit MANPAnalysis(mlir::DataFlowSolver &solver, bool debug)
|
||||
: mlir::dataflow::SparseDataFlowAnalysis<MANPLattice>(solver),
|
||||
debug(debug) {}
|
||||
|
||||
~MANPAnalysis() override = default;
|
||||
void setToEntryState(MANPLattice *lattice) override {
|
||||
if (isEncryptedFunctionParameter(lattice->getPoint())) {
|
||||
// Set minimal MANP for encrypted function arguments
|
||||
propagateIfChanged(lattice, lattice->join(MANPLatticeValue{
|
||||
std::optional{llvm::APInt(1, 1)}}));
|
||||
} else {
|
||||
// Everything else is initialized with an unset value
|
||||
propagateIfChanged(lattice, lattice->join(MANPLatticeValue{}));
|
||||
}
|
||||
}
|
||||
|
||||
void visitOperation(Operation *op, ArrayRef<const MANPLattice *> operands,
|
||||
ArrayRef<MANPLattice *> results) override {
|
||||
MANPLattice *latticeRes = results[0];
|
||||
|
||||
mlir::ChangeResult visitOperation(
|
||||
mlir::Operation *op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operands) final {
|
||||
mlir::LatticeElement<MANPLatticeValue> &latticeRes =
|
||||
getLatticeElement(op->getResult(0));
|
||||
bool isDummy = false;
|
||||
llvm::APInt norm2SqEquiv;
|
||||
|
||||
@@ -1350,7 +1341,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
} else if (auto transposeOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::TransposeOp>(
|
||||
op)) {
|
||||
if (transposeOp.tensor()
|
||||
if (transposeOp.getTensor()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
@@ -1364,7 +1355,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
// Tensor Operators
|
||||
// ExtractOp
|
||||
else if (auto extractOp = llvm::dyn_cast<mlir::tensor::ExtractOp>(op)) {
|
||||
if (extractOp.result()
|
||||
if (extractOp.getResult()
|
||||
.getType()
|
||||
.isa<mlir::concretelang::FHE::FheIntegerInterface>()) {
|
||||
norm2SqEquiv = getSqMANP(extractOp, operands);
|
||||
@@ -1375,7 +1366,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
// ExtractSliceOp
|
||||
else if (auto extractSliceOp =
|
||||
llvm::dyn_cast<mlir::tensor::ExtractSliceOp>(op)) {
|
||||
if (extractSliceOp.result()
|
||||
if (extractSliceOp.getResult()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
@@ -1387,7 +1378,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
}
|
||||
// InsertOp
|
||||
else if (auto insertOp = llvm::dyn_cast<mlir::tensor::InsertOp>(op)) {
|
||||
if (insertOp.result()
|
||||
if (insertOp.getResult()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
@@ -1400,7 +1391,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
// InsertSliceOp
|
||||
else if (auto insertSliceOp =
|
||||
llvm::dyn_cast<mlir::tensor::InsertSliceOp>(op)) {
|
||||
if (insertSliceOp.result()
|
||||
if (insertSliceOp.getResult()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
@@ -1412,7 +1403,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
}
|
||||
// FromElementOp
|
||||
else if (auto fromOp = llvm::dyn_cast<mlir::tensor::FromElementsOp>(op)) {
|
||||
if (fromOp.result()
|
||||
if (fromOp.getResult()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
@@ -1425,7 +1416,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
// TensorCollapseShapeOp
|
||||
else if (auto reshapeOp =
|
||||
llvm::dyn_cast<mlir::tensor::CollapseShapeOp>(op)) {
|
||||
if (reshapeOp.result()
|
||||
if (reshapeOp.getResult()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
@@ -1437,7 +1428,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
}
|
||||
// TensorExpandShapeOp
|
||||
else if (auto reshapeOp = llvm::dyn_cast<mlir::tensor::ExpandShapeOp>(op)) {
|
||||
if (reshapeOp.result()
|
||||
if (reshapeOp.getResult()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
@@ -1459,8 +1450,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
}
|
||||
|
||||
if (!isDummy) {
|
||||
latticeRes.join(MANPLatticeValue{norm2SqEquiv});
|
||||
latticeRes.markOptimisticFixpoint();
|
||||
latticeRes->join(MANPLatticeValue{norm2SqEquiv});
|
||||
|
||||
op->setAttr("SMANP",
|
||||
mlir::IntegerAttr::get(
|
||||
@@ -1483,10 +1473,8 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
<< APIntToStringValUnsigned(norm2SqEquiv) << "\n";
|
||||
}
|
||||
} else {
|
||||
latticeRes.join(MANPLatticeValue{});
|
||||
latticeRes->join(MANPLatticeValue{});
|
||||
}
|
||||
|
||||
return mlir::ChangeResult::Change;
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -1500,8 +1488,12 @@ struct MANPPass : public MANPBase<MANPPass> {
|
||||
void runOnOperation() override {
|
||||
mlir::func::FuncOp func = getOperation();
|
||||
|
||||
MANPAnalysis analysis(func->getContext(), debug);
|
||||
analysis.run(func);
|
||||
mlir::DataFlowSolver solver;
|
||||
solver.load<mlir::dataflow::DeadCodeAnalysis>();
|
||||
solver.load<MANPAnalysis>(debug);
|
||||
|
||||
if (failed(solver.initializeAndRun(func)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
MANPPass() = delete;
|
||||
MANPPass(bool debug) : debug(debug){};
|
||||
|
||||
@@ -61,8 +61,8 @@ bool verifyEncryptedIntegerInputsConsistency(mlir::Operation &op,
|
||||
}
|
||||
|
||||
mlir::LogicalResult AddEintIntOp::verify() {
|
||||
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto b = this->b().getType().cast<IntegerType>();
|
||||
auto a = this->getA().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto b = this->getB().getType().cast<IntegerType>();
|
||||
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
|
||||
|
||||
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
|
||||
@@ -79,8 +79,8 @@ mlir::LogicalResult AddEintIntOp::verify() {
|
||||
}
|
||||
|
||||
mlir::LogicalResult AddEintOp::verify() {
|
||||
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto b = this->b().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto a = this->getA().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto b = this->getB().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
|
||||
|
||||
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
|
||||
@@ -96,8 +96,8 @@ mlir::LogicalResult AddEintOp::verify() {
|
||||
}
|
||||
|
||||
mlir::LogicalResult SubIntEintOp::verify() {
|
||||
auto a = this->a().getType().cast<IntegerType>();
|
||||
auto b = this->b().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto a = this->getA().getType().cast<IntegerType>();
|
||||
auto b = this->getB().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
|
||||
|
||||
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), b,
|
||||
@@ -114,8 +114,8 @@ mlir::LogicalResult SubIntEintOp::verify() {
|
||||
}
|
||||
|
||||
mlir::LogicalResult SubEintIntOp::verify() {
|
||||
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto b = this->b().getType().cast<IntegerType>();
|
||||
auto a = this->getA().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto b = this->getB().getType().cast<IntegerType>();
|
||||
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
|
||||
|
||||
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
|
||||
@@ -132,8 +132,8 @@ mlir::LogicalResult SubEintIntOp::verify() {
|
||||
}
|
||||
|
||||
mlir::LogicalResult SubEintOp::verify() {
|
||||
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto b = this->b().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto a = this->getA().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto b = this->getB().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
|
||||
|
||||
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
|
||||
@@ -149,7 +149,7 @@ mlir::LogicalResult SubEintOp::verify() {
|
||||
}
|
||||
|
||||
mlir::LogicalResult NegEintOp::verify() {
|
||||
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto a = this->getA().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
|
||||
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
|
||||
out)) {
|
||||
@@ -159,8 +159,8 @@ mlir::LogicalResult NegEintOp::verify() {
|
||||
}
|
||||
|
||||
mlir::LogicalResult MulEintIntOp::verify() {
|
||||
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto b = this->b().getType().cast<IntegerType>();
|
||||
auto a = this->getA().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto b = this->getB().getType().cast<IntegerType>();
|
||||
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
|
||||
|
||||
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
|
||||
@@ -177,8 +177,8 @@ mlir::LogicalResult MulEintIntOp::verify() {
|
||||
}
|
||||
|
||||
mlir::LogicalResult MulEintOp::verify() {
|
||||
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto b = this->b().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto a = this->getA().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto b = this->getB().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
|
||||
|
||||
if (!verifyEncryptedIntegerInputsConsistency(*this->getOperation(), a, b)) {
|
||||
@@ -193,8 +193,8 @@ mlir::LogicalResult MulEintOp::verify() {
|
||||
}
|
||||
|
||||
mlir::LogicalResult MaxEintOp::verify() {
|
||||
auto xTy = this->x().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto yTy = this->y().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto xTy = this->getX().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto yTy = this->getY().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto outTy = this->getResult().getType().dyn_cast<FheIntegerInterface>();
|
||||
|
||||
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(),
|
||||
@@ -211,7 +211,7 @@ mlir::LogicalResult MaxEintOp::verify() {
|
||||
}
|
||||
|
||||
mlir::LogicalResult ToSignedOp::verify() {
|
||||
auto input = this->input().getType().cast<EncryptedIntegerType>();
|
||||
auto input = this->getInput().getType().cast<EncryptedIntegerType>();
|
||||
auto output = this->getResult().getType().cast<EncryptedSignedIntegerType>();
|
||||
|
||||
if (input.getWidth() != output.getWidth()) {
|
||||
@@ -224,7 +224,7 @@ mlir::LogicalResult ToSignedOp::verify() {
|
||||
}
|
||||
|
||||
mlir::LogicalResult ToUnsignedOp::verify() {
|
||||
auto input = this->input().getType().cast<EncryptedSignedIntegerType>();
|
||||
auto input = this->getInput().getType().cast<EncryptedSignedIntegerType>();
|
||||
auto output = this->getResult().getType().cast<EncryptedIntegerType>();
|
||||
|
||||
if (input.getWidth() != output.getWidth()) {
|
||||
@@ -237,7 +237,7 @@ mlir::LogicalResult ToUnsignedOp::verify() {
|
||||
}
|
||||
|
||||
mlir::LogicalResult ToBoolOp::verify() {
|
||||
auto input = this->input().getType().cast<EncryptedIntegerType>();
|
||||
auto input = this->getInput().getType().cast<EncryptedIntegerType>();
|
||||
|
||||
if (input.getWidth() != 1 && input.getWidth() != 2) {
|
||||
this->emitOpError("should have 1 or 2 as the width of encrypted input to "
|
||||
@@ -249,7 +249,7 @@ mlir::LogicalResult ToBoolOp::verify() {
|
||||
}
|
||||
|
||||
mlir::LogicalResult GenGateOp::verify() {
|
||||
auto truth_table = this->truth_table().getType().cast<TensorType>();
|
||||
auto truth_table = this->getTruthTable().getType().cast<TensorType>();
|
||||
|
||||
mlir::SmallVector<int64_t, 1> expectedShape{4};
|
||||
if (!truth_table.hasStaticShape(expectedShape)) {
|
||||
@@ -261,8 +261,8 @@ mlir::LogicalResult GenGateOp::verify() {
|
||||
}
|
||||
|
||||
::mlir::LogicalResult ApplyLookupTableEintOp::verify() {
|
||||
auto ct = this->a().getType().cast<FheIntegerInterface>();
|
||||
auto lut = this->lut().getType().cast<TensorType>();
|
||||
auto ct = this->getA().getType().cast<FheIntegerInterface>();
|
||||
auto lut = this->getLut().getType().cast<TensorType>();
|
||||
|
||||
// Check the shape of lut argument
|
||||
auto width = ct.getWidth();
|
||||
@@ -281,7 +281,7 @@ mlir::LogicalResult GenGateOp::verify() {
|
||||
}
|
||||
|
||||
mlir::LogicalResult RoundEintOp::verify() {
|
||||
auto input = this->input().getType().cast<FheIntegerInterface>();
|
||||
auto input = this->getInput().getType().cast<FheIntegerInterface>();
|
||||
auto output = this->getResult().getType().cast<FheIntegerInterface>();
|
||||
|
||||
if (input.getWidth() <= output.getWidth()) {
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
// for license information.
|
||||
|
||||
#include <mlir/Dialect/Affine/IR/AffineOps.h>
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/Tensor/IR/Tensor.h>
|
||||
#include <mlir/IR/PatternMatch.h>
|
||||
@@ -93,7 +93,8 @@ public:
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::AddEintOp op, FHE::AddEintOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
auto tensorType = adaptor.a().getType().dyn_cast<mlir::RankedTensorType>();
|
||||
auto tensorType =
|
||||
adaptor.getA().getType().dyn_cast<mlir::RankedTensorType>();
|
||||
auto shape = tensorType.getShape();
|
||||
assert(shape.size() == 1 &&
|
||||
"chunked integer should be converted to flat tensors, but tensor "
|
||||
@@ -112,7 +113,8 @@ public:
|
||||
.getResult();
|
||||
|
||||
mlir::Value resultTensor =
|
||||
rewriter.create<FHE::ZeroTensorOp>(op.getLoc(), adaptor.a().getType())
|
||||
rewriter
|
||||
.create<FHE::ZeroTensorOp>(op.getLoc(), adaptor.getA().getType())
|
||||
.getResult();
|
||||
// used to shift the carry bit to the left
|
||||
mlir::Value twoPowerChunkSizeCst =
|
||||
@@ -127,10 +129,10 @@ public:
|
||||
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter,
|
||||
mlir::ValueRange args) {
|
||||
// add inputs with the previous carry (init to 0)
|
||||
mlir::Value leftEint =
|
||||
builder.create<mlir::tensor::ExtractOp>(loc, adaptor.a(), iter);
|
||||
mlir::Value rightEint =
|
||||
builder.create<mlir::tensor::ExtractOp>(loc, adaptor.b(), iter);
|
||||
mlir::Value leftEint = builder.create<mlir::tensor::ExtractOp>(
|
||||
loc, adaptor.getA(), iter);
|
||||
mlir::Value rightEint = builder.create<mlir::tensor::ExtractOp>(
|
||||
loc, adaptor.getB(), iter);
|
||||
mlir::Value result =
|
||||
builder.create<FHE::AddEintOp>(loc, leftEint, rightEint)
|
||||
.getResult();
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/IR/PatternMatch.h>
|
||||
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
|
||||
|
||||
@@ -33,11 +33,11 @@ public:
|
||||
rewriter.getContext(), 2);
|
||||
auto left = rewriter
|
||||
.create<mlir::concretelang::FHE::FromBoolOp>(
|
||||
op.getLoc(), eint2, op.left())
|
||||
op.getLoc(), eint2, op.getLeft())
|
||||
.getResult();
|
||||
auto right = rewriter
|
||||
.create<mlir::concretelang::FHE::FromBoolOp>(
|
||||
op.getLoc(), eint2, op.right())
|
||||
op.getLoc(), eint2, op.getRight())
|
||||
.getResult();
|
||||
auto cst_two =
|
||||
rewriter.create<mlir::arith::ConstantIntOp>(op.getLoc(), 2, 3)
|
||||
@@ -52,7 +52,7 @@ public:
|
||||
.getResult();
|
||||
auto lut_result =
|
||||
rewriter.create<mlir::concretelang::FHE::ApplyLookupTableEintOp>(
|
||||
op.getLoc(), eint2, newIndex, op.truth_table());
|
||||
op.getLoc(), eint2, newIndex, op.getTruthTable());
|
||||
rewriter.replaceOpWithNewOp<mlir::concretelang::FHE::ToBoolOp>(
|
||||
op,
|
||||
mlir::concretelang::FHE::EncryptedBooleanType::get(
|
||||
@@ -84,7 +84,7 @@ public:
|
||||
auto truth_table =
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), truth_table_attr);
|
||||
rewriter.replaceOpWithNewOp<mlir::concretelang::FHE::GenGateOp>(
|
||||
op, op.getResult().getType(), op.left(), op.right(), truth_table);
|
||||
op, op.getResult().getType(), op.getLeft(), op.getRight(), truth_table);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
@@ -119,11 +119,11 @@ public:
|
||||
auto c1AndNotCond =
|
||||
rewriter
|
||||
.create<mlir::concretelang::FHE::GenGateOp>(
|
||||
op.getLoc(), boolType, op.c1(), op.cond(), truth_table)
|
||||
op.getLoc(), boolType, op.getC1(), op.getCond(), truth_table)
|
||||
.getResult();
|
||||
auto c2AndCond = rewriter
|
||||
.create<mlir::concretelang::FHE::BoolAndOp>(
|
||||
op.getLoc(), boolType, op.c2(), op.cond())
|
||||
op.getLoc(), boolType, op.getC2(), op.getCond())
|
||||
.getResult();
|
||||
|
||||
auto c1AndNotCondBool = rewriter
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include <concretelang/Dialect/FHE/IR/FHEOps.h>
|
||||
#include <concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.h>
|
||||
#include <concretelang/Support/Constants.h>
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/Linalg/IR/Linalg.h>
|
||||
#include <mlir/IR/PatternMatch.h>
|
||||
@@ -31,7 +31,7 @@ public:
|
||||
matchAndRewrite(FHE::MulEintOp op, FHE::MulEintOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
auto inputType = adaptor.a().getType();
|
||||
auto inputType = adaptor.getA().getType();
|
||||
auto bitWidth = inputType.cast<FHE::FheIntegerInterface>().getWidth();
|
||||
auto isSigned = inputType.cast<FHE::FheIntegerInterface>().isSigned();
|
||||
mlir::Type signedType =
|
||||
@@ -50,8 +50,8 @@ public:
|
||||
// signedness for inputs and outputs.
|
||||
|
||||
// s = a + b
|
||||
mlir::Value sum =
|
||||
rewriter.create<FHE::AddEintOp>(op->getLoc(), adaptor.a(), adaptor.b());
|
||||
mlir::Value sum = rewriter.create<FHE::AddEintOp>(
|
||||
op->getLoc(), adaptor.getA(), adaptor.getB());
|
||||
|
||||
// se = (s)^2/4
|
||||
// Depending on whether a,b,s are signed or not, we need a different lut to
|
||||
@@ -71,8 +71,8 @@ public:
|
||||
op->getLoc(), inputType, sum, sumLut);
|
||||
|
||||
// d = a - b
|
||||
mlir::Value diff =
|
||||
rewriter.create<FHE::SubEintOp>(op->getLoc(), adaptor.a(), adaptor.b());
|
||||
mlir::Value diff = rewriter.create<FHE::SubEintOp>(
|
||||
op->getLoc(), adaptor.getA(), adaptor.getB());
|
||||
|
||||
// de = (d)^2/4
|
||||
// Here, the tlu must be performed with signed encoded lut, to properly
|
||||
@@ -155,7 +155,7 @@ public:
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
|
||||
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
|
||||
target.addLegalDialect<mlir::arith::ArithDialect>();
|
||||
target.addLegalDialect<FHE::FHEDialect>();
|
||||
target.addIllegalOp<FHE::MulEintOp>();
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
@@ -32,8 +32,8 @@ struct MaxEintPattern : public mlir::OpRewritePattern<FHE::MaxEintOp> {
|
||||
maxEintOp->getResult(0).getType().cast<FHE::FheIntegerInterface>();
|
||||
const int64_t outputBitWidth = outputTy.getWidth();
|
||||
|
||||
mlir::Value x = maxEintOp.x();
|
||||
mlir::Value y = maxEintOp.y();
|
||||
mlir::Value x = maxEintOp.getX();
|
||||
mlir::Value y = maxEintOp.getY();
|
||||
|
||||
const auto xTy = x.getType().cast<FHE::FheIntegerInterface>();
|
||||
const auto yTy = y.getType().cast<FHE::FheIntegerInterface>();
|
||||
@@ -70,7 +70,7 @@ struct MaxEintPattern : public mlir::OpRewritePattern<FHE::MaxEintOp> {
|
||||
.getResult();
|
||||
|
||||
const mlir::Value add =
|
||||
rewriter.create<FHE::AddEintOp>(loc, max, maxEintOp.y()).getResult();
|
||||
rewriter.create<FHE::AddEintOp>(loc, max, maxEintOp.getY()).getResult();
|
||||
|
||||
rewriter.replaceOp(maxEintOp, {add});
|
||||
return mlir::success();
|
||||
@@ -85,7 +85,7 @@ struct FHEMaxTransform : public FHEMaxTransformBase<FHEMaxTransform> {
|
||||
|
||||
void FHEMaxTransform::runOnOperation() {
|
||||
auto target = mlir::ConversionTarget(this->getContext());
|
||||
target.addLegalDialect<arith::ArithmeticDialect>();
|
||||
target.addLegalDialect<arith::ArithDialect>();
|
||||
target.addLegalDialect<FHE::FHEDialect>();
|
||||
target.addIllegalOp<FHE::MaxEintOp>();
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h"
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
@@ -272,10 +273,10 @@ namespace concretelang {
|
||||
namespace FHELinalg {
|
||||
|
||||
mlir::LogicalResult ApplyLookupTableEintOp::verify() {
|
||||
auto tTy = this->t().getType().cast<mlir::RankedTensorType>();
|
||||
auto tTy = this->getT().getType().cast<mlir::RankedTensorType>();
|
||||
auto tEltTy = tTy.getElementType()
|
||||
.cast<mlir::concretelang::FHE::EncryptedIntegerType>();
|
||||
auto lutTy = this->lut().getType().cast<mlir::RankedTensorType>();
|
||||
auto lutTy = this->getLut().getType().cast<mlir::RankedTensorType>();
|
||||
auto lutEltTy = lutTy.getElementType().cast<mlir::IntegerType>();
|
||||
auto resultTy = this->getResult().getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
@@ -297,10 +298,10 @@ mlir::LogicalResult ApplyLookupTableEintOp::verify() {
|
||||
}
|
||||
|
||||
mlir::LogicalResult ApplyMultiLookupTableEintOp::verify() {
|
||||
auto tTy = this->t().getType().cast<mlir::RankedTensorType>();
|
||||
auto tTy = this->getT().getType().cast<mlir::RankedTensorType>();
|
||||
auto tEltTy = tTy.getElementType()
|
||||
.cast<mlir::concretelang::FHE::EncryptedIntegerType>();
|
||||
auto lutTy = this->luts().getType().cast<mlir::RankedTensorType>();
|
||||
auto lutTy = this->getLuts().getType().cast<mlir::RankedTensorType>();
|
||||
auto lutEltTy = lutTy.getElementType().cast<mlir::IntegerType>();
|
||||
auto resultTy = this->getResult().getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
@@ -378,9 +379,9 @@ mlir::LogicalResult verifyLutsSize(ApplyMappedLookupTableEintOp &op,
|
||||
}
|
||||
|
||||
mlir::LogicalResult ApplyMappedLookupTableEintOp::verify() {
|
||||
auto t = this->t();
|
||||
auto luts = this->luts();
|
||||
auto map = this->map();
|
||||
auto t = this->getT();
|
||||
auto luts = this->getLuts();
|
||||
auto map = this->getMap();
|
||||
auto result = this->getResult();
|
||||
|
||||
auto t_shape = getTensorType(t).getShape();
|
||||
@@ -401,16 +402,16 @@ mlir::LogicalResult ApplyMappedLookupTableEintOp::verify() {
|
||||
}
|
||||
|
||||
::mlir::LogicalResult Dot::verify() {
|
||||
if (::mlir::failed(mlir::verifyCompatibleShape(this->lhs().getType(),
|
||||
this->rhs().getType()))) {
|
||||
if (::mlir::failed(mlir::verifyCompatibleShape(this->getLhs().getType(),
|
||||
this->getRhs().getType()))) {
|
||||
return this->emitOpError("arguments have incompatible shapes");
|
||||
}
|
||||
auto lhsEltType = this->lhs()
|
||||
auto lhsEltType = this->getLhs()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.dyn_cast<FHE::FheIntegerInterface>();
|
||||
auto rhsEltType = this->rhs()
|
||||
auto rhsEltType = this->getRhs()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
@@ -480,8 +481,8 @@ mlir::LogicalResult SumOp::verify() {
|
||||
llvm::ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
int64_t inputDimensions = (int64_t)inputShape.size();
|
||||
|
||||
mlir::ArrayAttr axes = this->axes();
|
||||
bool keepDims = this->keep_dims();
|
||||
mlir::ArrayAttr axes = this->getAxes();
|
||||
bool keepDims = this->getKeepDims();
|
||||
|
||||
auto axesToDestroy = std::unordered_set<int64_t>{};
|
||||
for (mlir::Attribute axisAttribute : axes) {
|
||||
@@ -543,8 +544,8 @@ mlir::LogicalResult ConcatOp::verify() {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
int64_t axis = this->axis();
|
||||
mlir::Value out = this->out();
|
||||
int64_t axis = this->getAxis();
|
||||
mlir::Value out = this->getOut();
|
||||
|
||||
auto outVectorType = out.getType().dyn_cast<mlir::TensorType>();
|
||||
auto outElementType =
|
||||
@@ -561,7 +562,7 @@ mlir::LogicalResult ConcatOp::verify() {
|
||||
int64_t expectedOutputElementsInAxis = 0;
|
||||
|
||||
size_t index = 0;
|
||||
for (mlir::Value in : this->ins()) {
|
||||
for (mlir::Value in : this->getIns()) {
|
||||
auto inVectorType = in.getType().dyn_cast<mlir::TensorType>();
|
||||
auto inElementType =
|
||||
inVectorType.getElementType().dyn_cast<FHE::FheIntegerInterface>();
|
||||
@@ -626,9 +627,9 @@ mlir::LogicalResult ConcatOp::verify() {
|
||||
/// something else
|
||||
template <typename MatMulOp> mlir::LogicalResult verifyMatmul(MatMulOp &op) {
|
||||
auto lhsType =
|
||||
((mlir::Type)op.lhs().getType()).cast<mlir::RankedTensorType>();
|
||||
((mlir::Type)op.getLhs().getType()).cast<mlir::RankedTensorType>();
|
||||
auto rhsType =
|
||||
((mlir::Type)op.rhs().getType()).cast<mlir::RankedTensorType>();
|
||||
((mlir::Type)op.getRhs().getType()).cast<mlir::RankedTensorType>();
|
||||
|
||||
llvm::ArrayRef<int64_t> lhsShape = lhsType.getShape();
|
||||
llvm::ArrayRef<int64_t> rhsShape = rhsType.getShape();
|
||||
@@ -786,9 +787,10 @@ mlir::LogicalResult MatMulIntEintOp::verify() {
|
||||
mlir::SmallVector<int64_t, 4>
|
||||
getPaddingFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
|
||||
mlir::SmallVector<int64_t, 4> paddingInts;
|
||||
llvm::Optional<mlir::DenseIntElementsAttr> optionalPadding = convOp.padding();
|
||||
if (optionalPadding.hasValue()) {
|
||||
auto paddingAttr = optionalPadding.getValue();
|
||||
std::optional<mlir::DenseIntElementsAttr> optionalPadding =
|
||||
convOp.getPadding();
|
||||
if (optionalPadding.has_value()) {
|
||||
auto paddingAttr = optionalPadding.value();
|
||||
auto paddingAttrShape =
|
||||
paddingAttr.getType().cast<RankedTensorType>().getShape();
|
||||
assert(paddingAttrShape.size() == 1 && paddingAttrShape[0] == 4 &&
|
||||
@@ -804,9 +806,10 @@ getPaddingFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
|
||||
mlir::SmallVector<int64_t, 2>
|
||||
getStridesFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
|
||||
mlir::SmallVector<int64_t, 2> stridesInts;
|
||||
llvm::Optional<mlir::DenseIntElementsAttr> optionalStrides = convOp.strides();
|
||||
if (optionalStrides.hasValue()) {
|
||||
auto stridesAttr = optionalStrides.getValue();
|
||||
std::optional<mlir::DenseIntElementsAttr> optionalStrides =
|
||||
convOp.getStrides();
|
||||
if (optionalStrides.has_value()) {
|
||||
auto stridesAttr = optionalStrides.value();
|
||||
auto stridesAttrShape =
|
||||
stridesAttr.getType().cast<RankedTensorType>().getShape();
|
||||
assert(stridesAttrShape.size() == 1 && stridesAttrShape[0] == 2 &&
|
||||
@@ -822,10 +825,10 @@ getStridesFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
|
||||
mlir::SmallVector<int64_t, 2>
|
||||
getDilationsFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
|
||||
mlir::SmallVector<int64_t, 2> dilationsInts;
|
||||
llvm::Optional<mlir::DenseIntElementsAttr> optionalDilations =
|
||||
convOp.dilations();
|
||||
if (optionalDilations.hasValue()) {
|
||||
auto dilationsAttr = optionalDilations.getValue();
|
||||
std::optional<mlir::DenseIntElementsAttr> optionalDilations =
|
||||
convOp.getDilations();
|
||||
if (optionalDilations.has_value()) {
|
||||
auto dilationsAttr = optionalDilations.value();
|
||||
auto dilationsAttrShape =
|
||||
dilationsAttr.getType().cast<RankedTensorType>().getShape();
|
||||
assert(dilationsAttrShape.size() == 1 && dilationsAttrShape[0] == 2 &&
|
||||
@@ -840,18 +843,18 @@ getDilationsFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
|
||||
}
|
||||
|
||||
int64_t getGroupFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
|
||||
llvm::Optional<uint64_t> optionalGroup = convOp.group();
|
||||
if (optionalGroup.hasValue())
|
||||
return optionalGroup.getValue();
|
||||
std::optional<uint64_t> optionalGroup = convOp.getGroup();
|
||||
if (optionalGroup.has_value())
|
||||
return optionalGroup.value();
|
||||
return 1;
|
||||
}
|
||||
|
||||
/// Verify the Conv2d shapes, attributes, and expected output dimensions
|
||||
mlir::LogicalResult Conv2dOp::verify() {
|
||||
auto inputTy =
|
||||
((mlir::Type)this->input().getType()).cast<mlir::RankedTensorType>();
|
||||
((mlir::Type)this->getInput().getType()).cast<mlir::RankedTensorType>();
|
||||
auto weightTy =
|
||||
((mlir::Type)this->weight().getType()).cast<mlir::RankedTensorType>();
|
||||
((mlir::Type)this->getWeight().getType()).cast<mlir::RankedTensorType>();
|
||||
auto resultTy =
|
||||
((mlir::Type)this->getResult().getType()).cast<mlir::RankedTensorType>();
|
||||
auto inputShape = inputTy.getShape();
|
||||
@@ -890,9 +893,10 @@ mlir::LogicalResult Conv2dOp::verify() {
|
||||
|
||||
// Checking attributes
|
||||
mlir::SmallVector<int64_t, 4> paddingInts = getPaddingFromConv2d(*this);
|
||||
llvm::Optional<mlir::DenseIntElementsAttr> optionalPadding = this->padding();
|
||||
if (optionalPadding.hasValue()) {
|
||||
auto paddingAttr = optionalPadding.getValue();
|
||||
std::optional<mlir::DenseIntElementsAttr> optionalPadding =
|
||||
this->getPadding();
|
||||
if (optionalPadding.has_value()) {
|
||||
auto paddingAttr = optionalPadding.value();
|
||||
auto paddingAttrShape =
|
||||
paddingAttr.getType().cast<RankedTensorType>().getShape();
|
||||
if (paddingAttrShape.size() != 1 || paddingAttrShape[0] != 4) {
|
||||
@@ -918,9 +922,10 @@ mlir::LogicalResult Conv2dOp::verify() {
|
||||
}
|
||||
}
|
||||
mlir::SmallVector<int64_t, 2> stridesInts = getStridesFromConv2d(*this);
|
||||
llvm::Optional<mlir::DenseIntElementsAttr> optionalStrides = this->strides();
|
||||
if (optionalStrides.hasValue()) {
|
||||
auto stridesAttr = optionalStrides.getValue();
|
||||
std::optional<mlir::DenseIntElementsAttr> optionalStrides =
|
||||
this->getStrides();
|
||||
if (optionalStrides.has_value()) {
|
||||
auto stridesAttr = optionalStrides.value();
|
||||
auto stridesAttrShape =
|
||||
stridesAttr.getType().cast<RankedTensorType>().getShape();
|
||||
if (stridesAttrShape.size() != 1 || stridesAttrShape[0] != 2) {
|
||||
@@ -939,10 +944,10 @@ mlir::LogicalResult Conv2dOp::verify() {
|
||||
}
|
||||
}
|
||||
mlir::SmallVector<int64_t, 2> dilationsInts = getDilationsFromConv2d(*this);
|
||||
llvm::Optional<mlir::DenseIntElementsAttr> optionalDilations =
|
||||
this->dilations();
|
||||
if (optionalDilations.hasValue()) {
|
||||
auto dilationsAttr = optionalDilations.getValue();
|
||||
std::optional<mlir::DenseIntElementsAttr> optionalDilations =
|
||||
this->getDilations();
|
||||
if (optionalDilations.has_value()) {
|
||||
auto dilationsAttr = optionalDilations.value();
|
||||
auto dilationsAttrShape =
|
||||
dilationsAttr.getType().cast<RankedTensorType>().getShape();
|
||||
if (dilationsAttrShape.size() != 1 || dilationsAttrShape[0] != 2) {
|
||||
@@ -975,7 +980,7 @@ mlir::LogicalResult Conv2dOp::verify() {
|
||||
resultH = resultShape[2], resultW = resultShape[3];
|
||||
|
||||
// Bias check if specified
|
||||
mlir::Value bias = this->bias();
|
||||
mlir::Value bias = this->getBias();
|
||||
if (bias) {
|
||||
auto biasTy = ((mlir::Type)bias.getType()).cast<mlir::RankedTensorType>();
|
||||
auto biasShape = biasTy.getShape();
|
||||
@@ -1050,7 +1055,7 @@ mlir::LogicalResult Conv2dOp::verify() {
|
||||
|
||||
mlir::LogicalResult Maxpool2dOp::verify() {
|
||||
const mlir::RankedTensorType inputTy =
|
||||
this->input().getType().cast<mlir::RankedTensorType>();
|
||||
this->getInput().getType().cast<mlir::RankedTensorType>();
|
||||
const mlir::RankedTensorType outputTy =
|
||||
this->getResult().getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
@@ -1085,7 +1090,7 @@ mlir::LogicalResult Maxpool2dOp::verify() {
|
||||
const int64_t inputH = inputShape[2];
|
||||
const int64_t inputW = inputShape[3];
|
||||
|
||||
const mlir::DenseIntElementsAttr kernelShapeAttr = this->kernel_shape();
|
||||
const mlir::DenseIntElementsAttr kernelShapeAttr = this->getKernelShape();
|
||||
const mlir::RankedTensorType kernelShapeAttrTy =
|
||||
kernelShapeAttr.getType().cast<mlir::RankedTensorType>();
|
||||
const llvm::ArrayRef<int64_t> kernelShapeAttrShape =
|
||||
@@ -1108,9 +1113,9 @@ mlir::LogicalResult Maxpool2dOp::verify() {
|
||||
|
||||
mlir::SmallVector<int64_t, 2> strides;
|
||||
const llvm::Optional<mlir::DenseIntElementsAttr> maybeStridesAttr =
|
||||
this->strides();
|
||||
if (maybeStridesAttr.hasValue()) {
|
||||
const mlir::DenseIntElementsAttr stridesAttr = maybeStridesAttr.getValue();
|
||||
this->getStrides();
|
||||
if (maybeStridesAttr.has_value()) {
|
||||
const mlir::DenseIntElementsAttr stridesAttr = maybeStridesAttr.value();
|
||||
const mlir::RankedTensorType stridesAttrTy =
|
||||
stridesAttr.getType().cast<mlir::RankedTensorType>();
|
||||
const llvm::ArrayRef<int64_t> stridesAttrShape = stridesAttrTy.getShape();
|
||||
@@ -1141,10 +1146,9 @@ mlir::LogicalResult Maxpool2dOp::verify() {
|
||||
|
||||
mlir::SmallVector<int64_t, 2> dilations;
|
||||
const llvm::Optional<mlir::DenseIntElementsAttr> maybeDilationsAttr =
|
||||
this->dilations();
|
||||
if (maybeDilationsAttr.hasValue()) {
|
||||
const mlir::DenseIntElementsAttr dilationsAttr =
|
||||
maybeDilationsAttr.getValue();
|
||||
this->getDilations();
|
||||
if (maybeDilationsAttr.has_value()) {
|
||||
const mlir::DenseIntElementsAttr dilationsAttr = maybeDilationsAttr.value();
|
||||
const mlir::RankedTensorType dilationsAttrTy =
|
||||
dilationsAttr.getType().cast<mlir::RankedTensorType>();
|
||||
const llvm::ArrayRef<int64_t> dilationsAttrShape =
|
||||
@@ -1185,7 +1189,7 @@ mlir::LogicalResult Maxpool2dOp::verify() {
|
||||
expectedOutputW,
|
||||
};
|
||||
|
||||
if (outputShape != llvm::makeArrayRef(expectedOutputShape)) {
|
||||
if (outputShape != llvm::ArrayRef(expectedOutputShape)) {
|
||||
this->emitOpError() << "expected output to be of shape "
|
||||
<< "(" << expectedOutputShape << ") "
|
||||
<< "but it is of shape "
|
||||
@@ -1203,7 +1207,9 @@ mlir::LogicalResult FromElementOp::verify() {
|
||||
auto inType = in.getType();
|
||||
auto outType = out.getType().dyn_cast<mlir::TensorType>();
|
||||
|
||||
auto expectedOutType = outType.cloneWith({1}, inType);
|
||||
llvm::SmallVector<int64_t> shape{1};
|
||||
auto expectedOutType = outType.cloneWith(std::optional{shape}, inType);
|
||||
|
||||
if (outType != expectedOutType) {
|
||||
this->emitOpError() << "has invalid output type (expected "
|
||||
<< expectedOutType << ", got " << outType << ")";
|
||||
@@ -1215,7 +1221,7 @@ mlir::LogicalResult FromElementOp::verify() {
|
||||
|
||||
/// Verify the transpose shapes
|
||||
mlir::LogicalResult TransposeOp::verify() {
|
||||
mlir::Type tensorTy = ((mlir::Type)this->tensor().getType());
|
||||
mlir::Type tensorTy = ((mlir::Type)this->getTensor().getType());
|
||||
if (!tensorTy.isa<RankedTensorType>()) {
|
||||
this->emitOpError() << "should have operand as tensor";
|
||||
return mlir::failure();
|
||||
@@ -1243,7 +1249,7 @@ mlir::LogicalResult TransposeOp::verify() {
|
||||
|
||||
int64_t inputDimensions = (int64_t)inShape.size();
|
||||
|
||||
mlir::ArrayAttr axes = this->axes();
|
||||
mlir::ArrayAttr axes = this->getAxes();
|
||||
if (axes.empty()) {
|
||||
for (int64_t i = 0; i < inputDimensions; i++) {
|
||||
if (inShape[i] != outShape[inputDimensions - (i + 1)]) {
|
||||
@@ -1294,7 +1300,7 @@ mlir::LogicalResult TransposeOp::verify() {
|
||||
}
|
||||
|
||||
mlir::LogicalResult ToSignedOp::verify() {
|
||||
auto inputType = this->input().getType().cast<mlir::ShapedType>();
|
||||
auto inputType = this->getInput().getType().cast<mlir::ShapedType>();
|
||||
auto outputType = this->getResult().getType().cast<mlir::ShapedType>();
|
||||
|
||||
llvm::ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
@@ -1322,7 +1328,7 @@ mlir::LogicalResult ToSignedOp::verify() {
|
||||
|
||||
mlir::LogicalResult ToUnsignedOp::verify() {
|
||||
mlir::ShapedType inputType =
|
||||
this->input().getType().dyn_cast_or_null<mlir::ShapedType>();
|
||||
this->getInput().getType().dyn_cast_or_null<mlir::ShapedType>();
|
||||
mlir::ShapedType outputType =
|
||||
this->getResult().getType().dyn_cast_or_null<mlir::ShapedType>();
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Dialect/SCF/IR/SCF.h>
|
||||
|
||||
@@ -22,8 +22,8 @@
|
||||
#include <mlir/Dialect/Bufferization/Transforms/Passes.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/IR/BlockAndValueMapping.h>
|
||||
#include <mlir/IR/Builders.h>
|
||||
#include <mlir/IR/IRMapping.h>
|
||||
#include <mlir/Transforms/RegionUtils.h>
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
|
||||
@@ -16,13 +16,13 @@
|
||||
#include <concretelang/Support/Constants.h>
|
||||
#include <concretelang/Support/math.h>
|
||||
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/IR/Attributes.h>
|
||||
#include <mlir/IR/BlockAndValueMapping.h>
|
||||
#include <mlir/IR/Builders.h>
|
||||
#include <mlir/IR/BuiltinAttributes.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/IRMapping.h>
|
||||
#include <mlir/IR/OperationSupport.h>
|
||||
#include <mlir/IR/PatternMatch.h>
|
||||
#include <mlir/Interfaces/ViewLikeInterface.h>
|
||||
@@ -94,7 +94,7 @@ aggregateBeneficiaryOps(Operation *op, SetVector<Operation *> &beneficiaryOps,
|
||||
}
|
||||
|
||||
LogicalResult coarsenDFTask(RT::DataflowTaskOp taskOp) {
|
||||
Region &taskOpBody = taskOp.body();
|
||||
Region &taskOpBody = taskOp.getBody();
|
||||
|
||||
// Identify uses from values defined outside of the scope.
|
||||
SetVector<Value> sinkCandidates;
|
||||
@@ -111,7 +111,7 @@ LogicalResult coarsenDFTask(RT::DataflowTaskOp taskOp) {
|
||||
}
|
||||
|
||||
// Insert operations so that the defs get cloned before uses.
|
||||
BlockAndValueMapping map;
|
||||
IRMapping map;
|
||||
OpBuilder builder(taskOpBody);
|
||||
for (Operation *op : toBeSunk) {
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
@@ -156,7 +156,7 @@ struct BuildDataflowTaskGraphPass
|
||||
protected:
|
||||
void processOperation(mlir::Operation *op) {
|
||||
if (isCandidateForTask(op)) {
|
||||
BlockAndValueMapping map;
|
||||
IRMapping map;
|
||||
Region &opBody = getOperation().getBody();
|
||||
OpBuilder builder(opBody);
|
||||
|
||||
@@ -166,7 +166,7 @@ protected:
|
||||
op->getLoc(), op->getResultTypes(), op->getOperands());
|
||||
|
||||
// Add the operation to the task
|
||||
OpBuilder tbbuilder(dftop.body());
|
||||
OpBuilder tbbuilder(dftop.getBody());
|
||||
Operation *clonedOp = tbbuilder.clone(*op, map);
|
||||
|
||||
// Coarsen granularity by aggregating all dependence related
|
||||
@@ -180,7 +180,7 @@ protected:
|
||||
// Replace the uses of defined values
|
||||
for (auto pair : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||
replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair),
|
||||
dftop.body());
|
||||
dftop.getBody());
|
||||
// Replace uses of the values defined by the task
|
||||
for (auto pair : llvm::zip(op->getResults(), dftop->getResults()))
|
||||
replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair),
|
||||
|
||||
@@ -21,12 +21,12 @@
|
||||
#include <concretelang/Support/math.h>
|
||||
|
||||
#include <llvm/IR/Instructions.h>
|
||||
#include <mlir/Analysis/DataFlowAnalysis.h>
|
||||
#include <mlir/Analysis/DataFlowFramework.h>
|
||||
#include <mlir/Conversion/LLVMCommon/ConversionTarget.h>
|
||||
#include <mlir/Conversion/LLVMCommon/Pattern.h>
|
||||
#include <mlir/Conversion/LLVMCommon/VectorPattern.h>
|
||||
#include <mlir/Dialect/Affine/Utils.h>
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/Dialect/Bufferization/Transforms/Passes.h>
|
||||
#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
@@ -35,10 +35,10 @@
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/IR/Attributes.h>
|
||||
#include <mlir/IR/BlockAndValueMapping.h>
|
||||
#include <mlir/IR/Builders.h>
|
||||
#include <mlir/IR/BuiltinAttributes.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/IRMapping.h>
|
||||
#include <mlir/IR/SymbolTable.h>
|
||||
#include <mlir/Interfaces/ViewLikeInterface.h>
|
||||
#include <mlir/Pass/PassManager.h>
|
||||
@@ -60,7 +60,7 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp,
|
||||
StringRef workFunctionName) {
|
||||
Location loc = DFTOp.getLoc();
|
||||
OpBuilder builder(DFTOp.getContext());
|
||||
Region &DFTOpBody = DFTOp.body();
|
||||
Region &DFTOpBody = DFTOp.getBody();
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
|
||||
// Instead of outlining with the same operands/results, we pass all
|
||||
@@ -82,7 +82,7 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp,
|
||||
outlinedEntryBlock->addArguments(type.getInputs(), locations);
|
||||
outlinedFuncBody.push_back(outlinedEntryBlock);
|
||||
|
||||
BlockAndValueMapping map;
|
||||
IRMapping map;
|
||||
int input_offset = DFTOp.getNumResults();
|
||||
Block &entryBlock = outlinedFuncBody.front();
|
||||
builder.setInsertionPointToStart(&entryBlock);
|
||||
@@ -241,7 +241,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
// unsupported even in the LLVMIR Dialect - this needs to use two
|
||||
// placeholders for each output, before and after the
|
||||
// CreateAsyncTaskOp.
|
||||
BlockAndValueMapping map;
|
||||
IRMapping map;
|
||||
for (auto result : DFTOp.getResults()) {
|
||||
Type futType = RT::PointerType::get(RT::FutureType::get(result.getType()));
|
||||
auto brpp = builder.create<RT::BuildReturnPtrPlaceholderOp>(DFTOp.getLoc(),
|
||||
|
||||
@@ -21,20 +21,20 @@
|
||||
#include <concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h>
|
||||
#include <llvm/IR/Instructions.h>
|
||||
#include <llvm/Support/Compiler.h>
|
||||
#include <mlir/Analysis/DataFlowAnalysis.h>
|
||||
#include <mlir/Analysis/DataFlowFramework.h>
|
||||
#include <mlir/Conversion/LLVMCommon/ConversionTarget.h>
|
||||
#include <mlir/Conversion/LLVMCommon/Pattern.h>
|
||||
#include <mlir/Conversion/LLVMCommon/VectorPattern.h>
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/Func/Transforms/FuncConversions.h>
|
||||
#include <mlir/Dialect/LLVMIR/FunctionCallUtils.h>
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/IR/Attributes.h>
|
||||
#include <mlir/IR/BlockAndValueMapping.h>
|
||||
#include <mlir/IR/Builders.h>
|
||||
#include <mlir/IR/BuiltinAttributes.h>
|
||||
#include <mlir/IR/IRMapping.h>
|
||||
#include <mlir/IR/SymbolTable.h>
|
||||
#include <mlir/Pass/PassManager.h>
|
||||
#include <mlir/Support/LLVM.h>
|
||||
@@ -113,16 +113,18 @@ struct MakeReadyFutureOpInterfaceLowering
|
||||
// explicitly space that we can reference as a base for the
|
||||
// future.
|
||||
auto allocFuncOp = mlir::LLVM::lookupOrCreateMallocFn(
|
||||
mrfOp->getParentOfType<ModuleOp>(), getIndexType());
|
||||
mrfOp->getParentOfType<ModuleOp>(), getIndexType(),
|
||||
getTypeConverter()->useOpaquePointers());
|
||||
auto sizeBytes = getSizeInBytes(
|
||||
mrfOp.getLoc(), adaptor.getOperands().getTypes().front(), rewriter);
|
||||
auto results = mlir::LLVM::createLLVMCall(
|
||||
rewriter, mrfOp.getLoc(), allocFuncOp, {sizeBytes}, getVoidPtrType());
|
||||
|
||||
auto results =
|
||||
rewriter.create<LLVM::CallOp>(mrfOp.getLoc(), allocFuncOp, sizeBytes);
|
||||
Value allocatedPtr = rewriter.create<mlir::LLVM::BitcastOp>(
|
||||
mrfOp.getLoc(),
|
||||
mlir::LLVM::LLVMPointerType::get(
|
||||
adaptor.getOperands().getTypes().front()),
|
||||
results[0]);
|
||||
results.getResult());
|
||||
rewriter.create<LLVM::StoreOp>(mrfOp.getLoc(),
|
||||
adaptor.getOperands().front(), allocatedPtr);
|
||||
SmallVector<Value, 4> mrfOperands = {adaptor.getOperands()};
|
||||
@@ -150,7 +152,7 @@ struct AwaitFutureOpInterfaceLowering
|
||||
afOp.getLoc(),
|
||||
mlir::LLVM::LLVMPointerType::get(
|
||||
(*getTypeConverter()).convertType(afOp.getResult().getType())),
|
||||
afCallOp.getResult(0));
|
||||
afCallOp.getResult());
|
||||
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(afOp, futVal);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -42,7 +42,8 @@ void RTDialect::initialize() {
|
||||
::mlir::Type RTDialect::parseType(::mlir::DialectAsmParser &parser) const {
|
||||
mlir::Type type;
|
||||
if (parser.parseOptionalKeyword("future").succeeded()) {
|
||||
generatedTypeParser(parser, "future", type);
|
||||
llvm::StringRef mnenomic;
|
||||
generatedTypeParser(parser, &mnenomic, type);
|
||||
return type;
|
||||
}
|
||||
return type;
|
||||
|
||||
@@ -35,65 +35,65 @@ void DataflowTaskOp::getSuccessorRegions(
|
||||
Optional<unsigned> index, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {}
|
||||
|
||||
llvm::Optional<mlir::Operation *>
|
||||
std::optional<mlir::Operation *>
|
||||
DataflowTaskOp::buildDealloc(OpBuilder &builder, Value alloc) {
|
||||
return builder.create<DeallocateFutureOp>(alloc.getLoc(), alloc)
|
||||
.getOperation();
|
||||
}
|
||||
llvm::Optional<mlir::Value> DataflowTaskOp::buildClone(OpBuilder &builder,
|
||||
Value alloc) {
|
||||
std::optional<mlir::Value> DataflowTaskOp::buildClone(OpBuilder &builder,
|
||||
Value alloc) {
|
||||
return builder.create<CloneFutureOp>(alloc.getLoc(), alloc).getResult();
|
||||
}
|
||||
void DataflowTaskOp::getEffects(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||
&effects) {
|
||||
for (auto input : inputs())
|
||||
for (auto input : getInputs())
|
||||
effects.emplace_back(MemoryEffects::Read::get(), input,
|
||||
SideEffects::DefaultResource::get());
|
||||
for (auto output : outputs())
|
||||
for (auto output : getOutputs())
|
||||
effects.emplace_back(MemoryEffects::Write::get(), output,
|
||||
SideEffects::DefaultResource::get());
|
||||
for (auto output : outputs())
|
||||
for (auto output : getOutputs())
|
||||
effects.emplace_back(MemoryEffects::Allocate::get(), output,
|
||||
SideEffects::DefaultResource::get());
|
||||
}
|
||||
|
||||
llvm::Optional<mlir::Operation *>
|
||||
CloneFutureOp::buildDealloc(OpBuilder &builder, Value alloc) {
|
||||
std::optional<mlir::Operation *> CloneFutureOp::buildDealloc(OpBuilder &builder,
|
||||
Value alloc) {
|
||||
return builder.create<DeallocateFutureOp>(alloc.getLoc(), alloc)
|
||||
.getOperation();
|
||||
}
|
||||
llvm::Optional<mlir::Value> CloneFutureOp::buildClone(OpBuilder &builder,
|
||||
Value alloc) {
|
||||
std::optional<mlir::Value> CloneFutureOp::buildClone(OpBuilder &builder,
|
||||
Value alloc) {
|
||||
return builder.create<CloneFutureOp>(alloc.getLoc(), alloc).getResult();
|
||||
}
|
||||
void CloneFutureOp::getEffects(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||
&effects) {
|
||||
effects.emplace_back(MemoryEffects::Read::get(), input(),
|
||||
effects.emplace_back(MemoryEffects::Read::get(), getInput(),
|
||||
SideEffects::DefaultResource::get());
|
||||
effects.emplace_back(MemoryEffects::Write::get(), output(),
|
||||
effects.emplace_back(MemoryEffects::Write::get(), getOutput(),
|
||||
SideEffects::DefaultResource::get());
|
||||
effects.emplace_back(MemoryEffects::Allocate::get(), output(),
|
||||
effects.emplace_back(MemoryEffects::Allocate::get(), getOutput(),
|
||||
SideEffects::DefaultResource::get());
|
||||
}
|
||||
|
||||
llvm::Optional<mlir::Operation *>
|
||||
std::optional<mlir::Operation *>
|
||||
MakeReadyFutureOp::buildDealloc(OpBuilder &builder, Value alloc) {
|
||||
return builder.create<DeallocateFutureOp>(alloc.getLoc(), alloc)
|
||||
.getOperation();
|
||||
}
|
||||
llvm::Optional<mlir::Value> MakeReadyFutureOp::buildClone(OpBuilder &builder,
|
||||
Value alloc) {
|
||||
std::optional<mlir::Value> MakeReadyFutureOp::buildClone(OpBuilder &builder,
|
||||
Value alloc) {
|
||||
return builder.create<CloneFutureOp>(alloc.getLoc(), alloc).getResult();
|
||||
}
|
||||
void MakeReadyFutureOp::getEffects(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||
&effects) {
|
||||
effects.emplace_back(MemoryEffects::Read::get(), input(),
|
||||
effects.emplace_back(MemoryEffects::Read::get(), getInput(),
|
||||
SideEffects::DefaultResource::get());
|
||||
effects.emplace_back(MemoryEffects::Write::get(), output(),
|
||||
effects.emplace_back(MemoryEffects::Write::get(), getOutput(),
|
||||
SideEffects::DefaultResource::get());
|
||||
effects.emplace_back(MemoryEffects::Allocate::get(), output(),
|
||||
effects.emplace_back(MemoryEffects::Allocate::get(), getOutput(),
|
||||
SideEffects::DefaultResource::get());
|
||||
}
|
||||
|
||||
@@ -36,14 +36,14 @@ struct DerefWorkFunctionArgumentPtrPlaceholderOpBufferizationInterface
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const AnalysisState &state) const {
|
||||
return BufferRelation::None;
|
||||
return BufferRelation::Unknown;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *bop, RewriterBase &rewriter,
|
||||
@@ -70,7 +70,7 @@ struct DerefWorkFunctionArgumentPtrPlaceholderOpBufferizationInterface
|
||||
if (failed(bufferOrErr))
|
||||
return failure();
|
||||
|
||||
Value buffer = bufferOrErr.getValue();
|
||||
Value buffer = bufferOrErr.value();
|
||||
newOperands.push_back(buffer);
|
||||
} else {
|
||||
newOperands.push_back(opOperand.get());
|
||||
@@ -81,7 +81,7 @@ struct DerefWorkFunctionArgumentPtrPlaceholderOpBufferizationInterface
|
||||
|
||||
for (OpResult res : op->getResults()) {
|
||||
if (TensorType t = res.getType().dyn_cast<TensorType>()) {
|
||||
BaseMemRefType memrefType = getMemRefType(t, options);
|
||||
BaseMemRefType memrefType = getMemRefType(res, options);
|
||||
newResultTypes.push_back(memrefType);
|
||||
} else {
|
||||
newResultTypes.push_back(res.getType());
|
||||
@@ -112,14 +112,14 @@ struct MakeReadyFutureOpBufferizationInterface
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const AnalysisState &state) const {
|
||||
return BufferRelation::None;
|
||||
return BufferRelation::Unknown;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *bop, RewriterBase &rewriter,
|
||||
@@ -145,7 +145,7 @@ struct MakeReadyFutureOpBufferizationInterface
|
||||
if (failed(bufferOrErr))
|
||||
return failure();
|
||||
|
||||
Value buffer = bufferOrErr.getValue();
|
||||
Value buffer = bufferOrErr.value();
|
||||
newOperands.push_back(buffer);
|
||||
} else {
|
||||
newOperands.push_back(opOperand.get());
|
||||
@@ -156,7 +156,7 @@ struct MakeReadyFutureOpBufferizationInterface
|
||||
|
||||
for (OpResult res : op->getResults()) {
|
||||
if (TensorType t = res.getType().dyn_cast<TensorType>()) {
|
||||
BaseMemRefType memrefType = getMemRefType(t, options);
|
||||
BaseMemRefType memrefType = getMemRefType(res, options);
|
||||
newResultTypes.push_back(memrefType);
|
||||
} else {
|
||||
newResultTypes.push_back(res.getType());
|
||||
@@ -186,14 +186,14 @@ struct WorkFunctionReturnOpBufferizationInterface
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const AnalysisState &state) const {
|
||||
return BufferRelation::None;
|
||||
return BufferRelation::Unknown;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *bop, RewriterBase &rewriter,
|
||||
@@ -219,7 +219,7 @@ struct WorkFunctionReturnOpBufferizationInterface
|
||||
if (failed(bufferOrErr))
|
||||
return failure();
|
||||
|
||||
Value buffer = bufferOrErr.getValue();
|
||||
Value buffer = bufferOrErr.value();
|
||||
newOperands.push_back(buffer);
|
||||
} else {
|
||||
newOperands.push_back(opOperand.get());
|
||||
@@ -230,7 +230,7 @@ struct WorkFunctionReturnOpBufferizationInterface
|
||||
|
||||
for (OpResult res : op->getResults()) {
|
||||
if (TensorType t = res.getType().dyn_cast<TensorType>()) {
|
||||
BaseMemRefType memrefType = getMemRefType(t, options);
|
||||
BaseMemRefType memrefType = getMemRefType(res, options);
|
||||
newResultTypes.push_back(memrefType);
|
||||
} else {
|
||||
newResultTypes.push_back(res.getType());
|
||||
@@ -260,14 +260,14 @@ struct AwaitFutureOpBufferizationInterface
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const AnalysisState &state) const {
|
||||
return BufferRelation::None;
|
||||
return BufferRelation::Unknown;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *bop, RewriterBase &rewriter,
|
||||
@@ -293,7 +293,7 @@ struct AwaitFutureOpBufferizationInterface
|
||||
if (failed(bufferOrErr))
|
||||
return failure();
|
||||
|
||||
Value buffer = bufferOrErr.getValue();
|
||||
Value buffer = bufferOrErr.value();
|
||||
newOperands.push_back(buffer);
|
||||
} else {
|
||||
newOperands.push_back(opOperand.get());
|
||||
@@ -304,7 +304,7 @@ struct AwaitFutureOpBufferizationInterface
|
||||
|
||||
for (OpResult res : op->getResults()) {
|
||||
if (TensorType t = res.getType().dyn_cast<TensorType>()) {
|
||||
BaseMemRefType memrefType = getMemRefType(t, options);
|
||||
BaseMemRefType memrefType = getMemRefType(res, options);
|
||||
newResultTypes.push_back(memrefType);
|
||||
} else {
|
||||
newResultTypes.push_back(res.getType());
|
||||
|
||||
@@ -7,7 +7,7 @@ add_mlir_dialect_library(
|
||||
mlir-headers
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
MLIRArithmeticDialect
|
||||
MLIRArithDialect
|
||||
MLIRBufferizationDialect
|
||||
MLIRBufferizationTransforms
|
||||
MLIRIR
|
||||
|
||||
@@ -19,8 +19,8 @@ namespace concretelang {
|
||||
namespace SDFG {
|
||||
mlir::LogicalResult Put::verify() {
|
||||
mlir::Type streamElementType =
|
||||
stream().getType().cast<StreamType>().getElementType();
|
||||
mlir::Type elementType = data().getType();
|
||||
getStream().getType().cast<StreamType>().getElementType();
|
||||
mlir::Type elementType = getData().getType();
|
||||
|
||||
if (streamElementType != elementType) {
|
||||
emitError()
|
||||
@@ -34,10 +34,10 @@ mlir::LogicalResult Put::verify() {
|
||||
}
|
||||
|
||||
mlir::LogicalResult MakeProcess::checkStreams(size_t numIn, size_t numOut) {
|
||||
mlir::OperandRange streams = this->streams();
|
||||
mlir::OperandRange streams = this->getStreams();
|
||||
|
||||
if (streams.size() != numIn + numOut) {
|
||||
emitError() << "Process `" << stringifyProcessKind(type())
|
||||
emitError() << "Process `" << stringifyProcessKind(getType())
|
||||
<< "` expects 3 streams, but " << streams.size()
|
||||
<< " were given.";
|
||||
return mlir::failure();
|
||||
@@ -48,7 +48,7 @@ mlir::LogicalResult MakeProcess::checkStreams(size_t numIn, size_t numOut) {
|
||||
|
||||
if (in && !in.createsInputStream()) {
|
||||
emitError() << "Stream #" << (i + 1) << " of process `"
|
||||
<< stringifyProcessKind(type())
|
||||
<< stringifyProcessKind(getType())
|
||||
<< "` must be an input stream.";
|
||||
return mlir::failure();
|
||||
}
|
||||
@@ -59,7 +59,7 @@ mlir::LogicalResult MakeProcess::checkStreams(size_t numIn, size_t numOut) {
|
||||
|
||||
if (out && !out.createsOutputStream()) {
|
||||
emitError() << "Stream #" << (i + 1) << " of process `"
|
||||
<< stringifyProcessKind(type())
|
||||
<< stringifyProcessKind(getType())
|
||||
<< "` must be an output stream.";
|
||||
return mlir::failure();
|
||||
}
|
||||
@@ -69,7 +69,7 @@ mlir::LogicalResult MakeProcess::checkStreams(size_t numIn, size_t numOut) {
|
||||
}
|
||||
|
||||
mlir::LogicalResult MakeProcess::verify() {
|
||||
switch (type()) {
|
||||
switch (getType()) {
|
||||
case ProcessKind::add_eint:
|
||||
return checkStreams(2, 1);
|
||||
case ProcessKind::add_eint_int:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
@@ -43,7 +43,7 @@ namespace {} // namespace
|
||||
namespace {
|
||||
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
|
||||
size_t rank) {
|
||||
std::vector<int64_t> shape(rank, -1);
|
||||
std::vector<int64_t> shape(rank, mlir::ShapedType::kDynamic);
|
||||
mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0);
|
||||
for (size_t i = 0; i < rank; i++) {
|
||||
expr = expr +
|
||||
@@ -84,14 +84,14 @@ struct BufferizableWithCallOpInterface
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const AnalysisState &state) const {
|
||||
return BufferRelation::None;
|
||||
return BufferRelation::Unknown;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
|
||||
@@ -14,7 +14,7 @@ add_mlir_dialect_library(
|
||||
SDFGDialect
|
||||
ConcretelangSDFGInterfaces
|
||||
ConcretelangConversion
|
||||
MLIRArithmeticDialect
|
||||
MLIRArithDialect
|
||||
MLIRBufferizationDialect
|
||||
MLIRBufferizationTransforms
|
||||
MLIRIR
|
||||
|
||||
@@ -58,8 +58,8 @@ mlir::LogicalResult _verifyGLWEIntegerOperator(mlir::OpState &op,
|
||||
/// (!TFHE.glwe<{dim,poly,bits}{p}>))
|
||||
template <class Operator>
|
||||
mlir::LogicalResult verifyGLWEIntegerOperator(Operator &op) {
|
||||
auto a = ((mlir::Type)(op.a().getType())).cast<GLWECipherTextType>();
|
||||
auto b = ((mlir::Type)(op.b().getType())).cast<IntegerType>();
|
||||
auto a = ((mlir::Type)(op.getA().getType())).cast<GLWECipherTextType>();
|
||||
auto b = ((mlir::Type)(op.getB().getType())).cast<IntegerType>();
|
||||
auto result =
|
||||
((mlir::Type)(op.getResult().getType())).cast<GLWECipherTextType>();
|
||||
|
||||
@@ -71,8 +71,8 @@ mlir::LogicalResult verifyGLWEIntegerOperator(Operator &op) {
|
||||
/// (!TFHE.glwe<{dim,poly,bits}{p}>))
|
||||
template <class Operator>
|
||||
mlir::LogicalResult verifyIntegerGLWEOperator(Operator &op) {
|
||||
auto a = ((mlir::Type)(op.a().getType())).cast<IntegerType>();
|
||||
auto b = ((mlir::Type)(op.b().getType())).cast<GLWECipherTextType>();
|
||||
auto a = ((mlir::Type)(op.getA().getType())).cast<IntegerType>();
|
||||
auto b = ((mlir::Type)(op.getB().getType())).cast<GLWECipherTextType>();
|
||||
auto result =
|
||||
((mlir::Type)(op.getResult().getType())).cast<GLWECipherTextType>();
|
||||
|
||||
@@ -85,8 +85,8 @@ mlir::LogicalResult verifyIntegerGLWEOperator(Operator &op) {
|
||||
/// (!TFHE.glwe<{dim,poly,bits}{p}>))
|
||||
template <class Operator>
|
||||
mlir::LogicalResult verifyBinaryGLWEOperator(Operator &op) {
|
||||
auto a = ((mlir::Type)(op.a().getType())).cast<GLWECipherTextType>();
|
||||
auto b = ((mlir::Type)(op.b().getType())).cast<GLWECipherTextType>();
|
||||
auto a = ((mlir::Type)(op.getA().getType())).cast<GLWECipherTextType>();
|
||||
auto b = ((mlir::Type)(op.getB().getType())).cast<GLWECipherTextType>();
|
||||
auto result =
|
||||
((mlir::Type)(op.getResult().getType())).cast<GLWECipherTextType>();
|
||||
|
||||
@@ -114,7 +114,7 @@ mlir::LogicalResult verifyBinaryGLWEOperator(Operator &op) {
|
||||
/// (!TFHE.glwe<{dim,poly,bits}{p}>))
|
||||
template <class Operator>
|
||||
mlir::LogicalResult verifyUnaryGLWEOperator(Operator &op) {
|
||||
auto a = ((mlir::Type)(op.a().getType())).cast<GLWECipherTextType>();
|
||||
auto a = ((mlir::Type)(op.getA().getType())).cast<GLWECipherTextType>();
|
||||
auto result =
|
||||
((mlir::Type)(op.getResult().getType())).cast<GLWECipherTextType>();
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/IR/PatternMatch.h>
|
||||
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
|
||||
|
||||
@@ -17,7 +17,7 @@ namespace concretelang {
|
||||
namespace {
|
||||
|
||||
/// Get the constant integer that the cleartext was created from if it exists.
|
||||
llvm::Optional<IntegerAttr>
|
||||
std::optional<IntegerAttr>
|
||||
getConstantIntFromCleartextIfExists(mlir::Value cleartext) {
|
||||
auto constantOp = cleartext.getDefiningOp();
|
||||
if (constantOp == nullptr)
|
||||
@@ -46,8 +46,8 @@ public:
|
||||
auto cleartext = op.getOperand(1);
|
||||
auto constIntToMul = getConstantIntFromCleartextIfExists(cleartext);
|
||||
// Constant integer
|
||||
if (constIntToMul.hasValue()) {
|
||||
auto toMul = constIntToMul.getValue().getInt();
|
||||
if (constIntToMul.has_value()) {
|
||||
auto toMul = constIntToMul.value().getInt();
|
||||
if (toMul == 0) {
|
||||
rewriter.replaceOpWithNewOp<mlir::concretelang::TFHE::ZeroGLWEOp>(
|
||||
op, op.getResult().getType());
|
||||
|
||||
@@ -36,14 +36,14 @@ struct TrivialBufferizableInterface
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const AnalysisState &state) const {
|
||||
return BufferRelation::None;
|
||||
return BufferRelation::Unknown;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
@@ -55,7 +55,7 @@ struct TrivialBufferizableInterface
|
||||
operands.push_back(operand.get());
|
||||
} else {
|
||||
operands.push_back(
|
||||
bufferization::getBuffer(rewriter, operand.get(), options));
|
||||
*bufferization::getBuffer(rewriter, operand.get(), options));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ add_mlir_dialect_library(
|
||||
mlir-headers
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
MLIRArithmeticDialect
|
||||
MLIRArithDialect
|
||||
MLIRBufferizationDialect
|
||||
MLIRBufferizationTransforms
|
||||
MLIRIR
|
||||
|
||||
@@ -28,8 +28,16 @@ if(APPLE)
|
||||
endif()
|
||||
|
||||
target_include_directories(ConcretelangRuntime PUBLIC ${CONCRETE_CPU_INCLUDE_DIR})
|
||||
target_link_libraries(ConcretelangRuntime PUBLIC concrete_cpu ConcretelangClientLib pthread m dl
|
||||
$<TARGET_OBJECTS:mlir_c_runner_utils>)
|
||||
target_link_libraries(
|
||||
ConcretelangRuntime
|
||||
PUBLIC concrete_cpu
|
||||
ConcretelangClientLib
|
||||
pthread
|
||||
m
|
||||
dl
|
||||
$<TARGET_OBJECTS:mlir_c_runner_utils>
|
||||
$<TARGET_OBJECTS:mlir_float16_utils>
|
||||
$<TARGET_OBJECTS:MLIRSparseTensorRuntime>)
|
||||
|
||||
install(TARGETS ConcretelangRuntime omp EXPORT ConcretelangRuntime)
|
||||
install(EXPORT ConcretelangRuntime DESTINATION "./")
|
||||
|
||||
@@ -57,7 +57,7 @@ void CompilationFeedback::fillFromClientParameters(
|
||||
crtDecompositionsOfOutputs = {};
|
||||
for (auto gate : params.outputs) {
|
||||
std::vector<int64_t> decomposition;
|
||||
if (gate.encryption.hasValue()) {
|
||||
if (gate.encryption.has_value()) {
|
||||
decomposition = gate.encryption->encoding.crt;
|
||||
}
|
||||
crtDecompositionsOfOutputs.push_back(decomposition);
|
||||
|
||||
@@ -5,8 +5,9 @@
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h>
|
||||
#include <mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h>
|
||||
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
|
||||
#include <mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h>
|
||||
#include <mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h>
|
||||
#include <mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h>
|
||||
#include <stdio.h>
|
||||
@@ -92,6 +93,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() {
|
||||
registry);
|
||||
scf::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
tensor::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
linalg::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
RT::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
this->mlirContext = new mlir::MLIRContext();
|
||||
this->mlirContext->appendDialectRegistry(registry);
|
||||
@@ -136,17 +138,16 @@ void CompilerEngine::setEnablePass(
|
||||
}
|
||||
|
||||
/// Returns the optimizer::Description
|
||||
llvm::Expected<llvm::Optional<optimizer::Description>>
|
||||
llvm::Expected<std::optional<optimizer::Description>>
|
||||
CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) {
|
||||
mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
|
||||
mlir::ModuleOp module = res.mlirModuleRef->get();
|
||||
// If the values has been overwritten returns
|
||||
if (this->overrideMaxEintPrecision.hasValue() &&
|
||||
this->overrideMaxMANP.hasValue()) {
|
||||
if (this->overrideMaxEintPrecision.has_value() &&
|
||||
this->overrideMaxMANP.has_value()) {
|
||||
auto constraint = mlir::concretelang::V0FHEConstraint{
|
||||
this->overrideMaxMANP.getValue(),
|
||||
this->overrideMaxEintPrecision.getValue()};
|
||||
return optimizer::Description{constraint, llvm::None};
|
||||
this->overrideMaxMANP.value(), this->overrideMaxEintPrecision.value()};
|
||||
return optimizer::Description{constraint, std::nullopt};
|
||||
}
|
||||
auto config = this->compilerOptions.optimizerConfig;
|
||||
auto descriptions = mlir::concretelang::pipeline::getFHEContextFromFHE(
|
||||
@@ -155,10 +156,10 @@ CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) {
|
||||
return std::move(err);
|
||||
}
|
||||
if (descriptions->empty()) { // The pass has not been run
|
||||
return llvm::None;
|
||||
return std::nullopt;
|
||||
}
|
||||
if (this->compilerOptions.clientParametersFuncName.hasValue()) {
|
||||
auto name = this->compilerOptions.clientParametersFuncName.getValue();
|
||||
if (this->compilerOptions.clientParametersFuncName.has_value()) {
|
||||
auto name = this->compilerOptions.clientParametersFuncName.value();
|
||||
auto description = descriptions->find(name);
|
||||
if (description == descriptions->end()) {
|
||||
std::string names;
|
||||
@@ -181,14 +182,14 @@ CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) {
|
||||
/// set the fheContext field if the v0Constraint can be computed
|
||||
/// set the fheContext field if the v0Constraint can be computed
|
||||
llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
|
||||
if (compilerOptions.v0Parameter.hasValue()) {
|
||||
if (compilerOptions.v0Parameter.has_value()) {
|
||||
// parameters come from the compiler options
|
||||
auto v0Params = compilerOptions.v0Parameter.value();
|
||||
if (compilerOptions.largeIntegerParameter.hasValue()) {
|
||||
if (compilerOptions.largeIntegerParameter.has_value()) {
|
||||
v0Params.largeInteger = compilerOptions.largeIntegerParameter;
|
||||
}
|
||||
V0FHEConstraint constraint;
|
||||
if (compilerOptions.v0FHEConstraints.hasValue()) {
|
||||
if (compilerOptions.v0FHEConstraints.has_value()) {
|
||||
constraint = compilerOptions.v0FHEConstraints.value();
|
||||
}
|
||||
res.fheContext.emplace(
|
||||
@@ -201,7 +202,7 @@ llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
|
||||
if (auto err = descr.takeError()) {
|
||||
return err;
|
||||
}
|
||||
if (!descr.get().hasValue()) {
|
||||
if (!descr.get().has_value()) {
|
||||
return llvm::Error::success();
|
||||
}
|
||||
CompilationFeedback feedback;
|
||||
@@ -222,7 +223,7 @@ llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
using OptionalLib = llvm::Optional<std::shared_ptr<CompilerEngine::Library>>;
|
||||
using OptionalLib = std::optional<std::shared_ptr<CompilerEngine::Library>>;
|
||||
// Compile the sources managed by the source manager `sm` to the
|
||||
// target dialect `target`. If successful, the result can be retrieved
|
||||
// using `getModule()` and `getLLVMModule()`, respectively depending
|
||||
@@ -347,28 +348,28 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
|
||||
// Generate client parameters if requested
|
||||
if (this->generateClientParameters) {
|
||||
if (!options.clientParametersFuncName.hasValue()) {
|
||||
if (!options.clientParametersFuncName.has_value()) {
|
||||
return StreamStringError(
|
||||
"Generation of client parameters requested, but no function name "
|
||||
"specified");
|
||||
}
|
||||
if (!res.fheContext.hasValue()) {
|
||||
if (!res.fheContext.has_value()) {
|
||||
return StreamStringError(
|
||||
"Cannot generate client parameters, the fhe context is empty for " +
|
||||
options.clientParametersFuncName.getValue());
|
||||
options.clientParametersFuncName.value());
|
||||
}
|
||||
}
|
||||
// Generate client parameters if requested
|
||||
auto funcName = options.clientParametersFuncName.getValueOr("main");
|
||||
auto funcName = options.clientParametersFuncName.value_or("main");
|
||||
if (this->generateClientParameters || target == Target::LIBRARY) {
|
||||
if (!res.fheContext.hasValue()) {
|
||||
if (!res.fheContext.has_value()) {
|
||||
// Some tests involve call a to non encrypted functions
|
||||
ClientParameters emptyParams;
|
||||
emptyParams.functionName = funcName;
|
||||
res.clientParameters = emptyParams;
|
||||
} else {
|
||||
llvm::Optional<::concretelang::clientlib::ChunkInfo> chunkInfo =
|
||||
llvm::None;
|
||||
std::nullopt;
|
||||
if (options.chunkIntegers) {
|
||||
chunkInfo = ::concretelang::clientlib::ChunkInfo{options.chunkSize,
|
||||
options.chunkWidth};
|
||||
@@ -483,7 +484,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
return StreamStringError(
|
||||
"Internal Error: Please provide a library parameter");
|
||||
}
|
||||
auto objPath = lib.getValue()->addCompilation(res);
|
||||
auto objPath = lib.value()->addCompilation(res);
|
||||
if (!objPath) {
|
||||
return StreamStringError(llvm::toString(objPath.takeError()));
|
||||
}
|
||||
@@ -764,11 +765,11 @@ CompilerEngine::Library::addCompilation(CompilationResult &compilation) {
|
||||
}
|
||||
|
||||
addExtraObjectFilePath(objectPath);
|
||||
if (compilation.clientParameters.hasValue()) {
|
||||
clientParametersList.push_back(compilation.clientParameters.getValue());
|
||||
if (compilation.clientParameters.has_value()) {
|
||||
clientParametersList.push_back(compilation.clientParameters.value());
|
||||
}
|
||||
if (compilation.feedback.hasValue()) {
|
||||
compilationFeedbackList.push_back(compilation.feedback.getValue());
|
||||
if (compilation.feedback.has_value()) {
|
||||
compilationFeedbackList.push_back(compilation.feedback.value());
|
||||
}
|
||||
return objectPath;
|
||||
}
|
||||
@@ -791,7 +792,7 @@ std::string ensureLibDotExt(std::string path, std::string dotExt) {
|
||||
|
||||
llvm::Expected<std::string> CompilerEngine::Library::emit(
|
||||
std::string path, std::string dotExt, std::string linker,
|
||||
llvm::Optional<std::vector<std::string>> extraArgs) {
|
||||
std::optional<std::vector<std::string>> extraArgs) {
|
||||
auto pathDotExt = ensureLibDotExt(path, dotExt);
|
||||
auto error = mlir::concretelang::emitLibrary(objectsPath, pathDotExt, linker,
|
||||
extraArgs);
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
JITSupport::JITSupport(llvm::Optional<std::string> runtimeLibPath)
|
||||
JITSupport::JITSupport(std::optional<std::string> runtimeLibPath)
|
||||
: runtimeLibPath(runtimeLibPath) {}
|
||||
|
||||
llvm::Expected<std::unique_ptr<JitCompilationResult>>
|
||||
@@ -29,7 +29,7 @@ JITSupport::compile(llvm::SourceMgr &program, CompilationOptions options) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
if (!options.clientParametersFuncName.hasValue()) {
|
||||
if (!options.clientParametersFuncName.has_value()) {
|
||||
return StreamStringError("Need to have a funcname to JIT compile");
|
||||
}
|
||||
// Compile from LLVM Dialect to JITLambda
|
||||
@@ -40,7 +40,7 @@ JITSupport::compile(llvm::SourceMgr &program, CompilationOptions options) {
|
||||
if (auto err = lambda.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
if (!compilationResult.get().clientParameters.hasValue()) {
|
||||
if (!compilationResult.get().clientParameters.has_value()) {
|
||||
// i.e. that should not occurs
|
||||
return StreamStringError("No client parameters has been generated");
|
||||
}
|
||||
@@ -52,9 +52,8 @@ JITSupport::compile(llvm::SourceMgr &program, CompilationOptions options) {
|
||||
if (!mlir::concretelang::dfr::_dfr_is_root_node()) {
|
||||
result->clientParameters = clientlib::ClientParameters();
|
||||
} else {
|
||||
result->clientParameters =
|
||||
compilationResult.get().clientParameters.getValue();
|
||||
result->feedback = compilationResult.get().feedback.getValue();
|
||||
result->clientParameters = compilationResult.get().clientParameters.value();
|
||||
result->feedback = compilationResult.get().feedback.value();
|
||||
}
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ namespace concretelang {
|
||||
llvm::Expected<std::unique_ptr<JITLambda>>
|
||||
JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline,
|
||||
llvm::Optional<std::string> runtimeLibPath) {
|
||||
std::optional<std::string> runtimeLibPath) {
|
||||
|
||||
// Looking for the function
|
||||
auto rangeOps = module.getOps<mlir::LLVM::LLVMFuncOp>();
|
||||
@@ -46,13 +46,13 @@ JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
// JIT-compiles the module. If runtimeLibPath is specified, it's passed as a
|
||||
// shared library to the JIT compiler.
|
||||
std::vector<llvm::StringRef> sharedLibPaths;
|
||||
if (runtimeLibPath.hasValue())
|
||||
sharedLibPaths.push_back(runtimeLibPath.getValue());
|
||||
if (runtimeLibPath.has_value())
|
||||
sharedLibPaths.push_back(runtimeLibPath.value());
|
||||
|
||||
mlir::ExecutionEngineOptions execOptions;
|
||||
execOptions.transformer = optPipeline;
|
||||
execOptions.sharedLibPaths = sharedLibPaths;
|
||||
execOptions.jitCodeGenOptLevel = llvm::None;
|
||||
execOptions.jitCodeGenOptLevel = std::nullopt;
|
||||
execOptions.llvmModuleBuilder = nullptr;
|
||||
|
||||
auto maybeEngine = mlir::ExecutionEngine::create(module, execOptions);
|
||||
|
||||
@@ -7,10 +7,10 @@
|
||||
|
||||
#include <llvm/IR/LegacyPassManager.h>
|
||||
#include <llvm/MC/TargetRegistry.h>
|
||||
#include <llvm/Support/Host.h>
|
||||
#include <llvm/Support/ToolOutputFile.h>
|
||||
#include <llvm/Target/TargetMachine.h>
|
||||
#include <llvm/Target/TargetOptions.h>
|
||||
#include <llvm/TargetParser/Host.h>
|
||||
|
||||
#include <mlir/Support/FileUtilities.h>
|
||||
|
||||
@@ -72,13 +72,13 @@ llvm::Error emitObject(llvm::Module &module, string objectPath) {
|
||||
}
|
||||
|
||||
string linkerCmd(vector<string> objectsPath, string libraryPath, string linker,
|
||||
llvm::Optional<vector<string>> extraArgs) {
|
||||
std::optional<vector<string>> extraArgs) {
|
||||
string cmd = linker + libraryPath;
|
||||
for (auto objectPath : objectsPath) {
|
||||
cmd += " " + objectPath;
|
||||
}
|
||||
if (extraArgs.hasValue()) {
|
||||
for (auto extraArg : extraArgs.getValue()) {
|
||||
if (extraArgs.has_value()) {
|
||||
for (auto extraArg : extraArgs.value()) {
|
||||
cmd += " " + extraArg;
|
||||
}
|
||||
}
|
||||
@@ -116,7 +116,7 @@ llvm::Error callCmd(string cmd) {
|
||||
|
||||
llvm::Error emitLibrary(vector<string> objectsPath, string libraryPath,
|
||||
string linker,
|
||||
llvm::Optional<vector<string>> extraArgs) {
|
||||
std::optional<vector<string>> extraArgs) {
|
||||
auto cmd = linkerCmd(objectsPath, libraryPath, linker, extraArgs);
|
||||
return callCmd(cmd);
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
#include <mlir/Transforms/Passes.h>
|
||||
|
||||
#include <mlir/Dialect/Affine/Passes.h>
|
||||
#include <mlir/Dialect/Arithmetic/Transforms/Passes.h>
|
||||
#include <mlir/Dialect/Arith/Transforms/Passes.h>
|
||||
#include <mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h>
|
||||
#include <mlir/Dialect/Bufferization/Transforms/Passes.h>
|
||||
#include <mlir/Dialect/Linalg/Passes.h>
|
||||
@@ -81,12 +81,12 @@ addPotentiallyNestedPass(mlir::PassManager &pm, std::unique_ptr<Pass> pass,
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Expected<std::map<std::string, llvm::Optional<optimizer::Description>>>
|
||||
llvm::Expected<std::map<std::string, std::optional<optimizer::Description>>>
|
||||
getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
optimizer::Config config,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
llvm::Optional<size_t> oMax2norm;
|
||||
llvm::Optional<size_t> oMaxWidth;
|
||||
std::optional<size_t> oMax2norm;
|
||||
std::optional<size_t> oMaxWidth;
|
||||
optimizer::FunctionsDag dags;
|
||||
|
||||
mlir::PassManager pm(&context);
|
||||
@@ -99,10 +99,10 @@ getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
pm,
|
||||
mlir::concretelang::createMaxMANPPass(
|
||||
[&](const uint64_t manp, unsigned width) {
|
||||
if (!oMax2norm.hasValue() || oMax2norm.getValue() < manp)
|
||||
if (!oMax2norm.has_value() || oMax2norm.value() < manp)
|
||||
oMax2norm.emplace(manp);
|
||||
|
||||
if (!oMaxWidth.hasValue() || oMaxWidth.getValue() < width)
|
||||
if (!oMaxWidth.has_value() || oMaxWidth.value() < width)
|
||||
oMaxWidth.emplace(width);
|
||||
}),
|
||||
enablePass);
|
||||
@@ -112,28 +112,28 @@ getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
" required precision",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
llvm::Optional<mlir::concretelang::V0FHEConstraint> constraint = llvm::None;
|
||||
std::optional<mlir::concretelang::V0FHEConstraint> constraint = std::nullopt;
|
||||
|
||||
if (oMax2norm.hasValue() && oMaxWidth.hasValue()) {
|
||||
constraint = llvm::Optional<mlir::concretelang::V0FHEConstraint>(
|
||||
{/*.norm2 = */ ceilLog2(oMax2norm.getValue()),
|
||||
/*.p = */ oMaxWidth.getValue()});
|
||||
if (oMax2norm.has_value() && oMaxWidth.has_value()) {
|
||||
constraint = std::optional<mlir::concretelang::V0FHEConstraint>(
|
||||
{/*.norm2 = */ ceilLog2(oMax2norm.value()),
|
||||
/*.p = */ oMaxWidth.value()});
|
||||
}
|
||||
addPotentiallyNestedPass(pm, optimizer::createDagPass(config, dags),
|
||||
enablePass);
|
||||
if (pm.run(module.getOperation()).failed()) {
|
||||
return StreamStringError() << "Failed to create concrete-optimizer dag\n";
|
||||
}
|
||||
std::map<std::string, llvm::Optional<optimizer::Description>> descriptions;
|
||||
std::map<std::string, std::optional<optimizer::Description>> descriptions;
|
||||
for (auto &entry_dag : dags) {
|
||||
if (!constraint) {
|
||||
descriptions.insert(
|
||||
decltype(descriptions)::value_type(entry_dag.first, llvm::None));
|
||||
decltype(descriptions)::value_type(entry_dag.first, std::nullopt));
|
||||
continue;
|
||||
}
|
||||
optimizer::Description description = {*constraint,
|
||||
std::move(entry_dag.second)};
|
||||
llvm::Optional<optimizer::Description> opt_description{
|
||||
std::optional<optimizer::Description> opt_description{
|
||||
std::move(description)};
|
||||
descriptions.insert(decltype(descriptions)::value_type(
|
||||
entry_dag.first, std::move(opt_description)));
|
||||
@@ -191,7 +191,7 @@ transformHighLevelFHEOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
std::optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelizeLoops, bool batchOperations) {
|
||||
mlir::PassManager pm(&context);
|
||||
@@ -241,11 +241,12 @@ transformFHEBigInt(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
std::optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
|
||||
if (fheContext.hasValue() && fheContext->parameter.largeInteger.hasValue()) {
|
||||
if (fheContext.has_value() &&
|
||||
fheContext->parameter.largeInteger.has_value()) {
|
||||
pipelinePrinting("FHEToTFHECrt", pm, context);
|
||||
auto dec =
|
||||
fheContext.value().parameter.largeInteger.value().crtDecomposition;
|
||||
@@ -255,7 +256,7 @@ lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
mlir::concretelang::createConvertFHEToTFHECrtPass(
|
||||
mlir::concretelang::CrtLoweringParameters(mods)),
|
||||
enablePass);
|
||||
} else if (fheContext.hasValue()) {
|
||||
} else if (fheContext.has_value()) {
|
||||
pipelinePrinting("FHEToTFHEScalar", pm, context);
|
||||
size_t polySize = fheContext.value().parameter.getPolynomialSize();
|
||||
addPotentiallyNestedPass(
|
||||
@@ -270,16 +271,16 @@ lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
std::optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("TFHEToConcrete", pm, context);
|
||||
|
||||
if (fheContext.hasValue()) {
|
||||
if (fheContext.has_value()) {
|
||||
addPotentiallyNestedPass(
|
||||
pm,
|
||||
mlir::concretelang::createConvertTFHEGlobalParametrizationPass(
|
||||
fheContext.getValue()),
|
||||
fheContext.value()),
|
||||
enablePass);
|
||||
}
|
||||
|
||||
@@ -345,8 +346,12 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
mlir::bufferization::OneShotBufferizationOptions bufferizationOptions;
|
||||
bufferizationOptions.allowReturnAllocs = true;
|
||||
bufferizationOptions.printConflicts = true;
|
||||
bufferizationOptions.unknownTypeConversion = mlir::bufferization::
|
||||
OneShotBufferizationOptions::LayoutMapOption::IdentityLayoutMap;
|
||||
bufferizationOptions.unknownTypeConverterFn =
|
||||
[](Value value, Attribute memorySpace,
|
||||
const mlir::bufferization::BufferizationOptions &options) {
|
||||
return mlir::bufferization::getMemRefTypeWithStaticIdentityLayout(
|
||||
value.getType().cast<TensorType>(), memorySpace);
|
||||
};
|
||||
bufferizationOptions.bufferizeFunctionBoundaries = true;
|
||||
bufferizationOptions.createDeallocs = true;
|
||||
|
||||
@@ -354,6 +359,12 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
mlir::bufferization::createOneShotBufferizePass(bufferizationOptions);
|
||||
|
||||
addPotentiallyNestedPass(pm, std::move(comprBuffPass), enablePass);
|
||||
|
||||
// The bufferization may create `linalg.map` operations; Add another
|
||||
// conversion pass from linalg to loops
|
||||
addPotentiallyNestedPass(pm, mlir::createConvertLinalgToLoopsPass(),
|
||||
enablePass);
|
||||
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createBufferizeDataflowTaskOpsPass(), enablePass);
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
#include <optional>
|
||||
|
||||
#include "concrete/curves.h"
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
@@ -51,13 +52,13 @@ gateFromMLIRType(V0FHEContext fheContext, LweSecretKeyID secretKeyID,
|
||||
bool sign = type.isSignedInteger();
|
||||
|
||||
return CircuitGate{
|
||||
/*.encryption = */ llvm::None,
|
||||
/*.encryption = */ std::nullopt,
|
||||
/*.shape = */
|
||||
{/*.width = */ width,
|
||||
/*.dimensions = */ std::vector<int64_t>(),
|
||||
/*.size = */ 0,
|
||||
/* .sign */ sign},
|
||||
/*.chunkInfo = */ llvm::None,
|
||||
/*.chunkInfo = */ std::nullopt,
|
||||
};
|
||||
}
|
||||
if (auto lweTy = type.dyn_cast_or_null<
|
||||
@@ -70,7 +71,7 @@ gateFromMLIRType(V0FHEContext fheContext, LweSecretKeyID secretKeyID,
|
||||
size_t width;
|
||||
uint64_t size = 0;
|
||||
std::vector<int64_t> dims;
|
||||
if (chunkInfo.hasValue()) {
|
||||
if (chunkInfo.has_value()) {
|
||||
width = chunkInfo->size;
|
||||
assert(lweTy.getWidth() % chunkInfo->width == 0);
|
||||
size = lweTy.getWidth() / chunkInfo->width;
|
||||
@@ -79,7 +80,7 @@ gateFromMLIRType(V0FHEContext fheContext, LweSecretKeyID secretKeyID,
|
||||
width = (size_t)lweTy.getWidth();
|
||||
}
|
||||
return CircuitGate{
|
||||
/* .encryption = */ llvm::Optional<EncryptionGate>({
|
||||
/* .encryption = */ std::optional<EncryptionGate>({
|
||||
/* .secretKeyID = */ secretKeyID,
|
||||
/* .variance = */ variance,
|
||||
/* .encoding = */
|
||||
@@ -103,7 +104,7 @@ gateFromMLIRType(V0FHEContext fheContext, LweSecretKeyID secretKeyID,
|
||||
mlir::concretelang::FHE::EncryptedBooleanType>()) {
|
||||
size_t width = mlir::concretelang::FHE::EncryptedBooleanType::getWidth();
|
||||
return CircuitGate{
|
||||
/* .encryption = */ llvm::Optional<EncryptionGate>({
|
||||
/* .encryption = */ std::optional<EncryptionGate>({
|
||||
/* .secretKeyID = */ secretKeyID,
|
||||
/* .variance = */ variance,
|
||||
/* .encoding = */
|
||||
@@ -120,7 +121,7 @@ gateFromMLIRType(V0FHEContext fheContext, LweSecretKeyID secretKeyID,
|
||||
/*.size = */ 0,
|
||||
/*.sign = */ false,
|
||||
},
|
||||
/*.chunkInfo = */ llvm::None,
|
||||
/*.chunkInfo = */ std::nullopt,
|
||||
};
|
||||
}
|
||||
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
@@ -189,7 +190,7 @@ createClientParametersForV0(V0FHEContext fheContext,
|
||||
bskParam.inputLweDimension = v0Param.nSmall;
|
||||
c.bootstrapKeys.push_back(bskParam);
|
||||
}
|
||||
if (v0Param.largeInteger.hasValue()) {
|
||||
if (v0Param.largeInteger.has_value()) {
|
||||
clientlib::PackingKeyswitchKeyParam param;
|
||||
param.inputSecretKeyID = clientlib::BIG_KEY;
|
||||
param.outputSecretKeyID = clientlib::BIG_KEY;
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
@@ -200,7 +201,7 @@ llvm::Expected<V0Parameter> getParameter(optimizer::Description &descr,
|
||||
|
||||
auto sol = (!descr.dag || config.strategy_v0)
|
||||
? getV0Parameter(descr.constraint, config)
|
||||
: getV1Parameter(descr.dag.getValue(), config);
|
||||
: getV1Parameter(descr.dag.value(), config);
|
||||
|
||||
auto stop = chrono::high_resolution_clock::now();
|
||||
auto duration = chrono::duration_cast<chrono::milliseconds>(stop - start);
|
||||
@@ -235,7 +236,7 @@ llvm::Expected<V0Parameter> getParameter(optimizer::Description &descr,
|
||||
params.brLogBase = sol.br_decomposition_base_log;
|
||||
params.ksLevel = sol.ks_decomposition_level_count;
|
||||
params.ksLogBase = sol.ks_decomposition_base_log;
|
||||
params.largeInteger = llvm::None;
|
||||
params.largeInteger = std::nullopt;
|
||||
|
||||
if (sol.use_wop_pbs) {
|
||||
if (sol.crt_decomposition.empty()) {
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
// for license information.
|
||||
|
||||
#include <llvm/ADT/STLExtras.h>
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/SCF/IR/SCF.h>
|
||||
#include <mlir/Dialect/Tensor/IR/Tensor.h>
|
||||
@@ -206,7 +206,7 @@ struct BoundsAndStep {
|
||||
/// Returns the lower bound, upper bound and step of the quasi-affine
|
||||
/// expression `expr` on the the induction variable from a for
|
||||
/// operation.
|
||||
static llvm::Optional<BoundsAndStep>
|
||||
static std::optional<BoundsAndStep>
|
||||
getBoundsOfQuasiAffineIVExpression(mlir::Value expr, mlir::scf::ForOp forOp) {
|
||||
// Base case: expression is the induction variable itself -> return
|
||||
// loop bounds
|
||||
@@ -220,13 +220,13 @@ getBoundsOfQuasiAffineIVExpression(mlir::Value expr, mlir::scf::ForOp forOp) {
|
||||
if (llvm::isa<mlir::arith::AddIOp, mlir::arith::SubIOp, mlir::arith::MulIOp,
|
||||
mlir::arith::DivSIOp>(op)) {
|
||||
|
||||
llvm::Optional<BoundsAndStep> lhs =
|
||||
std::optional<BoundsAndStep> lhs =
|
||||
getBoundsOfQuasiAffineIVExpression(op->getOperand(0), forOp);
|
||||
llvm::Optional<BoundsAndStep> rhs =
|
||||
std::optional<BoundsAndStep> rhs =
|
||||
getBoundsOfQuasiAffineIVExpression(op->getOperand(1), forOp);
|
||||
|
||||
if (!lhs.hasValue() || !rhs.hasValue())
|
||||
return llvm::None;
|
||||
if (!lhs.has_value() || !rhs.has_value())
|
||||
return std::nullopt;
|
||||
|
||||
if (llvm::isa<mlir::arith::AddIOp>(op))
|
||||
return *lhs + *rhs;
|
||||
@@ -245,7 +245,7 @@ getBoundsOfQuasiAffineIVExpression(mlir::Value expr, mlir::scf::ForOp forOp) {
|
||||
// the divisor, there may be two iterations with the same
|
||||
// value. Conservatively bail out.
|
||||
if (lhs->step % rhsVal != 0)
|
||||
return llvm::None;
|
||||
return std::nullopt;
|
||||
|
||||
return *lhs / rhsVal;
|
||||
}
|
||||
@@ -370,10 +370,10 @@ isQuasiAffineIVExpressionWithConstantStep(mlir::Value expr,
|
||||
mlir::scf::ForOp tmpForOp;
|
||||
|
||||
if (isQuasiAffineIVExpression(expr, &tmpForOp)) {
|
||||
llvm::Optional<BoundsAndStep> bas =
|
||||
std::optional<BoundsAndStep> bas =
|
||||
getBoundsOfQuasiAffineIVExpression(expr, tmpForOp);
|
||||
|
||||
if (bas.hasValue()) {
|
||||
if (bas.has_value()) {
|
||||
if (forOp != nullptr)
|
||||
*forOp = tmpForOp;
|
||||
return true;
|
||||
@@ -409,10 +409,10 @@ mlir::Value hoistIndexedOp(
|
||||
|
||||
if (isAffine && forOp) {
|
||||
|
||||
llvm::Optional<BoundsAndStep> bas =
|
||||
std::optional<BoundsAndStep> bas =
|
||||
getBoundsOfQuasiAffineIVExpression(idx, forOp);
|
||||
|
||||
assert(bas.hasValue());
|
||||
assert(bas.has_value());
|
||||
assert(bas->step != 0);
|
||||
|
||||
offsets.push_back(rewriter.getIndexAttr(bas->lb));
|
||||
@@ -802,7 +802,7 @@ public:
|
||||
return mlir::WalkResult::skip();
|
||||
|
||||
if (!llvm::all_of(body, [&](mlir::Operation &op) {
|
||||
return MemoryEffectOpInterface::hasNoEffect(&op);
|
||||
return isMemoryEffectFree(&op);
|
||||
})) {
|
||||
return mlir::WalkResult::skip();
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
@@ -79,8 +79,7 @@ struct CollapseParallelLoopsPass
|
||||
if (maxPos > start)
|
||||
continue;
|
||||
|
||||
auto band =
|
||||
llvm::makeMutableArrayRef(loops.data() + start, end - start);
|
||||
auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
|
||||
(void)mlir::coalesceLoops(band);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
@@ -31,10 +31,11 @@ public:
|
||||
if (attr.getValue()) {
|
||||
rewriter.replaceOpWithNewOp<mlir::scf::ParallelOp>(
|
||||
forOp, mlir::ValueRange{forOp.getLowerBound()},
|
||||
mlir::ValueRange{forOp.getUpperBound()}, forOp.getStep(), llvm::None,
|
||||
mlir::ValueRange{forOp.getUpperBound()}, forOp.getStep(),
|
||||
std::nullopt,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location location,
|
||||
mlir::ValueRange indVar, mlir::ValueRange iterArgs) {
|
||||
mlir::BlockAndValueMapping map;
|
||||
mlir::IRMapping map;
|
||||
map.map(forOp.getInductionVar(), indVar.front());
|
||||
for (auto &op : forOp.getRegion().front()) {
|
||||
auto newOp = builder.clone(op, map);
|
||||
@@ -44,10 +45,10 @@ public:
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<mlir::scf::ForOp>(
|
||||
forOp, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(),
|
||||
llvm::None,
|
||||
std::nullopt,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location location,
|
||||
mlir::Value indVar, mlir::ValueRange iterArgs) {
|
||||
mlir::BlockAndValueMapping map;
|
||||
mlir::IRMapping map;
|
||||
map.map(forOp.getInductionVar(), indVar);
|
||||
for (auto &op : forOp.getRegion().front()) {
|
||||
auto newOp = builder.clone(op, map);
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include <mlir/Support/FileUtilities.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
#include <mlir/Support/ToolUtilities.h>
|
||||
#include <optional>
|
||||
#include <sstream>
|
||||
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
@@ -384,7 +385,7 @@ cmdlineCompilationOptions() {
|
||||
options.v0Parameter = {cmdline::v0Parameter[0], cmdline::v0Parameter[1],
|
||||
cmdline::v0Parameter[2], cmdline::v0Parameter[3],
|
||||
cmdline::v0Parameter[4], cmdline::v0Parameter[5],
|
||||
cmdline::v0Parameter[6], llvm::None};
|
||||
cmdline::v0Parameter[6], std::nullopt};
|
||||
}
|
||||
|
||||
// Setup the large integer options
|
||||
@@ -477,7 +478,7 @@ mlir::LogicalResult processInputBuffer(
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
|
||||
mlir::concretelang::CompilationContext::createShared();
|
||||
|
||||
std::string funcName = options.clientParametersFuncName.getValueOr("");
|
||||
std::string funcName = options.clientParametersFuncName.value_or("");
|
||||
if (action == Action::JIT_INVOKE) {
|
||||
auto lambdaOrErr =
|
||||
mlir::concretelang::ClientServer<mlir::concretelang::JITSupport>::
|
||||
|
||||
@@ -2,17 +2,16 @@
|
||||
|
||||
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-NEXT: module {
|
||||
// CHECK-NEXT: func.func @apply_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3x4x!FHE.eint<2>>) outs(%0 : tensor<2x3x4x!FHE.eint<2>>) {
|
||||
// CHECK-NEXT: ^bb0(%arg2: !FHE.eint<2>, %arg3: !FHE.eint<2>):
|
||||
// CHECK-NEXT: %2 = "FHE.apply_lookup_table"(%arg2, %arg1) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: linalg.yield %2 : !FHE.eint<2>
|
||||
// CHECK-NEXT: func.func @apply_lookup_table(%[[Varg0:.*]]: tensor<2x3x4x!FHE.eint<2>>, %[[Varg1:.*]]: tensor<4xi64>) -> tensor<2x3x4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = {{\[}}#map, #map{{\], iterator}}_types = {{\[}}"parallel", "parallel", "parallel"{{\]}}} ins(%[[Varg0]] : tensor<2x3x4x!FHE.eint<2>>) outs(%[[V0]] : tensor<2x3x4x!FHE.eint<2>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[Varg2:.*]]: !FHE.eint<2>, %[[Varg3:.*]]: !FHE.eint<2>):
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.apply_lookup_table"(%[[Varg2]], %[[Varg1]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: linalg.yield %[[V2]] : !FHE.eint<2>
|
||||
// CHECK-NEXT: } -> tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %1 : tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V1]] : tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func.func @apply_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!FHE.eint<2>>, tensor<4xi64>) -> (tensor<2x3x4x!FHE.eint<2>>)
|
||||
return %1: tensor<2x3x4x!FHE.eint<2>>
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
// CHECK: func.func @main(%[[a0:.*]]: tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x1x6x9x!FHE.eint<5>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = linalg.init_tensor [3, 2] : tensor<3x2xi64>
|
||||
// CHECK-NEXT: %[[v1:.*]] = tensor.empty() : tensor<3x2xi64>
|
||||
// CHECK-NEXT: %[[v2:.*]] = linalg.pooling_nchw_max {dilations = dense<1> : vector<2xi64>, max_signed = "FHE.max_eint", strides = dense<1> : vector<2xi64>} ins(%arg0, %1 : tensor<1x1x8x10x!FHE.eint<5>>, tensor<3x2xi64>) outs(%0 : tensor<1x1x6x9x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>>
|
||||
// CHECK-NEXT: return %[[v2]] : tensor<1x1x6x9x!FHE.eint<5>>
|
||||
// CHECK-NEXT: }
|
||||
@@ -23,12 +23,12 @@ func.func @main(%arg0: tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.ein
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x1x5x3x!FHE.esint<6>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = arith.constant dense<16> : tensor<1xi7>
|
||||
// CHECK-NEXT: %[[v2:.*]] = bufferization.alloc_tensor() : tensor<1x1x5x3x!FHE.esint<6>>
|
||||
// CHECK-NEXT: %[[v3:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[v0]], %[[v1]] : tensor<1x1x5x3x!FHE.esint<6>>, tensor<1xi7>) outs(%[[v2]] : tensor<1x1x5x3x!FHE.esint<6>>) {
|
||||
// CHECK-NEXT: %[[v3:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[v0]], %[[v1]] : tensor<1x1x5x3x!FHE.esint<6>>, tensor<1xi7>) outs(%[[v2]] : tensor<1x1x5x3x!FHE.esint<6>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.esint<6>, %[[aa1:.*]]: i7, %[[aa2:.*]]: !FHE.esint<6>):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.sub_eint_int"(%[[aa0]], %[[aa1]]) : (!FHE.esint<6>, i7) -> !FHE.esint<6>
|
||||
// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.esint<6>
|
||||
// CHECK-NEXT: } -> tensor<1x1x5x3x!FHE.esint<6>>
|
||||
// CHECK-NEXT: %[[v4:.*]] = linalg.init_tensor [2, 3] : tensor<2x3xi64>
|
||||
// CHECK-NEXT: %[[v4:.*]] = tensor.empty() : tensor<2x3xi64>
|
||||
// CHECK-NEXT: %[[v5:.*]] = linalg.pooling_nchw_max {dilations = dense<1> : vector<2xi64>, max_signed = "FHE.max_eint", strides = dense<1> : vector<2xi64>} ins(%arg0, %[[v4]] : tensor<1x1x6x5x!FHE.esint<6>>, tensor<2x3xi64>) outs(%[[v3]] : tensor<1x1x5x3x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>>
|
||||
// CHECK-NEXT: return %[[v5]] : tensor<1x1x5x3x!FHE.esint<6>>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
@@ -2,17 +2,16 @@
|
||||
|
||||
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-NEXT: module {
|
||||
// CHECK-NEXT: func.func @neg_eint(%arg0: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3x4x!FHE.eint<2>>) outs(%0 : tensor<2x3x4x!FHE.eint<2>>) {
|
||||
// CHECK-NEXT: ^bb0(%arg1: !FHE.eint<2>, %arg2: !FHE.eint<2>):
|
||||
// CHECK-NEXT: %2 = "FHE.neg_eint"(%arg1) : (!FHE.eint<2>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: linalg.yield %2 : !FHE.eint<2>
|
||||
// CHECK-NEXT: func.func @neg_eint(%[[Varg0:.*]]: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = {{\[}}#map, #map{{\], iterator}}_types = {{\[}}"parallel", "parallel", "parallel"{{\]}}} ins(%[[Varg0]] : tensor<2x3x4x!FHE.eint<2>>) outs(%[[V0]] : tensor<2x3x4x!FHE.eint<2>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[Varg1:.*]]: !FHE.eint<2>, %[[Varg2:.*]]: !FHE.eint<2>):
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.neg_eint"(%[[Varg1]]) : (!FHE.eint<2>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: linalg.yield %[[V2]] : !FHE.eint<2>
|
||||
// CHECK-NEXT: } -> tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %1 : tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V1]] : tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func.func @neg_eint(%arg0: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.neg_eint"(%arg0): (tensor<2x3x4x!FHE.eint<2>>) -> (tensor<2x3x4x!FHE.eint<2>>)
|
||||
return %1: tensor<2x3x4x!FHE.eint<2>>
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
|
||||
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-NEXT: module {
|
||||
// CHECK-NEXT: func.func @main(%arg0: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.esint<2>> {
|
||||
// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.esint<2>>
|
||||
// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3x4x!FHE.eint<2>>) outs(%0 : tensor<2x3x4x!FHE.esint<2>>) {
|
||||
// CHECK-NEXT: ^bb0(%arg1: !FHE.eint<2>, %arg2: !FHE.esint<2>):
|
||||
// CHECK-NEXT: %2 = "FHE.to_signed"(%arg1) : (!FHE.eint<2>) -> !FHE.esint<2>
|
||||
// CHECK-NEXT: linalg.yield %2 : !FHE.esint<2>
|
||||
// CHECK-NEXT: func.func @main(%[[Varg0:.*]]: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.esint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.esint<2>>
|
||||
// CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[Varg0]] : tensor<2x3x4x!FHE.eint<2>>) outs(%[[V0]] : tensor<2x3x4x!FHE.esint<2>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[Varg1:.*]]: !FHE.eint<2>, %[[Varg2:.*]]: !FHE.esint<2>):
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.to_signed"(%[[Varg1]]) : (!FHE.eint<2>) -> !FHE.esint<2>
|
||||
// CHECK-NEXT: linalg.yield %[[V2]] : !FHE.esint<2>
|
||||
// CHECK-NEXT: } -> tensor<2x3x4x!FHE.esint<2>>
|
||||
// CHECK-NEXT: return %1 : tensor<2x3x4x!FHE.esint<2>>
|
||||
// CHECK-NEXT: return %[[V1]] : tensor<2x3x4x!FHE.esint<2>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.esint<2>> {
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
|
||||
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-NEXT: module {
|
||||
// CHECK-NEXT: func.func @main(%arg0: tensor<2x3x4x!FHE.esint<2>>) -> tensor<2x3x4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3x4x!FHE.esint<2>>) outs(%0 : tensor<2x3x4x!FHE.eint<2>>) {
|
||||
// CHECK-NEXT: ^bb0(%arg1: !FHE.esint<2>, %arg2: !FHE.eint<2>):
|
||||
// CHECK-NEXT: %2 = "FHE.to_unsigned"(%arg1) : (!FHE.esint<2>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: linalg.yield %2 : !FHE.eint<2>
|
||||
// CHECK-NEXT: func.func @main(%[[Varg0:.*]]: tensor<2x3x4x!FHE.esint<2>>) -> tensor<2x3x4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[Varg0]] : tensor<2x3x4x!FHE.esint<2>>) outs(%[[V0]] : tensor<2x3x4x!FHE.eint<2>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[Varg1:.*]]: !FHE.esint<2>, %[[Varg2:.*]]: !FHE.eint<2>):
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.to_unsigned"(%[[Varg1]]) : (!FHE.esint<2>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: linalg.yield %[[V2]] : !FHE.eint<2>
|
||||
// CHECK-NEXT: } -> tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %1 : tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V1]] : tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: tensor<2x3x4x!FHE.esint<2>>) -> tensor<2x3x4x!FHE.eint<2>> {
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
//CHECK-LABEL: func.func @add_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>, %arg1: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
|
||||
//CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
//CHECK-NEXT: %c0 = arith.constant 0 : index
|
||||
//CHECK-NEXT: %c1 = arith.constant 1 : index
|
||||
//CHECK-NEXT: %c5 = arith.constant 5 : index
|
||||
//CHECK-NEXT: %1 = scf.for %arg2 = %c0 to %c5 step %c1 iter_args(%arg3 = %0) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
|
||||
//CHECK-NEXT: %2 = tensor.extract %arg0[%arg2] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
//CHECK-NEXT: %3 = tensor.extract %arg1[%arg2] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
//CHECK-NEXT: %4 = "TFHE.add_glwe"(%2, %3) : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
//CHECK-NEXT: %5 = tensor.insert %4 into %arg3[%arg2] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
//CHECK-NEXT: scf.yield %5 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
//CHECK: func.func @add_eint(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<{_,_,_}{7}>>, %[[Varg1:.*]]: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
|
||||
//CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
//CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
|
||||
//CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
|
||||
//CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg2:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg3:.*]] = %[[V0]]) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
|
||||
//CHECK-NEXT: %[[V2:.*]] = tensor.extract %[[Varg0]][%[[Varg2]]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
//CHECK-NEXT: %[[V3:.*]] = tensor.extract %[[Varg1]][%[[Varg2]]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
//CHECK-NEXT: %[[V4:.*]] = "TFHE.add_glwe"(%[[V2]], %[[V3]]) : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
//CHECK-NEXT: %[[V5:.*]] = tensor.insert %[[V4]] into %[[Varg3]][%[[Varg2]]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
//CHECK-NEXT: scf.yield %[[V5]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
//CHECK-NEXT: return %[[V1]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
func.func @add_eint(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @add_eint_int(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
|
||||
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %0 = arith.extsi %c1_i8 : i8 to i64
|
||||
// CHECK-NEXT: %1 = "TFHE.encode_plaintext_with_crt"(%0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
|
||||
// CHECK-NEXT: %2 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: %c0 = arith.constant 0 : index
|
||||
// CHECK-NEXT: %c1 = arith.constant 1 : index
|
||||
// CHECK-NEXT: %c5 = arith.constant 5 : index
|
||||
// CHECK-NEXT: %3 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %2) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
|
||||
// CHECK-NEXT: %4 = tensor.extract %arg0[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: %5 = tensor.extract %1[%arg1] : tensor<5xi64>
|
||||
// CHECK-NEXT: %6 = "TFHE.add_glwe_int"(%4, %5) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: %7 = tensor.insert %6 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: scf.yield %7 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK: func.func @add_eint_int(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
|
||||
// CHECK-NEXT: %[[Vc1_i8:.*]] = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %[[V0:.*]] = arith.extsi %[[Vc1_i8]] : i8 to i64
|
||||
// CHECK-NEXT: %[[V1:.*]] = "TFHE.encode_plaintext_with_crt"(%[[V0]]) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
|
||||
// CHECK-NEXT: %[[V2:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
|
||||
// CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
|
||||
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V2]]) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
|
||||
// CHECK-NEXT: %[[V4:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: %[[V5:.*]] = tensor.extract %[[V1]][%[[Varg1]]] : tensor<5xi64>
|
||||
// CHECK-NEXT: %[[V6:.*]] = "TFHE.add_glwe_int"(%[[V4]], %[[V5]]) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: %[[V7:.*]] = tensor.insert %[[V6]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: scf.yield %[[V7]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %3 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: return %[[V3]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
func.func @add_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
|
||||
@@ -1,98 +1,97 @@
|
||||
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
|
||||
//CHECK-LABEL: func.func @conv2d(%arg0: tensor<100x3x28x28x5x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4x3x14x14xi3>, %arg2: tensor<4xi3>) -> tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> {
|
||||
//CHECK-NEXT: %c4 = arith.constant 4 : index
|
||||
//CHECK-NEXT: %c100 = arith.constant 100 : index
|
||||
//CHECK-NEXT: %c15 = arith.constant 15 : index
|
||||
//CHECK-NEXT: %c0 = arith.constant 0 : index
|
||||
//CHECK-NEXT: %c1 = arith.constant 1 : index
|
||||
//CHECK-NEXT: %c3 = arith.constant 3 : index
|
||||
//CHECK-NEXT: %c14 = arith.constant 14 : index
|
||||
//CHECK-NEXT: %0 = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %1 = scf.for %arg3 = %c0 to %c100 step %c1 iter_args(%arg4 = %0) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %c4 step %c1 iter_args(%arg6 = %arg4) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %4 = scf.for %arg7 = %c0 to %c15 step %c1 iter_args(%arg8 = %arg6) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %5 = scf.for %arg9 = %c0 to %c15 step %c1 iter_args(%arg10 = %arg8) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %6 = tensor.extract %arg2[%arg5] : tensor<4xi3>
|
||||
//CHECK-NEXT: %c0_0 = arith.constant 0 : index
|
||||
//CHECK-NEXT: %7 = tensor.extract_slice %0[%arg3, %arg5, %arg7, %arg9, %c0_0] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %8 = arith.extsi %6 : i3 to i64
|
||||
//CHECK-NEXT: %9 = "TFHE.encode_plaintext_with_crt"(%8) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
|
||||
//CHECK-NEXT: %10 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %c0_1 = arith.constant 0 : index
|
||||
//CHECK-NEXT: %c1_2 = arith.constant 1 : index
|
||||
//CHECK-NEXT: %c5 = arith.constant 5 : index
|
||||
//CHECK-NEXT: %11 = scf.for %arg11 = %c0_1 to %c5 step %c1_2 iter_args(%arg12 = %10) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %13 = tensor.extract %7[%arg11] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %14 = tensor.extract %9[%arg11] : tensor<5xi64>
|
||||
//CHECK-NEXT: %15 = "TFHE.add_glwe_int"(%13, %14) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: %16 = tensor.insert %15 into %arg12[%arg11] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: scf.yield %16 : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK: func.func @conv2d(%[[Varg0:.*]]: tensor<100x3x28x28x5x!TFHE.glwe<{_,_,_}{2}>>, %[[Varg1:.*]]: tensor<4x3x14x14xi3>, %[[Varg2:.*]]: tensor<4xi3>) -> tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> {
|
||||
//CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vc100:.*]] = arith.constant 100 : index
|
||||
//CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
|
||||
//CHECK-NEXT: %[[Vc4:.*]] = arith.constant 4 : index
|
||||
//CHECK-NEXT: %[[Vc15:.*]] = arith.constant 15 : index
|
||||
//CHECK-NEXT: %[[Vc3:.*]] = arith.constant 3 : index
|
||||
//CHECK-NEXT: %[[Vc14:.*]] = arith.constant 14 : index
|
||||
//CHECK-NEXT: %[[V0:.*]] = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V0]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[Varg2]]{{\[}}%[[Varg5]]{{\]}} : tensor<4xi3>
|
||||
//CHECK-NEXT: %[[Vc0_0:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vextracted_slice:.*]] = tensor.extract_slice %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_0]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %[[V6:.*]] = arith.extsi %[[Vextracted]] : i3 to i64
|
||||
//CHECK-NEXT: %[[V7:.*]] = "TFHE.encode_plaintext_with_crt"(%[[V6]]) {mods = {{\[2, 3, 5, 7, 11\], modsProd}} = 2310 : i64} : (i64) -> tensor<5xi64>
|
||||
//CHECK-NEXT: %[[V8:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %[[Vc0_1:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vc1_2:.*]] = arith.constant 1 : index
|
||||
//CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
|
||||
//CHECK-NEXT: %[[V9:.*]] = scf.for %[[Varg11:.*]] = %[[Vc0_1]] to %[[Vc5]] step %[[Vc1_2]] iter_args(%[[Varg12:.*]] = %[[V8]]) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %[[Vextracted_4:.*]] = tensor.extract %[[Vextracted_slice]]{{\[}}%[[Varg11]]{{\]}} : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %[[Vextracted_5:.*]] = tensor.extract %[[V7]]{{\[}}%[[Varg11]]{{\]}} : tensor<5xi64>
|
||||
//CHECK-NEXT: %[[V10:.*]] = "TFHE.add_glwe_int"(%[[Vextracted_4]], %[[Vextracted_5]]) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V10]] into %[[Varg12]]{{\[}}%[[Varg11]]{{\]}} : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: %c0_3 = arith.constant 0 : index
|
||||
//CHECK-NEXT: %12 = tensor.insert_slice %11 into %arg10[%arg3, %arg5, %arg7, %arg9, %c0_3] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> into tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: scf.yield %12 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %[[Vc0_3:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vinserted_slice:.*]] = tensor.insert_slice %[[V9]] into %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_3]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<5x!TFHE.glwe<{_,_,_}{2}>> into tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: scf.yield %[[Vinserted_slice]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: scf.yield %5 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: scf.yield %4 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: scf.yield %3 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: %2 = scf.for %arg3 = %c0 to %c100 step %c1 iter_args(%arg4 = %1) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %c4 step %c1 iter_args(%arg6 = %arg4) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %4 = scf.for %arg7 = %c0 to %c15 step %c1 iter_args(%arg8 = %arg6) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %5 = scf.for %arg9 = %c0 to %c15 step %c1 iter_args(%arg10 = %arg8) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %6 = scf.for %arg11 = %c0 to %c3 step %c1 iter_args(%arg12 = %arg10) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %7 = scf.for %arg13 = %c0 to %c14 step %c1 iter_args(%arg14 = %arg12) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %8 = scf.for %arg15 = %c0 to %c14 step %c1 iter_args(%arg16 = %arg14) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %9 = affine.apply #map(%arg7, %arg13)
|
||||
//CHECK-NEXT: %10 = affine.apply #map(%arg9, %arg15)
|
||||
//CHECK-NEXT: %c0_0 = arith.constant 0 : index
|
||||
//CHECK-NEXT: %11 = tensor.extract_slice %arg0[%arg3, %arg11, %9, %10, %c0_0] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<100x3x28x28x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %12 = tensor.extract %arg1[%arg5, %arg11, %arg13, %arg15] : tensor<4x3x14x14xi3>
|
||||
//CHECK-NEXT: %c0_1 = arith.constant 0 : index
|
||||
//CHECK-NEXT: %13 = tensor.extract_slice %1[%arg3, %arg5, %arg7, %arg9, %c0_1] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %14 = arith.extsi %12 : i3 to i64
|
||||
//CHECK-NEXT: %15 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %c0_2 = arith.constant 0 : index
|
||||
//CHECK-NEXT: %c1_3 = arith.constant 1 : index
|
||||
//CHECK-NEXT: %c5 = arith.constant 5 : index
|
||||
//CHECK-NEXT: %16 = scf.for %arg17 = %c0_2 to %c5 step %c1_3 iter_args(%arg18 = %15) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %20 = tensor.extract %11[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %21 = "TFHE.mul_glwe_int"(%20, %14) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: %22 = tensor.insert %21 into %arg18[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: scf.yield %22 : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V1]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %[[V6:.*]] = scf.for %[[Varg11:.*]] = %[[Vc0]] to %[[Vc3]] step %[[Vc1]] iter_args(%[[Varg12:.*]] = %[[Varg10]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %[[V7:.*]] = scf.for %[[Varg13:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg14:.*]] = %[[Varg12]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %[[V8:.*]] = scf.for %[[Varg15:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg16:.*]] = %[[Varg14]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %[[V9:.*]] = affine.apply #map(%[[Varg7]], %[[Varg13]])
|
||||
//CHECK-NEXT: %[[V10:.*]] = affine.apply #map(%[[Varg9]], %[[Varg15]])
|
||||
//CHECK-NEXT: %[[Vc0_0:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vextracted_slice:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[Varg3]], %[[Varg11]], %[[V9]], %[[V10]], %[[Vc0_0]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<100x3x28x28x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[Varg1]]{{\[}}%[[Varg5]], %[[Varg11]], %[[Varg13]], %[[Varg15]]{{\]}} : tensor<4x3x14x14xi3>
|
||||
//CHECK-NEXT: %[[Vc0_1:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vextracted_slice_2:.*]] = tensor.extract_slice %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_1]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %[[V11:.*]] = arith.extsi %[[Vextracted]] : i3 to i64
|
||||
//CHECK-NEXT: %[[V12:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %[[Vc0_3:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vc1_4:.*]] = arith.constant 1 : index
|
||||
//CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
|
||||
//CHECK-NEXT: %[[V13:.*]] = scf.for %[[Varg17:.*]] = %[[Vc0_3]] to %[[Vc5]] step %[[Vc1_4]] iter_args(%[[Varg18:.*]] = %[[V12]]) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %[[Vextracted_9:.*]] = tensor.extract %[[Vextracted_slice]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %[[V16:.*]] = "TFHE.mul_glwe_int"(%[[Vextracted_9]], %[[V11]]) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V16]] into %[[Varg18]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: %17 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %c0_4 = arith.constant 0 : index
|
||||
//CHECK-NEXT: %c1_5 = arith.constant 1 : index
|
||||
//CHECK-NEXT: %c5_6 = arith.constant 5 : index
|
||||
//CHECK-NEXT: %18 = scf.for %arg17 = %c0_4 to %c5_6 step %c1_5 iter_args(%arg18 = %17) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %20 = tensor.extract %13[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %21 = tensor.extract %16[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %22 = "TFHE.add_glwe"(%20, %21) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: %23 = tensor.insert %22 into %arg18[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: scf.yield %23 : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %[[V14:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %[[Vc0_5:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vc1_6:.*]] = arith.constant 1 : index
|
||||
//CHECK-NEXT: %[[Vc5_7:.*]] = arith.constant 5 : index
|
||||
//CHECK-NEXT: %[[V15:.*]] = scf.for %[[Varg17:.*]] = %[[Vc0_5]] to %[[Vc5_7]] step %[[Vc1_6]] iter_args(%[[Varg18:.*]] = %[[V14]]) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: %[[Vextracted_9:.*]] = tensor.extract %[[Vextracted_slice_2]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %[[Vextracted_10:.*]] = tensor.extract %[[V13]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %[[V16:.*]] = "TFHE.add_glwe"(%[[Vextracted_9]], %[[Vextracted_10]]) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V16]] into %[[Varg18]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: %c0_7 = arith.constant 0 : index
|
||||
//CHECK-NEXT: %19 = tensor.insert_slice %18 into %arg16[%arg3, %arg5, %arg7, %arg9, %c0_7] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> into tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: scf.yield %19 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %[[Vc0_8:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vinserted_slice:.*]] = tensor.insert_slice %[[V15]] into %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_8]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<5x!TFHE.glwe<{_,_,_}{2}>> into tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: scf.yield %[[Vinserted_slice]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: scf.yield %8 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: scf.yield %[[V8]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: scf.yield %7 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: scf.yield %[[V7]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: scf.yield %6 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: scf.yield %[[V6]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: scf.yield %5 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: scf.yield %4 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: scf.yield %3 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: return %2 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: return %[[V2]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: }
|
||||
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0, 0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @mul_eint_int(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
|
||||
// CHECK-NEXT: %c2_i8 = arith.constant 2 : i8
|
||||
// CHECK-NEXT: %0 = arith.extsi %c2_i8 : i8 to i64
|
||||
// CHECK-NEXT: %1 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: %c0 = arith.constant 0 : index
|
||||
// CHECK-NEXT: %c1 = arith.constant 1 : index
|
||||
// CHECK-NEXT: %c5 = arith.constant 5 : index
|
||||
// CHECK-NEXT: %2 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %1) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
|
||||
// CHECK-NEXT: %3 = tensor.extract %arg0[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: %4 = "TFHE.mul_glwe_int"(%3, %0) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: %5 = tensor.insert %4 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: scf.yield %5 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK: func.func @mul_eint_int(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
|
||||
// CHECK-NEXT: %[[Vc2_i8:.*]] = arith.constant 2 : i8
|
||||
// CHECK-NEXT: %[[V0:.*]] = arith.extsi %[[Vc2_i8]] : i8 to i64
|
||||
// CHECK-NEXT: %[[V1:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
|
||||
// CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
|
||||
// CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V1]]) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
|
||||
// CHECK-NEXT: %[[V3:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: %[[V4:.*]] = "TFHE.mul_glwe_int"(%[[V3]], %[[V0]]) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: %[[V5:.*]] = tensor.insert %[[V4]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: scf.yield %[[V5]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %2 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: return %[[V2]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
func.func @mul_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%0 = arith.constant 2 : i8
|
||||
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @neg_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
|
||||
// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: %c0 = arith.constant 0 : index
|
||||
// CHECK-NEXT: %c1 = arith.constant 1 : index
|
||||
// CHECK-NEXT: %c5 = arith.constant 5 : index
|
||||
// CHECK-NEXT: %1 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %0) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
|
||||
// CHECK-NEXT: %2 = tensor.extract %arg0[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: %3 = "TFHE.neg_glwe"(%2) : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: %4 = tensor.insert %3 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: scf.yield %4 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK: func.func @neg_eint(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
|
||||
// CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
|
||||
// CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V0]]) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
|
||||
// CHECK-NEXT: %[[V2:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "TFHE.neg_glwe"(%[[V2]]) : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: %[[V4:.*]] = tensor.insert %[[V3]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: scf.yield %[[V4]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: return %[[V1]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: }
|
||||
func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user