From a670ee3f8534c336e96b64c2780cfecb0fa18fb2 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Thu, 4 Nov 2021 17:00:40 +0100 Subject: [PATCH] enhance(compiler): Use const pointers in JITLambda::Arguments::setArg All results in code compiled by zamacompiler are passed as return values, which means that all tensors passed as function arguments are constant inputs that are never written. This patch changes the arguments used as data pointers for input tensors in `JITLambda::Arguments::setArg()` from `void*` to `const void*` to emphasize their use as inputs and to allow for constant arrays to be passed as function inputs. --- compiler/include/zamalang/Support/Jit.h | 11 ++++++----- compiler/lib/Support/Jit.cpp | 5 +++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/compiler/include/zamalang/Support/Jit.h b/compiler/include/zamalang/Support/Jit.h index a6949b694..6af5074e8 100644 --- a/compiler/include/zamalang/Support/Jit.h +++ b/compiler/include/zamalang/Support/Jit.h @@ -28,14 +28,15 @@ public: // Set a argument at the given pos as a 1D tensor of T. template - llvm::Error setArg(size_t pos, T *data, int64_t dim1) { + llvm::Error setArg(size_t pos, const T *data, int64_t dim1) { return setArg(pos, data, llvm::ArrayRef(&dim1, 1)); } // Set a argument at the given pos as a tensor of T. template - llvm::Error setArg(size_t pos, T *data, llvm::ArrayRef shape) { - return setArg(pos, 8 * sizeof(T), static_cast(data), shape); + llvm::Error setArg(size_t pos, const T *data, + llvm::ArrayRef shape) { + return setArg(pos, 8 * sizeof(T), static_cast(data), shape); } // Get the result at the given pos as an uint64_t. @@ -60,14 +61,14 @@ public: llvm::Expected getResultVectorSize(size_t pos); private: - llvm::Error setArg(size_t pos, size_t width, void *data, + llvm::Error setArg(size_t pos, size_t width, const void *data, llvm::ArrayRef shape); friend JITLambda; // Store the pointer on inputs values and outputs values std::vector rawArg; // Store the values of inputs - std::vector inputs; + std::vector inputs; // Store the values of outputs std::vector outputs; // Store the input gates description and the offset of the argument. diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 91ab0b826..2b4c65234 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -89,7 +89,7 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) { // dimension of the tensor. numInputs = numInputs + 2 * keySet.inputGate(i).shape.dimensions.size(); } - inputs = std::vector(numInputs); + inputs = std::vector(numInputs); } // Setting the outputs @@ -180,7 +180,8 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) { return llvm::Error::success(); } -llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data, +llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, + const void *data, llvm::ArrayRef shape) { auto gate = inputGates[pos]; auto info = std::get<0>(gate);