diff --git a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp index dd693de1c..89087e5a4 100644 --- a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -141,6 +141,40 @@ getBroadcastedAffineMap(const mlir::RankedTensorType &resultType, rewriter.getContext()); } +// This create an affine map following the broadcasting rules, but also takes +// out one specific element of the LUT from the LUT dimension, which should be +// the last. +// +// Example: +// +// resultType: 4x2x5, operandType: 4x2x8, lut_index: 3 +// return: affine_map<(d0, d1, d2) -> (d0, d1, 3) +// last dimension of the operand is the lut size, and we take the map takes out +// the element at index 3 +mlir::AffineMap +getBroadcastedAffineMapMultiLUT(const mlir::RankedTensorType &resultType, + const mlir::RankedTensorType &operandType, + const int64_t lut_index, + ::mlir::PatternRewriter &rewriter) { + mlir::SmallVector affineExprs; + auto resultShape = resultType.getShape(); + auto operandShape = operandType.getShape(); + affineExprs.reserve(operandShape.size()); + // Don't take the lut dimension into account + size_t deltaNumDim = resultShape.size() - operandShape.size() + 1; + for (auto i = 0; i < operandShape.size() - 1; i++) { + if (operandShape[i] == 1 && resultShape[i + deltaNumDim] != 1) { + affineExprs.push_back(rewriter.getAffineConstantExpr(0)); + } else { + affineExprs.push_back(rewriter.getAffineDimExpr(i + deltaNumDim)); + } + } + // Index a specific element of the LUT + affineExprs.push_back(rewriter.getAffineConstantExpr(lut_index)); + return mlir::AffineMap::get(resultShape.size(), 0, affineExprs, + rewriter.getContext()); +} + // This template rewrite pattern transforms any instance of // operators `HLFHELinalgOp` that implements the broadasting rules to an // instance of `linalg.generic` with an appropriate region using `HLFHEOp` @@ -240,6 +274,134 @@ struct HLFHELinalgOpToLinalgGeneric }; }; +// This class rewrite pattern transforms any instance of +// operators `HLFHELinalg.ApplyMultiLookupTableEintOp` that implements the +// broadasting rules to an instance of `linalg.generic` with an appropriate +// region using `HLFHE.ApplyLookupTableEintOp` operation, an appropriate +// specification for the iteration dimensions and appropriate operaztions +// managing the accumulator of `linalg.generic`. +// +// Example: +// +// %res = "HLFHELinalg.apply_multi_lookup_table"(%t, %luts): +// (tensor<4x3x!HLFHE.eint<2>>, tensor<3x4xi64>) -> tensor<4x3x!HLFHE.eint<2>> +// +// becomes: +// +// #maps_0 = [ +// affine_map<(d0, d1) -> (d0, d1)> +// affine_map<(d0, d1) -> (d1, 0)> +// affine_map<(d0, d1) -> (d1, 1)> +// affine_map<(d0, d1) -> (d1, 2)> +// affine_map<(d0, d1) -> (d1, 3)> +// ] +// #attributes_0 { +// indexing_maps = #maps_0, +// iterator_types = ["parallel", "parallel"], +// } +// %init = linalg.init_tensor [4, 3] +// : tensor<4x3x!HLFHE.eint<2>> +// %res = linalg.generic { +// ins(%t, %luts, %luts, %luts, %luts: tensor<4x3x!HLFHE.eint

>, +// tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>) +// outs(%init : tensor<4x3x!HLFHE.eint<2>>) +// { +// ^bb0(%arg0: !HLFHE.eint<2>, %arg1: i64, %arg2: i64, %arg3: i64, +// %arg4: i64, %arg5: !HLFHE.eint<2>): +// %lut = tensor.from_elements %arg1, %arg2, %arg3, %arg4 : +// tensor<4xi64> %0 = "MidLFHE.apply_lookup_table"(%arg0, %lut) +// {baseLogBS = -1 : i32, baseLogKS = -1 : i32, k = -1 : i32, +// levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS = -1 : i32, +// polynomialSize = -1 : i32} : (!MidLFHE.glwe<{_,_,_}{2}>, +// tensor<4xi64>) -> !MidLFHE.glwe<{_,_,_}{2}> +// linalg.yield %0 : !HLFHE.eint<2> +// } +// } +// +struct HLFHELinalgApplyMultiLookupTableToLinalgGeneric + : public mlir::OpRewritePattern< + mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp> { + HLFHELinalgApplyMultiLookupTableToLinalgGeneric( + ::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern< + mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp>(context, + benefit) { + } + + ::mlir::LogicalResult matchAndRewrite( + mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp hlfheLinalgLutOp, + ::mlir::PatternRewriter &rewriter) const override { + mlir::RankedTensorType resultTy = + ((mlir::Type)hlfheLinalgLutOp->getResult(0).getType()) + .cast(); + mlir::RankedTensorType tensorTy = + ((mlir::Type)hlfheLinalgLutOp.t().getType()) + .cast(); + mlir::RankedTensorType lutsTy = + ((mlir::Type)hlfheLinalgLutOp.luts().getType()) + .cast(); + // linalg.init_tensor for initial value + mlir::Value init = rewriter.create( + hlfheLinalgLutOp.getLoc(), resultTy.getShape(), + resultTy.getElementType()); + + auto lutsShape = lutsTy.getShape(); + auto lut_size = lutsShape[lutsShape.size() - 1]; + // Create the affine maps + llvm::SmallVector maps{ + // Input tensor map + getBroadcastedAffineMap(resultTy, tensorTy, rewriter)}; + maps.reserve(lut_size + 1); + // Create as much affine maps as the size of the lut dimension + for (int64_t i = 0; i < lut_size; i++) + maps.push_back( + getBroadcastedAffineMapMultiLUT(resultTy, lutsTy, i, rewriter)); + // Result map + maps.push_back(getBroadcastedAffineMap(resultTy, resultTy, rewriter)); + + // Create the iterator_types + llvm::SmallVector iteratorTypes(resultTy.getShape().size(), + "parallel"); + + // Create the body of the `linalg.generic` op + auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder, + mlir::Location nestedLoc, + mlir::ValueRange blockArgs) { + mlir::tensor::FromElementsOp lut = + nestedBuilder.create( + hlfheLinalgLutOp.getLoc(), blockArgs.slice(1, lut_size)); + mlir::zamalang::HLFHE::ApplyLookupTableEintOp lutOp = + nestedBuilder.create( + hlfheLinalgLutOp.getLoc(), resultTy.getElementType(), + blockArgs[0], lut.result()); + + nestedBuilder.create(hlfheLinalgLutOp.getLoc(), + lutOp.getResult()); + }; + + // Create the `linalg.generic` op + llvm::SmallVector resTypes{init.getType()}; + llvm::SmallVector ins{hlfheLinalgLutOp.t()}; + ins.reserve(lut_size + 2); + // We extract one value at a time from one LUT using different maps, so we + // need to pass the LUT `lut_size` time + for (auto i = 0; i < lut_size; i++) + ins.push_back(hlfheLinalgLutOp.luts()); + llvm::SmallVector outs{init}; + llvm::StringRef doc{""}; + llvm::StringRef call{""}; + + mlir::linalg::GenericOp genericOp = + rewriter.create( + hlfheLinalgLutOp.getLoc(), resTypes, ins, outs, maps, iteratorTypes, + doc, call, bodyBuilder); + + rewriter.replaceOp(hlfheLinalgLutOp, {genericOp.getResult(0)}); + + return ::mlir::success(); + }; +}; + // This template rewrite pattern transforms any instance of // operators `HLFHELinalg.apply_lookup_table` that implements the broadasting // rules to an instance of `linalg.generic` with an appropriate region using @@ -610,6 +772,8 @@ void HLFHETensorOpsToLinalg::runOnFunction() { patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); + patterns.insert( + &getContext()); if (mlir::applyPartialConversion(function, target, std::move(patterns)) .failed()) diff --git a/compiler/tests/Dialect/HLFHELinalg/apply_multi_lut_to_linalg.mlir b/compiler/tests/Dialect/HLFHELinalg/apply_multi_lut_to_linalg.mlir new file mode 100644 index 000000000..57739e962 --- /dev/null +++ b/compiler/tests/Dialect/HLFHELinalg/apply_multi_lut_to_linalg.mlir @@ -0,0 +1,23 @@ +// RUN: zamacompiler %s --action=dump-midlfhe 2>&1 | FileCheck %s + +//CHECK-LABEL: #map0 = affine_map<(d0, d1) -> (d0, d1)> +//CHECK-NEXT: #map1 = affine_map<(d0, d1) -> (d0, d1, 0)> +//CHECK-NEXT: #map2 = affine_map<(d0, d1) -> (d0, d1, 1)> +//CHECK-NEXT: #map3 = affine_map<(d0, d1) -> (d0, d1, 2)> +//CHECK-NEXT: #map4 = affine_map<(d0, d1) -> (d0, d1, 3)> +//CHECK-NEXT: module { +//CHECK-NEXT: func @multi_lut(%arg0: tensor<4x4x!MidLFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4x4x4xi64>) -> tensor<4x4x!MidLFHE.glwe<{_,_,_}{2}>> { +//CHECK-NEXT: %[[V0:.*]] = linalg.init_tensor [4, 4] : tensor<4x4x!MidLFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3, #map4, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1, %arg1, %arg1, %arg1 : tensor<4x4x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<4x4x4xi64>, tensor<4x4x4xi64>, tensor<4x4x4xi64>, tensor<4x4x4xi64>) outs(%[[V0:.*]] : tensor<4x4x!MidLFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: ^bb0(%arg2: !MidLFHE.glwe<{_,_,_}{2}>, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors +//CHECK-NEXT: %[[V2:.*]] = tensor.from_elements %arg3, %arg4, %arg5, %arg6 : tensor<4xi64> +//CHECK-NEXT: %[[V3:.*]] = "MidLFHE.apply_lookup_table"(%arg2, %[[V2:.*]]) {baseLogBS = -1 : i32, baseLogKS = -1 : i32, k = -1 : i32, levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS = -1 : i32, polynomialSize = -1 : i32} : (!MidLFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> !MidLFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: linalg.yield %[[V3:.*]] : !MidLFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: } -> tensor<4x4x!MidLFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: return %[[V1:.*]] : tensor<4x4x!MidLFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } +//CHECK-NEXT: } +func @multi_lut(%arg0: tensor<4x4x!HLFHE.eint<2>>, %arg1: tensor<4x4x4xi64>) -> tensor<4x4x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<4x4x!HLFHE.eint<2>>, tensor<4x4x4xi64>) -> tensor<4x4x!HLFHE.eint<2>> + return %1: tensor<4x4x!HLFHE.eint<2>> +} \ No newline at end of file diff --git a/compiler/tests/Dialect/HLFHELinalg/apply_multi_lut_to_linalg_broadcast.mlir b/compiler/tests/Dialect/HLFHELinalg/apply_multi_lut_to_linalg_broadcast.mlir new file mode 100644 index 000000000..cf1c6be97 --- /dev/null +++ b/compiler/tests/Dialect/HLFHELinalg/apply_multi_lut_to_linalg_broadcast.mlir @@ -0,0 +1,23 @@ +// RUN: zamacompiler %s --action=dump-midlfhe 2>&1 | FileCheck %s + +//CHECK-LABEL: #map0 = affine_map<(d0, d1) -> (d0, d1)> +//CHECK-NEXT: #map1 = affine_map<(d0, d1) -> (d1, 0)> +//CHECK-NEXT: #map2 = affine_map<(d0, d1) -> (d1, 1)> +//CHECK-NEXT: #map3 = affine_map<(d0, d1) -> (d1, 2)> +//CHECK-NEXT: #map4 = affine_map<(d0, d1) -> (d1, 3)> +//CHECK-NEXT: module { +//CHECK-NEXT: func @multi_lut(%arg0: tensor<4x3x!MidLFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<3x4xi64>) -> tensor<4x3x!MidLFHE.glwe<{_,_,_}{2}>> { +//CHECK-NEXT: %[[V0:.*]] = linalg.init_tensor [4, 3] : tensor<4x3x!MidLFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3, #map4, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1, %arg1, %arg1, %arg1 : tensor<4x3x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>) outs(%[[V0:.*]] : tensor<4x3x!MidLFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: ^bb0(%arg2: !MidLFHE.glwe<{_,_,_}{2}>, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors +//CHECK-NEXT: %[[V2:.*]] = tensor.from_elements %arg3, %arg4, %arg5, %arg6 : tensor<4xi64> +//CHECK-NEXT: %[[V3:.*]] = "MidLFHE.apply_lookup_table"(%arg2, %[[V2:.*]]) {baseLogBS = -1 : i32, baseLogKS = -1 : i32, k = -1 : i32, levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS = -1 : i32, polynomialSize = -1 : i32} : (!MidLFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> !MidLFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: linalg.yield %[[V3:.*]] : !MidLFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: } -> tensor<4x3x!MidLFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: return %[[V1:.*]] : tensor<4x3x!MidLFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } +//CHECK-NEXT: } +func @multi_lut(%arg0: tensor<4x3x!HLFHE.eint<2>>, %arg1: tensor<3x4xi64>) -> tensor<4x3x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<4x3x!HLFHE.eint<2>>, tensor<3x4xi64>) -> tensor<4x3x!HLFHE.eint<2>> + return %1: tensor<4x3x!HLFHE.eint<2>> +} \ No newline at end of file diff --git a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc index 311c488ca..d9ae4ed0e 100644 --- a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc @@ -1045,6 +1045,109 @@ TEST(End2EndJit_HLFHELinalg, apply_lookup_table) { } } +/////////////////////////////////////////////////////////////////////////////// +// HLFHELinalg apply_multi_lookup_table +// ///////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(End2EndJit_HLFHELinalg, apply_multi_lookup_table) { + + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + // Returns the lookup of 3x3 matrix of encrypted indices of width 2 on a 3x3 matrix of tables of size 4=2² of clear integers. + func @main(%arg0: tensor<3x3x!HLFHE.eint<2>>, %arg1: tensor<3x3x4xi64>) -> tensor<3x3x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<3x3x!HLFHE.eint<2>>, tensor<3x3x4xi64>) -> tensor<3x3x!HLFHE.eint<2>> + return %1: tensor<3x3x!HLFHE.eint<2>> + } +)XXX"); + const uint8_t t[3][3]{ + {0, 1, 2}, + {3, 0, 1}, + {2, 3, 0}, + }; + const uint64_t luts[3][3][4]{ + {{1, 3, 5, 7}, {0, 4, 1, 3}, {3, 2, 5, 0}}, + {{0, 2, 1, 2}, {7, 1, 0, 2}, {0, 1, 2, 3}}, + {{2, 1, 0, 3}, {0, 1, 2, 3}, {6, 5, 4, 3}}, + }; + const uint8_t expected[3][3]{ + {1, 4, 5}, + {2, 7, 1}, + {0, 3, 6}, + }; + + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + tArg(llvm::MutableArrayRef((uint8_t *)t, 3 * 3), {3, 3}); + + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + lutsArg(llvm::MutableArrayRef((uint64_t *)luts, 3 * 3 * 4), + {3, 3, 4}); + + llvm::Expected> res = + lambda.operator()>({&tArg, &lutsArg}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), 3 * 3); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 3; j++) { + EXPECT_EQ((*res)[i * 3 + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + +TEST(End2EndJit_HLFHELinalg, apply_multi_lookup_table_with_boradcast) { + + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + // Returns the lookup of 3x3 matrix of encrypted indices of width 2 on a vector of 3 tables of size 4=2² of clear integers. + func @main(%arg0: tensor<3x3x!HLFHE.eint<2>>, %arg1: tensor<3x4xi64>) -> tensor<3x3x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<3x3x!HLFHE.eint<2>>, tensor<3x4xi64>) -> tensor<3x3x!HLFHE.eint<2>> + return %1: tensor<3x3x!HLFHE.eint<2>> + } +)XXX"); + const uint8_t t[3][3]{ + {0, 1, 2}, + {3, 0, 1}, + {2, 3, 0}, + }; + const uint64_t luts[3][4]{ + {1, 3, 5, 7}, + {0, 2, 1, 3}, + {2, 1, 0, 6}, + }; + const uint8_t expected[3][3]{ + {1, 2, 0}, + {7, 0, 1}, + {5, 3, 2}, + }; + + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + tArg(llvm::MutableArrayRef((uint8_t *)t, 3 * 3), {3, 3}); + + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + lutsArg(llvm::MutableArrayRef((uint64_t *)luts, 3 * 4), + {3, 4}); + + llvm::Expected> res = + lambda.operator()>({&tArg, &lutsArg}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), 3 * 3); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 3; j++) { + EXPECT_EQ((*res)[i * 3 + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + /////////////////////////////////////////////////////////////////////////////// // HLFHELinalg dot_eint_int /////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////