diff --git a/compiler/include/zamalang/Support/JitCompilerEngine.h b/compiler/include/zamalang/Support/JitCompilerEngine.h index 0ef9308a4..174d37878 100644 --- a/compiler/include/zamalang/Support/JitCompilerEngine.h +++ b/compiler/include/zamalang/Support/JitCompilerEngine.h @@ -51,6 +51,43 @@ typedResult(JITLambda::Argument &arguments) { return std::move(res); } +// Specialization of `typedResult()` for a single result wrapped into +// a `LambdaArgument`. +template <> +inline llvm::Expected> +typedResult(JITLambda::Argument &arguments) { + llvm::Expected resTy = + arguments.getResultType(0); + + if (!resTy) + return std::move(resTy.takeError()); + + switch (*resTy) { + case JITLambda::Argument::ResultType::SCALAR: { + uint64_t res; + + if (llvm::Error err = arguments.getResult(0, res)) + return std::move(err); + + return std::move(std::make_unique>(res)); + } + + case JITLambda::Argument::ResultType::TENSOR: { + llvm::Expected> tensorOrError = + typedResult>(arguments); + + if (!tensorOrError) + return std::move(tensorOrError.takeError()); + + return std::move( + std::make_unique>>( + *tensorOrError)); + } + } + + return StreamStringError("Unknown result type"); +} + // Adaptor class that adds arguments specified as instances of // `LambdaArgument` to `JitLambda::Argument`. class JITLambdaArgumentAdaptor {