feat: MidToLowLFHE lowering of apply_lut

This commit is contained in:
youben11
2021-08-19 09:08:46 +01:00
committed by Quentin Bourgerie
parent b6c3eceadd
commit 3b5ae0657d
3 changed files with 53 additions and 0 deletions

View File

@@ -169,6 +169,42 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter rewriter,
return op.getODSResults(0).front();
}
mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc,
mlir::Value ct, mlir::Value table, mlir::IntegerAttr k,
mlir::IntegerAttr polynomialSize,
mlir::IntegerAttr levelKS, mlir::IntegerAttr baseLogKS,
mlir::IntegerAttr levelBS, mlir::IntegerAttr baseLogBS,
mlir::OpResult result) {
// fill the the table in the GLWE accumulator
mlir::Value accumulator =
rewriter
.create<mlir::zamalang::LowLFHE::GlweFromTable>(
loc, LowLFHE::GlweCiphertextType::get(rewriter.getContext()),
table, polynomialSize, k)
.result();
// keyswitch
auto ct_type = ct.getType().cast<GLWECipherTextType>();
mlir::Value keyswitched =
rewriter
.create<mlir::zamalang::LowLFHE::KeySwitchLweOp>(
loc, convertTypeGLWEToLWE(rewriter.getContext(), ct_type), ct)
.result();
// convert result type
GLWECipherTextType glwe_type = result.getType().cast<GLWECipherTextType>();
LweCiphertextType lwe_type =
convertTypeGLWEToLWE(rewriter.getContext(), glwe_type);
// bootstrap operation
mlir::Value bootstrapped =
rewriter
.create<mlir::zamalang::LowLFHE::BootstrapLweOp>(
loc, lwe_type, keyswitched, accumulator)
.result();
return bootstrapped;
}
} // namespace zamalang
} // namespace mlir

View File

@@ -35,4 +35,10 @@ def SubIntGLWEPattern : Pat<
(SubIntGLWEOp:$result $arg0, $arg1),
(createSubIntLweOp $arg0, $arg1, $result)>;
def createPBS : NativeCodeCall<"mlir::zamalang::createPBS($_builder, $_loc, $0, $1, $2, $3, $4, $5, $6, $7, $8)">;
def ApplyLookupTableGLWEPattern : Pat<
(ApplyLookupTable:$result $ct, $table, $k, $polynomialSize, $levelKS, $baseLogKS, $levelBS, $baseLogBS),
(createPBS $ct, $table, $k, $polynomialSize, $levelKS, $baseLogKS, $levelBS, $baseLogBS, $result)>;
#endif

View File

@@ -0,0 +1,11 @@
// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi4>) -> !LowLFHE.lwe_ciphertext<1024,4>
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi4>) -> !MidLFHE.glwe<{1024,1,64}{4}> {
// CHECK-NEXT: %[[V1:.*]] = "LowLFHE.glwe_from_table"(%arg1) {k = 1 : i32, polynomialSize = 1024 : i32} : (tensor<16xi4>) -> !LowLFHE.glwe_ciphertext
// CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
// CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%[[V2]], %[[V1]]) : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
// CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext<1024,4>
%1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1){k=1:i32, polynomialSize=1024:i32, levelKS=-1:i32, baseLogKS=-1:i32, levelBS=-1:i32, baseLogBS=-1:i32}: (!MidLFHE.glwe<{1024,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{1024,1,64}{4}>)
return %1: !MidLFHE.glwe<{1024,1,64}{4}>
}