From c92f04772132cc573c2bf0d503b37301e663b3c2 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Thu, 4 Nov 2021 17:06:40 +0100 Subject: [PATCH] fix(compiler): TensorLambdaArgument: Copy input data instead of using a reference `TensorLambdaArgument` uses an `llvm::MutableArrayRef` to reference the tensor values. This prevents temporary tensors from being used as an argument, due to the data of the `TensorLambdaArgument` being accessed after the destruction of the temporary. This patch changes the type of the data field of `TensorLambdaArgument` from `llvm::MutableArrayRef` to `std::vector` and causes input data to be copied in order to guarantee that all data remains available until invocation of the destructor. --- .../include/zamalang/Support/LambdaArgument.h | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/compiler/include/zamalang/Support/LambdaArgument.h b/compiler/include/zamalang/Support/LambdaArgument.h index fc26d5d64..c2b15d600 100644 --- a/compiler/include/zamalang/Support/LambdaArgument.h +++ b/compiler/include/zamalang/Support/LambdaArgument.h @@ -124,14 +124,16 @@ public: // multi-dimensional tensor with the sizes of the dimensions // specified in `dimensions`. TensorLambdaArgument( - llvm::MutableArrayRef value, + llvm::ArrayRef value, llvm::ArrayRef dimensions) - : value(value), dimensions(dimensions.vec()) {} + : dimensions(dimensions.vec()) { + std::copy(value.begin(), value.end(), std::back_inserter(this->value)); + } // Construct a one-dimensional tensor argument from the // array `value`. TensorLambdaArgument( - llvm::MutableArrayRef value) + llvm::ArrayRef value) : TensorLambdaArgument(value, {(int64_t)value.size()}) {} const std::vector &getDimensions() const { return this->dimensions; } @@ -149,15 +151,22 @@ public: return accu; } - // Returns a bare pointer to the linearized values of the tensor. - typename ScalarArgumentT::value_type *getValue() const { + // Returns a bare pointer to the linearized values of the tensor + // (constant version). + const typename ScalarArgumentT::value_type *getValue() const { + return this->value.data(); + } + + // Returns a bare pointer to the linearized values of the tensor (mutable + // version). + typename ScalarArgumentT::value_type *getValue() { return this->value.data(); } static char ID; protected: - llvm::MutableArrayRef value; + std::vector value; std::vector dimensions; };