mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 04:35:03 -05:00
enhance(compiler/runtime): Add runtime tools to handle tensor inputs and outputs
This commit is contained in:
@@ -141,10 +141,16 @@ JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
|
||||
if (this->type.getNumParams() != args.size() - 1) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"invokeRaw: wrong number of argument", llvm::inconvertibleErrorCode());
|
||||
}
|
||||
size_t nbReturn = 0;
|
||||
// TODO - This check break with memref as we have 5 returns args.
|
||||
// if (!this->type.getReturnType().isa<mlir::LLVM::LLVMVoidType>()) {
|
||||
// nbReturn = 1;
|
||||
// }
|
||||
// if (this->type.getNumParams() != args.size() - nbReturn) {
|
||||
// return llvm::make_error<llvm::StringError>(
|
||||
// "invokeRaw: wrong number of argument",
|
||||
// llvm::inconvertibleErrorCode());
|
||||
// }
|
||||
if (llvm::find(args, nullptr) != args.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"invoke: some arguments are null", llvm::inconvertibleErrorCode());
|
||||
@@ -157,24 +163,58 @@ llvm::Error JITLambda::invoke(Argument &args) {
|
||||
}
|
||||
|
||||
JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
|
||||
inputs = std::vector<void *>(keySet.numInputs());
|
||||
results = std::vector<void *>(keySet.numOutputs());
|
||||
// Setting the inputs
|
||||
{
|
||||
auto numInputs = 0;
|
||||
for (size_t i = 0; i < keySet.numInputs(); i++) {
|
||||
auto offset = numInputs;
|
||||
auto gate = keySet.inputGate(i);
|
||||
inputGates.push_back({gate, offset});
|
||||
if (keySet.inputGate(i).shape.size == 0) {
|
||||
// scalar gate
|
||||
numInputs = numInputs + 1;
|
||||
continue;
|
||||
}
|
||||
// memref gate, as we follow the standard calling convention
|
||||
numInputs = numInputs + 5;
|
||||
}
|
||||
inputs = std::vector<void *>(numInputs);
|
||||
}
|
||||
|
||||
// Setting the outputs
|
||||
{
|
||||
auto numOutputs = 0;
|
||||
for (size_t i = 0; i < keySet.numOutputs(); i++) {
|
||||
auto offset = numOutputs;
|
||||
auto gate = keySet.outputGate(i);
|
||||
outputGates.push_back({gate, offset});
|
||||
if (gate.shape.size == 0) {
|
||||
// scalar gate
|
||||
numOutputs = numOutputs + 1;
|
||||
continue;
|
||||
}
|
||||
// memref gate, as we follow the standard calling convention
|
||||
numOutputs = numOutputs + 5;
|
||||
}
|
||||
outputs = std::vector<void *>(numOutputs);
|
||||
}
|
||||
|
||||
// The raw argument contains pointers to inputs and pointers to store the
|
||||
// results
|
||||
rawArg =
|
||||
std::vector<void *>(keySet.numInputs() + keySet.numOutputs(), nullptr);
|
||||
// Set the results pointer on the rawArg
|
||||
for (auto i = keySet.numInputs(); i < rawArg.size(); i++) {
|
||||
rawArg[i] = &results[i - keySet.numInputs()];
|
||||
rawArg = std::vector<void *>(inputs.size() + outputs.size(), nullptr);
|
||||
// Set the pointer on outputs on rawArg
|
||||
for (auto i = inputs.size(); i < rawArg.size(); i++) {
|
||||
rawArg[i] = &outputs[i - inputs.size()];
|
||||
}
|
||||
}
|
||||
|
||||
JITLambda::Argument::~Argument() {
|
||||
int err;
|
||||
for (auto i = 0; i < keySet.numInputs(); i++) {
|
||||
if (keySet.isInputEncrypted(i)) {
|
||||
free_lwe_ciphertext_u64(&err, (LweCiphertext_u64 *)(inputs[i]));
|
||||
}
|
||||
for (auto ct : allocatedCiphertexts) {
|
||||
free_lwe_ciphertext_u64(&err, ct);
|
||||
}
|
||||
for (auto buffer : ciphertextBuffers) {
|
||||
free(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,38 +225,206 @@ JITLambda::Argument::create(KeySet &keySet) {
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
|
||||
if (pos >= inputGates.size()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument index out of bound: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
auto gate = inputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
auto offset = std::get<1>(gate);
|
||||
|
||||
// Check is the argument is a scalar
|
||||
if (info.shape.size != 0) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument is not a scalar: pos=").concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
// If argument is not encrypted, just save.
|
||||
if (!keySet.isInputEncrypted(pos)) {
|
||||
inputs[pos] = (void *)arg;
|
||||
rawArg[pos] = &inputs[pos];
|
||||
if (!info.encryption.hasValue()) {
|
||||
inputs[offset] = (void *)arg;
|
||||
rawArg[offset] = &inputs[offset];
|
||||
return llvm::Error::success();
|
||||
}
|
||||
// Else if is encryted, allocate ciphertext.
|
||||
// Else if is encryted, allocate ciphertext and encrypt.
|
||||
LweCiphertext_u64 *ctArg;
|
||||
if (auto err = this->keySet.allocate_lwe(pos, &ctArg)) {
|
||||
return std::move(err);
|
||||
}
|
||||
allocatedCiphertexts.push_back(ctArg);
|
||||
if (auto err = this->keySet.encrypt_lwe(pos, ctArg, arg)) {
|
||||
return std::move(err);
|
||||
}
|
||||
inputs[pos] = ctArg;
|
||||
rawArg[pos] = &inputs[pos];
|
||||
inputs[offset] = ctArg;
|
||||
rawArg[offset] = &inputs[offset];
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data,
|
||||
size_t size) {
|
||||
auto gate = inputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
auto offset = std::get<1>(gate);
|
||||
// Check if the width is compatible
|
||||
// TODO - I found this rules empirically, they are a spec somewhere?
|
||||
if (info.shape.width <= 8 && width != 8) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument width should be 8: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (info.shape.width > 8 && info.shape.width <= 16 && width != 16) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument width should be 16: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (info.shape.width > 16 && info.shape.width <= 32 && width != 32) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument width should be 32: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (info.shape.width > 32 && info.shape.width <= 64 && width != 64) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument width should be 64: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (info.shape.width > 64) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument width not supported: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// Check the size
|
||||
if (info.shape.size == 0) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument is not a vector: pos=").concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (info.shape.size != size) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("vector argument has not the expected size")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// If argument is not encrypted, just save with the right calling convention.
|
||||
if (info.encryption.hasValue()) {
|
||||
// Else if is encrypted
|
||||
// For moment we support only 8 bits inputs
|
||||
uint8_t *data8 = (uint8_t *)data;
|
||||
if (width != 8) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine(
|
||||
"argument width > 8 for encrypted gates are not supported: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
// Allocate a buffer for ciphertexts.
|
||||
auto ctBuffer =
|
||||
(LweCiphertext_u64 **)malloc(size * sizeof(LweCiphertext_u64 *));
|
||||
ciphertextBuffers.push_back(ctBuffer);
|
||||
// Allocate ciphertexts and encrypt
|
||||
for (auto i = 0; i < size; i++) {
|
||||
if (auto err = this->keySet.allocate_lwe(pos, &ctBuffer[i])) {
|
||||
return std::move(err);
|
||||
}
|
||||
allocatedCiphertexts.push_back(ctBuffer[i]);
|
||||
if (auto err = this->keySet.encrypt_lwe(pos, ctBuffer[i], data8[i])) {
|
||||
return std::move(err);
|
||||
}
|
||||
}
|
||||
// Replace the data by the buffer to ciphertext
|
||||
data = (void *)ctBuffer;
|
||||
}
|
||||
// Set the buffer as the memref calling convention expect.
|
||||
// allocated
|
||||
inputs[offset] = (void *)0; // TODO - Better understand how it is used.
|
||||
rawArg[offset] = &inputs[offset];
|
||||
// aligned
|
||||
inputs[offset + 1] = data;
|
||||
rawArg[offset + 1] = &inputs[offset + 1];
|
||||
// offset
|
||||
inputs[offset + 2] = (void *)0;
|
||||
rawArg[offset + 2] = &inputs[offset + 2];
|
||||
// size
|
||||
inputs[offset + 3] = (void *)size;
|
||||
rawArg[offset + 3] = &inputs[offset + 3];
|
||||
// stride
|
||||
inputs[offset + 4] = (void *)0;
|
||||
rawArg[offset + 4] = &inputs[offset + 4];
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) {
|
||||
auto gate = outputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
auto offset = std::get<1>(gate);
|
||||
|
||||
// Check is the argument is a scalar
|
||||
if (info.shape.size != 0) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("output is not a scalar, pos=").concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// If result is not encrypted, just set the result
|
||||
if (!keySet.isOutputEncrypted(pos)) {
|
||||
res = (uint64_t)(results[pos]);
|
||||
if (!info.encryption.hasValue()) {
|
||||
res = (uint64_t)(outputs[offset]);
|
||||
return llvm::Error::success();
|
||||
}
|
||||
// Else if is encryted, decrypt
|
||||
LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(results[pos]);
|
||||
LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(outputs[offset]);
|
||||
if (auto err = this->keySet.decrypt_lwe(pos, ct, res)) {
|
||||
return std::move(err);
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res,
|
||||
size_t size) {
|
||||
auto gate = outputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
auto offset = std::get<1>(gate);
|
||||
|
||||
// Check is the argument is a scalar
|
||||
if (info.shape.size == 0) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("output is not a tensor, pos=").concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (!info.encryption.hasValue()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"unencrypted result as tensor output NYI",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// Get the values as the memref calling convention expect.
|
||||
void *allocated = outputs[offset]; // TODO - Better understand how it is used.
|
||||
// aligned
|
||||
void *aligned = outputs[offset + 1];
|
||||
// offset
|
||||
size_t offset_r = (size_t)outputs[offset + 2];
|
||||
// size
|
||||
size_t size_r = (size_t)outputs[offset + 3];
|
||||
// stride
|
||||
size_t stride = (size_t)outputs[offset + 4];
|
||||
// Check the sizes
|
||||
if (info.shape.size != size || size_r != size) {
|
||||
return llvm::make_error<llvm::StringError>("output bad result buffer size",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// decrypt and fill the result buffer
|
||||
for (auto i = 0; i < size_r; i++) {
|
||||
LweCiphertext_u64 *ct = ((LweCiphertext_u64 **)(aligned))[i];
|
||||
if (auto err = this->keySet.decrypt_lwe(pos, ct, res[i])) {
|
||||
return std::move(err);
|
||||
}
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
} // namespace mlir
|
||||
|
||||
Reference in New Issue
Block a user