From cb635f8a5597b85e8960852974f6b53513fc7831 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Mon, 19 Jul 2021 11:57:21 +0200 Subject: [PATCH] feat(compiler): HLFHE.apply_lookup_table (#54) --- .../zamalang/Dialect/HLFHE/IR/HLFHEOps.td | 10 ++++++ compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp | 31 ++++++++++++++++++- .../op_apply_lookup_table_bad_dimension.mlir | 7 +++++ ...okup_table_bad_integer_width_in_table.mlir | 7 +++++ 4 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir create mode 100644 compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_integer_width_in_table.mlir diff --git a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td index 895d66380..8281ad7a7 100644 --- a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td +++ b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td @@ -51,6 +51,16 @@ def MulEintIntOp : HLFHE_Op<"mul_eint_int"> { ]; } +def ApplyLookupTable : HLFHE_Op<"apply_lookup_table"> { + let arguments = (ins EncryptedIntegerType:$ct, + MemRefOf<[AnyInteger]>:$l_cst); + let results = (outs EncryptedIntegerType); + + let verifier = [{ + return ::mlir::zamalang::HLFHE::verifyApplyLookupTable(*this); + }]; +} + // Tensor operations // Dot product diff --git a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp index aefd4bec6..d563cc961 100644 --- a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp +++ b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp @@ -1,13 +1,42 @@ #include "mlir/IR/Region.h" +#include "mlir/IR/TypeUtilities.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h" #include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" -#include namespace mlir { namespace zamalang { namespace HLFHE { +using mlir::zamalang::HLFHE::ApplyLookupTable; +using mlir::zamalang::HLFHE::EncryptedIntegerType; + +::mlir::LogicalResult verifyApplyLookupTable(ApplyLookupTable &op) { + auto ct = op.ct().getType().cast(); + auto l_cst = op.l_cst().getType().cast(); + auto result = op.getResult().getType().cast(); + + // Check the shape of l_cst argument + auto width = ct.getWidth(); + auto lCstShape = l_cst.getShape(); + mlir::SmallVector expectedShape{1 << width}; + if (!l_cst.hasStaticShape(expectedShape)) { + op.emitOpError() << " should have as `l_cst` argument a shape of one " + "dimension equals to 2^p, where p is the width of the " + "`ct` argument."; + return mlir::failure(); + } + // Check the witdh of the encrypted integer and the integer of the tabulated + // lambda are equals + if (ct.getWidth() != l_cst.getElementType().cast().getWidth()) { + op.emitOpError() + << " should have equals width beetwen the encrypted integer result and " + "integers of the `tabulated_lambda` argument"; + return mlir::failure(); + } + return mlir::success(); +} + void Dot::getEffects( SmallVectorImpl> &effects) { diff --git a/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir b/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir new file mode 100644 index 000000000..b10b92383 --- /dev/null +++ b/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir @@ -0,0 +1,7 @@ +// RUN: not zamacompiler %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'HLFHE.apply_lookup_table' op should have as `l_cst` argument a shape of one dimension equals to 2^p, where p is the width of the `ct` argument. +func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<8xi3>) -> !HLFHE.eint<2> { + %1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, memref<8xi3>) -> (!HLFHE.eint<2>) + return %1: !HLFHE.eint<2> +} \ No newline at end of file diff --git a/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_integer_width_in_table.mlir b/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_integer_width_in_table.mlir new file mode 100644 index 000000000..ba6a3cbf1 --- /dev/null +++ b/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_integer_width_in_table.mlir @@ -0,0 +1,7 @@ +// RUN: not zamacompiler %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'HLFHE.apply_lookup_table' op should have equals width beetwen the encrypted integer result and integers of the `tabulated_lambda` argument +func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<4xi3>) -> !HLFHE.eint<2> { + %1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, memref<4xi3>) -> (!HLFHE.eint<2>) + return %1: !HLFHE.eint<2> +} \ No newline at end of file