diff --git a/compiler/include/zamalang/Support/Jit.h b/compiler/include/zamalang/Support/Jit.h index a6949b694..6af5074e8 100644 --- a/compiler/include/zamalang/Support/Jit.h +++ b/compiler/include/zamalang/Support/Jit.h @@ -28,14 +28,15 @@ public: // Set a argument at the given pos as a 1D tensor of T. template - llvm::Error setArg(size_t pos, T *data, int64_t dim1) { + llvm::Error setArg(size_t pos, const T *data, int64_t dim1) { return setArg(pos, data, llvm::ArrayRef(&dim1, 1)); } // Set a argument at the given pos as a tensor of T. template - llvm::Error setArg(size_t pos, T *data, llvm::ArrayRef shape) { - return setArg(pos, 8 * sizeof(T), static_cast(data), shape); + llvm::Error setArg(size_t pos, const T *data, + llvm::ArrayRef shape) { + return setArg(pos, 8 * sizeof(T), static_cast(data), shape); } // Get the result at the given pos as an uint64_t. @@ -60,14 +61,14 @@ public: llvm::Expected getResultVectorSize(size_t pos); private: - llvm::Error setArg(size_t pos, size_t width, void *data, + llvm::Error setArg(size_t pos, size_t width, const void *data, llvm::ArrayRef shape); friend JITLambda; // Store the pointer on inputs values and outputs values std::vector rawArg; // Store the values of inputs - std::vector inputs; + std::vector inputs; // Store the values of outputs std::vector outputs; // Store the input gates description and the offset of the argument. diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 91ab0b826..2b4c65234 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -89,7 +89,7 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) { // dimension of the tensor. numInputs = numInputs + 2 * keySet.inputGate(i).shape.dimensions.size(); } - inputs = std::vector(numInputs); + inputs = std::vector(numInputs); } // Setting the outputs @@ -180,7 +180,8 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) { return llvm::Error::success(); } -llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data, +llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, + const void *data, llvm::ArrayRef shape) { auto gate = inputGates[pos]; auto info = std::get<0>(gate);