feat(compiler): lower HLFHELinalg.apply_multi_lut

Support broadcasting
This commit is contained in:
youben11
2021-11-11 12:13:28 +01:00
committed by Ayoub Benaissa
parent ac7f1f5a6b
commit 36413235c5
4 changed files with 313 additions and 0 deletions

View File

@@ -141,6 +141,40 @@ getBroadcastedAffineMap(const mlir::RankedTensorType &resultType,
rewriter.getContext());
}
// This create an affine map following the broadcasting rules, but also takes
// out one specific element of the LUT from the LUT dimension, which should be
// the last.
//
// Example:
//
// resultType: 4x2x5, operandType: 4x2x8, lut_index: 3
// return: affine_map<(d0, d1, d2) -> (d0, d1, 3)
// last dimension of the operand is the lut size, and we take the map takes out
// the element at index 3
mlir::AffineMap
getBroadcastedAffineMapMultiLUT(const mlir::RankedTensorType &resultType,
const mlir::RankedTensorType &operandType,
const int64_t lut_index,
::mlir::PatternRewriter &rewriter) {
mlir::SmallVector<mlir::AffineExpr, 4> affineExprs;
auto resultShape = resultType.getShape();
auto operandShape = operandType.getShape();
affineExprs.reserve(operandShape.size());
// Don't take the lut dimension into account
size_t deltaNumDim = resultShape.size() - operandShape.size() + 1;
for (auto i = 0; i < operandShape.size() - 1; i++) {
if (operandShape[i] == 1 && resultShape[i + deltaNumDim] != 1) {
affineExprs.push_back(rewriter.getAffineConstantExpr(0));
} else {
affineExprs.push_back(rewriter.getAffineDimExpr(i + deltaNumDim));
}
}
// Index a specific element of the LUT
affineExprs.push_back(rewriter.getAffineConstantExpr(lut_index));
return mlir::AffineMap::get(resultShape.size(), 0, affineExprs,
rewriter.getContext());
}
// This template rewrite pattern transforms any instance of
// operators `HLFHELinalgOp` that implements the broadasting rules to an
// instance of `linalg.generic` with an appropriate region using `HLFHEOp`
@@ -240,6 +274,134 @@ struct HLFHELinalgOpToLinalgGeneric
};
};
// This class rewrite pattern transforms any instance of
// operators `HLFHELinalg.ApplyMultiLookupTableEintOp` that implements the
// broadasting rules to an instance of `linalg.generic` with an appropriate
// region using `HLFHE.ApplyLookupTableEintOp` operation, an appropriate
// specification for the iteration dimensions and appropriate operaztions
// managing the accumulator of `linalg.generic`.
//
// Example:
//
// %res = "HLFHELinalg.apply_multi_lookup_table"(%t, %luts):
// (tensor<4x3x!HLFHE.eint<2>>, tensor<3x4xi64>) -> tensor<4x3x!HLFHE.eint<2>>
//
// becomes:
//
// #maps_0 = [
// affine_map<(d0, d1) -> (d0, d1)>
// affine_map<(d0, d1) -> (d1, 0)>
// affine_map<(d0, d1) -> (d1, 1)>
// affine_map<(d0, d1) -> (d1, 2)>
// affine_map<(d0, d1) -> (d1, 3)>
// ]
// #attributes_0 {
// indexing_maps = #maps_0,
// iterator_types = ["parallel", "parallel"],
// }
// %init = linalg.init_tensor [4, 3]
// : tensor<4x3x!HLFHE.eint<2>>
// %res = linalg.generic {
// ins(%t, %luts, %luts, %luts, %luts: tensor<4x3x!HLFHE.eint<p>>,
// tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>)
// outs(%init : tensor<4x3x!HLFHE.eint<2>>)
// {
// ^bb0(%arg0: !HLFHE.eint<2>, %arg1: i64, %arg2: i64, %arg3: i64,
// %arg4: i64, %arg5: !HLFHE.eint<2>):
// %lut = tensor.from_elements %arg1, %arg2, %arg3, %arg4 :
// tensor<4xi64> %0 = "MidLFHE.apply_lookup_table"(%arg0, %lut)
// {baseLogBS = -1 : i32, baseLogKS = -1 : i32, k = -1 : i32,
// levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS = -1 : i32,
// polynomialSize = -1 : i32} : (!MidLFHE.glwe<{_,_,_}{2}>,
// tensor<4xi64>) -> !MidLFHE.glwe<{_,_,_}{2}>
// linalg.yield %0 : !HLFHE.eint<2>
// }
// }
//
struct HLFHELinalgApplyMultiLookupTableToLinalgGeneric
: public mlir::OpRewritePattern<
mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp> {
HLFHELinalgApplyMultiLookupTableToLinalgGeneric(
::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<
mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp>(context,
benefit) {
}
::mlir::LogicalResult matchAndRewrite(
mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp hlfheLinalgLutOp,
::mlir::PatternRewriter &rewriter) const override {
mlir::RankedTensorType resultTy =
((mlir::Type)hlfheLinalgLutOp->getResult(0).getType())
.cast<mlir::RankedTensorType>();
mlir::RankedTensorType tensorTy =
((mlir::Type)hlfheLinalgLutOp.t().getType())
.cast<mlir::RankedTensorType>();
mlir::RankedTensorType lutsTy =
((mlir::Type)hlfheLinalgLutOp.luts().getType())
.cast<mlir::RankedTensorType>();
// linalg.init_tensor for initial value
mlir::Value init = rewriter.create<mlir::linalg::InitTensorOp>(
hlfheLinalgLutOp.getLoc(), resultTy.getShape(),
resultTy.getElementType());
auto lutsShape = lutsTy.getShape();
auto lut_size = lutsShape[lutsShape.size() - 1];
// Create the affine maps
llvm::SmallVector<mlir::AffineMap> maps{
// Input tensor map
getBroadcastedAffineMap(resultTy, tensorTy, rewriter)};
maps.reserve(lut_size + 1);
// Create as much affine maps as the size of the lut dimension
for (int64_t i = 0; i < lut_size; i++)
maps.push_back(
getBroadcastedAffineMapMultiLUT(resultTy, lutsTy, i, rewriter));
// Result map
maps.push_back(getBroadcastedAffineMap(resultTy, resultTy, rewriter));
// Create the iterator_types
llvm::SmallVector<llvm::StringRef> iteratorTypes(resultTy.getShape().size(),
"parallel");
// Create the body of the `linalg.generic` op
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
mlir::tensor::FromElementsOp lut =
nestedBuilder.create<mlir::tensor::FromElementsOp>(
hlfheLinalgLutOp.getLoc(), blockArgs.slice(1, lut_size));
mlir::zamalang::HLFHE::ApplyLookupTableEintOp lutOp =
nestedBuilder.create<mlir::zamalang::HLFHE::ApplyLookupTableEintOp>(
hlfheLinalgLutOp.getLoc(), resultTy.getElementType(),
blockArgs[0], lut.result());
nestedBuilder.create<mlir::linalg::YieldOp>(hlfheLinalgLutOp.getLoc(),
lutOp.getResult());
};
// Create the `linalg.generic` op
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
llvm::SmallVector<mlir::Value> ins{hlfheLinalgLutOp.t()};
ins.reserve(lut_size + 2);
// We extract one value at a time from one LUT using different maps, so we
// need to pass the LUT `lut_size` time
for (auto i = 0; i < lut_size; i++)
ins.push_back(hlfheLinalgLutOp.luts());
llvm::SmallVector<mlir::Value, 1> outs{init};
llvm::StringRef doc{""};
llvm::StringRef call{""};
mlir::linalg::GenericOp genericOp =
rewriter.create<mlir::linalg::GenericOp>(
hlfheLinalgLutOp.getLoc(), resTypes, ins, outs, maps, iteratorTypes,
doc, call, bodyBuilder);
rewriter.replaceOp(hlfheLinalgLutOp, {genericOp.getResult(0)});
return ::mlir::success();
};
};
// This template rewrite pattern transforms any instance of
// operators `HLFHELinalg.apply_lookup_table` that implements the broadasting
// rules to an instance of `linalg.generic` with an appropriate region using
@@ -610,6 +772,8 @@ void HLFHETensorOpsToLinalg::runOnFunction() {
patterns.insert<HLFHELinalgApplyLookupTableToLinalgGeneric>(&getContext());
patterns.insert<HLFHELinalgNegEintToLinalgGeneric>(&getContext());
patterns.insert<HLFHELinalgMatmulEintIntToLinalgGeneric>(&getContext());
patterns.insert<HLFHELinalgApplyMultiLookupTableToLinalgGeneric>(
&getContext());
if (mlir::applyPartialConversion(function, target, std::move(patterns))
.failed())

View File

@@ -0,0 +1,23 @@
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1 | FileCheck %s
//CHECK-LABEL: #map0 = affine_map<(d0, d1) -> (d0, d1)>
//CHECK-NEXT: #map1 = affine_map<(d0, d1) -> (d0, d1, 0)>
//CHECK-NEXT: #map2 = affine_map<(d0, d1) -> (d0, d1, 1)>
//CHECK-NEXT: #map3 = affine_map<(d0, d1) -> (d0, d1, 2)>
//CHECK-NEXT: #map4 = affine_map<(d0, d1) -> (d0, d1, 3)>
//CHECK-NEXT: module {
//CHECK-NEXT: func @multi_lut(%arg0: tensor<4x4x!MidLFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4x4x4xi64>) -> tensor<4x4x!MidLFHE.glwe<{_,_,_}{2}>> {
//CHECK-NEXT: %[[V0:.*]] = linalg.init_tensor [4, 4] : tensor<4x4x!MidLFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3, #map4, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1, %arg1, %arg1, %arg1 : tensor<4x4x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<4x4x4xi64>, tensor<4x4x4xi64>, tensor<4x4x4xi64>, tensor<4x4x4xi64>) outs(%[[V0:.*]] : tensor<4x4x!MidLFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: ^bb0(%arg2: !MidLFHE.glwe<{_,_,_}{2}>, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors
//CHECK-NEXT: %[[V2:.*]] = tensor.from_elements %arg3, %arg4, %arg5, %arg6 : tensor<4xi64>
//CHECK-NEXT: %[[V3:.*]] = "MidLFHE.apply_lookup_table"(%arg2, %[[V2:.*]]) {baseLogBS = -1 : i32, baseLogKS = -1 : i32, k = -1 : i32, levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS = -1 : i32, polynomialSize = -1 : i32} : (!MidLFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> !MidLFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: linalg.yield %[[V3:.*]] : !MidLFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: } -> tensor<4x4x!MidLFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: return %[[V1:.*]] : tensor<4x4x!MidLFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
//CHECK-NEXT: }
func @multi_lut(%arg0: tensor<4x4x!HLFHE.eint<2>>, %arg1: tensor<4x4x4xi64>) -> tensor<4x4x!HLFHE.eint<2>> {
%1 = "HLFHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<4x4x!HLFHE.eint<2>>, tensor<4x4x4xi64>) -> tensor<4x4x!HLFHE.eint<2>>
return %1: tensor<4x4x!HLFHE.eint<2>>
}

View File

@@ -0,0 +1,23 @@
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1 | FileCheck %s
//CHECK-LABEL: #map0 = affine_map<(d0, d1) -> (d0, d1)>
//CHECK-NEXT: #map1 = affine_map<(d0, d1) -> (d1, 0)>
//CHECK-NEXT: #map2 = affine_map<(d0, d1) -> (d1, 1)>
//CHECK-NEXT: #map3 = affine_map<(d0, d1) -> (d1, 2)>
//CHECK-NEXT: #map4 = affine_map<(d0, d1) -> (d1, 3)>
//CHECK-NEXT: module {
//CHECK-NEXT: func @multi_lut(%arg0: tensor<4x3x!MidLFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<3x4xi64>) -> tensor<4x3x!MidLFHE.glwe<{_,_,_}{2}>> {
//CHECK-NEXT: %[[V0:.*]] = linalg.init_tensor [4, 3] : tensor<4x3x!MidLFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3, #map4, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1, %arg1, %arg1, %arg1 : tensor<4x3x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>) outs(%[[V0:.*]] : tensor<4x3x!MidLFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: ^bb0(%arg2: !MidLFHE.glwe<{_,_,_}{2}>, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors
//CHECK-NEXT: %[[V2:.*]] = tensor.from_elements %arg3, %arg4, %arg5, %arg6 : tensor<4xi64>
//CHECK-NEXT: %[[V3:.*]] = "MidLFHE.apply_lookup_table"(%arg2, %[[V2:.*]]) {baseLogBS = -1 : i32, baseLogKS = -1 : i32, k = -1 : i32, levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS = -1 : i32, polynomialSize = -1 : i32} : (!MidLFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> !MidLFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: linalg.yield %[[V3:.*]] : !MidLFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: } -> tensor<4x3x!MidLFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: return %[[V1:.*]] : tensor<4x3x!MidLFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
//CHECK-NEXT: }
func @multi_lut(%arg0: tensor<4x3x!HLFHE.eint<2>>, %arg1: tensor<3x4xi64>) -> tensor<4x3x!HLFHE.eint<2>> {
%1 = "HLFHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<4x3x!HLFHE.eint<2>>, tensor<3x4xi64>) -> tensor<4x3x!HLFHE.eint<2>>
return %1: tensor<4x3x!HLFHE.eint<2>>
}

View File

@@ -1045,6 +1045,109 @@ TEST(End2EndJit_HLFHELinalg, apply_lookup_table) {
}
}
///////////////////////////////////////////////////////////////////////////////
// HLFHELinalg apply_multi_lookup_table
// /////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(End2EndJit_HLFHELinalg, apply_multi_lookup_table) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
// Returns the lookup of 3x3 matrix of encrypted indices of width 2 on a 3x3 matrix of tables of size 4=2² of clear integers.
func @main(%arg0: tensor<3x3x!HLFHE.eint<2>>, %arg1: tensor<3x3x4xi64>) -> tensor<3x3x!HLFHE.eint<2>> {
%1 = "HLFHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<3x3x!HLFHE.eint<2>>, tensor<3x3x4xi64>) -> tensor<3x3x!HLFHE.eint<2>>
return %1: tensor<3x3x!HLFHE.eint<2>>
}
)XXX");
const uint8_t t[3][3]{
{0, 1, 2},
{3, 0, 1},
{2, 3, 0},
};
const uint64_t luts[3][3][4]{
{{1, 3, 5, 7}, {0, 4, 1, 3}, {3, 2, 5, 0}},
{{0, 2, 1, 2}, {7, 1, 0, 2}, {0, 1, 2, 3}},
{{2, 1, 0, 3}, {0, 1, 2, 3}, {6, 5, 4, 3}},
};
const uint8_t expected[3][3]{
{1, 4, 5},
{2, 7, 1},
{0, 3, 6},
};
mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint8_t>>
tArg(llvm::MutableArrayRef<uint8_t>((uint8_t *)t, 3 * 3), {3, 3});
mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint64_t>>
lutsArg(llvm::MutableArrayRef<uint64_t>((uint64_t *)luts, 3 * 3 * 4),
{3, 3, 4});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&tArg, &lutsArg});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), 3 * 3);
for (size_t i = 0; i < 3; i++) {
for (size_t j = 0; j < 3; j++) {
EXPECT_EQ((*res)[i * 3 + j], expected[i][j])
<< ", at pos(" << i << "," << j << ")";
}
}
}
TEST(End2EndJit_HLFHELinalg, apply_multi_lookup_table_with_boradcast) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
// Returns the lookup of 3x3 matrix of encrypted indices of width 2 on a vector of 3 tables of size 4=2² of clear integers.
func @main(%arg0: tensor<3x3x!HLFHE.eint<2>>, %arg1: tensor<3x4xi64>) -> tensor<3x3x!HLFHE.eint<2>> {
%1 = "HLFHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<3x3x!HLFHE.eint<2>>, tensor<3x4xi64>) -> tensor<3x3x!HLFHE.eint<2>>
return %1: tensor<3x3x!HLFHE.eint<2>>
}
)XXX");
const uint8_t t[3][3]{
{0, 1, 2},
{3, 0, 1},
{2, 3, 0},
};
const uint64_t luts[3][4]{
{1, 3, 5, 7},
{0, 2, 1, 3},
{2, 1, 0, 6},
};
const uint8_t expected[3][3]{
{1, 2, 0},
{7, 0, 1},
{5, 3, 2},
};
mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint8_t>>
tArg(llvm::MutableArrayRef<uint8_t>((uint8_t *)t, 3 * 3), {3, 3});
mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint64_t>>
lutsArg(llvm::MutableArrayRef<uint64_t>((uint64_t *)luts, 3 * 4),
{3, 4});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&tArg, &lutsArg});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), 3 * 3);
for (size_t i = 0; i < 3; i++) {
for (size_t j = 0; j < 3; j++) {
EXPECT_EQ((*res)[i * 3 + j], expected[i][j])
<< ", at pos(" << i << "," << j << ")";
}
}
}
///////////////////////////////////////////////////////////////////////////////
// HLFHELinalg dot_eint_int ///////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////