fix(compiler): TensorLambdaArgument: Copy input data instead of using a reference

`TensorLambdaArgument` uses an `llvm::MutableArrayRef` to reference
the tensor values. This prevents temporary tensors from being used as
an argument, due to the data of the `TensorLambdaArgument` being
accessed after the destruction of the temporary.

This patch changes the type of the data field of
`TensorLambdaArgument` from `llvm::MutableArrayRef` to `std::vector`
and causes input data to be copied in order to guarantee that all data
remains available until invocation of the destructor.
This commit is contained in:
Andi Drebes
2021-11-04 17:06:40 +01:00
committed by Ayoub Benaissa
parent a670ee3f85
commit c92f047721

View File

@@ -124,14 +124,16 @@ public:
// multi-dimensional tensor with the sizes of the dimensions
// specified in `dimensions`.
TensorLambdaArgument(
llvm::MutableArrayRef<typename ScalarArgumentT::value_type> value,
llvm::ArrayRef<typename ScalarArgumentT::value_type> value,
llvm::ArrayRef<int64_t> dimensions)
: value(value), dimensions(dimensions.vec()) {}
: dimensions(dimensions.vec()) {
std::copy(value.begin(), value.end(), std::back_inserter(this->value));
}
// Construct a one-dimensional tensor argument from the
// array `value`.
TensorLambdaArgument(
llvm::MutableArrayRef<typename ScalarArgumentT::value_type> value)
llvm::ArrayRef<typename ScalarArgumentT::value_type> value)
: TensorLambdaArgument(value, {(int64_t)value.size()}) {}
const std::vector<int64_t> &getDimensions() const { return this->dimensions; }
@@ -149,15 +151,22 @@ public:
return accu;
}
// Returns a bare pointer to the linearized values of the tensor.
typename ScalarArgumentT::value_type *getValue() const {
// Returns a bare pointer to the linearized values of the tensor
// (constant version).
const typename ScalarArgumentT::value_type *getValue() const {
return this->value.data();
}
// Returns a bare pointer to the linearized values of the tensor (mutable
// version).
typename ScalarArgumentT::value_type *getValue() {
return this->value.data();
}
static char ID;
protected:
llvm::MutableArrayRef<typename ScalarArgumentT::value_type> value;
std::vector<typename ScalarArgumentT::value_type> value;
std::vector<int64_t> dimensions;
};