enhance(compiler/runtime): Add runtime tools to handle tensor inputs and outputs

This commit is contained in:
Quentin Bourgerie
2021-08-24 15:02:45 +02:00
parent dba76a1e1b
commit af0789f128
11 changed files with 701 additions and 73 deletions

View File

@@ -1,5 +1,5 @@
#ifndef ZAMALANG_CONVERSION_GLOBALFHECONTEXT_PATTERNS_H_
#define ZAMALANG_CONVERSION_GLOBALFHECONTEXT_PATTERNS_H_
#ifndef ZAMALANG_CONVERSION_GLOBALFHECONTEXT_H_
#define ZAMALANG_CONVERSION_GLOBALFHECONTEXT_H_
#include <cstddef>
namespace mlir {

View File

@@ -56,7 +56,10 @@ struct EncryptionGate {
};
struct CircuitGateShape {
uint64_t size;
// Width of the scalar value
size_t width;
// Size of the buffer
size_t size;
};
struct CircuitGate {

View File

@@ -15,6 +15,9 @@
namespace mlir {
namespace zamalang {
/// CompilerEngine is an tools that provides tools to implements the compilation
/// flow and manage the compilation flow state.
class CompilerEngine {
public:
CompilerEngine() {
@@ -26,10 +29,16 @@ public:
delete context;
}
// Compile an MLIR input
llvm::Expected<mlir::LogicalResult> compileFHE(std::string mlir_input);
// Compile an mlir programs from it's textual representation.
llvm::Error compile(std::string mlirStr);
// Run the compiled module
// Build the jit lambda argument.
llvm::Expected<std::unique_ptr<JITLambda::Argument>> buildArgument();
// Call the compiled function with and argument object.
llvm::Error invoke(JITLambda::Argument &arg);
// Call the compiled function with a list of integer arguments.
llvm::Expected<uint64_t> run(std::vector<uint64_t> args);
// Get a printable representation of the compiled module

View File

@@ -51,17 +51,54 @@ public:
// and decryption operations.
static llvm::Expected<std::unique_ptr<Argument>> create(KeySet &keySet);
// Set the argument at the given pos as a uint64_t.
// Set a scalar argument at the given pos as a uint64_t.
llvm::Error setArg(size_t pos, uint64_t arg);
// Set a argument at the given pos as a tensor of int64.
llvm::Error setArg(size_t pos, uint64_t *data, size_t size) {
return setArg(pos, 64, (void *)data, size);
}
// Set a argument at the given pos as a tensor of int32.
llvm::Error setArg(size_t pos, uint32_t *data, size_t size) {
return setArg(pos, 32, (void *)data, size);
}
// Set a argument at the given pos as a tensor of int32.
llvm::Error setArg(size_t pos, uint16_t *data, size_t size) {
return setArg(pos, 16, (void *)data, size);
}
// Set a tensor argument at the given pos as a uint64_t.
llvm::Error setArg(size_t pos, uint8_t *data, size_t size) {
return setArg(pos, 8, (void *)data, size);
}
// Get the result at the given pos as an uint64_t.
llvm::Error getResult(size_t pos, uint64_t &res);
// Fill the result.
llvm::Error getResult(size_t pos, uint64_t *res, size_t size);
private:
llvm::Error setArg(size_t pos, size_t width, void *data, size_t size);
friend JITLambda;
// Store the pointer on inputs values and outputs values
std::vector<void *> rawArg;
// Store the values of inputs
std::vector<void *> inputs;
std::vector<void *> results;
// Store the values of outputs
std::vector<void *> outputs;
// Store the input gates description and the offset of the argument.
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> inputGates;
// Store the outputs gates description and the offset of the argument.
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> outputGates;
// Store allocated lwe ciphertexts (for free)
std::vector<LweCiphertext_u64 *> allocatedCiphertexts;
// Store buffers of ciphertexts
std::vector<LweCiphertext_u64 **> ciphertextBuffers;
KeySet &keySet;
};
JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name)

View File

@@ -37,6 +37,9 @@ public:
size_t numInputs() { return inputs.size(); }
size_t numOutputs() { return outputs.size(); }
CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); }
CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); }
protected:
llvm::Error generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
SecretRandomGenerator *generator);