mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
enhance(compiler/runtime): Add runtime tools to handle tensor inputs and outputs
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
#ifndef ZAMALANG_CONVERSION_GLOBALFHECONTEXT_PATTERNS_H_
|
||||
#define ZAMALANG_CONVERSION_GLOBALFHECONTEXT_PATTERNS_H_
|
||||
#ifndef ZAMALANG_CONVERSION_GLOBALFHECONTEXT_H_
|
||||
#define ZAMALANG_CONVERSION_GLOBALFHECONTEXT_H_
|
||||
#include <cstddef>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@@ -56,7 +56,10 @@ struct EncryptionGate {
|
||||
};
|
||||
|
||||
struct CircuitGateShape {
|
||||
uint64_t size;
|
||||
// Width of the scalar value
|
||||
size_t width;
|
||||
// Size of the buffer
|
||||
size_t size;
|
||||
};
|
||||
|
||||
struct CircuitGate {
|
||||
|
||||
@@ -15,6 +15,9 @@
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
/// CompilerEngine is an tools that provides tools to implements the compilation
|
||||
/// flow and manage the compilation flow state.
|
||||
class CompilerEngine {
|
||||
public:
|
||||
CompilerEngine() {
|
||||
@@ -26,10 +29,16 @@ public:
|
||||
delete context;
|
||||
}
|
||||
|
||||
// Compile an MLIR input
|
||||
llvm::Expected<mlir::LogicalResult> compileFHE(std::string mlir_input);
|
||||
// Compile an mlir programs from it's textual representation.
|
||||
llvm::Error compile(std::string mlirStr);
|
||||
|
||||
// Run the compiled module
|
||||
// Build the jit lambda argument.
|
||||
llvm::Expected<std::unique_ptr<JITLambda::Argument>> buildArgument();
|
||||
|
||||
// Call the compiled function with and argument object.
|
||||
llvm::Error invoke(JITLambda::Argument &arg);
|
||||
|
||||
// Call the compiled function with a list of integer arguments.
|
||||
llvm::Expected<uint64_t> run(std::vector<uint64_t> args);
|
||||
|
||||
// Get a printable representation of the compiled module
|
||||
|
||||
@@ -51,17 +51,54 @@ public:
|
||||
// and decryption operations.
|
||||
static llvm::Expected<std::unique_ptr<Argument>> create(KeySet &keySet);
|
||||
|
||||
// Set the argument at the given pos as a uint64_t.
|
||||
// Set a scalar argument at the given pos as a uint64_t.
|
||||
llvm::Error setArg(size_t pos, uint64_t arg);
|
||||
|
||||
// Set a argument at the given pos as a tensor of int64.
|
||||
llvm::Error setArg(size_t pos, uint64_t *data, size_t size) {
|
||||
return setArg(pos, 64, (void *)data, size);
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a tensor of int32.
|
||||
llvm::Error setArg(size_t pos, uint32_t *data, size_t size) {
|
||||
return setArg(pos, 32, (void *)data, size);
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a tensor of int32.
|
||||
llvm::Error setArg(size_t pos, uint16_t *data, size_t size) {
|
||||
return setArg(pos, 16, (void *)data, size);
|
||||
}
|
||||
|
||||
// Set a tensor argument at the given pos as a uint64_t.
|
||||
llvm::Error setArg(size_t pos, uint8_t *data, size_t size) {
|
||||
return setArg(pos, 8, (void *)data, size);
|
||||
}
|
||||
|
||||
// Get the result at the given pos as an uint64_t.
|
||||
llvm::Error getResult(size_t pos, uint64_t &res);
|
||||
|
||||
// Fill the result.
|
||||
llvm::Error getResult(size_t pos, uint64_t *res, size_t size);
|
||||
|
||||
private:
|
||||
llvm::Error setArg(size_t pos, size_t width, void *data, size_t size);
|
||||
|
||||
friend JITLambda;
|
||||
// Store the pointer on inputs values and outputs values
|
||||
std::vector<void *> rawArg;
|
||||
// Store the values of inputs
|
||||
std::vector<void *> inputs;
|
||||
std::vector<void *> results;
|
||||
// Store the values of outputs
|
||||
std::vector<void *> outputs;
|
||||
// Store the input gates description and the offset of the argument.
|
||||
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> inputGates;
|
||||
// Store the outputs gates description and the offset of the argument.
|
||||
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> outputGates;
|
||||
// Store allocated lwe ciphertexts (for free)
|
||||
std::vector<LweCiphertext_u64 *> allocatedCiphertexts;
|
||||
// Store buffers of ciphertexts
|
||||
std::vector<LweCiphertext_u64 **> ciphertextBuffers;
|
||||
|
||||
KeySet &keySet;
|
||||
};
|
||||
JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name)
|
||||
|
||||
@@ -37,6 +37,9 @@ public:
|
||||
size_t numInputs() { return inputs.size(); }
|
||||
size_t numOutputs() { return outputs.size(); }
|
||||
|
||||
CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); }
|
||||
CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); }
|
||||
|
||||
protected:
|
||||
llvm::Error generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
|
||||
SecretRandomGenerator *generator);
|
||||
|
||||
@@ -14,12 +14,25 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
|
||||
Precision precision,
|
||||
mlir::Type type) {
|
||||
if (type.isIntOrIndex()) {
|
||||
// TODO - The index type is dependant of the target architecture, so
|
||||
// actually we assume we target only 64 bits, we need to have some the size
|
||||
// of the word of the target system.
|
||||
size_t width = 64;
|
||||
if (!type.isIndex()) {
|
||||
width = type.getIntOrFloatBitWidth();
|
||||
}
|
||||
return CircuitGate{
|
||||
.encryption = llvm::None,
|
||||
.shape = {.size = 0},
|
||||
.shape =
|
||||
{
|
||||
.width = width,
|
||||
.size = 0,
|
||||
},
|
||||
};
|
||||
}
|
||||
if (type.isa<mlir::zamalang::LowLFHE::LweCiphertextType>()) {
|
||||
// TODO - Get the width from the LWECiphertextType instead of global
|
||||
// precision (could be possible after merge lowlfhe-ciphertext-parameter)
|
||||
return CircuitGate{
|
||||
.encryption = llvm::Optional<EncryptionGate>({
|
||||
.secretKeyID = secretKeyID,
|
||||
@@ -27,17 +40,17 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
|
||||
.variance = 0.,
|
||||
.encoding = {.precision = precision},
|
||||
}),
|
||||
.shape = {.size = 0},
|
||||
.shape = {.width = precision, .size = 0},
|
||||
};
|
||||
}
|
||||
auto memref = type.dyn_cast_or_null<mlir::MemRefType>();
|
||||
if (memref != nullptr) {
|
||||
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
if (tensor != nullptr) {
|
||||
auto gate =
|
||||
gateFromMLIRType(secretKeyID, precision, memref.getElementType());
|
||||
gateFromMLIRType(secretKeyID, precision, tensor.getElementType());
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
gate->shape.size = memref.getDimSize(0);
|
||||
gate->shape.size = tensor.getDimSize(0);
|
||||
return gate;
|
||||
}
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
|
||||
@@ -23,9 +23,8 @@ std::string CompilerEngine::getCompiledModule() {
|
||||
return os.str();
|
||||
}
|
||||
|
||||
llvm::Expected<mlir::LogicalResult>
|
||||
CompilerEngine::compileFHE(std::string mlir_input) {
|
||||
module_ref = mlir::parseSourceString(mlir_input, context);
|
||||
llvm::Error CompilerEngine::compile(std::string mlirStr) {
|
||||
module_ref = mlir::parseSourceString(mlirStr, context);
|
||||
if (!module_ref) {
|
||||
return llvm::make_error<llvm::StringError>("mlir parsing failed",
|
||||
llvm::inconvertibleErrorCode());
|
||||
@@ -60,29 +59,44 @@ CompilerEngine::compileFHE(std::string mlir_input) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"failed to lower to LLVM dialect", llvm::inconvertibleErrorCode());
|
||||
}
|
||||
return mlir::success();
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Expected<uint64_t> CompilerEngine::run(std::vector<uint64_t> args) {
|
||||
llvm::Expected<std::unique_ptr<JITLambda::Argument>>
|
||||
CompilerEngine::buildArgument() {
|
||||
if (keySet.get() == nullptr) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"CompilerEngine::buildArgument: invalid engine state, the keySet has "
|
||||
"not be generated",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
return JITLambda::Argument::create(*keySet);
|
||||
}
|
||||
|
||||
llvm::Error CompilerEngine::invoke(JITLambda::Argument &arg) {
|
||||
// Create the JIT lambda
|
||||
auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
|
||||
auto module = module_ref.get();
|
||||
auto maybeLambda =
|
||||
mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline);
|
||||
if (!maybeLambda) {
|
||||
return llvm::make_error<llvm::StringError>("couldn't create lambda",
|
||||
llvm::inconvertibleErrorCode());
|
||||
if (auto err = maybeLambda.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
auto lambda = std::move(maybeLambda.get());
|
||||
// Invoke the lambda
|
||||
if (auto err = maybeLambda.get()->invoke(arg)) {
|
||||
return std::move(err);
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
// Create the arguments of the JIT lambda
|
||||
auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(*keySet);
|
||||
if (auto err = maybeArguments.takeError()) {
|
||||
return llvm::make_error<llvm::StringError>("cannot create lambda args",
|
||||
llvm::inconvertibleErrorCode());
|
||||
llvm::Expected<uint64_t> CompilerEngine::run(std::vector<uint64_t> args) {
|
||||
// Build the argument of the JIT lambda.
|
||||
auto maybeArgument = buildArgument();
|
||||
if (auto err = maybeArgument.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
// Set the arguments
|
||||
auto arguments = std::move(maybeArguments.get());
|
||||
// Set the integer arguments
|
||||
auto arguments = std::move(maybeArgument.get());
|
||||
for (auto i = 0; i < args.size(); i++) {
|
||||
if (auto err = arguments->setArg(i, args[i])) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
@@ -90,14 +104,12 @@ llvm::Expected<uint64_t> CompilerEngine::run(std::vector<uint64_t> args) {
|
||||
}
|
||||
}
|
||||
// Invoke the lambda
|
||||
if (lambda->invoke(*arguments)) {
|
||||
return llvm::make_error<llvm::StringError>("failed execution",
|
||||
llvm::inconvertibleErrorCode());
|
||||
if (auto err = invoke(*arguments)) {
|
||||
return std::move(err);
|
||||
}
|
||||
uint64_t res = 0;
|
||||
if (auto err = arguments->getResult(0, res)) {
|
||||
return llvm::make_error<llvm::StringError>("cannot get result",
|
||||
llvm::inconvertibleErrorCode());
|
||||
return std::move(err);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -141,10 +141,16 @@ JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
|
||||
if (this->type.getNumParams() != args.size() - 1) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"invokeRaw: wrong number of argument", llvm::inconvertibleErrorCode());
|
||||
}
|
||||
size_t nbReturn = 0;
|
||||
// TODO - This check break with memref as we have 5 returns args.
|
||||
// if (!this->type.getReturnType().isa<mlir::LLVM::LLVMVoidType>()) {
|
||||
// nbReturn = 1;
|
||||
// }
|
||||
// if (this->type.getNumParams() != args.size() - nbReturn) {
|
||||
// return llvm::make_error<llvm::StringError>(
|
||||
// "invokeRaw: wrong number of argument",
|
||||
// llvm::inconvertibleErrorCode());
|
||||
// }
|
||||
if (llvm::find(args, nullptr) != args.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"invoke: some arguments are null", llvm::inconvertibleErrorCode());
|
||||
@@ -157,24 +163,58 @@ llvm::Error JITLambda::invoke(Argument &args) {
|
||||
}
|
||||
|
||||
JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
|
||||
inputs = std::vector<void *>(keySet.numInputs());
|
||||
results = std::vector<void *>(keySet.numOutputs());
|
||||
// Setting the inputs
|
||||
{
|
||||
auto numInputs = 0;
|
||||
for (size_t i = 0; i < keySet.numInputs(); i++) {
|
||||
auto offset = numInputs;
|
||||
auto gate = keySet.inputGate(i);
|
||||
inputGates.push_back({gate, offset});
|
||||
if (keySet.inputGate(i).shape.size == 0) {
|
||||
// scalar gate
|
||||
numInputs = numInputs + 1;
|
||||
continue;
|
||||
}
|
||||
// memref gate, as we follow the standard calling convention
|
||||
numInputs = numInputs + 5;
|
||||
}
|
||||
inputs = std::vector<void *>(numInputs);
|
||||
}
|
||||
|
||||
// Setting the outputs
|
||||
{
|
||||
auto numOutputs = 0;
|
||||
for (size_t i = 0; i < keySet.numOutputs(); i++) {
|
||||
auto offset = numOutputs;
|
||||
auto gate = keySet.outputGate(i);
|
||||
outputGates.push_back({gate, offset});
|
||||
if (gate.shape.size == 0) {
|
||||
// scalar gate
|
||||
numOutputs = numOutputs + 1;
|
||||
continue;
|
||||
}
|
||||
// memref gate, as we follow the standard calling convention
|
||||
numOutputs = numOutputs + 5;
|
||||
}
|
||||
outputs = std::vector<void *>(numOutputs);
|
||||
}
|
||||
|
||||
// The raw argument contains pointers to inputs and pointers to store the
|
||||
// results
|
||||
rawArg =
|
||||
std::vector<void *>(keySet.numInputs() + keySet.numOutputs(), nullptr);
|
||||
// Set the results pointer on the rawArg
|
||||
for (auto i = keySet.numInputs(); i < rawArg.size(); i++) {
|
||||
rawArg[i] = &results[i - keySet.numInputs()];
|
||||
rawArg = std::vector<void *>(inputs.size() + outputs.size(), nullptr);
|
||||
// Set the pointer on outputs on rawArg
|
||||
for (auto i = inputs.size(); i < rawArg.size(); i++) {
|
||||
rawArg[i] = &outputs[i - inputs.size()];
|
||||
}
|
||||
}
|
||||
|
||||
JITLambda::Argument::~Argument() {
|
||||
int err;
|
||||
for (auto i = 0; i < keySet.numInputs(); i++) {
|
||||
if (keySet.isInputEncrypted(i)) {
|
||||
free_lwe_ciphertext_u64(&err, (LweCiphertext_u64 *)(inputs[i]));
|
||||
}
|
||||
for (auto ct : allocatedCiphertexts) {
|
||||
free_lwe_ciphertext_u64(&err, ct);
|
||||
}
|
||||
for (auto buffer : ciphertextBuffers) {
|
||||
free(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,38 +225,206 @@ JITLambda::Argument::create(KeySet &keySet) {
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
|
||||
if (pos >= inputGates.size()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument index out of bound: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
auto gate = inputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
auto offset = std::get<1>(gate);
|
||||
|
||||
// Check is the argument is a scalar
|
||||
if (info.shape.size != 0) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument is not a scalar: pos=").concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
// If argument is not encrypted, just save.
|
||||
if (!keySet.isInputEncrypted(pos)) {
|
||||
inputs[pos] = (void *)arg;
|
||||
rawArg[pos] = &inputs[pos];
|
||||
if (!info.encryption.hasValue()) {
|
||||
inputs[offset] = (void *)arg;
|
||||
rawArg[offset] = &inputs[offset];
|
||||
return llvm::Error::success();
|
||||
}
|
||||
// Else if is encryted, allocate ciphertext.
|
||||
// Else if is encryted, allocate ciphertext and encrypt.
|
||||
LweCiphertext_u64 *ctArg;
|
||||
if (auto err = this->keySet.allocate_lwe(pos, &ctArg)) {
|
||||
return std::move(err);
|
||||
}
|
||||
allocatedCiphertexts.push_back(ctArg);
|
||||
if (auto err = this->keySet.encrypt_lwe(pos, ctArg, arg)) {
|
||||
return std::move(err);
|
||||
}
|
||||
inputs[pos] = ctArg;
|
||||
rawArg[pos] = &inputs[pos];
|
||||
inputs[offset] = ctArg;
|
||||
rawArg[offset] = &inputs[offset];
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data,
|
||||
size_t size) {
|
||||
auto gate = inputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
auto offset = std::get<1>(gate);
|
||||
// Check if the width is compatible
|
||||
// TODO - I found this rules empirically, they are a spec somewhere?
|
||||
if (info.shape.width <= 8 && width != 8) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument width should be 8: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (info.shape.width > 8 && info.shape.width <= 16 && width != 16) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument width should be 16: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (info.shape.width > 16 && info.shape.width <= 32 && width != 32) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument width should be 32: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (info.shape.width > 32 && info.shape.width <= 64 && width != 64) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument width should be 64: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (info.shape.width > 64) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument width not supported: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// Check the size
|
||||
if (info.shape.size == 0) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument is not a vector: pos=").concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (info.shape.size != size) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("vector argument has not the expected size")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// If argument is not encrypted, just save with the right calling convention.
|
||||
if (info.encryption.hasValue()) {
|
||||
// Else if is encrypted
|
||||
// For moment we support only 8 bits inputs
|
||||
uint8_t *data8 = (uint8_t *)data;
|
||||
if (width != 8) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine(
|
||||
"argument width > 8 for encrypted gates are not supported: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
// Allocate a buffer for ciphertexts.
|
||||
auto ctBuffer =
|
||||
(LweCiphertext_u64 **)malloc(size * sizeof(LweCiphertext_u64 *));
|
||||
ciphertextBuffers.push_back(ctBuffer);
|
||||
// Allocate ciphertexts and encrypt
|
||||
for (auto i = 0; i < size; i++) {
|
||||
if (auto err = this->keySet.allocate_lwe(pos, &ctBuffer[i])) {
|
||||
return std::move(err);
|
||||
}
|
||||
allocatedCiphertexts.push_back(ctBuffer[i]);
|
||||
if (auto err = this->keySet.encrypt_lwe(pos, ctBuffer[i], data8[i])) {
|
||||
return std::move(err);
|
||||
}
|
||||
}
|
||||
// Replace the data by the buffer to ciphertext
|
||||
data = (void *)ctBuffer;
|
||||
}
|
||||
// Set the buffer as the memref calling convention expect.
|
||||
// allocated
|
||||
inputs[offset] = (void *)0; // TODO - Better understand how it is used.
|
||||
rawArg[offset] = &inputs[offset];
|
||||
// aligned
|
||||
inputs[offset + 1] = data;
|
||||
rawArg[offset + 1] = &inputs[offset + 1];
|
||||
// offset
|
||||
inputs[offset + 2] = (void *)0;
|
||||
rawArg[offset + 2] = &inputs[offset + 2];
|
||||
// size
|
||||
inputs[offset + 3] = (void *)size;
|
||||
rawArg[offset + 3] = &inputs[offset + 3];
|
||||
// stride
|
||||
inputs[offset + 4] = (void *)0;
|
||||
rawArg[offset + 4] = &inputs[offset + 4];
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) {
|
||||
auto gate = outputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
auto offset = std::get<1>(gate);
|
||||
|
||||
// Check is the argument is a scalar
|
||||
if (info.shape.size != 0) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("output is not a scalar, pos=").concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// If result is not encrypted, just set the result
|
||||
if (!keySet.isOutputEncrypted(pos)) {
|
||||
res = (uint64_t)(results[pos]);
|
||||
if (!info.encryption.hasValue()) {
|
||||
res = (uint64_t)(outputs[offset]);
|
||||
return llvm::Error::success();
|
||||
}
|
||||
// Else if is encryted, decrypt
|
||||
LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(results[pos]);
|
||||
LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(outputs[offset]);
|
||||
if (auto err = this->keySet.decrypt_lwe(pos, ct, res)) {
|
||||
return std::move(err);
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res,
|
||||
size_t size) {
|
||||
auto gate = outputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
auto offset = std::get<1>(gate);
|
||||
|
||||
// Check is the argument is a scalar
|
||||
if (info.shape.size == 0) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("output is not a tensor, pos=").concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (!info.encryption.hasValue()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"unencrypted result as tensor output NYI",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// Get the values as the memref calling convention expect.
|
||||
void *allocated = outputs[offset]; // TODO - Better understand how it is used.
|
||||
// aligned
|
||||
void *aligned = outputs[offset + 1];
|
||||
// offset
|
||||
size_t offset_r = (size_t)outputs[offset + 2];
|
||||
// size
|
||||
size_t size_r = (size_t)outputs[offset + 3];
|
||||
// stride
|
||||
size_t stride = (size_t)outputs[offset + 4];
|
||||
// Check the sizes
|
||||
if (info.shape.size != size || size_r != size) {
|
||||
return llvm::make_error<llvm::StringError>("output bad result buffer size",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// decrypt and fill the result buffer
|
||||
for (auto i = 0; i < size_r; i++) {
|
||||
LweCiphertext_u64 *ct = ((LweCiphertext_u64 **)(aligned))[i];
|
||||
if (auto err = this->keySet.decrypt_lwe(pos, ct, res[i])) {
|
||||
return std::move(err);
|
||||
}
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
} // namespace mlir
|
||||
|
||||
@@ -42,7 +42,7 @@ KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
auto e = keySet->generateSecretKey(secretKeyParam.first,
|
||||
secretKeyParam.second, generator);
|
||||
if (e) {
|
||||
return e;
|
||||
return std::move(e);
|
||||
}
|
||||
}
|
||||
CAPI_ERR_TO_LLVM_ERROR(free_secret_generator(&err, generator),
|
||||
@@ -60,7 +60,7 @@ KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
bootstrapKeyParam.second,
|
||||
keySet->encryptionRandomGenerator);
|
||||
if (e) {
|
||||
return e;
|
||||
return std::move(e);
|
||||
}
|
||||
}
|
||||
for (auto keyswitchParam : params.keyswitchKeys) {
|
||||
@@ -68,7 +68,7 @@ KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
keyswitchParam.second,
|
||||
keySet->encryptionRandomGenerator);
|
||||
if (e) {
|
||||
return e;
|
||||
return std::move(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -112,9 +112,8 @@ llvm::Error KeySet::generateSecretKey(LweSecretKeyID id,
|
||||
LweSecretKeyParam param,
|
||||
SecretRandomGenerator *generator) {
|
||||
LweSecretKey_u64 *sk;
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
sk = allocate_lwe_secret_key_u64(&err, {_0 : param.size}),
|
||||
"cannot allocate secret key");
|
||||
CAPI_ERR_TO_LLVM_ERROR(sk = allocate_lwe_secret_key_u64(&err, {param.size}),
|
||||
"cannot allocate secret key");
|
||||
CAPI_ERR_TO_LLVM_ERROR(fill_lwe_secret_key_u64(&err, sk, generator),
|
||||
"cannot fill secret key with random generator")
|
||||
secretKeys[id] = {param, sk};
|
||||
@@ -250,6 +249,7 @@ llvm::Error KeySet::encrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
|
||||
|
||||
llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
|
||||
uint64_t &output) {
|
||||
|
||||
if (argPos >= outputs.size()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"decrypt_lwe: position of argument is too high",
|
||||
@@ -262,13 +262,14 @@ llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// Decrypt
|
||||
Plaintext_u64 plaintext;
|
||||
Plaintext_u64 plaintext = {0};
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
decrypt_lwe_u64(&err, std::get<2>(outputSk), ciphertext, &plaintext),
|
||||
"cannot decrypt");
|
||||
// Decode
|
||||
output = plaintext._0 >>
|
||||
(64 - (std::get<0>(outputSk).encryption->encoding.precision + 1));
|
||||
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
enable_testing()
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
|
||||
|
||||
add_executable(
|
||||
hello_test
|
||||
hello_test.cc
|
||||
@@ -7,6 +10,7 @@ add_executable(
|
||||
target_link_libraries(
|
||||
hello_test
|
||||
gtest_main
|
||||
ZamalangSupport
|
||||
)
|
||||
|
||||
include(GoogleTest)
|
||||
|
||||
@@ -1,9 +1,347 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
// Demonstrate some basic assertions.
|
||||
TEST(HelloTest, BasicAssertions) {
|
||||
// Expect two strings not to be equal.
|
||||
EXPECT_STRNE("hello", "world");
|
||||
// Expect equality.
|
||||
EXPECT_EQ(7 * 6, 42);
|
||||
#include "zamalang/Support/CompilerEngine.h"
|
||||
|
||||
#define ASSERT_LLVM_ERROR(err) \
|
||||
if (err) { \
|
||||
llvm::errs() << "error: " << std::move(err) << "\n"; \
|
||||
ASSERT_TRUE(false); \
|
||||
}
|
||||
|
||||
TEST(CompileAndRunHLFHE, add_eint) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
return %1: !HLFHE.eint<7>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_FALSE(engine.compile(mlirStr));
|
||||
auto maybeResult = engine.run({1, 2});
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
uint64_t result = maybeResult.get();
|
||||
ASSERT_EQ(result, 3);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_64) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10xi64>, %i: index) -> i64{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi64>
|
||||
return %c : i64
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
const size_t size = 10;
|
||||
uint64_t t_arg[size]{0xFFFFFFFFFFFFFFFF,
|
||||
0,
|
||||
8978,
|
||||
2587490,
|
||||
90,
|
||||
197864,
|
||||
698735,
|
||||
72132,
|
||||
87474,
|
||||
42};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_32) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10xi32>, %i: index) -> i32{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi32>
|
||||
return %c : i32
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
const size_t size = 10;
|
||||
uint32_t t_arg[size]{0xFFFFFFFF, 0, 8978, 2587490, 90,
|
||||
197864, 698735, 72132, 87474, 42};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_16) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10xi16>, %i: index) -> i16{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi16>
|
||||
return %c : i16
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
const size_t size = 10;
|
||||
uint16_t t_arg[size]{0xFFFF, 0, 59589, 47826, 16227,
|
||||
63269, 36435, 52380, 7401, 13313};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_8) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10xi8>, %i: index) -> i8{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi8>
|
||||
return %c : i8
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10xi5>, %i: index) -> i5{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi5>
|
||||
return %c : i5
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_1) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10xi1>, %i: index) -> i1{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi1>
|
||||
return %c : i1
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{0, 0, 1, 0, 1, 1, 0, 1, 1, 0};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, extract_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index) -> !HLFHE.eint<5>{
|
||||
%c = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>>
|
||||
return %c : !HLFHE.eint<5>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, extract_twice_and_add_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index, %j: index) -> !HLFHE.eint<5>{
|
||||
%ti = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>>
|
||||
%tj = tensor.extract %t[%j] : tensor<10x!HLFHE.eint<5>>
|
||||
%c = "HLFHE.add_eint"(%ti, %tj) : (!HLFHE.eint<5>, !HLFHE.eint<5>) -> !HLFHE.eint<5>
|
||||
return %c : !HLFHE.eint<5>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
for (size_t j = 0; j < size; j++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Set the %j argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(2, j));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i] + t_arg[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, dim_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10x!HLFHE.eint<5>>) -> index{
|
||||
%c0 = constant 0 : index
|
||||
%c = tensor.dim %t, %c0 : tensor<10x!HLFHE.eint<5>>
|
||||
return %c : index
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, size);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, from_elements_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> {
|
||||
%t = tensor.from_elements %0 : tensor<1x!HLFHE.eint<5>>
|
||||
return %t: tensor<1x!HLFHE.eint<5>>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, 10));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
size_t size_res = 1;
|
||||
uint64_t t_res[size_res];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res));
|
||||
ASSERT_EQ(t_res[0], 10);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, in_out_tensor_with_op_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%in: tensor<2x!HLFHE.eint<5>>) -> tensor<3x!HLFHE.eint<5>> {
|
||||
%c_0 = constant 0 : index
|
||||
%c_1 = constant 1 : index
|
||||
%a = tensor.extract %in[%c_0] : tensor<2x!HLFHE.eint<5>>
|
||||
%b = tensor.extract %in[%c_1] : tensor<2x!HLFHE.eint<5>>
|
||||
%aplusa = "HLFHE.add_eint"(%a, %a): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
|
||||
%aplusb = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
|
||||
%bplusb = "HLFHE.add_eint"(%b, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
|
||||
%out = tensor.from_elements %aplusa, %aplusb, %bplusb : tensor<3x!HLFHE.eint<5>>
|
||||
return %out: tensor<3x!HLFHE.eint<5>>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the argument
|
||||
const size_t in_size = 2;
|
||||
uint8_t in[in_size] = {2, 16};
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, in, in_size));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
const size_t size_res = 3;
|
||||
uint64_t t_res[size_res];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res));
|
||||
ASSERT_EQ(t_res[0], in[0] + in[0]);
|
||||
ASSERT_EQ(t_res[1], in[0] + in[1]);
|
||||
ASSERT_EQ(t_res[2], in[1] + in[1]);
|
||||
}
|
||||
Reference in New Issue
Block a user