enhance(compiler): Add function JITLambda::Arguments::getResultWidth

Add a new method `JITLambda::Arguments::getResultWidth` returning the
width of a scalar result or the element type of a tensor result at a
given position.
This commit is contained in:
Andi Drebes
2021-11-09 12:02:39 +01:00
committed by Quentin Bourgerie
parent e4cd340e36
commit 3118983287
2 changed files with 23 additions and 0 deletions

View File

@@ -66,6 +66,11 @@ public:
// `pos` or an error if the result is a scalar value
llvm::Expected<size_t> getResultVectorSize(size_t pos);
// Returns the width of the result scalar at position `pos` or the
// width of the scalar values of a vector if the result at
// position `pos` is a tensor.
llvm::Expected<size_t> getResultWidth(size_t pos);
// Returns the dimensions of the result tensor at position `pos` or
// an error if the result is a scalar value
llvm::Expected<std::vector<int64_t>> getResultDimensions(size_t pos);

View File

@@ -379,6 +379,24 @@ JITLambda::Argument::getResultType(size_t pos) {
}
}
llvm::Expected<size_t> JITLambda::Argument::getResultWidth(size_t pos) {
if (pos >= outputGates.size()) {
return llvm::createStringError(llvm::inconvertibleErrorCode(),
"Requesting width for result at index %zu, "
"but lambda only generates %zu results",
pos, outputGates.size());
}
auto gate = outputGates[pos];
auto info = std::get<0>(gate);
// Encrypted values are always returned as 64-bit values for now
if (info.encryption.hasValue())
return 64;
else
return info.shape.width;
}
llvm::Error JITLambda::Argument::getResult(size_t pos, void *res,
size_t elementSize,
size_t numElements) {