mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(compiler): Lower HLFHELinalg binary operators
This commit is contained in:
committed by
Andi Drebes
parent
dea1be9d52
commit
ba54560680
@@ -6,10 +6,12 @@ add_mlir_dialect_library(HLFHETensorOpsToLinalg
|
||||
|
||||
DEPENDS
|
||||
HLFHEDialect
|
||||
HLFHELinalgDialect
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
HLFHEDialect)
|
||||
HLFHEDialect
|
||||
HLFHELinalgDialect)
|
||||
|
||||
target_link_libraries(HLFHEDialect PUBLIC MLIRIR)
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h"
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h"
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h"
|
||||
|
||||
struct DotToLinalgGeneric : public ::mlir::RewritePattern {
|
||||
DotToLinalgGeneric(::mlir::MLIRContext *context)
|
||||
@@ -121,6 +123,125 @@ struct DotToLinalgGeneric : public ::mlir::RewritePattern {
|
||||
};
|
||||
};
|
||||
|
||||
mlir::AffineMap
|
||||
getBroadcastedAffineMap(const mlir::RankedTensorType &resultType,
|
||||
const mlir::RankedTensorType &operandType,
|
||||
::mlir::PatternRewriter &rewriter) {
|
||||
mlir::SmallVector<mlir::AffineExpr, 4> affineExprs;
|
||||
auto resultShape = resultType.getShape();
|
||||
auto operandShape = operandType.getShape();
|
||||
affineExprs.reserve(resultShape.size());
|
||||
size_t deltaNumDim = resultShape.size() - operandShape.size();
|
||||
for (auto i = 0; i < operandShape.size(); i++) {
|
||||
if (operandShape[i] == 1) {
|
||||
affineExprs.push_back(rewriter.getAffineConstantExpr(0));
|
||||
} else {
|
||||
affineExprs.push_back(rewriter.getAffineDimExpr(i + deltaNumDim));
|
||||
}
|
||||
}
|
||||
return mlir::AffineMap::get(resultShape.size(), 0, affineExprs,
|
||||
rewriter.getContext());
|
||||
}
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// operators `HLFHELinalgOp` that implements the broadasting rules to an
|
||||
// instance of `linalg.generic` with an appropriate region using `HLFHEOp`
|
||||
// operation, an appropriate specification for the iteration dimensions and
|
||||
// appropriate operaztions managing the accumulator of `linalg.generic`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// %res = HLFHELinalg.op(%lhs, %rhs):
|
||||
// (tensor<D$Ax...xD1x!HLFHE.eint<p>>, tensor<D$B'x...xD1'xT>)
|
||||
// -> tensor<DR"x...xD1"x!HLFHE.eint<p>>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// #maps_0 = [
|
||||
// affine_map<(a$R", ..., a$A, ..., a1) ->
|
||||
// (dim(lhs, $A) == 1 ? 0 : a$A,..., dim(lhs, 1) == 1 ? 0 : a1)>,
|
||||
// affine_map<(a$R", ..., a1) ->
|
||||
// (dim(rhs, $B') == 1 ? 0 : a$B', ..., dim(rhs, 1) == 1 ? 0 : a1)>,
|
||||
// affine_map<(a$R", ..., a1) -> (a$R", ..., a1)
|
||||
// ]
|
||||
// #attributes_0 {
|
||||
// indexing_maps = #maps_0,
|
||||
// iterator_types = ["parallel", ..., "parallel"], // $R" parallel
|
||||
// }
|
||||
// %init = linalg.init_tensor [DR",...,D1"]
|
||||
// : tensor<DR"x...xD1"x!HLFHE.eint<p>>
|
||||
// %res = linalg.generic {
|
||||
// ins(%lhs, %rhs: tensor<DAx...xD1x!HLFHE.eint<p>>,tensor<DB'x...xD1'xT>)
|
||||
// outs(%init : tensor<DR"x...xD1"x!HLFHE.eint<p>>)
|
||||
// {
|
||||
// ^bb0(%arg0: !HLFHE.eint<p>, %arg1: T):
|
||||
// %0 = HLFHE.op(%arg0, %arg1): !HLFHE.eint<p>, T ->
|
||||
// !HLFHE.eint<p>
|
||||
// linalg.yield %0 : !HLFHE.eint<p>
|
||||
// }
|
||||
// }
|
||||
//
|
||||
template <typename HLFHELinalgOp, typename HLFHEOp>
|
||||
struct HLFHELinalgOpToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<HLFHELinalgOp> {
|
||||
HLFHELinalgOpToLinalgGeneric(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<HLFHELinalgOp>(context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(HLFHELinalgOp linalgOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::RankedTensorType resultTy =
|
||||
((mlir::Type)linalgOp->getResult(0).getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
mlir::RankedTensorType lhsTy =
|
||||
((mlir::Type)linalgOp.lhs().getType()).cast<mlir::RankedTensorType>();
|
||||
mlir::RankedTensorType rhsTy =
|
||||
((mlir::Type)linalgOp.rhs().getType()).cast<mlir::RankedTensorType>();
|
||||
// linalg.init_tensor for initial value
|
||||
mlir::Value init = rewriter.create<mlir::linalg::InitTensorOp>(
|
||||
linalgOp.getLoc(), resultTy.getShape(), resultTy.getElementType());
|
||||
|
||||
// Create the affine #maps_0
|
||||
llvm::SmallVector<mlir::AffineMap, 3> maps{
|
||||
getBroadcastedAffineMap(resultTy, lhsTy, rewriter),
|
||||
getBroadcastedAffineMap(resultTy, rhsTy, rewriter),
|
||||
getBroadcastedAffineMap(resultTy, resultTy, rewriter),
|
||||
};
|
||||
|
||||
// Create the iterator_types
|
||||
llvm::SmallVector<llvm::StringRef> iteratorTypes(resultTy.getShape().size(),
|
||||
"parallel");
|
||||
|
||||
// Create the body of the `linalg.generic` op
|
||||
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
mlir::ValueRange blockArgs) {
|
||||
HLFHEOp hlfheOp = nestedBuilder.create<HLFHEOp>(
|
||||
linalgOp.getLoc(), blockArgs[0], blockArgs[1]);
|
||||
|
||||
nestedBuilder.create<mlir::linalg::YieldOp>(linalgOp.getLoc(),
|
||||
hlfheOp.getResult());
|
||||
};
|
||||
|
||||
// Create the `linalg.generic` op
|
||||
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
|
||||
llvm::SmallVector<mlir::Value, 2> ins{linalgOp.lhs(), linalgOp.rhs()};
|
||||
llvm::SmallVector<mlir::Value, 1> outs{init};
|
||||
llvm::StringRef doc{""};
|
||||
llvm::StringRef call{""};
|
||||
|
||||
mlir::linalg::GenericOp genericOp =
|
||||
rewriter.create<mlir::linalg::GenericOp>(linalgOp.getLoc(), resTypes,
|
||||
ins, outs, maps, iteratorTypes,
|
||||
doc, call, bodyBuilder);
|
||||
|
||||
rewriter.replaceOp(linalgOp, {genericOp.getResult(0)});
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct HLFHETensorOpsToLinalg
|
||||
: public HLFHETensorOpsToLinalgBase<HLFHETensorOpsToLinalg> {
|
||||
@@ -139,9 +260,14 @@ void HLFHETensorOpsToLinalg::runOnFunction() {
|
||||
target.addLegalDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
|
||||
target.addLegalDialect<mlir::tensor::TensorDialect>();
|
||||
target.addIllegalOp<mlir::zamalang::HLFHE::Dot>();
|
||||
target.addIllegalDialect<mlir::zamalang::HLFHELinalg::HLFHELinalgDialect>();
|
||||
|
||||
mlir::OwningRewritePatternList patterns(&getContext());
|
||||
patterns.insert<DotToLinalgGeneric>(&getContext());
|
||||
patterns.insert<
|
||||
HLFHELinalgOpToLinalgGeneric<mlir::zamalang::HLFHELinalg::AddEintOp,
|
||||
mlir::zamalang::HLFHE::AddEintOp>>(
|
||||
&getContext());
|
||||
|
||||
if (mlir::applyPartialConversion(function, target, std::move(patterns))
|
||||
.failed())
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <mlir/Parser.h>
|
||||
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHEDialect.h>
|
||||
#include <zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h>
|
||||
#include <zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h>
|
||||
#include <zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h>
|
||||
#include <zamalang/Support/CompilerEngine.h>
|
||||
@@ -15,6 +16,7 @@ namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
void CompilerEngine::loadDialects() {
|
||||
context->getOrLoadDialect<mlir::zamalang::HLFHELinalg::HLFHELinalgDialect>();
|
||||
context->getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
|
||||
context->getOrLoadDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
|
||||
context->getOrLoadDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
|
||||
|
||||
@@ -168,4 +168,222 @@ func @main(%t0: tensor<2x10x!HLFHE.eint<6>>, %t1: tensor<2x2x!HLFHE.eint<6>>) ->
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// HLFHELinalg add_eint ///////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(End2EndJit_HLFHELinalg, add_eint_term_to_term) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
// Returns the term to term addition of `%a0` with `%a1`
|
||||
func @main(%a0: tensor<4x!HLFHE.eint<4>>, %a1: tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>> {
|
||||
%res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>>
|
||||
return %res : tensor<4x!HLFHE.eint<4>>
|
||||
}
|
||||
)XXX";
|
||||
const uint8_t a0[4]{31, 6, 12, 9};
|
||||
const uint8_t a1[4]{32, 9, 2, 3};
|
||||
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %a0 and %a1 argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)a0, 4));
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, (uint8_t *)a1, 4));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t result[4];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, 4));
|
||||
for (size_t i = 0; i < 4; i++) {
|
||||
EXPECT_EQ(result[i], a0[i] + a1[i])
|
||||
<< "result differ at pos " << i << ", expect " << a0[i] + a1[i]
|
||||
<< " got " << result[i];
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_HLFHELinalg, add_eint_term_to_term_broadcast) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
// Returns the term to term addition of `%a0` with `%a1`
|
||||
func @main(%a0: tensor<4x1x4x!HLFHE.eint<4>>, %a1: tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>> {
|
||||
%res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>>
|
||||
return %res : tensor<4x4x4x!HLFHE.eint<4>>
|
||||
}
|
||||
)XXX";
|
||||
const uint8_t a0[4][1][4]{
|
||||
{{1, 2, 3, 4}},
|
||||
{{5, 6, 7, 8}},
|
||||
{{9, 10, 11, 12}},
|
||||
{{13, 14, 15, 16}},
|
||||
};
|
||||
const uint8_t a1[1][4][4]{
|
||||
{
|
||||
{1, 2, 3, 4},
|
||||
{5, 6, 7, 8},
|
||||
{9, 10, 11, 12},
|
||||
{13, 14, 15, 16},
|
||||
},
|
||||
};
|
||||
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %a0 and %a1 argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)a0, {4, 1, 4}));
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, (uint8_t *)a1, {1, 4, 4}));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t result[4][4][4];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, 4 * 4 * 4));
|
||||
for (size_t i = 0; i < 4; i++) {
|
||||
for (size_t j = 0; j < 4; j++) {
|
||||
for (size_t k = 0; k < 4; k++) {
|
||||
EXPECT_EQ(result[i][j][k], a0[i][0][k] + a1[0][j][k])
|
||||
<< "result differ at pos " << i << ", expect "
|
||||
<< a0[i][0][k] + a1[0][j][k] << " got " << result[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_HLFHELinalg, add_eint_matrix_column) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
// Returns the addition of a 3x3 matrix of encrypted integers and a 3x1 matrix (a column) of encrypted integers.
|
||||
//
|
||||
// [1,2,3] [1] [2,3,4]
|
||||
// [4,5,6] + [2] = [6,7,8]
|
||||
// [7,8,9] [3] [10,11,12]
|
||||
//
|
||||
// The dimension #1 of operand #2 is stretched as it is equals to 1.
|
||||
func @main(%a0: tensor<3x3x!HLFHE.eint<4>>, %a1: tensor<3x1x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> {
|
||||
%res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
return %res : tensor<3x3x!HLFHE.eint<4>>
|
||||
}
|
||||
)XXX";
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
{7, 8, 9},
|
||||
};
|
||||
const uint8_t a1[3][1]{
|
||||
{1},
|
||||
{2},
|
||||
{3},
|
||||
};
|
||||
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %a0 and %a1 argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)a0, {3, 3}));
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, (uint8_t *)a1, {3, 1}));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t result[3][3];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, 3 * 3));
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
for (size_t j = 0; j < 3; j++) {
|
||||
EXPECT_EQ(result[i][j], a0[i][j] + a1[i][0])
|
||||
<< "result differ at pos " << i << ", expect " << a0[i][j] + a1[i][0]
|
||||
<< " got " << result[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_HLFHELinalg, add_eint_matrix_line) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
// Returns the addition of a 3x3 matrix of encrypted integers and a 1x3 matrix (a line) of encrypted integers.
|
||||
//
|
||||
// [1,2,3] [2,4,6]
|
||||
// [4,5,6] + [1,2,3] = [5,7,9]
|
||||
// [7,8,9] [8,10,12]
|
||||
//
|
||||
// The dimension #2 of operand #2 is stretched as it is equals to 1.
|
||||
func @main(%a0: tensor<3x3x!HLFHE.eint<4>>, %a1: tensor<1x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> {
|
||||
%res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<1x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
return %res : tensor<3x3x!HLFHE.eint<4>>
|
||||
}
|
||||
)XXX";
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
{7, 8, 9},
|
||||
};
|
||||
const uint8_t a1[1][3]{
|
||||
{1, 2, 3},
|
||||
};
|
||||
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %a0 and %a1 argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)a0, {3, 3}));
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, (uint8_t *)a1, {1, 3}));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t result[3][3];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, 3 * 3));
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
for (size_t j = 0; j < 3; j++) {
|
||||
EXPECT_EQ(result[i][j], a0[i][j] + a1[0][j])
|
||||
<< "result differ at pos (" << i << "," << j << "), expect "
|
||||
<< a0[i][j] + a1[0][j] << " got " << result[i][j] << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_HLFHELinalg, add_eint_matrix_line_missing_dim) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
// Same behavior than the previous one, but as the dimension #2 of operand #2 is missing.
|
||||
func @main(%a0: tensor<3x3x!HLFHE.eint<4>>, %a1: tensor<3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> {
|
||||
%res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
return %res : tensor<3x3x!HLFHE.eint<4>>
|
||||
}
|
||||
)XXX";
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
{7, 8, 9},
|
||||
};
|
||||
const uint8_t a1[1][3]{
|
||||
{1, 2, 3},
|
||||
};
|
||||
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %a0 and %a1 argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)a0, {3, 3}));
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, (uint8_t *)a1, {3}));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t result[3][3];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, 3 * 3));
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
for (size_t j = 0; j < 3; j++) {
|
||||
EXPECT_EQ(result[i][j], a0[i][j] + a1[0][j])
|
||||
<< "result differ at pos (" << i << "," << j << "), expect "
|
||||
<< a0[i][j] + a1[0][j] << " got " << result[i][j] << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user