diff --git a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h index fa65f50b5..b5daad82d 100644 --- a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h +++ b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h @@ -216,11 +216,10 @@ mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::IntegerAttr levelBS, mlir::IntegerAttr baseLogBS, mlir::IntegerAttr outputSizeKS, mlir::OpResult result) { // convert result type - GLWECipherTextType glwe_type = result.getType().cast(); LweCiphertextType lwe_type = - convertTypeToLWE(rewriter.getContext(), glwe_type); + convertTypeToLWE(rewriter.getContext(), result.getType()); // fill the the table in the GLWE accumulator - mlir::IntegerAttr precision = rewriter.getI32IntegerAttr(glwe_type.getP()); + mlir::IntegerAttr precision = rewriter.getI32IntegerAttr(lwe_type.getP()); mlir::Value accumulator = rewriter .create( @@ -229,7 +228,6 @@ mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc, .result(); // keyswitch - auto ct_type = ct.getType().cast(); mlir::SmallVector ksArgs{ct}; mlir::SmallVector ksAttrs{ mlir::NamedAttribute( @@ -242,8 +240,10 @@ mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::NamedAttribute( mlir::Identifier::get("baseLog", rewriter.getContext()), baseLogKS), }; - auto ksOutType = LweCiphertextType::get( - rewriter.getContext(), outputSizeKS.getInt(), ct_type.getP()); + // convert result type + LweCiphertextType ksOutType = LweCiphertextType::get( + rewriter.getContext(), outputSizeKS.getInt(), precision.getInt()); + convertTypeToLWE(rewriter.getContext(), result.getType()); mlir::Value keyswitched = rewriter .create(loc, ksOutType, diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td index f3ccbed2c..85c33a0ce 100644 --- a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td @@ -241,11 +241,10 @@ def ApplyLookupTableEintOp : HLFHELinalg_Op<"apply_lookup_table", []> { // Returns the lookup of 3x3 matrix of encrypted indices of with 2 on a table of size 4=2² of clear integers. // - // [0,1,2] [2,4,6] - // [3,0,1] lut [2,4,6,8] = [8,2,4] - // [2,3,0] [6,8,0] - // - "HLFHELinalg.apply_lookup_table(%t, %lut)" : (tensor<3x3x!HLFHE.eint<2>>, tensor<4xi4>) -> tensor<3x3x!HLFHE.eint<3>> + // [0,1,2] [1,3,5] + // [3,0,1] lut [1,3,5,7] = [7,1,3] + // [2,3,0] [5,7,1] + "HLFHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!HLFHE.eint<3>> ``` }]; diff --git a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp index b03ade271..16c3fc1b9 100644 --- a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -242,6 +242,107 @@ struct HLFHELinalgOpToLinalgGeneric }; }; +// This template rewrite pattern transforms any instance of +// operators `HLFHELinalg.apply_lookup_table` that implements the broadasting +// rules to an instance of `linalg.generic` with an appropriate region using +// `HLFHE.apply_lookup_table` operation, an appropriate specification for the +// iteration dimensions and appropriate operaztions managing the accumulator of +// `linalg.generic`. +// +// Example: +// +// HLFHELinalg.apply_lookup_table(%t, %lut): +// tensor>, tensor +// -> tensor> +// +// becomes: +// +// #maps_0 = [ +// affine_map<(aN, ..., a1) -> (aN, ..., a1)>, +// affine_map<(aN, ..., a1) -> (aN, ..., a1)> +// ] +// #attributes_0 { +// indexing_maps = #maps_0, +// iterator_types = ["parallel",..],//N parallel +// } +// %init = linalg.init_tensor [DN,...,D1] +// : tensor> +// %res = linalg.generic { +// ins(%t: tensor>) +// outs(%init : tensor>) +// { +// ^bb0(%arg0: !HLFHE.eint

): +// %0 = HLFHE.apply_lookup_table(%arg0, %lut): !HLFHE.eint

, +// tensor<4xi64> -> !HLFHE.eint +// linalg.yield %0 : !HLFHE.eint +// } +// } +// +struct HLFHELinalgApplyLookupTableToLinalgGeneric + : public mlir::OpRewritePattern< + mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp> { + HLFHELinalgApplyLookupTableToLinalgGeneric(::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern< + mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp>(context, + benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp lutOp, + ::mlir::PatternRewriter &rewriter) const override { + mlir::RankedTensorType resultTy = + ((mlir::Type)lutOp->getResult(0).getType()) + .cast(); + mlir::RankedTensorType tTy = + ((mlir::Type)lutOp.t().getType()).cast(); + + // linalg.init_tensor for initial value + mlir::Value init = rewriter.create( + lutOp.getLoc(), resultTy.getShape(), resultTy.getElementType()); + + // Create the affine #maps_0 + llvm::SmallVector maps{ + mlir::AffineMap::getMultiDimIdentityMap(tTy.getShape().size(), + this->getContext()), + mlir::AffineMap::getMultiDimIdentityMap(resultTy.getShape().size(), + this->getContext()), + }; + + // Create the iterator_types + llvm::SmallVector iteratorTypes(resultTy.getShape().size(), + "parallel"); + + // Create the body of the `linalg.generic` op + auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder, + mlir::Location nestedLoc, + mlir::ValueRange blockArgs) { + mlir::zamalang::HLFHE::ApplyLookupTableEintOp hlfheOp = + nestedBuilder.create( + lutOp.getLoc(), resultTy.getElementType(), blockArgs[0], + lutOp.lut()); + + nestedBuilder.create(lutOp.getLoc(), + hlfheOp.getResult()); + }; + + // Create the `linalg.generic` op + llvm::SmallVector resTypes{init.getType()}; + llvm::SmallVector ins{lutOp.t()}; + llvm::SmallVector outs{init}; + llvm::StringRef doc{""}; + llvm::StringRef call{""}; + + mlir::linalg::GenericOp genericOp = + rewriter.create(lutOp.getLoc(), resTypes, ins, + outs, maps, iteratorTypes, doc, + call, bodyBuilder); + + rewriter.replaceOp(lutOp, {genericOp.getResult(0)}); + + return ::mlir::success(); + }; +}; + namespace { struct HLFHETensorOpsToLinalg : public HLFHETensorOpsToLinalgBase { @@ -280,6 +381,7 @@ void HLFHETensorOpsToLinalg::runOnFunction() { HLFHELinalgOpToLinalgGeneric>( &getContext()); + patterns.insert(&getContext()); if (mlir::applyPartialConversion(function, target, std::move(patterns)) .failed()) diff --git a/compiler/tests/Conversion/HLFHELinalgToLinalg/apply_lookup_table.mlir b/compiler/tests/Conversion/HLFHELinalgToLinalg/apply_lookup_table.mlir new file mode 100644 index 000000000..d090bda0e --- /dev/null +++ b/compiler/tests/Conversion/HLFHELinalgToLinalg/apply_lookup_table.mlir @@ -0,0 +1,19 @@ +// RUN: zamacompiler %s --action=dump-midlfhe --passes hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s + +// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-NEXT: module { +// CHECK-NEXT: func @apply_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!HLFHE.eint<2>> { +// CHECK-NEXT: %0 = linalg.init_tensor [2, 3, 4] : tensor<2x3x4x!HLFHE.eint<2>> +// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3x4x!HLFHE.eint<2>>) outs(%0 : tensor<2x3x4x!HLFHE.eint<2>>) { +// CHECK-NEXT: ^bb0(%arg2: !HLFHE.eint<2>, %arg3: !HLFHE.eint<2>): // no predecessors +// CHECK-NEXT: %2 = "HLFHE.apply_lookup_table"(%arg2, %arg1) : (!HLFHE.eint<2>, tensor<4xi64>) -> !HLFHE.eint<2> +// CHECK-NEXT: linalg.yield %2 : !HLFHE.eint<2> +// CHECK-NEXT: } -> tensor<2x3x4x!HLFHE.eint<2>> +// CHECK-NEXT: return %1 : tensor<2x3x4x!HLFHE.eint<2>> +// CHECK-NEXT: } +// CHECK-NEXT: } + +func @apply_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!HLFHE.eint<2>>, tensor<4xi64>) -> (tensor<2x3x4x!HLFHE.eint<2>>) + return %1: tensor<2x3x4x!HLFHE.eint<2>> +} \ No newline at end of file diff --git a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc index ab0eaed87..be82813f2 100644 --- a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc @@ -867,3 +867,52 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_matrix_line_missing_dim) { } } } + +/////////////////////////////////////////////////////////////////////////////// +// HLFHELinalg apply_lookup_table ///////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(End2EndJit_HLFHELinalg, apply_lookup_table) { + + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + // Returns the lookup of 3x3 matrix of encrypted indices of with 2 on a table of size 4=2² of clear integers. + // + // [0,1,2] [1,3,5] + // [3,0,1] lut [1,3,5,7] = [7,1,3] + // [2,3,0] [5,7,1] + func @main(%t: tensor<3x3x!HLFHE.eint<2>>) -> tensor<3x3x!HLFHE.eint<3>> { + %lut = std.constant dense<[1,3,5,7]> : tensor<4xi64> + %res = "HLFHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!HLFHE.eint<3>> + return %res : tensor<3x3x!HLFHE.eint<3>> + } +)XXX", + "main", true); + const uint8_t t[3][3]{ + {0, 1, 2}, + {3, 0, 1}, + {2, 3, 0}, + }; + const uint8_t expected[3][3]{ + {1, 3, 5}, + {7, 1, 3}, + {5, 7, 1}, + }; + + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + tArg(llvm::MutableArrayRef((uint8_t *)t, 3 * 3), {3, 3}); + + llvm::Expected> res = + lambda.operator()>({&tArg}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), 3 * 3); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 3; j++) { + EXPECT_EQ((*res)[i * 3 + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +}