feat(compiler): add support for multikey

This commit brings support for multiple secret keys in the TFHE
dialect. In particular, a parameterized `TFHE` circuit can now be
given as input, with any combination of (semantically valid) of
ks/bs/woppbs mixing different secret keys, and compiled down to a
valid executable function, with server keys properly looked up.

Secret keys are now stateful objects which can be:
-> none/unparameterized (syntax `sk?`): The keys are in state after
   the lowering from the `FHE` dialect.
-> parameterized (syntax `sk<identifier, polysize, dimension>`): The
   keys were parameterized, either by user or by the optimizer. The
   `identifier` field can be used to disambiguate two keys with same
   `polysize` and `dimension`.
-> normalized (syntax `sk[index]<polysize, dimension>`): The keys were
   attached to their index in the list of keys in the runtime context.

The _normalization_ of key indices also acts on the ksk, bsk and pksk,
which are given indices in the same spirit now.

Finally, in order to allow parameterized `TFHE` circuit to be given
as input and compiled down to executable functions, we added a way to
pass the encodings that are used to encode/decode the circuit
inputs/outputs. In the case of a compilation from the `FHE` dialect,
those informations are automatically extracted from the higher level
informations available in this dialect.
This commit is contained in:
aPere3
2023-03-27 15:37:56 +02:00
committed by Quentin Bourgerie
parent 823ea618af
commit cacffadbd2
66 changed files with 2329 additions and 747 deletions

View File

@@ -123,6 +123,9 @@ enum CompilationTarget {
ROUND_TRIP,
FHE,
TFHE,
PARAMETRIZED_TFHE,
NORMALIZED_TFHE,
BATCHED_TFHE,
CONCRETE,
STD,
LLVM,

View File

@@ -22,6 +22,7 @@
#include "concretelang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
#include "concretelang/Conversion/SDFGToStreamEmulator/Pass.h"
#include "concretelang/Conversion/TFHEGlobalParametrization/Pass.h"
#include "concretelang/Conversion/TFHEKeyNormalization/Pass.h"
#include "concretelang/Conversion/TFHEToConcrete/Pass.h"
#include "concretelang/Conversion/TracingToCAPI/Pass.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"

View File

@@ -32,6 +32,13 @@ def TFHEGlobalParametrization : Pass<"tfhe-global-parametrization", "mlir::Modul
let dependentDialects = ["mlir::concretelang::TFHE::TFHEDialect"];
}
def TFHEKeyNormalization : Pass<"tfhe-key-normalization", "mlir::ModuleOp"> {
let summary = "Ensures the key ids form proper ranges of indices";
let constructor = "mlir::concretelang::createTFHEKeyNormalizationPass()";
let options = [];
let dependentDialects = ["mlir::concretelang::TFHE::TFHEDialect"];
}
def TFHEToConcrete : Pass<"tfhe-to-concrete", "mlir::ModuleOp"> {
let summary = "Lowers operations from the TFHE dialect to Concrete";
let description = [{ Lowers operations from the TFHE dialect to Concrete }];

View File

@@ -0,0 +1,19 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CONVERSION_TFHEKEYNORMALIZATION_PASS_H_
#define CONCRETELANG_CONVERSION_TFHEKEYNORMALIZATION_PASS_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace concretelang {
/// Create a pass that ensures that the ids of the keys form a proper index
/// range.
std::unique_ptr<OperationPass<ModuleOp>> createTFHEKeyNormalizationPass();
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -191,7 +191,8 @@ def Concrete_BootstrapLweTensorOp : Concrete_Op<"bootstrap_lwe_tensor", [Pure]>
I32Attr:$polySize,
I32Attr:$level,
I32Attr:$baseLog,
I32Attr:$glweDimension
I32Attr:$glweDimension,
I32Attr:$bskIndex
);
let results = (outs Concrete_LweTensor:$result);
}
@@ -207,7 +208,8 @@ def Concrete_BootstrapLweBufferOp : Concrete_Op<"bootstrap_lwe_buffer"> {
I32Attr:$polySize,
I32Attr:$level,
I32Attr:$baseLog,
I32Attr:$glweDimension
I32Attr:$glweDimension,
I32Attr:$bskIndex
);
}
@@ -221,7 +223,8 @@ def Concrete_BatchedBootstrapLweTensorOp : Concrete_Op<"batched_bootstrap_lwe_te
I32Attr:$polySize,
I32Attr:$level,
I32Attr:$baseLog,
I32Attr:$glweDimension
I32Attr:$glweDimension,
I32Attr:$bskIndex
);
let results = (outs Concrete_BatchLweTensor:$result);
}
@@ -237,7 +240,8 @@ def Concrete_BatchedBootstrapLweBufferOp : Concrete_Op<"batched_bootstrap_lwe_bu
I32Attr:$polySize,
I32Attr:$level,
I32Attr:$baseLog,
I32Attr:$glweDimension
I32Attr:$glweDimension,
I32Attr:$bskIndex
);
}
@@ -250,7 +254,8 @@ def Concrete_KeySwitchLweTensorOp : Concrete_Op<"keyswitch_lwe_tensor", [Pure]>
I32Attr:$level,
I32Attr:$baseLog,
I32Attr:$lwe_dim_in,
I32Attr:$lwe_dim_out
I32Attr:$lwe_dim_out,
I32Attr:$kskIndex
);
let results = (outs Concrete_LweTensor:$result);
}
@@ -264,7 +269,8 @@ def Concrete_KeySwitchLweBufferOp : Concrete_Op<"keyswitch_lwe_buffer"> {
I32Attr:$level,
I32Attr:$baseLog,
I32Attr:$lwe_dim_in,
I32Attr:$lwe_dim_out
I32Attr:$lwe_dim_out,
I32Attr:$kskIndex
);
}
@@ -277,7 +283,8 @@ def Concrete_BatchedKeySwitchLweTensorOp : Concrete_Op<"batched_keyswitch_lwe_te
I32Attr:$level,
I32Attr:$baseLog,
I32Attr:$lwe_dim_in,
I32Attr:$lwe_dim_out
I32Attr:$lwe_dim_out,
I32Attr:$kskIndex
);
let results = (outs Concrete_BatchLweTensor:$result);
}
@@ -291,7 +298,8 @@ def Concrete_BatchedKeySwitchLweBufferOp : Concrete_Op<"batched_keyswitch_lwe_bu
I32Attr:$level,
I32Attr:$baseLog,
I32Attr:$lwe_dim_in,
I32Attr:$lwe_dim_out
I32Attr:$lwe_dim_out,
I32Attr:$kskIndex
);
}
@@ -313,7 +321,11 @@ def Concrete_WopPBSCRTLweTensorOp : Concrete_Op<"wop_pbs_crt_lwe_tensor", [Pure]
// Circuit bootstrap parameters
I32Attr : $circuitBootstrapLevel,
I32Attr : $circuitBootstrapBaseLog,
I64ArrayAttr:$crtDecomposition
I64ArrayAttr:$crtDecomposition,
// Key indices
I32Attr:$kskIndex,
I32Attr:$bskIndex,
I32Attr:$pkskIndex
);
let results = (outs Concrete_LweCRTTensor:$result);
}
@@ -337,7 +349,11 @@ def Concrete_WopPBSCRTLweBufferOp : Concrete_Op<"wop_pbs_crt_lwe_buffer"> {
// Circuit bootstrap parameters
I32Attr : $circuitBootstrapLevel,
I32Attr : $circuitBootstrapBaseLog,
I64ArrayAttr:$crtDecomposition
I64ArrayAttr:$crtDecomposition,
// Key indices
I32Attr:$kskIndex,
I32Attr:$bskIndex,
I32Attr:$pkskIndex
);
}

View File

@@ -27,10 +27,11 @@ def TFHE_KeyswitchKeyAttr: TFHE_Attr<"GLWEKeyswitchKey", "ksk"> {
"mlir::concretelang::TFHE::GLWESecretKey":$inputKey,
"mlir::concretelang::TFHE::GLWESecretKey":$outputKey,
"int":$levels,
"int":$baseLog
"int":$baseLog,
DefaultValuedParameter<"int", "-1">: $index
);
let assemblyFormat = "`<` $inputKey `,` $outputKey `,` $levels `,` $baseLog `>`";
let assemblyFormat = " (`[` $index^ `]`)? `<` $inputKey `,` $outputKey `,` $levels `,` $baseLog `>`";
}
def TFHE_BootstrapKeyAttr: TFHE_Attr<"GLWEBootstrapKey", "bsk"> {
@@ -43,10 +44,11 @@ def TFHE_BootstrapKeyAttr: TFHE_Attr<"GLWEBootstrapKey", "bsk"> {
"int":$polySize,
"int":$glweDim,
"int":$levels,
"int":$baseLog
"int":$baseLog,
DefaultValuedParameter<"int", "-1">: $index
);
let assemblyFormat = "`<` $inputKey `,` $outputKey `,` $polySize `,` $glweDim `,` $levels `,` $baseLog `>`";
let assemblyFormat = "(`[` $index^ `]`)? `<` $inputKey `,` $outputKey `,` $polySize `,` $glweDim `,` $levels `,` $baseLog `>`";
}
def TFHE_PackingKeyswitchKeyAttr: TFHE_Attr<"GLWEPackingKeyswitchKey", "pksk"> {
@@ -58,11 +60,13 @@ def TFHE_PackingKeyswitchKeyAttr: TFHE_Attr<"GLWEPackingKeyswitchKey", "pksk"> {
"mlir::concretelang::TFHE::GLWESecretKey":$outputKey,
"int" : $outputPolySize,
"int" : $inputLweDim,
"int" : $glweDim,
"int" : $levels,
"int" : $baseLog
"int" : $baseLog,
DefaultValuedParameter<"int", "-1">: $index
);
let assemblyFormat = "`<` $inputKey `,` $outputKey`,` $outputPolySize`,` $inputLweDim `,` $levels `,` $baseLog `>`";
let assemblyFormat = " (`[` $index^ `]` )? `<` $inputKey `,` $outputKey`,` $outputPolySize`,` $inputLweDim `,` $glweDim `,` $levels `,` $baseLog `>`";
}

View File

@@ -11,48 +11,63 @@
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/DialectImplementation.h>
#include <variant>
namespace mlir {
namespace concretelang {
namespace TFHE {
/// A type parameter representing GLWE secret key.
///
/// A glwe secret key is basically a glwe dimension, a polynomial size, and an
/// id that makes it possible to disambiguate potential keys with with same
/// parameters.
///
/// Note that a key can be instantiated to a `none` key, to serve as a
/// placeholder in the IR. In this case, none of its data are actually usable
/// for lowering to the `Concrete` dialect. Once the
/// `TFHEGlobalParameterization` was performed, there should remain no such
/// `none` keys in the IR.
class GLWESecretKey {
public:
/// Creates a new none key.
GLWESecretKey();
/// Create a new key from parameters.
GLWESecretKey(int64_t dimension, int64_t polySize, int64_t id);
bool operator==(GLWESecretKey other);
bool operator==(const GLWESecretKey other) const;
bool operator!=(GLWESecretKey other);
/// Returns the dimension associated with this key, if the key is not none.
mlir::Optional<int64_t> getDimension() const;
/// Returns the polynomial size associated with this key, if the key is not
/// none.
mlir::Optional<int64_t> getPolySize() const;
/// Returns the id associated with this key, if the key is not none.
mlir::Optional<int64_t> getId() const;
/// Returns true if the key was not filled with valid parameters.
bool isNotParameterized() const;
/// A placeholder.
struct GLWESecretKeyNone {};
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
const GLWESecretKeyNone sk);
private:
int64_t dimension;
int64_t polySize;
int64_t id;
// The key was parameterized.
struct GLWESecretKeyParameterized {
uint64_t dimension;
uint64_t polySize;
uint64_t identifier;
bool operator==(const GLWESecretKeyParameterized other) const;
};
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
const GLWESecretKeyParameterized sk);
// The key was normalized
struct GLWESecretKeyNormalized {
uint64_t dimension;
uint64_t polySize;
uint64_t index;
bool operator==(const GLWESecretKeyNormalized other) const;
};
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
const GLWESecretKeyNormalized sk);
/// A sum type parameter representing GLWE secret keys in different states.
struct GLWESecretKey {
std::variant<GLWESecretKeyNone, GLWESecretKeyParameterized,
GLWESecretKeyNormalized>
inner;
static GLWESecretKey newNone();
static GLWESecretKey newParameterized(uint64_t dimension, uint64_t polySize,
uint64_t identifier);
static GLWESecretKey newNormalized(uint64_t dimension, uint64_t polySize,
uint64_t index);
bool operator==(const GLWESecretKey other) const;
bool operator!=(const GLWESecretKey other) const;
template <typename V> bool is();
bool isNone();
bool isParameterized();
bool isNormalized();
template <typename V> std::optional<V> get();
std::optional<GLWESecretKeyNone> getNone();
std::optional<GLWESecretKeyParameterized> getParameterized();
std::optional<GLWESecretKeyNormalized> getNormalized();
};
llvm::hash_code hash_value(const GLWESecretKey &key);
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const GLWESecretKey sk);
} // namespace TFHE
} // namespace concretelang
} // namespace mlir

View File

@@ -25,13 +25,6 @@ def TFHE_GLWECipherTextType
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = true;
let extraClassDeclaration = [{
// Returns true if has an unparametrized parameters
bool hasUnparametrizedParameters() {
return getKey().isNotParameterized();
};
}];
}
#endif

View File

@@ -36,11 +36,11 @@ void stream_emulator_make_memref_negate_lwe_ciphertext_u64_process(void *dfg,
void stream_emulator_make_memref_keyswitch_lwe_u64_process(
void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log,
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size,
void *context);
uint32_t ksk_index, void *context);
void stream_emulator_make_memref_bootstrap_lwe_u64_process(
void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim,
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
uint32_t output_size, void *context);
uint32_t output_size, uint32_t bsk_index, void *context);
void *stream_emulator_make_uint64_stream(const char *name, stream_type stype);
void stream_emulator_put_uint64(void *stream, uint64_t e);

View File

@@ -82,6 +82,7 @@ void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
uint64_t ct0_size, uint64_t ct0_stride,
uint32_t level, uint32_t base_log,
uint32_t input_lwe_dim, uint32_t output_lwe_dim,
uint32_t ksk_index,
mlir::concretelang::RuntimeContext *context);
void memref_batched_keyswitch_lwe_u64(
@@ -91,7 +92,7 @@ void memref_batched_keyswitch_lwe_u64(
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint32_t level,
uint32_t base_log, uint32_t input_lwe_dim, uint32_t output_lwe_dim,
mlir::concretelang::RuntimeContext *context);
uint32_t ksk_index, mlir::concretelang::RuntimeContext *context);
void *memref_keyswitch_async_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
@@ -99,17 +100,15 @@ void *memref_keyswitch_async_lwe_u64(
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, mlir::concretelang::RuntimeContext *context);
void memref_bootstrap_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
uint64_t out_offset, uint64_t out_size,
uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset,
uint64_t ct0_size, uint64_t ct0_stride,
uint64_t *tlu_allocated, uint64_t *tlu_aligned,
uint64_t tlu_offset, uint64_t tlu_size,
uint64_t tlu_stride, uint32_t input_lwe_dim,
uint32_t poly_size, uint32_t level,
uint32_t base_log, uint32_t glwe_dim,
mlir::concretelang::RuntimeContext *context);
void memref_bootstrap_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
mlir::concretelang::RuntimeContext *context);
void memref_batched_bootstrap_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
@@ -119,7 +118,7 @@ void memref_batched_bootstrap_lwe_u64(
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size,
uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size,
uint32_t level, uint32_t base_log, uint32_t glwe_dim,
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
mlir::concretelang::RuntimeContext *context);
void *memref_bootstrap_async_lwe_u64(
@@ -129,7 +128,7 @@ void *memref_bootstrap_async_lwe_u64(
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
uint32_t base_log, uint32_t glwe_dim,
uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
mlir::concretelang::RuntimeContext *context);
void memref_await_future(uint64_t *out_allocated, uint64_t *out_aligned,
@@ -163,6 +162,8 @@ void memref_wop_pbs_crt_buffer(
uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count,
uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log,
uint32_t polynomial_size,
// Key indices
uint32_t ksk_index, uint32_t bsk_index, uint32_t pksk_index,
// runtime context that hold evluation keys
mlir::concretelang::RuntimeContext *context);
@@ -183,7 +184,7 @@ void memref_keyswitch_lwe_cuda_u64(
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint32_t level, uint32_t base_log,
uint32_t input_lwe_dim, uint32_t output_lwe_dim,
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t ksk_index,
mlir::concretelang::RuntimeContext *context);
/// \brief Run bootstrapping on GPU.
@@ -197,7 +198,7 @@ void memref_bootstrap_lwe_cuda_u64(
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
uint32_t base_log, uint32_t glwe_dim,
uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
mlir::concretelang::RuntimeContext *context);
// Batched CUDA function //////////////////////////////////////////////////////
@@ -209,7 +210,7 @@ void memref_batched_keyswitch_lwe_cuda_u64(
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint32_t level,
uint32_t base_log, uint32_t input_lwe_dim, uint32_t output_lwe_dim,
mlir::concretelang::RuntimeContext *context);
uint32_t ksk_index, mlir::concretelang::RuntimeContext *context);
void memref_batched_bootstrap_lwe_cuda_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
@@ -219,7 +220,7 @@ void memref_batched_bootstrap_lwe_cuda_u64(
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size,
uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size,
uint32_t level, uint32_t base_log, uint32_t glwe_dim,
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
mlir::concretelang::RuntimeContext *context);
// Tracing ////////////////////////////////////////////////////////////////////

View File

@@ -10,6 +10,7 @@
#include <mlir/IR/BuiltinOps.h>
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/Support/Encodings.h"
#include "concretelang/Support/V0Parameters.h"
namespace mlir {
@@ -19,9 +20,10 @@ using ::concretelang::clientlib::ChunkInfo;
using ::concretelang::clientlib::ClientParameters;
llvm::Expected<ClientParameters>
createClientParametersForV0(V0FHEContext context, llvm::StringRef functionName,
mlir::ModuleOp module, int bitsOfSecurity,
llvm::Optional<ChunkInfo> chunkInfo = std::nullopt);
createClientParametersFromTFHE(mlir::ModuleOp module,
llvm::StringRef functionName, int bitsOfSecurity,
encodings::CircuitEncodings encodings,
std::optional<CRTDecomposition> maybeCrt);
} // namespace concretelang
} // namespace mlir

View File

@@ -7,7 +7,8 @@
#define CONCRETELANG_SUPPORT_COMPILER_ENGINE_H
#include <concretelang/Conversion/Utils/GlobalFHEContext.h>
#include <concretelang/Support/V0ClientParameters.h>
#include <concretelang/Support/ClientParametersGeneration.h>
#include <concretelang/Support/Encodings.h>
#include <llvm/IR/Module.h>
#include <llvm/Support/Error.h>
#include <llvm/Support/SourceMgr.h>
@@ -78,6 +79,10 @@ struct CompilationOptions {
unsigned int chunkSize;
unsigned int chunkWidth;
/// When compiling from a dialect lower than FHE, one needs to provide
/// encodings info manually to allow the client lib to be generated.
std::optional<mlir::concretelang::encodings::CircuitEncodings> encodings;
CompilationOptions()
: v0FHEConstraints(std::nullopt), verifyDiagnostics(false),
autoParallelize(false), loopParallelize(false), batchTFHEOps(false),
@@ -85,7 +90,7 @@ struct CompilationOptions {
dataflowParallelize(false), optimizeTFHE(true), emitGPUOps(false),
clientParametersFuncName(std::nullopt),
optimizerConfig(optimizer::DEFAULT_CONFIG), chunkIntegers(false),
chunkSize(4), chunkWidth(2){};
chunkSize(4), chunkWidth(2), encodings(std::nullopt){};
CompilationOptions(std::string funcname) : CompilationOptions() {
clientParametersFuncName = funcname;
@@ -204,7 +209,7 @@ public:
/// and scf loops
FHE_NO_LINALG,
/// Read sources and lower all FHE operations to TFHE
/// Read sources and lower all FHE operations to unparameterized TFHE
/// operations
TFHE,
@@ -215,6 +220,10 @@ public:
/// Batch TFHE operations
BATCHED_TFHE,
/// Read sources and lower all FHE operations to normalized TFHE
/// operations
NORMALIZED_TFHE,
/// Read sources and lower all FHE and TFHE operations to Concrete
/// operations
CONCRETE,

View File

@@ -0,0 +1,141 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SUPPORT_ENCODINGS_H_
#define CONCRETELANG_SUPPORT_ENCODINGS_H_
#include <map>
#include <optional>
#include <string>
#include <vector>
#include "boost/outcome.h"
#include <llvm/ADT/Optional.h>
#include <llvm/ADT/STLExtras.h>
#include <llvm/Support/Error.h>
#include <llvm/Support/JSON.h>
#include <llvm/Support/raw_ostream.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
namespace mlir {
namespace concretelang {
namespace encodings {
/// Represents the encoding of a small (unchunked) `FHE::eint` type.
struct EncryptedIntegerScalarEncoding {
uint64_t width;
bool isSigned;
};
bool fromJSON(const llvm::json::Value, EncryptedIntegerScalarEncoding &,
llvm::json::Path);
llvm::json::Value toJSON(const EncryptedIntegerScalarEncoding &);
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
EncryptedIntegerScalarEncoding e) {
return OS << llvm::formatv("{0:2}", toJSON(e));
}
/// Represents the encoding of a big (chunked) `FHE::eint` type.
struct EncryptedChunkedIntegerScalarEncoding {
uint64_t width;
bool isSigned;
uint64_t chunkSize;
uint64_t chunkWidth;
};
bool fromJSON(const llvm::json::Value, EncryptedChunkedIntegerScalarEncoding &,
llvm::json::Path);
llvm::json::Value toJSON(const EncryptedChunkedIntegerScalarEncoding &);
static inline llvm::raw_ostream &
operator<<(llvm::raw_ostream &OS, EncryptedChunkedIntegerScalarEncoding e) {
return OS << llvm::formatv("{0:2}", toJSON(e));
}
/// Represents the encoding of a `FHE::ebool` type.
struct EncryptedBoolScalarEncoding {};
bool fromJSON(const llvm::json::Value, EncryptedBoolScalarEncoding &,
llvm::json::Path);
llvm::json::Value toJSON(const EncryptedBoolScalarEncoding &);
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
EncryptedBoolScalarEncoding e) {
return OS << llvm::formatv("{0:2}", toJSON(e));
}
/// Represents the encoding of a builtin integer type.
struct PlaintextScalarEncoding {
uint64_t width;
};
bool fromJSON(const llvm::json::Value, PlaintextScalarEncoding &,
llvm::json::Path);
llvm::json::Value toJSON(const PlaintextScalarEncoding &);
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
PlaintextScalarEncoding e) {
return OS << llvm::formatv("{0:2}", toJSON(e));
}
/// Represents the encoding of a builtin index type.
struct IndexScalarEncoding {};
bool fromJSON(const llvm::json::Value, IndexScalarEncoding &, llvm::json::Path);
llvm::json::Value toJSON(const IndexScalarEncoding &);
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
IndexScalarEncoding e) {
return OS << llvm::formatv("{0:2}", toJSON(e));
}
/// Represents the encoding of a scalar value.
using ScalarEncoding = std::variant<
EncryptedIntegerScalarEncoding, EncryptedChunkedIntegerScalarEncoding,
EncryptedBoolScalarEncoding, PlaintextScalarEncoding, IndexScalarEncoding>;
bool fromJSON(const llvm::json::Value, ScalarEncoding &, llvm::json::Path);
llvm::json::Value toJSON(const ScalarEncoding &);
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
ScalarEncoding e) {
return OS << llvm::formatv("{0:2}", toJSON(e));
}
/// Represents the encoding of a tensor value.
struct TensorEncoding {
ScalarEncoding scalarEncoding;
};
bool fromJSON(const llvm::json::Value, TensorEncoding &, llvm::json::Path);
llvm::json::Value toJSON(const TensorEncoding &);
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
TensorEncoding e) {
return OS << llvm::formatv("{0:2}", toJSON(e));
}
/// Represents the encoding of either an input or output value of a circuit.
using Encoding = std::variant<TensorEncoding, ScalarEncoding>;
bool fromJSON(const llvm::json::Value, Encoding &, llvm::json::Path);
llvm::json::Value toJSON(const Encoding &);
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, Encoding e) {
return OS << llvm::formatv("{0:2}", toJSON(e));
}
/// Represents the encodings of a circuit.
struct CircuitEncodings {
std::vector<Encoding> inputEncodings;
std::vector<Encoding> outputEncodings;
};
bool fromJSON(const llvm::json::Value, CircuitEncodings &, llvm::json::Path);
llvm::json::Value toJSON(const CircuitEncodings &);
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
CircuitEncodings e) {
return OS << llvm::formatv("{0:2}", toJSON(e));
}
llvm::Expected<CircuitEncodings> getCircuitEncodings(
llvm::StringRef functionName, mlir::ModuleOp module,
std::optional<::concretelang::clientlib::ChunkInfo> maybeChunkInfo);
} // namespace encodings
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -67,6 +67,10 @@ mlir::LogicalResult batchTFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
normalizeTFHEKeys(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);

View File

@@ -0,0 +1,37 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SUPPORT_TFHECIRCUITKEYS_H_
#define CONCRETELANG_SUPPORT_TFHECIRCUITKEYS_H_
#include "concretelang/Dialect/TFHE/IR/TFHEAttrs.h"
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHEParameters.h"
namespace mlir {
namespace concretelang {
namespace TFHE {
struct TFHECircuitKeys {
llvm::SmallVector<TFHE::GLWESecretKey, 10> secretKeys;
llvm::SmallVector<TFHE::GLWEBootstrapKeyAttr, 10> bootstrapKeys;
llvm::SmallVector<TFHE::GLWEKeyswitchKeyAttr, 10> keyswitchKeys;
llvm::SmallVector<TFHE::GLWEPackingKeyswitchKeyAttr, 10> packingKeyswitchKeys;
std::optional<uint64_t> getSecretKeyIndex(TFHE::GLWESecretKey key);
std::optional<uint64_t> getKeyswitchKeyIndex(TFHE::GLWEKeyswitchKeyAttr key);
std::optional<uint64_t> getBootstrapKeyIndex(TFHE::GLWEBootstrapKeyAttr key);
std::optional<uint64_t>
getPackingKeyswitchKeyIndex(TFHE::GLWEPackingKeyswitchKeyAttr key);
};
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const TFHECircuitKeys cks);
TFHECircuitKeys extractCircuitKeys(mlir::ModuleOp moduleOp);
} // namespace TFHE
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -13,6 +13,7 @@
#include <concretelang/Runtime/context.h>
#include <concretelang/ServerLib/ServerLambda.h>
#include <concretelang/Support/Error.h>
#include <llvm/ADT/SmallVector.h>
namespace concretelang {
@@ -112,6 +113,17 @@ invokeRawOnLambda(Lambda *lambda, clientlib::ClientParameters clientParameters,
return clientlib::PublicResult::fromBuffers(clientParameters,
std::move(buffers));
}
template <typename V, unsigned int N>
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
const llvm::SmallVector<V, N> vect) {
OS << "[";
for (auto v : vect) {
OS << v << ",";
}
OS << "]";
return OS;
}
} // namespace concretelang
#endif

View File

@@ -0,0 +1,12 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SUPPORT_VARIANTS_H_
#define CONCRETELANG_SUPPORT_VARIANTS_H_
template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
template <class... Ts> overloaded(Ts...) -> overloaded<Ts...>;
#endif

View File

@@ -131,6 +131,12 @@ llvm::Expected<mlir::concretelang::CompilerEngine::
return mlir::concretelang::CompilerEngine::Target::FHE;
case TFHE:
return mlir::concretelang::CompilerEngine::Target::TFHE;
case PARAMETRIZED_TFHE:
return mlir::concretelang::CompilerEngine::Target::PARAMETRIZED_TFHE;
case NORMALIZED_TFHE:
return mlir::concretelang::CompilerEngine::Target::NORMALIZED_TFHE;
case BATCHED_TFHE:
return mlir::concretelang::CompilerEngine::Target::BATCHED_TFHE;
case CONCRETE:
return mlir::concretelang::CompilerEngine::Target::CONCRETE;
case STD:

View File

@@ -1,6 +1,7 @@
add_subdirectory(FHEToTFHEScalar)
add_subdirectory(FHEToTFHECrt)
add_subdirectory(TFHEGlobalParametrization)
add_subdirectory(TFHEKeyNormalization)
add_subdirectory(TFHEToConcrete)
add_subdirectory(FHETensorOpsToLinalg)
add_subdirectory(TracingToCAPI)

View File

@@ -106,18 +106,20 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
{memref1DType, memref1DType}, {});
} else if (funcName == memref_keyswitch_lwe_u64 ||
funcName == memref_keyswitch_lwe_cuda_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref1DType, memref1DType, i32Type,
i32Type, i32Type, i32Type, contextType},
{});
funcType =
mlir::FunctionType::get(rewriter.getContext(),
{memref1DType, memref1DType, i32Type, i32Type,
i32Type, i32Type, i32Type, contextType},
{});
} else if (funcName == memref_bootstrap_lwe_u64 ||
funcName == memref_bootstrap_lwe_cuda_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref1DType, memref1DType,
memref1DType, i32Type, i32Type, i32Type,
i32Type, i32Type, contextType},
i32Type, i32Type, i32Type, contextType},
{});
} else if (funcName == memref_keyswitch_async_lwe_u64) {
// Todo Answer this question: Isn't it dead ?
funcType = mlir::FunctionType::get(
rewriter.getContext(), {memref1DType, memref1DType, contextType},
{futureType});
@@ -125,20 +127,21 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref1DType, memref1DType,
memref1DType, i32Type, i32Type, i32Type,
i32Type, i32Type, contextType},
i32Type, i32Type, i32Type, contextType},
{futureType});
} else if (funcName == memref_batched_keyswitch_lwe_u64 ||
funcName == memref_batched_keyswitch_lwe_cuda_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref2DType, memref2DType, i32Type,
i32Type, i32Type, i32Type, contextType},
{});
funcType =
mlir::FunctionType::get(rewriter.getContext(),
{memref2DType, memref2DType, i32Type, i32Type,
i32Type, i32Type, i32Type, contextType},
{});
} else if (funcName == memref_batched_bootstrap_lwe_u64 ||
funcName == memref_batched_bootstrap_lwe_cuda_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref2DType, memref2DType,
memref1DType, i32Type, i32Type, i32Type,
i32Type, i32Type, contextType},
i32Type, i32Type, i32Type, contextType},
{});
} else if (funcName == memref_await_future) {
funcType = mlir::FunctionType::get(
@@ -171,6 +174,9 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
rewriter.getI32Type(),
rewriter.getI32Type(),
rewriter.getI32Type(),
rewriter.getI32Type(),
rewriter.getI32Type(),
rewriter.getI32Type(),
contextType,
},
{});
@@ -273,6 +279,9 @@ void keyswitchAddOperands(KeySwitchOp op,
// lwe_dim_out
operands.push_back(
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getLweDimOutAttr()));
// ksk_index
operands.push_back(
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getKskIndexAttr()));
// context
operands.push_back(getContextArgument(op));
}
@@ -296,6 +305,9 @@ void bootstrapAddOperands(BootstrapOp op,
// glwe_dim
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.getGlweDimensionAttr()));
// bsk_index
operands.push_back(
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getBskIndexAttr()));
// context
operands.push_back(getContextArgument(op));
}
@@ -354,6 +366,15 @@ void wopPBSAddOperands(Concrete::WopPBSCRTLweBufferOp op,
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.getPackingKeySwitchoutputPolynomialSizeAttr()));
// ksk_index
operands.push_back(
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getKskIndexAttr()));
// bsk_index
operands.push_back(
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getBskIndexAttr()));
// pksk_index
operands.push_back(
rewriter.create<arith::ConstantOp>(op.getLoc(), op.getPkskIndexAttr()));
// context
operands.push_back(getContextArgument(op));
}

View File

@@ -580,12 +580,13 @@ struct ApplyLookupTableEintOpPattern
op.getLoc(), converter->convertType(op.getType()), adaptor.getA(),
newLut,
TFHE::GLWEKeyswitchKeyAttr::get(op.getContext(), TFHE::GLWESecretKey(),
TFHE::GLWESecretKey(), -1, -1),
TFHE::GLWESecretKey(), -1, -1, -1),
TFHE::GLWEBootstrapKeyAttr::get(op.getContext(), TFHE::GLWESecretKey(),
TFHE::GLWESecretKey(), -1, -1, -1, -1),
TFHE::GLWESecretKey(), -1, -1, -1, -1,
-1),
TFHE::GLWEPackingKeyswitchKeyAttr::get(
op.getContext(), TFHE::GLWESecretKey(), TFHE::GLWESecretKey(), -1,
-1, -1, -1),
-1, -1, -1, -1, -1),
rewriter.getI64ArrayAttr({}), rewriter.getI32IntegerAttr(-1),
rewriter.getI32IntegerAttr(-1));

View File

@@ -369,13 +369,14 @@ struct ApplyLookupTableEintOpPattern
op.getLoc(), getTypeConverter()->convertType(adaptor.getA().getType()),
input,
TFHE::GLWEKeyswitchKeyAttr::get(op.getContext(), TFHE::GLWESecretKey(),
TFHE::GLWESecretKey(), -1, -1));
TFHE::GLWESecretKey(), -1, -1, -1));
// Insert bootstrap
rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
op, getTypeConverter()->convertType(op.getType()), ksOp, newLut,
TFHE::GLWEBootstrapKeyAttr::get(op.getContext(), TFHE::GLWESecretKey(),
TFHE::GLWESecretKey(), -1, -1, -1, -1));
TFHE::GLWESecretKey(), -1, -1, -1, -1,
-1));
return mlir::success();
};
@@ -557,12 +558,12 @@ struct RoundEintOpPattern : public ScalarOpPattern<FHE::RoundEintOp> {
op.getLoc(), truncationInputTy, shiftedRotatedInput,
TFHE::GLWEKeyswitchKeyAttr::get(op->getContext(),
TFHE::GLWESecretKey(),
TFHE::GLWESecretKey(), -1, -1));
TFHE::GLWESecretKey(), -1, -1, -1));
mlir::Value bootstrapped = rewriter.create<TFHE::BootstrapGLWEOp>(
op.getLoc(), truncationInputTy, keyswitched, lut,
TFHE::GLWEBootstrapKeyAttr::get(
op->getContext(), TFHE::GLWESecretKey(), TFHE::GLWESecretKey(),
-1, -1, -1, -1));
-1, -1, -1, -1, -1));
//------------------------------------------------------------- CORRECTION
// The correction is performed to achieve our right shift semantic.

View File

@@ -200,6 +200,9 @@ struct LowerSDFGMakeProcess
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(),
mpOp->getAttrOfType<mlir::IntegerAttr>("output_size")));
// ksk_index
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(), mpOp->getAttrOfType<mlir::IntegerAttr>("kskIndex")));
// context
operands.push_back(getContextArgument(mpOp));
break;
@@ -226,6 +229,9 @@ struct LowerSDFGMakeProcess
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(),
mpOp->getAttrOfType<mlir::IntegerAttr>("output_size")));
// bsk_index
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(), mpOp->getAttrOfType<mlir::IntegerAttr>("bskIndex")));
// context
operands.push_back(getContextArgument(mpOp));
break;

View File

@@ -35,7 +35,7 @@ struct TFHEGlobalParametrizationPass
using mlir::concretelang::TFHE::GLWECipherTextType;
/// TFHEGlobalParametrizationTypeConverter is a TypeConverter that transform
/// `TFHE.glwe<sk[?]>` to
/// `TFHE.glwe<sk?>` to
/// `TFHE.glwe<sk[id]<glweDimension,polynomialSize>>`
class TFHEGlobalParametrizationTypeConverter : public mlir::TypeConverter {
@@ -44,11 +44,16 @@ public:
mlir::concretelang::V0Parameter &cryptoParameters)
: cryptoParameters(cryptoParameters) {
addConversion([](mlir::Type type) { return type; });
addConversion(
[&](GLWECipherTextType type) { return this->glweInterPBSType(type); });
addConversion([&](GLWECipherTextType type) {
if (type.getKey().isNone()) {
return this->glweInterPBSType(type);
} else {
return type;
}
});
addConversion([&](mlir::RankedTensorType type) {
auto glwe = type.getElementType().dyn_cast_or_null<GLWECipherTextType>();
if (glwe == nullptr) {
if (glwe == nullptr || !glwe.getKey().isNone()) {
return (mlir::Type)(type);
}
mlir::Type r = mlir::RankedTensorType::get(type.getShape(),
@@ -70,11 +75,9 @@ public:
TFHE::GLWESecretKey getInterPBSKey() {
auto dimension = cryptoParameters.getNBigLweDimension();
auto polynomialSize = 1;
// Warning, for now we use hardcoded ids. Later on, we expect the optimizer
// to give the id.
auto id = 1;
return mlir::concretelang::TFHE::GLWESecretKey(dimension, polynomialSize,
id);
auto identifier = 0;
return mlir::concretelang::TFHE::GLWESecretKey::newParameterized(
dimension, polynomialSize, identifier);
}
TFHE::GLWECipherTextType glweInterPBSType(GLWECipherTextType &type) {
@@ -84,11 +87,9 @@ public:
TFHE::GLWESecretKey getIntraPBSKey() {
auto dimension = cryptoParameters.nSmall;
auto polynomialSize = 1;
// Warning, for now we use hardcoded ids. Later on, we expect the optimizer
// to give the id.
auto id = 3;
return mlir::concretelang::TFHE::GLWESecretKey(dimension, polynomialSize,
id);
auto identifier = 1;
return mlir::concretelang::TFHE::GLWESecretKey::newParameterized(
dimension, polynomialSize, identifier);
}
TFHE::GLWECipherTextType glweIntraPBSType(GLWECipherTextType &type) {
@@ -121,7 +122,7 @@ struct KeySwitchGLWEOpPattern
auto newOutputKey = converter.getIntraPBSKey();
auto keyswitchKey = TFHE::GLWEKeyswitchKeyAttr::get(
ksOp->getContext(), newInputKey, newOutputKey, cryptoParameters.ksLevel,
cryptoParameters.ksLogBase);
cryptoParameters.ksLogBase, -1);
auto newOp = rewriter.replaceOpWithNewOp<TFHE::KeySwitchGLWEOp>(
ksOp, newOutputTy, ksOp.getCiphertext(), keyswitchKey);
rewriter.startRootUpdate(newOp);
@@ -159,7 +160,7 @@ struct BootstrapGLWEOpPattern
auto bootstrapKey = TFHE::GLWEBootstrapKeyAttr::get(
bsOp->getContext(), newInputKey, newOutputKey,
cryptoParameters.getPolynomialSize(), cryptoParameters.glweDimension,
cryptoParameters.brLevel, cryptoParameters.brLogBase);
cryptoParameters.brLevel, cryptoParameters.brLogBase, -1);
auto newOp = rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
bsOp, newOutputTy, bsOp.getCiphertext(), bsOp.getLookupTable(),
bootstrapKey);
@@ -196,19 +197,20 @@ struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
auto intraKey = converter.getIntraPBSKey();
auto keyswitchKey = TFHE::GLWEKeyswitchKeyAttr::get(
wopPBSOp->getContext(), interKey, intraKey, cryptoParameters.ksLevel,
cryptoParameters.ksLogBase);
cryptoParameters.ksLogBase, -1);
auto bootstrapKey = TFHE::GLWEBootstrapKeyAttr::get(
wopPBSOp->getContext(), intraKey, interKey,
cryptoParameters.getPolynomialSize(), cryptoParameters.glweDimension,
cryptoParameters.brLevel, cryptoParameters.brLogBase);
cryptoParameters.brLevel, cryptoParameters.brLogBase, -1);
auto packingKeyswitchKey = TFHE::GLWEPackingKeyswitchKeyAttr::get(
wopPBSOp->getContext(), interKey, interKey,
cryptoParameters.largeInteger->wopPBS.packingKeySwitch
.outputPolynomialSize,
cryptoParameters.largeInteger->wopPBS.packingKeySwitch
.inputLweDimension,
cryptoParameters.glweDimension,
cryptoParameters.largeInteger->wopPBS.packingKeySwitch.level,
cryptoParameters.largeInteger->wopPBS.packingKeySwitch.baseLog);
cryptoParameters.largeInteger->wopPBS.packingKeySwitch.baseLog, -1);
auto newOp = rewriter.replaceOpWithNewOp<TFHE::WopPBSGLWEOp>(
wopPBSOp, newOutputType, wopPBSOp.getCiphertexts(),
wopPBSOp.getLookupTable(), keyswitchKey, bootstrapKey,
@@ -290,21 +292,29 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
patterns, converter);
// Parametrize keyswitch bootstrap
// Parametrize keyswitch
target.addLegalOp<mlir::arith::ConstantOp>();
patterns.add<KeySwitchGLWEOpPattern>(&getContext(), converter,
cryptoParameters);
target.addDynamicallyLegalOp<TFHE::KeySwitchGLWEOp>(
[&](TFHE::KeySwitchGLWEOp op) {
return !op.getKey().getInputKey().isNotParameterized() &&
!op.getKey().getOutputKey().isNotParameterized() &&
op.getKey().getBaseLog() != 0 && op.getKey().getLevels() != 0;
return op.getKeyAttr().getInputKey().isParameterized() &&
op.getKeyAttr().getOutputKey().isParameterized() &&
op.getKeyAttr().getBaseLog() != -1 &&
op.getKeyAttr().getLevels() != -1;
});
// Parametrize bootstrap
patterns.add<BootstrapGLWEOpPattern>(&getContext(), converter,
cryptoParameters);
target.addDynamicallyLegalOp<TFHE::BootstrapGLWEOp>(
[&](TFHE::BootstrapGLWEOp op) {
return converter.isLegal(op->getResultTypes());
return op.getKeyAttr().getInputKey().isParameterized() &&
op.getKeyAttr().getOutputKey().isParameterized() &&
op.getKeyAttr().getLevels() != -1 &&
op.getKeyAttr().getBaseLog() != -1 &&
op.getKeyAttr().getGlweDim() != -1 &&
op.getKeyAttr().getPolySize() != -1;
});
// Parametrize wop pbs
@@ -312,11 +322,20 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
cryptoParameters);
target.addDynamicallyLegalOp<TFHE::WopPBSGLWEOp>(
[&](TFHE::WopPBSGLWEOp op) {
return !op.getType()
.cast<mlir::RankedTensorType>()
.getElementType()
.cast<TFHE::GLWECipherTextType>()
.hasUnparametrizedParameters();
return op.getKskAttr().getInputKey().isParameterized() &&
op.getKskAttr().getOutputKey().isParameterized() &&
op.getKskAttr().getBaseLog() != -1 &&
op.getKskAttr().getLevels() != -1 &&
op.getBskAttr().getInputKey().isParameterized() &&
op.getBskAttr().getOutputKey().isParameterized() &&
op.getBskAttr().getLevels() != -1 &&
op.getBskAttr().getBaseLog() != -1 &&
op.getBskAttr().getGlweDim() != -1 &&
op.getBskAttr().getPolySize() != -1 &&
op.getPkskAttr().getInputKey().isParameterized() &&
op.getPkskAttr().getOutputKey().isParameterized() &&
op.getPkskAttr().getLevels() != -1 &&
op.getPkskAttr().getBaseLog() != -1;
});
// Add all patterns to convert TFHE types

View File

@@ -0,0 +1,14 @@
add_mlir_dialect_library(
TFHEKeyNormalization
TFHEKeyNormalization.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/TFHE
DEPENDS
TFHEDialect
mlir-headers
LINK_LIBS
PUBLIC
MLIRIR
MLIRTransforms)
target_link_libraries(TFHEKeyNormalization PUBLIC MLIRIR)

View File

@@ -0,0 +1,411 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang/Dialect/TFHE/IR/TFHEAttrs.h"
#include "concretelang/Dialect/TFHE/IR/TFHEParameters.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include <llvm/ADT/SmallSet.h>
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h"
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
#include "concretelang/Dialect/RT/IR/RTOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
#include "concretelang/Support/Constants.h"
#include "concretelang/Support/TFHECircuitKeys.h"
#include <llvm/Support/raw_ostream.h>
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
#include <variant>
namespace TFHE = mlir::concretelang::TFHE;
using mlir::concretelang::TFHE::GLWECipherTextType;
namespace conversion {
class KeyConverter {
public:
KeyConverter(mlir::concretelang::TFHE::TFHECircuitKeys &circuitKeys)
: circuitKeys(circuitKeys){};
TFHE::GLWESecretKey convertSecretKey(TFHE::GLWESecretKey sk) {
auto parameterizedKey = sk.getParameterized().value();
return TFHE::GLWESecretKey::newNormalized(
parameterizedKey.dimension, parameterizedKey.polySize,
circuitKeys.getSecretKeyIndex(sk).value());
}
TFHE::GLWEBootstrapKeyAttr
convertBootstrapKey(TFHE::GLWEBootstrapKeyAttr bsk) {
return TFHE::GLWEBootstrapKeyAttr::get(
bsk.getContext(), convertSecretKey(bsk.getInputKey()),
convertSecretKey(bsk.getOutputKey()), bsk.getPolySize(),
bsk.getGlweDim(), bsk.getLevels(), bsk.getBaseLog(),
circuitKeys.getBootstrapKeyIndex(bsk).value());
}
TFHE::GLWEKeyswitchKeyAttr
convertKeyswitchKey(TFHE::GLWEKeyswitchKeyAttr ksk) {
return TFHE::GLWEKeyswitchKeyAttr::get(
ksk.getContext(), convertSecretKey(ksk.getInputKey()),
convertSecretKey(ksk.getOutputKey()), ksk.getLevels(), ksk.getBaseLog(),
circuitKeys.getKeyswitchKeyIndex(ksk).value());
}
TFHE::GLWEPackingKeyswitchKeyAttr
convertPackingKeyswitchKey(TFHE::GLWEPackingKeyswitchKeyAttr pksk) {
return TFHE::GLWEPackingKeyswitchKeyAttr::get(
pksk.getContext(), convertSecretKey(pksk.getInputKey()),
convertSecretKey(pksk.getOutputKey()), pksk.getOutputPolySize(),
pksk.getInputLweDim(), pksk.getGlweDim(), pksk.getLevels(),
pksk.getBaseLog(),
circuitKeys.getPackingKeyswitchKeyIndex(pksk).value());
}
private:
mlir::concretelang::TFHE::TFHECircuitKeys circuitKeys;
};
class TypeConverter : public mlir::TypeConverter {
public:
TypeConverter(KeyConverter &keyConverter) : keyConverter(keyConverter) {
addConversion([](mlir::Type type) { return type; });
addConversion([&](GLWECipherTextType type) {
auto key = type.getKey();
if (key.isParameterized()) {
return GLWECipherTextType::get(type.getContext(),
keyConverter.convertSecretKey(key));
} else {
return type;
}
});
addConversion([&](mlir::RankedTensorType type) {
mlir::Type r = mlir::RankedTensorType::get(
type.getShape(), this->convertType(type.getElementType()));
return r;
});
addConversion([&](mlir::concretelang::RT::FutureType type) {
return mlir::concretelang::RT::FutureType::get(
this->convertType(type.dyn_cast<mlir::concretelang::RT::FutureType>()
.getElementType()));
});
addConversion([&](mlir::concretelang::RT::PointerType type) {
return mlir::concretelang::RT::PointerType::get(
this->convertType(type.dyn_cast<mlir::concretelang::RT::PointerType>()
.getElementType()));
});
}
private:
KeyConverter keyConverter;
};
} // namespace conversion
namespace patterns {
struct KeySwitchGLWEOpPattern
: public mlir::OpRewritePattern<TFHE::KeySwitchGLWEOp> {
KeySwitchGLWEOpPattern(mlir::MLIRContext *context,
conversion::TypeConverter &typeConverter,
conversion::KeyConverter &keyConverter,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<TFHE::KeySwitchGLWEOp>(context, benefit),
keyConverter(keyConverter), typeConverter(typeConverter) {}
mlir::LogicalResult
matchAndRewrite(TFHE::KeySwitchGLWEOp ksOp,
mlir::PatternRewriter &rewriter) const override {
auto newInputTy = typeConverter.convertType(ksOp.getCiphertext().getType())
.cast<GLWECipherTextType>();
auto newOutputTy = typeConverter.convertType(ksOp.getResult().getType());
auto newKeyswitchKey = keyConverter.convertKeyswitchKey(ksOp.getKeyAttr());
auto newOp = rewriter.replaceOpWithNewOp<TFHE::KeySwitchGLWEOp>(
ksOp, newOutputTy, ksOp.getCiphertext(), newKeyswitchKey);
rewriter.startRootUpdate(newOp);
newOp.getCiphertext().setType(newInputTy);
rewriter.finalizeRootUpdate(newOp);
return mlir::success();
};
private:
conversion::KeyConverter &keyConverter;
conversion::TypeConverter &typeConverter;
};
struct BootstrapGLWEOpPattern
: public mlir::OpRewritePattern<TFHE::BootstrapGLWEOp> {
BootstrapGLWEOpPattern(mlir::MLIRContext *context,
conversion::TypeConverter &typeConverter,
conversion::KeyConverter &keyConverter,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<TFHE::BootstrapGLWEOp>(context, benefit),
keyConverter(keyConverter), typeConverter(typeConverter) {}
mlir::LogicalResult
matchAndRewrite(TFHE::BootstrapGLWEOp bsOp,
mlir::PatternRewriter &rewriter) const override {
auto newInputTy = typeConverter.convertType(bsOp.getCiphertext().getType())
.cast<GLWECipherTextType>();
auto newOutputTy = typeConverter.convertType(bsOp.getResult().getType());
auto newBootstrapKey = keyConverter.convertBootstrapKey(bsOp.getKeyAttr());
auto newOp = rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
bsOp, newOutputTy, bsOp.getCiphertext(), bsOp.getLookupTable(),
newBootstrapKey);
rewriter.startRootUpdate(newOp);
newOp.getCiphertext().setType(newInputTy.cast<GLWECipherTextType>());
rewriter.finalizeRootUpdate(newOp);
return mlir::success();
};
private:
conversion::KeyConverter &keyConverter;
conversion::TypeConverter &typeConverter;
};
struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
WopPBSGLWEOpPattern(mlir::MLIRContext *context,
conversion::TypeConverter &typeConverter,
conversion::KeyConverter &keyConverter,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<TFHE::WopPBSGLWEOp>(context, benefit),
keyConverter(keyConverter), typeConverter(typeConverter) {}
mlir::LogicalResult
matchAndRewrite(TFHE::WopPBSGLWEOp wopPBSOp,
mlir::PatternRewriter &rewriter) const override {
auto newInputTy =
typeConverter.convertType(wopPBSOp.getCiphertexts().getType())
.cast<mlir::RankedTensorType>();
auto newOutputType = typeConverter.convertType(wopPBSOp.getType());
auto newKeyswitchKey =
keyConverter.convertKeyswitchKey(wopPBSOp.getKskAttr());
auto newBootstrapKey =
keyConverter.convertBootstrapKey(wopPBSOp.getBskAttr());
auto newPackingKeyswitchKey =
keyConverter.convertPackingKeyswitchKey(wopPBSOp.getPkskAttr());
auto newOp = rewriter.replaceOpWithNewOp<TFHE::WopPBSGLWEOp>(
wopPBSOp, newOutputType, wopPBSOp.getCiphertexts(),
wopPBSOp.getLookupTable(), newKeyswitchKey, newBootstrapKey,
newPackingKeyswitchKey, wopPBSOp.getCrtDecompositionAttr(),
wopPBSOp.getCbsLevelsAttr(), wopPBSOp.getCbsBaseLogAttr());
rewriter.startRootUpdate(newOp);
newOp.getCiphertexts().setType(newInputTy);
rewriter.finalizeRootUpdate(newOp);
return mlir::success();
};
private:
conversion::KeyConverter &keyConverter;
conversion::TypeConverter &typeConverter;
};
} // namespace patterns
namespace {
struct TFHEKeyNormalizationPass
: public TFHEKeyNormalizationBase<TFHEKeyNormalizationPass> {
void runOnOperation() final;
};
template <typename Op>
void populateWithTFHEOpTypeConversionPattern(
mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target,
mlir::TypeConverter &typeConverter) {
patterns.add<mlir::concretelang::GenericTypeConverterPattern<Op>>(
patterns.getContext(), typeConverter);
target.addDynamicallyLegalOp<Op>(
[&](Op op) { return typeConverter.isLegal(op->getResultTypes()); });
}
/// Populate the RewritePatternSet with all patterns that rewrite Concrete
/// operators to the corresponding function call to the `Concrete C API`.
void populateWithTFHEOpTypeConversionPatterns(
mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target,
mlir::TypeConverter &typeConverter) {
populateWithTFHEOpTypeConversionPattern<mlir::concretelang::TFHE::ZeroGLWEOp>(
patterns, target, typeConverter);
populateWithTFHEOpTypeConversionPattern<
mlir::concretelang::TFHE::ZeroTensorGLWEOp>(patterns, target,
typeConverter);
populateWithTFHEOpTypeConversionPattern<
mlir::concretelang::TFHE::AddGLWEIntOp>(patterns, target, typeConverter);
populateWithTFHEOpTypeConversionPattern<mlir::concretelang::TFHE::AddGLWEOp>(
patterns, target, typeConverter);
populateWithTFHEOpTypeConversionPattern<
mlir::concretelang::TFHE::SubGLWEIntOp>(patterns, target, typeConverter);
populateWithTFHEOpTypeConversionPattern<mlir::concretelang::TFHE::NegGLWEOp>(
patterns, target, typeConverter);
populateWithTFHEOpTypeConversionPattern<
mlir::concretelang::TFHE::MulGLWEIntOp>(patterns, target, typeConverter);
}
} // namespace
void TFHEKeyNormalizationPass::runOnOperation() {
auto op = this->getOperation();
auto circuitKeys = TFHE::extractCircuitKeys(op);
auto keyConverter = conversion::KeyConverter(circuitKeys);
auto typeConverter = conversion::TypeConverter(keyConverter);
// Parametrize
{
mlir::ConversionTarget target(getContext());
mlir::RewritePatternSet patterns(&getContext());
// function signature
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
[&](mlir::func::FuncOp funcOp) {
return typeConverter.isSignatureLegal(funcOp.getFunctionType()) &&
typeConverter.isLegal(&funcOp.getBody());
});
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
[&](mlir::func::ConstantOp op) {
return FunctionConstantOpConversion<
conversion::TypeConverter>::isLegal(op, typeConverter);
});
patterns.add<FunctionConstantOpConversion<conversion::TypeConverter>>(
&getContext(), typeConverter);
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
patterns, typeConverter);
// Parametrize keyswitch
target.addLegalOp<mlir::arith::ConstantOp>();
patterns.add<patterns::KeySwitchGLWEOpPattern>(&getContext(), typeConverter,
keyConverter);
target.addDynamicallyLegalOp<TFHE::KeySwitchGLWEOp>(
[&](TFHE::KeySwitchGLWEOp op) {
return op.getKeyAttr().getInputKey().isNormalized() &&
op.getKeyAttr().getOutputKey().isNormalized() &&
op.getKeyAttr().getIndex() != -1;
});
// Parametrize bootstrap
patterns.add<patterns::BootstrapGLWEOpPattern>(&getContext(), typeConverter,
keyConverter);
target.addDynamicallyLegalOp<TFHE::BootstrapGLWEOp>(
[&](TFHE::BootstrapGLWEOp op) {
return op.getKeyAttr().getInputKey().isNormalized() &&
op.getKeyAttr().getOutputKey().isNormalized() &&
op.getKeyAttr().getIndex() != -1;
});
// Parametrize wop pbs
patterns.add<patterns::WopPBSGLWEOpPattern>(&getContext(), typeConverter,
keyConverter);
target.addDynamicallyLegalOp<TFHE::WopPBSGLWEOp>(
[&](TFHE::WopPBSGLWEOp op) {
return op.getKskAttr().getInputKey().isNormalized() &&
op.getKskAttr().getOutputKey().isNormalized() &&
op.getKskAttr().getIndex() != -1 &&
op.getBskAttr().getInputKey().isNormalized() &&
op.getBskAttr().getOutputKey().isNormalized() &&
op.getBskAttr().getIndex() != -1 &&
op.getPkskAttr().getInputKey().isNormalized() &&
op.getPkskAttr().getOutputKey().isNormalized() &&
op.getPkskAttr().getIndex() != -1;
});
// Add all patterns to convert TFHE types
populateWithTFHEOpTypeConversionPatterns(patterns, target, typeConverter);
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
mlir::bufferization::AllocTensorOp>>(&getContext(), typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::bufferization::AllocTensorOp>(target, typeConverter);
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
conversion::TypeConverter>>(
&getContext(), typeConverter);
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
conversion::TypeConverter>>(
&getContext(), typeConverter);
patterns.add<RegionOpTypeConverterPattern<mlir::scf::ForOp,
conversion::TypeConverter>>(
&getContext(), typeConverter);
patterns.add<RegionOpTypeConverterPattern<mlir::func::ReturnOp,
conversion::TypeConverter>>(
&getContext(), typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::func::ReturnOp>(
target, typeConverter);
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::YieldOp,
conversion::TypeConverter>>(
&getContext(), typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::linalg::YieldOp>(
target, typeConverter);
mlir::concretelang::populateWithTensorTypeConverterPatterns(
patterns, target, typeConverter);
// Conversion of RT Dialect Ops
patterns.add<
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::Tracing::TraceCiphertextOp>,
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
mlir::concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::MakeReadyFutureOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::AwaitFutureOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::CreateAsyncTaskOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::WorkFunctionReturnOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::Tracing::TraceCiphertextOp>(target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::MakeReadyFutureOp>(target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::AwaitFutureOp>(target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::CreateAsyncTaskOp>(target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target,
typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target,
typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::WorkFunctionReturnOp>(target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target,
typeConverter);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))
.failed()) {
this->signalPassFailure();
}
}
}
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>> createTFHEKeyNormalizationPass() {
return std::make_unique<TFHEKeyNormalizationPass>();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -46,11 +46,11 @@ public:
TFHEToConcreteTypeConverter() {
addConversion([](mlir::Type type) { return type; });
addConversion([&](GLWECipherTextType type) {
assert(!type.getKey().isNotParameterized());
assert(type.getKey().getPolySize().value() == 1 &&
assert(type.getKey().isNormalized() && "keys should be normalized");
assert(type.getKey().getNormalized().value().polySize == 1 &&
"converter doesn't support polynomialSize > 1");
llvm::SmallVector<int64_t, 2> shape;
shape.push_back(type.getKey().getDimension().value() + 1);
shape.push_back(type.getKey().getNormalized().value().dimension + 1);
return mlir::RankedTensorType::get(
shape, mlir::IntegerType::get(type.getContext(), 64));
});
@@ -62,8 +62,8 @@ public:
mlir::SmallVector<int64_t> newShape;
newShape.reserve(type.getShape().size() + 1);
newShape.append(type.getShape().begin(), type.getShape().end());
assert(!glwe.getKey().isNotParameterized());
newShape.push_back(glwe.getKey().getDimension().value() + 1);
assert(glwe.getKey().isNormalized());
newShape.push_back(glwe.getKey().getNormalized().value().dimension + 1);
mlir::Type r = mlir::RankedTensorType::get(
newShape, mlir::IntegerType::get(type.getContext(), 64));
return r;
@@ -128,12 +128,14 @@ struct BootstrapGLWEOpPattern
auto glweDimension = adaptor.getKey().getGlweDim();
auto levels = adaptor.getKey().getLevels();
auto baseLog = adaptor.getKey().getBaseLog();
auto inputLweDimension = inputType.getKey().getDimension().value();
auto inputLweDimension =
inputType.getKey().getNormalized().value().dimension;
auto bskIndex = bsOp.getKeyAttr().getIndex();
rewriter.replaceOpWithNewOp<Concrete::BootstrapLweTensorOp>(
bsOp, this->getTypeConverter()->convertType(resultType),
adaptor.getCiphertext(), adaptor.getLookupTable(), inputLweDimension,
polySize, levels, baseLog, glweDimension);
polySize, levels, baseLog, glweDimension, bskIndex);
return mlir::success();
}
@@ -164,12 +166,16 @@ struct WopPBSGLWEOpPattern
auto pksOutputPolySize = adaptor.getPksk().getOutputPolySize();
auto crtDecomposition = adaptor.getCrtDecompositionAttr();
auto resultType = op.getType();
auto kskIndex = op.getKskAttr().getIndex();
auto bskIndex = op.getBskAttr().getIndex();
auto pkskIndex = op.getPkskAttr().getIndex();
rewriter.replaceOpWithNewOp<Concrete::WopPBSCRTLweTensorOp>(
op, this->getTypeConverter()->convertType(resultType),
adaptor.getCiphertexts(), adaptor.getLookupTable(), bsLevels, bsBaseLog,
ksLevels, ksBaseLog, pksInputLweDim, pksOutputPolySize, pksLevels,
pksBaseLog, cbsLevels, cbsBaseLog, crtDecomposition);
pksBaseLog, cbsLevels, cbsBaseLog, crtDecomposition, kskIndex, bskIndex,
pkskIndex);
return mlir::success();
}
@@ -199,12 +205,14 @@ struct BatchedBootstrapGLWEOpPattern
auto glweDimension = adaptor.getKey().getGlweDim();
auto levels = adaptor.getKey().getLevels();
auto baseLog = adaptor.getKey().getBaseLog();
auto inputLweDimension = inputElementType.getKey().getDimension().value();
auto inputLweDimension =
inputElementType.getKey().getNormalized().value().dimension;
auto bskIndex = adaptor.getKey().getIndex();
rewriter.replaceOpWithNewOp<Concrete::BatchedBootstrapLweTensorOp>(
bbsOp, this->getTypeConverter()->convertType(bbsOp.getType()),
adaptor.getCiphertexts(), adaptor.getLookupTable(), inputLweDimension,
polySize, levels, baseLog, glweDimension);
polySize, levels, baseLog, glweDimension, bskIndex);
return mlir::success();
}
@@ -231,12 +239,14 @@ struct KeySwitchGLWEOpPattern
auto levels = adaptor.getKey().getLevels();
auto baseLog = adaptor.getKey().getBaseLog();
auto inputDim = inputType.getKey().getDimension().value();
auto outputDim = resultType.getKey().getDimension().value();
auto inputDim = inputType.getKey().getNormalized().value().dimension;
auto outputDim = resultType.getKey().getNormalized().value().dimension;
auto kskIndex = ksOp.getKeyAttr().getIndex();
rewriter.replaceOpWithNewOp<Concrete::KeySwitchLweTensorOp>(
ksOp, this->getTypeConverter()->convertType(resultType),
adaptor.getCiphertext(), levels, baseLog, inputDim, outputDim);
adaptor.getCiphertext(), levels, baseLog, inputDim, outputDim,
kskIndex);
return mlir::success();
}
@@ -270,12 +280,15 @@ struct BatchedKeySwitchGLWEOpPattern
auto levels = adaptor.getKey().getLevels();
auto baseLog = adaptor.getKey().getBaseLog();
auto inputDim = inputElementType.getKey().getDimension().value();
auto outputDim = resultElementType.getKey().getDimension().value();
auto inputDim = inputElementType.getKey().getNormalized().value().dimension;
auto outputDim =
resultElementType.getKey().getNormalized().value().dimension;
auto kskIndex = adaptor.getKey().getIndex();
rewriter.replaceOpWithNewOp<Concrete::BatchedKeySwitchLweTensorOp>(
bksOp, this->getTypeConverter()->convertType(bksOp.getType()),
adaptor.getCiphertexts(), levels, baseLog, inputDim, outputDim);
adaptor.getCiphertexts(), levels, baseLog, inputDim, outputDim,
kskIndex);
return mlir::success();
}

View File

@@ -5,6 +5,7 @@
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHEParameters.h"
#define GET_ATTRDEF_CLASSES
#include "concretelang/Dialect/TFHE/IR/TFHEAttrs.cpp.inc"
@@ -15,6 +16,7 @@
#include "concretelang/Dialect/TFHE/IR/TFHEOpsDialect.cpp.inc"
#include "concretelang/Support/Constants.h"
#include "concretelang/Support/Variants.h"
using namespace mlir::concretelang::TFHE;
@@ -36,17 +38,32 @@ void TFHEDialect::initialize() {
}
/// Verify that GLWE parameter are consistant
/// - The bits parameter is 64 (we support only this for v0)
::mlir::LogicalResult GLWECipherTextType::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
GLWESecretKey key) {
if (!key.isNotParameterized() && key.getPolySize().value() == 0) {
emitError() << "GLWE key has zero poly size.";
return ::mlir::failure();
}
if (!key.isNotParameterized() && key.getDimension().value() == 0) {
emitError() << "GLWE key has zero dimension.";
return ::mlir::failure();
}
return ::mlir::success();
return std::visit(
overloaded{[](GLWESecretKeyNone sk) { return mlir::success(); },
[&](GLWESecretKeyParameterized sk) {
if (sk.dimension == 0) {
emitError() << "GLWE key has zero dimension.";
return ::mlir::failure();
}
if (sk.polySize == 0) {
emitError() << "GLWE key has zero poly size.";
return ::mlir::failure();
}
return mlir::success();
},
[&](GLWESecretKeyNormalized sk) {
if (sk.dimension == 0) {
emitError() << "GLWE key has zero dimension.";
return ::mlir::failure();
}
if (sk.polySize == 0) {
emitError() << "GLWE key has zero poly size.";
return ::mlir::failure();
}
return mlir::success();
}},
key.inner);
}

View File

@@ -4,70 +4,131 @@
// for license information.
#include "concretelang/Dialect/TFHE/IR/TFHEParameters.h"
#include "concretelang/Support/Variants.h"
#include "mlir/IR/OpImplementation.h"
#include <variant>
namespace mlir {
namespace concretelang {
namespace TFHE {
GLWESecretKey::GLWESecretKey() {
dimension = -1;
polySize = -1;
id = -1;
}
GLWESecretKey::GLWESecretKey(int64_t dimension, int64_t polySize, int64_t id) {
assert(dimension > 0);
assert(polySize > 0);
assert(id > 0);
this->dimension = dimension;
this->polySize = polySize;
this->id = id;
}
bool GLWESecretKey::operator==(GLWESecretKey other) {
return this->id == other.id && this->dimension == other.dimension &&
bool GLWESecretKeyParameterized::operator==(
const GLWESecretKeyParameterized other) const {
return this->dimension == other.dimension &&
this->identifier == other.identifier &&
this->polySize == other.polySize;
}
bool GLWESecretKeyNormalized::operator==(
const GLWESecretKeyNormalized other) const {
return this->dimension == other.dimension && this->index == other.index &&
this->polySize == other.polySize;
}
GLWESecretKey GLWESecretKey::newNone() {
return GLWESecretKey{GLWESecretKeyNone{}};
}
GLWESecretKey GLWESecretKey::newNormalized(uint64_t dimension,
uint64_t polySize, uint64_t index) {
return GLWESecretKey{GLWESecretKeyNormalized{dimension, polySize, index}};
}
GLWESecretKey GLWESecretKey::newParameterized(uint64_t dimension,
uint64_t polySize,
uint64_t identifier) {
return GLWESecretKey{
GLWESecretKeyParameterized{dimension, polySize, identifier}};
}
bool GLWESecretKey::operator==(const GLWESecretKey other) const {
return this->id == other.id && this->dimension == other.dimension &&
this->polySize == other.polySize;
return std::visit(
overloaded{
[](GLWESecretKeyNone thisK, GLWESecretKeyNone otherK) {
return true;
},
[](GLWESecretKeyNormalized thisK, GLWESecretKeyNormalized otherK) {
return thisK == otherK;
},
[](GLWESecretKeyParameterized thisK,
GLWESecretKeyParameterized otherK) { return thisK == otherK; },
[](auto _thisK, auto _otherK) { return false; }},
this->inner, other.inner);
}
bool GLWESecretKey::operator!=(GLWESecretKey other) {
return this->id != other.id || this->dimension != other.dimension ||
this->polySize != other.polySize;
bool GLWESecretKey::operator!=(const GLWESecretKey other) const {
return !(*this == other);
}
std::optional<int64_t> GLWESecretKey::getDimension() const {
if (this->isNotParameterized()) {
return std::nullopt;
template <typename V> bool GLWESecretKey::is() {
return std::holds_alternative<V>(this->inner);
}
bool GLWESecretKey::isNone() { return is<GLWESecretKeyNone>(); }
bool GLWESecretKey::isParameterized() {
return is<GLWESecretKeyParameterized>();
}
bool GLWESecretKey::isNormalized() { return is<GLWESecretKeyNormalized>(); }
template <typename V> std::optional<V> GLWESecretKey::get() {
if (this->is<V>()) {
return std::get<V>(this->inner);
} else {
return this->dimension;
return std::nullopt;
}
}
std::optional<int64_t> GLWESecretKey::getPolySize() const {
if (this->isNotParameterized()) {
return std::nullopt;
} else {
return this->polySize;
}
std::optional<GLWESecretKeyNone> GLWESecretKey::getNone() {
return get<GLWESecretKeyNone>();
}
mlir::Optional<int64_t> GLWESecretKey::getId() const {
if (this->isNotParameterized()) {
return std::nullopt;
} else {
return this->id;
}
std::optional<GLWESecretKeyParameterized> GLWESecretKey::getParameterized() {
return get<GLWESecretKeyParameterized>();
}
std::optional<GLWESecretKeyNormalized> GLWESecretKey::getNormalized() {
return get<GLWESecretKeyNormalized>();
}
bool GLWESecretKey::isNotParameterized() const { return id <= 0; }
llvm::hash_code hash_value(const GLWESecretKey &key) {
return llvm::hash_combine("GlweSecretKey", key.getDimension(),
key.getPolySize(), key.getId());
return std::visit(overloaded{[](GLWESecretKeyNone sk) {
return llvm::hash_value("GlweSecretKeyNone");
},
[](GLWESecretKeyParameterized sk) {
return llvm::hash_combine(
"GlweSecretKeyParameterized", sk.dimension,
sk.polySize, sk.identifier);
},
[](GLWESecretKeyNormalized sk) {
return llvm::hash_combine(
"GlweSecretKeyNormalized", sk.dimension,
sk.polySize, sk.index);
}},
key.inner);
}
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, GLWESecretKeyNone sk) {
OS << "sk?";
return OS;
}
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
GLWESecretKeyParameterized sk) {
OS << "sk<" << sk.identifier << "," << sk.polySize << "," << sk.dimension
<< ">";
return OS;
}
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
GLWESecretKeyNormalized sk) {
OS << "sk[" << sk.index << "]<" << sk.polySize << "," << sk.dimension << ">";
return OS;
}
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, GLWESecretKey sk) {
std::visit(overloaded{[&](GLWESecretKeyNone nsk) { OS << nsk; },
[&](GLWESecretKeyNormalized nsk) { OS << nsk; },
[&](GLWESecretKeyParameterized nsk) { OS << nsk; }},
sk.inner);
return OS;
}
} // namespace TFHE
@@ -77,37 +138,45 @@ llvm::hash_code hash_value(const GLWESecretKey &key) {
namespace mlir {
AsmPrinter &operator<<(AsmPrinter &p,
mlir::concretelang::TFHE::GLWESecretKey key) {
if (key.isNotParameterized()) {
p << "sk[?]";
} else {
p << "sk[" << key.getId() << "]<" << key.getPolySize().value() << ","
<< key.getDimension().value() << ">";
}
p.getStream() << key;
return p;
}
FailureOr<mlir::concretelang::TFHE::GLWESecretKey>
FieldParser<mlir::concretelang::TFHE::GLWESecretKey>::parse(AsmParser &parser) {
int64_t dimension = -1, polySize = -1, id = -1;
if (parser.parseKeyword("sk") || parser.parseLSquare()) {
uint64_t dimension = -1, polySize = -1, id = -1;
if (parser.parseKeyword("sk")) {
return mlir::failure();
}
auto maybeId = parser.parseOptionalInteger(id);
if (maybeId.has_value()) {
if (maybeId.value() || parser.parseRSquare() || parser.parseLess() ||
if (parser.parseOptionalQuestion().succeeded()) { // Parsing none key
return mlir::concretelang::TFHE::GLWESecretKey::newNone();
}
if (parser.parseOptionalLSquare().succeeded()) { // Parsing normalized key
if (parser.parseInteger(id) || parser.parseRSquare() ||
parser.parseLess() || parser.parseInteger(polySize) ||
parser.parseComma() || parser.parseInteger(dimension) ||
parser.parseGreater()) {
return mlir::failure();
} else {
return mlir::concretelang::TFHE::GLWESecretKey::newNormalized(
dimension, polySize, id);
}
}
if (parser.parseOptionalLess().succeeded()) { // Parsing parameterized key
if (parser.parseInteger(id) || parser.parseComma() ||
parser.parseInteger(polySize) || parser.parseComma() ||
parser.parseInteger(dimension) || parser.parseGreater()) {
return mlir::failure();
}
} else {
if (parser.parseQuestion() || parser.parseRSquare()) {
return mlir::failure();
} else {
return mlir::concretelang::TFHE::GLWESecretKey::newParameterized(
dimension, polySize, id);
}
}
if (id <= 0) {
return mlir::concretelang::TFHE::GLWESecretKey();
} else {
return mlir::concretelang::TFHE::GLWESecretKey(dimension, polySize, id);
}
return mlir::failure();
}
} // namespace mlir

View File

@@ -70,6 +70,8 @@ struct Process {
Param glwe_dim;
Param precision;
Param output_size;
Param ksk_index;
Param bsk_index;
Context ctx;
void (*fun)(Process *);
};
@@ -102,7 +104,7 @@ void memref_keyswitch_lwe_u64_process(Process *p) {
out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0],
ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0],
p->level.val, p->base_log.val, p->input_lwe_dim.val,
p->output_lwe_dim.val, p->ctx.val);
p->output_lwe_dim.val, p->ksk_index.val, p->ctx.val);
(p->output_streams[0]).memref_stream->put(out);
}
delete p;
@@ -123,7 +125,7 @@ void memref_bootstrap_lwe_u64_process(Process *p) {
ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0],
tlu.allocated, tlu.aligned, tlu.offset, tlu.sizes[0], tlu.strides[0],
p->input_lwe_dim.val, p->poly_size.val, p->level.val, p->base_log.val,
p->glwe_dim.val, p->ctx.val);
p->glwe_dim.val, p->bsk_index.val, p->ctx.val);
(p->output_streams[0]).memref_stream->put(out);
}
delete p;
@@ -278,7 +280,7 @@ void stream_emulator_make_memref_negate_lwe_ciphertext_u64_process(void *dfg,
void stream_emulator_make_memref_keyswitch_lwe_u64_process(
void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log,
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size,
void *context) {
uint32_t ksk_index, void *context) {
mlir::concretelang::stream_emulator::Process *p =
new mlir::concretelang::stream_emulator::Process;
p->input_streams.push_back(
@@ -292,6 +294,7 @@ void stream_emulator_make_memref_keyswitch_lwe_u64_process(
p->input_lwe_dim.val = input_lwe_dim;
p->output_lwe_dim.val = output_lwe_dim;
p->output_size.val = output_size;
p->ksk_index.val = ksk_index;
p->ctx.val = (mlir::concretelang::RuntimeContext *)context;
p->fun =
mlir::concretelang::stream_emulator::memref_keyswitch_lwe_u64_process;
@@ -302,7 +305,7 @@ void stream_emulator_make_memref_keyswitch_lwe_u64_process(
void stream_emulator_make_memref_bootstrap_lwe_u64_process(
void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim,
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
uint32_t output_size, void *context) {
uint32_t output_size, uint32_t bsk_index, void *context) {
mlir::concretelang::stream_emulator::Process *p =
new mlir::concretelang::stream_emulator::Process;
p->input_streams.push_back(
@@ -320,6 +323,7 @@ void stream_emulator_make_memref_bootstrap_lwe_u64_process(
p->base_log.val = base_log;
p->glwe_dim.val = glwe_dim;
p->output_size.val = output_size;
p->bsk_index.val = bsk_index;
p->ctx.val = (mlir::concretelang::RuntimeContext *)context;
p->fun =
mlir::concretelang::stream_emulator::memref_bootstrap_lwe_u64_process;

View File

@@ -68,7 +68,7 @@ void memref_keyswitch_lwe_cuda_u64(
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint32_t level, uint32_t base_log,
uint32_t input_lwe_dim, uint32_t output_lwe_dim,
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t ksk_index,
mlir::concretelang::RuntimeContext *context) {
assert(out_stride == 1);
assert(ct0_stride == 1);
@@ -78,7 +78,7 @@ void memref_keyswitch_lwe_cuda_u64(
// Output 1D memref as 2D memref
ct0_allocated, ct0_aligned, ct0_offset, 1, ct0_size, ct0_size, ct0_stride,
// Keyswitch additional arguments
level, base_log, input_lwe_dim, output_lwe_dim, context);
level, base_log, input_lwe_dim, output_lwe_dim, ksk_index, context);
}
void memref_bootstrap_lwe_cuda_u64(
@@ -88,7 +88,7 @@ void memref_bootstrap_lwe_cuda_u64(
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
uint32_t base_log, uint32_t glwe_dim,
uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
mlir::concretelang::RuntimeContext *context) {
memref_batched_bootstrap_lwe_cuda_u64(
// Output 1D memref as 2D memref
@@ -98,7 +98,7 @@ void memref_bootstrap_lwe_cuda_u64(
// Table lookup memref
tlu_allocated, tlu_aligned, tlu_offset, tlu_size, tlu_stride,
// Bootstrap additional arguments
input_lwe_dim, poly_size, level, base_log, glwe_dim, context);
input_lwe_dim, poly_size, level, base_log, glwe_dim, bsk_index, context);
}
// Batched CUDA function //////////////////////////////////////////////////////
@@ -110,7 +110,8 @@ void memref_batched_keyswitch_lwe_cuda_u64(
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint32_t level,
uint32_t base_log, uint32_t input_lwe_dim, uint32_t output_lwe_dim,
mlir::concretelang::RuntimeContext *context) {
uint32_t ksk_index, mlir::concretelang::RuntimeContext *context) {
assert(ksk_index == 0 && "multiple ksk is not yet implemented on GPU");
assert(out_size0 == ct0_size0);
assert(out_size1 == output_lwe_dim + 1);
assert(ct0_size1 == input_lwe_dim + 1);
@@ -154,8 +155,9 @@ void memref_batched_bootstrap_lwe_cuda_u64(
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size,
uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size,
uint32_t level, uint32_t base_log, uint32_t glwe_dim,
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
mlir::concretelang::RuntimeContext *context) {
assert(bsk_index == 0 && "multiple bsk is not yet implemented on GPU");
assert(out_size0 == ct0_size0);
assert(out_size1 == glwe_dim * poly_size + 1);
// TODO: Multi GPU
@@ -495,16 +497,19 @@ void memref_negate_lwe_ciphertext_u64(
out_aligned + out_offset, ct0_aligned + ct0_offset, lwe_dimension);
}
void memref_keyswitch_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint32_t decomposition_level_count,
uint32_t decomposition_base_log, uint32_t input_dimension,
uint32_t output_dimension, mlir::concretelang::RuntimeContext *context) {
void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
uint64_t out_offset, uint64_t out_size,
uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset,
uint64_t ct0_size, uint64_t ct0_stride,
uint32_t decomposition_level_count,
uint32_t decomposition_base_log,
uint32_t input_dimension,
uint32_t output_dimension, uint32_t ksk_index,
mlir::concretelang::RuntimeContext *context) {
assert(out_stride == 1 && ct0_stride == 1);
// Get keyswitch key - TODO Give a non hardcoded keyID
const uint64_t *keyswitch_key = context->keyswitch_key_buffer(0);
// Get keyswitch key
const uint64_t *keyswitch_key = context->keyswitch_key_buffer(ksk_index);
// Get stack parameter
concrete_cpu_keyswitch_lwe_ciphertext_u64(
out_aligned + out_offset, ct0_aligned + ct0_offset, keyswitch_key,
@@ -519,13 +524,13 @@ void memref_batched_keyswitch_lwe_u64(
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint32_t level,
uint32_t base_log, uint32_t input_lwe_dim, uint32_t output_lwe_dim,
mlir::concretelang::RuntimeContext *context) {
uint32_t ksk_index, mlir::concretelang::RuntimeContext *context) {
for (size_t i = 0; i < ct0_size0; i++) {
memref_keyswitch_lwe_u64(
out_allocated + i * out_size1, out_aligned + i * out_size1, out_offset,
out_size1, out_stride1, ct0_allocated + i * ct0_size1,
ct0_aligned + i * ct0_size1, ct0_offset, ct0_size1, ct0_stride1, level,
base_log, input_lwe_dim, output_lwe_dim, context);
base_log, input_lwe_dim, output_lwe_dim, ksk_index, context);
}
}
@@ -537,7 +542,8 @@ void memref_bootstrap_lwe_u64(
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
uint32_t input_lwe_dimension, uint32_t polynomial_size,
uint32_t decomposition_level_count, uint32_t decomposition_base_log,
uint32_t glwe_dimension, mlir::concretelang::RuntimeContext *context) {
uint32_t glwe_dimension, uint32_t bsk_index,
mlir::concretelang::RuntimeContext *context) {
uint64_t glwe_ct_size = polynomial_size * (glwe_dimension + 1);
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size * sizeof(uint64_t));
@@ -551,10 +557,9 @@ void memref_bootstrap_lwe_u64(
glwe_ct[polynomial_size * glwe_dimension + i] = tlu[i];
}
// Get fourrier bootstrap key - TODO Give a non hardcoded keyID
size_t keyId = 0;
const auto &fft = context->fft(keyId);
auto bootstrap_key = context->fourier_bootstrap_key_buffer(keyId);
// Get fourrier bootstrap key
const auto &fft = context->fft(bsk_index);
auto bootstrap_key = context->fourier_bootstrap_key_buffer(bsk_index);
// Get stack parameter
size_t scratch_size;
size_t scratch_align;
@@ -582,7 +587,7 @@ void memref_batched_bootstrap_lwe_u64(
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size,
uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size,
uint32_t level, uint32_t base_log, uint32_t glwe_dim,
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
mlir::concretelang::RuntimeContext *context) {
for (size_t i = 0; i < out_size0; i++) {
@@ -591,7 +596,7 @@ void memref_batched_bootstrap_lwe_u64(
out_size1, out_stride1, ct0_allocated, ct0_aligned + i * ct0_size1,
ct0_offset, ct0_size1, ct0_stride1, tlu_allocated, tlu_aligned,
tlu_offset, tlu_size, tlu_stride, input_lwe_dim, poly_size, level,
base_log, glwe_dim, context);
base_log, glwe_dim, bsk_index, context);
}
}
@@ -617,10 +622,12 @@ void memref_wop_pbs_crt_buffer(
uint64_t crt_decomp_offset, uint64_t crt_decomp_size,
uint64_t crt_decomp_stride,
// Additional crypto parameters
uint32_t lwe_small_size, uint32_t cbs_level_count, uint32_t cbs_base_log,
uint32_t lwe_small_dim, uint32_t cbs_level_count, uint32_t cbs_base_log,
uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count,
uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log,
uint32_t polynomial_size,
// Key Indices,
uint32_t ksk_index, uint32_t bsk_index, uint32_t pksk_index,
// runtime context that hold evluation keys
mlir::concretelang::RuntimeContext *context) {
@@ -635,7 +642,7 @@ void memref_wop_pbs_crt_buffer(
// Check for the size S
assert(out_size_1 == in_size_1);
uint64_t lwe_small_dim = lwe_small_size - 1;
uint64_t lwe_small_size = lwe_small_dim + 1;
assert(out_size_1 == in_size_1);
uint64_t lwe_big_size = in_size_1;
@@ -672,13 +679,9 @@ void memref_wop_pbs_crt_buffer(
std::vector<uint64_t> in_copy(first_ciphertext, first_ciphertext + copy_size);
// Extraction of each bit for each block
size_t fftKeyId = 0;
const auto &fft = context->fft(fftKeyId);
size_t bskKeyId = 0;
auto bootstrap_key = context->fourier_bootstrap_key_buffer(bskKeyId);
size_t kskKeyId = 0;
auto keyswicth_key = context->keyswitch_key_buffer(kskKeyId);
const auto &fft = context->fft(bsk_index);
auto bootstrap_key = context->fourier_bootstrap_key_buffer(bsk_index);
auto keyswicth_key = context->keyswitch_key_buffer(ksk_index);
for (int64_t i = crt_decomp_size - 1, extract_bits_output_offset = 0; i >= 0;
extract_bits_output_offset += number_of_bits_per_block[i--]) {
@@ -731,8 +734,7 @@ void memref_wop_pbs_crt_buffer(
auto *scratch = (uint8_t *)aligned_alloc(scratch_align, scratch_size);
size_t fpkskKeyId = 0;
auto fp_keyswicth_key = context->fp_keyswitch_key_buffer(fpkskKeyId);
auto fp_keyswicth_key = context->fp_keyswitch_key_buffer(pksk_index);
concrete_cpu_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_u64(
out_aligned + out_offset, extract_bits_output_buffer,

View File

@@ -4,10 +4,12 @@ add_mlir_library(
Jit.cpp
CompilationFeedback.cpp
CompilerEngine.cpp
TFHECircuitKeys.cpp
Encodings.cpp
JITSupport.cpp
LambdaArgument.cpp
V0Parameters.cpp
V0ClientParameters.cpp
ClientParametersGeneration.cpp
logging.cpp
Jit.cpp
LLVMEmitFile.cpp

View File

@@ -0,0 +1,378 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <cassert>
#include <llvm/ADT/SmallVector.h>
#include <map>
#include <optional>
#include <unordered_set>
#include <variant>
#include <llvm/ADT/Optional.h>
#include <llvm/ADT/STLExtras.h>
#include <llvm/ADT/SmallSet.h>
#include <llvm/Support/Error.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include "concrete/curves.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
#include "concretelang/Dialect/TFHE/IR/TFHEAttrs.h"
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHEParameters.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
#include "concretelang/Support/Encodings.h"
#include "concretelang/Support/Error.h"
#include "concretelang/Support/TFHECircuitKeys.h"
#include "concretelang/Support/Variants.h"
#include "llvm/Config/abi-breaking.h"
namespace mlir {
namespace concretelang {
namespace clientlib = ::concretelang::clientlib;
using ::concretelang::clientlib::ChunkInfo;
using ::concretelang::clientlib::CircuitGate;
using ::concretelang::clientlib::ClientParameters;
using ::concretelang::clientlib::Encoding;
using ::concretelang::clientlib::EncryptionGate;
using ::concretelang::clientlib::LweSecretKeyID;
using ::concretelang::clientlib::Precision;
using ::concretelang::clientlib::Variance;
const auto keyFormat = concrete::BINARY;
llvm::Expected<CircuitGate>
generateGate(mlir::Type type, encodings::Encoding encoding,
concrete::SecurityCurve curve,
std::optional<CRTDecomposition> maybeCrt) {
auto scalarVisitor = overloaded{
[&](encodings::EncryptedIntegerScalarEncoding enc)
-> llvm::Expected<CircuitGate> {
TFHE::GLWESecretKeyNormalized normKey;
if (type.isa<RankedTensorType>()) {
normKey = type.cast<RankedTensorType>()
.getElementType()
.cast<TFHE::GLWECipherTextType>()
.getKey()
.getNormalized()
.value();
} else {
normKey = type.cast<TFHE::GLWECipherTextType>()
.getKey()
.getNormalized()
.value();
}
size_t width = enc.width;
bool isSigned = enc.isSigned;
uint64_t size = 0;
std::vector<int64_t> dims{};
LweSecretKeyID secretKeyID = normKey.index;
Variance variance = curve.getVariance(1, normKey.dimension, 64);
CRTDecomposition crt = maybeCrt.value_or(std::vector<int64_t>());
return CircuitGate{
/* .encryption = */ std::optional<EncryptionGate>({
/* .secretKeyID = */ secretKeyID,
/* .variance = */ variance,
/* .encoding = */
{
/* .precision = */ width,
/* .crt = */ crt,
/*.sign = */ isSigned,
},
}),
/*.shape = */
{
/*.width = */ width,
/*.dimensions = */ dims,
/*.size = */ size,
/*.sign = */ isSigned,
},
/*.chunkInfo = */ std::nullopt,
};
},
[&](encodings::EncryptedChunkedIntegerScalarEncoding enc)
-> llvm::Expected<CircuitGate> {
auto tensorType = type.cast<mlir::RankedTensorType>();
auto glweType =
tensorType.getElementType().cast<TFHE::GLWECipherTextType>();
auto normKey = glweType.getKey().getNormalized().value();
size_t width = enc.chunkSize;
assert(enc.width % enc.chunkWidth == 0);
uint64_t size = enc.width / enc.chunkWidth;
bool isSigned = enc.isSigned;
std::vector<int64_t> dims{
(int64_t)size,
};
LweSecretKeyID secretKeyID = normKey.index;
Variance variance = curve.getVariance(1, normKey.dimension, 64);
CRTDecomposition crt = maybeCrt.value_or(std::vector<int64_t>());
return CircuitGate{
/* .encryption = */ std::optional<EncryptionGate>({
/* .secretKeyID = */ secretKeyID,
/* .variance = */ variance,
/* .encoding = */
{
/* .precision = */ width,
/* .crt = */ crt,
/*.sign = */ isSigned,
},
}),
/*.shape = */
{
/*.width = */ width,
/*.dimensions = */ dims,
/*.size = */ size,
/*.sign = */ isSigned,
},
/*.chunkInfo = */
std::optional<ChunkInfo>(
{(unsigned int)enc.chunkSize, (unsigned int)enc.chunkWidth}),
};
},
[&](encodings::EncryptedBoolScalarEncoding enc)
-> llvm::Expected<CircuitGate> {
auto glweType = type.cast<TFHE::GLWECipherTextType>();
auto normKey = glweType.getKey().getNormalized().value();
size_t width =
mlir::concretelang::FHE::EncryptedBooleanType::getWidth();
LweSecretKeyID secretKeyID = normKey.index;
Variance variance = curve.getVariance(1, normKey.dimension, 64);
return CircuitGate{
/* .encryption = */ std::optional<EncryptionGate>({
/* .secretKeyID = */ secretKeyID,
/* .variance = */ variance,
/* .encoding = */
{
/* .precision = */ width,
/* .crt = */ std::vector<int64_t>(),
/* .sign = */ false,
},
}),
/*.shape = */
{
/*.width = */ width,
/*.dimensions = */ std::vector<int64_t>(),
/*.size = */ 0,
/*.sign = */ false,
},
/*.chunkInfo = */ std::nullopt,
};
},
[&](encodings::PlaintextScalarEncoding enc)
-> llvm::Expected<CircuitGate> {
size_t width = type.getIntOrFloatBitWidth();
bool sign = type.isSignedInteger();
return CircuitGate{
/*.encryption = */ std::nullopt,
/*.shape = */
{/*.width = */ width,
/*.dimensions = */ std::vector<int64_t>(),
/*.size = */ 0,
/* .sign */ sign},
/*.chunkInfo = */ std::nullopt,
};
},
[&](encodings::IndexScalarEncoding enc) -> llvm::Expected<CircuitGate> {
// TODO - The index type is dependant of the target architecture,
// so actually we assume we target only 64 bits, we need to have
// some the size of the word of the target system.
size_t width = 64;
bool sign = type.isSignedInteger();
return CircuitGate{
/*.encryption = */ std::nullopt,
/*.shape = */
{/*.width = */ width,
/*.dimensions = */ std::vector<int64_t>(),
/*.size = */ 0,
/* .sign */ sign},
/*.chunkInfo = */ std::nullopt,
};
},
[&](auto enc) -> llvm::Expected<CircuitGate> {
return llvm::make_error<llvm::StringError>(
"cannot convert MLIR type to shape there",
llvm::inconvertibleErrorCode());
}};
auto genericVisitor = overloaded{
[&](encodings::ScalarEncoding enc) -> llvm::Expected<CircuitGate> {
return std::visit(scalarVisitor, enc);
},
[&](encodings::TensorEncoding enc) -> llvm::Expected<CircuitGate> {
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
auto scalarGate = generateGate(tensor.getElementType(),
enc.scalarEncoding, curve, maybeCrt);
if (auto err = scalarGate.takeError()) {
return std::move(err);
}
if (maybeCrt.has_value()) {
// When using crt, the last dimension of the tensor is for the members
// of the decomposition. It should not be used.
scalarGate->shape.dimensions =
tensor.getShape().take_front(tensor.getShape().size() - 1).vec();
} else {
scalarGate->shape.dimensions = tensor.getShape().vec();
}
scalarGate->shape.size = 1;
for (auto dimSize : scalarGate->shape.dimensions) {
scalarGate->shape.size *= dimSize;
}
return scalarGate;
},
[&](auto enc) -> llvm::Expected<CircuitGate> {
return llvm::make_error<llvm::StringError>(
"cannot convert MLIR type to shape here",
llvm::inconvertibleErrorCode());
}};
return std::visit(genericVisitor, encoding);
}
template <typename V> struct HashValComparator {
bool operator()(const V &lhs, const V &rhs) const {
return hash_value(lhs) < hash_value(rhs);
}
};
template <typename V> using Set = llvm::SmallSet<V, 10, HashValComparator<V>>;
void extractCircuitKeys(ClientParameters &output,
TFHE::TFHECircuitKeys circuitKeys,
concrete::SecurityCurve curve) {
// Pushing secret keys
for (auto sk : circuitKeys.secretKeys) {
clientlib::LweSecretKeyParam skParam;
skParam.dimension = sk.getNormalized().value().dimension;
output.secretKeys.push_back(skParam);
}
// Pushing keyswitch keys
for (auto ksk : circuitKeys.keyswitchKeys) {
clientlib::KeyswitchKeyParam kskParam;
auto inputNormKey = ksk.getInputKey().getNormalized().value();
auto outputNormKey = ksk.getOutputKey().getNormalized().value();
kskParam.inputSecretKeyID = inputNormKey.index;
kskParam.outputSecretKeyID = outputNormKey.index;
kskParam.level = ksk.getLevels();
kskParam.baseLog = ksk.getBaseLog();
kskParam.variance = curve.getVariance(1, outputNormKey.dimension, 64);
output.keyswitchKeys.push_back(kskParam);
}
// Pushing bootstrap keys
for (auto bsk : circuitKeys.bootstrapKeys) {
clientlib::BootstrapKeyParam bskParam;
auto inputNormKey = bsk.getInputKey().getNormalized().value();
auto outputNormKey = bsk.getOutputKey().getNormalized().value();
bskParam.inputSecretKeyID = inputNormKey.index;
bskParam.outputSecretKeyID = outputNormKey.index;
bskParam.level = bsk.getLevels();
bskParam.baseLog = bsk.getBaseLog();
bskParam.glweDimension = bsk.getGlweDim();
bskParam.polynomialSize = bsk.getPolySize();
bskParam.variance =
curve.getVariance(bsk.getGlweDim(), bsk.getPolySize(), 64);
bskParam.inputLweDimension = inputNormKey.dimension;
output.bootstrapKeys.push_back(bskParam);
}
// Pushing circuit packing keyswitch keys
for (auto pksk : circuitKeys.packingKeyswitchKeys) {
clientlib::PackingKeyswitchKeyParam pkskParam;
auto inputNormKey = pksk.getInputKey().getNormalized().value();
auto outputNormKey = pksk.getOutputKey().getNormalized().value();
pkskParam.inputSecretKeyID = inputNormKey.index;
pkskParam.outputSecretKeyID = outputNormKey.index;
pkskParam.level = pksk.getLevels();
pkskParam.baseLog = pksk.getBaseLog();
pkskParam.glweDimension = pksk.getGlweDim();
pkskParam.polynomialSize = pksk.getOutputPolySize();
pkskParam.inputLweDimension = inputNormKey.dimension;
pkskParam.variance =
curve.getVariance(outputNormKey.dimension, outputNormKey.polySize, 64);
output.packingKeyswitchKeys.push_back(pkskParam);
}
}
llvm::Expected<std::monostate>
extractCircuitGates(ClientParameters &output, mlir::func::FuncOp funcOp,
encodings::CircuitEncodings encodings,
concrete::SecurityCurve curve,
std::optional<CRTDecomposition> maybeCrt) {
// Create input and output circuit gate parameters
auto funcType = funcOp.getFunctionType();
for (auto val : llvm::zip(funcType.getInputs(), encodings.inputEncodings)) {
auto ty = std::get<0>(val);
auto encoding = std::get<1>(val);
auto gate = generateGate(ty, encoding, curve, maybeCrt);
if (auto err = gate.takeError()) {
return std::move(err);
}
output.inputs.push_back(gate.get());
}
for (auto val : llvm::zip(funcType.getResults(), encodings.outputEncodings)) {
auto ty = std::get<0>(val);
auto encoding = std::get<1>(val);
auto gate = generateGate(ty, encoding, curve, maybeCrt);
if (auto err = gate.takeError()) {
return std::move(err);
}
output.outputs.push_back(gate.get());
}
return std::monostate();
}
llvm::Expected<ClientParameters>
createClientParametersFromTFHE(mlir::ModuleOp module,
llvm::StringRef functionName, int bitsOfSecurity,
encodings::CircuitEncodings encodings,
std::optional<CRTDecomposition> maybeCrt) {
// Check that security curves exist
const auto curve = concrete::getSecurityCurve(bitsOfSecurity, keyFormat);
if (curve == nullptr) {
return StreamStringError("Cannot find security curves for ")
<< bitsOfSecurity << "bits";
}
// Check that the specified function can be found
auto rangeOps = module.getOps<mlir::func::FuncOp>();
auto funcOp = llvm::find_if(rangeOps, [&](mlir::func::FuncOp op) {
return op.getName() == functionName;
});
if (funcOp == rangeOps.end()) {
return StreamStringError(
"cannot find the function for generate client parameters: ")
<< functionName;
}
// Create client parameters
ClientParameters output;
output.functionName = (std::string)functionName;
// We extract the keys of the circuit
auto circuitKeys = TFHE::extractCircuitKeys(module);
// We extract all the keys used in the circuit
extractCircuitKeys(output, circuitKeys, *curve);
// We generate the gates for the inputs aud outputs
if (auto err =
extractCircuitGates(output, *funcOp, encodings, *curve, maybeCrt)
.takeError()) {
return std::move(err);
}
return output;
}
} // namespace concretelang
} // namespace mlir

View File

@@ -5,6 +5,7 @@
#include <fstream>
#include <iostream>
#include <llvm/Support/Debug.h>
#include <mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h>
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
#include <mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h>
@@ -42,6 +43,7 @@
#include <concretelang/Dialect/Tracing/Transforms/BufferizableOpInterfaceImpl.h>
#include <concretelang/Runtime/DFRuntime.hpp>
#include <concretelang/Support/CompilerEngine.h>
#include <concretelang/Support/Encodings.h>
#include <concretelang/Support/Error.h>
#include <concretelang/Support/Jit.h>
#include <concretelang/Support/LLVMEmitFile.h>
@@ -163,12 +165,8 @@ CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) {
auto description = descriptions->find(name);
if (description == descriptions->end()) {
std::string names;
for (auto &entry : *descriptions) {
names += "'" + entry.first + "' ";
}
return StreamStringError()
<< "Could not find existing crypto parameters for function '"
<< name << "' (known functions: " << names << ")";
return StreamStringError("Function not found, name='")
<< name << "', cannot get optimizer description";
}
return std::move(description->second);
}
@@ -194,6 +192,10 @@ llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
}
res.fheContext.emplace(
mlir::concretelang::V0FHEContext{constraint, v0Params});
CompilationFeedback feedback;
res.feedback.emplace(feedback);
return llvm::Error::success();
}
// compute parameters
@@ -286,6 +288,25 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
if (target == Target::ROUND_TRIP)
return std::move(res);
// Retrieves the encoding informations before any transformation is performed
// on the `FHE` dialect.
if ((this->generateClientParameters || target == Target::LIBRARY) &&
!options.encodings.has_value()) {
auto funcName = options.clientParametersFuncName.value_or("main");
auto maybeChunkInfo =
options.chunkIntegers
? std::optional(concretelang::clientlib::ChunkInfo{
options.chunkSize, options.chunkWidth})
: std::nullopt;
auto encodingInfosOrErr =
mlir::concretelang::encodings::getCircuitEncodings(funcName, module,
maybeChunkInfo);
if (!encodingInfosOrErr) {
return encodingInfosOrErr.takeError();
}
options.encodings = encodingInfosOrErr.get();
}
if (mlir::concretelang::pipeline::transformFHEBoolean(mlirContext, module,
enablePass)
.failed()) {
@@ -345,46 +366,6 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
if (target == Target::FHE_NO_LINALG)
return std::move(res);
// Generate client parameters if requested
if (this->generateClientParameters) {
if (!options.clientParametersFuncName.has_value()) {
return StreamStringError(
"Generation of client parameters requested, but no function name "
"specified");
}
if (!res.fheContext.has_value()) {
return StreamStringError(
"Cannot generate client parameters, the fhe context is empty for " +
options.clientParametersFuncName.value());
}
}
// Generate client parameters if requested
auto funcName = options.clientParametersFuncName.value_or("main");
if (this->generateClientParameters || target == Target::LIBRARY) {
if (!res.fheContext.has_value()) {
// Some tests involve call a to non encrypted functions
ClientParameters emptyParams;
emptyParams.functionName = funcName;
res.clientParameters = emptyParams;
} else {
llvm::Optional<::concretelang::clientlib::ChunkInfo> chunkInfo =
std::nullopt;
if (options.chunkIntegers) {
chunkInfo = ::concretelang::clientlib::ChunkInfo{options.chunkSize,
options.chunkWidth};
}
auto clientParametersOrErr =
mlir::concretelang::createClientParametersForV0(
*res.fheContext, funcName, module,
options.optimizerConfig.security, chunkInfo);
if (!clientParametersOrErr)
return clientParametersOrErr.takeError();
res.clientParameters = clientParametersOrErr.get();
res.feedback->fillFromClientParameters(*res.clientParameters);
}
}
// FHE -> TFHE
if (mlir::concretelang::pipeline::lowerFHEToTFHE(mlirContext, module,
res.fheContext, enablePass)
@@ -412,6 +393,57 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
if (target == Target::PARAMETRIZED_TFHE)
return std::move(res);
// Normalize TFHE keys
if (mlir::concretelang::pipeline::normalizeTFHEKeys(mlirContext, module,
this->enablePass)
.failed()) {
return errorDiag("Normalizing TFHE keys failed");
}
// Generate client parameters if requested
if (this->generateClientParameters) {
if (!options.clientParametersFuncName.has_value()) {
return StreamStringError(
"Generation of client parameters requested, but no function name "
"specified");
}
if (!res.fheContext.has_value()) {
return StreamStringError(
"Cannot generate client parameters, the fhe context is empty for " +
options.clientParametersFuncName.value());
}
}
// Generate client parameters if requested
if (this->generateClientParameters || target == Target::LIBRARY) {
auto funcName = options.clientParametersFuncName.value_or("main");
if (!res.fheContext.has_value()) {
// Some tests involve call a to non encrypted functions
ClientParameters emptyParams;
emptyParams.functionName = funcName;
res.clientParameters = emptyParams;
} else {
std::optional<CRTDecomposition> maybeCrt = std::nullopt;
if (res.fheContext.value().parameter.largeInteger.has_value()) {
maybeCrt = res.fheContext.value()
.parameter.largeInteger.value()
.crtDecomposition;
}
auto clientParametersOrErr =
mlir::concretelang::createClientParametersFromTFHE(
module, funcName, options.optimizerConfig.security,
options.encodings.value(), maybeCrt);
if (!clientParametersOrErr)
return clientParametersOrErr.takeError();
res.clientParameters = clientParametersOrErr.get();
res.feedback->fillFromClientParameters(*res.clientParameters);
}
}
if (target == Target::NORMALIZED_TFHE)
return std::move(res);
if (options.batchTFHEOps) {
if (mlir::concretelang::pipeline::batchTFHE(mlirContext, module, enablePass)
.failed()) {

View File

@@ -0,0 +1,247 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <concretelang/ClientLib/ClientParameters.h>
#include <concretelang/Dialect/FHE/IR/FHETypes.h>
#include <concretelang/Support/Encodings.h>
#include <concretelang/Support/Error.h>
#include <concretelang/Support/Variants.h>
#include <optional>
#include <variant>
namespace FHE = mlir::concretelang::FHE;
namespace clientlib = concretelang::clientlib;
namespace mlir {
namespace concretelang {
namespace encodings {
std::optional<Encoding>
encodingFromType(mlir::Type ty,
std::optional<clientlib::ChunkInfo> maybeChunkInfo) {
if (auto eintTy = ty.dyn_cast<FHE::FheIntegerInterface>()) {
if (maybeChunkInfo.has_value() &&
eintTy.getWidth() > maybeChunkInfo.value().size) {
auto chunkInfo = maybeChunkInfo.value();
return EncryptedChunkedIntegerScalarEncoding{
eintTy.getWidth(), eintTy.isSigned(), chunkInfo.width,
chunkInfo.size};
} else {
return EncryptedIntegerScalarEncoding{eintTy.getWidth(),
eintTy.isSigned()};
}
} else if (auto eboolTy = ty.dyn_cast<FHE::EncryptedBooleanType>()) {
return EncryptedBoolScalarEncoding{};
} else if (auto intTy = ty.dyn_cast<mlir::IntegerType>()) {
return PlaintextScalarEncoding{intTy.getWidth()};
} else if (auto indexTy = ty.dyn_cast<mlir::IndexType>()) {
return IndexScalarEncoding{};
} else if (auto tensor = ty.dyn_cast<mlir::RankedTensorType>()) {
std::optional<Encoding> maybeEncoding =
encodingFromType(tensor.getElementType(), maybeChunkInfo);
if (maybeEncoding.has_value() &&
std::holds_alternative<ScalarEncoding>(maybeEncoding.value())) {
ScalarEncoding scalarEncoding =
std::get<ScalarEncoding>(maybeEncoding.value());
return TensorEncoding{scalarEncoding};
}
}
return std::nullopt;
}
llvm::Expected<CircuitEncodings>
getCircuitEncodings(llvm::StringRef functionName, mlir::ModuleOp module,
std::optional<clientlib::ChunkInfo> maybeChunkInfo) {
// Find the input function
auto rangeOps = module.getOps<mlir::func::FuncOp>();
auto funcOp = llvm::find_if(rangeOps, [&](mlir::func::FuncOp op) {
return op.getName() == functionName;
});
if (funcOp == rangeOps.end()) {
return StreamStringError("Function not found, name='")
<< functionName << "', cannot get circuit encodings";
}
auto funcType = (*funcOp).getFunctionType();
// Retrieve input/output encodings
std::vector<Encoding> inputs;
std::vector<Encoding> outputs;
for (auto ty : funcType.getInputs()) {
auto maybeGate = encodingFromType(ty, maybeChunkInfo);
if (!maybeGate.has_value()) {
return StreamStringError("Failed to recognize encoding for type : ")
<< ty;
}
inputs.push_back(maybeGate.value());
}
for (auto ty : funcType.getResults()) {
auto maybeGate = encodingFromType(ty, maybeChunkInfo);
if (!maybeGate.has_value()) {
return StreamStringError("Failed to recognize encoding for type : ")
<< ty;
}
outputs.push_back(maybeGate.value());
}
return CircuitEncodings{inputs, outputs};
}
bool fromJSON(const llvm::json::Value j, EncryptedIntegerScalarEncoding &e,
llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
return O && O.map("width", e.width) && O.map("isSigned", e.isSigned);
}
llvm::json::Value toJSON(const EncryptedIntegerScalarEncoding &e) {
llvm::json::Object object{
{"width", e.width},
{"isSigned", e.isSigned},
};
return object;
}
bool fromJSON(const llvm::json::Value j,
EncryptedChunkedIntegerScalarEncoding &e, llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
return O && O.map("width", e.width) && O.map("isSigned", e.isSigned) &&
O.map("chunkSize", e.chunkSize) && O.map("chunkWidth", e.chunkWidth);
}
llvm::json::Value toJSON(const EncryptedChunkedIntegerScalarEncoding &e) {
llvm::json::Object object{
{"width", e.width},
{"isSigned", e.isSigned},
{"chunkSize", e.chunkSize},
{"chunkWidth", e.chunkWidth},
};
return object;
}
bool fromJSON(const llvm::json::Value j, EncryptedBoolScalarEncoding &e,
llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
return O;
}
llvm::json::Value toJSON(const EncryptedBoolScalarEncoding &e) {
llvm::json::Object object{};
return object;
}
bool fromJSON(const llvm::json::Value j, PlaintextScalarEncoding &e,
llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
return O && O.map("width", e.width);
}
llvm::json::Value toJSON(const PlaintextScalarEncoding &e) {
llvm::json::Object object{{"width", e.width}};
return object;
}
bool fromJSON(const llvm::json::Value j, IndexScalarEncoding &e,
llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
return O;
}
llvm::json::Value toJSON(const IndexScalarEncoding &e) {
llvm::json::Object object{};
return object;
}
bool fromJSON(const llvm::json::Value j, ScalarEncoding &e,
llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
if (j.getAsObject()->getObject("EncryptedIntegerScalarEncoding")) {
return O && O.map("EncryptedIntegerScalarEncoding",
std::get<EncryptedIntegerScalarEncoding>(e));
} else if (j.getAsObject()->getObject(
"EncryptedChunkedIntegerScalarEncoding")) {
return O && O.map("EncryptedChunkedIntegerScalarEncoding",
std::get<EncryptedChunkedIntegerScalarEncoding>(e));
} else if (j.getAsObject()->getObject("EncryptedBoolScalarEncoding")) {
return O && O.map("EncryptedBoolScalarEncoding",
std::get<EncryptedBoolScalarEncoding>(e));
} else if (j.getAsObject()->getObject("PlaintextScalarEncoding")) {
return O && O.map("PlaintextScalarEncoding",
std::get<PlaintextScalarEncoding>(e));
} else if (j.getAsObject()->getObject("IndexScalarEncoding")) {
return O && O.map("IndexScalarEncoding", std::get<IndexScalarEncoding>(e));
} else {
return false;
}
}
llvm::json::Value toJSON(const ScalarEncoding &e) {
llvm::json::Object object = std::visit(
overloaded{
[](EncryptedIntegerScalarEncoding enc) {
return llvm::json::Object{{"EncryptedIntegerScalarEncoding", enc}};
},
[](EncryptedChunkedIntegerScalarEncoding enc) {
return llvm::json::Object{
{"EncryptedChunkedIntegerScalarEncoding", enc}};
},
[](EncryptedBoolScalarEncoding enc) {
return llvm::json::Object{{"EncryptedBoolScalarEncoding", enc}};
},
[](PlaintextScalarEncoding enc) {
return llvm::json::Object{{"PlaintextScalarEncoding", enc}};
},
[](IndexScalarEncoding enc) {
return llvm::json::Object{{"IndexScalarEncoding", enc}};
},
},
e);
return object;
}
bool fromJSON(const llvm::json::Value j, TensorEncoding &e,
llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
return O && O.map("scalarEncoding", e.scalarEncoding);
}
llvm::json::Value toJSON(const TensorEncoding &e) {
llvm::json::Object object{{"scalarEncoding", e.scalarEncoding}};
return object;
}
bool fromJSON(const llvm::json::Value j, Encoding &e, llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
if (j.getAsObject()->getObject("ScalarEncoding")) {
e = EncryptedIntegerScalarEncoding{0, false};
return O && O.map("ScalarEncoding", std::get<ScalarEncoding>(e));
} else if (j.getAsObject()->getObject("TensorEncoding")) {
e = TensorEncoding{EncryptedIntegerScalarEncoding{0, false}};
return O && O.map("TensorEncoding", std::get<TensorEncoding>(e));
} else {
return false;
}
}
llvm::json::Value toJSON(const Encoding &e) {
llvm::json::Object object =
std::visit(overloaded{
[](ScalarEncoding enc) {
return llvm::json::Object{{"ScalarEncoding", enc}};
},
[](TensorEncoding enc) {
return llvm::json::Object{{"TensorEncoding", enc}};
},
},
e);
return object;
}
bool fromJSON(const llvm::json::Value j, CircuitEncodings &e,
llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
return O && O.map("inputEncodings", e.inputEncodings) &&
O.map("outputEncodings", e.outputEncodings);
}
llvm::json::Value toJSON(const CircuitEncodings &e) {
llvm::json::Object object{{"inputEncodings", e.inputEncodings},
{"outputEncodings", e.outputEncodings}};
return object;
}
} // namespace encodings
} // namespace concretelang
} // namespace mlir

View File

@@ -28,6 +28,7 @@
#include <mlir/Target/LLVMIR/Export.h>
#include <mlir/Transforms/Passes.h>
#include "concretelang/Conversion/TFHEKeyNormalization/Pass.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
#include <concretelang/Conversion/Passes.h>
@@ -295,6 +296,18 @@ mlir::LogicalResult batchTFHE(mlir::MLIRContext &context,
return pm.run(module.getOperation());
}
mlir::LogicalResult
normalizeTFHEKeys(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("TFHEKeyNormalization", pm, context);
addPotentiallyNestedPass(
pm, mlir::concretelang::createTFHEKeyNormalizationPass(), enablePass);
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {

View File

@@ -0,0 +1,156 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang/Support/TFHECircuitKeys.h"
#include "concretelang/Dialect/TFHE/IR/TFHEAttrs.h"
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHEParameters.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
#include "llvm/ADT/SmallVector.h"
#include <llvm/ADT/SmallVector.h>
#include <optional>
// Faster than using a full fledged hash-set for small sets, and the array can
// be recovered right away.
template <typename V> struct SmallSet {
llvm::SmallVector<V, 10> vector;
void insert(V val) {
for (auto vectorVal : vector) {
if (vectorVal == val) {
return;
}
}
vector.push_back(val);
}
};
template <typename V, unsigned N>
std::optional<size_t> vectorIndex(llvm::SmallVector<V, N> vector, V val) {
for (size_t i = 0; i < vector.size(); i++) {
auto potentialVal = vector[i];
if (potentialVal == val) {
return i;
}
}
return std::nullopt;
}
namespace mlir {
namespace concretelang {
namespace TFHE {
template <typename V, unsigned int N>
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
const mlir::SmallVector<V, N> vect) {
OS << "[";
for (auto v : vect) {
OS << v << ",";
}
OS << "]";
return OS;
}
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
const TFHECircuitKeys cks) {
OS << "TFHECircuitKeys{\n"
<< " secretKeys:" << cks.secretKeys << "\n"
<< " keyswitchKeys:" << cks.keyswitchKeys << "\n"
<< " bootstrapKeys:" << cks.bootstrapKeys << "\n"
<< " packingKeyswitchKeys:" << cks.packingKeyswitchKeys
<< "\n"
"}";
return OS;
}
TFHECircuitKeys extractCircuitKeys(mlir::ModuleOp moduleOp) {
// Gathering circuit secret keys
SmallSet<TFHE::GLWESecretKey> secretKeys;
auto tryInsert = [&](mlir::Type type) {
if (auto glweType = type.dyn_cast<TFHE::GLWECipherTextType>()) {
secretKeys.insert(glweType.getKey());
} else if (auto tensorType = type.dyn_cast<mlir::RankedTensorType>()) {
if (auto elementType = tensorType.getElementType()
.dyn_cast<TFHE::GLWECipherTextType>()) {
secretKeys.insert(elementType.getKey());
}
}
};
moduleOp->walk([&](mlir::Operation *op) {
for (auto operand : op->getOperands()) {
tryInsert(operand.getType());
}
for (auto result : op->getResults()) {
tryInsert(result.getType());
}
});
moduleOp->walk([&](mlir::func::FuncOp op) {
for (auto argType : op.getArgumentTypes()) {
tryInsert(argType);
}
for (auto resultType : op.getResultTypes()) {
tryInsert(resultType);
}
});
// Gathering circuit keyswitch keys
SmallSet<TFHE::GLWEKeyswitchKeyAttr> keyswitchKeys;
moduleOp->walk([&](TFHE::KeySwitchGLWEOp op) {
keyswitchKeys.insert(op.getKeyAttr());
secretKeys.insert(op.getKeyAttr().getInputKey());
secretKeys.insert(op.getKeyAttr().getOutputKey());
});
// Gathering circuit bootstrap keys
SmallSet<TFHE::GLWEBootstrapKeyAttr> bootstrapKeys;
moduleOp->walk([&](TFHE::BootstrapGLWEOp op) {
bootstrapKeys.insert(op.getKeyAttr());
secretKeys.insert(op.getKeyAttr().getInputKey());
secretKeys.insert(op.getKeyAttr().getOutputKey());
});
// Gathering circuit packing keyswitch keys
SmallSet<TFHE::GLWEPackingKeyswitchKeyAttr> packingKeyswitchKeys;
moduleOp->walk([&](TFHE::WopPBSGLWEOp op) {
keyswitchKeys.insert(op.getKskAttr());
secretKeys.insert(op.getKskAttr().getInputKey());
secretKeys.insert(op.getKskAttr().getOutputKey());
bootstrapKeys.insert(op.getBskAttr());
secretKeys.insert(op.getBskAttr().getInputKey());
secretKeys.insert(op.getBskAttr().getOutputKey());
packingKeyswitchKeys.insert(op.getPkskAttr());
secretKeys.insert(op.getPkskAttr().getInputKey());
secretKeys.insert(op.getPkskAttr().getOutputKey());
});
return TFHECircuitKeys{secretKeys.vector, bootstrapKeys.vector,
keyswitchKeys.vector, packingKeyswitchKeys.vector};
}
std::optional<uint64_t>
TFHE::TFHECircuitKeys::getSecretKeyIndex(TFHE::GLWESecretKey key) {
return vectorIndex(this->secretKeys, key);
}
std::optional<uint64_t>
TFHE::TFHECircuitKeys::getBootstrapKeyIndex(TFHE::GLWEBootstrapKeyAttr key) {
return vectorIndex(this->bootstrapKeys, key);
}
std::optional<uint64_t>
TFHE::TFHECircuitKeys::getKeyswitchKeyIndex(TFHE::GLWEKeyswitchKeyAttr key) {
return vectorIndex(this->keyswitchKeys, key);
}
std::optional<uint64_t> TFHE::TFHECircuitKeys::getPackingKeyswitchKeyIndex(
TFHE::GLWEPackingKeyswitchKeyAttr key) {
return vectorIndex(this->packingKeyswitchKeys, key);
}
} // namespace TFHE
} // namespace concretelang
} // namespace mlir

View File

@@ -1,257 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <cassert>
#include <map>
#include <llvm/ADT/Optional.h>
#include <llvm/ADT/STLExtras.h>
#include <llvm/Support/Error.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <optional>
#include "concrete/curves.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
#include "concretelang/Support/Error.h"
namespace mlir {
namespace concretelang {
namespace clientlib = ::concretelang::clientlib;
using ::concretelang::clientlib::ChunkInfo;
using ::concretelang::clientlib::CircuitGate;
using ::concretelang::clientlib::ClientParameters;
using ::concretelang::clientlib::Encoding;
using ::concretelang::clientlib::EncryptionGate;
using ::concretelang::clientlib::LweSecretKeyID;
using ::concretelang::clientlib::Precision;
using ::concretelang::clientlib::Variance;
const auto keyFormat = concrete::BINARY;
/// For the v0 the secretKeyID and precision are the same for all gates.
llvm::Expected<CircuitGate>
gateFromMLIRType(V0FHEContext fheContext, LweSecretKeyID secretKeyID,
Variance variance, llvm::Optional<ChunkInfo> chunkInfo,
mlir::Type type) {
if (type.isIntOrIndex()) {
// TODO - The index type is dependant of the target architecture, so
// actually we assume we target only 64 bits, we need to have some the size
// of the word of the target system.
size_t width = 64;
if (!type.isIndex()) {
width = type.getIntOrFloatBitWidth();
}
bool sign = type.isSignedInteger();
return CircuitGate{
/*.encryption = */ std::nullopt,
/*.shape = */
{/*.width = */ width,
/*.dimensions = */ std::vector<int64_t>(),
/*.size = */ 0,
/* .sign */ sign},
/*.chunkInfo = */ std::nullopt,
};
}
if (auto lweTy = type.dyn_cast_or_null<
mlir::concretelang::FHE::FheIntegerInterface>()) {
bool sign = lweTy.isSigned();
std::vector<int64_t> crt;
if (fheContext.parameter.largeInteger.has_value()) {
crt = fheContext.parameter.largeInteger.value().crtDecomposition;
}
size_t width;
uint64_t size = 0;
std::vector<int64_t> dims;
if (chunkInfo.has_value()) {
width = chunkInfo->size;
assert(lweTy.getWidth() % chunkInfo->width == 0);
size = lweTy.getWidth() / chunkInfo->width;
dims.push_back(size);
} else {
width = (size_t)lweTy.getWidth();
}
return CircuitGate{
/* .encryption = */ std::optional<EncryptionGate>({
/* .secretKeyID = */ secretKeyID,
/* .variance = */ variance,
/* .encoding = */
{
/* .precision = */ width,
/* .crt = */ crt,
/*.sign = */ sign,
},
}),
/*.shape = */
{
/*.width = */ width,
/*.dimensions = */ dims,
/*.size = */ size,
/*.sign = */ sign,
},
/*.chunkInfo = */ chunkInfo,
};
}
if (auto lweTy = type.dyn_cast_or_null<
mlir::concretelang::FHE::EncryptedBooleanType>()) {
size_t width = mlir::concretelang::FHE::EncryptedBooleanType::getWidth();
return CircuitGate{
/* .encryption = */ std::optional<EncryptionGate>({
/* .secretKeyID = */ secretKeyID,
/* .variance = */ variance,
/* .encoding = */
{
/* .precision = */ width,
/* .crt = */ std::vector<int64_t>(),
/* .sign = */ false,
},
}),
/*.shape = */
{
/*.width = */ width,
/*.dimensions = */ std::vector<int64_t>(),
/*.size = */ 0,
/*.sign = */ false,
},
/*.chunkInfo = */ std::nullopt,
};
}
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
if (tensor != nullptr) {
auto gate = gateFromMLIRType(fheContext, secretKeyID, variance, chunkInfo,
tensor.getElementType());
if (auto err = gate.takeError()) {
return std::move(err);
}
gate->shape.dimensions = tensor.getShape().vec();
gate->shape.size = 1;
for (auto dimSize : gate->shape.dimensions) {
gate->shape.size *= dimSize;
}
return gate;
}
return llvm::make_error<llvm::StringError>(
"cannot convert MLIR type to shape", llvm::inconvertibleErrorCode());
}
llvm::Expected<ClientParameters>
createClientParametersForV0(V0FHEContext fheContext,
llvm::StringRef functionName, mlir::ModuleOp module,
int bitsOfSecurity,
llvm::Optional<ChunkInfo> chunkInfo) {
const auto v0Curve = concrete::getSecurityCurve(bitsOfSecurity, keyFormat);
if (v0Curve == nullptr) {
return StreamStringError("Cannot find security curves for ")
<< bitsOfSecurity << "bits";
}
V0Parameter &v0Param = fheContext.parameter;
Variance inputVariance =
v0Curve->getVariance(1, v0Param.getNBigLweDimension(), 64);
Variance bootstrapKeyVariance = v0Curve->getVariance(
v0Param.glweDimension, v0Param.getPolynomialSize(), 64);
Variance keyswitchKeyVariance = v0Curve->getVariance(1, v0Param.nSmall, 64);
// Static client parameters from global parameters for v0
ClientParameters c;
assert(c.secretKeys.size() == clientlib::BIG_KEY);
clientlib::LweSecretKeyParam skParam;
skParam.dimension = v0Param.getNBigLweDimension();
c.secretKeys.push_back(skParam);
bool has_small_key = v0Param.nSmall != 0;
bool has_bootstrap = v0Param.brLevel != 0;
if (has_small_key) {
assert(c.secretKeys.size() == clientlib::SMALL_KEY);
clientlib::LweSecretKeyParam skParam2;
skParam2.dimension = v0Param.nSmall;
c.secretKeys.push_back(skParam2);
}
if (has_bootstrap) {
auto inputKey = (has_small_key) ? clientlib::SMALL_KEY : clientlib::BIG_KEY;
clientlib::BootstrapKeyParam bskParam;
bskParam.inputSecretKeyID = inputKey;
bskParam.outputSecretKeyID = clientlib::BIG_KEY;
bskParam.level = v0Param.brLevel;
bskParam.baseLog = v0Param.brLogBase;
bskParam.glweDimension = v0Param.glweDimension;
bskParam.variance = bootstrapKeyVariance;
bskParam.polynomialSize = v0Param.getPolynomialSize();
bskParam.inputLweDimension = v0Param.nSmall;
c.bootstrapKeys.push_back(bskParam);
}
if (v0Param.largeInteger.has_value()) {
clientlib::PackingKeyswitchKeyParam param;
param.inputSecretKeyID = clientlib::BIG_KEY;
param.outputSecretKeyID = clientlib::BIG_KEY;
param.level = v0Param.largeInteger->wopPBS.packingKeySwitch.level;
param.baseLog = v0Param.largeInteger->wopPBS.packingKeySwitch.baseLog;
param.glweDimension = v0Param.glweDimension;
param.polynomialSize = v0Param.getPolynomialSize();
param.inputLweDimension = v0Param.getNBigLweDimension();
param.variance = v0Curve->getVariance(v0Param.glweDimension,
v0Param.getPolynomialSize(), 64);
c.packingKeyswitchKeys.push_back(param);
}
if (has_small_key) {
clientlib::KeyswitchKeyParam kskParam;
kskParam.inputSecretKeyID = clientlib::BIG_KEY;
kskParam.outputSecretKeyID = clientlib::SMALL_KEY;
kskParam.level = v0Param.ksLevel;
kskParam.baseLog = v0Param.ksLogBase;
kskParam.variance = keyswitchKeyVariance;
c.keyswitchKeys.push_back(kskParam);
}
c.functionName = (std::string)functionName;
// Find the input function
auto rangeOps = module.getOps<mlir::func::FuncOp>();
auto funcOp = llvm::find_if(rangeOps, [&](mlir::func::FuncOp op) {
return op.getName() == functionName;
});
if (funcOp == rangeOps.end()) {
return StreamStringError(
"cannot find the function for generate client parameters: ")
<< functionName;
}
// Create input and output circuit gate parameters
auto funcType = (*funcOp).getFunctionType();
auto inputs = funcType.getInputs();
auto gateFromType = [&](mlir::Type ty) {
return gateFromMLIRType(fheContext, clientlib::BIG_KEY, inputVariance,
chunkInfo, ty);
};
for (auto inType : inputs) {
auto gate = gateFromType(inType);
if (auto err = gate.takeError()) {
return std::move(err);
}
c.inputs.push_back(gate.get());
}
for (auto outType : funcType.getResults()) {
auto gate = gateFromType(outType);
if (auto err = gate.takeError()) {
return std::move(err);
}
c.outputs.push_back(gate.get());
}
return c;
}
} // namespace concretelang
} // namespace mlir

View File

@@ -254,7 +254,7 @@ llvm::Expected<V0Parameter> getParameter(optimizer::Description &descr,
lParams.wopPBS.circuitBootstrap.baseLog = sol.cb_decomposition_base_log;
lParams.wopPBS.circuitBootstrap.level = sol.cb_decomposition_level_count;
lParams.wopPBS.packingKeySwitch.inputLweDimension =
sol.internal_ks_output_lwe_dimension + 1;
sol.internal_ks_output_lwe_dimension;
lParams.wopPBS.packingKeySwitch.outputPolynomialSize =
sol.glwe_polynomial_size;
lParams.wopPBS.packingKeySwitch.level = sol.pp_decomposition_level_count;

View File

@@ -7,6 +7,7 @@
#include <iostream>
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/JSON.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/ToolOutputFile.h>
#include <mlir/Dialect/Linalg/IR/Linalg.h>
@@ -34,6 +35,7 @@
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Encodings.h"
#include "concretelang/Support/Error.h"
#include "concretelang/Support/JITSupport.h"
#include "concretelang/Support/LLVMEmitFile.h"
@@ -43,12 +45,14 @@
#include "mlir/IR/BuiltinOps.h"
namespace clientlib = concretelang::clientlib;
namespace encodings = mlir::concretelang::encodings;
enum Action {
ROUND_TRIP,
DUMP_FHE,
DUMP_FHE_NO_LINALG,
DUMP_TFHE,
DUMP_NORMALIZED_TFHE,
DUMP_PARAMETRIZED_TFHE,
DUMP_BATCHED_TFHE,
DUMP_CONCRETE,
@@ -124,8 +128,12 @@ static llvm::cl::opt<enum Action> action(
llvm::cl::values(clEnumValN(Action::DUMP_FHE_NO_LINALG,
"dump-fhe-no-linalg",
"Lower FHELinalg to FHE and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_TFHE, "dump-tfhe",
"Lower to TFHE and dump result")),
llvm::cl::values(
clEnumValN(Action::DUMP_TFHE, "dump-tfhe",
"Lower to unparameterized TFHE and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_NORMALIZED_TFHE,
"dump-normalized-tfhe",
"Lower to normalized TFHE and dump result")),
llvm::cl::values(clEnumValN(
Action::DUMP_PARAMETRIZED_TFHE, "dump-parametrized-tfhe",
"Lower to TFHE, parametrize TFHE operations and dump result")),
@@ -327,6 +335,12 @@ llvm::cl::list<int64_t> largeIntegerCircuitBootstrap(
"(experimental) [level, baseLog]"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
llvm::cl::opt<std::string> circuitEncodings(
"circuit-encodings",
llvm::cl::desc("Specify the input and output encodings of the circuit, "
"using the JSON representation."),
llvm::cl::init(std::string{}));
} // namespace cmdline
namespace llvm {
@@ -447,6 +461,20 @@ cmdlineCompilationOptions() {
llvm::inconvertibleErrorCode());
}
if (!cmdline::circuitEncodings.empty()) {
auto jsonString = cmdline::circuitEncodings.getValue();
auto maybeEncodings =
llvm::json::parse<encodings::CircuitEncodings>(jsonString);
if (auto err = maybeEncodings.takeError()) {
return llvm::make_error<llvm::StringError>(
"Failed to parse the --circuit-encodings option.",
llvm::inconvertibleErrorCode());
}
options.encodings = maybeEncodings.get();
} else {
options.encodings = std::nullopt;
}
return options;
}
@@ -532,6 +560,9 @@ mlir::LogicalResult processInputBuffer(
case Action::DUMP_TFHE:
target = mlir::concretelang::CompilerEngine::Target::TFHE;
break;
case Action::DUMP_NORMALIZED_TFHE:
target = mlir::concretelang::CompilerEngine::Target::NORMALIZED_TFHE;
break;
case Action::DUMP_PARAMETRIZED_TFHE:
target = mlir::concretelang::CompilerEngine::Target::PARAMETRIZED_TFHE;
break;

View File

@@ -4,7 +4,7 @@
//CHECK: llvm.call @memref_bootstrap_lwe_cuda_u64
func.func @main(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
%cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
%0 = "Concrete.keyswitch_lwe_tensor"(%arg0) {baseLog = 2 : i32, level = 5 : i32, lwe_dim_in = 1025 : i32, lwe_dim_out = 576 : i32} : (tensor<1025xi64>) -> tensor<576xi64>
%1 = "Concrete.bootstrap_lwe_tensor"(%0, %cst) {baseLog = 2 : i32, level = 5 : i32, polySize = 1024: i32, glweDimension = 1 : i32, inputLweDim = 576 : i32, outPrecision = 2 : i32} : (tensor<576xi64>, tensor<4xi64>) -> tensor<1025xi64>
%0 = "Concrete.keyswitch_lwe_tensor"(%arg0) {baseLog = 2 : i32, kskIndex = 0 : i32, level = 5 : i32, lwe_dim_in = 1025 : i32, lwe_dim_out = 576 : i32} : (tensor<1025xi64>) -> tensor<576xi64>
%1 = "Concrete.bootstrap_lwe_tensor"(%0, %cst) {baseLog = 2 : i32, bskIndex = 0 : i32, level = 5 : i32, polySize = 1024: i32, glweDimension = 1 : i32, inputLweDim = 576 : i32, outPrecision = 2 : i32} : (tensor<576xi64>, tensor<4xi64>) -> tensor<1025xi64>
return %1 : tensor<1025xi64>
}

View File

@@ -1,18 +1,18 @@
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
//CHECK: func.func @add_eint(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk[?]>>, %[[Varg1:.*]]: tensor<5x!TFHE.glwe<sk[?]>>) -> tensor<5x!TFHE.glwe<sk[?]>> {
//CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk[?]>>
//CHECK: func.func @add_eint(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk?>>, %[[Varg1:.*]]: tensor<5x!TFHE.glwe<sk?>>) -> tensor<5x!TFHE.glwe<sk?>> {
//CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
//CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
//CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg2:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg3:.*]] = %[[V0]]) -> (tensor<5x!TFHE.glwe<sk[?]>>) {
//CHECK-NEXT: %[[V2:.*]] = tensor.extract %[[Varg0]][%[[Varg2]]] : tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: %[[V3:.*]] = tensor.extract %[[Varg1]][%[[Varg2]]] : tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: %[[V4:.*]] = "TFHE.add_glwe"(%[[V2]], %[[V3]]) : (!TFHE.glwe<sk[?]>, !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
//CHECK-NEXT: %[[V5:.*]] = tensor.insert %[[V4]] into %[[Varg3]][%[[Varg2]]] : tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: scf.yield %[[V5]] : tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg2:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg3:.*]] = %[[V0]]) -> (tensor<5x!TFHE.glwe<sk?>>) {
//CHECK-NEXT: %[[V2:.*]] = tensor.extract %[[Varg0]][%[[Varg2]]] : tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: %[[V3:.*]] = tensor.extract %[[Varg1]][%[[Varg2]]] : tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: %[[V4:.*]] = "TFHE.add_glwe"(%[[V2]], %[[V3]]) : (!TFHE.glwe<sk?>, !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
//CHECK-NEXT: %[[V5:.*]] = tensor.insert %[[V4]] into %[[Varg3]][%[[Varg2]]] : tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: scf.yield %[[V5]] : tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: }
//CHECK-NEXT: return %[[V1]] : tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: return %[[V1]] : tensor<5x!TFHE.glwe<sk?>>
func.func @add_eint(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>

View File

@@ -1,21 +1,21 @@
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
// CHECK: func.func @add_eint_int(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk[?]>>) -> tensor<5x!TFHE.glwe<sk[?]>> {
// CHECK: func.func @add_eint_int(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk?>>) -> tensor<5x!TFHE.glwe<sk?>> {
// CHECK-NEXT: %[[Vc1_i8:.*]] = arith.constant 1 : i8
// CHECK-NEXT: %[[V0:.*]] = arith.extsi %[[Vc1_i8]] : i8 to i64
// CHECK-NEXT: %[[V1:.*]] = "TFHE.encode_plaintext_with_crt"(%[[V0]]) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
// CHECK-NEXT: %[[V2:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: %[[V2:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
// CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V2]]) -> (tensor<5x!TFHE.glwe<sk[?]>>) {
// CHECK-NEXT: %[[V4:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V2]]) -> (tensor<5x!TFHE.glwe<sk?>>) {
// CHECK-NEXT: %[[V4:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: %[[V5:.*]] = tensor.extract %[[V1]][%[[Varg1]]] : tensor<5xi64>
// CHECK-NEXT: %[[V6:.*]] = "TFHE.add_glwe_int"(%[[V4]], %[[V5]]) : (!TFHE.glwe<sk[?]>, i64) -> !TFHE.glwe<sk[?]>
// CHECK-NEXT: %[[V7:.*]] = tensor.insert %[[V6]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: scf.yield %[[V7]] : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: %[[V6:.*]] = "TFHE.add_glwe_int"(%[[V4]], %[[V5]]) : (!TFHE.glwe<sk?>, i64) -> !TFHE.glwe<sk?>
// CHECK-NEXT: %[[V7:.*]] = tensor.insert %[[V6]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: scf.yield %[[V7]] : tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: }
// CHECK-NEXT: return %[[V3]] : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: return %[[V3]] : tensor<5x!TFHE.glwe<sk?>>
func.func @add_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
%0 = arith.constant 1 : i8
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)

View File

@@ -1,9 +1,9 @@
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
// CHECK: func.func @apply_lookup_table(%arg0: tensor<5x!TFHE.glwe<sk[?]>>, %arg1: tensor<4xi64>) -> tensor<5x!TFHE.glwe<sk[?]>>
// CHECK: func.func @apply_lookup_table(%arg0: tensor<5x!TFHE.glwe<sk?>>, %arg1: tensor<4xi64>) -> tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: %0 = "TFHE.encode_lut_for_crt_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32} : (tensor<4xi64>) -> tensor<5x8192xi64>
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bsk = #TFHE.bsk<sk[?], sk[?], -1, -1, -1, -1>, cbsBaseLog = -1 : i32, cbsLevels = -1 : i32, crtDecomposition = [], ksk = #TFHE.ksk<sk[?], sk[?], -1, -1>, pksk = #TFHE.pksk<sk[?], sk[?], -1, -1, -1, -1>} : (tensor<5x!TFHE.glwe<sk[?]>>, tensor<5x8192xi64>) -> tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bsk = #TFHE.bsk<sk?, sk?, -1, -1, -1, -1>, cbsBaseLog = -1 : i32, cbsLevels = -1 : i32, crtDecomposition = [], ksk = #TFHE.ksk<sk?, sk?, -1, -1>, pksk = #TFHE.pksk<sk?, sk?, -1, -1, -1, -1, -1>} : (tensor<5x!TFHE.glwe<sk?>>, tensor<5x8192xi64>) -> tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<sk?>>
func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> {
%1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>)
return %1: !FHE.eint<3>

View File

@@ -1,10 +1,10 @@
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
// CHECK: func.func @apply_lookup_table_cst(%arg0: tensor<5x!TFHE.glwe<sk[?]>>) -> tensor<5x!TFHE.glwe<sk[?]>> {
// CHECK: func.func @apply_lookup_table_cst(%arg0: tensor<5x!TFHE.glwe<sk?>>) -> tensor<5x!TFHE.glwe<sk?>> {
// CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
// CHECK-NEXT: %0 = "TFHE.encode_lut_for_crt_woppbs"(%cst) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32} : (tensor<128xi64>) -> tensor<5x8192xi64>
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bsk = #TFHE.bsk<sk[?], sk[?], -1, -1, -1, -1>, cbsBaseLog = -1 : i32, cbsLevels = -1 : i32, crtDecomposition = [], ksk = #TFHE.ksk<sk[?], sk[?], -1, -1>, pksk = #TFHE.pksk<sk[?], sk[?], -1, -1, -1, -1>} : (tensor<5x!TFHE.glwe<sk[?]>>, tensor<5x8192xi64>) -> tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bsk = #TFHE.bsk<sk?, sk?, -1, -1, -1, -1>, cbsBaseLog = -1 : i32, cbsLevels = -1 : i32, crtDecomposition = [], ksk = #TFHE.ksk<sk?, sk?, -1, -1>, pksk = #TFHE.pksk<sk?, sk?, -1, -1, -1, -1, -1>} : (tensor<5x!TFHE.glwe<sk?>>, tensor<5x8192xi64>) -> tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<sk?>>
func.func @apply_lookup_table_cst(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64>
%1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>)

View File

@@ -1,6 +1,6 @@
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
//CHECK: func.func @conv2d(%[[Varg0:.*]]: tensor<100x3x28x28x5x!TFHE.glwe<sk[?]>>, %[[Varg1:.*]]: tensor<4x3x14x14xi3>, %[[Varg2:.*]]: tensor<4xi3>) -> tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>> {
//CHECK: func.func @conv2d(%[[Varg0:.*]]: tensor<100x3x28x28x5x!TFHE.glwe<sk?>>, %[[Varg1:.*]]: tensor<4x3x14x14xi3>, %[[Varg2:.*]]: tensor<4xi3>) -> tensor<100x4x15x15x5x!TFHE.glwe<sk?>> {
//CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vc100:.*]] = arith.constant 100 : index
//CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
@@ -8,90 +8,90 @@
//CHECK-NEXT: %[[Vc15:.*]] = arith.constant 15 : index
//CHECK-NEXT: %[[Vc3:.*]] = arith.constant 3 : index
//CHECK-NEXT: %[[Vc14:.*]] = arith.constant 14 : index
//CHECK-NEXT: %[[V0:.*]] = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V0]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
//CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
//CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
//CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
//CHECK-NEXT: %[[V0:.*]] = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
//CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V0]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
//CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
//CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
//CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
//CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[Varg2]]{{\[}}%[[Varg5]]{{\]}} : tensor<4xi3>
//CHECK-NEXT: %[[Vc0_0:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vextracted_slice:.*]] = tensor.extract_slice %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_0]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>> to tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: %[[Vextracted_slice:.*]] = tensor.extract_slice %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_0]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<100x4x15x15x5x!TFHE.glwe<sk?>> to tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: %[[V6:.*]] = arith.extsi %[[Vextracted]] : i3 to i64
//CHECK-NEXT: %[[V7:.*]] = "TFHE.encode_plaintext_with_crt"(%[[V6]]) {mods = {{\[2, 3, 5, 7, 11\], modsProd}} = 2310 : i64} : (i64) -> tensor<5xi64>
//CHECK-NEXT: %[[V8:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: %[[V8:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: %[[Vc0_1:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vc1_2:.*]] = arith.constant 1 : index
//CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
//CHECK-NEXT: %[[V9:.*]] = scf.for %[[Varg11:.*]] = %[[Vc0_1]] to %[[Vc5]] step %[[Vc1_2]] iter_args(%[[Varg12:.*]] = %[[V8]]) -> (tensor<5x!TFHE.glwe<sk[?]>>) {
//CHECK-NEXT: %[[Vextracted_4:.*]] = tensor.extract %[[Vextracted_slice]]{{\[}}%[[Varg11]]{{\]}} : tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: %[[V9:.*]] = scf.for %[[Varg11:.*]] = %[[Vc0_1]] to %[[Vc5]] step %[[Vc1_2]] iter_args(%[[Varg12:.*]] = %[[V8]]) -> (tensor<5x!TFHE.glwe<sk?>>) {
//CHECK-NEXT: %[[Vextracted_4:.*]] = tensor.extract %[[Vextracted_slice]]{{\[}}%[[Varg11]]{{\]}} : tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: %[[Vextracted_5:.*]] = tensor.extract %[[V7]]{{\[}}%[[Varg11]]{{\]}} : tensor<5xi64>
//CHECK-NEXT: %[[V10:.*]] = "TFHE.add_glwe_int"(%[[Vextracted_4]], %[[Vextracted_5]]) : (!TFHE.glwe<sk[?]>, i64) -> !TFHE.glwe<sk[?]>
//CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V10]] into %[[Varg12]]{{\[}}%[[Varg11]]{{\]}} : tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: %[[V10:.*]] = "TFHE.add_glwe_int"(%[[Vextracted_4]], %[[Vextracted_5]]) : (!TFHE.glwe<sk?>, i64) -> !TFHE.glwe<sk?>
//CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V10]] into %[[Varg12]]{{\[}}%[[Varg11]]{{\]}} : tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: }
//CHECK-NEXT: %[[Vc0_3:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vinserted_slice:.*]] = tensor.insert_slice %[[V9]] into %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_3]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<5x!TFHE.glwe<sk[?]>> into tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: scf.yield %[[Vinserted_slice]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: %[[Vinserted_slice:.*]] = tensor.insert_slice %[[V9]] into %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_3]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<5x!TFHE.glwe<sk?>> into tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
//CHECK-NEXT: scf.yield %[[Vinserted_slice]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
//CHECK-NEXT: }
//CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
//CHECK-NEXT: }
//CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
//CHECK-NEXT: }
//CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
//CHECK-NEXT: }
//CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V1]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
//CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
//CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
//CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
//CHECK-NEXT: %[[V6:.*]] = scf.for %[[Varg11:.*]] = %[[Vc0]] to %[[Vc3]] step %[[Vc1]] iter_args(%[[Varg12:.*]] = %[[Varg10]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
//CHECK-NEXT: %[[V7:.*]] = scf.for %[[Varg13:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg14:.*]] = %[[Varg12]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
//CHECK-NEXT: %[[V8:.*]] = scf.for %[[Varg15:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg16:.*]] = %[[Varg14]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>) {
//CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V1]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
//CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
//CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
//CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
//CHECK-NEXT: %[[V6:.*]] = scf.for %[[Varg11:.*]] = %[[Vc0]] to %[[Vc3]] step %[[Vc1]] iter_args(%[[Varg12:.*]] = %[[Varg10]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
//CHECK-NEXT: %[[V7:.*]] = scf.for %[[Varg13:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg14:.*]] = %[[Varg12]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
//CHECK-NEXT: %[[V8:.*]] = scf.for %[[Varg15:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg16:.*]] = %[[Varg14]]) -> (tensor<100x4x15x15x5x!TFHE.glwe<sk?>>) {
//CHECK-NEXT: %[[V9:.*]] = affine.apply #map(%[[Varg7]], %[[Varg13]])
//CHECK-NEXT: %[[V10:.*]] = affine.apply #map(%[[Varg9]], %[[Varg15]])
//CHECK-NEXT: %[[Vc0_0:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vextracted_slice:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[Varg3]], %[[Varg11]], %[[V9]], %[[V10]], %[[Vc0_0]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<100x3x28x28x5x!TFHE.glwe<sk[?]>> to tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: %[[Vextracted_slice:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[Varg3]], %[[Varg11]], %[[V9]], %[[V10]], %[[Vc0_0]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<100x3x28x28x5x!TFHE.glwe<sk?>> to tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[Varg1]]{{\[}}%[[Varg5]], %[[Varg11]], %[[Varg13]], %[[Varg15]]{{\]}} : tensor<4x3x14x14xi3>
//CHECK-NEXT: %[[Vc0_1:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vextracted_slice_2:.*]] = tensor.extract_slice %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_1]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>> to tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: %[[Vextracted_slice_2:.*]] = tensor.extract_slice %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_1]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<100x4x15x15x5x!TFHE.glwe<sk?>> to tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: %[[V11:.*]] = arith.extsi %[[Vextracted]] : i3 to i64
//CHECK-NEXT: %[[V12:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: %[[V12:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: %[[Vc0_3:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vc1_4:.*]] = arith.constant 1 : index
//CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
//CHECK-NEXT: %[[V13:.*]] = scf.for %[[Varg17:.*]] = %[[Vc0_3]] to %[[Vc5]] step %[[Vc1_4]] iter_args(%[[Varg18:.*]] = %[[V12]]) -> (tensor<5x!TFHE.glwe<sk[?]>>) {
//CHECK-NEXT: %[[Vextracted_9:.*]] = tensor.extract %[[Vextracted_slice]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: %[[V16:.*]] = "TFHE.mul_glwe_int"(%[[Vextracted_9]], %[[V11]]) : (!TFHE.glwe<sk[?]>, i64) -> !TFHE.glwe<sk[?]>
//CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V16]] into %[[Varg18]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: %[[V13:.*]] = scf.for %[[Varg17:.*]] = %[[Vc0_3]] to %[[Vc5]] step %[[Vc1_4]] iter_args(%[[Varg18:.*]] = %[[V12]]) -> (tensor<5x!TFHE.glwe<sk?>>) {
//CHECK-NEXT: %[[Vextracted_9:.*]] = tensor.extract %[[Vextracted_slice]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: %[[V16:.*]] = "TFHE.mul_glwe_int"(%[[Vextracted_9]], %[[V11]]) : (!TFHE.glwe<sk?>, i64) -> !TFHE.glwe<sk?>
//CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V16]] into %[[Varg18]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: }
//CHECK-NEXT: %[[V14:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: %[[V14:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: %[[Vc0_5:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vc1_6:.*]] = arith.constant 1 : index
//CHECK-NEXT: %[[Vc5_7:.*]] = arith.constant 5 : index
//CHECK-NEXT: %[[V15:.*]] = scf.for %[[Varg17:.*]] = %[[Vc0_5]] to %[[Vc5_7]] step %[[Vc1_6]] iter_args(%[[Varg18:.*]] = %[[V14]]) -> (tensor<5x!TFHE.glwe<sk[?]>>) {
//CHECK-NEXT: %[[Vextracted_9:.*]] = tensor.extract %[[Vextracted_slice_2]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: %[[Vextracted_10:.*]] = tensor.extract %[[V13]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: %[[V16:.*]] = "TFHE.add_glwe"(%[[Vextracted_9]], %[[Vextracted_10]]) : (!TFHE.glwe<sk[?]>, !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
//CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V16]] into %[[Varg18]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: %[[V15:.*]] = scf.for %[[Varg17:.*]] = %[[Vc0_5]] to %[[Vc5_7]] step %[[Vc1_6]] iter_args(%[[Varg18:.*]] = %[[V14]]) -> (tensor<5x!TFHE.glwe<sk?>>) {
//CHECK-NEXT: %[[Vextracted_9:.*]] = tensor.extract %[[Vextracted_slice_2]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: %[[Vextracted_10:.*]] = tensor.extract %[[V13]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: %[[V16:.*]] = "TFHE.add_glwe"(%[[Vextracted_9]], %[[Vextracted_10]]) : (!TFHE.glwe<sk?>, !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
//CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V16]] into %[[Varg18]]{{\[}}%[[Varg17]]{{\]}} : tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<5x!TFHE.glwe<sk?>>
//CHECK-NEXT: }
//CHECK-NEXT: %[[Vc0_8:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[Vinserted_slice:.*]] = tensor.insert_slice %[[V15]] into %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_8]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<5x!TFHE.glwe<sk[?]>> into tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: scf.yield %[[Vinserted_slice]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: %[[Vinserted_slice:.*]] = tensor.insert_slice %[[V15]] into %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]], %[[Vc0_8]]{{\] \[1, 1, 1, 1, 5\] \[1, 1, 1, 1, 1\]}} : tensor<5x!TFHE.glwe<sk?>> into tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
//CHECK-NEXT: scf.yield %[[Vinserted_slice]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
//CHECK-NEXT: }
//CHECK-NEXT: scf.yield %[[V8]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: scf.yield %[[V8]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
//CHECK-NEXT: }
//CHECK-NEXT: scf.yield %[[V7]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: scf.yield %[[V7]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
//CHECK-NEXT: }
//CHECK-NEXT: scf.yield %[[V6]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: scf.yield %[[V6]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
//CHECK-NEXT: }
//CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
//CHECK-NEXT: }
//CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
//CHECK-NEXT: }
//CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
//CHECK-NEXT: }
//CHECK-NEXT: return %[[V2]] : tensor<100x4x15x15x5x!TFHE.glwe<sk[?]>>
//CHECK-NEXT: return %[[V2]] : tensor<100x4x15x15x5x!TFHE.glwe<sk?>>
//CHECK-NEXT: }
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0, 0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>

View File

@@ -1,19 +1,19 @@
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
// CHECK: func.func @mul_eint_int(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk[?]>>) -> tensor<5x!TFHE.glwe<sk[?]>> {
// CHECK: func.func @mul_eint_int(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk?>>) -> tensor<5x!TFHE.glwe<sk?>> {
// CHECK-NEXT: %[[Vc2_i8:.*]] = arith.constant 2 : i8
// CHECK-NEXT: %[[V0:.*]] = arith.extsi %[[Vc2_i8]] : i8 to i64
// CHECK-NEXT: %[[V1:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: %[[V1:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
// CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
// CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V1]]) -> (tensor<5x!TFHE.glwe<sk[?]>>) {
// CHECK-NEXT: %[[V3:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: %[[V4:.*]] = "TFHE.mul_glwe_int"(%[[V3]], %[[V0]]) : (!TFHE.glwe<sk[?]>, i64) -> !TFHE.glwe<sk[?]>
// CHECK-NEXT: %[[V5:.*]] = tensor.insert %[[V4]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: scf.yield %[[V5]] : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V1]]) -> (tensor<5x!TFHE.glwe<sk?>>) {
// CHECK-NEXT: %[[V3:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: %[[V4:.*]] = "TFHE.mul_glwe_int"(%[[V3]], %[[V0]]) : (!TFHE.glwe<sk?>, i64) -> !TFHE.glwe<sk?>
// CHECK-NEXT: %[[V5:.*]] = tensor.insert %[[V4]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: scf.yield %[[V5]] : tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: }
// CHECK-NEXT: return %[[V2]] : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: return %[[V2]] : tensor<5x!TFHE.glwe<sk?>>
func.func @mul_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
%0 = arith.constant 2 : i8
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)

View File

@@ -1,17 +1,17 @@
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
// CHECK: func.func @neg_eint(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk[?]>>) -> tensor<5x!TFHE.glwe<sk[?]>> {
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK: func.func @neg_eint(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk?>>) -> tensor<5x!TFHE.glwe<sk?>> {
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
// CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
// CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V0]]) -> (tensor<5x!TFHE.glwe<sk[?]>>) {
// CHECK-NEXT: %[[V2:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: %[[V3:.*]] = "TFHE.neg_glwe"(%[[V2]]) : (!TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
// CHECK-NEXT: %[[V4:.*]] = tensor.insert %[[V3]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: scf.yield %[[V4]] : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V0]]) -> (tensor<5x!TFHE.glwe<sk?>>) {
// CHECK-NEXT: %[[V2:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: %[[V3:.*]] = "TFHE.neg_glwe"(%[[V2]]) : (!TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
// CHECK-NEXT: %[[V4:.*]] = tensor.insert %[[V3]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: scf.yield %[[V4]] : tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: }
// CHECK-NEXT: return %[[V1]] : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: return %[[V1]] : tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: }
func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>)

View File

@@ -1,21 +1,21 @@
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
// CHECK: func.func @sub_int_eint(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk[?]>>) -> tensor<5x!TFHE.glwe<sk[?]>> {
// CHECK: func.func @sub_int_eint(%[[Varg0:.*]]: tensor<5x!TFHE.glwe<sk?>>) -> tensor<5x!TFHE.glwe<sk?>> {
// CHECK-NEXT: %[[Vc1_i8:.*]] = arith.constant 1 : i8
// CHECK-NEXT: %[[V0:.*]] = arith.extsi %[[Vc1_i8]] : i8 to i64
// CHECK-NEXT: %[[V1:.*]] = "TFHE.encode_plaintext_with_crt"(%[[V0]]) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
// CHECK-NEXT: %[[V2:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: %[[V2:.*]] = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
// CHECK-NEXT: %[[Vc5:.*]] = arith.constant 5 : index
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V2]]) -> (tensor<5x!TFHE.glwe<sk[?]>>) {
// CHECK-NEXT: %[[V4:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg1:.*]] = %[[Vc0]] to %[[Vc5]] step %[[Vc1]] iter_args(%[[Varg2:.*]] = %[[V2]]) -> (tensor<5x!TFHE.glwe<sk?>>) {
// CHECK-NEXT: %[[V4:.*]] = tensor.extract %[[Varg0]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: %[[V5:.*]] = tensor.extract %[[V1]][%[[Varg1]]] : tensor<5xi64>
// CHECK-NEXT: %[[V6:.*]] = "TFHE.sub_int_glwe"(%[[V5]], %[[V4]]) : (i64, !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
// CHECK-NEXT: %[[V7:.*]] = tensor.insert %[[V6]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: scf.yield %[[V7]] : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: %[[V6:.*]] = "TFHE.sub_int_glwe"(%[[V5]], %[[V4]]) : (i64, !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
// CHECK-NEXT: %[[V7:.*]] = tensor.insert %[[V6]] into %[[Varg2]][%[[Varg1]]] : tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: scf.yield %[[V7]] : tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: }
// CHECK-NEXT: return %[[V3]] : tensor<5x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: return %[[V3]] : tensor<5x!TFHE.glwe<sk?>>
// CHECK-NEXT: }
func.func @sub_int_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
%0 = arith.constant 1 : i8

View File

@@ -1,9 +1,9 @@
// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s
// CHECK-LABEL: func.func @add_eint(%arg0: !TFHE.glwe<sk[?]>, %arg1: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
// CHECK-LABEL: func.func @add_eint(%arg0: !TFHE.glwe<sk?>, %arg1: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
func.func @add_eint(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %[[V1:.*]] = "TFHE.add_glwe"(%arg0, %arg1) : (!TFHE.glwe<sk[?]>, !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
// CHECK-NEXT: return %[[V1]] : !TFHE.glwe<sk[?]>
// CHECK-NEXT: %[[V1:.*]] = "TFHE.add_glwe"(%arg0, %arg1) : (!TFHE.glwe<sk?>, !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
// CHECK-NEXT: return %[[V1]] : !TFHE.glwe<sk?>
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>

View File

@@ -1,13 +1,13 @@
// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s
// CHECK-LABEL: func.func @add_eint_int(%arg0: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
// CHECK-LABEL: func.func @add_eint_int(%arg0: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
func.func @add_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
// CHECK-NEXT: %0 = arith.extsi %c1_i8 : i8 to i64
// CHECK-NEXT: %c56_i64 = arith.constant 56 : i64
// CHECK-NEXT: %1 = arith.shli %0, %c56_i64 : i64
// CHECK-NEXT: %2 = "TFHE.add_glwe_int"(%arg0, %1) : (!TFHE.glwe<sk[?]>, i64) -> !TFHE.glwe<sk[?]>
// CHECK-NEXT: return %2 : !TFHE.glwe<sk[?]>
// CHECK-NEXT: %2 = "TFHE.add_glwe_int"(%arg0, %1) : (!TFHE.glwe<sk?>, i64) -> !TFHE.glwe<sk?>
// CHECK-NEXT: return %2 : !TFHE.glwe<sk?>
%0 = arith.constant 1 : i8

View File

@@ -1,10 +1,10 @@
// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s
// CHECK: func.func @apply_lookup_table(%arg0: !TFHE.glwe<sk[?]>, %arg1: tensor<4xi64>) -> !TFHE.glwe<sk[?]> {
// CHECK: func.func @apply_lookup_table(%arg0: !TFHE.glwe<sk?>, %arg1: tensor<4xi64>) -> !TFHE.glwe<sk?> {
// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_bootstrap"(%arg1) {isSigned = false, outputBits = 3 : i32, polySize = 256 : i32} : (tensor<4xi64>) -> tensor<256xi64>
// CHECK-NEXT: %1 = "TFHE.keyswitch_glwe"(%arg0) {key = #TFHE.ksk<sk[?], sk[?], -1, -1>} : (!TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
// CHECK-NEXT: %2 = "TFHE.bootstrap_glwe"(%1, %0) {key = #TFHE.bsk<sk[?], sk[?], -1, -1, -1, -1>} : (!TFHE.glwe<sk[?]>, tensor<256xi64>) -> !TFHE.glwe<sk[?]>
// CHECK-NEXT: return %2 : !TFHE.glwe<sk[?]>
// CHECK-NEXT: %1 = "TFHE.keyswitch_glwe"(%arg0) {key = #TFHE.ksk<sk?, sk?, -1, -1>} : (!TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
// CHECK-NEXT: %2 = "TFHE.bootstrap_glwe"(%1, %0) {key = #TFHE.bsk<sk?, sk?, -1, -1, -1, -1>} : (!TFHE.glwe<sk?>, tensor<256xi64>) -> !TFHE.glwe<sk?>
// CHECK-NEXT: return %2 : !TFHE.glwe<sk?>
func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> {
%1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>)
return %1: !FHE.eint<3>

View File

@@ -1,12 +1,12 @@
// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s
//CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]> {
//CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?> {
//CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
//CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_bootstrap"(%cst) {isSigned = false, outputBits = 7 : i32, polySize = 8192 : i32} : (tensor<128xi64>) -> tensor<8192xi64>
//CHECK-NEXT: %1 = "TFHE.keyswitch_glwe"(%arg0) {key = #TFHE.ksk<sk[?], sk[?], -1, -1>} : (!TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
//CHECK-NEXT: %2 = "TFHE.bootstrap_glwe"(%1, %0) {key = #TFHE.bsk<sk[?], sk[?], -1, -1, -1, -1>} : (!TFHE.glwe<sk[?]>, tensor<8192xi64>) -> !TFHE.glwe<sk[?]>
//CHECK-NEXT: return %2 : !TFHE.glwe<sk[?]>
//CHECK-NEXT: %1 = "TFHE.keyswitch_glwe"(%arg0) {key = #TFHE.ksk<sk?, sk?, -1, -1>} : (!TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
//CHECK-NEXT: %2 = "TFHE.bootstrap_glwe"(%1, %0) {key = #TFHE.bsk<sk?, sk?, -1, -1, -1, -1>} : (!TFHE.glwe<sk?>, tensor<8192xi64>) -> !TFHE.glwe<sk?>
//CHECK-NEXT: return %2 : !TFHE.glwe<sk?>
func.func @apply_lookup_table_cst(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64>
%1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>)

View File

@@ -1,6 +1,6 @@
// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s
// CHECK: func.func @conv2d(%[[Varg0:.*]]: tensor<100x3x28x28x!TFHE.glwe<sk[?]>>, %[[Varg1:.*]]: tensor<4x3x14x14xi3>, %[[Varg2:.*]]: tensor<4xi3>) -> tensor<100x4x15x15x!TFHE.glwe<sk[?]>> {
// CHECK: func.func @conv2d(%[[Varg0:.*]]: tensor<100x3x28x28x!TFHE.glwe<sk?>>, %[[Varg1:.*]]: tensor<4x3x14x14xi3>, %[[Varg2:.*]]: tensor<4xi3>) -> tensor<100x4x15x15x!TFHE.glwe<sk?>> {
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[Vc100:.*]] = arith.constant 100 : index
// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
@@ -8,57 +8,57 @@
// CHECK-NEXT: %[[Vc15:.*]] = arith.constant 15 : index
// CHECK-NEXT: %[[Vc3:.*]] = arith.constant 3 : index
// CHECK-NEXT: %[[Vc14:.*]] = arith.constant 14 : index
// CHECK-NEXT: %[[V0:.*]] = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V0]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
// CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
// CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
// CHECK-NEXT: %[[V0:.*]] = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x!TFHE.glwe<sk?>>
// CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V0]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
// CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
// CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
// CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[Varg2]]{{\[}}%[[Varg5]]{{\]}} : tensor<4xi3>
// CHECK-NEXT: %[[Vextracted_0:.*]] = tensor.extract %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]]{{\]}} : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: %[[Vextracted_0:.*]] = tensor.extract %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]]{{\]}} : tensor<100x4x15x15x!TFHE.glwe<sk?>>
// CHECK-NEXT: %[[V6:.*]] = arith.extsi %[[Vextracted]] : i3 to i64
// CHECK-NEXT: %[[Vc61_i64:.*]] = arith.constant 61 : i64
// CHECK-NEXT: %[[V7:.*]] = arith.shli %[[V6]], %[[Vc61_i64]] : i64
// CHECK-NEXT: %[[V8:.*]] = "TFHE.add_glwe_int"(%[[Vextracted_0]], %[[V7]]) : (!TFHE.glwe<sk[?]>, i64) -> !TFHE.glwe<sk[?]>
// CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V8]] into %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]]{{\]}} : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: %[[V8:.*]] = "TFHE.add_glwe_int"(%[[Vextracted_0]], %[[V7]]) : (!TFHE.glwe<sk?>, i64) -> !TFHE.glwe<sk?>
// CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V8]] into %[[Varg10]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]]{{\]}} : tensor<100x4x15x15x!TFHE.glwe<sk?>>
// CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
// CHECK-NEXT: }
// CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V1]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
// CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
// CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
// CHECK-NEXT: %[[V6:.*]] = scf.for %[[Varg11:.*]] = %[[Vc0]] to %[[Vc3]] step %[[Vc1]] iter_args(%[[Varg12:.*]] = %[[Varg10]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
// CHECK-NEXT: %[[V7:.*]] = scf.for %[[Varg13:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg14:.*]] = %[[Varg12]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
// CHECK-NEXT: %[[V8:.*]] = scf.for %[[Varg15:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg16:.*]] = %[[Varg14]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk[?]>>) {
// CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg3:.*]] = %[[Vc0]] to %[[Vc100]] step %[[Vc1]] iter_args(%[[Varg4:.*]] = %[[V1]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg5:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc1]] iter_args(%[[Varg6:.*]] = %[[Varg4]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
// CHECK-NEXT: %[[V4:.*]] = scf.for %[[Varg7:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg8:.*]] = %[[Varg6]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
// CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg9:.*]] = %[[Vc0]] to %[[Vc15]] step %[[Vc1]] iter_args(%[[Varg10:.*]] = %[[Varg8]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
// CHECK-NEXT: %[[V6:.*]] = scf.for %[[Varg11:.*]] = %[[Vc0]] to %[[Vc3]] step %[[Vc1]] iter_args(%[[Varg12:.*]] = %[[Varg10]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
// CHECK-NEXT: %[[V7:.*]] = scf.for %[[Varg13:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg14:.*]] = %[[Varg12]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
// CHECK-NEXT: %[[V8:.*]] = scf.for %[[Varg15:.*]] = %[[Vc0]] to %[[Vc14]] step %[[Vc1]] iter_args(%[[Varg16:.*]] = %[[Varg14]]) -> (tensor<100x4x15x15x!TFHE.glwe<sk?>>) {
// CHECK-NEXT: %[[V9:.*]] = affine.apply #map(%[[Varg7]], %[[Varg13]])
// CHECK-NEXT: %[[V10:.*]] = affine.apply #map(%[[Varg9]], %[[Varg15]])
// CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[Varg0]]{{\[}}%[[Varg3]], %[[Varg11]], %[[V9]], %[[V10]]{{\]}} : tensor<100x3x28x28x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[Varg0]]{{\[}}%[[Varg3]], %[[Varg11]], %[[V9]], %[[V10]]{{\]}} : tensor<100x3x28x28x!TFHE.glwe<sk?>>
// CHECK-NEXT: %[[Vextracted_0:.*]] = tensor.extract %[[Varg1]]{{\[}}%[[Varg5]], %[[Varg11]], %[[Varg13]], %[[Varg15]]{{\]}} : tensor<4x3x14x14xi3>
// CHECK-NEXT: %[[Vextracted_1:.*]] = tensor.extract %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]]{{\]}} : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: %[[Vextracted_1:.*]] = tensor.extract %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]]{{\]}} : tensor<100x4x15x15x!TFHE.glwe<sk?>>
// CHECK-NEXT: %[[V11:.*]] = arith.extsi %[[Vextracted_0]] : i3 to i64
// CHECK-NEXT: %[[V12:.*]] = "TFHE.mul_glwe_int"(%[[Vextracted]], %[[V11]]) : (!TFHE.glwe<sk[?]>, i64) -> !TFHE.glwe<sk[?]>
// CHECK-NEXT: %[[V13:.*]] = "TFHE.add_glwe"(%[[Vextracted_1]], %[[V12]]) : (!TFHE.glwe<sk[?]>, !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
// CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V13]] into %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]]{{\]}} : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: %[[V12:.*]] = "TFHE.mul_glwe_int"(%[[Vextracted]], %[[V11]]) : (!TFHE.glwe<sk?>, i64) -> !TFHE.glwe<sk?>
// CHECK-NEXT: %[[V13:.*]] = "TFHE.add_glwe"(%[[Vextracted_1]], %[[V12]]) : (!TFHE.glwe<sk?>, !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
// CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V13]] into %[[Varg16]]{{\[}}%[[Varg3]], %[[Varg5]], %[[Varg7]], %[[Varg9]]{{\]}} : tensor<100x4x15x15x!TFHE.glwe<sk?>>
// CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[V8]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: scf.yield %[[V8]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[V7]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: scf.yield %[[V7]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[V6]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: scf.yield %[[V6]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: scf.yield %[[V5]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: scf.yield %[[V4]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: scf.yield %[[V3]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
// CHECK-NEXT: }
// CHECK-NEXT: return %[[V2]] : tensor<100x4x15x15x!TFHE.glwe<sk[?]>>
// CHECK-NEXT: return %[[V2]] : tensor<100x4x15x15x!TFHE.glwe<sk?>>
// CHECK-NEXT: }
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0, 0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>

View File

@@ -1,11 +1,11 @@
// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s
// CHECK-LABEL: func.func @mul_eint_int(%arg0: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
// CHECK-LABEL: func.func @mul_eint_int(%arg0: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
func.func @mul_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %c2_i8 = arith.constant 2 : i8
// CHECK-NEXT: %0 = arith.extsi %c2_i8 : i8 to i64
// CHECK-NEXT: %1 = "TFHE.mul_glwe_int"(%arg0, %0) : (!TFHE.glwe<sk[?]>, i64) -> !TFHE.glwe<sk[?]>
// CHECK-NEXT: return %1 : !TFHE.glwe<sk[?]>
// CHECK-NEXT: %1 = "TFHE.mul_glwe_int"(%arg0, %0) : (!TFHE.glwe<sk?>, i64) -> !TFHE.glwe<sk?>
// CHECK-NEXT: return %1 : !TFHE.glwe<sk?>
%0 = arith.constant 2 : i8
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)

View File

@@ -1,18 +1,18 @@
// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s
// CHECK-LABEL: func.func @neg_eint(%arg0: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
// CHECK-LABEL: func.func @neg_eint(%arg0: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
// CHECK-NEXT: return %0 : !TFHE.glwe<sk[?]>
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
// CHECK-NEXT: return %0 : !TFHE.glwe<sk?>
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
// CHECK-LABEL: func.func @not(%arg0: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
// CHECK-LABEL: func.func @not(%arg0: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
func.func @not(%arg0: !FHE.ebool) -> !FHE.ebool {
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
// CHECK-NEXT: return %0 : !TFHE.glwe<sk[?]>
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
// CHECK-NEXT: return %0 : !TFHE.glwe<sk?>
%1 = "FHE.not"(%arg0) : (!FHE.ebool) -> !FHE.ebool
return %1: !FHE.ebool
}

View File

@@ -1,13 +1,13 @@
// RUN: concretecompiler %s --optimize-tfhe=false --action=dump-tfhe 2>&1| FileCheck %s
// CHECK-LABEL: func.func @sub_int_eint(%arg0: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
// CHECK-LABEL: func.func @sub_int_eint(%arg0: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
func.func @sub_int_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
// CHECK-NEXT: %0 = arith.extsi %c1_i8 : i8 to i64
// CHECK-NEXT: %c56_i64 = arith.constant 56 : i64
// CHECK-NEXT: %1 = arith.shli %0, %c56_i64 : i64
// CHECK-NEXT: %2 = "TFHE.sub_int_glwe"(%1, %arg0) : (i64, !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
// CHECK-NEXT: return %2 : !TFHE.glwe<sk[?]>
// CHECK-NEXT: %2 = "TFHE.sub_int_glwe"(%1, %arg0) : (i64, !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
// CHECK-NEXT: return %2 : !TFHE.glwe<sk?>
%0 = arith.constant 1 : i8
%1 = "FHE.sub_int_eint"(%0, %arg0): (i8, !FHE.eint<7>) -> (!FHE.eint<7>)

View File

@@ -1,14 +1,14 @@
// RUN: concretecompiler --passes tfhe-global-parametrization --action=dump-std --optimizer-v0 --v0-parameter=2,10,750,1,23,3,4 --v0-constraint=4,0 %s 2>&1| FileCheck %s
// RUN: concretecompiler --action=dump-tfhe-parameterized --optimizer-v0 --v0-parameter=2,10,750,1,23,3,4 --v0-constraint=4,0 %s 2>&1| FileCheck %s
//CHECK: func.func @main(%[[A0:.*]]: !TFHE.glwe<sk[1]<1,2048>>) -> !TFHE.glwe<sk[1]<1,2048>> {
//CHECK: func.func @main(%[[A0:.*]]: !TFHE.glwe<sk<0,1,2048>>) -> !TFHE.glwe<sk<0,1,2048>> {
//CHECK-NEXT: %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
//CHECK-NEXT: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[A0]]) {key = #TFHE.ksk<sk[1]<1,2048>, sk[3]<1,750>, 3, 4>} : (!TFHE.glwe<sk[1]<1,2048>>) -> !TFHE.glwe<sk[3]<1,750>>
//CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %cst) {key = #TFHE.bsk<sk[3]<1,750>, sk[1]<1,2048>, 1024, 2, 1, 23>} : (!TFHE.glwe<sk[3]<1,750>>, tensor<16xi64>) -> !TFHE.glwe<sk[1]<1,2048>>
//CHECK-NEXT: return %[[V2]] : !TFHE.glwe<sk[1]<1,2048>>
//CHECK-NEXT: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[A0]]) {key = #TFHE.ksk<sk<0,1,2048>, sk<1,1,750>, 3, 4>} : (!TFHE.glwe<sk<0,1,2048>>) -> !TFHE.glwe<sk<1,1,750>>
//CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %cst) {key = #TFHE.bsk<sk<1,1,750>, sk<0,1,2048>, 1024, 2, 1, 23>} : (!TFHE.glwe<sk<1,1,750>>, tensor<16xi64>) -> !TFHE.glwe<sk<0,1,2048>>
//CHECK-NEXT: return %[[V2]] : !TFHE.glwe<sk<0,1,2048>>
//CHECK-NEXT: }
func.func @main(%arg0: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]> {
func.func @main(%arg0: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?> {
%cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
%1 = "TFHE.keyswitch_glwe"(%arg0) {key = #TFHE.ksk<sk[?], sk[?], -1, -1>} : (!TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
%2 = "TFHE.bootstrap_glwe"(%1, %cst) {key = #TFHE.bsk<sk[?], sk[?], -1, -1, -1, -1>} : (!TFHE.glwe<sk[?]>, tensor<16xi64>) -> !TFHE.glwe<sk[?]>
return %2 : !TFHE.glwe<sk[?]>
%1 = "TFHE.keyswitch_glwe"(%arg0) {key = #TFHE.ksk<sk?, sk?, -1, -1>} : (!TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
%2 = "TFHE.bootstrap_glwe"(%1, %cst) {key = #TFHE.bsk<sk?, sk?, -1, -1, -1, -1>} : (!TFHE.glwe<sk?>, tensor<16xi64>) -> !TFHE.glwe<sk?>
return %2 : !TFHE.glwe<sk?>
}

View File

@@ -2,7 +2,7 @@
//CHECK: func.func @bootstrap_lwe(%[[A0:.*]]: tensor<601xi64>) -> tensor<1025xi64> {
//CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
//CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe_tensor"(%arg0, %cst) {baseLog = 1 : i32, glweDimension = 1 : i32, inputLweDim = 600 : i32, level = 3 : i32, polySize = 1024 : i32} : (tensor<601xi64>, tensor<128xi64>) -> tensor<1025xi64>
//CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe_tensor"(%arg0, %cst) {baseLog = 1 : i32, bskIndex = -1 : i32, glweDimension = 1 : i32, inputLweDim = 600 : i32, level = 3 : i32, polySize = 1024 : i32} : (tensor<601xi64>, tensor<128xi64>) -> tensor<1025xi64>
//CHECK-NEXT: return %[[V1]] : tensor<1025xi64>
//CHECK-NEXT: }
func.func @bootstrap_lwe(%ciphertext: !TFHE.glwe<sk[1]<1,600>>) -> !TFHE.glwe<sk[5]<1,1024>> {

View File

@@ -1,7 +1,7 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// CHECK: func.func @keyswitch_glwe(%[[A0:.*]]: tensor<1025xi64>) -> tensor<568xi64> {
// CHECK-NEXT: %[[V0:.*]] = "Concrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 3 : i32, level = 2 : i32, lwe_dim_in = 1024 : i32, lwe_dim_out = 567 : i32} : (tensor<1025xi64>) -> tensor<568xi64>
// CHECK-NEXT: %[[V0:.*]] = "Concrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 3 : i32, kskIndex = -1 : i32, level = 2 : i32, lwe_dim_in = 1024 : i32, lwe_dim_out = 567 : i32} : (tensor<1025xi64>) -> tensor<568xi64>
// CHECK-NEXT: return %[[V0]] : tensor<568xi64>
// CHECK-NEXT: }
func.func @keyswitch_glwe(%arg0: !TFHE.glwe<sk[1]<1,1024>>) -> !TFHE.glwe<sk[3]<1,567>> {

View File

@@ -39,20 +39,20 @@ func.func @negate_lwe_ciphertext(%arg0: tensor<2049xi64>) -> tensor<2049xi64> {
}
//CHECK: func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>) -> tensor<2049xi64> {
//CHECK: %[[V0:.*]] = "Concrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> tensor<2049xi64>
//CHECK: %[[V0:.*]] = "Concrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, bskIndex = 0 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> tensor<2049xi64>
//CHECK: return %[[V0]] : tensor<2049xi64>
//CHECK: }
func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>) -> tensor<2049xi64> {
%0 = "Concrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> (tensor<2049xi64>)
%0 = "Concrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, bskIndex = 0 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> (tensor<2049xi64>)
return %0 : tensor<2049xi64>
}
//CHECK: func.func @keyswitch_lwe(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
//CHECK: %[[V0:.*]] = "Concrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> tensor<2049xi64>
//CHECK: %[[V0:.*]] = "Concrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 2 : i32, kskIndex = 0 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> tensor<2049xi64>
//CHECK: return %[[V0]] : tensor<2049xi64>
//CHECK: }
func.func @keyswitch_lwe(%arg0: tensor<2049xi64>) -> tensor<2049xi64> {
%0 = "Concrete.keyswitch_lwe_tensor"(%arg0) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> (tensor<2049xi64>)
%0 = "Concrete.keyswitch_lwe_tensor"(%arg0) {baseLog = 2 : i32, kskIndex = 0 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> (tensor<2049xi64>)
return %0 : tensor<2049xi64>
}
@@ -83,13 +83,13 @@ func.func @negate_lwe_ciphertext_buffer(%arg0: memref<2049xi64>, %result: memref
}
func.func @bootstrap_lwe_buffer(%arg0: memref<2049xi64>, %arg1: memref<16xi64>, %result: memref<2049xi64>) {
//CHECK: "Concrete.bootstrap_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A1:.*]]) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>, memref<16xi64>) -> ()
"Concrete.bootstrap_lwe_buffer"(%result, %arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>, memref<16xi64>) -> ()
//CHECK: "Concrete.bootstrap_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A1:.*]]) {baseLog = 2 : i32, bskIndex = 0 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>, memref<16xi64>) -> ()
"Concrete.bootstrap_lwe_buffer"(%result, %arg0, %arg1) {baseLog = 2 : i32, bskIndex = 0 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>, memref<16xi64>) -> ()
return
}
func.func @keyswitch_lwe_buffer(%arg0: memref<2049xi64>, %result: memref<2049xi64>) {
//CHECK: "Concrete.keyswitch_lwe_buffer"(%[[R:.*]], %[[A0:.*]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>) -> ()
"Concrete.keyswitch_lwe_buffer"(%result, %arg0) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>) -> ()
//CHECK: "Concrete.keyswitch_lwe_buffer"(%[[R:.*]], %[[A0:.*]]) {baseLog = 2 : i32, kskIndex = 0 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>) -> ()
"Concrete.keyswitch_lwe_buffer"(%result, %arg0) {baseLog = 2 : i32, kskIndex = 0 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>) -> ()
return
}

View File

@@ -8,8 +8,8 @@ func.func @glwe_0(%arg0: !TFHE.glwe<sk[1]<12,1024>>) -> !TFHE.glwe<sk[1]<12,1024
// -----
// CHECK-LABEL: func.func @glwe_1(%arg0: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]>
func.func @glwe_1(%arg0: !TFHE.glwe<sk[?]>) -> !TFHE.glwe<sk[?]> {
// CHECK-LABEL: return %arg0 : !TFHE.glwe<sk[?]>
return %arg0: !TFHE.glwe<sk[?]>
// CHECK-LABEL: func.func @glwe_1(%arg0: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
func.func @glwe_1(%arg0: !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?> {
// CHECK-LABEL: return %arg0 : !TFHE.glwe<sk?>
return %arg0: !TFHE.glwe<sk?>
}

View File

@@ -418,9 +418,7 @@ def test_compile_and_run_invalid_arg_number(
)
def test_compile_invalid(mlir_input):
engine = JITSupport.new()
with pytest.raises(
RuntimeError, match=r"Could not find existing crypto parameters for"
):
with pytest.raises(RuntimeError, match=r"Function not found, name='main'"):
engine.compile(mlir_input)

View File

@@ -3,3 +3,4 @@ add_custom_target(ConcretelangUnitTests)
add_subdirectory(ClientLib)
add_subdirectory(SDFG)
add_subdirectory(TestLib)
add_subdirectory(Encodings)

View File

@@ -0,0 +1,22 @@
add_custom_target(EncodingsUnitTests)
add_dependencies(ConcretelangUnitTests EncodingsUnitTests)
function(add_concretecompiler_lib_test test_name)
add_unittest(EncodingsUnitTests ${test_name} ${ARGN})
target_link_libraries(${test_name} PRIVATE ConcretelangSupport)
set_source_files_properties(${ARGN} PROPERTIES COMPILE_FLAGS "-fno-rtti")
endfunction()
if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
link_libraries(
# usefull for old gcc versions
-Wl,--allow-multiple-definition # static concrete-optimizer and concrete shares some code
)
endif()
if(CONCRETELANG_DATAFLOW_EXECUTION_ENABLED)
add_compile_options(-DCONCRETELANG_DATAFLOW_TESTING_ENABLED)
endif()
add_concretecompiler_lib_test(unit_tests_concretelang_Encodings Encodings_unit_tests.cpp)

View File

@@ -0,0 +1,98 @@
#include <gtest/gtest.h>
#include <cassert>
#include <chrono>
#include <iostream>
#include <thread>
#include "boost/outcome.h"
#include "concretelang/ClientLib/ClientLambda.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Encodings.h"
#include "concretelang/TestLib/TestTypedLambda.h"
#include "tests_tools/GtestEnvironment.h"
#include "tests_tools/assert.h"
#include "tests_tools/keySetCache.h"
testing::Environment *const dfr_env =
testing::AddGlobalTestEnvironment(new DFREnvironment);
const std::string FUNCNAME = "main";
using namespace concretelang::testlib;
namespace encodings = mlir::concretelang::encodings;
using concretelang::clientlib::scalar_in;
using concretelang::clientlib::scalar_out;
using concretelang::clientlib::tensor1_in;
using concretelang::clientlib::tensor1_out;
using concretelang::clientlib::tensor2_in;
using concretelang::clientlib::tensor2_out;
using concretelang::clientlib::tensor3_out;
mlir::concretelang::CompilerEngine::Library
compile(std::string outputLib, std::string source,
std::string funcname = FUNCNAME) {
std::vector<std::string> sources = {source};
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
mlir::concretelang::CompilationContext::createShared();
mlir::concretelang::CompilerEngine ce{ccx};
mlir::concretelang::CompilationOptions options(funcname);
options.encodings = encodings::CircuitEncodings{
{
encodings::EncryptedIntegerScalarEncoding{3, false},
encodings::EncryptedIntegerScalarEncoding{3, false},
},
{
encodings::EncryptedIntegerScalarEncoding{3, false},
}};
options.v0Parameter = {2, 10, 693, 4, 9, 7, 2, std::nullopt};
ce.setCompilationOptions(options);
auto result = ce.compile(sources, outputLib);
if (!result) {
llvm::errs() << result.takeError();
assert(false);
}
assert(result);
return result.get();
}
static const std::string CURRENT_FILE = __FILE__;
static const std::string THIS_TEST_DIRECTORY =
CURRENT_FILE.substr(0, CURRENT_FILE.find_last_of("/\\"));
static const std::string OUT_DIRECTORY = "/tmp";
template <typename Info> std::string outputLibFromThis(Info *info) {
return OUT_DIRECTORY + "/" + std::string(info->name());
}
template <typename Lambda> Lambda load(std::string outputLib) {
auto l = Lambda::load(FUNCNAME, outputLib, 0, 0, getTestKeySetCachePtr());
assert(l.has_value());
return l.value();
}
TEST(Encodings_unit_tests, multi_key) {
std::string source = R"(
func.func @main(
%arg0: !TFHE.glwe<sk<1,1,2048>>,
%arg1: !TFHE.glwe<sk<2,1,2048>>
) -> !TFHE.glwe<sk<2,1,2048>> {
%0 = "TFHE.keyswitch_glwe"(%arg0) {key=#TFHE.ksk<sk<1,1,2048>, sk<2, 1,2048>, 7, 2>} : (!TFHE.glwe<sk<1, 1, 2048>>) -> !TFHE.glwe<sk<2, 1, 2048>>
%1 = "TFHE.add_glwe"(%arg1, %0) : (!TFHE.glwe<sk<2,1,2048>>, !TFHE.glwe<sk<2,1,2048>>) -> !TFHE.glwe<sk<2,1,2048>>
return %1 : !TFHE.glwe<sk<2,1,2048>>
}
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda =
load<TestTypedLambda<scalar_out, scalar_in, scalar_in>>(outputLib);
scalar_in a = 5;
scalar_in b = 5;
auto res = lambda.call(a, b);
ASSERT_EQ_OUTCOME(res, (scalar_out)a + b);
}