From 75b70054b2f623beab663e4787f0eb5d151fd2ea Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Thu, 10 Nov 2022 16:52:27 +0100 Subject: [PATCH] feat(compiler): Make Concrete.bootstrap_lwe and Concrete.keyswitch_lwe batchable --- .../Dialect/Concrete/IR/ConcreteOps.h | 1 + .../Dialect/Concrete/IR/ConcreteOps.td | 75 ++++++++++++++++++- 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.h b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.h index 57c059e59..63828e42a 100644 --- a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.h +++ b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.h @@ -13,6 +13,7 @@ #include #include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" +#include "concretelang/Interfaces/BatchableInterface.h" #define GET_OP_CLASSES #include "concretelang/Dialect/Concrete/IR/ConcreteOps.h.inc" diff --git a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td index 08ff2a34a..7ecda3537 100644 --- a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td +++ b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td @@ -6,6 +6,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "concretelang/Dialect/Concrete/IR/ConcreteDialect.td" include "concretelang/Dialect/Concrete/IR/ConcreteTypes.td" +include "concretelang/Interfaces/BatchableInterface.td" class Concrete_Op traits = []> : Op; @@ -52,7 +53,7 @@ def Concrete_NegateLweCiphertextOp : Concrete_Op<"negate_lwe_ciphertext"> { let results = (outs Concrete_LweCiphertextType:$result); } -def Concrete_BootstrapLweOp : Concrete_Op<"bootstrap_lwe"> { +def Concrete_BootstrapLweOp : Concrete_Op<"bootstrap_lwe", [BatchableOpInterface]> { let summary = "Bootstraps a LWE ciphertext with a GLWE trivial encryption of the lookup table"; let arguments = (ins @@ -64,9 +65,46 @@ def Concrete_BootstrapLweOp : Concrete_Op<"bootstrap_lwe"> { I32Attr:$glweDimension ); let results = (outs Concrete_LweCiphertextType:$result); + + let extraClassDeclaration = [{ + ::mlir::OpOperand& getBatchableOperand() { + return getOperation()->getOpOperand(0); + } + + ::mlir::OperandRange getNonBatchableOperands() { + return getOperation()->getOperands().drop_front(); + } + + ::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder, + ::mlir::Value batchedOperands) { + ::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get( + batchedOperands.getType().cast<::mlir::RankedTensorType>().getShape(), + getResult().getType()); + + return builder.create( + mlir::TypeRange{resType}, + mlir::ValueRange{batchedOperands, lookup_table()}, + getOperation()->getAttrs()); + } + }]; + } -def Concrete_KeySwitchLweOp : Concrete_Op<"keyswitch_lwe"> { +def Concrete_BatchedBootstrapLweOp : Concrete_Op<"batched_bootstrap_lwe"> { + let summary = "Batched version of BootstrapLweOp, which performs the same operation on a tensor of elements"; + + let arguments = (ins + 1DTensorOf<[Concrete_LweCiphertextType]>:$input_ciphertexts, + 1DTensorOf<[I64]>:$lookup_table, + I32Attr:$level, + I32Attr:$baseLog, + I32Attr:$polySize, + I32Attr:$glweDimension + ); + let results = (outs 1DTensorOf<[Concrete_LweCiphertextType]>:$result); +} + +def Concrete_KeySwitchLweOp : Concrete_Op<"keyswitch_lwe", [BatchableOpInterface]> { let summary = "Keyswitches a LWE ciphertext"; let arguments = (ins @@ -75,6 +113,39 @@ def Concrete_KeySwitchLweOp : Concrete_Op<"keyswitch_lwe"> { I32Attr:$baseLog ); let results = (outs Concrete_LweCiphertextType:$result); + + let extraClassDeclaration = [{ + ::mlir::OpOperand& getBatchableOperand() { + return getOperation()->getOpOperand(0); + } + + ::mlir::OperandRange getNonBatchableOperands() { + return getOperation()->getOperands().drop_front(); + } + + ::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder, + ::mlir::Value batchedOperands) { + ::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get( + batchedOperands.getType().cast<::mlir::RankedTensorType>().getShape(), + getResult().getType()); + + return builder.create( + mlir::TypeRange{resType}, + mlir::ValueRange{batchedOperands}, + getOperation()->getAttrs()); + } + }]; +} + +def Concrete_BatchedKeySwitchLweOp : Concrete_Op<"batched_keyswitch_lwe"> { + let summary = "Batched version of KeySwitchLweOp, which performs the same operation on a tensor of elements"; + + let arguments = (ins + 1DTensorOf<[Concrete_LweCiphertextType]>:$ciphertexts, + I32Attr:$level, + I32Attr:$baseLog + ); + let results = (outs 1DTensorOf<[Concrete_LweCiphertextType]>:$result); } // TODO(16bits): hack