From 9040e5ab008b535e79c9ca408b16365ee0526eed Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Wed, 3 Nov 2021 17:03:16 +0100 Subject: [PATCH] feat(compiler): Add generic overload for result of JitCompilerEngine::Lambda This adds a new overload for `JitCompilerEngine::Lambda::operator()`, returning a result wrapped in a `std::unique_ptr` with meta information about the result. This allows for generic invocations of JitCompilerEngine::Lambda::operator(), where the result type is unknown before the invocation. --- .../zamalang/Support/JitCompilerEngine.h | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/compiler/include/zamalang/Support/JitCompilerEngine.h b/compiler/include/zamalang/Support/JitCompilerEngine.h index 0ef9308a4..174d37878 100644 --- a/compiler/include/zamalang/Support/JitCompilerEngine.h +++ b/compiler/include/zamalang/Support/JitCompilerEngine.h @@ -51,6 +51,43 @@ typedResult(JITLambda::Argument &arguments) { return std::move(res); } +// Specialization of `typedResult()` for a single result wrapped into +// a `LambdaArgument`. +template <> +inline llvm::Expected> +typedResult(JITLambda::Argument &arguments) { + llvm::Expected resTy = + arguments.getResultType(0); + + if (!resTy) + return std::move(resTy.takeError()); + + switch (*resTy) { + case JITLambda::Argument::ResultType::SCALAR: { + uint64_t res; + + if (llvm::Error err = arguments.getResult(0, res)) + return std::move(err); + + return std::move(std::make_unique>(res)); + } + + case JITLambda::Argument::ResultType::TENSOR: { + llvm::Expected> tensorOrError = + typedResult>(arguments); + + if (!tensorOrError) + return std::move(tensorOrError.takeError()); + + return std::move( + std::make_unique>>( + *tensorOrError)); + } + } + + return StreamStringError("Unknown result type"); +} + // Adaptor class that adds arguments specified as instances of // `LambdaArgument` to `JitLambda::Argument`. class JITLambdaArgumentAdaptor {