mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
enhance(compiler): Use const pointers in JITLambda::Arguments::setArg
All results in code compiled by zamacompiler are passed as return values, which means that all tensors passed as function arguments are constant inputs that are never written. This patch changes the arguments used as data pointers for input tensors in `JITLambda::Arguments::setArg()` from `void*` to `const void*` to emphasize their use as inputs and to allow for constant arrays to be passed as function inputs.
This commit is contained in:
committed by
Ayoub Benaissa
parent
2033a70ad2
commit
a670ee3f85
@@ -28,14 +28,15 @@ public:
|
||||
|
||||
// Set a argument at the given pos as a 1D tensor of T.
|
||||
template <typename T>
|
||||
llvm::Error setArg(size_t pos, T *data, int64_t dim1) {
|
||||
llvm::Error setArg(size_t pos, const T *data, int64_t dim1) {
|
||||
return setArg<T>(pos, data, llvm::ArrayRef<int64_t>(&dim1, 1));
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a tensor of T.
|
||||
template <typename T>
|
||||
llvm::Error setArg(size_t pos, T *data, llvm::ArrayRef<int64_t> shape) {
|
||||
return setArg(pos, 8 * sizeof(T), static_cast<void *>(data), shape);
|
||||
llvm::Error setArg(size_t pos, const T *data,
|
||||
llvm::ArrayRef<int64_t> shape) {
|
||||
return setArg(pos, 8 * sizeof(T), static_cast<const void *>(data), shape);
|
||||
}
|
||||
|
||||
// Get the result at the given pos as an uint64_t.
|
||||
@@ -60,14 +61,14 @@ public:
|
||||
llvm::Expected<size_t> getResultVectorSize(size_t pos);
|
||||
|
||||
private:
|
||||
llvm::Error setArg(size_t pos, size_t width, void *data,
|
||||
llvm::Error setArg(size_t pos, size_t width, const void *data,
|
||||
llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
friend JITLambda;
|
||||
// Store the pointer on inputs values and outputs values
|
||||
std::vector<void *> rawArg;
|
||||
// Store the values of inputs
|
||||
std::vector<void *> inputs;
|
||||
std::vector<const void *> inputs;
|
||||
// Store the values of outputs
|
||||
std::vector<void *> outputs;
|
||||
// Store the input gates description and the offset of the argument.
|
||||
|
||||
@@ -89,7 +89,7 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
|
||||
// dimension of the tensor.
|
||||
numInputs = numInputs + 2 * keySet.inputGate(i).shape.dimensions.size();
|
||||
}
|
||||
inputs = std::vector<void *>(numInputs);
|
||||
inputs = std::vector<const void *>(numInputs);
|
||||
}
|
||||
|
||||
// Setting the outputs
|
||||
@@ -180,7 +180,8 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data,
|
||||
llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
|
||||
const void *data,
|
||||
llvm::ArrayRef<int64_t> shape) {
|
||||
auto gate = inputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
|
||||
Reference in New Issue
Block a user