mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(compiler): Lower HLFHELinalg.apply_lookup_table (close #174)
This commit is contained in:
committed by
Andi Drebes
parent
2900c9a2a1
commit
ccaf1bff15
@@ -216,11 +216,10 @@ mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc,
|
||||
mlir::IntegerAttr levelBS, mlir::IntegerAttr baseLogBS,
|
||||
mlir::IntegerAttr outputSizeKS, mlir::OpResult result) {
|
||||
// convert result type
|
||||
GLWECipherTextType glwe_type = result.getType().cast<GLWECipherTextType>();
|
||||
LweCiphertextType lwe_type =
|
||||
convertTypeToLWE(rewriter.getContext(), glwe_type);
|
||||
convertTypeToLWE(rewriter.getContext(), result.getType());
|
||||
// fill the the table in the GLWE accumulator
|
||||
mlir::IntegerAttr precision = rewriter.getI32IntegerAttr(glwe_type.getP());
|
||||
mlir::IntegerAttr precision = rewriter.getI32IntegerAttr(lwe_type.getP());
|
||||
mlir::Value accumulator =
|
||||
rewriter
|
||||
.create<mlir::zamalang::LowLFHE::GlweFromTable>(
|
||||
@@ -229,7 +228,6 @@ mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc,
|
||||
.result();
|
||||
|
||||
// keyswitch
|
||||
auto ct_type = ct.getType().cast<GLWECipherTextType>();
|
||||
mlir::SmallVector<mlir::Value> ksArgs{ct};
|
||||
mlir::SmallVector<mlir::NamedAttribute> ksAttrs{
|
||||
mlir::NamedAttribute(
|
||||
@@ -242,8 +240,10 @@ mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc,
|
||||
mlir::NamedAttribute(
|
||||
mlir::Identifier::get("baseLog", rewriter.getContext()), baseLogKS),
|
||||
};
|
||||
auto ksOutType = LweCiphertextType::get(
|
||||
rewriter.getContext(), outputSizeKS.getInt(), ct_type.getP());
|
||||
// convert result type
|
||||
LweCiphertextType ksOutType = LweCiphertextType::get(
|
||||
rewriter.getContext(), outputSizeKS.getInt(), precision.getInt());
|
||||
convertTypeToLWE(rewriter.getContext(), result.getType());
|
||||
mlir::Value keyswitched =
|
||||
rewriter
|
||||
.create<mlir::zamalang::LowLFHE::KeySwitchLweOp>(loc, ksOutType,
|
||||
|
||||
@@ -241,11 +241,10 @@ def ApplyLookupTableEintOp : HLFHELinalg_Op<"apply_lookup_table", []> {
|
||||
|
||||
// Returns the lookup of 3x3 matrix of encrypted indices of with 2 on a table of size 4=2² of clear integers.
|
||||
//
|
||||
// [0,1,2] [2,4,6]
|
||||
// [3,0,1] lut [2,4,6,8] = [8,2,4]
|
||||
// [2,3,0] [6,8,0]
|
||||
//
|
||||
"HLFHELinalg.apply_lookup_table(%t, %lut)" : (tensor<3x3x!HLFHE.eint<2>>, tensor<4xi4>) -> tensor<3x3x!HLFHE.eint<3>>
|
||||
// [0,1,2] [1,3,5]
|
||||
// [3,0,1] lut [1,3,5,7] = [7,1,3]
|
||||
// [2,3,0] [5,7,1]
|
||||
"HLFHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!HLFHE.eint<3>>
|
||||
```
|
||||
}];
|
||||
|
||||
|
||||
@@ -242,6 +242,107 @@ struct HLFHELinalgOpToLinalgGeneric
|
||||
};
|
||||
};
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// operators `HLFHELinalg.apply_lookup_table` that implements the broadasting
|
||||
// rules to an instance of `linalg.generic` with an appropriate region using
|
||||
// `HLFHE.apply_lookup_table` operation, an appropriate specification for the
|
||||
// iteration dimensions and appropriate operaztions managing the accumulator of
|
||||
// `linalg.generic`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// HLFHELinalg.apply_lookup_table(%t, %lut):
|
||||
// tensor<DNx...xD1x!HLFHE.eint<p>>, tensor<DAxi64>
|
||||
// -> tensor<DNx...xD1x!HLFHE.eint<p'>>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// #maps_0 = [
|
||||
// affine_map<(aN, ..., a1) -> (aN, ..., a1)>,
|
||||
// affine_map<(aN, ..., a1) -> (aN, ..., a1)>
|
||||
// ]
|
||||
// #attributes_0 {
|
||||
// indexing_maps = #maps_0,
|
||||
// iterator_types = ["parallel",..],//N parallel
|
||||
// }
|
||||
// %init = linalg.init_tensor [DN,...,D1]
|
||||
// : tensor<DNx...xD1x!HLFHE.eint<p'>>
|
||||
// %res = linalg.generic {
|
||||
// ins(%t: tensor<DNx...xD1x!HLFHE.eint<p>>)
|
||||
// outs(%init : tensor<DNx...xD1x!HLFHE.eint<p'>>)
|
||||
// {
|
||||
// ^bb0(%arg0: !HLFHE.eint<p>):
|
||||
// %0 = HLFHE.apply_lookup_table(%arg0, %lut): !HLFHE.eint<p>,
|
||||
// tensor<4xi64> -> !HLFHE.eint<p'>
|
||||
// linalg.yield %0 : !HLFHE.eint<p'>
|
||||
// }
|
||||
// }
|
||||
//
|
||||
struct HLFHELinalgApplyLookupTableToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp> {
|
||||
HLFHELinalgApplyLookupTableToLinalgGeneric(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<
|
||||
mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp>(context,
|
||||
benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp lutOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::RankedTensorType resultTy =
|
||||
((mlir::Type)lutOp->getResult(0).getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
mlir::RankedTensorType tTy =
|
||||
((mlir::Type)lutOp.t().getType()).cast<mlir::RankedTensorType>();
|
||||
|
||||
// linalg.init_tensor for initial value
|
||||
mlir::Value init = rewriter.create<mlir::linalg::InitTensorOp>(
|
||||
lutOp.getLoc(), resultTy.getShape(), resultTy.getElementType());
|
||||
|
||||
// Create the affine #maps_0
|
||||
llvm::SmallVector<mlir::AffineMap, 2> maps{
|
||||
mlir::AffineMap::getMultiDimIdentityMap(tTy.getShape().size(),
|
||||
this->getContext()),
|
||||
mlir::AffineMap::getMultiDimIdentityMap(resultTy.getShape().size(),
|
||||
this->getContext()),
|
||||
};
|
||||
|
||||
// Create the iterator_types
|
||||
llvm::SmallVector<llvm::StringRef> iteratorTypes(resultTy.getShape().size(),
|
||||
"parallel");
|
||||
|
||||
// Create the body of the `linalg.generic` op
|
||||
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
mlir::ValueRange blockArgs) {
|
||||
mlir::zamalang::HLFHE::ApplyLookupTableEintOp hlfheOp =
|
||||
nestedBuilder.create<mlir::zamalang::HLFHE::ApplyLookupTableEintOp>(
|
||||
lutOp.getLoc(), resultTy.getElementType(), blockArgs[0],
|
||||
lutOp.lut());
|
||||
|
||||
nestedBuilder.create<mlir::linalg::YieldOp>(lutOp.getLoc(),
|
||||
hlfheOp.getResult());
|
||||
};
|
||||
|
||||
// Create the `linalg.generic` op
|
||||
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
|
||||
llvm::SmallVector<mlir::Value, 1> ins{lutOp.t()};
|
||||
llvm::SmallVector<mlir::Value, 1> outs{init};
|
||||
llvm::StringRef doc{""};
|
||||
llvm::StringRef call{""};
|
||||
|
||||
mlir::linalg::GenericOp genericOp =
|
||||
rewriter.create<mlir::linalg::GenericOp>(lutOp.getLoc(), resTypes, ins,
|
||||
outs, maps, iteratorTypes, doc,
|
||||
call, bodyBuilder);
|
||||
|
||||
rewriter.replaceOp(lutOp, {genericOp.getResult(0)});
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct HLFHETensorOpsToLinalg
|
||||
: public HLFHETensorOpsToLinalgBase<HLFHETensorOpsToLinalg> {
|
||||
@@ -280,6 +381,7 @@ void HLFHETensorOpsToLinalg::runOnFunction() {
|
||||
HLFHELinalgOpToLinalgGeneric<mlir::zamalang::HLFHELinalg::MulEintIntOp,
|
||||
mlir::zamalang::HLFHE::MulEintIntOp>>(
|
||||
&getContext());
|
||||
patterns.insert<HLFHELinalgApplyLookupTableToLinalgGeneric>(&getContext());
|
||||
|
||||
if (mlir::applyPartialConversion(function, target, std::move(patterns))
|
||||
.failed())
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
// RUN: zamacompiler %s --action=dump-midlfhe --passes hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-NEXT: module {
|
||||
// CHECK-NEXT: func @apply_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!HLFHE.eint<2>> {
|
||||
// CHECK-NEXT: %0 = linalg.init_tensor [2, 3, 4] : tensor<2x3x4x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3x4x!HLFHE.eint<2>>) outs(%0 : tensor<2x3x4x!HLFHE.eint<2>>) {
|
||||
// CHECK-NEXT: ^bb0(%arg2: !HLFHE.eint<2>, %arg3: !HLFHE.eint<2>): // no predecessors
|
||||
// CHECK-NEXT: %2 = "HLFHE.apply_lookup_table"(%arg2, %arg1) : (!HLFHE.eint<2>, tensor<4xi64>) -> !HLFHE.eint<2>
|
||||
// CHECK-NEXT: linalg.yield %2 : !HLFHE.eint<2>
|
||||
// CHECK-NEXT: } -> tensor<2x3x4x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: return %1 : tensor<2x3x4x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func @apply_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!HLFHE.eint<2>> {
|
||||
%1 = "HLFHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!HLFHE.eint<2>>, tensor<4xi64>) -> (tensor<2x3x4x!HLFHE.eint<2>>)
|
||||
return %1: tensor<2x3x4x!HLFHE.eint<2>>
|
||||
}
|
||||
@@ -867,3 +867,52 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_matrix_line_missing_dim) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// HLFHELinalg apply_lookup_table /////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(End2EndJit_HLFHELinalg, apply_lookup_table) {
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
// Returns the lookup of 3x3 matrix of encrypted indices of with 2 on a table of size 4=2² of clear integers.
|
||||
//
|
||||
// [0,1,2] [1,3,5]
|
||||
// [3,0,1] lut [1,3,5,7] = [7,1,3]
|
||||
// [2,3,0] [5,7,1]
|
||||
func @main(%t: tensor<3x3x!HLFHE.eint<2>>) -> tensor<3x3x!HLFHE.eint<3>> {
|
||||
%lut = std.constant dense<[1,3,5,7]> : tensor<4xi64>
|
||||
%res = "HLFHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!HLFHE.eint<3>>
|
||||
return %res : tensor<3x3x!HLFHE.eint<3>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
const uint8_t t[3][3]{
|
||||
{0, 1, 2},
|
||||
{3, 0, 1},
|
||||
{2, 3, 0},
|
||||
};
|
||||
const uint8_t expected[3][3]{
|
||||
{1, 3, 5},
|
||||
{7, 1, 3},
|
||||
{5, 7, 1},
|
||||
};
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>
|
||||
tArg(llvm::MutableArrayRef<uint8_t>((uint8_t *)t, 3 * 3), {3, 3});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&tArg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), 3 * 3);
|
||||
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
for (size_t j = 0; j < 3; j++) {
|
||||
EXPECT_EQ((*res)[i * 3 + j], expected[i][j])
|
||||
<< ", at pos(" << i << "," << j << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user