From f54c0dd8d8ea23d097e5865fc3352bbb46cb3ee4 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Tue, 9 Nov 2021 14:23:52 +0100 Subject: [PATCH] enhance(compiler): JitCompilerEngine: Add support for arbitrary int result tensors Add support for result tensors composed of uint8_t, uint16_t, uint32_t and uint64_t elements, replacing the current implementation, which only supports uint64_t tensors. --- .../zamalang/Support/JitCompilerEngine.h | 86 ++++++++++++++----- 1 file changed, 66 insertions(+), 20 deletions(-) diff --git a/compiler/include/zamalang/Support/JitCompilerEngine.h b/compiler/include/zamalang/Support/JitCompilerEngine.h index ac2cef3ef..49de95e2f 100644 --- a/compiler/include/zamalang/Support/JitCompilerEngine.h +++ b/compiler/include/zamalang/Support/JitCompilerEngine.h @@ -32,18 +32,15 @@ inline llvm::Expected typedResult(JITLambda::Argument &arguments) { return res; } -// Specialization of `typedResult()` for vector results, initializing -// an `std::vector` of the right size with the results and forwarding -// it to the caller with move semantics. -template <> -inline llvm::Expected> -typedResult(JITLambda::Argument &arguments) { +template +inline llvm::Expected> +typedVectorResult(JITLambda::Argument &arguments) { llvm::Expected n = arguments.getResultVectorSize(0); if (auto err = n.takeError()) return std::move(err); - std::vector res(*n); + std::vector res(*n); if (auto err = arguments.getResult(0, res.data(), res.size())) return StreamStringError() << "Cannot retrieve result:" << err; @@ -51,6 +48,54 @@ typedResult(JITLambda::Argument &arguments) { return std::move(res); } +// Specializations of `typedResult()` for vector results, initializing +// an `std::vector` of the right size with the results and forwarding +// it to the caller with move semantics. +// +// Cannot factor out into a template template inline +// llvm::Expected> +// typedResult(JITLambda::Argument &arguments); due to ambiguity with +// scalar template +template <> +inline llvm::Expected> +typedResult(JITLambda::Argument &arguments) { + return std::move(typedVectorResult(arguments)); +} +template <> +inline llvm::Expected> +typedResult(JITLambda::Argument &arguments) { + return std::move(typedVectorResult(arguments)); +} +template <> +inline llvm::Expected> +typedResult(JITLambda::Argument &arguments) { + return std::move(typedVectorResult(arguments)); +} +template <> +inline llvm::Expected> +typedResult(JITLambda::Argument &arguments) { + return std::move(typedVectorResult(arguments)); +} + +template +llvm::Expected> +buildTensorLambdaResult(JITLambda::Argument &arguments) { + llvm::Expected> tensorOrError = + typedResult>(arguments); + + if (!tensorOrError) + return std::move(tensorOrError.takeError()); + + llvm::Expected> tensorDimOrError = + arguments.getResultDimensions(0); + + if (!tensorDimOrError) + return std::move(tensorDimOrError.takeError()); + + return std::move(std::make_unique>>( + *tensorOrError, *tensorDimOrError)); +} + // Specialization of `typedResult()` for a single result wrapped into // a `LambdaArgument`. template <> @@ -73,25 +118,26 @@ typedResult(JITLambda::Argument &arguments) { } case JITLambda::Argument::ResultType::TENSOR: { - llvm::Expected> tensorOrError = - typedResult>(arguments); + llvm::Expected width = arguments.getResultWidth(0); - if (!tensorOrError) - return std::move(tensorOrError.takeError()); + if (!width) + return std::move(width.takeError()); - llvm::Expected> tensorDimOrError = - arguments.getResultDimensions(0); - if (!tensorDimOrError) - return std::move(tensorDimOrError.takeError()); - - return std::move( - std::make_unique>>( - *tensorOrError, *tensorDimOrError)); + if (*width > 64) + return StreamStringError("Cannot handle scalars with more than 64 bits"); + if (*width > 32) + return buildTensorLambdaResult(arguments); + else if (*width > 16) + return buildTensorLambdaResult(arguments); + else if (*width > 8) + return buildTensorLambdaResult(arguments); + else if (*width <= 8) + return buildTensorLambdaResult(arguments); } } return StreamStringError("Unknown result type"); -} +} // namespace // Adaptor class that adds arguments specified as instances of // `LambdaArgument` to `JitLambda::Argument`.