mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(compiler): add support for multikey
This commit brings support for multiple secret keys in the TFHE dialect. In particular, a parameterized `TFHE` circuit can now be given as input, with any combination of (semantically valid) of ks/bs/woppbs mixing different secret keys, and compiled down to a valid executable function, with server keys properly looked up. Secret keys are now stateful objects which can be: -> none/unparameterized (syntax `sk?`): The keys are in state after the lowering from the `FHE` dialect. -> parameterized (syntax `sk<identifier, polysize, dimension>`): The keys were parameterized, either by user or by the optimizer. The `identifier` field can be used to disambiguate two keys with same `polysize` and `dimension`. -> normalized (syntax `sk[index]<polysize, dimension>`): The keys were attached to their index in the list of keys in the runtime context. The _normalization_ of key indices also acts on the ksk, bsk and pksk, which are given indices in the same spirit now. Finally, in order to allow parameterized `TFHE` circuit to be given as input and compiled down to executable functions, we added a way to pass the encodings that are used to encode/decode the circuit inputs/outputs. In the case of a compilation from the `FHE` dialect, those informations are automatically extracted from the higher level informations available in this dialect.
This commit is contained in:
committed by
Quentin Bourgerie
parent
823ea618af
commit
cacffadbd2
@@ -123,6 +123,9 @@ enum CompilationTarget {
|
||||
ROUND_TRIP,
|
||||
FHE,
|
||||
TFHE,
|
||||
PARAMETRIZED_TFHE,
|
||||
NORMALIZED_TFHE,
|
||||
BATCHED_TFHE,
|
||||
CONCRETE,
|
||||
STD,
|
||||
LLVM,
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "concretelang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
|
||||
#include "concretelang/Conversion/SDFGToStreamEmulator/Pass.h"
|
||||
#include "concretelang/Conversion/TFHEGlobalParametrization/Pass.h"
|
||||
#include "concretelang/Conversion/TFHEKeyNormalization/Pass.h"
|
||||
#include "concretelang/Conversion/TFHEToConcrete/Pass.h"
|
||||
#include "concretelang/Conversion/TracingToCAPI/Pass.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
|
||||
@@ -32,6 +32,13 @@ def TFHEGlobalParametrization : Pass<"tfhe-global-parametrization", "mlir::Modul
|
||||
let dependentDialects = ["mlir::concretelang::TFHE::TFHEDialect"];
|
||||
}
|
||||
|
||||
def TFHEKeyNormalization : Pass<"tfhe-key-normalization", "mlir::ModuleOp"> {
|
||||
let summary = "Ensures the key ids form proper ranges of indices";
|
||||
let constructor = "mlir::concretelang::createTFHEKeyNormalizationPass()";
|
||||
let options = [];
|
||||
let dependentDialects = ["mlir::concretelang::TFHE::TFHEDialect"];
|
||||
}
|
||||
|
||||
def TFHEToConcrete : Pass<"tfhe-to-concrete", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the TFHE dialect to Concrete";
|
||||
let description = [{ Lowers operations from the TFHE dialect to Concrete }];
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_TFHEKEYNORMALIZATION_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_TFHEKEYNORMALIZATION_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass that ensures that the ids of the keys form a proper index
|
||||
/// range.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createTFHEKeyNormalizationPass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -191,7 +191,8 @@ def Concrete_BootstrapLweTensorOp : Concrete_Op<"bootstrap_lwe_tensor", [Pure]>
|
||||
I32Attr:$polySize,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$glweDimension
|
||||
I32Attr:$glweDimension,
|
||||
I32Attr:$bskIndex
|
||||
);
|
||||
let results = (outs Concrete_LweTensor:$result);
|
||||
}
|
||||
@@ -207,7 +208,8 @@ def Concrete_BootstrapLweBufferOp : Concrete_Op<"bootstrap_lwe_buffer"> {
|
||||
I32Attr:$polySize,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$glweDimension
|
||||
I32Attr:$glweDimension,
|
||||
I32Attr:$bskIndex
|
||||
);
|
||||
}
|
||||
|
||||
@@ -221,7 +223,8 @@ def Concrete_BatchedBootstrapLweTensorOp : Concrete_Op<"batched_bootstrap_lwe_te
|
||||
I32Attr:$polySize,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$glweDimension
|
||||
I32Attr:$glweDimension,
|
||||
I32Attr:$bskIndex
|
||||
);
|
||||
let results = (outs Concrete_BatchLweTensor:$result);
|
||||
}
|
||||
@@ -237,7 +240,8 @@ def Concrete_BatchedBootstrapLweBufferOp : Concrete_Op<"batched_bootstrap_lwe_bu
|
||||
I32Attr:$polySize,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$glweDimension
|
||||
I32Attr:$glweDimension,
|
||||
I32Attr:$bskIndex
|
||||
);
|
||||
}
|
||||
|
||||
@@ -250,7 +254,8 @@ def Concrete_KeySwitchLweTensorOp : Concrete_Op<"keyswitch_lwe_tensor", [Pure]>
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$lwe_dim_in,
|
||||
I32Attr:$lwe_dim_out
|
||||
I32Attr:$lwe_dim_out,
|
||||
I32Attr:$kskIndex
|
||||
);
|
||||
let results = (outs Concrete_LweTensor:$result);
|
||||
}
|
||||
@@ -264,7 +269,8 @@ def Concrete_KeySwitchLweBufferOp : Concrete_Op<"keyswitch_lwe_buffer"> {
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$lwe_dim_in,
|
||||
I32Attr:$lwe_dim_out
|
||||
I32Attr:$lwe_dim_out,
|
||||
I32Attr:$kskIndex
|
||||
);
|
||||
}
|
||||
|
||||
@@ -277,7 +283,8 @@ def Concrete_BatchedKeySwitchLweTensorOp : Concrete_Op<"batched_keyswitch_lwe_te
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$lwe_dim_in,
|
||||
I32Attr:$lwe_dim_out
|
||||
I32Attr:$lwe_dim_out,
|
||||
I32Attr:$kskIndex
|
||||
);
|
||||
let results = (outs Concrete_BatchLweTensor:$result);
|
||||
}
|
||||
@@ -291,7 +298,8 @@ def Concrete_BatchedKeySwitchLweBufferOp : Concrete_Op<"batched_keyswitch_lwe_bu
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$lwe_dim_in,
|
||||
I32Attr:$lwe_dim_out
|
||||
I32Attr:$lwe_dim_out,
|
||||
I32Attr:$kskIndex
|
||||
);
|
||||
}
|
||||
|
||||
@@ -313,7 +321,11 @@ def Concrete_WopPBSCRTLweTensorOp : Concrete_Op<"wop_pbs_crt_lwe_tensor", [Pure]
|
||||
// Circuit bootstrap parameters
|
||||
I32Attr : $circuitBootstrapLevel,
|
||||
I32Attr : $circuitBootstrapBaseLog,
|
||||
I64ArrayAttr:$crtDecomposition
|
||||
I64ArrayAttr:$crtDecomposition,
|
||||
// Key indices
|
||||
I32Attr:$kskIndex,
|
||||
I32Attr:$bskIndex,
|
||||
I32Attr:$pkskIndex
|
||||
);
|
||||
let results = (outs Concrete_LweCRTTensor:$result);
|
||||
}
|
||||
@@ -337,7 +349,11 @@ def Concrete_WopPBSCRTLweBufferOp : Concrete_Op<"wop_pbs_crt_lwe_buffer"> {
|
||||
// Circuit bootstrap parameters
|
||||
I32Attr : $circuitBootstrapLevel,
|
||||
I32Attr : $circuitBootstrapBaseLog,
|
||||
I64ArrayAttr:$crtDecomposition
|
||||
I64ArrayAttr:$crtDecomposition,
|
||||
// Key indices
|
||||
I32Attr:$kskIndex,
|
||||
I32Attr:$bskIndex,
|
||||
I32Attr:$pkskIndex
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -27,10 +27,11 @@ def TFHE_KeyswitchKeyAttr: TFHE_Attr<"GLWEKeyswitchKey", "ksk"> {
|
||||
"mlir::concretelang::TFHE::GLWESecretKey":$inputKey,
|
||||
"mlir::concretelang::TFHE::GLWESecretKey":$outputKey,
|
||||
"int":$levels,
|
||||
"int":$baseLog
|
||||
"int":$baseLog,
|
||||
DefaultValuedParameter<"int", "-1">: $index
|
||||
);
|
||||
|
||||
let assemblyFormat = "`<` $inputKey `,` $outputKey `,` $levels `,` $baseLog `>`";
|
||||
let assemblyFormat = " (`[` $index^ `]`)? `<` $inputKey `,` $outputKey `,` $levels `,` $baseLog `>`";
|
||||
}
|
||||
|
||||
def TFHE_BootstrapKeyAttr: TFHE_Attr<"GLWEBootstrapKey", "bsk"> {
|
||||
@@ -43,10 +44,11 @@ def TFHE_BootstrapKeyAttr: TFHE_Attr<"GLWEBootstrapKey", "bsk"> {
|
||||
"int":$polySize,
|
||||
"int":$glweDim,
|
||||
"int":$levels,
|
||||
"int":$baseLog
|
||||
"int":$baseLog,
|
||||
DefaultValuedParameter<"int", "-1">: $index
|
||||
);
|
||||
|
||||
let assemblyFormat = "`<` $inputKey `,` $outputKey `,` $polySize `,` $glweDim `,` $levels `,` $baseLog `>`";
|
||||
let assemblyFormat = "(`[` $index^ `]`)? `<` $inputKey `,` $outputKey `,` $polySize `,` $glweDim `,` $levels `,` $baseLog `>`";
|
||||
}
|
||||
|
||||
def TFHE_PackingKeyswitchKeyAttr: TFHE_Attr<"GLWEPackingKeyswitchKey", "pksk"> {
|
||||
@@ -58,11 +60,13 @@ def TFHE_PackingKeyswitchKeyAttr: TFHE_Attr<"GLWEPackingKeyswitchKey", "pksk"> {
|
||||
"mlir::concretelang::TFHE::GLWESecretKey":$outputKey,
|
||||
"int" : $outputPolySize,
|
||||
"int" : $inputLweDim,
|
||||
"int" : $glweDim,
|
||||
"int" : $levels,
|
||||
"int" : $baseLog
|
||||
"int" : $baseLog,
|
||||
DefaultValuedParameter<"int", "-1">: $index
|
||||
);
|
||||
|
||||
let assemblyFormat = "`<` $inputKey `,` $outputKey`,` $outputPolySize`,` $inputLweDim `,` $levels `,` $baseLog `>`";
|
||||
let assemblyFormat = " (`[` $index^ `]` )? `<` $inputKey `,` $outputKey`,` $outputPolySize`,` $inputLweDim `,` $glweDim `,` $levels `,` $baseLog `>`";
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -11,48 +11,63 @@
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
#include <mlir/IR/DialectImplementation.h>
|
||||
#include <variant>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace TFHE {
|
||||
|
||||
/// A type parameter representing GLWE secret key.
|
||||
///
|
||||
/// A glwe secret key is basically a glwe dimension, a polynomial size, and an
|
||||
/// id that makes it possible to disambiguate potential keys with with same
|
||||
/// parameters.
|
||||
///
|
||||
/// Note that a key can be instantiated to a `none` key, to serve as a
|
||||
/// placeholder in the IR. In this case, none of its data are actually usable
|
||||
/// for lowering to the `Concrete` dialect. Once the
|
||||
/// `TFHEGlobalParameterization` was performed, there should remain no such
|
||||
/// `none` keys in the IR.
|
||||
class GLWESecretKey {
|
||||
public:
|
||||
/// Creates a new none key.
|
||||
GLWESecretKey();
|
||||
/// Create a new key from parameters.
|
||||
GLWESecretKey(int64_t dimension, int64_t polySize, int64_t id);
|
||||
bool operator==(GLWESecretKey other);
|
||||
bool operator==(const GLWESecretKey other) const;
|
||||
bool operator!=(GLWESecretKey other);
|
||||
/// Returns the dimension associated with this key, if the key is not none.
|
||||
mlir::Optional<int64_t> getDimension() const;
|
||||
/// Returns the polynomial size associated with this key, if the key is not
|
||||
/// none.
|
||||
mlir::Optional<int64_t> getPolySize() const;
|
||||
/// Returns the id associated with this key, if the key is not none.
|
||||
mlir::Optional<int64_t> getId() const;
|
||||
/// Returns true if the key was not filled with valid parameters.
|
||||
bool isNotParameterized() const;
|
||||
/// A placeholder.
|
||||
struct GLWESecretKeyNone {};
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
const GLWESecretKeyNone sk);
|
||||
|
||||
private:
|
||||
int64_t dimension;
|
||||
int64_t polySize;
|
||||
int64_t id;
|
||||
// The key was parameterized.
|
||||
struct GLWESecretKeyParameterized {
|
||||
uint64_t dimension;
|
||||
uint64_t polySize;
|
||||
uint64_t identifier;
|
||||
bool operator==(const GLWESecretKeyParameterized other) const;
|
||||
};
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
const GLWESecretKeyParameterized sk);
|
||||
|
||||
// The key was normalized
|
||||
struct GLWESecretKeyNormalized {
|
||||
uint64_t dimension;
|
||||
uint64_t polySize;
|
||||
uint64_t index;
|
||||
bool operator==(const GLWESecretKeyNormalized other) const;
|
||||
};
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
const GLWESecretKeyNormalized sk);
|
||||
|
||||
/// A sum type parameter representing GLWE secret keys in different states.
|
||||
struct GLWESecretKey {
|
||||
std::variant<GLWESecretKeyNone, GLWESecretKeyParameterized,
|
||||
GLWESecretKeyNormalized>
|
||||
inner;
|
||||
|
||||
static GLWESecretKey newNone();
|
||||
static GLWESecretKey newParameterized(uint64_t dimension, uint64_t polySize,
|
||||
uint64_t identifier);
|
||||
static GLWESecretKey newNormalized(uint64_t dimension, uint64_t polySize,
|
||||
uint64_t index);
|
||||
bool operator==(const GLWESecretKey other) const;
|
||||
bool operator!=(const GLWESecretKey other) const;
|
||||
template <typename V> bool is();
|
||||
bool isNone();
|
||||
bool isParameterized();
|
||||
bool isNormalized();
|
||||
template <typename V> std::optional<V> get();
|
||||
std::optional<GLWESecretKeyNone> getNone();
|
||||
std::optional<GLWESecretKeyParameterized> getParameterized();
|
||||
std::optional<GLWESecretKeyNormalized> getNormalized();
|
||||
};
|
||||
|
||||
llvm::hash_code hash_value(const GLWESecretKey &key);
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const GLWESecretKey sk);
|
||||
|
||||
} // namespace TFHE
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -25,13 +25,6 @@ def TFHE_GLWECipherTextType
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let genVerifyDecl = true;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns true if has an unparametrized parameters
|
||||
bool hasUnparametrizedParameters() {
|
||||
return getKey().isNotParameterized();
|
||||
};
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -36,11 +36,11 @@ void stream_emulator_make_memref_negate_lwe_ciphertext_u64_process(void *dfg,
|
||||
void stream_emulator_make_memref_keyswitch_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log,
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size,
|
||||
void *context);
|
||||
uint32_t ksk_index, void *context);
|
||||
void stream_emulator_make_memref_bootstrap_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim,
|
||||
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
|
||||
uint32_t output_size, void *context);
|
||||
uint32_t output_size, uint32_t bsk_index, void *context);
|
||||
|
||||
void *stream_emulator_make_uint64_stream(const char *name, stream_type stype);
|
||||
void stream_emulator_put_uint64(void *stream, uint64_t e);
|
||||
|
||||
@@ -82,6 +82,7 @@ void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint64_t ct0_size, uint64_t ct0_stride,
|
||||
uint32_t level, uint32_t base_log,
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim,
|
||||
uint32_t ksk_index,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void memref_batched_keyswitch_lwe_u64(
|
||||
@@ -91,7 +92,7 @@ void memref_batched_keyswitch_lwe_u64(
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint32_t level,
|
||||
uint32_t base_log, uint32_t input_lwe_dim, uint32_t output_lwe_dim,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
uint32_t ksk_index, mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void *memref_keyswitch_async_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
@@ -99,17 +100,15 @@ void *memref_keyswitch_async_lwe_u64(
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void memref_bootstrap_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint64_t out_offset, uint64_t out_size,
|
||||
uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset,
|
||||
uint64_t ct0_size, uint64_t ct0_stride,
|
||||
uint64_t *tlu_allocated, uint64_t *tlu_aligned,
|
||||
uint64_t tlu_offset, uint64_t tlu_size,
|
||||
uint64_t tlu_stride, uint32_t input_lwe_dim,
|
||||
uint32_t poly_size, uint32_t level,
|
||||
uint32_t base_log, uint32_t glwe_dim,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
void memref_bootstrap_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
|
||||
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
|
||||
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
|
||||
uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void memref_batched_bootstrap_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
@@ -119,7 +118,7 @@ void memref_batched_bootstrap_lwe_u64(
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
|
||||
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size,
|
||||
uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size,
|
||||
uint32_t level, uint32_t base_log, uint32_t glwe_dim,
|
||||
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void *memref_bootstrap_async_lwe_u64(
|
||||
@@ -129,7 +128,7 @@ void *memref_bootstrap_async_lwe_u64(
|
||||
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
|
||||
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
|
||||
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
|
||||
uint32_t base_log, uint32_t glwe_dim,
|
||||
uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void memref_await_future(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
@@ -163,6 +162,8 @@ void memref_wop_pbs_crt_buffer(
|
||||
uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count,
|
||||
uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log,
|
||||
uint32_t polynomial_size,
|
||||
// Key indices
|
||||
uint32_t ksk_index, uint32_t bsk_index, uint32_t pksk_index,
|
||||
// runtime context that hold evluation keys
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
@@ -183,7 +184,7 @@ void memref_keyswitch_lwe_cuda_u64(
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint32_t level, uint32_t base_log,
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim,
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t ksk_index,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
/// \brief Run bootstrapping on GPU.
|
||||
@@ -197,7 +198,7 @@ void memref_bootstrap_lwe_cuda_u64(
|
||||
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
|
||||
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
|
||||
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
|
||||
uint32_t base_log, uint32_t glwe_dim,
|
||||
uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
// Batched CUDA function //////////////////////////////////////////////////////
|
||||
@@ -209,7 +210,7 @@ void memref_batched_keyswitch_lwe_cuda_u64(
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint32_t level,
|
||||
uint32_t base_log, uint32_t input_lwe_dim, uint32_t output_lwe_dim,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
uint32_t ksk_index, mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void memref_batched_bootstrap_lwe_cuda_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
@@ -219,7 +220,7 @@ void memref_batched_bootstrap_lwe_cuda_u64(
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
|
||||
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size,
|
||||
uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size,
|
||||
uint32_t level, uint32_t base_log, uint32_t glwe_dim,
|
||||
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
// Tracing ////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/Support/Encodings.h"
|
||||
#include "concretelang/Support/V0Parameters.h"
|
||||
|
||||
namespace mlir {
|
||||
@@ -19,9 +20,10 @@ using ::concretelang::clientlib::ChunkInfo;
|
||||
using ::concretelang::clientlib::ClientParameters;
|
||||
|
||||
llvm::Expected<ClientParameters>
|
||||
createClientParametersForV0(V0FHEContext context, llvm::StringRef functionName,
|
||||
mlir::ModuleOp module, int bitsOfSecurity,
|
||||
llvm::Optional<ChunkInfo> chunkInfo = std::nullopt);
|
||||
createClientParametersFromTFHE(mlir::ModuleOp module,
|
||||
llvm::StringRef functionName, int bitsOfSecurity,
|
||||
encodings::CircuitEncodings encodings,
|
||||
std::optional<CRTDecomposition> maybeCrt);
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -7,7 +7,8 @@
|
||||
#define CONCRETELANG_SUPPORT_COMPILER_ENGINE_H
|
||||
|
||||
#include <concretelang/Conversion/Utils/GlobalFHEContext.h>
|
||||
#include <concretelang/Support/V0ClientParameters.h>
|
||||
#include <concretelang/Support/ClientParametersGeneration.h>
|
||||
#include <concretelang/Support/Encodings.h>
|
||||
#include <llvm/IR/Module.h>
|
||||
#include <llvm/Support/Error.h>
|
||||
#include <llvm/Support/SourceMgr.h>
|
||||
@@ -78,6 +79,10 @@ struct CompilationOptions {
|
||||
unsigned int chunkSize;
|
||||
unsigned int chunkWidth;
|
||||
|
||||
/// When compiling from a dialect lower than FHE, one needs to provide
|
||||
/// encodings info manually to allow the client lib to be generated.
|
||||
std::optional<mlir::concretelang::encodings::CircuitEncodings> encodings;
|
||||
|
||||
CompilationOptions()
|
||||
: v0FHEConstraints(std::nullopt), verifyDiagnostics(false),
|
||||
autoParallelize(false), loopParallelize(false), batchTFHEOps(false),
|
||||
@@ -85,7 +90,7 @@ struct CompilationOptions {
|
||||
dataflowParallelize(false), optimizeTFHE(true), emitGPUOps(false),
|
||||
clientParametersFuncName(std::nullopt),
|
||||
optimizerConfig(optimizer::DEFAULT_CONFIG), chunkIntegers(false),
|
||||
chunkSize(4), chunkWidth(2){};
|
||||
chunkSize(4), chunkWidth(2), encodings(std::nullopt){};
|
||||
|
||||
CompilationOptions(std::string funcname) : CompilationOptions() {
|
||||
clientParametersFuncName = funcname;
|
||||
@@ -204,7 +209,7 @@ public:
|
||||
/// and scf loops
|
||||
FHE_NO_LINALG,
|
||||
|
||||
/// Read sources and lower all FHE operations to TFHE
|
||||
/// Read sources and lower all FHE operations to unparameterized TFHE
|
||||
/// operations
|
||||
TFHE,
|
||||
|
||||
@@ -215,6 +220,10 @@ public:
|
||||
/// Batch TFHE operations
|
||||
BATCHED_TFHE,
|
||||
|
||||
/// Read sources and lower all FHE operations to normalized TFHE
|
||||
/// operations
|
||||
NORMALIZED_TFHE,
|
||||
|
||||
/// Read sources and lower all FHE and TFHE operations to Concrete
|
||||
/// operations
|
||||
CONCRETE,
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SUPPORT_ENCODINGS_H_
|
||||
#define CONCRETELANG_SUPPORT_ENCODINGS_H_
|
||||
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
#include <llvm/ADT/Optional.h>
|
||||
#include <llvm/ADT/STLExtras.h>
|
||||
#include <llvm/Support/Error.h>
|
||||
#include <llvm/Support/JSON.h>
|
||||
#include <llvm/Support/raw_ostream.h>
|
||||
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace encodings {
|
||||
|
||||
/// Represents the encoding of a small (unchunked) `FHE::eint` type.
|
||||
struct EncryptedIntegerScalarEncoding {
|
||||
uint64_t width;
|
||||
bool isSigned;
|
||||
};
|
||||
bool fromJSON(const llvm::json::Value, EncryptedIntegerScalarEncoding &,
|
||||
llvm::json::Path);
|
||||
llvm::json::Value toJSON(const EncryptedIntegerScalarEncoding &);
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
EncryptedIntegerScalarEncoding e) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(e));
|
||||
}
|
||||
|
||||
/// Represents the encoding of a big (chunked) `FHE::eint` type.
|
||||
struct EncryptedChunkedIntegerScalarEncoding {
|
||||
uint64_t width;
|
||||
bool isSigned;
|
||||
uint64_t chunkSize;
|
||||
uint64_t chunkWidth;
|
||||
};
|
||||
bool fromJSON(const llvm::json::Value, EncryptedChunkedIntegerScalarEncoding &,
|
||||
llvm::json::Path);
|
||||
llvm::json::Value toJSON(const EncryptedChunkedIntegerScalarEncoding &);
|
||||
static inline llvm::raw_ostream &
|
||||
operator<<(llvm::raw_ostream &OS, EncryptedChunkedIntegerScalarEncoding e) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(e));
|
||||
}
|
||||
|
||||
/// Represents the encoding of a `FHE::ebool` type.
|
||||
struct EncryptedBoolScalarEncoding {};
|
||||
bool fromJSON(const llvm::json::Value, EncryptedBoolScalarEncoding &,
|
||||
llvm::json::Path);
|
||||
llvm::json::Value toJSON(const EncryptedBoolScalarEncoding &);
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
EncryptedBoolScalarEncoding e) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(e));
|
||||
}
|
||||
|
||||
/// Represents the encoding of a builtin integer type.
|
||||
struct PlaintextScalarEncoding {
|
||||
uint64_t width;
|
||||
};
|
||||
bool fromJSON(const llvm::json::Value, PlaintextScalarEncoding &,
|
||||
llvm::json::Path);
|
||||
llvm::json::Value toJSON(const PlaintextScalarEncoding &);
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
PlaintextScalarEncoding e) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(e));
|
||||
}
|
||||
|
||||
/// Represents the encoding of a builtin index type.
|
||||
struct IndexScalarEncoding {};
|
||||
bool fromJSON(const llvm::json::Value, IndexScalarEncoding &, llvm::json::Path);
|
||||
llvm::json::Value toJSON(const IndexScalarEncoding &);
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
IndexScalarEncoding e) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(e));
|
||||
}
|
||||
|
||||
/// Represents the encoding of a scalar value.
|
||||
using ScalarEncoding = std::variant<
|
||||
EncryptedIntegerScalarEncoding, EncryptedChunkedIntegerScalarEncoding,
|
||||
EncryptedBoolScalarEncoding, PlaintextScalarEncoding, IndexScalarEncoding>;
|
||||
bool fromJSON(const llvm::json::Value, ScalarEncoding &, llvm::json::Path);
|
||||
llvm::json::Value toJSON(const ScalarEncoding &);
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
ScalarEncoding e) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(e));
|
||||
}
|
||||
|
||||
/// Represents the encoding of a tensor value.
|
||||
struct TensorEncoding {
|
||||
ScalarEncoding scalarEncoding;
|
||||
};
|
||||
bool fromJSON(const llvm::json::Value, TensorEncoding &, llvm::json::Path);
|
||||
llvm::json::Value toJSON(const TensorEncoding &);
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
TensorEncoding e) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(e));
|
||||
}
|
||||
|
||||
/// Represents the encoding of either an input or output value of a circuit.
|
||||
using Encoding = std::variant<TensorEncoding, ScalarEncoding>;
|
||||
bool fromJSON(const llvm::json::Value, Encoding &, llvm::json::Path);
|
||||
llvm::json::Value toJSON(const Encoding &);
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, Encoding e) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(e));
|
||||
}
|
||||
|
||||
/// Represents the encodings of a circuit.
|
||||
struct CircuitEncodings {
|
||||
std::vector<Encoding> inputEncodings;
|
||||
std::vector<Encoding> outputEncodings;
|
||||
};
|
||||
bool fromJSON(const llvm::json::Value, CircuitEncodings &, llvm::json::Path);
|
||||
llvm::json::Value toJSON(const CircuitEncodings &);
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
CircuitEncodings e) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(e));
|
||||
}
|
||||
|
||||
llvm::Expected<CircuitEncodings> getCircuitEncodings(
|
||||
llvm::StringRef functionName, mlir::ModuleOp module,
|
||||
std::optional<::concretelang::clientlib::ChunkInfo> maybeChunkInfo);
|
||||
|
||||
} // namespace encodings
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -67,6 +67,10 @@ mlir::LogicalResult batchTFHE(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
normalizeTFHEKeys(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SUPPORT_TFHECIRCUITKEYS_H_
|
||||
#define CONCRETELANG_SUPPORT_TFHECIRCUITKEYS_H_
|
||||
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEAttrs.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEParameters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace TFHE {
|
||||
|
||||
struct TFHECircuitKeys {
|
||||
llvm::SmallVector<TFHE::GLWESecretKey, 10> secretKeys;
|
||||
llvm::SmallVector<TFHE::GLWEBootstrapKeyAttr, 10> bootstrapKeys;
|
||||
llvm::SmallVector<TFHE::GLWEKeyswitchKeyAttr, 10> keyswitchKeys;
|
||||
llvm::SmallVector<TFHE::GLWEPackingKeyswitchKeyAttr, 10> packingKeyswitchKeys;
|
||||
|
||||
std::optional<uint64_t> getSecretKeyIndex(TFHE::GLWESecretKey key);
|
||||
std::optional<uint64_t> getKeyswitchKeyIndex(TFHE::GLWEKeyswitchKeyAttr key);
|
||||
std::optional<uint64_t> getBootstrapKeyIndex(TFHE::GLWEBootstrapKeyAttr key);
|
||||
std::optional<uint64_t>
|
||||
getPackingKeyswitchKeyIndex(TFHE::GLWEPackingKeyswitchKeyAttr key);
|
||||
};
|
||||
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const TFHECircuitKeys cks);
|
||||
|
||||
TFHECircuitKeys extractCircuitKeys(mlir::ModuleOp moduleOp);
|
||||
|
||||
} // namespace TFHE
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
#endif
|
||||
@@ -13,6 +13,7 @@
|
||||
#include <concretelang/Runtime/context.h>
|
||||
#include <concretelang/ServerLib/ServerLambda.h>
|
||||
#include <concretelang/Support/Error.h>
|
||||
#include <llvm/ADT/SmallVector.h>
|
||||
|
||||
namespace concretelang {
|
||||
|
||||
@@ -112,6 +113,17 @@ invokeRawOnLambda(Lambda *lambda, clientlib::ClientParameters clientParameters,
|
||||
return clientlib::PublicResult::fromBuffers(clientParameters,
|
||||
std::move(buffers));
|
||||
}
|
||||
|
||||
template <typename V, unsigned int N>
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
const llvm::SmallVector<V, N> vect) {
|
||||
OS << "[";
|
||||
for (auto v : vect) {
|
||||
OS << v << ",";
|
||||
}
|
||||
OS << "]";
|
||||
return OS;
|
||||
}
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SUPPORT_VARIANTS_H_
|
||||
#define CONCRETELANG_SUPPORT_VARIANTS_H_
|
||||
|
||||
template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
|
||||
template <class... Ts> overloaded(Ts...) -> overloaded<Ts...>;
|
||||
|
||||
#endif
|
||||
@@ -131,6 +131,12 @@ llvm::Expected<mlir::concretelang::CompilerEngine::
|
||||
return mlir::concretelang::CompilerEngine::Target::FHE;
|
||||
case TFHE:
|
||||
return mlir::concretelang::CompilerEngine::Target::TFHE;
|
||||
case PARAMETRIZED_TFHE:
|
||||
return mlir::concretelang::CompilerEngine::Target::PARAMETRIZED_TFHE;
|
||||
case NORMALIZED_TFHE:
|
||||
return mlir::concretelang::CompilerEngine::Target::NORMALIZED_TFHE;
|
||||
case BATCHED_TFHE:
|
||||
return mlir::concretelang::CompilerEngine::Target::BATCHED_TFHE;
|
||||
case CONCRETE:
|
||||
return mlir::concretelang::CompilerEngine::Target::CONCRETE;
|
||||
case STD:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
add_subdirectory(FHEToTFHEScalar)
|
||||
add_subdirectory(FHEToTFHECrt)
|
||||
add_subdirectory(TFHEGlobalParametrization)
|
||||
add_subdirectory(TFHEKeyNormalization)
|
||||
add_subdirectory(TFHEToConcrete)
|
||||
add_subdirectory(FHETensorOpsToLinalg)
|
||||
add_subdirectory(TracingToCAPI)
|
||||
|
||||
@@ -106,18 +106,20 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
{memref1DType, memref1DType}, {});
|
||||
} else if (funcName == memref_keyswitch_lwe_u64 ||
|
||||
funcName == memref_keyswitch_lwe_cuda_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
funcType =
|
||||
mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType, i32Type, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_bootstrap_lwe_u64 ||
|
||||
funcName == memref_bootstrap_lwe_cuda_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType,
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
i32Type, i32Type, contextType},
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_keyswitch_async_lwe_u64) {
|
||||
// Todo Answer this question: Isn't it dead ?
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {memref1DType, memref1DType, contextType},
|
||||
{futureType});
|
||||
@@ -125,20 +127,21 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType,
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
i32Type, i32Type, contextType},
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{futureType});
|
||||
} else if (funcName == memref_batched_keyswitch_lwe_u64 ||
|
||||
funcName == memref_batched_keyswitch_lwe_cuda_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref2DType, memref2DType, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
funcType =
|
||||
mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref2DType, memref2DType, i32Type, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_batched_bootstrap_lwe_u64 ||
|
||||
funcName == memref_batched_bootstrap_lwe_cuda_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref2DType, memref2DType,
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
i32Type, i32Type, contextType},
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_await_future) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
@@ -171,6 +174,9 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
contextType,
|
||||
},
|
||||
{});
|
||||
@@ -273,6 +279,9 @@ void keyswitchAddOperands(KeySwitchOp op,
|
||||
// lwe_dim_out
|
||||
operands.push_back(
|
||||
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getLweDimOutAttr()));
|
||||
// ksk_index
|
||||
operands.push_back(
|
||||
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getKskIndexAttr()));
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
@@ -296,6 +305,9 @@ void bootstrapAddOperands(BootstrapOp op,
|
||||
// glwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.getGlweDimensionAttr()));
|
||||
// bsk_index
|
||||
operands.push_back(
|
||||
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getBskIndexAttr()));
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
@@ -354,6 +366,15 @@ void wopPBSAddOperands(Concrete::WopPBSCRTLweBufferOp op,
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.getPackingKeySwitchoutputPolynomialSizeAttr()));
|
||||
|
||||
// ksk_index
|
||||
operands.push_back(
|
||||
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getKskIndexAttr()));
|
||||
// bsk_index
|
||||
operands.push_back(
|
||||
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getBskIndexAttr()));
|
||||
// pksk_index
|
||||
operands.push_back(
|
||||
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getPkskIndexAttr()));
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
|
||||
@@ -580,12 +580,13 @@ struct ApplyLookupTableEintOpPattern
|
||||
op.getLoc(), converter->convertType(op.getType()), adaptor.getA(),
|
||||
newLut,
|
||||
TFHE::GLWEKeyswitchKeyAttr::get(op.getContext(), TFHE::GLWESecretKey(),
|
||||
TFHE::GLWESecretKey(), -1, -1),
|
||||
TFHE::GLWESecretKey(), -1, -1, -1),
|
||||
TFHE::GLWEBootstrapKeyAttr::get(op.getContext(), TFHE::GLWESecretKey(),
|
||||
TFHE::GLWESecretKey(), -1, -1, -1, -1),
|
||||
TFHE::GLWESecretKey(), -1, -1, -1, -1,
|
||||
-1),
|
||||
TFHE::GLWEPackingKeyswitchKeyAttr::get(
|
||||
op.getContext(), TFHE::GLWESecretKey(), TFHE::GLWESecretKey(), -1,
|
||||
-1, -1, -1),
|
||||
-1, -1, -1, -1, -1),
|
||||
rewriter.getI64ArrayAttr({}), rewriter.getI32IntegerAttr(-1),
|
||||
rewriter.getI32IntegerAttr(-1));
|
||||
|
||||
|
||||
@@ -369,13 +369,14 @@ struct ApplyLookupTableEintOpPattern
|
||||
op.getLoc(), getTypeConverter()->convertType(adaptor.getA().getType()),
|
||||
input,
|
||||
TFHE::GLWEKeyswitchKeyAttr::get(op.getContext(), TFHE::GLWESecretKey(),
|
||||
TFHE::GLWESecretKey(), -1, -1));
|
||||
TFHE::GLWESecretKey(), -1, -1, -1));
|
||||
|
||||
// Insert bootstrap
|
||||
rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), ksOp, newLut,
|
||||
TFHE::GLWEBootstrapKeyAttr::get(op.getContext(), TFHE::GLWESecretKey(),
|
||||
TFHE::GLWESecretKey(), -1, -1, -1, -1));
|
||||
TFHE::GLWESecretKey(), -1, -1, -1, -1,
|
||||
-1));
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
@@ -557,12 +558,12 @@ struct RoundEintOpPattern : public ScalarOpPattern<FHE::RoundEintOp> {
|
||||
op.getLoc(), truncationInputTy, shiftedRotatedInput,
|
||||
TFHE::GLWEKeyswitchKeyAttr::get(op->getContext(),
|
||||
TFHE::GLWESecretKey(),
|
||||
TFHE::GLWESecretKey(), -1, -1));
|
||||
TFHE::GLWESecretKey(), -1, -1, -1));
|
||||
mlir::Value bootstrapped = rewriter.create<TFHE::BootstrapGLWEOp>(
|
||||
op.getLoc(), truncationInputTy, keyswitched, lut,
|
||||
TFHE::GLWEBootstrapKeyAttr::get(
|
||||
op->getContext(), TFHE::GLWESecretKey(), TFHE::GLWESecretKey(),
|
||||
-1, -1, -1, -1));
|
||||
-1, -1, -1, -1, -1));
|
||||
|
||||
//------------------------------------------------------------- CORRECTION
|
||||
// The correction is performed to achieve our right shift semantic.
|
||||
|
||||
@@ -200,6 +200,9 @@ struct LowerSDFGMakeProcess
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(),
|
||||
mpOp->getAttrOfType<mlir::IntegerAttr>("output_size")));
|
||||
// ksk_index
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(), mpOp->getAttrOfType<mlir::IntegerAttr>("kskIndex")));
|
||||
// context
|
||||
operands.push_back(getContextArgument(mpOp));
|
||||
break;
|
||||
@@ -226,6 +229,9 @@ struct LowerSDFGMakeProcess
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(),
|
||||
mpOp->getAttrOfType<mlir::IntegerAttr>("output_size")));
|
||||
// bsk_index
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(), mpOp->getAttrOfType<mlir::IntegerAttr>("bskIndex")));
|
||||
// context
|
||||
operands.push_back(getContextArgument(mpOp));
|
||||
break;
|
||||
|
||||
@@ -35,7 +35,7 @@ struct TFHEGlobalParametrizationPass
|
||||
using mlir::concretelang::TFHE::GLWECipherTextType;
|
||||
|
||||
/// TFHEGlobalParametrizationTypeConverter is a TypeConverter that transform
|
||||
/// `TFHE.glwe<sk[?]>` to
|
||||
/// `TFHE.glwe<sk?>` to
|
||||
/// `TFHE.glwe<sk[id]<glweDimension,polynomialSize>>`
|
||||
class TFHEGlobalParametrizationTypeConverter : public mlir::TypeConverter {
|
||||
|
||||
@@ -44,11 +44,16 @@ public:
|
||||
mlir::concretelang::V0Parameter &cryptoParameters)
|
||||
: cryptoParameters(cryptoParameters) {
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion(
|
||||
[&](GLWECipherTextType type) { return this->glweInterPBSType(type); });
|
||||
addConversion([&](GLWECipherTextType type) {
|
||||
if (type.getKey().isNone()) {
|
||||
return this->glweInterPBSType(type);
|
||||
} else {
|
||||
return type;
|
||||
}
|
||||
});
|
||||
addConversion([&](mlir::RankedTensorType type) {
|
||||
auto glwe = type.getElementType().dyn_cast_or_null<GLWECipherTextType>();
|
||||
if (glwe == nullptr) {
|
||||
if (glwe == nullptr || !glwe.getKey().isNone()) {
|
||||
return (mlir::Type)(type);
|
||||
}
|
||||
mlir::Type r = mlir::RankedTensorType::get(type.getShape(),
|
||||
@@ -70,11 +75,9 @@ public:
|
||||
TFHE::GLWESecretKey getInterPBSKey() {
|
||||
auto dimension = cryptoParameters.getNBigLweDimension();
|
||||
auto polynomialSize = 1;
|
||||
// Warning, for now we use hardcoded ids. Later on, we expect the optimizer
|
||||
// to give the id.
|
||||
auto id = 1;
|
||||
return mlir::concretelang::TFHE::GLWESecretKey(dimension, polynomialSize,
|
||||
id);
|
||||
auto identifier = 0;
|
||||
return mlir::concretelang::TFHE::GLWESecretKey::newParameterized(
|
||||
dimension, polynomialSize, identifier);
|
||||
}
|
||||
|
||||
TFHE::GLWECipherTextType glweInterPBSType(GLWECipherTextType &type) {
|
||||
@@ -84,11 +87,9 @@ public:
|
||||
TFHE::GLWESecretKey getIntraPBSKey() {
|
||||
auto dimension = cryptoParameters.nSmall;
|
||||
auto polynomialSize = 1;
|
||||
// Warning, for now we use hardcoded ids. Later on, we expect the optimizer
|
||||
// to give the id.
|
||||
auto id = 3;
|
||||
return mlir::concretelang::TFHE::GLWESecretKey(dimension, polynomialSize,
|
||||
id);
|
||||
auto identifier = 1;
|
||||
return mlir::concretelang::TFHE::GLWESecretKey::newParameterized(
|
||||
dimension, polynomialSize, identifier);
|
||||
}
|
||||
|
||||
TFHE::GLWECipherTextType glweIntraPBSType(GLWECipherTextType &type) {
|
||||
@@ -121,7 +122,7 @@ struct KeySwitchGLWEOpPattern
|
||||
auto newOutputKey = converter.getIntraPBSKey();
|
||||
auto keyswitchKey = TFHE::GLWEKeyswitchKeyAttr::get(
|
||||
ksOp->getContext(), newInputKey, newOutputKey, cryptoParameters.ksLevel,
|
||||
cryptoParameters.ksLogBase);
|
||||
cryptoParameters.ksLogBase, -1);
|
||||
auto newOp = rewriter.replaceOpWithNewOp<TFHE::KeySwitchGLWEOp>(
|
||||
ksOp, newOutputTy, ksOp.getCiphertext(), keyswitchKey);
|
||||
rewriter.startRootUpdate(newOp);
|
||||
@@ -159,7 +160,7 @@ struct BootstrapGLWEOpPattern
|
||||
auto bootstrapKey = TFHE::GLWEBootstrapKeyAttr::get(
|
||||
bsOp->getContext(), newInputKey, newOutputKey,
|
||||
cryptoParameters.getPolynomialSize(), cryptoParameters.glweDimension,
|
||||
cryptoParameters.brLevel, cryptoParameters.brLogBase);
|
||||
cryptoParameters.brLevel, cryptoParameters.brLogBase, -1);
|
||||
auto newOp = rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
|
||||
bsOp, newOutputTy, bsOp.getCiphertext(), bsOp.getLookupTable(),
|
||||
bootstrapKey);
|
||||
@@ -196,19 +197,20 @@ struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
|
||||
auto intraKey = converter.getIntraPBSKey();
|
||||
auto keyswitchKey = TFHE::GLWEKeyswitchKeyAttr::get(
|
||||
wopPBSOp->getContext(), interKey, intraKey, cryptoParameters.ksLevel,
|
||||
cryptoParameters.ksLogBase);
|
||||
cryptoParameters.ksLogBase, -1);
|
||||
auto bootstrapKey = TFHE::GLWEBootstrapKeyAttr::get(
|
||||
wopPBSOp->getContext(), intraKey, interKey,
|
||||
cryptoParameters.getPolynomialSize(), cryptoParameters.glweDimension,
|
||||
cryptoParameters.brLevel, cryptoParameters.brLogBase);
|
||||
cryptoParameters.brLevel, cryptoParameters.brLogBase, -1);
|
||||
auto packingKeyswitchKey = TFHE::GLWEPackingKeyswitchKeyAttr::get(
|
||||
wopPBSOp->getContext(), interKey, interKey,
|
||||
cryptoParameters.largeInteger->wopPBS.packingKeySwitch
|
||||
.outputPolynomialSize,
|
||||
cryptoParameters.largeInteger->wopPBS.packingKeySwitch
|
||||
.inputLweDimension,
|
||||
cryptoParameters.glweDimension,
|
||||
cryptoParameters.largeInteger->wopPBS.packingKeySwitch.level,
|
||||
cryptoParameters.largeInteger->wopPBS.packingKeySwitch.baseLog);
|
||||
cryptoParameters.largeInteger->wopPBS.packingKeySwitch.baseLog, -1);
|
||||
auto newOp = rewriter.replaceOpWithNewOp<TFHE::WopPBSGLWEOp>(
|
||||
wopPBSOp, newOutputType, wopPBSOp.getCiphertexts(),
|
||||
wopPBSOp.getLookupTable(), keyswitchKey, bootstrapKey,
|
||||
@@ -290,21 +292,29 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
|
||||
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
|
||||
patterns, converter);
|
||||
|
||||
// Parametrize keyswitch bootstrap
|
||||
// Parametrize keyswitch
|
||||
target.addLegalOp<mlir::arith::ConstantOp>();
|
||||
patterns.add<KeySwitchGLWEOpPattern>(&getContext(), converter,
|
||||
cryptoParameters);
|
||||
target.addDynamicallyLegalOp<TFHE::KeySwitchGLWEOp>(
|
||||
[&](TFHE::KeySwitchGLWEOp op) {
|
||||
return !op.getKey().getInputKey().isNotParameterized() &&
|
||||
!op.getKey().getOutputKey().isNotParameterized() &&
|
||||
op.getKey().getBaseLog() != 0 && op.getKey().getLevels() != 0;
|
||||
return op.getKeyAttr().getInputKey().isParameterized() &&
|
||||
op.getKeyAttr().getOutputKey().isParameterized() &&
|
||||
op.getKeyAttr().getBaseLog() != -1 &&
|
||||
op.getKeyAttr().getLevels() != -1;
|
||||
});
|
||||
|
||||
// Parametrize bootstrap
|
||||
patterns.add<BootstrapGLWEOpPattern>(&getContext(), converter,
|
||||
cryptoParameters);
|
||||
target.addDynamicallyLegalOp<TFHE::BootstrapGLWEOp>(
|
||||
[&](TFHE::BootstrapGLWEOp op) {
|
||||
return converter.isLegal(op->getResultTypes());
|
||||
return op.getKeyAttr().getInputKey().isParameterized() &&
|
||||
op.getKeyAttr().getOutputKey().isParameterized() &&
|
||||
op.getKeyAttr().getLevels() != -1 &&
|
||||
op.getKeyAttr().getBaseLog() != -1 &&
|
||||
op.getKeyAttr().getGlweDim() != -1 &&
|
||||
op.getKeyAttr().getPolySize() != -1;
|
||||
});
|
||||
|
||||
// Parametrize wop pbs
|
||||
@@ -312,11 +322,20 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
|
||||
cryptoParameters);
|
||||
target.addDynamicallyLegalOp<TFHE::WopPBSGLWEOp>(
|
||||
[&](TFHE::WopPBSGLWEOp op) {
|
||||
return !op.getType()
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType()
|
||||
.cast<TFHE::GLWECipherTextType>()
|
||||
.hasUnparametrizedParameters();
|
||||
return op.getKskAttr().getInputKey().isParameterized() &&
|
||||
op.getKskAttr().getOutputKey().isParameterized() &&
|
||||
op.getKskAttr().getBaseLog() != -1 &&
|
||||
op.getKskAttr().getLevels() != -1 &&
|
||||
op.getBskAttr().getInputKey().isParameterized() &&
|
||||
op.getBskAttr().getOutputKey().isParameterized() &&
|
||||
op.getBskAttr().getLevels() != -1 &&
|
||||
op.getBskAttr().getBaseLog() != -1 &&
|
||||
op.getBskAttr().getGlweDim() != -1 &&
|
||||
op.getBskAttr().getPolySize() != -1 &&
|
||||
op.getPkskAttr().getInputKey().isParameterized() &&
|
||||
op.getPkskAttr().getOutputKey().isParameterized() &&
|
||||
op.getPkskAttr().getLevels() != -1 &&
|
||||
op.getPkskAttr().getBaseLog() != -1;
|
||||
});
|
||||
|
||||
// Add all patterns to convert TFHE types
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
add_mlir_dialect_library(
|
||||
TFHEKeyNormalization
|
||||
TFHEKeyNormalization.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/TFHE
|
||||
DEPENDS
|
||||
TFHEDialect
|
||||
mlir-headers
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransforms)
|
||||
|
||||
target_link_libraries(TFHEKeyNormalization PUBLIC MLIRIR)
|
||||
@@ -0,0 +1,411 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEAttrs.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEParameters.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include <llvm/ADT/SmallSet.h>
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
|
||||
#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h"
|
||||
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
|
||||
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTOps.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
|
||||
#include "concretelang/Support/Constants.h"
|
||||
#include "concretelang/Support/TFHECircuitKeys.h"
|
||||
#include <llvm/Support/raw_ostream.h>
|
||||
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
|
||||
#include <variant>
|
||||
|
||||
namespace TFHE = mlir::concretelang::TFHE;
|
||||
|
||||
using mlir::concretelang::TFHE::GLWECipherTextType;
|
||||
|
||||
namespace conversion {
|
||||
|
||||
class KeyConverter {
|
||||
|
||||
public:
|
||||
KeyConverter(mlir::concretelang::TFHE::TFHECircuitKeys &circuitKeys)
|
||||
: circuitKeys(circuitKeys){};
|
||||
|
||||
TFHE::GLWESecretKey convertSecretKey(TFHE::GLWESecretKey sk) {
|
||||
auto parameterizedKey = sk.getParameterized().value();
|
||||
return TFHE::GLWESecretKey::newNormalized(
|
||||
parameterizedKey.dimension, parameterizedKey.polySize,
|
||||
circuitKeys.getSecretKeyIndex(sk).value());
|
||||
}
|
||||
|
||||
TFHE::GLWEBootstrapKeyAttr
|
||||
convertBootstrapKey(TFHE::GLWEBootstrapKeyAttr bsk) {
|
||||
return TFHE::GLWEBootstrapKeyAttr::get(
|
||||
bsk.getContext(), convertSecretKey(bsk.getInputKey()),
|
||||
convertSecretKey(bsk.getOutputKey()), bsk.getPolySize(),
|
||||
bsk.getGlweDim(), bsk.getLevels(), bsk.getBaseLog(),
|
||||
circuitKeys.getBootstrapKeyIndex(bsk).value());
|
||||
}
|
||||
|
||||
TFHE::GLWEKeyswitchKeyAttr
|
||||
convertKeyswitchKey(TFHE::GLWEKeyswitchKeyAttr ksk) {
|
||||
return TFHE::GLWEKeyswitchKeyAttr::get(
|
||||
ksk.getContext(), convertSecretKey(ksk.getInputKey()),
|
||||
convertSecretKey(ksk.getOutputKey()), ksk.getLevels(), ksk.getBaseLog(),
|
||||
circuitKeys.getKeyswitchKeyIndex(ksk).value());
|
||||
}
|
||||
|
||||
TFHE::GLWEPackingKeyswitchKeyAttr
|
||||
convertPackingKeyswitchKey(TFHE::GLWEPackingKeyswitchKeyAttr pksk) {
|
||||
return TFHE::GLWEPackingKeyswitchKeyAttr::get(
|
||||
pksk.getContext(), convertSecretKey(pksk.getInputKey()),
|
||||
convertSecretKey(pksk.getOutputKey()), pksk.getOutputPolySize(),
|
||||
pksk.getInputLweDim(), pksk.getGlweDim(), pksk.getLevels(),
|
||||
pksk.getBaseLog(),
|
||||
circuitKeys.getPackingKeyswitchKeyIndex(pksk).value());
|
||||
}
|
||||
|
||||
private:
|
||||
mlir::concretelang::TFHE::TFHECircuitKeys circuitKeys;
|
||||
};
|
||||
|
||||
class TypeConverter : public mlir::TypeConverter {
|
||||
|
||||
public:
|
||||
TypeConverter(KeyConverter &keyConverter) : keyConverter(keyConverter) {
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([&](GLWECipherTextType type) {
|
||||
auto key = type.getKey();
|
||||
if (key.isParameterized()) {
|
||||
return GLWECipherTextType::get(type.getContext(),
|
||||
keyConverter.convertSecretKey(key));
|
||||
} else {
|
||||
return type;
|
||||
}
|
||||
});
|
||||
addConversion([&](mlir::RankedTensorType type) {
|
||||
mlir::Type r = mlir::RankedTensorType::get(
|
||||
type.getShape(), this->convertType(type.getElementType()));
|
||||
return r;
|
||||
});
|
||||
addConversion([&](mlir::concretelang::RT::FutureType type) {
|
||||
return mlir::concretelang::RT::FutureType::get(
|
||||
this->convertType(type.dyn_cast<mlir::concretelang::RT::FutureType>()
|
||||
.getElementType()));
|
||||
});
|
||||
addConversion([&](mlir::concretelang::RT::PointerType type) {
|
||||
return mlir::concretelang::RT::PointerType::get(
|
||||
this->convertType(type.dyn_cast<mlir::concretelang::RT::PointerType>()
|
||||
.getElementType()));
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
KeyConverter keyConverter;
|
||||
};
|
||||
} // namespace conversion
|
||||
|
||||
namespace patterns {
|
||||
struct KeySwitchGLWEOpPattern
|
||||
: public mlir::OpRewritePattern<TFHE::KeySwitchGLWEOp> {
|
||||
KeySwitchGLWEOpPattern(mlir::MLIRContext *context,
|
||||
conversion::TypeConverter &typeConverter,
|
||||
conversion::KeyConverter &keyConverter,
|
||||
mlir::PatternBenefit benefit =
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<TFHE::KeySwitchGLWEOp>(context, benefit),
|
||||
keyConverter(keyConverter), typeConverter(typeConverter) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(TFHE::KeySwitchGLWEOp ksOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto newInputTy = typeConverter.convertType(ksOp.getCiphertext().getType())
|
||||
.cast<GLWECipherTextType>();
|
||||
auto newOutputTy = typeConverter.convertType(ksOp.getResult().getType());
|
||||
auto newKeyswitchKey = keyConverter.convertKeyswitchKey(ksOp.getKeyAttr());
|
||||
auto newOp = rewriter.replaceOpWithNewOp<TFHE::KeySwitchGLWEOp>(
|
||||
ksOp, newOutputTy, ksOp.getCiphertext(), newKeyswitchKey);
|
||||
rewriter.startRootUpdate(newOp);
|
||||
newOp.getCiphertext().setType(newInputTy);
|
||||
rewriter.finalizeRootUpdate(newOp);
|
||||
return mlir::success();
|
||||
};
|
||||
|
||||
private:
|
||||
conversion::KeyConverter &keyConverter;
|
||||
conversion::TypeConverter &typeConverter;
|
||||
};
|
||||
|
||||
struct BootstrapGLWEOpPattern
|
||||
: public mlir::OpRewritePattern<TFHE::BootstrapGLWEOp> {
|
||||
BootstrapGLWEOpPattern(mlir::MLIRContext *context,
|
||||
conversion::TypeConverter &typeConverter,
|
||||
conversion::KeyConverter &keyConverter,
|
||||
mlir::PatternBenefit benefit =
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<TFHE::BootstrapGLWEOp>(context, benefit),
|
||||
keyConverter(keyConverter), typeConverter(typeConverter) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(TFHE::BootstrapGLWEOp bsOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto newInputTy = typeConverter.convertType(bsOp.getCiphertext().getType())
|
||||
.cast<GLWECipherTextType>();
|
||||
auto newOutputTy = typeConverter.convertType(bsOp.getResult().getType());
|
||||
auto newBootstrapKey = keyConverter.convertBootstrapKey(bsOp.getKeyAttr());
|
||||
auto newOp = rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
|
||||
bsOp, newOutputTy, bsOp.getCiphertext(), bsOp.getLookupTable(),
|
||||
newBootstrapKey);
|
||||
rewriter.startRootUpdate(newOp);
|
||||
newOp.getCiphertext().setType(newInputTy.cast<GLWECipherTextType>());
|
||||
rewriter.finalizeRootUpdate(newOp);
|
||||
return mlir::success();
|
||||
};
|
||||
|
||||
private:
|
||||
conversion::KeyConverter &keyConverter;
|
||||
conversion::TypeConverter &typeConverter;
|
||||
};
|
||||
|
||||
struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
|
||||
WopPBSGLWEOpPattern(mlir::MLIRContext *context,
|
||||
conversion::TypeConverter &typeConverter,
|
||||
conversion::KeyConverter &keyConverter,
|
||||
mlir::PatternBenefit benefit =
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<TFHE::WopPBSGLWEOp>(context, benefit),
|
||||
keyConverter(keyConverter), typeConverter(typeConverter) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(TFHE::WopPBSGLWEOp wopPBSOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto newInputTy =
|
||||
typeConverter.convertType(wopPBSOp.getCiphertexts().getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
auto newOutputType = typeConverter.convertType(wopPBSOp.getType());
|
||||
auto newKeyswitchKey =
|
||||
keyConverter.convertKeyswitchKey(wopPBSOp.getKskAttr());
|
||||
auto newBootstrapKey =
|
||||
keyConverter.convertBootstrapKey(wopPBSOp.getBskAttr());
|
||||
auto newPackingKeyswitchKey =
|
||||
keyConverter.convertPackingKeyswitchKey(wopPBSOp.getPkskAttr());
|
||||
auto newOp = rewriter.replaceOpWithNewOp<TFHE::WopPBSGLWEOp>(
|
||||
wopPBSOp, newOutputType, wopPBSOp.getCiphertexts(),
|
||||
wopPBSOp.getLookupTable(), newKeyswitchKey, newBootstrapKey,
|
||||
newPackingKeyswitchKey, wopPBSOp.getCrtDecompositionAttr(),
|
||||
wopPBSOp.getCbsLevelsAttr(), wopPBSOp.getCbsBaseLogAttr());
|
||||
rewriter.startRootUpdate(newOp);
|
||||
newOp.getCiphertexts().setType(newInputTy);
|
||||
rewriter.finalizeRootUpdate(newOp);
|
||||
return mlir::success();
|
||||
};
|
||||
|
||||
private:
|
||||
conversion::KeyConverter &keyConverter;
|
||||
conversion::TypeConverter &typeConverter;
|
||||
};
|
||||
} // namespace patterns
|
||||
|
||||
namespace {
|
||||
struct TFHEKeyNormalizationPass
|
||||
: public TFHEKeyNormalizationBase<TFHEKeyNormalizationPass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
|
||||
template <typename Op>
|
||||
void populateWithTFHEOpTypeConversionPattern(
|
||||
mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target,
|
||||
mlir::TypeConverter &typeConverter) {
|
||||
patterns.add<mlir::concretelang::GenericTypeConverterPattern<Op>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
|
||||
target.addDynamicallyLegalOp<Op>(
|
||||
[&](Op op) { return typeConverter.isLegal(op->getResultTypes()); });
|
||||
}
|
||||
|
||||
/// Populate the RewritePatternSet with all patterns that rewrite Concrete
|
||||
/// operators to the corresponding function call to the `Concrete C API`.
|
||||
void populateWithTFHEOpTypeConversionPatterns(
|
||||
mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target,
|
||||
mlir::TypeConverter &typeConverter) {
|
||||
populateWithTFHEOpTypeConversionPattern<mlir::concretelang::TFHE::ZeroGLWEOp>(
|
||||
patterns, target, typeConverter);
|
||||
populateWithTFHEOpTypeConversionPattern<
|
||||
mlir::concretelang::TFHE::ZeroTensorGLWEOp>(patterns, target,
|
||||
typeConverter);
|
||||
populateWithTFHEOpTypeConversionPattern<
|
||||
mlir::concretelang::TFHE::AddGLWEIntOp>(patterns, target, typeConverter);
|
||||
populateWithTFHEOpTypeConversionPattern<mlir::concretelang::TFHE::AddGLWEOp>(
|
||||
patterns, target, typeConverter);
|
||||
populateWithTFHEOpTypeConversionPattern<
|
||||
mlir::concretelang::TFHE::SubGLWEIntOp>(patterns, target, typeConverter);
|
||||
populateWithTFHEOpTypeConversionPattern<mlir::concretelang::TFHE::NegGLWEOp>(
|
||||
patterns, target, typeConverter);
|
||||
populateWithTFHEOpTypeConversionPattern<
|
||||
mlir::concretelang::TFHE::MulGLWEIntOp>(patterns, target, typeConverter);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void TFHEKeyNormalizationPass::runOnOperation() {
|
||||
auto op = this->getOperation();
|
||||
|
||||
auto circuitKeys = TFHE::extractCircuitKeys(op);
|
||||
auto keyConverter = conversion::KeyConverter(circuitKeys);
|
||||
auto typeConverter = conversion::TypeConverter(keyConverter);
|
||||
|
||||
// Parametrize
|
||||
{
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
// function signature
|
||||
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
|
||||
[&](mlir::func::FuncOp funcOp) {
|
||||
return typeConverter.isSignatureLegal(funcOp.getFunctionType()) &&
|
||||
typeConverter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
|
||||
[&](mlir::func::ConstantOp op) {
|
||||
return FunctionConstantOpConversion<
|
||||
conversion::TypeConverter>::isLegal(op, typeConverter);
|
||||
});
|
||||
patterns.add<FunctionConstantOpConversion<conversion::TypeConverter>>(
|
||||
&getContext(), typeConverter);
|
||||
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
|
||||
patterns, typeConverter);
|
||||
|
||||
// Parametrize keyswitch
|
||||
target.addLegalOp<mlir::arith::ConstantOp>();
|
||||
patterns.add<patterns::KeySwitchGLWEOpPattern>(&getContext(), typeConverter,
|
||||
keyConverter);
|
||||
target.addDynamicallyLegalOp<TFHE::KeySwitchGLWEOp>(
|
||||
[&](TFHE::KeySwitchGLWEOp op) {
|
||||
return op.getKeyAttr().getInputKey().isNormalized() &&
|
||||
op.getKeyAttr().getOutputKey().isNormalized() &&
|
||||
op.getKeyAttr().getIndex() != -1;
|
||||
});
|
||||
|
||||
// Parametrize bootstrap
|
||||
patterns.add<patterns::BootstrapGLWEOpPattern>(&getContext(), typeConverter,
|
||||
keyConverter);
|
||||
target.addDynamicallyLegalOp<TFHE::BootstrapGLWEOp>(
|
||||
[&](TFHE::BootstrapGLWEOp op) {
|
||||
return op.getKeyAttr().getInputKey().isNormalized() &&
|
||||
op.getKeyAttr().getOutputKey().isNormalized() &&
|
||||
op.getKeyAttr().getIndex() != -1;
|
||||
});
|
||||
|
||||
// Parametrize wop pbs
|
||||
patterns.add<patterns::WopPBSGLWEOpPattern>(&getContext(), typeConverter,
|
||||
keyConverter);
|
||||
target.addDynamicallyLegalOp<TFHE::WopPBSGLWEOp>(
|
||||
[&](TFHE::WopPBSGLWEOp op) {
|
||||
return op.getKskAttr().getInputKey().isNormalized() &&
|
||||
op.getKskAttr().getOutputKey().isNormalized() &&
|
||||
op.getKskAttr().getIndex() != -1 &&
|
||||
op.getBskAttr().getInputKey().isNormalized() &&
|
||||
op.getBskAttr().getOutputKey().isNormalized() &&
|
||||
op.getBskAttr().getIndex() != -1 &&
|
||||
op.getPkskAttr().getInputKey().isNormalized() &&
|
||||
op.getPkskAttr().getOutputKey().isNormalized() &&
|
||||
op.getPkskAttr().getIndex() != -1;
|
||||
});
|
||||
|
||||
// Add all patterns to convert TFHE types
|
||||
populateWithTFHEOpTypeConversionPatterns(patterns, target, typeConverter);
|
||||
|
||||
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::bufferization::AllocTensorOp>>(&getContext(), typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::bufferization::AllocTensorOp>(target, typeConverter);
|
||||
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
|
||||
conversion::TypeConverter>>(
|
||||
&getContext(), typeConverter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
|
||||
conversion::TypeConverter>>(
|
||||
&getContext(), typeConverter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::scf::ForOp,
|
||||
conversion::TypeConverter>>(
|
||||
&getContext(), typeConverter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::func::ReturnOp,
|
||||
conversion::TypeConverter>>(
|
||||
&getContext(), typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::func::ReturnOp>(
|
||||
target, typeConverter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::YieldOp,
|
||||
conversion::TypeConverter>>(
|
||||
&getContext(), typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::linalg::YieldOp>(
|
||||
target, typeConverter);
|
||||
|
||||
mlir::concretelang::populateWithTensorTypeConverterPatterns(
|
||||
patterns, target, typeConverter);
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
patterns.add<
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::Tracing::TraceCiphertextOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::MakeReadyFutureOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::AwaitFutureOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::CreateAsyncTaskOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::WorkFunctionReturnOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
|
||||
typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::Tracing::TraceCiphertextOp>(target, typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::MakeReadyFutureOp>(target, typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::AwaitFutureOp>(target, typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::CreateAsyncTaskOp>(target, typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target,
|
||||
typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
|
||||
target, typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target,
|
||||
typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::WorkFunctionReturnOp>(target, typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target,
|
||||
typeConverter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createTFHEKeyNormalizationPass() {
|
||||
return std::make_unique<TFHEKeyNormalizationPass>();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -46,11 +46,11 @@ public:
|
||||
TFHEToConcreteTypeConverter() {
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([&](GLWECipherTextType type) {
|
||||
assert(!type.getKey().isNotParameterized());
|
||||
assert(type.getKey().getPolySize().value() == 1 &&
|
||||
assert(type.getKey().isNormalized() && "keys should be normalized");
|
||||
assert(type.getKey().getNormalized().value().polySize == 1 &&
|
||||
"converter doesn't support polynomialSize > 1");
|
||||
llvm::SmallVector<int64_t, 2> shape;
|
||||
shape.push_back(type.getKey().getDimension().value() + 1);
|
||||
shape.push_back(type.getKey().getNormalized().value().dimension + 1);
|
||||
return mlir::RankedTensorType::get(
|
||||
shape, mlir::IntegerType::get(type.getContext(), 64));
|
||||
});
|
||||
@@ -62,8 +62,8 @@ public:
|
||||
mlir::SmallVector<int64_t> newShape;
|
||||
newShape.reserve(type.getShape().size() + 1);
|
||||
newShape.append(type.getShape().begin(), type.getShape().end());
|
||||
assert(!glwe.getKey().isNotParameterized());
|
||||
newShape.push_back(glwe.getKey().getDimension().value() + 1);
|
||||
assert(glwe.getKey().isNormalized());
|
||||
newShape.push_back(glwe.getKey().getNormalized().value().dimension + 1);
|
||||
mlir::Type r = mlir::RankedTensorType::get(
|
||||
newShape, mlir::IntegerType::get(type.getContext(), 64));
|
||||
return r;
|
||||
@@ -128,12 +128,14 @@ struct BootstrapGLWEOpPattern
|
||||
auto glweDimension = adaptor.getKey().getGlweDim();
|
||||
auto levels = adaptor.getKey().getLevels();
|
||||
auto baseLog = adaptor.getKey().getBaseLog();
|
||||
auto inputLweDimension = inputType.getKey().getDimension().value();
|
||||
auto inputLweDimension =
|
||||
inputType.getKey().getNormalized().value().dimension;
|
||||
auto bskIndex = bsOp.getKeyAttr().getIndex();
|
||||
|
||||
rewriter.replaceOpWithNewOp<Concrete::BootstrapLweTensorOp>(
|
||||
bsOp, this->getTypeConverter()->convertType(resultType),
|
||||
adaptor.getCiphertext(), adaptor.getLookupTable(), inputLweDimension,
|
||||
polySize, levels, baseLog, glweDimension);
|
||||
polySize, levels, baseLog, glweDimension, bskIndex);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -164,12 +166,16 @@ struct WopPBSGLWEOpPattern
|
||||
auto pksOutputPolySize = adaptor.getPksk().getOutputPolySize();
|
||||
auto crtDecomposition = adaptor.getCrtDecompositionAttr();
|
||||
auto resultType = op.getType();
|
||||
auto kskIndex = op.getKskAttr().getIndex();
|
||||
auto bskIndex = op.getBskAttr().getIndex();
|
||||
auto pkskIndex = op.getPkskAttr().getIndex();
|
||||
|
||||
rewriter.replaceOpWithNewOp<Concrete::WopPBSCRTLweTensorOp>(
|
||||
op, this->getTypeConverter()->convertType(resultType),
|
||||
adaptor.getCiphertexts(), adaptor.getLookupTable(), bsLevels, bsBaseLog,
|
||||
ksLevels, ksBaseLog, pksInputLweDim, pksOutputPolySize, pksLevels,
|
||||
pksBaseLog, cbsLevels, cbsBaseLog, crtDecomposition);
|
||||
pksBaseLog, cbsLevels, cbsBaseLog, crtDecomposition, kskIndex, bskIndex,
|
||||
pkskIndex);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -199,12 +205,14 @@ struct BatchedBootstrapGLWEOpPattern
|
||||
auto glweDimension = adaptor.getKey().getGlweDim();
|
||||
auto levels = adaptor.getKey().getLevels();
|
||||
auto baseLog = adaptor.getKey().getBaseLog();
|
||||
auto inputLweDimension = inputElementType.getKey().getDimension().value();
|
||||
auto inputLweDimension =
|
||||
inputElementType.getKey().getNormalized().value().dimension;
|
||||
auto bskIndex = adaptor.getKey().getIndex();
|
||||
|
||||
rewriter.replaceOpWithNewOp<Concrete::BatchedBootstrapLweTensorOp>(
|
||||
bbsOp, this->getTypeConverter()->convertType(bbsOp.getType()),
|
||||
adaptor.getCiphertexts(), adaptor.getLookupTable(), inputLweDimension,
|
||||
polySize, levels, baseLog, glweDimension);
|
||||
polySize, levels, baseLog, glweDimension, bskIndex);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -231,12 +239,14 @@ struct KeySwitchGLWEOpPattern
|
||||
|
||||
auto levels = adaptor.getKey().getLevels();
|
||||
auto baseLog = adaptor.getKey().getBaseLog();
|
||||
auto inputDim = inputType.getKey().getDimension().value();
|
||||
auto outputDim = resultType.getKey().getDimension().value();
|
||||
auto inputDim = inputType.getKey().getNormalized().value().dimension;
|
||||
auto outputDim = resultType.getKey().getNormalized().value().dimension;
|
||||
auto kskIndex = ksOp.getKeyAttr().getIndex();
|
||||
|
||||
rewriter.replaceOpWithNewOp<Concrete::KeySwitchLweTensorOp>(
|
||||
ksOp, this->getTypeConverter()->convertType(resultType),
|
||||
adaptor.getCiphertext(), levels, baseLog, inputDim, outputDim);
|
||||
adaptor.getCiphertext(), levels, baseLog, inputDim, outputDim,
|
||||
kskIndex);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -270,12 +280,15 @@ struct BatchedKeySwitchGLWEOpPattern
|
||||
|
||||
auto levels = adaptor.getKey().getLevels();
|
||||
auto baseLog = adaptor.getKey().getBaseLog();
|
||||
auto inputDim = inputElementType.getKey().getDimension().value();
|
||||
auto outputDim = resultElementType.getKey().getDimension().value();
|
||||
auto inputDim = inputElementType.getKey().getNormalized().value().dimension;
|
||||
auto outputDim =
|
||||
resultElementType.getKey().getNormalized().value().dimension;
|
||||
auto kskIndex = adaptor.getKey().getIndex();
|
||||
|
||||
rewriter.replaceOpWithNewOp<Concrete::BatchedKeySwitchLweTensorOp>(
|
||||
bksOp, this->getTypeConverter()->convertType(bksOp.getType()),
|
||||
adaptor.getCiphertexts(), levels, baseLog, inputDim, outputDim);
|
||||
adaptor.getCiphertexts(), levels, baseLog, inputDim, outputDim,
|
||||
kskIndex);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEParameters.h"
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEAttrs.cpp.inc"
|
||||
@@ -15,6 +16,7 @@
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOpsDialect.cpp.inc"
|
||||
|
||||
#include "concretelang/Support/Constants.h"
|
||||
#include "concretelang/Support/Variants.h"
|
||||
|
||||
using namespace mlir::concretelang::TFHE;
|
||||
|
||||
@@ -36,17 +38,32 @@ void TFHEDialect::initialize() {
|
||||
}
|
||||
|
||||
/// Verify that GLWE parameter are consistant
|
||||
/// - The bits parameter is 64 (we support only this for v0)
|
||||
::mlir::LogicalResult GLWECipherTextType::verify(
|
||||
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
|
||||
GLWESecretKey key) {
|
||||
if (!key.isNotParameterized() && key.getPolySize().value() == 0) {
|
||||
emitError() << "GLWE key has zero poly size.";
|
||||
return ::mlir::failure();
|
||||
}
|
||||
if (!key.isNotParameterized() && key.getDimension().value() == 0) {
|
||||
emitError() << "GLWE key has zero dimension.";
|
||||
return ::mlir::failure();
|
||||
}
|
||||
return ::mlir::success();
|
||||
return std::visit(
|
||||
overloaded{[](GLWESecretKeyNone sk) { return mlir::success(); },
|
||||
[&](GLWESecretKeyParameterized sk) {
|
||||
if (sk.dimension == 0) {
|
||||
emitError() << "GLWE key has zero dimension.";
|
||||
return ::mlir::failure();
|
||||
}
|
||||
if (sk.polySize == 0) {
|
||||
emitError() << "GLWE key has zero poly size.";
|
||||
return ::mlir::failure();
|
||||
}
|
||||
return mlir::success();
|
||||
},
|
||||
[&](GLWESecretKeyNormalized sk) {
|
||||
if (sk.dimension == 0) {
|
||||
emitError() << "GLWE key has zero dimension.";
|
||||
return ::mlir::failure();
|
||||
}
|
||||
if (sk.polySize == 0) {
|
||||
emitError() << "GLWE key has zero poly size.";
|
||||
return ::mlir::failure();
|
||||
}
|
||||
return mlir::success();
|
||||
}},
|
||||
key.inner);
|
||||
}
|
||||
|
||||
@@ -4,70 +4,131 @@
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEParameters.h"
|
||||
#include "concretelang/Support/Variants.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include <variant>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace TFHE {
|
||||
|
||||
GLWESecretKey::GLWESecretKey() {
|
||||
dimension = -1;
|
||||
polySize = -1;
|
||||
id = -1;
|
||||
}
|
||||
|
||||
GLWESecretKey::GLWESecretKey(int64_t dimension, int64_t polySize, int64_t id) {
|
||||
assert(dimension > 0);
|
||||
assert(polySize > 0);
|
||||
assert(id > 0);
|
||||
this->dimension = dimension;
|
||||
this->polySize = polySize;
|
||||
this->id = id;
|
||||
}
|
||||
|
||||
bool GLWESecretKey::operator==(GLWESecretKey other) {
|
||||
return this->id == other.id && this->dimension == other.dimension &&
|
||||
bool GLWESecretKeyParameterized::operator==(
|
||||
const GLWESecretKeyParameterized other) const {
|
||||
return this->dimension == other.dimension &&
|
||||
this->identifier == other.identifier &&
|
||||
this->polySize == other.polySize;
|
||||
}
|
||||
|
||||
bool GLWESecretKeyNormalized::operator==(
|
||||
const GLWESecretKeyNormalized other) const {
|
||||
return this->dimension == other.dimension && this->index == other.index &&
|
||||
this->polySize == other.polySize;
|
||||
}
|
||||
|
||||
GLWESecretKey GLWESecretKey::newNone() {
|
||||
return GLWESecretKey{GLWESecretKeyNone{}};
|
||||
}
|
||||
|
||||
GLWESecretKey GLWESecretKey::newNormalized(uint64_t dimension,
|
||||
uint64_t polySize, uint64_t index) {
|
||||
return GLWESecretKey{GLWESecretKeyNormalized{dimension, polySize, index}};
|
||||
}
|
||||
|
||||
GLWESecretKey GLWESecretKey::newParameterized(uint64_t dimension,
|
||||
uint64_t polySize,
|
||||
uint64_t identifier) {
|
||||
return GLWESecretKey{
|
||||
GLWESecretKeyParameterized{dimension, polySize, identifier}};
|
||||
}
|
||||
|
||||
bool GLWESecretKey::operator==(const GLWESecretKey other) const {
|
||||
return this->id == other.id && this->dimension == other.dimension &&
|
||||
this->polySize == other.polySize;
|
||||
return std::visit(
|
||||
overloaded{
|
||||
[](GLWESecretKeyNone thisK, GLWESecretKeyNone otherK) {
|
||||
return true;
|
||||
},
|
||||
[](GLWESecretKeyNormalized thisK, GLWESecretKeyNormalized otherK) {
|
||||
return thisK == otherK;
|
||||
},
|
||||
[](GLWESecretKeyParameterized thisK,
|
||||
GLWESecretKeyParameterized otherK) { return thisK == otherK; },
|
||||
[](auto _thisK, auto _otherK) { return false; }},
|
||||
this->inner, other.inner);
|
||||
}
|
||||
|
||||
bool GLWESecretKey::operator!=(GLWESecretKey other) {
|
||||
return this->id != other.id || this->dimension != other.dimension ||
|
||||
this->polySize != other.polySize;
|
||||
bool GLWESecretKey::operator!=(const GLWESecretKey other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
std::optional<int64_t> GLWESecretKey::getDimension() const {
|
||||
if (this->isNotParameterized()) {
|
||||
return std::nullopt;
|
||||
template <typename V> bool GLWESecretKey::is() {
|
||||
return std::holds_alternative<V>(this->inner);
|
||||
}
|
||||
|
||||
bool GLWESecretKey::isNone() { return is<GLWESecretKeyNone>(); }
|
||||
bool GLWESecretKey::isParameterized() {
|
||||
return is<GLWESecretKeyParameterized>();
|
||||
}
|
||||
bool GLWESecretKey::isNormalized() { return is<GLWESecretKeyNormalized>(); }
|
||||
|
||||
template <typename V> std::optional<V> GLWESecretKey::get() {
|
||||
if (this->is<V>()) {
|
||||
return std::get<V>(this->inner);
|
||||
} else {
|
||||
return this->dimension;
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<int64_t> GLWESecretKey::getPolySize() const {
|
||||
if (this->isNotParameterized()) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
return this->polySize;
|
||||
}
|
||||
std::optional<GLWESecretKeyNone> GLWESecretKey::getNone() {
|
||||
return get<GLWESecretKeyNone>();
|
||||
}
|
||||
|
||||
mlir::Optional<int64_t> GLWESecretKey::getId() const {
|
||||
if (this->isNotParameterized()) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
return this->id;
|
||||
}
|
||||
std::optional<GLWESecretKeyParameterized> GLWESecretKey::getParameterized() {
|
||||
return get<GLWESecretKeyParameterized>();
|
||||
}
|
||||
std::optional<GLWESecretKeyNormalized> GLWESecretKey::getNormalized() {
|
||||
return get<GLWESecretKeyNormalized>();
|
||||
}
|
||||
|
||||
bool GLWESecretKey::isNotParameterized() const { return id <= 0; }
|
||||
|
||||
llvm::hash_code hash_value(const GLWESecretKey &key) {
|
||||
return llvm::hash_combine("GlweSecretKey", key.getDimension(),
|
||||
key.getPolySize(), key.getId());
|
||||
return std::visit(overloaded{[](GLWESecretKeyNone sk) {
|
||||
return llvm::hash_value("GlweSecretKeyNone");
|
||||
},
|
||||
[](GLWESecretKeyParameterized sk) {
|
||||
return llvm::hash_combine(
|
||||
"GlweSecretKeyParameterized", sk.dimension,
|
||||
sk.polySize, sk.identifier);
|
||||
},
|
||||
[](GLWESecretKeyNormalized sk) {
|
||||
return llvm::hash_combine(
|
||||
"GlweSecretKeyNormalized", sk.dimension,
|
||||
sk.polySize, sk.index);
|
||||
}},
|
||||
key.inner);
|
||||
}
|
||||
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, GLWESecretKeyNone sk) {
|
||||
OS << "sk?";
|
||||
return OS;
|
||||
}
|
||||
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
GLWESecretKeyParameterized sk) {
|
||||
OS << "sk<" << sk.identifier << "," << sk.polySize << "," << sk.dimension
|
||||
<< ">";
|
||||
return OS;
|
||||
}
|
||||
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
GLWESecretKeyNormalized sk) {
|
||||
OS << "sk[" << sk.index << "]<" << sk.polySize << "," << sk.dimension << ">";
|
||||
return OS;
|
||||
}
|
||||
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, GLWESecretKey sk) {
|
||||
std::visit(overloaded{[&](GLWESecretKeyNone nsk) { OS << nsk; },
|
||||
[&](GLWESecretKeyNormalized nsk) { OS << nsk; },
|
||||
[&](GLWESecretKeyParameterized nsk) { OS << nsk; }},
|
||||
sk.inner);
|
||||
return OS;
|
||||
}
|
||||
|
||||
} // namespace TFHE
|
||||
@@ -77,37 +138,45 @@ llvm::hash_code hash_value(const GLWESecretKey &key) {
|
||||
namespace mlir {
|
||||
AsmPrinter &operator<<(AsmPrinter &p,
|
||||
mlir::concretelang::TFHE::GLWESecretKey key) {
|
||||
if (key.isNotParameterized()) {
|
||||
p << "sk[?]";
|
||||
} else {
|
||||
p << "sk[" << key.getId() << "]<" << key.getPolySize().value() << ","
|
||||
<< key.getDimension().value() << ">";
|
||||
}
|
||||
p.getStream() << key;
|
||||
return p;
|
||||
}
|
||||
|
||||
FailureOr<mlir::concretelang::TFHE::GLWESecretKey>
|
||||
FieldParser<mlir::concretelang::TFHE::GLWESecretKey>::parse(AsmParser &parser) {
|
||||
int64_t dimension = -1, polySize = -1, id = -1;
|
||||
if (parser.parseKeyword("sk") || parser.parseLSquare()) {
|
||||
uint64_t dimension = -1, polySize = -1, id = -1;
|
||||
|
||||
if (parser.parseKeyword("sk")) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto maybeId = parser.parseOptionalInteger(id);
|
||||
if (maybeId.has_value()) {
|
||||
if (maybeId.value() || parser.parseRSquare() || parser.parseLess() ||
|
||||
|
||||
if (parser.parseOptionalQuestion().succeeded()) { // Parsing none key
|
||||
return mlir::concretelang::TFHE::GLWESecretKey::newNone();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalLSquare().succeeded()) { // Parsing normalized key
|
||||
if (parser.parseInteger(id) || parser.parseRSquare() ||
|
||||
parser.parseLess() || parser.parseInteger(polySize) ||
|
||||
parser.parseComma() || parser.parseInteger(dimension) ||
|
||||
parser.parseGreater()) {
|
||||
return mlir::failure();
|
||||
} else {
|
||||
return mlir::concretelang::TFHE::GLWESecretKey::newNormalized(
|
||||
dimension, polySize, id);
|
||||
}
|
||||
}
|
||||
|
||||
if (parser.parseOptionalLess().succeeded()) { // Parsing parameterized key
|
||||
if (parser.parseInteger(id) || parser.parseComma() ||
|
||||
parser.parseInteger(polySize) || parser.parseComma() ||
|
||||
parser.parseInteger(dimension) || parser.parseGreater()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
} else {
|
||||
if (parser.parseQuestion() || parser.parseRSquare()) {
|
||||
return mlir::failure();
|
||||
} else {
|
||||
return mlir::concretelang::TFHE::GLWESecretKey::newParameterized(
|
||||
dimension, polySize, id);
|
||||
}
|
||||
}
|
||||
if (id <= 0) {
|
||||
return mlir::concretelang::TFHE::GLWESecretKey();
|
||||
} else {
|
||||
return mlir::concretelang::TFHE::GLWESecretKey(dimension, polySize, id);
|
||||
}
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
@@ -70,6 +70,8 @@ struct Process {
|
||||
Param glwe_dim;
|
||||
Param precision;
|
||||
Param output_size;
|
||||
Param ksk_index;
|
||||
Param bsk_index;
|
||||
Context ctx;
|
||||
void (*fun)(Process *);
|
||||
};
|
||||
@@ -102,7 +104,7 @@ void memref_keyswitch_lwe_u64_process(Process *p) {
|
||||
out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0],
|
||||
ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0],
|
||||
p->level.val, p->base_log.val, p->input_lwe_dim.val,
|
||||
p->output_lwe_dim.val, p->ctx.val);
|
||||
p->output_lwe_dim.val, p->ksk_index.val, p->ctx.val);
|
||||
(p->output_streams[0]).memref_stream->put(out);
|
||||
}
|
||||
delete p;
|
||||
@@ -123,7 +125,7 @@ void memref_bootstrap_lwe_u64_process(Process *p) {
|
||||
ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0],
|
||||
tlu.allocated, tlu.aligned, tlu.offset, tlu.sizes[0], tlu.strides[0],
|
||||
p->input_lwe_dim.val, p->poly_size.val, p->level.val, p->base_log.val,
|
||||
p->glwe_dim.val, p->ctx.val);
|
||||
p->glwe_dim.val, p->bsk_index.val, p->ctx.val);
|
||||
(p->output_streams[0]).memref_stream->put(out);
|
||||
}
|
||||
delete p;
|
||||
@@ -278,7 +280,7 @@ void stream_emulator_make_memref_negate_lwe_ciphertext_u64_process(void *dfg,
|
||||
void stream_emulator_make_memref_keyswitch_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log,
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size,
|
||||
void *context) {
|
||||
uint32_t ksk_index, void *context) {
|
||||
mlir::concretelang::stream_emulator::Process *p =
|
||||
new mlir::concretelang::stream_emulator::Process;
|
||||
p->input_streams.push_back(
|
||||
@@ -292,6 +294,7 @@ void stream_emulator_make_memref_keyswitch_lwe_u64_process(
|
||||
p->input_lwe_dim.val = input_lwe_dim;
|
||||
p->output_lwe_dim.val = output_lwe_dim;
|
||||
p->output_size.val = output_size;
|
||||
p->ksk_index.val = ksk_index;
|
||||
p->ctx.val = (mlir::concretelang::RuntimeContext *)context;
|
||||
p->fun =
|
||||
mlir::concretelang::stream_emulator::memref_keyswitch_lwe_u64_process;
|
||||
@@ -302,7 +305,7 @@ void stream_emulator_make_memref_keyswitch_lwe_u64_process(
|
||||
void stream_emulator_make_memref_bootstrap_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim,
|
||||
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
|
||||
uint32_t output_size, void *context) {
|
||||
uint32_t output_size, uint32_t bsk_index, void *context) {
|
||||
mlir::concretelang::stream_emulator::Process *p =
|
||||
new mlir::concretelang::stream_emulator::Process;
|
||||
p->input_streams.push_back(
|
||||
@@ -320,6 +323,7 @@ void stream_emulator_make_memref_bootstrap_lwe_u64_process(
|
||||
p->base_log.val = base_log;
|
||||
p->glwe_dim.val = glwe_dim;
|
||||
p->output_size.val = output_size;
|
||||
p->bsk_index.val = bsk_index;
|
||||
p->ctx.val = (mlir::concretelang::RuntimeContext *)context;
|
||||
p->fun =
|
||||
mlir::concretelang::stream_emulator::memref_bootstrap_lwe_u64_process;
|
||||
|
||||
@@ -68,7 +68,7 @@ void memref_keyswitch_lwe_cuda_u64(
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint32_t level, uint32_t base_log,
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim,
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t ksk_index,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
assert(out_stride == 1);
|
||||
assert(ct0_stride == 1);
|
||||
@@ -78,7 +78,7 @@ void memref_keyswitch_lwe_cuda_u64(
|
||||
// Output 1D memref as 2D memref
|
||||
ct0_allocated, ct0_aligned, ct0_offset, 1, ct0_size, ct0_size, ct0_stride,
|
||||
// Keyswitch additional arguments
|
||||
level, base_log, input_lwe_dim, output_lwe_dim, context);
|
||||
level, base_log, input_lwe_dim, output_lwe_dim, ksk_index, context);
|
||||
}
|
||||
|
||||
void memref_bootstrap_lwe_cuda_u64(
|
||||
@@ -88,7 +88,7 @@ void memref_bootstrap_lwe_cuda_u64(
|
||||
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
|
||||
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
|
||||
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
|
||||
uint32_t base_log, uint32_t glwe_dim,
|
||||
uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
memref_batched_bootstrap_lwe_cuda_u64(
|
||||
// Output 1D memref as 2D memref
|
||||
@@ -98,7 +98,7 @@ void memref_bootstrap_lwe_cuda_u64(
|
||||
// Table lookup memref
|
||||
tlu_allocated, tlu_aligned, tlu_offset, tlu_size, tlu_stride,
|
||||
// Bootstrap additional arguments
|
||||
input_lwe_dim, poly_size, level, base_log, glwe_dim, context);
|
||||
input_lwe_dim, poly_size, level, base_log, glwe_dim, bsk_index, context);
|
||||
}
|
||||
|
||||
// Batched CUDA function //////////////////////////////////////////////////////
|
||||
@@ -110,7 +110,8 @@ void memref_batched_keyswitch_lwe_cuda_u64(
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint32_t level,
|
||||
uint32_t base_log, uint32_t input_lwe_dim, uint32_t output_lwe_dim,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
uint32_t ksk_index, mlir::concretelang::RuntimeContext *context) {
|
||||
assert(ksk_index == 0 && "multiple ksk is not yet implemented on GPU");
|
||||
assert(out_size0 == ct0_size0);
|
||||
assert(out_size1 == output_lwe_dim + 1);
|
||||
assert(ct0_size1 == input_lwe_dim + 1);
|
||||
@@ -154,8 +155,9 @@ void memref_batched_bootstrap_lwe_cuda_u64(
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
|
||||
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size,
|
||||
uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size,
|
||||
uint32_t level, uint32_t base_log, uint32_t glwe_dim,
|
||||
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
assert(bsk_index == 0 && "multiple bsk is not yet implemented on GPU");
|
||||
assert(out_size0 == ct0_size0);
|
||||
assert(out_size1 == glwe_dim * poly_size + 1);
|
||||
// TODO: Multi GPU
|
||||
@@ -495,16 +497,19 @@ void memref_negate_lwe_ciphertext_u64(
|
||||
out_aligned + out_offset, ct0_aligned + ct0_offset, lwe_dimension);
|
||||
}
|
||||
|
||||
void memref_keyswitch_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint32_t decomposition_level_count,
|
||||
uint32_t decomposition_base_log, uint32_t input_dimension,
|
||||
uint32_t output_dimension, mlir::concretelang::RuntimeContext *context) {
|
||||
void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint64_t out_offset, uint64_t out_size,
|
||||
uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset,
|
||||
uint64_t ct0_size, uint64_t ct0_stride,
|
||||
uint32_t decomposition_level_count,
|
||||
uint32_t decomposition_base_log,
|
||||
uint32_t input_dimension,
|
||||
uint32_t output_dimension, uint32_t ksk_index,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
assert(out_stride == 1 && ct0_stride == 1);
|
||||
// Get keyswitch key - TODO Give a non hardcoded keyID
|
||||
const uint64_t *keyswitch_key = context->keyswitch_key_buffer(0);
|
||||
// Get keyswitch key
|
||||
const uint64_t *keyswitch_key = context->keyswitch_key_buffer(ksk_index);
|
||||
// Get stack parameter
|
||||
concrete_cpu_keyswitch_lwe_ciphertext_u64(
|
||||
out_aligned + out_offset, ct0_aligned + ct0_offset, keyswitch_key,
|
||||
@@ -519,13 +524,13 @@ void memref_batched_keyswitch_lwe_u64(
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint32_t level,
|
||||
uint32_t base_log, uint32_t input_lwe_dim, uint32_t output_lwe_dim,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
uint32_t ksk_index, mlir::concretelang::RuntimeContext *context) {
|
||||
for (size_t i = 0; i < ct0_size0; i++) {
|
||||
memref_keyswitch_lwe_u64(
|
||||
out_allocated + i * out_size1, out_aligned + i * out_size1, out_offset,
|
||||
out_size1, out_stride1, ct0_allocated + i * ct0_size1,
|
||||
ct0_aligned + i * ct0_size1, ct0_offset, ct0_size1, ct0_stride1, level,
|
||||
base_log, input_lwe_dim, output_lwe_dim, context);
|
||||
base_log, input_lwe_dim, output_lwe_dim, ksk_index, context);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -537,7 +542,8 @@ void memref_bootstrap_lwe_u64(
|
||||
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
|
||||
uint32_t input_lwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t decomposition_level_count, uint32_t decomposition_base_log,
|
||||
uint32_t glwe_dimension, mlir::concretelang::RuntimeContext *context) {
|
||||
uint32_t glwe_dimension, uint32_t bsk_index,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
|
||||
uint64_t glwe_ct_size = polynomial_size * (glwe_dimension + 1);
|
||||
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size * sizeof(uint64_t));
|
||||
@@ -551,10 +557,9 @@ void memref_bootstrap_lwe_u64(
|
||||
glwe_ct[polynomial_size * glwe_dimension + i] = tlu[i];
|
||||
}
|
||||
|
||||
// Get fourrier bootstrap key - TODO Give a non hardcoded keyID
|
||||
size_t keyId = 0;
|
||||
const auto &fft = context->fft(keyId);
|
||||
auto bootstrap_key = context->fourier_bootstrap_key_buffer(keyId);
|
||||
// Get fourrier bootstrap key
|
||||
const auto &fft = context->fft(bsk_index);
|
||||
auto bootstrap_key = context->fourier_bootstrap_key_buffer(bsk_index);
|
||||
// Get stack parameter
|
||||
size_t scratch_size;
|
||||
size_t scratch_align;
|
||||
@@ -582,7 +587,7 @@ void memref_batched_bootstrap_lwe_u64(
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
|
||||
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size,
|
||||
uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size,
|
||||
uint32_t level, uint32_t base_log, uint32_t glwe_dim,
|
||||
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
|
||||
for (size_t i = 0; i < out_size0; i++) {
|
||||
@@ -591,7 +596,7 @@ void memref_batched_bootstrap_lwe_u64(
|
||||
out_size1, out_stride1, ct0_allocated, ct0_aligned + i * ct0_size1,
|
||||
ct0_offset, ct0_size1, ct0_stride1, tlu_allocated, tlu_aligned,
|
||||
tlu_offset, tlu_size, tlu_stride, input_lwe_dim, poly_size, level,
|
||||
base_log, glwe_dim, context);
|
||||
base_log, glwe_dim, bsk_index, context);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -617,10 +622,12 @@ void memref_wop_pbs_crt_buffer(
|
||||
uint64_t crt_decomp_offset, uint64_t crt_decomp_size,
|
||||
uint64_t crt_decomp_stride,
|
||||
// Additional crypto parameters
|
||||
uint32_t lwe_small_size, uint32_t cbs_level_count, uint32_t cbs_base_log,
|
||||
uint32_t lwe_small_dim, uint32_t cbs_level_count, uint32_t cbs_base_log,
|
||||
uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count,
|
||||
uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log,
|
||||
uint32_t polynomial_size,
|
||||
// Key Indices,
|
||||
uint32_t ksk_index, uint32_t bsk_index, uint32_t pksk_index,
|
||||
// runtime context that hold evluation keys
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
|
||||
@@ -635,7 +642,7 @@ void memref_wop_pbs_crt_buffer(
|
||||
// Check for the size S
|
||||
assert(out_size_1 == in_size_1);
|
||||
|
||||
uint64_t lwe_small_dim = lwe_small_size - 1;
|
||||
uint64_t lwe_small_size = lwe_small_dim + 1;
|
||||
|
||||
assert(out_size_1 == in_size_1);
|
||||
uint64_t lwe_big_size = in_size_1;
|
||||
@@ -672,13 +679,9 @@ void memref_wop_pbs_crt_buffer(
|
||||
std::vector<uint64_t> in_copy(first_ciphertext, first_ciphertext + copy_size);
|
||||
// Extraction of each bit for each block
|
||||
|
||||
size_t fftKeyId = 0;
|
||||
const auto &fft = context->fft(fftKeyId);
|
||||
size_t bskKeyId = 0;
|
||||
auto bootstrap_key = context->fourier_bootstrap_key_buffer(bskKeyId);
|
||||
|
||||
size_t kskKeyId = 0;
|
||||
auto keyswicth_key = context->keyswitch_key_buffer(kskKeyId);
|
||||
const auto &fft = context->fft(bsk_index);
|
||||
auto bootstrap_key = context->fourier_bootstrap_key_buffer(bsk_index);
|
||||
auto keyswicth_key = context->keyswitch_key_buffer(ksk_index);
|
||||
|
||||
for (int64_t i = crt_decomp_size - 1, extract_bits_output_offset = 0; i >= 0;
|
||||
extract_bits_output_offset += number_of_bits_per_block[i--]) {
|
||||
@@ -731,8 +734,7 @@ void memref_wop_pbs_crt_buffer(
|
||||
|
||||
auto *scratch = (uint8_t *)aligned_alloc(scratch_align, scratch_size);
|
||||
|
||||
size_t fpkskKeyId = 0;
|
||||
auto fp_keyswicth_key = context->fp_keyswitch_key_buffer(fpkskKeyId);
|
||||
auto fp_keyswicth_key = context->fp_keyswitch_key_buffer(pksk_index);
|
||||
|
||||
concrete_cpu_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_u64(
|
||||
out_aligned + out_offset, extract_bits_output_buffer,
|
||||
|
||||
@@ -4,10 +4,12 @@ add_mlir_library(
|
||||
Jit.cpp
|
||||
CompilationFeedback.cpp
|
||||
CompilerEngine.cpp
|
||||
TFHECircuitKeys.cpp
|
||||
Encodings.cpp
|
||||
JITSupport.cpp
|
||||
LambdaArgument.cpp
|
||||
V0Parameters.cpp
|
||||
V0ClientParameters.cpp
|
||||
ClientParametersGeneration.cpp
|
||||
logging.cpp
|
||||
Jit.cpp
|
||||
LLVMEmitFile.cpp
|
||||
|
||||
@@ -0,0 +1,378 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
#include <cassert>
|
||||
#include <llvm/ADT/SmallVector.h>
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <unordered_set>
|
||||
#include <variant>
|
||||
|
||||
#include <llvm/ADT/Optional.h>
|
||||
#include <llvm/ADT/STLExtras.h>
|
||||
#include <llvm/ADT/SmallSet.h>
|
||||
#include <llvm/Support/Error.h>
|
||||
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
|
||||
#include "concrete/curves.h"
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEAttrs.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEParameters.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
#include "concretelang/Support/Encodings.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "concretelang/Support/TFHECircuitKeys.h"
|
||||
#include "concretelang/Support/Variants.h"
|
||||
#include "llvm/Config/abi-breaking.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
namespace clientlib = ::concretelang::clientlib;
|
||||
using ::concretelang::clientlib::ChunkInfo;
|
||||
using ::concretelang::clientlib::CircuitGate;
|
||||
using ::concretelang::clientlib::ClientParameters;
|
||||
using ::concretelang::clientlib::Encoding;
|
||||
using ::concretelang::clientlib::EncryptionGate;
|
||||
using ::concretelang::clientlib::LweSecretKeyID;
|
||||
using ::concretelang::clientlib::Precision;
|
||||
using ::concretelang::clientlib::Variance;
|
||||
|
||||
const auto keyFormat = concrete::BINARY;
|
||||
|
||||
llvm::Expected<CircuitGate>
|
||||
generateGate(mlir::Type type, encodings::Encoding encoding,
|
||||
concrete::SecurityCurve curve,
|
||||
std::optional<CRTDecomposition> maybeCrt) {
|
||||
auto scalarVisitor = overloaded{
|
||||
[&](encodings::EncryptedIntegerScalarEncoding enc)
|
||||
-> llvm::Expected<CircuitGate> {
|
||||
TFHE::GLWESecretKeyNormalized normKey;
|
||||
if (type.isa<RankedTensorType>()) {
|
||||
normKey = type.cast<RankedTensorType>()
|
||||
.getElementType()
|
||||
.cast<TFHE::GLWECipherTextType>()
|
||||
.getKey()
|
||||
.getNormalized()
|
||||
.value();
|
||||
} else {
|
||||
normKey = type.cast<TFHE::GLWECipherTextType>()
|
||||
.getKey()
|
||||
.getNormalized()
|
||||
.value();
|
||||
}
|
||||
size_t width = enc.width;
|
||||
bool isSigned = enc.isSigned;
|
||||
uint64_t size = 0;
|
||||
std::vector<int64_t> dims{};
|
||||
LweSecretKeyID secretKeyID = normKey.index;
|
||||
Variance variance = curve.getVariance(1, normKey.dimension, 64);
|
||||
CRTDecomposition crt = maybeCrt.value_or(std::vector<int64_t>());
|
||||
return CircuitGate{
|
||||
/* .encryption = */ std::optional<EncryptionGate>({
|
||||
/* .secretKeyID = */ secretKeyID,
|
||||
/* .variance = */ variance,
|
||||
/* .encoding = */
|
||||
{
|
||||
/* .precision = */ width,
|
||||
/* .crt = */ crt,
|
||||
/*.sign = */ isSigned,
|
||||
},
|
||||
}),
|
||||
/*.shape = */
|
||||
{
|
||||
/*.width = */ width,
|
||||
/*.dimensions = */ dims,
|
||||
/*.size = */ size,
|
||||
/*.sign = */ isSigned,
|
||||
},
|
||||
/*.chunkInfo = */ std::nullopt,
|
||||
};
|
||||
},
|
||||
[&](encodings::EncryptedChunkedIntegerScalarEncoding enc)
|
||||
-> llvm::Expected<CircuitGate> {
|
||||
auto tensorType = type.cast<mlir::RankedTensorType>();
|
||||
auto glweType =
|
||||
tensorType.getElementType().cast<TFHE::GLWECipherTextType>();
|
||||
auto normKey = glweType.getKey().getNormalized().value();
|
||||
size_t width = enc.chunkSize;
|
||||
assert(enc.width % enc.chunkWidth == 0);
|
||||
uint64_t size = enc.width / enc.chunkWidth;
|
||||
bool isSigned = enc.isSigned;
|
||||
std::vector<int64_t> dims{
|
||||
(int64_t)size,
|
||||
};
|
||||
LweSecretKeyID secretKeyID = normKey.index;
|
||||
Variance variance = curve.getVariance(1, normKey.dimension, 64);
|
||||
CRTDecomposition crt = maybeCrt.value_or(std::vector<int64_t>());
|
||||
return CircuitGate{
|
||||
/* .encryption = */ std::optional<EncryptionGate>({
|
||||
/* .secretKeyID = */ secretKeyID,
|
||||
/* .variance = */ variance,
|
||||
/* .encoding = */
|
||||
{
|
||||
/* .precision = */ width,
|
||||
/* .crt = */ crt,
|
||||
/*.sign = */ isSigned,
|
||||
},
|
||||
}),
|
||||
/*.shape = */
|
||||
{
|
||||
/*.width = */ width,
|
||||
/*.dimensions = */ dims,
|
||||
/*.size = */ size,
|
||||
/*.sign = */ isSigned,
|
||||
},
|
||||
/*.chunkInfo = */
|
||||
std::optional<ChunkInfo>(
|
||||
{(unsigned int)enc.chunkSize, (unsigned int)enc.chunkWidth}),
|
||||
};
|
||||
},
|
||||
[&](encodings::EncryptedBoolScalarEncoding enc)
|
||||
-> llvm::Expected<CircuitGate> {
|
||||
auto glweType = type.cast<TFHE::GLWECipherTextType>();
|
||||
auto normKey = glweType.getKey().getNormalized().value();
|
||||
size_t width =
|
||||
mlir::concretelang::FHE::EncryptedBooleanType::getWidth();
|
||||
LweSecretKeyID secretKeyID = normKey.index;
|
||||
Variance variance = curve.getVariance(1, normKey.dimension, 64);
|
||||
return CircuitGate{
|
||||
/* .encryption = */ std::optional<EncryptionGate>({
|
||||
/* .secretKeyID = */ secretKeyID,
|
||||
/* .variance = */ variance,
|
||||
/* .encoding = */
|
||||
{
|
||||
/* .precision = */ width,
|
||||
/* .crt = */ std::vector<int64_t>(),
|
||||
/* .sign = */ false,
|
||||
},
|
||||
}),
|
||||
/*.shape = */
|
||||
{
|
||||
/*.width = */ width,
|
||||
/*.dimensions = */ std::vector<int64_t>(),
|
||||
/*.size = */ 0,
|
||||
/*.sign = */ false,
|
||||
},
|
||||
/*.chunkInfo = */ std::nullopt,
|
||||
};
|
||||
},
|
||||
[&](encodings::PlaintextScalarEncoding enc)
|
||||
-> llvm::Expected<CircuitGate> {
|
||||
size_t width = type.getIntOrFloatBitWidth();
|
||||
bool sign = type.isSignedInteger();
|
||||
return CircuitGate{
|
||||
/*.encryption = */ std::nullopt,
|
||||
/*.shape = */
|
||||
{/*.width = */ width,
|
||||
/*.dimensions = */ std::vector<int64_t>(),
|
||||
/*.size = */ 0,
|
||||
/* .sign */ sign},
|
||||
/*.chunkInfo = */ std::nullopt,
|
||||
};
|
||||
},
|
||||
[&](encodings::IndexScalarEncoding enc) -> llvm::Expected<CircuitGate> {
|
||||
// TODO - The index type is dependant of the target architecture,
|
||||
// so actually we assume we target only 64 bits, we need to have
|
||||
// some the size of the word of the target system.
|
||||
size_t width = 64;
|
||||
bool sign = type.isSignedInteger();
|
||||
return CircuitGate{
|
||||
/*.encryption = */ std::nullopt,
|
||||
/*.shape = */
|
||||
{/*.width = */ width,
|
||||
/*.dimensions = */ std::vector<int64_t>(),
|
||||
/*.size = */ 0,
|
||||
/* .sign */ sign},
|
||||
/*.chunkInfo = */ std::nullopt,
|
||||
};
|
||||
},
|
||||
[&](auto enc) -> llvm::Expected<CircuitGate> {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot convert MLIR type to shape there",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}};
|
||||
auto genericVisitor = overloaded{
|
||||
[&](encodings::ScalarEncoding enc) -> llvm::Expected<CircuitGate> {
|
||||
return std::visit(scalarVisitor, enc);
|
||||
},
|
||||
[&](encodings::TensorEncoding enc) -> llvm::Expected<CircuitGate> {
|
||||
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
auto scalarGate = generateGate(tensor.getElementType(),
|
||||
enc.scalarEncoding, curve, maybeCrt);
|
||||
if (auto err = scalarGate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
if (maybeCrt.has_value()) {
|
||||
// When using crt, the last dimension of the tensor is for the members
|
||||
// of the decomposition. It should not be used.
|
||||
scalarGate->shape.dimensions =
|
||||
tensor.getShape().take_front(tensor.getShape().size() - 1).vec();
|
||||
} else {
|
||||
scalarGate->shape.dimensions = tensor.getShape().vec();
|
||||
}
|
||||
scalarGate->shape.size = 1;
|
||||
for (auto dimSize : scalarGate->shape.dimensions) {
|
||||
scalarGate->shape.size *= dimSize;
|
||||
}
|
||||
return scalarGate;
|
||||
},
|
||||
[&](auto enc) -> llvm::Expected<CircuitGate> {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot convert MLIR type to shape here",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}};
|
||||
return std::visit(genericVisitor, encoding);
|
||||
}
|
||||
|
||||
template <typename V> struct HashValComparator {
|
||||
bool operator()(const V &lhs, const V &rhs) const {
|
||||
return hash_value(lhs) < hash_value(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename V> using Set = llvm::SmallSet<V, 10, HashValComparator<V>>;
|
||||
|
||||
void extractCircuitKeys(ClientParameters &output,
|
||||
TFHE::TFHECircuitKeys circuitKeys,
|
||||
concrete::SecurityCurve curve) {
|
||||
|
||||
// Pushing secret keys
|
||||
for (auto sk : circuitKeys.secretKeys) {
|
||||
clientlib::LweSecretKeyParam skParam;
|
||||
skParam.dimension = sk.getNormalized().value().dimension;
|
||||
output.secretKeys.push_back(skParam);
|
||||
}
|
||||
|
||||
// Pushing keyswitch keys
|
||||
for (auto ksk : circuitKeys.keyswitchKeys) {
|
||||
clientlib::KeyswitchKeyParam kskParam;
|
||||
auto inputNormKey = ksk.getInputKey().getNormalized().value();
|
||||
auto outputNormKey = ksk.getOutputKey().getNormalized().value();
|
||||
kskParam.inputSecretKeyID = inputNormKey.index;
|
||||
kskParam.outputSecretKeyID = outputNormKey.index;
|
||||
kskParam.level = ksk.getLevels();
|
||||
kskParam.baseLog = ksk.getBaseLog();
|
||||
kskParam.variance = curve.getVariance(1, outputNormKey.dimension, 64);
|
||||
output.keyswitchKeys.push_back(kskParam);
|
||||
}
|
||||
|
||||
// Pushing bootstrap keys
|
||||
for (auto bsk : circuitKeys.bootstrapKeys) {
|
||||
clientlib::BootstrapKeyParam bskParam;
|
||||
auto inputNormKey = bsk.getInputKey().getNormalized().value();
|
||||
auto outputNormKey = bsk.getOutputKey().getNormalized().value();
|
||||
bskParam.inputSecretKeyID = inputNormKey.index;
|
||||
bskParam.outputSecretKeyID = outputNormKey.index;
|
||||
bskParam.level = bsk.getLevels();
|
||||
bskParam.baseLog = bsk.getBaseLog();
|
||||
bskParam.glweDimension = bsk.getGlweDim();
|
||||
bskParam.polynomialSize = bsk.getPolySize();
|
||||
bskParam.variance =
|
||||
curve.getVariance(bsk.getGlweDim(), bsk.getPolySize(), 64);
|
||||
bskParam.inputLweDimension = inputNormKey.dimension;
|
||||
output.bootstrapKeys.push_back(bskParam);
|
||||
}
|
||||
|
||||
// Pushing circuit packing keyswitch keys
|
||||
for (auto pksk : circuitKeys.packingKeyswitchKeys) {
|
||||
clientlib::PackingKeyswitchKeyParam pkskParam;
|
||||
auto inputNormKey = pksk.getInputKey().getNormalized().value();
|
||||
auto outputNormKey = pksk.getOutputKey().getNormalized().value();
|
||||
pkskParam.inputSecretKeyID = inputNormKey.index;
|
||||
pkskParam.outputSecretKeyID = outputNormKey.index;
|
||||
pkskParam.level = pksk.getLevels();
|
||||
pkskParam.baseLog = pksk.getBaseLog();
|
||||
pkskParam.glweDimension = pksk.getGlweDim();
|
||||
pkskParam.polynomialSize = pksk.getOutputPolySize();
|
||||
pkskParam.inputLweDimension = inputNormKey.dimension;
|
||||
pkskParam.variance =
|
||||
curve.getVariance(outputNormKey.dimension, outputNormKey.polySize, 64);
|
||||
output.packingKeyswitchKeys.push_back(pkskParam);
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Expected<std::monostate>
|
||||
extractCircuitGates(ClientParameters &output, mlir::func::FuncOp funcOp,
|
||||
encodings::CircuitEncodings encodings,
|
||||
concrete::SecurityCurve curve,
|
||||
std::optional<CRTDecomposition> maybeCrt) {
|
||||
|
||||
// Create input and output circuit gate parameters
|
||||
auto funcType = funcOp.getFunctionType();
|
||||
|
||||
for (auto val : llvm::zip(funcType.getInputs(), encodings.inputEncodings)) {
|
||||
auto ty = std::get<0>(val);
|
||||
auto encoding = std::get<1>(val);
|
||||
auto gate = generateGate(ty, encoding, curve, maybeCrt);
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
output.inputs.push_back(gate.get());
|
||||
}
|
||||
for (auto val : llvm::zip(funcType.getResults(), encodings.outputEncodings)) {
|
||||
auto ty = std::get<0>(val);
|
||||
auto encoding = std::get<1>(val);
|
||||
auto gate = generateGate(ty, encoding, curve, maybeCrt);
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
output.outputs.push_back(gate.get());
|
||||
}
|
||||
|
||||
return std::monostate();
|
||||
}
|
||||
|
||||
llvm::Expected<ClientParameters>
|
||||
createClientParametersFromTFHE(mlir::ModuleOp module,
|
||||
llvm::StringRef functionName, int bitsOfSecurity,
|
||||
encodings::CircuitEncodings encodings,
|
||||
std::optional<CRTDecomposition> maybeCrt) {
|
||||
|
||||
// Check that security curves exist
|
||||
const auto curve = concrete::getSecurityCurve(bitsOfSecurity, keyFormat);
|
||||
if (curve == nullptr) {
|
||||
return StreamStringError("Cannot find security curves for ")
|
||||
<< bitsOfSecurity << "bits";
|
||||
}
|
||||
|
||||
// Check that the specified function can be found
|
||||
auto rangeOps = module.getOps<mlir::func::FuncOp>();
|
||||
auto funcOp = llvm::find_if(rangeOps, [&](mlir::func::FuncOp op) {
|
||||
return op.getName() == functionName;
|
||||
});
|
||||
if (funcOp == rangeOps.end()) {
|
||||
return StreamStringError(
|
||||
"cannot find the function for generate client parameters: ")
|
||||
<< functionName;
|
||||
}
|
||||
|
||||
// Create client parameters
|
||||
ClientParameters output;
|
||||
output.functionName = (std::string)functionName;
|
||||
|
||||
// We extract the keys of the circuit
|
||||
auto circuitKeys = TFHE::extractCircuitKeys(module);
|
||||
|
||||
// We extract all the keys used in the circuit
|
||||
extractCircuitKeys(output, circuitKeys, *curve);
|
||||
|
||||
// We generate the gates for the inputs aud outputs
|
||||
if (auto err =
|
||||
extractCircuitGates(output, *funcOp, encodings, *curve, maybeCrt)
|
||||
.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <llvm/Support/Debug.h>
|
||||
#include <mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h>
|
||||
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
|
||||
#include <mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h>
|
||||
@@ -42,6 +43,7 @@
|
||||
#include <concretelang/Dialect/Tracing/Transforms/BufferizableOpInterfaceImpl.h>
|
||||
#include <concretelang/Runtime/DFRuntime.hpp>
|
||||
#include <concretelang/Support/CompilerEngine.h>
|
||||
#include <concretelang/Support/Encodings.h>
|
||||
#include <concretelang/Support/Error.h>
|
||||
#include <concretelang/Support/Jit.h>
|
||||
#include <concretelang/Support/LLVMEmitFile.h>
|
||||
@@ -163,12 +165,8 @@ CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) {
|
||||
auto description = descriptions->find(name);
|
||||
if (description == descriptions->end()) {
|
||||
std::string names;
|
||||
for (auto &entry : *descriptions) {
|
||||
names += "'" + entry.first + "' ";
|
||||
}
|
||||
return StreamStringError()
|
||||
<< "Could not find existing crypto parameters for function '"
|
||||
<< name << "' (known functions: " << names << ")";
|
||||
return StreamStringError("Function not found, name='")
|
||||
<< name << "', cannot get optimizer description";
|
||||
}
|
||||
return std::move(description->second);
|
||||
}
|
||||
@@ -194,6 +192,10 @@ llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
|
||||
}
|
||||
res.fheContext.emplace(
|
||||
mlir::concretelang::V0FHEContext{constraint, v0Params});
|
||||
|
||||
CompilationFeedback feedback;
|
||||
res.feedback.emplace(feedback);
|
||||
|
||||
return llvm::Error::success();
|
||||
}
|
||||
// compute parameters
|
||||
@@ -286,6 +288,25 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
if (target == Target::ROUND_TRIP)
|
||||
return std::move(res);
|
||||
|
||||
// Retrieves the encoding informations before any transformation is performed
|
||||
// on the `FHE` dialect.
|
||||
if ((this->generateClientParameters || target == Target::LIBRARY) &&
|
||||
!options.encodings.has_value()) {
|
||||
auto funcName = options.clientParametersFuncName.value_or("main");
|
||||
auto maybeChunkInfo =
|
||||
options.chunkIntegers
|
||||
? std::optional(concretelang::clientlib::ChunkInfo{
|
||||
options.chunkSize, options.chunkWidth})
|
||||
: std::nullopt;
|
||||
auto encodingInfosOrErr =
|
||||
mlir::concretelang::encodings::getCircuitEncodings(funcName, module,
|
||||
maybeChunkInfo);
|
||||
if (!encodingInfosOrErr) {
|
||||
return encodingInfosOrErr.takeError();
|
||||
}
|
||||
options.encodings = encodingInfosOrErr.get();
|
||||
}
|
||||
|
||||
if (mlir::concretelang::pipeline::transformFHEBoolean(mlirContext, module,
|
||||
enablePass)
|
||||
.failed()) {
|
||||
@@ -345,46 +366,6 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
if (target == Target::FHE_NO_LINALG)
|
||||
return std::move(res);
|
||||
|
||||
// Generate client parameters if requested
|
||||
if (this->generateClientParameters) {
|
||||
if (!options.clientParametersFuncName.has_value()) {
|
||||
return StreamStringError(
|
||||
"Generation of client parameters requested, but no function name "
|
||||
"specified");
|
||||
}
|
||||
if (!res.fheContext.has_value()) {
|
||||
return StreamStringError(
|
||||
"Cannot generate client parameters, the fhe context is empty for " +
|
||||
options.clientParametersFuncName.value());
|
||||
}
|
||||
}
|
||||
// Generate client parameters if requested
|
||||
auto funcName = options.clientParametersFuncName.value_or("main");
|
||||
if (this->generateClientParameters || target == Target::LIBRARY) {
|
||||
if (!res.fheContext.has_value()) {
|
||||
// Some tests involve call a to non encrypted functions
|
||||
ClientParameters emptyParams;
|
||||
emptyParams.functionName = funcName;
|
||||
res.clientParameters = emptyParams;
|
||||
} else {
|
||||
llvm::Optional<::concretelang::clientlib::ChunkInfo> chunkInfo =
|
||||
std::nullopt;
|
||||
if (options.chunkIntegers) {
|
||||
chunkInfo = ::concretelang::clientlib::ChunkInfo{options.chunkSize,
|
||||
options.chunkWidth};
|
||||
}
|
||||
auto clientParametersOrErr =
|
||||
mlir::concretelang::createClientParametersForV0(
|
||||
*res.fheContext, funcName, module,
|
||||
options.optimizerConfig.security, chunkInfo);
|
||||
if (!clientParametersOrErr)
|
||||
return clientParametersOrErr.takeError();
|
||||
|
||||
res.clientParameters = clientParametersOrErr.get();
|
||||
res.feedback->fillFromClientParameters(*res.clientParameters);
|
||||
}
|
||||
}
|
||||
|
||||
// FHE -> TFHE
|
||||
if (mlir::concretelang::pipeline::lowerFHEToTFHE(mlirContext, module,
|
||||
res.fheContext, enablePass)
|
||||
@@ -412,6 +393,57 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
if (target == Target::PARAMETRIZED_TFHE)
|
||||
return std::move(res);
|
||||
|
||||
// Normalize TFHE keys
|
||||
if (mlir::concretelang::pipeline::normalizeTFHEKeys(mlirContext, module,
|
||||
this->enablePass)
|
||||
.failed()) {
|
||||
return errorDiag("Normalizing TFHE keys failed");
|
||||
}
|
||||
|
||||
// Generate client parameters if requested
|
||||
if (this->generateClientParameters) {
|
||||
if (!options.clientParametersFuncName.has_value()) {
|
||||
return StreamStringError(
|
||||
"Generation of client parameters requested, but no function name "
|
||||
"specified");
|
||||
}
|
||||
if (!res.fheContext.has_value()) {
|
||||
return StreamStringError(
|
||||
"Cannot generate client parameters, the fhe context is empty for " +
|
||||
options.clientParametersFuncName.value());
|
||||
}
|
||||
}
|
||||
// Generate client parameters if requested
|
||||
if (this->generateClientParameters || target == Target::LIBRARY) {
|
||||
auto funcName = options.clientParametersFuncName.value_or("main");
|
||||
if (!res.fheContext.has_value()) {
|
||||
// Some tests involve call a to non encrypted functions
|
||||
ClientParameters emptyParams;
|
||||
emptyParams.functionName = funcName;
|
||||
res.clientParameters = emptyParams;
|
||||
} else {
|
||||
std::optional<CRTDecomposition> maybeCrt = std::nullopt;
|
||||
if (res.fheContext.value().parameter.largeInteger.has_value()) {
|
||||
maybeCrt = res.fheContext.value()
|
||||
.parameter.largeInteger.value()
|
||||
.crtDecomposition;
|
||||
}
|
||||
auto clientParametersOrErr =
|
||||
mlir::concretelang::createClientParametersFromTFHE(
|
||||
module, funcName, options.optimizerConfig.security,
|
||||
options.encodings.value(), maybeCrt);
|
||||
|
||||
if (!clientParametersOrErr)
|
||||
return clientParametersOrErr.takeError();
|
||||
|
||||
res.clientParameters = clientParametersOrErr.get();
|
||||
res.feedback->fillFromClientParameters(*res.clientParameters);
|
||||
}
|
||||
}
|
||||
|
||||
if (target == Target::NORMALIZED_TFHE)
|
||||
return std::move(res);
|
||||
|
||||
if (options.batchTFHEOps) {
|
||||
if (mlir::concretelang::pipeline::batchTFHE(mlirContext, module, enablePass)
|
||||
.failed()) {
|
||||
|
||||
247
compilers/concrete-compiler/compiler/lib/Support/Encodings.cpp
Normal file
247
compilers/concrete-compiler/compiler/lib/Support/Encodings.cpp
Normal file
@@ -0,0 +1,247 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <concretelang/ClientLib/ClientParameters.h>
|
||||
#include <concretelang/Dialect/FHE/IR/FHETypes.h>
|
||||
#include <concretelang/Support/Encodings.h>
|
||||
#include <concretelang/Support/Error.h>
|
||||
#include <concretelang/Support/Variants.h>
|
||||
#include <optional>
|
||||
#include <variant>
|
||||
|
||||
namespace FHE = mlir::concretelang::FHE;
|
||||
namespace clientlib = concretelang::clientlib;
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace encodings {
|
||||
|
||||
std::optional<Encoding>
|
||||
encodingFromType(mlir::Type ty,
|
||||
std::optional<clientlib::ChunkInfo> maybeChunkInfo) {
|
||||
if (auto eintTy = ty.dyn_cast<FHE::FheIntegerInterface>()) {
|
||||
if (maybeChunkInfo.has_value() &&
|
||||
eintTy.getWidth() > maybeChunkInfo.value().size) {
|
||||
auto chunkInfo = maybeChunkInfo.value();
|
||||
return EncryptedChunkedIntegerScalarEncoding{
|
||||
eintTy.getWidth(), eintTy.isSigned(), chunkInfo.width,
|
||||
chunkInfo.size};
|
||||
} else {
|
||||
return EncryptedIntegerScalarEncoding{eintTy.getWidth(),
|
||||
eintTy.isSigned()};
|
||||
}
|
||||
} else if (auto eboolTy = ty.dyn_cast<FHE::EncryptedBooleanType>()) {
|
||||
return EncryptedBoolScalarEncoding{};
|
||||
} else if (auto intTy = ty.dyn_cast<mlir::IntegerType>()) {
|
||||
return PlaintextScalarEncoding{intTy.getWidth()};
|
||||
} else if (auto indexTy = ty.dyn_cast<mlir::IndexType>()) {
|
||||
return IndexScalarEncoding{};
|
||||
} else if (auto tensor = ty.dyn_cast<mlir::RankedTensorType>()) {
|
||||
std::optional<Encoding> maybeEncoding =
|
||||
encodingFromType(tensor.getElementType(), maybeChunkInfo);
|
||||
if (maybeEncoding.has_value() &&
|
||||
std::holds_alternative<ScalarEncoding>(maybeEncoding.value())) {
|
||||
ScalarEncoding scalarEncoding =
|
||||
std::get<ScalarEncoding>(maybeEncoding.value());
|
||||
return TensorEncoding{scalarEncoding};
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
llvm::Expected<CircuitEncodings>
|
||||
getCircuitEncodings(llvm::StringRef functionName, mlir::ModuleOp module,
|
||||
std::optional<clientlib::ChunkInfo> maybeChunkInfo) {
|
||||
|
||||
// Find the input function
|
||||
auto rangeOps = module.getOps<mlir::func::FuncOp>();
|
||||
auto funcOp = llvm::find_if(rangeOps, [&](mlir::func::FuncOp op) {
|
||||
return op.getName() == functionName;
|
||||
});
|
||||
if (funcOp == rangeOps.end()) {
|
||||
return StreamStringError("Function not found, name='")
|
||||
<< functionName << "', cannot get circuit encodings";
|
||||
}
|
||||
auto funcType = (*funcOp).getFunctionType();
|
||||
|
||||
// Retrieve input/output encodings
|
||||
std::vector<Encoding> inputs;
|
||||
std::vector<Encoding> outputs;
|
||||
for (auto ty : funcType.getInputs()) {
|
||||
auto maybeGate = encodingFromType(ty, maybeChunkInfo);
|
||||
if (!maybeGate.has_value()) {
|
||||
return StreamStringError("Failed to recognize encoding for type : ")
|
||||
<< ty;
|
||||
}
|
||||
inputs.push_back(maybeGate.value());
|
||||
}
|
||||
for (auto ty : funcType.getResults()) {
|
||||
auto maybeGate = encodingFromType(ty, maybeChunkInfo);
|
||||
if (!maybeGate.has_value()) {
|
||||
return StreamStringError("Failed to recognize encoding for type : ")
|
||||
<< ty;
|
||||
}
|
||||
outputs.push_back(maybeGate.value());
|
||||
}
|
||||
|
||||
return CircuitEncodings{inputs, outputs};
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value j, EncryptedIntegerScalarEncoding &e,
|
||||
llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("width", e.width) && O.map("isSigned", e.isSigned);
|
||||
}
|
||||
llvm::json::Value toJSON(const EncryptedIntegerScalarEncoding &e) {
|
||||
llvm::json::Object object{
|
||||
{"width", e.width},
|
||||
{"isSigned", e.isSigned},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value j,
|
||||
EncryptedChunkedIntegerScalarEncoding &e, llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("width", e.width) && O.map("isSigned", e.isSigned) &&
|
||||
O.map("chunkSize", e.chunkSize) && O.map("chunkWidth", e.chunkWidth);
|
||||
}
|
||||
llvm::json::Value toJSON(const EncryptedChunkedIntegerScalarEncoding &e) {
|
||||
llvm::json::Object object{
|
||||
{"width", e.width},
|
||||
{"isSigned", e.isSigned},
|
||||
{"chunkSize", e.chunkSize},
|
||||
{"chunkWidth", e.chunkWidth},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value j, EncryptedBoolScalarEncoding &e,
|
||||
llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O;
|
||||
}
|
||||
llvm::json::Value toJSON(const EncryptedBoolScalarEncoding &e) {
|
||||
llvm::json::Object object{};
|
||||
return object;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value j, PlaintextScalarEncoding &e,
|
||||
llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("width", e.width);
|
||||
}
|
||||
llvm::json::Value toJSON(const PlaintextScalarEncoding &e) {
|
||||
llvm::json::Object object{{"width", e.width}};
|
||||
return object;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value j, IndexScalarEncoding &e,
|
||||
llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O;
|
||||
}
|
||||
llvm::json::Value toJSON(const IndexScalarEncoding &e) {
|
||||
llvm::json::Object object{};
|
||||
return object;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value j, ScalarEncoding &e,
|
||||
llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
if (j.getAsObject()->getObject("EncryptedIntegerScalarEncoding")) {
|
||||
return O && O.map("EncryptedIntegerScalarEncoding",
|
||||
std::get<EncryptedIntegerScalarEncoding>(e));
|
||||
} else if (j.getAsObject()->getObject(
|
||||
"EncryptedChunkedIntegerScalarEncoding")) {
|
||||
return O && O.map("EncryptedChunkedIntegerScalarEncoding",
|
||||
std::get<EncryptedChunkedIntegerScalarEncoding>(e));
|
||||
} else if (j.getAsObject()->getObject("EncryptedBoolScalarEncoding")) {
|
||||
return O && O.map("EncryptedBoolScalarEncoding",
|
||||
std::get<EncryptedBoolScalarEncoding>(e));
|
||||
} else if (j.getAsObject()->getObject("PlaintextScalarEncoding")) {
|
||||
return O && O.map("PlaintextScalarEncoding",
|
||||
std::get<PlaintextScalarEncoding>(e));
|
||||
} else if (j.getAsObject()->getObject("IndexScalarEncoding")) {
|
||||
return O && O.map("IndexScalarEncoding", std::get<IndexScalarEncoding>(e));
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
llvm::json::Value toJSON(const ScalarEncoding &e) {
|
||||
llvm::json::Object object = std::visit(
|
||||
overloaded{
|
||||
[](EncryptedIntegerScalarEncoding enc) {
|
||||
return llvm::json::Object{{"EncryptedIntegerScalarEncoding", enc}};
|
||||
},
|
||||
[](EncryptedChunkedIntegerScalarEncoding enc) {
|
||||
return llvm::json::Object{
|
||||
{"EncryptedChunkedIntegerScalarEncoding", enc}};
|
||||
},
|
||||
[](EncryptedBoolScalarEncoding enc) {
|
||||
return llvm::json::Object{{"EncryptedBoolScalarEncoding", enc}};
|
||||
},
|
||||
[](PlaintextScalarEncoding enc) {
|
||||
return llvm::json::Object{{"PlaintextScalarEncoding", enc}};
|
||||
},
|
||||
[](IndexScalarEncoding enc) {
|
||||
return llvm::json::Object{{"IndexScalarEncoding", enc}};
|
||||
},
|
||||
},
|
||||
e);
|
||||
return object;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value j, TensorEncoding &e,
|
||||
llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("scalarEncoding", e.scalarEncoding);
|
||||
}
|
||||
llvm::json::Value toJSON(const TensorEncoding &e) {
|
||||
llvm::json::Object object{{"scalarEncoding", e.scalarEncoding}};
|
||||
return object;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value j, Encoding &e, llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
if (j.getAsObject()->getObject("ScalarEncoding")) {
|
||||
e = EncryptedIntegerScalarEncoding{0, false};
|
||||
return O && O.map("ScalarEncoding", std::get<ScalarEncoding>(e));
|
||||
} else if (j.getAsObject()->getObject("TensorEncoding")) {
|
||||
e = TensorEncoding{EncryptedIntegerScalarEncoding{0, false}};
|
||||
return O && O.map("TensorEncoding", std::get<TensorEncoding>(e));
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
llvm::json::Value toJSON(const Encoding &e) {
|
||||
llvm::json::Object object =
|
||||
std::visit(overloaded{
|
||||
[](ScalarEncoding enc) {
|
||||
return llvm::json::Object{{"ScalarEncoding", enc}};
|
||||
},
|
||||
[](TensorEncoding enc) {
|
||||
return llvm::json::Object{{"TensorEncoding", enc}};
|
||||
},
|
||||
},
|
||||
e);
|
||||
return object;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value j, CircuitEncodings &e,
|
||||
llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("inputEncodings", e.inputEncodings) &&
|
||||
O.map("outputEncodings", e.outputEncodings);
|
||||
}
|
||||
llvm::json::Value toJSON(const CircuitEncodings &e) {
|
||||
llvm::json::Object object{{"inputEncodings", e.inputEncodings},
|
||||
{"outputEncodings", e.outputEncodings}};
|
||||
return object;
|
||||
}
|
||||
|
||||
} // namespace encodings
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -28,6 +28,7 @@
|
||||
#include <mlir/Target/LLVMIR/Export.h>
|
||||
#include <mlir/Transforms/Passes.h>
|
||||
|
||||
#include "concretelang/Conversion/TFHEKeyNormalization/Pass.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include <concretelang/Conversion/Passes.h>
|
||||
@@ -295,6 +296,18 @@ mlir::LogicalResult batchTFHE(mlir::MLIRContext &context,
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
normalizeTFHEKeys(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("TFHEKeyNormalization", pm, context);
|
||||
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createTFHEKeyNormalizationPass(), enablePass);
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/Support/TFHECircuitKeys.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEAttrs.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEParameters.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include <llvm/ADT/SmallVector.h>
|
||||
#include <optional>
|
||||
|
||||
// Faster than using a full fledged hash-set for small sets, and the array can
|
||||
// be recovered right away.
|
||||
template <typename V> struct SmallSet {
|
||||
llvm::SmallVector<V, 10> vector;
|
||||
|
||||
void insert(V val) {
|
||||
for (auto vectorVal : vector) {
|
||||
if (vectorVal == val) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
vector.push_back(val);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename V, unsigned N>
|
||||
std::optional<size_t> vectorIndex(llvm::SmallVector<V, N> vector, V val) {
|
||||
for (size_t i = 0; i < vector.size(); i++) {
|
||||
auto potentialVal = vector[i];
|
||||
if (potentialVal == val) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace TFHE {
|
||||
|
||||
template <typename V, unsigned int N>
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
const mlir::SmallVector<V, N> vect) {
|
||||
OS << "[";
|
||||
for (auto v : vect) {
|
||||
OS << v << ",";
|
||||
}
|
||||
OS << "]";
|
||||
return OS;
|
||||
}
|
||||
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
const TFHECircuitKeys cks) {
|
||||
|
||||
OS << "TFHECircuitKeys{\n"
|
||||
<< " secretKeys:" << cks.secretKeys << "\n"
|
||||
<< " keyswitchKeys:" << cks.keyswitchKeys << "\n"
|
||||
<< " bootstrapKeys:" << cks.bootstrapKeys << "\n"
|
||||
<< " packingKeyswitchKeys:" << cks.packingKeyswitchKeys
|
||||
<< "\n"
|
||||
"}";
|
||||
return OS;
|
||||
}
|
||||
|
||||
TFHECircuitKeys extractCircuitKeys(mlir::ModuleOp moduleOp) {
|
||||
// Gathering circuit secret keys
|
||||
SmallSet<TFHE::GLWESecretKey> secretKeys;
|
||||
auto tryInsert = [&](mlir::Type type) {
|
||||
if (auto glweType = type.dyn_cast<TFHE::GLWECipherTextType>()) {
|
||||
secretKeys.insert(glweType.getKey());
|
||||
} else if (auto tensorType = type.dyn_cast<mlir::RankedTensorType>()) {
|
||||
if (auto elementType = tensorType.getElementType()
|
||||
.dyn_cast<TFHE::GLWECipherTextType>()) {
|
||||
secretKeys.insert(elementType.getKey());
|
||||
}
|
||||
}
|
||||
};
|
||||
moduleOp->walk([&](mlir::Operation *op) {
|
||||
for (auto operand : op->getOperands()) {
|
||||
tryInsert(operand.getType());
|
||||
}
|
||||
for (auto result : op->getResults()) {
|
||||
tryInsert(result.getType());
|
||||
}
|
||||
});
|
||||
moduleOp->walk([&](mlir::func::FuncOp op) {
|
||||
for (auto argType : op.getArgumentTypes()) {
|
||||
tryInsert(argType);
|
||||
}
|
||||
for (auto resultType : op.getResultTypes()) {
|
||||
tryInsert(resultType);
|
||||
}
|
||||
});
|
||||
|
||||
// Gathering circuit keyswitch keys
|
||||
SmallSet<TFHE::GLWEKeyswitchKeyAttr> keyswitchKeys;
|
||||
moduleOp->walk([&](TFHE::KeySwitchGLWEOp op) {
|
||||
keyswitchKeys.insert(op.getKeyAttr());
|
||||
secretKeys.insert(op.getKeyAttr().getInputKey());
|
||||
secretKeys.insert(op.getKeyAttr().getOutputKey());
|
||||
});
|
||||
|
||||
// Gathering circuit bootstrap keys
|
||||
SmallSet<TFHE::GLWEBootstrapKeyAttr> bootstrapKeys;
|
||||
moduleOp->walk([&](TFHE::BootstrapGLWEOp op) {
|
||||
bootstrapKeys.insert(op.getKeyAttr());
|
||||
secretKeys.insert(op.getKeyAttr().getInputKey());
|
||||
secretKeys.insert(op.getKeyAttr().getOutputKey());
|
||||
});
|
||||
|
||||
// Gathering circuit packing keyswitch keys
|
||||
SmallSet<TFHE::GLWEPackingKeyswitchKeyAttr> packingKeyswitchKeys;
|
||||
moduleOp->walk([&](TFHE::WopPBSGLWEOp op) {
|
||||
keyswitchKeys.insert(op.getKskAttr());
|
||||
secretKeys.insert(op.getKskAttr().getInputKey());
|
||||
secretKeys.insert(op.getKskAttr().getOutputKey());
|
||||
bootstrapKeys.insert(op.getBskAttr());
|
||||
secretKeys.insert(op.getBskAttr().getInputKey());
|
||||
secretKeys.insert(op.getBskAttr().getOutputKey());
|
||||
packingKeyswitchKeys.insert(op.getPkskAttr());
|
||||
secretKeys.insert(op.getPkskAttr().getInputKey());
|
||||
secretKeys.insert(op.getPkskAttr().getOutputKey());
|
||||
});
|
||||
|
||||
return TFHECircuitKeys{secretKeys.vector, bootstrapKeys.vector,
|
||||
keyswitchKeys.vector, packingKeyswitchKeys.vector};
|
||||
}
|
||||
|
||||
std::optional<uint64_t>
|
||||
TFHE::TFHECircuitKeys::getSecretKeyIndex(TFHE::GLWESecretKey key) {
|
||||
return vectorIndex(this->secretKeys, key);
|
||||
}
|
||||
|
||||
std::optional<uint64_t>
|
||||
TFHE::TFHECircuitKeys::getBootstrapKeyIndex(TFHE::GLWEBootstrapKeyAttr key) {
|
||||
return vectorIndex(this->bootstrapKeys, key);
|
||||
}
|
||||
|
||||
std::optional<uint64_t>
|
||||
TFHE::TFHECircuitKeys::getKeyswitchKeyIndex(TFHE::GLWEKeyswitchKeyAttr key) {
|
||||
return vectorIndex(this->keyswitchKeys, key);
|
||||
}
|
||||
|
||||
std::optional<uint64_t> TFHE::TFHECircuitKeys::getPackingKeyswitchKeyIndex(
|
||||
TFHE::GLWEPackingKeyswitchKeyAttr key) {
|
||||
return vectorIndex(this->packingKeyswitchKeys, key);
|
||||
}
|
||||
|
||||
} // namespace TFHE
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -1,257 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
#include <cassert>
|
||||
#include <map>
|
||||
|
||||
#include <llvm/ADT/Optional.h>
|
||||
#include <llvm/ADT/STLExtras.h>
|
||||
#include <llvm/Support/Error.h>
|
||||
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
#include <optional>
|
||||
|
||||
#include "concrete/curves.h"
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
namespace clientlib = ::concretelang::clientlib;
|
||||
using ::concretelang::clientlib::ChunkInfo;
|
||||
using ::concretelang::clientlib::CircuitGate;
|
||||
using ::concretelang::clientlib::ClientParameters;
|
||||
using ::concretelang::clientlib::Encoding;
|
||||
using ::concretelang::clientlib::EncryptionGate;
|
||||
using ::concretelang::clientlib::LweSecretKeyID;
|
||||
using ::concretelang::clientlib::Precision;
|
||||
using ::concretelang::clientlib::Variance;
|
||||
|
||||
const auto keyFormat = concrete::BINARY;
|
||||
|
||||
/// For the v0 the secretKeyID and precision are the same for all gates.
|
||||
llvm::Expected<CircuitGate>
|
||||
gateFromMLIRType(V0FHEContext fheContext, LweSecretKeyID secretKeyID,
|
||||
Variance variance, llvm::Optional<ChunkInfo> chunkInfo,
|
||||
mlir::Type type) {
|
||||
if (type.isIntOrIndex()) {
|
||||
// TODO - The index type is dependant of the target architecture, so
|
||||
// actually we assume we target only 64 bits, we need to have some the size
|
||||
// of the word of the target system.
|
||||
size_t width = 64;
|
||||
if (!type.isIndex()) {
|
||||
width = type.getIntOrFloatBitWidth();
|
||||
}
|
||||
|
||||
bool sign = type.isSignedInteger();
|
||||
|
||||
return CircuitGate{
|
||||
/*.encryption = */ std::nullopt,
|
||||
/*.shape = */
|
||||
{/*.width = */ width,
|
||||
/*.dimensions = */ std::vector<int64_t>(),
|
||||
/*.size = */ 0,
|
||||
/* .sign */ sign},
|
||||
/*.chunkInfo = */ std::nullopt,
|
||||
};
|
||||
}
|
||||
if (auto lweTy = type.dyn_cast_or_null<
|
||||
mlir::concretelang::FHE::FheIntegerInterface>()) {
|
||||
bool sign = lweTy.isSigned();
|
||||
std::vector<int64_t> crt;
|
||||
if (fheContext.parameter.largeInteger.has_value()) {
|
||||
crt = fheContext.parameter.largeInteger.value().crtDecomposition;
|
||||
}
|
||||
size_t width;
|
||||
uint64_t size = 0;
|
||||
std::vector<int64_t> dims;
|
||||
if (chunkInfo.has_value()) {
|
||||
width = chunkInfo->size;
|
||||
assert(lweTy.getWidth() % chunkInfo->width == 0);
|
||||
size = lweTy.getWidth() / chunkInfo->width;
|
||||
dims.push_back(size);
|
||||
} else {
|
||||
width = (size_t)lweTy.getWidth();
|
||||
}
|
||||
return CircuitGate{
|
||||
/* .encryption = */ std::optional<EncryptionGate>({
|
||||
/* .secretKeyID = */ secretKeyID,
|
||||
/* .variance = */ variance,
|
||||
/* .encoding = */
|
||||
{
|
||||
/* .precision = */ width,
|
||||
/* .crt = */ crt,
|
||||
/*.sign = */ sign,
|
||||
},
|
||||
}),
|
||||
/*.shape = */
|
||||
{
|
||||
/*.width = */ width,
|
||||
/*.dimensions = */ dims,
|
||||
/*.size = */ size,
|
||||
/*.sign = */ sign,
|
||||
},
|
||||
/*.chunkInfo = */ chunkInfo,
|
||||
};
|
||||
}
|
||||
if (auto lweTy = type.dyn_cast_or_null<
|
||||
mlir::concretelang::FHE::EncryptedBooleanType>()) {
|
||||
size_t width = mlir::concretelang::FHE::EncryptedBooleanType::getWidth();
|
||||
return CircuitGate{
|
||||
/* .encryption = */ std::optional<EncryptionGate>({
|
||||
/* .secretKeyID = */ secretKeyID,
|
||||
/* .variance = */ variance,
|
||||
/* .encoding = */
|
||||
{
|
||||
/* .precision = */ width,
|
||||
/* .crt = */ std::vector<int64_t>(),
|
||||
/* .sign = */ false,
|
||||
},
|
||||
}),
|
||||
/*.shape = */
|
||||
{
|
||||
/*.width = */ width,
|
||||
/*.dimensions = */ std::vector<int64_t>(),
|
||||
/*.size = */ 0,
|
||||
/*.sign = */ false,
|
||||
},
|
||||
/*.chunkInfo = */ std::nullopt,
|
||||
};
|
||||
}
|
||||
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
if (tensor != nullptr) {
|
||||
auto gate = gateFromMLIRType(fheContext, secretKeyID, variance, chunkInfo,
|
||||
tensor.getElementType());
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
gate->shape.dimensions = tensor.getShape().vec();
|
||||
gate->shape.size = 1;
|
||||
for (auto dimSize : gate->shape.dimensions) {
|
||||
gate->shape.size *= dimSize;
|
||||
}
|
||||
return gate;
|
||||
}
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot convert MLIR type to shape", llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
llvm::Expected<ClientParameters>
|
||||
createClientParametersForV0(V0FHEContext fheContext,
|
||||
llvm::StringRef functionName, mlir::ModuleOp module,
|
||||
int bitsOfSecurity,
|
||||
llvm::Optional<ChunkInfo> chunkInfo) {
|
||||
const auto v0Curve = concrete::getSecurityCurve(bitsOfSecurity, keyFormat);
|
||||
|
||||
if (v0Curve == nullptr) {
|
||||
return StreamStringError("Cannot find security curves for ")
|
||||
<< bitsOfSecurity << "bits";
|
||||
}
|
||||
|
||||
V0Parameter &v0Param = fheContext.parameter;
|
||||
Variance inputVariance =
|
||||
v0Curve->getVariance(1, v0Param.getNBigLweDimension(), 64);
|
||||
|
||||
Variance bootstrapKeyVariance = v0Curve->getVariance(
|
||||
v0Param.glweDimension, v0Param.getPolynomialSize(), 64);
|
||||
Variance keyswitchKeyVariance = v0Curve->getVariance(1, v0Param.nSmall, 64);
|
||||
// Static client parameters from global parameters for v0
|
||||
ClientParameters c;
|
||||
|
||||
assert(c.secretKeys.size() == clientlib::BIG_KEY);
|
||||
clientlib::LweSecretKeyParam skParam;
|
||||
skParam.dimension = v0Param.getNBigLweDimension();
|
||||
c.secretKeys.push_back(skParam);
|
||||
|
||||
bool has_small_key = v0Param.nSmall != 0;
|
||||
bool has_bootstrap = v0Param.brLevel != 0;
|
||||
if (has_small_key) {
|
||||
assert(c.secretKeys.size() == clientlib::SMALL_KEY);
|
||||
clientlib::LweSecretKeyParam skParam2;
|
||||
skParam2.dimension = v0Param.nSmall;
|
||||
c.secretKeys.push_back(skParam2);
|
||||
}
|
||||
if (has_bootstrap) {
|
||||
auto inputKey = (has_small_key) ? clientlib::SMALL_KEY : clientlib::BIG_KEY;
|
||||
clientlib::BootstrapKeyParam bskParam;
|
||||
bskParam.inputSecretKeyID = inputKey;
|
||||
bskParam.outputSecretKeyID = clientlib::BIG_KEY;
|
||||
bskParam.level = v0Param.brLevel;
|
||||
bskParam.baseLog = v0Param.brLogBase;
|
||||
bskParam.glweDimension = v0Param.glweDimension;
|
||||
bskParam.variance = bootstrapKeyVariance;
|
||||
bskParam.polynomialSize = v0Param.getPolynomialSize();
|
||||
bskParam.inputLweDimension = v0Param.nSmall;
|
||||
c.bootstrapKeys.push_back(bskParam);
|
||||
}
|
||||
if (v0Param.largeInteger.has_value()) {
|
||||
clientlib::PackingKeyswitchKeyParam param;
|
||||
param.inputSecretKeyID = clientlib::BIG_KEY;
|
||||
param.outputSecretKeyID = clientlib::BIG_KEY;
|
||||
param.level = v0Param.largeInteger->wopPBS.packingKeySwitch.level;
|
||||
param.baseLog = v0Param.largeInteger->wopPBS.packingKeySwitch.baseLog;
|
||||
|
||||
param.glweDimension = v0Param.glweDimension;
|
||||
param.polynomialSize = v0Param.getPolynomialSize();
|
||||
param.inputLweDimension = v0Param.getNBigLweDimension();
|
||||
param.variance = v0Curve->getVariance(v0Param.glweDimension,
|
||||
v0Param.getPolynomialSize(), 64);
|
||||
|
||||
c.packingKeyswitchKeys.push_back(param);
|
||||
}
|
||||
if (has_small_key) {
|
||||
clientlib::KeyswitchKeyParam kskParam;
|
||||
kskParam.inputSecretKeyID = clientlib::BIG_KEY;
|
||||
kskParam.outputSecretKeyID = clientlib::SMALL_KEY;
|
||||
kskParam.level = v0Param.ksLevel;
|
||||
kskParam.baseLog = v0Param.ksLogBase;
|
||||
kskParam.variance = keyswitchKeyVariance;
|
||||
c.keyswitchKeys.push_back(kskParam);
|
||||
}
|
||||
|
||||
c.functionName = (std::string)functionName;
|
||||
// Find the input function
|
||||
auto rangeOps = module.getOps<mlir::func::FuncOp>();
|
||||
auto funcOp = llvm::find_if(rangeOps, [&](mlir::func::FuncOp op) {
|
||||
return op.getName() == functionName;
|
||||
});
|
||||
if (funcOp == rangeOps.end()) {
|
||||
return StreamStringError(
|
||||
"cannot find the function for generate client parameters: ")
|
||||
<< functionName;
|
||||
}
|
||||
|
||||
// Create input and output circuit gate parameters
|
||||
auto funcType = (*funcOp).getFunctionType();
|
||||
|
||||
auto inputs = funcType.getInputs();
|
||||
|
||||
auto gateFromType = [&](mlir::Type ty) {
|
||||
return gateFromMLIRType(fheContext, clientlib::BIG_KEY, inputVariance,
|
||||
chunkInfo, ty);
|
||||
};
|
||||
for (auto inType : inputs) {
|
||||
auto gate = gateFromType(inType);
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
c.inputs.push_back(gate.get());
|
||||
}
|
||||
for (auto outType : funcType.getResults()) {
|
||||
auto gate = gateFromType(outType);
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
c.outputs.push_back(gate.get());
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -254,7 +254,7 @@ llvm::Expected<V0Parameter> getParameter(optimizer::Description &descr,
|
||||
lParams.wopPBS.circuitBootstrap.baseLog = sol.cb_decomposition_base_log;
|
||||
lParams.wopPBS.circuitBootstrap.level = sol.cb_decomposition_level_count;
|
||||
lParams.wopPBS.packingKeySwitch.inputLweDimension =
|
||||
sol.internal_ks_output_lwe_dimension + 1;
|
||||
sol.internal_ks_output_lwe_dimension;
|
||||
lParams.wopPBS.packingKeySwitch.outputPolynomialSize =
|
||||
sol.glwe_polynomial_size;
|
||||
lParams.wopPBS.packingKeySwitch.level = sol.pp_decomposition_level_count;
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <iostream>
|
||||
|
||||
#include <llvm/Support/CommandLine.h>
|
||||
#include <llvm/Support/JSON.h>
|
||||
#include <llvm/Support/SourceMgr.h>
|
||||
#include <llvm/Support/ToolOutputFile.h>
|
||||
#include <mlir/Dialect/Linalg/IR/Linalg.h>
|
||||
@@ -34,6 +35,7 @@
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Encodings.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "concretelang/Support/JITSupport.h"
|
||||
#include "concretelang/Support/LLVMEmitFile.h"
|
||||
@@ -43,12 +45,14 @@
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
||||
namespace clientlib = concretelang::clientlib;
|
||||
namespace encodings = mlir::concretelang::encodings;
|
||||
|
||||
enum Action {
|
||||
ROUND_TRIP,
|
||||
DUMP_FHE,
|
||||
DUMP_FHE_NO_LINALG,
|
||||
DUMP_TFHE,
|
||||
DUMP_NORMALIZED_TFHE,
|
||||
DUMP_PARAMETRIZED_TFHE,
|
||||
DUMP_BATCHED_TFHE,
|
||||
DUMP_CONCRETE,
|
||||
@@ -124,8 +128,12 @@ static llvm::cl::opt<enum Action> action(
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_FHE_NO_LINALG,
|
||||
"dump-fhe-no-linalg",
|
||||
"Lower FHELinalg to FHE and dump result")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_TFHE, "dump-tfhe",
|
||||
"Lower to TFHE and dump result")),
|
||||
llvm::cl::values(
|
||||
clEnumValN(Action::DUMP_TFHE, "dump-tfhe",
|
||||
"Lower to unparameterized TFHE and dump result")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_NORMALIZED_TFHE,
|
||||
"dump-normalized-tfhe",
|
||||
"Lower to normalized TFHE and dump result")),
|
||||
llvm::cl::values(clEnumValN(
|
||||
Action::DUMP_PARAMETRIZED_TFHE, "dump-parametrized-tfhe",
|
||||
"Lower to TFHE, parametrize TFHE operations and dump result")),
|
||||
@@ -327,6 +335,12 @@ llvm::cl::list<int64_t> largeIntegerCircuitBootstrap(
|
||||
"(experimental) [level, baseLog]"),
|
||||
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
|
||||
|
||||
llvm::cl::opt<std::string> circuitEncodings(
|
||||
"circuit-encodings",
|
||||
llvm::cl::desc("Specify the input and output encodings of the circuit, "
|
||||
"using the JSON representation."),
|
||||
llvm::cl::init(std::string{}));
|
||||
|
||||
} // namespace cmdline
|
||||
|
||||
namespace llvm {
|
||||
@@ -447,6 +461,20 @@ cmdlineCompilationOptions() {
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
if (!cmdline::circuitEncodings.empty()) {
|
||||
auto jsonString = cmdline::circuitEncodings.getValue();
|
||||
auto maybeEncodings =
|
||||
llvm::json::parse<encodings::CircuitEncodings>(jsonString);
|
||||
if (auto err = maybeEncodings.takeError()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"Failed to parse the --circuit-encodings option.",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
options.encodings = maybeEncodings.get();
|
||||
} else {
|
||||
options.encodings = std::nullopt;
|
||||
}
|
||||
|
||||
return options;
|
||||
}
|
||||
|
||||
@@ -532,6 +560,9 @@ mlir::LogicalResult processInputBuffer(
|
||||
case Action::DUMP_TFHE:
|
||||
target = mlir::concretelang::CompilerEngine::Target::TFHE;
|
||||
break;
|
||||
case Action::DUMP_NORMALIZED_TFHE:
|
||||
target = mlir::concretelang::CompilerEngine::Target::NORMALIZED_TFHE;
|
||||
break;
|
||||
case Action::DUMP_PARAMETRIZED_TFHE:
|
||||
target = mlir::concretelang::CompilerEngine::Target::PARAMETRIZED_TFHE;
|
||||
break;
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
//CHECK: llvm.call @memref_bootstrap_lwe_cuda_u64
|
||||
func.func @main(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
%cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
|
||||
%0 = "Concrete.keyswitch_lwe_tensor"(%arg0) {baseLog = 2 : i32, level = 5 : i32, lwe_dim_in = 1025 : i32, lwe_dim_out = 576 : i32} : (tensor<1025xi64>) -> tensor<576xi64>
|
||||
%1 = "Concrete.bootstrap_lwe_tensor"(%0, %cst) {baseLog = 2 : i32, level = 5 : i32, polySize = 1024: i32, glweDimension = 1 : i32, inputLweDim = 576 : i32, outPrecision = 2 : i32} : (tensor<576xi64>, tensor<4xi64>) -> tensor<1025xi64>
|
||||
%0 = "Concrete.keyswitch_lwe_tensor"(%arg0) {baseLog = 2 : i32, kskIndex = 0 : i32, level = 5 : i32, lwe_dim_in = 1025 : i32, lwe_dim_out = 576 : i32} : (tensor<1025xi64>) -> tensor<576xi64>
|
||||
%1 = "Concrete.bootstrap_lwe_tensor"(%0, %cst) {baseLog = 2 : i32, bskIndex = 0 : i32, level = 5 : i32, polySize = 1024: i32, glweDimension = 1 : i32, inputLweDim = 576 : i32, outPrecision = 2 : i32} : (tensor<576xi64>, tensor<4xi64>) -> tensor<1025xi64>
|
||||
return %1 : tensor<1025xi64>
|
||||
}
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @add_eint(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk[?]>>, %[[Varg1:.*]]: tensor<5x!TFHE.glwe<sk[?]>>) -> tensor<5x!TFHE.glwe<sk[?]>> {
|
||||
//CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK: func.func @add_eint(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk?>>, %[[Varg1:.*]]: tensor<5x!TFHE.glwe<sk?>>) -> tensor<5x!TFHE.glwe<sk?>> {
|
||||
//CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
|
||||
//CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
|
||||
//CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg2:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg3:.*]] = %[[V0]]) -> (tensor<5x!TFHE.glwe<sk[?]>>) {
|
||||
//CHECK-NEXT: %[[V2:.*]] = tensor.extract %[[Varg0]][%[[Varg2]]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: %[[V3:.*]] = tensor.extract %[[Varg1]][%[[Varg2]]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: %[[V4:.*]] = "TFHE.add_glwe"(%[[V2]], %[[V3]]) : (!TFHE.glwe<sk[?]>, !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
|
||||
//CHECK-NEXT: %[[V5:.*]] = tensor.insert %[[V4]] into %[[Varg3]][%[[Varg2]]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: scf.yield %[[V5]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg2:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg3:.*]] = %[[V0]]) -> (tensor<5x!TFHE.glwe<sk?>>) {
|
||||
//CHECK-NEXT: %[[V2:.*]] = tensor.extract %[[Varg0]][%[[Varg2]]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: %[[V3:.*]] = tensor.extract %[[Varg1]][%[[Varg2]]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: %[[V4:.*]] = "TFHE.add_glwe"(%[[V2]], %[[V3]]) : (!TFHE.glwe<sk?>, !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
|
||||
//CHECK-NEXT: %[[V5:.*]] = tensor.insert %[[V4]] into %[[Varg3]][%[[Varg2]]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: scf.yield %[[V5]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: return %[[V1]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: return %[[V1]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
func.func @add_eint(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @add_eint_int(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk[?]>>) -> tensor<5x!TFHE.glwe<sk[?]>> {
|
||||
// CHECK: func.func @add_eint_int(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk?>>) -> tensor<5x!TFHE.glwe<sk?>> {
|
||||
// CHECK-NEXT: %[[Vc1_i8:.*]] = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %[[V0:.*]] = arith.extsi %[[Vc1_i8]] : i8 to i64
|
||||
// CHECK-NEXT: %[[V1:.*]] = "TFHE.encode_plaintext_with_crt"(%[[V0]]) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
|
||||
// CHECK-NEXT: %[[V2:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: %[[V2:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
|
||||
// CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
|
||||
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V2]]) -> (tensor<5x!TFHE.glwe<sk[?]>>) {
|
||||
// CHECK-NEXT: %[[V4:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V2]]) -> (tensor<5x!TFHE.glwe<sk?>>) {
|
||||
// CHECK-NEXT: %[[V4:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: %[[V5:.*]] = tensor.extract %[[V1]][%[[Varg1]]] : tensor<5xi64>
|
||||
// CHECK-NEXT: %[[V6:.*]] = "TFHE.add_glwe_int"(%[[V4]], %[[V5]]) : (!TFHE.glwe<sk[?]>, i64) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: %[[V7:.*]] = tensor.insert %[[V6]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: scf.yield %[[V7]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: %[[V6:.*]] = "TFHE.add_glwe_int"(%[[V4]], %[[V5]]) : (!TFHE.glwe<sk?>, i64) -> !TFHE.glwe<sk?>
|
||||
// CHECK-NEXT: %[[V7:.*]] = tensor.insert %[[V6]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: scf.yield %[[V7]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[V3]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: return %[[V3]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
func.func @add_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @apply_lookup_table(%arg0: tensor<5x!TFHE.glwe<sk[?]>>, %arg1: tensor<4xi64>) -> tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK: func.func @apply_lookup_table(%arg0: tensor<5x!TFHE.glwe<sk?>>, %arg1: tensor<4xi64>) -> tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: %0 = "TFHE.encode_lut_for_crt_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32} : (tensor<4xi64>) -> tensor<5x8192xi64>
|
||||
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bsk = #TFHE.bsk<sk[?], sk[?], -1, -1, -1, -1>, cbsBaseLog = -1 : i32, cbsLevels = -1 : i32, crtDecomposition = [], ksk = #TFHE.ksk<sk[?], sk[?], -1, -1>, pksk = #TFHE.pksk<sk[?], sk[?], -1, -1, -1, -1>} : (tensor<5x!TFHE.glwe<sk[?]>>, tensor<5x8192xi64>) -> tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bsk = #TFHE.bsk<sk?, sk?, -1, -1, -1, -1>, cbsBaseLog = -1 : i32, cbsLevels = -1 : i32, crtDecomposition = [], ksk = #TFHE.ksk<sk?, sk?, -1, -1>, pksk = #TFHE.pksk<sk?, sk?, -1, -1, -1, -1, -1>} : (tensor<5x!TFHE.glwe<sk?>>, tensor<5x8192xi64>) -> tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<sk?>>
|
||||
func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> {
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>)
|
||||
return %1: !FHE.eint<3>
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @apply_lookup_table_cst(%arg0: tensor<5x!TFHE.glwe<sk[?]>>) -> tensor<5x!TFHE.glwe<sk[?]>> {
|
||||
// CHECK: func.func @apply_lookup_table_cst(%arg0: tensor<5x!TFHE.glwe<sk?>>) -> tensor<5x!TFHE.glwe<sk?>> {
|
||||
// CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
|
||||
// CHECK-NEXT: %0 = "TFHE.encode_lut_for_crt_woppbs"(%cst) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32} : (tensor<128xi64>) -> tensor<5x8192xi64>
|
||||
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bsk = #TFHE.bsk<sk[?], sk[?], -1, -1, -1, -1>, cbsBaseLog = -1 : i32, cbsLevels = -1 : i32, crtDecomposition = [], ksk = #TFHE.ksk<sk[?], sk[?], -1, -1>, pksk = #TFHE.pksk<sk[?], sk[?], -1, -1, -1, -1>} : (tensor<5x!TFHE.glwe<sk[?]>>, tensor<5x8192xi64>) -> tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bsk = #TFHE.bsk<sk?, sk?, -1, -1, -1, -1>, cbsBaseLog = -1 : i32, cbsLevels = -1 : i32, crtDecomposition = [], ksk = #TFHE.ksk<sk?, sk?, -1, -1>, pksk = #TFHE.pksk<sk?, sk?, -1, -1, -1, -1, -1>} : (tensor<5x!TFHE.glwe<sk?>>, tensor<5x8192xi64>) -> tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<sk?>>
|
||||
func.func @apply_lookup_table_cst(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64>
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @conv2d(%[[Varg0:.*]]: tensor<100x3x28x28x5x!TFHE.glwe<sk[?]>>, %[[Varg1:.*]]: tensor<4x3x14x14xi3>, %[[Varg2:.*]]: tensor<4xi3>) -> tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>> {
|
||||
//CHECK: func.func @conv2d(%[[Varg0:.*]]: tensor<100x3x28x28x5x!TFHE.glwe<sk?>>, %[[Varg1:.*]]: tensor<4x3x14x14xi3>, %[[Varg2:.*]]: tensor<4xi3>) -> tensor<100x4x15x15x5x!TFHE.glwe<sk?>> {
|
||||
//CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vc100:.*]] = arith.constant 100 : index
|
||||
//CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
|
||||
@@ -8,90 +8,90 @@
|
||||
//CHECK-NEXT: %[[Vc15:.*]] = arith.constant 15 : index
|
||||
//CHECK-NEXT: %[[Vc3:.*]] = arith.constant 3 : index
|
||||
//CHECK-NEXT: %[[Vc14:.*]] = arith.constant 14 : index
|
||||
//CHECK-NEXT: %[[V0:.*]] = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V0]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
|
||||
//CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
|
||||
//CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
|
||||
//CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
|
||||
//CHECK-NEXT: %[[V0:.*]] = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V0]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
|
||||
//CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
|
||||
//CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
|
||||
//CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
|
||||
//CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[Varg2]]{{\[}}%[[Varg5]]{{\]}} : tensor<4xi3>
|
||||
//CHECK-NEXT: %[[Vc0_0:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vextracted_slice:.*]] = tensor.extract_slice %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_0]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>> to tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: %[[Vextracted_slice:.*]] = tensor.extract_slice %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_0]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<100x4x15x15x5x!TFHE.glwe<sk?>> to tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: %[[V6:.*]] = arith.extsi %[[Vextracted]] : i3 to i64
|
||||
//CHECK-NEXT: %[[V7:.*]] = "TFHE.encode_plaintext_with_crt"(%[[V6]]) {mods = {{\[2, 3, 5, 7, 11\], modsProd}} = 2310 : i64} : (i64) -> tensor<5xi64>
|
||||
//CHECK-NEXT: %[[V8:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: %[[V8:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: %[[Vc0_1:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vc1_2:.*]] = arith.constant 1 : index
|
||||
//CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
|
||||
//CHECK-NEXT: %[[V9:.*]] = scf.for %[[Varg11:.*]] = %[[Vc0_1]] to %[[Vc5]] step %[[Vc1_2]] iter_args(%[[Varg12:.*]] = %[[V8]]) -> (tensor<5x!TFHE.glwe<sk[?]>>) {
|
||||
//CHECK-NEXT: %[[Vextracted_4:.*]] = tensor.extract %[[Vextracted_slice]]{{\[}}%[[Varg11]]{{\]}} : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: %[[V9:.*]] = scf.for %[[Varg11:.*]] = %[[Vc0_1]] to %[[Vc5]] step %[[Vc1_2]] iter_args(%[[Varg12:.*]] = %[[V8]]) -> (tensor<5x!TFHE.glwe<sk?>>) {
|
||||
//CHECK-NEXT: %[[Vextracted_4:.*]] = tensor.extract %[[Vextracted_slice]]{{\[}}%[[Varg11]]{{\]}} : tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: %[[Vextracted_5:.*]] = tensor.extract %[[V7]]{{\[}}%[[Varg11]]{{\]}} : tensor<5xi64>
|
||||
//CHECK-NEXT: %[[V10:.*]] = "TFHE.add_glwe_int"(%[[Vextracted_4]], %[[Vextracted_5]]) : (!TFHE.glwe<sk[?]>, i64) -> !TFHE.glwe<sk[?]>
|
||||
//CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V10]] into %[[Varg12]]{{\[}}%[[Varg11]]{{\]}} : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: %[[V10:.*]] = "TFHE.add_glwe_int"(%[[Vextracted_4]], %[[Vextracted_5]]) : (!TFHE.glwe<sk?>, i64) -> !TFHE.glwe<sk?>
|
||||
//CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V10]] into %[[Varg12]]{{\[}}%[[Varg11]]{{\]}} : tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: %[[Vc0_3:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vinserted_slice:.*]] = tensor.insert_slice %[[V9]] into %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_3]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<5x!TFHE.glwe<sk[?]>> into tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: scf.yield %[[Vinserted_slice]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: %[[Vinserted_slice:.*]] = tensor.insert_slice %[[V9]] into %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_3]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<5x!TFHE.glwe<sk?>> into tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: scf.yield %[[Vinserted_slice]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V1]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
|
||||
//CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
|
||||
//CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
|
||||
//CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
|
||||
//CHECK-NEXT: %[[V6:.*]] = scf.for %[[Varg11:.*]] = %[[Vc0]] to %[[Vc3]] step %[[Vc1]] iter_args(%[[Varg12:.*]] = %[[Varg10]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
|
||||
//CHECK-NEXT: %[[V7:.*]] = scf.for %[[Varg13:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg14:.*]] = %[[Varg12]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
|
||||
//CHECK-NEXT: %[[V8:.*]] = scf.for %[[Varg15:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg16:.*]] = %[[Varg14]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
|
||||
//CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V1]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
|
||||
//CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
|
||||
//CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
|
||||
//CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
|
||||
//CHECK-NEXT: %[[V6:.*]] = scf.for %[[Varg11:.*]] = %[[Vc0]] to %[[Vc3]] step %[[Vc1]] iter_args(%[[Varg12:.*]] = %[[Varg10]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
|
||||
//CHECK-NEXT: %[[V7:.*]] = scf.for %[[Varg13:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg14:.*]] = %[[Varg12]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
|
||||
//CHECK-NEXT: %[[V8:.*]] = scf.for %[[Varg15:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg16:.*]] = %[[Varg14]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
|
||||
//CHECK-NEXT: %[[V9:.*]] = affine.apply #map(%[[Varg7]], %[[Varg13]])
|
||||
//CHECK-NEXT: %[[V10:.*]] = affine.apply #map(%[[Varg9]], %[[Varg15]])
|
||||
//CHECK-NEXT: %[[Vc0_0:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vextracted_slice:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[Varg3]], %[[Varg11]], %[[V9]], %[[V10]], %[[Vc0_0]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<100x3x28x28x5x!TFHE.glwe<sk[?]>> to tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: %[[Vextracted_slice:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[Varg3]], %[[Varg11]], %[[V9]], %[[V10]], %[[Vc0_0]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<100x3x28x28x5x!TFHE.glwe<sk?>> to tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[Varg1]]{{\[}}%[[Varg5]], %[[Varg11]], %[[Varg13]], %[[Varg15]]{{\]}} : tensor<4x3x14x14xi3>
|
||||
//CHECK-NEXT: %[[Vc0_1:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vextracted_slice_2:.*]] = tensor.extract_slice %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_1]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>> to tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: %[[Vextracted_slice_2:.*]] = tensor.extract_slice %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_1]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<100x4x15x15x5x!TFHE.glwe<sk?>> to tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: %[[V11:.*]] = arith.extsi %[[Vextracted]] : i3 to i64
|
||||
//CHECK-NEXT: %[[V12:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: %[[V12:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: %[[Vc0_3:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vc1_4:.*]] = arith.constant 1 : index
|
||||
//CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
|
||||
//CHECK-NEXT: %[[V13:.*]] = scf.for %[[Varg17:.*]] = %[[Vc0_3]] to %[[Vc5]] step %[[Vc1_4]] iter_args(%[[Varg18:.*]] = %[[V12]]) -> (tensor<5x!TFHE.glwe<sk[?]>>) {
|
||||
//CHECK-NEXT: %[[Vextracted_9:.*]] = tensor.extract %[[Vextracted_slice]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: %[[V16:.*]] = "TFHE.mul_glwe_int"(%[[Vextracted_9]], %[[V11]]) : (!TFHE.glwe<sk[?]>, i64) -> !TFHE.glwe<sk[?]>
|
||||
//CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V16]] into %[[Varg18]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: %[[V13:.*]] = scf.for %[[Varg17:.*]] = %[[Vc0_3]] to %[[Vc5]] step %[[Vc1_4]] iter_args(%[[Varg18:.*]] = %[[V12]]) -> (tensor<5x!TFHE.glwe<sk?>>) {
|
||||
//CHECK-NEXT: %[[Vextracted_9:.*]] = tensor.extract %[[Vextracted_slice]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: %[[V16:.*]] = "TFHE.mul_glwe_int"(%[[Vextracted_9]], %[[V11]]) : (!TFHE.glwe<sk?>, i64) -> !TFHE.glwe<sk?>
|
||||
//CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V16]] into %[[Varg18]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: %[[V14:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: %[[V14:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: %[[Vc0_5:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vc1_6:.*]] = arith.constant 1 : index
|
||||
//CHECK-NEXT: %[[Vc5_7:.*]] = arith.constant 5 : index
|
||||
//CHECK-NEXT: %[[V15:.*]] = scf.for %[[Varg17:.*]] = %[[Vc0_5]] to %[[Vc5_7]] step %[[Vc1_6]] iter_args(%[[Varg18:.*]] = %[[V14]]) -> (tensor<5x!TFHE.glwe<sk[?]>>) {
|
||||
//CHECK-NEXT: %[[Vextracted_9:.*]] = tensor.extract %[[Vextracted_slice_2]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: %[[Vextracted_10:.*]] = tensor.extract %[[V13]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: %[[V16:.*]] = "TFHE.add_glwe"(%[[Vextracted_9]], %[[Vextracted_10]]) : (!TFHE.glwe<sk[?]>, !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
|
||||
//CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V16]] into %[[Varg18]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: %[[V15:.*]] = scf.for %[[Varg17:.*]] = %[[Vc0_5]] to %[[Vc5_7]] step %[[Vc1_6]] iter_args(%[[Varg18:.*]] = %[[V14]]) -> (tensor<5x!TFHE.glwe<sk?>>) {
|
||||
//CHECK-NEXT: %[[Vextracted_9:.*]] = tensor.extract %[[Vextracted_slice_2]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: %[[Vextracted_10:.*]] = tensor.extract %[[V13]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: %[[V16:.*]] = "TFHE.add_glwe"(%[[Vextracted_9]], %[[Vextracted_10]]) : (!TFHE.glwe<sk?>, !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
|
||||
//CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V16]] into %[[Varg18]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: %[[Vc0_8:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[Vinserted_slice:.*]] = tensor.insert_slice %[[V15]] into %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_8]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<5x!TFHE.glwe<sk[?]>> into tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: scf.yield %[[Vinserted_slice]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: %[[Vinserted_slice:.*]] = tensor.insert_slice %[[V15]] into %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_8]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<5x!TFHE.glwe<sk?>> into tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: scf.yield %[[Vinserted_slice]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: scf.yield %[[V8]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: scf.yield %[[V8]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: scf.yield %[[V7]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: scf.yield %[[V7]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: scf.yield %[[V6]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: scf.yield %[[V6]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: return %[[V2]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
|
||||
//CHECK-NEXT: return %[[V2]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
|
||||
//CHECK-NEXT: }
|
||||
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0, 0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @mul_eint_int(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk[?]>>) -> tensor<5x!TFHE.glwe<sk[?]>> {
|
||||
// CHECK: func.func @mul_eint_int(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk?>>) -> tensor<5x!TFHE.glwe<sk?>> {
|
||||
// CHECK-NEXT: %[[Vc2_i8:.*]] = arith.constant 2 : i8
|
||||
// CHECK-NEXT: %[[V0:.*]] = arith.extsi %[[Vc2_i8]] : i8 to i64
|
||||
// CHECK-NEXT: %[[V1:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: %[[V1:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
|
||||
// CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
|
||||
// CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V1]]) -> (tensor<5x!TFHE.glwe<sk[?]>>) {
|
||||
// CHECK-NEXT: %[[V3:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: %[[V4:.*]] = "TFHE.mul_glwe_int"(%[[V3]], %[[V0]]) : (!TFHE.glwe<sk[?]>, i64) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: %[[V5:.*]] = tensor.insert %[[V4]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: scf.yield %[[V5]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V1]]) -> (tensor<5x!TFHE.glwe<sk?>>) {
|
||||
// CHECK-NEXT: %[[V3:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: %[[V4:.*]] = "TFHE.mul_glwe_int"(%[[V3]], %[[V0]]) : (!TFHE.glwe<sk?>, i64) -> !TFHE.glwe<sk?>
|
||||
// CHECK-NEXT: %[[V5:.*]] = tensor.insert %[[V4]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: scf.yield %[[V5]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[V2]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: return %[[V2]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
func.func @mul_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%0 = arith.constant 2 : i8
|
||||
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @neg_eint(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk[?]>>) -> tensor<5x!TFHE.glwe<sk[?]>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK: func.func @neg_eint(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk?>>) -> tensor<5x!TFHE.glwe<sk?>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
|
||||
// CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
|
||||
// CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V0]]) -> (tensor<5x!TFHE.glwe<sk[?]>>) {
|
||||
// CHECK-NEXT: %[[V2:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "TFHE.neg_glwe"(%[[V2]]) : (!TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: %[[V4:.*]] = tensor.insert %[[V3]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: scf.yield %[[V4]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V0]]) -> (tensor<5x!TFHE.glwe<sk?>>) {
|
||||
// CHECK-NEXT: %[[V2:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "TFHE.neg_glwe"(%[[V2]]) : (!TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
|
||||
// CHECK-NEXT: %[[V4:.*]] = tensor.insert %[[V3]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: scf.yield %[[V4]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[V1]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: return %[[V1]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: }
|
||||
func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @sub_int_eint(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk[?]>>) -> tensor<5x!TFHE.glwe<sk[?]>> {
|
||||
// CHECK: func.func @sub_int_eint(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk?>>) -> tensor<5x!TFHE.glwe<sk?>> {
|
||||
// CHECK-NEXT: %[[Vc1_i8:.*]] = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %[[V0:.*]] = arith.extsi %[[Vc1_i8]] : i8 to i64
|
||||
// CHECK-NEXT: %[[V1:.*]] = "TFHE.encode_plaintext_with_crt"(%[[V0]]) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
|
||||
// CHECK-NEXT: %[[V2:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: %[[V2:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
|
||||
// CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
|
||||
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V2]]) -> (tensor<5x!TFHE.glwe<sk[?]>>) {
|
||||
// CHECK-NEXT: %[[V4:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V2]]) -> (tensor<5x!TFHE.glwe<sk?>>) {
|
||||
// CHECK-NEXT: %[[V4:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: %[[V5:.*]] = tensor.extract %[[V1]][%[[Varg1]]] : tensor<5xi64>
|
||||
// CHECK-NEXT: %[[V6:.*]] = "TFHE.sub_int_glwe"(%[[V5]], %[[V4]]) : (i64, !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: %[[V7:.*]] = tensor.insert %[[V6]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: scf.yield %[[V7]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: %[[V6:.*]] = "TFHE.sub_int_glwe"(%[[V5]], %[[V4]]) : (i64, !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
|
||||
// CHECK-NEXT: %[[V7:.*]] = tensor.insert %[[V6]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: scf.yield %[[V7]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[V3]] : tensor<5x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: return %[[V3]] : tensor<5x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: }
|
||||
func.func @sub_int_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%0 = arith.constant 1 : i8
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @add_eint(%arg0: !TFHE.glwe<sk[?]>, %arg1: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-LABEL: func.func @add_eint(%arg0: !TFHE.glwe<sk?>, %arg1: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
|
||||
func.func @add_eint(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "TFHE.add_glwe"(%arg0, %arg1) : (!TFHE.glwe<sk[?]>, !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: return %[[V1]] : !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: %[[V1:.*]] = "TFHE.add_glwe"(%arg0, %arg1) : (!TFHE.glwe<sk?>, !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
|
||||
// CHECK-NEXT: return %[[V1]] : !TFHE.glwe<sk?>
|
||||
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @add_eint_int(%arg0: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-LABEL: func.func @add_eint_int(%arg0: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
|
||||
func.func @add_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %0 = arith.extsi %c1_i8 : i8 to i64
|
||||
// CHECK-NEXT: %c56_i64 = arith.constant 56 : i64
|
||||
// CHECK-NEXT: %1 = arith.shli %0, %c56_i64 : i64
|
||||
// CHECK-NEXT: %2 = "TFHE.add_glwe_int"(%arg0, %1) : (!TFHE.glwe<sk[?]>, i64) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: return %2 : !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: %2 = "TFHE.add_glwe_int"(%arg0, %1) : (!TFHE.glwe<sk?>, i64) -> !TFHE.glwe<sk?>
|
||||
// CHECK-NEXT: return %2 : !TFHE.glwe<sk?>
|
||||
|
||||
|
||||
%0 = arith.constant 1 : i8
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @apply_lookup_table(%arg0: !TFHE.glwe<sk[?]>, %arg1: tensor<4xi64>) -> !TFHE.glwe<sk[?]> {
|
||||
// CHECK: func.func @apply_lookup_table(%arg0: !TFHE.glwe<sk?>, %arg1: tensor<4xi64>) -> !TFHE.glwe<sk?> {
|
||||
// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_bootstrap"(%arg1) {isSigned = false, outputBits = 3 : i32, polySize = 256 : i32} : (tensor<4xi64>) -> tensor<256xi64>
|
||||
// CHECK-NEXT: %1 = "TFHE.keyswitch_glwe"(%arg0) {key = #TFHE.ksk<sk[?], sk[?], -1, -1>} : (!TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: %2 = "TFHE.bootstrap_glwe"(%1, %0) {key = #TFHE.bsk<sk[?], sk[?], -1, -1, -1, -1>} : (!TFHE.glwe<sk[?]>, tensor<256xi64>) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: return %2 : !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: %1 = "TFHE.keyswitch_glwe"(%arg0) {key = #TFHE.ksk<sk?, sk?, -1, -1>} : (!TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
|
||||
// CHECK-NEXT: %2 = "TFHE.bootstrap_glwe"(%1, %0) {key = #TFHE.bsk<sk?, sk?, -1, -1, -1, -1>} : (!TFHE.glwe<sk?>, tensor<256xi64>) -> !TFHE.glwe<sk?>
|
||||
// CHECK-NEXT: return %2 : !TFHE.glwe<sk?>
|
||||
func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> {
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>)
|
||||
return %1: !FHE.eint<3>
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]> {
|
||||
//CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?> {
|
||||
|
||||
//CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
|
||||
//CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_bootstrap"(%cst) {isSigned = false, outputBits = 7 : i32, polySize = 8192 : i32} : (tensor<128xi64>) -> tensor<8192xi64>
|
||||
//CHECK-NEXT: %1 = "TFHE.keyswitch_glwe"(%arg0) {key = #TFHE.ksk<sk[?], sk[?], -1, -1>} : (!TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
|
||||
//CHECK-NEXT: %2 = "TFHE.bootstrap_glwe"(%1, %0) {key = #TFHE.bsk<sk[?], sk[?], -1, -1, -1, -1>} : (!TFHE.glwe<sk[?]>, tensor<8192xi64>) -> !TFHE.glwe<sk[?]>
|
||||
//CHECK-NEXT: return %2 : !TFHE.glwe<sk[?]>
|
||||
//CHECK-NEXT: %1 = "TFHE.keyswitch_glwe"(%arg0) {key = #TFHE.ksk<sk?, sk?, -1, -1>} : (!TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
|
||||
//CHECK-NEXT: %2 = "TFHE.bootstrap_glwe"(%1, %0) {key = #TFHE.bsk<sk?, sk?, -1, -1, -1, -1>} : (!TFHE.glwe<sk?>, tensor<8192xi64>) -> !TFHE.glwe<sk?>
|
||||
//CHECK-NEXT: return %2 : !TFHE.glwe<sk?>
|
||||
func.func @apply_lookup_table_cst(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64>
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @conv2d(%[[Varg0:.*]]: tensor<100x3x28x28x!TFHE.glwe<sk[?]>>, %[[Varg1:.*]]: tensor<4x3x14x14xi3>, %[[Varg2:.*]]: tensor<4xi3>) -> tensor<100x4x15x15x!TFHE.glwe<sk[?]>> {
|
||||
// CHECK: func.func @conv2d(%[[Varg0:.*]]: tensor<100x3x28x28x!TFHE.glwe<sk?>>, %[[Varg1:.*]]: tensor<4x3x14x14xi3>, %[[Varg2:.*]]: tensor<4xi3>) -> tensor<100x4x15x15x!TFHE.glwe<sk?>> {
|
||||
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[Vc100:.*]] = arith.constant 100 : index
|
||||
// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
|
||||
@@ -8,57 +8,57 @@
|
||||
// CHECK-NEXT: %[[Vc15:.*]] = arith.constant 15 : index
|
||||
// CHECK-NEXT: %[[Vc3:.*]] = arith.constant 3 : index
|
||||
// CHECK-NEXT: %[[Vc14:.*]] = arith.constant 14 : index
|
||||
// CHECK-NEXT: %[[V0:.*]] = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V0]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
|
||||
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
|
||||
// CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
|
||||
// CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V0]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
|
||||
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
|
||||
// CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
|
||||
// CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
|
||||
// CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[Varg2]]{{\[}}%[[Varg5]]{{\]}} : tensor<4xi3>
|
||||
// CHECK-NEXT: %[[Vextracted_0:.*]] = tensor.extract %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]]{{\]}} : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: %[[Vextracted_0:.*]] = tensor.extract %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]]{{\]}} : tensor<100x4x15x15x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: %[[V6:.*]] = arith.extsi %[[Vextracted]] : i3 to i64
|
||||
// CHECK-NEXT: %[[Vc61_i64:.*]] = arith.constant 61 : i64
|
||||
// CHECK-NEXT: %[[V7:.*]] = arith.shli %[[V6]], %[[Vc61_i64]] : i64
|
||||
// CHECK-NEXT: %[[V8:.*]] = "TFHE.add_glwe_int"(%[[Vextracted_0]], %[[V7]]) : (!TFHE.glwe<sk[?]>, i64) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V8]] into %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]]{{\]}} : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: %[[V8:.*]] = "TFHE.add_glwe_int"(%[[Vextracted_0]], %[[V7]]) : (!TFHE.glwe<sk?>, i64) -> !TFHE.glwe<sk?>
|
||||
// CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V8]] into %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]]{{\]}} : tensor<100x4x15x15x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V1]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
|
||||
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
|
||||
// CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
|
||||
// CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
|
||||
// CHECK-NEXT: %[[V6:.*]] = scf.for %[[Varg11:.*]] = %[[Vc0]] to %[[Vc3]] step %[[Vc1]] iter_args(%[[Varg12:.*]] = %[[Varg10]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
|
||||
// CHECK-NEXT: %[[V7:.*]] = scf.for %[[Varg13:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg14:.*]] = %[[Varg12]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
|
||||
// CHECK-NEXT: %[[V8:.*]] = scf.for %[[Varg15:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg16:.*]] = %[[Varg14]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
|
||||
// CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V1]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
|
||||
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
|
||||
// CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
|
||||
// CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
|
||||
// CHECK-NEXT: %[[V6:.*]] = scf.for %[[Varg11:.*]] = %[[Vc0]] to %[[Vc3]] step %[[Vc1]] iter_args(%[[Varg12:.*]] = %[[Varg10]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
|
||||
// CHECK-NEXT: %[[V7:.*]] = scf.for %[[Varg13:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg14:.*]] = %[[Varg12]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
|
||||
// CHECK-NEXT: %[[V8:.*]] = scf.for %[[Varg15:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg16:.*]] = %[[Varg14]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
|
||||
// CHECK-NEXT: %[[V9:.*]] = affine.apply #map(%[[Varg7]], %[[Varg13]])
|
||||
// CHECK-NEXT: %[[V10:.*]] = affine.apply #map(%[[Varg9]], %[[Varg15]])
|
||||
// CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[Varg0]]{{\[}}%[[Varg3]], %[[Varg11]], %[[V9]], %[[V10]]{{\]}} : tensor<100x3x28x28x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[Varg0]]{{\[}}%[[Varg3]], %[[Varg11]], %[[V9]], %[[V10]]{{\]}} : tensor<100x3x28x28x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: %[[Vextracted_0:.*]] = tensor.extract %[[Varg1]]{{\[}}%[[Varg5]], %[[Varg11]], %[[Varg13]], %[[Varg15]]{{\]}} : tensor<4x3x14x14xi3>
|
||||
// CHECK-NEXT: %[[Vextracted_1:.*]] = tensor.extract %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]]{{\]}} : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: %[[Vextracted_1:.*]] = tensor.extract %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]]{{\]}} : tensor<100x4x15x15x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: %[[V11:.*]] = arith.extsi %[[Vextracted_0]] : i3 to i64
|
||||
// CHECK-NEXT: %[[V12:.*]] = "TFHE.mul_glwe_int"(%[[Vextracted]], %[[V11]]) : (!TFHE.glwe<sk[?]>, i64) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: %[[V13:.*]] = "TFHE.add_glwe"(%[[Vextracted_1]], %[[V12]]) : (!TFHE.glwe<sk[?]>, !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V13]] into %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]]{{\]}} : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: %[[V12:.*]] = "TFHE.mul_glwe_int"(%[[Vextracted]], %[[V11]]) : (!TFHE.glwe<sk?>, i64) -> !TFHE.glwe<sk?>
|
||||
// CHECK-NEXT: %[[V13:.*]] = "TFHE.add_glwe"(%[[Vextracted_1]], %[[V12]]) : (!TFHE.glwe<sk?>, !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
|
||||
// CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V13]] into %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]]{{\]}} : tensor<100x4x15x15x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[V8]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: scf.yield %[[V8]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[V7]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: scf.yield %[[V7]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[V6]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: scf.yield %[[V6]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[V2]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
|
||||
// CHECK-NEXT: return %[[V2]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
|
||||
// CHECK-NEXT: }
|
||||
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0, 0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @mul_eint_int(%arg0: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-LABEL: func.func @mul_eint_int(%arg0: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
|
||||
func.func @mul_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %c2_i8 = arith.constant 2 : i8
|
||||
// CHECK-NEXT: %0 = arith.extsi %c2_i8 : i8 to i64
|
||||
// CHECK-NEXT: %1 = "TFHE.mul_glwe_int"(%arg0, %0) : (!TFHE.glwe<sk[?]>, i64) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: return %1 : !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: %1 = "TFHE.mul_glwe_int"(%arg0, %0) : (!TFHE.glwe<sk?>, i64) -> !TFHE.glwe<sk?>
|
||||
// CHECK-NEXT: return %1 : !TFHE.glwe<sk?>
|
||||
|
||||
%0 = arith.constant 2 : i8
|
||||
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @neg_eint(%arg0: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-LABEL: func.func @neg_eint(%arg0: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
|
||||
func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: return %0 : !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
|
||||
// CHECK-NEXT: return %0 : !TFHE.glwe<sk?>
|
||||
|
||||
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @not(%arg0: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-LABEL: func.func @not(%arg0: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
|
||||
func.func @not(%arg0: !FHE.ebool) -> !FHE.ebool {
|
||||
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: return %0 : !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
|
||||
// CHECK-NEXT: return %0 : !TFHE.glwe<sk?>
|
||||
%1 = "FHE.not"(%arg0) : (!FHE.ebool) -> !FHE.ebool
|
||||
return %1: !FHE.ebool
|
||||
}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @sub_int_eint(%arg0: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-LABEL: func.func @sub_int_eint(%arg0: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
|
||||
func.func @sub_int_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %0 = arith.extsi %c1_i8 : i8 to i64
|
||||
// CHECK-NEXT: %c56_i64 = arith.constant 56 : i64
|
||||
// CHECK-NEXT: %1 = arith.shli %0, %c56_i64 : i64
|
||||
// CHECK-NEXT: %2 = "TFHE.sub_int_glwe"(%1, %arg0) : (i64, !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: return %2 : !TFHE.glwe<sk[?]>
|
||||
// CHECK-NEXT: %2 = "TFHE.sub_int_glwe"(%1, %arg0) : (i64, !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
|
||||
// CHECK-NEXT: return %2 : !TFHE.glwe<sk?>
|
||||
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "FHE.sub_int_eint"(%0, %arg0): (i8, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
// RUN: concretecompiler --passes tfhe-global-parametrization --action=dump-std --optimizer-v0 --v0-parameter=2,10,750,1,23,3,4 --v0-constraint=4,0 %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --action=dump-tfhe-parameterized --optimizer-v0 --v0-parameter=2,10,750,1,23,3,4 --v0-constraint=4,0 %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @main(%[[A0:.*]]: !TFHE.glwe<sk[1]<1,2048>>) -> !TFHE.glwe<sk[1]<1,2048>> {
|
||||
//CHECK: func.func @main(%[[A0:.*]]: !TFHE.glwe<sk<0,1,2048>>) -> !TFHE.glwe<sk<0,1,2048>> {
|
||||
//CHECK-NEXT: %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
|
||||
//CHECK-NEXT: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[A0]]) {key = #TFHE.ksk<sk[1]<1,2048>, sk[3]<1,750>, 3, 4>} : (!TFHE.glwe<sk[1]<1,2048>>) -> !TFHE.glwe<sk[3]<1,750>>
|
||||
//CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %cst) {key = #TFHE.bsk<sk[3]<1,750>, sk[1]<1,2048>, 1024, 2, 1, 23>} : (!TFHE.glwe<sk[3]<1,750>>, tensor<16xi64>) -> !TFHE.glwe<sk[1]<1,2048>>
|
||||
//CHECK-NEXT: return %[[V2]] : !TFHE.glwe<sk[1]<1,2048>>
|
||||
//CHECK-NEXT: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[A0]]) {key = #TFHE.ksk<sk<0,1,2048>, sk<1,1,750>, 3, 4>} : (!TFHE.glwe<sk<0,1,2048>>) -> !TFHE.glwe<sk<1,1,750>>
|
||||
//CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %cst) {key = #TFHE.bsk<sk<1,1,750>, sk<0,1,2048>, 1024, 2, 1, 23>} : (!TFHE.glwe<sk<1,1,750>>, tensor<16xi64>) -> !TFHE.glwe<sk<0,1,2048>>
|
||||
//CHECK-NEXT: return %[[V2]] : !TFHE.glwe<sk<0,1,2048>>
|
||||
//CHECK-NEXT: }
|
||||
func.func @main(%arg0: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]> {
|
||||
func.func @main(%arg0: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?> {
|
||||
%cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
|
||||
%1 = "TFHE.keyswitch_glwe"(%arg0) {key = #TFHE.ksk<sk[?], sk[?], -1, -1>} : (!TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
|
||||
%2 = "TFHE.bootstrap_glwe"(%1, %cst) {key = #TFHE.bsk<sk[?], sk[?], -1, -1, -1, -1>} : (!TFHE.glwe<sk[?]>, tensor<16xi64>) -> !TFHE.glwe<sk[?]>
|
||||
return %2 : !TFHE.glwe<sk[?]>
|
||||
%1 = "TFHE.keyswitch_glwe"(%arg0) {key = #TFHE.ksk<sk?, sk?, -1, -1>} : (!TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
|
||||
%2 = "TFHE.bootstrap_glwe"(%1, %cst) {key = #TFHE.bsk<sk?, sk?, -1, -1, -1, -1>} : (!TFHE.glwe<sk?>, tensor<16xi64>) -> !TFHE.glwe<sk?>
|
||||
return %2 : !TFHE.glwe<sk?>
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
//CHECK: func.func @bootstrap_lwe(%[[A0:.*]]: tensor<601xi64>) -> tensor<1025xi64> {
|
||||
//CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
|
||||
//CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe_tensor"(%arg0, %cst) {baseLog = 1 : i32, glweDimension = 1 : i32, inputLweDim = 600 : i32, level = 3 : i32, polySize = 1024 : i32} : (tensor<601xi64>, tensor<128xi64>) -> tensor<1025xi64>
|
||||
//CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe_tensor"(%arg0, %cst) {baseLog = 1 : i32, bskIndex = -1 : i32, glweDimension = 1 : i32, inputLweDim = 600 : i32, level = 3 : i32, polySize = 1024 : i32} : (tensor<601xi64>, tensor<128xi64>) -> tensor<1025xi64>
|
||||
//CHECK-NEXT: return %[[V1]] : tensor<1025xi64>
|
||||
//CHECK-NEXT: }
|
||||
func.func @bootstrap_lwe(%ciphertext: !TFHE.glwe<sk[1]<1,600>>) -> !TFHE.glwe<sk[5]<1,1024>> {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @keyswitch_glwe(%[[A0:.*]]: tensor<1025xi64>) -> tensor<568xi64> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "Concrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 3 : i32, level = 2 : i32, lwe_dim_in = 1024 : i32, lwe_dim_out = 567 : i32} : (tensor<1025xi64>) -> tensor<568xi64>
|
||||
// CHECK-NEXT: %[[V0:.*]] = "Concrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 3 : i32, kskIndex = -1 : i32, level = 2 : i32, lwe_dim_in = 1024 : i32, lwe_dim_out = 567 : i32} : (tensor<1025xi64>) -> tensor<568xi64>
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<568xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @keyswitch_glwe(%arg0: !TFHE.glwe<sk[1]<1,1024>>) -> !TFHE.glwe<sk[3]<1,567>> {
|
||||
|
||||
@@ -39,20 +39,20 @@ func.func @negate_lwe_ciphertext(%arg0: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
}
|
||||
|
||||
//CHECK: func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "Concrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> tensor<2049xi64>
|
||||
//CHECK: %[[V0:.*]] = "Concrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, bskIndex = 0 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
//CHECK: }
|
||||
func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>) -> tensor<2049xi64> {
|
||||
%0 = "Concrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> (tensor<2049xi64>)
|
||||
%0 = "Concrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, bskIndex = 0 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> (tensor<2049xi64>)
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @keyswitch_lwe(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "Concrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> tensor<2049xi64>
|
||||
//CHECK: %[[V0:.*]] = "Concrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 2 : i32, kskIndex = 0 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
//CHECK: }
|
||||
func.func @keyswitch_lwe(%arg0: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
%0 = "Concrete.keyswitch_lwe_tensor"(%arg0) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> (tensor<2049xi64>)
|
||||
%0 = "Concrete.keyswitch_lwe_tensor"(%arg0) {baseLog = 2 : i32, kskIndex = 0 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> (tensor<2049xi64>)
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
@@ -83,13 +83,13 @@ func.func @negate_lwe_ciphertext_buffer(%arg0: memref<2049xi64>, %result: memref
|
||||
}
|
||||
|
||||
func.func @bootstrap_lwe_buffer(%arg0: memref<2049xi64>, %arg1: memref<16xi64>, %result: memref<2049xi64>) {
|
||||
//CHECK: "Concrete.bootstrap_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A1:.*]]) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>, memref<16xi64>) -> ()
|
||||
"Concrete.bootstrap_lwe_buffer"(%result, %arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>, memref<16xi64>) -> ()
|
||||
//CHECK: "Concrete.bootstrap_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A1:.*]]) {baseLog = 2 : i32, bskIndex = 0 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>, memref<16xi64>) -> ()
|
||||
"Concrete.bootstrap_lwe_buffer"(%result, %arg0, %arg1) {baseLog = 2 : i32, bskIndex = 0 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>, memref<16xi64>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
func.func @keyswitch_lwe_buffer(%arg0: memref<2049xi64>, %result: memref<2049xi64>) {
|
||||
//CHECK: "Concrete.keyswitch_lwe_buffer"(%[[R:.*]], %[[A0:.*]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>) -> ()
|
||||
"Concrete.keyswitch_lwe_buffer"(%result, %arg0) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>) -> ()
|
||||
//CHECK: "Concrete.keyswitch_lwe_buffer"(%[[R:.*]], %[[A0:.*]]) {baseLog = 2 : i32, kskIndex = 0 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>) -> ()
|
||||
"Concrete.keyswitch_lwe_buffer"(%result, %arg0) {baseLog = 2 : i32, kskIndex = 0 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ func.func @glwe_0(%arg0: !TFHE.glwe<sk[1]<12,1024>>) -> !TFHE.glwe<sk[1]<12,1024
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @glwe_1(%arg0: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
|
||||
func.func @glwe_1(%arg0: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]> {
|
||||
// CHECK-LABEL: return %arg0 : !TFHE.glwe<sk[?]>
|
||||
return %arg0: !TFHE.glwe<sk[?]>
|
||||
// CHECK-LABEL: func.func @glwe_1(%arg0: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
|
||||
func.func @glwe_1(%arg0: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?> {
|
||||
// CHECK-LABEL: return %arg0 : !TFHE.glwe<sk?>
|
||||
return %arg0: !TFHE.glwe<sk?>
|
||||
}
|
||||
|
||||
@@ -418,9 +418,7 @@ def test_compile_and_run_invalid_arg_number(
|
||||
)
|
||||
def test_compile_invalid(mlir_input):
|
||||
engine = JITSupport.new()
|
||||
with pytest.raises(
|
||||
RuntimeError, match=r"Could not find existing crypto parameters for"
|
||||
):
|
||||
with pytest.raises(RuntimeError, match=r"Function not found, name='main'"):
|
||||
engine.compile(mlir_input)
|
||||
|
||||
|
||||
|
||||
@@ -3,3 +3,4 @@ add_custom_target(ConcretelangUnitTests)
|
||||
add_subdirectory(ClientLib)
|
||||
add_subdirectory(SDFG)
|
||||
add_subdirectory(TestLib)
|
||||
add_subdirectory(Encodings)
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
add_custom_target(EncodingsUnitTests)
|
||||
|
||||
add_dependencies(ConcretelangUnitTests EncodingsUnitTests)
|
||||
|
||||
function(add_concretecompiler_lib_test test_name)
|
||||
add_unittest(EncodingsUnitTests ${test_name} ${ARGN})
|
||||
target_link_libraries(${test_name} PRIVATE ConcretelangSupport)
|
||||
set_source_files_properties(${ARGN} PROPERTIES COMPILE_FLAGS "-fno-rtti")
|
||||
endfunction()
|
||||
|
||||
if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
link_libraries(
|
||||
# usefull for old gcc versions
|
||||
-Wl,--allow-multiple-definition # static concrete-optimizer and concrete shares some code
|
||||
)
|
||||
endif()
|
||||
|
||||
if(CONCRETELANG_DATAFLOW_EXECUTION_ENABLED)
|
||||
add_compile_options(-DCONCRETELANG_DATAFLOW_TESTING_ENABLED)
|
||||
endif()
|
||||
|
||||
add_concretecompiler_lib_test(unit_tests_concretelang_Encodings Encodings_unit_tests.cpp)
|
||||
@@ -0,0 +1,98 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientLambda.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Encodings.h"
|
||||
#include "concretelang/TestLib/TestTypedLambda.h"
|
||||
|
||||
#include "tests_tools/GtestEnvironment.h"
|
||||
#include "tests_tools/assert.h"
|
||||
#include "tests_tools/keySetCache.h"
|
||||
|
||||
testing::Environment *const dfr_env =
|
||||
testing::AddGlobalTestEnvironment(new DFREnvironment);
|
||||
|
||||
const std::string FUNCNAME = "main";
|
||||
|
||||
using namespace concretelang::testlib;
|
||||
namespace encodings = mlir::concretelang::encodings;
|
||||
using concretelang::clientlib::scalar_in;
|
||||
using concretelang::clientlib::scalar_out;
|
||||
using concretelang::clientlib::tensor1_in;
|
||||
using concretelang::clientlib::tensor1_out;
|
||||
using concretelang::clientlib::tensor2_in;
|
||||
using concretelang::clientlib::tensor2_out;
|
||||
using concretelang::clientlib::tensor3_out;
|
||||
|
||||
mlir::concretelang::CompilerEngine::Library
|
||||
compile(std::string outputLib, std::string source,
|
||||
std::string funcname = FUNCNAME) {
|
||||
std::vector<std::string> sources = {source};
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
|
||||
mlir::concretelang::CompilationContext::createShared();
|
||||
mlir::concretelang::CompilerEngine ce{ccx};
|
||||
mlir::concretelang::CompilationOptions options(funcname);
|
||||
options.encodings = encodings::CircuitEncodings{
|
||||
{
|
||||
encodings::EncryptedIntegerScalarEncoding{3, false},
|
||||
encodings::EncryptedIntegerScalarEncoding{3, false},
|
||||
},
|
||||
{
|
||||
encodings::EncryptedIntegerScalarEncoding{3, false},
|
||||
}};
|
||||
options.v0Parameter = {2, 10, 693, 4, 9, 7, 2, std::nullopt};
|
||||
ce.setCompilationOptions(options);
|
||||
auto result = ce.compile(sources, outputLib);
|
||||
if (!result) {
|
||||
llvm::errs() << result.takeError();
|
||||
assert(false);
|
||||
}
|
||||
assert(result);
|
||||
return result.get();
|
||||
}
|
||||
|
||||
static const std::string CURRENT_FILE = __FILE__;
|
||||
static const std::string THIS_TEST_DIRECTORY =
|
||||
CURRENT_FILE.substr(0, CURRENT_FILE.find_last_of("/\\"));
|
||||
static const std::string OUT_DIRECTORY = "/tmp";
|
||||
|
||||
template <typename Info> std::string outputLibFromThis(Info *info) {
|
||||
return OUT_DIRECTORY + "/" + std::string(info->name());
|
||||
}
|
||||
|
||||
template <typename Lambda> Lambda load(std::string outputLib) {
|
||||
auto l = Lambda::load(FUNCNAME, outputLib, 0, 0, getTestKeySetCachePtr());
|
||||
assert(l.has_value());
|
||||
return l.value();
|
||||
}
|
||||
|
||||
TEST(Encodings_unit_tests, multi_key) {
|
||||
std::string source = R"(
|
||||
func.func @main(
|
||||
%arg0: !TFHE.glwe<sk<1,1,2048>>,
|
||||
%arg1: !TFHE.glwe<sk<2,1,2048>>
|
||||
) -> !TFHE.glwe<sk<2,1,2048>> {
|
||||
|
||||
%0 = "TFHE.keyswitch_glwe"(%arg0) {key=#TFHE.ksk<sk<1,1,2048>, sk<2, 1,2048>, 7, 2>} : (!TFHE.glwe<sk<1, 1, 2048>>) -> !TFHE.glwe<sk<2, 1, 2048>>
|
||||
%1 = "TFHE.add_glwe"(%arg1, %0) : (!TFHE.glwe<sk<2,1,2048>>, !TFHE.glwe<sk<2,1,2048>>) -> !TFHE.glwe<sk<2,1,2048>>
|
||||
return %1 : !TFHE.glwe<sk<2,1,2048>>
|
||||
|
||||
}
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
auto lambda =
|
||||
load<TestTypedLambda<scalar_out, scalar_in, scalar_in>>(outputLib);
|
||||
scalar_in a = 5;
|
||||
scalar_in b = 5;
|
||||
auto res = lambda.call(a, b);
|
||||
ASSERT_EQ_OUTCOME(res, (scalar_out)a + b);
|
||||
}
|
||||
Reference in New Issue
Block a user