refactor(compiler): Move memref HLFHE and MidLFHE operators to tensor

This commit is contained in:
Quentin Bourgerie
2021-08-17 16:20:11 +02:00
parent 7d6738916b
commit fa62e1f0e5
14 changed files with 61 additions and 63 deletions

View File

@@ -80,7 +80,7 @@ def MulEintIntOp : HLFHE_Op<"mul_eint_int"> {
def ApplyLookupTableEintOp : HLFHE_Op<"apply_lookup_table"> {
let arguments = (ins EncryptedIntegerType:$ct,
MemRefOf<[AnyInteger]>:$l_cst);
TensorOf<[AnyInteger]>:$l_cst);
let results = (outs EncryptedIntegerType);
let verifier = [{

View File

@@ -59,7 +59,7 @@ def MulGLWEIntOp : MidLFHE_Op<"mul_glwe_int"> {
def ApplyLookupTable : MidLFHE_Op<"apply_lookup_table"> {
let arguments = (ins GLWECipherTextType:$ct,
MemRefOf<[AnyInteger]>:$l_cst,
TensorOf<[AnyInteger]>:$l_cst,
I32Attr:$k, I32Attr:$polynomialSize,
I32Attr:$levelKS, I32Attr:$baseLogKS,
I32Attr:$levelBS, I32Attr:$baseLogBS);

View File

@@ -31,17 +31,15 @@ public:
return mlir::zamalang::convertTypeEncryptedIntegerToGLWE(
type.getContext(), type);
});
addConversion([](mlir::MemRefType type) {
addConversion([](mlir::RankedTensorType type) {
auto eint =
type.getElementType().dyn_cast_or_null<EncryptedIntegerType>();
if (eint == nullptr) {
return (mlir::Type)(type);
}
mlir::Type r = mlir::MemRefType::get(
type.getShape(),
mlir::zamalang::convertTypeEncryptedIntegerToGLWE(eint.getContext(),
eint),
type.getAffineMaps(), type.getMemorySpace());
mlir::Type r = mlir::RankedTensorType::get(
type.getShape(), mlir::zamalang::convertTypeEncryptedIntegerToGLWE(
eint.getContext(), eint));
return r;
});
}

View File

@@ -31,15 +31,14 @@ public:
addConversion([&](GLWECipherTextType type) {
return mlir::zamalang::convertTypeGLWEToLWE(type.getContext(), type);
});
addConversion([&](mlir::MemRefType type) {
addConversion([&](mlir::RankedTensorType type) {
auto glwe = type.getElementType().dyn_cast_or_null<GLWECipherTextType>();
if (glwe == nullptr) {
return (mlir::Type)(type);
}
mlir::Type r = mlir::MemRefType::get(
mlir::Type r = mlir::RankedTensorType::get(
type.getShape(),
mlir::zamalang::convertTypeGLWEToLWE(glwe.getContext(), glwe),
type.getAffineMaps(), type.getMemorySpace());
mlir::zamalang::convertTypeGLWEToLWE(glwe.getContext(), glwe));
return r;
});
}

View File

@@ -94,7 +94,7 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::OpState &op,
::mlir::LogicalResult verifyApplyLookupTable(ApplyLookupTableEintOp &op) {
auto ct = op.ct().getType().cast<EncryptedIntegerType>();
auto l_cst = op.l_cst().getType().cast<MemRefType>();
auto l_cst = op.l_cst().getType().cast<TensorType>();
auto result = op.getResult().getType().cast<EncryptedIntegerType>();
// Check the shape of l_cst argument

View File

@@ -108,7 +108,7 @@ mlir::LogicalResult verifyBinaryGLWEOperator(Operator &op) {
/// - The lookup table contains integer values of the same width of the output
mlir::LogicalResult verifyApplyLookupTable(ApplyLookupTable &op) {
auto ct = op.ct().getType().cast<GLWECipherTextType>();
auto l_cst = op.l_cst().getType().cast<MemRefType>();
auto l_cst = op.l_cst().getType().cast<RankedTensorType>();
auto result = op.getResult().getType().cast<GLWECipherTextType>();
// Check the shape of l_cst argument

View File

@@ -1,10 +1,10 @@
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{_,_,_}{2}>, %arg1: memref<4xi2>) -> !MidLFHE.glwe<{_,_,_}{2}>
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<4xi2>) -> !HLFHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = "MidLFHE.apply_lookup_table"(%arg0, %arg1) {baseLogBS = -1 : i32, baseLogKS = -1 : i32, k = -1 : i32, levelBS = -1 : i32, levelKS = -1 : i32, polynomialSize = -1 : i32} : (!MidLFHE.glwe<{_,_,_}{2}>, memref<4xi2>) -> !MidLFHE.glwe<{_,_,_}{2}>
// CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{_,_,_}{2}>, %arg1: tensor<4xi2>) -> !MidLFHE.glwe<{_,_,_}{2}>
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi2>) -> !HLFHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = "MidLFHE.apply_lookup_table"(%arg0, %arg1) {baseLogBS = -1 : i32, baseLogKS = -1 : i32, k = -1 : i32, levelBS = -1 : i32, levelKS = -1 : i32, polynomialSize = -1 : i32} : (!MidLFHE.glwe<{_,_,_}{2}>, tensor<4xi2>) -> !MidLFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: return %[[V1]] : !MidLFHE.glwe<{_,_,_}{2}>
%1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, memref<4xi2>) -> (!HLFHE.eint<2>)
%1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, tensor<4xi2>) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}

View File

@@ -1,29 +0,0 @@
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s
//CHECK: #map0 = affine_map<(d0) -> (d0)>
// CHECK-NEXT #map1 = affine_map<(d0) -> ()>
// CHECK-NEXT module {
// CHECK-NEXT func @dot_eint_int(%arg0: memref<2x!MidLFHE.glwe<{_,_,_}{2}>>, %arg1: memref<2xi3>, %arg2: memref<!MidLFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : memref<2x!MidLFHE.glwe<{_,_,_}{2}>>, memref<2xi3>) outs(%arg2 : memref<!MidLFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT ^bb0(%arg3: !MidLFHE.glwe<{_,_,_}{2}>, %arg4: i3, %arg5: !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors
// CHECK-NEXT %0 = "MidLFHE.mul_glwe_int"(%arg3, %arg4) : (!MidLFHE.glwe<{_,_,_}{2}>, i3) -> !MidLFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT %1 = "MidLFHE.add_glwe"(%0, %arg5) : (!MidLFHE.glwe<{_,_,_}{2}>, !MidLFHE.glwe<{_,_,_}{2}>) -> !MidLFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT linalg.yield %1 : !MidLFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT }
// CHECK-NEXT return
// CHECK-NEXT }
// CHECK-NEXT }
#map0 = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0) -> ()>
module {
func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>, %arg1: memref<2xi3>, %arg2: memref<!HLFHE.eint<2>>) {
linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : memref<2x!HLFHE.eint<2>>, memref<2xi3>) outs(%arg2 : memref<!HLFHE.eint<2>>) {
^bb0(%arg3: !HLFHE.eint<2>, %arg4: i3, %arg5: !HLFHE.eint<2>): // no predecessors
%0 = "HLFHE.mul_eint_int"(%arg3, %arg4) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
%1 = "HLFHE.add_eint"(%0, %arg5) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2>
linalg.yield %1 : !HLFHE.eint<2>
}
return
}
}

View File

@@ -0,0 +1,30 @@
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s
//CHECK: #map0 = affine_map<(d0) -> (d0)>
//CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>
//CHECK-NEXT: module {
//CHECK-NEXT: func @linalg_generic(%arg0: tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<2xi3>, %arg2: tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %0 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<2xi3>) outs(%arg2 : tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: ^bb0(%arg3: !MidLFHE.glwe<{_,_,_}{2}>, %arg4: i3, %arg5: !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors
//CHECK-NEXT: %1 = "MidLFHE.mul_glwe_int"(%arg3, %arg4) : (!MidLFHE.glwe<{_,_,_}{2}>, i3) -> !MidLFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: %2 = "MidLFHE.add_glwe"(%1, %arg5) : (!MidLFHE.glwe<{_,_,_}{2}>, !MidLFHE.glwe<{_,_,_}{2}>) -> !MidLFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: linalg.yield %2 : !MidLFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: } -> tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: return
//CHECK-NEXT: }
//CHECK-NEXT: }
#map0 = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0) -> (0)>
module {
func @linalg_generic(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi3>, %acc: tensor<1x!HLFHE.eint<2>>) {
%2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) outs(%acc : tensor<1x!HLFHE.eint<2>>) {
^bb0(%arg2: !HLFHE.eint<2>, %arg3: i3, %arg4: !HLFHE.eint<2>): // no predecessors
%4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
%5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2>
linalg.yield %5 : !HLFHE.eint<2>
} -> tensor<1x!HLFHE.eint<2>>
return
}
}

View File

@@ -1,7 +1,7 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.apply_lookup_table' op should have as `l_cst` argument a shape of one dimension equals to 2^p, where p is the width of the `ct` argument.
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<8xi3>) -> !HLFHE.eint<2> {
%1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, memref<8xi3>) -> (!HLFHE.eint<2>)
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<8xi3>) -> !HLFHE.eint<2> {
%1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, tensor<8xi3>) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}

View File

@@ -1,7 +1,7 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.apply_lookup_table' op should have equals width beetwen the encrypted integer result and integers of the `tabulated_lambda` argument
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<4xi3>) -> !HLFHE.eint<2> {
%1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, memref<4xi3>) -> (!HLFHE.eint<2>)
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi3>) -> !HLFHE.eint<2> {
%1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, tensor<4xi3>) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}

View File

@@ -51,12 +51,12 @@ func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
return %1: !HLFHE.eint<2>
}
// CHECK-LABEL: func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<4xi2>) -> !HLFHE.eint<2>
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<4xi2>) -> !HLFHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = "HLFHE.apply_lookup_table"(%arg0, %arg1) : (!HLFHE.eint<2>, memref<4xi2>) -> !HLFHE.eint<2>
// CHECK-LABEL: func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi2>) -> !HLFHE.eint<2>
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi2>) -> !HLFHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = "HLFHE.apply_lookup_table"(%arg0, %arg1) : (!HLFHE.eint<2>, tensor<4xi2>) -> !HLFHE.eint<2>
// CHECK-NEXT: return %[[V1]] : !HLFHE.eint<2>
%1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, memref<4xi2>) -> (!HLFHE.eint<2>)
%1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, tensor<4xi2>) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}

View File

@@ -1,17 +1,17 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics %s
// Bad dimension of the lookup table
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: memref<4xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> {
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<4xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> {
// expected-error @+1 {{'MidLFHE.apply_lookup_table' op should have as `l_cst` argument a shape of one dimension equals to 2^p, where p is the width of the `ct` argument}}
%1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1) {k = 1 : i32, polynomialSize = 1024 : i32, levelKS = 2 : i32, baseLogKS = -82 : i32, levelBS = 3 : i32, baseLogBS = -83 : i32}: (!MidLFHE.glwe<{1024,12,64}{7}>, memref<4xi2>) -> (!MidLFHE.glwe<{512,10,64}{2}>)
%1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1) {k = 1 : i32, polynomialSize = 1024 : i32, levelKS = 2 : i32, baseLogKS = -82 : i32, levelBS = 3 : i32, baseLogBS = -83 : i32}: (!MidLFHE.glwe<{1024,12,64}{7}>, tensor<4xi2>) -> (!MidLFHE.glwe<{512,10,64}{2}>)
return %1: !MidLFHE.glwe<{512,10,64}{2}>
}
// -----
// Bad dimension of integer in the lookup table
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: memref<128xi3>) -> !MidLFHE.glwe<{512,10,64}{2}> {
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi3>) -> !MidLFHE.glwe<{512,10,64}{2}> {
// expected-error @+1 {{'MidLFHE.apply_lookup_table' op should have equals width beetwen the encrypted integer result and integers of the `tabulated_lambda` argument}}
%1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1) {k = 1 : i32, polynomialSize = 1024 : i32, levelKS = 2 : i32, baseLogKS = -82 : i32, levelBS = 3 : i32, baseLogBS = -83 : i32}: (!MidLFHE.glwe<{1024,12,64}{7}>, memref<128xi3>) -> (!MidLFHE.glwe<{512,10,64}{2}>)
%1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1) {k = 1 : i32, polynomialSize = 1024 : i32, levelKS = 2 : i32, baseLogKS = -82 : i32, levelBS = 3 : i32, baseLogBS = -83 : i32}: (!MidLFHE.glwe<{1024,12,64}{7}>, tensor<128xi3>) -> (!MidLFHE.glwe<{512,10,64}{2}>)
return %1: !MidLFHE.glwe<{512,10,64}{2}>
}

View File

@@ -1,10 +1,10 @@
// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: memref<128xi2>) -> !MidLFHE.glwe<{512,10,64}{2}>
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: memref<128xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> {
// CHECK-NEXT: %[[V1:.*]] = "MidLFHE.apply_lookup_table"(%arg0, %arg1) {baseLogBS = -83 : i32, baseLogKS = -82 : i32, k = 1 : i32, levelBS = 3 : i32, levelKS = 2 : i32, polynomialSize = 1024 : i32} : (!MidLFHE.glwe<{1024,12,64}{7}>, memref<128xi2>) -> !MidLFHE.glwe<{512,10,64}{2}>
// CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi2>) -> !MidLFHE.glwe<{512,10,64}{2}>
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> {
// CHECK-NEXT: %[[V1:.*]] = "MidLFHE.apply_lookup_table"(%arg0, %arg1) {baseLogBS = -83 : i32, baseLogKS = -82 : i32, k = 1 : i32, levelBS = 3 : i32, levelKS = 2 : i32, polynomialSize = 1024 : i32} : (!MidLFHE.glwe<{1024,12,64}{7}>, tensor<128xi2>) -> !MidLFHE.glwe<{512,10,64}{2}>
// CHECK-NEXT: return %[[V1]] : !MidLFHE.glwe<{512,10,64}{2}>
%1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1) {k = 1 : i32, polynomialSize = 1024 : i32, levelKS = 2 : i32, baseLogKS = -82 : i32, levelBS = 3 : i32, baseLogBS = -83 : i32} : (!MidLFHE.glwe<{1024,12,64}{7}>, memref<128xi2>) -> (!MidLFHE.glwe<{512,10,64}{2}>)
%1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1) {k = 1 : i32, polynomialSize = 1024 : i32, levelKS = 2 : i32, baseLogKS = -82 : i32, levelBS = 3 : i32, baseLogBS = -83 : i32} : (!MidLFHE.glwe<{1024,12,64}{7}>, tensor<128xi2>) -> (!MidLFHE.glwe<{512,10,64}{2}>)
return %1: !MidLFHE.glwe<{512,10,64}{2}>
}