mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(compiler): Make Concrete.bootstrap_lwe and Concrete.keyswitch_lwe batchable
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
#include <mlir/Interfaces/SideEffectInterfaces.h>
|
||||
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "concretelang/Interfaces/BatchableInterface.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h.inc"
|
||||
|
||||
@@ -6,6 +6,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
|
||||
include "concretelang/Dialect/Concrete/IR/ConcreteDialect.td"
|
||||
include "concretelang/Dialect/Concrete/IR/ConcreteTypes.td"
|
||||
include "concretelang/Interfaces/BatchableInterface.td"
|
||||
|
||||
class Concrete_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<Concrete_Dialect, mnemonic, traits>;
|
||||
@@ -52,7 +53,7 @@ def Concrete_NegateLweCiphertextOp : Concrete_Op<"negate_lwe_ciphertext"> {
|
||||
let results = (outs Concrete_LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
def Concrete_BootstrapLweOp : Concrete_Op<"bootstrap_lwe"> {
|
||||
def Concrete_BootstrapLweOp : Concrete_Op<"bootstrap_lwe", [BatchableOpInterface]> {
|
||||
let summary = "Bootstraps a LWE ciphertext with a GLWE trivial encryption of the lookup table";
|
||||
|
||||
let arguments = (ins
|
||||
@@ -64,9 +65,46 @@ def Concrete_BootstrapLweOp : Concrete_Op<"bootstrap_lwe"> {
|
||||
I32Attr:$glweDimension
|
||||
);
|
||||
let results = (outs Concrete_LweCiphertextType:$result);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::OpOperand& getBatchableOperand() {
|
||||
return getOperation()->getOpOperand(0);
|
||||
}
|
||||
|
||||
::mlir::OperandRange getNonBatchableOperands() {
|
||||
return getOperation()->getOperands().drop_front();
|
||||
}
|
||||
|
||||
::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder,
|
||||
::mlir::Value batchedOperands) {
|
||||
::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get(
|
||||
batchedOperands.getType().cast<::mlir::RankedTensorType>().getShape(),
|
||||
getResult().getType());
|
||||
|
||||
return builder.create<BatchedBootstrapLweOp>(
|
||||
mlir::TypeRange{resType},
|
||||
mlir::ValueRange{batchedOperands, lookup_table()},
|
||||
getOperation()->getAttrs());
|
||||
}
|
||||
}];
|
||||
|
||||
}
|
||||
|
||||
def Concrete_KeySwitchLweOp : Concrete_Op<"keyswitch_lwe"> {
|
||||
def Concrete_BatchedBootstrapLweOp : Concrete_Op<"batched_bootstrap_lwe"> {
|
||||
let summary = "Batched version of BootstrapLweOp, which performs the same operation on a tensor of elements";
|
||||
|
||||
let arguments = (ins
|
||||
1DTensorOf<[Concrete_LweCiphertextType]>:$input_ciphertexts,
|
||||
1DTensorOf<[I64]>:$lookup_table,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$polySize,
|
||||
I32Attr:$glweDimension
|
||||
);
|
||||
let results = (outs 1DTensorOf<[Concrete_LweCiphertextType]>:$result);
|
||||
}
|
||||
|
||||
def Concrete_KeySwitchLweOp : Concrete_Op<"keyswitch_lwe", [BatchableOpInterface]> {
|
||||
let summary = "Keyswitches a LWE ciphertext";
|
||||
|
||||
let arguments = (ins
|
||||
@@ -75,6 +113,39 @@ def Concrete_KeySwitchLweOp : Concrete_Op<"keyswitch_lwe"> {
|
||||
I32Attr:$baseLog
|
||||
);
|
||||
let results = (outs Concrete_LweCiphertextType:$result);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::OpOperand& getBatchableOperand() {
|
||||
return getOperation()->getOpOperand(0);
|
||||
}
|
||||
|
||||
::mlir::OperandRange getNonBatchableOperands() {
|
||||
return getOperation()->getOperands().drop_front();
|
||||
}
|
||||
|
||||
::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder,
|
||||
::mlir::Value batchedOperands) {
|
||||
::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get(
|
||||
batchedOperands.getType().cast<::mlir::RankedTensorType>().getShape(),
|
||||
getResult().getType());
|
||||
|
||||
return builder.create<BatchedKeySwitchLweOp>(
|
||||
mlir::TypeRange{resType},
|
||||
mlir::ValueRange{batchedOperands},
|
||||
getOperation()->getAttrs());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Concrete_BatchedKeySwitchLweOp : Concrete_Op<"batched_keyswitch_lwe"> {
|
||||
let summary = "Batched version of KeySwitchLweOp, which performs the same operation on a tensor of elements";
|
||||
|
||||
let arguments = (ins
|
||||
1DTensorOf<[Concrete_LweCiphertextType]>:$ciphertexts,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog
|
||||
);
|
||||
let results = (outs 1DTensorOf<[Concrete_LweCiphertextType]>:$result);
|
||||
}
|
||||
|
||||
// TODO(16bits): hack
|
||||
|
||||
Reference in New Issue
Block a user