diff --git a/compiler/include/zamalang/Support/Jit.h b/compiler/include/zamalang/Support/Jit.h index bda19b08c..35ef9cf8c 100644 --- a/compiler/include/zamalang/Support/Jit.h +++ b/compiler/include/zamalang/Support/Jit.h @@ -66,6 +66,11 @@ public: // `pos` or an error if the result is a scalar value llvm::Expected getResultVectorSize(size_t pos); + // Returns the width of the result scalar at position `pos` or the + // width of the scalar values of a vector if the result at + // position `pos` is a tensor. + llvm::Expected getResultWidth(size_t pos); + // Returns the dimensions of the result tensor at position `pos` or // an error if the result is a scalar value llvm::Expected> getResultDimensions(size_t pos); diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index a60f88ea7..7278aa0a7 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -379,6 +379,24 @@ JITLambda::Argument::getResultType(size_t pos) { } } +llvm::Expected JITLambda::Argument::getResultWidth(size_t pos) { + if (pos >= outputGates.size()) { + return llvm::createStringError(llvm::inconvertibleErrorCode(), + "Requesting width for result at index %zu, " + "but lambda only generates %zu results", + pos, outputGates.size()); + } + + auto gate = outputGates[pos]; + auto info = std::get<0>(gate); + + // Encrypted values are always returned as 64-bit values for now + if (info.encryption.hasValue()) + return 64; + else + return info.shape.width; +} + llvm::Error JITLambda::Argument::getResult(size_t pos, void *res, size_t elementSize, size_t numElements) {