#include "zamalang-c/Support/CompilerEngine.h" #include "zamalang/Support/CompilerEngine.h" #include "zamalang/Support/Jit.h" #include "zamalang/Support/JitCompilerEngine.h" using mlir::zamalang::JitCompilerEngine; mlir::zamalang::JitCompilerEngine::Lambda buildLambda(const char *module, const char *funcName, const char *runtimeLibPath) { // Set the runtime library path if not nullptr llvm::Optional runtimeLibPathOptional = {}; if (runtimeLibPath != nullptr) runtimeLibPathOptional = runtimeLibPath; mlir::zamalang::JitCompilerEngine engine; llvm::Expected lambdaOrErr = engine.buildLambda(module, funcName, runtimeLibPathOptional); if (!lambdaOrErr) { std::string backingString; llvm::raw_string_ostream os(backingString); os << "Compilation failed: " << llvm::toString(std::move(lambdaOrErr.takeError())); throw std::runtime_error(os.str()); } return std::move(*lambdaOrErr); } lambdaArgument invokeLambda(lambda l, executionArguments args) { mlir::zamalang::JitCompilerEngine::Lambda *lambda_ptr = (mlir::zamalang::JitCompilerEngine::Lambda *)l.ptr; if (args.size != lambda_ptr->getNumArguments()) { throw std::invalid_argument("wrong number of arguments"); } // Set the integer/tensor arguments std::vector lambdaArgumentsRef; for (auto i = 0; i < args.size; i++) { lambdaArgumentsRef.push_back(args.data[i].ptr.get()); } // Run lambda llvm::Expected> resOrError = (*lambda_ptr) . operator()>( llvm::ArrayRef( lambdaArgumentsRef)); if (!resOrError) { std::string backingString; llvm::raw_string_ostream os(backingString); os << "Lambda invocation failed: " << llvm::toString(std::move(resOrError.takeError())); throw std::runtime_error(os.str()); } lambdaArgument result{std::move(*resOrError)}; return std::move(result); } std::string roundTrip(const char *module) { std::shared_ptr ccx = mlir::zamalang::CompilationContext::createShared(); mlir::zamalang::JitCompilerEngine ce{ccx}; std::string backingString; llvm::raw_string_ostream os(backingString); llvm::Expected retOrErr = ce.compile(module, mlir::zamalang::CompilerEngine::Target::ROUND_TRIP); if (!retOrErr) { os << "MLIR parsing failed: " << llvm::toString(std::move(retOrErr.takeError())); throw std::runtime_error(os.str()); } retOrErr->mlirModuleRef->get().print(os); return os.str(); } bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg) { return lambda_arg.ptr->isa>>(); } std::vector lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) { mlir::zamalang::TensorLambdaArgument< mlir::zamalang::IntLambdaArgument> *arg = lambda_arg.ptr->dyn_cast>>(); if (arg == nullptr) { throw std::invalid_argument( "LambdaArgument isn't a tensor, should " "be a TensorLambdaArgument>"); } llvm::Expected sizeOrErr = arg->getNumElements(); if (!sizeOrErr) { std::string backingString; llvm::raw_string_ostream os(backingString); os << "Couldn't get size of tensor: " << llvm::toString(std::move(sizeOrErr.takeError())); throw std::runtime_error(os.str()); } std::vector data(arg->getValue(), arg->getValue() + *sizeOrErr); return data; } std::vector lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg) { mlir::zamalang::TensorLambdaArgument< mlir::zamalang::IntLambdaArgument> *arg = lambda_arg.ptr->dyn_cast>>(); if (arg == nullptr) { throw std::invalid_argument( "LambdaArgument isn't a tensor, should " "be a TensorLambdaArgument>"); } return arg->getDimensions(); } bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg) { return lambda_arg.ptr->isa>(); } uint64_t lambdaArgumentGetScalar(lambdaArgument &lambda_arg) { mlir::zamalang::IntLambdaArgument *arg = lambda_arg.ptr->dyn_cast>(); if (arg == nullptr) { throw std::invalid_argument("LambdaArgument isn't a scalar, should " "be an IntLambdaArgument"); } return arg->getValue(); } lambdaArgument lambdaArgumentFromTensor(std::vector data, std::vector dimensions) { lambdaArgument tensor_arg{ std::make_shared>>(data, dimensions)}; return tensor_arg; } lambdaArgument lambdaArgumentFromScalar(uint64_t scalar) { lambdaArgument scalar_arg{ std::make_shared>(scalar)}; return scalar_arg; }