feat(compiler): Lower HLFHELinalg binary operators

This commit is contained in:
Quentin Bourgerie
2021-10-25 15:37:48 +02:00
committed by Andi Drebes
parent dea1be9d52
commit ba54560680
4 changed files with 349 additions and 1 deletions

View File

@@ -6,10 +6,12 @@ add_mlir_dialect_library(HLFHETensorOpsToLinalg
DEPENDS
HLFHEDialect
HLFHELinalgDialect
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
MLIRIR
HLFHEDialect)
HLFHEDialect
HLFHELinalgDialect)
target_link_libraries(HLFHEDialect PUBLIC MLIRIR)

View File

@@ -13,6 +13,8 @@
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h"
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h"
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h"
struct DotToLinalgGeneric : public ::mlir::RewritePattern {
DotToLinalgGeneric(::mlir::MLIRContext *context)
@@ -121,6 +123,125 @@ struct DotToLinalgGeneric : public ::mlir::RewritePattern {
};
};
mlir::AffineMap
getBroadcastedAffineMap(const mlir::RankedTensorType &resultType,
const mlir::RankedTensorType &operandType,
::mlir::PatternRewriter &rewriter) {
mlir::SmallVector<mlir::AffineExpr, 4> affineExprs;
auto resultShape = resultType.getShape();
auto operandShape = operandType.getShape();
affineExprs.reserve(resultShape.size());
size_t deltaNumDim = resultShape.size() - operandShape.size();
for (auto i = 0; i < operandShape.size(); i++) {
if (operandShape[i] == 1) {
affineExprs.push_back(rewriter.getAffineConstantExpr(0));
} else {
affineExprs.push_back(rewriter.getAffineDimExpr(i + deltaNumDim));
}
}
return mlir::AffineMap::get(resultShape.size(), 0, affineExprs,
rewriter.getContext());
}
// This template rewrite pattern transforms any instance of
// operators `HLFHELinalgOp` that implements the broadasting rules to an
// instance of `linalg.generic` with an appropriate region using `HLFHEOp`
// operation, an appropriate specification for the iteration dimensions and
// appropriate operaztions managing the accumulator of `linalg.generic`.
//
// Example:
//
// %res = HLFHELinalg.op(%lhs, %rhs):
// (tensor<D$Ax...xD1x!HLFHE.eint<p>>, tensor<D$B'x...xD1'xT>)
// -> tensor<DR"x...xD1"x!HLFHE.eint<p>>
//
// becomes:
//
// #maps_0 = [
// affine_map<(a$R", ..., a$A, ..., a1) ->
// (dim(lhs, $A) == 1 ? 0 : a$A,..., dim(lhs, 1) == 1 ? 0 : a1)>,
// affine_map<(a$R", ..., a1) ->
// (dim(rhs, $B') == 1 ? 0 : a$B', ..., dim(rhs, 1) == 1 ? 0 : a1)>,
// affine_map<(a$R", ..., a1) -> (a$R", ..., a1)
// ]
// #attributes_0 {
// indexing_maps = #maps_0,
// iterator_types = ["parallel", ..., "parallel"], // $R" parallel
// }
// %init = linalg.init_tensor [DR",...,D1"]
// : tensor<DR"x...xD1"x!HLFHE.eint<p>>
// %res = linalg.generic {
// ins(%lhs, %rhs: tensor<DAx...xD1x!HLFHE.eint<p>>,tensor<DB'x...xD1'xT>)
// outs(%init : tensor<DR"x...xD1"x!HLFHE.eint<p>>)
// {
// ^bb0(%arg0: !HLFHE.eint<p>, %arg1: T):
// %0 = HLFHE.op(%arg0, %arg1): !HLFHE.eint<p>, T ->
// !HLFHE.eint<p>
// linalg.yield %0 : !HLFHE.eint<p>
// }
// }
//
template <typename HLFHELinalgOp, typename HLFHEOp>
struct HLFHELinalgOpToLinalgGeneric
: public mlir::OpRewritePattern<HLFHELinalgOp> {
HLFHELinalgOpToLinalgGeneric(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<HLFHELinalgOp>(context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(HLFHELinalgOp linalgOp,
::mlir::PatternRewriter &rewriter) const override {
mlir::RankedTensorType resultTy =
((mlir::Type)linalgOp->getResult(0).getType())
.cast<mlir::RankedTensorType>();
mlir::RankedTensorType lhsTy =
((mlir::Type)linalgOp.lhs().getType()).cast<mlir::RankedTensorType>();
mlir::RankedTensorType rhsTy =
((mlir::Type)linalgOp.rhs().getType()).cast<mlir::RankedTensorType>();
// linalg.init_tensor for initial value
mlir::Value init = rewriter.create<mlir::linalg::InitTensorOp>(
linalgOp.getLoc(), resultTy.getShape(), resultTy.getElementType());
// Create the affine #maps_0
llvm::SmallVector<mlir::AffineMap, 3> maps{
getBroadcastedAffineMap(resultTy, lhsTy, rewriter),
getBroadcastedAffineMap(resultTy, rhsTy, rewriter),
getBroadcastedAffineMap(resultTy, resultTy, rewriter),
};
// Create the iterator_types
llvm::SmallVector<llvm::StringRef> iteratorTypes(resultTy.getShape().size(),
"parallel");
// Create the body of the `linalg.generic` op
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
HLFHEOp hlfheOp = nestedBuilder.create<HLFHEOp>(
linalgOp.getLoc(), blockArgs[0], blockArgs[1]);
nestedBuilder.create<mlir::linalg::YieldOp>(linalgOp.getLoc(),
hlfheOp.getResult());
};
// Create the `linalg.generic` op
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
llvm::SmallVector<mlir::Value, 2> ins{linalgOp.lhs(), linalgOp.rhs()};
llvm::SmallVector<mlir::Value, 1> outs{init};
llvm::StringRef doc{""};
llvm::StringRef call{""};
mlir::linalg::GenericOp genericOp =
rewriter.create<mlir::linalg::GenericOp>(linalgOp.getLoc(), resTypes,
ins, outs, maps, iteratorTypes,
doc, call, bodyBuilder);
rewriter.replaceOp(linalgOp, {genericOp.getResult(0)});
return ::mlir::success();
};
};
namespace {
struct HLFHETensorOpsToLinalg
: public HLFHETensorOpsToLinalgBase<HLFHETensorOpsToLinalg> {
@@ -139,9 +260,14 @@ void HLFHETensorOpsToLinalg::runOnFunction() {
target.addLegalDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
target.addLegalDialect<mlir::tensor::TensorDialect>();
target.addIllegalOp<mlir::zamalang::HLFHE::Dot>();
target.addIllegalDialect<mlir::zamalang::HLFHELinalg::HLFHELinalgDialect>();
mlir::OwningRewritePatternList patterns(&getContext());
patterns.insert<DotToLinalgGeneric>(&getContext());
patterns.insert<
HLFHELinalgOpToLinalgGeneric<mlir::zamalang::HLFHELinalg::AddEintOp,
mlir::zamalang::HLFHE::AddEintOp>>(
&getContext());
if (mlir::applyPartialConversion(function, target, std::move(patterns))
.failed())

View File

@@ -6,6 +6,7 @@
#include <mlir/Parser.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHEDialect.h>
#include <zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h>
#include <zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h>
#include <zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h>
#include <zamalang/Support/CompilerEngine.h>
@@ -15,6 +16,7 @@ namespace mlir {
namespace zamalang {
void CompilerEngine::loadDialects() {
context->getOrLoadDialect<mlir::zamalang::HLFHELinalg::HLFHELinalgDialect>();
context->getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
context->getOrLoadDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
context->getOrLoadDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();

View File

@@ -168,4 +168,222 @@ func @main(%t0: tensor<2x10x!HLFHE.eint<6>>, %t1: tensor<2x2x!HLFHE.eint<6>>) ->
}
}
}
}
///////////////////////////////////////////////////////////////////////////////
// HLFHELinalg add_eint ///////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(End2EndJit_HLFHELinalg, add_eint_term_to_term) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
// Returns the term to term addition of `%a0` with `%a1`
func @main(%a0: tensor<4x!HLFHE.eint<4>>, %a1: tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>> {
%res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>>
return %res : tensor<4x!HLFHE.eint<4>>
}
)XXX";
const uint8_t a0[4]{31, 6, 12, 9};
const uint8_t a1[4]{32, 9, 2, 3};
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %a0 and %a1 argument
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)a0, 4));
ASSERT_LLVM_ERROR(argument->setArg(1, (uint8_t *)a1, 4));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t result[4];
ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, 4));
for (size_t i = 0; i < 4; i++) {
EXPECT_EQ(result[i], a0[i] + a1[i])
<< "result differ at pos " << i << ", expect " << a0[i] + a1[i]
<< " got " << result[i];
}
}
TEST(End2EndJit_HLFHELinalg, add_eint_term_to_term_broadcast) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
// Returns the term to term addition of `%a0` with `%a1`
func @main(%a0: tensor<4x1x4x!HLFHE.eint<4>>, %a1: tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>> {
%res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>>
return %res : tensor<4x4x4x!HLFHE.eint<4>>
}
)XXX";
const uint8_t a0[4][1][4]{
{{1, 2, 3, 4}},
{{5, 6, 7, 8}},
{{9, 10, 11, 12}},
{{13, 14, 15, 16}},
};
const uint8_t a1[1][4][4]{
{
{1, 2, 3, 4},
{5, 6, 7, 8},
{9, 10, 11, 12},
{13, 14, 15, 16},
},
};
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %a0 and %a1 argument
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)a0, {4, 1, 4}));
ASSERT_LLVM_ERROR(argument->setArg(1, (uint8_t *)a1, {1, 4, 4}));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t result[4][4][4];
ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, 4 * 4 * 4));
for (size_t i = 0; i < 4; i++) {
for (size_t j = 0; j < 4; j++) {
for (size_t k = 0; k < 4; k++) {
EXPECT_EQ(result[i][j][k], a0[i][0][k] + a1[0][j][k])
<< "result differ at pos " << i << ", expect "
<< a0[i][0][k] + a1[0][j][k] << " got " << result[i];
}
}
}
}
TEST(End2EndJit_HLFHELinalg, add_eint_matrix_column) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
// Returns the addition of a 3x3 matrix of encrypted integers and a 3x1 matrix (a column) of encrypted integers.
//
// [1,2,3] [1] [2,3,4]
// [4,5,6] + [2] = [6,7,8]
// [7,8,9] [3] [10,11,12]
//
// The dimension #1 of operand #2 is stretched as it is equals to 1.
func @main(%a0: tensor<3x3x!HLFHE.eint<4>>, %a1: tensor<3x1x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> {
%res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
return %res : tensor<3x3x!HLFHE.eint<4>>
}
)XXX";
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
};
const uint8_t a1[3][1]{
{1},
{2},
{3},
};
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %a0 and %a1 argument
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)a0, {3, 3}));
ASSERT_LLVM_ERROR(argument->setArg(1, (uint8_t *)a1, {3, 1}));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t result[3][3];
ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, 3 * 3));
for (size_t i = 0; i < 3; i++) {
for (size_t j = 0; j < 3; j++) {
EXPECT_EQ(result[i][j], a0[i][j] + a1[i][0])
<< "result differ at pos " << i << ", expect " << a0[i][j] + a1[i][0]
<< " got " << result[i];
}
}
}
TEST(End2EndJit_HLFHELinalg, add_eint_matrix_line) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
// Returns the addition of a 3x3 matrix of encrypted integers and a 1x3 matrix (a line) of encrypted integers.
//
// [1,2,3] [2,4,6]
// [4,5,6] + [1,2,3] = [5,7,9]
// [7,8,9] [8,10,12]
//
// The dimension #2 of operand #2 is stretched as it is equals to 1.
func @main(%a0: tensor<3x3x!HLFHE.eint<4>>, %a1: tensor<1x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> {
%res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<1x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
return %res : tensor<3x3x!HLFHE.eint<4>>
}
)XXX";
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
};
const uint8_t a1[1][3]{
{1, 2, 3},
};
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %a0 and %a1 argument
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)a0, {3, 3}));
ASSERT_LLVM_ERROR(argument->setArg(1, (uint8_t *)a1, {1, 3}));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t result[3][3];
ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, 3 * 3));
for (size_t i = 0; i < 3; i++) {
for (size_t j = 0; j < 3; j++) {
EXPECT_EQ(result[i][j], a0[i][j] + a1[0][j])
<< "result differ at pos (" << i << "," << j << "), expect "
<< a0[i][j] + a1[0][j] << " got " << result[i][j] << "\n";
}
}
}
TEST(End2EndJit_HLFHELinalg, add_eint_matrix_line_missing_dim) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
// Same behavior than the previous one, but as the dimension #2 of operand #2 is missing.
func @main(%a0: tensor<3x3x!HLFHE.eint<4>>, %a1: tensor<3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> {
%res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
return %res : tensor<3x3x!HLFHE.eint<4>>
}
)XXX";
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
};
const uint8_t a1[1][3]{
{1, 2, 3},
};
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %a0 and %a1 argument
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)a0, {3, 3}));
ASSERT_LLVM_ERROR(argument->setArg(1, (uint8_t *)a1, {3}));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t result[3][3];
ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, 3 * 3));
for (size_t i = 0; i < 3; i++) {
for (size_t j = 0; j < 3; j++) {
EXPECT_EQ(result[i][j], a0[i][j] + a1[0][j])
<< "result differ at pos (" << i << "," << j << "), expect "
<< a0[i][j] + a1[0][j] << " got " << result[i][j] << "\n";
}
}
}