From cc58608589e297e7dd03e873c66ca6e873679d84 Mon Sep 17 00:00:00 2001 From: rudy Date: Fri, 19 Nov 2021 10:35:42 +0100 Subject: [PATCH] chore(Lambda): simplify, extract, enhance message for bit width rounding bit with rounding: 5bit element is widen to a standard 8bit word --- compiler/Makefile | 4 +- compiler/include/zamalang/Support/Jit.h | 2 +- compiler/lib/Support/Jit.cpp | 111 +++++++----------- compiler/tests/unittest/end_to_end_jit_test.h | 4 +- 4 files changed, 50 insertions(+), 71 deletions(-) diff --git a/compiler/Makefile b/compiler/Makefile index e35cf5d72..75ba847f2 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -61,7 +61,7 @@ build-end-to-end-jit-encrypted-tensor: build-initialized build-end-to-end-jit-hlfhelinalg: build-initialized cmake --build $(BUILD_DIR) --target end_to_end_jit_hlfhelinalg -build-end-to-end-jit-lamnda: build-initialized +build-end-to-end-jit-lambda: build-initialized cmake --build $(BUILD_DIR) --target end_to_end_jit_lambda build-end-to-end-jit: build-end-to-end-jit-test build-end-to-end-jit-clear-tensor build-end-to-end-jit-encrypted-tensor build-end-to-end-jit-hlfhelinalg @@ -78,7 +78,7 @@ test-end-to-end-jit-encrypted-tensor: build-end-to-end-jit-encrypted-tensor test-end-to-end-jit-hlfhelinalg: build-end-to-end-jit-hlfhelinalg $(BUILD_DIR)/bin/end_to_end_jit_hlfhelinalg -test-end-to-end-jit-lambda: build-initialized build-end-to-end-jit-lamnda +test-end-to-end-jit-lambda: build-initialized build-end-to-end-jit-lambda $(BUILD_DIR)/bin/end_to_end_jit_lambda diff --git a/compiler/include/zamalang/Support/Jit.h b/compiler/include/zamalang/Support/Jit.h index 272eeac42..e94542160 100644 --- a/compiler/include/zamalang/Support/Jit.h +++ b/compiler/include/zamalang/Support/Jit.h @@ -77,7 +77,7 @@ public: private: // Verify if lambda can accept a n-th argument. - llvm::Error acceptNthArg(size_t n); + llvm::Error emitErrorIfTooManyArgs(size_t n); llvm::Error setArg(size_t pos, size_t width, const void *data, llvm::ArrayRef shape); diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 02ea96a6e..4fe1140dc 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include @@ -43,9 +44,7 @@ JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module, module, /*llvmModuleBuilder=*/nullptr, optPipeline, /*jitCodeGenOptLevel=*/llvm::None, sharedLibPaths); if (!maybeEngine) { - return llvm::make_error( - "failed to construct the MLIR ExecutionEngine", - llvm::inconvertibleErrorCode()); + return StreamStringError("failed to construct the MLIR ExecutionEngine"); } auto &engine = maybeEngine.get(); auto lambda = std::make_unique((*funcOp).getType(), name); @@ -54,37 +53,24 @@ JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module, return std::move(lambda); } -llvm::Error hasSomeNull(llvm::MutableArrayRef args) { - auto pos = 0; - for (auto arg : args) { - if (arg == nullptr) { - auto msg = - "invoke: argument at pos " + llvm::Twine(pos) + " is null or missing"; - return llvm::make_error( - msg, llvm::inconvertibleErrorCode()); - } - pos++; - } - return llvm::Error::success(); -} - llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef args) { - if (auto hasSomeNullError = hasSomeNull(args)) { - return hasSomeNullError; + auto found = std::find(args.begin(), args.end(), nullptr); + if (found == args.end()) { + return this->engine->invokePacked(this->name, args); } - return this->engine->invokePacked(this->name, args); + int pos = found - args.begin(); + return StreamStringError("invoke: argument at pos ") + << pos << " is null or missing"; } llvm::Error JITLambda::invoke(Argument &args) { size_t expectedInputs = this->type.getNumParams(); size_t actualInputs = args.inputs.size(); - if (expectedInputs != actualInputs) { - auto msg = "invokeRaw: received " + llvm::Twine(actualInputs) + - "arguments instead of " + llvm::Twine(expectedInputs); - return llvm::make_error(msg, - llvm::inconvertibleErrorCode()); + if (expectedInputs == actualInputs) { + return invokeRaw(args.rawArg); } - return std::move(invokeRaw(args.rawArg)); + return StreamStringError("invokeRaw: received ") + << actualInputs << "arguments instead of " << expectedInputs; } JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) { @@ -163,20 +149,17 @@ JITLambda::Argument::create(KeySet &keySet) { return std::move(args); } -llvm::Error JITLambda::Argument::acceptNthArg(size_t pos) { +llvm::Error JITLambda::Argument::emitErrorIfTooManyArgs(size_t pos) { size_t arity = inputGates.size(); - if (pos >= arity) { - auto msg = "Call a function of arity " + llvm::Twine(arity) + - " with at least " + llvm::Twine(pos + 1) + " arguments"; - return llvm::make_error(msg, - llvm::inconvertibleErrorCode()); + if (pos < arity) { + return llvm::Error::success(); } - return llvm::Error::success(); + return StreamStringError("The function has arity ") + << arity << " but is applied to too many arguments"; } llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) { - auto error = acceptNthArg(pos); - if (error) { + if (auto error = emitErrorIfTooManyArgs(pos)) { return error; } auto gate = inputGates[pos]; @@ -210,47 +193,43 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) { return llvm::Error::success(); } +size_t bitWidthAsWord(size_t exactBitWidth) { + size_t sortedWordBitWidths[] = {8, 16, 32, 64}; + size_t previousWidth = 0; + for (auto currentWidth : sortedWordBitWidths) { + if (previousWidth < exactBitWidth && exactBitWidth <= currentWidth) { + return currentWidth; + } + } + return exactBitWidth; +} + llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, const void *data, llvm::ArrayRef shape) { - auto error = acceptNthArg(pos); - if (error) { + if (auto error = emitErrorIfTooManyArgs(pos)) { return error; } auto gate = inputGates[pos]; auto info = std::get<0>(gate); auto offset = std::get<1>(gate); // Check if the width is compatible - // TODO - I found this rules empirically, they are a spec somewhere? - if (info.shape.width <= 8 && width != 8) { - return llvm::make_error( - llvm::Twine("argument width should be 8: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - if (info.shape.width > 8 && info.shape.width <= 16 && width != 16) { - return llvm::make_error( - llvm::Twine("argument width should be 16: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - if (info.shape.width > 16 && info.shape.width <= 32 && width != 32) { - return llvm::make_error( - llvm::Twine("argument width should be 32: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - if (info.shape.width > 32 && info.shape.width <= 64 && width != 64) { - return llvm::make_error( - llvm::Twine("argument width should be 64: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } + // TODO - I found this rules empirically, they are a spec somewhere? if (info.shape.width > 64) { - return llvm::make_error( - llvm::Twine("argument width not supported: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); + auto msg = "Bad argument (pos=" + llvm::Twine(pos) + ") : a width of " + + llvm::Twine(info.shape.width) + + "bits > 64 is not supported: pos=" + llvm::Twine(pos); + return llvm::make_error(msg, + llvm::inconvertibleErrorCode()); + } + auto roundedSize = bitWidthAsWord(info.shape.width); + if (width != roundedSize) { + auto msg = "Bad argument (pos=" + llvm::Twine(pos) + ") : expected " + + llvm::Twine(roundedSize) + "bits" + " but received " + + llvm::Twine(width) + "bits (rounded from " + + llvm::Twine(info.shape.width) + ")"; + return llvm::make_error(msg, + llvm::inconvertibleErrorCode()); } // Check the size if (info.shape.dimensions.empty()) { diff --git a/compiler/tests/unittest/end_to_end_jit_test.h b/compiler/tests/unittest/end_to_end_jit_test.h index 613550997..8a1e5d6b0 100644 --- a/compiler/tests/unittest/end_to_end_jit_test.h +++ b/compiler/tests/unittest/end_to_end_jit_test.h @@ -47,7 +47,7 @@ static bool assert_expected_failure(llvm::Expected &&val) { #define ASSERT_EXPECTED_SUCCESS(val) \ do { \ if (!assert_expected_success(val)) \ - GTEST_FATAL_FAILURE_("Expected contained in error state"); \ + GTEST_FATAL_FAILURE_("Expected in error state"); \ } while (0) // Checks that the value `val` of type `llvm::Expected` is in @@ -55,7 +55,7 @@ static bool assert_expected_failure(llvm::Expected &&val) { #define ASSERT_EXPECTED_FAILURE(val) \ do { \ if (assert_expected_success(val)) \ - GTEST_FATAL_FAILURE_("Expected contained not in error state"); \ + GTEST_FATAL_FAILURE_("Expected not in error state"); \ } while (0) // Checks that the value `val` is not in an error state and is equal