mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-16 23:51:36 -05:00
feat(compiler): support multi-circuit compilation
This commit is contained in:
committed by
Alexandre Péré
parent
3247a28d9d
commit
9b5a2e46da
@@ -286,7 +286,6 @@ struct LibraryCompilationResult {
|
||||
/// The output directory path where the compilation artifacts have been
|
||||
/// generated.
|
||||
std::string outputDirPath;
|
||||
std::string funcName;
|
||||
};
|
||||
|
||||
class LibrarySupport {
|
||||
@@ -318,14 +317,8 @@ public:
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
if (!options.mainFuncName.has_value()) {
|
||||
return StreamStringError("Need to have a funcname to compile library");
|
||||
}
|
||||
this->funcName = options.mainFuncName.value();
|
||||
|
||||
auto result = std::make_unique<LibraryCompilationResult>();
|
||||
result->outputDirPath = outputPath;
|
||||
result->funcName = *options.mainFuncName;
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
@@ -356,20 +349,15 @@ public:
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
if (!options.mainFuncName.has_value()) {
|
||||
return StreamStringError("Need to have a funcname to compile library");
|
||||
}
|
||||
this->funcName = options.mainFuncName.value();
|
||||
|
||||
auto result = std::make_unique<LibraryCompilationResult>();
|
||||
result->outputDirPath = outputPath;
|
||||
result->funcName = *options.mainFuncName;
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
/// Load the server lambda from the compilation result.
|
||||
llvm::Expected<::concretelang::serverlib::ServerLambda>
|
||||
loadServerLambda(LibraryCompilationResult &result, bool useSimulation) {
|
||||
loadServerLambda(LibraryCompilationResult &result, std::string circuitName,
|
||||
bool useSimulation) {
|
||||
EXPECTED_TRY(auto programInfo, getProgramInfo());
|
||||
EXPECTED_TRY(ServerProgram serverProgram,
|
||||
outcomeToExpected(ServerProgram::load(programInfo.asReader(),
|
||||
@@ -377,7 +365,7 @@ public:
|
||||
useSimulation)));
|
||||
EXPECTED_TRY(
|
||||
ServerCircuit serverCircuit,
|
||||
outcomeToExpected(serverProgram.getServerCircuit(result.funcName)));
|
||||
outcomeToExpected(serverProgram.getServerCircuit(circuitName)));
|
||||
return ::concretelang::serverlib::ServerLambda{serverCircuit,
|
||||
useSimulation};
|
||||
}
|
||||
@@ -386,13 +374,6 @@ public:
|
||||
llvm::Expected<::concretelang::clientlib::ClientParameters>
|
||||
loadClientParameters(LibraryCompilationResult &result) {
|
||||
EXPECTED_TRY(auto programInfo, getProgramInfo());
|
||||
if (programInfo.asReader().getCircuits().size() > 1) {
|
||||
return StreamStringError("ClientLambda: Provided program info contains "
|
||||
"more than one circuit.");
|
||||
}
|
||||
if (programInfo.asReader().getCircuits()[0].getName() != result.funcName) {
|
||||
return StreamStringError("Unexpected circuit name in program info");
|
||||
}
|
||||
auto secretKeys =
|
||||
std::vector<::concretelang::clientlib::LweSecretKeyParam>();
|
||||
for (auto key : programInfo.asReader().getKeyset().getLweSecretKeys()) {
|
||||
@@ -441,15 +422,14 @@ public:
|
||||
loadCompilationResult() {
|
||||
auto result = std::make_unique<LibraryCompilationResult>();
|
||||
result->outputDirPath = outputPath;
|
||||
result->funcName = funcName;
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
llvm::Expected<CompilationFeedback>
|
||||
llvm::Expected<ProgramCompilationFeedback>
|
||||
loadCompilationFeedback(LibraryCompilationResult &result) {
|
||||
auto path = CompilerEngine::Library::getCompilationFeedbackPath(
|
||||
result.outputDirPath);
|
||||
auto feedback = CompilationFeedback::load(path);
|
||||
auto feedback = ProgramCompilationFeedback::load(path);
|
||||
if (feedback.has_error()) {
|
||||
return StreamStringError(feedback.error().mesg);
|
||||
}
|
||||
@@ -499,7 +479,6 @@ public:
|
||||
|
||||
private:
|
||||
std::string outputPath;
|
||||
std::string funcName;
|
||||
std::string runtimeLibraryPath;
|
||||
/// Flags to select generated artifacts
|
||||
bool generateSharedLib;
|
||||
|
||||
@@ -240,6 +240,7 @@ private:
|
||||
|
||||
template struct Message<concreteprotocol::ProgramInfo>;
|
||||
template struct Message<concreteprotocol::CircuitEncodingInfo>;
|
||||
template struct Message<concreteprotocol::ProgramEncodingInfo>;
|
||||
template struct Message<concreteprotocol::Value>;
|
||||
template struct Message<concreteprotocol::GateInfo>;
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
|
||||
createMemoryUsagePass(CompilationFeedback &feedback);
|
||||
createMemoryUsagePass(ProgramCompilationFeedback &feedback);
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -15,7 +15,7 @@ namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
|
||||
createStatisticExtractionPass(CompilationFeedback &feedback);
|
||||
createStatisticExtractionPass(ProgramCompilationFeedback &feedback);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -44,27 +44,13 @@ enum class KeyType {
|
||||
struct Statistic {
|
||||
std::string location;
|
||||
PrimitiveOperation operation;
|
||||
std::vector<std::pair<KeyType, size_t>> keys;
|
||||
size_t count;
|
||||
std::vector<std::pair<KeyType, int64_t>> keys;
|
||||
int64_t count;
|
||||
};
|
||||
|
||||
struct CompilationFeedback {
|
||||
double complexity;
|
||||
|
||||
/// @brief Probability of error for every PBS.
|
||||
double pError;
|
||||
|
||||
/// @brief Probability of error for the whole programs.
|
||||
double globalPError;
|
||||
|
||||
/// @brief the total number of bytes of secret keys
|
||||
uint64_t totalSecretKeysSize;
|
||||
|
||||
/// @brief the total number of bytes of bootstrap keys
|
||||
uint64_t totalBootstrapKeysSize;
|
||||
|
||||
/// @brief the total number of bytes of keyswitch keys
|
||||
uint64_t totalKeyswitchKeysSize;
|
||||
struct CircuitCompilationFeedback {
|
||||
/// @brief the name of circuit.
|
||||
std::string name;
|
||||
|
||||
/// @brief the total number of bytes of inputs
|
||||
uint64_t totalInputsSize;
|
||||
@@ -81,25 +67,52 @@ struct CompilationFeedback {
|
||||
/// @brief memory usage per location
|
||||
std::map<std::string, int64_t> memoryUsagePerLoc;
|
||||
|
||||
/// Fill the sizes from the program info.
|
||||
void fillFromCircuitInfo(concreteprotocol::CircuitInfo::Reader params);
|
||||
};
|
||||
|
||||
struct ProgramCompilationFeedback {
|
||||
double complexity;
|
||||
|
||||
/// @brief Probability of error for every PBS.
|
||||
double pError;
|
||||
|
||||
/// @brief Probability of error for the whole programs.
|
||||
double globalPError;
|
||||
|
||||
/// @brief the total number of bytes of secret keys
|
||||
uint64_t totalSecretKeysSize;
|
||||
|
||||
/// @brief the total number of bytes of bootstrap keys
|
||||
uint64_t totalBootstrapKeysSize;
|
||||
|
||||
/// @brief the total number of bytes of keyswitch keys
|
||||
uint64_t totalKeyswitchKeysSize;
|
||||
|
||||
/// @brief the feedback for each circuit
|
||||
std::vector<CircuitCompilationFeedback> circuitFeedbacks;
|
||||
|
||||
/// Fill the sizes from the program info.
|
||||
void fillFromProgramInfo(const Message<protocol::ProgramInfo> ¶ms);
|
||||
|
||||
/// Load the compilation feedback from a path
|
||||
static outcome::checked<CompilationFeedback, StringError>
|
||||
static outcome::checked<ProgramCompilationFeedback, StringError>
|
||||
load(std::string path);
|
||||
};
|
||||
|
||||
llvm::json::Value toJSON(const mlir::concretelang::CompilationFeedback &);
|
||||
llvm::json::Value
|
||||
toJSON(const mlir::concretelang::ProgramCompilationFeedback &);
|
||||
|
||||
bool fromJSON(const llvm::json::Value,
|
||||
mlir::concretelang::CompilationFeedback &, llvm::json::Path);
|
||||
mlir::concretelang::ProgramCompilationFeedback &,
|
||||
llvm::json::Path);
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
static inline llvm::raw_ostream &
|
||||
operator<<(llvm::raw_string_ostream &OS,
|
||||
mlir::concretelang::CompilationFeedback cp) {
|
||||
mlir::concretelang::ProgramCompilationFeedback cp) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(cp));
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -79,8 +79,6 @@ struct CompilationOptions {
|
||||
|
||||
std::optional<std::vector<int64_t>> fhelinalgTileSizes;
|
||||
|
||||
std::optional<std::string> mainFuncName;
|
||||
|
||||
optimizer::Config optimizerConfig;
|
||||
|
||||
/// When decomposing big integers into chunks, chunkSize is the total number
|
||||
@@ -92,7 +90,9 @@ struct CompilationOptions {
|
||||
|
||||
/// When compiling from a dialect lower than FHE, one needs to provide
|
||||
/// encodings info manually to allow the client lib to be generated.
|
||||
std::optional<Message<concreteprotocol::CircuitEncodingInfo>> encodings;
|
||||
std::optional<Message<concreteprotocol::ProgramEncodingInfo>> encodings;
|
||||
|
||||
bool skipProgramInfo;
|
||||
|
||||
bool compressEvaluationKeys;
|
||||
|
||||
@@ -102,20 +102,14 @@ struct CompilationOptions {
|
||||
maxBatchSize(std::numeric_limits<int64_t>::max()), emitSDFGOps(false),
|
||||
unrollLoopsWithSDFGConvertibleOps(false), dataflowParallelize(false),
|
||||
optimizeTFHE(true), simulate(false), emitGPUOps(false),
|
||||
mainFuncName(std::nullopt), optimizerConfig(optimizer::DEFAULT_CONFIG),
|
||||
chunkIntegers(false), chunkSize(4), chunkWidth(2),
|
||||
encodings(std::nullopt), compressEvaluationKeys(false){};
|
||||
|
||||
CompilationOptions(std::string funcname) : CompilationOptions() {
|
||||
mainFuncName = funcname;
|
||||
}
|
||||
optimizerConfig(optimizer::DEFAULT_CONFIG), chunkIntegers(false),
|
||||
chunkSize(4), chunkWidth(2), encodings(std::nullopt),
|
||||
skipProgramInfo(false), compressEvaluationKeys(false){};
|
||||
|
||||
/// @brief Constructor for CompilationOptions with default parameters for a
|
||||
/// specific backend.
|
||||
/// @param funcname The name of the function to compile.
|
||||
/// @param backend The backend to target.
|
||||
CompilationOptions(std::string funcname, enum Backend backend)
|
||||
: CompilationOptions(funcname) {
|
||||
CompilationOptions(enum Backend backend) : CompilationOptions() {
|
||||
switch (backend) {
|
||||
case Backend::CPU:
|
||||
loopParallelize = true;
|
||||
@@ -143,7 +137,7 @@ public:
|
||||
|
||||
std::optional<mlir::OwningOpRef<mlir::ModuleOp>> mlirModuleRef;
|
||||
std::optional<Message<concreteprotocol::ProgramInfo>> programInfo;
|
||||
std::optional<CompilationFeedback> feedback;
|
||||
std::optional<ProgramCompilationFeedback> feedback;
|
||||
std::unique_ptr<llvm::Module> llvmModule;
|
||||
std::optional<mlir::concretelang::V0FHEContext> fheContext;
|
||||
|
||||
@@ -157,7 +151,7 @@ public:
|
||||
/// Path to the runtime library. Will be linked to the output library if set
|
||||
std::string runtimeLibraryPath;
|
||||
bool cleanUp;
|
||||
mlir::concretelang::CompilationFeedback compilationFeedback;
|
||||
mlir::concretelang::ProgramCompilationFeedback compilationFeedback;
|
||||
Message<concreteprotocol::ProgramInfo> programInfo;
|
||||
|
||||
public:
|
||||
@@ -280,7 +274,7 @@ public:
|
||||
|
||||
CompilerEngine(std::shared_ptr<CompilationContext> compilationContext)
|
||||
: overrideMaxEintPrecision(), overrideMaxMANP(), compilerOptions(),
|
||||
generateProgramInfo(compilerOptions.mainFuncName.has_value()),
|
||||
generateProgramInfo(true),
|
||||
enablePass([](mlir::Pass *pass) { return true; }),
|
||||
compilationContext(compilationContext) {}
|
||||
|
||||
@@ -325,10 +319,6 @@ public:
|
||||
if (compilerOptions.v0FHEConstraints.has_value()) {
|
||||
setFHEConstraints(*compilerOptions.v0FHEConstraints);
|
||||
}
|
||||
|
||||
if (compilerOptions.mainFuncName.has_value()) {
|
||||
setGenerateProgramInfo(true);
|
||||
}
|
||||
}
|
||||
|
||||
CompilationOptions &getCompilationOptions() { return compilerOptions; }
|
||||
|
||||
@@ -35,11 +35,11 @@ namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace encodings {
|
||||
|
||||
llvm::Expected<Message<concreteprotocol::CircuitEncodingInfo>>
|
||||
getCircuitEncodings(llvm::StringRef functionName, mlir::ModuleOp module);
|
||||
llvm::Expected<Message<concreteprotocol::ProgramEncodingInfo>>
|
||||
getProgramEncoding(mlir::ModuleOp module);
|
||||
|
||||
void setCircuitEncodingModes(
|
||||
Message<concreteprotocol::CircuitEncodingInfo> &info,
|
||||
void setProgramEncodingModes(
|
||||
Message<concreteprotocol::ProgramEncodingInfo> &info,
|
||||
std::optional<
|
||||
Message<concreteprotocol::IntegerCiphertextEncodingInfo::ChunkedMode>>
|
||||
maybeChunk,
|
||||
|
||||
@@ -77,7 +77,7 @@ normalizeTFHEKeys(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
mlir::LogicalResult
|
||||
extractTFHEStatistics(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
CompilationFeedback &feedback);
|
||||
ProgramCompilationFeedback &feedback);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
@@ -86,7 +86,7 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
mlir::LogicalResult
|
||||
computeMemoryUsage(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
CompilationFeedback &feedback);
|
||||
ProgramCompilationFeedback &feedback);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerConcreteLinalgToLoops(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
|
||||
@@ -21,8 +21,8 @@ namespace concretelang {
|
||||
|
||||
llvm::Expected<Message<concreteprotocol::ProgramInfo>>
|
||||
createProgramInfoFromTfheDialect(
|
||||
mlir::ModuleOp module, llvm::StringRef functionName, int bitsOfSecurity,
|
||||
Message<concreteprotocol::CircuitEncodingInfo> &encodings,
|
||||
mlir::ModuleOp module, int bitsOfSecurity,
|
||||
const Message<concreteprotocol::ProgramEncodingInfo> &encodings,
|
||||
bool compressEvaluationKeys);
|
||||
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -151,10 +151,10 @@ typedef std::variant<V0Parameter, CircuitSolution> Solution;
|
||||
|
||||
} // namespace optimizer
|
||||
|
||||
struct CompilationFeedback;
|
||||
struct ProgramCompilationFeedback;
|
||||
|
||||
llvm::Expected<optimizer::Solution>
|
||||
getSolution(optimizer::Description &descr, CompilationFeedback &feedback,
|
||||
getSolution(optimizer::Description &descr, ProgramCompilationFeedback &feedback,
|
||||
optimizer::Config optimizerConfig);
|
||||
|
||||
// As for now the solution which contains a crt encoding is mono parameter only
|
||||
|
||||
@@ -36,26 +36,26 @@ using concretelang::values::Value;
|
||||
namespace concretelang {
|
||||
namespace testlib {
|
||||
|
||||
class TestCircuit {
|
||||
class TestProgram {
|
||||
|
||||
public:
|
||||
TestCircuit(mlir::concretelang::CompilationOptions options)
|
||||
TestProgram(mlir::concretelang::CompilationOptions options)
|
||||
: artifactDirectory(createTempFolderIn(getSystemTempFolderPath())),
|
||||
compiler(mlir::concretelang::CompilationContext::createShared()),
|
||||
encryptionCsprng(std::make_shared<csprng::EncryptionCSPRNG>(0)) {
|
||||
compiler.setCompilationOptions(options);
|
||||
}
|
||||
|
||||
TestCircuit(TestCircuit &&tc)
|
||||
TestProgram(TestProgram &&tc)
|
||||
: artifactDirectory(tc.artifactDirectory), compiler(tc.compiler),
|
||||
library(tc.library), keyset(tc.keyset),
|
||||
encryptionCsprng(tc.encryptionCsprng) {
|
||||
tc.artifactDirectory = "";
|
||||
};
|
||||
|
||||
TestCircuit(TestCircuit &tc) = delete;
|
||||
TestProgram(TestProgram &tc) = delete;
|
||||
|
||||
~TestCircuit() {
|
||||
~TestProgram() {
|
||||
auto d = getArtifactDirectory();
|
||||
if (d.empty())
|
||||
return;
|
||||
@@ -93,16 +93,17 @@ public:
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
Result<std::vector<Value>> call(std::vector<Value> inputs) {
|
||||
Result<std::vector<Value>> call(std::vector<Value> inputs,
|
||||
std::string name = "main") {
|
||||
// preprocess arguments
|
||||
auto preparedArgs = std::vector<TransportValue>();
|
||||
OUTCOME_TRY(auto clientCircuit, getClientCircuit());
|
||||
OUTCOME_TRY(auto clientCircuit, getClientCircuit(name));
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
OUTCOME_TRY(auto preparedInput, clientCircuit.prepareInput(inputs[i], i));
|
||||
preparedArgs.push_back(preparedInput);
|
||||
}
|
||||
// Call server
|
||||
OUTCOME_TRY(auto returns, callServer(preparedArgs));
|
||||
OUTCOME_TRY(auto returns, callServer(preparedArgs, name));
|
||||
// postprocess arguments
|
||||
std::vector<Value> processedOutputs(returns.size());
|
||||
for (size_t i = 0; i < processedOutputs.size(); i++) {
|
||||
@@ -113,17 +114,18 @@ public:
|
||||
}
|
||||
|
||||
Result<std::vector<Value>> compose_n_times(std::vector<Value> inputs,
|
||||
size_t n) {
|
||||
size_t n,
|
||||
std::string name = "main") {
|
||||
// preprocess arguments
|
||||
auto preparedArgs = std::vector<TransportValue>();
|
||||
OUTCOME_TRY(auto clientCircuit, getClientCircuit());
|
||||
OUTCOME_TRY(auto clientCircuit, getClientCircuit(name));
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
OUTCOME_TRY(auto preparedInput, clientCircuit.prepareInput(inputs[i], i));
|
||||
preparedArgs.push_back(preparedInput);
|
||||
}
|
||||
// Call server multiple times in a row
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
OUTCOME_TRY(preparedArgs, callServer(preparedArgs));
|
||||
OUTCOME_TRY(preparedArgs, callServer(preparedArgs, name));
|
||||
}
|
||||
// postprocess arguments
|
||||
std::vector<Value> processedOutputs(preparedArgs.size());
|
||||
@@ -135,9 +137,9 @@ public:
|
||||
}
|
||||
|
||||
Result<std::vector<TransportValue>>
|
||||
callServer(std::vector<TransportValue> inputs) {
|
||||
callServer(std::vector<TransportValue> inputs, std::string name = "main") {
|
||||
std::vector<TransportValue> returns;
|
||||
OUTCOME_TRY(auto serverCircuit, getServerCircuit());
|
||||
OUTCOME_TRY(auto serverCircuit, getServerCircuit(name));
|
||||
if (compiler.getCompilationOptions().simulate) {
|
||||
OUTCOME_TRY(returns, serverCircuit.simulate(inputs));
|
||||
} else {
|
||||
@@ -146,29 +148,25 @@ public:
|
||||
return returns;
|
||||
}
|
||||
|
||||
Result<ClientCircuit> getClientCircuit() {
|
||||
Result<ClientCircuit> getClientCircuit(std::string name = "main") {
|
||||
OUTCOME_TRY(auto lib, getLibrary());
|
||||
OUTCOME_TRY(auto ks, getKeyset());
|
||||
auto programInfo = lib.getProgramInfo();
|
||||
OUTCOME_TRY(auto clientProgram,
|
||||
ClientProgram::create(programInfo, ks.client, encryptionCsprng,
|
||||
isSimulation()));
|
||||
OUTCOME_TRY(auto clientCircuit,
|
||||
clientProgram.getClientCircuit(
|
||||
programInfo.asReader().getCircuits()[0].getName()));
|
||||
OUTCOME_TRY(auto clientCircuit, clientProgram.getClientCircuit(name));
|
||||
return clientCircuit;
|
||||
}
|
||||
|
||||
Result<ServerCircuit> getServerCircuit() {
|
||||
Result<ServerCircuit> getServerCircuit(std::string name = "main") {
|
||||
OUTCOME_TRY(auto lib, getLibrary());
|
||||
auto programInfo = lib.getProgramInfo();
|
||||
OUTCOME_TRY(auto serverProgram,
|
||||
ServerProgram::load(programInfo,
|
||||
lib.getSharedLibraryPath(artifactDirectory),
|
||||
isSimulation()));
|
||||
OUTCOME_TRY(auto serverCircuit,
|
||||
serverProgram.getServerCircuit(
|
||||
programInfo.asReader().getCircuits()[0].getName()));
|
||||
OUTCOME_TRY(auto serverCircuit, serverProgram.getServerCircuit(name));
|
||||
return serverCircuit;
|
||||
}
|
||||
|
||||
@@ -177,14 +175,14 @@ private:
|
||||
|
||||
Result<mlir::concretelang::CompilerEngine::Library> getLibrary() {
|
||||
if (!library.has_value()) {
|
||||
return StringError("TestCircuit: compilation has not been done\n");
|
||||
return StringError("TestProgram: compilation has not been done\n");
|
||||
}
|
||||
return *library;
|
||||
}
|
||||
|
||||
Result<Keyset> getKeyset() {
|
||||
if (!keyset.has_value()) {
|
||||
return StringError("TestCircuit: keyset has not been generated\n");
|
||||
return StringError("TestProgram: keyset has not been generated\n");
|
||||
}
|
||||
return *keyset;
|
||||
}
|
||||
@@ -206,10 +204,10 @@ private:
|
||||
|
||||
void deleteFolder(const std::string &folder) {
|
||||
auto ec = std::error_code();
|
||||
llvm::errs() << "TestCircuit: delete artifact directory(" << folder
|
||||
llvm::errs() << "TestProgram: delete artifact directory(" << folder
|
||||
<< ")\n";
|
||||
if (!std::filesystem::remove_all(folder, ec)) {
|
||||
llvm::errs() << "TestCircuit: fail to delete directory(" << folder
|
||||
llvm::errs() << "TestProgram: fail to delete directory(" << folder
|
||||
<< "), error(" << ec.message() << ")\n";
|
||||
}
|
||||
}
|
||||
@@ -232,10 +230,10 @@ private:
|
||||
for (size_t i = 0; i < 5; i++) {
|
||||
auto pathString = new_path();
|
||||
auto ec = std::error_code();
|
||||
llvm::errs() << "TestCircuit: create temporary directory(" << pathString
|
||||
llvm::errs() << "TestProgram: create temporary directory(" << pathString
|
||||
<< ")\n";
|
||||
if (!std::filesystem::create_directory(pathString, ec)) {
|
||||
llvm::errs() << "TestCircuit: fail to create temporary directory("
|
||||
llvm::errs() << "TestProgram: fail to create temporary directory("
|
||||
<< pathString << "), ";
|
||||
if (ec) {
|
||||
llvm::errs() << "already exists";
|
||||
@@ -243,7 +241,7 @@ private:
|
||||
llvm::errs() << "error(" << ec.message() << ")";
|
||||
}
|
||||
} else {
|
||||
llvm::errs() << "TestCircuit: directory(" << pathString
|
||||
llvm::errs() << "TestProgram: directory(" << pathString
|
||||
<< ") successfuly created\n";
|
||||
return pathString;
|
||||
}
|
||||
@@ -102,7 +102,8 @@ concretelang::clientlib::ClientParameters library_load_client_parameters(
|
||||
return *clientParameters;
|
||||
}
|
||||
|
||||
mlir::concretelang::CompilationFeedback library_load_compilation_feedback(
|
||||
mlir::concretelang::ProgramCompilationFeedback
|
||||
library_load_compilation_feedback(
|
||||
LibrarySupport_Py support,
|
||||
mlir::concretelang::LibraryCompilationResult &result) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(compilationFeedback,
|
||||
@@ -113,9 +114,10 @@ mlir::concretelang::CompilationFeedback library_load_compilation_feedback(
|
||||
concretelang::serverlib::ServerLambda
|
||||
library_load_server_lambda(LibrarySupport_Py support,
|
||||
mlir::concretelang::LibraryCompilationResult &result,
|
||||
bool useSimulation) {
|
||||
std::string circuitName, bool useSimulation) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(
|
||||
serverLambda, support.support.loadServerLambda(result, useSimulation));
|
||||
serverLambda,
|
||||
support.support.loadServerLambda(result, circuitName, useSimulation));
|
||||
return *serverLambda;
|
||||
}
|
||||
|
||||
@@ -176,7 +178,8 @@ key_set(concretelang::clientlib::ClientParameters clientParameters,
|
||||
std::unique_ptr<concretelang::clientlib::PublicArguments>
|
||||
encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
|
||||
concretelang::clientlib::KeySet &keySet,
|
||||
llvm::ArrayRef<mlir::concretelang::LambdaArgument *> args) {
|
||||
llvm::ArrayRef<mlir::concretelang::LambdaArgument *> args,
|
||||
const std::string &circuitName) {
|
||||
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
|
||||
clientParameters.programInfo.asReader(), keySet.keyset.client,
|
||||
std::make_shared<::concretelang::csprng::EncryptionCSPRNG>(
|
||||
@@ -185,15 +188,10 @@ encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
|
||||
if (maybeProgram.has_failure()) {
|
||||
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
|
||||
}
|
||||
auto circuit = maybeProgram.value()
|
||||
.getClientCircuit(clientParameters.programInfo.asReader()
|
||||
.getCircuits()[0]
|
||||
.getName())
|
||||
.value();
|
||||
auto circuit = maybeProgram.value().getClientCircuit(circuitName).value();
|
||||
std::vector<TransportValue> output;
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
auto info =
|
||||
clientParameters.programInfo.asReader().getCircuits()[0].getInputs()[i];
|
||||
auto info = circuit.getCircuitInfo().asReader().getInputs()[i];
|
||||
auto typeTransformer = getPythonTypeTransformer(info);
|
||||
auto input = typeTransformer(args[i]->value);
|
||||
auto maybePrepared = circuit.prepareInput(input, i);
|
||||
@@ -211,7 +209,8 @@ encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
|
||||
std::vector<lambdaArgument>
|
||||
decrypt_result(concretelang::clientlib::ClientParameters clientParameters,
|
||||
concretelang::clientlib::KeySet &keySet,
|
||||
concretelang::clientlib::PublicResult &publicResult) {
|
||||
concretelang::clientlib::PublicResult &publicResult,
|
||||
const std::string &circuitName) {
|
||||
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
|
||||
clientParameters.programInfo.asReader(), keySet.keyset.client,
|
||||
std::make_shared<::concretelang::csprng::EncryptionCSPRNG>(
|
||||
@@ -220,11 +219,7 @@ decrypt_result(concretelang::clientlib::ClientParameters clientParameters,
|
||||
if (maybeProgram.has_failure()) {
|
||||
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
|
||||
}
|
||||
auto circuit = maybeProgram.value()
|
||||
.getClientCircuit(clientParameters.programInfo.asReader()
|
||||
.getCircuits()[0]
|
||||
.getName())
|
||||
.value();
|
||||
auto circuit = maybeProgram.value().getClientCircuit(circuitName).value();
|
||||
std::vector<lambdaArgument> results;
|
||||
for (auto e : llvm::enumerate(publicResult.values)) {
|
||||
auto maybeProcessed = circuit.processOutput(e.value(), e.index());
|
||||
@@ -370,9 +365,10 @@ valueSerialize(const concretelang::clientlib::SharedScalarOrTensorData &value) {
|
||||
return maybeString.value();
|
||||
}
|
||||
|
||||
concretelang::clientlib::ValueExporter createValueExporter(
|
||||
concretelang::clientlib::KeySet &keySet,
|
||||
concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
concretelang::clientlib::ValueExporter
|
||||
createValueExporter(concretelang::clientlib::KeySet &keySet,
|
||||
concretelang::clientlib::ClientParameters &clientParameters,
|
||||
const std::string &circuitName) {
|
||||
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
|
||||
clientParameters.programInfo.asReader(), keySet.keyset.client,
|
||||
std::make_shared<::concretelang::csprng::EncryptionCSPRNG>(
|
||||
@@ -381,13 +377,13 @@ concretelang::clientlib::ValueExporter createValueExporter(
|
||||
if (maybeProgram.has_failure()) {
|
||||
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
|
||||
}
|
||||
auto maybeCircuit = maybeProgram.value().getClientCircuit(
|
||||
clientParameters.programInfo.asReader().getCircuits()[0].getName());
|
||||
auto maybeCircuit = maybeProgram.value().getClientCircuit(circuitName);
|
||||
return ::concretelang::clientlib::ValueExporter{maybeCircuit.value()};
|
||||
}
|
||||
|
||||
concretelang::clientlib::SimulatedValueExporter createSimulatedValueExporter(
|
||||
concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
concretelang::clientlib::ClientParameters &clientParameters,
|
||||
const std::string &circuitName) {
|
||||
|
||||
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
|
||||
clientParameters.programInfo, ::concretelang::keysets::ClientKeyset(),
|
||||
@@ -397,15 +393,15 @@ concretelang::clientlib::SimulatedValueExporter createSimulatedValueExporter(
|
||||
if (maybeProgram.has_failure()) {
|
||||
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
|
||||
}
|
||||
auto maybeCircuit = maybeProgram.value().getClientCircuit(
|
||||
clientParameters.programInfo.asReader().getCircuits()[0].getName());
|
||||
auto maybeCircuit = maybeProgram.value().getClientCircuit(circuitName);
|
||||
return ::concretelang::clientlib::SimulatedValueExporter{
|
||||
maybeCircuit.value()};
|
||||
}
|
||||
|
||||
concretelang::clientlib::ValueDecrypter createValueDecrypter(
|
||||
concretelang::clientlib::KeySet &keySet,
|
||||
concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
concretelang::clientlib::ClientParameters &clientParameters,
|
||||
const std::string &circuitName) {
|
||||
|
||||
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
|
||||
clientParameters.programInfo.asReader(), keySet.keyset.client,
|
||||
@@ -415,13 +411,13 @@ concretelang::clientlib::ValueDecrypter createValueDecrypter(
|
||||
if (maybeProgram.has_failure()) {
|
||||
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
|
||||
}
|
||||
auto maybeCircuit = maybeProgram.value().getClientCircuit(
|
||||
clientParameters.programInfo.asReader().getCircuits()[0].getName());
|
||||
auto maybeCircuit = maybeProgram.value().getClientCircuit(circuitName);
|
||||
return ::concretelang::clientlib::ValueDecrypter{maybeCircuit.value()};
|
||||
}
|
||||
|
||||
concretelang::clientlib::SimulatedValueDecrypter createSimulatedValueDecrypter(
|
||||
concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
concretelang::clientlib::ClientParameters &clientParameters,
|
||||
const std::string &circuitName) {
|
||||
|
||||
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
|
||||
clientParameters.programInfo.asReader(),
|
||||
@@ -432,8 +428,7 @@ concretelang::clientlib::SimulatedValueDecrypter createSimulatedValueDecrypter(
|
||||
if (maybeProgram.has_failure()) {
|
||||
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
|
||||
}
|
||||
auto maybeCircuit = maybeProgram.value().getClientCircuit(
|
||||
clientParameters.programInfo.asReader().getCircuits()[0].getName());
|
||||
auto maybeCircuit = maybeProgram.value().getClientCircuit(circuitName);
|
||||
return ::concretelang::clientlib::SimulatedValueDecrypter{
|
||||
maybeCircuit.value()};
|
||||
}
|
||||
@@ -699,14 +694,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
.export_values();
|
||||
|
||||
pybind11::class_<CompilationOptions>(m, "CompilationOptions")
|
||||
.def(pybind11::init(
|
||||
[](std::string funcname, mlir::concretelang::Backend backend) {
|
||||
return CompilationOptions(funcname, backend);
|
||||
}))
|
||||
.def("set_funcname",
|
||||
[](CompilationOptions &options, std::string funcname) {
|
||||
options.mainFuncName = funcname;
|
||||
})
|
||||
.def(pybind11::init([](mlir::concretelang::Backend backend) {
|
||||
return CompilationOptions(backend);
|
||||
}))
|
||||
.def("set_verify_diagnostics",
|
||||
[](CompilationOptions &options, bool b) {
|
||||
options.verifyDiagnostics = b;
|
||||
@@ -828,34 +818,46 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
.def_readonly("keys", &mlir::concretelang::Statistic::keys)
|
||||
.def_readonly("count", &mlir::concretelang::Statistic::count);
|
||||
|
||||
pybind11::class_<mlir::concretelang::CompilationFeedback>(
|
||||
m, "CompilationFeedback")
|
||||
pybind11::class_<mlir::concretelang::ProgramCompilationFeedback>(
|
||||
m, "ProgramCompilationFeedback")
|
||||
.def_readonly("complexity",
|
||||
&mlir::concretelang::CompilationFeedback::complexity)
|
||||
.def_readonly("p_error", &mlir::concretelang::CompilationFeedback::pError)
|
||||
.def_readonly("global_p_error",
|
||||
&mlir::concretelang::CompilationFeedback::globalPError)
|
||||
&mlir::concretelang::ProgramCompilationFeedback::complexity)
|
||||
.def_readonly("p_error",
|
||||
&mlir::concretelang::ProgramCompilationFeedback::pError)
|
||||
.def_readonly(
|
||||
"global_p_error",
|
||||
&mlir::concretelang::ProgramCompilationFeedback::globalPError)
|
||||
.def_readonly(
|
||||
"total_secret_keys_size",
|
||||
&mlir::concretelang::CompilationFeedback::totalSecretKeysSize)
|
||||
&mlir::concretelang::ProgramCompilationFeedback::totalSecretKeysSize)
|
||||
.def_readonly("total_bootstrap_keys_size",
|
||||
&mlir::concretelang::ProgramCompilationFeedback::
|
||||
totalBootstrapKeysSize)
|
||||
.def_readonly("total_keyswitch_keys_size",
|
||||
&mlir::concretelang::ProgramCompilationFeedback::
|
||||
totalKeyswitchKeysSize)
|
||||
.def_readonly(
|
||||
"total_bootstrap_keys_size",
|
||||
&mlir::concretelang::CompilationFeedback::totalBootstrapKeysSize)
|
||||
"circuit_feedbacks",
|
||||
&mlir::concretelang::ProgramCompilationFeedback::circuitFeedbacks);
|
||||
|
||||
pybind11::class_<mlir::concretelang::CircuitCompilationFeedback>(
|
||||
m, "CircuitCompilationFeedback")
|
||||
.def_readonly("name",
|
||||
&mlir::concretelang::CircuitCompilationFeedback::name)
|
||||
.def_readonly(
|
||||
"total_keyswitch_keys_size",
|
||||
&mlir::concretelang::CompilationFeedback::totalKeyswitchKeysSize)
|
||||
.def_readonly("total_inputs_size",
|
||||
&mlir::concretelang::CompilationFeedback::totalInputsSize)
|
||||
.def_readonly("total_output_size",
|
||||
&mlir::concretelang::CompilationFeedback::totalOutputsSize)
|
||||
"total_inputs_size",
|
||||
&mlir::concretelang::CircuitCompilationFeedback::totalInputsSize)
|
||||
.def_readonly(
|
||||
"crt_decompositions_of_outputs",
|
||||
&mlir::concretelang::CompilationFeedback::crtDecompositionsOfOutputs)
|
||||
"total_output_size",
|
||||
&mlir::concretelang::CircuitCompilationFeedback::totalOutputsSize)
|
||||
.def_readonly("crt_decompositions_of_outputs",
|
||||
&mlir::concretelang::CircuitCompilationFeedback::
|
||||
crtDecompositionsOfOutputs)
|
||||
.def_readonly("statistics",
|
||||
&mlir::concretelang::CompilationFeedback::statistics)
|
||||
&mlir::concretelang::CircuitCompilationFeedback::statistics)
|
||||
.def_readonly(
|
||||
"memory_usage_per_location",
|
||||
&mlir::concretelang::CompilationFeedback::memoryUsagePerLoc);
|
||||
&mlir::concretelang::CircuitCompilationFeedback::memoryUsagePerLoc);
|
||||
|
||||
pybind11::class_<mlir::concretelang::CompilationContext,
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext>>(
|
||||
@@ -872,11 +874,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
|
||||
pybind11::class_<mlir::concretelang::LibraryCompilationResult>(
|
||||
m, "LibraryCompilationResult")
|
||||
.def(pybind11::init([](std::string outputDirPath, std::string funcname) {
|
||||
return mlir::concretelang::LibraryCompilationResult{
|
||||
outputDirPath,
|
||||
funcname,
|
||||
};
|
||||
.def(pybind11::init([](std::string outputDirPath) {
|
||||
return mlir::concretelang::LibraryCompilationResult{outputDirPath};
|
||||
}));
|
||||
pybind11::class_<::concretelang::serverlib::ServerLambda>(m, "LibraryLambda");
|
||||
pybind11::class_<LibrarySupport_Py>(m, "LibrarySupport")
|
||||
@@ -920,8 +919,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
"load_server_lambda",
|
||||
[](LibrarySupport_Py &support,
|
||||
mlir::concretelang::LibraryCompilationResult &result,
|
||||
bool useSimulation) {
|
||||
return library_load_server_lambda(support, result, useSimulation);
|
||||
std::string circuitName, bool useSimulation) {
|
||||
return library_load_server_lambda(support, result, circuitName,
|
||||
useSimulation);
|
||||
},
|
||||
pybind11::return_value_policy::reference)
|
||||
.def("server_call",
|
||||
@@ -974,19 +974,22 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
"encrypt_arguments",
|
||||
[](::concretelang::clientlib::ClientParameters clientParameters,
|
||||
::concretelang::clientlib::KeySet &keySet,
|
||||
std::vector<lambdaArgument> args) {
|
||||
std::vector<lambdaArgument> args, const std::string &circuitName) {
|
||||
std::vector<mlir::concretelang::LambdaArgument *> argsRef;
|
||||
for (auto i = 0u; i < args.size(); i++) {
|
||||
argsRef.push_back(args[i].ptr.get());
|
||||
}
|
||||
return encrypt_arguments(clientParameters, keySet, argsRef);
|
||||
return encrypt_arguments(clientParameters, keySet, argsRef,
|
||||
circuitName);
|
||||
})
|
||||
.def_static(
|
||||
"decrypt_result",
|
||||
[](::concretelang::clientlib::ClientParameters clientParameters,
|
||||
::concretelang::clientlib::KeySet &keySet,
|
||||
::concretelang::clientlib::PublicResult &publicResult) {
|
||||
return decrypt_result(clientParameters, keySet, publicResult);
|
||||
::concretelang::clientlib::PublicResult &publicResult,
|
||||
const std::string &circuitName) {
|
||||
return decrypt_result(clientParameters, keySet, publicResult,
|
||||
circuitName);
|
||||
});
|
||||
pybind11::class_<::concretelang::clientlib::KeySetCache>(m, "KeySetCache")
|
||||
.def(pybind11::init<std::string &>());
|
||||
@@ -1187,8 +1190,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
.def_static(
|
||||
"create",
|
||||
[](::concretelang::clientlib::KeySet &keySet,
|
||||
::concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
return createValueExporter(keySet, clientParameters);
|
||||
::concretelang::clientlib::ClientParameters &clientParameters,
|
||||
const std::string &circuitName) {
|
||||
return createValueExporter(keySet, clientParameters, circuitName);
|
||||
})
|
||||
.def("export_scalar",
|
||||
[](::concretelang::clientlib::ValueExporter &exporter,
|
||||
@@ -1233,8 +1237,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
m, "SimulatedValueExporter")
|
||||
.def_static(
|
||||
"create",
|
||||
[](::concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
return createSimulatedValueExporter(clientParameters);
|
||||
[](::concretelang::clientlib::ClientParameters &clientParameters,
|
||||
const std::string &circuitName) {
|
||||
return createSimulatedValueExporter(clientParameters, circuitName);
|
||||
})
|
||||
.def("export_scalar",
|
||||
[](::concretelang::clientlib::SimulatedValueExporter &exporter,
|
||||
@@ -1279,8 +1284,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
.def_static(
|
||||
"create",
|
||||
[](::concretelang::clientlib::KeySet &keySet,
|
||||
::concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
return createValueDecrypter(keySet, clientParameters);
|
||||
::concretelang::clientlib::ClientParameters &clientParameters,
|
||||
const std::string &circuitName) {
|
||||
return createValueDecrypter(keySet, clientParameters, circuitName);
|
||||
})
|
||||
.def("decrypt",
|
||||
[](::concretelang::clientlib::ValueDecrypter &decrypter,
|
||||
@@ -1303,8 +1309,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
m, "SimulatedValueDecrypter")
|
||||
.def_static(
|
||||
"create",
|
||||
[](::concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
return createSimulatedValueDecrypter(clientParameters);
|
||||
[](::concretelang::clientlib::ClientParameters &clientParameters,
|
||||
const std::string &circuitName) {
|
||||
return createSimulatedValueDecrypter(clientParameters, circuitName);
|
||||
})
|
||||
.def("decrypt",
|
||||
[](::concretelang::clientlib::SimulatedValueDecrypter &decrypter,
|
||||
|
||||
@@ -21,7 +21,7 @@ from .compilation_options import CompilationOptions, Encoding
|
||||
from .compilation_context import CompilationContext
|
||||
from .key_set_cache import KeySetCache
|
||||
from .client_parameters import ClientParameters
|
||||
from .compilation_feedback import CompilationFeedback
|
||||
from .compilation_feedback import ProgramCompilationFeedback, CircuitCompilationFeedback
|
||||
from .key_set import KeySet
|
||||
from .public_result import PublicResult
|
||||
from .public_arguments import PublicArguments
|
||||
|
||||
@@ -116,6 +116,7 @@ class ClientSupport(WrapperCpp):
|
||||
client_parameters: ClientParameters,
|
||||
keyset: KeySet,
|
||||
args: List[Union[int, np.ndarray]],
|
||||
circuit_name: str = "main",
|
||||
) -> PublicArguments:
|
||||
"""Prepare arguments for encrypted computation.
|
||||
|
||||
@@ -126,10 +127,12 @@ class ClientSupport(WrapperCpp):
|
||||
client_parameters (ClientParameters): client parameters specification
|
||||
keyset (KeySet): keyset used to encrypt arguments that require encryption
|
||||
args (List[Union[int, np.ndarray]]): list of scalar or tensor arguments
|
||||
circuit_name(str): the name of the circuit for which to encrypt
|
||||
|
||||
Raises:
|
||||
TypeError: if client_parameters is not of type ClientParameters
|
||||
TypeError: if keyset is not of type KeySet
|
||||
TypeError: if circuit_name is not of type str
|
||||
|
||||
Returns:
|
||||
PublicArguments: public arguments for execution
|
||||
@@ -140,6 +143,10 @@ class ClientSupport(WrapperCpp):
|
||||
)
|
||||
if not isinstance(keyset, KeySet):
|
||||
raise TypeError(f"keyset must be of type KeySet, not {type(keyset)}")
|
||||
if not isinstance(circuit_name, str):
|
||||
raise TypeError(
|
||||
f"circuit_name must be of type str, not {type(circuit_name)}"
|
||||
)
|
||||
|
||||
signs = client_parameters.input_signs()
|
||||
if len(signs) != len(args):
|
||||
@@ -156,6 +163,7 @@ class ClientSupport(WrapperCpp):
|
||||
client_parameters.cpp(),
|
||||
keyset.cpp(),
|
||||
[arg.cpp() for arg in lambda_arguments],
|
||||
circuit_name,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -164,17 +172,20 @@ class ClientSupport(WrapperCpp):
|
||||
client_parameters: ClientParameters,
|
||||
keyset: KeySet,
|
||||
public_result: PublicResult,
|
||||
circuit_name: str = "main",
|
||||
) -> Union[int, np.ndarray]:
|
||||
"""Decrypt a public result using the keyset.
|
||||
|
||||
Args:
|
||||
client_parameters (ClientParameters): client parameters for decryption
|
||||
keyset (KeySet): keyset used for decryption
|
||||
public_result: public result to decrypt
|
||||
public_result (PublicResult): public result to decrypt
|
||||
circuit_name (str): name of the circuit for which to decrypt
|
||||
|
||||
Raises:
|
||||
TypeError: if keyset is not of type KeySet
|
||||
TypeError: if public_result is not of type PublicResult
|
||||
TypeError: if circuit_name is not of type str
|
||||
RuntimeError: if the result is of an unknown type
|
||||
|
||||
Returns:
|
||||
@@ -186,8 +197,12 @@ class ClientSupport(WrapperCpp):
|
||||
raise TypeError(
|
||||
f"public_result must be of type PublicResult, not {type(public_result)}"
|
||||
)
|
||||
if not isinstance(circuit_name, str):
|
||||
raise TypeError(
|
||||
f"circuit_name must be of type str, not {type(circuit_name)}"
|
||||
)
|
||||
results = _ClientSupport.decrypt_result(
|
||||
client_parameters.cpp(), keyset.cpp(), public_result.cpp()
|
||||
client_parameters.cpp(), keyset.cpp(), public_result.cpp(), circuit_name
|
||||
)
|
||||
|
||||
def process_result(result):
|
||||
|
||||
@@ -8,7 +8,8 @@ from typing import Dict, Set
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error,too-many-instance-attributes
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
CompilationFeedback as _CompilationFeedback,
|
||||
ProgramCompilationFeedback as _ProgramCompilationFeedback,
|
||||
CircuitCompilationFeedback as _CircuitCompilationFeedback,
|
||||
KeyType,
|
||||
PrimitiveOperation,
|
||||
)
|
||||
@@ -39,38 +40,36 @@ def tag_from_location(location):
|
||||
return tag
|
||||
|
||||
|
||||
class CompilationFeedback(WrapperCpp):
|
||||
"""CompilationFeedback is a set of hint computed by the compiler engine."""
|
||||
class CircuitCompilationFeedback(WrapperCpp):
|
||||
"""CircuitCompilationFeedback is a set of hint computed by the compiler engine for a circuit."""
|
||||
|
||||
def __init__(self, compilation_feedback: _CompilationFeedback):
|
||||
def __init__(self, circuit_compilation_feedback: _CircuitCompilationFeedback):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
compilation_feeback (_CompilationFeedback): object to wrap
|
||||
circuit_compilation_feeback (_CircuitCompilationFeedback): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if compilation_feedback is not of type _CompilationFeedback
|
||||
TypeError: if circuit_compilation_feedback is not of type _CircuitCompilationFeedback
|
||||
"""
|
||||
if not isinstance(compilation_feedback, _CompilationFeedback):
|
||||
if not isinstance(circuit_compilation_feedback, _CircuitCompilationFeedback):
|
||||
raise TypeError(
|
||||
f"compilation_feedback must be of type _CompilationFeedback, not {type(compilation_feedback)}"
|
||||
"circuit_compilation_feedback must be of type "
|
||||
f"_CircuitCompilationFeedback, not {type(circuit_compilation_feedback)}"
|
||||
)
|
||||
|
||||
self.complexity = compilation_feedback.complexity
|
||||
self.p_error = compilation_feedback.p_error
|
||||
self.global_p_error = compilation_feedback.global_p_error
|
||||
self.total_secret_keys_size = compilation_feedback.total_secret_keys_size
|
||||
self.total_bootstrap_keys_size = compilation_feedback.total_bootstrap_keys_size
|
||||
self.total_keyswitch_keys_size = compilation_feedback.total_keyswitch_keys_size
|
||||
self.total_inputs_size = compilation_feedback.total_inputs_size
|
||||
self.total_output_size = compilation_feedback.total_output_size
|
||||
self.name = circuit_compilation_feedback.name
|
||||
self.total_inputs_size = circuit_compilation_feedback.total_inputs_size
|
||||
self.total_output_size = circuit_compilation_feedback.total_output_size
|
||||
self.crt_decompositions_of_outputs = (
|
||||
compilation_feedback.crt_decompositions_of_outputs
|
||||
circuit_compilation_feedback.crt_decompositions_of_outputs
|
||||
)
|
||||
self.statistics = circuit_compilation_feedback.statistics
|
||||
self.memory_usage_per_location = (
|
||||
circuit_compilation_feedback.memory_usage_per_location
|
||||
)
|
||||
self.statistics = compilation_feedback.statistics
|
||||
self.memory_usage_per_location = compilation_feedback.memory_usage_per_location
|
||||
|
||||
super().__init__(compilation_feedback)
|
||||
super().__init__(circuit_compilation_feedback)
|
||||
|
||||
def count(self, *, operations: Set[PrimitiveOperation]) -> int:
|
||||
"""
|
||||
@@ -216,3 +215,68 @@ class CompilationFeedback(WrapperCpp):
|
||||
result[current_tag][parameter] += statistic.count
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ProgramCompilationFeedback(WrapperCpp):
|
||||
"""CompilationFeedback is a set of hint computed by the compiler engine."""
|
||||
|
||||
def __init__(self, program_compilation_feedback: _ProgramCompilationFeedback):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
compilation_feeback (_CompilationFeedback): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if program_compilation_feedback is not of type _CompilationFeedback
|
||||
"""
|
||||
if not isinstance(program_compilation_feedback, _ProgramCompilationFeedback):
|
||||
raise TypeError(
|
||||
"program_compilation_feedback must be of type "
|
||||
f"_CompilationFeedback, not {type(program_compilation_feedback)}"
|
||||
)
|
||||
|
||||
self.complexity = program_compilation_feedback.complexity
|
||||
self.p_error = program_compilation_feedback.p_error
|
||||
self.global_p_error = program_compilation_feedback.global_p_error
|
||||
self.total_secret_keys_size = (
|
||||
program_compilation_feedback.total_secret_keys_size
|
||||
)
|
||||
self.total_bootstrap_keys_size = (
|
||||
program_compilation_feedback.total_bootstrap_keys_size
|
||||
)
|
||||
self.total_keyswitch_keys_size = (
|
||||
program_compilation_feedback.total_keyswitch_keys_size
|
||||
)
|
||||
self.circuit_feedbacks = [
|
||||
CircuitCompilationFeedback(c)
|
||||
for c in program_compilation_feedback.circuit_feedbacks
|
||||
]
|
||||
|
||||
super().__init__(program_compilation_feedback)
|
||||
|
||||
def circuit(self, circuit_name: str) -> CircuitCompilationFeedback:
|
||||
"""
|
||||
Returns the feedback for the circuit circuit_name.
|
||||
|
||||
Args:
|
||||
circuit_name (str):
|
||||
the name of the circuit.
|
||||
|
||||
Returns:
|
||||
CircuitCompilationFeedback:
|
||||
the feedback for the circuit.
|
||||
|
||||
Raises:
|
||||
TypeError: if the circuit_name is not a string
|
||||
ValueError: if there is no circuit with name circuit_name
|
||||
"""
|
||||
if not isinstance(circuit_name, str):
|
||||
raise TypeError(
|
||||
f"circuit_name must be of type str, not {type(circuit_name)}"
|
||||
)
|
||||
|
||||
for circuit in self.circuit_feedbacks:
|
||||
if circuit.name == circuit_name:
|
||||
return circuit
|
||||
|
||||
raise ValueError(f"no circuit with name {circuit_name} found")
|
||||
|
||||
@@ -43,11 +43,11 @@ class CompilationOptions(WrapperCpp):
|
||||
|
||||
@staticmethod
|
||||
# pylint: disable=arguments-differ
|
||||
def new(function_name="main", backend=_Backend.CPU) -> "CompilationOptions":
|
||||
def new(backend=_Backend.CPU) -> "CompilationOptions":
|
||||
"""Build a CompilationOptions.
|
||||
|
||||
Args:
|
||||
function_name (str, optional): name of the entrypoint function. Defaults to "main".
|
||||
backend (_Backend): backend to use.
|
||||
|
||||
Raises:
|
||||
TypeError: if function_name is not an str
|
||||
@@ -55,15 +55,9 @@ class CompilationOptions(WrapperCpp):
|
||||
Returns:
|
||||
CompilationOptions
|
||||
"""
|
||||
if not isinstance(function_name, str):
|
||||
raise TypeError(
|
||||
f"function_name must be of type str not {type(function_name)}"
|
||||
)
|
||||
if not isinstance(backend, _Backend):
|
||||
raise TypeError(
|
||||
f"backend must be of type Backend not {type(function_name)}"
|
||||
)
|
||||
return CompilationOptions.wrap(_CompilationOptions(function_name, backend))
|
||||
raise TypeError(f"backend must be of type Backend not {type(backend)}")
|
||||
return CompilationOptions.wrap(_CompilationOptions(backend))
|
||||
|
||||
# pylint: enable=arguments-differ
|
||||
|
||||
|
||||
@@ -33,16 +33,14 @@ class LibraryCompilationResult(WrapperCpp):
|
||||
|
||||
@staticmethod
|
||||
# pylint: disable=arguments-differ
|
||||
def new(output_dir_path: str, func_name: str) -> "LibraryCompilationResult":
|
||||
"""Build a LibraryCompilationResult at output_dir_path, with func_name as entrypoint.
|
||||
def new(output_dir_path: str) -> "LibraryCompilationResult":
|
||||
"""Build a LibraryCompilationResult at output_dir_path.
|
||||
|
||||
Args:
|
||||
output_dir_path (str): path to the compilation artifacts
|
||||
func_name (str): entrypoint function name
|
||||
|
||||
Raises:
|
||||
TypeError: if output_dir_path is not of type str
|
||||
TypeError: if func_name is not of type str
|
||||
|
||||
Returns:
|
||||
LibraryCompilationResult
|
||||
@@ -51,10 +49,6 @@ class LibraryCompilationResult(WrapperCpp):
|
||||
raise TypeError(
|
||||
f"output_dir_path must be of type str, not {type(output_dir_path)}"
|
||||
)
|
||||
if not isinstance(func_name, str):
|
||||
raise TypeError(f"func_name must be of type str, not {type(func_name)}")
|
||||
return LibraryCompilationResult.wrap(
|
||||
_LibraryCompilationResult(output_dir_path, func_name)
|
||||
)
|
||||
return LibraryCompilationResult.wrap(_LibraryCompilationResult(output_dir_path))
|
||||
|
||||
# pylint: enable=arguments-differ
|
||||
|
||||
@@ -23,7 +23,7 @@ from .public_arguments import PublicArguments
|
||||
from .library_lambda import LibraryLambda
|
||||
from .public_result import PublicResult
|
||||
from .client_parameters import ClientParameters
|
||||
from .compilation_feedback import CompilationFeedback
|
||||
from .compilation_feedback import ProgramCompilationFeedback
|
||||
from .wrapper import WrapperCpp
|
||||
from .utils import lookup_runtime_lib
|
||||
from .evaluation_keys import EvaluationKeys
|
||||
@@ -132,7 +132,7 @@ class LibrarySupport(WrapperCpp):
|
||||
def compile(
|
||||
self,
|
||||
mlir_program: Union[str, MlirModule],
|
||||
options: CompilationOptions = CompilationOptions.new("main"),
|
||||
options: CompilationOptions = CompilationOptions.new(),
|
||||
compilation_context: Optional[CompilationContext] = None,
|
||||
) -> LibraryCompilationResult:
|
||||
"""Compile an MLIR program using Concrete dialects into a library.
|
||||
@@ -178,18 +178,13 @@ class LibrarySupport(WrapperCpp):
|
||||
self.cpp().compile(mlir_program, options.cpp())
|
||||
)
|
||||
|
||||
def reload(self, func_name: str = "main") -> LibraryCompilationResult:
|
||||
def reload(self) -> LibraryCompilationResult:
|
||||
"""Reload the library compilation result from the output_dir_path.
|
||||
|
||||
Args:
|
||||
func_name: entrypoint function name
|
||||
|
||||
Returns:
|
||||
LibraryCompilationResult: loaded library
|
||||
"""
|
||||
if not isinstance(func_name, str):
|
||||
raise TypeError(f"func_name must be of type str, not {type(func_name)}")
|
||||
return LibraryCompilationResult.new(self.output_dir_path, func_name)
|
||||
return LibraryCompilationResult.new(self.output_dir_path)
|
||||
|
||||
def load_client_parameters(
|
||||
self, library_compilation_result: LibraryCompilationResult
|
||||
@@ -217,7 +212,7 @@ class LibrarySupport(WrapperCpp):
|
||||
|
||||
def load_compilation_feedback(
|
||||
self, compilation_result: LibraryCompilationResult
|
||||
) -> CompilationFeedback:
|
||||
) -> ProgramCompilationFeedback:
|
||||
"""Load the compilation feedback from the compilation result.
|
||||
|
||||
Args:
|
||||
@@ -227,13 +222,13 @@ class LibrarySupport(WrapperCpp):
|
||||
TypeError: if compilation_result is not of type LibraryCompilationResult
|
||||
|
||||
Returns:
|
||||
CompilationFeedback: the compilation feedback for the compiled program
|
||||
ProgramCompilationFeedback: the compilation feedback for the compiled program
|
||||
"""
|
||||
if not isinstance(compilation_result, LibraryCompilationResult):
|
||||
raise TypeError(
|
||||
f"compilation_result must be of type LibraryCompilationResult, not {type(compilation_result)}"
|
||||
)
|
||||
return CompilationFeedback.wrap(
|
||||
return ProgramCompilationFeedback.wrap(
|
||||
self.cpp().load_compilation_feedback(compilation_result.cpp())
|
||||
)
|
||||
|
||||
@@ -241,14 +236,18 @@ class LibrarySupport(WrapperCpp):
|
||||
self,
|
||||
library_compilation_result: LibraryCompilationResult,
|
||||
simulation: bool,
|
||||
circuit_name: str = "main",
|
||||
) -> LibraryLambda:
|
||||
"""Load the server lambda from the library compilation result.
|
||||
"""Load the server lambda for a given circuit from the library compilation result.
|
||||
|
||||
Args:
|
||||
library_compilation_result (LibraryCompilationResult): compilation result of the library
|
||||
simulation (bool): use simulation for execution
|
||||
circuit_name (str): name of the circuit to be loaded
|
||||
|
||||
Raises:
|
||||
TypeError: if library_compilation_result is not of type LibraryCompilationResult
|
||||
TypeError: if library_compilation_result is not of type LibraryCompilationResult, if
|
||||
circuit_name is not of type str or
|
||||
|
||||
Returns:
|
||||
LibraryLambda: executable reference to the library
|
||||
@@ -258,8 +257,18 @@ class LibrarySupport(WrapperCpp):
|
||||
f"library_compilation_result must be of type LibraryCompilationResult, not "
|
||||
f"{type(library_compilation_result)}"
|
||||
)
|
||||
if not isinstance(circuit_name, str):
|
||||
raise TypeError(
|
||||
f"circuit_name must be of type str, not " f"{type(circuit_name)}"
|
||||
)
|
||||
if not isinstance(simulation, bool):
|
||||
raise TypeError(
|
||||
f"simulation must be of type bool, not " f"{type(simulation)}"
|
||||
)
|
||||
return LibraryLambda.wrap(
|
||||
self.cpp().load_server_lambda(library_compilation_result.cpp(), simulation)
|
||||
self.cpp().load_server_lambda(
|
||||
library_compilation_result.cpp(), circuit_name, simulation
|
||||
)
|
||||
)
|
||||
|
||||
def server_call(
|
||||
|
||||
@@ -41,12 +41,12 @@ class SimulatedValueDecrypter(WrapperCpp):
|
||||
|
||||
@staticmethod
|
||||
# pylint: disable=arguments-differ
|
||||
def new(client_parameters: ClientParameters):
|
||||
def new(client_parameters: ClientParameters, circuit_name: str = "main"):
|
||||
"""
|
||||
Create a value decrypter.
|
||||
"""
|
||||
return SimulatedValueDecrypter(
|
||||
_SimulatedValueDecrypter.create(client_parameters.cpp())
|
||||
_SimulatedValueDecrypter.create(client_parameters.cpp(), circuit_name)
|
||||
)
|
||||
|
||||
def decrypt(self, position: int, value: Value) -> Union[int, np.ndarray]:
|
||||
|
||||
@@ -40,12 +40,14 @@ class SimulatedValueExporter(WrapperCpp):
|
||||
|
||||
@staticmethod
|
||||
# pylint: disable=arguments-differ
|
||||
def new(client_parameters: ClientParameters) -> "SimulatedValueExporter":
|
||||
def new(
|
||||
client_parameters: ClientParameters, circuitName: str = "main"
|
||||
) -> "SimulatedValueExporter":
|
||||
"""
|
||||
Create a value exporter.
|
||||
"""
|
||||
return SimulatedValueExporter(
|
||||
_SimulatedValueExporter.create(client_parameters.cpp())
|
||||
_SimulatedValueExporter.create(client_parameters.cpp(), circuitName)
|
||||
)
|
||||
|
||||
def export_scalar(self, position: int, value: int) -> Value:
|
||||
|
||||
@@ -42,12 +42,14 @@ class ValueDecrypter(WrapperCpp):
|
||||
|
||||
@staticmethod
|
||||
# pylint: disable=arguments-differ
|
||||
def new(keyset: KeySet, client_parameters: ClientParameters):
|
||||
def new(
|
||||
keyset: KeySet, client_parameters: ClientParameters, circuit_name: str = "main"
|
||||
):
|
||||
"""
|
||||
Create a value decrypter.
|
||||
"""
|
||||
return ValueDecrypter(
|
||||
_ValueDecrypter.create(keyset.cpp(), client_parameters.cpp())
|
||||
_ValueDecrypter.create(keyset.cpp(), client_parameters.cpp(), circuit_name)
|
||||
)
|
||||
|
||||
def decrypt(self, position: int, value: Value) -> Union[int, np.ndarray]:
|
||||
|
||||
@@ -41,12 +41,14 @@ class ValueExporter(WrapperCpp):
|
||||
|
||||
@staticmethod
|
||||
# pylint: disable=arguments-differ
|
||||
def new(keyset: KeySet, client_parameters: ClientParameters) -> "ValueExporter":
|
||||
def new(
|
||||
keyset: KeySet, client_parameters: ClientParameters, circuit_name: str = "main"
|
||||
) -> "ValueExporter":
|
||||
"""
|
||||
Create a value exporter.
|
||||
"""
|
||||
return ValueExporter(
|
||||
_ValueExporter.create(keyset.cpp(), client_parameters.cpp())
|
||||
_ValueExporter.create(keyset.cpp(), client_parameters.cpp(), circuit_name)
|
||||
)
|
||||
|
||||
def export_scalar(self, position: int, value: int) -> Value:
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <concretelang/Dialect/Concrete/IR/ConcreteOps.h>
|
||||
#include <concretelang/Support/logging.h>
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Dialect/SCF/IR/SCF.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
@@ -66,34 +67,48 @@ namespace Concrete {
|
||||
struct MemoryUsagePass
|
||||
: public PassWrapper<MemoryUsagePass, OperationPass<ModuleOp>> {
|
||||
|
||||
CompilationFeedback &feedback;
|
||||
ProgramCompilationFeedback &feedback;
|
||||
CircuitCompilationFeedback *circuitFeedback;
|
||||
|
||||
MemoryUsagePass(CompilationFeedback &feedback) : feedback{feedback} {};
|
||||
MemoryUsagePass(ProgramCompilationFeedback &feedback)
|
||||
: feedback{feedback}, circuitFeedback{nullptr} {};
|
||||
|
||||
void runOnOperation() override {
|
||||
WalkResult walk =
|
||||
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
|
||||
if (stage.isBeforeAllRegions()) {
|
||||
std::optional<StringError> error = this->enter(op);
|
||||
if (error.has_value()) {
|
||||
op->emitError() << error->mesg;
|
||||
return WalkResult::interrupt();
|
||||
auto module = getOperation();
|
||||
auto funcs = module.getOps<mlir::func::FuncOp>();
|
||||
for (CircuitCompilationFeedback &circuitFeedback :
|
||||
feedback.circuitFeedbacks) {
|
||||
auto funcOp = llvm::find_if(funcs, [&](mlir::func::FuncOp op) {
|
||||
return op.getName() == circuitFeedback.name;
|
||||
});
|
||||
assert(funcOp != funcs.end());
|
||||
this->circuitFeedback = &circuitFeedback;
|
||||
|
||||
WalkResult walk =
|
||||
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
|
||||
if (stage.isBeforeAllRegions()) {
|
||||
std::optional<StringError> error = this->enter(op);
|
||||
if (error.has_value()) {
|
||||
op->emitError() << error->mesg;
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (stage.isAfterAllRegions()) {
|
||||
std::optional<StringError> error = this->exit(op);
|
||||
if (error.has_value()) {
|
||||
op->emitError() << error->mesg;
|
||||
return WalkResult::interrupt();
|
||||
if (stage.isAfterAllRegions()) {
|
||||
std::optional<StringError> error = this->exit(op);
|
||||
if (error.has_value()) {
|
||||
op->emitError() << error->mesg;
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (walk.wasInterrupted()) {
|
||||
signalPassFailure();
|
||||
if (walk.wasInterrupted()) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,7 +187,7 @@ struct MemoryUsagePass
|
||||
// element_size
|
||||
auto memoryUsage = numberOfAlloc * maybeBufferSize.value();
|
||||
|
||||
pass.feedback.memoryUsagePerLoc[location] += memoryUsage;
|
||||
pass.circuitFeedback->memoryUsagePerLoc[location] += memoryUsage;
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
@@ -218,7 +233,7 @@ struct MemoryUsagePass
|
||||
}
|
||||
auto bufferSize = maybeBufferSize.value();
|
||||
|
||||
pass.feedback.memoryUsagePerLoc[location] += bufferSize;
|
||||
pass.circuitFeedback->memoryUsagePerLoc[location] += bufferSize;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -233,7 +248,7 @@ struct MemoryUsagePass
|
||||
} // namespace Concrete
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createMemoryUsagePass(CompilationFeedback &feedback) {
|
||||
createMemoryUsagePass(ProgramCompilationFeedback &feedback) {
|
||||
return std::make_unique<Concrete::MemoryUsagePass>(feedback);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "concretelang/Support/CompilationFeedback.h"
|
||||
#include <concretelang/Analysis/Utils.h>
|
||||
#include <concretelang/Dialect/TFHE/Analysis/ExtractStatistics.h>
|
||||
|
||||
@@ -34,35 +35,48 @@ namespace TFHE {
|
||||
struct ExtractTFHEStatisticsPass
|
||||
: public PassWrapper<ExtractTFHEStatisticsPass, OperationPass<ModuleOp>> {
|
||||
|
||||
CompilationFeedback &feedback;
|
||||
ProgramCompilationFeedback &feedback;
|
||||
CircuitCompilationFeedback *circuitFeedback;
|
||||
|
||||
ExtractTFHEStatisticsPass(CompilationFeedback &feedback)
|
||||
: feedback{feedback} {};
|
||||
ExtractTFHEStatisticsPass(ProgramCompilationFeedback &feedback)
|
||||
: feedback{feedback}, circuitFeedback{nullptr} {};
|
||||
|
||||
void runOnOperation() override {
|
||||
WalkResult walk =
|
||||
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
|
||||
if (stage.isBeforeAllRegions()) {
|
||||
std::optional<StringError> error = this->enter(op);
|
||||
if (error.has_value()) {
|
||||
op->emitError() << error->mesg;
|
||||
return WalkResult::interrupt();
|
||||
auto module = getOperation();
|
||||
auto funcs = module.getOps<mlir::func::FuncOp>();
|
||||
for (CircuitCompilationFeedback &circuitFeedback :
|
||||
feedback.circuitFeedbacks) {
|
||||
auto funcOp = llvm::find_if(funcs, [&](mlir::func::FuncOp op) {
|
||||
return op.getName() == circuitFeedback.name;
|
||||
});
|
||||
assert(funcOp != funcs.end());
|
||||
this->circuitFeedback = &circuitFeedback;
|
||||
|
||||
WalkResult walk =
|
||||
(*funcOp)->walk([&](Operation *op, const WalkStage &stage) {
|
||||
if (stage.isBeforeAllRegions()) {
|
||||
std::optional<StringError> error = this->enter(op);
|
||||
if (error.has_value()) {
|
||||
op->emitError() << error->mesg;
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (stage.isAfterAllRegions()) {
|
||||
std::optional<StringError> error = this->exit(op);
|
||||
if (error.has_value()) {
|
||||
op->emitError() << error->mesg;
|
||||
return WalkResult::interrupt();
|
||||
if (stage.isAfterAllRegions()) {
|
||||
std::optional<StringError> error = this->exit(op);
|
||||
if (error.has_value()) {
|
||||
op->emitError() << error->mesg;
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (walk.wasInterrupted()) {
|
||||
signalPassFailure();
|
||||
if (walk.wasInterrupted()) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,14 +132,14 @@ struct ExtractTFHEStatisticsPass
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::ENCRYPTED_ADDITION;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto keys = std::vector<std::pair<KeyType, int64_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
std::pair<KeyType, int64_t> key =
|
||||
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(concretelang::Statistic{
|
||||
pass.circuitFeedback->statistics.push_back(concretelang::Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
@@ -145,14 +159,14 @@ struct ExtractTFHEStatisticsPass
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::CLEAR_ADDITION;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto keys = std::vector<std::pair<KeyType, int64_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
std::pair<KeyType, int64_t> key =
|
||||
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(concretelang::Statistic{
|
||||
pass.circuitFeedback->statistics.push_back(concretelang::Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
@@ -172,14 +186,14 @@ struct ExtractTFHEStatisticsPass
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::PBS;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto keys = std::vector<std::pair<KeyType, int64_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex());
|
||||
std::pair<KeyType, int64_t> key =
|
||||
std::make_pair(KeyType::BOOTSTRAP, (int64_t)bsk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(concretelang::Statistic{
|
||||
pass.circuitFeedback->statistics.push_back(concretelang::Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
@@ -199,14 +213,14 @@ struct ExtractTFHEStatisticsPass
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::KEY_SWITCH;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto keys = std::vector<std::pair<KeyType, int64_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex());
|
||||
std::pair<KeyType, int64_t> key =
|
||||
std::make_pair(KeyType::KEY_SWITCH, (int64_t)ksk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(concretelang::Statistic{
|
||||
pass.circuitFeedback->statistics.push_back(concretelang::Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
@@ -226,14 +240,14 @@ struct ExtractTFHEStatisticsPass
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::CLEAR_MULTIPLICATION;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto keys = std::vector<std::pair<KeyType, int64_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
std::pair<KeyType, int64_t> key =
|
||||
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(concretelang::Statistic{
|
||||
pass.circuitFeedback->statistics.push_back(concretelang::Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
@@ -253,14 +267,14 @@ struct ExtractTFHEStatisticsPass
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::ENCRYPTED_NEGATION;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto keys = std::vector<std::pair<KeyType, int64_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
std::pair<KeyType, int64_t> key =
|
||||
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(concretelang::Statistic{
|
||||
pass.circuitFeedback->statistics.push_back(concretelang::Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
@@ -279,17 +293,17 @@ struct ExtractTFHEStatisticsPass
|
||||
auto resultingKey = op.getType().getKey().getNormalized();
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto keys = std::vector<std::pair<KeyType, int64_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
std::pair<KeyType, int64_t> key =
|
||||
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
// clear - encrypted = clear + neg(encrypted)
|
||||
|
||||
auto operation = PrimitiveOperation::ENCRYPTED_NEGATION;
|
||||
pass.feedback.statistics.push_back(concretelang::Statistic{
|
||||
pass.circuitFeedback->statistics.push_back(concretelang::Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
@@ -297,7 +311,7 @@ struct ExtractTFHEStatisticsPass
|
||||
});
|
||||
|
||||
operation = PrimitiveOperation::CLEAR_ADDITION;
|
||||
pass.feedback.statistics.push_back(concretelang::Statistic{
|
||||
pass.circuitFeedback->statistics.push_back(concretelang::Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
@@ -319,20 +333,20 @@ struct ExtractTFHEStatisticsPass
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::WOP_PBS;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto keys = std::vector<std::pair<KeyType, int64_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex());
|
||||
std::pair<KeyType, int64_t> key =
|
||||
std::make_pair(KeyType::BOOTSTRAP, (int64_t)bsk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
key = std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex());
|
||||
key = std::make_pair(KeyType::KEY_SWITCH, (int64_t)ksk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
key = std::make_pair(KeyType::PACKING_KEY_SWITCH, (size_t)pksk.getIndex());
|
||||
key = std::make_pair(KeyType::PACKING_KEY_SWITCH, (int64_t)pksk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(concretelang::Statistic{
|
||||
pass.circuitFeedback->statistics.push_back(concretelang::Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
@@ -342,13 +356,13 @@ struct ExtractTFHEStatisticsPass
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
size_t iterations = 1;
|
||||
int64_t iterations = 1;
|
||||
};
|
||||
|
||||
} // namespace TFHE
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createStatisticExtractionPass(CompilationFeedback &feedback) {
|
||||
createStatisticExtractionPass(ProgramCompilationFeedback &feedback) {
|
||||
return std::make_unique<TFHE::ExtractTFHEStatisticsPass>(feedback);
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,9 @@
|
||||
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
#include "concrete-protocol.capnp.h"
|
||||
#include "concretelang/Support/CompilationFeedback.h"
|
||||
|
||||
using concretelang::protocol::Message;
|
||||
@@ -13,60 +15,11 @@ using concretelang::protocol::Message;
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
void CompilationFeedback::fillFromProgramInfo(
|
||||
const Message<concreteprotocol::ProgramInfo> &programInfo) {
|
||||
auto params = programInfo.asReader();
|
||||
|
||||
// Compute the size of secret keys
|
||||
totalSecretKeysSize = 0;
|
||||
for (auto skInfo : params.getKeyset().getLweSecretKeys()) {
|
||||
assert(skInfo.getParams().getIntegerPrecision() % 8 == 0);
|
||||
auto byteSize = skInfo.getParams().getIntegerPrecision() / 8;
|
||||
totalSecretKeysSize += skInfo.getParams().getLweDimension() * byteSize;
|
||||
}
|
||||
// Compute the boostrap keys size
|
||||
totalBootstrapKeysSize = 0;
|
||||
for (auto bskInfo : params.getKeyset().getLweBootstrapKeys()) {
|
||||
assert(bskInfo.getInputId() <
|
||||
(uint32_t)params.getKeyset().getLweSecretKeys().size());
|
||||
auto inputKeyInfo =
|
||||
params.getKeyset().getLweSecretKeys()[bskInfo.getInputId()];
|
||||
assert(bskInfo.getOutputId() <
|
||||
(uint32_t)params.getKeyset().getLweSecretKeys().size());
|
||||
auto outputKeyInfo =
|
||||
params.getKeyset().getLweSecretKeys()[bskInfo.getOutputId()];
|
||||
assert(bskInfo.getParams().getIntegerPrecision() % 8 == 0);
|
||||
auto byteSize = bskInfo.getParams().getIntegerPrecision() % 8;
|
||||
auto inputLweSize = inputKeyInfo.getParams().getLweDimension() + 1;
|
||||
auto outputLweSize = outputKeyInfo.getParams().getLweDimension() + 1;
|
||||
auto level = bskInfo.getParams().getLevelCount();
|
||||
auto glweDimension = bskInfo.getParams().getGlweDimension();
|
||||
totalBootstrapKeysSize += inputLweSize * level * (glweDimension + 1) *
|
||||
(glweDimension + 1) * outputLweSize * byteSize;
|
||||
}
|
||||
// Compute the keyswitch keys size
|
||||
totalKeyswitchKeysSize = 0;
|
||||
for (auto kskInfo : params.getKeyset().getLweKeyswitchKeys()) {
|
||||
assert(kskInfo.getInputId() <
|
||||
(uint32_t)params.getKeyset().getLweSecretKeys().size());
|
||||
auto inputKeyInfo =
|
||||
params.getKeyset().getLweSecretKeys()[kskInfo.getInputId()];
|
||||
assert(kskInfo.getOutputId() <
|
||||
(uint32_t)params.getKeyset().getLweSecretKeys().size());
|
||||
auto outputKeyInfo =
|
||||
params.getKeyset().getLweSecretKeys()[kskInfo.getOutputId()];
|
||||
assert(kskInfo.getParams().getIntegerPrecision() % 8 == 0);
|
||||
auto byteSize = kskInfo.getParams().getIntegerPrecision() % 8;
|
||||
auto inputLweSize = inputKeyInfo.getParams().getLweDimension() + 1;
|
||||
auto outputLweSize = outputKeyInfo.getParams().getLweDimension() + 1;
|
||||
auto level = kskInfo.getParams().getLevelCount();
|
||||
totalKeyswitchKeysSize += level * inputLweSize * outputLweSize * byteSize;
|
||||
}
|
||||
auto circuitInfo = params.getCircuits()[0];
|
||||
void CircuitCompilationFeedback::fillFromCircuitInfo(
|
||||
concreteprotocol::CircuitInfo::Reader circuitInfo) {
|
||||
auto computeGateSize =
|
||||
[&](const Message<concreteprotocol::GateInfo> &gateInfo) {
|
||||
unsigned int nElements = 1;
|
||||
// TODO: CHANGE THAT ITS WRONG
|
||||
for (auto dimension :
|
||||
gateInfo.asReader().getRawInfo().getShape().getDimensions()) {
|
||||
nElements *= dimension;
|
||||
@@ -104,17 +57,77 @@ void CompilationFeedback::fillFromProgramInfo(
|
||||
}
|
||||
}
|
||||
}
|
||||
// Sets name
|
||||
name = circuitInfo.getName().cStr();
|
||||
}
|
||||
|
||||
outcome::checked<CompilationFeedback, StringError>
|
||||
CompilationFeedback::load(std::string jsonPath) {
|
||||
void ProgramCompilationFeedback::fillFromProgramInfo(
|
||||
const Message<concreteprotocol::ProgramInfo> &programInfo) {
|
||||
auto params = programInfo.asReader();
|
||||
|
||||
// Compute the size of secret keys
|
||||
totalSecretKeysSize = 0;
|
||||
for (auto skInfo : params.getKeyset().getLweSecretKeys()) {
|
||||
assert(skInfo.getParams().getIntegerPrecision() % 8 == 0);
|
||||
auto byteSize = skInfo.getParams().getIntegerPrecision() / 8;
|
||||
totalSecretKeysSize += skInfo.getParams().getLweDimension() * byteSize;
|
||||
}
|
||||
// Compute the boostrap keys size
|
||||
totalBootstrapKeysSize = 0;
|
||||
for (auto bskInfo : params.getKeyset().getLweBootstrapKeys()) {
|
||||
assert(bskInfo.getInputId() <
|
||||
(uint32_t)params.getKeyset().getLweSecretKeys().size());
|
||||
auto inputKeyInfo =
|
||||
params.getKeyset().getLweSecretKeys()[bskInfo.getInputId()];
|
||||
assert(bskInfo.getOutputId() <
|
||||
(uint32_t)params.getKeyset().getLweSecretKeys().size());
|
||||
auto outputKeyInfo =
|
||||
params.getKeyset().getLweSecretKeys()[bskInfo.getOutputId()];
|
||||
assert(bskInfo.getParams().getIntegerPrecision() % 8 == 0);
|
||||
auto byteSize = bskInfo.getParams().getIntegerPrecision() / 8;
|
||||
auto inputLweSize = inputKeyInfo.getParams().getLweDimension() + 1;
|
||||
auto outputLweSize = outputKeyInfo.getParams().getLweDimension() + 1;
|
||||
auto level = bskInfo.getParams().getLevelCount();
|
||||
auto glweDimension = bskInfo.getParams().getGlweDimension();
|
||||
totalBootstrapKeysSize += inputLweSize * level * (glweDimension + 1) *
|
||||
(glweDimension + 1) * outputLweSize * byteSize;
|
||||
}
|
||||
// Compute the keyswitch keys size
|
||||
totalKeyswitchKeysSize = 0;
|
||||
for (auto kskInfo : params.getKeyset().getLweKeyswitchKeys()) {
|
||||
assert(kskInfo.getInputId() <
|
||||
(uint32_t)params.getKeyset().getLweSecretKeys().size());
|
||||
auto inputKeyInfo =
|
||||
params.getKeyset().getLweSecretKeys()[kskInfo.getInputId()];
|
||||
assert(kskInfo.getOutputId() <
|
||||
(uint32_t)params.getKeyset().getLweSecretKeys().size());
|
||||
auto outputKeyInfo =
|
||||
params.getKeyset().getLweSecretKeys()[kskInfo.getOutputId()];
|
||||
assert(kskInfo.getParams().getIntegerPrecision() % 8 == 0);
|
||||
auto byteSize = kskInfo.getParams().getIntegerPrecision() / 8;
|
||||
auto inputLweSize = inputKeyInfo.getParams().getLweDimension() + 1;
|
||||
auto outputLweSize = outputKeyInfo.getParams().getLweDimension() + 1;
|
||||
auto level = kskInfo.getParams().getLevelCount();
|
||||
totalKeyswitchKeysSize += level * inputLweSize * outputLweSize * byteSize;
|
||||
}
|
||||
// Compute the circuit feedbacks
|
||||
for (auto circuitInfo : params.getCircuits()) {
|
||||
CircuitCompilationFeedback feedback;
|
||||
feedback.fillFromCircuitInfo(circuitInfo);
|
||||
circuitFeedbacks.push_back(feedback);
|
||||
}
|
||||
}
|
||||
|
||||
outcome::checked<ProgramCompilationFeedback, StringError>
|
||||
ProgramCompilationFeedback::load(std::string jsonPath) {
|
||||
std::ifstream file(jsonPath);
|
||||
std::string content((std::istreambuf_iterator<char>(file)),
|
||||
(std::istreambuf_iterator<char>()));
|
||||
if (file.fail()) {
|
||||
return StringError("Cannot read file: ") << jsonPath;
|
||||
}
|
||||
auto expectedCompFeedback = llvm::json::parse<CompilationFeedback>(content);
|
||||
auto expectedCompFeedback =
|
||||
llvm::json::parse<ProgramCompilationFeedback>(content);
|
||||
if (auto err = expectedCompFeedback.takeError()) {
|
||||
return StringError("Cannot open compilation feedback: ")
|
||||
<< llvm::toString(std::move(err)) << "\n"
|
||||
@@ -123,196 +136,228 @@ CompilationFeedback::load(std::string jsonPath) {
|
||||
return expectedCompFeedback.get();
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const mlir::concretelang::CompilationFeedback &v) {
|
||||
llvm::json::Object object{
|
||||
{"complexity", v.complexity},
|
||||
{"pError", v.pError},
|
||||
{"globalPError", v.globalPError},
|
||||
{"totalSecretKeysSize", v.totalSecretKeysSize},
|
||||
{"totalBootstrapKeysSize", v.totalBootstrapKeysSize},
|
||||
{"totalKeyswitchKeysSize", v.totalKeyswitchKeysSize},
|
||||
{"totalInputsSize", v.totalInputsSize},
|
||||
{"totalOutputsSize", v.totalOutputsSize},
|
||||
{"crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs},
|
||||
};
|
||||
|
||||
auto memoryUsageObject = llvm::json::Object();
|
||||
for (auto key : v.memoryUsagePerLoc) {
|
||||
memoryUsageObject.insert({key.first, key.second});
|
||||
llvm::json::Object
|
||||
memoryUsageToJson(const std::map<std::string, int64_t> &memoryUsagePerLoc) {
|
||||
auto object = llvm::json::Object();
|
||||
for (auto key : memoryUsagePerLoc) {
|
||||
object.insert({key.first, key.second});
|
||||
}
|
||||
object.insert({"memoryUsagePerLoc", std::move(memoryUsageObject)});
|
||||
|
||||
auto statisticsJson = llvm::json::Array();
|
||||
for (auto statistic : v.statistics) {
|
||||
auto statisticJson = llvm::json::Object();
|
||||
statisticJson.insert({"location", statistic.location});
|
||||
switch (statistic.operation) {
|
||||
case PrimitiveOperation::PBS:
|
||||
statisticJson.insert({"operation", "PBS"});
|
||||
break;
|
||||
case PrimitiveOperation::WOP_PBS:
|
||||
statisticJson.insert({"operation", "WOP_PBS"});
|
||||
break;
|
||||
case PrimitiveOperation::KEY_SWITCH:
|
||||
statisticJson.insert({"operation", "KEY_SWITCH"});
|
||||
break;
|
||||
case PrimitiveOperation::CLEAR_ADDITION:
|
||||
statisticJson.insert({"operation", "CLEAR_ADDITION"});
|
||||
break;
|
||||
case PrimitiveOperation::ENCRYPTED_ADDITION:
|
||||
statisticJson.insert({"operation", "ENCRYPTED_ADDITION"});
|
||||
break;
|
||||
case PrimitiveOperation::CLEAR_MULTIPLICATION:
|
||||
statisticJson.insert({"operation", "CLEAR_MULTIPLICATION"});
|
||||
break;
|
||||
case PrimitiveOperation::ENCRYPTED_NEGATION:
|
||||
statisticJson.insert({"operation", "ENCRYPTED_NEGATION"});
|
||||
break;
|
||||
}
|
||||
auto keysJson = llvm::json::Array();
|
||||
for (auto &key : statistic.keys) {
|
||||
KeyType type = key.first;
|
||||
size_t index = key.second;
|
||||
|
||||
auto keyJson = llvm::json::Array();
|
||||
switch (type) {
|
||||
case KeyType::SECRET:
|
||||
keyJson.push_back("SECRET");
|
||||
break;
|
||||
case KeyType::BOOTSTRAP:
|
||||
keyJson.push_back("BOOTSTRAP");
|
||||
break;
|
||||
case KeyType::KEY_SWITCH:
|
||||
keyJson.push_back("KEY_SWITCH");
|
||||
break;
|
||||
case KeyType::PACKING_KEY_SWITCH:
|
||||
keyJson.push_back("PACKING_KEY_SWITCH");
|
||||
break;
|
||||
}
|
||||
keyJson.push_back((int64_t)index);
|
||||
|
||||
keysJson.push_back(std::move(keyJson));
|
||||
}
|
||||
statisticJson.insert({"keys", std::move(keysJson)});
|
||||
statisticJson.insert({"count", (int64_t)statistic.count});
|
||||
|
||||
statisticsJson.push_back(std::move(statisticJson));
|
||||
}
|
||||
object.insert({"statistics", std::move(statisticsJson)});
|
||||
|
||||
return object;
|
||||
}
|
||||
|
||||
llvm::json::Object statisticToJson(const Statistic &statistic) {
|
||||
auto object = llvm::json::Object();
|
||||
object.insert({"location", statistic.location});
|
||||
object.insert({"count", statistic.count});
|
||||
switch (statistic.operation) {
|
||||
case PrimitiveOperation::PBS:
|
||||
object.insert({"operation", "PBS"});
|
||||
break;
|
||||
case PrimitiveOperation::WOP_PBS:
|
||||
object.insert({"operation", "WOP_PBS"});
|
||||
break;
|
||||
case PrimitiveOperation::KEY_SWITCH:
|
||||
object.insert({"operation", "KEY_SWITCH"});
|
||||
break;
|
||||
case PrimitiveOperation::CLEAR_ADDITION:
|
||||
object.insert({"operation", "CLEAR_ADDITION"});
|
||||
break;
|
||||
case PrimitiveOperation::ENCRYPTED_ADDITION:
|
||||
object.insert({"operation", "ENCRYPTED_ADDITION"});
|
||||
break;
|
||||
case PrimitiveOperation::CLEAR_MULTIPLICATION:
|
||||
object.insert({"operation", "CLEAR_MULTIPLICATION"});
|
||||
break;
|
||||
case PrimitiveOperation::ENCRYPTED_NEGATION:
|
||||
object.insert({"operation", "ENCRYPTED_NEGATION"});
|
||||
break;
|
||||
}
|
||||
auto keysJson = llvm::json::Array();
|
||||
for (auto &key : statistic.keys) {
|
||||
KeyType type = key.first;
|
||||
size_t index = key.second;
|
||||
|
||||
auto keyJson = llvm::json::Array();
|
||||
switch (type) {
|
||||
case KeyType::SECRET:
|
||||
keyJson.push_back("SECRET");
|
||||
break;
|
||||
case KeyType::BOOTSTRAP:
|
||||
keyJson.push_back("BOOTSTRAP");
|
||||
break;
|
||||
case KeyType::KEY_SWITCH:
|
||||
keyJson.push_back("KEY_SWITCH");
|
||||
break;
|
||||
case KeyType::PACKING_KEY_SWITCH:
|
||||
keyJson.push_back("PACKING_KEY_SWITCH");
|
||||
break;
|
||||
}
|
||||
keyJson.push_back((int64_t)index);
|
||||
|
||||
keysJson.push_back(std::move(keyJson));
|
||||
}
|
||||
object.insert({"keys", std::move(keysJson)});
|
||||
return object;
|
||||
}
|
||||
|
||||
llvm::json::Array statisticsToJson(const std::vector<Statistic> &statistics) {
|
||||
auto object = llvm::json::Array();
|
||||
for (auto statistic : statistics) {
|
||||
object.push_back(statisticToJson(statistic));
|
||||
}
|
||||
return object;
|
||||
}
|
||||
|
||||
llvm::json::Array crtDecompositionToJson(
|
||||
const std::vector<std::vector<int64_t>> &crtDecompositionsOfOutputs) {
|
||||
auto object = llvm::json::Array();
|
||||
for (auto crtDec : crtDecompositionsOfOutputs) {
|
||||
auto inner = llvm::json::Array();
|
||||
for (auto val : crtDec) {
|
||||
inner.push_back(val);
|
||||
}
|
||||
object.push_back(std::move(inner));
|
||||
}
|
||||
return object;
|
||||
}
|
||||
|
||||
llvm::json::Array circuitFeedbacksToJson(
|
||||
const std::vector<CircuitCompilationFeedback> &circuitFeedbacks) {
|
||||
auto object = llvm::json::Array();
|
||||
for (auto circuit : circuitFeedbacks) {
|
||||
llvm::json::Object circuitObject{
|
||||
{"name", circuit.name},
|
||||
{"totalInputsSize", circuit.totalInputsSize},
|
||||
{"totalOutputsSize", circuit.totalOutputsSize},
|
||||
{"crtDecompositionsOfOutputs",
|
||||
crtDecompositionToJson(circuit.crtDecompositionsOfOutputs)},
|
||||
{"statistics", statisticsToJson(circuit.statistics)},
|
||||
{"memoryUsagePerLoc", memoryUsageToJson(circuit.memoryUsagePerLoc)},
|
||||
};
|
||||
object.push_back(std::move(circuitObject));
|
||||
}
|
||||
return object;
|
||||
}
|
||||
|
||||
llvm::json::Value
|
||||
toJSON(const mlir::concretelang::ProgramCompilationFeedback &program) {
|
||||
llvm::json::Object programObject{
|
||||
{"complexity", program.complexity},
|
||||
{"pError", program.pError},
|
||||
{"globalPError", program.globalPError},
|
||||
{"totalSecretKeysSize", program.totalSecretKeysSize},
|
||||
{"totalBootstrapKeysSize", program.totalBootstrapKeysSize},
|
||||
{"totalKeyswitchKeysSize", program.totalKeyswitchKeysSize},
|
||||
{"circuitFeedbacks", circuitFeedbacksToJson(program.circuitFeedbacks)}};
|
||||
return programObject;
|
||||
}
|
||||
|
||||
template <typename K, typename V>
|
||||
bool fromJSON(const llvm::json::Value &j, std::pair<K, V> &v,
|
||||
llvm::json::Path p) {
|
||||
if (auto *array = j.getAsArray()) {
|
||||
if (!fromJSON((*array)[0], v.first, p.index(0)))
|
||||
return false;
|
||||
if (!fromJSON((*array)[1], v.second, p.index(1)))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
p.report("expected array");
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value j,
|
||||
mlir::concretelang::CompilationFeedback &v, llvm::json::Path p) {
|
||||
mlir::concretelang::PrimitiveOperation &v, llvm::json::Path p) {
|
||||
if (auto operationString = j.getAsString()) {
|
||||
if (operationString == "PBS") {
|
||||
v = PrimitiveOperation::PBS;
|
||||
return true;
|
||||
} else if (operationString == "KEY_SWITCH") {
|
||||
v = PrimitiveOperation::KEY_SWITCH;
|
||||
return true;
|
||||
} else if (operationString == "WOP_PBS") {
|
||||
v = PrimitiveOperation::WOP_PBS;
|
||||
return true;
|
||||
} else if (operationString == "CLEAR_ADDITION") {
|
||||
v = PrimitiveOperation::CLEAR_ADDITION;
|
||||
return true;
|
||||
} else if (operationString == "ENCRYPTED_ADDITION") {
|
||||
v = PrimitiveOperation::ENCRYPTED_ADDITION;
|
||||
return true;
|
||||
} else if (operationString == "CLEAR_MULTIPLICATION") {
|
||||
v = PrimitiveOperation::CLEAR_MULTIPLICATION;
|
||||
return true;
|
||||
} else if (operationString == "ENCRYPTED_NEGATION") {
|
||||
v = PrimitiveOperation::ENCRYPTED_NEGATION;
|
||||
return true;
|
||||
} else {
|
||||
p.report("expected one of "
|
||||
"(PBS|KEY_SWITCH|WOP_PBS|CLEAR_ADDITION|ENCRYPTED_ADDITION|"
|
||||
"CLEAR_MULTIPLICATION|ENCRYPTED_NEGATION)");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
p.report("expected string");
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value j, mlir::concretelang::KeyType &v,
|
||||
llvm::json::Path p) {
|
||||
if (auto keyTypeString = j.getAsString()) {
|
||||
if (keyTypeString == "SECRET") {
|
||||
v = KeyType::SECRET;
|
||||
return true;
|
||||
} else if (keyTypeString == "BOOTSTRAP") {
|
||||
v = KeyType::BOOTSTRAP;
|
||||
return true;
|
||||
} else if (keyTypeString == "KEY_SWITCH") {
|
||||
v = KeyType::KEY_SWITCH;
|
||||
return true;
|
||||
} else if (keyTypeString == "PACKING_KEY_SWITCH") {
|
||||
v = KeyType::PACKING_KEY_SWITCH;
|
||||
return true;
|
||||
} else {
|
||||
p.report(
|
||||
"expected one of (SECRET|BOOTSTRAP|KEY_SWITCH|PACKING_KEY_SWITCH)");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
p.report("expected string");
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value j, mlir::concretelang::Statistic &v,
|
||||
llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
|
||||
bool is_success =
|
||||
O && O.map("complexity", v.complexity) && O.map("pError", v.pError) &&
|
||||
O.map("globalPError", v.globalPError) &&
|
||||
O.map("totalSecretKeysSize", v.totalSecretKeysSize) &&
|
||||
O.map("totalBootstrapKeysSize", v.totalBootstrapKeysSize) &&
|
||||
O.map("totalKeyswitchKeysSize", v.totalKeyswitchKeysSize) &&
|
||||
O.map("totalInputsSize", v.totalInputsSize) &&
|
||||
O.map("totalOutputsSize", v.totalOutputsSize) &&
|
||||
O.map("crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs);
|
||||
return O && O.map("location", v.location) &&
|
||||
O.map("operation", v.operation) && O.map("operation", v.operation) &&
|
||||
O.map("keys", v.keys) && O.map("count", v.count);
|
||||
}
|
||||
|
||||
if (!is_success) {
|
||||
return false;
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j,
|
||||
mlir::concretelang::CircuitCompilationFeedback &v,
|
||||
llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("name", v.name) &&
|
||||
O.map("totalInputsSize", v.totalInputsSize) &&
|
||||
O.map("totalOutputsSize", v.totalOutputsSize) &&
|
||||
O.map("crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs) &&
|
||||
O.map("statistics", v.statistics) &&
|
||||
O.map("memoryUsagePerLoc", v.memoryUsagePerLoc);
|
||||
}
|
||||
|
||||
auto object = j.getAsObject();
|
||||
if (!object) {
|
||||
return false;
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j,
|
||||
mlir::concretelang::ProgramCompilationFeedback &v,
|
||||
llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
|
||||
auto memoryUsageObject = object->getObject("memoryUsagePerLoc");
|
||||
if (!memoryUsageObject) {
|
||||
return false;
|
||||
}
|
||||
for (auto entry : *memoryUsageObject) {
|
||||
auto loc = entry.getFirst().str();
|
||||
auto maybeUsage = entry.getSecond().getAsInteger();
|
||||
if (!maybeUsage.has_value()) {
|
||||
return false;
|
||||
}
|
||||
v.memoryUsagePerLoc[loc] = *maybeUsage;
|
||||
}
|
||||
|
||||
auto statistics = object->getArray("statistics");
|
||||
if (!statistics) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto statisticValue : *statistics) {
|
||||
auto statistic = statisticValue.getAsObject();
|
||||
if (!statistic) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto location = statistic->getString("location");
|
||||
auto operationStr = statistic->getString("operation");
|
||||
auto keysArray = statistic->getArray("keys");
|
||||
auto count = statistic->getInteger("count");
|
||||
|
||||
if (!operationStr || !location || !keysArray || !count) {
|
||||
return false;
|
||||
}
|
||||
|
||||
PrimitiveOperation operation;
|
||||
if (operationStr.value() == "PBS") {
|
||||
operation = PrimitiveOperation::PBS;
|
||||
} else if (operationStr.value() == "KEY_SWITCH") {
|
||||
operation = PrimitiveOperation::KEY_SWITCH;
|
||||
} else if (operationStr.value() == "WOP_PBS") {
|
||||
operation = PrimitiveOperation::WOP_PBS;
|
||||
} else if (operationStr.value() == "CLEAR_ADDITION") {
|
||||
operation = PrimitiveOperation::CLEAR_ADDITION;
|
||||
} else if (operationStr.value() == "ENCRYPTED_ADDITION") {
|
||||
operation = PrimitiveOperation::ENCRYPTED_ADDITION;
|
||||
} else if (operationStr.value() == "CLEAR_MULTIPLICATION") {
|
||||
operation = PrimitiveOperation::CLEAR_MULTIPLICATION;
|
||||
} else if (operationStr.value() == "ENCRYPTED_NEGATION") {
|
||||
operation = PrimitiveOperation::ENCRYPTED_NEGATION;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
for (auto keyValue : *keysArray) {
|
||||
llvm::json::Array *keyArray = keyValue.getAsArray();
|
||||
if (!keyArray || keyArray->size() != 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto typeStr = keyArray->front().getAsString();
|
||||
auto index = keyArray->back().getAsInteger();
|
||||
|
||||
if (!typeStr || !index) {
|
||||
return false;
|
||||
}
|
||||
|
||||
KeyType type;
|
||||
if (typeStr.value() == "SECRET") {
|
||||
type = KeyType::SECRET;
|
||||
} else if (typeStr.value() == "BOOTSTRAP") {
|
||||
type = KeyType::BOOTSTRAP;
|
||||
} else if (typeStr.value() == "KEY_SWITCH") {
|
||||
type = KeyType::KEY_SWITCH;
|
||||
} else if (typeStr.value() == "PACKING_KEY_SWITCH") {
|
||||
type = KeyType::PACKING_KEY_SWITCH;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
keys.push_back(std::make_pair(type, (size_t)*index));
|
||||
}
|
||||
|
||||
v.statistics.push_back(
|
||||
Statistic{location->str(), operation, keys, (uint64_t)*count});
|
||||
}
|
||||
|
||||
return true;
|
||||
return O && O.map("complexity", v.complexity) && O.map("pError", v.pError) &&
|
||||
O.map("globalPError", v.globalPError) &&
|
||||
O.map("totalSecretKeysSize", v.totalSecretKeysSize) &&
|
||||
O.map("totalBootstrapKeysSize", v.totalBootstrapKeysSize) &&
|
||||
O.map("totalKeyswitchKeysSize", v.totalKeyswitchKeysSize) &&
|
||||
O.map("circuitFeedbacks", v.circuitFeedbacks);
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/Support/V0Parameters.h"
|
||||
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
@@ -166,19 +167,33 @@ CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) {
|
||||
if (descriptions->empty()) { // The pass has not been run
|
||||
return std::nullopt;
|
||||
}
|
||||
if (this->compilerOptions.mainFuncName.has_value()) {
|
||||
auto name = this->compilerOptions.mainFuncName.value();
|
||||
auto description = descriptions->find(name);
|
||||
if (description == descriptions->end()) {
|
||||
std::string names;
|
||||
return StreamStringError("Function not found, name='")
|
||||
<< name << "', cannot get optimizer description";
|
||||
}
|
||||
return std::move(description->second);
|
||||
if (descriptions->size() > 1 &&
|
||||
config.strategy !=
|
||||
mlir::concretelang::optimizer::V0) { // Multi circuits without V0
|
||||
return StreamStringError(
|
||||
"Multi-circuits is only supported for V0 optimization.");
|
||||
}
|
||||
if (descriptions->size() != 1) {
|
||||
llvm::errs() << "Several crypto parameters exists: the function need to be "
|
||||
"specified, taking the first one";
|
||||
if (descriptions->size() > 1) {
|
||||
auto iter = descriptions->begin();
|
||||
auto desc = std::move(iter->second);
|
||||
if (!desc.has_value()) {
|
||||
return StreamStringError("Expected description.");
|
||||
}
|
||||
if (!desc.value().dag.has_value()) {
|
||||
return StreamStringError("Expected dag in description.");
|
||||
}
|
||||
iter++;
|
||||
while (iter != descriptions->end()) {
|
||||
if (!iter->second.has_value()) {
|
||||
return StreamStringError("Expected description.");
|
||||
}
|
||||
if (!iter->second.value().dag.has_value()) {
|
||||
return StreamStringError("Expected dag in description.");
|
||||
}
|
||||
desc->dag.value()->concat(*iter->second.value().dag.value());
|
||||
iter++;
|
||||
}
|
||||
return std::move(desc);
|
||||
}
|
||||
return std::move(descriptions->begin()->second);
|
||||
}
|
||||
@@ -199,7 +214,7 @@ llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
|
||||
res.fheContext.emplace(
|
||||
mlir::concretelang::V0FHEContext{constraint, v0Params});
|
||||
|
||||
CompilationFeedback feedback;
|
||||
ProgramCompilationFeedback feedback;
|
||||
res.feedback.emplace(feedback);
|
||||
|
||||
return llvm::Error::success();
|
||||
@@ -213,7 +228,7 @@ llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
|
||||
if (!descr.get().has_value()) {
|
||||
return llvm::Error::success();
|
||||
}
|
||||
CompilationFeedback feedback;
|
||||
ProgramCompilationFeedback feedback;
|
||||
// Make sure to use the gpu constraint of the optimizer if we use gpu
|
||||
// backend.
|
||||
compilerOptions.optimizerConfig.use_gpu_constraints =
|
||||
@@ -322,9 +337,8 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target,
|
||||
// on the `FHE` dialect.
|
||||
if ((this->generateProgramInfo || target == Target::LIBRARY) &&
|
||||
!options.encodings) {
|
||||
auto funcName = options.mainFuncName.value_or("main");
|
||||
auto encodingInfosOrErr =
|
||||
mlir::concretelang::encodings::getCircuitEncodings(funcName, module);
|
||||
mlir::concretelang::encodings::getProgramEncoding(module);
|
||||
if (!encodingInfosOrErr) {
|
||||
return encodingInfosOrErr.takeError();
|
||||
}
|
||||
@@ -363,7 +377,7 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target,
|
||||
chunkedMode.asBuilder().setWidth(options.chunkWidth);
|
||||
maybeChunkInfo = chunkedMode;
|
||||
}
|
||||
mlir::concretelang::encodings::setCircuitEncodingModes(
|
||||
mlir::concretelang::encodings::setProgramEncodingModes(
|
||||
*options.encodings, maybeChunkInfo, res.fheContext);
|
||||
}
|
||||
|
||||
@@ -457,41 +471,29 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target,
|
||||
|
||||
// Generate client parameters if requested
|
||||
if (this->generateProgramInfo) {
|
||||
if (!options.mainFuncName.has_value()) {
|
||||
return StreamStringError(
|
||||
"Generation of client parameters requested, but no function name "
|
||||
"specified");
|
||||
}
|
||||
if (!res.fheContext.has_value()) {
|
||||
return StreamStringError(
|
||||
"Cannot generate client parameters, the fhe context is empty for " +
|
||||
options.mainFuncName.value());
|
||||
"Cannot generate client parameters, the fhe context is empty");
|
||||
}
|
||||
}
|
||||
// Generate program info if requested
|
||||
if (this->generateProgramInfo || target == Target::LIBRARY) {
|
||||
auto funcName = options.mainFuncName.value_or("main");
|
||||
if (!res.fheContext.has_value()) {
|
||||
// Some tests involve call a to non encrypted functions
|
||||
auto programInfo = Message<concreteprotocol::ProgramInfo>();
|
||||
programInfo.asBuilder().initCircuits(1);
|
||||
programInfo.asBuilder().getCircuits()[0].setName(std::string(funcName));
|
||||
programInfo.asBuilder().getCircuits()[0].setName(std::string("main"));
|
||||
res.programInfo = programInfo;
|
||||
} else {
|
||||
auto programInfoOrErr =
|
||||
mlir::concretelang::createProgramInfoFromTfheDialect(
|
||||
module, funcName, options.optimizerConfig.security,
|
||||
module, options.optimizerConfig.security,
|
||||
options.encodings.value(), options.compressEvaluationKeys);
|
||||
|
||||
if (!programInfoOrErr)
|
||||
return programInfoOrErr.takeError();
|
||||
|
||||
res.programInfo = std::move(*programInfoOrErr);
|
||||
// If more than one circuit, feedback can not be generated for now ..
|
||||
if (res.programInfo->asReader().getCircuits().size() != 1) {
|
||||
return StreamStringError(
|
||||
"Cannot generate feedback for program with more than one circuit.");
|
||||
}
|
||||
res.feedback->fillFromProgramInfo(*res.programInfo);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
namespace FHE = mlir::concretelang::FHE;
|
||||
using concretelang::protocol::Message;
|
||||
@@ -67,20 +68,13 @@ encodingFromType(mlir::Type ty) {
|
||||
}
|
||||
|
||||
llvm::Expected<Message<concreteprotocol::CircuitEncodingInfo>>
|
||||
getCircuitEncodings(llvm::StringRef functionName, mlir::ModuleOp module) {
|
||||
// Find the input function
|
||||
auto rangeOps = module.getOps<mlir::func::FuncOp>();
|
||||
auto funcOp = llvm::find_if(rangeOps, [&](mlir::func::FuncOp op) {
|
||||
return op.getName() == functionName;
|
||||
});
|
||||
if (funcOp == rangeOps.end()) {
|
||||
return StreamStringError("Function not found, name='")
|
||||
<< functionName << "', cannot get circuit encodings";
|
||||
}
|
||||
auto funcType = (*funcOp).getFunctionType();
|
||||
getCircuitEncodings(mlir::func::FuncOp funcOp) {
|
||||
|
||||
auto funcType = funcOp.getFunctionType();
|
||||
|
||||
// Retrieve input/output encodings
|
||||
auto circuitEncodings = Message<concreteprotocol::CircuitEncodingInfo>();
|
||||
circuitEncodings.asBuilder().setName(funcOp.getSymName().str());
|
||||
auto inputsBuilder =
|
||||
circuitEncodings.asBuilder().initInputs(funcType.getNumInputs());
|
||||
for (size_t i = 0; i < funcType.getNumInputs(); i++) {
|
||||
@@ -105,8 +99,32 @@ getCircuitEncodings(llvm::StringRef functionName, mlir::ModuleOp module) {
|
||||
return std::move(circuitEncodings);
|
||||
}
|
||||
|
||||
llvm::Expected<Message<concreteprotocol::ProgramEncodingInfo>>
|
||||
getProgramEncoding(mlir::ModuleOp module) {
|
||||
|
||||
auto funcs = module.getOps<mlir::func::FuncOp>();
|
||||
auto circuitEncodings =
|
||||
std::vector<Message<concreteprotocol::CircuitEncodingInfo>>();
|
||||
for (auto func : funcs) {
|
||||
auto encodingInfosOrErr = getCircuitEncodings(func);
|
||||
if (!encodingInfosOrErr) {
|
||||
return encodingInfosOrErr.takeError();
|
||||
}
|
||||
circuitEncodings.push_back(*encodingInfosOrErr);
|
||||
}
|
||||
|
||||
auto programEncoding = Message<concreteprotocol::ProgramEncodingInfo>();
|
||||
auto circuitBuilder =
|
||||
programEncoding.asBuilder().initCircuits(circuitEncodings.size());
|
||||
for (size_t i = 0; i < circuitEncodings.size(); i++) {
|
||||
circuitBuilder.setWithCaveats(i, circuitEncodings[i].asReader());
|
||||
}
|
||||
|
||||
return std::move(programEncoding);
|
||||
}
|
||||
|
||||
void setCircuitEncodingModes(
|
||||
Message<concreteprotocol::CircuitEncodingInfo> &info,
|
||||
concreteprotocol::CircuitEncodingInfo::Builder info,
|
||||
std::optional<
|
||||
Message<concreteprotocol::IntegerCiphertextEncodingInfo::ChunkedMode>>
|
||||
maybeChunk,
|
||||
@@ -164,13 +182,25 @@ void setCircuitEncodingModes(
|
||||
// Got nothing particular. Setting encoding mode to native.
|
||||
integerEncodingBuilder.getMode().initNative();
|
||||
};
|
||||
for (auto encInfoBuilder : info.asBuilder().getInputs()) {
|
||||
for (auto encInfoBuilder : info.getInputs()) {
|
||||
setMode(encInfoBuilder);
|
||||
}
|
||||
for (auto encInfoBuilder : info.asBuilder().getOutputs()) {
|
||||
for (auto encInfoBuilder : info.getOutputs()) {
|
||||
setMode(encInfoBuilder);
|
||||
}
|
||||
}
|
||||
|
||||
void setProgramEncodingModes(
|
||||
Message<concreteprotocol::ProgramEncodingInfo> &info,
|
||||
std::optional<
|
||||
Message<concreteprotocol::IntegerCiphertextEncodingInfo::ChunkedMode>>
|
||||
maybeChunk,
|
||||
std::optional<V0FHEContext> maybeFheContext) {
|
||||
for (auto circuitInfo : info.asBuilder().getCircuits()) {
|
||||
setCircuitEncodingModes(circuitInfo, maybeChunk, maybeFheContext);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace encodings
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
|
||||
#include "concretelang/Support/CompilationFeedback.h"
|
||||
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
|
||||
#include "mlir/Conversion/Passes.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
|
||||
@@ -345,7 +346,7 @@ normalizeTFHEKeys(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
mlir::LogicalResult
|
||||
extractTFHEStatistics(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
CompilationFeedback &feedback) {
|
||||
ProgramCompilationFeedback &feedback) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("TFHEStatistics", pm, context);
|
||||
|
||||
@@ -371,7 +372,7 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
mlir::LogicalResult
|
||||
computeMemoryUsage(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
CompilationFeedback &feedback) {
|
||||
ProgramCompilationFeedback &feedback) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("Computing Memory Usage", pm, context);
|
||||
|
||||
|
||||
@@ -296,32 +296,22 @@ extractKeysetInfo(TFHE::TFHECircuitKeys circuitKeys,
|
||||
}
|
||||
|
||||
llvm::Expected<Message<concreteprotocol::CircuitInfo>>
|
||||
extractCircuitInfo(mlir::ModuleOp module, llvm::StringRef functionName,
|
||||
Message<concreteprotocol::CircuitEncodingInfo> &encodings,
|
||||
extractCircuitInfo(mlir::func::FuncOp funcOp,
|
||||
concreteprotocol::CircuitEncodingInfo::Reader encodings,
|
||||
concrete::SecurityCurve curve) {
|
||||
|
||||
auto output = Message<concreteprotocol::CircuitInfo>();
|
||||
|
||||
// Check that the specified function can be found
|
||||
auto rangeOps = module.getOps<mlir::func::FuncOp>();
|
||||
auto funcOp = llvm::find_if(rangeOps, [&](mlir::func::FuncOp op) {
|
||||
return op.getName() == functionName;
|
||||
});
|
||||
if (funcOp == rangeOps.end()) {
|
||||
return StreamStringError(
|
||||
"cannot find the function for generate client parameters: ")
|
||||
<< functionName;
|
||||
}
|
||||
// Create input and output circuit gate parameters
|
||||
auto funcType = (*funcOp).getFunctionType();
|
||||
auto funcType = funcOp.getFunctionType();
|
||||
|
||||
output.asBuilder().setName(functionName.str());
|
||||
output.asBuilder().setName(encodings.getName().cStr());
|
||||
output.asBuilder().initInputs(funcType.getNumInputs());
|
||||
output.asBuilder().initOutputs(funcType.getNumResults());
|
||||
|
||||
for (unsigned int i = 0; i < funcType.getNumInputs(); i++) {
|
||||
auto ty = funcType.getInput(i);
|
||||
auto encoding = encodings.asReader().getInputs()[i];
|
||||
auto encoding = encodings.getInputs()[i];
|
||||
auto maybeGate = generateGate(ty, encoding, curve);
|
||||
if (!maybeGate) {
|
||||
return maybeGate.takeError();
|
||||
@@ -330,7 +320,7 @@ extractCircuitInfo(mlir::ModuleOp module, llvm::StringRef functionName,
|
||||
}
|
||||
for (unsigned int i = 0; i < funcType.getNumResults(); i++) {
|
||||
auto ty = funcType.getResult(i);
|
||||
auto encoding = encodings.asReader().getOutputs()[i];
|
||||
auto encoding = encodings.getOutputs()[i];
|
||||
auto maybeGate = generateGate(ty, encoding, curve);
|
||||
if (!maybeGate) {
|
||||
return maybeGate.takeError();
|
||||
@@ -341,10 +331,43 @@ extractCircuitInfo(mlir::ModuleOp module, llvm::StringRef functionName,
|
||||
return output;
|
||||
}
|
||||
|
||||
llvm::Expected<Message<concreteprotocol::ProgramInfo>> extractProgramInfo(
|
||||
mlir::ModuleOp module,
|
||||
const Message<concreteprotocol::ProgramEncodingInfo> &encodings,
|
||||
concrete::SecurityCurve curve) {
|
||||
|
||||
auto output = Message<concreteprotocol::ProgramInfo>();
|
||||
auto circuitsCount = encodings.asReader().getCircuits().size();
|
||||
auto circuitsBuilder = output.asBuilder().initCircuits(circuitsCount);
|
||||
auto rangeOps = module.getOps<mlir::func::FuncOp>();
|
||||
|
||||
for (size_t i = 0; i < circuitsCount; i++) {
|
||||
auto circuitEncoding = encodings.asReader().getCircuits()[i];
|
||||
auto functionName = circuitEncoding.getName();
|
||||
auto funcOp = llvm::find_if(rangeOps, [&](mlir::func::FuncOp op) {
|
||||
return op.getName() == functionName.cStr();
|
||||
});
|
||||
if (funcOp == rangeOps.end()) {
|
||||
return StreamStringError("cannot find the following function to generate "
|
||||
"program info: ")
|
||||
<< functionName.cStr();
|
||||
}
|
||||
|
||||
auto maybeCircuitInfo = extractCircuitInfo(*funcOp, circuitEncoding, curve);
|
||||
if (!maybeCircuitInfo) {
|
||||
return maybeCircuitInfo.takeError();
|
||||
}
|
||||
|
||||
circuitsBuilder.setWithCaveats(i, (*maybeCircuitInfo).asReader());
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
llvm::Expected<Message<concreteprotocol::ProgramInfo>>
|
||||
createProgramInfoFromTfheDialect(
|
||||
mlir::ModuleOp module, llvm::StringRef functionName, int bitsOfSecurity,
|
||||
Message<concreteprotocol::CircuitEncodingInfo> &encodings,
|
||||
mlir::ModuleOp module, int bitsOfSecurity,
|
||||
const Message<concreteprotocol::ProgramEncodingInfo> &encodings,
|
||||
bool compressEvaluationKeys) {
|
||||
|
||||
// Check that security curves exist
|
||||
@@ -354,23 +377,20 @@ createProgramInfoFromTfheDialect(
|
||||
<< bitsOfSecurity << "bits";
|
||||
}
|
||||
|
||||
// Create the output Program Info.
|
||||
auto output = Message<concreteprotocol::ProgramInfo>();
|
||||
// We generate the circuit infos from the module.
|
||||
auto maybeProgramInfo = extractProgramInfo(module, encodings, *curve);
|
||||
if (!maybeProgramInfo) {
|
||||
return maybeProgramInfo.takeError();
|
||||
}
|
||||
|
||||
// Extract the output Program Info.
|
||||
Message<concreteprotocol::ProgramInfo> output = *maybeProgramInfo;
|
||||
|
||||
// We extract the keys of the circuit
|
||||
auto keysetInfo = extractKeysetInfo(TFHE::extractCircuitKeys(module), *curve,
|
||||
compressEvaluationKeys);
|
||||
output.asBuilder().setKeyset(keysetInfo.asReader());
|
||||
|
||||
// We generate the gates for the inputs aud outputs
|
||||
auto maybeCircuitInfo =
|
||||
extractCircuitInfo(module, functionName, encodings, *curve);
|
||||
if (!maybeCircuitInfo) {
|
||||
return maybeCircuitInfo.takeError();
|
||||
}
|
||||
output.asBuilder().initCircuits(1).setWithCaveats(
|
||||
0, maybeCircuitInfo->asReader());
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
|
||||
@@ -268,7 +268,7 @@ optimizer::Solution convertSolution(optimizer::CircuitSolution sol) {
|
||||
|
||||
/// Fill the compilation `feedback` from a `solution` returned by the optmizer.
|
||||
template <typename Solution>
|
||||
void fillFeedback(Solution solution, CompilationFeedback &feedback) {
|
||||
void fillFeedback(Solution solution, ProgramCompilationFeedback &feedback) {
|
||||
feedback.complexity = solution.complexity;
|
||||
feedback.pError = solution.p_error;
|
||||
feedback.globalPError =
|
||||
@@ -314,7 +314,7 @@ llvm::Error checkPErrorSolution(Solution solution, optimizer::Config config) {
|
||||
/// optimizer::Solution, and fill the `feedback`.
|
||||
template <typename Solution>
|
||||
llvm::Expected<optimizer::Solution>
|
||||
toCompilerSolution(Solution solution, CompilationFeedback &feedback,
|
||||
toCompilerSolution(Solution solution, ProgramCompilationFeedback &feedback,
|
||||
optimizer::Config config) {
|
||||
// display(descr, config, sol, naive_user, duration);
|
||||
if (auto err = checkPErrorSolution(solution, config); err) {
|
||||
@@ -334,9 +334,9 @@ optimizer::Solution emptySolution() {
|
||||
return solution;
|
||||
}
|
||||
|
||||
llvm::Expected<optimizer::Solution> getSolution(optimizer::Description &descr,
|
||||
CompilationFeedback &feedback,
|
||||
optimizer::Config config) {
|
||||
llvm::Expected<optimizer::Solution>
|
||||
getSolution(optimizer::Description &descr, ProgramCompilationFeedback &feedback,
|
||||
optimizer::Config config) {
|
||||
namespace chrono = std::chrono;
|
||||
// auto start = chrono::high_resolution_clock::now();
|
||||
auto naive_user =
|
||||
|
||||
@@ -223,11 +223,6 @@ llvm::cl::opt<bool> dataflowParallelize(
|
||||
llvm::cl::desc("Generate the program as a dataflow graph"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<std::string>
|
||||
funcName("funcname",
|
||||
llvm::cl::desc("Name of the function to compile, default 'main'"),
|
||||
llvm::cl::init<std::string>(""));
|
||||
|
||||
llvm::cl::opt<bool>
|
||||
chunkIntegers("chunk-integers",
|
||||
llvm::cl::desc("Whether to decompose integer into chunks or "
|
||||
@@ -379,12 +374,17 @@ llvm::cl::list<int64_t> largeIntegerCircuitBootstrap(
|
||||
"(experimental) [level, baseLog]"),
|
||||
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
|
||||
|
||||
llvm::cl::opt<std::string> circuitEncodings(
|
||||
"circuit-encodings",
|
||||
llvm::cl::desc("Specify the input and output encodings of the circuit, "
|
||||
llvm::cl::opt<std::string> programEncoding(
|
||||
"program-encoding",
|
||||
llvm::cl::desc("Specify the encodings to use for the program, "
|
||||
"using the JSON representation."),
|
||||
llvm::cl::init(std::string{}));
|
||||
|
||||
llvm::cl::opt<bool> skipProgramInfo(
|
||||
"skip-program-info",
|
||||
llvm::cl::desc("Skip generating the program info artefacts."),
|
||||
llvm::cl::init(false));
|
||||
|
||||
} // namespace cmdline
|
||||
|
||||
namespace llvm {
|
||||
@@ -424,6 +424,7 @@ cmdlineCompilationOptions() {
|
||||
options.chunkIntegers = cmdline::chunkIntegers;
|
||||
options.chunkSize = cmdline::chunkSize;
|
||||
options.chunkWidth = cmdline::chunkWidth;
|
||||
options.skipProgramInfo = cmdline::skipProgramInfo;
|
||||
|
||||
if (!cmdline::v0Constraint.empty()) {
|
||||
if (cmdline::v0Constraint.size() != 2) {
|
||||
@@ -435,10 +436,6 @@ cmdlineCompilationOptions() {
|
||||
cmdline::v0Constraint[1], cmdline::v0Constraint[0]};
|
||||
}
|
||||
|
||||
if (!cmdline::funcName.empty()) {
|
||||
options.mainFuncName = cmdline::funcName;
|
||||
}
|
||||
|
||||
// Convert tile sizes to `Optional`
|
||||
if (!cmdline::fhelinalgTileSizes.empty())
|
||||
options.fhelinalgTileSizes.emplace(cmdline::fhelinalgTileSizes);
|
||||
@@ -512,12 +509,12 @@ cmdlineCompilationOptions() {
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
if (!cmdline::circuitEncodings.empty()) {
|
||||
auto jsonString = cmdline::circuitEncodings.getValue();
|
||||
auto encodings = Message<concreteprotocol::CircuitEncodingInfo>();
|
||||
if (!cmdline::programEncoding.empty()) {
|
||||
auto jsonString = cmdline::programEncoding.getValue();
|
||||
auto encodings = Message<concreteprotocol::ProgramEncodingInfo>();
|
||||
if (encodings.readJsonFromString(jsonString).has_failure()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"Failed to parse the --circuit-encodings option",
|
||||
"Failed to parse the --program-encoding option",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
options.encodings = encodings;
|
||||
@@ -556,10 +553,9 @@ mlir::LogicalResult processInputBuffer(
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
|
||||
mlir::concretelang::CompilationContext::createShared();
|
||||
|
||||
std::string funcName = options.mainFuncName.value_or("");
|
||||
|
||||
mlir::concretelang::CompilerEngine ce{ccx};
|
||||
ce.setCompilationOptions(std::move(options));
|
||||
ce.setGenerateProgramInfo(!options.skipProgramInfo);
|
||||
|
||||
if (cmdline::passes.size() != 0) {
|
||||
ce.setEnablePass([](mlir::Pass *pass) {
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
// RUN: concretecompiler --action=dump-llvm-ir %s
|
||||
// RUN: concretecompiler --action=dump-llvm-ir --optimizer-strategy=V0 --skip-program-info %s
|
||||
// Just ensure that compile
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/issues/785
|
||||
func.func @main(%arg0: !FHE.eint<15>, %cst: tensor<32768xi64>) -> tensor<1x!FHE.eint<15>> {
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<15>, tensor<32768xi64>) -> !FHE.eint<15>
|
||||
%6 = tensor.from_elements %1 : tensor<1x!FHE.eint<15>> // ERROR HERE line 4
|
||||
return %6 : tensor<1x!FHE.eint<15>>
|
||||
func.func @main(%arg0: !FHE.eint<5>, %cst: tensor<32xi64>) -> tensor<1x!FHE.eint<5>> {
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<5>, tensor<32xi64>) -> !FHE.eint<5>
|
||||
%6 = tensor.from_elements %1 : tensor<1x!FHE.eint<5>> // ERROR HERE line 4
|
||||
return %6 : tensor<1x!FHE.eint<5>>
|
||||
}
|
||||
|
||||
// Ensures that tensors of multiple elements can be constructed as well.
|
||||
func.func @main2(%arg0: !FHE.eint<15>, %cst: tensor<32768xi64>) -> tensor<2x!FHE.eint<15>> {
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<15>, tensor<32768xi64>) -> !FHE.eint<15>
|
||||
%6 = tensor.from_elements %1, %arg0 : tensor<2x!FHE.eint<15>> // ERROR HERE line 4
|
||||
return %6 : tensor<2x!FHE.eint<15>>
|
||||
func.func @main2(%arg0: !FHE.eint<5>, %cst: tensor<32xi64>) -> tensor<2x!FHE.eint<5>> {
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<5>, tensor<32xi64>) -> !FHE.eint<5>
|
||||
%6 = tensor.from_elements %1, %arg0 : tensor<2x!FHE.eint<5>> // ERROR HERE line 4
|
||||
return %6 : tensor<2x!FHE.eint<5>>
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --action=dump-tfhe --force-encoding crt %s
|
||||
// RUN: concretecompiler --action=dump-tfhe --optimizer-strategy=V0 --force-encoding crt --skip-program-info %s
|
||||
func.func @main(%arg0: tensor<32x!FHE.eint<8>>) -> tensor<16x!FHE.eint<8>>{
|
||||
%0 = tensor.extract_slice %arg0[16] [16] [1] : tensor<32x!FHE.eint<8>> to tensor<16x!FHE.eint<8>>
|
||||
return %0 : tensor<16x!FHE.eint<8>>
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
// RUN: concretecompiler --action=dump-tfhe --force-encoding crt %s
|
||||
func.func @main(%2: tensor<1x1x!FHE.eint<16>>) -> tensor<1x1x1x!FHE.eint<16>> {
|
||||
%3 = tensor.expand_shape %2 [[0], [1, 2]] : tensor<1x1x!FHE.eint<16>> into tensor<1x1x1x!FHE.eint<16>>
|
||||
return %3 : tensor<1x1x1x!FHE.eint<16>>
|
||||
}
|
||||
|
||||
func.func @main2(%2: tensor<1x1x1x!FHE.eint<16>>) -> tensor<1x1x!FHE.eint<16>> {
|
||||
%3 = tensor.collapse_shape %2 [[0], [1, 2]] : tensor<1x1x1x!FHE.eint<16>> into tensor<1x1x!FHE.eint<16>>
|
||||
return %3 : tensor<1x1x!FHE.eint<16>>
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --action=dump-llvm-dialect --emit-gpu-ops %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --action=dump-llvm-dialect --emit-gpu-ops --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: llvm.call @memref_keyswitch_lwe_cuda_u64
|
||||
//CHECK: llvm.call @memref_bootstrap_lwe_cuda_u64
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --action=dump-parametrized-tfhe --optimizer-strategy=V0 --v0-parameter=2,10,750,1,23,3,4 --v0-constraint=4,0 %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --action=dump-parametrized-tfhe --optimizer-strategy=V0 --v0-parameter=2,10,750,1,23,3,4 --v0-constraint=4,0 --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @main(%[[A0:.*]]: !TFHE.glwe<sk<0,1,2048>>) -> !TFHE.glwe<sk<0,1,2048>> {
|
||||
//CHECK-NEXT: %cst = arith.constant dense<0> : tensor<1024xi64>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @add_glwe(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>) -> tensor<2049xi64>
|
||||
func.func @add_glwe(%arg0: !TFHE.glwe<sk[1]<1,2048>>, %arg1: !TFHE.glwe<sk[1]<1,2048>>) -> !TFHE.glwe<sk[1]<1,2048>> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @add_glwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
//CHECK: %c1_i64 = arith.constant 1 : i64
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @apply_lookup_table(%arg0: tensor<4xi64>) -> tensor<1024xi64> {
|
||||
// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_bootstrap_tensor"(%arg0) {isSigned = true, outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @main(%arg0: tensor<4xi64>) -> tensor<5x8192xi64> {
|
||||
// CHECK-NEXT: %0 = "Concrete.encode_lut_for_crt_woppbs_tensor"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32} : (tensor<4xi64>) -> tensor<5x8192xi64>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @main(%arg0: i64) -> tensor<5xi64> {
|
||||
// CHECK-NEXT: %0 = "Concrete.encode_plaintext_with_crt_tensor"(%arg0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @keyswitch_glwe(%[[A0:.*]]: tensor<1025xi64>) -> tensor<568xi64> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "Concrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 3 : i32, kskIndex = -1 : i32, level = 2 : i32, lwe_dim_in = 1024 : i32, lwe_dim_out = 567 : i32} : (tensor<1025xi64>) -> tensor<568xi64>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @mul_glwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
//CHECK: %c1_i64 = arith.constant 1 : i64
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @neg_glwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64>
|
||||
func.func @neg_glwe(%arg0: !TFHE.glwe<sk[1]<1,1024>>) -> !TFHE.glwe<sk[1]<1,1024>> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @sub_const_int_glwe(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
//CHECK: %c1_i64 = arith.constant 1 : i64
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --split-input-file --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --split-input-file --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @tensor_collapse_shape(%[[A0:.*]]: tensor<2x3x4x5x6x1025xi64>) -> tensor<720x1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = tensor.collapse_shape %[[A0]] [[_:\[\[0, 1, 2, 3, 4\], \[5\]\]]] : tensor<2x3x4x5x6x1025xi64> into tensor<720x1025xi64>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --split-input-file %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --split-input-file --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @main(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>, %[[A2:.*]]: tensor<2049xi64>, %[[A3:.*]]: tensor<2049xi64>, %[[A4:.*]]: tensor<2049xi64>, %[[A5:.*]]: tensor<2049xi64>) -> tensor<6x2049xi64> {
|
||||
// CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<6x2049xi64>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @tensor_identity(%arg0: tensor<2x3x4x1025xi64>) -> tensor<2x3x4x1025xi64> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x3x4x1025xi64>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --action=dump-fhe %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --action=dump-fhe --optimizer-strategy=V0 --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
func.func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --action=dump-fhe %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --action=dump-fhe --optimizer-strategy=V0 --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @add_eint_int_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --passes tfhe-optimization --optimize-tfhe=false --action=dump-tfhe %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --passes tfhe-optimization --optimize-tfhe=false --action=dump-tfhe --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @mul_cleartext_glwe_ciphertext_0(%[[A0:.*]]: !TFHE.glwe<sk[1]<527,1>>) -> !TFHE.glwe<sk[1]<527,1>> {
|
||||
//CHECK: %c0_i64 = arith.constant 0 : i64
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --passes tfhe-optimization --action=dump-tfhe %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --passes tfhe-optimization --action=dump-tfhe --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
|
||||
// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !TFHE.glwe<sk[1]<527,1>>, %arg1: i64) -> !TFHE.glwe<sk[1]<527,1>>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --split-input-file --action=dump-batched-tfhe --batch-tfhe-ops %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --split-input-file --action=dump-batched-tfhe --batch-tfhe-ops --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @batch_continuous_slice_keyswitch
|
||||
// CHECK: (%arg0: tensor<2x3x4x!TFHE.glwe<sk{{\[}}[[SK_IN:.*]]{{\]}}<1,2048>>>) -> tensor<2x3x4x!TFHE.glwe<sk{{\[}}[[SK_OUT:.*]]{{\]}}<1,750>>> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --split-input-file --action=dump-parametrized-tfhe --optimizer-strategy=dag-multi %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --split-input-file --action=dump-parametrized-tfhe --optimizer-strategy=dag-multi --skip-program-info %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @funconly_fwd(%arg0: !TFHE.glwe<sk[1]<12,1024>>) -> !TFHE.glwe<sk[1]<12,1024>> {
|
||||
// CHECK-NEXT: return %arg0 : !TFHE.glwe<sk[1]<12,1024>>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "../end_to_end_tests/end_to_end_test.h"
|
||||
#include "concretelang/Common/Compat.h"
|
||||
#include "concretelang/TestLib/TestCircuit.h"
|
||||
#include "concretelang/TestLib/TestProgram.h"
|
||||
#include <concretelang/Runtime/DFRuntime.hpp>
|
||||
|
||||
#include <benchmark/benchmark.h>
|
||||
@@ -23,7 +23,7 @@ using namespace concretelang::testlib;
|
||||
/// Benchmark time of the compilation
|
||||
static void BM_Compile(benchmark::State &state, EndToEndDesc description,
|
||||
mlir::concretelang::CompilationOptions options) {
|
||||
TestCircuit tc(options);
|
||||
TestProgram tc(options);
|
||||
for (auto _ : state) {
|
||||
assert(tc.compile(description.program));
|
||||
}
|
||||
@@ -32,7 +32,7 @@ static void BM_Compile(benchmark::State &state, EndToEndDesc description,
|
||||
/// Benchmark time of the key generation
|
||||
static void BM_KeyGen(benchmark::State &state, EndToEndDesc description,
|
||||
mlir::concretelang::CompilationOptions options) {
|
||||
TestCircuit tc(options);
|
||||
TestProgram tc(options);
|
||||
assert(tc.compile(description.program));
|
||||
|
||||
for (auto _ : state) {
|
||||
@@ -44,7 +44,7 @@ static void BM_KeyGen(benchmark::State &state, EndToEndDesc description,
|
||||
static void BM_ExportArguments(benchmark::State &state,
|
||||
EndToEndDesc description,
|
||||
mlir::concretelang::CompilationOptions options) {
|
||||
TestCircuit tc(options);
|
||||
TestProgram tc(options);
|
||||
assert(tc.compile(description.program));
|
||||
assert(tc.generateKeyset());
|
||||
|
||||
@@ -68,7 +68,7 @@ static void BM_ExportArguments(benchmark::State &state,
|
||||
/// Benchmark time of the program evaluation
|
||||
static void BM_Evaluate(benchmark::State &state, EndToEndDesc description,
|
||||
mlir::concretelang::CompilationOptions options) {
|
||||
TestCircuit tc(options);
|
||||
TestProgram tc(options);
|
||||
assert(tc.compile(description.program));
|
||||
assert(tc.generateKeyset());
|
||||
auto clientCircuit = tc.getClientCircuit().value();
|
||||
@@ -111,7 +111,6 @@ void registerEndToEndBenchmark(std::string suiteName,
|
||||
int num_iterations = 0) {
|
||||
auto optionsName = getOptionsName(options);
|
||||
for (auto description : descriptions) {
|
||||
options.mainFuncName = "main";
|
||||
if (description.p_error) {
|
||||
assert(std::isnan(options.optimizerConfig.global_p_error));
|
||||
options.optimizerConfig.p_error = description.p_error.value();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "concretelang/Common/Compat.h"
|
||||
#include "concretelang/TestLib/TestCircuit.h"
|
||||
#include "concretelang/TestLib/TestProgram.h"
|
||||
#include "end_to_end_fixture/EndToEndFixture.h"
|
||||
#include <concretelang/Runtime/DFRuntime.hpp>
|
||||
#define BENCHMARK_HAS_CXX11
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "concretelang/TestLib/TestCircuit.h"
|
||||
#include "concretelang/TestLib/TestProgram.h"
|
||||
#include "end_to_end_jit_test.h"
|
||||
#include "tests_tools/GtestEnvironment.h"
|
||||
std::vector<uint64_t> distributed_results;
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "concretelang/TestLib/TestCircuit.h"
|
||||
#include "concretelang/TestLib/TestProgram.h"
|
||||
#include "end_to_end_jit_test.h"
|
||||
#include "tests_tools/GtestEnvironment.h"
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "concretelang/TestLib/TestCircuit.h"
|
||||
#include "concretelang/TestLib/TestProgram.h"
|
||||
#include "end_to_end_jit_test.h"
|
||||
#include "tests_tools/GtestEnvironment.h"
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "concretelang/TestLib/TestCircuit.h"
|
||||
#include "concretelang/TestLib/TestProgram.h"
|
||||
#include "end_to_end_jit_test.h"
|
||||
#include "tests_tools/GtestEnvironment.h"
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "concretelang/TestLib/TestCircuit.h"
|
||||
#include "concretelang/TestLib/TestProgram.h"
|
||||
#include "end_to_end_jit_test.h"
|
||||
#include "tests_tools/GtestEnvironment.h"
|
||||
#include "tests_tools/assert.h"
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include "concretelang/TestLib/TestCircuit.h"
|
||||
#include "concretelang/TestLib/TestProgram.h"
|
||||
#include "end_to_end_jit_test.h"
|
||||
#include "tests_tools/GtestEnvironment.h"
|
||||
|
||||
@@ -395,10 +395,10 @@ func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
}
|
||||
|
||||
TEST(CompileNotComposable, not_composable_1) {
|
||||
mlir::concretelang::CompilationOptions options("main");
|
||||
mlir::concretelang::CompilationOptions options;
|
||||
options.optimizerConfig.composable = true;
|
||||
options.optimizerConfig.strategy = mlir::concretelang::optimizer::DAG_MULTI;
|
||||
TestCircuit circuit(options);
|
||||
TestProgram circuit(options);
|
||||
auto err = circuit.compile(R"XXX(
|
||||
func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
%cst_1 = arith.constant 1 : i4
|
||||
@@ -411,11 +411,11 @@ func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
}
|
||||
|
||||
TEST(CompileNotComposable, not_composable_2) {
|
||||
mlir::concretelang::CompilationOptions options("main");
|
||||
mlir::concretelang::CompilationOptions options;
|
||||
options.optimizerConfig.composable = true;
|
||||
options.optimizerConfig.display = true;
|
||||
options.optimizerConfig.strategy = mlir::concretelang::optimizer::DAG_MULTI;
|
||||
TestCircuit circuit(options);
|
||||
TestProgram circuit(options);
|
||||
auto err = circuit.compile(R"XXX(
|
||||
func.func @main(%arg0: !FHE.eint<3>) -> (!FHE.eint<3>, !FHE.eint<3>) {
|
||||
%cst_1 = arith.constant 1 : i4
|
||||
@@ -430,11 +430,11 @@ func.func @main(%arg0: !FHE.eint<3>) -> (!FHE.eint<3>, !FHE.eint<3>) {
|
||||
}
|
||||
|
||||
TEST(CompileComposable, composable_supported_dag_mono) {
|
||||
mlir::concretelang::CompilationOptions options("main");
|
||||
mlir::concretelang::CompilationOptions options;
|
||||
options.optimizerConfig.composable = true;
|
||||
options.optimizerConfig.display = true;
|
||||
options.optimizerConfig.strategy = mlir::concretelang::optimizer::DAG_MONO;
|
||||
TestCircuit circuit(options);
|
||||
TestProgram circuit(options);
|
||||
auto err = circuit.compile(R"XXX(
|
||||
func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
%cst_1 = arith.constant 1 : i4
|
||||
@@ -446,11 +446,11 @@ func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
}
|
||||
|
||||
TEST(CompileComposable, composable_supported_v0) {
|
||||
mlir::concretelang::CompilationOptions options("main");
|
||||
mlir::concretelang::CompilationOptions options;
|
||||
options.optimizerConfig.composable = true;
|
||||
options.optimizerConfig.display = true;
|
||||
options.optimizerConfig.strategy = mlir::concretelang::optimizer::V0;
|
||||
TestCircuit circuit(options);
|
||||
TestProgram circuit(options);
|
||||
auto err = circuit.compile(R"XXX(
|
||||
func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
%cst_1 = arith.constant 1 : i4
|
||||
@@ -460,3 +460,39 @@ func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
)XXX");
|
||||
assert(err.has_value());
|
||||
}
|
||||
|
||||
TEST(CompileMultiFunctions, multi_functions_v0) {
|
||||
mlir::concretelang::CompilationOptions options;
|
||||
options.optimizerConfig.strategy = mlir::concretelang::optimizer::V0;
|
||||
TestProgram circuit(options);
|
||||
auto err = circuit.compile(R"XXX(
|
||||
func.func @inc(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
%cst_1 = arith.constant 1 : i4
|
||||
%1 = "FHE.add_eint_int"(%arg0, %cst_1) : (!FHE.eint<3>, i4) -> !FHE.eint<3>
|
||||
return %1: !FHE.eint<3>
|
||||
}
|
||||
func.func @dec(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
%cst_1 = arith.constant 1 : i4
|
||||
%1 = "FHE.sub_eint_int"(%arg0, %cst_1) : (!FHE.eint<3>, i4) -> !FHE.eint<3>
|
||||
return %1: !FHE.eint<3>
|
||||
}
|
||||
)XXX");
|
||||
assert(err.has_value());
|
||||
assert(circuit.generateKeyset().has_value());
|
||||
auto lambda_inc = [&](std::vector<concretelang::values::Value> args) {
|
||||
return circuit.call(args, "inc")
|
||||
.value()[0]
|
||||
.template getTensor<uint64_t>()
|
||||
.value()[0];
|
||||
};
|
||||
auto lambda_dec = [&](std::vector<concretelang::values::Value> args) {
|
||||
return circuit.call(args, "dec")
|
||||
.value()[0]
|
||||
.template getTensor<uint64_t>()
|
||||
.value()[0];
|
||||
};
|
||||
ASSERT_EQ(lambda_inc({Tensor<uint64_t>(1)}), (uint64_t)2);
|
||||
ASSERT_EQ(lambda_inc({Tensor<uint64_t>(4)}), (uint64_t)5);
|
||||
ASSERT_EQ(lambda_dec({Tensor<uint64_t>(1)}), (uint64_t)0);
|
||||
ASSERT_EQ(lambda_dec({Tensor<uint64_t>(4)}), (uint64_t)3);
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/V0Parameters.h"
|
||||
#include "concretelang/TestLib/TestCircuit.h"
|
||||
#include "concretelang/TestLib/TestProgram.h"
|
||||
#include "cstdlib"
|
||||
#include "end_to_end_test.h"
|
||||
#include "globals.h"
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
using concretelang::error::Result;
|
||||
using concretelang::error::StringError;
|
||||
using concretelang::testlib::TestCircuit;
|
||||
using concretelang::testlib::TestProgram;
|
||||
|
||||
llvm::StringRef DEFAULT_func = "main";
|
||||
bool DEFAULT_useDefaultFHEConstraints = false;
|
||||
@@ -30,7 +30,7 @@ bool DEFAULT_composable = false;
|
||||
// Jit-compiles the function specified by `func` from `src` and
|
||||
// returns the corresponding lambda. Any compilation errors are caught
|
||||
// and reult in abnormal termination.
|
||||
inline Result<TestCircuit> internalCheckedJit(
|
||||
inline Result<TestProgram> internalCheckedJit(
|
||||
llvm::StringRef src, llvm::StringRef func = DEFAULT_func,
|
||||
bool useDefaultFHEConstraints = DEFAULT_useDefaultFHEConstraints,
|
||||
bool dataflowParallelize = DEFAULT_dataflowParallelize,
|
||||
@@ -42,8 +42,7 @@ inline Result<TestCircuit> internalCheckedJit(
|
||||
unsigned int chunkWidth = DEFAULT_chunkWidth,
|
||||
bool composable = DEFAULT_composable) {
|
||||
|
||||
auto options =
|
||||
mlir::concretelang::CompilationOptions(std::string(func.data()));
|
||||
auto options = mlir::concretelang::CompilationOptions();
|
||||
options.optimizerConfig.global_p_error = global_p_error;
|
||||
options.chunkIntegers = chunkedIntegers;
|
||||
options.chunkSize = chunkSize;
|
||||
@@ -70,10 +69,10 @@ inline Result<TestCircuit> internalCheckedJit(
|
||||
}
|
||||
|
||||
std::vector<std::string> sources = {src.str()};
|
||||
TestCircuit testCircuit(options);
|
||||
OUTCOME_TRYV(testCircuit.compile({src.str()}));
|
||||
OUTCOME_TRYV(testCircuit.generateKeyset());
|
||||
return std::move(testCircuit);
|
||||
TestProgram testProgram(options);
|
||||
OUTCOME_TRYV(testProgram.compile({src.str()}));
|
||||
OUTCOME_TRYV(testProgram.generateKeyset());
|
||||
return std::move(testProgram);
|
||||
}
|
||||
|
||||
// Wrapper around `internalCheckedJit` that causes
|
||||
|
||||
@@ -7,14 +7,14 @@
|
||||
|
||||
#include "concretelang/Common/Values.h"
|
||||
#include "concretelang/Support/CompilationFeedback.h"
|
||||
#include "concretelang/TestLib/TestCircuit.h"
|
||||
#include "concretelang/TestLib/TestProgram.h"
|
||||
#include "end_to_end_fixture/EndToEndFixture.h"
|
||||
#include "end_to_end_jit_test.h"
|
||||
#include "tests_tools/GtestEnvironment.h"
|
||||
#include "tests_tools/assert.h"
|
||||
#include "tests_tools/keySetCache.h"
|
||||
|
||||
using concretelang::testlib::TestCircuit;
|
||||
using concretelang::testlib::TestProgram;
|
||||
using concretelang::values::Value;
|
||||
|
||||
/// @brief EndToEndTest is a template that allows testing for one program for a
|
||||
@@ -35,7 +35,7 @@ public:
|
||||
};
|
||||
|
||||
void SetUp() override {
|
||||
TestCircuit tc(options.compilationOptions);
|
||||
TestProgram tc(options.compilationOptions);
|
||||
ASSERT_OUTCOME_HAS_VALUE(tc.compile({program}));
|
||||
ASSERT_OUTCOME_HAS_VALUE(tc.generateKeyset());
|
||||
testCircuit.emplace(std::move(tc));
|
||||
@@ -122,7 +122,7 @@ private:
|
||||
TestDescription desc;
|
||||
std::optional<TestErrorRate> errorRate;
|
||||
std::optional<mlir::concretelang::CompilerEngine::Library> library;
|
||||
std::optional<TestCircuit> testCircuit;
|
||||
std::optional<TestProgram> testCircuit;
|
||||
EndToEndTestOptions options;
|
||||
std::vector<Value> args;
|
||||
};
|
||||
|
||||
@@ -135,8 +135,7 @@ parseEndToEndCommandLine(int argc, char **argv) {
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv);
|
||||
|
||||
// Build compilation options
|
||||
mlir::concretelang::CompilationOptions compilationOptions("main",
|
||||
backend.getValue());
|
||||
mlir::concretelang::CompilationOptions compilationOptions(backend.getValue());
|
||||
if (loopParallelize.getValue() != -1)
|
||||
compilationOptions.loopParallelize = loopParallelize.getValue();
|
||||
|
||||
|
||||
@@ -7,7 +7,8 @@ from concrete.compiler import (
|
||||
LibrarySupport,
|
||||
ClientSupport,
|
||||
CompilationOptions,
|
||||
CompilationFeedback,
|
||||
ProgramCompilationFeedback,
|
||||
CircuitCompilationFeedback,
|
||||
)
|
||||
|
||||
|
||||
@@ -23,31 +24,40 @@ def assert_result(result, expected_result):
|
||||
assert np.all(result == expected_result)
|
||||
|
||||
|
||||
def run(engine, args, compilation_result, keyset_cache):
|
||||
def run(engine, args, compilation_result, keyset_cache, circuit_name="main"):
|
||||
"""Execute engine on the given arguments.
|
||||
|
||||
Perform required loading, encryption, execution, and decryption."""
|
||||
# Dev
|
||||
compilation_feedback = engine.load_compilation_feedback(compilation_result)
|
||||
assert isinstance(compilation_feedback, CompilationFeedback)
|
||||
assert isinstance(compilation_feedback, ProgramCompilationFeedback)
|
||||
assert isinstance(compilation_feedback.complexity, float)
|
||||
assert isinstance(compilation_feedback.p_error, float)
|
||||
assert isinstance(compilation_feedback.global_p_error, float)
|
||||
assert isinstance(compilation_feedback.total_secret_keys_size, int)
|
||||
assert isinstance(compilation_feedback.total_bootstrap_keys_size, int)
|
||||
assert isinstance(compilation_feedback.total_inputs_size, int)
|
||||
assert isinstance(compilation_feedback.total_output_size, int)
|
||||
assert isinstance(compilation_feedback.circuit_feedbacks, list)
|
||||
circuit_feedback = next(
|
||||
filter(lambda x: x.name == circuit_name, compilation_feedback.circuit_feedbacks)
|
||||
)
|
||||
assert isinstance(circuit_feedback, CircuitCompilationFeedback)
|
||||
assert isinstance(circuit_feedback.total_inputs_size, int)
|
||||
assert isinstance(circuit_feedback.total_output_size, int)
|
||||
|
||||
# Client
|
||||
client_parameters = engine.load_client_parameters(compilation_result)
|
||||
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
|
||||
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
|
||||
public_arguments = ClientSupport.encrypt_arguments(
|
||||
client_parameters, key_set, args, circuit_name
|
||||
)
|
||||
# Server
|
||||
server_lambda = engine.load_server_lambda(compilation_result, False)
|
||||
server_lambda = engine.load_server_lambda(compilation_result, False, circuit_name)
|
||||
evaluation_keys = key_set.get_evaluation_keys()
|
||||
public_result = engine.server_call(server_lambda, public_arguments, evaluation_keys)
|
||||
# Client
|
||||
result = ClientSupport.decrypt_result(client_parameters, key_set, public_result)
|
||||
result = ClientSupport.decrypt_result(
|
||||
client_parameters, key_set, public_result, circuit_name
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@@ -57,11 +67,12 @@ def compile_run_assert(
|
||||
args,
|
||||
expected_result,
|
||||
keyset_cache,
|
||||
options=CompilationOptions.new("main"),
|
||||
options=CompilationOptions.new(),
|
||||
circuit_name="main",
|
||||
):
|
||||
"""Compile run and assert result."""
|
||||
compilation_result = engine.compile(mlir_input, options)
|
||||
result = run(engine, args, compilation_result, keyset_cache)
|
||||
result = run(engine, args, compilation_result, keyset_cache, circuit_name)
|
||||
assert_result(result, expected_result)
|
||||
|
||||
|
||||
@@ -231,6 +242,33 @@ def test_lib_compilation_artifacts():
|
||||
assert not os.path.exists(engine.get_shared_lib_path())
|
||||
|
||||
|
||||
def test_multi_circuits(keyset_cache):
|
||||
from mlir._mlir_libs._concretelang._compiler import OptimizerStrategy
|
||||
|
||||
mlir_str = """
|
||||
func.func @add(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
func.func @sub(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
"""
|
||||
args = (10, 3)
|
||||
expected_add_result = 13
|
||||
expected_sub_result = 7
|
||||
engine = LibrarySupport.new("./py_test_multi_circuits")
|
||||
options = CompilationOptions.new()
|
||||
options.set_optimizer_strategy(OptimizerStrategy.V0)
|
||||
compile_run_assert(
|
||||
engine, mlir_str, args, expected_add_result, keyset_cache, options, "add"
|
||||
)
|
||||
compile_run_assert(
|
||||
engine, mlir_str, args, expected_sub_result, keyset_cache, options, "sub"
|
||||
)
|
||||
|
||||
|
||||
def _test_lib_compile_and_run_with_options(keyset_cache, options):
|
||||
mlir_input = """
|
||||
func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
@@ -247,21 +285,21 @@ def _test_lib_compile_and_run_with_options(keyset_cache, options):
|
||||
|
||||
|
||||
def test_lib_compile_and_run_p_error(keyset_cache):
|
||||
options = CompilationOptions.new("main")
|
||||
options = CompilationOptions.new()
|
||||
options.set_p_error(0.00001)
|
||||
options.set_display_optimizer_choice(True)
|
||||
_test_lib_compile_and_run_with_options(keyset_cache, options)
|
||||
|
||||
|
||||
def test_lib_compile_and_run_global_p_error(keyset_cache):
|
||||
options = CompilationOptions.new("main")
|
||||
options = CompilationOptions.new()
|
||||
options.set_global_p_error(0.00001)
|
||||
options.set_display_optimizer_choice(True)
|
||||
_test_lib_compile_and_run_with_options(keyset_cache, options)
|
||||
|
||||
|
||||
def test_lib_compile_and_run_security_level(keyset_cache):
|
||||
options = CompilationOptions.new("main")
|
||||
options = CompilationOptions.new()
|
||||
options.set_security_level(80)
|
||||
options.set_display_optimizer_choice(True)
|
||||
_test_lib_compile_and_run_with_options(keyset_cache, options)
|
||||
@@ -276,7 +314,7 @@ def test_compile_and_run_auto_parallelize(
|
||||
):
|
||||
artifact_dir = "./py_test_compile_and_run_auto_parallelize"
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
options = CompilationOptions.new("main")
|
||||
options = CompilationOptions.new()
|
||||
options.set_auto_parallelize(True)
|
||||
compile_run_assert(
|
||||
engine, mlir_input, args, expected_result, keyset_cache, options=options
|
||||
@@ -301,7 +339,7 @@ def test_compile_and_run_auto_parallelize(
|
||||
# if no_parallel:
|
||||
# artifact_dir = "./py_test_compile_dataflow_and_fail_run"
|
||||
# engine = LibrarySupport.new(artifact_dir)
|
||||
# options = CompilationOptions.new("main")
|
||||
# options = CompilationOptions.new()
|
||||
# options.set_auto_parallelize(True)
|
||||
# with pytest.raises(
|
||||
# RuntimeError,
|
||||
@@ -334,7 +372,7 @@ def test_compile_and_run_loop_parallelize(
|
||||
):
|
||||
artifact_dir = "./py_test_compile_and_run_loop_parallelize"
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
options = CompilationOptions.new("main")
|
||||
options = CompilationOptions.new()
|
||||
options.set_loop_parallelize(True)
|
||||
compile_run_assert(
|
||||
engine, mlir_input, args, expected_result, keyset_cache, options=options
|
||||
@@ -365,29 +403,6 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args, keyset_cache):
|
||||
compile_run_assert(engine, mlir_input, args, None, keyset_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input",
|
||||
[
|
||||
pytest.param(
|
||||
"""
|
||||
func.func @test(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7>
|
||||
{
|
||||
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
|
||||
(tensor<4x!FHE.eint<7>>, tensor<4xi8>) -> !FHE.eint<7>
|
||||
return %ret : !FHE.eint<7>
|
||||
}
|
||||
""",
|
||||
id="not @main",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_invalid(mlir_input):
|
||||
artifact_dir = "./py_test_compile_invalid"
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
with pytest.raises(RuntimeError, match=r"Function not found, name='main'"):
|
||||
engine.compile(mlir_input)
|
||||
|
||||
|
||||
def test_crt_decomposition_feedback():
|
||||
mlir = """
|
||||
|
||||
@@ -401,11 +416,27 @@ func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
|
||||
|
||||
artifact_dir = "./py_test_crt_decomposition_feedback"
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
compilation_result = engine.compile(mlir, options=CompilationOptions.new("main"))
|
||||
compilation_result = engine.compile(mlir, options=CompilationOptions.new())
|
||||
compilation_feedback = engine.load_compilation_feedback(compilation_result)
|
||||
|
||||
assert isinstance(compilation_feedback, CompilationFeedback)
|
||||
assert compilation_feedback.crt_decompositions_of_outputs == [[7, 8, 9, 11, 13]]
|
||||
assert isinstance(compilation_feedback, ProgramCompilationFeedback)
|
||||
assert isinstance(compilation_feedback.complexity, float)
|
||||
assert isinstance(compilation_feedback.p_error, float)
|
||||
assert isinstance(compilation_feedback.global_p_error, float)
|
||||
assert isinstance(compilation_feedback.total_secret_keys_size, int)
|
||||
assert isinstance(compilation_feedback.total_bootstrap_keys_size, int)
|
||||
assert isinstance(compilation_feedback.circuit_feedbacks, list)
|
||||
assert isinstance(
|
||||
compilation_feedback.circuit_feedbacks[0], CircuitCompilationFeedback
|
||||
)
|
||||
assert isinstance(compilation_feedback.circuit_feedbacks[0].total_inputs_size, int)
|
||||
assert isinstance(compilation_feedback.circuit_feedbacks[0].total_output_size, int)
|
||||
assert isinstance(
|
||||
compilation_feedback.circuit_feedbacks[0].crt_decompositions_of_outputs, list
|
||||
)
|
||||
assert compilation_feedback.circuit_feedbacks[0].crt_decompositions_of_outputs == [
|
||||
[7, 8, 9, 11, 13]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -450,10 +481,11 @@ def test_memory_usage(mlir: str, expected_memory_usage_per_loc: dict):
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
compilation_result = engine.compile(mlir)
|
||||
compilation_feedback = engine.load_compilation_feedback(compilation_result)
|
||||
assert isinstance(compilation_feedback, CompilationFeedback)
|
||||
assert isinstance(compilation_feedback, ProgramCompilationFeedback)
|
||||
|
||||
assert (
|
||||
expected_memory_usage_per_loc == compilation_feedback.memory_usage_per_location
|
||||
expected_memory_usage_per_loc
|
||||
== compilation_feedback.circuit_feedbacks[0].memory_usage_per_location
|
||||
)
|
||||
|
||||
shutil.rmtree(artifact_dir)
|
||||
|
||||
@@ -49,7 +49,7 @@ def compile_run_assert(
|
||||
mlir_input,
|
||||
args_and_shape,
|
||||
expected_result,
|
||||
options=CompilationOptions.new("main"),
|
||||
options=CompilationOptions.new(),
|
||||
):
|
||||
# compile with simulation
|
||||
options.simulation(True)
|
||||
|
||||
@@ -32,7 +32,10 @@ module {
|
||||
compilation_result = support.compile(mlir)
|
||||
|
||||
client_parameters = support.load_client_parameters(compilation_result)
|
||||
compilation_feedback = support.load_compilation_feedback(compilation_result)
|
||||
program_compilation_feedback = support.load_compilation_feedback(
|
||||
compilation_result
|
||||
)
|
||||
compilation_feedback = program_compilation_feedback.circuit("main")
|
||||
|
||||
pbs_count = compilation_feedback.count(
|
||||
operations={
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Encodings.h"
|
||||
#include "concretelang/TestLib/TestCircuit.h"
|
||||
#include "concretelang/TestLib/TestProgram.h"
|
||||
|
||||
#include "tests_tools/GtestEnvironment.h"
|
||||
#include "tests_tools/assert.h"
|
||||
@@ -23,17 +23,18 @@ testing::Environment *const dfr_env =
|
||||
using namespace concretelang::testlib;
|
||||
namespace encodings = mlir::concretelang::encodings;
|
||||
|
||||
Result<TestCircuit> setupTestCircuit(std::string source,
|
||||
Result<TestProgram> setupTestProgram(std::string source,
|
||||
std::string funcname = FUNCNAME) {
|
||||
std::vector<std::string> sources = {source};
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
|
||||
mlir::concretelang::CompilationContext::createShared();
|
||||
mlir::concretelang::CompilerEngine ce{ccx};
|
||||
mlir::concretelang::CompilationOptions options(funcname);
|
||||
mlir::concretelang::CompilationOptions options;
|
||||
|
||||
options.encodings = Message<concreteprotocol::CircuitEncodingInfo>();
|
||||
auto inputs = options.encodings->asBuilder().initInputs(2);
|
||||
auto outputs = options.encodings->asBuilder().initOutputs(1);
|
||||
auto circuitEncoding = Message<concreteprotocol::CircuitEncodingInfo>();
|
||||
auto inputs = circuitEncoding.asBuilder().initInputs(2);
|
||||
auto outputs = circuitEncoding.asBuilder().initOutputs(1);
|
||||
circuitEncoding.asBuilder().setName(funcname);
|
||||
|
||||
auto encodingInfo = Message<concreteprotocol::EncodingInfo>().asBuilder();
|
||||
encodingInfo.initShape();
|
||||
@@ -46,9 +47,12 @@ Result<TestCircuit> setupTestCircuit(std::string source,
|
||||
inputs.setWithCaveats(1, encodingInfo);
|
||||
outputs.setWithCaveats(0, encodingInfo);
|
||||
|
||||
options.encodings->asBuilder().setName("main");
|
||||
options.encodings = Message<concreteprotocol::ProgramEncodingInfo>();
|
||||
options.encodings->asBuilder().initCircuits(1).setWithCaveats(
|
||||
0, circuitEncoding.asReader());
|
||||
|
||||
options.v0Parameter = {2, 10, 693, 4, 9, 7, 2, std::nullopt};
|
||||
TestCircuit testCircuit(options);
|
||||
TestProgram testCircuit(options);
|
||||
OUTCOME_TRYV(testCircuit.compile({source}));
|
||||
OUTCOME_TRYV(testCircuit.generateKeyset());
|
||||
return std::move(testCircuit);
|
||||
@@ -67,7 +71,7 @@ func.func @main(
|
||||
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
uint64_t a = 5;
|
||||
uint64_t b = 5;
|
||||
auto res = circuit.call({Tensor<uint64_t>(a), Tensor<uint64_t>(b)});
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/TestLib/TestCircuit.h"
|
||||
#include "concretelang/TestLib/TestProgram.h"
|
||||
|
||||
#include "tests_tools/GtestEnvironment.h"
|
||||
#include "tests_tools/assert.h"
|
||||
@@ -20,15 +20,15 @@ using namespace concretelang::testlib;
|
||||
testing::Environment *const dfr_env =
|
||||
testing::AddGlobalTestEnvironment(new DFREnvironment);
|
||||
|
||||
Result<TestCircuit> setupTestCircuit(std::string source,
|
||||
Result<TestProgram> setupTestProgram(std::string source,
|
||||
std::string funcname = FUNCNAME) {
|
||||
mlir::concretelang::CompilationOptions options(funcname);
|
||||
mlir::concretelang::CompilationOptions options;
|
||||
#ifdef CONCRETELANG_CUDA_SUPPORT
|
||||
options.emitGPUOps = true;
|
||||
options.emitSDFGOps = true;
|
||||
#endif
|
||||
options.batchTFHEOps = true;
|
||||
TestCircuit testCircuit(options);
|
||||
TestProgram testCircuit(options);
|
||||
OUTCOME_TRYV(testCircuit.compile({source}));
|
||||
OUTCOME_TRYV(testCircuit.generateKeyset());
|
||||
return std::move(testCircuit);
|
||||
@@ -41,7 +41,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
for (auto a : values_7bits())
|
||||
for (auto b : values_7bits()) {
|
||||
if (a > b) {
|
||||
@@ -61,7 +61,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
for (auto a : values_7bits())
|
||||
for (auto b : values_7bits()) {
|
||||
if (a > b) {
|
||||
@@ -81,7 +81,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
for (auto a : values_3bits())
|
||||
for (auto b : values_3bits()) {
|
||||
if (a > b) {
|
||||
@@ -101,7 +101,7 @@ func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
for (auto a : values_7bits()) {
|
||||
auto res = circuit.call({Tensor<uint64_t>(a)});
|
||||
ASSERT_TRUE(res.has_value());
|
||||
@@ -119,7 +119,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>, %arg2: !FHE.eint<7>, %
|
||||
return %3: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
for (auto a : values_3bits()) {
|
||||
for (auto b : values_3bits()) {
|
||||
auto res = circuit.call({
|
||||
@@ -143,7 +143,7 @@ func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
return %1: !FHE.eint<3>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
for (auto a : values_3bits()) {
|
||||
auto res = circuit.call({Tensor<uint64_t>(a)});
|
||||
ASSERT_TRUE(res.has_value());
|
||||
@@ -165,7 +165,7 @@ func.func @main(%arg0: !FHE.eint<4>) -> !FHE.eint<4> {
|
||||
return %6: !FHE.eint<4>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
for (auto a : values_3bits()) {
|
||||
auto res = circuit.call({Tensor<uint64_t>(a)});
|
||||
ASSERT_TRUE(res.has_value());
|
||||
@@ -182,7 +182,7 @@ TEST(SDFG_unit_tests, tlu_batched) {
|
||||
return %res : tensor<3x3x!FHE.eint<3>>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
auto t = Tensor<uint64_t>({0, 1, 2, 3, 0, 1, 2, 3, 0}, {3, 3});
|
||||
auto expected = Tensor<uint64_t>({1, 3, 5, 7, 1, 3, 5, 7, 1}, {3, 3});
|
||||
auto res = circuit.call({t});
|
||||
@@ -201,7 +201,7 @@ TEST(SDFG_unit_tests, batched_tree) {
|
||||
return %res : tensor<3x3x!FHE.eint<4>>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
auto t = Tensor<uint64_t>({0, 1, 2, 3, 0, 1, 2, 3, 0}, {3, 3});
|
||||
auto a1 = Tensor<uint8_t>({0, 1, 0, 0, 1, 0, 0, 1, 0}, {3, 3});
|
||||
auto a2 = Tensor<uint8_t>({1, 0, 1, 1, 0, 1, 1, 0, 1}, {3, 3});
|
||||
@@ -226,7 +226,7 @@ TEST(SDFG_unit_tests, batched_tree_mapped_tlu) {
|
||||
return %res : tensor<3x3x!FHE.eint<4>>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
auto t = Tensor<uint64_t>({0, 1, 2, 3, 0, 1, 2, 3, 0}, {3, 3});
|
||||
auto a1 = Tensor<uint8_t>({0, 1, 0, 0, 1, 0, 0, 1, 0}, {3, 3});
|
||||
auto a2 = Tensor<uint8_t>({1, 0, 1, 1, 0, 1, 1, 0, 1}, {3, 3});
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/TestLib/TestCircuit.h"
|
||||
#include "concretelang/TestLib/TestProgram.h"
|
||||
|
||||
#include "tests_tools/GtestEnvironment.h"
|
||||
#include "tests_tools/assert.h"
|
||||
@@ -18,17 +18,17 @@ using namespace concretelang::testlib;
|
||||
testing::Environment *const dfr_env =
|
||||
testing::AddGlobalTestEnvironment(new DFREnvironment);
|
||||
|
||||
Result<TestCircuit> setupTestCircuit(std::string source,
|
||||
Result<TestProgram> setupTestProgram(std::string source,
|
||||
std::string funcname = FUNCNAME) {
|
||||
std::vector<std::string> sources = {source};
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
|
||||
mlir::concretelang::CompilationContext::createShared();
|
||||
mlir::concretelang::CompilerEngine ce{ccx};
|
||||
mlir::concretelang::CompilationOptions options(funcname);
|
||||
mlir::concretelang::CompilationOptions options;
|
||||
#ifdef CONCRETELANG_DATAFLOW_TESTING_ENABLED
|
||||
options.dataflowParallelize = true;
|
||||
#endif
|
||||
TestCircuit testCircuit(options);
|
||||
TestProgram testCircuit(options);
|
||||
OUTCOME_TRYV(testCircuit.compile({source}));
|
||||
OUTCOME_TRYV(testCircuit.generateKeyset());
|
||||
return std::move(testCircuit);
|
||||
@@ -67,7 +67,7 @@ func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
return %arg0: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
for (auto a : values_7bits()) {
|
||||
auto res = circuit.call({Tensor<uint64_t>(a)});
|
||||
ASSERT_TRUE(res.has_value());
|
||||
@@ -82,7 +82,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
return %arg0: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
for (auto a : values_7bits())
|
||||
for (auto b : values_7bits()) {
|
||||
if (a > b) {
|
||||
@@ -102,7 +102,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
for (auto a : values_7bits())
|
||||
for (auto b : values_7bits()) {
|
||||
if (a > b) {
|
||||
@@ -122,7 +122,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
auto res = circuit.call({Tensor<uint64_t>(1)});
|
||||
ASSERT_FALSE(res.has_value());
|
||||
}
|
||||
@@ -134,7 +134,7 @@ func.func @main(%arg0: !FHE.eint<7>) -> tensor<1x!FHE.eint<7>> {
|
||||
return %1: tensor<1x!FHE.eint<7>>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
for (auto a : values_7bits()) {
|
||||
auto res = circuit.call({Tensor<uint64_t>(a)});
|
||||
EXPECT_TRUE(res);
|
||||
@@ -150,7 +150,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> tensor<2x!FHE.eint<
|
||||
return %1: tensor<2x!FHE.eint<7>>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
for (auto a : values_7bits()) {
|
||||
auto res = circuit.call({Tensor<uint64_t>(a), Tensor<uint64_t>(a + 1)});
|
||||
EXPECT_TRUE(res);
|
||||
@@ -168,7 +168,7 @@ func.func @main(%arg0: tensor<1x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
for (uint8_t a : values_7bits()) {
|
||||
auto ta = Tensor<uint64_t>({a}, {1});
|
||||
auto res = circuit.call({ta});
|
||||
@@ -184,7 +184,7 @@ func.func @main(%arg0: tensor<3x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> {
|
||||
return %arg0: tensor<3x!FHE.eint<7>>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
auto ta = Tensor<uint64_t>({1, 2, 3}, {3});
|
||||
auto res = circuit.call({ta});
|
||||
ASSERT_TRUE(res);
|
||||
@@ -202,7 +202,7 @@ func.func @main(%arg0: tensor<3x!FHE.eint<7>>, %arg1: tensor<3x!FHE.eint<7>>) ->
|
||||
return %3: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
auto ta = Tensor<uint64_t>({1, 2, 3}, {3});
|
||||
auto tb = Tensor<uint64_t>({5, 7, 9}, {3});
|
||||
auto res = circuit.call({ta, tb});
|
||||
@@ -219,7 +219,7 @@ func.func @main(%arg0: tensor<2x3x!FHE.eint<7>>) -> tensor<2x3x!FHE.eint<7>> {
|
||||
return %arg0: tensor<2x3x!FHE.eint<7>>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
auto ta = Tensor<uint64_t>({1, 2, 3, 4, 5, 6}, {2, 3});
|
||||
auto res = circuit.call({ta});
|
||||
ASSERT_TRUE(res);
|
||||
@@ -233,7 +233,7 @@ func.func @main(%arg0: tensor<2x3x1x!FHE.eint<7>>) -> tensor<2x3x1x!FHE.eint<7>>
|
||||
return %arg0: tensor<2x3x1x!FHE.eint<7>>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
auto ta = Tensor<uint64_t>({1, 2, 3, 4, 5, 6}, {2, 3, 1});
|
||||
auto res = circuit.call({ta});
|
||||
ASSERT_TRUE(res);
|
||||
@@ -248,7 +248,7 @@ func.func @main(%arg0: tensor<2x3x1x!FHE.eint<7>>, %arg1: tensor<2x3x1x!FHE.eint
|
||||
return %1: tensor<2x3x1x!FHE.eint<7>>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
auto ta = Tensor<uint64_t>({1, 2, 3, 4, 5, 6}, {2, 3, 1});
|
||||
auto res = circuit.call({ta, ta});
|
||||
ASSERT_TRUE(res);
|
||||
@@ -312,7 +312,7 @@ func.func @main(%arg0: !FHE.eint<6>, %arg1: !FHE.eint<3>) -> !FHE.eint<6> {
|
||||
return %a_plus_b: !FHE.eint<6>
|
||||
}
|
||||
)";
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
|
||||
for (auto a : values_6bits())
|
||||
for (auto b : values_3bits()) {
|
||||
auto res = circuit.call({Tensor<uint64_t>(a), Tensor<uint64_t>(b)});
|
||||
|
||||
@@ -599,6 +599,10 @@ impl OperationDag {
|
||||
self.0.dump()
|
||||
}
|
||||
|
||||
fn concat(&mut self, other: &Self) {
|
||||
self.0.concat(&other.0);
|
||||
}
|
||||
|
||||
fn tag_operator_as_output(&mut self, op: ffi::OperatorIndex) {
|
||||
self.0.tag_operator_as_output(op.into());
|
||||
}
|
||||
@@ -741,6 +745,8 @@ mod ffi {
|
||||
|
||||
fn dump(self: &OperationDag) -> String;
|
||||
|
||||
fn concat(self: &mut OperationDag, other: &OperationDag);
|
||||
|
||||
#[namespace = "concrete_optimizer::dag"]
|
||||
fn dump(self: &CircuitSolution) -> String;
|
||||
|
||||
|
||||
@@ -976,6 +976,7 @@ struct OperationDag final : public ::rust::Opaque {
|
||||
::concrete_optimizer::dag::OperatorIndex add_unsafe_cast_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept;
|
||||
::concrete_optimizer::dag::DagSolution optimize(::concrete_optimizer::Options options) const noexcept;
|
||||
::rust::String dump() const noexcept;
|
||||
void concat(::concrete_optimizer::OperationDag const &other) noexcept;
|
||||
void tag_operator_as_output(::concrete_optimizer::dag::OperatorIndex op) noexcept;
|
||||
::concrete_optimizer::dag::CircuitSolution optimize_multi(::concrete_optimizer::Options options) const noexcept;
|
||||
~OperationDag() = delete;
|
||||
@@ -1289,6 +1290,8 @@ extern "C" {
|
||||
void concrete_optimizer$cxxbridge1$OperationDag$optimize(::concrete_optimizer::OperationDag const &self, ::concrete_optimizer::Options options, ::concrete_optimizer::dag::DagSolution *return$) noexcept;
|
||||
|
||||
void concrete_optimizer$cxxbridge1$OperationDag$dump(::concrete_optimizer::OperationDag const &self, ::rust::String *return$) noexcept;
|
||||
|
||||
void concrete_optimizer$cxxbridge1$OperationDag$concat(::concrete_optimizer::OperationDag &self, ::concrete_optimizer::OperationDag const &other) noexcept;
|
||||
} // extern "C"
|
||||
|
||||
namespace dag {
|
||||
@@ -1390,6 +1393,10 @@ namespace dag {
|
||||
return ::std::move(return$.value);
|
||||
}
|
||||
|
||||
void OperationDag::concat(::concrete_optimizer::OperationDag const &other) noexcept {
|
||||
concrete_optimizer$cxxbridge1$OperationDag$concat(*this, other);
|
||||
}
|
||||
|
||||
namespace dag {
|
||||
::rust::String CircuitSolution::dump() const noexcept {
|
||||
::rust::MaybeUninit<::rust::String> return$;
|
||||
|
||||
@@ -957,6 +957,7 @@ struct OperationDag final : public ::rust::Opaque {
|
||||
::concrete_optimizer::dag::OperatorIndex add_unsafe_cast_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept;
|
||||
::concrete_optimizer::dag::DagSolution optimize(::concrete_optimizer::Options options) const noexcept;
|
||||
::rust::String dump() const noexcept;
|
||||
void concat(::concrete_optimizer::OperationDag const &other) noexcept;
|
||||
void tag_operator_as_output(::concrete_optimizer::dag::OperatorIndex op) noexcept;
|
||||
::concrete_optimizer::dag::CircuitSolution optimize_multi(::concrete_optimizer::Options options) const noexcept;
|
||||
~OperationDag() = delete;
|
||||
|
||||
@@ -252,6 +252,29 @@ impl OperationDag {
|
||||
self.add_lut(rounded, table, out_precision)
|
||||
}
|
||||
|
||||
/// Concatenates two dags into a single one (with two disconnected clusters).
|
||||
pub fn concat(&mut self, other: &Self) {
|
||||
let length = self.len();
|
||||
self.operators.extend(other.operators.iter().cloned());
|
||||
self.out_precisions.extend(other.out_precisions.iter());
|
||||
self.out_shapes.extend(other.out_shapes.iter().cloned());
|
||||
self.output_tags.extend(other.output_tags.iter());
|
||||
self.operators[length..]
|
||||
.iter_mut()
|
||||
.for_each(|node| match node {
|
||||
Operator::Lut { ref mut input, .. }
|
||||
| Operator::UnsafeCast { ref mut input, .. }
|
||||
| Operator::Round { ref mut input, .. } => {
|
||||
input.i += length;
|
||||
}
|
||||
Operator::Dot { ref mut inputs, .. }
|
||||
| Operator::LevelledOp { ref mut inputs, .. } => {
|
||||
inputs.iter_mut().for_each(|inp| inp.i += length);
|
||||
}
|
||||
_ => (),
|
||||
});
|
||||
}
|
||||
|
||||
/// Returns an iterator over input nodes indices.
|
||||
pub(crate) fn get_input_index_iter(&self) -> impl Iterator<Item = usize> + '_ {
|
||||
self.operators
|
||||
@@ -315,6 +338,7 @@ impl OperationDag {
|
||||
DotKind::Broadcast { shape } => shape,
|
||||
DotKind::Unsupported { .. } => {
|
||||
let weights_shape = &weights.shape;
|
||||
|
||||
println!();
|
||||
println!();
|
||||
println!("Error diagnostic on dot operation:");
|
||||
@@ -347,6 +371,33 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::dag::operator::Shape;
|
||||
|
||||
#[test]
|
||||
fn graph_concat() {
|
||||
let mut graph1 = OperationDag::new();
|
||||
let a = graph1.add_input(1, Shape::number());
|
||||
let b = graph1.add_input(1, Shape::number());
|
||||
let c = graph1.add_dot([a, b], [1, 1]);
|
||||
let _d = graph1.add_lut(c, FunctionTable::UNKWOWN, 1);
|
||||
let mut graph2 = OperationDag::new();
|
||||
let a = graph2.add_input(2, Shape::number());
|
||||
let b = graph2.add_input(2, Shape::number());
|
||||
let c = graph2.add_dot([a, b], [2, 2]);
|
||||
let _d = graph2.add_lut(c, FunctionTable::UNKWOWN, 2);
|
||||
graph1.concat(&graph2);
|
||||
|
||||
let mut graph3 = OperationDag::new();
|
||||
let a = graph3.add_input(1, Shape::number());
|
||||
let b = graph3.add_input(1, Shape::number());
|
||||
let c = graph3.add_dot([a, b], [1, 1]);
|
||||
let _d = graph3.add_lut(c, FunctionTable::UNKWOWN, 1);
|
||||
let a = graph3.add_input(2, Shape::number());
|
||||
let b = graph3.add_input(2, Shape::number());
|
||||
let c = graph3.add_dot([a, b], [2, 2]);
|
||||
let _d = graph3.add_lut(c, FunctionTable::UNKWOWN, 2);
|
||||
|
||||
assert_eq!(graph1, graph3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn graph_creation() {
|
||||
let mut graph = OperationDag::new();
|
||||
|
||||
Reference in New Issue
Block a user