fix(compiler): pass dimensions at TensorLambdaArg creation

This commit is contained in:
youben11
2021-11-05 14:52:10 +01:00
committed by Ayoub Benaissa
parent b501e3d6c0
commit 56e261d140
4 changed files with 25 additions and 48 deletions

View File

@@ -1,47 +0,0 @@
#ifndef ZAMALANG_SUPPORT_EXECUTION_ARGUMENT_H
#define ZAMALANG_SUPPORT_EXECUTION_ARGUMENT_H
#include <vector>
namespace mlir {
namespace zamalang {
// Frontend object to abstract the different types of possible arguments,
// namely, integers, and tensors.
class ExecutionArgument {
public:
// There are two possible underlying types for the execution argument, either
// and int, or a tensor
bool isTensor() { return isTensorArg; }
bool isInt() { return !isTensorArg; }
uint8_t *getTensorArgument() { return tensorArg.data(); }
size_t getTensorSize() { return tensorArg.size(); }
uint64_t getIntegerArgument() { return intArg; }
// Create an execution argument from an integer
static std::shared_ptr<ExecutionArgument> create(uint64_t arg) {
return std::shared_ptr<ExecutionArgument>(new ExecutionArgument(arg));
}
// Create an execution argument from a tensor
static std::shared_ptr<ExecutionArgument> create(std::vector<uint8_t> arg) {
return std::shared_ptr<ExecutionArgument>(new ExecutionArgument(arg));
}
private:
ExecutionArgument(int arg) : isTensorArg(false), intArg(arg) {}
ExecutionArgument(std::vector<uint8_t> tensor)
: isTensorArg(true), tensorArg(tensor) {}
uint64_t intArg;
std::vector<uint8_t> tensorArg;
bool isTensorArg;
};
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -60,6 +60,10 @@ public:
// `pos` or an error if the result is a scalar value
llvm::Expected<size_t> getResultVectorSize(size_t pos);
// Returns the dimensions of the result tensor at position `pos` or
// an error if the result is a scalar value
llvm::Expected<std::vector<int64_t>> getResultDimensions(size_t pos);
private:
llvm::Error setArg(size_t pos, size_t width, const void *data,
llvm::ArrayRef<int64_t> shape);

View File

@@ -79,9 +79,14 @@ typedResult(JITLambda::Argument &arguments) {
if (!tensorOrError)
return std::move(tensorOrError.takeError());
llvm::Expected<std::vector<int64_t>> tensorDimOrError =
arguments.getResultDimensions(0);
if (!tensorDimOrError)
return std::move(tensorDimOrError.takeError());
return std::move(
std::make_unique<TensorLambdaArgument<IntLambdaArgument<uint64_t>>>(
*tensorOrError));
*tensorOrError, *tensorDimOrError));
}
}

View File

@@ -345,6 +345,21 @@ llvm::Expected<size_t> JITLambda::Argument::getResultVectorSize(size_t pos) {
return info.shape.size;
}
// Returns the dimensions of the result tensor at position `pos` or
// an error if the result is a scalar value
llvm::Expected<std::vector<int64_t>>
JITLambda::Argument::getResultDimensions(size_t pos) {
auto gate = outputGates[pos];
auto info = std::get<0>(gate);
if (info.shape.size == 0) {
return llvm::createStringError(llvm::inconvertibleErrorCode(),
"Result at pos %zu is not a tensor", pos);
}
return info.shape.dimensions;
}
llvm::Expected<enum JITLambda::Argument::ResultType>
JITLambda::Argument::getResultType(size_t pos) {
if (pos >= outputGates.size()) {