refactor(compiler): HLFHE.dot_eint_int: Switch from reference to value semantics

This changes the semantics of `HLFHE.dot_eint_int` from memref-based
reference semantics to tensor-based value semantics. The former:

  "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
     (memref<Nx!HLFHE.eint<0>>, memref<Nxi32>, memref<!HLFHE.eint<0>>) -> ()

becomes:

  "HLFHE.dot_eint_int"(%arg0, %arg1) :
     (tensor<Nx!HLFHE.eint<0>>, tensor<Nxi32>) -> !HLFHE.eint<0>

As a side effect, data-flow analyses become much easier. With the
previous memref type of the plaintext argument it is difficult to
check whether the plaintext values are statically defined constants or
originate from a memory region changed at execution time (e.g., for
analyses evaluating the impact on noise). Changing the plaintext type
from `memref` to `vector` makes such analyses significantly easier.
This commit is contained in:
Andi Drebes
2021-07-27 10:58:28 +02:00
committed by Quentin Bourgerie
parent cb580f16d2
commit 8b9c9f2da1
7 changed files with 108 additions and 104 deletions

View File

@@ -91,13 +91,11 @@ def ApplyLookupTableEintOp : HLFHE_Op<"apply_lookup_table"> {
// Tensor operations
// Dot product
def Dot : HLFHE_Op<"dot_eint_int", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
// Output memref is passed as the last argument; Input and output
// relationships are modeled through MemoryEffectsOpInterface`
def Dot : HLFHE_Op<"dot_eint_int"> {
let arguments = (ins
Type<And<[StaticShapeMemRefOf<[EncryptedIntegerType]>.predicate, HasAnyRankOfPred<[1]>]>>:$lhs,
Type<And<[StaticShapeMemRefOf<[AnyInteger]>.predicate, HasAnyRankOfPred<[1]>]>>:$rhs,
Type<And<[StaticShapeMemRefOf<[EncryptedIntegerType]>.predicate, HasAnyRankOfPred<[0]>]>>:$out);
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred, HasAnyRankOfPred<[1]>]>>:$lhs,
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred, HasAnyRankOfPred<[1]>]>>:$rhs);
let results = (outs EncryptedIntegerType:$out);
let verifier = [{
if(::mlir::failed(
mlir::verifyCompatibleShape(

View File

@@ -1,6 +1,8 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@@ -20,31 +22,38 @@ struct DotToLinalgGeneric : public ::mlir::RewritePattern {
// This rewrite pattern transforms any instance of
// `HLFHE.dot_eint_int` to an instance of `linalg.generic` with an
// appropriate region using `HLFHE.mul_eint_int` and
// `HLFHE.add_eint` operations and an appropriate specification for
// the iteration dimensions.
// `HLFHE.add_eint` operations, an appropriate specification for the
// iteration dimensions and appropriate operaztions managing the
// accumulator of `linalg.generic`.
//
// Example:
//
// "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
// (memref<?x!HLFHE.eint<0>>,
// memref<?xi32>,
// memref<!HLFHE.eint<0>>) -> ()
// %o = "HLFHE.dot_eint_int"(%arg0, %arg1) :
// (tensor<4x!HLFHE.eint<0>>,
// tensor<4xi32>) -> (!HLFHE.eint<0>)
//
// becomes:
//
// linalg.generic {
// indexing_maps = [affine_map<(d0) -> (d0)>,
// affine_map<(d0) -> (d0)>,
// affine_map<(d0) -> ()>],
// iterator_types = ["reduction"]
// } ins(%arg0, %arg1 : memref<?x!HLFHE.eint<0>>, memref<?xi32>)
// outs(%arg2: memref<!HLFHE.eint<0>>)
// {
// ^bb0(%arg3: !HLFHE.eint<0>, %arg4: i32, %arg5: !HLFHE.eint<0>):
// %0 = "HLFHE.mul_eint_int"(%arg3, %arg4) : (!HLFHE.eint<0>, i32) ->
// !HLFHE.eint<0> %1 = "HLFHE.add_eint"(%0, %arg5) : (!HLFHE.eint<0>,
// !HLFHE.eint<0>) -> !HLFHE.eint<0> linalg.yield %1 : !HLFHE.eint<0>
// }
// %0 = "HLFHE.zero"() : () -> !HLFHE.eint<0>
// %1 = tensor.from_elements %0 : tensor<1x!HLFHE.eint<0>>
// %2 = linalg.generic {
// indexing_maps = [#map0, #map0, #map1],
// iterator_types = ["reduction"]
// }
// ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<0>>, tensor<2xi32>)
// outs(%1 : tensor<1x!HLFHE.eint<0>>) {
// ^bb0(%arg2: !HLFHE.eint<0>, %arg3: i32, %arg4: !HLFHE.eint<0>):
// %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) :
// (!HLFHE.eint<0>, i32) -> !HLFHE.eint<0>
//
// %5 = "HLFHE.add_eint"(%4, %arg4) :
// (!HLFHE.eint<0>, !HLFHE.eint<0>) -> !HLFHE.eint<0>
//
// linalg.yield %5 : !HLFHE.eint<0>
// } -> tensor<1x!HLFHE.eint<0>>
//
// %c0 = constant 0 : index
// %o = tensor.extract %2[%c0] : tensor<1x!HLFHE.eint<0>>
//
::mlir::LogicalResult
matchAndRewrite(::mlir::Operation *op0,
@@ -52,14 +61,28 @@ struct DotToLinalgGeneric : public ::mlir::RewritePattern {
::mlir::zamalang::HLFHE::Dot &&dotOp =
::llvm::dyn_cast_or_null<::mlir::zamalang::HLFHE::Dot>(op0);
mlir::TypeRange resTypes{};
llvm::SmallVector<mlir::Value, 2> ins{dotOp.lhs(), dotOp.rhs()};
llvm::SmallVector<mlir::Value, 1> outs{dotOp.out()};
// Zero value to initialize accumulator
mlir::Value zeroCst = rewriter.create<mlir::zamalang::HLFHE::ZeroOp>(
dotOp.getLoc(),
dotOp.lhs().getType().cast<mlir::ShapedType>().getElementType());
// Create one-dimensional accumulator with a single element
// (`tensor.from_elements` does not allow for the creation of 0d
// tensors)
mlir::tensor::FromElementsOp feOp =
rewriter.create<mlir::tensor::FromElementsOp>(dotOp.getLoc(), zeroCst);
mlir::Value accu = feOp.getResult();
// Create `linalg.generic` op
llvm::SmallVector<mlir::Type, 1> resTypes{accu.getType()};
llvm::SmallVector<mlir::Value, 2> ins{dotOp.lhs(), dotOp.rhs()};
llvm::SmallVector<mlir::Value, 1> outs{accu};
llvm::SmallVector<mlir::AffineMap, 3> maps{
mlir::AffineMap::getMultiDimIdentityMap(1, this->getContext()),
mlir::AffineMap::getMultiDimIdentityMap(1, this->getContext()),
mlir::AffineMap::get(1, 0, this->getContext())};
mlir::AffineMap::get(1, 0, {rewriter.getAffineConstantExpr(0)},
this->getContext())};
llvm::SmallVector<llvm::StringRef, 1> itTypes{"reduction"};
llvm::StringRef doc{""};
@@ -83,7 +106,16 @@ struct DotToLinalgGeneric : public ::mlir::RewritePattern {
dotOp.getLoc(), resTypes, ins, outs, maps, itTypes, doc, call,
regBuilder);
rewriter.replaceOp(op0, {gop.getODSResults(0)});
// Return value is still a 1-dimensional tensor; extract first
// element and use it as a replacement for the result of the dot
// operation
mlir::Value idx0 =
rewriter.create<mlir::ConstantIndexOp>(dotOp.getLoc(), 0);
llvm::SmallVector<mlir::Value, 1> indexes{idx0};
mlir::Value res = rewriter.create<mlir::tensor::ExtractOp>(
dotOp.getLoc(), gop.getResult(0), indexes);
rewriter.replaceOp(op0, {res});
return ::mlir::success();
};
@@ -105,6 +137,7 @@ void HLFHETensorOpsToLinalg::runOnFunction() {
target.addLegalDialect<mlir::StandardOpsDialect>();
target.addLegalDialect<mlir::memref::MemRefDialect>();
target.addLegalDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
target.addLegalDialect<mlir::tensor::TensorDialect>();
target.addIllegalOp<mlir::zamalang::HLFHE::Dot>();
mlir::OwningRewritePatternList patterns(&getContext());

View File

@@ -118,19 +118,6 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::OpState &op,
return mlir::success();
}
void Dot::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
// Side effects for Dot product: the first two operands are inputs,
// the last one is an output
effects.emplace_back(MemoryEffects::Read::get(), this->lhs(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Read::get(), this->rhs(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), this->out(),
SideEffects::DefaultResource::get());
}
} // namespace HLFHE
} // namespace zamalang
} // namespace mlir

View File

@@ -250,6 +250,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
context.getOrLoadDialect<mlir::StandardOpsDialect>();
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
context.getOrLoadDialect<mlir::linalg::LinalgDialect>();
context.getOrLoadDialect<mlir::tensor::TensorDialect>();
context.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
if (cmdline::verifyDiagnostics)

View File

@@ -1,59 +1,41 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics %s
// Unranked types
func @dot_unranked(
%arg0: memref<?x!HLFHE.eint<2>>,
%arg1: memref<?xi32>,
%arg2: memref<!HLFHE.eint<2>>)
{
// expected-error @+1 {{'HLFHE.dot_eint_int' op operand #0}}
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
(memref<?x!HLFHE.eint<2>>, memref<?xi32>, memref<!HLFHE.eint<2>>) -> ()
return
}
// -----
// Incompatible shapes
func @dot_incompatible_shapes(
%arg0: memref<5x!HLFHE.eint<2>>,
%arg1: memref<4xi32>,
%arg2: memref<!HLFHE.eint<2>>)
%arg0: tensor<5x!HLFHE.eint<5>>,
%arg1: tensor<4xi32>) -> !HLFHE.eint<5>
{
// expected-error @+1 {{'HLFHE.dot_eint_int' op arguments have incompatible shapes}}
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
(memref<5x!HLFHE.eint<2>>, memref<4xi32>, memref<!HLFHE.eint<2>>) -> ()
%ret = "HLFHE.dot_eint_int"(%arg0, %arg1) :
(tensor<5x!HLFHE.eint<5>>, tensor<4xi32>) -> !HLFHE.eint<5>
return
return %ret : !HLFHE.eint<5>
}
// -----
// Incompatible input types
func @dot_incompatible_input_types(
%arg0: memref<4x!HLFHE.eint<2>>,
%arg1: memref<4xf32>,
%arg2: memref<!HLFHE.eint<2>>)
%arg0: tensor<5x!HLFHE.eint<2>>,
%arg1: tensor<4xf32>) -> !HLFHE.eint<2>
{
// expected-error @+1 {{'HLFHE.dot_eint_int' op operand #1 must}}
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
(memref<4x!HLFHE.eint<2>>, memref<4xf32>, memref<!HLFHE.eint<2>>) -> ()
%ret = "HLFHE.dot_eint_int"(%arg0, %arg1) :
(tensor<5x!HLFHE.eint<2>>, tensor<4xf32>) -> !HLFHE.eint<0>
return
return %ret : !HLFHE.eint<2>
}
// -----
// Wrong number of dimensions
func @dot_num_dims(
%arg0: memref<2x4x!HLFHE.eint<2>>,
%arg1: memref<2x4xi32>,
%arg2: memref<!HLFHE.eint<2>>)
%arg0: tensor<2x4x!HLFHE.eint<2>>,
%arg1: tensor<2x4xi32>) -> !HLFHE.eint<2>
{
// expected-error @+1 {{'HLFHE.dot_eint_int' op operand #0}}
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
(memref<2x4x!HLFHE.eint<2>>, memref<2x4xi32>, memref<!HLFHE.eint<2>>) -> ()
%ret = "HLFHE.dot_eint_int"(%arg0, %arg1) :
(tensor<2x4x!HLFHE.eint<2>>, tensor<2x4xi32>) -> !HLFHE.eint<2>
return
return %ret : !HLFHE.eint<2>
}

View File

@@ -60,15 +60,14 @@ func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<4xi2>) -> !HLFHE.e
return %1: !HLFHE.eint<2>
}
// CHECK-LABEL: func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>, %arg1: memref<2xi3>, %arg2: memref<!HLFHE.eint<2>>)
func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>,
%arg1: memref<2xi3>,
%arg2: memref<!HLFHE.eint<2>>)
// CHECK-LABEL: func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi32>) -> !HLFHE.eint<2>
func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>,
%arg1: tensor<2xi32>) -> !HLFHE.eint<2>
{
// CHECK-NEXT: "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : (memref<2x!HLFHE.eint<2>>, memref<2xi3>, memref<!HLFHE.eint<2>>) -> ()
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
(memref<2x!HLFHE.eint<2>>, memref<2xi3>, memref<!HLFHE.eint<2>>) -> ()
// CHECK-NEXT: %[[RET:.*]] = "HLFHE.dot_eint_int"(%arg0, %arg1) : (tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) -> !HLFHE.eint<2>
%ret = "HLFHE.dot_eint_int"(%arg0, %arg1) :
(tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) -> !HLFHE.eint<2>
//CHECK-NEXT: return
return
//CHECK-NEXT: return %[[RET]] : !HLFHE.eint<2>
return %ret : !HLFHE.eint<2>
}

View File

@@ -1,22 +1,26 @@
// RUN: zamacompiler %s --passes hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
// RUN: zamacompiler %s --convert-hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
// CHECK: #map0 = affine_map<(d0) -> (d0)>
// CHECK-NEXT: #map1 = affine_map<(d0) -> ()>
// CHECK-NEXT: module {
// CHECK-NEXT: func @dot_eint_int(%[[A0:.*]]: memref<2x!HLFHE.eint<2>>, %[[A1:.*]]: memref<2xi3>, %[[A2:.*]]: memref<!HLFHE.eint<2>>)
// CHECK-NEXT: linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%[[A0]], %[[A1]] : memref<2x!HLFHE.eint<2>>, memref<2xi3>) outs(%arg2 : memref<!HLFHE.eint<2>>) {
// CHECK-NEXT: ^bb0(%[[A3:.*]]: !HLFHE.eint<2>, %[[A4:.*]]: i3, %[[A5:.*]]: !HLFHE.eint<2>): // no predecessors
// CHECK-NEXT: %[[T0:.*]] = "HLFHE.mul_eint_int"(%[[A3]], %[[A4]]) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
// CHECK-NEXT: %[[T1:.*]] = "HLFHE.add_eint"(%[[T0]], %[[A5]]) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2>
// CHECK-NEXT: linalg.yield %[[T1]] : !HLFHE.eint<2>
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }
func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>,
%arg1: memref<2xi3>,
%arg2: memref<!HLFHE.eint<2>>)
//CHECK: #map0 = affine_map<(d0) -> (d0)>
//CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>
//CHECK-NEXT: module {
//CHECK-NEXT: func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi32>) -> !HLFHE.eint<2> {
//CHECK-NEXT: %0 = "HLFHE.zero"() : () -> !HLFHE.eint<2>
//CHECK-NEXT: %1 = tensor.from_elements %0 : tensor<1x!HLFHE.eint<2>>
//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) outs(%1 : tensor<1x!HLFHE.eint<2>>) {
//CHECK-NEXT: ^bb0(%arg2: !HLFHE.eint<2>, %arg3: i32, %arg4: !HLFHE.eint<2>): // no predecessors
//CHECK-NEXT: %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<2>, i32) -> !HLFHE.eint<2>
//CHECK-NEXT: %5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2>
//CHECK-NEXT: linalg.yield %5 : !HLFHE.eint<2>
//CHECK-NEXT: } -> tensor<1x!HLFHE.eint<2>>
//CHECK-NEXT: %c0 = constant 0 : index
//CHECK-NEXT: %3 = tensor.extract %2[%c0] : tensor<1x!HLFHE.eint<2>>
//CHECK-NEXT: return %3 : !HLFHE.eint<2>
//CHECK-NEXT: }
//CHECK-NEXT: }
func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>,
%arg1: tensor<2xi32>) -> !HLFHE.eint<2>
{
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
(memref<2x!HLFHE.eint<2>>, memref<2xi3>, memref<!HLFHE.eint<2>>) -> ()
return
%o = "HLFHE.dot_eint_int"(%arg0, %arg1) :
(tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) -> !HLFHE.eint<2>
return %o : !HLFHE.eint<2>
}