Rebase onto llvm-project 465ee9bfb26d with local changes

This commit rebases the compiler onto commit 465ee9bfb26d from
llvm-project with locally maintained patches on top, i.e.:

  * 5d8669d669ee: Fix the element alignment (size) for memrefCopy
  * 4239163ea337: fix: Do not fold the memref.subview if the offset are
                  != 0 and strides != 1
  * 72c5decfcc21: remove github stuff from llvm
  * 8d0ce8f9eca1: Support arbitrary element types in named operations
                  via attributes
  * 94f64805c38c: Copy attributes of scf.for on bufferization and make
                  it an allocation hoisting barrier

Main upstream changes from llvm-project that required modification of
concretecompiler:

  * Switch to C++17
  * Various changes in the interfaces for linalg named operations
  * Transition from `llvm::Optional` to `std::optional`
  * Use of enums instead of string values for iterator types in linalg
  * Changed default naming convention of getter methods in
    ODS-generated operation classes from `some_value()` to
    `getSomeValue()`
  * Renaming of Arithmetic dialect to Arith
  * Refactoring of side effect interfaces (i.e., renaming from
    `NoSideEffect` to `Pure`)
  * Re-design of the data flow analysis framework
  * Refactoring of build targets for Python bindings
  * Refactoring of array attributes with integer values
  * Renaming of `linalg.init_tensor` to `tensor.empty`
  * Emission of `linalg.map` operations in bufferization of the Tensor
    dialect requiring another linalg conversion pass and registration
    of the bufferization op interfaces for linalg operations
  * Refactoring of the one-shot bufferizer
  * Necessity to run the expand-strided-metadata, affine-to-std and
    finalize-memref-to-llvm passes before converson to the LLVM
    dialect
  * Renaming of `BlockAndValueMapping` to `IRMapping`
  * Changes in the build function of `LLVM::CallOp`
  * Refactoring of the construction of `llvm::ArrayRef` and
    `llvm::MutableArrayRef` (direct invocation of constructor instead
    of builder functions for some cases)
  * New naming conventions for generated SSA values requiring rewrite
    of some check tests
  * Refactoring of `mlir::LLVM::lookupOrCreateMallocFn()`
  * Interface changes in generated type parsers
  * New dependencies for to mlir_float16_utils and
    MLIRSparseTensorRuntime for the runtime
  * Overhaul of MLIR-c deleting `mlir-c/Registration.h`
  * Deletion of library MLIRLinalgToSPIRV
  * Deletion of library MLIRLinalgAnalysis
  * Deletion of library MLIRMemRefUtils
  * Deletion of library MLIRQuantTransforms
  * Deletion of library MLIRVectorToROCDL
This commit is contained in:
Andi Drebes
2023-02-20 12:04:43 +01:00
committed by Quentin Bourgerie
parent 8ebfccd9a7
commit c8c969773e
115 changed files with 1383 additions and 1355 deletions

View File

@@ -4,7 +4,7 @@ project(concretecompiler LANGUAGES C CXX)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# Needed on linux with clang 15 and on MacOS because cxx emits dollars in the optimizer C++ API

View File

@@ -416,7 +416,7 @@ rust-format:
# libraries we want to have in the installation that aren't already a deps of other targets
install-deps:
cmake --build $(BUILD_DIR) --target MLIRCAPIRegistration
cmake --build $(BUILD_DIR) --target MLIRCAPIRegisterEverything
ifeq ($(OS), darwin)
# rsync should normally come pre-installed on macOS

View File

@@ -7,7 +7,6 @@
#define CONCRETELANG_C_DIALECT_FHE_H
#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
#ifdef __cplusplus
extern "C" {

View File

@@ -7,7 +7,6 @@
#define CONCRETELANG_C_DIALECT_FHELINALG_H
#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
#ifdef __cplusplus
extern "C" {

View File

@@ -104,7 +104,7 @@ library_get_client_parameters_path(LibrarySupport_Py support);
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::KeySet>
key_set(concretelang::clientlib::ClientParameters clientParameters,
llvm::Optional<concretelang::clientlib::KeySetCache> cache);
std::optional<concretelang::clientlib::KeySetCache> cache);
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,

View File

@@ -184,11 +184,11 @@ static inline bool operator==(const ChunkInfo &lhs, const ChunkInfo &rhs) {
}
struct CircuitGate {
llvm::Optional<EncryptionGate> encryption;
std::optional<EncryptionGate> encryption;
CircuitGateShape shape;
llvm::Optional<ChunkInfo> chunkInfo;
std::optional<ChunkInfo> chunkInfo;
bool isEncrypted() { return encryption.hasValue(); }
bool isEncrypted() { return encryption.has_value(); }
/// byteSize returns the size in bytes for this gate.
size_t byteSize(std::vector<LweSecretKeyParam> secretKeys) {
@@ -240,7 +240,7 @@ struct ClientParameters {
outcome::checked<LweSecretKeyParam, StringError>
lweSecretKeyParam(CircuitGate gate) {
if (!gate.encryption.hasValue()) {
if (!gate.encryption.has_value()) {
return StringError("gate is not encrypted");
}
assert(gate.encryption->secretKeyID < secretKeys.size());
@@ -250,7 +250,7 @@ struct ClientParameters {
/// bufferSize returns the size of the whole buffer of a gate.
int64_t bufferSize(CircuitGate gate) {
if (!gate.encryption.hasValue()) {
if (!gate.encryption.has_value()) {
// Value is not encrypted just returns the tensor size
return gate.shape.size;
}
@@ -261,7 +261,7 @@ struct ClientParameters {
/// lweBufferSize returns the size of one ciphertext of a gate.
int64_t lweBufferSize(CircuitGate gate) {
assert(gate.encryption.hasValue());
assert(gate.encryption.has_value());
auto nbBlocks = gate.encryption->encoding.crt.size();
nbBlocks = nbBlocks == 0 ? 1 : nbBlocks;
@@ -273,7 +273,7 @@ struct ClientParameters {
/// bufferShape returns the shape of the tensor for the given gate. It returns
/// the shape used at low-level, i.e. contains the dimensions for ciphertexts.
std::vector<int64_t> bufferShape(CircuitGate gate) {
if (!gate.encryption.hasValue()) {
if (!gate.encryption.has_value()) {
// Value is not encrypted just returns the tensor shape
return gate.shape.dimensions;
}

View File

@@ -159,7 +159,7 @@ public:
// Set sizes
std::vector<int64_t> sizes = keySet.clientParameters().bufferShape(input);
if (input.encryption.hasValue()) {
if (input.encryption.has_value()) {
TensorData td(sizes, EncryptedScalarElementType,
EncryptedScalarElementWidth);

View File

@@ -111,7 +111,7 @@ private:
///////////////////////////////////////////////
// Convenient positional mapping between positional gate en secret key
typedef std::vector<std::pair<CircuitGate, llvm::Optional<LweSecretKey>>>
typedef std::vector<std::pair<CircuitGate, std::optional<LweSecretKey>>>
SecretKeyGateMapping;
outcome::checked<SecretKeyGateMapping, StringError>
mapCircuitGateLweSecretKey(std::vector<CircuitGate> gates);

View File

@@ -113,7 +113,7 @@ struct PublicResult {
// Chunked integers are represented as tensors at a lower level, so we need
// to deal with them as tensors, then build the resulting scalar out of the
// tensor values
if (gate.chunkInfo.hasValue()) {
if (gate.chunkInfo.has_value()) {
OUTCOME_TRY(std::vector<uint64_t> decryptedChunks,
this->asClearTextVector<uint64_t>(keySet, pos));
uint64_t decrypted = fromChunks(decryptedChunks, gate.chunkInfo->width);

View File

@@ -6,7 +6,7 @@
#ifndef CONCRETELANG_TRANSFORMS_PASSES_H
#define CONCRETELANG_TRANSFORMS_PASSES_H
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"

View File

@@ -79,7 +79,7 @@ def SDFGToStreamEmulator : Pass<"sdfg-to-stream-emulator", "mlir::ModuleOp"> {
def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> {
let summary = "Lowers operations from MLIR lowerable dialects to LLVM";
let constructor = "mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass()";
let dependentDialects = ["mlir::func::FuncDialect", "mlir::arith::ArithmeticDialect", "mlir::scf::SCFDialect", "mlir::LLVM::LLVMDialect"];
let dependentDialects = ["mlir::func::FuncDialect", "mlir::arith::ArithDialect", "mlir::scf::SCFDialect", "mlir::LLVM::LLVMDialect"];
let options = [];
}

View File

@@ -52,7 +52,7 @@ struct V0Parameter {
size_t ksLevel;
size_t ksLogBase;
llvm::Optional<LargeIntegerParameter> largeInteger;
std::optional<LargeIntegerParameter> largeInteger;
// TODO remove the shift when we have true polynomial size
size_t getPolynomialSize() { return 1 << logPolynomialSize; }

View File

@@ -31,7 +31,7 @@ class Concrete_Op<string mnemonic, list<Trait> traits = []> :
Op<Concrete_Dialect, mnemonic, traits>;
def Concrete_AddLweTensorOp : Concrete_Op<"add_lwe_tensor", [NoSideEffect]> {
def Concrete_AddLweTensorOp : Concrete_Op<"add_lwe_tensor", [Pure]> {
let summary = "Returns the sum of 2 lwe ciphertexts";
let arguments = (ins
@@ -51,7 +51,7 @@ def Concrete_AddLweBufferOp : Concrete_Op<"add_lwe_buffer"> {
);
}
def Concrete_AddPlaintextLweTensorOp : Concrete_Op<"add_plaintext_lwe_tensor", [NoSideEffect]> {
def Concrete_AddPlaintextLweTensorOp : Concrete_Op<"add_plaintext_lwe_tensor", [Pure]> {
let summary = "Returns the sum of a clear integer and an lwe ciphertext";
let arguments = (ins Concrete_LweTensor:$lhs, I64:$rhs);
@@ -68,7 +68,7 @@ def Concrete_AddPlaintextLweBufferOp : Concrete_Op<"add_plaintext_lwe_buffer"> {
);
}
def Concrete_MulCleartextLweTensorOp : Concrete_Op<"mul_cleartext_lwe_tensor", [NoSideEffect]> {
def Concrete_MulCleartextLweTensorOp : Concrete_Op<"mul_cleartext_lwe_tensor", [Pure]> {
let summary = "Returns the product of a clear integer and a lwe ciphertext";
let arguments = (ins Concrete_LweTensor:$lhs, I64:$rhs);
@@ -85,7 +85,7 @@ def Concrete_MulCleartextLweBufferOp : Concrete_Op<"mul_cleartext_lwe_buffer"> {
);
}
def Concrete_NegateLweTensorOp : Concrete_Op<"negate_lwe_tensor", [NoSideEffect]> {
def Concrete_NegateLweTensorOp : Concrete_Op<"negate_lwe_tensor", [Pure]> {
let summary = "Negates a lwe ciphertext";
let arguments = (ins Concrete_LweTensor:$ciphertext);
@@ -101,7 +101,7 @@ def Concrete_NegateLweBufferOp : Concrete_Op<"negate_lwe_buffer"> {
);
}
def Concrete_EncodeExpandLutForBootstrapTensorOp : Concrete_Op<"encode_expand_lut_for_bootstrap_tensor", [NoSideEffect]> {
def Concrete_EncodeExpandLutForBootstrapTensorOp : Concrete_Op<"encode_expand_lut_for_bootstrap_tensor", [Pure]> {
let summary =
"Encode and expand a lookup table so that it can be used for a bootstrap";
@@ -128,7 +128,7 @@ def Concrete_EncodeExpandLutForBootstrapBufferOp : Concrete_Op<"encode_expand_lu
);
}
def Concrete_EncodeLutForCrtWopPBSTensorOp : Concrete_Op<"encode_lut_for_crt_woppbs_tensor", [NoSideEffect]> {
def Concrete_EncodeLutForCrtWopPBSTensorOp : Concrete_Op<"encode_lut_for_crt_woppbs_tensor", [Pure]> {
let summary =
"Encode and expand a lookup table so that it can be used for a wop pbs";
@@ -157,7 +157,7 @@ def Concrete_EncodeLutForCrtWopPBSBufferOp : Concrete_Op<"encode_lut_for_crt_wop
);
}
def Concrete_EncodePlaintextWithCrtTensorOp : Concrete_Op<"encode_plaintext_with_crt_tensor", [NoSideEffect]> {
def Concrete_EncodePlaintextWithCrtTensorOp : Concrete_Op<"encode_plaintext_with_crt_tensor", [Pure]> {
let summary =
"Encodes a plaintext by decomposing it on a crt basis";
@@ -182,7 +182,7 @@ def Concrete_EncodePlaintextWithCrtBufferOp : Concrete_Op<"encode_plaintext_with
);
}
def Concrete_BootstrapLweTensorOp : Concrete_Op<"bootstrap_lwe_tensor", [NoSideEffect]> {
def Concrete_BootstrapLweTensorOp : Concrete_Op<"bootstrap_lwe_tensor", [Pure]> {
let summary = "Bootstraps an LWE ciphertext with a GLWE trivial encryption of the lookup table";
let arguments = (ins
@@ -214,7 +214,7 @@ def Concrete_BootstrapLweTensorOp : Concrete_Op<"bootstrap_lwe_tensor", [NoSideE
return builder.create<BatchedBootstrapLweTensorOp>(
mlir::TypeRange{resType},
mlir::ValueRange{batchedOperands, lookup_table()},
mlir::ValueRange{batchedOperands, getLookupTable()},
getOperation()->getAttrs());
}
}];
@@ -236,7 +236,7 @@ def Concrete_BootstrapLweBufferOp : Concrete_Op<"bootstrap_lwe_buffer"> {
);
}
def Concrete_BatchedBootstrapLweTensorOp : Concrete_Op<"batched_bootstrap_lwe_tensor", [NoSideEffect]> {
def Concrete_BatchedBootstrapLweTensorOp : Concrete_Op<"batched_bootstrap_lwe_tensor", [Pure]> {
let summary = "Batched version of BootstrapLweOp, which performs the same operation on multiple elements";
let arguments = (ins
@@ -268,7 +268,7 @@ def Concrete_BatchedBootstrapLweBufferOp : Concrete_Op<"batched_bootstrap_lwe_bu
);
}
def Concrete_KeySwitchLweTensorOp : Concrete_Op<"keyswitch_lwe_tensor", [NoSideEffect]> {
def Concrete_KeySwitchLweTensorOp : Concrete_Op<"keyswitch_lwe_tensor", [Pure]> {
let summary = "Keyswitches an LWE ciphertext";
let arguments = (ins
@@ -317,7 +317,7 @@ def Concrete_KeySwitchLweBufferOp : Concrete_Op<"keyswitch_lwe_buffer"> {
);
}
def Concrete_BatchedKeySwitchLweTensorOp : Concrete_Op<"batched_keyswitch_lwe_tensor", [NoSideEffect]> {
def Concrete_BatchedKeySwitchLweTensorOp : Concrete_Op<"batched_keyswitch_lwe_tensor", [Pure]> {
let summary = "Batched version of KeySwitchLweOp, which performs the same operation on multiple elements";
let arguments = (ins
@@ -344,7 +344,7 @@ def Concrete_BatchedKeySwitchLweBufferOp : Concrete_Op<"batched_keyswitch_lwe_bu
);
}
def Concrete_WopPBSCRTLweTensorOp : Concrete_Op<"wop_pbs_crt_lwe_tensor", [NoSideEffect]> {
def Concrete_WopPBSCRTLweTensorOp : Concrete_Op<"wop_pbs_crt_lwe_tensor", [Pure]> {
let arguments = (ins
Concrete_LweCRTTensor:$ciphertext,
Concrete_CrtLutsTensor:$lookupTable,

View File

@@ -17,7 +17,7 @@ namespace mlir {
namespace concretelang {
namespace optimizer {
using FunctionsDag = std::map<std::string, llvm::Optional<Dag>>;
using FunctionsDag = std::map<std::string, std::optional<Dag>>;
std::unique_ptr<mlir::Pass> createDagPass(optimizer::Config config,
optimizer::FunctionsDag &dags);

View File

@@ -19,6 +19,7 @@ def FHE_Dialect : Dialect {
}];
let cppNamespace = "::mlir::concretelang::FHE";
let useDefaultTypePrinterParser = 1;
let useFoldAPI = kEmitRawAttributesFolder;
}
#endif

View File

@@ -18,7 +18,7 @@ include "concretelang/Dialect/FHE/IR/FHETypes.td"
class FHE_Op<string mnemonic, list<Trait> traits = []> :
Op<FHE_Dialect, mnemonic, traits>;
def FHE_ZeroEintOp : FHE_Op<"zero", [NoSideEffect]> {
def FHE_ZeroEintOp : FHE_Op<"zero", [Pure]> {
let summary = "Returns a trivial encrypted integer of 0";
let description = [{
@@ -35,7 +35,7 @@ def FHE_ZeroEintOp : FHE_Op<"zero", [NoSideEffect]> {
let results = (outs FHE_AnyEncryptedInteger:$out);
}
def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [NoSideEffect]> {
def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [Pure]> {
let summary = "Creates a new tensor with all elements initialized to an encrypted zero.";
let description = [{
@@ -53,7 +53,7 @@ def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [NoSideEffect]> {
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$tensor);
}
def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [NoSideEffect]> {
def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [Pure]> {
let summary = "Adds an encrypted integer and a clear integer";
let description = [{
@@ -87,7 +87,7 @@ def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [NoSideEffect]> {
let hasFolder = 1;
}
def FHE_AddEintOp : FHE_Op<"add_eint", [NoSideEffect]> {
def FHE_AddEintOp : FHE_Op<"add_eint", [Pure]> {
let summary = "Adds two encrypted integers";
let description = [{
@@ -120,7 +120,7 @@ def FHE_AddEintOp : FHE_Op<"add_eint", [NoSideEffect]> {
let hasVerifier = 1;
}
def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [NoSideEffect]> {
def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [Pure]> {
let summary = "Subtract an encrypted integer from a clear integer";
let description = [{
@@ -153,7 +153,7 @@ def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [NoSideEffect]> {
let hasVerifier = 1;
}
def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [NoSideEffect]> {
def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [Pure]> {
let summary = "Subtract a clear integer from an encrypted integer";
let description = [{
@@ -187,7 +187,7 @@ def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [NoSideEffect]> {
let hasFolder = 1;
}
def FHE_SubEintOp : FHE_Op<"sub_eint", [NoSideEffect]> {
def FHE_SubEintOp : FHE_Op<"sub_eint", [Pure]> {
let summary = "Subtract an encrypted integer from an encrypted integer";
let description = [{
@@ -220,7 +220,7 @@ def FHE_SubEintOp : FHE_Op<"sub_eint", [NoSideEffect]> {
let hasVerifier = 1;
}
def FHE_NegEintOp : FHE_Op<"neg_eint", [NoSideEffect]> {
def FHE_NegEintOp : FHE_Op<"neg_eint", [Pure]> {
let summary = "Negates an encrypted integer";
@@ -251,7 +251,7 @@ def FHE_NegEintOp : FHE_Op<"neg_eint", [NoSideEffect]> {
let hasVerifier = 1;
}
def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [NoSideEffect]> {
def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [Pure]> {
let summary = "Multiply an encrypted integer with a clear integer";
let description = [{
@@ -285,7 +285,7 @@ def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [NoSideEffect]> {
let hasFolder = 1;
}
def FHE_MulEintOp : FHE_Op<"mul_eint", [NoSideEffect]> {
def FHE_MulEintOp : FHE_Op<"mul_eint", [Pure]> {
let summary = "Multiplies two encrypted integers";
let description = [{
@@ -323,7 +323,7 @@ def FHE_MulEintOp : FHE_Op<"mul_eint", [NoSideEffect]> {
let hasVerifier = 1;
}
def FHE_MaxEintOp : FHE_Op<"max_eint", [NoSideEffect]> {
def FHE_MaxEintOp : FHE_Op<"max_eint", [Pure]> {
let summary = "Get maximum of two encrypted integers.";
let description = [{
@@ -358,7 +358,7 @@ def FHE_MaxEintOp : FHE_Op<"max_eint", [NoSideEffect]> {
let hasVerifier = 1;
}
def FHE_ToSignedOp : FHE_Op<"to_signed", [NoSideEffect]> {
def FHE_ToSignedOp : FHE_Op<"to_signed", [Pure]> {
let summary = "Cast an unsigned integer to a signed one";
let description = [{
@@ -383,7 +383,7 @@ def FHE_ToSignedOp : FHE_Op<"to_signed", [NoSideEffect]> {
let hasVerifier = 1;
}
def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [NoSideEffect]> {
def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [Pure]> {
let summary = "Cast a signed integer to an unsigned one";
let description = [{
@@ -408,7 +408,7 @@ def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [NoSideEffect]> {
let hasVerifier = 1;
}
def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [NoSideEffect]> {
def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [Pure]> {
let summary = "Applies a clear lookup table to an encrypted integer";
@@ -434,7 +434,7 @@ def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [NoSideEffect]> {
let hasVerifier = 1;
}
def FHE_RoundEintOp: FHE_Op<"round", [NoSideEffect]> {
def FHE_RoundEintOp: FHE_Op<"round", [Pure]> {
let summary = "Rounds a ciphertext to a smaller precision.";
@@ -465,7 +465,7 @@ def FHE_RoundEintOp: FHE_Op<"round", [NoSideEffect]> {
// FHE Boolean Operations
def FHE_GenGateOp : FHE_Op<"gen_gate", [NoSideEffect]> {
def FHE_GenGateOp : FHE_Op<"gen_gate", [Pure]> {
let summary = "Applies a truth table based on two boolean inputs";
@@ -492,7 +492,7 @@ def FHE_GenGateOp : FHE_Op<"gen_gate", [NoSideEffect]> {
let hasVerifier = 1;
}
def FHE_MuxOp : FHE_Op<"mux", [NoSideEffect]> {
def FHE_MuxOp : FHE_Op<"mux", [Pure]> {
let summary = "Multiplexer for two encrypted boolean inputs, based on an encrypted condition";
@@ -509,7 +509,7 @@ def FHE_MuxOp : FHE_Op<"mux", [NoSideEffect]> {
let results = (outs FHE_EncryptedBooleanType);
}
def FHE_BoolAndOp : FHE_Op<"and", [NoSideEffect]> {
def FHE_BoolAndOp : FHE_Op<"and", [Pure]> {
let summary = "Applies an AND gate to two encrypted boolean values";
@@ -526,7 +526,7 @@ def FHE_BoolAndOp : FHE_Op<"and", [NoSideEffect]> {
let results = (outs FHE_EncryptedBooleanType);
}
def FHE_BoolOrOp : FHE_Op<"or", [NoSideEffect]> {
def FHE_BoolOrOp : FHE_Op<"or", [Pure]> {
let summary = "Applies an OR gate to two encrypted boolean values";
@@ -543,7 +543,7 @@ def FHE_BoolOrOp : FHE_Op<"or", [NoSideEffect]> {
let results = (outs FHE_EncryptedBooleanType);
}
def FHE_BoolNandOp : FHE_Op<"nand", [NoSideEffect]> {
def FHE_BoolNandOp : FHE_Op<"nand", [Pure]> {
let summary = "Applies a NAND gate to two encrypted boolean values";
@@ -560,7 +560,7 @@ def FHE_BoolNandOp : FHE_Op<"nand", [NoSideEffect]> {
let results = (outs FHE_EncryptedBooleanType);
}
def FHE_BoolXorOp : FHE_Op<"xor", [NoSideEffect]> {
def FHE_BoolXorOp : FHE_Op<"xor", [Pure]> {
let summary = "Applies a XOR gate to two encrypted boolean values";
@@ -577,7 +577,7 @@ def FHE_BoolXorOp : FHE_Op<"xor", [NoSideEffect]> {
let results = (outs FHE_EncryptedBooleanType);
}
def FHE_BoolNotOp : FHE_Op<"not", [NoSideEffect]> {
def FHE_BoolNotOp : FHE_Op<"not", [Pure]> {
let summary = "Applies a NOT gate to an encrypted boolean value";
@@ -594,7 +594,7 @@ def FHE_BoolNotOp : FHE_Op<"not", [NoSideEffect]> {
let results = (outs FHE_EncryptedBooleanType);
}
def FHE_ToBoolOp : FHE_Op<"to_bool", [NoSideEffect]> {
def FHE_ToBoolOp : FHE_Op<"to_bool", [Pure]> {
let summary = "Cast an unsigned integer to a boolean";
let description = [{
@@ -620,7 +620,7 @@ def FHE_ToBoolOp : FHE_Op<"to_bool", [NoSideEffect]> {
let hasVerifier = 1;
}
def FHE_FromBoolOp : FHE_Op<"from_bool", [NoSideEffect]> {
def FHE_FromBoolOp : FHE_Op<"from_bool", [Pure]> {
let summary = "Cast a boolean to an unsigned integer";
let description = [{

View File

@@ -10,6 +10,7 @@ def FHELinalg_Dialect : Dialect {
A dialect for representation of high level linalg operations on fully homomorphic ciphertexts.
}];
let cppNamespace = "::mlir::concretelang::FHELinalg";
let useFoldAPI = kEmitRawAttributesFolder;
}
#endif

View File

@@ -72,13 +72,13 @@ def SDFG_MakeStream : SDFG_Op<"make_stream"> {
let extraClassDeclaration = [{
bool createsInputStream() {
return type() == StreamKind::host_to_device ||
type() == StreamKind::on_device;
return getType() == StreamKind::host_to_device ||
getType() == StreamKind::on_device;
}
bool createsOutputStream() {
return type() == StreamKind::device_to_host ||
type() == StreamKind::on_device;
return getType() == StreamKind::device_to_host ||
getType() == StreamKind::on_device;
}
}];
}

View File

@@ -45,13 +45,13 @@ enum Backend {
/// Compilation options allows to configure the compilation pipeline.
struct CompilationOptions {
llvm::Optional<mlir::concretelang::V0FHEConstraint> v0FHEConstraints;
std::optional<mlir::concretelang::V0FHEConstraint> v0FHEConstraints;
llvm::Optional<mlir::concretelang::V0Parameter> v0Parameter;
std::optional<mlir::concretelang::V0Parameter> v0Parameter;
/// largeIntegerParameter force the compiler engine to lower FHE.eint using
/// the large integers strategy with the given parameters.
llvm::Optional<mlir::concretelang::LargeIntegerParameter>
std::optional<mlir::concretelang::LargeIntegerParameter>
largeIntegerParameter;
bool verifyDiagnostics;
@@ -65,9 +65,9 @@ struct CompilationOptions {
bool optimizeTFHE;
/// use GPU during execution by generating GPU operations if possible
bool emitGPUOps;
llvm::Optional<std::vector<int64_t>> fhelinalgTileSizes;
std::optional<std::vector<int64_t>> fhelinalgTileSizes;
llvm::Optional<std::string> clientParametersFuncName;
std::optional<std::string> clientParametersFuncName;
optimizer::Config optimizerConfig;
@@ -79,11 +79,11 @@ struct CompilationOptions {
unsigned int chunkWidth;
CompilationOptions()
: v0FHEConstraints(llvm::None), verifyDiagnostics(false),
: v0FHEConstraints(std::nullopt), verifyDiagnostics(false),
autoParallelize(false), loopParallelize(false), batchConcreteOps(false),
emitSDFGOps(false), unrollLoopsWithSDFGConvertibleOps(false),
dataflowParallelize(false), optimizeTFHE(true), emitGPUOps(false),
clientParametersFuncName(llvm::None),
clientParametersFuncName(std::nullopt),
optimizerConfig(optimizer::DEFAULT_CONFIG), chunkIntegers(false),
chunkSize(4), chunkWidth(2){};
@@ -119,11 +119,11 @@ public:
CompilationContext::createShared())
: compilationContext(compilationContext) {}
llvm::Optional<mlir::OwningOpRef<mlir::ModuleOp>> mlirModuleRef;
llvm::Optional<mlir::concretelang::ClientParameters> clientParameters;
llvm::Optional<CompilationFeedback> feedback;
std::optional<mlir::OwningOpRef<mlir::ModuleOp>> mlirModuleRef;
std::optional<mlir::concretelang::ClientParameters> clientParameters;
std::optional<CompilationFeedback> feedback;
std::unique_ptr<llvm::Module> llvmModule;
llvm::Optional<mlir::concretelang::V0FHEContext> fheContext;
std::optional<mlir::concretelang::V0FHEContext> fheContext;
protected:
std::shared_ptr<CompilationContext> compilationContext;
@@ -176,7 +176,7 @@ public:
void addExtraObjectFilePath(std::string objectFilePath);
llvm::Expected<std::string>
emit(std::string path, std::string dotExt, std::string linker,
llvm::Optional<std::vector<std::string>> extraArgs = {});
std::optional<std::vector<std::string>> extraArgs = {});
~Library();
private:
@@ -242,21 +242,21 @@ public:
CompilerEngine(std::shared_ptr<CompilationContext> compilationContext)
: overrideMaxEintPrecision(), overrideMaxMANP(), compilerOptions(),
generateClientParameters(
compilerOptions.clientParametersFuncName.hasValue()),
compilerOptions.clientParametersFuncName.has_value()),
enablePass([](mlir::Pass *pass) { return true; }),
compilationContext(compilationContext) {}
llvm::Expected<CompilationResult>
compile(llvm::StringRef s, Target target,
llvm::Optional<std::shared_ptr<Library>> lib = {});
std::optional<std::shared_ptr<Library>> lib = {});
llvm::Expected<CompilationResult>
compile(std::unique_ptr<llvm::MemoryBuffer> buffer, Target target,
llvm::Optional<std::shared_ptr<Library>> lib = {});
std::optional<std::shared_ptr<Library>> lib = {});
llvm::Expected<CompilationResult>
compile(llvm::SourceMgr &sm, Target target,
llvm::Optional<std::shared_ptr<Library>> lib = {});
std::optional<std::shared_ptr<Library>> lib = {});
llvm::Expected<CompilerEngine::Library>
compile(std::vector<std::string> inputs, std::string outputDirPath,
@@ -276,11 +276,11 @@ public:
void setCompilationOptions(CompilationOptions &options) {
compilerOptions = options;
if (options.v0FHEConstraints.hasValue()) {
if (options.v0FHEConstraints.has_value()) {
setFHEConstraints(*options.v0FHEConstraints);
}
if (options.clientParametersFuncName.hasValue()) {
if (options.clientParametersFuncName.has_value()) {
setGenerateClientParameters(true);
}
}
@@ -292,8 +292,8 @@ public:
void setEnablePass(std::function<bool(mlir::Pass *)> enablePass);
protected:
llvm::Optional<size_t> overrideMaxEintPrecision;
llvm::Optional<size_t> overrideMaxMANP;
std::optional<size_t> overrideMaxEintPrecision;
std::optional<size_t> overrideMaxMANP;
CompilationOptions compilerOptions;
bool generateClientParameters;
std::function<bool(mlir::Pass *)> enablePass;
@@ -301,7 +301,7 @@ protected:
std::shared_ptr<CompilationContext> compilationContext;
private:
llvm::Expected<llvm::Optional<optimizer::Description>>
llvm::Expected<std::optional<optimizer::Description>>
getConcreteOptimizerDescription(CompilationResult &res);
llvm::Error determineFHEParameters(CompilationResult &res);
};

View File

@@ -34,7 +34,7 @@ class JITSupport
JitCompilationResult> {
public:
JITSupport(llvm::Optional<std::string> runtimeLibPath = llvm::None);
JITSupport(std::optional<std::string> runtimeLibPath = std::nullopt);
llvm::Expected<std::unique_ptr<JitCompilationResult>>
compile(llvm::SourceMgr &program, CompilationOptions options) override;
@@ -63,7 +63,7 @@ public:
}
private:
llvm::Optional<std::string> runtimeLibPath;
std::optional<std::string> runtimeLibPath;
llvm::function_ref<llvm::Error(llvm::Module *)> llvmOptPipeline;
};

View File

@@ -32,7 +32,7 @@ public:
static llvm::Expected<std::unique_ptr<JITLambda>>
create(llvm::StringRef name, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline,
llvm::Optional<std::string> runtimeLibPath = {});
std::optional<std::string> runtimeLibPath = {});
/// Call the JIT lambda with the public arguments.
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>

View File

@@ -13,10 +13,9 @@ llvm::Error emitObject(llvm::Module &module, std::string objectPath);
llvm::Error callCmd(std::string cmd);
llvm::Error
emitLibrary(std::vector<std::string> objectsPath, std::string libraryPath,
std::string linker,
llvm::Optional<std::vector<std::string>> extraArgs = {});
llvm::Error emitLibrary(std::vector<std::string> objectsPath,
std::string libraryPath, std::string linker,
std::optional<std::vector<std::string>> extraArgs = {});
} // namespace concretelang
} // namespace mlir

View File

@@ -206,7 +206,7 @@ typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
return (sign) ? buildScalarLambdaResult<int8_t>(keySet, result)
: buildScalarLambdaResult<uint8_t>(keySet, result);
}
} else if (gate.chunkInfo.hasValue()) {
} else if (gate.chunkInfo.has_value()) {
// chunked scalar case
assert(gate.shape.dimensions.size() == 1);
width = gate.shape.size * gate.chunkInfo->width;
@@ -392,10 +392,10 @@ public:
/// Build the client KeySet from the client parameters.
static llvm::Expected<std::unique_ptr<clientlib::KeySet>>
keySet(clientlib::ClientParameters clientParameters,
llvm::Optional<clientlib::KeySetCache> cache) {
std::optional<clientlib::KeySetCache> cache) {
std::shared_ptr<clientlib::KeySetCache> cachePtr;
if (cache.hasValue()) {
cachePtr = std::make_shared<clientlib::KeySetCache>(cache.getValue());
if (cache.has_value()) {
cachePtr = std::make_shared<clientlib::KeySetCache>(cache.value());
}
auto keySet =
clientlib::KeySetCache::generate(cachePtr, clientParameters, 0, 0);
@@ -434,7 +434,7 @@ public:
static llvm::Expected<ClientServer>
create(llvm::StringRef program,
CompilationOptions options = CompilationOptions("main"),
llvm::Optional<clientlib::KeySetCache> cache = {},
std::optional<clientlib::KeySetCache> cache = {},
LambdaSupport support = LambdaSupport()) {
auto compilationResult = support.compile(program, options);
if (auto err = compilationResult.takeError()) {

View File

@@ -62,7 +62,7 @@ public:
return std::move(err);
}
if (!options.clientParametersFuncName.hasValue()) {
if (!options.clientParametersFuncName.has_value()) {
return StreamStringError("Need to have a funcname to compile library");
}

View File

@@ -39,11 +39,11 @@ static SmallVector<Value> makeCanonicalAffineApplies(OpBuilder &b, Location loc,
}
template <typename LoadOpTy, typename StoreOpTy, typename OpType>
static std::vector<Value> inlineRegionAndEmitStore(
static llvm::SmallVector<Value> inlineRegionAndEmitStore(
OpBuilder &b, Location loc, OpType op, ArrayRef<Value> indexedValues,
ArrayRef<SmallVector<Value>> indexing, ArrayRef<Value> outputBuffers) {
auto &block = op->getRegion(0).front();
BlockAndValueMapping map;
IRMapping map;
map.map(block.getArguments(), indexedValues);
for (auto &op : block.without_terminator()) {
auto *newOp = b.clone(op, map);
@@ -51,7 +51,7 @@ static std::vector<Value> inlineRegionAndEmitStore(
}
Operation *terminator = block.getTerminator();
std::vector<Value> retVals;
llvm::SmallVector<Value> retVals;
for (OpOperand &operand : terminator->getOpOperands()) {
Value toStore = map.lookupOrDefault(operand.get());
@@ -91,40 +91,42 @@ static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
LoopLikeOpInterface loopOp = loopOps.back();
for (IndexOp indexOp :
llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>()))
rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]);
rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]);
}
}
template <typename LoadOpTy, typename StoreOpTy>
static std::vector<Value>
static llvm::SmallVector<Value>
emitScalarImplementation(OpBuilder &b, Location loc, ArrayRef<Value> allIvs,
LinalgOp linalgOp, ValueRange operandValuesToUse) {
assert(linalgOp.hasTensorSemantics() &&
"expected linalg op with buffer semantics");
SmallVector<Value> indexedValues;
indexedValues.reserve(linalgOp.getNumInputsAndOutputs());
indexedValues.reserve(linalgOp->getNumOperands());
auto allIvsPlusDims = SmallVector<Value>(allIvs.begin(), allIvs.end());
// TODO: Avoid the loads if the corresponding argument of the
// region has no uses.
// 1.a. Emit load from input operand or for scalars access the operand itself.
for (OpOperand *inputOperand : linalgOp.getInputOperands()) {
for (OpOperand *inputOperand : linalgOp.getDpsInputOperands()) {
Value v = operandValuesToUse[inputOperand->getOperandNumber()];
if (linalgOp.isScalar(inputOperand)) {
indexedValues.push_back(inputOperand->get());
indexedValues.push_back(v);
continue;
}
auto indexing = makeCanonicalAffineApplies(
b, loc, linalgOp.getTiedIndexingMap(inputOperand), allIvsPlusDims);
indexedValues.push_back(
b.create<LoadOpTy>(loc, inputOperand->get(), indexing));
b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims);
indexedValues.push_back(b.create<LoadOpTy>(loc, v, indexing));
}
// 1.b. Emit load from output views.
for (OpOperand *outputOperand : linalgOp.getOutputOperands()) {
for (OpOperand *outputOperand : linalgOp.getDpsInitOperands()) {
Value v = operandValuesToUse[outputOperand->getOperandNumber()];
SmallVector<Value> indexing = makeCanonicalAffineApplies(
b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims);
indexedValues.push_back(
b.create<LoadOpTy>(loc, outputOperand->get(), indexing));
b, loc, linalgOp.getMatchingIndexingMap(outputOperand), allIvsPlusDims);
indexedValues.push_back(b.create<LoadOpTy>(loc, v, indexing));
}
// TODO: When a region inliner exists, use it.
@@ -132,10 +134,13 @@ emitScalarImplementation(OpBuilder &b, Location loc, ArrayRef<Value> allIvs,
// 3. Emit store.
SmallVector<SmallVector<Value>, 8> indexing;
SmallVector<Value> outputBuffers;
for (OpOperand *outputOperand : linalgOp.getOutputTensorOperands()) {
indexing.push_back(makeCanonicalAffineApplies(
b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims));
outputBuffers.push_back(operandValuesToUse.back());
for (OpOperand *outputOperand : linalgOp.getDpsInitOperands()) {
if (outputOperand->get().getType().isa<mlir::TensorType>()) {
indexing.push_back(makeCanonicalAffineApplies(
b, loc, linalgOp.getMatchingIndexingMap(outputOperand),
allIvsPlusDims));
outputBuffers.push_back(operandValuesToUse.back());
}
}
return inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(
b, loc, linalgOp, indexedValues, indexing, outputBuffers);
@@ -151,7 +156,7 @@ linalgTensorOpToLoopsImpl(PatternRewriter &rewriter, LinalgOp linalgOp,
"expected linalg op with value semantics");
auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
auto iteratorTypes = llvm::to_vector<4>(linalgOp.getIteratorTypesArray());
SmallVector<Value> allIvs;
GenerateLoopNest<LoopTy>::doit(

View File

@@ -20,7 +20,7 @@ namespace pipeline {
mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
llvm::Expected<std::map<std::string, llvm::Optional<optimizer::Description>>>
llvm::Expected<std::map<std::string, std::optional<optimizer::Description>>>
getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
optimizer::Config config,
std::function<bool(mlir::Pass *)> enablePass);
@@ -40,7 +40,7 @@ transformHighLevelFHEOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::LogicalResult
lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,
std::optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelize, bool batch);
@@ -55,12 +55,12 @@ transformFHEBigInt(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::LogicalResult
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,
std::optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,
std::optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult

View File

@@ -21,7 +21,7 @@ using ::concretelang::clientlib::ClientParameters;
llvm::Expected<ClientParameters>
createClientParametersForV0(V0FHEContext context, llvm::StringRef functionName,
mlir::ModuleOp module, int bitsOfSecurity,
llvm::Optional<ChunkInfo> chunkInfo = llvm::None);
llvm::Optional<ChunkInfo> chunkInfo = std::nullopt);
} // namespace concretelang
} // namespace mlir

View File

@@ -60,7 +60,7 @@ using DagSolution = concrete_optimizer::dag::DagSolution;
/* Contains any circuit description usable by the concrete-optimizer */
struct Description {
V0FHEConstraint constraint;
llvm::Optional<optimizer::Dag> dag;
std::optional<optimizer::Dag> dag;
};
} // namespace optimizer

View File

@@ -115,7 +115,6 @@ add_mlir_python_common_capi_library(
DECLARED_SOURCES
# TODO: This can be chopped down significantly for size.
MLIRPythonSources
MLIRPythonExtension.AllPassesRegistration
ConcretelangBindingsPythonSources
ConcretelangBindingsPythonExtension)
@@ -130,7 +129,6 @@ add_mlir_python_modules(
"python_packages/concretelang_core/mlir"
DECLARED_SOURCES
MLIRPythonSources
MLIRPythonExtension.AllPassesRegistration
# We need the circt extensions co-located with the MLIR extensions. When the namespace is unified, this moves to the
# below.
ConcretelangBindingsPythonExtension

View File

@@ -203,10 +203,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
"key_set",
[](clientlib::ClientParameters clientParameters,
clientlib::KeySetCache *cache) {
auto optCache =
cache == nullptr
? llvm::None
: llvm::Optional<clientlib::KeySetCache>(*cache);
auto optCache = cache == nullptr
? std::nullopt
: std::optional<clientlib::KeySetCache>(*cache);
return key_set(clientParameters, optCache);
},
pybind11::arg().none(false), pybind11::arg().none(true))
@@ -241,9 +240,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](mlir::concretelang::ClientParameters &clientParameters) {
std::vector<bool> result;
for (auto output : clientParameters.outputs) {
if (output.encryption.hasValue()) {
result.push_back(
output.encryption.getValue().encoding.isSigned);
if (output.encryption.has_value()) {
result.push_back(output.encryption.value().encoding.isSigned);
} else {
result.push_back(true);
}

View File

@@ -23,8 +23,8 @@
MLIR_CAPI_EXPORTED JITSupport_Py jit_support(std::string runtimeLibPath) {
auto opt = runtimeLibPath.empty()
? llvm::None
: llvm::Optional<std::string>(runtimeLibPath);
? std::nullopt
: std::optional<std::string>(runtimeLibPath);
return JITSupport_Py{mlir::concretelang::JITSupport(opt)};
}
@@ -139,7 +139,7 @@ library_get_client_parameters_path(LibrarySupport_Py support) {
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::KeySet>
key_set(concretelang::clientlib::ClientParameters clientParameters,
llvm::Optional<concretelang::clientlib::KeySetCache> cache) {
std::optional<concretelang::clientlib::KeySetCache> cache) {
GET_OR_THROW_LLVM_EXPECTED(
ks, (mlir::concretelang::LambdaSupport<int, int>::keySet(clientParameters,
cache)));

View File

@@ -9,7 +9,7 @@
#include "concretelang/Bindings/Python/DialectModules.h"
#include "concretelang/Support/Constants.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Registration.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "llvm-c/ErrorHandling.h"

View File

@@ -30,6 +30,5 @@
#include <mlir-c/IntegerSet.h>
#include <mlir-c/Interfaces.h>
#include <mlir-c/Pass.h>
#include <mlir-c/Registration.h>
#include <mlir-c/Support.h>
#include <mlir-c/Transforms.h>

View File

@@ -5,7 +5,7 @@ use std::error::Error;
use std::path::Path;
use std::process::exit;
const MLIR_STATIC_LIBS: [&str; 179] = [
const MLIR_STATIC_LIBS: [&str; 174] = [
"MLIRMemRefDialect",
"MLIRVectorToSPIRV",
"MLIRControlFlowInterfaces",
@@ -37,7 +37,7 @@ const MLIR_STATIC_LIBS: [&str; 179] = [
"MLIRPresburger",
"MLIRFuncDialect",
"MLIRPDLToPDLInterp",
"MLIRArithmeticTransforms",
"MLIRArithTransforms",
"MLIRViewLikeInterface",
"MLIRTargetCpp",
"MLIROpenMPToLLVM",
@@ -53,8 +53,8 @@ const MLIR_STATIC_LIBS: [&str; 179] = [
"MLIRTensorUtils",
"MLIRSPIRVSerialization",
"MLIRShapeToStandard",
"MLIRArithmeticToSPIRV",
"MLIRArithmeticDialect",
"MLIRArithToSPIRV",
"MLIRArithDialect",
"MLIRFuncToSPIRV",
"MLIRQuantUtils",
"MLIRTensorTilingInterfaceImpl",
@@ -95,7 +95,7 @@ const MLIR_STATIC_LIBS: [&str; 179] = [
"MLIRRewrite",
"MLIRAMXToLLVMIRTranslation",
"MLIRInferIntRangeInterface",
"MLIRCAPIRegistration",
"MLIRCAPIRegisterEverything",
"MLIRNVVMToLLVMIRTranslation",
"MLIRAsyncTransforms",
"MLIRPDLInterpDialect",
@@ -129,7 +129,6 @@ const MLIR_STATIC_LIBS: [&str; 179] = [
"MLIRSPIRVUtils",
"MLIRCastInterfaces",
"MLIRTosaToTensor",
"MLIRMemRefUtils",
"MLIRGPUToSPIRV",
"MLIRBufferizationDialect",
"MLIRSCFToControlFlow",
@@ -139,7 +138,6 @@ const MLIR_STATIC_LIBS: [&str; 179] = [
"MLIRSparseTensorDialect",
"MLIRTensorToSPIRV",
"MLIRVectorToSCF",
"MLIRQuantTransforms",
"MLIRLLVMToLLVMIRTranslation",
"MLIRNVGPUDialect",
"MLIRAsyncToLLVM",
@@ -158,11 +156,9 @@ const MLIR_STATIC_LIBS: [&str; 179] = [
"MLIRVectorToLLVM",
"MLIRSPIRVDialect",
"MLIRSideEffectInterfaces",
"MLIRVectorToROCDL",
"MLIRQuantDialect",
"MLIRSCFTransforms",
"MLIRMLProgramDialect",
"MLIRLinalgToSPIRV",
"MLIRDLTIDialect",
"MLIRLinalgFrontend",
"MLIRROCDLToLLVMIRTranslation",
@@ -177,14 +173,13 @@ const MLIR_STATIC_LIBS: [&str; 179] = [
"MLIRSPIRVTransforms",
"MLIRMemRefToLLVM",
"MLIRSPIRVBinaryUtils",
"MLIRLinalgAnalysis",
"MLIRArithmeticUtils",
"MLIRArithUtils",
"MLIRVectorInterfaces",
"MLIRGPUOps",
"MLIRComplexToLLVM",
"MLIRShapeOpsTransforms",
"MLIRX86VectorTransforms",
"MLIRArithmeticToLLVM",
"MLIRArithToLLVM",
];
const LLVM_STATIC_LIBS: [&str; 51] = [

View File

@@ -29,7 +29,7 @@ ClientLambda::load(std::string functionName, std::string jsonPath) {
<< std::to_string(param->outputs.size()) << ") != 1 is not supprted";
}
if (!param->outputs[0].encryption.hasValue()) {
if (!param->outputs[0].encryption.has_value()) {
return StringError("ClientLambda: clear output is not yet supported");
}
ClientLambda lambda;

View File

@@ -37,9 +37,9 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
OUTCOME_TRY(CircuitGate input, keySet.clientParameters().input(currentPos));
// a chunked input is represented as a tensor in lower levels, and need to to
// splitted into chunks and encrypted as such
if (input.chunkInfo.hasValue()) {
if (input.chunkInfo.has_value()) {
std::vector<uint64_t> chunks =
chunkInput(arg, input.shape.size, input.chunkInfo.getPointer()->width);
chunkInput(arg, input.shape.size, input.chunkInfo.value().width);
return this->pushArg(chunks.data(), input.shape.size, keySet);
}
// we only increment if we don't forward the call to another pushArg method
@@ -47,7 +47,7 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
if (input.shape.size != 0) {
return StringError("argument #") << pos << " is not a scalar";
}
if (!input.encryption.hasValue()) {
if (!input.encryption.has_value()) {
// clear scalar: just push the argument
preparedArgs.push_back((void *)arg);
return outcome::success();

View File

@@ -45,15 +45,15 @@ outcome::checked<KeySet::SecretKeyGateMapping, StringError>
KeySet::mapCircuitGateLweSecretKey(std::vector<CircuitGate> gates) {
SecretKeyGateMapping mapping;
for (auto gate : gates) {
if (gate.encryption.hasValue()) {
if (gate.encryption.has_value()) {
assert(gate.encryption->secretKeyID < this->secretKeys.size());
auto skIt = this->secretKeys[gate.encryption->secretKeyID];
std::pair<CircuitGate, llvm::Optional<LweSecretKey>> input = {gate, skIt};
std::pair<CircuitGate, std::optional<LweSecretKey>> input = {gate, skIt};
mapping.push_back(input);
} else {
std::pair<CircuitGate, llvm::Optional<LweSecretKey>> input = {gate,
llvm::None};
std::pair<CircuitGate, std::optional<LweSecretKey>> input = {
gate, std::nullopt};
mapping.push_back(input);
}
}
@@ -152,7 +152,7 @@ KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size) {
}
auto inputSk = inputs[argPos];
auto encryption = std::get<0>(inputSk).encryption;
if (!encryption.hasValue()) {
if (!encryption.has_value()) {
return StringError("allocate_lwe argument #")
<< argPos << "is not encypeted";
}
@@ -167,12 +167,12 @@ KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size) {
bool KeySet::isInputEncrypted(size_t argPos) {
return argPos < inputs.size() &&
std::get<0>(inputs[argPos]).encryption.hasValue();
std::get<0>(inputs[argPos]).encryption.has_value();
}
bool KeySet::isOutputEncrypted(size_t argPos) {
return argPos < outputs.size() &&
std::get<0>(outputs[argPos]).encryption.hasValue();
std::get<0>(outputs[argPos]).encryption.has_value();
}
/// Return the number of bits to represents the given value
@@ -183,9 +183,9 @@ KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) {
if (argPos >= inputs.size()) {
return StringError("encrypt_lwe position of argument is too high");
}
auto inputSk = inputs[argPos];
const auto &inputSk = inputs[argPos];
auto encryption = std::get<0>(inputSk).encryption;
if (!encryption.hasValue()) {
if (!encryption.has_value()) {
return StringError("encrypt_lwe the positional argument is not encrypted");
}
auto encoding = encryption->encoding;
@@ -221,7 +221,7 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
auto lweSecretKey = *outputSk.second;
auto lweSecretKeyParam = lweSecretKey.parameters();
auto encryption = std::get<0>(outputSk).encryption;
if (!encryption.hasValue()) {
if (!encryption.has_value()) {
return StringError("decrypt_lwe: the positional argument is not encrypted");
}

View File

@@ -37,7 +37,7 @@ PublicArguments::serialize(std::ostream &ostream) {
for (auto gate : clientParameters.inputs) {
iGate++;
size_t rank = gate.shape.dimensions.size();
if (!gate.encryption.hasValue()) {
if (!gate.encryption.has_value()) {
return StringError("PublicArguments::serialize: Clear arguments "
"are not yet supported. Argument ")
<< iGate;
@@ -78,7 +78,7 @@ PublicArguments::unserializeArgs(std::istream &istream) {
int iGate = -1;
for (auto gate : clientParameters.inputs) {
iGate++;
if (!gate.encryption.hasValue()) {
if (!gate.encryption.has_value()) {
return StringError("Clear values are not handled");
}
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
@@ -135,7 +135,7 @@ PublicArguments::unserialize(ClientParameters &clientParameters,
outcome::checked<void, StringError>
PublicResult::unserialize(std::istream &istream) {
for (auto gate : clientParameters.outputs) {
if (!gate.encryption.hasValue()) {
if (!gate.encryption.has_value()) {
return StringError("Clear values are not handled");
}
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();

View File

@@ -52,7 +52,7 @@ char memref_trace[] = "memref_trace";
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
size_t rank) {
std::vector<int64_t> shape(rank, -1);
std::vector<int64_t> shape(rank, mlir::ShapedType::kDynamic);
mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0);
for (size_t i = 0; i < rank; i++) {
expr = expr +
@@ -263,16 +263,16 @@ void keyswitchAddOperands(KeySwitchOp op,
mlir::RewriterBase &rewriter) {
// level
operands.push_back(
rewriter.create<arith::ConstantOp>(op.getLoc(), op.levelAttr()));
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getLevelAttr()));
// base_log
operands.push_back(
rewriter.create<arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getBaseLogAttr()));
// lwe_dim_in
operands.push_back(
rewriter.create<arith::ConstantOp>(op.getLoc(), op.lwe_dim_inAttr()));
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getLweDimInAttr()));
// lwe_dim_out
operands.push_back(
rewriter.create<arith::ConstantOp>(op.getLoc(), op.lwe_dim_outAttr()));
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getLweDimOutAttr()));
// context
operands.push_back(getContextArgument(op));
}
@@ -283,22 +283,22 @@ void bootstrapAddOperands(BootstrapOp op,
mlir::RewriterBase &rewriter) {
// input_lwe_dim
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.inputLweDimAttr()));
op.getLoc(), op.getInputLweDimAttr()));
// poly_size
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.getPolySizeAttr()));
// level
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.getLevelAttr()));
// base_log
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.getBaseLogAttr()));
// glwe_dim
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.glweDimensionAttr()));
op.getLoc(), op.getGlweDimensionAttr()));
// out_precision
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.outPrecisionAttr()));
op.getLoc(), op.getOutPrecisionAttr()));
// context
operands.push_back(getContextArgument(op));
}
@@ -307,9 +307,9 @@ void wopPBSAddOperands(Concrete::WopPBSCRTLweBufferOp op,
mlir::SmallVector<mlir::Value> &operands,
mlir::RewriterBase &rewriter) {
mlir::Type crtType = mlir::RankedTensorType::get(
{(int)op.crtDecompositionAttr().size()}, rewriter.getI64Type());
{(int)op.getCrtDecompositionAttr().size()}, rewriter.getI64Type());
std::vector<int64_t> values;
for (auto a : op.crtDecomposition()) {
for (auto a : op.getCrtDecomposition()) {
values.push_back(a.cast<mlir::IntegerAttr>().getValue().getZExtValue());
}
auto attr = rewriter.getI64TensorAttr(values);
@@ -319,43 +319,43 @@ void wopPBSAddOperands(Concrete::WopPBSCRTLweBufferOp op,
assert(!failed(globalMemref));
auto globalRef = rewriter.create<memref::GetGlobalOp>(
op.getLoc(), (*globalMemref).type(), (*globalMemref).getName());
op.getLoc(), (*globalMemref).getType(), (*globalMemref).getName());
operands.push_back(getCastedMemRef(rewriter, globalRef));
// lwe_small_size
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.packingKeySwitchInputLweDimensionAttr()));
op.getLoc(), op.getPackingKeySwitchInputLweDimensionAttr()));
// cbs_level_count
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.circuitBootstrapLevelAttr()));
op.getLoc(), op.getCircuitBootstrapLevelAttr()));
// cbs_base_log
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.circuitBootstrapBaseLogAttr()));
op.getLoc(), op.getCircuitBootstrapBaseLogAttr()));
// ksk_level_count
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.keyswitchLevelAttr()));
op.getLoc(), op.getKeyswitchLevelAttr()));
// ksk_base_log
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.keyswitchBaseLogAttr()));
op.getLoc(), op.getKeyswitchBaseLogAttr()));
// bsk_level_count
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.bootstrapLevelAttr()));
op.getLoc(), op.getBootstrapLevelAttr()));
// bsk_base_log
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.bootstrapBaseLogAttr()));
op.getLoc(), op.getBootstrapBaseLogAttr()));
// fpksk_level_count
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.packingKeySwitchLevelAttr()));
op.getLoc(), op.getPackingKeySwitchLevelAttr()));
// fpksk_base_log
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.packingKeySwitchBaseLogAttr()));
op.getLoc(), op.getPackingKeySwitchBaseLogAttr()));
// polynomial_size
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.packingKeySwitchoutputPolynomialSizeAttr()));
op.getLoc(), op.getPackingKeySwitchoutputPolynomialSizeAttr()));
// context
operands.push_back(getContextArgument(op));
@@ -365,10 +365,10 @@ void encodePlaintextWithCrtAddOperands(
Concrete::EncodePlaintextWithCrtBufferOp op,
mlir::SmallVector<mlir::Value> &operands, mlir::RewriterBase &rewriter) {
// mods
mlir::Type modsType = mlir::RankedTensorType::get({(int)op.modsAttr().size()},
rewriter.getI64Type());
mlir::Type modsType = mlir::RankedTensorType::get(
{(int)op.getModsAttr().size()}, rewriter.getI64Type());
std::vector<int64_t> modsValues;
for (auto a : op.mods()) {
for (auto a : op.getMods()) {
modsValues.push_back(a.cast<mlir::IntegerAttr>().getValue().getZExtValue());
}
auto modsAttr = rewriter.getI64TensorAttr(modsValues);
@@ -378,26 +378,27 @@ void encodePlaintextWithCrtAddOperands(
rewriter.eraseOp(modsOp);
assert(!failed(modsGlobalMemref));
auto modsGlobalRef = rewriter.create<memref::GetGlobalOp>(
op.getLoc(), (*modsGlobalMemref).type(), (*modsGlobalMemref).getName());
op.getLoc(), (*modsGlobalMemref).getType(),
(*modsGlobalMemref).getName());
operands.push_back(getCastedMemRef(rewriter, modsGlobalRef));
// mods_prod
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.modsProdAttr()));
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.getModsProdAttr()));
}
void encodeExpandLutForBootstrapAddOperands(
Concrete::EncodeExpandLutForBootstrapBufferOp op,
mlir::SmallVector<mlir::Value> &operands, mlir::RewriterBase &rewriter) {
// poly_size
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.getPolySizeAttr()));
// output bits
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.outputBitsAttr()));
op.getLoc(), op.getOutputBitsAttr()));
// is_signed
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.isSignedAttr()));
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.getIsSignedAttr()));
}
void encodeLutForWopPBSAddOperands(Concrete::EncodeLutForCrtWopPBSBufferOp op,
@@ -406,9 +407,9 @@ void encodeLutForWopPBSAddOperands(Concrete::EncodeLutForCrtWopPBSBufferOp op,
// crt_decomposition
mlir::Type crtDecompositionType = mlir::RankedTensorType::get(
{(int)op.crtDecompositionAttr().size()}, rewriter.getI64Type());
{(int)op.getCrtDecompositionAttr().size()}, rewriter.getI64Type());
std::vector<int64_t> crtDecompositionValues;
for (auto a : op.crtDecomposition()) {
for (auto a : op.getCrtDecomposition()) {
crtDecompositionValues.push_back(
a.cast<mlir::IntegerAttr>().getValue().getZExtValue());
}
@@ -420,15 +421,15 @@ void encodeLutForWopPBSAddOperands(Concrete::EncodeLutForCrtWopPBSBufferOp op,
rewriter.eraseOp(crtDecompositionOp);
assert(!failed(crtDecompositionGlobalMemref));
auto crtDecompositionGlobalRef = rewriter.create<memref::GetGlobalOp>(
op.getLoc(), (*crtDecompositionGlobalMemref).type(),
op.getLoc(), (*crtDecompositionGlobalMemref).getType(),
(*crtDecompositionGlobalMemref).getName());
operands.push_back(getCastedMemRef(rewriter, crtDecompositionGlobalRef));
// crt_bits
mlir::Type crtBitsType = mlir::RankedTensorType::get(
{(int)op.crtBitsAttr().size()}, rewriter.getI64Type());
{(int)op.getCrtBitsAttr().size()}, rewriter.getI64Type());
std::vector<int64_t> crtBitsValues;
for (auto a : op.crtBits()) {
for (auto a : op.getCrtBits()) {
crtBitsValues.push_back(
a.cast<mlir::IntegerAttr>().getValue().getZExtValue());
}
@@ -439,15 +440,15 @@ void encodeLutForWopPBSAddOperands(Concrete::EncodeLutForCrtWopPBSBufferOp op,
rewriter.eraseOp(crtBitsOp);
assert(!failed(crtBitsGlobalMemref));
auto crtBitsGlobalRef = rewriter.create<memref::GetGlobalOp>(
op.getLoc(), (*crtBitsGlobalMemref).type(),
op.getLoc(), (*crtBitsGlobalMemref).getType(),
(*crtBitsGlobalMemref).getName());
operands.push_back(getCastedMemRef(rewriter, crtBitsGlobalRef));
// modulus_product
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.modulusProductAttr()));
op.getLoc(), op.getModulusProductAttr()));
// is_signed
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.isSignedAttr()));
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.getIsSignedAttr()));
}
struct ConcreteToCAPIPass : public ConcreteToCAPIBase<ConcreteToCAPIPass> {
@@ -463,7 +464,7 @@ struct ConcreteToCAPIPass : public ConcreteToCAPIBase<ConcreteToCAPIPass> {
// Mark ops from the target dialect as legal operations
target.addLegalDialect<func::FuncDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalDialect<arith::ArithmeticDialect>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
// Make sure that no ops from `FHE` remain after the lowering

View File

@@ -19,7 +19,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
namespace SDFG = mlir::concretelang::SDFG;

View File

@@ -86,7 +86,7 @@ struct DotToLinalgGeneric
// Create `linalg.generic` op
llvm::SmallVector<mlir::Type, 1> resTypes{zeroTensorOp.getType()};
llvm::SmallVector<mlir::Value, 2> ins{dotOp.lhs(), dotOp.rhs()};
llvm::SmallVector<mlir::Value, 2> ins{dotOp.getLhs(), dotOp.getRhs()};
llvm::SmallVector<mlir::Value, 1> outs{zeroTensorOp};
llvm::SmallVector<mlir::AffineMap, 3> maps{
mlir::AffineMap::getMultiDimIdentityMap(1, this->getContext()),
@@ -94,7 +94,8 @@ struct DotToLinalgGeneric
mlir::AffineMap::get(1, 0, {rewriter.getAffineConstantExpr(0)},
this->getContext())};
llvm::SmallVector<llvm::StringRef, 1> itTypes{"reduction"};
llvm::SmallVector<mlir::utils::IteratorType, 1> itTypes{
mlir::utils::IteratorType::reduction};
llvm::StringRef doc{""};
llvm::StringRef call{""};
@@ -236,10 +237,10 @@ struct FHELinalgOpToLinalgGeneric : public mlir::OpRewritePattern<FHELinalgOp> {
mlir::RankedTensorType resultTy =
((mlir::Type)linalgOp->getResult(0).getType())
.cast<mlir::RankedTensorType>();
mlir::RankedTensorType lhsTy =
((mlir::Type)linalgOp.lhs().getType()).cast<mlir::RankedTensorType>();
mlir::RankedTensorType rhsTy =
((mlir::Type)linalgOp.rhs().getType()).cast<mlir::RankedTensorType>();
mlir::RankedTensorType lhsTy = ((mlir::Type)linalgOp.getLhs().getType())
.cast<mlir::RankedTensorType>();
mlir::RankedTensorType rhsTy = ((mlir::Type)linalgOp.getRhs().getType())
.cast<mlir::RankedTensorType>();
// linalg.init_tensor for initial value
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
linalgOp.getLoc(), resultTy, mlir::ValueRange{});
@@ -252,8 +253,8 @@ struct FHELinalgOpToLinalgGeneric : public mlir::OpRewritePattern<FHELinalgOp> {
};
// Create the iterator_types
llvm::SmallVector<llvm::StringRef> iteratorTypes(resultTy.getShape().size(),
"parallel");
llvm::SmallVector<mlir::utils::IteratorType> iteratorTypes(
resultTy.getShape().size(), mlir::utils::IteratorType::parallel);
// Create the body of the `linalg.generic` op
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
@@ -268,7 +269,7 @@ struct FHELinalgOpToLinalgGeneric : public mlir::OpRewritePattern<FHELinalgOp> {
// Create the `linalg.generic` op
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
llvm::SmallVector<mlir::Value, 2> ins{linalgOp.lhs(), linalgOp.rhs()};
llvm::SmallVector<mlir::Value, 2> ins{linalgOp.getLhs(), linalgOp.getRhs()};
llvm::SmallVector<mlir::Value, 1> outs{init};
llvm::StringRef doc{""};
llvm::StringRef call{""};
@@ -288,8 +289,9 @@ template <class T> inline mlir::RankedTensorType getRankedTensorType(T v) {
return ((mlir::Type)v.getType()).cast<mlir::RankedTensorType>();
}
llvm::SmallVector<llvm::StringRef> parallelIteratorType(int n) {
return llvm::SmallVector<llvm::StringRef>(n, "parallel");
llvm::SmallVector<mlir::utils::IteratorType> parallelIteratorType(int n) {
return llvm::SmallVector<mlir::utils::IteratorType>(
n, mlir::utils::IteratorType::parallel);
}
/// This class rewrite pattern transforms any instance of
@@ -347,9 +349,9 @@ struct FHELinalgApplyMappedLookupTableToLinalgGeneric
using AffineMaps = llvm::SmallVector<mlir::AffineMap>;
using sliceArg = llvm::SmallVector<mlir::OpFoldResult>;
auto input = mappedLookup.t();
auto luts = mappedLookup.luts();
auto map = mappedLookup.map();
auto input = mappedLookup.getT();
auto luts = mappedLookup.getLuts();
auto map = mappedLookup.getMap();
auto loc = mappedLookup.getLoc();
auto tensorTy = getRankedTensorType(input);
@@ -471,9 +473,10 @@ struct FHELinalgApplyMultiLookupTableToLinalgGeneric
mlir::RankedTensorType resultTy =
((mlir::Type)fheLinalgLutOp->getResult(0).getType())
.cast<mlir::RankedTensorType>();
mlir::RankedTensorType tensorTy = ((mlir::Type)fheLinalgLutOp.t().getType())
.cast<mlir::RankedTensorType>();
auto luts = fheLinalgLutOp.luts();
mlir::RankedTensorType tensorTy =
((mlir::Type)fheLinalgLutOp.getT().getType())
.cast<mlir::RankedTensorType>();
auto luts = fheLinalgLutOp.getLuts();
mlir::RankedTensorType lutsTy = getRankedTensorType(luts);
auto lutElmtTy = lutsTy.getElementType();
// linalg.init_tensor for initial value
@@ -549,7 +552,7 @@ struct FHELinalgApplyMultiLookupTableToLinalgGeneric
// Create the `linalg.generic` op
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
llvm::SmallVector<mlir::Value> ins{fheLinalgLutOp.t()};
llvm::SmallVector<mlir::Value> ins{fheLinalgLutOp.getT()};
llvm::SmallVector<mlir::Value, 1> outs{init};
llvm::StringRef doc{""};
llvm::StringRef call{""};
@@ -619,7 +622,7 @@ struct FHELinalgApplyLookupTableToLinalgGeneric
((mlir::Type)lutOp->getResult(0).getType())
.cast<mlir::RankedTensorType>();
mlir::RankedTensorType tTy =
((mlir::Type)lutOp.t().getType()).cast<mlir::RankedTensorType>();
((mlir::Type)lutOp.getT().getType()).cast<mlir::RankedTensorType>();
// linalg.init_tensor for initial value
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
@@ -634,8 +637,8 @@ struct FHELinalgApplyLookupTableToLinalgGeneric
};
// Create the iterator_types
llvm::SmallVector<llvm::StringRef> iteratorTypes(resultTy.getShape().size(),
"parallel");
llvm::SmallVector<mlir::utils::IteratorType> iteratorTypes(
resultTy.getShape().size(), mlir::utils::IteratorType::parallel);
// Create the body of the `linalg.generic` op
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
@@ -644,7 +647,7 @@ struct FHELinalgApplyLookupTableToLinalgGeneric
mlir::concretelang::FHE::ApplyLookupTableEintOp fheOp =
nestedBuilder.create<mlir::concretelang::FHE::ApplyLookupTableEintOp>(
lutOp.getLoc(), resultTy.getElementType(), blockArgs[0],
lutOp.lut());
lutOp.getLut());
nestedBuilder.create<mlir::linalg::YieldOp>(lutOp.getLoc(),
fheOp.getResult());
@@ -652,7 +655,7 @@ struct FHELinalgApplyLookupTableToLinalgGeneric
// Create the `linalg.generic` op
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
llvm::SmallVector<mlir::Value, 1> ins{lutOp.t()};
llvm::SmallVector<mlir::Value, 1> ins{lutOp.getT()};
llvm::SmallVector<mlir::Value, 1> outs{init};
llvm::StringRef doc{""};
llvm::StringRef call{""};
@@ -716,8 +719,9 @@ struct FHELinalgNegEintToLinalgGeneric
mlir::RankedTensorType resultTy =
((mlir::Type)negEintOp->getResult(0).getType())
.cast<mlir::RankedTensorType>();
mlir::RankedTensorType tensorTy = ((mlir::Type)negEintOp.tensor().getType())
.cast<mlir::RankedTensorType>();
mlir::RankedTensorType tensorTy =
((mlir::Type)negEintOp.getTensor().getType())
.cast<mlir::RankedTensorType>();
// linalg.init_tensor for initial value
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
@@ -732,8 +736,8 @@ struct FHELinalgNegEintToLinalgGeneric
};
// Create the iterator_types
llvm::SmallVector<llvm::StringRef> iteratorTypes(resultTy.getShape().size(),
"parallel");
llvm::SmallVector<mlir::utils::IteratorType> iteratorTypes(
resultTy.getShape().size(), mlir::utils::IteratorType::parallel);
// Create the body of the `linalg.generic` op
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
@@ -749,7 +753,7 @@ struct FHELinalgNegEintToLinalgGeneric
// Create the `linalg.generic` op
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
llvm::SmallVector<mlir::Value, 1> ins{negEintOp.tensor()};
llvm::SmallVector<mlir::Value, 1> ins{negEintOp.getTensor()};
llvm::SmallVector<mlir::Value, 1> outs{init};
llvm::StringRef doc{""};
llvm::StringRef call{""};
@@ -821,8 +825,8 @@ struct FHELinalgMatmulToLinalgGeneric
mlir::Location location = matmulOp.getLoc();
mlir::Value lhs = matmulOp.lhs();
mlir::Value rhs = matmulOp.rhs();
mlir::Value lhs = matmulOp.getLhs();
mlir::Value rhs = matmulOp.getRhs();
mlir::Value out = matmulOp.getResult();
auto lhsType = ((mlir::Type)lhs.getType()).cast<mlir::RankedTensorType>();
@@ -843,7 +847,7 @@ struct FHELinalgMatmulToLinalgGeneric
auto ins = llvm::SmallVector<mlir::Value, 2>{lhs, rhs};
auto outs = llvm::SmallVector<mlir::Value, 1>{zeros};
auto iteratorTypes = llvm::SmallVector<llvm::StringRef, 3>{};
auto iteratorTypes = llvm::SmallVector<mlir::utils::IteratorType, 3>{};
auto lhsAffineExpressions = llvm::SmallVector<mlir::AffineExpr, 2>{};
auto rhsAffineExpressions = llvm::SmallVector<mlir::AffineExpr, 2>{};
@@ -878,9 +882,9 @@ struct FHELinalgMatmulToLinalgGeneric
// - Last iterator is for the reduced dimension (N in the examples)
for (int64_t i = 0; i < outDims; i++) {
iteratorTypes.push_back(mlir::getParallelIteratorTypeName());
iteratorTypes.push_back(mlir::utils::IteratorType::parallel);
}
iteratorTypes.push_back(mlir::getReductionIteratorTypeName());
iteratorTypes.push_back(mlir::utils::IteratorType::reduction);
// we need to put appropriate affine dimension expressions
// that match lhs.shape on iterator types array
@@ -988,9 +992,9 @@ struct FHELinalgMatmulToLinalgGeneric
int64_t commonDim = rhsDims - 2;
for (int64_t i = 0; i < rhsDims; i++) {
if (i == commonDim) {
iteratorTypes.push_back(mlir::getReductionIteratorTypeName());
iteratorTypes.push_back(mlir::utils::IteratorType::reduction);
} else {
iteratorTypes.push_back(mlir::getParallelIteratorTypeName());
iteratorTypes.push_back(mlir::utils::IteratorType::parallel);
}
}
@@ -1016,9 +1020,9 @@ struct FHELinalgMatmulToLinalgGeneric
// KxLxMxN @ N -> KxLxM
for (int64_t i = 0; i < lhsDims - 1; i++) {
iteratorTypes.push_back(mlir::getParallelIteratorTypeName());
iteratorTypes.push_back(mlir::utils::IteratorType::parallel);
}
iteratorTypes.push_back(mlir::getReductionIteratorTypeName());
iteratorTypes.push_back(mlir::utils::IteratorType::reduction);
for (int64_t i = 0; i < lhsDims; i++) {
lhsAffineExpressions.push_back(rewriter.getAffineDimExpr(i));
@@ -1145,7 +1149,7 @@ struct SumToLinalgGeneric
}
auto axesToDestroy = std::unordered_set<int64_t>{};
for (mlir::Attribute axisAttribute : sumOp.axes()) {
for (mlir::Attribute axisAttribute : sumOp.getAxes()) {
int64_t axis = axisAttribute.cast<mlir::IntegerAttr>().getInt();
axesToDestroy.insert(axis);
}
@@ -1178,7 +1182,7 @@ struct SumToLinalgGeneric
bool ithAxisIsDestroyed = axesToDestroy.find(i) != axesToDestroy.end();
if (!ithAxisIsDestroyed) {
outputAffineExpressions.push_back(rewriter.getAffineDimExpr(i));
} else if (sumOp.keep_dims()) {
} else if (sumOp.getKeepDims()) {
outputAffineExpressions.push_back(rewriter.getAffineConstantExpr(0));
}
}
@@ -1191,11 +1195,11 @@ struct SumToLinalgGeneric
auto maps = llvm::SmallVector<mlir::AffineMap, 2>{inputMap, outputMap};
auto iteratorTypes = llvm::SmallVector<llvm::StringRef, 3>(
inputDimensions, mlir::getParallelIteratorTypeName());
auto iteratorTypes = llvm::SmallVector<mlir::utils::IteratorType, 3>(
inputDimensions, mlir::utils::IteratorType::parallel);
for (int64_t axis : axesToDestroy) {
iteratorTypes[axis] = mlir::getReductionIteratorTypeName();
iteratorTypes[axis] = mlir::utils::IteratorType::reduction;
}
auto regionBuilder = [&](mlir::OpBuilder &nestedBuilder,
@@ -1210,11 +1214,10 @@ struct SumToLinalgGeneric
};
auto resultTypes = llvm::SmallVector<mlir::Type, 1>{accumulatorType};
mlir::Value accumulation =
rewriter
.create<linalg::GenericOp>(location, resultTypes, ins, outs, maps,
iteratorTypes, regionBuilder)
.getResult(0);
linalg::GenericOp genericOp = rewriter.create<linalg::GenericOp>(
location, resultTypes, ins, outs, maps, iteratorTypes, regionBuilder);
mlir::Value accumulation = genericOp.getResult(0);
mlir::Value result = accumulation;
if (!outputIsTensor) {
@@ -1282,7 +1285,7 @@ struct TransposeToLinalgGeneric
std::vector<unsigned int> perms = {};
mlir::ArrayAttr axes = transposeOp.axes();
mlir::ArrayAttr axes = transposeOp.getAxes();
if (axes.empty()) {
for (int i = n_dim - 1; i >= 0; i--) {
perms.push_back(i);
@@ -1386,7 +1389,7 @@ struct ConcatRewritePattern
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
size_t axis = op.axis();
size_t axis = op.getAxis();
mlir::Value output = op.getResult();
auto outputType = output.getType().dyn_cast<mlir::TensorType>();
@@ -1452,9 +1455,11 @@ struct ConcatRewritePattern
sizes[axis] = axisSize;
// these arrays are copied, so it's fine to modify and use them again
mlir::ArrayAttr offsetsAttr = rewriter.getI64ArrayAttr(offsets);
mlir::ArrayAttr sizesAttr = rewriter.getI64ArrayAttr(sizes);
mlir::ArrayAttr stridesAttr = rewriter.getI64ArrayAttr(strides);
mlir::DenseI64ArrayAttr offsetsAttr =
rewriter.getDenseI64ArrayAttr(offsets);
mlir::DenseI64ArrayAttr sizesAttr = rewriter.getDenseI64ArrayAttr(sizes);
mlir::DenseI64ArrayAttr stridesAttr =
rewriter.getDenseI64ArrayAttr(strides);
offsets[axis] += axisSize;
@@ -1497,9 +1502,10 @@ getPaddedTensor(mlir::Operation *op, mlir::OpBuilder &b, mlir::Value &input,
getAsOpFoldResult(b, loc, lowPaddingInts);
mlir::SmallVector<mlir::OpFoldResult> highPaddings =
getAsOpFoldResult(b, loc, highPaddingInts);
mlir::Value paddedInput = mlir::tensor::createPadScalarOp(
rankedTensorType, input, pad, /*low=*/lowPaddings, /*high=*/highPaddings,
/*packing=*/false, loc, b);
mlir::Value paddedInput = b.create<mlir::tensor::PadOp>(
loc, rankedTensorType, input, lowPaddings, highPaddings, pad);
return paddedInput;
}
@@ -1665,9 +1671,9 @@ struct FHELinalgConv2dToLinalgConv2d
mlir::Location loc = conv2dOp->getLoc();
mlir::Value input =
conv2dOp.input(); /* shape: Batch*Channels*Height*Width */
conv2dOp.getInput(); /* shape: Batch*Channels*Height*Width */
mlir::Value weight =
conv2dOp.weight(); /* shape: Filters*Channels*Height*Width */
conv2dOp.getWeight(); /* shape: Filters*Channels*Height*Width */
mlir::Type inputElementType =
input.getType().cast<mlir::RankedTensorType>().getElementType();
@@ -1714,7 +1720,7 @@ struct FHELinalgConv2dToLinalgConv2d
// Since linalg doesn't support a bias in the conv operation, we initialize
// the output tensor to the bias values, so that conv results get
// accumulated to it
mlir::Value bias = conv2dOp.bias(); /* optional of shape: Filters */
mlir::Value bias = conv2dOp.getBias(); /* optional of shape: Filters */
mlir::Value biasInitTensor;
if (!bias) { // no bias was used
biasInitTensor = initTensor;
@@ -1726,7 +1732,8 @@ struct FHELinalgConv2dToLinalgConv2d
mlir::AffineMap::get(resultRank, 0, rewriter.getAffineDimExpr(1),
rewriter.getContext()),
rewriter.getMultiDimIdentityMap(resultRank)};
mlir::SmallVector<llvm::StringRef> iteratorTypes(resultRank, "parallel");
mlir::SmallVector<mlir::utils::IteratorType> iteratorTypes(
resultRank, mlir::utils::IteratorType::parallel);
biasInitTensor =
rewriter
.create<mlir::linalg::GenericOp>(
@@ -1817,14 +1824,15 @@ struct FHELinalgMaxpool2dToLinalgMaxpool2d
output = rewriter.create<FHELinalg::SubEintIntOp>(loc, output, offset);
}
const mlir::DenseElementsAttr kernelShapeAttr = maxpool2dOp.kernel_shape();
const mlir::DenseElementsAttr kernelShapeAttr =
maxpool2dOp.getKernelShape();
const auto kernelShape =
llvm::SmallVector<int64_t, 2>(kernelShapeAttr.value_begin<int64_t>(),
kernelShapeAttr.value_end<int64_t>());
const mlir::Value kernel =
rewriter
.create<mlir::linalg::InitTensorOp>(
.create<mlir::tensor::EmptyOp>(
loc, kernelShape,
mlir::IntegerType::get(this->getContext(), 64))
.getResult();
@@ -1833,12 +1841,12 @@ struct FHELinalgMaxpool2dToLinalgMaxpool2d
rewriter.getI64VectorAttr({1, 1});
const mlir::DenseIntElementsAttr stridesAttr =
maxpool2dOp.dilations().getValueOr(defaultAttr);
maxpool2dOp.getDilations().value_or(defaultAttr);
const mlir::DenseIntElementsAttr dilationsAttr =
maxpool2dOp.dilations().getValueOr(defaultAttr);
maxpool2dOp.getDilations().value_or(defaultAttr);
rewriter.replaceOpWithNewOp<mlir::linalg::PoolingNchwMaxOp>(
maxpool2dOp, outputTy, mlir::ValueRange{maxpool2dOp.input(), kernel},
maxpool2dOp, outputTy, mlir::ValueRange{maxpool2dOp.getInput(), kernel},
output, stridesAttr, dilationsAttr,
llvm::ArrayRef<mlir::NamedAttribute>({maxOpAttr}));
@@ -1892,7 +1900,7 @@ struct FHELinalgToSignedToLinalgGeneric
mlir::PatternRewriter &rewriter) const override {
mlir::RankedTensorType inputTy =
op.input().getType().cast<mlir::RankedTensorType>();
op.getInput().getType().cast<mlir::RankedTensorType>();
mlir::RankedTensorType resultTy =
op->getResult(0).getType().cast<mlir::RankedTensorType>();
@@ -1906,8 +1914,8 @@ struct FHELinalgToSignedToLinalgGeneric
this->getContext()),
};
llvm::SmallVector<llvm::StringRef> iteratorTypes(resultTy.getShape().size(),
"parallel");
llvm::SmallVector<mlir::utils::IteratorType> iteratorTypes(
resultTy.getShape().size(), mlir::utils::IteratorType::parallel);
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
@@ -1920,7 +1928,7 @@ struct FHELinalgToSignedToLinalgGeneric
};
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
llvm::SmallVector<mlir::Value, 1> ins{op.input()};
llvm::SmallVector<mlir::Value, 1> ins{op.getInput()};
llvm::SmallVector<mlir::Value, 1> outs{init};
llvm::StringRef doc{""};
@@ -1981,7 +1989,7 @@ struct FHELinalgToUnsignedToLinalgGeneric
mlir::PatternRewriter &rewriter) const override {
mlir::RankedTensorType inputTy =
op.input().getType().cast<mlir::RankedTensorType>();
op.getInput().getType().cast<mlir::RankedTensorType>();
mlir::RankedTensorType resultTy =
op->getResult(0).getType().cast<mlir::RankedTensorType>();
@@ -1995,8 +2003,8 @@ struct FHELinalgToUnsignedToLinalgGeneric
this->getContext()),
};
llvm::SmallVector<llvm::StringRef> iteratorTypes(resultTy.getShape().size(),
"parallel");
llvm::SmallVector<mlir::utils::IteratorType> iteratorTypes(
resultTy.getShape().size(), mlir::utils::IteratorType::parallel);
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
@@ -2009,7 +2017,7 @@ struct FHELinalgToUnsignedToLinalgGeneric
};
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
llvm::SmallVector<mlir::Value, 1> ins{op.input()};
llvm::SmallVector<mlir::Value, 1> ins{op.getInput()};
llvm::SmallVector<mlir::Value, 1> outs{init};
llvm::StringRef doc{""};
@@ -2040,7 +2048,7 @@ void FHETensorOpsToLinalg::runOnOperation() {
target.addLegalDialect<mlir::memref::MemRefDialect>();
target.addLegalDialect<mlir::concretelang::FHE::FHEDialect>();
target.addLegalDialect<mlir::tensor::TensorDialect>();
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
target.addLegalDialect<mlir::arith::ArithDialect>();
target.addIllegalOp<mlir::concretelang::FHELinalg::Dot>();
target.addIllegalDialect<mlir::concretelang::FHELinalg::FHELinalgDialect>();

View File

@@ -4,7 +4,7 @@
// for license information.
#include <iostream>
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/IR/Operation.h>
@@ -195,8 +195,8 @@ struct AddEintIntOpPattern : public CrtOpPattern<FHE::AddEintIntOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
mlir::Location location = op.getLoc();
mlir::Value eintOperand = adaptor.a();
mlir::Value intOperand = adaptor.b();
mlir::Value eintOperand = adaptor.getA();
mlir::Value intOperand = adaptor.getB();
// Write plaintext encoding
mlir::Value encodedPlaintextTensor =
@@ -242,8 +242,8 @@ struct SubIntEintOpPattern : public CrtOpPattern<FHE::SubIntEintOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
mlir::Location location = op.getLoc();
mlir::Value intOperand = adaptor.a();
mlir::Value eintOperand = adaptor.b();
mlir::Value intOperand = adaptor.getA();
mlir::Value eintOperand = adaptor.getB();
// Write plaintext encoding
mlir::Value encodedPlaintextTensor =
@@ -289,8 +289,8 @@ struct SubEintIntOpPattern : public CrtOpPattern<FHE::SubEintIntOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
mlir::Location location = op.getLoc();
mlir::Value eintOperand = adaptor.a();
mlir::Value intOperand = adaptor.b();
mlir::Value eintOperand = adaptor.getA();
mlir::Value intOperand = adaptor.getB();
// Write plaintext negation
mlir::Type intType = intOperand.getType();
@@ -346,8 +346,8 @@ struct AddEintOpPattern : CrtOpPattern<FHE::AddEintOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
mlir::Location location = op.getLoc();
mlir::Value lhsOperand = adaptor.a();
mlir::Value rhsOperand = adaptor.b();
mlir::Value lhsOperand = adaptor.getA();
mlir::Value rhsOperand = adaptor.getB();
// Write add loop.
mlir::Type ciphertextScalarType =
@@ -389,8 +389,8 @@ struct SubEintOpPattern : CrtOpPattern<FHE::SubEintOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
mlir::Location location = op.getLoc();
mlir::Value lhsOperand = adaptor.a();
mlir::Value rhsOperand = adaptor.b();
mlir::Value lhsOperand = adaptor.getA();
mlir::Value rhsOperand = adaptor.getB();
// Write sub loop.
mlir::Type ciphertextScalarType =
@@ -434,7 +434,7 @@ struct NegEintOpPattern : CrtOpPattern<FHE::NegEintOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
mlir::Location location = op.getLoc();
mlir::Value operand = adaptor.a();
mlir::Value operand = adaptor.getA();
// Write the loop nest.
mlir::Type ciphertextScalarType = converter->convertType(operand.getType())
@@ -472,7 +472,7 @@ struct ToSignedOpPattern : public CrtOpPattern<FHE::ToSignedOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
typing::TypeConverter converter{loweringParameters};
rewriter.replaceOp(op, {adaptor.input()});
rewriter.replaceOp(op, {adaptor.getInput()});
return mlir::success();
}
@@ -490,7 +490,7 @@ struct ToUnsignedOpPattern : public CrtOpPattern<FHE::ToUnsignedOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
typing::TypeConverter converter{loweringParameters};
rewriter.replaceOp(op, {adaptor.input()});
rewriter.replaceOp(op, {adaptor.getInput()});
return mlir::success();
}
@@ -509,8 +509,8 @@ struct MulEintIntOpPattern : CrtOpPattern<FHE::MulEintIntOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
mlir::Location location = op.getLoc();
mlir::Value eintOperand = adaptor.a();
mlir::Value intOperand = adaptor.b();
mlir::Value eintOperand = adaptor.getA();
mlir::Value intOperand = adaptor.getB();
// Write cleartext "encoding"
mlir::Value encodedCleartext = rewriter.create<mlir::arith::ExtSIOp>(
@@ -556,7 +556,8 @@ struct ApplyLookupTableEintOpPattern
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
auto originalInputType = op.a().getType().cast<FHE::FheIntegerInterface>();
auto originalInputType =
op.getA().getType().cast<FHE::FheIntegerInterface>();
mlir::Value newLut =
rewriter
@@ -567,7 +568,7 @@ struct ApplyLookupTableEintOpPattern
(int64_t)loweringParameters.nMods,
(int64_t)loweringParameters.singleLutSize},
rewriter.getI64Type()),
adaptor.lut(),
adaptor.getLut(),
rewriter.getI64ArrayAttr(
mlir::ArrayRef<int64_t>(loweringParameters.mods)),
rewriter.getI64ArrayAttr(
@@ -578,8 +579,9 @@ struct ApplyLookupTableEintOpPattern
// Replace the lut with an encoded / expanded one.
auto wopPBS = rewriter.create<TFHE::WopPBSGLWEOp>(
op.getLoc(), converter->convertType(op.getType()), adaptor.a(), newLut,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, rewriter.getI64ArrayAttr({}));
op.getLoc(), converter->convertType(op.getType()), adaptor.getA(),
newLut, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
rewriter.getI64ArrayAttr({}));
rewriter.replaceOp(op, {wopPBS.getResult()});
return ::mlir::success();
@@ -601,25 +603,25 @@ struct TraceCiphertextOpPattern : CrtOpPattern<Tracing::TraceCiphertextOp> {
typing::TypeConverter converter{loweringParameters};
mlir::Type ciphertextScalarType =
converter.convertType(op.ciphertext().getType())
converter.convertType(op.getCiphertext().getType())
.cast<mlir::RankedTensorType>()
.getElementType();
for (size_t i = 0; i < (loweringParameters.nMods - 1); ++i) {
auto extractedCiphertext = rewriter.create<mlir::tensor::ExtractOp>(
op.getLoc(), ciphertextScalarType, adaptor.ciphertext(),
op.getLoc(), ciphertextScalarType, adaptor.getCiphertext(),
mlir::ValueRange{rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(i))});
rewriter.create<Tracing::TraceCiphertextOp>(
op.getLoc(), extractedCiphertext, op.msgAttr(), op.nmsbAttr());
op.getLoc(), extractedCiphertext, op.getMsgAttr(), op.getNmsbAttr());
}
auto extractedCiphertext = rewriter.create<mlir::tensor::ExtractOp>(
op.getLoc(), ciphertextScalarType, adaptor.ciphertext(),
op.getLoc(), ciphertextScalarType, adaptor.getCiphertext(),
mlir::ValueRange{rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(loweringParameters.nMods - 1))});
rewriter.replaceOpWithNewOp<Tracing::TraceCiphertextOp>(
op, extractedCiphertext, op.msgAttr(), op.nmsbAttr());
op, extractedCiphertext, op.getMsgAttr(), op.getNmsbAttr());
return mlir::success();
}
@@ -813,7 +815,7 @@ struct TensorReassociationOpPattern : public CrtOpPattern<Op> {
mlir::TypeConverter *converter = this->getTypeConverter();
auto reassocVal = (inRank ? adaptor.src() : op.result());
auto reassocVal = (inRank ? adaptor.getSrc() : op.getResult());
auto reassocTy = reassocVal.getType();
auto newReassocType = converter->convertType(reassocTy);
@@ -827,7 +829,7 @@ struct TensorReassociationOpPattern : public CrtOpPattern<Op> {
auto newOp = rewriter.create<Op>(
op.getLoc(), converter->convertType(op.getResult().getType()),
adaptor.src(), newReassocs);
adaptor.getSrc(), newReassocs);
rewriter.replaceOp(op, {newOp});
return mlir::success();
@@ -848,25 +850,22 @@ struct ExtractSliceOpPattern
mlir::TypeConverter *converter = this->getTypeConverter();
mlir::SmallVector<mlir::Attribute> newStaticOffsets{
op.static_offsets().template getAsRange<mlir::IntegerAttr>()};
mlir::SmallVector<mlir::Attribute> newStaticSizes{
op.static_sizes().template getAsRange<mlir::IntegerAttr>()};
mlir::SmallVector<mlir::Attribute> newStaticStrides{
op.static_strides().template getAsRange<mlir::IntegerAttr>()};
newStaticOffsets.push_back(rewriter.getI64IntegerAttr(0));
newStaticSizes.push_back(
rewriter.getI64IntegerAttr(this->loweringParameters.nMods));
newStaticStrides.push_back(rewriter.getI64IntegerAttr(1));
mlir::SmallVector<int64_t> newStaticOffsets{op.static_offsets()};
mlir::SmallVector<int64_t> newStaticSizes{op.static_sizes()};
mlir::SmallVector<int64_t> newStaticStrides{op.static_strides()};
newStaticOffsets.push_back(0);
newStaticSizes.push_back(this->loweringParameters.nMods);
newStaticStrides.push_back(1);
mlir::RankedTensorType newType =
converter->convertType(op.getResult().getType())
.template cast<mlir::RankedTensorType>();
rewriter.replaceOpWithNewOp<mlir::tensor::ExtractSliceOp>(
op, newType, adaptor.source(), adaptor.getOffsets(), adaptor.getSizes(),
adaptor.getStrides(), rewriter.getArrayAttr(newStaticOffsets),
rewriter.getArrayAttr(newStaticSizes),
rewriter.getArrayAttr(newStaticStrides));
op, newType, adaptor.getSource(), adaptor.getOffsets(),
adaptor.getSizes(), adaptor.getStrides(),
rewriter.getDenseI64ArrayAttr(newStaticOffsets),
rewriter.getDenseI64ArrayAttr(newStaticSizes),
rewriter.getDenseI64ArrayAttr(newStaticStrides));
return mlir::success();
};
@@ -885,26 +884,22 @@ struct InsertSliceOpPattern : public CrtOpPattern<mlir::tensor::InsertSliceOp> {
mlir::TypeConverter *converter = this->getTypeConverter();
mlir::SmallVector<mlir::Attribute> newStaticOffsets{
op.static_offsets().template getAsRange<mlir::IntegerAttr>()};
mlir::SmallVector<mlir::Attribute> newStaticSizes{
op.static_sizes().template getAsRange<mlir::IntegerAttr>()};
mlir::SmallVector<mlir::Attribute> newStaticStrides{
op.static_strides().template getAsRange<mlir::IntegerAttr>()};
newStaticOffsets.push_back(rewriter.getI64IntegerAttr(0));
newStaticSizes.push_back(
rewriter.getI64IntegerAttr(this->loweringParameters.nMods));
newStaticStrides.push_back(rewriter.getI64IntegerAttr(1));
mlir::SmallVector<int64_t> newStaticOffsets{op.static_offsets()};
mlir::SmallVector<int64_t> newStaticSizes{op.static_sizes()};
mlir::SmallVector<int64_t> newStaticStrides{op.static_strides()};
newStaticOffsets.push_back(0);
newStaticSizes.push_back(this->loweringParameters.nMods);
newStaticStrides.push_back(1);
mlir::RankedTensorType newType =
converter->convertType(op.getResult().getType())
.template cast<mlir::RankedTensorType>();
rewriter.replaceOpWithNewOp<mlir::tensor::InsertSliceOp>(
op, newType, adaptor.source(), adaptor.dest(), adaptor.getOffsets(),
adaptor.getSizes(), adaptor.getStrides(),
rewriter.getArrayAttr(newStaticOffsets),
rewriter.getArrayAttr(newStaticSizes),
rewriter.getArrayAttr(newStaticStrides));
op, newType, adaptor.getSource(), adaptor.getDest(),
adaptor.getOffsets(), adaptor.getSizes(), adaptor.getStrides(),
rewriter.getDenseI64ArrayAttr(newStaticOffsets),
rewriter.getDenseI64ArrayAttr(newStaticSizes),
rewriter.getDenseI64ArrayAttr(newStaticStrides));
return mlir::success();
};
@@ -926,7 +921,7 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
//------------------------------------------- Marking legal/illegal dialects
target.addIllegalDialect<FHE::FHEDialect>();
target.addLegalDialect<TFHE::TFHEDialect>();
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
target.addLegalDialect<mlir::arith::ArithDialect>();
target.addDynamicallyLegalOp<mlir::tensor::GenerateOp, mlir::scf::ForOp>(
[&](mlir::Operation *op) {
return (

View File

@@ -4,7 +4,7 @@
// for license information.
#include <iostream>
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Linalg/IR/Linalg.h>
@@ -146,12 +146,12 @@ struct AddEintIntOpPattern : public ScalarOpPattern<FHE::AddEintIntOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
// Write the plaintext encoding
mlir::Value encodedInt = writePlaintextShiftEncoding(
op.getLoc(), adaptor.b(),
op.getLoc(), adaptor.getB(),
op.getType().cast<FHE::FheIntegerInterface>().getWidth(), rewriter);
// Write the new op
rewriter.replaceOpWithNewOp<TFHE::AddGLWEIntOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.a(),
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
encodedInt);
return mlir::success();
@@ -169,8 +169,8 @@ struct SubEintIntOpPattern : public ScalarOpPattern<FHE::SubEintIntOp> {
matchAndRewrite(FHE::SubEintIntOp op, FHE::SubEintIntOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value eintOperand = op.a();
mlir::Value intOperand = op.b();
mlir::Value eintOperand = op.getA();
mlir::Value intOperand = op.getB();
// Write the integer negation
mlir::Type intType = intOperand.getType();
@@ -190,7 +190,7 @@ struct SubEintIntOpPattern : public ScalarOpPattern<FHE::SubEintIntOp> {
// Write the new op
rewriter.replaceOpWithNewOp<TFHE::AddGLWEIntOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.a(),
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
encodedInt);
return mlir::success();
@@ -209,13 +209,14 @@ struct SubIntEintOpPattern : public ScalarOpPattern<FHE::SubIntEintOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
// Write the plaintext encoding
mlir::Value encodedInt = writePlaintextShiftEncoding(
op.getLoc(), adaptor.a(),
op.b().getType().cast<FHE::FheIntegerInterface>().getWidth(), rewriter);
op.getLoc(), adaptor.getA(),
op.getB().getType().cast<FHE::FheIntegerInterface>().getWidth(),
rewriter);
// Write the new op
rewriter.replaceOpWithNewOp<TFHE::SubGLWEIntOp>(
op, getTypeConverter()->convertType(op.getType()), encodedInt,
adaptor.b());
adaptor.getB());
return mlir::success();
};
@@ -231,8 +232,8 @@ struct SubEintOpPattern : public ScalarOpPattern<FHE::SubEintOp> {
matchAndRewrite(FHE::SubEintOp op, FHE::SubEintOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value lhsOperand = adaptor.a();
mlir::Value rhsOperand = adaptor.b();
mlir::Value lhsOperand = adaptor.getA();
mlir::Value rhsOperand = adaptor.getB();
// Write rhs negation
auto negative = rewriter.create<TFHE::NegGLWEOp>(
@@ -259,8 +260,8 @@ struct MulEintIntOpPattern : public ScalarOpPattern<FHE::MulEintIntOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value eintOperand = adaptor.a();
mlir::Value intOperand = adaptor.b();
mlir::Value eintOperand = adaptor.getA();
mlir::Value intOperand = adaptor.getB();
// Write the cleartext "encoding"
mlir::Value castedCleartext = rewriter.create<mlir::arith::ExtSIOp>(
@@ -286,7 +287,7 @@ struct ToSignedOpPattern : public ScalarOpPattern<FHE::ToSignedOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
typing::TypeConverter converter;
rewriter.replaceOp(op, {adaptor.input()});
rewriter.replaceOp(op, {adaptor.getInput()});
return mlir::success();
}
@@ -304,7 +305,7 @@ struct ToUnsignedOpPattern : public ScalarOpPattern<FHE::ToUnsignedOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
typing::TypeConverter converter;
rewriter.replaceOp(op, {adaptor.input()});
rewriter.replaceOp(op, {adaptor.getInput()});
return mlir::success();
}
@@ -326,7 +327,7 @@ struct ApplyLookupTableEintOpPattern
FHE::ApplyLookupTableEintOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto inputType = op.a().getType().cast<FHE::FheIntegerInterface>();
auto inputType = op.getA().getType().cast<FHE::FheIntegerInterface>();
size_t outputBits =
op.getResult().getType().cast<FHE::FheIntegerInterface>().getWidth();
mlir::Value newLut =
@@ -336,14 +337,14 @@ struct ApplyLookupTableEintOpPattern
mlir::RankedTensorType::get(
mlir::ArrayRef<int64_t>(loweringParameters.polynomialSize),
rewriter.getI64Type()),
op.lut(),
op.getLut(),
rewriter.getI32IntegerAttr(loweringParameters.polynomialSize),
rewriter.getI32IntegerAttr(outputBits),
rewriter.getBoolAttr(inputType.isSigned()))
.getResult();
typing::TypeConverter converter;
mlir::Value input = adaptor.a();
mlir::Value input = adaptor.getA();
if (inputType.isSigned()) {
// If the input is a signed integer, it comes to the bootstrap with a
@@ -367,7 +368,7 @@ struct ApplyLookupTableEintOpPattern
// Insert keyswitch
auto ksOp = rewriter.create<TFHE::KeySwitchGLWEOp>(
op.getLoc(), getTypeConverter()->convertType(adaptor.a().getType()),
op.getLoc(), getTypeConverter()->convertType(adaptor.getA().getType()),
input, -1, -1);
// Insert bootstrap
@@ -409,8 +410,8 @@ struct RoundEintOpPattern : public ScalarOpPattern<FHE::RoundEintOp> {
// -> Subtract this one from the input by performing a
// homomorphic subtraction.
mlir::Value input = adaptor.input();
auto inputType = op.input().getType().cast<FHE::FheIntegerInterface>();
mlir::Value input = adaptor.getInput();
auto inputType = op.getInput().getType().cast<FHE::FheIntegerInterface>();
mlir::Value output = op.getResult();
uint64_t inputBitwidth = inputType.getWidth();
uint64_t outputBitwidth =
@@ -594,12 +595,12 @@ struct ToBoolOpPattern : public mlir::OpRewritePattern<FHE::ToBoolOp> {
mlir::LogicalResult
matchAndRewrite(FHE::ToBoolOp op,
mlir::PatternRewriter &rewriter) const override {
auto width = op.input()
auto width = op.getInput()
.getType()
.dyn_cast<mlir::concretelang::FHE::EncryptedIntegerType>()
.getWidth();
if (width == mlir::concretelang::FHE::EncryptedBooleanType::getWidth()) {
rewriter.replaceOp(op, op.input());
rewriter.replaceOp(op, op.getInput());
return mlir::success();
}
// TODO
@@ -622,7 +623,7 @@ struct FromBoolOpPattern : public mlir::OpRewritePattern<FHE::FromBoolOp> {
.dyn_cast<mlir::concretelang::FHE::EncryptedIntegerType>()
.getWidth();
if (width == mlir::concretelang::FHE::EncryptedBooleanType::getWidth()) {
rewriter.replaceOp(op, op.input());
rewriter.replaceOp(op, op.getInput());
return mlir::success();
}
// TODO
@@ -647,7 +648,7 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
//------------------------------------------- Marking legal/illegal dialects
target.addIllegalDialect<FHE::FHEDialect>();
target.addLegalDialect<TFHE::TFHEDialect>();
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
target.addLegalDialect<mlir::arith::ArithDialect>();
target.addDynamicallyLegalOp<mlir::linalg::GenericOp,
mlir::tensor::GenerateOp>(
[&](mlir::Operation *op) {

View File

@@ -43,7 +43,7 @@ public:
if (((mlir::LogicalResult)loops).failed() || loops->size() == 0)
return mlir::failure();
rewriter.replaceOp(linalgOp, loops.getValue()[0]->getResult(0));
rewriter.replaceOp(linalgOp, loops.value()[0]->getResult(0));
return mlir::success();
};

View File

@@ -5,11 +5,13 @@
#include <iostream>
#include <mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h>
#include <mlir/IR/BuiltinTypeInterfaces.h>
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
@@ -21,6 +23,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Pass/Pass.h"
@@ -40,7 +43,7 @@ struct MLIRLowerableDialectsToLLVMPass
void runOnOperation() final;
/// Convert types to the LLVM dialect-compatible type
static llvm::Optional<mlir::Type> convertTypes(mlir::Type type);
static std::optional<mlir::Type> convertTypes(mlir::Type type);
};
} // namespace
@@ -73,11 +76,12 @@ struct Memref1DCopyOpPattern
mlir::LogicalResult
matchAndRewrite(mlir::memref::CopyOp copyOp,
mlir::PatternRewriter &rewriter) const override {
if (copyOp.source().getType().cast<mlir::MemRefType>().getRank() != 1 ||
copyOp.source().getType().cast<mlir::MemRefType>().getRank() != 1) {
if (copyOp.getSource().getType().cast<mlir::MemRefType>().getRank() != 1 ||
copyOp.getSource().getType().cast<mlir::MemRefType>().getRank() != 1) {
return mlir::failure();
}
auto opType = mlir::MemRefType::get({-1}, rewriter.getI64Type());
auto opType = mlir::MemRefType::get({mlir::ShapedType::kDynamic},
rewriter.getI64Type());
// Insert forward declaration of the add_lwe_ciphertexts function
{
if (insertForwardDeclaration(
@@ -89,9 +93,9 @@ struct Memref1DCopyOpPattern
}
}
auto sourceOp = rewriter.create<mlir::memref::CastOp>(
copyOp.getLoc(), opType, copyOp.source());
copyOp.getLoc(), opType, copyOp.getSource());
auto targetOp = rewriter.create<mlir::memref::CastOp>(
copyOp.getLoc(), opType, copyOp.target());
copyOp.getLoc(), opType, copyOp.getTarget());
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
copyOp, "memref_copy_one_rank", mlir::TypeRange{},
mlir::ValueRange{sourceOp, targetOp});
@@ -119,9 +123,10 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
mlir::concretelang::populateRTToLLVMConversionPatterns(typeConverter,
patterns);
mlir::populateFuncToLLVMConversionPatterns(typeConverter, patterns);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns);
mlir::populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
mlir::memref::populateExpandStridedMetadataPatterns(patterns);
mlir::populateAffineToStdConversionPatterns(patterns);
mlir::populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
mlir::populateSCFToControlFlowConversionPatterns(patterns);
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
patterns);
@@ -143,7 +148,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
}
}
llvm::Optional<mlir::Type>
std::optional<mlir::Type>
MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) {
if (type.isa<mlir::concretelang::Concrete::ContextType>() ||
type.isa<mlir::concretelang::RT::FutureType>() ||
@@ -161,7 +166,7 @@ MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) {
mlir::Type convertedSubtype = typeConverter.convertType(subtype);
return mlir::LLVM::LLVMPointerType::get(convertedSubtype);
}
return llvm::None;
return std::nullopt;
}
namespace mlir {

View File

@@ -47,7 +47,7 @@ char stream_emulator_put_uint64[] = "stream_emulator_put_uint64";
char stream_emulator_get_uint64[] = "stream_emulator_get_uint64";
mlir::Type getDynamicTensor(mlir::OpBuilder &rewriter, size_t rank) {
std::vector<int64_t> shape(rank, -1);
std::vector<int64_t> shape(rank, mlir::ShapedType::kDynamic);
return mlir::RankedTensorType::get(shape, rewriter.getI64Type());
}
@@ -166,7 +166,7 @@ struct LowerSDFGMakeProcess
::mlir::PatternRewriter &rewriter) const override {
const char *funcName;
mlir::SmallVector<mlir::Value> operands(mpOp->getOperands());
switch (mpOp.type()) {
switch (mpOp.getType()) {
case SDFG::ProcessKind::add_eint:
funcName = stream_emulator_make_memref_add_lwe_ciphertexts_u64_process;
break;
@@ -249,7 +249,7 @@ struct LowerSDFGMakeStream
const char *funcName;
stream_type t;
switch (msOp.type()) {
switch (msOp.getType()) {
case SDFG::StreamKind::host_to_device:
t = TS_STREAM_TYPE_X86_TO_TOPO_LSAP;
break;
@@ -370,7 +370,7 @@ void SDFGToStreamEmulatorPass::runOnOperation() {
SDFG::MakeProcess, SDFG::MakeStream, SDFG::Put>();
// All Concrete ops are legal after the conversion
target.addLegalDialect<mlir::concretelang::Concrete::ConcreteDialect>();
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
target.addLegalDialect<mlir::arith::ArithDialect>();
target.addLegalOp<mlir::func::ReturnOp, mlir::func::FuncOp,
mlir::func::CallOp, SDFG::Get, mlir::tensor::CastOp>();

View File

@@ -114,15 +114,17 @@ struct KeySwitchGLWEOpPattern
mlir::LogicalResult
matchAndRewrite(TFHE::KeySwitchGLWEOp ksOp,
mlir::PatternRewriter &rewriter) const override {
auto inputTy = ksOp.ciphertext().getType().cast<TFHE::GLWECipherTextType>();
auto newInputTy = converter.convertType(inputTy);
auto outputTy = ksOp.result().getType().cast<TFHE::GLWECipherTextType>();
auto inputTy =
ksOp.getCiphertext().getType().cast<TFHE::GLWECipherTextType>();
auto newInputTy = converter.convertType(inputTy)
.cast<mlir::concretelang::TFHE::GLWECipherTextType>();
auto outputTy = ksOp.getResult().getType().cast<TFHE::GLWECipherTextType>();
auto newOutputTy = converter.glweIntraPBSType(outputTy);
auto newOp = rewriter.replaceOpWithNewOp<TFHE::KeySwitchGLWEOp>(
ksOp, newOutputTy, ksOp.ciphertext(), cryptoParameters.ksLevel,
ksOp, newOutputTy, ksOp.getCiphertext(), cryptoParameters.ksLevel,
cryptoParameters.ksLogBase);
rewriter.startRootUpdate(newOp);
newOp.ciphertext().setType(newInputTy);
newOp.getCiphertext().setType(newInputTy);
rewriter.finalizeRootUpdate(newOp);
return mlir::success();
};
@@ -145,16 +147,17 @@ struct BootstrapGLWEOpPattern
mlir::LogicalResult
matchAndRewrite(TFHE::BootstrapGLWEOp bsOp,
mlir::PatternRewriter &rewriter) const override {
auto inputTy = bsOp.ciphertext().getType().cast<TFHE::GLWECipherTextType>();
auto inputTy =
bsOp.getCiphertext().getType().cast<TFHE::GLWECipherTextType>();
auto newInputTy = converter.glweIntraPBSType(inputTy);
auto outputTy = bsOp.result().getType().cast<TFHE::GLWECipherTextType>();
auto outputTy = bsOp.getResult().getType().cast<TFHE::GLWECipherTextType>();
auto newOutputTy = converter.convertType(outputTy);
auto newOp = rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
bsOp, newOutputTy, bsOp.ciphertext(), bsOp.lookup_table(),
bsOp, newOutputTy, bsOp.getCiphertext(), bsOp.getLookupTable(),
cryptoParameters.brLevel, cryptoParameters.brLogBase,
cryptoParameters.getPolynomialSize(), cryptoParameters.glweDimension);
rewriter.startRootUpdate(newOp);
newOp.ciphertext().setType(newInputTy);
newOp.getCiphertext().setType(newInputTy);
rewriter.finalizeRootUpdate(newOp);
return mlir::success();
};
@@ -177,8 +180,8 @@ struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
matchAndRewrite(TFHE::WopPBSGLWEOp wopPBSOp,
mlir::PatternRewriter &rewriter) const override {
auto newOp = rewriter.replaceOpWithNewOp<TFHE::WopPBSGLWEOp>(
wopPBSOp, converter.convertType(wopPBSOp.result().getType()),
wopPBSOp.ciphertexts(), wopPBSOp.lookupTable(),
wopPBSOp, converter.convertType(wopPBSOp.getResult().getType()),
wopPBSOp.getCiphertexts(), wopPBSOp.getLookupTable(),
// Bootstrap parameters
cryptoParameters.brLevel, cryptoParameters.brLogBase,
// Keyswitch parameters
@@ -198,12 +201,12 @@ struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
cryptoParameters.largeInteger->crtDecomposition));
rewriter.startRootUpdate(newOp);
auto ctType =
wopPBSOp.ciphertexts().getType().cast<mlir::RankedTensorType>();
wopPBSOp.getCiphertexts().getType().cast<mlir::RankedTensorType>();
auto ciphertextType =
ctType.getElementType().cast<TFHE::GLWECipherTextType>();
auto newType = mlir::RankedTensorType::get(
ctType.getShape(), converter.glweInterPBSType(ciphertextType));
newOp.ciphertexts().setType(newType);
newOp.getCiphertexts().setType(newType);
rewriter.finalizeRootUpdate(newOp);
return mlir::success();
};
@@ -279,7 +282,8 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
cryptoParameters);
target.addDynamicallyLegalOp<TFHE::KeySwitchGLWEOp>(
[&](TFHE::KeySwitchGLWEOp op) {
return op.level() != (uint32_t)-1 && op.baseLog() != (uint32_t)-1;
return op.getLevel() != (uint32_t)-1 &&
op.getBaseLog() != (uint32_t)-1;
});
patterns.add<BootstrapGLWEOpPattern>(&getContext(), converter,
cryptoParameters);

View File

@@ -96,11 +96,11 @@ struct SubIntGLWEOpPattern
matchAndRewrite(TFHE::SubGLWEIntOp subOp, TFHE::SubGLWEIntOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Value negated = rewriter.create<Concrete::NegateLweTensorOp>(
subOp.getLoc(), adaptor.b().getType(), adaptor.b());
subOp.getLoc(), adaptor.getB().getType(), adaptor.getB());
rewriter.replaceOpWithNewOp<Concrete::AddPlaintextLweTensorOp>(
subOp, this->getTypeConverter()->convertType(subOp.getType()), negated,
subOp.a());
subOp.getA());
return mlir::success();
}
@@ -119,17 +119,16 @@ struct BootstrapGLWEOpPattern
matchAndRewrite(TFHE::BootstrapGLWEOp bsOp,
TFHE::BootstrapGLWEOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
TFHE::GLWECipherTextType resultType =
bsOp.getType().cast<TFHE::GLWECipherTextType>();
TFHE::GLWECipherTextType inputType =
bsOp.ciphertext().getType().cast<TFHE::GLWECipherTextType>();
bsOp.getCiphertext().getType().cast<TFHE::GLWECipherTextType>();
rewriter.replaceOpWithNewOp<Concrete::BootstrapLweTensorOp>(
bsOp, this->getTypeConverter()->convertType(resultType),
adaptor.ciphertext(), adaptor.lookup_table(), inputType.getDimension(),
adaptor.polySize(), adaptor.level(), adaptor.baseLog(),
adaptor.glweDimension(), resultType.getP());
adaptor.getCiphertext(), adaptor.getLookupTable(),
inputType.getDimension(), adaptor.getPolySize(), adaptor.getLevel(),
adaptor.getBaseLog(), adaptor.getGlweDimension(), resultType.getP());
return mlir::success();
}
@@ -152,11 +151,11 @@ struct KeySwitchGLWEOpPattern
TFHE::GLWECipherTextType resultType =
ksOp.getType().cast<TFHE::GLWECipherTextType>();
TFHE::GLWECipherTextType inputType =
ksOp.ciphertext().getType().cast<TFHE::GLWECipherTextType>();
ksOp.getCiphertext().getType().cast<TFHE::GLWECipherTextType>();
rewriter.replaceOpWithNewOp<Concrete::KeySwitchLweTensorOp>(
ksOp, this->getTypeConverter()->convertType(resultType),
adaptor.ciphertext(), adaptor.level(), adaptor.baseLog(),
adaptor.getCiphertext(), adaptor.getLevel(), adaptor.getBaseLog(),
inputType.getDimension(), resultType.getDimension());
return mlir::success();
@@ -174,15 +173,15 @@ struct TracePlaintextOpPattern
matchAndRewrite(Tracing::TracePlaintextOp op,
mlir::PatternRewriter &rewriter) const override {
auto inputWidth =
op.plaintext().getType().cast<mlir::IntegerType>().getWidth();
op.getPlaintext().getType().cast<mlir::IntegerType>().getWidth();
if (inputWidth == 64) {
op->setAttr("input_width", rewriter.getI64IntegerAttr(inputWidth));
return mlir::success();
}
auto extendedInput = rewriter.create<mlir::arith::ExtUIOp>(
op.getLoc(), rewriter.getI64Type(), op.plaintext());
op.getLoc(), rewriter.getI64Type(), op.getPlaintext());
auto newOp = rewriter.replaceOpWithNewOp<Tracing::TracePlaintextOp>(
op, extendedInput, op.msgAttr(), op.nmsbAttr());
op, extendedInput, op.getMsgAttr(), op.getNmsbAttr());
newOp->setAttr("input_width", rewriter.getI64IntegerAttr(inputWidth));
return ::mlir::success();
}
@@ -235,37 +234,36 @@ struct ExtractSliceOpPattern
if (this->getTypeConverter()->isLegal(extractSliceOp.getType())) {
return mlir::failure();
}
auto resultTy = extractSliceOp.result().getType();
auto resultTy = extractSliceOp.getResult().getType();
auto newResultTy = this->getTypeConverter()
->convertType(resultTy)
.cast<mlir::RankedTensorType>();
// add 0 to the static_offsets
mlir::SmallVector<mlir::Attribute> staticOffsets;
staticOffsets.append(adaptor.static_offsets().begin(),
adaptor.static_offsets().end());
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
mlir::SmallVector<int64_t> staticOffsets;
staticOffsets.append(adaptor.getStaticOffsets().begin(),
adaptor.getStaticOffsets().end());
staticOffsets.push_back(0);
// add the lweSize to the sizes
mlir::SmallVector<mlir::Attribute> staticSizes;
staticSizes.append(adaptor.static_sizes().begin(),
adaptor.static_sizes().end());
staticSizes.push_back(rewriter.getI64IntegerAttr(
newResultTy.getDimSize(newResultTy.getRank() - 1)));
mlir::SmallVector<int64_t> staticSizes;
staticSizes.append(adaptor.getStaticSizes().begin(),
adaptor.getStaticSizes().end());
staticSizes.push_back(newResultTy.getDimSize(newResultTy.getRank() - 1));
// add 1 to the strides
mlir::SmallVector<mlir::Attribute> staticStrides;
staticStrides.append(adaptor.static_strides().begin(),
adaptor.static_strides().end());
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
mlir::SmallVector<int64_t> staticStrides;
staticStrides.append(adaptor.getStaticStrides().begin(),
adaptor.getStaticStrides().end());
staticStrides.push_back(1);
// replace tensor.extract_slice to the new one
rewriter.replaceOpWithNewOp<mlir::tensor::ExtractSliceOp>(
extractSliceOp, newResultTy, adaptor.source(), adaptor.offsets(),
adaptor.sizes(), adaptor.strides(),
rewriter.getArrayAttr(staticOffsets),
rewriter.getArrayAttr(staticSizes),
rewriter.getArrayAttr(staticStrides));
extractSliceOp, newResultTy, adaptor.getSource(), adaptor.getOffsets(),
adaptor.getSizes(), adaptor.getStrides(),
rewriter.getDenseI64ArrayAttr(staticOffsets),
rewriter.getDenseI64ArrayAttr(staticSizes),
rewriter.getDenseI64ArrayAttr(staticStrides));
return ::mlir::success();
};
@@ -294,31 +292,28 @@ struct ExtractOpPattern
->convertType(extractOp.getType())
.cast<mlir::RankedTensorType>();
auto tensorRank =
adaptor.tensor().getType().cast<mlir::RankedTensorType>().getRank();
adaptor.getTensor().getType().cast<mlir::RankedTensorType>().getRank();
// [min..., 0] for static_offsets ()
mlir::SmallVector<mlir::Attribute> staticOffsets(
tensorRank,
rewriter.getI64IntegerAttr(std::numeric_limits<int64_t>::min()));
staticOffsets[staticOffsets.size() - 1] = rewriter.getI64IntegerAttr(0);
mlir::SmallVector<int64_t> staticOffsets(
tensorRank, std::numeric_limits<int64_t>::min());
staticOffsets[staticOffsets.size() - 1] = 0;
// [1..., lweDimension+1] for static_sizes or
// [1..., nbBlock, lweDimension+1]
mlir::SmallVector<mlir::Attribute> staticSizes(
tensorRank, rewriter.getI64IntegerAttr(1));
staticSizes[staticSizes.size() - 1] = rewriter.getI64IntegerAttr(
newResultType.getDimSize(newResultType.getRank() - 1));
mlir::SmallVector<int64_t> staticSizes(tensorRank, 1);
staticSizes[staticSizes.size() - 1] =
newResultType.getDimSize(newResultType.getRank() - 1);
// [1...] for static_strides
mlir::SmallVector<mlir::Attribute> staticStrides(
tensorRank, rewriter.getI64IntegerAttr(1));
mlir::SmallVector<int64_t> staticStrides(tensorRank, 1);
rewriter.replaceOpWithNewOp<mlir::tensor::ExtractSliceOp>(
extractOp, newResultType, adaptor.tensor(), adaptor.indices(),
extractOp, newResultType, adaptor.getTensor(), adaptor.getIndices(),
mlir::SmallVector<mlir::Value>{}, mlir::SmallVector<mlir::Value>{},
rewriter.getArrayAttr(staticOffsets),
rewriter.getArrayAttr(staticSizes),
rewriter.getArrayAttr(staticStrides));
rewriter.getDenseI64ArrayAttr(staticOffsets),
rewriter.getDenseI64ArrayAttr(staticSizes),
rewriter.getDenseI64ArrayAttr(staticStrides));
return ::mlir::success();
};
@@ -344,35 +339,34 @@ struct InsertSliceOpPattern
}
auto newResultTy = this->getTypeConverter()
->convertType(insertSliceOp.result().getType())
->convertType(insertSliceOp.getResult().getType())
.cast<mlir::RankedTensorType>();
// add 0 to static_offsets
mlir::SmallVector<mlir::Attribute> staticOffsets;
staticOffsets.append(adaptor.static_offsets().begin(),
adaptor.static_offsets().end());
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
// add 0 to static offsets
mlir::SmallVector<int64_t> staticOffsets;
staticOffsets.append(adaptor.getStaticOffsets().begin(),
adaptor.getStaticOffsets().end());
staticOffsets.push_back(0);
// add lweDimension+1 to static_sizes
mlir::SmallVector<mlir::Attribute> staticSizes;
staticSizes.append(adaptor.static_sizes().begin(),
adaptor.static_sizes().end());
staticSizes.push_back(rewriter.getI64IntegerAttr(
newResultTy.getDimSize(newResultTy.getRank() - 1)));
mlir::SmallVector<int64_t> staticSizes;
staticSizes.append(adaptor.getStaticSizes().begin(),
adaptor.getStaticSizes().end());
staticSizes.push_back(newResultTy.getDimSize(newResultTy.getRank() - 1));
// add 1 to the strides
mlir::SmallVector<mlir::Attribute> staticStrides;
staticStrides.append(adaptor.static_strides().begin(),
adaptor.static_strides().end());
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
mlir::SmallVector<int64_t> staticStrides;
staticStrides.append(adaptor.getStaticStrides().begin(),
adaptor.getStaticStrides().end());
staticStrides.push_back(1);
// replace tensor.insert_slice with the new one
rewriter.replaceOpWithNewOp<mlir::tensor::InsertSliceOp>(
insertSliceOp, newResultTy, adaptor.source(), adaptor.dest(),
adaptor.offsets(), adaptor.sizes(), adaptor.strides(),
rewriter.getArrayAttr(staticOffsets),
rewriter.getArrayAttr(staticSizes),
rewriter.getArrayAttr(staticStrides));
insertSliceOp, newResultTy, adaptor.getSource(), adaptor.getDest(),
adaptor.getOffsets(), adaptor.getSizes(), adaptor.getStrides(),
rewriter.getDenseI64ArrayAttr(staticOffsets),
rewriter.getDenseI64ArrayAttr(staticSizes),
rewriter.getDenseI64ArrayAttr(staticStrides));
return ::mlir::success();
};
@@ -399,18 +393,18 @@ struct InsertOpPattern
mlir::RankedTensorType newResultTy =
this->getTypeConverter()
->convertType(insertOp.result().getType())
->convertType(insertOp.getResult().getType())
.cast<mlir::RankedTensorType>();
// add zeros to static_offsets
// add zeros to static offsets
mlir::SmallVector<mlir::OpFoldResult> offsets;
offsets.append(adaptor.indices().begin(), adaptor.indices().end());
offsets.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
offsets.push_back(rewriter.getIndexAttr(0));
// Inserting a smaller tensor into a (potentially) bigger one. Set
// dimensions for all leading dimensions of the target tensor not
// present in the source to 1.
mlir::SmallVector<mlir::OpFoldResult> sizes(adaptor.indices().size(),
mlir::SmallVector<mlir::OpFoldResult> sizes(adaptor.getIndices().size(),
rewriter.getI64IntegerAttr(1));
// Add size for the bufferized source element
@@ -423,7 +417,8 @@ struct InsertOpPattern
// replace tensor.insert_slice with the new one
rewriter.replaceOpWithNewOp<mlir::tensor::InsertSliceOp>(
insertOp, adaptor.scalar(), adaptor.dest(), offsets, sizes, strides);
insertOp, adaptor.getScalar(), adaptor.getDest(), offsets, sizes,
strides);
return ::mlir::success();
};
@@ -453,7 +448,7 @@ struct FromElementsOpPattern
auto converter = this->getTypeConverter();
auto resultTy = fromElementsOp.result().getType();
auto resultTy = fromElementsOp.getResult().getType();
if (converter->isLegal(resultTy)) {
return mlir::failure();
}
@@ -483,7 +478,7 @@ struct FromElementsOpPattern
llvm::SmallVector<int64_t> currentOffsets(newRank, 0);
// for each elements insert_slice with right offet
for (auto elt : llvm::enumerate(adaptor.elements())) {
for (auto elt : llvm::enumerate(adaptor.getElements())) {
// Just create offsets as attributes
llvm::SmallVector<mlir::OpFoldResult, 4> offsets;
offsets.reserve(currentOffsets.size());
@@ -560,7 +555,7 @@ struct TensorShapeOpPattern : public mlir::OpConversionPattern<ShapeOp> {
auto reassocTy =
((mlir::Type)this->getTypeConverter()->convertType(
(inRank ? shapeOp.src() : shapeOp.result()).getType()))
(inRank ? shapeOp.getSrc() : shapeOp.getResult()).getType()))
.cast<VecTy>();
auto oldReassocs = shapeOp.getReassociationIndices();
@@ -574,7 +569,7 @@ struct TensorShapeOpPattern : public mlir::OpConversionPattern<ShapeOp> {
newReassocs.push_back(lweAssoc);
}
rewriter.replaceOpWithNewOp<ShapeOp>(shapeOp, newResultTy, adaptor.src(),
rewriter.replaceOpWithNewOp<ShapeOp>(shapeOp, newResultTy, adaptor.getSrc(),
newReassocs);
return ::mlir::success();
@@ -729,8 +724,9 @@ void TFHEToConcretePass::runOnOperation() {
target.addLegalOp<mlir::arith::ExtUIOp>();
target.addDynamicallyLegalOp<Tracing::TracePlaintextOp>(
[&](Tracing::TracePlaintextOp op) {
return (op.plaintext().getType().cast<mlir::IntegerType>().getWidth() ==
64);
return (
op.getPlaintext().getType().cast<mlir::IntegerType>().getWidth() ==
64);
});
// Conversion of RT Dialect Ops

View File

@@ -3,6 +3,7 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <mlir/IR/BuiltinTypeInterfaces.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Transforms/DialectConversion.h>
@@ -24,7 +25,7 @@ char memref_trace_message[] = "memref_trace_message";
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
size_t rank) {
std::vector<int64_t> shape(rank, -1);
std::vector<int64_t> shape(rank, mlir::ShapedType::kDynamic);
mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0);
for (size_t i = 0; i < rank; i++) {
expr = expr +
@@ -136,15 +137,15 @@ private:
void traceCiphertextAddOperands(Tracing::TraceCiphertextOp op,
mlir::SmallVector<mlir::Value> &operands,
mlir::RewriterBase &rewriter) {
auto msg = op.msg().getValueOr("");
auto nmsb = op.nmsb().getValueOr(0);
auto msg = op.getMsg().value_or("");
auto nmsb = op.getNmsb().value_or(0);
std::string msgName;
std::stringstream stream;
stream << rand();
stream >> msgName;
auto messageVal =
mlir::LLVM::createGlobalString(op.getLoc(), rewriter, msgName, msg,
mlir::LLVM::linkage::Linkage::Linkonce);
auto messageVal = mlir::LLVM::createGlobalString(
op.getLoc(), rewriter, msgName, msg,
mlir::LLVM::linkage::Linkage::Linkonce, false);
operands.push_back(messageVal);
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(msg.size())));
@@ -155,15 +156,15 @@ void traceCiphertextAddOperands(Tracing::TraceCiphertextOp op,
void tracePlaintextAddOperands(Tracing::TracePlaintextOp op,
mlir::SmallVector<mlir::Value> &operands,
mlir::RewriterBase &rewriter) {
auto msg = op.msg().getValueOr("");
auto nmsb = op.nmsb().getValueOr(0);
auto msg = op.getMsg().value_or("");
auto nmsb = op.getNmsb().value_or(0);
std::string msgName;
std::stringstream stream;
stream << rand();
stream >> msgName;
auto messageVal =
mlir::LLVM::createGlobalString(op.getLoc(), rewriter, msgName, msg,
mlir::LLVM::linkage::Linkage::Linkonce);
auto messageVal = mlir::LLVM::createGlobalString(
op.getLoc(), rewriter, msgName, msg,
mlir::LLVM::linkage::Linkage::Linkonce, false);
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op->getAttr("input_width")));
operands.push_back(messageVal);
@@ -176,14 +177,14 @@ void tracePlaintextAddOperands(Tracing::TracePlaintextOp op,
void traceMessageAddOperands(Tracing::TraceMessageOp op,
mlir::SmallVector<mlir::Value> &operands,
mlir::RewriterBase &rewriter) {
auto msg = op.msg().getValueOr("");
auto msg = op.getMsg().value_or("");
std::string msgName;
std::stringstream stream;
stream << rand();
stream >> msgName;
auto messageVal =
mlir::LLVM::createGlobalString(op.getLoc(), rewriter, msgName, msg,
mlir::LLVM::linkage::Linkage::Linkonce);
auto messageVal = mlir::LLVM::createGlobalString(
op.getLoc(), rewriter, msgName, msg,
mlir::LLVM::linkage::Linkage::Linkonce, false);
operands.push_back(messageVal);
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(msg.size())));
@@ -202,7 +203,7 @@ struct TracingToCAPIPass : public TracingToCAPIBase<TracingToCAPIPass> {
// Mark ops from the target dialect as legal operations
target.addLegalDialect<func::FuncDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalDialect<arith::ArithmeticDialect>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
// Make sure that no ops from `Tracing` remain after the lowering

View File

@@ -5,7 +5,6 @@
#include "concretelang/Conversion/Utils/Dialects/SCF.h"
#include "mlir/Transforms/RegionUtils.h"
#include <mlir/IR/BlockAndValueMapping.h>
namespace mlir {
namespace concretelang {

View File

@@ -4,7 +4,7 @@
// for license information.
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"

View File

@@ -3,7 +3,7 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -46,14 +46,14 @@ struct TensorToMemrefOp : public BufferizableOpInterface::ExternalModel<
return false;
}
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::None;
return BufferRelation::Unknown;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -63,7 +63,7 @@ struct TensorToMemrefOp : public BufferizableOpInterface::ExternalModel<
auto castOp = cast<TensorOp>(op);
auto resTensorType =
castOp.result().getType().template cast<mlir::TensorType>();
castOp.getResult().getType().template cast<mlir::TensorType>();
auto outMemrefType = MemRefType::get(resTensorType.getShape(),
resTensorType.getElementType());
@@ -81,7 +81,7 @@ struct TensorToMemrefOp : public BufferizableOpInterface::ExternalModel<
operands.push_back(operand.get());
} else {
operands.push_back(
bufferization::getBuffer(rewriter, operand.get(), options));
*bufferization::getBuffer(rewriter, operand.get(), options));
}
}

View File

@@ -10,7 +10,7 @@ add_mlir_dialect_library(
LINK_LIBS
PUBLIC
ConcretelangConversion
MLIRArithmeticDialect
MLIRArithDialect
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRIR

View File

@@ -6,11 +6,12 @@
#include <chrono>
#include <cmath>
#include <initializer_list>
#include <optional>
#include <vector>
#include "boost/outcome.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Pass/PassManager.h"
@@ -65,7 +66,7 @@ struct FunctionToDag {
mlir::concretelang::log_verbose() << MSG << "\n"; \
}
outcome::checked<llvm::Optional<optimizer::Dag>,
outcome::checked<std::optional<optimizer::Dag>,
::concretelang::error::StringError>
build() {
auto dag = concrete_optimizer::dag::empty();
@@ -88,7 +89,7 @@ struct FunctionToDag {
// Dag is empty <=> classical function without encryption
DEBUG("!!! concrete-optimizer: nothing to do in " << func.getName()
<< "\n");
return llvm::None;
return std::nullopt;
};
DEBUG(std::string(dag->dump()));
return std::move(dag);
@@ -143,7 +144,7 @@ struct FunctionToDag {
if (auto dot = asDot(op)) {
auto weightsOpt = dotWeights(dot);
if (weightsOpt) {
addDot(dag, val, encrypted_inputs, weightsOpt.getValue());
addDot(dag, val, encrypted_inputs, weightsOpt.value());
return;
}
// If can't find weights return default leveled op
@@ -231,8 +232,8 @@ struct FunctionToDag {
mlir::Value result = mulOp.getResult();
const std::vector<uint64_t> resultShape = getShape(result);
Operation *xOp = mulOp.a().getDefiningOp();
Operation *yOp = mulOp.b().getDefiningOp();
Operation *xOp = mulOp.getA().getDefiningOp();
Operation *yOp = mulOp.getB().getDefiningOp();
const double fixedCost = NEGLIGIBLE_COMPLEXITY;
const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY;
@@ -292,8 +293,8 @@ struct FunctionToDag {
mlir::Value result = maxOp.getResult();
const std::vector<uint64_t> resultShape = getShape(result);
Operation *xOp = maxOp.x().getDefiningOp();
Operation *yOp = maxOp.y().getDefiningOp();
Operation *xOp = maxOp.getX().getDefiningOp();
Operation *yOp = maxOp.getY().getDefiningOp();
const double fixedCost = NEGLIGIBLE_COMPLEXITY;
const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY;
@@ -344,12 +345,13 @@ struct FunctionToDag {
std::vector<uint64_t> fakeShape = resultShape;
uint64_t numberOfComparisons = 1;
for (auto dimensionSize : maxpool2dOp.kernel_shape().getValues<int64_t>()) {
for (auto dimensionSize :
maxpool2dOp.getKernelShape().getValues<int64_t>()) {
numberOfComparisons *= dimensionSize;
}
fakeShape.push_back(numberOfComparisons);
Operation *inputOp = maxpool2dOp.input().getDefiningOp();
Operation *inputOp = maxpool2dOp.getInput().getDefiningOp();
const double fixedCost = NEGLIGIBLE_COMPLEXITY;
const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY;
@@ -438,7 +440,7 @@ struct FunctionToDag {
return value.isa<mlir::BlockArgument>();
}
llvm::Optional<std::vector<std::int64_t>>
std::optional<std::vector<std::int64_t>>
resolveConstantVectorWeights(mlir::arith::ConstantOp &cstOp) {
std::vector<std::int64_t> values;
mlir::DenseIntElementsAttr denseVals =
@@ -446,14 +448,14 @@ struct FunctionToDag {
for (llvm::APInt val : denseVals.getValues<llvm::APInt>()) {
if (val.getActiveBits() > 64) {
return llvm::None;
return std::nullopt;
}
values.push_back(val.getSExtValue());
}
return values;
}
llvm::Optional<std::vector<std::int64_t>>
std::optional<std::vector<std::int64_t>>
resolveConstantWeights(mlir::Value &value) {
if (auto cstOp = llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
value.getDefiningOp())) {
@@ -463,18 +465,18 @@ struct FunctionToDag {
return resolveConstantVectorWeights(cstOp);
default:
DEBUG("High-Rank tensor: rely on MANP and levelledOp");
return llvm::None;
return std::nullopt;
}
} else {
DEBUG("Dynamic Weights: rely on MANP and levelledOp");
return llvm::None;
return std::nullopt;
}
}
llvm::Optional<std::vector<std::int64_t>>
std::optional<std::vector<std::int64_t>>
dotWeights(mlir::concretelang::FHELinalg::Dot &dot) {
if (dot.getOperands().size() != 2) {
return llvm::None;
return std::nullopt;
}
auto weights = dot.getOperands()[1];
return resolveConstantWeights(weights);

View File

@@ -16,8 +16,9 @@
#include <llvm/ADT/APInt.h>
#include <llvm/ADT/Optional.h>
#include <llvm/ADT/SmallString.h>
#include <mlir/Analysis/DataFlowAnalysis.h>
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
#include <mlir/Analysis/DataFlow/SparseAnalysis.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Linalg/IR/Linalg.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
@@ -55,7 +56,7 @@ static bool isEncryptedFunctionParameter(mlir::Value value) {
/// values for its predecessors have been calculated beforehand or an
/// unknown value otherwise.
struct MANPLatticeValue {
MANPLatticeValue(llvm::Optional<llvm::APInt> manp = {}) : manp(manp) {}
MANPLatticeValue(std::optional<llvm::APInt> manp = {}) : manp(manp) {}
static MANPLatticeValue getPessimisticValueState(mlir::MLIRContext *context) {
return MANPLatticeValue();
@@ -81,21 +82,38 @@ struct MANPLatticeValue {
return this->manp == rhs.manp;
}
/// Required by `mlir::LatticeElement::join()`, but should never be
/// invoked, as `MANPAnalysis::visitOperation()` takes care of
/// combining the squared Minimal Arithmetic Noise Padding of
/// operands into the Minimal Arithmetic Noise Padding of the result.
static MANPLatticeValue join(const MANPLatticeValue &lhs,
const MANPLatticeValue &rhs) {
assert(false && "Minimal Arithmetic Noise Padding values can only be "
"combined sensibly when the combining operation is known");
if (!lhs.getMANP().has_value())
return rhs;
if (!rhs.getMANP().has_value())
return lhs;
if (lhs.getMANP().value() == rhs.getMANP().value())
return lhs;
assert(false && "Attempting to join two distinct initialized values");
return MANPLatticeValue{};
}
llvm::Optional<llvm::APInt> getMANP() { return manp; }
void print(raw_ostream &os) const {
if (manp.has_value())
os << manp.value();
else
os << "(undefined)";
}
std::optional<llvm::APInt> getMANP() const { return manp; }
protected:
llvm::Optional<llvm::APInt> manp;
std::optional<llvm::APInt> manp;
};
class MANPLattice : public mlir::dataflow::Lattice<MANPLatticeValue> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MANPLattice)
using Lattice::Lattice;
};
/// Checks if `lhs` is less than `rhs`, where both values are assumed
@@ -278,15 +296,14 @@ static llvm::APInt conservativeIntNorm2Sq(mlir::Type t) {
/// Calculates the squared Minimal Arithmetic Noise Padding of an
/// `FHELinalg.dot_eint_int` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::Dot op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::Dot op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
"operands");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
mlir::arith::ConstantOp cstOp =
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
@@ -312,150 +329,140 @@ static llvm::APInt getSqMANP(
/// Calculates the squared Minimal Arithmetic Noise Padding of an
/// `FHE.add_eint_int` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::AddEintIntOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHE::AddEintIntOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(
operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
return eNorm;
}
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
/// that is equivalent to an `FHE.add_eint` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::AddEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHE::AddEintOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
operandMANPs[1]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
"operands");
llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue();
llvm::APInt a = operandMANPs[0]->getValue().getMANP().value();
llvm::APInt b = operandMANPs[1]->getValue().getMANP().value();
return APIntWidthExtendUAdd(a, b);
}
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
/// that is equivalent to an `FHE.sub_int_eint` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::SubIntEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHE::SubIntEintOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(
operandMANPs.size() == 2 &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
operandMANPs[1]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[1]->getValue().getMANP().getValue();
llvm::APInt eNorm = operandMANPs[1]->getValue().getMANP().value();
return eNorm;
}
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
/// that is equivalent to an `FHE.sub_eint_int` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::SubEintIntOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHE::SubEintIntOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(
operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
return eNorm;
}
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
/// that is equivalent to an `FHE.sub_eint` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::SubEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHE::SubEintOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
operandMANPs[1]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
"operands");
llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue();
llvm::APInt a = operandMANPs[0]->getValue().getMANP().value();
llvm::APInt b = operandMANPs[1]->getValue().getMANP().value();
return APIntWidthExtendUAdd(a, b);
}
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
/// that is equivalent to an `FHE.neg_eint` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::NegEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHE::NegEintOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(
operandMANPs.size() == 1 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
return eNorm;
}
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
/// that is equivalent to an `FHE.not` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::BoolNotOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHE::BoolNotOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(
operandMANPs.size() == 1 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
return eNorm;
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::ToSignedOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHE::ToSignedOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(
operandMANPs.size() == 1 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
return eNorm;
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::ToUnsignedOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHE::ToUnsignedOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(
operandMANPs.size() == 1 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
return eNorm;
}
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
/// that is equivalent to an `FHE.mul_eint_int` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::MulEintIntOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHE::MulEintIntOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
mlir::Type iTy = op->getOpOperand(1).get().getType();
assert(iTy.isSignlessInteger() &&
@@ -463,14 +470,14 @@ static llvm::APInt getSqMANP(
assert(
operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
mlir::arith::ConstantOp cstOp =
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
op->getOpOperand(1).get().getDefiningOp());
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
llvm::APInt sqNorm;
if (cstOp) {
@@ -488,19 +495,18 @@ static llvm::APInt getSqMANP(
/// Calculates the squared Minimal Arithmetic Noise Padding of
/// `FHE.mul_eint` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::MulEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHE::MulEintOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
operandMANPs[1]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
"operands");
// x * y = ((x + y)^2 / 4) - ((x - y)^2 / 4) == tlu(x + y) - tlu(x - y)
const llvm::APInt x = operandMANPs[0]->getValue().getMANP().getValue();
const llvm::APInt y = operandMANPs[1]->getValue().getMANP().getValue();
const llvm::APInt x = operandMANPs[0]->getValue().getMANP().value();
const llvm::APInt y = operandMANPs[1]->getValue().getMANP().value();
const llvm::APInt beforeTLUs = APIntWidthExtendUAdd(x, y);
const llvm::APInt tlu = {1, 1, false};
@@ -512,13 +518,12 @@ static llvm::APInt getSqMANP(
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
/// that is equivalent to an `FHE.round` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::RoundEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHE::RoundEintOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(
operandMANPs.size() == 1 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
uint64_t inputWidth =
@@ -527,7 +532,7 @@ static llvm::APInt getSqMANP(
op.getResult().getType().cast<FHE::FheIntegerInterface>().getWidth();
uint64_t clearedBits = inputWidth - outputWidth;
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
eNorm += clearedBits;
return eNorm;
@@ -535,19 +540,18 @@ static llvm::APInt getSqMANP(
/// Calculates the squared Minimal Arithmetic Noise Padding of
/// `FHE.max_eint` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::MaxEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHE::MaxEintOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
operandMANPs[1]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
"operands");
// max(x, y) = max(x - y, 0) + y
const llvm::APInt x = operandMANPs[0]->getValue().getMANP().getValue();
const llvm::APInt y = operandMANPs[1]->getValue().getMANP().getValue();
const llvm::APInt x = operandMANPs[0]->getValue().getMANP().value();
const llvm::APInt y = operandMANPs[1]->getValue().getMANP().value();
const llvm::APInt sub = APIntWidthExtendUAdd(x, y);
const llvm::APInt tlu = {1, 1, false};
@@ -559,101 +563,94 @@ static llvm::APInt getSqMANP(
/// Calculates the squared Minimal Arithmetic Noise Padding of an
/// `FHELinalg.add_eint_int` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::AddEintIntOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::AddEintIntOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(
operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
return eNorm;
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::AddEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::AddEintOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
operandMANPs[1]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
"operands");
llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue();
llvm::APInt a = operandMANPs[0]->getValue().getMANP().value();
llvm::APInt b = operandMANPs[1]->getValue().getMANP().value();
return APIntWidthExtendUAdd(a, b);
}
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
/// that is equivalent to an `FHELinalg.sub_int_eint` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::SubIntEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::SubIntEintOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(
operandMANPs.size() == 2 &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
operandMANPs[1]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[1]->getValue().getMANP().getValue();
llvm::APInt eNorm = operandMANPs[1]->getValue().getMANP().value();
return eNorm;
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::SubEintIntOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::SubEintIntOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(
operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
return eNorm;
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::SubEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::SubEintOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
operandMANPs[1]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
"operands");
llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue();
llvm::APInt a = operandMANPs[0]->getValue().getMANP().value();
llvm::APInt b = operandMANPs[1]->getValue().getMANP().value();
return APIntWidthExtendUAdd(a, b);
}
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
/// that is equivalent to an `FHELinalg.neg_eint` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::NegEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::NegEintOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(
operandMANPs.size() == 1 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
return eNorm;
}
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
/// that is equivalent to an `FHE.mul_eint_int` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::MulEintIntOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MulEintIntOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
mlir::RankedTensorType op0Ty =
op->getOpOperand(1).get().getType().cast<mlir::RankedTensorType>();
@@ -665,10 +662,10 @@ static llvm::APInt getSqMANP(
assert(
operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
llvm::APInt sqNorm;
mlir::arith::ConstantOp cstOp =
@@ -761,14 +758,13 @@ static llvm::APInt calculateSqManpForMatMulWithDenseValues(
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
/// that is equivalent to an `FHE.mul_eint_int` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::MatMulEintIntOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MatMulEintIntOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
auto lhsType =
((mlir::Type)op.lhs().getType()).cast<mlir::RankedTensorType>();
((mlir::Type)op.getLhs().getType()).cast<mlir::RankedTensorType>();
auto rhsType =
((mlir::Type)op.rhs().getType()).cast<mlir::RankedTensorType>();
((mlir::Type)op.getRhs().getType()).cast<mlir::RankedTensorType>();
llvm::ArrayRef<int64_t> lhsShape = lhsType.getShape();
llvm::ArrayRef<int64_t> rhsShape = rhsType.getShape();
@@ -782,10 +778,10 @@ static llvm::APInt getSqMANP(
assert(
operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt lhsNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt lhsNorm = operandMANPs[0]->getValue().getMANP().value();
llvm::APInt accNorm = llvm::APInt{1, 1, false};
mlir::arith::ConstantOp cstOp =
@@ -860,14 +856,13 @@ static llvm::APInt getSqMANP(
return accNorm;
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::MatMulIntEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MatMulIntEintOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
auto lhsType =
((mlir::Type)op.lhs().getType()).cast<mlir::RankedTensorType>();
((mlir::Type)op.getLhs().getType()).cast<mlir::RankedTensorType>();
auto rhsType =
((mlir::Type)op.rhs().getType()).cast<mlir::RankedTensorType>();
((mlir::Type)op.getRhs().getType()).cast<mlir::RankedTensorType>();
llvm::ArrayRef<int64_t> lhsShape = lhsType.getShape();
llvm::ArrayRef<int64_t> rhsShape = rhsType.getShape();
@@ -881,10 +876,10 @@ static llvm::APInt getSqMANP(
assert(
operandMANPs.size() == 2 &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
operandMANPs[1]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt rhsNorm = operandMANPs[1]->getValue().getMANP().getValue();
llvm::APInt rhsNorm = operandMANPs[1]->getValue().getMANP().value();
llvm::APInt accNorm = llvm::APInt{1, 1, false};
mlir::arith::ConstantOp cstOp =
@@ -960,123 +955,112 @@ static llvm::APInt getSqMANP(
return accNorm;
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::TransposeOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::TransposeOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(
operandMANPs.size() == 1 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
return operandMANPs[0]->getValue().getMANP().getValue();
return operandMANPs[0]->getValue().getMANP().value();
}
static llvm::APInt getSqMANP(
mlir::tensor::ExtractOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::tensor::ExtractOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value();
return eNorm;
}
static llvm::APInt getSqMANP(
FHELinalg::FromElementOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(FHELinalg::FromElementOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
auto manp = operandMANPs[0]->getValue().getMANP();
if (manp.hasValue()) {
return manp.getValue();
if (manp.has_value()) {
return manp.value();
}
return llvm::APInt{1, 1, false};
}
static llvm::APInt getSqMANP(
mlir::tensor::FromElementsOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::tensor::FromElementsOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
auto max = std::max_element(
operandMANPs.begin(), operandMANPs.end(),
[](mlir::LatticeElement<MANPLatticeValue> *const a,
mlir::LatticeElement<MANPLatticeValue> *const b) {
return APIntWidthExtendULT(a->getValue().getMANP().getValue(),
b->getValue().getMANP().getValue());
});
return (*max)->getValue().getMANP().getValue();
auto max = std::max_element(operandMANPs.begin(), operandMANPs.end(),
[](const MANPLattice *a, const MANPLattice *b) {
return APIntWidthExtendULT(
a->getValue().getMANP().value(),
b->getValue().getMANP().value());
});
return (*max)->getValue().getMANP().value();
}
static llvm::APInt getSqMANP(
mlir::tensor::ExtractSliceOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::tensor::ExtractSliceOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
return operandMANPs[0]->getValue().getMANP().getValue();
return operandMANPs[0]->getValue().getMANP().value();
}
static llvm::APInt getSqMANP(
mlir::tensor::InsertSliceOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::tensor::InsertSliceOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(
operandMANPs.size() >= 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
operandMANPs[1]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
return APIntUMax(operandMANPs[0]->getValue().getMANP().getValue(),
operandMANPs[1]->getValue().getMANP().getValue());
return APIntUMax(operandMANPs[0]->getValue().getMANP().value(),
operandMANPs[1]->getValue().getMANP().value());
}
static llvm::APInt getSqMANP(
mlir::tensor::InsertOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::tensor::InsertOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(
operandMANPs.size() >= 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
operandMANPs[1]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
return APIntUMax(operandMANPs[0]->getValue().getMANP().getValue(),
operandMANPs[1]->getValue().getMANP().getValue());
return APIntUMax(operandMANPs[0]->getValue().getMANP().value(),
operandMANPs[1]->getValue().getMANP().value());
}
static llvm::APInt getSqMANP(
mlir::tensor::CollapseShapeOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::tensor::CollapseShapeOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(
operandMANPs.size() >= 1 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
return operandMANPs[0]->getValue().getMANP().getValue();
return operandMANPs[0]->getValue().getMANP().value();
}
static llvm::APInt getSqMANP(
mlir::tensor::ExpandShapeOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::tensor::ExpandShapeOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(
operandMANPs.size() >= 1 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
return operandMANPs[0]->getValue().getMANP().getValue();
return operandMANPs[0]->getValue().getMANP().value();
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::SumOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::SumOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
auto inputType = op.getOperand().getType().dyn_cast<mlir::TensorType>();
@@ -1087,12 +1071,12 @@ static llvm::APInt getSqMANP(
uint64_t numberOfElementsAddedTogetherInEachOutputCell = 1;
mlir::ArrayAttr axes = op.axes();
mlir::ArrayAttr axes = op.getAxes();
if (axes.empty()) {
numberOfElementsAddedTogetherInEachOutputCell *= numberOfElementsInTheInput;
} else {
llvm::ArrayRef<int64_t> shape = inputType.getShape();
for (mlir::Attribute axisAttribute : op.axes()) {
for (mlir::Attribute axisAttribute : op.getAxes()) {
int64_t axis = axisAttribute.cast<IntegerAttr>().getInt();
numberOfElementsAddedTogetherInEachOutputCell *= shape[axis];
}
@@ -1108,22 +1092,21 @@ static llvm::APInt getSqMANP(
};
assert(operandMANPs.size() == 1 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
"operands");
llvm::APInt operandMANP = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt operandMANP = operandMANPs[0]->getValue().getMANP().value();
return APIntWidthExtendUMul(noiseMultiplier, operandMANP);
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::ConcatOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::ConcatOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
llvm::APInt result = llvm::APInt{1, 0, false};
for (mlir::LatticeElement<MANPLatticeValue> *operandMANP : operandMANPs) {
llvm::APInt candidate = operandMANP->getValue().getMANP().getValue();
for (const MANPLattice *operandMANP : operandMANPs) {
llvm::APInt candidate = operandMANP->getValue().getMANP().value();
if (candidate.getLimitedValue() >= result.getLimitedValue()) {
result = candidate;
}
@@ -1131,22 +1114,21 @@ static llvm::APInt getSqMANP(
return result;
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::Conv2dOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::Conv2dOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
mlir::RankedTensorType weightTy =
op.weight().getType().cast<mlir::RankedTensorType>();
op.getWeight().getType().cast<mlir::RankedTensorType>();
mlir::Type weightIntType = weightTy.getElementType();
// Bias is optional, so we can have both 2 or 3 operands
assert((operandMANPs.size() == 2 || operandMANPs.size() == 3) &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
"operand");
llvm::APInt inputNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt inputNorm = operandMANPs[0]->getValue().getMANP().value();
mlir::arith::ConstantOp weightCstOp =
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
@@ -1200,9 +1182,8 @@ static llvm::APInt getSqMANP(
return accNorm;
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::Maxpool2dOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::Maxpool2dOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
// maximum between two value is calculated using
// - max(x - y, 0) + y
@@ -1214,7 +1195,7 @@ static llvm::APInt getSqMANP(
// so the resulting MANP is `{1, 1, false} + MANP input`
const llvm::APInt tlu = {1, 1, false};
const llvm::APInt input = operandMANPs[0]->getValue().getMANP().getValue();
const llvm::APInt input = operandMANPs[0]->getValue().getMANP().value();
const llvm::APInt forResult = APIntWidthExtendUAdd(tlu, input);
const llvm::APInt forIntermediate = APIntWidthExtendUAdd(forResult, input);
@@ -1222,18 +1203,28 @@ static llvm::APInt getSqMANP(
return APIntUMax(forIntermediate, forResult);
}
struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
using ForwardDataFlowAnalysis<MANPLatticeValue>::ForwardDataFlowAnalysis;
MANPAnalysis(mlir::MLIRContext *ctx, bool debug)
: mlir::ForwardDataFlowAnalysis<MANPLatticeValue>(ctx), debug(debug) {}
class MANPAnalysis
: public mlir::dataflow::SparseDataFlowAnalysis<MANPLattice> {
public:
explicit MANPAnalysis(mlir::DataFlowSolver &solver, bool debug)
: mlir::dataflow::SparseDataFlowAnalysis<MANPLattice>(solver),
debug(debug) {}
~MANPAnalysis() override = default;
void setToEntryState(MANPLattice *lattice) override {
if (isEncryptedFunctionParameter(lattice->getPoint())) {
// Set minimal MANP for encrypted function arguments
propagateIfChanged(lattice, lattice->join(MANPLatticeValue{
std::optional{llvm::APInt(1, 1)}}));
} else {
// Everything else is initialized with an unset value
propagateIfChanged(lattice, lattice->join(MANPLatticeValue{}));
}
}
void visitOperation(Operation *op, ArrayRef<const MANPLattice *> operands,
ArrayRef<MANPLattice *> results) override {
MANPLattice *latticeRes = results[0];
mlir::ChangeResult visitOperation(
mlir::Operation *op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operands) final {
mlir::LatticeElement<MANPLatticeValue> &latticeRes =
getLatticeElement(op->getResult(0));
bool isDummy = false;
llvm::APInt norm2SqEquiv;
@@ -1350,7 +1341,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
} else if (auto transposeOp =
llvm::dyn_cast<mlir::concretelang::FHELinalg::TransposeOp>(
op)) {
if (transposeOp.tensor()
if (transposeOp.getTensor()
.getType()
.cast<mlir::TensorType>()
.getElementType()
@@ -1364,7 +1355,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
// Tensor Operators
// ExtractOp
else if (auto extractOp = llvm::dyn_cast<mlir::tensor::ExtractOp>(op)) {
if (extractOp.result()
if (extractOp.getResult()
.getType()
.isa<mlir::concretelang::FHE::FheIntegerInterface>()) {
norm2SqEquiv = getSqMANP(extractOp, operands);
@@ -1375,7 +1366,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
// ExtractSliceOp
else if (auto extractSliceOp =
llvm::dyn_cast<mlir::tensor::ExtractSliceOp>(op)) {
if (extractSliceOp.result()
if (extractSliceOp.getResult()
.getType()
.cast<mlir::TensorType>()
.getElementType()
@@ -1387,7 +1378,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
}
// InsertOp
else if (auto insertOp = llvm::dyn_cast<mlir::tensor::InsertOp>(op)) {
if (insertOp.result()
if (insertOp.getResult()
.getType()
.cast<mlir::TensorType>()
.getElementType()
@@ -1400,7 +1391,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
// InsertSliceOp
else if (auto insertSliceOp =
llvm::dyn_cast<mlir::tensor::InsertSliceOp>(op)) {
if (insertSliceOp.result()
if (insertSliceOp.getResult()
.getType()
.cast<mlir::TensorType>()
.getElementType()
@@ -1412,7 +1403,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
}
// FromElementOp
else if (auto fromOp = llvm::dyn_cast<mlir::tensor::FromElementsOp>(op)) {
if (fromOp.result()
if (fromOp.getResult()
.getType()
.cast<mlir::TensorType>()
.getElementType()
@@ -1425,7 +1416,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
// TensorCollapseShapeOp
else if (auto reshapeOp =
llvm::dyn_cast<mlir::tensor::CollapseShapeOp>(op)) {
if (reshapeOp.result()
if (reshapeOp.getResult()
.getType()
.cast<mlir::TensorType>()
.getElementType()
@@ -1437,7 +1428,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
}
// TensorExpandShapeOp
else if (auto reshapeOp = llvm::dyn_cast<mlir::tensor::ExpandShapeOp>(op)) {
if (reshapeOp.result()
if (reshapeOp.getResult()
.getType()
.cast<mlir::TensorType>()
.getElementType()
@@ -1459,8 +1450,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
}
if (!isDummy) {
latticeRes.join(MANPLatticeValue{norm2SqEquiv});
latticeRes.markOptimisticFixpoint();
latticeRes->join(MANPLatticeValue{norm2SqEquiv});
op->setAttr("SMANP",
mlir::IntegerAttr::get(
@@ -1483,10 +1473,8 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
<< APIntToStringValUnsigned(norm2SqEquiv) << "\n";
}
} else {
latticeRes.join(MANPLatticeValue{});
latticeRes->join(MANPLatticeValue{});
}
return mlir::ChangeResult::Change;
}
private:
@@ -1500,8 +1488,12 @@ struct MANPPass : public MANPBase<MANPPass> {
void runOnOperation() override {
mlir::func::FuncOp func = getOperation();
MANPAnalysis analysis(func->getContext(), debug);
analysis.run(func);
mlir::DataFlowSolver solver;
solver.load<mlir::dataflow::DeadCodeAnalysis>();
solver.load<MANPAnalysis>(debug);
if (failed(solver.initializeAndRun(func)))
return signalPassFailure();
}
MANPPass() = delete;
MANPPass(bool debug) : debug(debug){};

View File

@@ -61,8 +61,8 @@ bool verifyEncryptedIntegerInputsConsistency(mlir::Operation &op,
}
mlir::LogicalResult AddEintIntOp::verify() {
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
auto b = this->b().getType().cast<IntegerType>();
auto a = this->getA().getType().dyn_cast<FheIntegerInterface>();
auto b = this->getB().getType().cast<IntegerType>();
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
@@ -79,8 +79,8 @@ mlir::LogicalResult AddEintIntOp::verify() {
}
mlir::LogicalResult AddEintOp::verify() {
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
auto b = this->b().getType().dyn_cast<FheIntegerInterface>();
auto a = this->getA().getType().dyn_cast<FheIntegerInterface>();
auto b = this->getB().getType().dyn_cast<FheIntegerInterface>();
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
@@ -96,8 +96,8 @@ mlir::LogicalResult AddEintOp::verify() {
}
mlir::LogicalResult SubIntEintOp::verify() {
auto a = this->a().getType().cast<IntegerType>();
auto b = this->b().getType().dyn_cast<FheIntegerInterface>();
auto a = this->getA().getType().cast<IntegerType>();
auto b = this->getB().getType().dyn_cast<FheIntegerInterface>();
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), b,
@@ -114,8 +114,8 @@ mlir::LogicalResult SubIntEintOp::verify() {
}
mlir::LogicalResult SubEintIntOp::verify() {
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
auto b = this->b().getType().cast<IntegerType>();
auto a = this->getA().getType().dyn_cast<FheIntegerInterface>();
auto b = this->getB().getType().cast<IntegerType>();
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
@@ -132,8 +132,8 @@ mlir::LogicalResult SubEintIntOp::verify() {
}
mlir::LogicalResult SubEintOp::verify() {
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
auto b = this->b().getType().dyn_cast<FheIntegerInterface>();
auto a = this->getA().getType().dyn_cast<FheIntegerInterface>();
auto b = this->getB().getType().dyn_cast<FheIntegerInterface>();
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
@@ -149,7 +149,7 @@ mlir::LogicalResult SubEintOp::verify() {
}
mlir::LogicalResult NegEintOp::verify() {
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
auto a = this->getA().getType().dyn_cast<FheIntegerInterface>();
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
out)) {
@@ -159,8 +159,8 @@ mlir::LogicalResult NegEintOp::verify() {
}
mlir::LogicalResult MulEintIntOp::verify() {
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
auto b = this->b().getType().cast<IntegerType>();
auto a = this->getA().getType().dyn_cast<FheIntegerInterface>();
auto b = this->getB().getType().cast<IntegerType>();
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
@@ -177,8 +177,8 @@ mlir::LogicalResult MulEintIntOp::verify() {
}
mlir::LogicalResult MulEintOp::verify() {
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
auto b = this->b().getType().dyn_cast<FheIntegerInterface>();
auto a = this->getA().getType().dyn_cast<FheIntegerInterface>();
auto b = this->getB().getType().dyn_cast<FheIntegerInterface>();
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
if (!verifyEncryptedIntegerInputsConsistency(*this->getOperation(), a, b)) {
@@ -193,8 +193,8 @@ mlir::LogicalResult MulEintOp::verify() {
}
mlir::LogicalResult MaxEintOp::verify() {
auto xTy = this->x().getType().dyn_cast<FheIntegerInterface>();
auto yTy = this->y().getType().dyn_cast<FheIntegerInterface>();
auto xTy = this->getX().getType().dyn_cast<FheIntegerInterface>();
auto yTy = this->getY().getType().dyn_cast<FheIntegerInterface>();
auto outTy = this->getResult().getType().dyn_cast<FheIntegerInterface>();
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(),
@@ -211,7 +211,7 @@ mlir::LogicalResult MaxEintOp::verify() {
}
mlir::LogicalResult ToSignedOp::verify() {
auto input = this->input().getType().cast<EncryptedIntegerType>();
auto input = this->getInput().getType().cast<EncryptedIntegerType>();
auto output = this->getResult().getType().cast<EncryptedSignedIntegerType>();
if (input.getWidth() != output.getWidth()) {
@@ -224,7 +224,7 @@ mlir::LogicalResult ToSignedOp::verify() {
}
mlir::LogicalResult ToUnsignedOp::verify() {
auto input = this->input().getType().cast<EncryptedSignedIntegerType>();
auto input = this->getInput().getType().cast<EncryptedSignedIntegerType>();
auto output = this->getResult().getType().cast<EncryptedIntegerType>();
if (input.getWidth() != output.getWidth()) {
@@ -237,7 +237,7 @@ mlir::LogicalResult ToUnsignedOp::verify() {
}
mlir::LogicalResult ToBoolOp::verify() {
auto input = this->input().getType().cast<EncryptedIntegerType>();
auto input = this->getInput().getType().cast<EncryptedIntegerType>();
if (input.getWidth() != 1 && input.getWidth() != 2) {
this->emitOpError("should have 1 or 2 as the width of encrypted input to "
@@ -249,7 +249,7 @@ mlir::LogicalResult ToBoolOp::verify() {
}
mlir::LogicalResult GenGateOp::verify() {
auto truth_table = this->truth_table().getType().cast<TensorType>();
auto truth_table = this->getTruthTable().getType().cast<TensorType>();
mlir::SmallVector<int64_t, 1> expectedShape{4};
if (!truth_table.hasStaticShape(expectedShape)) {
@@ -261,8 +261,8 @@ mlir::LogicalResult GenGateOp::verify() {
}
::mlir::LogicalResult ApplyLookupTableEintOp::verify() {
auto ct = this->a().getType().cast<FheIntegerInterface>();
auto lut = this->lut().getType().cast<TensorType>();
auto ct = this->getA().getType().cast<FheIntegerInterface>();
auto lut = this->getLut().getType().cast<TensorType>();
// Check the shape of lut argument
auto width = ct.getWidth();
@@ -281,7 +281,7 @@ mlir::LogicalResult GenGateOp::verify() {
}
mlir::LogicalResult RoundEintOp::verify() {
auto input = this->input().getType().cast<FheIntegerInterface>();
auto input = this->getInput().getType().cast<FheIntegerInterface>();
auto output = this->getResult().getType().cast<FheIntegerInterface>();
if (input.getWidth() <= output.getWidth()) {

View File

@@ -4,7 +4,7 @@
// for license information.
#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
#include <mlir/IR/PatternMatch.h>
@@ -93,7 +93,8 @@ public:
mlir::LogicalResult
matchAndRewrite(FHE::AddEintOp op, FHE::AddEintOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto tensorType = adaptor.a().getType().dyn_cast<mlir::RankedTensorType>();
auto tensorType =
adaptor.getA().getType().dyn_cast<mlir::RankedTensorType>();
auto shape = tensorType.getShape();
assert(shape.size() == 1 &&
"chunked integer should be converted to flat tensors, but tensor "
@@ -112,7 +113,8 @@ public:
.getResult();
mlir::Value resultTensor =
rewriter.create<FHE::ZeroTensorOp>(op.getLoc(), adaptor.a().getType())
rewriter
.create<FHE::ZeroTensorOp>(op.getLoc(), adaptor.getA().getType())
.getResult();
// used to shift the carry bit to the left
mlir::Value twoPowerChunkSizeCst =
@@ -127,10 +129,10 @@ public:
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter,
mlir::ValueRange args) {
// add inputs with the previous carry (init to 0)
mlir::Value leftEint =
builder.create<mlir::tensor::ExtractOp>(loc, adaptor.a(), iter);
mlir::Value rightEint =
builder.create<mlir::tensor::ExtractOp>(loc, adaptor.b(), iter);
mlir::Value leftEint = builder.create<mlir::tensor::ExtractOp>(
loc, adaptor.getA(), iter);
mlir::Value rightEint = builder.create<mlir::tensor::ExtractOp>(
loc, adaptor.getB(), iter);
mlir::Value result =
builder.create<FHE::AddEintOp>(loc, leftEint, rightEint)
.getResult();

View File

@@ -3,7 +3,7 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
@@ -33,11 +33,11 @@ public:
rewriter.getContext(), 2);
auto left = rewriter
.create<mlir::concretelang::FHE::FromBoolOp>(
op.getLoc(), eint2, op.left())
op.getLoc(), eint2, op.getLeft())
.getResult();
auto right = rewriter
.create<mlir::concretelang::FHE::FromBoolOp>(
op.getLoc(), eint2, op.right())
op.getLoc(), eint2, op.getRight())
.getResult();
auto cst_two =
rewriter.create<mlir::arith::ConstantIntOp>(op.getLoc(), 2, 3)
@@ -52,7 +52,7 @@ public:
.getResult();
auto lut_result =
rewriter.create<mlir::concretelang::FHE::ApplyLookupTableEintOp>(
op.getLoc(), eint2, newIndex, op.truth_table());
op.getLoc(), eint2, newIndex, op.getTruthTable());
rewriter.replaceOpWithNewOp<mlir::concretelang::FHE::ToBoolOp>(
op,
mlir::concretelang::FHE::EncryptedBooleanType::get(
@@ -84,7 +84,7 @@ public:
auto truth_table =
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), truth_table_attr);
rewriter.replaceOpWithNewOp<mlir::concretelang::FHE::GenGateOp>(
op, op.getResult().getType(), op.left(), op.right(), truth_table);
op, op.getResult().getType(), op.getLeft(), op.getRight(), truth_table);
return mlir::success();
}
@@ -119,11 +119,11 @@ public:
auto c1AndNotCond =
rewriter
.create<mlir::concretelang::FHE::GenGateOp>(
op.getLoc(), boolType, op.c1(), op.cond(), truth_table)
op.getLoc(), boolType, op.getC1(), op.getCond(), truth_table)
.getResult();
auto c2AndCond = rewriter
.create<mlir::concretelang::FHE::BoolAndOp>(
op.getLoc(), boolType, op.c2(), op.cond())
op.getLoc(), boolType, op.getC2(), op.getCond())
.getResult();
auto c1AndNotCondBool = rewriter

View File

@@ -7,7 +7,7 @@
#include <concretelang/Dialect/FHE/IR/FHEOps.h>
#include <concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.h>
#include <concretelang/Support/Constants.h>
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Linalg/IR/Linalg.h>
#include <mlir/IR/PatternMatch.h>
@@ -31,7 +31,7 @@ public:
matchAndRewrite(FHE::MulEintOp op, FHE::MulEintOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto inputType = adaptor.a().getType();
auto inputType = adaptor.getA().getType();
auto bitWidth = inputType.cast<FHE::FheIntegerInterface>().getWidth();
auto isSigned = inputType.cast<FHE::FheIntegerInterface>().isSigned();
mlir::Type signedType =
@@ -50,8 +50,8 @@ public:
// signedness for inputs and outputs.
// s = a + b
mlir::Value sum =
rewriter.create<FHE::AddEintOp>(op->getLoc(), adaptor.a(), adaptor.b());
mlir::Value sum = rewriter.create<FHE::AddEintOp>(
op->getLoc(), adaptor.getA(), adaptor.getB());
// se = (s)^2/4
// Depending on whether a,b,s are signed or not, we need a different lut to
@@ -71,8 +71,8 @@ public:
op->getLoc(), inputType, sum, sumLut);
// d = a - b
mlir::Value diff =
rewriter.create<FHE::SubEintOp>(op->getLoc(), adaptor.a(), adaptor.b());
mlir::Value diff = rewriter.create<FHE::SubEintOp>(
op->getLoc(), adaptor.getA(), adaptor.getB());
// de = (d)^2/4
// Here, the tlu must be performed with signed encoded lut, to properly
@@ -155,7 +155,7 @@ public:
mlir::ConversionTarget target(getContext());
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
target.addLegalDialect<mlir::arith::ArithDialect>();
target.addLegalDialect<FHE::FHEDialect>();
target.addIllegalOp<FHE::MulEintOp>();

View File

@@ -3,7 +3,7 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -32,8 +32,8 @@ struct MaxEintPattern : public mlir::OpRewritePattern<FHE::MaxEintOp> {
maxEintOp->getResult(0).getType().cast<FHE::FheIntegerInterface>();
const int64_t outputBitWidth = outputTy.getWidth();
mlir::Value x = maxEintOp.x();
mlir::Value y = maxEintOp.y();
mlir::Value x = maxEintOp.getX();
mlir::Value y = maxEintOp.getY();
const auto xTy = x.getType().cast<FHE::FheIntegerInterface>();
const auto yTy = y.getType().cast<FHE::FheIntegerInterface>();
@@ -70,7 +70,7 @@ struct MaxEintPattern : public mlir::OpRewritePattern<FHE::MaxEintOp> {
.getResult();
const mlir::Value add =
rewriter.create<FHE::AddEintOp>(loc, max, maxEintOp.y()).getResult();
rewriter.create<FHE::AddEintOp>(loc, max, maxEintOp.getY()).getResult();
rewriter.replaceOp(maxEintOp, {add});
return mlir::success();
@@ -85,7 +85,7 @@ struct FHEMaxTransform : public FHEMaxTransformBase<FHEMaxTransform> {
void FHEMaxTransform::runOnOperation() {
auto target = mlir::ConversionTarget(this->getContext());
target.addLegalDialect<arith::ArithmeticDialect>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<FHE::FHEDialect>();
target.addIllegalOp<FHE::MaxEintOp>();

View File

@@ -5,7 +5,7 @@
#include <unordered_set>
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/TypeUtilities.h"
@@ -13,6 +13,7 @@
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.h"
#include "llvm/ADT/SmallVector.h"
namespace mlir {
namespace OpTrait {
@@ -272,10 +273,10 @@ namespace concretelang {
namespace FHELinalg {
mlir::LogicalResult ApplyLookupTableEintOp::verify() {
auto tTy = this->t().getType().cast<mlir::RankedTensorType>();
auto tTy = this->getT().getType().cast<mlir::RankedTensorType>();
auto tEltTy = tTy.getElementType()
.cast<mlir::concretelang::FHE::EncryptedIntegerType>();
auto lutTy = this->lut().getType().cast<mlir::RankedTensorType>();
auto lutTy = this->getLut().getType().cast<mlir::RankedTensorType>();
auto lutEltTy = lutTy.getElementType().cast<mlir::IntegerType>();
auto resultTy = this->getResult().getType().cast<mlir::RankedTensorType>();
@@ -297,10 +298,10 @@ mlir::LogicalResult ApplyLookupTableEintOp::verify() {
}
mlir::LogicalResult ApplyMultiLookupTableEintOp::verify() {
auto tTy = this->t().getType().cast<mlir::RankedTensorType>();
auto tTy = this->getT().getType().cast<mlir::RankedTensorType>();
auto tEltTy = tTy.getElementType()
.cast<mlir::concretelang::FHE::EncryptedIntegerType>();
auto lutTy = this->luts().getType().cast<mlir::RankedTensorType>();
auto lutTy = this->getLuts().getType().cast<mlir::RankedTensorType>();
auto lutEltTy = lutTy.getElementType().cast<mlir::IntegerType>();
auto resultTy = this->getResult().getType().cast<mlir::RankedTensorType>();
@@ -378,9 +379,9 @@ mlir::LogicalResult verifyLutsSize(ApplyMappedLookupTableEintOp &op,
}
mlir::LogicalResult ApplyMappedLookupTableEintOp::verify() {
auto t = this->t();
auto luts = this->luts();
auto map = this->map();
auto t = this->getT();
auto luts = this->getLuts();
auto map = this->getMap();
auto result = this->getResult();
auto t_shape = getTensorType(t).getShape();
@@ -401,16 +402,16 @@ mlir::LogicalResult ApplyMappedLookupTableEintOp::verify() {
}
::mlir::LogicalResult Dot::verify() {
if (::mlir::failed(mlir::verifyCompatibleShape(this->lhs().getType(),
this->rhs().getType()))) {
if (::mlir::failed(mlir::verifyCompatibleShape(this->getLhs().getType(),
this->getRhs().getType()))) {
return this->emitOpError("arguments have incompatible shapes");
}
auto lhsEltType = this->lhs()
auto lhsEltType = this->getLhs()
.getType()
.cast<mlir::TensorType>()
.getElementType()
.dyn_cast<FHE::FheIntegerInterface>();
auto rhsEltType = this->rhs()
auto rhsEltType = this->getRhs()
.getType()
.cast<mlir::TensorType>()
.getElementType()
@@ -480,8 +481,8 @@ mlir::LogicalResult SumOp::verify() {
llvm::ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t inputDimensions = (int64_t)inputShape.size();
mlir::ArrayAttr axes = this->axes();
bool keepDims = this->keep_dims();
mlir::ArrayAttr axes = this->getAxes();
bool keepDims = this->getKeepDims();
auto axesToDestroy = std::unordered_set<int64_t>{};
for (mlir::Attribute axisAttribute : axes) {
@@ -543,8 +544,8 @@ mlir::LogicalResult ConcatOp::verify() {
return mlir::failure();
}
int64_t axis = this->axis();
mlir::Value out = this->out();
int64_t axis = this->getAxis();
mlir::Value out = this->getOut();
auto outVectorType = out.getType().dyn_cast<mlir::TensorType>();
auto outElementType =
@@ -561,7 +562,7 @@ mlir::LogicalResult ConcatOp::verify() {
int64_t expectedOutputElementsInAxis = 0;
size_t index = 0;
for (mlir::Value in : this->ins()) {
for (mlir::Value in : this->getIns()) {
auto inVectorType = in.getType().dyn_cast<mlir::TensorType>();
auto inElementType =
inVectorType.getElementType().dyn_cast<FHE::FheIntegerInterface>();
@@ -626,9 +627,9 @@ mlir::LogicalResult ConcatOp::verify() {
/// something else
template <typename MatMulOp> mlir::LogicalResult verifyMatmul(MatMulOp &op) {
auto lhsType =
((mlir::Type)op.lhs().getType()).cast<mlir::RankedTensorType>();
((mlir::Type)op.getLhs().getType()).cast<mlir::RankedTensorType>();
auto rhsType =
((mlir::Type)op.rhs().getType()).cast<mlir::RankedTensorType>();
((mlir::Type)op.getRhs().getType()).cast<mlir::RankedTensorType>();
llvm::ArrayRef<int64_t> lhsShape = lhsType.getShape();
llvm::ArrayRef<int64_t> rhsShape = rhsType.getShape();
@@ -786,9 +787,10 @@ mlir::LogicalResult MatMulIntEintOp::verify() {
mlir::SmallVector<int64_t, 4>
getPaddingFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
mlir::SmallVector<int64_t, 4> paddingInts;
llvm::Optional<mlir::DenseIntElementsAttr> optionalPadding = convOp.padding();
if (optionalPadding.hasValue()) {
auto paddingAttr = optionalPadding.getValue();
std::optional<mlir::DenseIntElementsAttr> optionalPadding =
convOp.getPadding();
if (optionalPadding.has_value()) {
auto paddingAttr = optionalPadding.value();
auto paddingAttrShape =
paddingAttr.getType().cast<RankedTensorType>().getShape();
assert(paddingAttrShape.size() == 1 && paddingAttrShape[0] == 4 &&
@@ -804,9 +806,10 @@ getPaddingFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
mlir::SmallVector<int64_t, 2>
getStridesFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
mlir::SmallVector<int64_t, 2> stridesInts;
llvm::Optional<mlir::DenseIntElementsAttr> optionalStrides = convOp.strides();
if (optionalStrides.hasValue()) {
auto stridesAttr = optionalStrides.getValue();
std::optional<mlir::DenseIntElementsAttr> optionalStrides =
convOp.getStrides();
if (optionalStrides.has_value()) {
auto stridesAttr = optionalStrides.value();
auto stridesAttrShape =
stridesAttr.getType().cast<RankedTensorType>().getShape();
assert(stridesAttrShape.size() == 1 && stridesAttrShape[0] == 2 &&
@@ -822,10 +825,10 @@ getStridesFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
mlir::SmallVector<int64_t, 2>
getDilationsFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
mlir::SmallVector<int64_t, 2> dilationsInts;
llvm::Optional<mlir::DenseIntElementsAttr> optionalDilations =
convOp.dilations();
if (optionalDilations.hasValue()) {
auto dilationsAttr = optionalDilations.getValue();
std::optional<mlir::DenseIntElementsAttr> optionalDilations =
convOp.getDilations();
if (optionalDilations.has_value()) {
auto dilationsAttr = optionalDilations.value();
auto dilationsAttrShape =
dilationsAttr.getType().cast<RankedTensorType>().getShape();
assert(dilationsAttrShape.size() == 1 && dilationsAttrShape[0] == 2 &&
@@ -840,18 +843,18 @@ getDilationsFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
}
int64_t getGroupFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
llvm::Optional<uint64_t> optionalGroup = convOp.group();
if (optionalGroup.hasValue())
return optionalGroup.getValue();
std::optional<uint64_t> optionalGroup = convOp.getGroup();
if (optionalGroup.has_value())
return optionalGroup.value();
return 1;
}
/// Verify the Conv2d shapes, attributes, and expected output dimensions
mlir::LogicalResult Conv2dOp::verify() {
auto inputTy =
((mlir::Type)this->input().getType()).cast<mlir::RankedTensorType>();
((mlir::Type)this->getInput().getType()).cast<mlir::RankedTensorType>();
auto weightTy =
((mlir::Type)this->weight().getType()).cast<mlir::RankedTensorType>();
((mlir::Type)this->getWeight().getType()).cast<mlir::RankedTensorType>();
auto resultTy =
((mlir::Type)this->getResult().getType()).cast<mlir::RankedTensorType>();
auto inputShape = inputTy.getShape();
@@ -890,9 +893,10 @@ mlir::LogicalResult Conv2dOp::verify() {
// Checking attributes
mlir::SmallVector<int64_t, 4> paddingInts = getPaddingFromConv2d(*this);
llvm::Optional<mlir::DenseIntElementsAttr> optionalPadding = this->padding();
if (optionalPadding.hasValue()) {
auto paddingAttr = optionalPadding.getValue();
std::optional<mlir::DenseIntElementsAttr> optionalPadding =
this->getPadding();
if (optionalPadding.has_value()) {
auto paddingAttr = optionalPadding.value();
auto paddingAttrShape =
paddingAttr.getType().cast<RankedTensorType>().getShape();
if (paddingAttrShape.size() != 1 || paddingAttrShape[0] != 4) {
@@ -918,9 +922,10 @@ mlir::LogicalResult Conv2dOp::verify() {
}
}
mlir::SmallVector<int64_t, 2> stridesInts = getStridesFromConv2d(*this);
llvm::Optional<mlir::DenseIntElementsAttr> optionalStrides = this->strides();
if (optionalStrides.hasValue()) {
auto stridesAttr = optionalStrides.getValue();
std::optional<mlir::DenseIntElementsAttr> optionalStrides =
this->getStrides();
if (optionalStrides.has_value()) {
auto stridesAttr = optionalStrides.value();
auto stridesAttrShape =
stridesAttr.getType().cast<RankedTensorType>().getShape();
if (stridesAttrShape.size() != 1 || stridesAttrShape[0] != 2) {
@@ -939,10 +944,10 @@ mlir::LogicalResult Conv2dOp::verify() {
}
}
mlir::SmallVector<int64_t, 2> dilationsInts = getDilationsFromConv2d(*this);
llvm::Optional<mlir::DenseIntElementsAttr> optionalDilations =
this->dilations();
if (optionalDilations.hasValue()) {
auto dilationsAttr = optionalDilations.getValue();
std::optional<mlir::DenseIntElementsAttr> optionalDilations =
this->getDilations();
if (optionalDilations.has_value()) {
auto dilationsAttr = optionalDilations.value();
auto dilationsAttrShape =
dilationsAttr.getType().cast<RankedTensorType>().getShape();
if (dilationsAttrShape.size() != 1 || dilationsAttrShape[0] != 2) {
@@ -975,7 +980,7 @@ mlir::LogicalResult Conv2dOp::verify() {
resultH = resultShape[2], resultW = resultShape[3];
// Bias check if specified
mlir::Value bias = this->bias();
mlir::Value bias = this->getBias();
if (bias) {
auto biasTy = ((mlir::Type)bias.getType()).cast<mlir::RankedTensorType>();
auto biasShape = biasTy.getShape();
@@ -1050,7 +1055,7 @@ mlir::LogicalResult Conv2dOp::verify() {
mlir::LogicalResult Maxpool2dOp::verify() {
const mlir::RankedTensorType inputTy =
this->input().getType().cast<mlir::RankedTensorType>();
this->getInput().getType().cast<mlir::RankedTensorType>();
const mlir::RankedTensorType outputTy =
this->getResult().getType().cast<mlir::RankedTensorType>();
@@ -1085,7 +1090,7 @@ mlir::LogicalResult Maxpool2dOp::verify() {
const int64_t inputH = inputShape[2];
const int64_t inputW = inputShape[3];
const mlir::DenseIntElementsAttr kernelShapeAttr = this->kernel_shape();
const mlir::DenseIntElementsAttr kernelShapeAttr = this->getKernelShape();
const mlir::RankedTensorType kernelShapeAttrTy =
kernelShapeAttr.getType().cast<mlir::RankedTensorType>();
const llvm::ArrayRef<int64_t> kernelShapeAttrShape =
@@ -1108,9 +1113,9 @@ mlir::LogicalResult Maxpool2dOp::verify() {
mlir::SmallVector<int64_t, 2> strides;
const llvm::Optional<mlir::DenseIntElementsAttr> maybeStridesAttr =
this->strides();
if (maybeStridesAttr.hasValue()) {
const mlir::DenseIntElementsAttr stridesAttr = maybeStridesAttr.getValue();
this->getStrides();
if (maybeStridesAttr.has_value()) {
const mlir::DenseIntElementsAttr stridesAttr = maybeStridesAttr.value();
const mlir::RankedTensorType stridesAttrTy =
stridesAttr.getType().cast<mlir::RankedTensorType>();
const llvm::ArrayRef<int64_t> stridesAttrShape = stridesAttrTy.getShape();
@@ -1141,10 +1146,9 @@ mlir::LogicalResult Maxpool2dOp::verify() {
mlir::SmallVector<int64_t, 2> dilations;
const llvm::Optional<mlir::DenseIntElementsAttr> maybeDilationsAttr =
this->dilations();
if (maybeDilationsAttr.hasValue()) {
const mlir::DenseIntElementsAttr dilationsAttr =
maybeDilationsAttr.getValue();
this->getDilations();
if (maybeDilationsAttr.has_value()) {
const mlir::DenseIntElementsAttr dilationsAttr = maybeDilationsAttr.value();
const mlir::RankedTensorType dilationsAttrTy =
dilationsAttr.getType().cast<mlir::RankedTensorType>();
const llvm::ArrayRef<int64_t> dilationsAttrShape =
@@ -1185,7 +1189,7 @@ mlir::LogicalResult Maxpool2dOp::verify() {
expectedOutputW,
};
if (outputShape != llvm::makeArrayRef(expectedOutputShape)) {
if (outputShape != llvm::ArrayRef(expectedOutputShape)) {
this->emitOpError() << "expected output to be of shape "
<< "(" << expectedOutputShape << ") "
<< "but it is of shape "
@@ -1203,7 +1207,9 @@ mlir::LogicalResult FromElementOp::verify() {
auto inType = in.getType();
auto outType = out.getType().dyn_cast<mlir::TensorType>();
auto expectedOutType = outType.cloneWith({1}, inType);
llvm::SmallVector<int64_t> shape{1};
auto expectedOutType = outType.cloneWith(std::optional{shape}, inType);
if (outType != expectedOutType) {
this->emitOpError() << "has invalid output type (expected "
<< expectedOutType << ", got " << outType << ")";
@@ -1215,7 +1221,7 @@ mlir::LogicalResult FromElementOp::verify() {
/// Verify the transpose shapes
mlir::LogicalResult TransposeOp::verify() {
mlir::Type tensorTy = ((mlir::Type)this->tensor().getType());
mlir::Type tensorTy = ((mlir::Type)this->getTensor().getType());
if (!tensorTy.isa<RankedTensorType>()) {
this->emitOpError() << "should have operand as tensor";
return mlir::failure();
@@ -1243,7 +1249,7 @@ mlir::LogicalResult TransposeOp::verify() {
int64_t inputDimensions = (int64_t)inShape.size();
mlir::ArrayAttr axes = this->axes();
mlir::ArrayAttr axes = this->getAxes();
if (axes.empty()) {
for (int64_t i = 0; i < inputDimensions; i++) {
if (inShape[i] != outShape[inputDimensions - (i + 1)]) {
@@ -1294,7 +1300,7 @@ mlir::LogicalResult TransposeOp::verify() {
}
mlir::LogicalResult ToSignedOp::verify() {
auto inputType = this->input().getType().cast<mlir::ShapedType>();
auto inputType = this->getInput().getType().cast<mlir::ShapedType>();
auto outputType = this->getResult().getType().cast<mlir::ShapedType>();
llvm::ArrayRef<int64_t> inputShape = inputType.getShape();
@@ -1322,7 +1328,7 @@ mlir::LogicalResult ToSignedOp::verify() {
mlir::LogicalResult ToUnsignedOp::verify() {
mlir::ShapedType inputType =
this->input().getType().dyn_cast_or_null<mlir::ShapedType>();
this->getInput().getType().dyn_cast_or_null<mlir::ShapedType>();
mlir::ShapedType outputType =
this->getResult().getType().dyn_cast_or_null<mlir::ShapedType>();

View File

@@ -3,7 +3,7 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/SCF/IR/SCF.h>

View File

@@ -22,8 +22,8 @@
#include <mlir/Dialect/Bufferization/Transforms/Passes.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/IR/BlockAndValueMapping.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/IRMapping.h>
#include <mlir/Transforms/RegionUtils.h>
#define GEN_PASS_CLASSES

View File

@@ -16,13 +16,13 @@
#include <concretelang/Support/Constants.h>
#include <concretelang/Support/math.h>
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/BlockAndValueMapping.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/IRMapping.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Interfaces/ViewLikeInterface.h>
@@ -94,7 +94,7 @@ aggregateBeneficiaryOps(Operation *op, SetVector<Operation *> &beneficiaryOps,
}
LogicalResult coarsenDFTask(RT::DataflowTaskOp taskOp) {
Region &taskOpBody = taskOp.body();
Region &taskOpBody = taskOp.getBody();
// Identify uses from values defined outside of the scope.
SetVector<Value> sinkCandidates;
@@ -111,7 +111,7 @@ LogicalResult coarsenDFTask(RT::DataflowTaskOp taskOp) {
}
// Insert operations so that the defs get cloned before uses.
BlockAndValueMapping map;
IRMapping map;
OpBuilder builder(taskOpBody);
for (Operation *op : toBeSunk) {
OpBuilder::InsertionGuard guard(builder);
@@ -156,7 +156,7 @@ struct BuildDataflowTaskGraphPass
protected:
void processOperation(mlir::Operation *op) {
if (isCandidateForTask(op)) {
BlockAndValueMapping map;
IRMapping map;
Region &opBody = getOperation().getBody();
OpBuilder builder(opBody);
@@ -166,7 +166,7 @@ protected:
op->getLoc(), op->getResultTypes(), op->getOperands());
// Add the operation to the task
OpBuilder tbbuilder(dftop.body());
OpBuilder tbbuilder(dftop.getBody());
Operation *clonedOp = tbbuilder.clone(*op, map);
// Coarsen granularity by aggregating all dependence related
@@ -180,7 +180,7 @@ protected:
// Replace the uses of defined values
for (auto pair : llvm::zip(op->getResults(), clonedOp->getResults()))
replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair),
dftop.body());
dftop.getBody());
// Replace uses of the values defined by the task
for (auto pair : llvm::zip(op->getResults(), dftop->getResults()))
replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair),

View File

@@ -21,12 +21,12 @@
#include <concretelang/Support/math.h>
#include <llvm/IR/Instructions.h>
#include <mlir/Analysis/DataFlowAnalysis.h>
#include <mlir/Analysis/DataFlowFramework.h>
#include <mlir/Conversion/LLVMCommon/ConversionTarget.h>
#include <mlir/Conversion/LLVMCommon/Pattern.h>
#include <mlir/Conversion/LLVMCommon/VectorPattern.h>
#include <mlir/Dialect/Affine/Utils.h>
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Bufferization/Transforms/Passes.h>
#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
@@ -35,10 +35,10 @@
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/BlockAndValueMapping.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/IRMapping.h>
#include <mlir/IR/SymbolTable.h>
#include <mlir/Interfaces/ViewLikeInterface.h>
#include <mlir/Pass/PassManager.h>
@@ -60,7 +60,7 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp,
StringRef workFunctionName) {
Location loc = DFTOp.getLoc();
OpBuilder builder(DFTOp.getContext());
Region &DFTOpBody = DFTOp.body();
Region &DFTOpBody = DFTOp.getBody();
OpBuilder::InsertionGuard guard(builder);
// Instead of outlining with the same operands/results, we pass all
@@ -82,7 +82,7 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp,
outlinedEntryBlock->addArguments(type.getInputs(), locations);
outlinedFuncBody.push_back(outlinedEntryBlock);
BlockAndValueMapping map;
IRMapping map;
int input_offset = DFTOp.getNumResults();
Block &entryBlock = outlinedFuncBody.front();
builder.setInsertionPointToStart(&entryBlock);
@@ -241,7 +241,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
// unsupported even in the LLVMIR Dialect - this needs to use two
// placeholders for each output, before and after the
// CreateAsyncTaskOp.
BlockAndValueMapping map;
IRMapping map;
for (auto result : DFTOp.getResults()) {
Type futType = RT::PointerType::get(RT::FutureType::get(result.getType()));
auto brpp = builder.create<RT::BuildReturnPtrPlaceholderOp>(DFTOp.getLoc(),

View File

@@ -21,20 +21,20 @@
#include <concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h>
#include <llvm/IR/Instructions.h>
#include <llvm/Support/Compiler.h>
#include <mlir/Analysis/DataFlowAnalysis.h>
#include <mlir/Analysis/DataFlowFramework.h>
#include <mlir/Conversion/LLVMCommon/ConversionTarget.h>
#include <mlir/Conversion/LLVMCommon/Pattern.h>
#include <mlir/Conversion/LLVMCommon/VectorPattern.h>
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Func/Transforms/FuncConversions.h>
#include <mlir/Dialect/LLVMIR/FunctionCallUtils.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/BlockAndValueMapping.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/IRMapping.h>
#include <mlir/IR/SymbolTable.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/LLVM.h>
@@ -113,16 +113,18 @@ struct MakeReadyFutureOpInterfaceLowering
// explicitly space that we can reference as a base for the
// future.
auto allocFuncOp = mlir::LLVM::lookupOrCreateMallocFn(
mrfOp->getParentOfType<ModuleOp>(), getIndexType());
mrfOp->getParentOfType<ModuleOp>(), getIndexType(),
getTypeConverter()->useOpaquePointers());
auto sizeBytes = getSizeInBytes(
mrfOp.getLoc(), adaptor.getOperands().getTypes().front(), rewriter);
auto results = mlir::LLVM::createLLVMCall(
rewriter, mrfOp.getLoc(), allocFuncOp, {sizeBytes}, getVoidPtrType());
auto results =
rewriter.create<LLVM::CallOp>(mrfOp.getLoc(), allocFuncOp, sizeBytes);
Value allocatedPtr = rewriter.create<mlir::LLVM::BitcastOp>(
mrfOp.getLoc(),
mlir::LLVM::LLVMPointerType::get(
adaptor.getOperands().getTypes().front()),
results[0]);
results.getResult());
rewriter.create<LLVM::StoreOp>(mrfOp.getLoc(),
adaptor.getOperands().front(), allocatedPtr);
SmallVector<Value, 4> mrfOperands = {adaptor.getOperands()};
@@ -150,7 +152,7 @@ struct AwaitFutureOpInterfaceLowering
afOp.getLoc(),
mlir::LLVM::LLVMPointerType::get(
(*getTypeConverter()).convertType(afOp.getResult().getType())),
afCallOp.getResult(0));
afCallOp.getResult());
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(afOp, futVal);
return success();
}

View File

@@ -42,7 +42,8 @@ void RTDialect::initialize() {
::mlir::Type RTDialect::parseType(::mlir::DialectAsmParser &parser) const {
mlir::Type type;
if (parser.parseOptionalKeyword("future").succeeded()) {
generatedTypeParser(parser, "future", type);
llvm::StringRef mnenomic;
generatedTypeParser(parser, &mnenomic, type);
return type;
}
return type;

View File

@@ -35,65 +35,65 @@ void DataflowTaskOp::getSuccessorRegions(
Optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {}
llvm::Optional<mlir::Operation *>
std::optional<mlir::Operation *>
DataflowTaskOp::buildDealloc(OpBuilder &builder, Value alloc) {
return builder.create<DeallocateFutureOp>(alloc.getLoc(), alloc)
.getOperation();
}
llvm::Optional<mlir::Value> DataflowTaskOp::buildClone(OpBuilder &builder,
Value alloc) {
std::optional<mlir::Value> DataflowTaskOp::buildClone(OpBuilder &builder,
Value alloc) {
return builder.create<CloneFutureOp>(alloc.getLoc(), alloc).getResult();
}
void DataflowTaskOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
for (auto input : inputs())
for (auto input : getInputs())
effects.emplace_back(MemoryEffects::Read::get(), input,
SideEffects::DefaultResource::get());
for (auto output : outputs())
for (auto output : getOutputs())
effects.emplace_back(MemoryEffects::Write::get(), output,
SideEffects::DefaultResource::get());
for (auto output : outputs())
for (auto output : getOutputs())
effects.emplace_back(MemoryEffects::Allocate::get(), output,
SideEffects::DefaultResource::get());
}
llvm::Optional<mlir::Operation *>
CloneFutureOp::buildDealloc(OpBuilder &builder, Value alloc) {
std::optional<mlir::Operation *> CloneFutureOp::buildDealloc(OpBuilder &builder,
Value alloc) {
return builder.create<DeallocateFutureOp>(alloc.getLoc(), alloc)
.getOperation();
}
llvm::Optional<mlir::Value> CloneFutureOp::buildClone(OpBuilder &builder,
Value alloc) {
std::optional<mlir::Value> CloneFutureOp::buildClone(OpBuilder &builder,
Value alloc) {
return builder.create<CloneFutureOp>(alloc.getLoc(), alloc).getResult();
}
void CloneFutureOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Read::get(), input(),
effects.emplace_back(MemoryEffects::Read::get(), getInput(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), output(),
effects.emplace_back(MemoryEffects::Write::get(), getOutput(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Allocate::get(), output(),
effects.emplace_back(MemoryEffects::Allocate::get(), getOutput(),
SideEffects::DefaultResource::get());
}
llvm::Optional<mlir::Operation *>
std::optional<mlir::Operation *>
MakeReadyFutureOp::buildDealloc(OpBuilder &builder, Value alloc) {
return builder.create<DeallocateFutureOp>(alloc.getLoc(), alloc)
.getOperation();
}
llvm::Optional<mlir::Value> MakeReadyFutureOp::buildClone(OpBuilder &builder,
Value alloc) {
std::optional<mlir::Value> MakeReadyFutureOp::buildClone(OpBuilder &builder,
Value alloc) {
return builder.create<CloneFutureOp>(alloc.getLoc(), alloc).getResult();
}
void MakeReadyFutureOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Read::get(), input(),
effects.emplace_back(MemoryEffects::Read::get(), getInput(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), output(),
effects.emplace_back(MemoryEffects::Write::get(), getOutput(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Allocate::get(), output(),
effects.emplace_back(MemoryEffects::Allocate::get(), getOutput(),
SideEffects::DefaultResource::get());
}

View File

@@ -36,14 +36,14 @@ struct DerefWorkFunctionArgumentPtrPlaceholderOpBufferizationInterface
return false;
}
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::None;
return BufferRelation::Unknown;
}
LogicalResult bufferize(Operation *bop, RewriterBase &rewriter,
@@ -70,7 +70,7 @@ struct DerefWorkFunctionArgumentPtrPlaceholderOpBufferizationInterface
if (failed(bufferOrErr))
return failure();
Value buffer = bufferOrErr.getValue();
Value buffer = bufferOrErr.value();
newOperands.push_back(buffer);
} else {
newOperands.push_back(opOperand.get());
@@ -81,7 +81,7 @@ struct DerefWorkFunctionArgumentPtrPlaceholderOpBufferizationInterface
for (OpResult res : op->getResults()) {
if (TensorType t = res.getType().dyn_cast<TensorType>()) {
BaseMemRefType memrefType = getMemRefType(t, options);
BaseMemRefType memrefType = getMemRefType(res, options);
newResultTypes.push_back(memrefType);
} else {
newResultTypes.push_back(res.getType());
@@ -112,14 +112,14 @@ struct MakeReadyFutureOpBufferizationInterface
return false;
}
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::None;
return BufferRelation::Unknown;
}
LogicalResult bufferize(Operation *bop, RewriterBase &rewriter,
@@ -145,7 +145,7 @@ struct MakeReadyFutureOpBufferizationInterface
if (failed(bufferOrErr))
return failure();
Value buffer = bufferOrErr.getValue();
Value buffer = bufferOrErr.value();
newOperands.push_back(buffer);
} else {
newOperands.push_back(opOperand.get());
@@ -156,7 +156,7 @@ struct MakeReadyFutureOpBufferizationInterface
for (OpResult res : op->getResults()) {
if (TensorType t = res.getType().dyn_cast<TensorType>()) {
BaseMemRefType memrefType = getMemRefType(t, options);
BaseMemRefType memrefType = getMemRefType(res, options);
newResultTypes.push_back(memrefType);
} else {
newResultTypes.push_back(res.getType());
@@ -186,14 +186,14 @@ struct WorkFunctionReturnOpBufferizationInterface
return false;
}
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::None;
return BufferRelation::Unknown;
}
LogicalResult bufferize(Operation *bop, RewriterBase &rewriter,
@@ -219,7 +219,7 @@ struct WorkFunctionReturnOpBufferizationInterface
if (failed(bufferOrErr))
return failure();
Value buffer = bufferOrErr.getValue();
Value buffer = bufferOrErr.value();
newOperands.push_back(buffer);
} else {
newOperands.push_back(opOperand.get());
@@ -230,7 +230,7 @@ struct WorkFunctionReturnOpBufferizationInterface
for (OpResult res : op->getResults()) {
if (TensorType t = res.getType().dyn_cast<TensorType>()) {
BaseMemRefType memrefType = getMemRefType(t, options);
BaseMemRefType memrefType = getMemRefType(res, options);
newResultTypes.push_back(memrefType);
} else {
newResultTypes.push_back(res.getType());
@@ -260,14 +260,14 @@ struct AwaitFutureOpBufferizationInterface
return false;
}
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::None;
return BufferRelation::Unknown;
}
LogicalResult bufferize(Operation *bop, RewriterBase &rewriter,
@@ -293,7 +293,7 @@ struct AwaitFutureOpBufferizationInterface
if (failed(bufferOrErr))
return failure();
Value buffer = bufferOrErr.getValue();
Value buffer = bufferOrErr.value();
newOperands.push_back(buffer);
} else {
newOperands.push_back(opOperand.get());
@@ -304,7 +304,7 @@ struct AwaitFutureOpBufferizationInterface
for (OpResult res : op->getResults()) {
if (TensorType t = res.getType().dyn_cast<TensorType>()) {
BaseMemRefType memrefType = getMemRefType(t, options);
BaseMemRefType memrefType = getMemRefType(res, options);
newResultTypes.push_back(memrefType);
} else {
newResultTypes.push_back(res.getType());

View File

@@ -7,7 +7,7 @@ add_mlir_dialect_library(
mlir-headers
LINK_LIBS
PUBLIC
MLIRArithmeticDialect
MLIRArithDialect
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRIR

View File

@@ -19,8 +19,8 @@ namespace concretelang {
namespace SDFG {
mlir::LogicalResult Put::verify() {
mlir::Type streamElementType =
stream().getType().cast<StreamType>().getElementType();
mlir::Type elementType = data().getType();
getStream().getType().cast<StreamType>().getElementType();
mlir::Type elementType = getData().getType();
if (streamElementType != elementType) {
emitError()
@@ -34,10 +34,10 @@ mlir::LogicalResult Put::verify() {
}
mlir::LogicalResult MakeProcess::checkStreams(size_t numIn, size_t numOut) {
mlir::OperandRange streams = this->streams();
mlir::OperandRange streams = this->getStreams();
if (streams.size() != numIn + numOut) {
emitError() << "Process `" << stringifyProcessKind(type())
emitError() << "Process `" << stringifyProcessKind(getType())
<< "` expects 3 streams, but " << streams.size()
<< " were given.";
return mlir::failure();
@@ -48,7 +48,7 @@ mlir::LogicalResult MakeProcess::checkStreams(size_t numIn, size_t numOut) {
if (in && !in.createsInputStream()) {
emitError() << "Stream #" << (i + 1) << " of process `"
<< stringifyProcessKind(type())
<< stringifyProcessKind(getType())
<< "` must be an input stream.";
return mlir::failure();
}
@@ -59,7 +59,7 @@ mlir::LogicalResult MakeProcess::checkStreams(size_t numIn, size_t numOut) {
if (out && !out.createsOutputStream()) {
emitError() << "Stream #" << (i + 1) << " of process `"
<< stringifyProcessKind(type())
<< stringifyProcessKind(getType())
<< "` must be an output stream.";
return mlir::failure();
}
@@ -69,7 +69,7 @@ mlir::LogicalResult MakeProcess::checkStreams(size_t numIn, size_t numOut) {
}
mlir::LogicalResult MakeProcess::verify() {
switch (type()) {
switch (getType()) {
case ProcessKind::add_eint:
return checkStreams(2, 1);
case ProcessKind::add_eint_int:

View File

@@ -3,7 +3,7 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -43,7 +43,7 @@ namespace {} // namespace
namespace {
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
size_t rank) {
std::vector<int64_t> shape(rank, -1);
std::vector<int64_t> shape(rank, mlir::ShapedType::kDynamic);
mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0);
for (size_t i = 0; i < rank; i++) {
expr = expr +
@@ -84,14 +84,14 @@ struct BufferizableWithCallOpInterface
return false;
}
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::None;
return BufferRelation::Unknown;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,

View File

@@ -14,7 +14,7 @@ add_mlir_dialect_library(
SDFGDialect
ConcretelangSDFGInterfaces
ConcretelangConversion
MLIRArithmeticDialect
MLIRArithDialect
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRIR

View File

@@ -58,8 +58,8 @@ mlir::LogicalResult _verifyGLWEIntegerOperator(mlir::OpState &op,
/// (!TFHE.glwe<{dim,poly,bits}{p}>))
template <class Operator>
mlir::LogicalResult verifyGLWEIntegerOperator(Operator &op) {
auto a = ((mlir::Type)(op.a().getType())).cast<GLWECipherTextType>();
auto b = ((mlir::Type)(op.b().getType())).cast<IntegerType>();
auto a = ((mlir::Type)(op.getA().getType())).cast<GLWECipherTextType>();
auto b = ((mlir::Type)(op.getB().getType())).cast<IntegerType>();
auto result =
((mlir::Type)(op.getResult().getType())).cast<GLWECipherTextType>();
@@ -71,8 +71,8 @@ mlir::LogicalResult verifyGLWEIntegerOperator(Operator &op) {
/// (!TFHE.glwe<{dim,poly,bits}{p}>))
template <class Operator>
mlir::LogicalResult verifyIntegerGLWEOperator(Operator &op) {
auto a = ((mlir::Type)(op.a().getType())).cast<IntegerType>();
auto b = ((mlir::Type)(op.b().getType())).cast<GLWECipherTextType>();
auto a = ((mlir::Type)(op.getA().getType())).cast<IntegerType>();
auto b = ((mlir::Type)(op.getB().getType())).cast<GLWECipherTextType>();
auto result =
((mlir::Type)(op.getResult().getType())).cast<GLWECipherTextType>();
@@ -85,8 +85,8 @@ mlir::LogicalResult verifyIntegerGLWEOperator(Operator &op) {
/// (!TFHE.glwe<{dim,poly,bits}{p}>))
template <class Operator>
mlir::LogicalResult verifyBinaryGLWEOperator(Operator &op) {
auto a = ((mlir::Type)(op.a().getType())).cast<GLWECipherTextType>();
auto b = ((mlir::Type)(op.b().getType())).cast<GLWECipherTextType>();
auto a = ((mlir::Type)(op.getA().getType())).cast<GLWECipherTextType>();
auto b = ((mlir::Type)(op.getB().getType())).cast<GLWECipherTextType>();
auto result =
((mlir::Type)(op.getResult().getType())).cast<GLWECipherTextType>();
@@ -114,7 +114,7 @@ mlir::LogicalResult verifyBinaryGLWEOperator(Operator &op) {
/// (!TFHE.glwe<{dim,poly,bits}{p}>))
template <class Operator>
mlir::LogicalResult verifyUnaryGLWEOperator(Operator &op) {
auto a = ((mlir::Type)(op.a().getType())).cast<GLWECipherTextType>();
auto a = ((mlir::Type)(op.getA().getType())).cast<GLWECipherTextType>();
auto result =
((mlir::Type)(op.getResult().getType())).cast<GLWECipherTextType>();

View File

@@ -3,7 +3,7 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
@@ -17,7 +17,7 @@ namespace concretelang {
namespace {
/// Get the constant integer that the cleartext was created from if it exists.
llvm::Optional<IntegerAttr>
std::optional<IntegerAttr>
getConstantIntFromCleartextIfExists(mlir::Value cleartext) {
auto constantOp = cleartext.getDefiningOp();
if (constantOp == nullptr)
@@ -46,8 +46,8 @@ public:
auto cleartext = op.getOperand(1);
auto constIntToMul = getConstantIntFromCleartextIfExists(cleartext);
// Constant integer
if (constIntToMul.hasValue()) {
auto toMul = constIntToMul.getValue().getInt();
if (constIntToMul.has_value()) {
auto toMul = constIntToMul.value().getInt();
if (toMul == 0) {
rewriter.replaceOpWithNewOp<mlir::concretelang::TFHE::ZeroGLWEOp>(
op, op.getResult().getType());

View File

@@ -36,14 +36,14 @@ struct TrivialBufferizableInterface
return false;
}
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::None;
return BufferRelation::Unknown;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -55,7 +55,7 @@ struct TrivialBufferizableInterface
operands.push_back(operand.get());
} else {
operands.push_back(
bufferization::getBuffer(rewriter, operand.get(), options));
*bufferization::getBuffer(rewriter, operand.get(), options));
}
}

View File

@@ -7,7 +7,7 @@ add_mlir_dialect_library(
mlir-headers
LINK_LIBS
PUBLIC
MLIRArithmeticDialect
MLIRArithDialect
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRIR

View File

@@ -28,8 +28,16 @@ if(APPLE)
endif()
target_include_directories(ConcretelangRuntime PUBLIC ${CONCRETE_CPU_INCLUDE_DIR})
target_link_libraries(ConcretelangRuntime PUBLIC concrete_cpu ConcretelangClientLib pthread m dl
$<TARGET_OBJECTS:mlir_c_runner_utils>)
target_link_libraries(
ConcretelangRuntime
PUBLIC concrete_cpu
ConcretelangClientLib
pthread
m
dl
$<TARGET_OBJECTS:mlir_c_runner_utils>
$<TARGET_OBJECTS:mlir_float16_utils>
$<TARGET_OBJECTS:MLIRSparseTensorRuntime>)
install(TARGETS ConcretelangRuntime omp EXPORT ConcretelangRuntime)
install(EXPORT ConcretelangRuntime DESTINATION "./")

View File

@@ -57,7 +57,7 @@ void CompilationFeedback::fillFromClientParameters(
crtDecompositionsOfOutputs = {};
for (auto gate : params.outputs) {
std::vector<int64_t> decomposition;
if (gate.encryption.hasValue()) {
if (gate.encryption.has_value()) {
decomposition = gate.encryption->encoding.crt;
}
crtDecompositionsOfOutputs.push_back(decomposition);

View File

@@ -5,8 +5,9 @@
#include <fstream>
#include <iostream>
#include <mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h>
#include <mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h>
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
#include <mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h>
#include <mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h>
#include <mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h>
#include <stdio.h>
@@ -92,6 +93,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() {
registry);
scf::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerBufferizableOpInterfaceExternalModels(registry);
linalg::registerBufferizableOpInterfaceExternalModels(registry);
RT::registerBufferizableOpInterfaceExternalModels(registry);
this->mlirContext = new mlir::MLIRContext();
this->mlirContext->appendDialectRegistry(registry);
@@ -136,17 +138,16 @@ void CompilerEngine::setEnablePass(
}
/// Returns the optimizer::Description
llvm::Expected<llvm::Optional<optimizer::Description>>
llvm::Expected<std::optional<optimizer::Description>>
CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) {
mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
mlir::ModuleOp module = res.mlirModuleRef->get();
// If the values has been overwritten returns
if (this->overrideMaxEintPrecision.hasValue() &&
this->overrideMaxMANP.hasValue()) {
if (this->overrideMaxEintPrecision.has_value() &&
this->overrideMaxMANP.has_value()) {
auto constraint = mlir::concretelang::V0FHEConstraint{
this->overrideMaxMANP.getValue(),
this->overrideMaxEintPrecision.getValue()};
return optimizer::Description{constraint, llvm::None};
this->overrideMaxMANP.value(), this->overrideMaxEintPrecision.value()};
return optimizer::Description{constraint, std::nullopt};
}
auto config = this->compilerOptions.optimizerConfig;
auto descriptions = mlir::concretelang::pipeline::getFHEContextFromFHE(
@@ -155,10 +156,10 @@ CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) {
return std::move(err);
}
if (descriptions->empty()) { // The pass has not been run
return llvm::None;
return std::nullopt;
}
if (this->compilerOptions.clientParametersFuncName.hasValue()) {
auto name = this->compilerOptions.clientParametersFuncName.getValue();
if (this->compilerOptions.clientParametersFuncName.has_value()) {
auto name = this->compilerOptions.clientParametersFuncName.value();
auto description = descriptions->find(name);
if (description == descriptions->end()) {
std::string names;
@@ -181,14 +182,14 @@ CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) {
/// set the fheContext field if the v0Constraint can be computed
/// set the fheContext field if the v0Constraint can be computed
llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
if (compilerOptions.v0Parameter.hasValue()) {
if (compilerOptions.v0Parameter.has_value()) {
// parameters come from the compiler options
auto v0Params = compilerOptions.v0Parameter.value();
if (compilerOptions.largeIntegerParameter.hasValue()) {
if (compilerOptions.largeIntegerParameter.has_value()) {
v0Params.largeInteger = compilerOptions.largeIntegerParameter;
}
V0FHEConstraint constraint;
if (compilerOptions.v0FHEConstraints.hasValue()) {
if (compilerOptions.v0FHEConstraints.has_value()) {
constraint = compilerOptions.v0FHEConstraints.value();
}
res.fheContext.emplace(
@@ -201,7 +202,7 @@ llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
if (auto err = descr.takeError()) {
return err;
}
if (!descr.get().hasValue()) {
if (!descr.get().has_value()) {
return llvm::Error::success();
}
CompilationFeedback feedback;
@@ -222,7 +223,7 @@ llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
return llvm::Error::success();
}
using OptionalLib = llvm::Optional<std::shared_ptr<CompilerEngine::Library>>;
using OptionalLib = std::optional<std::shared_ptr<CompilerEngine::Library>>;
// Compile the sources managed by the source manager `sm` to the
// target dialect `target`. If successful, the result can be retrieved
// using `getModule()` and `getLLVMModule()`, respectively depending
@@ -347,28 +348,28 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
// Generate client parameters if requested
if (this->generateClientParameters) {
if (!options.clientParametersFuncName.hasValue()) {
if (!options.clientParametersFuncName.has_value()) {
return StreamStringError(
"Generation of client parameters requested, but no function name "
"specified");
}
if (!res.fheContext.hasValue()) {
if (!res.fheContext.has_value()) {
return StreamStringError(
"Cannot generate client parameters, the fhe context is empty for " +
options.clientParametersFuncName.getValue());
options.clientParametersFuncName.value());
}
}
// Generate client parameters if requested
auto funcName = options.clientParametersFuncName.getValueOr("main");
auto funcName = options.clientParametersFuncName.value_or("main");
if (this->generateClientParameters || target == Target::LIBRARY) {
if (!res.fheContext.hasValue()) {
if (!res.fheContext.has_value()) {
// Some tests involve call a to non encrypted functions
ClientParameters emptyParams;
emptyParams.functionName = funcName;
res.clientParameters = emptyParams;
} else {
llvm::Optional<::concretelang::clientlib::ChunkInfo> chunkInfo =
llvm::None;
std::nullopt;
if (options.chunkIntegers) {
chunkInfo = ::concretelang::clientlib::ChunkInfo{options.chunkSize,
options.chunkWidth};
@@ -483,7 +484,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
return StreamStringError(
"Internal Error: Please provide a library parameter");
}
auto objPath = lib.getValue()->addCompilation(res);
auto objPath = lib.value()->addCompilation(res);
if (!objPath) {
return StreamStringError(llvm::toString(objPath.takeError()));
}
@@ -764,11 +765,11 @@ CompilerEngine::Library::addCompilation(CompilationResult &compilation) {
}
addExtraObjectFilePath(objectPath);
if (compilation.clientParameters.hasValue()) {
clientParametersList.push_back(compilation.clientParameters.getValue());
if (compilation.clientParameters.has_value()) {
clientParametersList.push_back(compilation.clientParameters.value());
}
if (compilation.feedback.hasValue()) {
compilationFeedbackList.push_back(compilation.feedback.getValue());
if (compilation.feedback.has_value()) {
compilationFeedbackList.push_back(compilation.feedback.value());
}
return objectPath;
}
@@ -791,7 +792,7 @@ std::string ensureLibDotExt(std::string path, std::string dotExt) {
llvm::Expected<std::string> CompilerEngine::Library::emit(
std::string path, std::string dotExt, std::string linker,
llvm::Optional<std::vector<std::string>> extraArgs) {
std::optional<std::vector<std::string>> extraArgs) {
auto pathDotExt = ensureLibDotExt(path, dotExt);
auto error = mlir::concretelang::emitLibrary(objectsPath, pathDotExt, linker,
extraArgs);

View File

@@ -11,7 +11,7 @@
namespace mlir {
namespace concretelang {
JITSupport::JITSupport(llvm::Optional<std::string> runtimeLibPath)
JITSupport::JITSupport(std::optional<std::string> runtimeLibPath)
: runtimeLibPath(runtimeLibPath) {}
llvm::Expected<std::unique_ptr<JitCompilationResult>>
@@ -29,7 +29,7 @@ JITSupport::compile(llvm::SourceMgr &program, CompilationOptions options) {
return std::move(err);
}
if (!options.clientParametersFuncName.hasValue()) {
if (!options.clientParametersFuncName.has_value()) {
return StreamStringError("Need to have a funcname to JIT compile");
}
// Compile from LLVM Dialect to JITLambda
@@ -40,7 +40,7 @@ JITSupport::compile(llvm::SourceMgr &program, CompilationOptions options) {
if (auto err = lambda.takeError()) {
return std::move(err);
}
if (!compilationResult.get().clientParameters.hasValue()) {
if (!compilationResult.get().clientParameters.has_value()) {
// i.e. that should not occurs
return StreamStringError("No client parameters has been generated");
}
@@ -52,9 +52,8 @@ JITSupport::compile(llvm::SourceMgr &program, CompilationOptions options) {
if (!mlir::concretelang::dfr::_dfr_is_root_node()) {
result->clientParameters = clientlib::ClientParameters();
} else {
result->clientParameters =
compilationResult.get().clientParameters.getValue();
result->feedback = compilationResult.get().feedback.getValue();
result->clientParameters = compilationResult.get().clientParameters.value();
result->feedback = compilationResult.get().feedback.value();
}
return std::move(result);
}

View File

@@ -25,7 +25,7 @@ namespace concretelang {
llvm::Expected<std::unique_ptr<JITLambda>>
JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline,
llvm::Optional<std::string> runtimeLibPath) {
std::optional<std::string> runtimeLibPath) {
// Looking for the function
auto rangeOps = module.getOps<mlir::LLVM::LLVMFuncOp>();
@@ -46,13 +46,13 @@ JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
// JIT-compiles the module. If runtimeLibPath is specified, it's passed as a
// shared library to the JIT compiler.
std::vector<llvm::StringRef> sharedLibPaths;
if (runtimeLibPath.hasValue())
sharedLibPaths.push_back(runtimeLibPath.getValue());
if (runtimeLibPath.has_value())
sharedLibPaths.push_back(runtimeLibPath.value());
mlir::ExecutionEngineOptions execOptions;
execOptions.transformer = optPipeline;
execOptions.sharedLibPaths = sharedLibPaths;
execOptions.jitCodeGenOptLevel = llvm::None;
execOptions.jitCodeGenOptLevel = std::nullopt;
execOptions.llvmModuleBuilder = nullptr;
auto maybeEngine = mlir::ExecutionEngine::create(module, execOptions);

View File

@@ -7,10 +7,10 @@
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/MC/TargetRegistry.h>
#include <llvm/Support/Host.h>
#include <llvm/Support/ToolOutputFile.h>
#include <llvm/Target/TargetMachine.h>
#include <llvm/Target/TargetOptions.h>
#include <llvm/TargetParser/Host.h>
#include <mlir/Support/FileUtilities.h>
@@ -72,13 +72,13 @@ llvm::Error emitObject(llvm::Module &module, string objectPath) {
}
string linkerCmd(vector<string> objectsPath, string libraryPath, string linker,
llvm::Optional<vector<string>> extraArgs) {
std::optional<vector<string>> extraArgs) {
string cmd = linker + libraryPath;
for (auto objectPath : objectsPath) {
cmd += " " + objectPath;
}
if (extraArgs.hasValue()) {
for (auto extraArg : extraArgs.getValue()) {
if (extraArgs.has_value()) {
for (auto extraArg : extraArgs.value()) {
cmd += " " + extraArg;
}
}
@@ -116,7 +116,7 @@ llvm::Error callCmd(string cmd) {
llvm::Error emitLibrary(vector<string> objectsPath, string libraryPath,
string linker,
llvm::Optional<vector<string>> extraArgs) {
std::optional<vector<string>> extraArgs) {
auto cmd = linkerCmd(objectsPath, libraryPath, linker, extraArgs);
return callCmd(cmd);
}

View File

@@ -13,7 +13,7 @@
#include <mlir/Transforms/Passes.h>
#include <mlir/Dialect/Affine/Passes.h>
#include <mlir/Dialect/Arithmetic/Transforms/Passes.h>
#include <mlir/Dialect/Arith/Transforms/Passes.h>
#include <mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h>
#include <mlir/Dialect/Bufferization/Transforms/Passes.h>
#include <mlir/Dialect/Linalg/Passes.h>
@@ -81,12 +81,12 @@ addPotentiallyNestedPass(mlir::PassManager &pm, std::unique_ptr<Pass> pass,
}
}
llvm::Expected<std::map<std::string, llvm::Optional<optimizer::Description>>>
llvm::Expected<std::map<std::string, std::optional<optimizer::Description>>>
getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
optimizer::Config config,
std::function<bool(mlir::Pass *)> enablePass) {
llvm::Optional<size_t> oMax2norm;
llvm::Optional<size_t> oMaxWidth;
std::optional<size_t> oMax2norm;
std::optional<size_t> oMaxWidth;
optimizer::FunctionsDag dags;
mlir::PassManager pm(&context);
@@ -99,10 +99,10 @@ getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
pm,
mlir::concretelang::createMaxMANPPass(
[&](const uint64_t manp, unsigned width) {
if (!oMax2norm.hasValue() || oMax2norm.getValue() < manp)
if (!oMax2norm.has_value() || oMax2norm.value() < manp)
oMax2norm.emplace(manp);
if (!oMaxWidth.hasValue() || oMaxWidth.getValue() < width)
if (!oMaxWidth.has_value() || oMaxWidth.value() < width)
oMaxWidth.emplace(width);
}),
enablePass);
@@ -112,28 +112,28 @@ getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
" required precision",
llvm::inconvertibleErrorCode());
}
llvm::Optional<mlir::concretelang::V0FHEConstraint> constraint = llvm::None;
std::optional<mlir::concretelang::V0FHEConstraint> constraint = std::nullopt;
if (oMax2norm.hasValue() && oMaxWidth.hasValue()) {
constraint = llvm::Optional<mlir::concretelang::V0FHEConstraint>(
{/*.norm2 = */ ceilLog2(oMax2norm.getValue()),
/*.p = */ oMaxWidth.getValue()});
if (oMax2norm.has_value() && oMaxWidth.has_value()) {
constraint = std::optional<mlir::concretelang::V0FHEConstraint>(
{/*.norm2 = */ ceilLog2(oMax2norm.value()),
/*.p = */ oMaxWidth.value()});
}
addPotentiallyNestedPass(pm, optimizer::createDagPass(config, dags),
enablePass);
if (pm.run(module.getOperation()).failed()) {
return StreamStringError() << "Failed to create concrete-optimizer dag\n";
}
std::map<std::string, llvm::Optional<optimizer::Description>> descriptions;
std::map<std::string, std::optional<optimizer::Description>> descriptions;
for (auto &entry_dag : dags) {
if (!constraint) {
descriptions.insert(
decltype(descriptions)::value_type(entry_dag.first, llvm::None));
decltype(descriptions)::value_type(entry_dag.first, std::nullopt));
continue;
}
optimizer::Description description = {*constraint,
std::move(entry_dag.second)};
llvm::Optional<optimizer::Description> opt_description{
std::optional<optimizer::Description> opt_description{
std::move(description)};
descriptions.insert(decltype(descriptions)::value_type(
entry_dag.first, std::move(opt_description)));
@@ -191,7 +191,7 @@ transformHighLevelFHEOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::LogicalResult
lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,
std::optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops, bool batchOperations) {
mlir::PassManager pm(&context);
@@ -241,11 +241,12 @@ transformFHEBigInt(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::LogicalResult
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,
std::optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
if (fheContext.hasValue() && fheContext->parameter.largeInteger.hasValue()) {
if (fheContext.has_value() &&
fheContext->parameter.largeInteger.has_value()) {
pipelinePrinting("FHEToTFHECrt", pm, context);
auto dec =
fheContext.value().parameter.largeInteger.value().crtDecomposition;
@@ -255,7 +256,7 @@ lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::concretelang::createConvertFHEToTFHECrtPass(
mlir::concretelang::CrtLoweringParameters(mods)),
enablePass);
} else if (fheContext.hasValue()) {
} else if (fheContext.has_value()) {
pipelinePrinting("FHEToTFHEScalar", pm, context);
size_t polySize = fheContext.value().parameter.getPolynomialSize();
addPotentiallyNestedPass(
@@ -270,16 +271,16 @@ lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::LogicalResult
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,
std::optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("TFHEToConcrete", pm, context);
if (fheContext.hasValue()) {
if (fheContext.has_value()) {
addPotentiallyNestedPass(
pm,
mlir::concretelang::createConvertTFHEGlobalParametrizationPass(
fheContext.getValue()),
fheContext.value()),
enablePass);
}
@@ -345,8 +346,12 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::bufferization::OneShotBufferizationOptions bufferizationOptions;
bufferizationOptions.allowReturnAllocs = true;
bufferizationOptions.printConflicts = true;
bufferizationOptions.unknownTypeConversion = mlir::bufferization::
OneShotBufferizationOptions::LayoutMapOption::IdentityLayoutMap;
bufferizationOptions.unknownTypeConverterFn =
[](Value value, Attribute memorySpace,
const mlir::bufferization::BufferizationOptions &options) {
return mlir::bufferization::getMemRefTypeWithStaticIdentityLayout(
value.getType().cast<TensorType>(), memorySpace);
};
bufferizationOptions.bufferizeFunctionBoundaries = true;
bufferizationOptions.createDeallocs = true;
@@ -354,6 +359,12 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::bufferization::createOneShotBufferizePass(bufferizationOptions);
addPotentiallyNestedPass(pm, std::move(comprBuffPass), enablePass);
// The bufferization may create `linalg.map` operations; Add another
// conversion pass from linalg to loops
addPotentiallyNestedPass(pm, mlir::createConvertLinalgToLoopsPass(),
enablePass);
addPotentiallyNestedPass(
pm, mlir::concretelang::createBufferizeDataflowTaskOpsPass(), enablePass);

View File

@@ -11,6 +11,7 @@
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <optional>
#include "concrete/curves.h"
#include "concretelang/ClientLib/ClientParameters.h"
@@ -51,13 +52,13 @@ gateFromMLIRType(V0FHEContext fheContext, LweSecretKeyID secretKeyID,
bool sign = type.isSignedInteger();
return CircuitGate{
/*.encryption = */ llvm::None,
/*.encryption = */ std::nullopt,
/*.shape = */
{/*.width = */ width,
/*.dimensions = */ std::vector<int64_t>(),
/*.size = */ 0,
/* .sign */ sign},
/*.chunkInfo = */ llvm::None,
/*.chunkInfo = */ std::nullopt,
};
}
if (auto lweTy = type.dyn_cast_or_null<
@@ -70,7 +71,7 @@ gateFromMLIRType(V0FHEContext fheContext, LweSecretKeyID secretKeyID,
size_t width;
uint64_t size = 0;
std::vector<int64_t> dims;
if (chunkInfo.hasValue()) {
if (chunkInfo.has_value()) {
width = chunkInfo->size;
assert(lweTy.getWidth() % chunkInfo->width == 0);
size = lweTy.getWidth() / chunkInfo->width;
@@ -79,7 +80,7 @@ gateFromMLIRType(V0FHEContext fheContext, LweSecretKeyID secretKeyID,
width = (size_t)lweTy.getWidth();
}
return CircuitGate{
/* .encryption = */ llvm::Optional<EncryptionGate>({
/* .encryption = */ std::optional<EncryptionGate>({
/* .secretKeyID = */ secretKeyID,
/* .variance = */ variance,
/* .encoding = */
@@ -103,7 +104,7 @@ gateFromMLIRType(V0FHEContext fheContext, LweSecretKeyID secretKeyID,
mlir::concretelang::FHE::EncryptedBooleanType>()) {
size_t width = mlir::concretelang::FHE::EncryptedBooleanType::getWidth();
return CircuitGate{
/* .encryption = */ llvm::Optional<EncryptionGate>({
/* .encryption = */ std::optional<EncryptionGate>({
/* .secretKeyID = */ secretKeyID,
/* .variance = */ variance,
/* .encoding = */
@@ -120,7 +121,7 @@ gateFromMLIRType(V0FHEContext fheContext, LweSecretKeyID secretKeyID,
/*.size = */ 0,
/*.sign = */ false,
},
/*.chunkInfo = */ llvm::None,
/*.chunkInfo = */ std::nullopt,
};
}
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
@@ -189,7 +190,7 @@ createClientParametersForV0(V0FHEContext fheContext,
bskParam.inputLweDimension = v0Param.nSmall;
c.bootstrapKeys.push_back(bskParam);
}
if (v0Param.largeInteger.hasValue()) {
if (v0Param.largeInteger.has_value()) {
clientlib::PackingKeyswitchKeyParam param;
param.inputSecretKeyID = clientlib::BIG_KEY;
param.outputSecretKeyID = clientlib::BIG_KEY;

View File

@@ -12,6 +12,7 @@
#include <chrono>
#include <cmath>
#include <iostream>
#include <optional>
#include "llvm/Support/raw_ostream.h"
@@ -200,7 +201,7 @@ llvm::Expected<V0Parameter> getParameter(optimizer::Description &descr,
auto sol = (!descr.dag || config.strategy_v0)
? getV0Parameter(descr.constraint, config)
: getV1Parameter(descr.dag.getValue(), config);
: getV1Parameter(descr.dag.value(), config);
auto stop = chrono::high_resolution_clock::now();
auto duration = chrono::duration_cast<chrono::milliseconds>(stop - start);
@@ -235,7 +236,7 @@ llvm::Expected<V0Parameter> getParameter(optimizer::Description &descr,
params.brLogBase = sol.br_decomposition_base_log;
params.ksLevel = sol.ks_decomposition_level_count;
params.ksLogBase = sol.ks_decomposition_base_log;
params.largeInteger = llvm::None;
params.largeInteger = std::nullopt;
if (sol.use_wop_pbs) {
if (sol.crt_decomposition.empty()) {

View File

@@ -4,7 +4,7 @@
// for license information.
#include <llvm/ADT/STLExtras.h>
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
@@ -206,7 +206,7 @@ struct BoundsAndStep {
/// Returns the lower bound, upper bound and step of the quasi-affine
/// expression `expr` on the the induction variable from a for
/// operation.
static llvm::Optional<BoundsAndStep>
static std::optional<BoundsAndStep>
getBoundsOfQuasiAffineIVExpression(mlir::Value expr, mlir::scf::ForOp forOp) {
// Base case: expression is the induction variable itself -> return
// loop bounds
@@ -220,13 +220,13 @@ getBoundsOfQuasiAffineIVExpression(mlir::Value expr, mlir::scf::ForOp forOp) {
if (llvm::isa<mlir::arith::AddIOp, mlir::arith::SubIOp, mlir::arith::MulIOp,
mlir::arith::DivSIOp>(op)) {
llvm::Optional<BoundsAndStep> lhs =
std::optional<BoundsAndStep> lhs =
getBoundsOfQuasiAffineIVExpression(op->getOperand(0), forOp);
llvm::Optional<BoundsAndStep> rhs =
std::optional<BoundsAndStep> rhs =
getBoundsOfQuasiAffineIVExpression(op->getOperand(1), forOp);
if (!lhs.hasValue() || !rhs.hasValue())
return llvm::None;
if (!lhs.has_value() || !rhs.has_value())
return std::nullopt;
if (llvm::isa<mlir::arith::AddIOp>(op))
return *lhs + *rhs;
@@ -245,7 +245,7 @@ getBoundsOfQuasiAffineIVExpression(mlir::Value expr, mlir::scf::ForOp forOp) {
// the divisor, there may be two iterations with the same
// value. Conservatively bail out.
if (lhs->step % rhsVal != 0)
return llvm::None;
return std::nullopt;
return *lhs / rhsVal;
}
@@ -370,10 +370,10 @@ isQuasiAffineIVExpressionWithConstantStep(mlir::Value expr,
mlir::scf::ForOp tmpForOp;
if (isQuasiAffineIVExpression(expr, &tmpForOp)) {
llvm::Optional<BoundsAndStep> bas =
std::optional<BoundsAndStep> bas =
getBoundsOfQuasiAffineIVExpression(expr, tmpForOp);
if (bas.hasValue()) {
if (bas.has_value()) {
if (forOp != nullptr)
*forOp = tmpForOp;
return true;
@@ -409,10 +409,10 @@ mlir::Value hoistIndexedOp(
if (isAffine && forOp) {
llvm::Optional<BoundsAndStep> bas =
std::optional<BoundsAndStep> bas =
getBoundsOfQuasiAffineIVExpression(idx, forOp);
assert(bas.hasValue());
assert(bas.has_value());
assert(bas->step != 0);
offsets.push_back(rewriter.getIndexAttr(bas->lb));
@@ -802,7 +802,7 @@ public:
return mlir::WalkResult::skip();
if (!llvm::all_of(body, [&](mlir::Operation &op) {
return MemoryEffectOpInterface::hasNoEffect(&op);
return isMemoryEffectFree(&op);
})) {
return mlir::WalkResult::skip();
}

View File

@@ -8,7 +8,7 @@
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
@@ -79,8 +79,7 @@ struct CollapseParallelLoopsPass
if (maxPos > start)
continue;
auto band =
llvm::makeMutableArrayRef(loops.data() + start, end - start);
auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
(void)mlir::coalesceLoops(band);
break;
}

View File

@@ -7,7 +7,7 @@
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
@@ -31,10 +31,11 @@ public:
if (attr.getValue()) {
rewriter.replaceOpWithNewOp<mlir::scf::ParallelOp>(
forOp, mlir::ValueRange{forOp.getLowerBound()},
mlir::ValueRange{forOp.getUpperBound()}, forOp.getStep(), llvm::None,
mlir::ValueRange{forOp.getUpperBound()}, forOp.getStep(),
std::nullopt,
[&](mlir::OpBuilder &builder, mlir::Location location,
mlir::ValueRange indVar, mlir::ValueRange iterArgs) {
mlir::BlockAndValueMapping map;
mlir::IRMapping map;
map.map(forOp.getInductionVar(), indVar.front());
for (auto &op : forOp.getRegion().front()) {
auto newOp = builder.clone(op, map);
@@ -44,10 +45,10 @@ public:
} else {
rewriter.replaceOpWithNewOp<mlir::scf::ForOp>(
forOp, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(),
llvm::None,
std::nullopt,
[&](mlir::OpBuilder &builder, mlir::Location location,
mlir::Value indVar, mlir::ValueRange iterArgs) {
mlir::BlockAndValueMapping map;
mlir::IRMapping map;
map.map(forOp.getInductionVar(), indVar);
for (auto &op : forOp.getRegion().front()) {
auto newOp = builder.clone(op, map);

View File

@@ -17,6 +17,7 @@
#include <mlir/Support/FileUtilities.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Support/ToolUtilities.h>
#include <optional>
#include <sstream>
#include "concretelang/ClientLib/KeySet.h"
@@ -384,7 +385,7 @@ cmdlineCompilationOptions() {
options.v0Parameter = {cmdline::v0Parameter[0], cmdline::v0Parameter[1],
cmdline::v0Parameter[2], cmdline::v0Parameter[3],
cmdline::v0Parameter[4], cmdline::v0Parameter[5],
cmdline::v0Parameter[6], llvm::None};
cmdline::v0Parameter[6], std::nullopt};
}
// Setup the large integer options
@@ -477,7 +478,7 @@ mlir::LogicalResult processInputBuffer(
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
mlir::concretelang::CompilationContext::createShared();
std::string funcName = options.clientParametersFuncName.getValueOr("");
std::string funcName = options.clientParametersFuncName.value_or("");
if (action == Action::JIT_INVOKE) {
auto lambdaOrErr =
mlir::concretelang::ClientServer<mlir::concretelang::JITSupport>::

View File

@@ -2,17 +2,16 @@
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-NEXT: module {
// CHECK-NEXT: func.func @apply_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!FHE.eint<2>> {
// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3x4x!FHE.eint<2>>) outs(%0 : tensor<2x3x4x!FHE.eint<2>>) {
// CHECK-NEXT: ^bb0(%arg2: !FHE.eint<2>, %arg3: !FHE.eint<2>):
// CHECK-NEXT: %2 = "FHE.apply_lookup_table"(%arg2, %arg1) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
// CHECK-NEXT: linalg.yield %2 : !FHE.eint<2>
// CHECK-NEXT: func.func @apply_lookup_table(%[[Varg0:.*]]: tensor<2x3x4x!FHE.eint<2>>, %[[Varg1:.*]]: tensor<4xi64>) -> tensor<2x3x4x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = {{\[}}#map, #map{{\], iterator}}_types = {{\[}}"parallel", "parallel", "parallel"{{\]}}} ins(%[[Varg0]] : tensor<2x3x4x!FHE.eint<2>>) outs(%[[V0]] : tensor<2x3x4x!FHE.eint<2>>) {
// CHECK-NEXT: ^bb0(%[[Varg2:.*]]: !FHE.eint<2>, %[[Varg3:.*]]: !FHE.eint<2>):
// CHECK-NEXT: %[[V2:.*]] = "FHE.apply_lookup_table"(%[[Varg2]], %[[Varg1]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
// CHECK-NEXT: linalg.yield %[[V2]] : !FHE.eint<2>
// CHECK-NEXT: } -> tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: return %1 : tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: return %[[V1]] : tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: }
// CHECK-NEXT: }
func.func @apply_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!FHE.eint<2>> {
%1 = "FHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!FHE.eint<2>>, tensor<4xi64>) -> (tensor<2x3x4x!FHE.eint<2>>)
return %1: tensor<2x3x4x!FHE.eint<2>>

View File

@@ -4,7 +4,7 @@
// CHECK: func.func @main(%[[a0:.*]]: tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x1x6x9x!FHE.eint<5>>
// CHECK-NEXT: %[[v1:.*]] = linalg.init_tensor [3, 2] : tensor<3x2xi64>
// CHECK-NEXT: %[[v1:.*]] = tensor.empty() : tensor<3x2xi64>
// CHECK-NEXT: %[[v2:.*]] = linalg.pooling_nchw_max {dilations = dense<1> : vector<2xi64>, max_signed = "FHE.max_eint", strides = dense<1> : vector<2xi64>} ins(%arg0, %1 : tensor<1x1x8x10x!FHE.eint<5>>, tensor<3x2xi64>) outs(%0 : tensor<1x1x6x9x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>>
// CHECK-NEXT: return %[[v2]] : tensor<1x1x6x9x!FHE.eint<5>>
// CHECK-NEXT: }
@@ -23,12 +23,12 @@ func.func @main(%arg0: tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.ein
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x1x5x3x!FHE.esint<6>>
// CHECK-NEXT: %[[v1:.*]] = arith.constant dense<16> : tensor<1xi7>
// CHECK-NEXT: %[[v2:.*]] = bufferization.alloc_tensor() : tensor<1x1x5x3x!FHE.esint<6>>
// CHECK-NEXT: %[[v3:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[v0]], %[[v1]] : tensor<1x1x5x3x!FHE.esint<6>>, tensor<1xi7>) outs(%[[v2]] : tensor<1x1x5x3x!FHE.esint<6>>) {
// CHECK-NEXT: %[[v3:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[v0]], %[[v1]] : tensor<1x1x5x3x!FHE.esint<6>>, tensor<1xi7>) outs(%[[v2]] : tensor<1x1x5x3x!FHE.esint<6>>) {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.esint<6>, %[[aa1:.*]]: i7, %[[aa2:.*]]: !FHE.esint<6>):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.sub_eint_int"(%[[aa0]], %[[aa1]]) : (!FHE.esint<6>, i7) -> !FHE.esint<6>
// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.esint<6>
// CHECK-NEXT: } -> tensor<1x1x5x3x!FHE.esint<6>>
// CHECK-NEXT: %[[v4:.*]] = linalg.init_tensor [2, 3] : tensor<2x3xi64>
// CHECK-NEXT: %[[v4:.*]] = tensor.empty() : tensor<2x3xi64>
// CHECK-NEXT: %[[v5:.*]] = linalg.pooling_nchw_max {dilations = dense<1> : vector<2xi64>, max_signed = "FHE.max_eint", strides = dense<1> : vector<2xi64>} ins(%arg0, %[[v4]] : tensor<1x1x6x5x!FHE.esint<6>>, tensor<2x3xi64>) outs(%[[v3]] : tensor<1x1x5x3x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>>
// CHECK-NEXT: return %[[v5]] : tensor<1x1x5x3x!FHE.esint<6>>
// CHECK-NEXT: }

View File

@@ -2,17 +2,16 @@
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-NEXT: module {
// CHECK-NEXT: func.func @neg_eint(%arg0: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.eint<2>> {
// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3x4x!FHE.eint<2>>) outs(%0 : tensor<2x3x4x!FHE.eint<2>>) {
// CHECK-NEXT: ^bb0(%arg1: !FHE.eint<2>, %arg2: !FHE.eint<2>):
// CHECK-NEXT: %2 = "FHE.neg_eint"(%arg1) : (!FHE.eint<2>) -> !FHE.eint<2>
// CHECK-NEXT: linalg.yield %2 : !FHE.eint<2>
// CHECK-NEXT: func.func @neg_eint(%[[Varg0:.*]]: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = {{\[}}#map, #map{{\], iterator}}_types = {{\[}}"parallel", "parallel", "parallel"{{\]}}} ins(%[[Varg0]] : tensor<2x3x4x!FHE.eint<2>>) outs(%[[V0]] : tensor<2x3x4x!FHE.eint<2>>) {
// CHECK-NEXT: ^bb0(%[[Varg1:.*]]: !FHE.eint<2>, %[[Varg2:.*]]: !FHE.eint<2>):
// CHECK-NEXT: %[[V2:.*]] = "FHE.neg_eint"(%[[Varg1]]) : (!FHE.eint<2>) -> !FHE.eint<2>
// CHECK-NEXT: linalg.yield %[[V2]] : !FHE.eint<2>
// CHECK-NEXT: } -> tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: return %1 : tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: return %[[V1]] : tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: }
// CHECK-NEXT: }
func.func @neg_eint(%arg0: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.eint<2>> {
%1 = "FHELinalg.neg_eint"(%arg0): (tensor<2x3x4x!FHE.eint<2>>) -> (tensor<2x3x4x!FHE.eint<2>>)
return %1: tensor<2x3x4x!FHE.eint<2>>

View File

@@ -2,14 +2,14 @@
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-NEXT: module {
// CHECK-NEXT: func.func @main(%arg0: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.esint<2>> {
// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.esint<2>>
// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3x4x!FHE.eint<2>>) outs(%0 : tensor<2x3x4x!FHE.esint<2>>) {
// CHECK-NEXT: ^bb0(%arg1: !FHE.eint<2>, %arg2: !FHE.esint<2>):
// CHECK-NEXT: %2 = "FHE.to_signed"(%arg1) : (!FHE.eint<2>) -> !FHE.esint<2>
// CHECK-NEXT: linalg.yield %2 : !FHE.esint<2>
// CHECK-NEXT: func.func @main(%[[Varg0:.*]]: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.esint<2>> {
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.esint<2>>
// CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[Varg0]] : tensor<2x3x4x!FHE.eint<2>>) outs(%[[V0]] : tensor<2x3x4x!FHE.esint<2>>) {
// CHECK-NEXT: ^bb0(%[[Varg1:.*]]: !FHE.eint<2>, %[[Varg2:.*]]: !FHE.esint<2>):
// CHECK-NEXT: %[[V2:.*]] = "FHE.to_signed"(%[[Varg1]]) : (!FHE.eint<2>) -> !FHE.esint<2>
// CHECK-NEXT: linalg.yield %[[V2]] : !FHE.esint<2>
// CHECK-NEXT: } -> tensor<2x3x4x!FHE.esint<2>>
// CHECK-NEXT: return %1 : tensor<2x3x4x!FHE.esint<2>>
// CHECK-NEXT: return %[[V1]] : tensor<2x3x4x!FHE.esint<2>>
// CHECK-NEXT: }
// CHECK-NEXT: }
func.func @main(%arg0: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.esint<2>> {

View File

@@ -2,14 +2,14 @@
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-NEXT: module {
// CHECK-NEXT: func.func @main(%arg0: tensor<2x3x4x!FHE.esint<2>>) -> tensor<2x3x4x!FHE.eint<2>> {
// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3x4x!FHE.esint<2>>) outs(%0 : tensor<2x3x4x!FHE.eint<2>>) {
// CHECK-NEXT: ^bb0(%arg1: !FHE.esint<2>, %arg2: !FHE.eint<2>):
// CHECK-NEXT: %2 = "FHE.to_unsigned"(%arg1) : (!FHE.esint<2>) -> !FHE.eint<2>
// CHECK-NEXT: linalg.yield %2 : !FHE.eint<2>
// CHECK-NEXT: func.func @main(%[[Varg0:.*]]: tensor<2x3x4x!FHE.esint<2>>) -> tensor<2x3x4x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[Varg0]] : tensor<2x3x4x!FHE.esint<2>>) outs(%[[V0]] : tensor<2x3x4x!FHE.eint<2>>) {
// CHECK-NEXT: ^bb0(%[[Varg1:.*]]: !FHE.esint<2>, %[[Varg2:.*]]: !FHE.eint<2>):
// CHECK-NEXT: %[[V2:.*]] = "FHE.to_unsigned"(%[[Varg1]]) : (!FHE.esint<2>) -> !FHE.eint<2>
// CHECK-NEXT: linalg.yield %[[V2]] : !FHE.eint<2>
// CHECK-NEXT: } -> tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: return %1 : tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: return %[[V1]] : tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: }
// CHECK-NEXT: }
func.func @main(%arg0: tensor<2x3x4x!FHE.esint<2>>) -> tensor<2x3x4x!FHE.eint<2>> {

View File

@@ -1,18 +1,18 @@
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
//CHECK-LABEL: func.func @add_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>, %arg1: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
//CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
//CHECK-NEXT: %c0 = arith.constant 0 : index
//CHECK-NEXT: %c1 = arith.constant 1 : index
//CHECK-NEXT: %c5 = arith.constant 5 : index
//CHECK-NEXT: %1 = scf.for %arg2 = %c0 to %c5 step %c1 iter_args(%arg3 = %0) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
//CHECK-NEXT: %2 = tensor.extract %arg0[%arg2] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
//CHECK-NEXT: %3 = tensor.extract %arg1[%arg2] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
//CHECK-NEXT: %4 = "TFHE.add_glwe"(%2, %3) : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
//CHECK-NEXT: %5 = tensor.insert %4 into %arg3[%arg2] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
//CHECK-NEXT: scf.yield %5 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
//CHECK: func.func @add_eint(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<{_,_,_}{7}>>, %[[Varg1:.*]]: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
//CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
//CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
//CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
//CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg2:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg3:.*]] = %[[V0]]) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
//CHECK-NEXT: %[[V2:.*]] = tensor.extract %[[Varg0]][%[[Varg2]]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
//CHECK-NEXT: %[[V3:.*]] = tensor.extract %[[Varg1]][%[[Varg2]]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
//CHECK-NEXT: %[[V4:.*]] = "TFHE.add_glwe"(%[[V2]], %[[V3]]) : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
//CHECK-NEXT: %[[V5:.*]] = tensor.insert %[[V4]] into %[[Varg3]][%[[Varg2]]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
//CHECK-NEXT: scf.yield %[[V5]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
//CHECK-NEXT: }
//CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
//CHECK-NEXT: return %[[V1]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
func.func @add_eint(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>

View File

@@ -1,21 +1,21 @@
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
// CHECK-LABEL: func.func @add_eint_int(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
// CHECK-NEXT: %0 = arith.extsi %c1_i8 : i8 to i64
// CHECK-NEXT: %1 = "TFHE.encode_plaintext_with_crt"(%0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
// CHECK-NEXT: %2 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-NEXT: %c5 = arith.constant 5 : index
// CHECK-NEXT: %3 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %2) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
// CHECK-NEXT: %4 = tensor.extract %arg0[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: %5 = tensor.extract %1[%arg1] : tensor<5xi64>
// CHECK-NEXT: %6 = "TFHE.add_glwe_int"(%4, %5) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: %7 = tensor.insert %6 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: scf.yield %7 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK: func.func @add_eint_int(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
// CHECK-NEXT: %[[Vc1_i8:.*]] = arith.constant 1 : i8
// CHECK-NEXT: %[[V0:.*]] = arith.extsi %[[Vc1_i8]] : i8 to i64
// CHECK-NEXT: %[[V1:.*]] = "TFHE.encode_plaintext_with_crt"(%[[V0]]) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
// CHECK-NEXT: %[[V2:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
// CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V2]]) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
// CHECK-NEXT: %[[V4:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: %[[V5:.*]] = tensor.extract %[[V1]][%[[Varg1]]] : tensor<5xi64>
// CHECK-NEXT: %[[V6:.*]] = "TFHE.add_glwe_int"(%[[V4]], %[[V5]]) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: %[[V7:.*]] = tensor.insert %[[V6]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: scf.yield %[[V7]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: }
// CHECK-NEXT: return %3 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: return %[[V3]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
func.func @add_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
%0 = arith.constant 1 : i8
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)

View File

@@ -1,98 +1,97 @@
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
//CHECK-LABEL: func.func @conv2d(%arg0: tensor<100x3x28x28x5x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4x3x14x14xi3>, %arg2: tensor<4xi3>) -> tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> {
//CHECK-NEXT: %c4 = arith.constant 4 : index
//CHECK-NEXT: %c100 = arith.constant 100 : index
//CHECK-NEXT: %c15 = arith.constant 15 : index
//CHECK-NEXT: %c0 = arith.constant 0 : index
//CHECK-NEXT: %c1 = arith.constant 1 : index
//CHECK-NEXT: %c3 = arith.constant 3 : index
//CHECK-NEXT: %c14 = arith.constant 14 : index
//CHECK-NEXT: %0 = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %1 = scf.for %arg3 = %c0 to %c100 step %c1 iter_args(%arg4 = %0) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %c4 step %c1 iter_args(%arg6 = %arg4) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %4 = scf.for %arg7 = %c0 to %c15 step %c1 iter_args(%arg8 = %arg6) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %5 = scf.for %arg9 = %c0 to %c15 step %c1 iter_args(%arg10 = %arg8) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %6 = tensor.extract %arg2[%arg5] : tensor<4xi3>
//CHECK-NEXT: %c0_0 = arith.constant 0 : index
//CHECK-NEXT: %7 = tensor.extract_slice %0[%arg3, %arg5, %arg7, %arg9, %c0_0] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %8 = arith.extsi %6 : i3 to i64
//CHECK-NEXT: %9 = "TFHE.encode_plaintext_with_crt"(%8) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
//CHECK-NEXT: %10 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %c0_1 = arith.constant 0 : index
//CHECK-NEXT: %c1_2 = arith.constant 1 : index
//CHECK-NEXT: %c5 = arith.constant 5 : index
//CHECK-NEXT: %11 = scf.for %arg11 = %c0_1 to %c5 step %c1_2 iter_args(%arg12 = %10) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %13 = tensor.extract %7[%arg11] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %14 = tensor.extract %9[%arg11] : tensor<5xi64>
//CHECK-NEXT: %15 = "TFHE.add_glwe_int"(%13, %14) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: %16 = tensor.insert %15 into %arg12[%arg11] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: scf.yield %16 : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK: func.func @conv2d(%[[Varg0:.*]]: tensor<100x3x28x28x5x!TFHE.glwe<{_,_,_}{2}>>, %[[Varg1:.*]]: tensor<4x3x14x14xi3>, %[[Varg2:.*]]: tensor<4xi3>) -> tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> {
//CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vc100:.*]] = arith.constant 100 : index
//CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
//CHECK-NEXT: %[[Vc4:.*]] = arith.constant 4 : index
//CHECK-NEXT: %[[Vc15:.*]] = arith.constant 15 : index
//CHECK-NEXT: %[[Vc3:.*]] = arith.constant 3 : index
//CHECK-NEXT: %[[Vc14:.*]] = arith.constant 14 : index
//CHECK-NEXT: %[[V0:.*]] = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V0]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[Varg2]]{{\[}}%[[Varg5]]{{\]}} : tensor<4xi3>
//CHECK-NEXT: %[[Vc0_0:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vextracted_slice:.*]] = tensor.extract_slice %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_0]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %[[V6:.*]] = arith.extsi %[[Vextracted]] : i3 to i64
//CHECK-NEXT: %[[V7:.*]] = "TFHE.encode_plaintext_with_crt"(%[[V6]]) {mods = {{\[2, 3, 5, 7, 11\], modsProd}} = 2310 : i64} : (i64) -> tensor<5xi64>
//CHECK-NEXT: %[[V8:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %[[Vc0_1:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vc1_2:.*]] = arith.constant 1 : index
//CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
//CHECK-NEXT: %[[V9:.*]] = scf.for %[[Varg11:.*]] = %[[Vc0_1]] to %[[Vc5]] step %[[Vc1_2]] iter_args(%[[Varg12:.*]] = %[[V8]]) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %[[Vextracted_4:.*]] = tensor.extract %[[Vextracted_slice]]{{\[}}%[[Varg11]]{{\]}} : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %[[Vextracted_5:.*]] = tensor.extract %[[V7]]{{\[}}%[[Varg11]]{{\]}} : tensor<5xi64>
//CHECK-NEXT: %[[V10:.*]] = "TFHE.add_glwe_int"(%[[Vextracted_4]], %[[Vextracted_5]]) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V10]] into %[[Varg12]]{{\[}}%[[Varg11]]{{\]}} : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
//CHECK-NEXT: %c0_3 = arith.constant 0 : index
//CHECK-NEXT: %12 = tensor.insert_slice %11 into %arg10[%arg3, %arg5, %arg7, %arg9, %c0_3] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> into tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: scf.yield %12 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %[[Vc0_3:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vinserted_slice:.*]] = tensor.insert_slice %[[V9]] into %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_3]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<5x!TFHE.glwe<{_,_,_}{2}>> into tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: scf.yield %[[Vinserted_slice]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
//CHECK-NEXT: scf.yield %5 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
//CHECK-NEXT: scf.yield %4 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
//CHECK-NEXT: scf.yield %3 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
//CHECK-NEXT: %2 = scf.for %arg3 = %c0 to %c100 step %c1 iter_args(%arg4 = %1) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %c4 step %c1 iter_args(%arg6 = %arg4) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %4 = scf.for %arg7 = %c0 to %c15 step %c1 iter_args(%arg8 = %arg6) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %5 = scf.for %arg9 = %c0 to %c15 step %c1 iter_args(%arg10 = %arg8) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %6 = scf.for %arg11 = %c0 to %c3 step %c1 iter_args(%arg12 = %arg10) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %7 = scf.for %arg13 = %c0 to %c14 step %c1 iter_args(%arg14 = %arg12) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %8 = scf.for %arg15 = %c0 to %c14 step %c1 iter_args(%arg16 = %arg14) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %9 = affine.apply #map(%arg7, %arg13)
//CHECK-NEXT: %10 = affine.apply #map(%arg9, %arg15)
//CHECK-NEXT: %c0_0 = arith.constant 0 : index
//CHECK-NEXT: %11 = tensor.extract_slice %arg0[%arg3, %arg11, %9, %10, %c0_0] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<100x3x28x28x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %12 = tensor.extract %arg1[%arg5, %arg11, %arg13, %arg15] : tensor<4x3x14x14xi3>
//CHECK-NEXT: %c0_1 = arith.constant 0 : index
//CHECK-NEXT: %13 = tensor.extract_slice %1[%arg3, %arg5, %arg7, %arg9, %c0_1] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %14 = arith.extsi %12 : i3 to i64
//CHECK-NEXT: %15 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %c0_2 = arith.constant 0 : index
//CHECK-NEXT: %c1_3 = arith.constant 1 : index
//CHECK-NEXT: %c5 = arith.constant 5 : index
//CHECK-NEXT: %16 = scf.for %arg17 = %c0_2 to %c5 step %c1_3 iter_args(%arg18 = %15) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %20 = tensor.extract %11[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %21 = "TFHE.mul_glwe_int"(%20, %14) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: %22 = tensor.insert %21 into %arg18[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: scf.yield %22 : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V1]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %[[V6:.*]] = scf.for %[[Varg11:.*]] = %[[Vc0]] to %[[Vc3]] step %[[Vc1]] iter_args(%[[Varg12:.*]] = %[[Varg10]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %[[V7:.*]] = scf.for %[[Varg13:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg14:.*]] = %[[Varg12]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %[[V8:.*]] = scf.for %[[Varg15:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg16:.*]] = %[[Varg14]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %[[V9:.*]] = affine.apply #map(%[[Varg7]], %[[Varg13]])
//CHECK-NEXT: %[[V10:.*]] = affine.apply #map(%[[Varg9]], %[[Varg15]])
//CHECK-NEXT: %[[Vc0_0:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vextracted_slice:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[Varg3]], %[[Varg11]], %[[V9]], %[[V10]], %[[Vc0_0]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<100x3x28x28x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[Varg1]]{{\[}}%[[Varg5]], %[[Varg11]], %[[Varg13]], %[[Varg15]]{{\]}} : tensor<4x3x14x14xi3>
//CHECK-NEXT: %[[Vc0_1:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vextracted_slice_2:.*]] = tensor.extract_slice %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_1]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %[[V11:.*]] = arith.extsi %[[Vextracted]] : i3 to i64
//CHECK-NEXT: %[[V12:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %[[Vc0_3:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vc1_4:.*]] = arith.constant 1 : index
//CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
//CHECK-NEXT: %[[V13:.*]] = scf.for %[[Varg17:.*]] = %[[Vc0_3]] to %[[Vc5]] step %[[Vc1_4]] iter_args(%[[Varg18:.*]] = %[[V12]]) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %[[Vextracted_9:.*]] = tensor.extract %[[Vextracted_slice]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %[[V16:.*]] = "TFHE.mul_glwe_int"(%[[Vextracted_9]], %[[V11]]) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V16]] into %[[Varg18]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
//CHECK-NEXT: %17 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %c0_4 = arith.constant 0 : index
//CHECK-NEXT: %c1_5 = arith.constant 1 : index
//CHECK-NEXT: %c5_6 = arith.constant 5 : index
//CHECK-NEXT: %18 = scf.for %arg17 = %c0_4 to %c5_6 step %c1_5 iter_args(%arg18 = %17) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %20 = tensor.extract %13[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %21 = tensor.extract %16[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %22 = "TFHE.add_glwe"(%20, %21) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: %23 = tensor.insert %22 into %arg18[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: scf.yield %23 : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %[[V14:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %[[Vc0_5:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vc1_6:.*]] = arith.constant 1 : index
//CHECK-NEXT: %[[Vc5_7:.*]] = arith.constant 5 : index
//CHECK-NEXT: %[[V15:.*]] = scf.for %[[Varg17:.*]] = %[[Vc0_5]] to %[[Vc5_7]] step %[[Vc1_6]] iter_args(%[[Varg18:.*]] = %[[V14]]) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %[[Vextracted_9:.*]] = tensor.extract %[[Vextracted_slice_2]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %[[Vextracted_10:.*]] = tensor.extract %[[V13]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %[[V16:.*]] = "TFHE.add_glwe"(%[[Vextracted_9]], %[[Vextracted_10]]) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V16]] into %[[Varg18]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
//CHECK-NEXT: %c0_7 = arith.constant 0 : index
//CHECK-NEXT: %19 = tensor.insert_slice %18 into %arg16[%arg3, %arg5, %arg7, %arg9, %c0_7] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> into tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: scf.yield %19 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %[[Vc0_8:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vinserted_slice:.*]] = tensor.insert_slice %[[V15]] into %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_8]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<5x!TFHE.glwe<{_,_,_}{2}>> into tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: scf.yield %[[Vinserted_slice]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
//CHECK-NEXT: scf.yield %8 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: scf.yield %[[V8]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
//CHECK-NEXT: scf.yield %7 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: scf.yield %[[V7]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
//CHECK-NEXT: scf.yield %6 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: scf.yield %[[V6]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
//CHECK-NEXT: scf.yield %5 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
//CHECK-NEXT: scf.yield %4 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
//CHECK-NEXT: scf.yield %3 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
//CHECK-NEXT: return %2 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: return %[[V2]] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0, 0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>

View File

@@ -1,19 +1,19 @@
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
// CHECK-LABEL: func.func @mul_eint_int(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
// CHECK-NEXT: %c2_i8 = arith.constant 2 : i8
// CHECK-NEXT: %0 = arith.extsi %c2_i8 : i8 to i64
// CHECK-NEXT: %1 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-NEXT: %c5 = arith.constant 5 : index
// CHECK-NEXT: %2 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %1) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
// CHECK-NEXT: %3 = tensor.extract %arg0[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: %4 = "TFHE.mul_glwe_int"(%3, %0) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: %5 = tensor.insert %4 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: scf.yield %5 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK: func.func @mul_eint_int(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
// CHECK-NEXT: %[[Vc2_i8:.*]] = arith.constant 2 : i8
// CHECK-NEXT: %[[V0:.*]] = arith.extsi %[[Vc2_i8]] : i8 to i64
// CHECK-NEXT: %[[V1:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
// CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
// CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V1]]) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
// CHECK-NEXT: %[[V3:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: %[[V4:.*]] = "TFHE.mul_glwe_int"(%[[V3]], %[[V0]]) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: %[[V5:.*]] = tensor.insert %[[V4]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: scf.yield %[[V5]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: }
// CHECK-NEXT: return %2 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: return %[[V2]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
func.func @mul_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
%0 = arith.constant 2 : i8
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)

View File

@@ -1,17 +1,17 @@
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
// CHECK-LABEL: func.func @neg_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-NEXT: %c5 = arith.constant 5 : index
// CHECK-NEXT: %1 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %0) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
// CHECK-NEXT: %2 = tensor.extract %arg0[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: %3 = "TFHE.neg_glwe"(%2) : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: %4 = tensor.insert %3 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: scf.yield %4 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK: func.func @neg_eint(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
// CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
// CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V0]]) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
// CHECK-NEXT: %[[V2:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: %[[V3:.*]] = "TFHE.neg_glwe"(%[[V2]]) : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: %[[V4:.*]] = tensor.insert %[[V3]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: scf.yield %[[V4]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: }
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: return %[[V1]] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: }
func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>)

Some files were not shown because too many files have changed in this diff Show More