diff --git a/compiler/include/zamalang/Support/LambdaArgument.h b/compiler/include/zamalang/Support/LambdaArgument.h index fc26d5d64..c2b15d600 100644 --- a/compiler/include/zamalang/Support/LambdaArgument.h +++ b/compiler/include/zamalang/Support/LambdaArgument.h @@ -124,14 +124,16 @@ public: // multi-dimensional tensor with the sizes of the dimensions // specified in `dimensions`. TensorLambdaArgument( - llvm::MutableArrayRef value, + llvm::ArrayRef value, llvm::ArrayRef dimensions) - : value(value), dimensions(dimensions.vec()) {} + : dimensions(dimensions.vec()) { + std::copy(value.begin(), value.end(), std::back_inserter(this->value)); + } // Construct a one-dimensional tensor argument from the // array `value`. TensorLambdaArgument( - llvm::MutableArrayRef value) + llvm::ArrayRef value) : TensorLambdaArgument(value, {(int64_t)value.size()}) {} const std::vector &getDimensions() const { return this->dimensions; } @@ -149,15 +151,22 @@ public: return accu; } - // Returns a bare pointer to the linearized values of the tensor. - typename ScalarArgumentT::value_type *getValue() const { + // Returns a bare pointer to the linearized values of the tensor + // (constant version). + const typename ScalarArgumentT::value_type *getValue() const { + return this->value.data(); + } + + // Returns a bare pointer to the linearized values of the tensor (mutable + // version). + typename ScalarArgumentT::value_type *getValue() { return this->value.data(); } static char ID; protected: - llvm::MutableArrayRef value; + std::vector value; std::vector dimensions; };