mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 04:35:03 -05:00
enhance(compiler/runtime): Add runtime tools to handle tensor inputs and outputs
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
#ifndef ZAMALANG_CONVERSION_GLOBALFHECONTEXT_PATTERNS_H_
|
||||
#define ZAMALANG_CONVERSION_GLOBALFHECONTEXT_PATTERNS_H_
|
||||
#ifndef ZAMALANG_CONVERSION_GLOBALFHECONTEXT_H_
|
||||
#define ZAMALANG_CONVERSION_GLOBALFHECONTEXT_H_
|
||||
#include <cstddef>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@@ -56,7 +56,10 @@ struct EncryptionGate {
|
||||
};
|
||||
|
||||
struct CircuitGateShape {
|
||||
uint64_t size;
|
||||
// Width of the scalar value
|
||||
size_t width;
|
||||
// Size of the buffer
|
||||
size_t size;
|
||||
};
|
||||
|
||||
struct CircuitGate {
|
||||
|
||||
@@ -15,6 +15,9 @@
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
/// CompilerEngine is an tools that provides tools to implements the compilation
|
||||
/// flow and manage the compilation flow state.
|
||||
class CompilerEngine {
|
||||
public:
|
||||
CompilerEngine() {
|
||||
@@ -26,10 +29,16 @@ public:
|
||||
delete context;
|
||||
}
|
||||
|
||||
// Compile an MLIR input
|
||||
llvm::Expected<mlir::LogicalResult> compileFHE(std::string mlir_input);
|
||||
// Compile an mlir programs from it's textual representation.
|
||||
llvm::Error compile(std::string mlirStr);
|
||||
|
||||
// Run the compiled module
|
||||
// Build the jit lambda argument.
|
||||
llvm::Expected<std::unique_ptr<JITLambda::Argument>> buildArgument();
|
||||
|
||||
// Call the compiled function with and argument object.
|
||||
llvm::Error invoke(JITLambda::Argument &arg);
|
||||
|
||||
// Call the compiled function with a list of integer arguments.
|
||||
llvm::Expected<uint64_t> run(std::vector<uint64_t> args);
|
||||
|
||||
// Get a printable representation of the compiled module
|
||||
|
||||
@@ -51,17 +51,54 @@ public:
|
||||
// and decryption operations.
|
||||
static llvm::Expected<std::unique_ptr<Argument>> create(KeySet &keySet);
|
||||
|
||||
// Set the argument at the given pos as a uint64_t.
|
||||
// Set a scalar argument at the given pos as a uint64_t.
|
||||
llvm::Error setArg(size_t pos, uint64_t arg);
|
||||
|
||||
// Set a argument at the given pos as a tensor of int64.
|
||||
llvm::Error setArg(size_t pos, uint64_t *data, size_t size) {
|
||||
return setArg(pos, 64, (void *)data, size);
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a tensor of int32.
|
||||
llvm::Error setArg(size_t pos, uint32_t *data, size_t size) {
|
||||
return setArg(pos, 32, (void *)data, size);
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a tensor of int32.
|
||||
llvm::Error setArg(size_t pos, uint16_t *data, size_t size) {
|
||||
return setArg(pos, 16, (void *)data, size);
|
||||
}
|
||||
|
||||
// Set a tensor argument at the given pos as a uint64_t.
|
||||
llvm::Error setArg(size_t pos, uint8_t *data, size_t size) {
|
||||
return setArg(pos, 8, (void *)data, size);
|
||||
}
|
||||
|
||||
// Get the result at the given pos as an uint64_t.
|
||||
llvm::Error getResult(size_t pos, uint64_t &res);
|
||||
|
||||
// Fill the result.
|
||||
llvm::Error getResult(size_t pos, uint64_t *res, size_t size);
|
||||
|
||||
private:
|
||||
llvm::Error setArg(size_t pos, size_t width, void *data, size_t size);
|
||||
|
||||
friend JITLambda;
|
||||
// Store the pointer on inputs values and outputs values
|
||||
std::vector<void *> rawArg;
|
||||
// Store the values of inputs
|
||||
std::vector<void *> inputs;
|
||||
std::vector<void *> results;
|
||||
// Store the values of outputs
|
||||
std::vector<void *> outputs;
|
||||
// Store the input gates description and the offset of the argument.
|
||||
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> inputGates;
|
||||
// Store the outputs gates description and the offset of the argument.
|
||||
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> outputGates;
|
||||
// Store allocated lwe ciphertexts (for free)
|
||||
std::vector<LweCiphertext_u64 *> allocatedCiphertexts;
|
||||
// Store buffers of ciphertexts
|
||||
std::vector<LweCiphertext_u64 **> ciphertextBuffers;
|
||||
|
||||
KeySet &keySet;
|
||||
};
|
||||
JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name)
|
||||
|
||||
@@ -37,6 +37,9 @@ public:
|
||||
size_t numInputs() { return inputs.size(); }
|
||||
size_t numOutputs() { return outputs.size(); }
|
||||
|
||||
CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); }
|
||||
CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); }
|
||||
|
||||
protected:
|
||||
llvm::Error generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
|
||||
SecretRandomGenerator *generator);
|
||||
|
||||
Reference in New Issue
Block a user