feat(compiler): HLFHELinalg.apply_lookup_table definition

This commit is contained in:
Quentin Bourgerie
2021-10-22 17:22:31 +02:00
committed by Andi Drebes
parent f72d51d98d
commit dea1be9d52
4 changed files with 111 additions and 1 deletions

View File

@@ -220,7 +220,45 @@ def MulEintIntOp : HLFHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, Tens
);
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
}
def ApplyLookupTableEintOp : HLFHELinalg_Op<"apply_lookup_table", []> {
let summary = "Returns a tensor that contains the result of the lookup on a table.";
let description = [{
Performs for each encrypted indices a lookup on a table of clear integers.
```mlir
// The result of this operation, is a tensor that contains the result of the lookup on a table.
// i.e. %res[i, ..., k] = %lut[%t[i, ..., k]]
%res = HLFHELinalg.apply_lookup_table(%t, %lut): tensor<DNx...xD1x!HLFHE.eint<$p>>, tensor<D2^$pxi64> -> tensor<DNx...xD1x!HLFHE.eint<$p>>
```
The `%lut` argument should be a tensor with one dimension, where its dimension is equals to `2^p` where `p` is the width of the encrypted integers.
Examples:
```mlir
// Returns the lookup of 3x3 matrix of encrypted indices of with 2 on a table of size 4=2² of clear integers.
//
// [0,1,2] [2,4,6]
// [3,0,1] lut [2,4,6,8] = [8,2,4]
// [2,3,0] [6,8,0]
//
"HLFHELinalg.apply_lookup_table(%t, %lut)" : (tensor<3x3x!HLFHE.eint<2>>, tensor<4xi4>) -> tensor<3x3x!HLFHE.eint<3>>
```
}];
let arguments = (ins
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$t,
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$lut
);
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let verifier = [{
return ::mlir::zamalang::HLFHELinalg::verifyApplyLookupTable(*this);
}];
}
#endif

View File

@@ -202,5 +202,38 @@ LogicalResult verifyTensorBinaryEint(mlir::Operation *op) {
} // namespace OpTrait
} // namespace mlir
namespace mlir {
namespace zamalang {
namespace HLFHELinalg {
mlir::LogicalResult verifyApplyLookupTable(ApplyLookupTableEintOp &op) {
auto tTy = op.t().getType().cast<mlir::RankedTensorType>();
auto tEltTy =
tTy.getElementType().cast<mlir::zamalang::HLFHE::EncryptedIntegerType>();
auto lutTy = op.lut().getType().cast<mlir::RankedTensorType>();
auto lutEltTy = lutTy.getElementType().cast<mlir::IntegerType>();
auto resultTy = op.getResult().getType().cast<mlir::RankedTensorType>();
// Check the shape of lut argument
auto tEltwidth = tEltTy.getWidth();
mlir::SmallVector<int64_t, 1> expectedShape{1 << tEltwidth};
if (!lutTy.hasStaticShape(expectedShape) || !lutEltTy.isInteger(64)) {
op.emitOpError()
<< "should have as operand #2 a tensor<2^pxi64>, where p is the width "
"of the encrypted integer of the operand #1,"
<< "expect tensor <" << expectedShape[0] << "xi64>";
return mlir::failure();
}
if (!resultTy.hasStaticShape(tTy.getShape())) {
op.emitOpError()
<< " should have same shapes for operand #1 and the result";
}
return mlir::success();
}
} // namespace HLFHELinalg
} // namespace zamalang
} // namespace mlir
#define GET_OP_CLASSES
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp.inc"

View File

@@ -118,4 +118,30 @@ func @main(%a0: tensor<2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x3x4xi4>) -> tensor<2
return %1 : tensor<2x3x4x!HLFHE.eint<2>>
}
// -----
// -----
/////////////////////////////////////////////////
// HLFHELinalg.apply_lookup_table
/////////////////////////////////////////////////
func @apply_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %arg1: tensor<4xi32>) -> tensor<2x3x4x!HLFHE.eint<2>> {
// expected-error @+1 {{'HLFHELinalg.apply_lookup_table' op should have as operand #2 a tensor<2^pxi64>, where p is the width of the encrypted integer of the operand #1,expect tensor <4xi64>}}
%1 = "HLFHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!HLFHE.eint<2>>, tensor<4xi32>) -> (tensor<2x3x4x!HLFHE.eint<2>>)
return %1: tensor<2x3x4x!HLFHE.eint<2>>
}
// -----
func @apply_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %arg1: tensor<12xi64>) -> tensor<2x3x4x!HLFHE.eint<2>> {
// expected-error @+1 {{'HLFHELinalg.apply_lookup_table' op should have as operand #2 a tensor<2^pxi64>, where p is the width of the encrypted integer of the operand #1,expect tensor <4xi64>}}
%1 = "HLFHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!HLFHE.eint<2>>, tensor<12xi64>) -> (tensor<2x3x4x!HLFHE.eint<2>>)
return %1: tensor<2x3x4x!HLFHE.eint<2>>
}
// -----
func @apply_lookup_table(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!HLFHE.eint<2>> {
// expected-error @+1 {{'HLFHELinalg.apply_lookup_table' op should have same shapes for operand #1 and the result}}
%1 = "HLFHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<3x4x!HLFHE.eint<2>>, tensor<4xi64>) -> (tensor<2x3x4x!HLFHE.eint<2>>)
return %1: tensor<2x3x4x!HLFHE.eint<2>>
}

View File

@@ -216,4 +216,17 @@ func @mul_eint_int_broadcast_1(%a0: tensor<1x4x5x!HLFHE.eint<2>>, %a1: tensor<3x
func @mul_eint_int_broadcast_2(%a0: tensor<4x!HLFHE.eint<2>>, %a1: tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>> {
%1 ="HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<2>>, tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>>
return %1: tensor<3x4x!HLFHE.eint<2>>
}
/////////////////////////////////////////////////
// HLFHELinalg.apply_lookup_table
/////////////////////////////////////////////////
// CHECK-LABEL: func @apply_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!HLFHE.eint<2>>
func @apply_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!HLFHE.eint<2>> {
// CHECK-NEXT: %[[V1:.*]] = "HLFHELinalg.apply_lookup_table"(%arg0, %arg1) : (tensor<2x3x4x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<2x3x4x!HLFHE.eint<2>>
// CHECK-NEXT: return %[[V1]] : tensor<2x3x4x!HLFHE.eint<2>>
%1 = "HLFHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!HLFHE.eint<2>>, tensor<4xi64>) -> (tensor<2x3x4x!HLFHE.eint<2>>)
return %1: tensor<2x3x4x!HLFHE.eint<2>>
}