mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
chore(Lambda): simplify, extract, enhance message for bit width rounding
bit with rounding: 5bit element is widen to a standard 8bit word
This commit is contained in:
@@ -61,7 +61,7 @@ build-end-to-end-jit-encrypted-tensor: build-initialized
|
||||
build-end-to-end-jit-hlfhelinalg: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_jit_hlfhelinalg
|
||||
|
||||
build-end-to-end-jit-lamnda: build-initialized
|
||||
build-end-to-end-jit-lambda: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_jit_lambda
|
||||
|
||||
build-end-to-end-jit: build-end-to-end-jit-test build-end-to-end-jit-clear-tensor build-end-to-end-jit-encrypted-tensor build-end-to-end-jit-hlfhelinalg
|
||||
@@ -78,7 +78,7 @@ test-end-to-end-jit-encrypted-tensor: build-end-to-end-jit-encrypted-tensor
|
||||
test-end-to-end-jit-hlfhelinalg: build-end-to-end-jit-hlfhelinalg
|
||||
$(BUILD_DIR)/bin/end_to_end_jit_hlfhelinalg
|
||||
|
||||
test-end-to-end-jit-lambda: build-initialized build-end-to-end-jit-lamnda
|
||||
test-end-to-end-jit-lambda: build-initialized build-end-to-end-jit-lambda
|
||||
$(BUILD_DIR)/bin/end_to_end_jit_lambda
|
||||
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ public:
|
||||
|
||||
private:
|
||||
// Verify if lambda can accept a n-th argument.
|
||||
llvm::Error acceptNthArg(size_t n);
|
||||
llvm::Error emitErrorIfTooManyArgs(size_t n);
|
||||
llvm::Error setArg(size_t pos, size_t width, const void *data,
|
||||
llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
|
||||
|
||||
#include <zamalang/Support/Error.h>
|
||||
#include <zamalang/Support/Jit.h>
|
||||
#include <zamalang/Support/logging.h>
|
||||
|
||||
@@ -43,9 +44,7 @@ JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
module, /*llvmModuleBuilder=*/nullptr, optPipeline,
|
||||
/*jitCodeGenOptLevel=*/llvm::None, sharedLibPaths);
|
||||
if (!maybeEngine) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"failed to construct the MLIR ExecutionEngine",
|
||||
llvm::inconvertibleErrorCode());
|
||||
return StreamStringError("failed to construct the MLIR ExecutionEngine");
|
||||
}
|
||||
auto &engine = maybeEngine.get();
|
||||
auto lambda = std::make_unique<JITLambda>((*funcOp).getType(), name);
|
||||
@@ -54,37 +53,24 @@ JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
return std::move(lambda);
|
||||
}
|
||||
|
||||
llvm::Error hasSomeNull(llvm::MutableArrayRef<void *> args) {
|
||||
auto pos = 0;
|
||||
for (auto arg : args) {
|
||||
if (arg == nullptr) {
|
||||
auto msg =
|
||||
"invoke: argument at pos " + llvm::Twine(pos) + " is null or missing";
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
msg, llvm::inconvertibleErrorCode());
|
||||
}
|
||||
pos++;
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
|
||||
if (auto hasSomeNullError = hasSomeNull(args)) {
|
||||
return hasSomeNullError;
|
||||
auto found = std::find(args.begin(), args.end(), nullptr);
|
||||
if (found == args.end()) {
|
||||
return this->engine->invokePacked(this->name, args);
|
||||
}
|
||||
return this->engine->invokePacked(this->name, args);
|
||||
int pos = found - args.begin();
|
||||
return StreamStringError("invoke: argument at pos ")
|
||||
<< pos << " is null or missing";
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::invoke(Argument &args) {
|
||||
size_t expectedInputs = this->type.getNumParams();
|
||||
size_t actualInputs = args.inputs.size();
|
||||
if (expectedInputs != actualInputs) {
|
||||
auto msg = "invokeRaw: received " + llvm::Twine(actualInputs) +
|
||||
"arguments instead of " + llvm::Twine(expectedInputs);
|
||||
return llvm::make_error<llvm::StringError>(msg,
|
||||
llvm::inconvertibleErrorCode());
|
||||
if (expectedInputs == actualInputs) {
|
||||
return invokeRaw(args.rawArg);
|
||||
}
|
||||
return std::move(invokeRaw(args.rawArg));
|
||||
return StreamStringError("invokeRaw: received ")
|
||||
<< actualInputs << "arguments instead of " << expectedInputs;
|
||||
}
|
||||
|
||||
JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
|
||||
@@ -163,20 +149,17 @@ JITLambda::Argument::create(KeySet &keySet) {
|
||||
return std::move(args);
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::acceptNthArg(size_t pos) {
|
||||
llvm::Error JITLambda::Argument::emitErrorIfTooManyArgs(size_t pos) {
|
||||
size_t arity = inputGates.size();
|
||||
if (pos >= arity) {
|
||||
auto msg = "Call a function of arity " + llvm::Twine(arity) +
|
||||
" with at least " + llvm::Twine(pos + 1) + " arguments";
|
||||
return llvm::make_error<llvm::StringError>(msg,
|
||||
llvm::inconvertibleErrorCode());
|
||||
if (pos < arity) {
|
||||
return llvm::Error::success();
|
||||
}
|
||||
return llvm::Error::success();
|
||||
return StreamStringError("The function has arity ")
|
||||
<< arity << " but is applied to too many arguments";
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
|
||||
auto error = acceptNthArg(pos);
|
||||
if (error) {
|
||||
if (auto error = emitErrorIfTooManyArgs(pos)) {
|
||||
return error;
|
||||
}
|
||||
auto gate = inputGates[pos];
|
||||
@@ -210,47 +193,43 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
size_t bitWidthAsWord(size_t exactBitWidth) {
|
||||
size_t sortedWordBitWidths[] = {8, 16, 32, 64};
|
||||
size_t previousWidth = 0;
|
||||
for (auto currentWidth : sortedWordBitWidths) {
|
||||
if (previousWidth < exactBitWidth && exactBitWidth <= currentWidth) {
|
||||
return currentWidth;
|
||||
}
|
||||
}
|
||||
return exactBitWidth;
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
|
||||
const void *data,
|
||||
llvm::ArrayRef<int64_t> shape) {
|
||||
auto error = acceptNthArg(pos);
|
||||
if (error) {
|
||||
if (auto error = emitErrorIfTooManyArgs(pos)) {
|
||||
return error;
|
||||
}
|
||||
auto gate = inputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
auto offset = std::get<1>(gate);
|
||||
// Check if the width is compatible
|
||||
// TODO - I found this rules empirically, they are a spec somewhere?
|
||||
if (info.shape.width <= 8 && width != 8) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument width should be 8: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (info.shape.width > 8 && info.shape.width <= 16 && width != 16) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument width should be 16: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (info.shape.width > 16 && info.shape.width <= 32 && width != 32) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument width should be 32: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (info.shape.width > 32 && info.shape.width <= 64 && width != 64) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument width should be 64: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// TODO - I found this rules empirically, they are a spec somewhere?
|
||||
if (info.shape.width > 64) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument width not supported: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
auto msg = "Bad argument (pos=" + llvm::Twine(pos) + ") : a width of " +
|
||||
llvm::Twine(info.shape.width) +
|
||||
"bits > 64 is not supported: pos=" + llvm::Twine(pos);
|
||||
return llvm::make_error<llvm::StringError>(msg,
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
auto roundedSize = bitWidthAsWord(info.shape.width);
|
||||
if (width != roundedSize) {
|
||||
auto msg = "Bad argument (pos=" + llvm::Twine(pos) + ") : expected " +
|
||||
llvm::Twine(roundedSize) + "bits" + " but received " +
|
||||
llvm::Twine(width) + "bits (rounded from " +
|
||||
llvm::Twine(info.shape.width) + ")";
|
||||
return llvm::make_error<llvm::StringError>(msg,
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// Check the size
|
||||
if (info.shape.dimensions.empty()) {
|
||||
|
||||
@@ -47,7 +47,7 @@ static bool assert_expected_failure(llvm::Expected<T> &&val) {
|
||||
#define ASSERT_EXPECTED_SUCCESS(val) \
|
||||
do { \
|
||||
if (!assert_expected_success(val)) \
|
||||
GTEST_FATAL_FAILURE_("Expected<T> contained in error state"); \
|
||||
GTEST_FATAL_FAILURE_("Expected<T> in error state"); \
|
||||
} while (0)
|
||||
|
||||
// Checks that the value `val` of type `llvm::Expected<T>` is in
|
||||
@@ -55,7 +55,7 @@ static bool assert_expected_failure(llvm::Expected<T> &&val) {
|
||||
#define ASSERT_EXPECTED_FAILURE(val) \
|
||||
do { \
|
||||
if (assert_expected_success(val)) \
|
||||
GTEST_FATAL_FAILURE_("Expected<T> contained not in error state"); \
|
||||
GTEST_FATAL_FAILURE_("Expected<T> not in error state"); \
|
||||
} while (0)
|
||||
|
||||
// Checks that the value `val` is not in an error state and is equal
|
||||
|
||||
Reference in New Issue
Block a user