chore(Lambda): simplify, extract, enhance message for bit width rounding

bit with rounding: 5bit element is widen to a standard 8bit word
This commit is contained in:
rudy
2021-11-19 10:35:42 +01:00
committed by rudy-6-4
parent 209463be22
commit cc58608589
4 changed files with 50 additions and 71 deletions

View File

@@ -61,7 +61,7 @@ build-end-to-end-jit-encrypted-tensor: build-initialized
build-end-to-end-jit-hlfhelinalg: build-initialized
cmake --build $(BUILD_DIR) --target end_to_end_jit_hlfhelinalg
build-end-to-end-jit-lamnda: build-initialized
build-end-to-end-jit-lambda: build-initialized
cmake --build $(BUILD_DIR) --target end_to_end_jit_lambda
build-end-to-end-jit: build-end-to-end-jit-test build-end-to-end-jit-clear-tensor build-end-to-end-jit-encrypted-tensor build-end-to-end-jit-hlfhelinalg
@@ -78,7 +78,7 @@ test-end-to-end-jit-encrypted-tensor: build-end-to-end-jit-encrypted-tensor
test-end-to-end-jit-hlfhelinalg: build-end-to-end-jit-hlfhelinalg
$(BUILD_DIR)/bin/end_to_end_jit_hlfhelinalg
test-end-to-end-jit-lambda: build-initialized build-end-to-end-jit-lamnda
test-end-to-end-jit-lambda: build-initialized build-end-to-end-jit-lambda
$(BUILD_DIR)/bin/end_to_end_jit_lambda

View File

@@ -77,7 +77,7 @@ public:
private:
// Verify if lambda can accept a n-th argument.
llvm::Error acceptNthArg(size_t n);
llvm::Error emitErrorIfTooManyArgs(size_t n);
llvm::Error setArg(size_t pos, size_t width, const void *data,
llvm::ArrayRef<int64_t> shape);

View File

@@ -7,6 +7,7 @@
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
#include <zamalang/Support/Error.h>
#include <zamalang/Support/Jit.h>
#include <zamalang/Support/logging.h>
@@ -43,9 +44,7 @@ JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
module, /*llvmModuleBuilder=*/nullptr, optPipeline,
/*jitCodeGenOptLevel=*/llvm::None, sharedLibPaths);
if (!maybeEngine) {
return llvm::make_error<llvm::StringError>(
"failed to construct the MLIR ExecutionEngine",
llvm::inconvertibleErrorCode());
return StreamStringError("failed to construct the MLIR ExecutionEngine");
}
auto &engine = maybeEngine.get();
auto lambda = std::make_unique<JITLambda>((*funcOp).getType(), name);
@@ -54,37 +53,24 @@ JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
return std::move(lambda);
}
llvm::Error hasSomeNull(llvm::MutableArrayRef<void *> args) {
auto pos = 0;
for (auto arg : args) {
if (arg == nullptr) {
auto msg =
"invoke: argument at pos " + llvm::Twine(pos) + " is null or missing";
return llvm::make_error<llvm::StringError>(
msg, llvm::inconvertibleErrorCode());
}
pos++;
}
return llvm::Error::success();
}
llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
if (auto hasSomeNullError = hasSomeNull(args)) {
return hasSomeNullError;
auto found = std::find(args.begin(), args.end(), nullptr);
if (found == args.end()) {
return this->engine->invokePacked(this->name, args);
}
return this->engine->invokePacked(this->name, args);
int pos = found - args.begin();
return StreamStringError("invoke: argument at pos ")
<< pos << " is null or missing";
}
llvm::Error JITLambda::invoke(Argument &args) {
size_t expectedInputs = this->type.getNumParams();
size_t actualInputs = args.inputs.size();
if (expectedInputs != actualInputs) {
auto msg = "invokeRaw: received " + llvm::Twine(actualInputs) +
"arguments instead of " + llvm::Twine(expectedInputs);
return llvm::make_error<llvm::StringError>(msg,
llvm::inconvertibleErrorCode());
if (expectedInputs == actualInputs) {
return invokeRaw(args.rawArg);
}
return std::move(invokeRaw(args.rawArg));
return StreamStringError("invokeRaw: received ")
<< actualInputs << "arguments instead of " << expectedInputs;
}
JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
@@ -163,20 +149,17 @@ JITLambda::Argument::create(KeySet &keySet) {
return std::move(args);
}
llvm::Error JITLambda::Argument::acceptNthArg(size_t pos) {
llvm::Error JITLambda::Argument::emitErrorIfTooManyArgs(size_t pos) {
size_t arity = inputGates.size();
if (pos >= arity) {
auto msg = "Call a function of arity " + llvm::Twine(arity) +
" with at least " + llvm::Twine(pos + 1) + " arguments";
return llvm::make_error<llvm::StringError>(msg,
llvm::inconvertibleErrorCode());
if (pos < arity) {
return llvm::Error::success();
}
return llvm::Error::success();
return StreamStringError("The function has arity ")
<< arity << " but is applied to too many arguments";
}
llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
auto error = acceptNthArg(pos);
if (error) {
if (auto error = emitErrorIfTooManyArgs(pos)) {
return error;
}
auto gate = inputGates[pos];
@@ -210,47 +193,43 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
return llvm::Error::success();
}
size_t bitWidthAsWord(size_t exactBitWidth) {
size_t sortedWordBitWidths[] = {8, 16, 32, 64};
size_t previousWidth = 0;
for (auto currentWidth : sortedWordBitWidths) {
if (previousWidth < exactBitWidth && exactBitWidth <= currentWidth) {
return currentWidth;
}
}
return exactBitWidth;
}
llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
const void *data,
llvm::ArrayRef<int64_t> shape) {
auto error = acceptNthArg(pos);
if (error) {
if (auto error = emitErrorIfTooManyArgs(pos)) {
return error;
}
auto gate = inputGates[pos];
auto info = std::get<0>(gate);
auto offset = std::get<1>(gate);
// Check if the width is compatible
// TODO - I found this rules empirically, they are a spec somewhere?
if (info.shape.width <= 8 && width != 8) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width should be 8: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (info.shape.width > 8 && info.shape.width <= 16 && width != 16) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width should be 16: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (info.shape.width > 16 && info.shape.width <= 32 && width != 32) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width should be 32: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (info.shape.width > 32 && info.shape.width <= 64 && width != 64) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width should be 64: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
// TODO - I found this rules empirically, they are a spec somewhere?
if (info.shape.width > 64) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width not supported: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
auto msg = "Bad argument (pos=" + llvm::Twine(pos) + ") : a width of " +
llvm::Twine(info.shape.width) +
"bits > 64 is not supported: pos=" + llvm::Twine(pos);
return llvm::make_error<llvm::StringError>(msg,
llvm::inconvertibleErrorCode());
}
auto roundedSize = bitWidthAsWord(info.shape.width);
if (width != roundedSize) {
auto msg = "Bad argument (pos=" + llvm::Twine(pos) + ") : expected " +
llvm::Twine(roundedSize) + "bits" + " but received " +
llvm::Twine(width) + "bits (rounded from " +
llvm::Twine(info.shape.width) + ")";
return llvm::make_error<llvm::StringError>(msg,
llvm::inconvertibleErrorCode());
}
// Check the size
if (info.shape.dimensions.empty()) {

View File

@@ -47,7 +47,7 @@ static bool assert_expected_failure(llvm::Expected<T> &&val) {
#define ASSERT_EXPECTED_SUCCESS(val) \
do { \
if (!assert_expected_success(val)) \
GTEST_FATAL_FAILURE_("Expected<T> contained in error state"); \
GTEST_FATAL_FAILURE_("Expected<T> in error state"); \
} while (0)
// Checks that the value `val` of type `llvm::Expected<T>` is in
@@ -55,7 +55,7 @@ static bool assert_expected_failure(llvm::Expected<T> &&val) {
#define ASSERT_EXPECTED_FAILURE(val) \
do { \
if (assert_expected_success(val)) \
GTEST_FATAL_FAILURE_("Expected<T> contained not in error state"); \
GTEST_FATAL_FAILURE_("Expected<T> not in error state"); \
} while (0)
// Checks that the value `val` is not in an error state and is equal