mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(compiler): TensorLambdaArgument: Copy input data instead of using a reference
`TensorLambdaArgument` uses an `llvm::MutableArrayRef` to reference the tensor values. This prevents temporary tensors from being used as an argument, due to the data of the `TensorLambdaArgument` being accessed after the destruction of the temporary. This patch changes the type of the data field of `TensorLambdaArgument` from `llvm::MutableArrayRef` to `std::vector` and causes input data to be copied in order to guarantee that all data remains available until invocation of the destructor.
This commit is contained in:
committed by
Ayoub Benaissa
parent
a670ee3f85
commit
c92f047721
@@ -124,14 +124,16 @@ public:
|
||||
// multi-dimensional tensor with the sizes of the dimensions
|
||||
// specified in `dimensions`.
|
||||
TensorLambdaArgument(
|
||||
llvm::MutableArrayRef<typename ScalarArgumentT::value_type> value,
|
||||
llvm::ArrayRef<typename ScalarArgumentT::value_type> value,
|
||||
llvm::ArrayRef<int64_t> dimensions)
|
||||
: value(value), dimensions(dimensions.vec()) {}
|
||||
: dimensions(dimensions.vec()) {
|
||||
std::copy(value.begin(), value.end(), std::back_inserter(this->value));
|
||||
}
|
||||
|
||||
// Construct a one-dimensional tensor argument from the
|
||||
// array `value`.
|
||||
TensorLambdaArgument(
|
||||
llvm::MutableArrayRef<typename ScalarArgumentT::value_type> value)
|
||||
llvm::ArrayRef<typename ScalarArgumentT::value_type> value)
|
||||
: TensorLambdaArgument(value, {(int64_t)value.size()}) {}
|
||||
|
||||
const std::vector<int64_t> &getDimensions() const { return this->dimensions; }
|
||||
@@ -149,15 +151,22 @@ public:
|
||||
return accu;
|
||||
}
|
||||
|
||||
// Returns a bare pointer to the linearized values of the tensor.
|
||||
typename ScalarArgumentT::value_type *getValue() const {
|
||||
// Returns a bare pointer to the linearized values of the tensor
|
||||
// (constant version).
|
||||
const typename ScalarArgumentT::value_type *getValue() const {
|
||||
return this->value.data();
|
||||
}
|
||||
|
||||
// Returns a bare pointer to the linearized values of the tensor (mutable
|
||||
// version).
|
||||
typename ScalarArgumentT::value_type *getValue() {
|
||||
return this->value.data();
|
||||
}
|
||||
|
||||
static char ID;
|
||||
|
||||
protected:
|
||||
llvm::MutableArrayRef<typename ScalarArgumentT::value_type> value;
|
||||
std::vector<typename ScalarArgumentT::value_type> value;
|
||||
std::vector<int64_t> dimensions;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user