mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(compiler): Lower HLFHELinalg.matmul_eint_int to linalg.generic (close #177)
This commit is contained in:
@@ -436,6 +436,138 @@ struct HLFHELinalgNegEintToLinalgGeneric
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of
|
||||
// operators `HLFHELinalg.matmul_eint_int` to an instance of `linalg.generic`
|
||||
// with an appropriate region using `HLFHE.mul_eint_int` and `HLFHE.add_eint`
|
||||
// operation, an appropriate specification for the iteration dimensions and
|
||||
// appropriate operations managing the accumulator of `linalg.generic`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// "HLFHELinalg.matmul_eint_int(%a, %b) :
|
||||
// (tensor<MxPx!HLFHE.eint<p>>, tensor<PxNxip'>) ->
|
||||
// tensor<MxNx!HLFHE.eint<p>>"
|
||||
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// #maps_0 = [
|
||||
// (m, n, p) -> (m, p),
|
||||
// (m, n, p) -> (p, n),
|
||||
// (m, n, p) -> (m, n)
|
||||
// ]
|
||||
// #attributes_0 = {
|
||||
// indexing_maps = #maps_0,
|
||||
// iterator_types = ["parallel", "parallel", "reduction"]
|
||||
// }
|
||||
// %init = linalg.generate {
|
||||
// ^bb0(%i : index, %j : index, %k : index):
|
||||
// %z = "HLFHE.zero" : () -> !HLFHE.eint<2>
|
||||
// linalg.yield %z
|
||||
// }: tensor<MxNx!HLFHE.eint<p>>
|
||||
// linalg.generic #attributes_0
|
||||
// ins(%A, %B : tensor<MxPx!HLFHE.eint<p>>,
|
||||
// tensor<PxNxip'>)
|
||||
// outs(%C : tensor<MxNx!HLFHE.eint<p>>)
|
||||
// {
|
||||
// ^bb0(%a: !HLFHE.eint<p>, %b: ip', %c: !HLFHE.eint<p>) :
|
||||
// %d = "HLFHE.mul_eint_int"(%a, %b) :
|
||||
// (!HLFHE.eint<p>, ip') -> !HLFHE.eint<p>
|
||||
// %e = "HLFHE.add_eint"(%c, %d):
|
||||
// (!HLFHE.eint<p>, !HLFHE.eint<p>) -> !HLFHE.eint<p>
|
||||
// linalg.yield %e : !HLFHE.eint<p>
|
||||
// }
|
||||
//
|
||||
struct HLFHELinalgMatmulEintIntToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::zamalang::HLFHELinalg::MatMulEintIntOp> {
|
||||
HLFHELinalgMatmulEintIntToLinalgGeneric(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::zamalang::HLFHELinalg::MatMulEintIntOp>(
|
||||
context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::zamalang::HLFHELinalg::MatMulEintIntOp matmulOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::RankedTensorType resultTy =
|
||||
((mlir::Type)matmulOp->getResult(0).getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
// Create tensor.generate for initial value
|
||||
auto generateBody = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
mlir::ValueRange blockArgs) {
|
||||
// %z = "HLFHE.zero" : () -> !HLFHE.eint<2>
|
||||
mlir::zamalang::HLFHE::ZeroEintOp zeroOp =
|
||||
nestedBuilder.create<mlir::zamalang::HLFHE::ZeroEintOp>(
|
||||
matmulOp.getLoc(), resultTy.getElementType());
|
||||
// linalg.yield %z : !HLFHE.eint<p>
|
||||
nestedBuilder.create<mlir::tensor::YieldOp>(matmulOp.getLoc(),
|
||||
zeroOp.getResult());
|
||||
};
|
||||
mlir::tensor::GenerateOp init = rewriter.create<mlir::tensor::GenerateOp>(
|
||||
matmulOp.getLoc(), (mlir::Type)resultTy, mlir::ValueRange{},
|
||||
generateBody);
|
||||
// linalg.init_tensor for initial value
|
||||
// mlir::Value init = rewriter.create<mlir::linalg::InitTensorOp>(
|
||||
// matmulOp.getLoc(), resultTy.getShape(), resultTy.getElementType());
|
||||
// Create the affine #maps_0
|
||||
llvm::SmallVector<mlir::AffineMap> maps{
|
||||
// (m, n, p) -> (m, p),
|
||||
mlir::AffineMap::get(
|
||||
3, 0, {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(2)},
|
||||
rewriter.getContext()),
|
||||
// (m, n, p) -> (p, n),
|
||||
mlir::AffineMap::get(
|
||||
3, 0, {rewriter.getAffineDimExpr(2), rewriter.getAffineDimExpr(1)},
|
||||
rewriter.getContext()),
|
||||
// (m, n, p) -> (m, n)
|
||||
mlir::AffineMap::get(
|
||||
3, 0, {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
|
||||
rewriter.getContext()),
|
||||
};
|
||||
|
||||
// Create the iterator_types
|
||||
llvm::SmallVector<llvm::StringRef> iteratorTypes{"parallel", "parallel",
|
||||
"reduction"};
|
||||
|
||||
// Create the body of the `linalg.generic` op
|
||||
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
mlir::ValueRange blockArgs) {
|
||||
// "HLFHE.mul_eint_int"(%a, %b) : (!HLFHE.eint<p>, ip') -> !HLFHE.eint<p>
|
||||
mlir::zamalang::HLFHE::MulEintIntOp mulEintIntOp =
|
||||
nestedBuilder.create<mlir::zamalang::HLFHE::MulEintIntOp>(
|
||||
matmulOp.getLoc(), resultTy.getElementType(), blockArgs[0],
|
||||
blockArgs[1]);
|
||||
// "HLFHE.add_eint"(%c, %d): (!HLFHE.eint<p>, !HLFHE.eint<p>) ->
|
||||
// !HLFHE.eint<p>
|
||||
mlir::zamalang::HLFHE::AddEintOp addEintOp =
|
||||
nestedBuilder.create<mlir::zamalang::HLFHE::AddEintOp>(
|
||||
matmulOp.getLoc(), resultTy.getElementType(), blockArgs[2],
|
||||
mulEintIntOp);
|
||||
// linalg.yield %e : !HLFHE.eint<p>
|
||||
nestedBuilder.create<mlir::linalg::YieldOp>(matmulOp.getLoc(),
|
||||
addEintOp.getResult());
|
||||
};
|
||||
|
||||
// Create the `linalg.generic` op
|
||||
llvm::SmallVector<mlir::Type> resTypes{init.getType()};
|
||||
llvm::SmallVector<mlir::Value> ins{matmulOp.lhs(), matmulOp.rhs()};
|
||||
llvm::SmallVector<mlir::Value> outs{init};
|
||||
llvm::StringRef doc{""};
|
||||
llvm::StringRef call{""};
|
||||
|
||||
mlir::linalg::GenericOp genericOp =
|
||||
rewriter.create<mlir::linalg::GenericOp>(matmulOp.getLoc(), resTypes,
|
||||
ins, outs, maps, iteratorTypes,
|
||||
doc, call, bodyBuilder);
|
||||
|
||||
rewriter.replaceOp(matmulOp, {genericOp.getResult(0)});
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct HLFHETensorOpsToLinalg
|
||||
: public HLFHETensorOpsToLinalgBase<HLFHETensorOpsToLinalg> {
|
||||
@@ -477,6 +609,7 @@ void HLFHETensorOpsToLinalg::runOnFunction() {
|
||||
&getContext());
|
||||
patterns.insert<HLFHELinalgApplyLookupTableToLinalgGeneric>(&getContext());
|
||||
patterns.insert<HLFHELinalgNegEintToLinalgGeneric>(&getContext());
|
||||
patterns.insert<HLFHELinalgMatmulEintIntToLinalgGeneric>(&getContext());
|
||||
|
||||
if (mlir::applyPartialConversion(function, target, std::move(patterns))
|
||||
.failed())
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
// RUN: zamacompiler %s --action=dump-midlfhe --passes hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: #map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK-NEXT: #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
// CHECK-NEXT: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
// CHECK-NEXT: module {
|
||||
// CHECK-NEXT: func @matmul_eint_int(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> {
|
||||
// CHECK-NEXT: %0 = tensor.generate {
|
||||
// CHECK-NEXT: ^bb0(%arg2: index, %arg3: index): // no predecessors
|
||||
// CHECK-NEXT: %2 = "HLFHE.zero"() : () -> !HLFHE.eint<2>
|
||||
// CHECK-NEXT: tensor.yield %2 : !HLFHE.eint<2>
|
||||
// CHECK-NEXT: } : tensor<3x2x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) outs(%0 : tensor<3x2x!HLFHE.eint<2>>) {
|
||||
// CHECK-NEXT: ^bb0(%arg2: !HLFHE.eint<2>, %arg3: i3, %arg4: !HLFHE.eint<2>): // no predecessors
|
||||
// CHECK-NEXT: %2 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
|
||||
// CHECK-NEXT: %3 = "HLFHE.add_eint"(%arg4, %2) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2>
|
||||
// CHECK-NEXT: linalg.yield %3 : !HLFHE.eint<2>
|
||||
// CHECK-NEXT: } -> tensor<3x2x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: return %1 : tensor<3x2x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
func @matmul_eint_int(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> {
|
||||
%1 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
return %1 : tensor<3x2x!HLFHE.eint<2>>
|
||||
}
|
||||
@@ -1114,3 +1114,60 @@ TEST(End2EndJit_HLFHELinalg, neg_eint) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// HLFHELinalg matmul_eint_int ////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(End2EndJit_HLFHELinalg, matmul_eint_int) {
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
// Returns the matrix multiplication of a 3x2 matrix of encrypted integers and a 2x3 matrix of integers.
|
||||
// [ 1, 2, 3]
|
||||
// [ 2, 3, 4]
|
||||
// *
|
||||
// [1,2] [ 5, 8,11]
|
||||
// [3,4] = [11,18,25]
|
||||
// [5,6] [17,28,39]
|
||||
func @main(%a: tensor<3x2x!HLFHE.eint<6>>, %b: tensor<2x3xi7>) -> tensor<3x3x!HLFHE.eint<6>> {
|
||||
%0 = "HLFHELinalg.matmul_eint_int"(%a, %b) : (tensor<3x2x!HLFHE.eint<6>>, tensor<2x3xi7>) -> tensor<3x3x!HLFHE.eint<6>>
|
||||
return %0 : tensor<3x3x!HLFHE.eint<6>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
const uint8_t A[3][2]{
|
||||
{1, 2},
|
||||
{3, 4},
|
||||
{5, 6},
|
||||
};
|
||||
const uint8_t B[2][3]{
|
||||
{1, 2, 3},
|
||||
{2, 3, 4},
|
||||
};
|
||||
const uint8_t expected[3][3]{
|
||||
{5, 8, 11},
|
||||
{11, 18, 25},
|
||||
{17, 28, 39},
|
||||
};
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>
|
||||
aArg(llvm::MutableArrayRef<uint8_t>((uint8_t *)A, 3 * 2), {3, 2});
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>
|
||||
bArg(llvm::MutableArrayRef<uint8_t>((uint8_t *)B, 2 * 3), {2, 3});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&aArg, &bArg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), 3 * 3);
|
||||
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
for (size_t j = 0; j < 3; j++) {
|
||||
EXPECT_EQ((*res)[i * 3 + j], expected[i][j])
|
||||
<< ", at pos(" << i << "," << j << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user