From 8b9c9f2da1e81eb530f56fe9edd5b8817fbb7372 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Tue, 27 Jul 2021 10:58:28 +0200 Subject: [PATCH] refactor(compiler): HLFHE.dot_eint_int: Switch from reference to value semantics This changes the semantics of `HLFHE.dot_eint_int` from memref-based reference semantics to tensor-based value semantics. The former: "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : (memref>, memref, memref>) -> () becomes: "HLFHE.dot_eint_int"(%arg0, %arg1) : (tensor>, tensor) -> !HLFHE.eint<0> As a side effect, data-flow analyses become much easier. With the previous memref type of the plaintext argument it is difficult to check whether the plaintext values are statically defined constants or originate from a memory region changed at execution time (e.g., for analyses evaluating the impact on noise). Changing the plaintext type from `memref` to `vector` makes such analyses significantly easier. --- .../zamalang/Dialect/HLFHE/IR/HLFHEOps.td | 10 +-- .../TensorOpsToLinalg.cpp | 81 +++++++++++++------ compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp | 13 --- compiler/src/main.cpp | 1 + compiler/tests/Dialect/HLFHE/dot.invalid.mlir | 48 ++++------- compiler/tests/Dialect/HLFHE/ops.mlir | 17 ++-- .../Dialect/HLFHE/tensor-ops-to-linalg.mlir | 42 +++++----- 7 files changed, 108 insertions(+), 104 deletions(-) diff --git a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td index 4d6882772..69fa3558d 100644 --- a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td +++ b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td @@ -91,13 +91,11 @@ def ApplyLookupTableEintOp : HLFHE_Op<"apply_lookup_table"> { // Tensor operations // Dot product -def Dot : HLFHE_Op<"dot_eint_int", [DeclareOpInterfaceMethods]> { - // Output memref is passed as the last argument; Input and output - // relationships are modeled through MemoryEffectsOpInterface` +def Dot : HLFHE_Op<"dot_eint_int"> { let arguments = (ins - Type.predicate, HasAnyRankOfPred<[1]>]>>:$lhs, - Type.predicate, HasAnyRankOfPred<[1]>]>>:$rhs, - Type.predicate, HasAnyRankOfPred<[0]>]>>:$out); + Type.predicate, HasStaticShapePred, HasAnyRankOfPred<[1]>]>>:$lhs, + Type.predicate, HasStaticShapePred, HasAnyRankOfPred<[1]>]>>:$rhs); + let results = (outs EncryptedIntegerType:$out); let verifier = [{ if(::mlir::failed( mlir::verifyCompatibleShape( diff --git a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp index ecf03ff24..fb333cc03 100644 --- a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -1,6 +1,8 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -20,31 +22,38 @@ struct DotToLinalgGeneric : public ::mlir::RewritePattern { // This rewrite pattern transforms any instance of // `HLFHE.dot_eint_int` to an instance of `linalg.generic` with an // appropriate region using `HLFHE.mul_eint_int` and - // `HLFHE.add_eint` operations and an appropriate specification for - // the iteration dimensions. + // `HLFHE.add_eint` operations, an appropriate specification for the + // iteration dimensions and appropriate operaztions managing the + // accumulator of `linalg.generic`. // // Example: // - // "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : - // (memref>, - // memref, - // memref>) -> () + // %o = "HLFHE.dot_eint_int"(%arg0, %arg1) : + // (tensor<4x!HLFHE.eint<0>>, + // tensor<4xi32>) -> (!HLFHE.eint<0>) // // becomes: // - // linalg.generic { - // indexing_maps = [affine_map<(d0) -> (d0)>, - // affine_map<(d0) -> (d0)>, - // affine_map<(d0) -> ()>], - // iterator_types = ["reduction"] - // } ins(%arg0, %arg1 : memref>, memref) - // outs(%arg2: memref>) - // { - // ^bb0(%arg3: !HLFHE.eint<0>, %arg4: i32, %arg5: !HLFHE.eint<0>): - // %0 = "HLFHE.mul_eint_int"(%arg3, %arg4) : (!HLFHE.eint<0>, i32) -> - // !HLFHE.eint<0> %1 = "HLFHE.add_eint"(%0, %arg5) : (!HLFHE.eint<0>, - // !HLFHE.eint<0>) -> !HLFHE.eint<0> linalg.yield %1 : !HLFHE.eint<0> - // } + // %0 = "HLFHE.zero"() : () -> !HLFHE.eint<0> + // %1 = tensor.from_elements %0 : tensor<1x!HLFHE.eint<0>> + // %2 = linalg.generic { + // indexing_maps = [#map0, #map0, #map1], + // iterator_types = ["reduction"] + // } + // ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<0>>, tensor<2xi32>) + // outs(%1 : tensor<1x!HLFHE.eint<0>>) { + // ^bb0(%arg2: !HLFHE.eint<0>, %arg3: i32, %arg4: !HLFHE.eint<0>): + // %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : + // (!HLFHE.eint<0>, i32) -> !HLFHE.eint<0> + // + // %5 = "HLFHE.add_eint"(%4, %arg4) : + // (!HLFHE.eint<0>, !HLFHE.eint<0>) -> !HLFHE.eint<0> + // + // linalg.yield %5 : !HLFHE.eint<0> + // } -> tensor<1x!HLFHE.eint<0>> + // + // %c0 = constant 0 : index + // %o = tensor.extract %2[%c0] : tensor<1x!HLFHE.eint<0>> // ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, @@ -52,14 +61,28 @@ struct DotToLinalgGeneric : public ::mlir::RewritePattern { ::mlir::zamalang::HLFHE::Dot &&dotOp = ::llvm::dyn_cast_or_null<::mlir::zamalang::HLFHE::Dot>(op0); - mlir::TypeRange resTypes{}; - llvm::SmallVector ins{dotOp.lhs(), dotOp.rhs()}; - llvm::SmallVector outs{dotOp.out()}; + // Zero value to initialize accumulator + mlir::Value zeroCst = rewriter.create( + dotOp.getLoc(), + dotOp.lhs().getType().cast().getElementType()); + // Create one-dimensional accumulator with a single element + // (`tensor.from_elements` does not allow for the creation of 0d + // tensors) + mlir::tensor::FromElementsOp feOp = + rewriter.create(dotOp.getLoc(), zeroCst); + + mlir::Value accu = feOp.getResult(); + + // Create `linalg.generic` op + llvm::SmallVector resTypes{accu.getType()}; + llvm::SmallVector ins{dotOp.lhs(), dotOp.rhs()}; + llvm::SmallVector outs{accu}; llvm::SmallVector maps{ mlir::AffineMap::getMultiDimIdentityMap(1, this->getContext()), mlir::AffineMap::getMultiDimIdentityMap(1, this->getContext()), - mlir::AffineMap::get(1, 0, this->getContext())}; + mlir::AffineMap::get(1, 0, {rewriter.getAffineConstantExpr(0)}, + this->getContext())}; llvm::SmallVector itTypes{"reduction"}; llvm::StringRef doc{""}; @@ -83,7 +106,16 @@ struct DotToLinalgGeneric : public ::mlir::RewritePattern { dotOp.getLoc(), resTypes, ins, outs, maps, itTypes, doc, call, regBuilder); - rewriter.replaceOp(op0, {gop.getODSResults(0)}); + // Return value is still a 1-dimensional tensor; extract first + // element and use it as a replacement for the result of the dot + // operation + mlir::Value idx0 = + rewriter.create(dotOp.getLoc(), 0); + llvm::SmallVector indexes{idx0}; + mlir::Value res = rewriter.create( + dotOp.getLoc(), gop.getResult(0), indexes); + + rewriter.replaceOp(op0, {res}); return ::mlir::success(); }; @@ -105,6 +137,7 @@ void HLFHETensorOpsToLinalg::runOnFunction() { target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); + target.addLegalDialect(); target.addIllegalOp(); mlir::OwningRewritePatternList patterns(&getContext()); diff --git a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp index 61ed1d6fc..b2a3e0de5 100644 --- a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp +++ b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp @@ -118,19 +118,6 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::OpState &op, return mlir::success(); } -void Dot::getEffects( - SmallVectorImpl> - &effects) { - // Side effects for Dot product: the first two operands are inputs, - // the last one is an output - effects.emplace_back(MemoryEffects::Read::get(), this->lhs(), - SideEffects::DefaultResource::get()); - effects.emplace_back(MemoryEffects::Read::get(), this->rhs(), - SideEffects::DefaultResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), this->out(), - SideEffects::DefaultResource::get()); -} - } // namespace HLFHE } // namespace zamalang } // namespace mlir diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 69c2d4c5e..9f40c2439 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -250,6 +250,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); + context.getOrLoadDialect(); context.getOrLoadDialect(); if (cmdline::verifyDiagnostics) diff --git a/compiler/tests/Dialect/HLFHE/dot.invalid.mlir b/compiler/tests/Dialect/HLFHE/dot.invalid.mlir index 7127999c2..9345c8b91 100644 --- a/compiler/tests/Dialect/HLFHE/dot.invalid.mlir +++ b/compiler/tests/Dialect/HLFHE/dot.invalid.mlir @@ -1,59 +1,41 @@ // RUN: zamacompiler --split-input-file --verify-diagnostics %s -// Unranked types -func @dot_unranked( - %arg0: memref>, - %arg1: memref, - %arg2: memref>) -{ - // expected-error @+1 {{'HLFHE.dot_eint_int' op operand #0}} - "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : - (memref>, memref, memref>) -> () - - return -} - -// ----- - // Incompatible shapes func @dot_incompatible_shapes( - %arg0: memref<5x!HLFHE.eint<2>>, - %arg1: memref<4xi32>, - %arg2: memref>) + %arg0: tensor<5x!HLFHE.eint<5>>, + %arg1: tensor<4xi32>) -> !HLFHE.eint<5> { // expected-error @+1 {{'HLFHE.dot_eint_int' op arguments have incompatible shapes}} - "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : - (memref<5x!HLFHE.eint<2>>, memref<4xi32>, memref>) -> () + %ret = "HLFHE.dot_eint_int"(%arg0, %arg1) : + (tensor<5x!HLFHE.eint<5>>, tensor<4xi32>) -> !HLFHE.eint<5> - return + return %ret : !HLFHE.eint<5> } // ----- // Incompatible input types func @dot_incompatible_input_types( - %arg0: memref<4x!HLFHE.eint<2>>, - %arg1: memref<4xf32>, - %arg2: memref>) + %arg0: tensor<5x!HLFHE.eint<2>>, + %arg1: tensor<4xf32>) -> !HLFHE.eint<2> { // expected-error @+1 {{'HLFHE.dot_eint_int' op operand #1 must}} - "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : - (memref<4x!HLFHE.eint<2>>, memref<4xf32>, memref>) -> () + %ret = "HLFHE.dot_eint_int"(%arg0, %arg1) : + (tensor<5x!HLFHE.eint<2>>, tensor<4xf32>) -> !HLFHE.eint<0> - return + return %ret : !HLFHE.eint<2> } // ----- // Wrong number of dimensions func @dot_num_dims( - %arg0: memref<2x4x!HLFHE.eint<2>>, - %arg1: memref<2x4xi32>, - %arg2: memref>) + %arg0: tensor<2x4x!HLFHE.eint<2>>, + %arg1: tensor<2x4xi32>) -> !HLFHE.eint<2> { // expected-error @+1 {{'HLFHE.dot_eint_int' op operand #0}} - "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : - (memref<2x4x!HLFHE.eint<2>>, memref<2x4xi32>, memref>) -> () + %ret = "HLFHE.dot_eint_int"(%arg0, %arg1) : + (tensor<2x4x!HLFHE.eint<2>>, tensor<2x4xi32>) -> !HLFHE.eint<2> - return + return %ret : !HLFHE.eint<2> } diff --git a/compiler/tests/Dialect/HLFHE/ops.mlir b/compiler/tests/Dialect/HLFHE/ops.mlir index 2afd85ba9..31fda9dc9 100644 --- a/compiler/tests/Dialect/HLFHE/ops.mlir +++ b/compiler/tests/Dialect/HLFHE/ops.mlir @@ -60,15 +60,14 @@ func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<4xi2>) -> !HLFHE.e return %1: !HLFHE.eint<2> } -// CHECK-LABEL: func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>, %arg1: memref<2xi3>, %arg2: memref>) -func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>, - %arg1: memref<2xi3>, - %arg2: memref>) +// CHECK-LABEL: func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi32>) -> !HLFHE.eint<2> +func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, + %arg1: tensor<2xi32>) -> !HLFHE.eint<2> { - // CHECK-NEXT: "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : (memref<2x!HLFHE.eint<2>>, memref<2xi3>, memref>) -> () - "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : - (memref<2x!HLFHE.eint<2>>, memref<2xi3>, memref>) -> () + // CHECK-NEXT: %[[RET:.*]] = "HLFHE.dot_eint_int"(%arg0, %arg1) : (tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) -> !HLFHE.eint<2> + %ret = "HLFHE.dot_eint_int"(%arg0, %arg1) : + (tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) -> !HLFHE.eint<2> - //CHECK-NEXT: return - return + //CHECK-NEXT: return %[[RET]] : !HLFHE.eint<2> + return %ret : !HLFHE.eint<2> } diff --git a/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir b/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir index fb93456ad..752200b71 100644 --- a/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir +++ b/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir @@ -1,22 +1,26 @@ -// RUN: zamacompiler %s --passes hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s +// RUN: zamacompiler %s --convert-hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s -// CHECK: #map0 = affine_map<(d0) -> (d0)> -// CHECK-NEXT: #map1 = affine_map<(d0) -> ()> -// CHECK-NEXT: module { -// CHECK-NEXT: func @dot_eint_int(%[[A0:.*]]: memref<2x!HLFHE.eint<2>>, %[[A1:.*]]: memref<2xi3>, %[[A2:.*]]: memref>) -// CHECK-NEXT: linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%[[A0]], %[[A1]] : memref<2x!HLFHE.eint<2>>, memref<2xi3>) outs(%arg2 : memref>) { -// CHECK-NEXT: ^bb0(%[[A3:.*]]: !HLFHE.eint<2>, %[[A4:.*]]: i3, %[[A5:.*]]: !HLFHE.eint<2>): // no predecessors -// CHECK-NEXT: %[[T0:.*]] = "HLFHE.mul_eint_int"(%[[A3]], %[[A4]]) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2> -// CHECK-NEXT: %[[T1:.*]] = "HLFHE.add_eint"(%[[T0]], %[[A5]]) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2> -// CHECK-NEXT: linalg.yield %[[T1]] : !HLFHE.eint<2> -// CHECK-NEXT: } -// CHECK-NEXT: return -// CHECK-NEXT: } -func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>, - %arg1: memref<2xi3>, - %arg2: memref>) +//CHECK: #map0 = affine_map<(d0) -> (d0)> +//CHECK-NEXT: #map1 = affine_map<(d0) -> (0)> +//CHECK-NEXT: module { +//CHECK-NEXT: func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi32>) -> !HLFHE.eint<2> { +//CHECK-NEXT: %0 = "HLFHE.zero"() : () -> !HLFHE.eint<2> +//CHECK-NEXT: %1 = tensor.from_elements %0 : tensor<1x!HLFHE.eint<2>> +//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) outs(%1 : tensor<1x!HLFHE.eint<2>>) { +//CHECK-NEXT: ^bb0(%arg2: !HLFHE.eint<2>, %arg3: i32, %arg4: !HLFHE.eint<2>): // no predecessors +//CHECK-NEXT: %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<2>, i32) -> !HLFHE.eint<2> +//CHECK-NEXT: %5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2> +//CHECK-NEXT: linalg.yield %5 : !HLFHE.eint<2> +//CHECK-NEXT: } -> tensor<1x!HLFHE.eint<2>> +//CHECK-NEXT: %c0 = constant 0 : index +//CHECK-NEXT: %3 = tensor.extract %2[%c0] : tensor<1x!HLFHE.eint<2>> +//CHECK-NEXT: return %3 : !HLFHE.eint<2> +//CHECK-NEXT: } +//CHECK-NEXT: } +func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, + %arg1: tensor<2xi32>) -> !HLFHE.eint<2> { - "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : - (memref<2x!HLFHE.eint<2>>, memref<2xi3>, memref>) -> () - return + %o = "HLFHE.dot_eint_int"(%arg0, %arg1) : + (tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) -> !HLFHE.eint<2> + return %o : !HLFHE.eint<2> }