feat(compiler): Make Concrete.bootstrap_lwe and Concrete.keyswitch_lwe batchable

This commit is contained in:
Andi Drebes
2022-11-10 16:52:27 +01:00
parent c367a4b6fd
commit 75b70054b2
2 changed files with 74 additions and 2 deletions

View File

@@ -13,6 +13,7 @@
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
#include "concretelang/Interfaces/BatchableInterface.h"
#define GET_OP_CLASSES
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h.inc"

View File

@@ -6,6 +6,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
include "concretelang/Dialect/Concrete/IR/ConcreteDialect.td"
include "concretelang/Dialect/Concrete/IR/ConcreteTypes.td"
include "concretelang/Interfaces/BatchableInterface.td"
class Concrete_Op<string mnemonic, list<Trait> traits = []> :
Op<Concrete_Dialect, mnemonic, traits>;
@@ -52,7 +53,7 @@ def Concrete_NegateLweCiphertextOp : Concrete_Op<"negate_lwe_ciphertext"> {
let results = (outs Concrete_LweCiphertextType:$result);
}
def Concrete_BootstrapLweOp : Concrete_Op<"bootstrap_lwe"> {
def Concrete_BootstrapLweOp : Concrete_Op<"bootstrap_lwe", [BatchableOpInterface]> {
let summary = "Bootstraps a LWE ciphertext with a GLWE trivial encryption of the lookup table";
let arguments = (ins
@@ -64,9 +65,46 @@ def Concrete_BootstrapLweOp : Concrete_Op<"bootstrap_lwe"> {
I32Attr:$glweDimension
);
let results = (outs Concrete_LweCiphertextType:$result);
let extraClassDeclaration = [{
::mlir::OpOperand& getBatchableOperand() {
return getOperation()->getOpOperand(0);
}
::mlir::OperandRange getNonBatchableOperands() {
return getOperation()->getOperands().drop_front();
}
::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder,
::mlir::Value batchedOperands) {
::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get(
batchedOperands.getType().cast<::mlir::RankedTensorType>().getShape(),
getResult().getType());
return builder.create<BatchedBootstrapLweOp>(
mlir::TypeRange{resType},
mlir::ValueRange{batchedOperands, lookup_table()},
getOperation()->getAttrs());
}
}];
}
def Concrete_KeySwitchLweOp : Concrete_Op<"keyswitch_lwe"> {
def Concrete_BatchedBootstrapLweOp : Concrete_Op<"batched_bootstrap_lwe"> {
let summary = "Batched version of BootstrapLweOp, which performs the same operation on a tensor of elements";
let arguments = (ins
1DTensorOf<[Concrete_LweCiphertextType]>:$input_ciphertexts,
1DTensorOf<[I64]>:$lookup_table,
I32Attr:$level,
I32Attr:$baseLog,
I32Attr:$polySize,
I32Attr:$glweDimension
);
let results = (outs 1DTensorOf<[Concrete_LweCiphertextType]>:$result);
}
def Concrete_KeySwitchLweOp : Concrete_Op<"keyswitch_lwe", [BatchableOpInterface]> {
let summary = "Keyswitches a LWE ciphertext";
let arguments = (ins
@@ -75,6 +113,39 @@ def Concrete_KeySwitchLweOp : Concrete_Op<"keyswitch_lwe"> {
I32Attr:$baseLog
);
let results = (outs Concrete_LweCiphertextType:$result);
let extraClassDeclaration = [{
::mlir::OpOperand& getBatchableOperand() {
return getOperation()->getOpOperand(0);
}
::mlir::OperandRange getNonBatchableOperands() {
return getOperation()->getOperands().drop_front();
}
::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder,
::mlir::Value batchedOperands) {
::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get(
batchedOperands.getType().cast<::mlir::RankedTensorType>().getShape(),
getResult().getType());
return builder.create<BatchedKeySwitchLweOp>(
mlir::TypeRange{resType},
mlir::ValueRange{batchedOperands},
getOperation()->getAttrs());
}
}];
}
def Concrete_BatchedKeySwitchLweOp : Concrete_Op<"batched_keyswitch_lwe"> {
let summary = "Batched version of KeySwitchLweOp, which performs the same operation on a tensor of elements";
let arguments = (ins
1DTensorOf<[Concrete_LweCiphertextType]>:$ciphertexts,
I32Attr:$level,
I32Attr:$baseLog
);
let results = (outs 1DTensorOf<[Concrete_LweCiphertextType]>:$result);
}
// TODO(16bits): hack