mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
cleanup(compiler/jit): Removing dead code since the preparation of arguments has been factorized thanks the EncryptedArguments
This commit is contained in:
@@ -24,92 +24,6 @@ namespace clientlib = ::concretelang::clientlib;
|
||||
/// of the module.
|
||||
class JITLambda {
|
||||
public:
|
||||
class Argument {
|
||||
public:
|
||||
Argument(KeySet &keySet);
|
||||
~Argument();
|
||||
|
||||
// Create lambda Argument that use the given KeySet to perform encryption
|
||||
// and decryption operations.
|
||||
static llvm::Expected<std::unique_ptr<Argument>> create(KeySet &keySet);
|
||||
|
||||
// Set a scalar argument at the given pos as a uint64_t.
|
||||
llvm::Error setArg(size_t pos, uint64_t arg);
|
||||
|
||||
// Set a argument at the given pos as a 1D tensor of T.
|
||||
template <typename T>
|
||||
llvm::Error setArg(size_t pos, const T *data, int64_t dim1) {
|
||||
return setArg<T>(pos, data, llvm::ArrayRef<int64_t>(&dim1, 1));
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a tensor of T.
|
||||
template <typename T>
|
||||
llvm::Error setArg(size_t pos, const T *data,
|
||||
llvm::ArrayRef<int64_t> shape) {
|
||||
return setArg(pos, 8 * sizeof(T), static_cast<const void *>(data), shape);
|
||||
}
|
||||
|
||||
// Get the result at the given pos as an uint64_t.
|
||||
llvm::Error getResult(size_t pos, uint64_t &res);
|
||||
|
||||
// Specifies the type of a result
|
||||
enum ResultType { SCALAR, TENSOR };
|
||||
|
||||
// Returns the result type at position `pos`. If pos is invalid,
|
||||
// an error is returned.
|
||||
llvm::Expected<enum ResultType> getResultType(size_t pos);
|
||||
|
||||
// Get a result for tensors, fill the `res` buffer with the value of the
|
||||
// tensor result.
|
||||
// Returns an error:
|
||||
// - if the result is a scalar
|
||||
// - or the size of the `res` buffser doesn't match the size of the tensor.
|
||||
template <typename T>
|
||||
llvm::Error getResult(size_t pos, T *res, size_t size) {
|
||||
return std::move(this->getResult(pos, res, sizeof(T), size));
|
||||
}
|
||||
|
||||
llvm::Error getResult(size_t pos, void *res, size_t elementSize,
|
||||
size_t numElements);
|
||||
|
||||
// Returns the number of elements of the result vector at position
|
||||
// `pos` or an error if the result is a scalar value
|
||||
llvm::Expected<size_t> getResultVectorSize(size_t pos);
|
||||
|
||||
// Returns the width of the result scalar at position `pos` or the
|
||||
// width of the scalar values of a vector if the result at
|
||||
// position `pos` is a tensor.
|
||||
llvm::Expected<size_t> getResultWidth(size_t pos);
|
||||
|
||||
// Returns the dimensions of the result tensor at position `pos` or
|
||||
// an error if the result is a scalar value
|
||||
llvm::Expected<std::vector<int64_t>> getResultDimensions(size_t pos);
|
||||
|
||||
private:
|
||||
// Verify if lambda can accept a n-th argument.
|
||||
llvm::Error emitErrorIfTooManyArgs(size_t n);
|
||||
llvm::Error setArg(size_t pos, size_t width, const void *data,
|
||||
llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
friend JITLambda;
|
||||
// Store the pointer on inputs values and outputs values
|
||||
std::vector<void *> rawArg;
|
||||
// Store the values of inputs
|
||||
std::vector<const void *> inputs;
|
||||
// Store the values of outputs
|
||||
std::vector<void *> outputs;
|
||||
// Store the input gates description and the offset of the argument.
|
||||
std::vector<std::tuple<CircuitGate, size_t /*offset*/>> inputGates;
|
||||
// Store the outputs gates description and the offset of the argument.
|
||||
std::vector<std::tuple<CircuitGate, size_t /*offset*/>> outputGates;
|
||||
// Store allocated lwe ciphertexts (for free)
|
||||
std::vector<uint64_t *> allocatedCiphertexts;
|
||||
// Store buffers of ciphertexts
|
||||
std::vector<uint64_t *> ciphertextBuffers;
|
||||
|
||||
KeySet &keySet;
|
||||
RuntimeContext context;
|
||||
};
|
||||
JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name)
|
||||
: type(type), name(name){};
|
||||
|
||||
|
||||
@@ -341,38 +341,6 @@ public:
|
||||
}
|
||||
|
||||
protected:
|
||||
template <int pos>
|
||||
inline llvm::Error addArgs(JITLambda::Argument *jitArgs) {
|
||||
// base case -- nothing to do
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
// Recursive case for scalars: extract first scalar argument from
|
||||
// parameter pack and forward rest
|
||||
template <int pos, typename ArgT, typename... Ts>
|
||||
inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT arg,
|
||||
Ts... remainder) {
|
||||
if (auto err = jitArgs->setArg(pos, arg)) {
|
||||
return StreamStringError()
|
||||
<< "Cannot push scalar argument " << pos << ": " << err;
|
||||
}
|
||||
|
||||
return this->addArgs<pos + 1>(jitArgs, remainder...);
|
||||
}
|
||||
|
||||
// Recursive case for tensors: extract pointer and size from
|
||||
// parameter pack and forward rest
|
||||
template <int pos, typename ArgT, typename... Ts>
|
||||
inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT *arg,
|
||||
size_t size, Ts... remainder) {
|
||||
if (auto err = jitArgs->setArg(pos, arg, size)) {
|
||||
return StreamStringError()
|
||||
<< "Cannot push tensor argument " << pos << ": " << err;
|
||||
}
|
||||
|
||||
return this->addArgs<pos + 1>(jitArgs, remainder...);
|
||||
}
|
||||
|
||||
std::unique_ptr<JITLambda> innerLambda;
|
||||
std::unique_ptr<KeySet> keySet;
|
||||
std::shared_ptr<CompilationContext> compilationContext;
|
||||
|
||||
Reference in New Issue
Block a user