cleanup(compiler/jit): Removing dead code since the preparation of arguments has been factorized thanks the EncryptedArguments

This commit is contained in:
Quentin Bourgerie
2022-03-08 16:30:15 +01:00
parent e5cec23868
commit 1ffd480d07
10 changed files with 12 additions and 586 deletions

View File

@@ -24,92 +24,6 @@ namespace clientlib = ::concretelang::clientlib;
/// of the module.
class JITLambda {
public:
class Argument {
public:
Argument(KeySet &keySet);
~Argument();
// Create lambda Argument that use the given KeySet to perform encryption
// and decryption operations.
static llvm::Expected<std::unique_ptr<Argument>> create(KeySet &keySet);
// Set a scalar argument at the given pos as a uint64_t.
llvm::Error setArg(size_t pos, uint64_t arg);
// Set a argument at the given pos as a 1D tensor of T.
template <typename T>
llvm::Error setArg(size_t pos, const T *data, int64_t dim1) {
return setArg<T>(pos, data, llvm::ArrayRef<int64_t>(&dim1, 1));
}
// Set a argument at the given pos as a tensor of T.
template <typename T>
llvm::Error setArg(size_t pos, const T *data,
llvm::ArrayRef<int64_t> shape) {
return setArg(pos, 8 * sizeof(T), static_cast<const void *>(data), shape);
}
// Get the result at the given pos as an uint64_t.
llvm::Error getResult(size_t pos, uint64_t &res);
// Specifies the type of a result
enum ResultType { SCALAR, TENSOR };
// Returns the result type at position `pos`. If pos is invalid,
// an error is returned.
llvm::Expected<enum ResultType> getResultType(size_t pos);
// Get a result for tensors, fill the `res` buffer with the value of the
// tensor result.
// Returns an error:
// - if the result is a scalar
// - or the size of the `res` buffser doesn't match the size of the tensor.
template <typename T>
llvm::Error getResult(size_t pos, T *res, size_t size) {
return std::move(this->getResult(pos, res, sizeof(T), size));
}
llvm::Error getResult(size_t pos, void *res, size_t elementSize,
size_t numElements);
// Returns the number of elements of the result vector at position
// `pos` or an error if the result is a scalar value
llvm::Expected<size_t> getResultVectorSize(size_t pos);
// Returns the width of the result scalar at position `pos` or the
// width of the scalar values of a vector if the result at
// position `pos` is a tensor.
llvm::Expected<size_t> getResultWidth(size_t pos);
// Returns the dimensions of the result tensor at position `pos` or
// an error if the result is a scalar value
llvm::Expected<std::vector<int64_t>> getResultDimensions(size_t pos);
private:
// Verify if lambda can accept a n-th argument.
llvm::Error emitErrorIfTooManyArgs(size_t n);
llvm::Error setArg(size_t pos, size_t width, const void *data,
llvm::ArrayRef<int64_t> shape);
friend JITLambda;
// Store the pointer on inputs values and outputs values
std::vector<void *> rawArg;
// Store the values of inputs
std::vector<const void *> inputs;
// Store the values of outputs
std::vector<void *> outputs;
// Store the input gates description and the offset of the argument.
std::vector<std::tuple<CircuitGate, size_t /*offset*/>> inputGates;
// Store the outputs gates description and the offset of the argument.
std::vector<std::tuple<CircuitGate, size_t /*offset*/>> outputGates;
// Store allocated lwe ciphertexts (for free)
std::vector<uint64_t *> allocatedCiphertexts;
// Store buffers of ciphertexts
std::vector<uint64_t *> ciphertextBuffers;
KeySet &keySet;
RuntimeContext context;
};
JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name)
: type(type), name(name){};

View File

@@ -341,38 +341,6 @@ public:
}
protected:
template <int pos>
inline llvm::Error addArgs(JITLambda::Argument *jitArgs) {
// base case -- nothing to do
return llvm::Error::success();
}
// Recursive case for scalars: extract first scalar argument from
// parameter pack and forward rest
template <int pos, typename ArgT, typename... Ts>
inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT arg,
Ts... remainder) {
if (auto err = jitArgs->setArg(pos, arg)) {
return StreamStringError()
<< "Cannot push scalar argument " << pos << ": " << err;
}
return this->addArgs<pos + 1>(jitArgs, remainder...);
}
// Recursive case for tensors: extract pointer and size from
// parameter pack and forward rest
template <int pos, typename ArgT, typename... Ts>
inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT *arg,
size_t size, Ts... remainder) {
if (auto err = jitArgs->setArg(pos, arg, size)) {
return StreamStringError()
<< "Cannot push tensor argument " << pos << ": " << err;
}
return this->addArgs<pos + 1>(jitArgs, remainder...);
}
std::unique_ptr<JITLambda> innerLambda;
std::unique_ptr<KeySet> keySet;
std::shared_ptr<CompilationContext> compilationContext;